503 lines
16 KiB
Python
503 lines
16 KiB
Python
"""
|
|
AI Client for oAI.
|
|
|
|
This module provides a high-level client for interacting with AI models
|
|
through the provider abstraction layer.
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
from typing import Any, Callable, Dict, Iterator, List, Optional, Union
|
|
|
|
from oai.constants import APP_NAME, APP_URL, MODEL_PRICING, OLLAMA_DEFAULT_URL
|
|
from oai.providers.base import (
|
|
AIProvider,
|
|
ChatMessage,
|
|
ChatResponse,
|
|
ModelInfo,
|
|
StreamChunk,
|
|
ToolCall,
|
|
UsageStats,
|
|
)
|
|
from oai.utils.logging import get_logger
|
|
|
|
|
|
class AIClient:
|
|
"""
|
|
High-level AI client for chat interactions.
|
|
|
|
Provides a simplified interface for sending chat requests,
|
|
handling streaming, and managing tool calls.
|
|
|
|
Attributes:
|
|
provider: The underlying AI provider
|
|
provider_name: Name of the current provider
|
|
default_model: Default model ID to use
|
|
http_headers: Custom HTTP headers for requests
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
provider_name: str = "openrouter",
|
|
provider_api_keys: Optional[Dict[str, str]] = None,
|
|
ollama_base_url: str = OLLAMA_DEFAULT_URL,
|
|
app_name: str = APP_NAME,
|
|
app_url: str = APP_URL,
|
|
):
|
|
"""
|
|
Initialize the AI client with specified provider.
|
|
|
|
Args:
|
|
provider_name: Provider to use ("openrouter", "anthropic", "openai", "ollama")
|
|
provider_api_keys: Dict mapping provider names to API keys
|
|
ollama_base_url: Base URL for Ollama server
|
|
app_name: Application name for headers
|
|
app_url: Application URL for headers
|
|
|
|
Raises:
|
|
ValueError: If provider is invalid or not configured
|
|
"""
|
|
from oai.providers.registry import get_provider_class
|
|
|
|
self.provider_name = provider_name
|
|
self.provider_api_keys = provider_api_keys or {}
|
|
self.ollama_base_url = ollama_base_url
|
|
self.app_name = app_name
|
|
self.app_url = app_url
|
|
|
|
# Get provider class
|
|
provider_class = get_provider_class(provider_name)
|
|
if not provider_class:
|
|
raise ValueError(f"Unknown provider: {provider_name}")
|
|
|
|
# Get API key for this provider
|
|
api_key = self.provider_api_keys.get(provider_name, "")
|
|
|
|
# Initialize provider with appropriate parameters
|
|
if provider_name == "ollama":
|
|
self.provider: AIProvider = provider_class(
|
|
api_key=api_key,
|
|
base_url=ollama_base_url,
|
|
)
|
|
else:
|
|
self.provider: AIProvider = provider_class(
|
|
api_key=api_key,
|
|
app_name=app_name,
|
|
app_url=app_url,
|
|
)
|
|
|
|
self.default_model: Optional[str] = None
|
|
self.logger = get_logger()
|
|
|
|
self.logger.info(f"Initialized {provider_name} provider")
|
|
|
|
def list_models(self, filter_text_only: bool = True) -> List[ModelInfo]:
|
|
"""
|
|
Get available models.
|
|
|
|
Args:
|
|
filter_text_only: Whether to exclude video-only models
|
|
|
|
Returns:
|
|
List of ModelInfo objects
|
|
"""
|
|
return self.provider.list_models(filter_text_only=filter_text_only)
|
|
|
|
def get_model(self, model_id: str) -> Optional[ModelInfo]:
|
|
"""
|
|
Get information about a specific model.
|
|
|
|
Args:
|
|
model_id: Model identifier
|
|
|
|
Returns:
|
|
ModelInfo or None if not found
|
|
"""
|
|
return self.provider.get_model(model_id)
|
|
|
|
def get_raw_model(self, model_id: str) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Get raw model data for provider-specific fields.
|
|
|
|
Args:
|
|
model_id: Model identifier
|
|
|
|
Returns:
|
|
Raw model dictionary or None
|
|
"""
|
|
if hasattr(self.provider, "get_raw_model"):
|
|
return self.provider.get_raw_model(model_id)
|
|
return None
|
|
|
|
def chat(
|
|
self,
|
|
messages: List[Dict[str, Any]],
|
|
model: Optional[str] = None,
|
|
stream: bool = False,
|
|
max_tokens: Optional[int] = None,
|
|
temperature: Optional[float] = None,
|
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
tool_choice: Optional[str] = None,
|
|
system_prompt: Optional[str] = None,
|
|
online: bool = False,
|
|
transforms: Optional[List[str]] = None,
|
|
enable_web_search: bool = False,
|
|
web_search_config: Optional[Dict[str, Any]] = None,
|
|
) -> Union[ChatResponse, Iterator[StreamChunk]]:
|
|
"""
|
|
Send a chat request.
|
|
|
|
Args:
|
|
messages: List of message dictionaries
|
|
model: Model ID (uses default if not specified)
|
|
stream: Whether to stream the response
|
|
max_tokens: Maximum tokens in response
|
|
temperature: Sampling temperature
|
|
tools: Tool definitions for function calling
|
|
tool_choice: Tool selection mode
|
|
system_prompt: System prompt to prepend
|
|
online: Whether to enable online mode
|
|
transforms: List of transforms (e.g., ["middle-out"])
|
|
enable_web_search: Enable native web search (Anthropic only)
|
|
web_search_config: Web search configuration (Anthropic only)
|
|
|
|
Returns:
|
|
ChatResponse for non-streaming, Iterator[StreamChunk] for streaming
|
|
|
|
Raises:
|
|
ValueError: If no model specified and no default set
|
|
"""
|
|
model_id = model or self.default_model
|
|
if not model_id:
|
|
raise ValueError("No model specified and no default set")
|
|
|
|
# Apply online mode suffix
|
|
if online and hasattr(self.provider, "get_effective_model_id"):
|
|
model_id = self.provider.get_effective_model_id(model_id, True)
|
|
|
|
# Convert dict messages to ChatMessage objects
|
|
chat_messages = []
|
|
|
|
# Add system prompt if provided
|
|
if system_prompt:
|
|
chat_messages.append(ChatMessage(role="system", content=system_prompt))
|
|
|
|
# Convert message dicts
|
|
for msg in messages:
|
|
# Convert tool_calls dicts to ToolCall objects if present
|
|
tool_calls_data = msg.get("tool_calls")
|
|
tool_calls_obj = None
|
|
if tool_calls_data:
|
|
from oai.providers.base import ToolCall, ToolFunction
|
|
tool_calls_obj = []
|
|
for tc in tool_calls_data:
|
|
# Handle both ToolCall objects and dicts
|
|
if isinstance(tc, ToolCall):
|
|
tool_calls_obj.append(tc)
|
|
elif isinstance(tc, dict):
|
|
func_data = tc.get("function", {})
|
|
tool_calls_obj.append(
|
|
ToolCall(
|
|
id=tc.get("id", ""),
|
|
type=tc.get("type", "function"),
|
|
function=ToolFunction(
|
|
name=func_data.get("name", ""),
|
|
arguments=func_data.get("arguments", "{}"),
|
|
),
|
|
)
|
|
)
|
|
|
|
chat_messages.append(
|
|
ChatMessage(
|
|
role=msg.get("role", "user"),
|
|
content=msg.get("content"),
|
|
tool_calls=tool_calls_obj,
|
|
tool_call_id=msg.get("tool_call_id"),
|
|
)
|
|
)
|
|
|
|
self.logger.debug(
|
|
f"Sending chat request: model={model_id}, "
|
|
f"messages={len(chat_messages)}, stream={stream}"
|
|
)
|
|
|
|
return self.provider.chat(
|
|
model=model_id,
|
|
messages=chat_messages,
|
|
stream=stream,
|
|
max_tokens=max_tokens,
|
|
temperature=temperature,
|
|
tools=tools,
|
|
tool_choice=tool_choice,
|
|
transforms=transforms,
|
|
enable_web_search=enable_web_search,
|
|
web_search_config=web_search_config or {},
|
|
)
|
|
|
|
def chat_with_tools(
|
|
self,
|
|
messages: List[Dict[str, Any]],
|
|
tools: List[Dict[str, Any]],
|
|
tool_executor: Callable[[str, Dict[str, Any]], Dict[str, Any]],
|
|
model: Optional[str] = None,
|
|
max_loops: int = 5,
|
|
max_tokens: Optional[int] = None,
|
|
system_prompt: Optional[str] = None,
|
|
on_tool_call: Optional[Callable[[ToolCall], None]] = None,
|
|
on_tool_result: Optional[Callable[[str, Dict[str, Any]], None]] = None,
|
|
) -> ChatResponse:
|
|
"""
|
|
Send a chat request with automatic tool call handling.
|
|
|
|
Executes tool calls returned by the model and continues
|
|
the conversation until no more tool calls are requested.
|
|
|
|
Args:
|
|
messages: Initial messages
|
|
tools: Tool definitions
|
|
tool_executor: Function to execute tool calls
|
|
model: Model ID
|
|
max_loops: Maximum tool call iterations
|
|
max_tokens: Maximum response tokens
|
|
system_prompt: System prompt
|
|
on_tool_call: Callback when tool is called
|
|
on_tool_result: Callback when tool returns result
|
|
|
|
Returns:
|
|
Final ChatResponse after all tool calls complete
|
|
"""
|
|
model_id = model or self.default_model
|
|
if not model_id:
|
|
raise ValueError("No model specified and no default set")
|
|
|
|
# Build initial messages
|
|
chat_messages = []
|
|
if system_prompt:
|
|
chat_messages.append({"role": "system", "content": system_prompt})
|
|
chat_messages.extend(messages)
|
|
|
|
loop_count = 0
|
|
current_response: Optional[ChatResponse] = None
|
|
|
|
while loop_count < max_loops:
|
|
# Send request
|
|
response = self.chat(
|
|
messages=chat_messages,
|
|
model=model_id,
|
|
stream=False,
|
|
max_tokens=max_tokens,
|
|
tools=tools,
|
|
tool_choice="auto",
|
|
)
|
|
|
|
if not isinstance(response, ChatResponse):
|
|
raise ValueError("Expected non-streaming response")
|
|
|
|
current_response = response
|
|
|
|
# Check for tool calls
|
|
tool_calls = response.tool_calls
|
|
if not tool_calls:
|
|
break
|
|
|
|
self.logger.info(f"Model requested {len(tool_calls)} tool call(s)")
|
|
|
|
# Process each tool call
|
|
tool_results = []
|
|
for tc in tool_calls:
|
|
if on_tool_call:
|
|
on_tool_call(tc)
|
|
|
|
try:
|
|
args = json.loads(tc.function.arguments)
|
|
except json.JSONDecodeError as e:
|
|
self.logger.error(f"Failed to parse tool arguments: {e}")
|
|
result = {"error": f"Invalid arguments: {e}"}
|
|
else:
|
|
result = tool_executor(tc.function.name, args)
|
|
|
|
if on_tool_result:
|
|
on_tool_result(tc.function.name, result)
|
|
|
|
tool_results.append({
|
|
"tool_call_id": tc.id,
|
|
"role": "tool",
|
|
"name": tc.function.name,
|
|
"content": json.dumps(result),
|
|
})
|
|
|
|
# Add assistant message with tool calls
|
|
assistant_msg = {
|
|
"role": "assistant",
|
|
"content": response.content,
|
|
"tool_calls": [
|
|
{
|
|
"id": tc.id,
|
|
"type": tc.type,
|
|
"function": {
|
|
"name": tc.function.name,
|
|
"arguments": tc.function.arguments,
|
|
},
|
|
}
|
|
for tc in tool_calls
|
|
],
|
|
}
|
|
chat_messages.append(assistant_msg)
|
|
chat_messages.extend(tool_results)
|
|
|
|
loop_count += 1
|
|
|
|
if loop_count >= max_loops:
|
|
self.logger.warning(f"Reached max tool call loops ({max_loops})")
|
|
|
|
return current_response
|
|
|
|
def stream_chat(
|
|
self,
|
|
messages: List[Dict[str, Any]],
|
|
model: Optional[str] = None,
|
|
max_tokens: Optional[int] = None,
|
|
system_prompt: Optional[str] = None,
|
|
online: bool = False,
|
|
on_chunk: Optional[Callable[[StreamChunk], None]] = None,
|
|
) -> tuple[str, Optional[UsageStats]]:
|
|
"""
|
|
Stream a chat response and collect the full text.
|
|
|
|
Args:
|
|
messages: Chat messages
|
|
model: Model ID
|
|
max_tokens: Maximum tokens
|
|
system_prompt: System prompt
|
|
online: Online mode
|
|
on_chunk: Optional callback for each chunk
|
|
|
|
Returns:
|
|
Tuple of (full_response_text, usage_stats)
|
|
"""
|
|
response = self.chat(
|
|
messages=messages,
|
|
model=model,
|
|
stream=True,
|
|
max_tokens=max_tokens,
|
|
system_prompt=system_prompt,
|
|
online=online,
|
|
)
|
|
|
|
if isinstance(response, ChatResponse):
|
|
# Not actually streaming
|
|
return response.content or "", response.usage
|
|
|
|
full_text = ""
|
|
usage: Optional[UsageStats] = None
|
|
|
|
for chunk in response:
|
|
if chunk.error:
|
|
self.logger.error(f"Stream error: {chunk.error}")
|
|
break
|
|
|
|
if chunk.delta_content:
|
|
full_text += chunk.delta_content
|
|
if on_chunk:
|
|
on_chunk(chunk)
|
|
|
|
if chunk.usage:
|
|
usage = chunk.usage
|
|
|
|
return full_text, usage
|
|
|
|
def get_credits(self) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Get account credit information.
|
|
|
|
Returns:
|
|
Credit info dict or None if unavailable
|
|
"""
|
|
return self.provider.get_credits()
|
|
|
|
def estimate_cost(
|
|
self,
|
|
model_id: str,
|
|
input_tokens: int,
|
|
output_tokens: int,
|
|
) -> float:
|
|
"""
|
|
Estimate cost for a completion.
|
|
|
|
Args:
|
|
model_id: Model ID
|
|
input_tokens: Number of input tokens
|
|
output_tokens: Number of output tokens
|
|
|
|
Returns:
|
|
Estimated cost in USD
|
|
"""
|
|
if hasattr(self.provider, "estimate_cost"):
|
|
return self.provider.estimate_cost(model_id, input_tokens, output_tokens)
|
|
|
|
# Fallback to default pricing
|
|
input_cost = MODEL_PRICING["input"] * input_tokens / 1_000_000
|
|
output_cost = MODEL_PRICING["output"] * output_tokens / 1_000_000
|
|
return input_cost + output_cost
|
|
|
|
def set_default_model(self, model_id: str) -> None:
|
|
"""
|
|
Set the default model.
|
|
|
|
Args:
|
|
model_id: Model ID to use as default
|
|
"""
|
|
self.default_model = model_id
|
|
self.logger.info(f"Default model set to: {model_id}")
|
|
|
|
def clear_cache(self) -> None:
|
|
"""Clear the provider's model cache."""
|
|
if hasattr(self.provider, "clear_cache"):
|
|
self.provider.clear_cache()
|
|
|
|
def switch_provider(self, provider_name: str, ollama_base_url: Optional[str] = None) -> None:
|
|
"""
|
|
Switch to a different provider.
|
|
|
|
Args:
|
|
provider_name: Provider to switch to
|
|
ollama_base_url: Optional Ollama base URL (if switching to Ollama)
|
|
|
|
Raises:
|
|
ValueError: If provider is invalid or not configured
|
|
"""
|
|
from oai.providers.registry import get_provider_class
|
|
|
|
# Get provider class
|
|
provider_class = get_provider_class(provider_name)
|
|
if not provider_class:
|
|
raise ValueError(f"Unknown provider: {provider_name}")
|
|
|
|
# Get API key
|
|
api_key = self.provider_api_keys.get(provider_name, "")
|
|
|
|
# Check API key requirement
|
|
if provider_name != "ollama" and not api_key:
|
|
raise ValueError(f"No API key configured for {provider_name}")
|
|
|
|
# Initialize new provider
|
|
if provider_name == "ollama":
|
|
base_url = ollama_base_url or self.ollama_base_url
|
|
self.provider = provider_class(
|
|
api_key=api_key,
|
|
base_url=base_url,
|
|
)
|
|
self.ollama_base_url = base_url
|
|
else:
|
|
self.provider = provider_class(
|
|
api_key=api_key,
|
|
app_name=self.app_name,
|
|
app_url=self.app_url,
|
|
)
|
|
|
|
self.provider_name = provider_name
|
|
self.logger.info(f"Switched to {provider_name} provider")
|
|
|
|
# Clear model cache when switching providers
|
|
self.default_model = None
|