145 lines
4.4 KiB
Python
145 lines
4.4 KiB
Python
"""
|
|
mcp_client/store.py — CRUD for mcp_servers table (async).
|
|
|
|
API keys and extra headers are encrypted at rest using the same
|
|
AES-256-GCM helpers as the credentials table.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import uuid
|
|
from datetime import datetime, timezone
|
|
from typing import Any
|
|
|
|
from ..database import _decrypt, _encrypt, _rowcount, get_pool
|
|
|
|
|
|
def _now() -> str:
|
|
return datetime.now(timezone.utc).isoformat()
|
|
|
|
|
|
def _row_to_dict(row, include_secrets: bool = False) -> dict:
|
|
d = dict(row)
|
|
# Decrypt api_key
|
|
if d.get("api_key_enc"):
|
|
d["api_key"] = _decrypt(d["api_key_enc"]) if include_secrets else None
|
|
d["has_api_key"] = True
|
|
else:
|
|
d["api_key"] = None
|
|
d["has_api_key"] = False
|
|
del d["api_key_enc"]
|
|
|
|
# Decrypt headers JSON
|
|
if d.get("headers_enc"):
|
|
try:
|
|
d["headers"] = json.loads(_decrypt(d["headers_enc"])) if include_secrets else None
|
|
except Exception:
|
|
d["headers"] = None
|
|
d["has_headers"] = True
|
|
else:
|
|
d["headers"] = None
|
|
d["has_headers"] = False
|
|
del d["headers_enc"]
|
|
|
|
# enabled is already Python bool from BOOLEAN column
|
|
return d
|
|
|
|
|
|
async def list_servers(
|
|
include_secrets: bool = False,
|
|
user_id: str | None = "GLOBAL",
|
|
) -> list[dict]:
|
|
"""
|
|
List MCP servers.
|
|
- user_id="GLOBAL" (default): global servers (user_id IS NULL)
|
|
- user_id=None: ALL servers (admin use)
|
|
- user_id="<uuid>": servers owned by that user
|
|
"""
|
|
pool = await get_pool()
|
|
if user_id == "GLOBAL":
|
|
rows = await pool.fetch(
|
|
"SELECT * FROM mcp_servers WHERE user_id IS NULL ORDER BY name"
|
|
)
|
|
elif user_id is None:
|
|
rows = await pool.fetch("SELECT * FROM mcp_servers ORDER BY name")
|
|
else:
|
|
rows = await pool.fetch(
|
|
"SELECT * FROM mcp_servers WHERE user_id = $1 ORDER BY name", user_id
|
|
)
|
|
return [_row_to_dict(r, include_secrets) for r in rows]
|
|
|
|
|
|
async def get_server(server_id: str, include_secrets: bool = False) -> dict | None:
|
|
pool = await get_pool()
|
|
row = await pool.fetchrow("SELECT * FROM mcp_servers WHERE id = $1", server_id)
|
|
return _row_to_dict(row, include_secrets) if row else None
|
|
|
|
|
|
async def create_server(
|
|
name: str,
|
|
url: str,
|
|
transport: str = "sse",
|
|
api_key: str = "",
|
|
headers: dict | None = None,
|
|
enabled: bool = True,
|
|
user_id: str | None = None,
|
|
) -> dict:
|
|
server_id = str(uuid.uuid4())
|
|
now = _now()
|
|
api_key_enc = _encrypt(api_key) if api_key else None
|
|
headers_enc = _encrypt(json.dumps(headers)) if headers else None
|
|
pool = await get_pool()
|
|
await pool.execute(
|
|
"""
|
|
INSERT INTO mcp_servers
|
|
(id, name, url, transport, api_key_enc, headers_enc, enabled, user_id, created_at, updated_at)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
|
|
""",
|
|
server_id, name, url, transport, api_key_enc, headers_enc, enabled, user_id, now, now,
|
|
)
|
|
return await get_server(server_id)
|
|
|
|
|
|
async def update_server(server_id: str, **fields) -> dict | None:
|
|
row = await get_server(server_id, include_secrets=True)
|
|
if not row:
|
|
return None
|
|
now = _now()
|
|
updates: dict[str, Any] = {}
|
|
if "name" in fields:
|
|
updates["name"] = fields["name"]
|
|
if "url" in fields:
|
|
updates["url"] = fields["url"]
|
|
if "transport" in fields:
|
|
updates["transport"] = fields["transport"]
|
|
if "api_key" in fields:
|
|
updates["api_key_enc"] = _encrypt(fields["api_key"]) if fields["api_key"] else None
|
|
if "headers" in fields:
|
|
updates["headers_enc"] = _encrypt(json.dumps(fields["headers"])) if fields["headers"] else None
|
|
if "enabled" in fields:
|
|
updates["enabled"] = fields["enabled"]
|
|
if not updates:
|
|
return row
|
|
|
|
set_parts = []
|
|
values: list[Any] = []
|
|
for i, (k, v) in enumerate(updates.items(), start=1):
|
|
set_parts.append(f"{k} = ${i}")
|
|
values.append(v)
|
|
|
|
n = len(updates) + 1
|
|
values.extend([now, server_id])
|
|
|
|
pool = await get_pool()
|
|
await pool.execute(
|
|
f"UPDATE mcp_servers SET {', '.join(set_parts)}, updated_at = ${n} WHERE id = ${n + 1}",
|
|
*values,
|
|
)
|
|
return await get_server(server_id)
|
|
|
|
|
|
async def delete_server(server_id: str) -> bool:
|
|
pool = await get_pool()
|
|
status = await pool.execute("DELETE FROM mcp_servers WHERE id = $1", server_id)
|
|
return _rowcount(status) > 0
|