229 lines
8.4 KiB
Python
229 lines
8.4 KiB
Python
"""
|
|
mcp_client/manager.py — MCP tool discovery and per-call execution.
|
|
|
|
Uses per-call connections: each discover_tools() and call_tool() opens
|
|
a fresh connection, does its work, and closes. Simpler than persistent
|
|
sessions and perfectly adequate for a personal agent.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from typing import TYPE_CHECKING
|
|
|
|
if TYPE_CHECKING:
|
|
from ..agent.tool_registry import ToolRegistry
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
async def _open_session(url: str, transport: str, headers: dict):
|
|
"""Async context manager that yields an initialized MCP ClientSession."""
|
|
from mcp import ClientSession
|
|
from mcp.client.sse import sse_client
|
|
from mcp.client.streamable_http import streamablehttp_client
|
|
|
|
if transport == "streamable_http":
|
|
async with streamablehttp_client(url, headers=headers) as (read, write, _):
|
|
async with ClientSession(read, write) as session:
|
|
await session.initialize()
|
|
yield session
|
|
else: # default: sse
|
|
async with sse_client(url, headers=headers) as (read, write):
|
|
async with ClientSession(read, write) as session:
|
|
await session.initialize()
|
|
yield session
|
|
|
|
|
|
async def discover_tools(server: dict) -> list[dict]:
|
|
"""
|
|
Connect to an MCP server, call list_tools(), and return a list of
|
|
tool-descriptor dicts: {tool_name, description, input_schema}.
|
|
Returns [] on any error.
|
|
"""
|
|
url = server["url"]
|
|
transport = server.get("transport", "sse")
|
|
headers = _build_headers(server)
|
|
try:
|
|
from mcp import ClientSession
|
|
from mcp.client.sse import sse_client
|
|
from mcp.client.streamable_http import streamablehttp_client
|
|
|
|
if transport == "streamable_http":
|
|
async with streamablehttp_client(url, headers=headers) as (read, write, _):
|
|
async with ClientSession(read, write) as session:
|
|
await session.initialize()
|
|
result = await session.list_tools()
|
|
return _parse_tools(result.tools)
|
|
else:
|
|
async with sse_client(url, headers=headers) as (read, write):
|
|
async with ClientSession(read, write) as session:
|
|
await session.initialize()
|
|
result = await session.list_tools()
|
|
return _parse_tools(result.tools)
|
|
except Exception as e:
|
|
logger.warning("[mcp-client] discover_tools failed for %s (%s): %s", server["name"], url, e)
|
|
return []
|
|
|
|
|
|
async def call_tool(server: dict, tool_name: str, arguments: dict) -> dict:
|
|
"""
|
|
Open a fresh connection, call the tool, return a ToolResult-compatible dict
|
|
{success, data, error}.
|
|
"""
|
|
from ..tools.base import ToolResult
|
|
url = server["url"]
|
|
transport = server.get("transport", "sse")
|
|
headers = _build_headers(server)
|
|
try:
|
|
from mcp import ClientSession
|
|
from mcp.client.sse import sse_client
|
|
from mcp.client.streamable_http import streamablehttp_client
|
|
|
|
if transport == "streamable_http":
|
|
async with streamablehttp_client(url, headers=headers) as (read, write, _):
|
|
async with ClientSession(read, write) as session:
|
|
await session.initialize()
|
|
result = await session.call_tool(tool_name, arguments)
|
|
else:
|
|
async with sse_client(url, headers=headers) as (read, write):
|
|
async with ClientSession(read, write) as session:
|
|
await session.initialize()
|
|
result = await session.call_tool(tool_name, arguments)
|
|
|
|
text = "\n".join(
|
|
c.text for c in result.content if hasattr(c, "text")
|
|
)
|
|
if result.isError:
|
|
return ToolResult(success=False, error=text or "MCP tool returned an error")
|
|
return ToolResult(success=True, data=text)
|
|
except Exception as e:
|
|
logger.error("[mcp-client] call_tool failed: %s.%s: %s", server["name"], tool_name, e)
|
|
return ToolResult(success=False, error=f"MCP call failed: {e}")
|
|
|
|
|
|
def _build_headers(server: dict) -> dict:
|
|
headers = {}
|
|
if server.get("api_key"):
|
|
headers["Authorization"] = f"Bearer {server['api_key']}"
|
|
if server.get("headers"):
|
|
headers.update(server["headers"])
|
|
return headers
|
|
|
|
|
|
def _parse_tools(tools) -> list[dict]:
|
|
result = []
|
|
for t in tools:
|
|
schema = t.inputSchema if hasattr(t, "inputSchema") else {}
|
|
if not isinstance(schema, dict):
|
|
schema = {}
|
|
result.append({
|
|
"tool_name": t.name,
|
|
"description": t.description or "",
|
|
"input_schema": schema,
|
|
})
|
|
return result
|
|
|
|
|
|
async def discover_and_register_mcp_tools(registry: ToolRegistry) -> None:
|
|
"""
|
|
Called from lifespan() after build_registry(). Discovers tools from all
|
|
enabled global MCP servers (user_id IS NULL) and registers McpProxyTool
|
|
instances into the registry.
|
|
"""
|
|
from .store import list_servers
|
|
from ..tools.mcp_proxy_tool import McpProxyTool
|
|
|
|
servers = await list_servers(include_secrets=True, user_id="GLOBAL")
|
|
for server in servers:
|
|
if not server["enabled"]:
|
|
continue
|
|
tools = await discover_tools(server)
|
|
_register_server_tools(registry, server, tools)
|
|
logger.info(
|
|
"[mcp-client] Registered %d tools from '%s'", len(tools), server["name"]
|
|
)
|
|
|
|
|
|
async def discover_user_mcp_tools(user_id: str) -> list:
|
|
"""
|
|
Discover MCP tools for a specific user's personal MCP servers.
|
|
Returns a list of McpProxyTool instances (not registered in the global registry).
|
|
These are passed as extra_tools to agent.run() for the duration of the session.
|
|
"""
|
|
from .store import list_servers
|
|
from ..tools.mcp_proxy_tool import McpProxyTool
|
|
|
|
servers = await list_servers(include_secrets=True, user_id=user_id)
|
|
user_tools: list = []
|
|
for server in servers:
|
|
if not server["enabled"]:
|
|
continue
|
|
tools = await discover_tools(server)
|
|
for t in tools:
|
|
proxy = McpProxyTool(
|
|
server_id=server["id"],
|
|
server_name=server["name"],
|
|
server=server,
|
|
tool_name=t["tool_name"],
|
|
description=t["description"],
|
|
input_schema=t["input_schema"],
|
|
)
|
|
user_tools.append(proxy)
|
|
if user_tools:
|
|
logger.info(
|
|
"[mcp-client] Discovered %d user MCP tools for user_id=%s",
|
|
len(user_tools), user_id,
|
|
)
|
|
return user_tools
|
|
|
|
|
|
def reload_server_tools(registry: ToolRegistry, server_id: str | None = None) -> None:
|
|
"""
|
|
Synchronous wrapper that schedules async tool discovery.
|
|
Called after adding/updating/deleting an MCP server config.
|
|
Since we can't await here (called from sync route handlers), we schedule
|
|
it as an asyncio task on the running loop.
|
|
"""
|
|
import asyncio
|
|
try:
|
|
loop = asyncio.get_running_loop()
|
|
loop.create_task(_reload_async(registry, server_id))
|
|
except RuntimeError:
|
|
pass # no running loop — startup context, ignore
|
|
|
|
|
|
async def _reload_async(registry: ToolRegistry, server_id: str | None) -> None:
|
|
from .store import list_servers, get_server
|
|
from ..tools.mcp_proxy_tool import McpProxyTool
|
|
|
|
# Remove existing MCP proxy tools
|
|
for name in list(registry._tools.keys()):
|
|
if name.startswith("mcp__"):
|
|
registry.deregister(name)
|
|
|
|
# Re-register all enabled global servers (user_id IS NULL)
|
|
servers = await list_servers(include_secrets=True, user_id="GLOBAL")
|
|
for server in servers:
|
|
if not server["enabled"]:
|
|
continue
|
|
tools = await discover_tools(server)
|
|
_register_server_tools(registry, server, tools)
|
|
logger.info("[mcp-client] Reloaded %d tools from '%s'", len(tools), server["name"])
|
|
|
|
|
|
def _register_server_tools(registry: ToolRegistry, server: dict, tools: list[dict]) -> None:
|
|
from ..tools.mcp_proxy_tool import McpProxyTool
|
|
for t in tools:
|
|
proxy = McpProxyTool(
|
|
server_id=server["id"],
|
|
server_name=server["name"],
|
|
server=server,
|
|
tool_name=t["tool_name"],
|
|
description=t["description"],
|
|
input_schema=t["input_schema"],
|
|
)
|
|
if proxy.name not in registry._tools:
|
|
registry.register(proxy)
|
|
else:
|
|
logger.warning("[mcp-client] Tool name collision, skipping: %s", proxy.name)
|