232 lines
8.3 KiB
Python
232 lines
8.3 KiB
Python
"""
|
|
providers/openai_provider.py — Direct OpenAI provider.
|
|
|
|
Uses the official openai SDK pointing at api.openai.com (default base URL).
|
|
Tool schema conversion reuses the same Anthropic→OpenAI format translation
|
|
as the OpenRouter provider (they share the same wire format).
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
from typing import Any
|
|
|
|
from openai import OpenAI, AsyncOpenAI
|
|
|
|
from .base import AIProvider, ProviderResponse, ToolCallResult, UsageStats
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
DEFAULT_MODEL = "gpt-4o"
|
|
|
|
# Models that use max_completion_tokens instead of max_tokens, and don't support
|
|
# tool_choice="auto" (reasoning models use implicit tool choice).
|
|
_REASONING_MODELS = frozenset({"o1", "o1-mini", "o1-preview"})
|
|
|
|
|
|
def _convert_content_blocks(blocks: list[dict]) -> list[dict]:
|
|
"""Convert Anthropic-native content blocks to OpenAI image_url format."""
|
|
result = []
|
|
for block in blocks:
|
|
if block.get("type") == "image":
|
|
src = block.get("source", {})
|
|
if src.get("type") == "base64":
|
|
data_url = f"data:{src['media_type']};base64,{src['data']}"
|
|
result.append({"type": "image_url", "image_url": {"url": data_url}})
|
|
else:
|
|
result.append(block)
|
|
return result
|
|
|
|
|
|
class OpenAIProvider(AIProvider):
|
|
def __init__(self, api_key: str) -> None:
|
|
self._client = OpenAI(api_key=api_key)
|
|
self._async_client = AsyncOpenAI(api_key=api_key)
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
return "OpenAI"
|
|
|
|
@property
|
|
def default_model(self) -> str:
|
|
return DEFAULT_MODEL
|
|
|
|
# ── Public interface ──────────────────────────────────────────────────────
|
|
|
|
def chat(
|
|
self,
|
|
messages: list[dict],
|
|
tools: list[dict] | None = None,
|
|
system: str = "",
|
|
model: str = "",
|
|
max_tokens: int = 4096,
|
|
) -> ProviderResponse:
|
|
params = self._build_params(messages, tools, system, model, max_tokens)
|
|
try:
|
|
response = self._client.chat.completions.create(**params)
|
|
return self._parse_response(response)
|
|
except Exception as e:
|
|
logger.error(f"OpenAI chat error: {e}")
|
|
return ProviderResponse(text=f"Error: {e}", finish_reason="error")
|
|
|
|
async def chat_async(
|
|
self,
|
|
messages: list[dict],
|
|
tools: list[dict] | None = None,
|
|
system: str = "",
|
|
model: str = "",
|
|
max_tokens: int = 4096,
|
|
) -> ProviderResponse:
|
|
params = self._build_params(messages, tools, system, model, max_tokens)
|
|
try:
|
|
response = await self._async_client.chat.completions.create(**params)
|
|
return self._parse_response(response)
|
|
except Exception as e:
|
|
logger.error(f"OpenAI async chat error: {e}")
|
|
return ProviderResponse(text=f"Error: {e}", finish_reason="error")
|
|
|
|
# ── Internal helpers ──────────────────────────────────────────────────────
|
|
|
|
def _build_params(
|
|
self,
|
|
messages: list[dict],
|
|
tools: list[dict] | None,
|
|
system: str,
|
|
model: str,
|
|
max_tokens: int,
|
|
) -> dict:
|
|
model = model or self.default_model
|
|
openai_messages = self._convert_messages(messages, system, model)
|
|
params: dict = {
|
|
"model": model,
|
|
"messages": openai_messages,
|
|
}
|
|
|
|
is_reasoning = model in _REASONING_MODELS
|
|
if is_reasoning:
|
|
params["max_completion_tokens"] = max_tokens
|
|
else:
|
|
params["max_tokens"] = max_tokens
|
|
|
|
if tools:
|
|
params["tools"] = [self._to_openai_tool(t) for t in tools]
|
|
if not is_reasoning:
|
|
params["tool_choice"] = "auto"
|
|
|
|
return params
|
|
|
|
def _convert_messages(self, messages: list[dict], system: str, model: str) -> list[dict]:
|
|
"""Convert aide's internal message list to OpenAI format."""
|
|
result: list[dict] = []
|
|
|
|
# Reasoning models (o1, o1-mini) don't support system role — use user role instead
|
|
is_reasoning = model in _REASONING_MODELS
|
|
if system:
|
|
if is_reasoning:
|
|
result.append({"role": "user", "content": f"[System instructions]\n{system}"})
|
|
else:
|
|
result.append({"role": "system", "content": system})
|
|
|
|
i = 0
|
|
while i < len(messages):
|
|
msg = messages[i]
|
|
role = msg["role"]
|
|
|
|
if role == "system":
|
|
i += 1
|
|
continue # Already prepended above
|
|
|
|
if role == "assistant" and msg.get("tool_calls"):
|
|
openai_tool_calls = []
|
|
for tc in msg["tool_calls"]:
|
|
openai_tool_calls.append({
|
|
"id": tc["id"],
|
|
"type": "function",
|
|
"function": {
|
|
"name": tc["name"],
|
|
"arguments": json.dumps(tc["arguments"]),
|
|
},
|
|
})
|
|
out: dict[str, Any] = {"role": "assistant", "tool_calls": openai_tool_calls}
|
|
if msg.get("content"):
|
|
out["content"] = msg["content"]
|
|
result.append(out)
|
|
|
|
elif role == "tool":
|
|
# Group consecutive tool results; collect image blocks for injection
|
|
pending_images: list[dict] = []
|
|
while i < len(messages) and messages[i]["role"] == "tool":
|
|
t = messages[i]
|
|
content = t.get("content", "")
|
|
if isinstance(content, list):
|
|
text = " ".join(b.get("text", "") for b in content if b.get("type") == "text") or "[image]"
|
|
pending_images.extend(b for b in content if b.get("type") == "image")
|
|
content = text
|
|
result.append({"role": "tool", "tool_call_id": t["tool_call_id"], "content": content})
|
|
i += 1
|
|
if pending_images:
|
|
result.append({"role": "user", "content": _convert_content_blocks(pending_images)})
|
|
continue # i already advanced
|
|
|
|
else:
|
|
content = msg.get("content", "")
|
|
if isinstance(content, list):
|
|
content = _convert_content_blocks(content)
|
|
result.append({"role": role, "content": content})
|
|
|
|
i += 1
|
|
|
|
return result
|
|
|
|
@staticmethod
|
|
def _to_openai_tool(aide_tool: dict) -> dict:
|
|
"""Convert aide's Anthropic-native tool schema to OpenAI function-calling format."""
|
|
return {
|
|
"type": "function",
|
|
"function": {
|
|
"name": aide_tool["name"],
|
|
"description": aide_tool.get("description", ""),
|
|
"parameters": aide_tool.get("input_schema", {"type": "object", "properties": {}}),
|
|
},
|
|
}
|
|
|
|
def _parse_response(self, response) -> ProviderResponse:
|
|
choice = response.choices[0] if response.choices else None
|
|
if not choice:
|
|
return ProviderResponse(text=None, finish_reason="error")
|
|
|
|
message = choice.message
|
|
text = message.content or None
|
|
tool_calls: list[ToolCallResult] = []
|
|
|
|
if message.tool_calls:
|
|
for tc in message.tool_calls:
|
|
try:
|
|
arguments = json.loads(tc.function.arguments)
|
|
except json.JSONDecodeError:
|
|
arguments = {"_raw": tc.function.arguments}
|
|
tool_calls.append(ToolCallResult(
|
|
id=tc.id,
|
|
name=tc.function.name,
|
|
arguments=arguments,
|
|
))
|
|
|
|
usage = UsageStats()
|
|
if response.usage:
|
|
usage = UsageStats(
|
|
input_tokens=response.usage.prompt_tokens,
|
|
output_tokens=response.usage.completion_tokens,
|
|
)
|
|
|
|
finish_reason = choice.finish_reason or "stop"
|
|
if tool_calls:
|
|
finish_reason = "tool_use"
|
|
|
|
return ProviderResponse(
|
|
text=text,
|
|
tool_calls=tool_calls,
|
|
usage=usage,
|
|
finish_reason=finish_reason,
|
|
model=response.model,
|
|
)
|