Initial commit
This commit is contained in:
786
server/database.py
Normal file
786
server/database.py
Normal file
@@ -0,0 +1,786 @@
|
||||
"""
|
||||
database.py — PostgreSQL database with asyncpg connection pool.
|
||||
|
||||
Application-level AES-256-GCM encryption for credentials (unchanged from SQLite era).
|
||||
The pool is initialised once at startup via init_db() and closed via close_db().
|
||||
All store methods are async — callers must await them.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import asyncpg
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
||||
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
||||
|
||||
from .config import settings
|
||||
|
||||
# ─── Encryption ───────────────────────────────────────────────────────────────
|
||||
# Unchanged from SQLite version — encrypted blobs are stored as base64 TEXT.
|
||||
|
||||
_SALT = b"aide-credential-store-v1"
|
||||
_ITERATIONS = 480_000
|
||||
|
||||
|
||||
def _derive_key(password: str) -> bytes:
|
||||
kdf = PBKDF2HMAC(
|
||||
algorithm=hashes.SHA256(),
|
||||
length=32,
|
||||
salt=_SALT,
|
||||
iterations=_ITERATIONS,
|
||||
)
|
||||
return kdf.derive(password.encode())
|
||||
|
||||
|
||||
_ENCRYPTION_KEY = _derive_key(settings.db_master_password)
|
||||
|
||||
|
||||
def _encrypt(plaintext: str) -> str:
|
||||
"""Encrypt a string value, return base64-encoded ciphertext (nonce + tag + data)."""
|
||||
aesgcm = AESGCM(_ENCRYPTION_KEY)
|
||||
nonce = os.urandom(12)
|
||||
ciphertext = aesgcm.encrypt(nonce, plaintext.encode(), None)
|
||||
return base64.b64encode(nonce + ciphertext).decode()
|
||||
|
||||
|
||||
def _decrypt(encoded: str) -> str:
|
||||
"""Decrypt a base64-encoded ciphertext, return plaintext string."""
|
||||
data = base64.b64decode(encoded)
|
||||
nonce, ciphertext = data[:12], data[12:]
|
||||
aesgcm = AESGCM(_ENCRYPTION_KEY)
|
||||
return aesgcm.decrypt(nonce, ciphertext, None).decode()
|
||||
|
||||
|
||||
# ─── Connection Pool ──────────────────────────────────────────────────────────
|
||||
|
||||
_pool: asyncpg.Pool | None = None
|
||||
|
||||
|
||||
async def get_pool() -> asyncpg.Pool:
|
||||
"""Return the shared connection pool. Must call init_db() first."""
|
||||
assert _pool is not None, "Database not initialised — call init_db() first"
|
||||
return _pool
|
||||
|
||||
|
||||
# ─── Migrations ───────────────────────────────────────────────────────────────
|
||||
# Each migration is a list of SQL statements (asyncpg runs one statement at a time).
|
||||
# All migrations are idempotent (IF NOT EXISTS / ADD COLUMN IF NOT EXISTS / ON CONFLICT DO NOTHING).
|
||||
|
||||
_MIGRATIONS: list[list[str]] = [
|
||||
# v1 — initial schema
|
||||
[
|
||||
"""CREATE TABLE IF NOT EXISTS schema_version (
|
||||
version INTEGER PRIMARY KEY
|
||||
)""",
|
||||
"""CREATE TABLE IF NOT EXISTS credentials (
|
||||
key TEXT PRIMARY KEY,
|
||||
value_enc TEXT NOT NULL,
|
||||
description TEXT,
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL
|
||||
)""",
|
||||
"""CREATE TABLE IF NOT EXISTS audit_log (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
timestamp TEXT NOT NULL,
|
||||
session_id TEXT,
|
||||
tool_name TEXT NOT NULL,
|
||||
arguments JSONB,
|
||||
result_summary TEXT,
|
||||
confirmed BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
task_id TEXT
|
||||
)""",
|
||||
"CREATE INDEX IF NOT EXISTS idx_audit_timestamp ON audit_log(timestamp)",
|
||||
"CREATE INDEX IF NOT EXISTS idx_audit_session ON audit_log(session_id)",
|
||||
"CREATE INDEX IF NOT EXISTS idx_audit_tool ON audit_log(tool_name)",
|
||||
"""CREATE TABLE IF NOT EXISTS scheduled_tasks (
|
||||
id TEXT PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
description TEXT,
|
||||
schedule TEXT,
|
||||
prompt TEXT NOT NULL,
|
||||
allowed_tools JSONB,
|
||||
enabled BOOLEAN NOT NULL DEFAULT TRUE,
|
||||
last_run TEXT,
|
||||
last_status TEXT,
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL
|
||||
)""",
|
||||
"""CREATE TABLE IF NOT EXISTS conversations (
|
||||
id TEXT PRIMARY KEY,
|
||||
started_at TEXT NOT NULL,
|
||||
ended_at TEXT,
|
||||
messages JSONB NOT NULL,
|
||||
task_id TEXT
|
||||
)""",
|
||||
],
|
||||
# v2 — email whitelist, agents, agent_runs
|
||||
[
|
||||
"""CREATE TABLE IF NOT EXISTS email_whitelist (
|
||||
email TEXT PRIMARY KEY,
|
||||
daily_limit INTEGER NOT NULL DEFAULT 0,
|
||||
created_at TEXT NOT NULL
|
||||
)""",
|
||||
"""CREATE TABLE IF NOT EXISTS agents (
|
||||
id TEXT PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
description TEXT,
|
||||
prompt TEXT NOT NULL,
|
||||
model TEXT NOT NULL,
|
||||
can_create_subagents BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
allowed_tools JSONB,
|
||||
schedule TEXT,
|
||||
enabled BOOLEAN NOT NULL DEFAULT TRUE,
|
||||
parent_agent_id TEXT REFERENCES agents(id),
|
||||
created_by TEXT NOT NULL DEFAULT 'user',
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL
|
||||
)""",
|
||||
"""CREATE TABLE IF NOT EXISTS agent_runs (
|
||||
id TEXT PRIMARY KEY,
|
||||
agent_id TEXT NOT NULL REFERENCES agents(id),
|
||||
started_at TEXT NOT NULL,
|
||||
ended_at TEXT,
|
||||
status TEXT NOT NULL DEFAULT 'running',
|
||||
input_tokens INTEGER NOT NULL DEFAULT 0,
|
||||
output_tokens INTEGER NOT NULL DEFAULT 0,
|
||||
cost_usd REAL,
|
||||
result TEXT,
|
||||
error TEXT
|
||||
)""",
|
||||
"CREATE INDEX IF NOT EXISTS idx_agent_runs_agent_id ON agent_runs(agent_id)",
|
||||
"CREATE INDEX IF NOT EXISTS idx_agent_runs_started_at ON agent_runs(started_at)",
|
||||
"CREATE INDEX IF NOT EXISTS idx_agent_runs_status ON agent_runs(status)",
|
||||
],
|
||||
# v3 — web domain whitelist
|
||||
[
|
||||
"""CREATE TABLE IF NOT EXISTS web_whitelist (
|
||||
domain TEXT PRIMARY KEY,
|
||||
note TEXT NOT NULL DEFAULT '',
|
||||
created_at TEXT NOT NULL
|
||||
)""",
|
||||
"INSERT INTO web_whitelist (domain, note, created_at) VALUES ('duckduckgo.com', 'DuckDuckGo search', '2024-01-01T00:00:00+00:00') ON CONFLICT DO NOTHING",
|
||||
"INSERT INTO web_whitelist (domain, note, created_at) VALUES ('wikipedia.org', 'Wikipedia', '2024-01-01T00:00:00+00:00') ON CONFLICT DO NOTHING",
|
||||
"INSERT INTO web_whitelist (domain, note, created_at) VALUES ('weather.met.no', 'Norwegian Meteorological Institute', '2024-01-01T00:00:00+00:00') ON CONFLICT DO NOTHING",
|
||||
"INSERT INTO web_whitelist (domain, note, created_at) VALUES ('api.met.no', 'Norwegian Meteorological API', '2024-01-01T00:00:00+00:00') ON CONFLICT DO NOTHING",
|
||||
"INSERT INTO web_whitelist (domain, note, created_at) VALUES ('yr.no', 'Yr weather service', '2024-01-01T00:00:00+00:00') ON CONFLICT DO NOTHING",
|
||||
"INSERT INTO web_whitelist (domain, note, created_at) VALUES ('timeanddate.com', 'Time and Date', '2024-01-01T00:00:00+00:00') ON CONFLICT DO NOTHING",
|
||||
],
|
||||
# v4 — filesystem sandbox whitelist
|
||||
[
|
||||
"""CREATE TABLE IF NOT EXISTS filesystem_whitelist (
|
||||
path TEXT PRIMARY KEY,
|
||||
note TEXT NOT NULL DEFAULT '',
|
||||
created_at TEXT NOT NULL
|
||||
)""",
|
||||
],
|
||||
# v5 — optional agent assignment for scheduled tasks
|
||||
[
|
||||
"ALTER TABLE scheduled_tasks ADD COLUMN IF NOT EXISTS agent_id TEXT REFERENCES agents(id)",
|
||||
],
|
||||
# v6 — per-agent max_tool_calls override
|
||||
[
|
||||
"ALTER TABLE agents ADD COLUMN IF NOT EXISTS max_tool_calls INTEGER",
|
||||
],
|
||||
# v7 — email inbox trigger rules
|
||||
[
|
||||
"""CREATE TABLE IF NOT EXISTS email_triggers (
|
||||
id TEXT PRIMARY KEY,
|
||||
trigger_word TEXT NOT NULL,
|
||||
agent_id TEXT NOT NULL,
|
||||
description TEXT NOT NULL DEFAULT '',
|
||||
enabled BOOLEAN NOT NULL DEFAULT TRUE,
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL
|
||||
)""",
|
||||
],
|
||||
# v8 — Telegram bot integration
|
||||
[
|
||||
"""CREATE TABLE IF NOT EXISTS telegram_whitelist (
|
||||
chat_id TEXT PRIMARY KEY,
|
||||
label TEXT NOT NULL DEFAULT '',
|
||||
created_at TEXT NOT NULL
|
||||
)""",
|
||||
"""CREATE TABLE IF NOT EXISTS telegram_triggers (
|
||||
id TEXT PRIMARY KEY,
|
||||
trigger_word TEXT NOT NULL,
|
||||
agent_id TEXT NOT NULL,
|
||||
description TEXT NOT NULL DEFAULT '',
|
||||
enabled BOOLEAN NOT NULL DEFAULT TRUE,
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL
|
||||
)""",
|
||||
],
|
||||
# v9 — agent prompt_mode column
|
||||
[
|
||||
"ALTER TABLE agents ADD COLUMN IF NOT EXISTS prompt_mode TEXT NOT NULL DEFAULT 'combined'",
|
||||
],
|
||||
# v10 — (was SQLite re-apply of v9; no-op here)
|
||||
[],
|
||||
# v11 — MCP client server configurations
|
||||
[
|
||||
"""CREATE TABLE IF NOT EXISTS mcp_servers (
|
||||
id TEXT PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
url TEXT NOT NULL,
|
||||
transport TEXT NOT NULL DEFAULT 'sse',
|
||||
api_key_enc TEXT,
|
||||
headers_enc TEXT,
|
||||
enabled BOOLEAN NOT NULL DEFAULT TRUE,
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL
|
||||
)""",
|
||||
],
|
||||
# v12 — users table for multi-user support (Part 2)
|
||||
[
|
||||
"""CREATE TABLE IF NOT EXISTS users (
|
||||
id TEXT PRIMARY KEY,
|
||||
username TEXT NOT NULL UNIQUE,
|
||||
password_hash TEXT NOT NULL,
|
||||
role TEXT NOT NULL DEFAULT 'user',
|
||||
is_active BOOLEAN NOT NULL DEFAULT TRUE,
|
||||
totp_secret TEXT,
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL
|
||||
)""",
|
||||
"CREATE INDEX IF NOT EXISTS idx_users_username ON users(username)",
|
||||
"ALTER TABLE agents ADD COLUMN IF NOT EXISTS owner_user_id TEXT REFERENCES users(id)",
|
||||
"ALTER TABLE conversations ADD COLUMN IF NOT EXISTS user_id TEXT REFERENCES users(id)",
|
||||
"ALTER TABLE audit_log ADD COLUMN IF NOT EXISTS user_id TEXT REFERENCES users(id)",
|
||||
],
|
||||
# v13 — add email column to users
|
||||
[
|
||||
"ALTER TABLE users ADD COLUMN IF NOT EXISTS email TEXT",
|
||||
],
|
||||
# v14 — per-user settings table + user_id columns on multi-tenant tables
|
||||
[
|
||||
"""CREATE TABLE IF NOT EXISTS user_settings (
|
||||
user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
key TEXT NOT NULL,
|
||||
value TEXT,
|
||||
PRIMARY KEY (user_id, key)
|
||||
)""",
|
||||
"ALTER TABLE email_triggers ADD COLUMN IF NOT EXISTS user_id TEXT REFERENCES users(id)",
|
||||
"ALTER TABLE telegram_triggers ADD COLUMN IF NOT EXISTS user_id TEXT REFERENCES users(id)",
|
||||
"ALTER TABLE telegram_whitelist ADD COLUMN IF NOT EXISTS user_id TEXT REFERENCES users(id)",
|
||||
"ALTER TABLE mcp_servers ADD COLUMN IF NOT EXISTS user_id TEXT REFERENCES users(id)",
|
||||
],
|
||||
# v15 — fix telegram_whitelist unique constraint to allow (chat_id, user_id) pairs
|
||||
# Uses NULLS NOT DISTINCT (PostgreSQL 15+) so (chat_id, NULL) is unique per global entry
|
||||
[
|
||||
# Drop old primary key constraint so chat_id alone no longer enforces uniqueness
|
||||
"""DO $$ BEGIN
|
||||
IF EXISTS (
|
||||
SELECT 1 FROM pg_constraint
|
||||
WHERE conname = 'telegram_whitelist_pkey' AND conrelid = 'telegram_whitelist'::regclass
|
||||
) THEN
|
||||
ALTER TABLE telegram_whitelist DROP CONSTRAINT telegram_whitelist_pkey;
|
||||
END IF;
|
||||
END $$""",
|
||||
# Add a surrogate UUID primary key
|
||||
"ALTER TABLE telegram_whitelist ADD COLUMN IF NOT EXISTS id UUID DEFAULT gen_random_uuid()",
|
||||
# Make it not null and set primary key (only if not already set)
|
||||
"""DO $$ BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM pg_constraint
|
||||
WHERE conname = 'telegram_whitelist_pk' AND conrelid = 'telegram_whitelist'::regclass
|
||||
) THEN
|
||||
ALTER TABLE telegram_whitelist ADD CONSTRAINT telegram_whitelist_pk PRIMARY KEY (id);
|
||||
END IF;
|
||||
END $$""",
|
||||
# Create unique index on (chat_id, user_id) NULLS NOT DISTINCT
|
||||
"""CREATE UNIQUE INDEX IF NOT EXISTS telegram_whitelist_chat_user_idx
|
||||
ON telegram_whitelist (chat_id, user_id) NULLS NOT DISTINCT""",
|
||||
],
|
||||
# v16 — email_accounts table for multi-account email handling
|
||||
[
|
||||
"""CREATE TABLE IF NOT EXISTS email_accounts (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT REFERENCES users(id),
|
||||
label TEXT NOT NULL,
|
||||
account_type TEXT NOT NULL DEFAULT 'handling',
|
||||
imap_host TEXT NOT NULL,
|
||||
imap_port INTEGER NOT NULL DEFAULT 993,
|
||||
imap_username TEXT NOT NULL,
|
||||
imap_password TEXT NOT NULL,
|
||||
smtp_host TEXT,
|
||||
smtp_port INTEGER,
|
||||
smtp_username TEXT,
|
||||
smtp_password TEXT,
|
||||
agent_id TEXT REFERENCES agents(id),
|
||||
enabled BOOLEAN NOT NULL DEFAULT TRUE,
|
||||
initial_load_done BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
initial_load_limit INTEGER NOT NULL DEFAULT 200,
|
||||
monitored_folders TEXT NOT NULL DEFAULT '[\"INBOX\"]',
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL
|
||||
)""",
|
||||
"ALTER TABLE email_triggers ADD COLUMN IF NOT EXISTS account_id UUID REFERENCES email_accounts(id)",
|
||||
],
|
||||
# v17 — convert audit_log.arguments from TEXT to JSONB (SQLite-migrated DBs have TEXT)
|
||||
# and agents/scheduled_tasks allowed_tools from TEXT to JSONB if not already
|
||||
[
|
||||
"""DO $$
|
||||
BEGIN
|
||||
IF (SELECT data_type FROM information_schema.columns
|
||||
WHERE table_name='audit_log' AND column_name='arguments') = 'text' THEN
|
||||
ALTER TABLE audit_log
|
||||
ALTER COLUMN arguments TYPE JSONB
|
||||
USING CASE WHEN arguments IS NULL OR arguments = '' THEN NULL
|
||||
ELSE arguments::jsonb END;
|
||||
END IF;
|
||||
END $$""",
|
||||
"""DO $$
|
||||
BEGIN
|
||||
IF (SELECT data_type FROM information_schema.columns
|
||||
WHERE table_name='agents' AND column_name='allowed_tools') = 'text' THEN
|
||||
ALTER TABLE agents
|
||||
ALTER COLUMN allowed_tools TYPE JSONB
|
||||
USING CASE WHEN allowed_tools IS NULL OR allowed_tools = '' THEN NULL
|
||||
ELSE allowed_tools::jsonb END;
|
||||
END IF;
|
||||
END $$""",
|
||||
"""DO $$
|
||||
BEGIN
|
||||
IF (SELECT data_type FROM information_schema.columns
|
||||
WHERE table_name='scheduled_tasks' AND column_name='allowed_tools') = 'text' THEN
|
||||
ALTER TABLE scheduled_tasks
|
||||
ALTER COLUMN allowed_tools TYPE JSONB
|
||||
USING CASE WHEN allowed_tools IS NULL OR allowed_tools = '' THEN NULL
|
||||
ELSE allowed_tools::jsonb END;
|
||||
END IF;
|
||||
END $$""",
|
||||
],
|
||||
# v18 — MFA challenge table for TOTP second-factor login
|
||||
[
|
||||
"""CREATE TABLE IF NOT EXISTS mfa_challenges (
|
||||
token TEXT PRIMARY KEY,
|
||||
user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
next_url TEXT NOT NULL DEFAULT '/',
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
expires_at TIMESTAMPTZ NOT NULL,
|
||||
attempts INTEGER NOT NULL DEFAULT 0
|
||||
)""",
|
||||
"CREATE INDEX IF NOT EXISTS idx_mfa_challenges_expires ON mfa_challenges(expires_at)",
|
||||
],
|
||||
# v19 — display name for users (editable, separate from username)
|
||||
[
|
||||
"ALTER TABLE users ADD COLUMN IF NOT EXISTS display_name TEXT",
|
||||
],
|
||||
# v20 — extra notification tools for handling email accounts
|
||||
[
|
||||
"ALTER TABLE email_accounts ADD COLUMN IF NOT EXISTS extra_tools JSONB DEFAULT '[]'",
|
||||
],
|
||||
# v21 — bound Telegram chat_id for email handling accounts
|
||||
[
|
||||
"ALTER TABLE email_accounts ADD COLUMN IF NOT EXISTS telegram_chat_id TEXT",
|
||||
],
|
||||
# v22 — Telegram keyword routing + pause flag for email handling accounts
|
||||
[
|
||||
"ALTER TABLE email_accounts ADD COLUMN IF NOT EXISTS telegram_keyword TEXT",
|
||||
"ALTER TABLE email_accounts ADD COLUMN IF NOT EXISTS paused BOOLEAN DEFAULT FALSE",
|
||||
],
|
||||
# v23 — Conversation title for chat history UI
|
||||
[
|
||||
"ALTER TABLE conversations ADD COLUMN IF NOT EXISTS title TEXT",
|
||||
],
|
||||
# v24 — Store model ID used in each conversation
|
||||
[
|
||||
"ALTER TABLE conversations ADD COLUMN IF NOT EXISTS model TEXT",
|
||||
],
|
||||
]
|
||||
|
||||
|
||||
async def _run_migrations(conn: asyncpg.Connection) -> None:
|
||||
"""Apply pending migrations idempotently, each in its own transaction."""
|
||||
await conn.execute(
|
||||
"CREATE TABLE IF NOT EXISTS schema_version (version INTEGER PRIMARY KEY)"
|
||||
)
|
||||
current: int = await conn.fetchval(
|
||||
"SELECT COALESCE(MAX(version), 0) FROM schema_version"
|
||||
) or 0
|
||||
|
||||
for i, statements in enumerate(_MIGRATIONS, start=1):
|
||||
if i <= current:
|
||||
continue
|
||||
async with conn.transaction():
|
||||
for sql in statements:
|
||||
sql = sql.strip()
|
||||
if sql:
|
||||
await conn.execute(sql)
|
||||
await conn.execute(
|
||||
"INSERT INTO schema_version (version) VALUES ($1) ON CONFLICT DO NOTHING", i
|
||||
)
|
||||
print(f"[aide] Applied database migration v{i}")
|
||||
|
||||
|
||||
# ─── Helpers ──────────────────────────────────────────────────────────────────
|
||||
|
||||
def _utcnow() -> str:
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
def _jsonify(obj: Any) -> Any:
|
||||
"""Return a JSON-safe version of obj (converts non-serializable values to strings)."""
|
||||
if obj is None:
|
||||
return None
|
||||
return json.loads(json.dumps(obj, default=str))
|
||||
|
||||
|
||||
def _rowcount(status: str) -> int:
|
||||
"""Parse asyncpg execute() status string like 'DELETE 3' → 3."""
|
||||
try:
|
||||
return int(status.split()[-1])
|
||||
except (ValueError, IndexError):
|
||||
return 0
|
||||
|
||||
|
||||
# ─── Credential Store ─────────────────────────────────────────────────────────
|
||||
|
||||
class CredentialStore:
|
||||
"""Encrypted key-value store for sensitive credentials."""
|
||||
|
||||
async def get(self, key: str) -> str | None:
|
||||
pool = await get_pool()
|
||||
row = await pool.fetchrow(
|
||||
"SELECT value_enc FROM credentials WHERE key = $1", key
|
||||
)
|
||||
if row is None:
|
||||
return None
|
||||
return _decrypt(row["value_enc"])
|
||||
|
||||
async def set(self, key: str, value: str, description: str = "") -> None:
|
||||
now = _utcnow()
|
||||
encrypted = _encrypt(value)
|
||||
pool = await get_pool()
|
||||
await pool.execute(
|
||||
"""
|
||||
INSERT INTO credentials (key, value_enc, description, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
ON CONFLICT (key) DO UPDATE SET
|
||||
value_enc = EXCLUDED.value_enc,
|
||||
description = EXCLUDED.description,
|
||||
updated_at = EXCLUDED.updated_at
|
||||
""",
|
||||
key, encrypted, description, now, now,
|
||||
)
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
pool = await get_pool()
|
||||
status = await pool.execute("DELETE FROM credentials WHERE key = $1", key)
|
||||
return _rowcount(status) > 0
|
||||
|
||||
async def list_keys(self) -> list[dict]:
|
||||
pool = await get_pool()
|
||||
rows = await pool.fetch(
|
||||
"SELECT key, description, created_at, updated_at FROM credentials ORDER BY key"
|
||||
)
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
async def require(self, key: str) -> str:
|
||||
value = await self.get(key)
|
||||
if not value:
|
||||
raise RuntimeError(
|
||||
f"Credential '{key}' is not configured. Add it via /settings."
|
||||
)
|
||||
return value
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
credential_store = CredentialStore()
|
||||
|
||||
|
||||
# ─── User Settings Store ──────────────────────────────────────────────────────
|
||||
|
||||
class UserSettingsStore:
|
||||
"""Per-user key/value settings. Values are plaintext (not encrypted)."""
|
||||
|
||||
async def get(self, user_id: str, key: str) -> str | None:
|
||||
pool = await get_pool()
|
||||
return await pool.fetchval(
|
||||
"SELECT value FROM user_settings WHERE user_id = $1 AND key = $2",
|
||||
user_id, key,
|
||||
)
|
||||
|
||||
async def set(self, user_id: str, key: str, value: str) -> None:
|
||||
pool = await get_pool()
|
||||
await pool.execute(
|
||||
"""
|
||||
INSERT INTO user_settings (user_id, key, value)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT (user_id, key) DO UPDATE SET value = EXCLUDED.value
|
||||
""",
|
||||
user_id, key, value,
|
||||
)
|
||||
|
||||
async def delete(self, user_id: str, key: str) -> bool:
|
||||
pool = await get_pool()
|
||||
status = await pool.execute(
|
||||
"DELETE FROM user_settings WHERE user_id = $1 AND key = $2", user_id, key
|
||||
)
|
||||
return _rowcount(status) > 0
|
||||
|
||||
async def get_with_global_fallback(self, user_id: str, key: str, global_key: str) -> str | None:
|
||||
"""Try user-specific setting, fall back to global credential_store key."""
|
||||
val = await self.get(user_id, key)
|
||||
if val:
|
||||
return val
|
||||
return await credential_store.get(global_key)
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
user_settings_store = UserSettingsStore()
|
||||
|
||||
|
||||
# ─── Email Whitelist Store ────────────────────────────────────────────────────
|
||||
|
||||
class EmailWhitelistStore:
|
||||
"""Manage allowed email recipients with optional per-address daily rate limits."""
|
||||
|
||||
async def list(self) -> list[dict]:
|
||||
pool = await get_pool()
|
||||
rows = await pool.fetch(
|
||||
"SELECT email, daily_limit, created_at FROM email_whitelist ORDER BY email"
|
||||
)
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
async def add(self, email: str, daily_limit: int = 0) -> None:
|
||||
now = _utcnow()
|
||||
normalized = email.strip().lower()
|
||||
pool = await get_pool()
|
||||
await pool.execute(
|
||||
"""
|
||||
INSERT INTO email_whitelist (email, daily_limit, created_at)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT (email) DO UPDATE SET daily_limit = EXCLUDED.daily_limit
|
||||
""",
|
||||
normalized, daily_limit, now,
|
||||
)
|
||||
|
||||
async def remove(self, email: str) -> bool:
|
||||
normalized = email.strip().lower()
|
||||
pool = await get_pool()
|
||||
status = await pool.execute(
|
||||
"DELETE FROM email_whitelist WHERE email = $1", normalized
|
||||
)
|
||||
return _rowcount(status) > 0
|
||||
|
||||
async def get(self, email: str) -> dict | None:
|
||||
normalized = email.strip().lower()
|
||||
pool = await get_pool()
|
||||
row = await pool.fetchrow(
|
||||
"SELECT email, daily_limit, created_at FROM email_whitelist WHERE email = $1",
|
||||
normalized,
|
||||
)
|
||||
return dict(row) if row else None
|
||||
|
||||
async def check_rate_limit(self, email: str) -> tuple[bool, int, int]:
|
||||
"""
|
||||
Check whether sending to this address is within the daily limit.
|
||||
Returns (allowed, count_today, limit). limit=0 means unlimited.
|
||||
"""
|
||||
entry = await self.get(email)
|
||||
if entry is None:
|
||||
return False, 0, 0
|
||||
|
||||
limit = entry["daily_limit"]
|
||||
if limit == 0:
|
||||
return True, 0, 0
|
||||
|
||||
# Compute start of today in UTC as ISO8601 string for TEXT comparison
|
||||
today_start = (
|
||||
datetime.now(timezone.utc)
|
||||
.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
.isoformat()
|
||||
)
|
||||
pool = await get_pool()
|
||||
count: int = await pool.fetchval(
|
||||
"""
|
||||
SELECT COUNT(*) FROM audit_log
|
||||
WHERE tool_name = 'email'
|
||||
AND arguments->>'operation' = 'send_email'
|
||||
AND arguments->>'to' = $1
|
||||
AND timestamp >= $2
|
||||
AND (result_summary IS NULL OR result_summary NOT LIKE '%"success": false%')
|
||||
""",
|
||||
email.strip().lower(),
|
||||
today_start,
|
||||
) or 0
|
||||
|
||||
return count < limit, count, limit
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
email_whitelist_store = EmailWhitelistStore()
|
||||
|
||||
|
||||
# ─── Web Whitelist Store ──────────────────────────────────────────────────────
|
||||
|
||||
class WebWhitelistStore:
|
||||
"""Manage Tier-1 always-allowed web domains."""
|
||||
|
||||
async def list(self) -> list[dict]:
|
||||
pool = await get_pool()
|
||||
rows = await pool.fetch(
|
||||
"SELECT domain, note, created_at FROM web_whitelist ORDER BY domain"
|
||||
)
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
async def add(self, domain: str, note: str = "") -> None:
|
||||
normalized = _normalize_domain(domain)
|
||||
now = _utcnow()
|
||||
pool = await get_pool()
|
||||
await pool.execute(
|
||||
"""
|
||||
INSERT INTO web_whitelist (domain, note, created_at)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT (domain) DO UPDATE SET note = EXCLUDED.note
|
||||
""",
|
||||
normalized, note, now,
|
||||
)
|
||||
|
||||
async def remove(self, domain: str) -> bool:
|
||||
normalized = _normalize_domain(domain)
|
||||
pool = await get_pool()
|
||||
status = await pool.execute(
|
||||
"DELETE FROM web_whitelist WHERE domain = $1", normalized
|
||||
)
|
||||
return _rowcount(status) > 0
|
||||
|
||||
async def is_allowed(self, url: str) -> bool:
|
||||
"""Return True if the URL's hostname matches a whitelisted domain or subdomain."""
|
||||
try:
|
||||
hostname = urlparse(url).hostname or ""
|
||||
except Exception:
|
||||
return False
|
||||
if not hostname:
|
||||
return False
|
||||
domains = await self.list()
|
||||
for entry in domains:
|
||||
d = entry["domain"]
|
||||
if hostname == d or hostname.endswith("." + d):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _normalize_domain(domain: str) -> str:
|
||||
"""Strip scheme and path, return lowercase hostname only."""
|
||||
d = domain.strip().lower()
|
||||
if "://" not in d:
|
||||
d = "https://" + d
|
||||
parsed = urlparse(d)
|
||||
return parsed.hostname or domain.strip().lower()
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
web_whitelist_store = WebWhitelistStore()
|
||||
|
||||
|
||||
# ─── Filesystem Whitelist Store ───────────────────────────────────────────────
|
||||
|
||||
class FilesystemWhitelistStore:
|
||||
"""Manage allowed filesystem sandbox directories."""
|
||||
|
||||
async def list(self) -> list[dict]:
|
||||
pool = await get_pool()
|
||||
rows = await pool.fetch(
|
||||
"SELECT path, note, created_at FROM filesystem_whitelist ORDER BY path"
|
||||
)
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
async def add(self, path: str, note: str = "") -> None:
|
||||
from pathlib import Path as _Path
|
||||
normalized = str(_Path(path).resolve())
|
||||
now = _utcnow()
|
||||
pool = await get_pool()
|
||||
await pool.execute(
|
||||
"""
|
||||
INSERT INTO filesystem_whitelist (path, note, created_at)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT (path) DO UPDATE SET note = EXCLUDED.note
|
||||
""",
|
||||
normalized, note, now,
|
||||
)
|
||||
|
||||
async def remove(self, path: str) -> bool:
|
||||
from pathlib import Path as _Path
|
||||
normalized = str(_Path(path).resolve())
|
||||
pool = await get_pool()
|
||||
status = await pool.execute(
|
||||
"DELETE FROM filesystem_whitelist WHERE path = $1", normalized
|
||||
)
|
||||
if _rowcount(status) == 0:
|
||||
# Fallback: try exact match without resolving
|
||||
status = await pool.execute(
|
||||
"DELETE FROM filesystem_whitelist WHERE path = $1", path
|
||||
)
|
||||
return _rowcount(status) > 0
|
||||
|
||||
async def is_allowed(self, path: Any) -> tuple[bool, str]:
|
||||
"""
|
||||
Check if path is inside any whitelisted directory.
|
||||
Returns (allowed, resolved_path_str).
|
||||
"""
|
||||
from pathlib import Path as _Path
|
||||
try:
|
||||
resolved = _Path(path).resolve()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid path: {e}")
|
||||
|
||||
sandboxes = await self.list()
|
||||
for entry in sandboxes:
|
||||
try:
|
||||
resolved.relative_to(_Path(entry["path"]).resolve())
|
||||
return True, str(resolved)
|
||||
except ValueError:
|
||||
continue
|
||||
return False, str(resolved)
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
filesystem_whitelist_store = FilesystemWhitelistStore()
|
||||
|
||||
|
||||
# ─── Initialisation ───────────────────────────────────────────────────────────
|
||||
|
||||
async def _init_connection(conn: asyncpg.Connection) -> None:
|
||||
"""Register codecs on every new connection so asyncpg handles JSONB ↔ dict."""
|
||||
await conn.set_type_codec(
|
||||
"jsonb",
|
||||
encoder=json.dumps,
|
||||
decoder=json.loads,
|
||||
schema="pg_catalog",
|
||||
)
|
||||
await conn.set_type_codec(
|
||||
"json",
|
||||
encoder=json.dumps,
|
||||
decoder=json.loads,
|
||||
schema="pg_catalog",
|
||||
)
|
||||
|
||||
|
||||
async def init_db() -> None:
|
||||
"""Initialise the connection pool and run migrations. Call once at startup."""
|
||||
global _pool
|
||||
_pool = await asyncpg.create_pool(
|
||||
settings.aide_db_url,
|
||||
min_size=2,
|
||||
max_size=10,
|
||||
init=_init_connection,
|
||||
)
|
||||
async with _pool.acquire() as conn:
|
||||
await _run_migrations(conn)
|
||||
print(f"[aide] Database ready: {settings.aide_db_url.split('@')[-1]}")
|
||||
|
||||
|
||||
async def close_db() -> None:
|
||||
"""Close the connection pool. Call at shutdown."""
|
||||
global _pool
|
||||
if _pool:
|
||||
await _pool.close()
|
||||
_pool = None
|
||||
Reference in New Issue
Block a user