Initial commit
This commit is contained in:
0
server/mcp_client/__init__.py
Normal file
0
server/mcp_client/__init__.py
Normal file
BIN
server/mcp_client/__pycache__/__init__.cpython-314.pyc
Normal file
BIN
server/mcp_client/__pycache__/__init__.cpython-314.pyc
Normal file
Binary file not shown.
BIN
server/mcp_client/__pycache__/store.cpython-314.pyc
Normal file
BIN
server/mcp_client/__pycache__/store.cpython-314.pyc
Normal file
Binary file not shown.
228
server/mcp_client/manager.py
Normal file
228
server/mcp_client/manager.py
Normal file
@@ -0,0 +1,228 @@
|
||||
"""
|
||||
mcp_client/manager.py — MCP tool discovery and per-call execution.
|
||||
|
||||
Uses per-call connections: each discover_tools() and call_tool() opens
|
||||
a fresh connection, does its work, and closes. Simpler than persistent
|
||||
sessions and perfectly adequate for a personal agent.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..agent.tool_registry import ToolRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _open_session(url: str, transport: str, headers: dict):
|
||||
"""Async context manager that yields an initialized MCP ClientSession."""
|
||||
from mcp import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
|
||||
if transport == "streamable_http":
|
||||
async with streamablehttp_client(url, headers=headers) as (read, write, _):
|
||||
async with ClientSession(read, write) as session:
|
||||
await session.initialize()
|
||||
yield session
|
||||
else: # default: sse
|
||||
async with sse_client(url, headers=headers) as (read, write):
|
||||
async with ClientSession(read, write) as session:
|
||||
await session.initialize()
|
||||
yield session
|
||||
|
||||
|
||||
async def discover_tools(server: dict) -> list[dict]:
|
||||
"""
|
||||
Connect to an MCP server, call list_tools(), and return a list of
|
||||
tool-descriptor dicts: {tool_name, description, input_schema}.
|
||||
Returns [] on any error.
|
||||
"""
|
||||
url = server["url"]
|
||||
transport = server.get("transport", "sse")
|
||||
headers = _build_headers(server)
|
||||
try:
|
||||
from mcp import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
|
||||
if transport == "streamable_http":
|
||||
async with streamablehttp_client(url, headers=headers) as (read, write, _):
|
||||
async with ClientSession(read, write) as session:
|
||||
await session.initialize()
|
||||
result = await session.list_tools()
|
||||
return _parse_tools(result.tools)
|
||||
else:
|
||||
async with sse_client(url, headers=headers) as (read, write):
|
||||
async with ClientSession(read, write) as session:
|
||||
await session.initialize()
|
||||
result = await session.list_tools()
|
||||
return _parse_tools(result.tools)
|
||||
except Exception as e:
|
||||
logger.warning("[mcp-client] discover_tools failed for %s (%s): %s", server["name"], url, e)
|
||||
return []
|
||||
|
||||
|
||||
async def call_tool(server: dict, tool_name: str, arguments: dict) -> dict:
|
||||
"""
|
||||
Open a fresh connection, call the tool, return a ToolResult-compatible dict
|
||||
{success, data, error}.
|
||||
"""
|
||||
from ..tools.base import ToolResult
|
||||
url = server["url"]
|
||||
transport = server.get("transport", "sse")
|
||||
headers = _build_headers(server)
|
||||
try:
|
||||
from mcp import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
|
||||
if transport == "streamable_http":
|
||||
async with streamablehttp_client(url, headers=headers) as (read, write, _):
|
||||
async with ClientSession(read, write) as session:
|
||||
await session.initialize()
|
||||
result = await session.call_tool(tool_name, arguments)
|
||||
else:
|
||||
async with sse_client(url, headers=headers) as (read, write):
|
||||
async with ClientSession(read, write) as session:
|
||||
await session.initialize()
|
||||
result = await session.call_tool(tool_name, arguments)
|
||||
|
||||
text = "\n".join(
|
||||
c.text for c in result.content if hasattr(c, "text")
|
||||
)
|
||||
if result.isError:
|
||||
return ToolResult(success=False, error=text or "MCP tool returned an error")
|
||||
return ToolResult(success=True, data=text)
|
||||
except Exception as e:
|
||||
logger.error("[mcp-client] call_tool failed: %s.%s: %s", server["name"], tool_name, e)
|
||||
return ToolResult(success=False, error=f"MCP call failed: {e}")
|
||||
|
||||
|
||||
def _build_headers(server: dict) -> dict:
|
||||
headers = {}
|
||||
if server.get("api_key"):
|
||||
headers["Authorization"] = f"Bearer {server['api_key']}"
|
||||
if server.get("headers"):
|
||||
headers.update(server["headers"])
|
||||
return headers
|
||||
|
||||
|
||||
def _parse_tools(tools) -> list[dict]:
|
||||
result = []
|
||||
for t in tools:
|
||||
schema = t.inputSchema if hasattr(t, "inputSchema") else {}
|
||||
if not isinstance(schema, dict):
|
||||
schema = {}
|
||||
result.append({
|
||||
"tool_name": t.name,
|
||||
"description": t.description or "",
|
||||
"input_schema": schema,
|
||||
})
|
||||
return result
|
||||
|
||||
|
||||
async def discover_and_register_mcp_tools(registry: ToolRegistry) -> None:
|
||||
"""
|
||||
Called from lifespan() after build_registry(). Discovers tools from all
|
||||
enabled global MCP servers (user_id IS NULL) and registers McpProxyTool
|
||||
instances into the registry.
|
||||
"""
|
||||
from .store import list_servers
|
||||
from ..tools.mcp_proxy_tool import McpProxyTool
|
||||
|
||||
servers = await list_servers(include_secrets=True, user_id="GLOBAL")
|
||||
for server in servers:
|
||||
if not server["enabled"]:
|
||||
continue
|
||||
tools = await discover_tools(server)
|
||||
_register_server_tools(registry, server, tools)
|
||||
logger.info(
|
||||
"[mcp-client] Registered %d tools from '%s'", len(tools), server["name"]
|
||||
)
|
||||
|
||||
|
||||
async def discover_user_mcp_tools(user_id: str) -> list:
|
||||
"""
|
||||
Discover MCP tools for a specific user's personal MCP servers.
|
||||
Returns a list of McpProxyTool instances (not registered in the global registry).
|
||||
These are passed as extra_tools to agent.run() for the duration of the session.
|
||||
"""
|
||||
from .store import list_servers
|
||||
from ..tools.mcp_proxy_tool import McpProxyTool
|
||||
|
||||
servers = await list_servers(include_secrets=True, user_id=user_id)
|
||||
user_tools: list = []
|
||||
for server in servers:
|
||||
if not server["enabled"]:
|
||||
continue
|
||||
tools = await discover_tools(server)
|
||||
for t in tools:
|
||||
proxy = McpProxyTool(
|
||||
server_id=server["id"],
|
||||
server_name=server["name"],
|
||||
server=server,
|
||||
tool_name=t["tool_name"],
|
||||
description=t["description"],
|
||||
input_schema=t["input_schema"],
|
||||
)
|
||||
user_tools.append(proxy)
|
||||
if user_tools:
|
||||
logger.info(
|
||||
"[mcp-client] Discovered %d user MCP tools for user_id=%s",
|
||||
len(user_tools), user_id,
|
||||
)
|
||||
return user_tools
|
||||
|
||||
|
||||
def reload_server_tools(registry: ToolRegistry, server_id: str | None = None) -> None:
|
||||
"""
|
||||
Synchronous wrapper that schedules async tool discovery.
|
||||
Called after adding/updating/deleting an MCP server config.
|
||||
Since we can't await here (called from sync route handlers), we schedule
|
||||
it as an asyncio task on the running loop.
|
||||
"""
|
||||
import asyncio
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.create_task(_reload_async(registry, server_id))
|
||||
except RuntimeError:
|
||||
pass # no running loop — startup context, ignore
|
||||
|
||||
|
||||
async def _reload_async(registry: ToolRegistry, server_id: str | None) -> None:
|
||||
from .store import list_servers, get_server
|
||||
from ..tools.mcp_proxy_tool import McpProxyTool
|
||||
|
||||
# Remove existing MCP proxy tools
|
||||
for name in list(registry._tools.keys()):
|
||||
if name.startswith("mcp__"):
|
||||
registry.deregister(name)
|
||||
|
||||
# Re-register all enabled global servers (user_id IS NULL)
|
||||
servers = await list_servers(include_secrets=True, user_id="GLOBAL")
|
||||
for server in servers:
|
||||
if not server["enabled"]:
|
||||
continue
|
||||
tools = await discover_tools(server)
|
||||
_register_server_tools(registry, server, tools)
|
||||
logger.info("[mcp-client] Reloaded %d tools from '%s'", len(tools), server["name"])
|
||||
|
||||
|
||||
def _register_server_tools(registry: ToolRegistry, server: dict, tools: list[dict]) -> None:
|
||||
from ..tools.mcp_proxy_tool import McpProxyTool
|
||||
for t in tools:
|
||||
proxy = McpProxyTool(
|
||||
server_id=server["id"],
|
||||
server_name=server["name"],
|
||||
server=server,
|
||||
tool_name=t["tool_name"],
|
||||
description=t["description"],
|
||||
input_schema=t["input_schema"],
|
||||
)
|
||||
if proxy.name not in registry._tools:
|
||||
registry.register(proxy)
|
||||
else:
|
||||
logger.warning("[mcp-client] Tool name collision, skipping: %s", proxy.name)
|
||||
144
server/mcp_client/store.py
Normal file
144
server/mcp_client/store.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""
|
||||
mcp_client/store.py — CRUD for mcp_servers table (async).
|
||||
|
||||
API keys and extra headers are encrypted at rest using the same
|
||||
AES-256-GCM helpers as the credentials table.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from ..database import _decrypt, _encrypt, _rowcount, get_pool
|
||||
|
||||
|
||||
def _now() -> str:
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
def _row_to_dict(row, include_secrets: bool = False) -> dict:
|
||||
d = dict(row)
|
||||
# Decrypt api_key
|
||||
if d.get("api_key_enc"):
|
||||
d["api_key"] = _decrypt(d["api_key_enc"]) if include_secrets else None
|
||||
d["has_api_key"] = True
|
||||
else:
|
||||
d["api_key"] = None
|
||||
d["has_api_key"] = False
|
||||
del d["api_key_enc"]
|
||||
|
||||
# Decrypt headers JSON
|
||||
if d.get("headers_enc"):
|
||||
try:
|
||||
d["headers"] = json.loads(_decrypt(d["headers_enc"])) if include_secrets else None
|
||||
except Exception:
|
||||
d["headers"] = None
|
||||
d["has_headers"] = True
|
||||
else:
|
||||
d["headers"] = None
|
||||
d["has_headers"] = False
|
||||
del d["headers_enc"]
|
||||
|
||||
# enabled is already Python bool from BOOLEAN column
|
||||
return d
|
||||
|
||||
|
||||
async def list_servers(
|
||||
include_secrets: bool = False,
|
||||
user_id: str | None = "GLOBAL",
|
||||
) -> list[dict]:
|
||||
"""
|
||||
List MCP servers.
|
||||
- user_id="GLOBAL" (default): global servers (user_id IS NULL)
|
||||
- user_id=None: ALL servers (admin use)
|
||||
- user_id="<uuid>": servers owned by that user
|
||||
"""
|
||||
pool = await get_pool()
|
||||
if user_id == "GLOBAL":
|
||||
rows = await pool.fetch(
|
||||
"SELECT * FROM mcp_servers WHERE user_id IS NULL ORDER BY name"
|
||||
)
|
||||
elif user_id is None:
|
||||
rows = await pool.fetch("SELECT * FROM mcp_servers ORDER BY name")
|
||||
else:
|
||||
rows = await pool.fetch(
|
||||
"SELECT * FROM mcp_servers WHERE user_id = $1 ORDER BY name", user_id
|
||||
)
|
||||
return [_row_to_dict(r, include_secrets) for r in rows]
|
||||
|
||||
|
||||
async def get_server(server_id: str, include_secrets: bool = False) -> dict | None:
|
||||
pool = await get_pool()
|
||||
row = await pool.fetchrow("SELECT * FROM mcp_servers WHERE id = $1", server_id)
|
||||
return _row_to_dict(row, include_secrets) if row else None
|
||||
|
||||
|
||||
async def create_server(
|
||||
name: str,
|
||||
url: str,
|
||||
transport: str = "sse",
|
||||
api_key: str = "",
|
||||
headers: dict | None = None,
|
||||
enabled: bool = True,
|
||||
user_id: str | None = None,
|
||||
) -> dict:
|
||||
server_id = str(uuid.uuid4())
|
||||
now = _now()
|
||||
api_key_enc = _encrypt(api_key) if api_key else None
|
||||
headers_enc = _encrypt(json.dumps(headers)) if headers else None
|
||||
pool = await get_pool()
|
||||
await pool.execute(
|
||||
"""
|
||||
INSERT INTO mcp_servers
|
||||
(id, name, url, transport, api_key_enc, headers_enc, enabled, user_id, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
|
||||
""",
|
||||
server_id, name, url, transport, api_key_enc, headers_enc, enabled, user_id, now, now,
|
||||
)
|
||||
return await get_server(server_id)
|
||||
|
||||
|
||||
async def update_server(server_id: str, **fields) -> dict | None:
|
||||
row = await get_server(server_id, include_secrets=True)
|
||||
if not row:
|
||||
return None
|
||||
now = _now()
|
||||
updates: dict[str, Any] = {}
|
||||
if "name" in fields:
|
||||
updates["name"] = fields["name"]
|
||||
if "url" in fields:
|
||||
updates["url"] = fields["url"]
|
||||
if "transport" in fields:
|
||||
updates["transport"] = fields["transport"]
|
||||
if "api_key" in fields:
|
||||
updates["api_key_enc"] = _encrypt(fields["api_key"]) if fields["api_key"] else None
|
||||
if "headers" in fields:
|
||||
updates["headers_enc"] = _encrypt(json.dumps(fields["headers"])) if fields["headers"] else None
|
||||
if "enabled" in fields:
|
||||
updates["enabled"] = fields["enabled"]
|
||||
if not updates:
|
||||
return row
|
||||
|
||||
set_parts = []
|
||||
values: list[Any] = []
|
||||
for i, (k, v) in enumerate(updates.items(), start=1):
|
||||
set_parts.append(f"{k} = ${i}")
|
||||
values.append(v)
|
||||
|
||||
n = len(updates) + 1
|
||||
values.extend([now, server_id])
|
||||
|
||||
pool = await get_pool()
|
||||
await pool.execute(
|
||||
f"UPDATE mcp_servers SET {', '.join(set_parts)}, updated_at = ${n} WHERE id = ${n + 1}",
|
||||
*values,
|
||||
)
|
||||
return await get_server(server_id)
|
||||
|
||||
|
||||
async def delete_server(server_id: str) -> bool:
|
||||
pool = await get_pool()
|
||||
status = await pool.execute("DELETE FROM mcp_servers WHERE id = $1", server_id)
|
||||
return _rowcount(status) > 0
|
||||
Reference in New Issue
Block a user