110 lines
4.3 KiB
Python
110 lines
4.3 KiB
Python
"""
|
|
agent/tool_registry.py — Central tool registry.
|
|
|
|
Tools register themselves here. The agent loop asks the registry for
|
|
schemas (to send to the AI) and dispatches tool calls through it.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import traceback
|
|
|
|
from ..tools.base import BaseTool, ToolResult
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ToolRegistry:
|
|
def __init__(self) -> None:
|
|
self._tools: dict[str, BaseTool] = {}
|
|
|
|
def register(self, tool: BaseTool) -> None:
|
|
"""Register a tool instance. Raises if name already taken."""
|
|
if tool.name in self._tools:
|
|
raise ValueError(f"Tool '{tool.name}' is already registered")
|
|
self._tools[tool.name] = tool
|
|
logger.debug(f"Registered tool: {tool.name}")
|
|
|
|
def deregister(self, name: str) -> None:
|
|
"""Remove a tool by name. No-op if not registered."""
|
|
self._tools.pop(name, None)
|
|
logger.debug(f"Deregistered tool: {name}")
|
|
|
|
def get(self, name: str) -> BaseTool | None:
|
|
return self._tools.get(name)
|
|
|
|
def all_tools(self) -> list[BaseTool]:
|
|
return list(self._tools.values())
|
|
|
|
# ── Schema generation ─────────────────────────────────────────────────────
|
|
|
|
def get_schemas(self) -> list[dict]:
|
|
"""All tool schemas — used for interactive sessions."""
|
|
return [t.get_schema() for t in self._tools.values()]
|
|
|
|
def get_schemas_for_task(self, allowed_tools: list[str]) -> list[dict]:
|
|
"""
|
|
Filtered schemas for a scheduled task or agent.
|
|
Only tools explicitly declared in allowed_tools are included.
|
|
Supports server-level wildcards: "mcp__servername" includes all tools from that server.
|
|
Structurally impossible for the agent to call undeclared tools.
|
|
"""
|
|
schemas = []
|
|
seen: set[str] = set()
|
|
for name in allowed_tools:
|
|
# Server-level wildcard: mcp__servername (no third segment)
|
|
if name.startswith("mcp__") and name.count("__") == 1:
|
|
prefix = name + "__"
|
|
for tool_name, tool in self._tools.items():
|
|
if tool_name.startswith(prefix) and tool_name not in seen:
|
|
seen.add(tool_name)
|
|
schemas.append(tool.get_schema())
|
|
else:
|
|
if name in seen:
|
|
continue
|
|
tool = self._tools.get(name)
|
|
if tool is None:
|
|
logger.warning(f"Requested unknown tool: {name!r}")
|
|
continue
|
|
if not tool.allowed_in_scheduled_tasks:
|
|
logger.warning(f"Tool {name!r} is not allowed in scheduled tasks — skipped")
|
|
continue
|
|
seen.add(name)
|
|
schemas.append(tool.get_schema())
|
|
return schemas
|
|
|
|
# ── Dispatch ──────────────────────────────────────────────────────────────
|
|
|
|
async def dispatch(
|
|
self,
|
|
name: str,
|
|
arguments: dict,
|
|
task_id: str | None = None,
|
|
) -> ToolResult:
|
|
"""
|
|
Execute a tool by name. Never raises into the agent loop —
|
|
all exceptions are caught and returned as ToolResult(success=False).
|
|
"""
|
|
tool = self._tools.get(name)
|
|
if tool is None:
|
|
# This can happen if a scheduled task somehow tries an undeclared tool
|
|
msg = f"Tool '{name}' is not available in this context."
|
|
logger.warning(f"Dispatch rejected: {msg}")
|
|
return ToolResult(success=False, error=msg)
|
|
|
|
if task_id and not tool.allowed_in_scheduled_tasks:
|
|
msg = f"Tool '{name}' is not allowed in scheduled tasks."
|
|
logger.warning(f"Dispatch rejected: {msg}")
|
|
return ToolResult(success=False, error=msg)
|
|
|
|
try:
|
|
result = await tool.execute(**arguments)
|
|
return result
|
|
except Exception:
|
|
tb = traceback.format_exc()
|
|
logger.error(f"Tool '{name}' raised unexpectedly:\n{tb}")
|
|
return ToolResult(
|
|
success=False,
|
|
error=f"Tool '{name}' encountered an internal error.",
|
|
)
|