Files
oai/oai/providers/openrouter.py
Rune Olsen b0cf88704e 2.1 (#2)
Final release of version 2.1.

Headlights:

### Core Features
- 🤖 Interactive chat with 300+ AI models via OpenRouter
- 🔍 Model selection with search and filtering
- 💾 Conversation save/load/export (Markdown, JSON, HTML)
- 📎 File attachments (images, PDFs, code files)
- 💰 Real-time cost tracking and credit monitoring
- 🎨 Rich terminal UI with syntax highlighting
- 📝 Persistent command history with search (Ctrl+R)
- 🌐 Online mode (web search capabilities)
- 🧠 Conversation memory toggle

### MCP Integration
- 🔧 **File Mode**: AI can read, search, and list local files
  - Automatic .gitignore filtering
  - Virtual environment exclusion
  - Large file handling (auto-truncates >50KB)

- ✍️ **Write Mode**: AI can modify files with permission
  - Create, edit, delete files
  - Move, copy, organize files
  - Always requires explicit opt-in

- 🗄️ **Database Mode**: AI can query SQLite databases
  - Read-only access (safe)
  - Schema inspection
  - Full SQL query support

Reviewed-on: #2
Co-authored-by: Rune Olsen <rune@rune.pm>
Co-committed-by: Rune Olsen <rune@rune.pm>
2026-02-03 09:02:44 +01:00

624 lines
20 KiB
Python

"""
OpenRouter provider implementation.
This module implements the AIProvider interface for OpenRouter,
supporting chat completions, streaming, and function calling.
"""
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
import requests
from openrouter import OpenRouter
from oai.constants import APP_NAME, APP_URL, DEFAULT_BASE_URL
from oai.providers.base import (
AIProvider,
ChatMessage,
ChatResponse,
ChatResponseChoice,
ModelInfo,
ProviderCapabilities,
StreamChunk,
ToolCall,
ToolFunction,
UsageStats,
)
from oai.utils.logging import get_logger
class OpenRouterProvider(AIProvider):
"""
OpenRouter API provider implementation.
Provides access to multiple AI models through OpenRouter's unified API,
supporting chat completions, streaming responses, and function calling.
Attributes:
client: The underlying OpenRouter client
_models_cache: Cached list of available models
"""
def __init__(
self,
api_key: str,
base_url: Optional[str] = None,
app_name: str = APP_NAME,
app_url: str = APP_URL,
):
"""
Initialize the OpenRouter provider.
Args:
api_key: OpenRouter API key
base_url: Optional custom base URL
app_name: Application name for API headers
app_url: Application URL for API headers
"""
super().__init__(api_key, base_url or DEFAULT_BASE_URL)
self.app_name = app_name
self.app_url = app_url
self.client = OpenRouter(api_key=api_key)
self._models_cache: Optional[List[ModelInfo]] = None
self._raw_models_cache: Optional[List[Dict[str, Any]]] = None
self.logger = get_logger()
self.logger.info(f"OpenRouter provider initialized with base URL: {self.base_url}")
@property
def name(self) -> str:
"""Get the provider name."""
return "OpenRouter"
@property
def capabilities(self) -> ProviderCapabilities:
"""Get provider capabilities."""
return ProviderCapabilities(
streaming=True,
tools=True,
images=True,
online=True,
max_context=2000000, # Claude models support up to 200k
)
def _get_headers(self) -> Dict[str, str]:
"""Get standard HTTP headers for API requests."""
headers = {
"HTTP-Referer": self.app_url,
"X-Title": self.app_name,
}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
return headers
def _parse_model(self, model_data: Dict[str, Any]) -> ModelInfo:
"""
Parse raw model data into ModelInfo.
Args:
model_data: Raw model data from API
Returns:
Parsed ModelInfo object
"""
architecture = model_data.get("architecture", {})
pricing_data = model_data.get("pricing", {})
# Parse pricing (convert from string to float if needed)
pricing = {}
for key in ["prompt", "completion"]:
value = pricing_data.get(key)
if value is not None:
try:
# Convert from per-token to per-million-tokens
pricing[key] = float(value) * 1_000_000
except (ValueError, TypeError):
pricing[key] = 0.0
return ModelInfo(
id=model_data.get("id", ""),
name=model_data.get("name", model_data.get("id", "")),
description=model_data.get("description", ""),
context_length=model_data.get("context_length", 0),
pricing=pricing,
supported_parameters=model_data.get("supported_parameters", []),
input_modalities=architecture.get("input_modalities", ["text"]),
output_modalities=architecture.get("output_modalities", ["text"]),
)
def list_models(self, filter_text_only: bool = True) -> List[ModelInfo]:
"""
Fetch available models from OpenRouter.
Args:
filter_text_only: If True, exclude video-only models
Returns:
List of available models
Raises:
Exception: If API request fails
"""
if self._models_cache is not None:
return self._models_cache
try:
response = requests.get(
f"{self.base_url}/models",
headers=self._get_headers(),
timeout=10,
)
response.raise_for_status()
raw_models = response.json().get("data", [])
self._raw_models_cache = raw_models
models = []
for model_data in raw_models:
# Optionally filter out video-only models
if filter_text_only:
modalities = model_data.get("modalities", [])
if modalities and "video" in modalities and "text" not in modalities:
continue
models.append(self._parse_model(model_data))
self._models_cache = models
self.logger.info(f"Fetched {len(models)} models from OpenRouter")
return models
except requests.RequestException as e:
self.logger.error(f"Failed to fetch models: {e}")
raise
def get_raw_models(self) -> List[Dict[str, Any]]:
"""
Get raw model data as returned by the API.
Useful for accessing provider-specific fields not in ModelInfo.
Returns:
List of raw model dictionaries
"""
if self._raw_models_cache is None:
self.list_models()
return self._raw_models_cache or []
def get_model(self, model_id: str) -> Optional[ModelInfo]:
"""
Get information about a specific model.
Args:
model_id: The model identifier
Returns:
Model information or None if not found
"""
models = self.list_models()
for model in models:
if model.id == model_id:
return model
return None
def get_raw_model(self, model_id: str) -> Optional[Dict[str, Any]]:
"""
Get raw model data for a specific model.
Args:
model_id: The model identifier
Returns:
Raw model dictionary or None if not found
"""
raw_models = self.get_raw_models()
for model in raw_models:
if model.get("id") == model_id:
return model
return None
def _convert_messages(self, messages: List[ChatMessage]) -> List[Dict[str, Any]]:
"""
Convert ChatMessage objects to API format.
Args:
messages: List of ChatMessage objects
Returns:
List of message dictionaries for the API
"""
return [msg.to_dict() for msg in messages]
def _parse_usage(self, usage_data: Any) -> Optional[UsageStats]:
"""
Parse usage data from API response.
Args:
usage_data: Raw usage data from API
Returns:
Parsed UsageStats or None
"""
if not usage_data:
return None
# Handle both attribute and dict access
prompt_tokens = 0
completion_tokens = 0
total_cost = None
if hasattr(usage_data, "prompt_tokens"):
prompt_tokens = getattr(usage_data, "prompt_tokens", 0) or 0
elif isinstance(usage_data, dict):
prompt_tokens = usage_data.get("prompt_tokens", 0) or 0
if hasattr(usage_data, "completion_tokens"):
completion_tokens = getattr(usage_data, "completion_tokens", 0) or 0
elif isinstance(usage_data, dict):
completion_tokens = usage_data.get("completion_tokens", 0) or 0
# Try alternative naming (input_tokens/output_tokens)
if prompt_tokens == 0:
if hasattr(usage_data, "input_tokens"):
prompt_tokens = getattr(usage_data, "input_tokens", 0) or 0
elif isinstance(usage_data, dict):
prompt_tokens = usage_data.get("input_tokens", 0) or 0
if completion_tokens == 0:
if hasattr(usage_data, "output_tokens"):
completion_tokens = getattr(usage_data, "output_tokens", 0) or 0
elif isinstance(usage_data, dict):
completion_tokens = usage_data.get("output_tokens", 0) or 0
# Get cost if available
if hasattr(usage_data, "total_cost_usd"):
total_cost = getattr(usage_data, "total_cost_usd", None)
elif isinstance(usage_data, dict):
total_cost = usage_data.get("total_cost_usd")
return UsageStats(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
total_cost_usd=float(total_cost) if total_cost else None,
)
def _parse_tool_calls(self, tool_calls_data: Any) -> Optional[List[ToolCall]]:
"""
Parse tool calls from API response.
Args:
tool_calls_data: Raw tool calls data
Returns:
List of ToolCall objects or None
"""
if not tool_calls_data:
return None
tool_calls = []
for tc in tool_calls_data:
# Handle both attribute and dict access
if hasattr(tc, "id"):
tc_id = tc.id
tc_type = getattr(tc, "type", "function")
func = tc.function
func_name = func.name
func_args = func.arguments
else:
tc_id = tc.get("id", "")
tc_type = tc.get("type", "function")
func = tc.get("function", {})
func_name = func.get("name", "")
func_args = func.get("arguments", "{}")
tool_calls.append(
ToolCall(
id=tc_id,
type=tc_type,
function=ToolFunction(name=func_name, arguments=func_args),
)
)
return tool_calls if tool_calls else None
def _parse_response(self, response: Any) -> ChatResponse:
"""
Parse API response into ChatResponse.
Args:
response: Raw API response
Returns:
Parsed ChatResponse
"""
choices = []
for choice in response.choices:
msg = choice.message
message = ChatMessage(
role=msg.role if hasattr(msg, "role") else "assistant",
content=msg.content if hasattr(msg, "content") else None,
tool_calls=self._parse_tool_calls(
getattr(msg, "tool_calls", None)
),
)
choices.append(
ChatResponseChoice(
index=choice.index if hasattr(choice, "index") else 0,
message=message,
finish_reason=getattr(choice, "finish_reason", None),
)
)
return ChatResponse(
id=response.id if hasattr(response, "id") else "",
choices=choices,
usage=self._parse_usage(getattr(response, "usage", None)),
model=getattr(response, "model", None),
created=getattr(response, "created", None),
)
def chat(
self,
model: str,
messages: List[ChatMessage],
stream: bool = False,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
tools: Optional[List[Dict[str, Any]]] = None,
tool_choice: Optional[str] = None,
transforms: Optional[List[str]] = None,
**kwargs: Any,
) -> Union[ChatResponse, Iterator[StreamChunk]]:
"""
Send a chat completion request to OpenRouter.
Args:
model: Model ID to use
messages: List of chat messages
stream: Whether to stream the response
max_tokens: Maximum tokens in response
temperature: Sampling temperature (0-2)
tools: List of tool definitions for function calling
tool_choice: How to handle tool selection ("auto", "none", etc.)
transforms: List of transforms (e.g., ["middle-out"])
**kwargs: Additional parameters
Returns:
ChatResponse for non-streaming, Iterator[StreamChunk] for streaming
"""
# Build request parameters
params: Dict[str, Any] = {
"model": model,
"messages": self._convert_messages(messages),
"stream": stream,
"http_headers": self._get_headers(),
}
# Request usage stats in streaming responses
if stream:
params["stream_options"] = {"include_usage": True}
if max_tokens is not None:
params["max_tokens"] = max_tokens
if temperature is not None:
params["temperature"] = temperature
if tools:
params["tools"] = tools
params["tool_choice"] = tool_choice or "auto"
if transforms:
params["transforms"] = transforms
# Add any additional parameters
params.update(kwargs)
self.logger.debug(f"Sending chat request to model {model}")
try:
response = self.client.chat.send(**params)
if stream:
return self._stream_response(response)
else:
return self._parse_response(response)
except Exception as e:
self.logger.error(f"Chat request failed: {e}")
raise
def _stream_response(self, response: Any) -> Iterator[StreamChunk]:
"""
Process a streaming response.
Args:
response: Streaming response from API
Yields:
StreamChunk objects
"""
last_usage = None
try:
for chunk in response:
# Check for errors
if hasattr(chunk, "error") and chunk.error:
yield StreamChunk(
id=getattr(chunk, "id", ""),
error=chunk.error.message if hasattr(chunk.error, "message") else str(chunk.error),
)
return
# Extract delta content
delta_content = None
finish_reason = None
if hasattr(chunk, "choices") and chunk.choices:
choice = chunk.choices[0]
if hasattr(choice, "delta"):
delta = choice.delta
if hasattr(delta, "content") and delta.content:
delta_content = delta.content
finish_reason = getattr(choice, "finish_reason", None)
# Track usage from last chunk
if hasattr(chunk, "usage") and chunk.usage:
last_usage = self._parse_usage(chunk.usage)
yield StreamChunk(
id=getattr(chunk, "id", ""),
delta_content=delta_content,
finish_reason=finish_reason,
usage=last_usage if finish_reason else None,
)
except Exception as e:
self.logger.error(f"Stream error: {e}")
yield StreamChunk(id="", error=str(e))
async def chat_async(
self,
model: str,
messages: List[ChatMessage],
stream: bool = False,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
tools: Optional[List[Dict[str, Any]]] = None,
tool_choice: Optional[str] = None,
**kwargs: Any,
) -> Union[ChatResponse, AsyncIterator[StreamChunk]]:
"""
Send an async chat completion request.
Note: Currently wraps the sync implementation.
TODO: Implement true async support when OpenRouter SDK supports it.
Args:
model: Model ID to use
messages: List of chat messages
stream: Whether to stream the response
max_tokens: Maximum tokens in response
temperature: Sampling temperature
tools: List of tool definitions
tool_choice: Tool selection mode
**kwargs: Additional parameters
Returns:
ChatResponse for non-streaming, AsyncIterator for streaming
"""
# For now, use sync implementation
# TODO: Add true async when SDK supports it
result = self.chat(
model=model,
messages=messages,
stream=stream,
max_tokens=max_tokens,
temperature=temperature,
tools=tools,
tool_choice=tool_choice,
**kwargs,
)
if stream and isinstance(result, Iterator):
# Convert sync iterator to async
async def async_iter() -> AsyncIterator[StreamChunk]:
for chunk in result:
yield chunk
return async_iter()
return result
def get_credits(self) -> Optional[Dict[str, Any]]:
"""
Get OpenRouter account credit information.
Returns:
Dict with credit info:
- total_credits: Total credits purchased
- used_credits: Credits used
- credits_left: Remaining credits
Raises:
Exception: If API request fails
"""
if not self.api_key:
return None
try:
response = requests.get(
f"{self.base_url}/credits",
headers=self._get_headers(),
timeout=10,
)
response.raise_for_status()
data = response.json().get("data", {})
total_credits = float(data.get("total_credits", 0))
total_usage = float(data.get("total_usage", 0))
credits_left = total_credits - total_usage
return {
"total_credits": total_credits,
"used_credits": total_usage,
"credits_left": credits_left,
"total_credits_formatted": f"${total_credits:.2f}",
"used_credits_formatted": f"${total_usage:.2f}",
"credits_left_formatted": f"${credits_left:.2f}",
}
except Exception as e:
self.logger.error(f"Failed to fetch credits: {e}")
return None
def clear_cache(self) -> None:
"""Clear the models cache to force a refresh."""
self._models_cache = None
self._raw_models_cache = None
self.logger.debug("Models cache cleared")
def get_effective_model_id(self, model_id: str, online_enabled: bool) -> str:
"""
Get the effective model ID with online suffix if needed.
Args:
model_id: Base model ID
online_enabled: Whether online mode is enabled
Returns:
Model ID with :online suffix if applicable
"""
if online_enabled and not model_id.endswith(":online"):
return f"{model_id}:online"
return model_id
def estimate_cost(
self,
model_id: str,
input_tokens: int,
output_tokens: int,
) -> float:
"""
Estimate the cost for a completion.
Args:
model_id: Model ID
input_tokens: Number of input tokens
output_tokens: Number of output tokens
Returns:
Estimated cost in USD
"""
model = self.get_model(model_id)
if model and model.pricing:
input_cost = model.pricing.get("prompt", 0) * input_tokens / 1_000_000
output_cost = model.pricing.get("completion", 0) * output_tokens / 1_000_000
return input_cost + output_cost
# Fallback to default pricing if model not found
from oai.constants import MODEL_PRICING
input_cost = MODEL_PRICING["input"] * input_tokens / 1_000_000
output_cost = MODEL_PRICING["output"] * output_tokens / 1_000_000
return input_cost + output_cost