787 lines
30 KiB
Python
787 lines
30 KiB
Python
"""
|
|
database.py — PostgreSQL database with asyncpg connection pool.
|
|
|
|
Application-level AES-256-GCM encryption for credentials (unchanged from SQLite era).
|
|
The pool is initialised once at startup via init_db() and closed via close_db().
|
|
All store methods are async — callers must await them.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import base64
|
|
import json
|
|
import os
|
|
from datetime import datetime, timezone
|
|
from typing import Any
|
|
from urllib.parse import urlparse
|
|
|
|
import asyncpg
|
|
from cryptography.hazmat.primitives import hashes
|
|
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
|
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
|
|
|
from .config import settings
|
|
|
|
# ─── Encryption ───────────────────────────────────────────────────────────────
|
|
# Unchanged from SQLite version — encrypted blobs are stored as base64 TEXT.
|
|
|
|
_SALT = b"aide-credential-store-v1"
|
|
_ITERATIONS = 480_000
|
|
|
|
|
|
def _derive_key(password: str) -> bytes:
|
|
kdf = PBKDF2HMAC(
|
|
algorithm=hashes.SHA256(),
|
|
length=32,
|
|
salt=_SALT,
|
|
iterations=_ITERATIONS,
|
|
)
|
|
return kdf.derive(password.encode())
|
|
|
|
|
|
_ENCRYPTION_KEY = _derive_key(settings.db_master_password)
|
|
|
|
|
|
def _encrypt(plaintext: str) -> str:
|
|
"""Encrypt a string value, return base64-encoded ciphertext (nonce + tag + data)."""
|
|
aesgcm = AESGCM(_ENCRYPTION_KEY)
|
|
nonce = os.urandom(12)
|
|
ciphertext = aesgcm.encrypt(nonce, plaintext.encode(), None)
|
|
return base64.b64encode(nonce + ciphertext).decode()
|
|
|
|
|
|
def _decrypt(encoded: str) -> str:
|
|
"""Decrypt a base64-encoded ciphertext, return plaintext string."""
|
|
data = base64.b64decode(encoded)
|
|
nonce, ciphertext = data[:12], data[12:]
|
|
aesgcm = AESGCM(_ENCRYPTION_KEY)
|
|
return aesgcm.decrypt(nonce, ciphertext, None).decode()
|
|
|
|
|
|
# ─── Connection Pool ──────────────────────────────────────────────────────────
|
|
|
|
_pool: asyncpg.Pool | None = None
|
|
|
|
|
|
async def get_pool() -> asyncpg.Pool:
|
|
"""Return the shared connection pool. Must call init_db() first."""
|
|
assert _pool is not None, "Database not initialised — call init_db() first"
|
|
return _pool
|
|
|
|
|
|
# ─── Migrations ───────────────────────────────────────────────────────────────
|
|
# Each migration is a list of SQL statements (asyncpg runs one statement at a time).
|
|
# All migrations are idempotent (IF NOT EXISTS / ADD COLUMN IF NOT EXISTS / ON CONFLICT DO NOTHING).
|
|
|
|
_MIGRATIONS: list[list[str]] = [
|
|
# v1 — initial schema
|
|
[
|
|
"""CREATE TABLE IF NOT EXISTS schema_version (
|
|
version INTEGER PRIMARY KEY
|
|
)""",
|
|
"""CREATE TABLE IF NOT EXISTS credentials (
|
|
key TEXT PRIMARY KEY,
|
|
value_enc TEXT NOT NULL,
|
|
description TEXT,
|
|
created_at TEXT NOT NULL,
|
|
updated_at TEXT NOT NULL
|
|
)""",
|
|
"""CREATE TABLE IF NOT EXISTS audit_log (
|
|
id BIGSERIAL PRIMARY KEY,
|
|
timestamp TEXT NOT NULL,
|
|
session_id TEXT,
|
|
tool_name TEXT NOT NULL,
|
|
arguments JSONB,
|
|
result_summary TEXT,
|
|
confirmed BOOLEAN NOT NULL DEFAULT FALSE,
|
|
task_id TEXT
|
|
)""",
|
|
"CREATE INDEX IF NOT EXISTS idx_audit_timestamp ON audit_log(timestamp)",
|
|
"CREATE INDEX IF NOT EXISTS idx_audit_session ON audit_log(session_id)",
|
|
"CREATE INDEX IF NOT EXISTS idx_audit_tool ON audit_log(tool_name)",
|
|
"""CREATE TABLE IF NOT EXISTS scheduled_tasks (
|
|
id TEXT PRIMARY KEY,
|
|
name TEXT NOT NULL,
|
|
description TEXT,
|
|
schedule TEXT,
|
|
prompt TEXT NOT NULL,
|
|
allowed_tools JSONB,
|
|
enabled BOOLEAN NOT NULL DEFAULT TRUE,
|
|
last_run TEXT,
|
|
last_status TEXT,
|
|
created_at TEXT NOT NULL,
|
|
updated_at TEXT NOT NULL
|
|
)""",
|
|
"""CREATE TABLE IF NOT EXISTS conversations (
|
|
id TEXT PRIMARY KEY,
|
|
started_at TEXT NOT NULL,
|
|
ended_at TEXT,
|
|
messages JSONB NOT NULL,
|
|
task_id TEXT
|
|
)""",
|
|
],
|
|
# v2 — email whitelist, agents, agent_runs
|
|
[
|
|
"""CREATE TABLE IF NOT EXISTS email_whitelist (
|
|
email TEXT PRIMARY KEY,
|
|
daily_limit INTEGER NOT NULL DEFAULT 0,
|
|
created_at TEXT NOT NULL
|
|
)""",
|
|
"""CREATE TABLE IF NOT EXISTS agents (
|
|
id TEXT PRIMARY KEY,
|
|
name TEXT NOT NULL,
|
|
description TEXT,
|
|
prompt TEXT NOT NULL,
|
|
model TEXT NOT NULL,
|
|
can_create_subagents BOOLEAN NOT NULL DEFAULT FALSE,
|
|
allowed_tools JSONB,
|
|
schedule TEXT,
|
|
enabled BOOLEAN NOT NULL DEFAULT TRUE,
|
|
parent_agent_id TEXT REFERENCES agents(id),
|
|
created_by TEXT NOT NULL DEFAULT 'user',
|
|
created_at TEXT NOT NULL,
|
|
updated_at TEXT NOT NULL
|
|
)""",
|
|
"""CREATE TABLE IF NOT EXISTS agent_runs (
|
|
id TEXT PRIMARY KEY,
|
|
agent_id TEXT NOT NULL REFERENCES agents(id),
|
|
started_at TEXT NOT NULL,
|
|
ended_at TEXT,
|
|
status TEXT NOT NULL DEFAULT 'running',
|
|
input_tokens INTEGER NOT NULL DEFAULT 0,
|
|
output_tokens INTEGER NOT NULL DEFAULT 0,
|
|
cost_usd REAL,
|
|
result TEXT,
|
|
error TEXT
|
|
)""",
|
|
"CREATE INDEX IF NOT EXISTS idx_agent_runs_agent_id ON agent_runs(agent_id)",
|
|
"CREATE INDEX IF NOT EXISTS idx_agent_runs_started_at ON agent_runs(started_at)",
|
|
"CREATE INDEX IF NOT EXISTS idx_agent_runs_status ON agent_runs(status)",
|
|
],
|
|
# v3 — web domain whitelist
|
|
[
|
|
"""CREATE TABLE IF NOT EXISTS web_whitelist (
|
|
domain TEXT PRIMARY KEY,
|
|
note TEXT NOT NULL DEFAULT '',
|
|
created_at TEXT NOT NULL
|
|
)""",
|
|
"INSERT INTO web_whitelist (domain, note, created_at) VALUES ('duckduckgo.com', 'DuckDuckGo search', '2024-01-01T00:00:00+00:00') ON CONFLICT DO NOTHING",
|
|
"INSERT INTO web_whitelist (domain, note, created_at) VALUES ('wikipedia.org', 'Wikipedia', '2024-01-01T00:00:00+00:00') ON CONFLICT DO NOTHING",
|
|
"INSERT INTO web_whitelist (domain, note, created_at) VALUES ('weather.met.no', 'Norwegian Meteorological Institute', '2024-01-01T00:00:00+00:00') ON CONFLICT DO NOTHING",
|
|
"INSERT INTO web_whitelist (domain, note, created_at) VALUES ('api.met.no', 'Norwegian Meteorological API', '2024-01-01T00:00:00+00:00') ON CONFLICT DO NOTHING",
|
|
"INSERT INTO web_whitelist (domain, note, created_at) VALUES ('yr.no', 'Yr weather service', '2024-01-01T00:00:00+00:00') ON CONFLICT DO NOTHING",
|
|
"INSERT INTO web_whitelist (domain, note, created_at) VALUES ('timeanddate.com', 'Time and Date', '2024-01-01T00:00:00+00:00') ON CONFLICT DO NOTHING",
|
|
],
|
|
# v4 — filesystem sandbox whitelist
|
|
[
|
|
"""CREATE TABLE IF NOT EXISTS filesystem_whitelist (
|
|
path TEXT PRIMARY KEY,
|
|
note TEXT NOT NULL DEFAULT '',
|
|
created_at TEXT NOT NULL
|
|
)""",
|
|
],
|
|
# v5 — optional agent assignment for scheduled tasks
|
|
[
|
|
"ALTER TABLE scheduled_tasks ADD COLUMN IF NOT EXISTS agent_id TEXT REFERENCES agents(id)",
|
|
],
|
|
# v6 — per-agent max_tool_calls override
|
|
[
|
|
"ALTER TABLE agents ADD COLUMN IF NOT EXISTS max_tool_calls INTEGER",
|
|
],
|
|
# v7 — email inbox trigger rules
|
|
[
|
|
"""CREATE TABLE IF NOT EXISTS email_triggers (
|
|
id TEXT PRIMARY KEY,
|
|
trigger_word TEXT NOT NULL,
|
|
agent_id TEXT NOT NULL,
|
|
description TEXT NOT NULL DEFAULT '',
|
|
enabled BOOLEAN NOT NULL DEFAULT TRUE,
|
|
created_at TEXT NOT NULL,
|
|
updated_at TEXT NOT NULL
|
|
)""",
|
|
],
|
|
# v8 — Telegram bot integration
|
|
[
|
|
"""CREATE TABLE IF NOT EXISTS telegram_whitelist (
|
|
chat_id TEXT PRIMARY KEY,
|
|
label TEXT NOT NULL DEFAULT '',
|
|
created_at TEXT NOT NULL
|
|
)""",
|
|
"""CREATE TABLE IF NOT EXISTS telegram_triggers (
|
|
id TEXT PRIMARY KEY,
|
|
trigger_word TEXT NOT NULL,
|
|
agent_id TEXT NOT NULL,
|
|
description TEXT NOT NULL DEFAULT '',
|
|
enabled BOOLEAN NOT NULL DEFAULT TRUE,
|
|
created_at TEXT NOT NULL,
|
|
updated_at TEXT NOT NULL
|
|
)""",
|
|
],
|
|
# v9 — agent prompt_mode column
|
|
[
|
|
"ALTER TABLE agents ADD COLUMN IF NOT EXISTS prompt_mode TEXT NOT NULL DEFAULT 'combined'",
|
|
],
|
|
# v10 — (was SQLite re-apply of v9; no-op here)
|
|
[],
|
|
# v11 — MCP client server configurations
|
|
[
|
|
"""CREATE TABLE IF NOT EXISTS mcp_servers (
|
|
id TEXT PRIMARY KEY,
|
|
name TEXT NOT NULL,
|
|
url TEXT NOT NULL,
|
|
transport TEXT NOT NULL DEFAULT 'sse',
|
|
api_key_enc TEXT,
|
|
headers_enc TEXT,
|
|
enabled BOOLEAN NOT NULL DEFAULT TRUE,
|
|
created_at TEXT NOT NULL,
|
|
updated_at TEXT NOT NULL
|
|
)""",
|
|
],
|
|
# v12 — users table for multi-user support (Part 2)
|
|
[
|
|
"""CREATE TABLE IF NOT EXISTS users (
|
|
id TEXT PRIMARY KEY,
|
|
username TEXT NOT NULL UNIQUE,
|
|
password_hash TEXT NOT NULL,
|
|
role TEXT NOT NULL DEFAULT 'user',
|
|
is_active BOOLEAN NOT NULL DEFAULT TRUE,
|
|
totp_secret TEXT,
|
|
created_at TEXT NOT NULL,
|
|
updated_at TEXT NOT NULL
|
|
)""",
|
|
"CREATE INDEX IF NOT EXISTS idx_users_username ON users(username)",
|
|
"ALTER TABLE agents ADD COLUMN IF NOT EXISTS owner_user_id TEXT REFERENCES users(id)",
|
|
"ALTER TABLE conversations ADD COLUMN IF NOT EXISTS user_id TEXT REFERENCES users(id)",
|
|
"ALTER TABLE audit_log ADD COLUMN IF NOT EXISTS user_id TEXT REFERENCES users(id)",
|
|
],
|
|
# v13 — add email column to users
|
|
[
|
|
"ALTER TABLE users ADD COLUMN IF NOT EXISTS email TEXT",
|
|
],
|
|
# v14 — per-user settings table + user_id columns on multi-tenant tables
|
|
[
|
|
"""CREATE TABLE IF NOT EXISTS user_settings (
|
|
user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
|
key TEXT NOT NULL,
|
|
value TEXT,
|
|
PRIMARY KEY (user_id, key)
|
|
)""",
|
|
"ALTER TABLE email_triggers ADD COLUMN IF NOT EXISTS user_id TEXT REFERENCES users(id)",
|
|
"ALTER TABLE telegram_triggers ADD COLUMN IF NOT EXISTS user_id TEXT REFERENCES users(id)",
|
|
"ALTER TABLE telegram_whitelist ADD COLUMN IF NOT EXISTS user_id TEXT REFERENCES users(id)",
|
|
"ALTER TABLE mcp_servers ADD COLUMN IF NOT EXISTS user_id TEXT REFERENCES users(id)",
|
|
],
|
|
# v15 — fix telegram_whitelist unique constraint to allow (chat_id, user_id) pairs
|
|
# Uses NULLS NOT DISTINCT (PostgreSQL 15+) so (chat_id, NULL) is unique per global entry
|
|
[
|
|
# Drop old primary key constraint so chat_id alone no longer enforces uniqueness
|
|
"""DO $$ BEGIN
|
|
IF EXISTS (
|
|
SELECT 1 FROM pg_constraint
|
|
WHERE conname = 'telegram_whitelist_pkey' AND conrelid = 'telegram_whitelist'::regclass
|
|
) THEN
|
|
ALTER TABLE telegram_whitelist DROP CONSTRAINT telegram_whitelist_pkey;
|
|
END IF;
|
|
END $$""",
|
|
# Add a surrogate UUID primary key
|
|
"ALTER TABLE telegram_whitelist ADD COLUMN IF NOT EXISTS id UUID DEFAULT gen_random_uuid()",
|
|
# Make it not null and set primary key (only if not already set)
|
|
"""DO $$ BEGIN
|
|
IF NOT EXISTS (
|
|
SELECT 1 FROM pg_constraint
|
|
WHERE conname = 'telegram_whitelist_pk' AND conrelid = 'telegram_whitelist'::regclass
|
|
) THEN
|
|
ALTER TABLE telegram_whitelist ADD CONSTRAINT telegram_whitelist_pk PRIMARY KEY (id);
|
|
END IF;
|
|
END $$""",
|
|
# Create unique index on (chat_id, user_id) NULLS NOT DISTINCT
|
|
"""CREATE UNIQUE INDEX IF NOT EXISTS telegram_whitelist_chat_user_idx
|
|
ON telegram_whitelist (chat_id, user_id) NULLS NOT DISTINCT""",
|
|
],
|
|
# v16 — email_accounts table for multi-account email handling
|
|
[
|
|
"""CREATE TABLE IF NOT EXISTS email_accounts (
|
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
|
user_id TEXT REFERENCES users(id),
|
|
label TEXT NOT NULL,
|
|
account_type TEXT NOT NULL DEFAULT 'handling',
|
|
imap_host TEXT NOT NULL,
|
|
imap_port INTEGER NOT NULL DEFAULT 993,
|
|
imap_username TEXT NOT NULL,
|
|
imap_password TEXT NOT NULL,
|
|
smtp_host TEXT,
|
|
smtp_port INTEGER,
|
|
smtp_username TEXT,
|
|
smtp_password TEXT,
|
|
agent_id TEXT REFERENCES agents(id),
|
|
enabled BOOLEAN NOT NULL DEFAULT TRUE,
|
|
initial_load_done BOOLEAN NOT NULL DEFAULT FALSE,
|
|
initial_load_limit INTEGER NOT NULL DEFAULT 200,
|
|
monitored_folders TEXT NOT NULL DEFAULT '[\"INBOX\"]',
|
|
created_at TEXT NOT NULL,
|
|
updated_at TEXT NOT NULL
|
|
)""",
|
|
"ALTER TABLE email_triggers ADD COLUMN IF NOT EXISTS account_id UUID REFERENCES email_accounts(id)",
|
|
],
|
|
# v17 — convert audit_log.arguments from TEXT to JSONB (SQLite-migrated DBs have TEXT)
|
|
# and agents/scheduled_tasks allowed_tools from TEXT to JSONB if not already
|
|
[
|
|
"""DO $$
|
|
BEGIN
|
|
IF (SELECT data_type FROM information_schema.columns
|
|
WHERE table_name='audit_log' AND column_name='arguments') = 'text' THEN
|
|
ALTER TABLE audit_log
|
|
ALTER COLUMN arguments TYPE JSONB
|
|
USING CASE WHEN arguments IS NULL OR arguments = '' THEN NULL
|
|
ELSE arguments::jsonb END;
|
|
END IF;
|
|
END $$""",
|
|
"""DO $$
|
|
BEGIN
|
|
IF (SELECT data_type FROM information_schema.columns
|
|
WHERE table_name='agents' AND column_name='allowed_tools') = 'text' THEN
|
|
ALTER TABLE agents
|
|
ALTER COLUMN allowed_tools TYPE JSONB
|
|
USING CASE WHEN allowed_tools IS NULL OR allowed_tools = '' THEN NULL
|
|
ELSE allowed_tools::jsonb END;
|
|
END IF;
|
|
END $$""",
|
|
"""DO $$
|
|
BEGIN
|
|
IF (SELECT data_type FROM information_schema.columns
|
|
WHERE table_name='scheduled_tasks' AND column_name='allowed_tools') = 'text' THEN
|
|
ALTER TABLE scheduled_tasks
|
|
ALTER COLUMN allowed_tools TYPE JSONB
|
|
USING CASE WHEN allowed_tools IS NULL OR allowed_tools = '' THEN NULL
|
|
ELSE allowed_tools::jsonb END;
|
|
END IF;
|
|
END $$""",
|
|
],
|
|
# v18 — MFA challenge table for TOTP second-factor login
|
|
[
|
|
"""CREATE TABLE IF NOT EXISTS mfa_challenges (
|
|
token TEXT PRIMARY KEY,
|
|
user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
|
next_url TEXT NOT NULL DEFAULT '/',
|
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
|
expires_at TIMESTAMPTZ NOT NULL,
|
|
attempts INTEGER NOT NULL DEFAULT 0
|
|
)""",
|
|
"CREATE INDEX IF NOT EXISTS idx_mfa_challenges_expires ON mfa_challenges(expires_at)",
|
|
],
|
|
# v19 — display name for users (editable, separate from username)
|
|
[
|
|
"ALTER TABLE users ADD COLUMN IF NOT EXISTS display_name TEXT",
|
|
],
|
|
# v20 — extra notification tools for handling email accounts
|
|
[
|
|
"ALTER TABLE email_accounts ADD COLUMN IF NOT EXISTS extra_tools JSONB DEFAULT '[]'",
|
|
],
|
|
# v21 — bound Telegram chat_id for email handling accounts
|
|
[
|
|
"ALTER TABLE email_accounts ADD COLUMN IF NOT EXISTS telegram_chat_id TEXT",
|
|
],
|
|
# v22 — Telegram keyword routing + pause flag for email handling accounts
|
|
[
|
|
"ALTER TABLE email_accounts ADD COLUMN IF NOT EXISTS telegram_keyword TEXT",
|
|
"ALTER TABLE email_accounts ADD COLUMN IF NOT EXISTS paused BOOLEAN DEFAULT FALSE",
|
|
],
|
|
# v23 — Conversation title for chat history UI
|
|
[
|
|
"ALTER TABLE conversations ADD COLUMN IF NOT EXISTS title TEXT",
|
|
],
|
|
# v24 — Store model ID used in each conversation
|
|
[
|
|
"ALTER TABLE conversations ADD COLUMN IF NOT EXISTS model TEXT",
|
|
],
|
|
]
|
|
|
|
|
|
async def _run_migrations(conn: asyncpg.Connection) -> None:
|
|
"""Apply pending migrations idempotently, each in its own transaction."""
|
|
await conn.execute(
|
|
"CREATE TABLE IF NOT EXISTS schema_version (version INTEGER PRIMARY KEY)"
|
|
)
|
|
current: int = await conn.fetchval(
|
|
"SELECT COALESCE(MAX(version), 0) FROM schema_version"
|
|
) or 0
|
|
|
|
for i, statements in enumerate(_MIGRATIONS, start=1):
|
|
if i <= current:
|
|
continue
|
|
async with conn.transaction():
|
|
for sql in statements:
|
|
sql = sql.strip()
|
|
if sql:
|
|
await conn.execute(sql)
|
|
await conn.execute(
|
|
"INSERT INTO schema_version (version) VALUES ($1) ON CONFLICT DO NOTHING", i
|
|
)
|
|
print(f"[aide] Applied database migration v{i}")
|
|
|
|
|
|
# ─── Helpers ──────────────────────────────────────────────────────────────────
|
|
|
|
def _utcnow() -> str:
|
|
return datetime.now(timezone.utc).isoformat()
|
|
|
|
|
|
def _jsonify(obj: Any) -> Any:
|
|
"""Return a JSON-safe version of obj (converts non-serializable values to strings)."""
|
|
if obj is None:
|
|
return None
|
|
return json.loads(json.dumps(obj, default=str))
|
|
|
|
|
|
def _rowcount(status: str) -> int:
|
|
"""Parse asyncpg execute() status string like 'DELETE 3' → 3."""
|
|
try:
|
|
return int(status.split()[-1])
|
|
except (ValueError, IndexError):
|
|
return 0
|
|
|
|
|
|
# ─── Credential Store ─────────────────────────────────────────────────────────
|
|
|
|
class CredentialStore:
|
|
"""Encrypted key-value store for sensitive credentials."""
|
|
|
|
async def get(self, key: str) -> str | None:
|
|
pool = await get_pool()
|
|
row = await pool.fetchrow(
|
|
"SELECT value_enc FROM credentials WHERE key = $1", key
|
|
)
|
|
if row is None:
|
|
return None
|
|
return _decrypt(row["value_enc"])
|
|
|
|
async def set(self, key: str, value: str, description: str = "") -> None:
|
|
now = _utcnow()
|
|
encrypted = _encrypt(value)
|
|
pool = await get_pool()
|
|
await pool.execute(
|
|
"""
|
|
INSERT INTO credentials (key, value_enc, description, created_at, updated_at)
|
|
VALUES ($1, $2, $3, $4, $5)
|
|
ON CONFLICT (key) DO UPDATE SET
|
|
value_enc = EXCLUDED.value_enc,
|
|
description = EXCLUDED.description,
|
|
updated_at = EXCLUDED.updated_at
|
|
""",
|
|
key, encrypted, description, now, now,
|
|
)
|
|
|
|
async def delete(self, key: str) -> bool:
|
|
pool = await get_pool()
|
|
status = await pool.execute("DELETE FROM credentials WHERE key = $1", key)
|
|
return _rowcount(status) > 0
|
|
|
|
async def list_keys(self) -> list[dict]:
|
|
pool = await get_pool()
|
|
rows = await pool.fetch(
|
|
"SELECT key, description, created_at, updated_at FROM credentials ORDER BY key"
|
|
)
|
|
return [dict(r) for r in rows]
|
|
|
|
async def require(self, key: str) -> str:
|
|
value = await self.get(key)
|
|
if not value:
|
|
raise RuntimeError(
|
|
f"Credential '{key}' is not configured. Add it via /settings."
|
|
)
|
|
return value
|
|
|
|
|
|
# Module-level singleton
|
|
credential_store = CredentialStore()
|
|
|
|
|
|
# ─── User Settings Store ──────────────────────────────────────────────────────
|
|
|
|
class UserSettingsStore:
|
|
"""Per-user key/value settings. Values are plaintext (not encrypted)."""
|
|
|
|
async def get(self, user_id: str, key: str) -> str | None:
|
|
pool = await get_pool()
|
|
return await pool.fetchval(
|
|
"SELECT value FROM user_settings WHERE user_id = $1 AND key = $2",
|
|
user_id, key,
|
|
)
|
|
|
|
async def set(self, user_id: str, key: str, value: str) -> None:
|
|
pool = await get_pool()
|
|
await pool.execute(
|
|
"""
|
|
INSERT INTO user_settings (user_id, key, value)
|
|
VALUES ($1, $2, $3)
|
|
ON CONFLICT (user_id, key) DO UPDATE SET value = EXCLUDED.value
|
|
""",
|
|
user_id, key, value,
|
|
)
|
|
|
|
async def delete(self, user_id: str, key: str) -> bool:
|
|
pool = await get_pool()
|
|
status = await pool.execute(
|
|
"DELETE FROM user_settings WHERE user_id = $1 AND key = $2", user_id, key
|
|
)
|
|
return _rowcount(status) > 0
|
|
|
|
async def get_with_global_fallback(self, user_id: str, key: str, global_key: str) -> str | None:
|
|
"""Try user-specific setting, fall back to global credential_store key."""
|
|
val = await self.get(user_id, key)
|
|
if val:
|
|
return val
|
|
return await credential_store.get(global_key)
|
|
|
|
|
|
# Module-level singleton
|
|
user_settings_store = UserSettingsStore()
|
|
|
|
|
|
# ─── Email Whitelist Store ────────────────────────────────────────────────────
|
|
|
|
class EmailWhitelistStore:
|
|
"""Manage allowed email recipients with optional per-address daily rate limits."""
|
|
|
|
async def list(self) -> list[dict]:
|
|
pool = await get_pool()
|
|
rows = await pool.fetch(
|
|
"SELECT email, daily_limit, created_at FROM email_whitelist ORDER BY email"
|
|
)
|
|
return [dict(r) for r in rows]
|
|
|
|
async def add(self, email: str, daily_limit: int = 0) -> None:
|
|
now = _utcnow()
|
|
normalized = email.strip().lower()
|
|
pool = await get_pool()
|
|
await pool.execute(
|
|
"""
|
|
INSERT INTO email_whitelist (email, daily_limit, created_at)
|
|
VALUES ($1, $2, $3)
|
|
ON CONFLICT (email) DO UPDATE SET daily_limit = EXCLUDED.daily_limit
|
|
""",
|
|
normalized, daily_limit, now,
|
|
)
|
|
|
|
async def remove(self, email: str) -> bool:
|
|
normalized = email.strip().lower()
|
|
pool = await get_pool()
|
|
status = await pool.execute(
|
|
"DELETE FROM email_whitelist WHERE email = $1", normalized
|
|
)
|
|
return _rowcount(status) > 0
|
|
|
|
async def get(self, email: str) -> dict | None:
|
|
normalized = email.strip().lower()
|
|
pool = await get_pool()
|
|
row = await pool.fetchrow(
|
|
"SELECT email, daily_limit, created_at FROM email_whitelist WHERE email = $1",
|
|
normalized,
|
|
)
|
|
return dict(row) if row else None
|
|
|
|
async def check_rate_limit(self, email: str) -> tuple[bool, int, int]:
|
|
"""
|
|
Check whether sending to this address is within the daily limit.
|
|
Returns (allowed, count_today, limit). limit=0 means unlimited.
|
|
"""
|
|
entry = await self.get(email)
|
|
if entry is None:
|
|
return False, 0, 0
|
|
|
|
limit = entry["daily_limit"]
|
|
if limit == 0:
|
|
return True, 0, 0
|
|
|
|
# Compute start of today in UTC as ISO8601 string for TEXT comparison
|
|
today_start = (
|
|
datetime.now(timezone.utc)
|
|
.replace(hour=0, minute=0, second=0, microsecond=0)
|
|
.isoformat()
|
|
)
|
|
pool = await get_pool()
|
|
count: int = await pool.fetchval(
|
|
"""
|
|
SELECT COUNT(*) FROM audit_log
|
|
WHERE tool_name = 'email'
|
|
AND arguments->>'operation' = 'send_email'
|
|
AND arguments->>'to' = $1
|
|
AND timestamp >= $2
|
|
AND (result_summary IS NULL OR result_summary NOT LIKE '%"success": false%')
|
|
""",
|
|
email.strip().lower(),
|
|
today_start,
|
|
) or 0
|
|
|
|
return count < limit, count, limit
|
|
|
|
|
|
# Module-level singleton
|
|
email_whitelist_store = EmailWhitelistStore()
|
|
|
|
|
|
# ─── Web Whitelist Store ──────────────────────────────────────────────────────
|
|
|
|
class WebWhitelistStore:
|
|
"""Manage Tier-1 always-allowed web domains."""
|
|
|
|
async def list(self) -> list[dict]:
|
|
pool = await get_pool()
|
|
rows = await pool.fetch(
|
|
"SELECT domain, note, created_at FROM web_whitelist ORDER BY domain"
|
|
)
|
|
return [dict(r) for r in rows]
|
|
|
|
async def add(self, domain: str, note: str = "") -> None:
|
|
normalized = _normalize_domain(domain)
|
|
now = _utcnow()
|
|
pool = await get_pool()
|
|
await pool.execute(
|
|
"""
|
|
INSERT INTO web_whitelist (domain, note, created_at)
|
|
VALUES ($1, $2, $3)
|
|
ON CONFLICT (domain) DO UPDATE SET note = EXCLUDED.note
|
|
""",
|
|
normalized, note, now,
|
|
)
|
|
|
|
async def remove(self, domain: str) -> bool:
|
|
normalized = _normalize_domain(domain)
|
|
pool = await get_pool()
|
|
status = await pool.execute(
|
|
"DELETE FROM web_whitelist WHERE domain = $1", normalized
|
|
)
|
|
return _rowcount(status) > 0
|
|
|
|
async def is_allowed(self, url: str) -> bool:
|
|
"""Return True if the URL's hostname matches a whitelisted domain or subdomain."""
|
|
try:
|
|
hostname = urlparse(url).hostname or ""
|
|
except Exception:
|
|
return False
|
|
if not hostname:
|
|
return False
|
|
domains = await self.list()
|
|
for entry in domains:
|
|
d = entry["domain"]
|
|
if hostname == d or hostname.endswith("." + d):
|
|
return True
|
|
return False
|
|
|
|
|
|
def _normalize_domain(domain: str) -> str:
|
|
"""Strip scheme and path, return lowercase hostname only."""
|
|
d = domain.strip().lower()
|
|
if "://" not in d:
|
|
d = "https://" + d
|
|
parsed = urlparse(d)
|
|
return parsed.hostname or domain.strip().lower()
|
|
|
|
|
|
# Module-level singleton
|
|
web_whitelist_store = WebWhitelistStore()
|
|
|
|
|
|
# ─── Filesystem Whitelist Store ───────────────────────────────────────────────
|
|
|
|
class FilesystemWhitelistStore:
|
|
"""Manage allowed filesystem sandbox directories."""
|
|
|
|
async def list(self) -> list[dict]:
|
|
pool = await get_pool()
|
|
rows = await pool.fetch(
|
|
"SELECT path, note, created_at FROM filesystem_whitelist ORDER BY path"
|
|
)
|
|
return [dict(r) for r in rows]
|
|
|
|
async def add(self, path: str, note: str = "") -> None:
|
|
from pathlib import Path as _Path
|
|
normalized = str(_Path(path).resolve())
|
|
now = _utcnow()
|
|
pool = await get_pool()
|
|
await pool.execute(
|
|
"""
|
|
INSERT INTO filesystem_whitelist (path, note, created_at)
|
|
VALUES ($1, $2, $3)
|
|
ON CONFLICT (path) DO UPDATE SET note = EXCLUDED.note
|
|
""",
|
|
normalized, note, now,
|
|
)
|
|
|
|
async def remove(self, path: str) -> bool:
|
|
from pathlib import Path as _Path
|
|
normalized = str(_Path(path).resolve())
|
|
pool = await get_pool()
|
|
status = await pool.execute(
|
|
"DELETE FROM filesystem_whitelist WHERE path = $1", normalized
|
|
)
|
|
if _rowcount(status) == 0:
|
|
# Fallback: try exact match without resolving
|
|
status = await pool.execute(
|
|
"DELETE FROM filesystem_whitelist WHERE path = $1", path
|
|
)
|
|
return _rowcount(status) > 0
|
|
|
|
async def is_allowed(self, path: Any) -> tuple[bool, str]:
|
|
"""
|
|
Check if path is inside any whitelisted directory.
|
|
Returns (allowed, resolved_path_str).
|
|
"""
|
|
from pathlib import Path as _Path
|
|
try:
|
|
resolved = _Path(path).resolve()
|
|
except Exception as e:
|
|
raise ValueError(f"Invalid path: {e}")
|
|
|
|
sandboxes = await self.list()
|
|
for entry in sandboxes:
|
|
try:
|
|
resolved.relative_to(_Path(entry["path"]).resolve())
|
|
return True, str(resolved)
|
|
except ValueError:
|
|
continue
|
|
return False, str(resolved)
|
|
|
|
|
|
# Module-level singleton
|
|
filesystem_whitelist_store = FilesystemWhitelistStore()
|
|
|
|
|
|
# ─── Initialisation ───────────────────────────────────────────────────────────
|
|
|
|
async def _init_connection(conn: asyncpg.Connection) -> None:
|
|
"""Register codecs on every new connection so asyncpg handles JSONB ↔ dict."""
|
|
await conn.set_type_codec(
|
|
"jsonb",
|
|
encoder=json.dumps,
|
|
decoder=json.loads,
|
|
schema="pg_catalog",
|
|
)
|
|
await conn.set_type_codec(
|
|
"json",
|
|
encoder=json.dumps,
|
|
decoder=json.loads,
|
|
schema="pg_catalog",
|
|
)
|
|
|
|
|
|
async def init_db() -> None:
|
|
"""Initialise the connection pool and run migrations. Call once at startup."""
|
|
global _pool
|
|
_pool = await asyncpg.create_pool(
|
|
settings.aide_db_url,
|
|
min_size=2,
|
|
max_size=10,
|
|
init=_init_connection,
|
|
)
|
|
async with _pool.acquire() as conn:
|
|
await _run_migrations(conn)
|
|
print(f"[aide] Database ready: {settings.aide_db_url.split('@')[-1]}")
|
|
|
|
|
|
async def close_db() -> None:
|
|
"""Close the connection pool. Call at shutdown."""
|
|
global _pool
|
|
if _pool:
|
|
await _pool.close()
|
|
_pool = None
|