182 lines
6.3 KiB
Python
182 lines
6.3 KiB
Python
"""
|
|
providers/anthropic_provider.py — Anthropic Claude provider.
|
|
|
|
Uses the official `anthropic` Python SDK.
|
|
Tool schemas are already in Anthropic's native format, so no conversion needed.
|
|
Messages are converted from the OpenAI-style format used internally by aide.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
|
|
import anthropic
|
|
|
|
from .base import AIProvider, ProviderResponse, ToolCallResult, UsageStats
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
DEFAULT_MODEL = "claude-sonnet-4-6"
|
|
|
|
|
|
class AnthropicProvider(AIProvider):
|
|
def __init__(self, api_key: str) -> None:
|
|
self._client = anthropic.Anthropic(api_key=api_key)
|
|
self._async_client = anthropic.AsyncAnthropic(api_key=api_key)
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
return "Anthropic"
|
|
|
|
@property
|
|
def default_model(self) -> str:
|
|
return DEFAULT_MODEL
|
|
|
|
# ── Public interface ──────────────────────────────────────────────────────
|
|
|
|
def chat(
|
|
self,
|
|
messages: list[dict],
|
|
tools: list[dict] | None = None,
|
|
system: str = "",
|
|
model: str = "",
|
|
max_tokens: int = 4096,
|
|
) -> ProviderResponse:
|
|
params = self._build_params(messages, tools, system, model, max_tokens)
|
|
try:
|
|
response = self._client.messages.create(**params)
|
|
return self._parse_response(response)
|
|
except Exception as e:
|
|
logger.error(f"Anthropic chat error: {e}")
|
|
return ProviderResponse(text=f"Error: {e}", finish_reason="error")
|
|
|
|
async def chat_async(
|
|
self,
|
|
messages: list[dict],
|
|
tools: list[dict] | None = None,
|
|
system: str = "",
|
|
model: str = "",
|
|
max_tokens: int = 4096,
|
|
) -> ProviderResponse:
|
|
params = self._build_params(messages, tools, system, model, max_tokens)
|
|
try:
|
|
response = await self._async_client.messages.create(**params)
|
|
return self._parse_response(response)
|
|
except Exception as e:
|
|
logger.error(f"Anthropic async chat error: {e}")
|
|
return ProviderResponse(text=f"Error: {e}", finish_reason="error")
|
|
|
|
# ── Internal helpers ──────────────────────────────────────────────────────
|
|
|
|
def _build_params(
|
|
self,
|
|
messages: list[dict],
|
|
tools: list[dict] | None,
|
|
system: str,
|
|
model: str,
|
|
max_tokens: int,
|
|
) -> dict:
|
|
anthropic_messages = self._convert_messages(messages)
|
|
params: dict = {
|
|
"model": model or self.default_model,
|
|
"messages": anthropic_messages,
|
|
"max_tokens": max_tokens,
|
|
}
|
|
if system:
|
|
params["system"] = system
|
|
if tools:
|
|
# aide tool schemas ARE Anthropic format — pass through directly
|
|
params["tools"] = tools
|
|
params["tool_choice"] = {"type": "auto"}
|
|
return params
|
|
|
|
def _convert_messages(self, messages: list[dict]) -> list[dict]:
|
|
"""
|
|
Convert aide's internal message list to Anthropic format.
|
|
|
|
aide uses an OpenAI-style internal format:
|
|
{"role": "user", "content": "..."}
|
|
{"role": "assistant", "content": "...", "tool_calls": [...]}
|
|
{"role": "tool", "tool_call_id": "...", "content": "..."}
|
|
|
|
Anthropic requires:
|
|
- tool calls embedded in content blocks (tool_use type)
|
|
- tool results as user messages with tool_result content blocks
|
|
"""
|
|
result: list[dict] = []
|
|
i = 0
|
|
while i < len(messages):
|
|
msg = messages[i]
|
|
role = msg["role"]
|
|
|
|
if role == "system":
|
|
i += 1
|
|
continue # Already handled via system= param
|
|
|
|
if role == "assistant" and msg.get("tool_calls"):
|
|
# Convert assistant tool calls to Anthropic content blocks
|
|
blocks: list[dict] = []
|
|
if msg.get("content"):
|
|
blocks.append({"type": "text", "text": msg["content"]})
|
|
for tc in msg["tool_calls"]:
|
|
blocks.append({
|
|
"type": "tool_use",
|
|
"id": tc["id"],
|
|
"name": tc["name"],
|
|
"input": tc["arguments"],
|
|
})
|
|
result.append({"role": "assistant", "content": blocks})
|
|
|
|
elif role == "tool":
|
|
# Group consecutive tool results into one user message
|
|
tool_results: list[dict] = []
|
|
while i < len(messages) and messages[i]["role"] == "tool":
|
|
t = messages[i]
|
|
tool_results.append({
|
|
"type": "tool_result",
|
|
"tool_use_id": t["tool_call_id"],
|
|
"content": t["content"],
|
|
})
|
|
i += 1
|
|
result.append({"role": "user", "content": tool_results})
|
|
continue # i already advanced
|
|
|
|
else:
|
|
# content may be a string (plain text) or a list of blocks (multimodal)
|
|
result.append({"role": role, "content": msg.get("content", "")})
|
|
|
|
i += 1
|
|
|
|
return result
|
|
|
|
def _parse_response(self, response) -> ProviderResponse:
|
|
text = ""
|
|
tool_calls: list[ToolCallResult] = []
|
|
|
|
for block in response.content:
|
|
if block.type == "text":
|
|
text += block.text
|
|
elif block.type == "tool_use":
|
|
tool_calls.append(ToolCallResult(
|
|
id=block.id,
|
|
name=block.name,
|
|
arguments=block.input,
|
|
))
|
|
|
|
usage = UsageStats(
|
|
input_tokens=response.usage.input_tokens,
|
|
output_tokens=response.usage.output_tokens,
|
|
) if response.usage else UsageStats()
|
|
|
|
finish_reason = response.stop_reason or "stop"
|
|
if tool_calls:
|
|
finish_reason = "tool_use"
|
|
|
|
return ProviderResponse(
|
|
text=text or None,
|
|
tool_calls=tool_calls,
|
|
usage=usage,
|
|
finish_reason=finish_reason,
|
|
model=response.model,
|
|
)
|