""" mcp_client/store.py — CRUD for mcp_servers table (async). API keys and extra headers are encrypted at rest using the same AES-256-GCM helpers as the credentials table. """ from __future__ import annotations import json import uuid from datetime import datetime, timezone from typing import Any from ..database import _decrypt, _encrypt, _rowcount, get_pool def _now() -> str: return datetime.now(timezone.utc).isoformat() def _row_to_dict(row, include_secrets: bool = False) -> dict: d = dict(row) # Decrypt api_key if d.get("api_key_enc"): d["api_key"] = _decrypt(d["api_key_enc"]) if include_secrets else None d["has_api_key"] = True else: d["api_key"] = None d["has_api_key"] = False del d["api_key_enc"] # Decrypt headers JSON if d.get("headers_enc"): try: d["headers"] = json.loads(_decrypt(d["headers_enc"])) if include_secrets else None except Exception: d["headers"] = None d["has_headers"] = True else: d["headers"] = None d["has_headers"] = False del d["headers_enc"] # enabled is already Python bool from BOOLEAN column return d async def list_servers( include_secrets: bool = False, user_id: str | None = "GLOBAL", ) -> list[dict]: """ List MCP servers. - user_id="GLOBAL" (default): global servers (user_id IS NULL) - user_id=None: ALL servers (admin use) - user_id="": servers owned by that user """ pool = await get_pool() if user_id == "GLOBAL": rows = await pool.fetch( "SELECT * FROM mcp_servers WHERE user_id IS NULL ORDER BY name" ) elif user_id is None: rows = await pool.fetch("SELECT * FROM mcp_servers ORDER BY name") else: rows = await pool.fetch( "SELECT * FROM mcp_servers WHERE user_id = $1 ORDER BY name", user_id ) return [_row_to_dict(r, include_secrets) for r in rows] async def get_server(server_id: str, include_secrets: bool = False) -> dict | None: pool = await get_pool() row = await pool.fetchrow("SELECT * FROM mcp_servers WHERE id = $1", server_id) return _row_to_dict(row, include_secrets) if row else None async def create_server( name: str, url: str, transport: str = "sse", api_key: str = "", headers: dict | None = None, enabled: bool = True, user_id: str | None = None, ) -> dict: server_id = str(uuid.uuid4()) now = _now() api_key_enc = _encrypt(api_key) if api_key else None headers_enc = _encrypt(json.dumps(headers)) if headers else None pool = await get_pool() await pool.execute( """ INSERT INTO mcp_servers (id, name, url, transport, api_key_enc, headers_enc, enabled, user_id, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) """, server_id, name, url, transport, api_key_enc, headers_enc, enabled, user_id, now, now, ) return await get_server(server_id) async def update_server(server_id: str, **fields) -> dict | None: row = await get_server(server_id, include_secrets=True) if not row: return None now = _now() updates: dict[str, Any] = {} if "name" in fields: updates["name"] = fields["name"] if "url" in fields: updates["url"] = fields["url"] if "transport" in fields: updates["transport"] = fields["transport"] if "api_key" in fields: updates["api_key_enc"] = _encrypt(fields["api_key"]) if fields["api_key"] else None if "headers" in fields: updates["headers_enc"] = _encrypt(json.dumps(fields["headers"])) if fields["headers"] else None if "enabled" in fields: updates["enabled"] = fields["enabled"] if not updates: return row set_parts = [] values: list[Any] = [] for i, (k, v) in enumerate(updates.items(), start=1): set_parts.append(f"{k} = ${i}") values.append(v) n = len(updates) + 1 values.extend([now, server_id]) pool = await get_pool() await pool.execute( f"UPDATE mcp_servers SET {', '.join(set_parts)}, updated_at = ${n} WHERE id = ${n + 1}", *values, ) return await get_server(server_id) async def delete_server(server_id: str) -> bool: pool = await get_pool() status = await pool.execute("DELETE FROM mcp_servers WHERE id = $1", server_id) return _rowcount(status) > 0