Files
oai/oai/core/client.py

503 lines
16 KiB
Python

"""
AI Client for oAI.
This module provides a high-level client for interacting with AI models
through the provider abstraction layer.
"""
import asyncio
import json
from typing import Any, Callable, Dict, Iterator, List, Optional, Union
from oai.constants import APP_NAME, APP_URL, MODEL_PRICING, OLLAMA_DEFAULT_URL
from oai.providers.base import (
AIProvider,
ChatMessage,
ChatResponse,
ModelInfo,
StreamChunk,
ToolCall,
UsageStats,
)
from oai.utils.logging import get_logger
class AIClient:
"""
High-level AI client for chat interactions.
Provides a simplified interface for sending chat requests,
handling streaming, and managing tool calls.
Attributes:
provider: The underlying AI provider
provider_name: Name of the current provider
default_model: Default model ID to use
http_headers: Custom HTTP headers for requests
"""
def __init__(
self,
provider_name: str = "openrouter",
provider_api_keys: Optional[Dict[str, str]] = None,
ollama_base_url: str = OLLAMA_DEFAULT_URL,
app_name: str = APP_NAME,
app_url: str = APP_URL,
):
"""
Initialize the AI client with specified provider.
Args:
provider_name: Provider to use ("openrouter", "anthropic", "openai", "ollama")
provider_api_keys: Dict mapping provider names to API keys
ollama_base_url: Base URL for Ollama server
app_name: Application name for headers
app_url: Application URL for headers
Raises:
ValueError: If provider is invalid or not configured
"""
from oai.providers.registry import get_provider_class
self.provider_name = provider_name
self.provider_api_keys = provider_api_keys or {}
self.ollama_base_url = ollama_base_url
self.app_name = app_name
self.app_url = app_url
# Get provider class
provider_class = get_provider_class(provider_name)
if not provider_class:
raise ValueError(f"Unknown provider: {provider_name}")
# Get API key for this provider
api_key = self.provider_api_keys.get(provider_name, "")
# Initialize provider with appropriate parameters
if provider_name == "ollama":
self.provider: AIProvider = provider_class(
api_key=api_key,
base_url=ollama_base_url,
)
else:
self.provider: AIProvider = provider_class(
api_key=api_key,
app_name=app_name,
app_url=app_url,
)
self.default_model: Optional[str] = None
self.logger = get_logger()
self.logger.info(f"Initialized {provider_name} provider")
def list_models(self, filter_text_only: bool = True) -> List[ModelInfo]:
"""
Get available models.
Args:
filter_text_only: Whether to exclude video-only models
Returns:
List of ModelInfo objects
"""
return self.provider.list_models(filter_text_only=filter_text_only)
def get_model(self, model_id: str) -> Optional[ModelInfo]:
"""
Get information about a specific model.
Args:
model_id: Model identifier
Returns:
ModelInfo or None if not found
"""
return self.provider.get_model(model_id)
def get_raw_model(self, model_id: str) -> Optional[Dict[str, Any]]:
"""
Get raw model data for provider-specific fields.
Args:
model_id: Model identifier
Returns:
Raw model dictionary or None
"""
if hasattr(self.provider, "get_raw_model"):
return self.provider.get_raw_model(model_id)
return None
def chat(
self,
messages: List[Dict[str, Any]],
model: Optional[str] = None,
stream: bool = False,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
tools: Optional[List[Dict[str, Any]]] = None,
tool_choice: Optional[str] = None,
system_prompt: Optional[str] = None,
online: bool = False,
transforms: Optional[List[str]] = None,
enable_web_search: bool = False,
web_search_config: Optional[Dict[str, Any]] = None,
) -> Union[ChatResponse, Iterator[StreamChunk]]:
"""
Send a chat request.
Args:
messages: List of message dictionaries
model: Model ID (uses default if not specified)
stream: Whether to stream the response
max_tokens: Maximum tokens in response
temperature: Sampling temperature
tools: Tool definitions for function calling
tool_choice: Tool selection mode
system_prompt: System prompt to prepend
online: Whether to enable online mode
transforms: List of transforms (e.g., ["middle-out"])
enable_web_search: Enable native web search (Anthropic only)
web_search_config: Web search configuration (Anthropic only)
Returns:
ChatResponse for non-streaming, Iterator[StreamChunk] for streaming
Raises:
ValueError: If no model specified and no default set
"""
model_id = model or self.default_model
if not model_id:
raise ValueError("No model specified and no default set")
# Apply online mode suffix
if online and hasattr(self.provider, "get_effective_model_id"):
model_id = self.provider.get_effective_model_id(model_id, True)
# Convert dict messages to ChatMessage objects
chat_messages = []
# Add system prompt if provided
if system_prompt:
chat_messages.append(ChatMessage(role="system", content=system_prompt))
# Convert message dicts
for msg in messages:
# Convert tool_calls dicts to ToolCall objects if present
tool_calls_data = msg.get("tool_calls")
tool_calls_obj = None
if tool_calls_data:
from oai.providers.base import ToolCall, ToolFunction
tool_calls_obj = []
for tc in tool_calls_data:
# Handle both ToolCall objects and dicts
if isinstance(tc, ToolCall):
tool_calls_obj.append(tc)
elif isinstance(tc, dict):
func_data = tc.get("function", {})
tool_calls_obj.append(
ToolCall(
id=tc.get("id", ""),
type=tc.get("type", "function"),
function=ToolFunction(
name=func_data.get("name", ""),
arguments=func_data.get("arguments", "{}"),
),
)
)
chat_messages.append(
ChatMessage(
role=msg.get("role", "user"),
content=msg.get("content"),
tool_calls=tool_calls_obj,
tool_call_id=msg.get("tool_call_id"),
)
)
self.logger.debug(
f"Sending chat request: model={model_id}, "
f"messages={len(chat_messages)}, stream={stream}"
)
return self.provider.chat(
model=model_id,
messages=chat_messages,
stream=stream,
max_tokens=max_tokens,
temperature=temperature,
tools=tools,
tool_choice=tool_choice,
transforms=transforms,
enable_web_search=enable_web_search,
web_search_config=web_search_config or {},
)
def chat_with_tools(
self,
messages: List[Dict[str, Any]],
tools: List[Dict[str, Any]],
tool_executor: Callable[[str, Dict[str, Any]], Dict[str, Any]],
model: Optional[str] = None,
max_loops: int = 5,
max_tokens: Optional[int] = None,
system_prompt: Optional[str] = None,
on_tool_call: Optional[Callable[[ToolCall], None]] = None,
on_tool_result: Optional[Callable[[str, Dict[str, Any]], None]] = None,
) -> ChatResponse:
"""
Send a chat request with automatic tool call handling.
Executes tool calls returned by the model and continues
the conversation until no more tool calls are requested.
Args:
messages: Initial messages
tools: Tool definitions
tool_executor: Function to execute tool calls
model: Model ID
max_loops: Maximum tool call iterations
max_tokens: Maximum response tokens
system_prompt: System prompt
on_tool_call: Callback when tool is called
on_tool_result: Callback when tool returns result
Returns:
Final ChatResponse after all tool calls complete
"""
model_id = model or self.default_model
if not model_id:
raise ValueError("No model specified and no default set")
# Build initial messages
chat_messages = []
if system_prompt:
chat_messages.append({"role": "system", "content": system_prompt})
chat_messages.extend(messages)
loop_count = 0
current_response: Optional[ChatResponse] = None
while loop_count < max_loops:
# Send request
response = self.chat(
messages=chat_messages,
model=model_id,
stream=False,
max_tokens=max_tokens,
tools=tools,
tool_choice="auto",
)
if not isinstance(response, ChatResponse):
raise ValueError("Expected non-streaming response")
current_response = response
# Check for tool calls
tool_calls = response.tool_calls
if not tool_calls:
break
self.logger.info(f"Model requested {len(tool_calls)} tool call(s)")
# Process each tool call
tool_results = []
for tc in tool_calls:
if on_tool_call:
on_tool_call(tc)
try:
args = json.loads(tc.function.arguments)
except json.JSONDecodeError as e:
self.logger.error(f"Failed to parse tool arguments: {e}")
result = {"error": f"Invalid arguments: {e}"}
else:
result = tool_executor(tc.function.name, args)
if on_tool_result:
on_tool_result(tc.function.name, result)
tool_results.append({
"tool_call_id": tc.id,
"role": "tool",
"name": tc.function.name,
"content": json.dumps(result),
})
# Add assistant message with tool calls
assistant_msg = {
"role": "assistant",
"content": response.content,
"tool_calls": [
{
"id": tc.id,
"type": tc.type,
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments,
},
}
for tc in tool_calls
],
}
chat_messages.append(assistant_msg)
chat_messages.extend(tool_results)
loop_count += 1
if loop_count >= max_loops:
self.logger.warning(f"Reached max tool call loops ({max_loops})")
return current_response
def stream_chat(
self,
messages: List[Dict[str, Any]],
model: Optional[str] = None,
max_tokens: Optional[int] = None,
system_prompt: Optional[str] = None,
online: bool = False,
on_chunk: Optional[Callable[[StreamChunk], None]] = None,
) -> tuple[str, Optional[UsageStats]]:
"""
Stream a chat response and collect the full text.
Args:
messages: Chat messages
model: Model ID
max_tokens: Maximum tokens
system_prompt: System prompt
online: Online mode
on_chunk: Optional callback for each chunk
Returns:
Tuple of (full_response_text, usage_stats)
"""
response = self.chat(
messages=messages,
model=model,
stream=True,
max_tokens=max_tokens,
system_prompt=system_prompt,
online=online,
)
if isinstance(response, ChatResponse):
# Not actually streaming
return response.content or "", response.usage
full_text = ""
usage: Optional[UsageStats] = None
for chunk in response:
if chunk.error:
self.logger.error(f"Stream error: {chunk.error}")
break
if chunk.delta_content:
full_text += chunk.delta_content
if on_chunk:
on_chunk(chunk)
if chunk.usage:
usage = chunk.usage
return full_text, usage
def get_credits(self) -> Optional[Dict[str, Any]]:
"""
Get account credit information.
Returns:
Credit info dict or None if unavailable
"""
return self.provider.get_credits()
def estimate_cost(
self,
model_id: str,
input_tokens: int,
output_tokens: int,
) -> float:
"""
Estimate cost for a completion.
Args:
model_id: Model ID
input_tokens: Number of input tokens
output_tokens: Number of output tokens
Returns:
Estimated cost in USD
"""
if hasattr(self.provider, "estimate_cost"):
return self.provider.estimate_cost(model_id, input_tokens, output_tokens)
# Fallback to default pricing
input_cost = MODEL_PRICING["input"] * input_tokens / 1_000_000
output_cost = MODEL_PRICING["output"] * output_tokens / 1_000_000
return input_cost + output_cost
def set_default_model(self, model_id: str) -> None:
"""
Set the default model.
Args:
model_id: Model ID to use as default
"""
self.default_model = model_id
self.logger.info(f"Default model set to: {model_id}")
def clear_cache(self) -> None:
"""Clear the provider's model cache."""
if hasattr(self.provider, "clear_cache"):
self.provider.clear_cache()
def switch_provider(self, provider_name: str, ollama_base_url: Optional[str] = None) -> None:
"""
Switch to a different provider.
Args:
provider_name: Provider to switch to
ollama_base_url: Optional Ollama base URL (if switching to Ollama)
Raises:
ValueError: If provider is invalid or not configured
"""
from oai.providers.registry import get_provider_class
# Get provider class
provider_class = get_provider_class(provider_name)
if not provider_class:
raise ValueError(f"Unknown provider: {provider_name}")
# Get API key
api_key = self.provider_api_keys.get(provider_name, "")
# Check API key requirement
if provider_name != "ollama" and not api_key:
raise ValueError(f"No API key configured for {provider_name}")
# Initialize new provider
if provider_name == "ollama":
base_url = ollama_base_url or self.ollama_base_url
self.provider = provider_class(
api_key=api_key,
base_url=base_url,
)
self.ollama_base_url = base_url
else:
self.provider = provider_class(
api_key=api_key,
app_name=self.app_name,
app_url=self.app_url,
)
self.provider_name = provider_name
self.logger.info(f"Switched to {provider_name} provider")
# Clear model cache when switching providers
self.default_model = None