Lot's of changes. None breaking. v3.0.0-b3

This commit is contained in:
2026-02-05 11:21:22 +01:00
parent ecc2489eef
commit 06a3c898d3
25 changed files with 3252 additions and 117 deletions

141
README.md
View File

@@ -1,19 +1,23 @@
# oAI - OpenRouter AI Chat Client
# oAI - Open AI Chat Client
A powerful, modern **Textual TUI** chat client for OpenRouter API with **MCP (Model Context Protocol)** support, enabling AI to access local files and query SQLite databases.
A powerful, modern **Textual TUI** chat client with **multi-provider support** (OpenRouter, Anthropic, OpenAI, Ollama) and **MCP (Model Context Protocol)** integration, enabling AI to access local files and query SQLite databases.
## Features
### Core Features
- 🖥️ **Modern Textual TUI** with async streaming and beautiful interface
- 🤖 Interactive chat with 300+ AI models via OpenRouter
- 🔄 **Multi-Provider Support** - OpenRouter, Anthropic (Claude), OpenAI (ChatGPT), Ollama (local)
- 🤖 Interactive chat with 300+ AI models across providers
- 🔍 Model selection with search, filtering, and capability icons
- 💾 Conversation save/load/export (Markdown, JSON, HTML)
- 📎 File attachments (images, PDFs, code files)
- 💰 Real-time cost tracking and credit monitoring
- 💰 Real-time cost tracking and credit monitoring (OpenRouter)
- 🎨 Dark theme with syntax highlighting and Markdown rendering
- 📝 Command history navigation (Up/Down arrows)
- 🌐 Online mode (web search capabilities)
- 🌐 **Universal Online Mode** - Web search for ALL providers:
- **Anthropic Native** - Built-in search with automatic citations ($0.01/search)
- **DuckDuckGo** - Free web scraping (all providers)
- **Google Custom Search** - Premium search option
- 🧠 Conversation memory toggle
- ⌨️ Keyboard shortcuts (F1=Help, F2=Models, Ctrl+S=Stats)
@@ -36,7 +40,11 @@ A powerful, modern **Textual TUI** chat client for OpenRouter API with **MCP (Mo
## Requirements
- Python 3.10-3.13
- OpenRouter API key ([get one here](https://openrouter.ai))
- API key for your chosen provider:
- **OpenRouter**: [openrouter.ai](https://openrouter.ai)
- **Anthropic**: [console.anthropic.com](https://console.anthropic.com)
- **OpenAI**: [platform.openai.com](https://platform.openai.com)
- **Ollama**: No API key needed (local server)
## Installation
@@ -83,28 +91,111 @@ pip install -e .
# Start oAI (launches TUI)
oai
# Start with specific provider
oai --provider anthropic
oai --provider openai
oai --provider ollama
# Or with options
oai --model gpt-4o --online --mcp
oai --provider openrouter --model gpt-4o --online --mcp
# Show version
oai version
```
On first run, you'll be prompted for your OpenRouter API key.
On first run, you'll be prompted for your API key. Configure additional providers anytime with `/config`.
### Enable Web Search (All Providers)
```bash
# Using Anthropic native search (best quality, automatic citations)
/config search_provider anthropic_native
/online on
# Using free DuckDuckGo (works with all providers)
/config search_provider duckduckgo
/online on
# Using Google Custom Search (requires API key)
/config search_provider google
/config google_api_key YOUR_KEY
/config google_search_engine_id YOUR_ID
/online on
```
### Basic Commands
```bash
# In the TUI interface:
/provider # Show current provider or switch
/provider anthropic # Switch to Anthropic (Claude)
/provider openai # Switch to OpenAI (ChatGPT)
/provider ollama # Switch to Ollama (local)
/model # Select AI model (or press F2)
/online on # Enable web search
/help # Show all commands (or press F1)
/mcp on # Enable file/database access
/stats # View session statistics (or press Ctrl+S)
/config # View configuration settings
/credits # Check account credits
/credits # Check account credits (shows API balance or console link)
Ctrl+Q # Quit
```
## Web Search
oAI provides universal web search capabilities for all AI providers with three options:
### Anthropic Native Search (Recommended for Anthropic)
Anthropic's built-in web search API with automatic citations:
```bash
/config search_provider anthropic_native
/online on
# Now ask questions requiring current information
What are the latest developments in quantum computing?
```
**Features:**
-**Automatic citations** - Claude cites its sources
-**Smart searching** - Claude decides when to search
-**Progressive searches** - Multiple searches for complex queries
-**Best quality** - Professional-grade results
**Pricing:** $10 per 1,000 searches ($0.01 per search) + token costs
**Note:** Only works with Anthropic provider (Claude models)
### DuckDuckGo Search (Default - Free)
Free web scraping that works with ALL providers:
```bash
/config search_provider duckduckgo # Default
/online on
# Works with Anthropic, OpenAI, Ollama, and OpenRouter
```
**Features:**
-**Free** - No API key or costs
-**Universal** - Works with all providers
-**Privacy-friendly** - Uses DuckDuckGo
### Google Custom Search (Premium Option)
Google's Custom Search API for high-quality results:
```bash
/config search_provider google
/config google_api_key YOUR_GOOGLE_API_KEY
/config google_search_engine_id YOUR_SEARCH_ENGINE_ID
/online on
```
Get your API key: [Google Custom Search API](https://developers.google.com/custom-search/v1/overview)
## MCP (Model Context Protocol)
MCP allows the AI to interact with your local files and databases.
@@ -184,11 +275,18 @@ MCP allows the AI to interact with your local files and databases.
| Command | Description |
|---------|-------------|
| `/config` | View settings |
| `/config api` | Set API key |
| `/config provider <name>` | Set default provider |
| `/config openrouter_api_key` | Set OpenRouter API key |
| `/config anthropic_api_key` | Set Anthropic API key |
| `/config openai_api_key` | Set OpenAI API key |
| `/config ollama_base_url` | Set Ollama server URL |
| `/config search_provider <provider>` | Set search provider (anthropic_native/duckduckgo/google) |
| `/config google_api_key` | Set Google API key (for Google search) |
| `/config online on\|off` | Set default online mode |
| `/config model <id>` | Set default model |
| `/config stream on\|off` | Toggle streaming |
| `/stats` | Session statistics |
| `/credits` | Check credits |
| `/credits` | Check credits (OpenRouter) |
## CLI Options
@@ -196,9 +294,10 @@ MCP allows the AI to interact with your local files and databases.
oai [OPTIONS]
Options:
-p, --provider TEXT Provider to use (openrouter/anthropic/openai/ollama)
-m, --model TEXT Model ID to use
-s, --system TEXT System prompt
-o, --online Enable online mode
-o, --online Enable online mode (OpenRouter only)
--mcp Enable MCP server
-v, --version Show version
--help Show help
@@ -280,7 +379,23 @@ pip install -e . --force-reinstall
## Version History
### v3.0.0 (Current)
### v3.0.0-b3 (Current - Beta 3)
- 🔄 **Multi-Provider Support** - OpenRouter, Anthropic (Claude), OpenAI (ChatGPT), Ollama (local)
- 🔌 **Provider Switching** - Switch between providers mid-session with `/provider`
- 🧠 **Provider Model Memory** - Remembers last used model per provider
- ⚙️ **Provider Configuration** - Separate API keys for each provider
- 🌐 **Universal Web Search** - Three search options for all providers:
- **Anthropic Native** - Built-in search with citations ($0.01/search)
- **DuckDuckGo** - Free web scraping (default)
- **Google Custom Search** - Premium option
- 🎨 **Enhanced Header** - Shows current provider and model
- 📊 **Updated Config Screen** - Shows provider-specific settings
- 🔧 **Command Dropdown** - Added `/provider` commands for discoverability
- 🎯 **Auto-Scrolling** - Chat window auto-scrolls during streaming responses
- 🎨 **Refined UI** - Thinner scrollbar, cleaner styling
- 🔧 **Updated Models** - Claude 4.5 (Sonnet, Haiku, Opus)
### v3.0.0 (Beta)
- 🎨 **Complete migration to Textual TUI** - Modern async terminal interface
- 🗑️ **Removed CLI interface** - TUI-only for cleaner codebase (11.6% smaller)
- 🖱️ **Modal screens** - Help, stats, config, credits, model selector

View File

@@ -1,16 +1,16 @@
"""
oAI - OpenRouter AI Chat Client
oAI - Open AI Chat Client
A feature-rich terminal-based chat application that provides an interactive CLI
interface to OpenRouter's unified AI API with advanced Model Context Protocol (MCP)
A feature-rich terminal-based chat application with multi-provider support
(OpenRouter, Anthropic, OpenAI, Ollama) and advanced Model Context Protocol (MCP)
integration for filesystem and database access.
Author: Rune
License: MIT
"""
__version__ = "3.0.0-b2"
__author__ = "Rune"
__version__ = "3.0.0-b3"
__author__ = "Rune Olsen"
__license__ = "MIT"
# Lazy imports to avoid circular dependencies and improve startup time

View File

@@ -21,7 +21,7 @@ from oai.utils.logging import LoggingManager, get_logger
# Create Typer app
app = typer.Typer(
name="oai",
help=f"oAI - OpenRouter AI Chat Client (TUI)\n\nVersion: {APP_VERSION}",
help=f"oAI - Open AI Chat Client (TUI)\n\nVersion: {APP_VERSION}",
add_completion=False,
epilog="For more information, visit: " + APP_URL,
)
@@ -60,6 +60,12 @@ def main_callback(
"--mcp",
help="Enable MCP server",
),
provider: Optional[str] = typer.Option(
None,
"--provider",
"-p",
help="AI provider to use (openrouter, anthropic, openai, ollama)",
),
) -> None:
"""Main callback - launches TUI by default."""
if version_flag:
@@ -68,7 +74,7 @@ def main_callback(
# If no subcommand provided, launch TUI
if ctx.invoked_subcommand is None:
_launch_tui(model, system, online, mcp)
_launch_tui(model, system, online, mcp, provider)
def _launch_tui(
@@ -76,8 +82,11 @@ def _launch_tui(
system: Optional[str] = None,
online: bool = False,
mcp: bool = False,
provider: Optional[str] = None,
) -> None:
"""Launch the Textual TUI interface."""
from oai.constants import VALID_PROVIDERS
# Setup logging
logging_manager = LoggingManager()
logging_manager.setup()
@@ -86,17 +95,35 @@ def _launch_tui(
# Load settings
settings = Settings.load()
# Check API key
if not settings.api_key:
typer.echo("Error: No API key configured", err=True)
typer.echo("Run: oai config api to set your API key", err=True)
# Determine provider
selected_provider = provider or settings.default_provider
# Validate provider
if selected_provider not in VALID_PROVIDERS:
typer.echo(f"Error: Invalid provider: {selected_provider}", err=True)
typer.echo(f"Valid providers: {', '.join(VALID_PROVIDERS)}", err=True)
raise typer.Exit(1)
# Build provider API keys dict
provider_api_keys = {
"openrouter": settings.openrouter_api_key,
"anthropic": settings.anthropic_api_key,
"openai": settings.openai_api_key,
}
# Check if provider is configured (except Ollama which doesn't need API key)
if selected_provider != "ollama":
if not provider_api_keys.get(selected_provider):
typer.echo(f"Error: No API key configured for {selected_provider}", err=True)
typer.echo(f"Set it with: oai config {selected_provider}_api_key <key>", err=True)
raise typer.Exit(1)
# Initialize client
try:
client = AIClient(
api_key=settings.api_key,
base_url=settings.base_url,
provider_name=selected_provider,
provider_api_keys=provider_api_keys,
ollama_base_url=settings.ollama_base_url,
)
except Exception as e:
typer.echo(f"Error: Failed to initialize client: {e}", err=True)

View File

@@ -275,7 +275,7 @@ class OnlineCommand(Command):
("Enable web search", "/online on"),
("Disable web search", "/online off"),
],
notes="Not all models support online mode.",
notes="OpenRouter: Native :online suffix. Other providers: DuckDuckGo search results injected into context.",
)
def execute(self, args: str, context: CommandContext) -> CommandResult:
@@ -820,7 +820,56 @@ class ConfigCommand(Command):
from oai.constants import DEFAULT_SYSTEM_PROMPT
table = Table("Setting", "Value", show_header=True, header_style="bold magenta")
table.add_row("API Key", "***" + settings.api_key[-4:] if settings.api_key else "Not set")
# Current Session Provider Info (if available)
if context.session and context.session.client:
client = context.session.client
current_provider = client.provider_name
table.add_row(
"[bold cyan]Current Provider[/]",
f"[bold green]{current_provider}[/]"
)
# Show current provider's base URL
if hasattr(client.provider, 'base_url'):
provider_url = client.provider.base_url
table.add_row(" Provider URL", provider_url)
# Show API key status for current provider
if current_provider == "ollama":
table.add_row(" API Key", "[dim]Not required (local)[/]")
else:
current_key = settings.get_provider_api_key(current_provider)
if current_key:
masked = "***" + current_key[-4:] if len(current_key) > 4 else "***"
table.add_row(" API Key", f"[green]{masked}[/]")
else:
table.add_row(" API Key", "[red]Not set[/]")
# Show current model
if context.session.selected_model:
model_name = context.session.selected_model.get("name", "Unknown")
table.add_row(" Current Model", model_name)
# Add separator
table.add_row("", "")
# Provider configuration settings
table.add_row("[bold]Default Provider[/]", settings.default_provider)
# Provider API keys (masked)
def mask_key(key: Optional[str]) -> str:
if not key:
return "[red]Not set[/]"
return "***" + key[-4:] if len(key) > 4 else "***"
table.add_row("OpenRouter API Key", mask_key(settings.openrouter_api_key))
table.add_row("Anthropic API Key", mask_key(settings.anthropic_api_key))
table.add_row("OpenAI API Key", mask_key(settings.openai_api_key))
table.add_row("Ollama Base URL", settings.ollama_base_url)
# Legacy/general settings
table.add_row("Base URL", settings.base_url)
table.add_row("Default Model", settings.default_model or "Not set")
@@ -845,7 +894,62 @@ class ConfigCommand(Command):
setting = parts[0].lower()
value = parts[1] if len(parts) > 1 else None
if setting == "api":
# Provider selection
if setting == "provider":
from oai.constants import VALID_PROVIDERS
if value:
if value.lower() in VALID_PROVIDERS:
settings.set_default_provider(value.lower())
return CommandResult.success()
else:
return CommandResult.error(f"Invalid provider. Valid: {', '.join(VALID_PROVIDERS)}")
else:
return CommandResult.error("Usage: /config provider <provider_name>")
# Provider-specific API keys
elif setting in ["openrouter_api_key", "anthropic_api_key", "openai_api_key"]:
if value:
provider = setting.replace("_api_key", "")
settings.set_provider_api_key(provider, value)
return CommandResult.success()
else:
return CommandResult.success(data={"show_api_key_input": True, "provider": setting})
elif setting == "ollama_base_url":
if value:
settings.set_ollama_base_url(value)
return CommandResult.success()
else:
return CommandResult.error("Usage: /config ollama_base_url <url>")
# Web search settings
elif setting == "search_provider":
if value:
valid_providers = ["anthropic_native", "duckduckgo", "google"]
if value.lower() in valid_providers:
settings.set_search_provider(value.lower())
return CommandResult.success()
else:
return CommandResult.error(f"Invalid search provider. Valid: {', '.join(valid_providers)}")
else:
return CommandResult.error("Usage: /config search_provider <provider>")
elif setting == "google_api_key":
if value:
settings.set_google_api_key(value)
return CommandResult.success()
else:
return CommandResult.error("Usage: /config google_api_key <key>")
elif setting == "google_search_engine_id":
if value:
settings.set_google_search_engine_id(value)
return CommandResult.success()
else:
return CommandResult.error("Usage: /config google_search_engine_id <id>")
elif setting == "api":
if value:
settings.set_api_key(value)
else:
@@ -1388,6 +1492,104 @@ class MCPCommand(Command):
message = f"Unknown MCP command: {cmd}\nAvailable: on, off, status, add, remove, list, db, files, write, gitignore"
return CommandResult.error(message)
class ProviderCommand(Command):
"""Switch or display current provider."""
@property
def name(self) -> str:
return "/provider"
@property
def help(self) -> CommandHelp:
return CommandHelp(
description="Switch AI provider or show current provider.",
usage="/provider [provider_name]",
examples=[
("Show current provider", "/provider"),
("Switch to Anthropic", "/provider anthropic"),
("Switch to OpenAI", "/provider openai"),
("Switch to Ollama", "/provider ollama"),
("Switch to OpenRouter", "/provider openrouter"),
],
)
def execute(self, args: str, context: CommandContext) -> CommandResult:
from oai.constants import VALID_PROVIDERS
provider_name = args.strip().lower()
if not context.session:
return CommandResult.error("No active session")
if not provider_name:
# Show current provider and capabilities
current = context.session.client.provider_name
capabilities = context.session.client.provider.capabilities
message = f"\n[bold]Current Provider:[/] {current}\n"
message += f" Streaming: {capabilities.streaming}\n"
message += f" Tools: {capabilities.tools}\n"
message += f" Images: {capabilities.images}\n"
message += f" Online: {capabilities.online}"
return CommandResult.success(message=message)
# Validate provider
if provider_name not in VALID_PROVIDERS:
return CommandResult.error(
f"Invalid provider: {provider_name}\nValid: {', '.join(VALID_PROVIDERS)}"
)
# Check API key (except for Ollama)
if provider_name != "ollama":
key = context.settings.get_provider_api_key(provider_name)
if not key:
return CommandResult.error(
f"No API key for {provider_name}. Set with /config {provider_name}_api_key <key>"
)
# Switch provider
try:
context.session.client.switch_provider(provider_name)
# Try to restore last used model for this provider
last_model_id = context.settings.get_provider_model(provider_name)
restored_model = False
if last_model_id:
# Try to get the model from the new provider
raw_model = context.session.client.get_raw_model(last_model_id)
if raw_model:
context.session.set_model(raw_model)
restored_model = True
message = f"Switched to {provider_name} provider (model: {raw_model.get('name', last_model_id)})"
return CommandResult.success(message=message)
# If no stored model or model not found, select a default
if not restored_model:
models = context.session.client.list_models()
if models:
# Select a sensible default (first model)
default_model = models[0]
raw_model = context.session.client.get_raw_model(default_model.id)
if raw_model:
context.session.set_model(raw_model)
# Save this as the default for this provider
context.settings.set_provider_model(provider_name, default_model.id)
message = f"Switched to {provider_name} provider (auto-selected: {default_model.name})"
return CommandResult.success(message=message)
else:
# No models available, clear selection
context.session.selected_model = None
message = f"Switched to {provider_name} provider (no models available)"
return CommandResult.success(message=message)
return CommandResult.success(message=f"Switched to {provider_name} provider")
except Exception as e:
return CommandResult.error(f"Failed to switch provider: {e}")
class PasteCommand(Command):
"""Paste from clipboard."""
@@ -1468,6 +1670,7 @@ def register_all_commands() -> None:
DeleteCommand(),
InfoCommand(),
MCPCommand(),
ProviderCommand(),
PasteCommand(),
ModelCommand(),
]

View File

@@ -80,6 +80,7 @@ class CommandContext:
online_enabled: Whether online mode is enabled
session_tokens: Session token counts
session_cost: Session cost total
session: Reference to the ChatSession (for provider switching, etc.)
"""
settings: Optional["Settings"] = None
@@ -100,6 +101,7 @@ class CommandContext:
message_count: int = 0
is_tui: bool = False # Flag for TUI mode
current_index: int = 0
session: Optional[Any] = None # Reference to ChatSession (avoid circular import)
@dataclass

View File

@@ -6,8 +6,9 @@ configuration with type safety, validation, and persistence.
"""
from dataclasses import dataclass, field
from typing import Optional
from typing import Optional, Dict
from pathlib import Path
import json
from oai.constants import (
DEFAULT_BASE_URL,
@@ -20,6 +21,8 @@ from oai.constants import (
DEFAULT_LOG_LEVEL,
DEFAULT_SYSTEM_PROMPT,
VALID_LOG_LEVELS,
DEFAULT_PROVIDER,
OLLAMA_DEFAULT_URL,
)
from oai.config.database import get_database
@@ -34,7 +37,7 @@ class Settings:
initialization and can be persisted back.
Attributes:
api_key: OpenRouter API key
api_key: Legacy OpenRouter API key (deprecated, use openrouter_api_key)
base_url: API base URL
default_model: Default model ID to use
default_system_prompt: Custom system prompt (None = use hardcoded default, "" = blank)
@@ -45,9 +48,33 @@ class Settings:
log_max_size_mb: Maximum log file size in MB
log_backup_count: Number of log file backups to keep
log_level: Logging level (debug/info/warning/error/critical)
# Provider-specific settings
default_provider: Default AI provider to use
openrouter_api_key: OpenRouter API key
anthropic_api_key: Anthropic API key
openai_api_key: OpenAI API key
ollama_base_url: Ollama server URL
"""
# Legacy field (kept for backward compatibility)
api_key: Optional[str] = None
# Provider configuration
default_provider: str = DEFAULT_PROVIDER
openrouter_api_key: Optional[str] = None
anthropic_api_key: Optional[str] = None
openai_api_key: Optional[str] = None
ollama_base_url: str = OLLAMA_DEFAULT_URL
provider_models: Dict[str, str] = field(default_factory=dict) # provider -> last_model_id
# Web search configuration (for online mode with non-OpenRouter providers)
search_provider: str = "duckduckgo" # "duckduckgo" or "google"
google_api_key: Optional[str] = None
google_search_engine_id: Optional[str] = None
search_num_results: int = 5
# General settings
base_url: str = DEFAULT_BASE_URL
default_model: Optional[str] = None
default_system_prompt: Optional[str] = None
@@ -134,8 +161,43 @@ class Settings:
# Get system prompt from DB: None means not set (use default), "" means explicitly blank
system_prompt_value = db.get_config("default_system_prompt")
# Migration: copy legacy api_key to openrouter_api_key if not already set
legacy_api_key = db.get_config("api_key")
openrouter_key = db.get_config("openrouter_api_key")
if legacy_api_key and not openrouter_key:
db.set_config("openrouter_api_key", legacy_api_key)
openrouter_key = legacy_api_key
# Note: We keep the legacy api_key in DB for backward compatibility
# Load provider-model mapping
provider_models_json = db.get_config("provider_models")
provider_models = {}
if provider_models_json:
try:
provider_models = json.loads(provider_models_json)
except json.JSONDecodeError:
provider_models = {}
return cls(
api_key=db.get_config("api_key"),
# Legacy field
api_key=legacy_api_key,
# Provider configuration
default_provider=db.get_config("default_provider") or DEFAULT_PROVIDER,
openrouter_api_key=openrouter_key,
anthropic_api_key=db.get_config("anthropic_api_key"),
openai_api_key=db.get_config("openai_api_key"),
ollama_base_url=db.get_config("ollama_base_url") or OLLAMA_DEFAULT_URL,
provider_models=provider_models,
# Web search configuration
search_provider=db.get_config("search_provider") or "duckduckgo",
google_api_key=db.get_config("google_api_key"),
google_search_engine_id=db.get_config("google_search_engine_id"),
search_num_results=parse_int(db.get_config("search_num_results"), 5),
# General settings
base_url=db.get_config("base_url") or DEFAULT_BASE_URL,
default_model=db.get_config("default_model"),
default_system_prompt=system_prompt_value,
@@ -331,6 +393,155 @@ class Settings:
self.log_max_size_mb = min(size_mb, 100)
get_database().set_config("log_max_size_mb", str(self.log_max_size_mb))
def set_provider_api_key(self, provider: str, api_key: str) -> None:
"""
Set and persist an API key for a specific provider.
Args:
provider: Provider name ("openrouter", "anthropic", "openai")
api_key: The API key to set
Raises:
ValueError: If provider is invalid
"""
provider = provider.lower()
api_key = api_key.strip()
if provider == "openrouter":
self.openrouter_api_key = api_key
get_database().set_config("openrouter_api_key", api_key)
elif provider == "anthropic":
self.anthropic_api_key = api_key
get_database().set_config("anthropic_api_key", api_key)
elif provider == "openai":
self.openai_api_key = api_key
get_database().set_config("openai_api_key", api_key)
else:
raise ValueError(f"Invalid provider: {provider}")
def get_provider_api_key(self, provider: str) -> Optional[str]:
"""
Get the API key for a specific provider.
Args:
provider: Provider name ("openrouter", "anthropic", "openai", "ollama")
Returns:
API key or None if not set
Raises:
ValueError: If provider is invalid
"""
provider = provider.lower()
if provider == "openrouter":
return self.openrouter_api_key
elif provider == "anthropic":
return self.anthropic_api_key
elif provider == "openai":
return self.openai_api_key
elif provider == "ollama":
return "" # Ollama doesn't require an API key
else:
raise ValueError(f"Invalid provider: {provider}")
def set_default_provider(self, provider: str) -> None:
"""
Set and persist the default provider.
Args:
provider: Provider name
Raises:
ValueError: If provider is invalid
"""
from oai.constants import VALID_PROVIDERS
provider = provider.lower()
if provider not in VALID_PROVIDERS:
raise ValueError(
f"Invalid provider: {provider}. "
f"Valid providers: {', '.join(VALID_PROVIDERS)}"
)
self.default_provider = provider
get_database().set_config("default_provider", provider)
def set_ollama_base_url(self, url: str) -> None:
"""
Set and persist the Ollama base URL.
Args:
url: Ollama server URL
"""
self.ollama_base_url = url.strip()
get_database().set_config("ollama_base_url", self.ollama_base_url)
def set_search_provider(self, provider: str) -> None:
"""
Set and persist the web search provider.
Args:
provider: Search provider ("anthropic_native", "duckduckgo", "google")
Raises:
ValueError: If provider is invalid
"""
valid_providers = ["anthropic_native", "duckduckgo", "google"]
provider = provider.lower()
if provider not in valid_providers:
raise ValueError(
f"Invalid search provider: {provider}. "
f"Valid providers: {', '.join(valid_providers)}"
)
self.search_provider = provider
get_database().set_config("search_provider", provider)
def set_google_api_key(self, api_key: str) -> None:
"""
Set and persist the Google API key for Google Custom Search.
Args:
api_key: The Google API key
"""
self.google_api_key = api_key.strip()
get_database().set_config("google_api_key", self.google_api_key)
def set_google_search_engine_id(self, engine_id: str) -> None:
"""
Set and persist the Google Custom Search Engine ID.
Args:
engine_id: The Google Search Engine ID
"""
self.google_search_engine_id = engine_id.strip()
get_database().set_config("google_search_engine_id", self.google_search_engine_id)
def get_provider_model(self, provider: str) -> Optional[str]:
"""
Get the last used model for a provider.
Args:
provider: Provider name
Returns:
Model ID or None if not set
"""
return self.provider_models.get(provider)
def set_provider_model(self, provider: str, model_id: str) -> None:
"""
Set and persist the last used model for a provider.
Args:
provider: Provider name
model_id: Model ID to remember
"""
self.provider_models[provider] = model_id
# Save to database as JSON
get_database().set_config("provider_models", json.dumps(self.provider_models))
# Global settings instance
_settings: Optional[Settings] = None

View File

@@ -20,7 +20,7 @@ from oai import __version__
APP_NAME = "oAI"
APP_VERSION = __version__ # Single source of truth in oai/__init__.py
APP_URL = "https://iurl.no/oai"
APP_DESCRIPTION = "OpenRouter AI Chat Client with MCP Integration"
APP_DESCRIPTION = "Open AI Chat Client with Multi-Provider Support"
# =============================================================================
# FILE PATHS
@@ -42,6 +42,26 @@ DEFAULT_STREAM_ENABLED = True
DEFAULT_MAX_TOKENS = 100_000
DEFAULT_ONLINE_MODE = False
# =============================================================================
# PROVIDER CONFIGURATION
# =============================================================================
# Provider names
PROVIDER_OPENROUTER = "openrouter"
PROVIDER_ANTHROPIC = "anthropic"
PROVIDER_OPENAI = "openai"
PROVIDER_OLLAMA = "ollama"
VALID_PROVIDERS = [PROVIDER_OPENROUTER, PROVIDER_ANTHROPIC, PROVIDER_OPENAI, PROVIDER_OLLAMA]
# Provider base URLs
ANTHROPIC_BASE_URL = "https://api.anthropic.com"
OPENAI_BASE_URL = "https://api.openai.com/v1"
OLLAMA_DEFAULT_URL = "http://localhost:11434"
# Default provider
DEFAULT_PROVIDER = PROVIDER_OPENROUTER
# =============================================================================
# DEFAULT SYSTEM PROMPT
# =============================================================================

View File

@@ -9,7 +9,7 @@ import asyncio
import json
from typing import Any, Callable, Dict, Iterator, List, Optional, Union
from oai.constants import APP_NAME, APP_URL, MODEL_PRICING
from oai.constants import APP_NAME, APP_URL, MODEL_PRICING, OLLAMA_DEFAULT_URL
from oai.providers.base import (
AIProvider,
ChatMessage,
@@ -19,7 +19,6 @@ from oai.providers.base import (
ToolCall,
UsageStats,
)
from oai.providers.openrouter import OpenRouterProvider
from oai.utils.logging import get_logger
@@ -32,37 +31,66 @@ class AIClient:
Attributes:
provider: The underlying AI provider
provider_name: Name of the current provider
default_model: Default model ID to use
http_headers: Custom HTTP headers for requests
"""
def __init__(
self,
api_key: str,
base_url: Optional[str] = None,
provider_class: type = OpenRouterProvider,
provider_name: str = "openrouter",
provider_api_keys: Optional[Dict[str, str]] = None,
ollama_base_url: str = OLLAMA_DEFAULT_URL,
app_name: str = APP_NAME,
app_url: str = APP_URL,
):
"""
Initialize the AI client.
Initialize the AI client with specified provider.
Args:
api_key: API key for authentication
base_url: Optional custom base URL
provider_class: Provider class to use (default: OpenRouterProvider)
provider_name: Provider to use ("openrouter", "anthropic", "openai", "ollama")
provider_api_keys: Dict mapping provider names to API keys
ollama_base_url: Base URL for Ollama server
app_name: Application name for headers
app_url: Application URL for headers
Raises:
ValueError: If provider is invalid or not configured
"""
self.provider: AIProvider = provider_class(
api_key=api_key,
base_url=base_url,
app_name=app_name,
app_url=app_url,
)
from oai.providers.registry import get_provider_class
self.provider_name = provider_name
self.provider_api_keys = provider_api_keys or {}
self.ollama_base_url = ollama_base_url
self.app_name = app_name
self.app_url = app_url
# Get provider class
provider_class = get_provider_class(provider_name)
if not provider_class:
raise ValueError(f"Unknown provider: {provider_name}")
# Get API key for this provider
api_key = self.provider_api_keys.get(provider_name, "")
# Initialize provider with appropriate parameters
if provider_name == "ollama":
self.provider: AIProvider = provider_class(
api_key=api_key,
base_url=ollama_base_url,
)
else:
self.provider: AIProvider = provider_class(
api_key=api_key,
app_name=app_name,
app_url=app_url,
)
self.default_model: Optional[str] = None
self.logger = get_logger()
self.logger.info(f"Initialized {provider_name} provider")
def list_models(self, filter_text_only: bool = True) -> List[ModelInfo]:
"""
Get available models.
@@ -420,3 +448,49 @@ class AIClient:
"""Clear the provider's model cache."""
if hasattr(self.provider, "clear_cache"):
self.provider.clear_cache()
def switch_provider(self, provider_name: str, ollama_base_url: Optional[str] = None) -> None:
"""
Switch to a different provider.
Args:
provider_name: Provider to switch to
ollama_base_url: Optional Ollama base URL (if switching to Ollama)
Raises:
ValueError: If provider is invalid or not configured
"""
from oai.providers.registry import get_provider_class
# Get provider class
provider_class = get_provider_class(provider_name)
if not provider_class:
raise ValueError(f"Unknown provider: {provider_name}")
# Get API key
api_key = self.provider_api_keys.get(provider_name, "")
# Check API key requirement
if provider_name != "ollama" and not api_key:
raise ValueError(f"No API key configured for {provider_name}")
# Initialize new provider
if provider_name == "ollama":
base_url = ollama_base_url or self.ollama_base_url
self.provider = provider_class(
api_key=api_key,
base_url=base_url,
)
self.ollama_base_url = base_url
else:
self.provider = provider_class(
api_key=api_key,
app_name=self.app_name,
app_url=self.app_url,
)
self.provider_name = provider_name
self.logger.info(f"Switched to {provider_name} provider")
# Clear model cache when switching providers
self.default_model = None

View File

@@ -23,6 +23,9 @@ from oai.core.client import AIClient
from oai.mcp.manager import MCPManager
from oai.providers.base import ChatResponse, StreamChunk, UsageStats
from oai.utils.logging import get_logger
from oai.utils.web_search import perform_web_search, format_search_results
logger = get_logger()
@dataclass
@@ -164,6 +167,7 @@ class ChatSession:
total_cost=self.stats.total_cost,
message_count=self.stats.message_count,
current_index=self.current_index,
session=self,
)
def set_model(self, model: Dict[str, Any]) -> None:
@@ -290,6 +294,44 @@ class ChatSession:
# Build request parameters
model_id = self.selected_model["id"]
# Handle online mode
enable_web_search = False
web_search_config = {}
if self.online_enabled:
# OpenRouter handles online mode natively with :online suffix
if self.client.provider_name == "openrouter":
if hasattr(self.client.provider, "get_effective_model_id"):
model_id = self.client.provider.get_effective_model_id(model_id, True)
# Anthropic has native web search when search provider is set to anthropic_native
elif self.client.provider_name == "anthropic" and self.settings.search_provider == "anthropic_native":
enable_web_search = True
web_search_config = {
"max_uses": self.settings.search_num_results or 5
}
logger.info("Using Anthropic native web search")
else:
# For other providers, perform web search and inject results
logger.info(f"Performing web search for: {user_input}")
search_results = perform_web_search(
user_input,
num_results=self.settings.search_num_results,
provider=self.settings.search_provider,
google_api_key=self.settings.google_api_key,
google_search_engine_id=self.settings.google_search_engine_id
)
if search_results:
# Inject search results into messages
formatted_results = format_search_results(search_results)
search_context = f"\n\n{formatted_results}\n\nPlease use the above web search results to help answer the user's question."
# Add search results to the last user message
if messages and messages[-1]["role"] == "user":
messages[-1]["content"] += search_context
logger.info(f"Injected {len(search_results)} search results into context")
if self.online_enabled:
if hasattr(self.client.provider, "get_effective_model_id"):
model_id = self.client.provider.get_effective_model_id(model_id, True)
@@ -320,6 +362,8 @@ class ChatSession:
max_tokens=max_tokens,
transforms=transforms,
on_chunk=on_stream_chunk,
enable_web_search=enable_web_search,
web_search_config=web_search_config,
)
response_time = time.time() - start_time
return full_text, usage, response_time
@@ -455,6 +499,8 @@ class ChatSession:
max_tokens: Optional[int] = None,
transforms: Optional[List[str]] = None,
on_chunk: Optional[Callable[[str], None]] = None,
enable_web_search: bool = False,
web_search_config: Optional[Dict[str, Any]] = None,
) -> Tuple[str, Optional[UsageStats]]:
"""
Stream a response with live display.
@@ -465,6 +511,8 @@ class ChatSession:
max_tokens: Max tokens
transforms: Transforms
on_chunk: Callback for chunks
enable_web_search: Whether to enable Anthropic native web search
web_search_config: Web search configuration
Returns:
Tuple of (full_text, usage)
@@ -475,6 +523,8 @@ class ChatSession:
stream=True,
max_tokens=max_tokens,
transforms=transforms,
enable_web_search=enable_web_search,
web_search_config=web_search_config or {},
)
if isinstance(response, ChatResponse):
@@ -530,10 +580,45 @@ class ChatSession:
# Disable streaming when tools are present
stream = False
# Handle online mode
model_id = self.selected_model["id"]
enable_web_search = False
web_search_config = {}
if self.online_enabled:
if hasattr(self.client.provider, "get_effective_model_id"):
model_id = self.client.provider.get_effective_model_id(model_id, True)
# OpenRouter handles online mode natively with :online suffix
if self.client.provider_name == "openrouter":
if hasattr(self.client.provider, "get_effective_model_id"):
model_id = self.client.provider.get_effective_model_id(model_id, True)
# Anthropic has native web search when search provider is set to anthropic_native
elif self.client.provider_name == "anthropic" and self.settings.search_provider == "anthropic_native":
enable_web_search = True
web_search_config = {
"max_uses": self.settings.search_num_results or 5
}
logger.info("Using Anthropic native web search")
else:
# For other providers, perform web search and inject results
logger.info(f"Performing web search for: {user_input}")
search_results = await asyncio.to_thread(
perform_web_search,
user_input,
num_results=self.settings.search_num_results,
provider=self.settings.search_provider,
google_api_key=self.settings.google_api_key,
google_search_engine_id=self.settings.google_search_engine_id
)
if search_results:
# Inject search results into messages
formatted_results = format_search_results(search_results)
search_context = f"\n\n{formatted_results}\n\nPlease use the above web search results to help answer the user's question."
# Add search results to the last user message
if messages and messages[-1]["role"] == "user":
messages[-1]["content"] += search_context
logger.info(f"Injected {len(search_results)} search results into context")
transforms = ["middle-out"] if self.middle_out_enabled else None
max_tokens = None
@@ -557,6 +642,8 @@ class ChatSession:
model_id=model_id,
max_tokens=max_tokens,
transforms=transforms,
enable_web_search=enable_web_search,
web_search_config=web_search_config,
):
yield chunk
else:
@@ -733,6 +820,8 @@ class ChatSession:
model_id: str,
max_tokens: Optional[int] = None,
transforms: Optional[List[str]] = None,
enable_web_search: bool = False,
web_search_config: Optional[Dict[str, Any]] = None,
) -> AsyncIterator[StreamChunk]:
"""
Async version of _stream_response for TUI.
@@ -742,6 +831,8 @@ class ChatSession:
model_id: Model ID
max_tokens: Max tokens
transforms: Transforms
enable_web_search: Whether to enable Anthropic native web search
web_search_config: Web search configuration
Yields:
StreamChunk objects
@@ -752,6 +843,8 @@ class ChatSession:
stream=True,
max_tokens=max_tokens,
transforms=transforms,
enable_web_search=enable_web_search,
web_search_config=web_search_config or {},
)
if isinstance(response, ChatResponse):

View File

@@ -16,6 +16,16 @@ from oai.providers.base import (
ProviderCapabilities,
)
from oai.providers.openrouter import OpenRouterProvider
from oai.providers.anthropic import AnthropicProvider
from oai.providers.openai import OpenAIProvider
from oai.providers.ollama import OllamaProvider
from oai.providers.registry import register_provider
# Register all providers
register_provider("openrouter", OpenRouterProvider)
register_provider("anthropic", AnthropicProvider)
register_provider("openai", OpenAIProvider)
register_provider("ollama", OllamaProvider)
__all__ = [
# Base classes and types
@@ -29,4 +39,7 @@ __all__ = [
"ProviderCapabilities",
# Provider implementations
"OpenRouterProvider",
"AnthropicProvider",
"OpenAIProvider",
"OllamaProvider",
]

673
oai/providers/anthropic.py Normal file
View File

@@ -0,0 +1,673 @@
"""
Anthropic provider for Claude models.
This provider connects to Anthropic's API for accessing Claude models.
"""
import json
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
import anthropic
from anthropic.types import Message, MessageStreamEvent
from oai.constants import ANTHROPIC_BASE_URL
from oai.providers.base import (
AIProvider,
ChatMessage,
ChatResponse,
ChatResponseChoice,
ModelInfo,
ProviderCapabilities,
StreamChunk,
ToolCall,
ToolFunction,
UsageStats,
)
from oai.utils.logging import get_logger
logger = get_logger()
# Model name aliases
MODEL_ALIASES = {
"claude-sonnet": "claude-sonnet-4-5-20250929",
"claude-haiku": "claude-haiku-4-5-20251001",
"claude-opus": "claude-opus-4-5-20251101",
# Legacy aliases
"claude-3-haiku": "claude-3-haiku-20240307",
"claude-3-7-sonnet": "claude-3-7-sonnet-20250219",
}
class AnthropicProvider(AIProvider):
"""
Anthropic API provider.
Provides access to Claude 3.5 Sonnet, Claude 3 Opus, and other Anthropic models.
"""
def __init__(
self,
api_key: str,
base_url: Optional[str] = None,
app_name: str = "oAI",
app_url: str = "",
**kwargs: Any,
):
"""
Initialize Anthropic provider.
Args:
api_key: Anthropic API key
base_url: Optional custom base URL
app_name: Application name (for headers)
app_url: Application URL (for headers)
**kwargs: Additional arguments
"""
super().__init__(api_key, base_url or ANTHROPIC_BASE_URL)
self.client = anthropic.Anthropic(api_key=api_key)
self.async_client = anthropic.AsyncAnthropic(api_key=api_key)
self._models_cache: Optional[List[ModelInfo]] = None
def _create_web_search_tool(self, config: Dict[str, Any]) -> Dict[str, Any]:
"""
Create Anthropic native web search tool definition.
Args:
config: Optional configuration for web search (max_uses, allowed_domains, etc.)
Returns:
Tool definition dict
"""
tool: Dict[str, Any] = {
"type": "web_search_20250305",
"name": "web_search",
}
# Add optional parameters if provided
if "max_uses" in config:
tool["max_uses"] = config["max_uses"]
else:
tool["max_uses"] = 5 # Default
if "allowed_domains" in config:
tool["allowed_domains"] = config["allowed_domains"]
if "blocked_domains" in config:
tool["blocked_domains"] = config["blocked_domains"]
if "user_location" in config:
tool["user_location"] = config["user_location"]
return tool
@property
def name(self) -> str:
"""Get provider name."""
return "Anthropic"
@property
def capabilities(self) -> ProviderCapabilities:
"""Get provider capabilities."""
return ProviderCapabilities(
streaming=True,
tools=True,
images=True,
online=True, # Web search via DuckDuckGo/Google
max_context=200000,
)
def list_models(self, filter_text_only: bool = True) -> List[ModelInfo]:
"""
List available Anthropic models.
Args:
filter_text_only: Whether to filter for text models only
Returns:
List of ModelInfo objects
"""
if self._models_cache:
return self._models_cache
# Anthropic doesn't have a models list API, so we hardcode the available models
models = [
# Current Claude 4.5 models
ModelInfo(
id="claude-sonnet-4-5-20250929",
name="Claude Sonnet 4.5",
description="Smart model for complex agents and coding (recommended)",
context_length=200000,
pricing={"input": 3.0, "output": 15.0},
supported_parameters=["temperature", "max_tokens", "stream", "tools"],
input_modalities=["text", "image"],
),
ModelInfo(
id="claude-haiku-4-5-20251001",
name="Claude Haiku 4.5",
description="Fastest model with near-frontier intelligence",
context_length=200000,
pricing={"input": 1.0, "output": 5.0},
supported_parameters=["temperature", "max_tokens", "stream", "tools"],
input_modalities=["text", "image"],
),
ModelInfo(
id="claude-opus-4-5-20251101",
name="Claude Opus 4.5",
description="Premium model with maximum intelligence",
context_length=200000,
pricing={"input": 5.0, "output": 25.0},
supported_parameters=["temperature", "max_tokens", "stream", "tools"],
input_modalities=["text", "image"],
),
# Legacy models (still available)
ModelInfo(
id="claude-3-7-sonnet-20250219",
name="Claude Sonnet 3.7",
description="Legacy model - recommend migrating to 4.5",
context_length=200000,
pricing={"input": 3.0, "output": 15.0},
supported_parameters=["temperature", "max_tokens", "stream", "tools"],
input_modalities=["text", "image"],
),
ModelInfo(
id="claude-3-haiku-20240307",
name="Claude 3 Haiku",
description="Legacy fast model - recommend migrating to 4.5",
context_length=200000,
pricing={"input": 0.25, "output": 1.25},
supported_parameters=["temperature", "max_tokens", "stream", "tools"],
input_modalities=["text", "image"],
),
]
self._models_cache = models
logger.info(f"Loaded {len(models)} Anthropic models")
return models
def get_model(self, model_id: str) -> Optional[ModelInfo]:
"""
Get information about a specific model.
Args:
model_id: Model identifier
Returns:
ModelInfo or None
"""
# Resolve alias
resolved_id = MODEL_ALIASES.get(model_id, model_id)
models = self.list_models()
for model in models:
if model.id == resolved_id or model.id == model_id:
return model
return None
def chat(
self,
model: str,
messages: List[ChatMessage],
stream: bool = False,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
tools: Optional[List[Dict[str, Any]]] = None,
tool_choice: Optional[str] = None,
**kwargs: Any,
) -> Union[ChatResponse, Iterator[StreamChunk]]:
"""
Send chat completion request to Anthropic.
Args:
model: Model ID
messages: Chat messages
stream: Whether to stream response
max_tokens: Maximum tokens
temperature: Sampling temperature
tools: Tool definitions
tool_choice: Tool selection mode
**kwargs: Additional parameters (including enable_web_search)
Returns:
ChatResponse or Iterator[StreamChunk]
"""
# Resolve model alias
model_id = MODEL_ALIASES.get(model, model)
# Extract system message (Anthropic requires it separate from messages)
system_prompt, anthropic_messages = self._convert_messages(messages)
# Build request parameters
params: Dict[str, Any] = {
"model": model_id,
"messages": anthropic_messages,
"max_tokens": max_tokens or 4096,
}
if system_prompt:
params["system"] = system_prompt
if temperature is not None:
params["temperature"] = temperature
# Prepare tools list
tools_list = []
# Add web search tool if requested via kwargs
if kwargs.get("enable_web_search", False):
web_search_config = kwargs.get("web_search_config", {})
tools_list.append(self._create_web_search_tool(web_search_config))
logger.info("Added Anthropic native web search tool")
# Add user-provided tools
if tools:
# Convert tools to Anthropic format
converted_tools = self._convert_tools(tools)
tools_list.extend(converted_tools)
if tools_list:
params["tools"] = tools_list
if tool_choice and tool_choice != "auto":
# Anthropic uses different tool_choice format
if tool_choice == "none":
pass # Don't include tools
elif tool_choice == "required":
params["tool_choice"] = {"type": "any"}
else:
params["tool_choice"] = {"type": "tool", "name": tool_choice}
logger.debug(f"Anthropic request: model={model_id}, messages={len(anthropic_messages)}")
try:
if stream:
return self._stream_chat(params)
else:
return self._sync_chat(params)
except Exception as e:
logger.error(f"Anthropic request failed: {e}")
return ChatResponse(
id="error",
choices=[
ChatResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=f"Error: {str(e)}"),
finish_reason="error",
)
],
)
def _convert_messages(self, messages: List[ChatMessage]) -> tuple[str, List[Dict[str, Any]]]:
"""
Convert messages to Anthropic format.
Anthropic requires system messages to be separate from the conversation.
Args:
messages: List of ChatMessage objects
Returns:
Tuple of (system_prompt, anthropic_messages)
"""
system_prompt = ""
anthropic_messages = []
for msg in messages:
if msg.role == "system":
# Accumulate system messages
if system_prompt:
system_prompt += "\n\n"
system_prompt += msg.content or ""
else:
# Convert to Anthropic format
message_dict: Dict[str, Any] = {"role": msg.role}
# Handle content
if msg.content:
message_dict["content"] = msg.content
# Handle tool calls (assistant messages)
if msg.tool_calls:
# Anthropic format for tool use
content_blocks = []
if msg.content:
content_blocks.append({"type": "text", "text": msg.content})
for tc in msg.tool_calls:
content_blocks.append({
"type": "tool_use",
"id": tc.id,
"name": tc.function.name,
"input": json.loads(tc.function.arguments),
})
message_dict["content"] = content_blocks
# Handle tool results (tool messages)
if msg.role == "tool" and msg.tool_call_id:
# Convert to Anthropic's tool_result format
anthropic_messages.append({
"role": "user",
"content": [{
"type": "tool_result",
"tool_use_id": msg.tool_call_id,
"content": msg.content or "",
}]
})
continue
anthropic_messages.append(message_dict)
return system_prompt, anthropic_messages
def _convert_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Convert OpenAI-style tools to Anthropic format.
Args:
tools: OpenAI tool definitions
Returns:
Anthropic tool definitions
"""
anthropic_tools = []
for tool in tools:
if tool.get("type") == "function":
func = tool.get("function", {})
anthropic_tools.append({
"name": func.get("name"),
"description": func.get("description", ""),
"input_schema": func.get("parameters", {}),
})
return anthropic_tools
def _sync_chat(self, params: Dict[str, Any]) -> ChatResponse:
"""
Send synchronous chat request.
Args:
params: Request parameters
Returns:
ChatResponse
"""
message: Message = self.client.messages.create(**params)
# Extract content
content = ""
tool_calls = []
for block in message.content:
if block.type == "text":
content += block.text
elif block.type == "tool_use":
# Convert to ToolCall format
tool_calls.append(
ToolCall(
id=block.id,
type="function",
function=ToolFunction(
name=block.name,
arguments=json.dumps(block.input),
),
)
)
# Build ChatMessage
chat_message = ChatMessage(
role="assistant",
content=content if content else None,
tool_calls=tool_calls if tool_calls else None,
)
# Extract usage
usage = None
if message.usage:
usage = UsageStats(
prompt_tokens=message.usage.input_tokens,
completion_tokens=message.usage.output_tokens,
total_tokens=message.usage.input_tokens + message.usage.output_tokens,
)
return ChatResponse(
id=message.id,
choices=[
ChatResponseChoice(
index=0,
message=chat_message,
finish_reason=message.stop_reason,
)
],
usage=usage,
model=message.model,
)
def _stream_chat(self, params: Dict[str, Any]) -> Iterator[StreamChunk]:
"""
Stream chat response from Anthropic.
Args:
params: Request parameters
Yields:
StreamChunk objects
"""
stream = self.client.messages.stream(**params)
with stream as event_stream:
for event in event_stream:
event_data: MessageStreamEvent = event
# Handle different event types
if event_data.type == "content_block_delta":
delta = event_data.delta
if hasattr(delta, "text"):
yield StreamChunk(
id="stream",
delta_content=delta.text,
)
elif event_data.type == "message_stop":
# Final event with usage
pass
elif event_data.type == "message_delta":
# Contains stop reason and usage
usage = None
if hasattr(event_data, "usage"):
usage_data = event_data.usage
usage = UsageStats(
prompt_tokens=0,
completion_tokens=usage_data.output_tokens,
total_tokens=usage_data.output_tokens,
)
yield StreamChunk(
id="stream",
finish_reason=event_data.delta.stop_reason if hasattr(event_data.delta, "stop_reason") else None,
usage=usage,
)
async def chat_async(
self,
model: str,
messages: List[ChatMessage],
stream: bool = False,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
tools: Optional[List[Dict[str, Any]]] = None,
tool_choice: Optional[str] = None,
**kwargs: Any,
) -> Union[ChatResponse, AsyncIterator[StreamChunk]]:
"""
Send async chat request to Anthropic.
Args:
model: Model ID
messages: Chat messages
stream: Whether to stream
max_tokens: Max tokens
temperature: Temperature
tools: Tool definitions
tool_choice: Tool choice
**kwargs: Additional args
Returns:
ChatResponse or AsyncIterator[StreamChunk]
"""
# Resolve model alias
model_id = MODEL_ALIASES.get(model, model)
# Convert messages
system_prompt, anthropic_messages = self._convert_messages(messages)
# Build params
params: Dict[str, Any] = {
"model": model_id,
"messages": anthropic_messages,
"max_tokens": max_tokens or 4096,
}
if system_prompt:
params["system"] = system_prompt
if temperature is not None:
params["temperature"] = temperature
if tools:
params["tools"] = self._convert_tools(tools)
if stream:
return self._stream_chat_async(params)
else:
message = await self.async_client.messages.create(**params)
return self._convert_message(message)
async def _stream_chat_async(self, params: Dict[str, Any]) -> AsyncIterator[StreamChunk]:
"""
Stream async chat response.
Args:
params: Request parameters
Yields:
StreamChunk objects
"""
stream = await self.async_client.messages.stream(**params)
async with stream as event_stream:
async for event in event_stream:
if event.type == "content_block_delta":
delta = event.delta
if hasattr(delta, "text"):
yield StreamChunk(
id="stream",
delta_content=delta.text,
)
def _convert_message(self, message: Message) -> ChatResponse:
"""Helper to convert Anthropic message to ChatResponse."""
content = ""
tool_calls = []
for block in message.content:
if block.type == "text":
content += block.text
elif block.type == "tool_use":
tool_calls.append(
ToolCall(
id=block.id,
type="function",
function=ToolFunction(
name=block.name,
arguments=json.dumps(block.input),
),
)
)
chat_message = ChatMessage(
role="assistant",
content=content if content else None,
tool_calls=tool_calls if tool_calls else None,
)
usage = None
if message.usage:
usage = UsageStats(
prompt_tokens=message.usage.input_tokens,
completion_tokens=message.usage.output_tokens,
total_tokens=message.usage.input_tokens + message.usage.output_tokens,
)
return ChatResponse(
id=message.id,
choices=[
ChatResponseChoice(
index=0,
message=chat_message,
finish_reason=message.stop_reason,
)
],
usage=usage,
model=message.model,
)
def get_credits(self) -> Optional[Dict[str, Any]]:
"""
Get account credits from Anthropic.
Note: Anthropic does not currently provide a public API endpoint
for checking account credits/balance. This information is only
available through the Anthropic Console web interface.
Returns:
None (credits API not available)
"""
# Anthropic doesn't provide a public credits API endpoint
# Users must check their balance at console.anthropic.com
return None
def clear_cache(self) -> None:
"""Clear model cache."""
self._models_cache = None
def get_raw_models(self) -> List[Dict[str, Any]]:
"""
Get raw model data as dictionaries.
Returns:
List of model dictionaries
"""
models = self.list_models()
return [
{
"id": model.id,
"name": model.name,
"description": model.description,
"context_length": model.context_length,
"pricing": model.pricing,
}
for model in models
]
def get_raw_model(self, model_id: str) -> Optional[Dict[str, Any]]:
"""
Get raw model data for a specific model.
Args:
model_id: Model identifier
Returns:
Model dictionary or None
"""
model = self.get_model(model_id)
if model:
return {
"id": model.id,
"name": model.name,
"description": model.description,
"context_length": model.context_length,
"pricing": model.pricing,
}
return None

423
oai/providers/ollama.py Normal file
View File

@@ -0,0 +1,423 @@
"""
Ollama provider for local AI model serving.
This provider connects to a local Ollama server for running models
locally without API keys or external dependencies.
"""
import json
import time
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
import requests
from oai.constants import OLLAMA_DEFAULT_URL
from oai.providers.base import (
AIProvider,
ChatMessage,
ChatResponse,
ChatResponseChoice,
ModelInfo,
ProviderCapabilities,
StreamChunk,
UsageStats,
)
from oai.utils.logging import get_logger
logger = get_logger()
class OllamaProvider(AIProvider):
"""
Ollama local model provider.
Connects to a local Ollama server for running models locally.
No API key required.
"""
def __init__(
self,
api_key: str = "",
base_url: Optional[str] = None,
**kwargs: Any,
):
"""
Initialize Ollama provider.
Args:
api_key: Not used (Ollama doesn't require API keys)
base_url: Ollama server URL (default: http://localhost:11434)
**kwargs: Additional arguments (ignored)
"""
super().__init__(api_key or "", base_url)
self.base_url = base_url or OLLAMA_DEFAULT_URL
self._check_server_available()
def _check_server_available(self) -> bool:
"""
Check if Ollama server is accessible.
Returns:
True if server is accessible
"""
try:
response = requests.get(f"{self.base_url}/api/tags", timeout=2)
if response.ok:
logger.info(f"Ollama server accessible at {self.base_url}")
return True
else:
logger.warning(f"Ollama server returned status {response.status_code}")
return False
except requests.RequestException as e:
logger.warning(f"Ollama server not accessible at {self.base_url}: {e}")
return False
@property
def name(self) -> str:
"""Get provider name."""
return "Ollama"
@property
def capabilities(self) -> ProviderCapabilities:
"""Get provider capabilities."""
return ProviderCapabilities(
streaming=True,
tools=False, # Tool support varies by model
images=False, # Image support varies by model
online=True, # Web search via DuckDuckGo/Google
max_context=8192, # Varies by model
)
def list_models(self, filter_text_only: bool = True) -> List[ModelInfo]:
"""
List models from local Ollama installation.
Args:
filter_text_only: Ignored for Ollama
Returns:
List of available models
"""
try:
response = requests.get(f"{self.base_url}/api/tags", timeout=5)
response.raise_for_status()
data = response.json()
models = []
for model_data in data.get("models", []):
models.append(self._parse_model(model_data))
logger.info(f"Found {len(models)} Ollama models")
return models
except requests.RequestException as e:
logger.error(f"Failed to list Ollama models: {e}")
return []
def _parse_model(self, model_data: Dict[str, Any]) -> ModelInfo:
"""
Parse Ollama model data into ModelInfo.
Args:
model_data: Raw model data from Ollama API
Returns:
ModelInfo object
"""
model_name = model_data.get("name", "unknown")
size_bytes = model_data.get("size", 0)
size_gb = size_bytes / (1024 ** 3) if size_bytes else 0
return ModelInfo(
id=model_name,
name=model_name,
description=f"Size: {size_gb:.1f}GB",
context_length=8192, # Default, varies by model
pricing={}, # Local models are free
supported_parameters=["stream", "temperature", "max_tokens"],
)
def get_model(self, model_id: str) -> Optional[ModelInfo]:
"""
Get information about a specific model.
Args:
model_id: Model identifier
Returns:
ModelInfo or None if not found
"""
models = self.list_models()
for model in models:
if model.id == model_id:
return model
return None
def chat(
self,
model: str,
messages: List[ChatMessage],
stream: bool = False,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
tools: Optional[List[Dict[str, Any]]] = None,
tool_choice: Optional[str] = None,
**kwargs: Any,
) -> Union[ChatResponse, Iterator[StreamChunk]]:
"""
Send chat request to Ollama.
Args:
model: Model name
messages: Chat messages
stream: Whether to stream response
max_tokens: Maximum tokens (Ollama calls this num_predict)
temperature: Sampling temperature
tools: Not supported
tool_choice: Not supported
**kwargs: Additional parameters
Returns:
ChatResponse or Iterator[StreamChunk]
"""
# Convert messages to Ollama format
ollama_messages = []
for msg in messages:
ollama_messages.append({
"role": msg.role,
"content": msg.content or "",
})
# Build request payload
payload: Dict[str, Any] = {
"model": model,
"messages": ollama_messages,
"stream": stream,
}
# Add optional parameters
options = {}
if temperature is not None:
options["temperature"] = temperature
if max_tokens is not None:
options["num_predict"] = max_tokens
if options:
payload["options"] = options
logger.debug(f"Ollama request: model={model}, messages={len(ollama_messages)}")
try:
if stream:
return self._stream_chat(payload)
else:
return self._sync_chat(payload)
except requests.RequestException as e:
logger.error(f"Ollama request failed: {e}")
# Return error response
return ChatResponse(
id="error",
choices=[
ChatResponseChoice(
index=0,
message=ChatMessage(
role="assistant",
content=f"Error: {str(e)}",
),
finish_reason="error",
)
],
)
def _sync_chat(self, payload: Dict[str, Any]) -> ChatResponse:
"""
Send synchronous chat request.
Args:
payload: Request payload
Returns:
ChatResponse
"""
response = requests.post(
f"{self.base_url}/api/chat",
json=payload,
timeout=120,
)
response.raise_for_status()
data = response.json()
# Parse response
message_data = data.get("message", {})
content = message_data.get("content", "")
# Extract token usage if available
usage = None
if "prompt_eval_count" in data or "eval_count" in data:
usage = UsageStats(
prompt_tokens=data.get("prompt_eval_count", 0),
completion_tokens=data.get("eval_count", 0),
total_tokens=data.get("prompt_eval_count", 0) + data.get("eval_count", 0),
total_cost_usd=0.0, # Local models are free
)
return ChatResponse(
id=str(time.time()),
choices=[
ChatResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=content),
finish_reason="stop",
)
],
usage=usage,
model=data.get("model"),
)
def _stream_chat(self, payload: Dict[str, Any]) -> Iterator[StreamChunk]:
"""
Stream chat response from Ollama.
Args:
payload: Request payload
Yields:
StreamChunk objects
"""
response = requests.post(
f"{self.base_url}/api/chat",
json=payload,
stream=True,
timeout=120,
)
response.raise_for_status()
total_prompt_tokens = 0
total_completion_tokens = 0
for line in response.iter_lines():
if not line:
continue
try:
data = json.loads(line)
# Extract content delta
message_data = data.get("message", {})
content = message_data.get("content", "")
# Check if done
done = data.get("done", False)
finish_reason = "stop" if done else None
# Extract usage if available
usage = None
if done and ("prompt_eval_count" in data or "eval_count" in data):
total_prompt_tokens = data.get("prompt_eval_count", 0)
total_completion_tokens = data.get("eval_count", 0)
usage = UsageStats(
prompt_tokens=total_prompt_tokens,
completion_tokens=total_completion_tokens,
total_tokens=total_prompt_tokens + total_completion_tokens,
total_cost_usd=0.0,
)
yield StreamChunk(
id=str(time.time()),
delta_content=content if content else None,
finish_reason=finish_reason,
usage=usage,
)
except json.JSONDecodeError as e:
logger.error(f"Failed to parse Ollama stream chunk: {e}")
yield StreamChunk(
id="error",
error=f"Parse error: {e}",
)
async def chat_async(
self,
model: str,
messages: List[ChatMessage],
stream: bool = False,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
tools: Optional[List[Dict[str, Any]]] = None,
tool_choice: Optional[str] = None,
**kwargs: Any,
) -> Union[ChatResponse, AsyncIterator[StreamChunk]]:
"""
Async chat not implemented for Ollama.
Args:
model: Model name
messages: Chat messages
stream: Whether to stream
max_tokens: Max tokens
temperature: Temperature
tools: Tools (not supported)
tool_choice: Tool choice (not supported)
**kwargs: Additional args
Returns:
ChatResponse or AsyncIterator[StreamChunk]
Raises:
NotImplementedError: Async not implemented
"""
raise NotImplementedError("Async chat not implemented for Ollama provider")
def get_credits(self) -> Optional[Dict[str, Any]]:
"""
Get account credits.
Returns:
None (Ollama is local and free)
"""
return None
def clear_cache(self) -> None:
"""Clear model cache (no-op for Ollama)."""
pass
def get_raw_models(self) -> List[Dict[str, Any]]:
"""
Get raw model data as dictionaries.
Returns:
List of model dictionaries
"""
models = self.list_models()
return [
{
"id": model.id,
"name": model.name,
"description": model.description,
"context_length": model.context_length,
"pricing": model.pricing,
}
for model in models
]
def get_raw_model(self, model_id: str) -> Optional[Dict[str, Any]]:
"""
Get raw model data for a specific model.
Args:
model_id: Model identifier
Returns:
Model dictionary or None
"""
model = self.get_model(model_id)
if model:
return {
"id": model.id,
"name": model.name,
"description": model.description,
"context_length": model.context_length,
"pricing": model.pricing,
}
return None

630
oai/providers/openai.py Normal file
View File

@@ -0,0 +1,630 @@
"""
OpenAI provider for GPT models.
This provider connects to OpenAI's API for accessing GPT-4, GPT-3.5, and other OpenAI models.
"""
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
from openai import OpenAI, AsyncOpenAI
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from oai.constants import OPENAI_BASE_URL
from oai.providers.base import (
AIProvider,
ChatMessage,
ChatResponse,
ChatResponseChoice,
ModelInfo,
ProviderCapabilities,
StreamChunk,
ToolCall,
ToolFunction,
UsageStats,
)
from oai.utils.logging import get_logger
logger = get_logger()
# Model aliases for convenience
MODEL_ALIASES = {
"gpt-4": "gpt-4-turbo",
"gpt-4-turbo": "gpt-4-turbo-2024-04-09",
"gpt-4o": "gpt-4o-2024-11-20",
"gpt-4o-mini": "gpt-4o-mini-2024-07-18",
"gpt-3.5": "gpt-3.5-turbo",
"gpt-3.5-turbo": "gpt-3.5-turbo-0125",
"o1": "o1-2024-12-17",
"o1-mini": "o1-mini-2024-09-12",
"o1-preview": "o1-preview-2024-09-12",
}
class OpenAIProvider(AIProvider):
"""
OpenAI API provider.
Provides access to GPT-4, GPT-3.5, o1, and other OpenAI models.
"""
def __init__(
self,
api_key: str,
base_url: Optional[str] = None,
app_name: str = "oAI",
app_url: str = "",
**kwargs: Any,
):
"""
Initialize OpenAI provider.
Args:
api_key: OpenAI API key
base_url: Optional custom base URL
app_name: Application name (for headers)
app_url: Application URL (for headers)
**kwargs: Additional arguments
"""
super().__init__(api_key, base_url or OPENAI_BASE_URL)
self.client = OpenAI(api_key=api_key, base_url=self.base_url)
self.async_client = AsyncOpenAI(api_key=api_key, base_url=self.base_url)
self._models_cache: Optional[List[ModelInfo]] = None
@property
def name(self) -> str:
"""Get provider name."""
return "OpenAI"
@property
def capabilities(self) -> ProviderCapabilities:
"""Get provider capabilities."""
return ProviderCapabilities(
streaming=True,
tools=True,
images=True,
online=True, # Web search via DuckDuckGo/Google
max_context=128000,
)
def list_models(self, filter_text_only: bool = True) -> List[ModelInfo]:
"""
List available OpenAI models.
Args:
filter_text_only: Whether to filter for text models only
Returns:
List of ModelInfo objects
"""
if self._models_cache:
return self._models_cache
try:
models_response = self.client.models.list()
models = []
for model in models_response.data:
# Filter for chat models
if "gpt" in model.id or "o1" in model.id:
models.append(self._parse_model(model))
# Sort by name
models.sort(key=lambda m: m.name)
self._models_cache = models
logger.info(f"Loaded {len(models)} OpenAI models")
return models
except Exception as e:
logger.error(f"Failed to list OpenAI models: {e}")
return self._get_fallback_models()
def _get_fallback_models(self) -> List[ModelInfo]:
"""
Get fallback list of common OpenAI models.
Returns:
List of common models
"""
return [
ModelInfo(
id="gpt-4o",
name="GPT-4o",
description="Most capable GPT-4 model",
context_length=128000,
pricing={"input": 5.0, "output": 15.0},
supported_parameters=["temperature", "max_tokens", "stream", "tools"],
input_modalities=["text", "image"],
),
ModelInfo(
id="gpt-4o-mini",
name="GPT-4o Mini",
description="Affordable and fast GPT-4 class model",
context_length=128000,
pricing={"input": 0.15, "output": 0.6},
supported_parameters=["temperature", "max_tokens", "stream", "tools"],
input_modalities=["text", "image"],
),
ModelInfo(
id="gpt-4-turbo",
name="GPT-4 Turbo",
description="GPT-4 Turbo with vision",
context_length=128000,
pricing={"input": 10.0, "output": 30.0},
supported_parameters=["temperature", "max_tokens", "stream", "tools"],
input_modalities=["text", "image"],
),
ModelInfo(
id="gpt-3.5-turbo",
name="GPT-3.5 Turbo",
description="Fast and affordable model",
context_length=16384,
pricing={"input": 0.5, "output": 1.5},
supported_parameters=["temperature", "max_tokens", "stream", "tools"],
),
ModelInfo(
id="o1",
name="o1",
description="Advanced reasoning model",
context_length=200000,
pricing={"input": 15.0, "output": 60.0},
supported_parameters=["max_tokens"],
),
ModelInfo(
id="o1-mini",
name="o1-mini",
description="Fast reasoning model",
context_length=128000,
pricing={"input": 3.0, "output": 12.0},
supported_parameters=["max_tokens"],
),
]
def _parse_model(self, model: Any) -> ModelInfo:
"""
Parse OpenAI model into ModelInfo.
Args:
model: OpenAI model object
Returns:
ModelInfo object
"""
model_id = model.id
# Determine context length
context_length = 8192 # Default
if "gpt-4o" in model_id or "gpt-4-turbo" in model_id:
context_length = 128000
elif "gpt-4" in model_id:
context_length = 8192
elif "gpt-3.5-turbo" in model_id:
context_length = 16384
elif "o1" in model_id:
context_length = 128000
# Determine pricing (approximate)
pricing = {}
if "gpt-4o-mini" in model_id:
pricing = {"input": 0.15, "output": 0.6}
elif "gpt-4o" in model_id:
pricing = {"input": 5.0, "output": 15.0}
elif "gpt-4-turbo" in model_id:
pricing = {"input": 10.0, "output": 30.0}
elif "gpt-4" in model_id:
pricing = {"input": 30.0, "output": 60.0}
elif "gpt-3.5" in model_id:
pricing = {"input": 0.5, "output": 1.5}
elif "o1" in model_id and "mini" not in model_id:
pricing = {"input": 15.0, "output": 60.0}
elif "o1-mini" in model_id:
pricing = {"input": 3.0, "output": 12.0}
return ModelInfo(
id=model_id,
name=model_id,
description="",
context_length=context_length,
pricing=pricing,
supported_parameters=["temperature", "max_tokens", "stream", "tools"],
)
def get_model(self, model_id: str) -> Optional[ModelInfo]:
"""
Get information about a specific model.
Args:
model_id: Model identifier
Returns:
ModelInfo or None
"""
# Resolve alias
resolved_id = MODEL_ALIASES.get(model_id, model_id)
models = self.list_models()
for model in models:
if model.id == resolved_id or model.id == model_id:
return model
# Try to fetch directly
try:
model = self.client.models.retrieve(resolved_id)
return self._parse_model(model)
except Exception:
return None
def chat(
self,
model: str,
messages: List[ChatMessage],
stream: bool = False,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
tools: Optional[List[Dict[str, Any]]] = None,
tool_choice: Optional[str] = None,
**kwargs: Any,
) -> Union[ChatResponse, Iterator[StreamChunk]]:
"""
Send chat completion request to OpenAI.
Args:
model: Model ID
messages: Chat messages
stream: Whether to stream response
max_tokens: Maximum tokens
temperature: Sampling temperature
tools: Tool definitions
tool_choice: Tool selection mode
**kwargs: Additional parameters
Returns:
ChatResponse or Iterator[StreamChunk]
"""
# Resolve model alias
model_id = MODEL_ALIASES.get(model, model)
# Convert messages to OpenAI format
openai_messages = []
for msg in messages:
message_dict = {"role": msg.role, "content": msg.content or ""}
if msg.tool_calls:
message_dict["tool_calls"] = [
{
"id": tc.id,
"type": tc.type,
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments,
},
}
for tc in msg.tool_calls
]
if msg.tool_call_id:
message_dict["tool_call_id"] = msg.tool_call_id
openai_messages.append(message_dict)
# Build request parameters
params: Dict[str, Any] = {
"model": model_id,
"messages": openai_messages,
"stream": stream,
}
# Add optional parameters
if max_tokens is not None:
params["max_tokens"] = max_tokens
if temperature is not None and "o1" not in model_id:
# o1 models don't support temperature
params["temperature"] = temperature
if tools:
params["tools"] = tools
if tool_choice:
params["tool_choice"] = tool_choice
logger.debug(f"OpenAI request: model={model_id}, messages={len(openai_messages)}")
try:
if stream:
return self._stream_chat(params)
else:
return self._sync_chat(params)
except Exception as e:
logger.error(f"OpenAI request failed: {e}")
return ChatResponse(
id="error",
choices=[
ChatResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=f"Error: {str(e)}"),
finish_reason="error",
)
],
)
def _sync_chat(self, params: Dict[str, Any]) -> ChatResponse:
"""
Send synchronous chat request.
Args:
params: Request parameters
Returns:
ChatResponse
"""
completion: ChatCompletion = self.client.chat.completions.create(**params)
# Convert to our format
choices = []
for choice in completion.choices:
# Convert tool calls if present
tool_calls = None
if choice.message.tool_calls:
tool_calls = [
ToolCall(
id=tc.id,
type=tc.type,
function=ToolFunction(
name=tc.function.name,
arguments=tc.function.arguments,
),
)
for tc in choice.message.tool_calls
]
choices.append(
ChatResponseChoice(
index=choice.index,
message=ChatMessage(
role=choice.message.role,
content=choice.message.content,
tool_calls=tool_calls,
),
finish_reason=choice.finish_reason,
)
)
# Convert usage
usage = None
if completion.usage:
usage = UsageStats(
prompt_tokens=completion.usage.prompt_tokens,
completion_tokens=completion.usage.completion_tokens,
total_tokens=completion.usage.total_tokens,
)
return ChatResponse(
id=completion.id,
choices=choices,
usage=usage,
model=completion.model,
created=completion.created,
)
def _stream_chat(self, params: Dict[str, Any]) -> Iterator[StreamChunk]:
"""
Stream chat response from OpenAI.
Args:
params: Request parameters
Yields:
StreamChunk objects
"""
stream = self.client.chat.completions.create(**params)
for chunk in stream:
chunk_data: ChatCompletionChunk = chunk
if not chunk_data.choices:
continue
choice = chunk_data.choices[0]
delta = choice.delta
# Extract content
content = delta.content if delta.content else None
# Extract finish reason
finish_reason = choice.finish_reason
# Extract usage (usually in last chunk)
usage = None
if hasattr(chunk_data, "usage") and chunk_data.usage:
usage = UsageStats(
prompt_tokens=chunk_data.usage.prompt_tokens,
completion_tokens=chunk_data.usage.completion_tokens,
total_tokens=chunk_data.usage.total_tokens,
)
yield StreamChunk(
id=chunk_data.id,
delta_content=content,
finish_reason=finish_reason,
usage=usage,
)
async def chat_async(
self,
model: str,
messages: List[ChatMessage],
stream: bool = False,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
tools: Optional[List[Dict[str, Any]]] = None,
tool_choice: Optional[str] = None,
**kwargs: Any,
) -> Union[ChatResponse, AsyncIterator[StreamChunk]]:
"""
Send async chat request to OpenAI.
Args:
model: Model ID
messages: Chat messages
stream: Whether to stream
max_tokens: Max tokens
temperature: Temperature
tools: Tool definitions
tool_choice: Tool choice
**kwargs: Additional args
Returns:
ChatResponse or AsyncIterator[StreamChunk]
"""
# Resolve model alias
model_id = MODEL_ALIASES.get(model, model)
# Convert messages
openai_messages = [msg.to_dict() for msg in messages]
# Build params
params: Dict[str, Any] = {
"model": model_id,
"messages": openai_messages,
"stream": stream,
}
if max_tokens:
params["max_tokens"] = max_tokens
if temperature is not None and "o1" not in model_id:
params["temperature"] = temperature
if tools:
params["tools"] = tools
if tool_choice:
params["tool_choice"] = tool_choice
if stream:
return self._stream_chat_async(params)
else:
completion = await self.async_client.chat.completions.create(**params)
# Convert to ChatResponse (similar to _sync_chat)
return self._convert_completion(completion)
async def _stream_chat_async(self, params: Dict[str, Any]) -> AsyncIterator[StreamChunk]:
"""
Stream async chat response.
Args:
params: Request parameters
Yields:
StreamChunk objects
"""
stream = await self.async_client.chat.completions.create(**params)
async for chunk in stream:
if not chunk.choices:
continue
choice = chunk.choices[0]
delta = choice.delta
yield StreamChunk(
id=chunk.id,
delta_content=delta.content,
finish_reason=choice.finish_reason,
)
def _convert_completion(self, completion: ChatCompletion) -> ChatResponse:
"""Helper to convert OpenAI completion to ChatResponse."""
choices = []
for choice in completion.choices:
tool_calls = None
if choice.message.tool_calls:
tool_calls = [
ToolCall(
id=tc.id,
type=tc.type,
function=ToolFunction(
name=tc.function.name,
arguments=tc.function.arguments,
),
)
for tc in choice.message.tool_calls
]
choices.append(
ChatResponseChoice(
index=choice.index,
message=ChatMessage(
role=choice.message.role,
content=choice.message.content,
tool_calls=tool_calls,
),
finish_reason=choice.finish_reason,
)
)
usage = None
if completion.usage:
usage = UsageStats(
prompt_tokens=completion.usage.prompt_tokens,
completion_tokens=completion.usage.completion_tokens,
total_tokens=completion.usage.total_tokens,
)
return ChatResponse(
id=completion.id,
choices=choices,
usage=usage,
model=completion.model,
created=completion.created,
)
def get_credits(self) -> Optional[Dict[str, Any]]:
"""
Get account credits.
Returns:
None (OpenAI doesn't provide credit API)
"""
return None
def clear_cache(self) -> None:
"""Clear model cache."""
self._models_cache = None
def get_raw_models(self) -> List[Dict[str, Any]]:
"""
Get raw model data as dictionaries.
Returns:
List of model dictionaries
"""
models = self.list_models()
return [
{
"id": model.id,
"name": model.name,
"description": model.description,
"context_length": model.context_length,
"pricing": model.pricing,
}
for model in models
]
def get_raw_model(self, model_id: str) -> Optional[Dict[str, Any]]:
"""
Get raw model data for a specific model.
Args:
model_id: Model identifier
Returns:
Model dictionary or None
"""
model = self.get_model(model_id)
if model:
return {
"id": model.id,
"name": model.name,
"description": model.description,
"context_length": model.context_length,
"pricing": model.pricing,
}
return None

60
oai/providers/registry.py Normal file
View File

@@ -0,0 +1,60 @@
"""
Provider registry for AI model providers.
This module maintains a central registry of all available AI providers,
allowing dynamic provider lookup and registration.
"""
from typing import Dict, List, Optional, Type
from oai.providers.base import AIProvider
# Global provider registry
PROVIDER_REGISTRY: Dict[str, Type[AIProvider]] = {}
def register_provider(name: str, provider_class: Type[AIProvider]) -> None:
"""
Register a provider class with the given name.
Args:
name: Provider identifier (e.g., "openrouter", "anthropic")
provider_class: The provider class to register
"""
PROVIDER_REGISTRY[name] = provider_class
def get_provider_class(name: str) -> Optional[Type[AIProvider]]:
"""
Get a provider class by name.
Args:
name: Provider identifier
Returns:
Provider class or None if not found
"""
return PROVIDER_REGISTRY.get(name)
def list_providers() -> List[str]:
"""
List all registered provider names.
Returns:
List of provider identifiers
"""
return list(PROVIDER_REGISTRY.keys())
def is_provider_registered(name: str) -> bool:
"""
Check if a provider is registered.
Args:
name: Provider identifier
Returns:
True if provider is registered
"""
return name in PROVIDER_REGISTRY

View File

@@ -16,6 +16,7 @@ from oai.core.client import AIClient
from oai.core.session import ChatSession
from oai.tui.screens import (
AlertDialog,
CommandsScreen,
ConfirmDialog,
ConfigScreen,
ConversationSelectorScreen,
@@ -69,7 +70,8 @@ class oAIChatApp(App):
"""Compose the TUI layout."""
model_name = self.session.selected_model.get("name", "") if self.session.selected_model else ""
model_info = self.session.selected_model if self.session.selected_model else None
yield Header(version=__version__, model=model_name, model_info=model_info)
provider_name = self.session.client.provider_name if self.session.client else ""
yield Header(version=__version__, model=model_name, model_info=model_info, provider=provider_name)
yield ChatDisplay()
yield InputBar()
yield CommandDropdown()
@@ -158,10 +160,10 @@ class oAIChatApp(App):
else:
# Command is complete, submit it directly
dropdown.hide()
chat_input.value = "" # Clear immediately
# Process the command directly
async def submit_command():
await self._process_submitted_input(selected)
chat_input.value = ""
self.call_later(submit_command)
return
elif event.key == "escape":
@@ -232,12 +234,12 @@ class oAIChatApp(App):
if not user_input:
return
# Process the input
await self._process_submitted_input(user_input)
# Clear input field
# Clear input field immediately
event.input.value = ""
# Process the input (async, will wait for AI response)
await self._process_submitted_input(user_input)
async def _process_submitted_input(self, user_input: str) -> None:
"""Process submitted input (command or message).
@@ -280,6 +282,10 @@ class oAIChatApp(App):
await self.push_screen(HelpScreen())
return
if cmd_word == "commands":
await self.push_screen(CommandsScreen())
return
if cmd_word == "stats":
await self.push_screen(StatsScreen(self.session))
return
@@ -288,7 +294,7 @@ class oAIChatApp(App):
# Check if there are any arguments
args = command_text.split(maxsplit=1)
if len(args) == 1: # No arguments, just "/config"
await self.push_screen(ConfigScreen(self.settings))
await self.push_screen(ConfigScreen(self.settings, self.session))
return
# If there are arguments, fall through to normal command handler
@@ -447,9 +453,11 @@ class oAIChatApp(App):
# Update header if model changed
if self.session.selected_model:
header = self.query_one(Header)
provider_name = self.session.client.provider_name if self.session.client else ""
header.update_model(
self.session.selected_model.get("name", ""),
self.session.selected_model
self.session.selected_model,
provider_name
)
# Update MCP status indicator in input bar
@@ -472,7 +480,7 @@ class oAIChatApp(App):
# Create assistant message widget with loading indicator
model_name = self.session.selected_model.get("name", "Assistant") if self.session.selected_model else "Assistant"
assistant_widget = AssistantMessageWidget(model_name)
assistant_widget = AssistantMessageWidget(model_name, chat_display=chat_display)
await chat_display.add_message(assistant_widget)
# Show loading indicator immediately
@@ -851,7 +859,13 @@ class oAIChatApp(App):
if selected:
self.session.set_model(selected)
header = self.query_one(Header)
header.update_model(selected.get("name", ""), selected)
provider_name = self.session.client.provider_name if self.session.client else ""
header.update_model(selected.get("name", ""), selected, provider_name)
# Save this model as the last used for this provider
model_id = selected.get("id")
if model_id and provider_name:
self.settings.set_provider_model(provider_name, model_id)
# Save as default if requested
if set_as_default:

View File

@@ -1,5 +1,6 @@
"""TUI screens for oAI."""
from oai.tui.screens.commands_screen import CommandsScreen
from oai.tui.screens.config_screen import ConfigScreen
from oai.tui.screens.conversation_selector import ConversationSelectorScreen
from oai.tui.screens.credits_screen import CreditsScreen
@@ -10,6 +11,7 @@ from oai.tui.screens.stats_screen import StatsScreen
__all__ = [
"AlertDialog",
"CommandsScreen",
"ConfirmDialog",
"ConfigScreen",
"ConversationSelectorScreen",

View File

@@ -0,0 +1,172 @@
"""Commands reference screen for oAI TUI."""
from textual.app import ComposeResult
from textual.containers import Container, Vertical, VerticalScroll
from textual.screen import ModalScreen
from textual.widgets import Button, Static
class CommandsScreen(ModalScreen[None]):
"""Modal screen showing all available commands."""
DEFAULT_CSS = """
CommandsScreen {
align: center middle;
}
CommandsScreen > Container {
width: 90;
height: 40;
background: #1e1e1e;
border: solid #555555;
}
CommandsScreen .header {
dock: top;
width: 100%;
height: auto;
background: #2d2d2d;
color: #cccccc;
padding: 0 2;
}
CommandsScreen .content {
width: 100%;
height: 1fr;
background: #1e1e1e;
padding: 2;
color: #cccccc;
overflow-y: auto;
scrollbar-background: #1e1e1e;
scrollbar-color: #555555;
scrollbar-size: 1 1;
}
CommandsScreen .footer {
dock: bottom;
width: 100%;
height: auto;
background: #2d2d2d;
padding: 1 2;
align: center middle;
}
"""
def compose(self) -> ComposeResult:
"""Compose the commands screen."""
with Container():
yield Static("[bold]Commands Reference[/]", classes="header")
with VerticalScroll(classes="content"):
yield Static(self._get_commands_text(), markup=True)
with Vertical(classes="footer"):
yield Button("Close", id="close-button", variant="primary")
def _get_commands_text(self) -> str:
"""Generate formatted commands text."""
return """[bold cyan]General Commands[/]
[green]/help[/] - Show help screen with keyboard shortcuts
[green]/commands[/] - Show this commands reference
[green]/model[/] - Open model selector (or press F2)
[green]/stats[/] - Show session statistics (or press Ctrl+S)
[green]/credits[/] - Check account credits (OpenRouter) or view console link
[green]/clear[/] - Clear chat display
[green]/reset[/] - Reset conversation history
[green]/retry[/] - Retry last prompt
[green]/paste[/] - Paste from clipboard
[bold cyan]Provider Commands[/]
[green]/provider[/] - Show current provider
[green]/provider openrouter[/] - Switch to OpenRouter
[green]/provider anthropic[/] - Switch to Anthropic (Claude)
[green]/provider openai[/] - Switch to OpenAI (ChatGPT)
[green]/provider ollama[/] - Switch to Ollama (local)
[bold cyan]Online Mode (Web Search)[/]
[green]/online[/] - Show online mode status
[green]/online on[/] - Enable web search
[green]/online off[/] - Disable web search
[dim]Search Providers:[/]
• [yellow]anthropic_native[/] - Anthropic's native search with citations ($0.01/search)
• [yellow]duckduckgo[/] - Free web scraping (default, works with all providers)
• [yellow]google[/] - Google Custom Search (requires API key)
[bold cyan]Configuration Commands[/]
[green]/config[/] - View all settings
[green]/config provider <name>[/] - Set default provider
[green]/config search_provider <provider>[/] - Set search provider (anthropic_native/duckduckgo/google)
[green]/config openrouter_api_key <key>[/] - Set OpenRouter API key
[green]/config anthropic_api_key <key>[/] - Set Anthropic API key
[green]/config openai_api_key <key>[/] - Set OpenAI API key
[green]/config ollama_base_url <url>[/] - Set Ollama server URL
[green]/config google_api_key <key>[/] - Set Google API key (for Google search)
[green]/config google_search_engine_id <id>[/] - Set Google Search Engine ID
[green]/config online on|off[/] - Set default online mode
[green]/config stream on|off[/] - Toggle streaming
[green]/config model <id>[/] - Set default model
[green]/config system <prompt>[/] - Set system prompt
[green]/config maxtoken <num>[/] - Set token limit
[bold cyan]Memory & Context[/]
[green]/memory on[/] - Enable conversation memory
[green]/memory off[/] - Disable memory (fresh context each message)
[bold cyan]Conversation Management[/]
[green]/save <name>[/] - Save current conversation
[green]/load <name>[/] - Load saved conversation
[green]/list[/] - List all saved conversations
[green]/delete <name>[/] - Delete a conversation
[green]/prev[/] - Show previous message
[green]/next[/] - Show next message
[bold cyan]Export Commands[/]
[green]/export md <file>[/] - Export conversation as Markdown
[green]/export json <file>[/] - Export as JSON
[green]/export html <file>[/] - Export as HTML
[bold cyan]MCP (Model Context Protocol)[/]
[green]/mcp on[/] - Enable MCP file access
[green]/mcp off[/] - Disable MCP
[green]/mcp status[/] - Show MCP status
[green]/mcp add <path>[/] - Add folder for file access
[green]/mcp add db <path>[/] - Add SQLite database
[green]/mcp remove <path>[/] - Remove folder/database
[green]/mcp list[/] - List allowed folders
[green]/mcp db list[/] - List added databases
[green]/mcp db <n>[/] - Switch to database mode
[green]/mcp files[/] - Switch to file mode
[green]/mcp write on[/] - Enable write mode (allows file modifications)
[green]/mcp write off[/] - Disable write mode
[bold cyan]System Prompt[/]
[green]/system <prompt>[/] - Set custom system prompt for session
[green]/config system <prompt>[/] - Set default system prompt
[bold cyan]Keyboard Shortcuts[/]
• [yellow]F1[/] - Help screen
• [yellow]F2[/] - Model selector
• [yellow]Ctrl+S[/] - Statistics
• [yellow]Ctrl+Q[/] - Quit
• [yellow]Ctrl+Y[/] - Copy latest reply in Markdown
• [yellow]Up/Down[/] - Command history
• [yellow]Tab[/] - Command completion
"""
def on_button_pressed(self, event: Button.Pressed) -> None:
"""Handle button press."""
self.dismiss()
def on_key(self, event) -> None:
"""Handle keyboard shortcuts."""
if event.key in ("escape", "enter"):
self.dismiss()

View File

@@ -50,9 +50,10 @@ class ConfigScreen(ModalScreen[None]):
}
"""
def __init__(self, settings: Settings):
def __init__(self, settings: Settings, session=None):
super().__init__()
self.settings = settings
self.session = session
def compose(self) -> ComposeResult:
"""Compose the screen."""
@@ -67,35 +68,90 @@ class ConfigScreen(ModalScreen[None]):
"""Generate the configuration text."""
from oai.constants import DEFAULT_SYSTEM_PROMPT
# API Key display
api_key_display = "***" + self.settings.api_key[-4:] if self.settings.api_key else "Not set"
config_lines = ["[bold cyan]═══ CONFIGURATION ═══[/]\n"]
# Current Session Info
if self.session and self.session.client:
config_lines.append("[bold yellow]Current Session:[/]")
provider = self.session.client.provider_name
config_lines.append(f"[bold]Provider:[/] [green]{provider}[/]")
# Provider URL
if hasattr(self.session.client.provider, 'base_url'):
provider_url = self.session.client.provider.base_url
config_lines.append(f"[bold]Provider URL:[/] {provider_url}")
# API Key status
if provider == "ollama":
config_lines.append(f"[bold]API Key:[/] [dim]Not required (local)[/]")
else:
api_key = self.settings.get_provider_api_key(provider)
if api_key:
key_display = "***" + api_key[-4:] if len(api_key) > 4 else "***"
config_lines.append(f"[bold]API Key:[/] [green]{key_display}[/]")
else:
config_lines.append(f"[bold]API Key:[/] [red]Not set[/]")
# Current model
if self.session.selected_model:
model_name = self.session.selected_model.get("name", "Unknown")
config_lines.append(f"[bold]Current Model:[/] {model_name}")
config_lines.append("")
# Default Settings
config_lines.append("[bold yellow]Default Settings:[/]")
config_lines.append(f"[bold]Default Provider:[/] {self.settings.default_provider}")
# Mask helper
def mask_key(key):
if not key:
return "[red]Not set[/]"
return "***" + key[-4:] if len(key) > 4 else "***"
config_lines.append(f"[bold]OpenRouter Key:[/] {mask_key(self.settings.openrouter_api_key)}")
config_lines.append(f"[bold]Anthropic Key:[/] {mask_key(self.settings.anthropic_api_key)}")
config_lines.append(f"[bold]OpenAI Key:[/] {mask_key(self.settings.openai_api_key)}")
config_lines.append(f"[bold]Ollama URL:[/] {self.settings.ollama_base_url}")
config_lines.append(f"[bold]Default Model:[/] {self.settings.default_model or 'Not set'}")
# System prompt display
if self.settings.default_system_prompt is None:
system_prompt_display = f"[default] {DEFAULT_SYSTEM_PROMPT[:40]}..."
system_prompt_display = f"[dim][default][/] {DEFAULT_SYSTEM_PROMPT[:40]}..."
elif self.settings.default_system_prompt == "":
system_prompt_display = "[blank]"
system_prompt_display = "[dim][blank][/]"
else:
prompt = self.settings.default_system_prompt
system_prompt_display = prompt[:50] + "..." if len(prompt) > 50 else prompt
return f"""
[bold cyan]═══ CONFIGURATION ═══[/]
config_lines.append(f"[bold]System Prompt:[/] {system_prompt_display}")
config_lines.append("")
[bold]API Key:[/] {api_key_display}
[bold]Base URL:[/] {self.settings.base_url}
[bold]Default Model:[/] {self.settings.default_model or "Not set"}
# Web Search Configuration
config_lines.append("[bold yellow]Web Search Configuration:[/]")
config_lines.append(f"[bold]Search Provider:[/] {self.settings.search_provider}")
[bold]System Prompt:[/] {system_prompt_display}
# Show API key status based on search provider
if self.settings.search_provider == "google":
config_lines.append(f"[bold]Google API Key:[/] {mask_key(self.settings.google_api_key)}")
config_lines.append(f"[bold]Search Engine ID:[/] {mask_key(self.settings.google_search_engine_id)}")
elif self.settings.search_provider == "duckduckgo":
config_lines.append("[dim] DuckDuckGo requires no configuration (free)[/]")
elif self.settings.search_provider == "anthropic_native":
config_lines.append("[dim] Uses Anthropic API ($0.01 per search)[/]")
[bold]Streaming:[/] {"on" if self.settings.stream_enabled else "off"}
[bold]Cost Warning:[/] ${self.settings.cost_warning_threshold:.4f}
[bold]Max Tokens:[/] {self.settings.max_tokens}
[bold]Default Online:[/] {"on" if self.settings.default_online_mode else "off"}
[bold]Log Level:[/] {self.settings.log_level}
config_lines.append("")
[dim]Use /config [setting] [value] to modify settings[/]
"""
# Other settings
config_lines.append("[bold]Streaming:[/] " + ("on" if self.settings.stream_enabled else "off"))
config_lines.append(f"[bold]Cost Warning:[/] ${self.settings.cost_warning_threshold:.4f}")
config_lines.append(f"[bold]Max Tokens:[/] {self.settings.max_tokens}")
config_lines.append(f"[bold]Default Online:[/] " + ("on" if self.settings.default_online_mode else "off"))
config_lines.append(f"[bold]Log Level:[/] {self.settings.log_level}")
config_lines.append("\n[dim]Use /config [setting] [value] to modify settings[/]")
return "\n".join(config_lines)
def on_button_pressed(self, event: Button.Pressed) -> None:
"""Handle button press."""

View File

@@ -83,37 +83,70 @@ class CreditsScreen(ModalScreen[None]):
def _get_credits_text(self) -> str:
"""Generate the credits text."""
if not self.credits_data:
return "[yellow]No credit information available[/]"
# Provider-specific message when credits aren't available
if self.client.provider_name == "anthropic":
return """[yellow]Credit information not available via API[/]
total = self.credits_data.get("total_credits", 0)
used = self.credits_data.get("used_credits", 0)
Anthropic does not provide a public API endpoint for checking
account credits.
To view your account balance and usage:
[cyan]→ Visit console.anthropic.com[/]
[cyan]→ Navigate to Settings → Billing[/]
"""
elif self.client.provider_name == "openai":
return """[yellow]Credit information not available via API[/]
OpenAI does not provide credit balance through their API.
To view your account usage:
[cyan]→ Visit platform.openai.com[/]
[cyan]→ Navigate to Usage[/]
"""
elif self.client.provider_name == "ollama":
return """[yellow]Credit information not applicable[/]
Ollama runs locally on your machine and does not use credits.
"""
else:
return "[yellow]No credit information available[/]"
provider_name = self.client.provider_name.upper()
total = self.credits_data.get("total_credits")
used = self.credits_data.get("used_credits")
remaining = self.credits_data.get("credits_left", 0)
# Calculate percentage used
if total > 0:
percent_used = (used / total) * 100
percent_remaining = (remaining / total) * 100
else:
percent_used = 0
percent_remaining = 0
# Color code based on remaining credits
if percent_remaining > 50:
# Determine color based on absolute remaining amount
if remaining > 10:
remaining_color = "green"
elif percent_remaining > 20:
elif remaining > 2:
remaining_color = "yellow"
else:
remaining_color = "red"
return f"""
[bold cyan]═══ OPENROUTER CREDITS ═══[/]
lines = [f"[bold cyan]═══ {provider_name} CREDITS ═══[/]\n"]
[bold]Total Credits:[/] ${total:.2f}
[bold]Used:[/] ${used:.2f} [dim]({percent_used:.1f}%)[/]
[bold]Remaining:[/] [{remaining_color}]${remaining:.2f}[/] [dim]({percent_remaining:.1f}%)[/]
# If we have total/used info (OpenRouter)
if total is not None and used is not None:
percent_used = (used / total) * 100 if total > 0 else 0
percent_remaining = (remaining / total) * 100 if total > 0 else 0
[dim]Visit openrouter.ai to add more credits[/]
"""
lines.append(f"[bold]Total Credits:[/] ${total:.2f}")
lines.append(f"[bold]Used:[/] ${used:.2f} [dim]({percent_used:.1f}%)[/]")
lines.append(f"[bold]Remaining:[/] [{remaining_color}]${remaining:.2f}[/] [dim]({percent_remaining:.1f}%)[/]")
else:
# Anthropic - only shows balance
lines.append(f"[bold]Current Balance:[/] [{remaining_color}]${remaining:.2f}[/]")
lines.append("")
# Provider-specific footer
if self.client.provider_name == "openrouter":
lines.append("[dim]Visit openrouter.ai to add more credits[/]")
elif self.client.provider_name == "anthropic":
lines.append("[dim]Visit console.anthropic.com to manage billing[/]")
return "\n".join(lines)
def on_button_pressed(self, event: Button.Pressed) -> None:
"""Handle button press."""

View File

@@ -17,9 +17,11 @@ Header {
ChatDisplay {
background: $background;
border: none;
padding: 1;
padding: 1 0 1 1; /* top right bottom left - no right padding for scrollbar */
scrollbar-background: $background;
scrollbar-color: $primary;
scrollbar-size: 1 1;
scrollbar-gutter: stable; /* Reserve space for scrollbar */
overflow-y: auto;
}
@@ -43,7 +45,7 @@ SystemMessageWidget {
AssistantMessageWidget {
margin: 0 0 1 0;
padding: 1;
background: $panel;
background: $background;
border-left: thick $accent;
height: auto;
}
@@ -59,6 +61,9 @@ AssistantMessageWidget {
color: #cccccc;
link-color: #888888;
link-style: none;
border: none;
scrollbar-background: transparent;
scrollbar-color: #555555;
}
InputBar {

View File

@@ -59,9 +59,15 @@ class CommandDropdown(VerticalScroll):
# Get base commands with descriptions
base_commands = [
("/help", "Show help screen"),
("/commands", "Show all commands"),
("/model", "Select AI model"),
("/provider", "Switch AI provider"),
("/provider openrouter", "Switch to OpenRouter"),
("/provider anthropic", "Switch to Anthropic (Claude)"),
("/provider openai", "Switch to OpenAI (GPT)"),
("/provider ollama", "Switch to Ollama (local)"),
("/stats", "Show session statistics"),
("/credits", "Check account credits"),
("/credits", "Check account credits or view console link"),
("/clear", "Clear chat display"),
("/reset", "Reset conversation history"),
("/memory on", "Enable conversation memory"),
@@ -78,7 +84,19 @@ class CommandDropdown(VerticalScroll):
("/prev", "Show previous message"),
("/next", "Show next message"),
("/config", "View configuration"),
("/config api", "Set API key"),
("/config provider", "Set default provider"),
("/config search_provider", "Set search provider (anthropic_native/duckduckgo/google)"),
("/config openrouter_api_key", "Set OpenRouter API key"),
("/config anthropic_api_key", "Set Anthropic API key"),
("/config openai_api_key", "Set OpenAI API key"),
("/config ollama_base_url", "Set Ollama server URL"),
("/config google_api_key", "Set Google API key"),
("/config google_search_engine_id", "Set Google Search Engine ID"),
("/config online", "Set default online mode (on/off)"),
("/config stream", "Toggle streaming (on/off)"),
("/config model", "Set default model"),
("/config system", "Set system prompt"),
("/config maxtoken", "Set token limit"),
("/system", "Set system prompt"),
("/maxtoken", "Set token limit"),
("/retry", "Retry last prompt"),
@@ -117,12 +135,30 @@ class CommandDropdown(VerticalScroll):
# Remove the leading slash for filtering
filter_without_slash = filter_text[1:].lower()
# Filter commands - show if filter text is contained anywhere in the command
# Filter commands
if filter_without_slash:
matching = [
(cmd, desc) for cmd, desc in self._all_commands
if filter_without_slash in cmd[1:].lower() # Skip the / in command for matching
]
matching = []
# Check if user typed a parent command (e.g., "/config")
# If so, show all sub-commands that start with it
parent_command = "/" + filter_without_slash
has_subcommands = any(
cmd.startswith(parent_command + " ")
for cmd, _ in self._all_commands
)
if has_subcommands and parent_command in [cmd for cmd, _ in self._all_commands]:
# Show the parent command and all its sub-commands
matching = [
(cmd, desc) for cmd, desc in self._all_commands
if cmd == parent_command or cmd.startswith(parent_command + " ")
]
else:
# Regular filtering - show if filter text is contained anywhere in the command
matching = [
(cmd, desc) for cmd, desc in self._all_commands
if filter_without_slash in cmd[1:].lower() # Skip the / in command for matching
]
else:
# Show all commands when just "/" is typed
matching = self._all_commands
@@ -131,8 +167,8 @@ class CommandDropdown(VerticalScroll):
self.remove_class("visible")
return
# Add options - limit to 10 results
for cmd, desc in matching[:10]:
# Add options - limit to 15 results (increased from 10 for sub-commands)
for cmd, desc in matching[:15]:
# Format: command in white, description in gray, separated by spaces
label = f"{cmd} [dim]{desc}[/]" if desc else cmd
option_list.add_option(Option(label, id=cmd))

View File

@@ -8,11 +8,12 @@ from typing import Optional, Dict, Any
class Header(Static):
"""Header displaying app title, version, current model, and capabilities."""
def __init__(self, version: str = "3.0.1", model: str = "", model_info: Optional[Dict[str, Any]] = None):
def __init__(self, version: str = "3.0.1", model: str = "", model_info: Optional[Dict[str, Any]] = None, provider: str = ""):
super().__init__()
self.version = version
self.model = model
self.model_info = model_info or {}
self.provider = provider
def compose(self) -> ComposeResult:
"""Compose the header."""
@@ -51,15 +52,32 @@ class Header(Static):
def _format_header(self) -> str:
"""Format the header text."""
model_text = f" | {self.model}" if self.model else ""
# Show provider : model format
if self.provider and self.model:
provider_model = f"[bold cyan]{self.provider}[/] [dim]:[/] [bold]{self.model}[/]"
elif self.provider:
provider_model = f"[bold cyan]{self.provider}[/]"
elif self.model:
provider_model = f"[bold]{self.model}[/]"
else:
provider_model = ""
capabilities = self._format_capabilities()
capabilities_text = f" {capabilities}" if capabilities else ""
return f"[bold cyan]oAI[/] [dim]v{self.version}[/]{model_text}{capabilities_text}"
def update_model(self, model: str, model_info: Optional[Dict[str, Any]] = None) -> None:
# Format: oAI v{version} | provider : model capabilities
version_text = f"[bold cyan]oAI[/] [dim]v{self.version}[/]"
if provider_model:
return f"{version_text} [dim]|[/] {provider_model}{capabilities_text}"
else:
return version_text
def update_model(self, model: str, model_info: Optional[Dict[str, Any]] = None, provider: Optional[str] = None) -> None:
"""Update the displayed model and capabilities."""
self.model = model
if model_info:
self.model_info = model_info
if provider is not None:
self.provider = provider
content = self.query_one("#header-content", Static)
content.update(self._format_header())

View File

@@ -53,10 +53,11 @@ class SystemMessageWidget(Static):
class AssistantMessageWidget(Static):
"""Widget for displaying assistant responses with streaming support."""
def __init__(self, model_name: str = "Assistant"):
def __init__(self, model_name: str = "Assistant", chat_display=None):
super().__init__()
self.model_name = model_name
self.full_text = ""
self.chat_display = chat_display
def compose(self) -> ComposeResult:
"""Compose the assistant message."""
@@ -77,6 +78,11 @@ class AssistantMessageWidget(Static):
md = Markdown(self.full_text, code_theme="github-dark", inline_code_theme="github-dark")
log.write(md)
# Auto-scroll to keep the latest content visible
# Use call_after_refresh to ensure scroll happens after layout update
if self.chat_display:
self.chat_display.call_after_refresh(self.chat_display.scroll_end, animate=False)
if hasattr(chunk, "usage") and chunk.usage:
usage = chunk.usage

247
oai/utils/web_search.py Normal file
View File

@@ -0,0 +1,247 @@
"""
Web search utilities for oAI.
Provides web search capabilities for all providers (not just OpenRouter).
Uses DuckDuckGo by default (no API key needed).
"""
import json
import re
from typing import Dict, List, Optional
from urllib.parse import quote_plus
import requests
from oai.utils.logging import get_logger
logger = get_logger()
class WebSearchResult:
"""Container for a single search result."""
def __init__(self, title: str, url: str, snippet: str):
self.title = title
self.url = url
self.snippet = snippet
def __repr__(self) -> str:
return f"WebSearchResult(title='{self.title}', url='{self.url}')"
class WebSearchProvider:
"""Base class for web search providers."""
def search(self, query: str, num_results: int = 5) -> List[WebSearchResult]:
"""
Perform a web search.
Args:
query: Search query
num_results: Number of results to return
Returns:
List of search results
"""
raise NotImplementedError
class DuckDuckGoSearch(WebSearchProvider):
"""DuckDuckGo search provider (no API key needed)."""
def __init__(self):
self.session = requests.Session()
self.session.headers.update({
'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36'
})
def search(self, query: str, num_results: int = 5) -> List[WebSearchResult]:
"""
Search using DuckDuckGo HTML interface.
Args:
query: Search query
num_results: Number of results to return (default: 5)
Returns:
List of search results
"""
try:
# Use DuckDuckGo HTML search
url = f"https://html.duckduckgo.com/html/?q={quote_plus(query)}"
response = self.session.get(url, timeout=10)
response.raise_for_status()
results = []
html = response.text
# Parse results using regex (simple HTML parsing)
# Find all result blocks - they end at next result or end of results section
result_blocks = re.findall(
r'<div class="result results_links.*?(?=<div class="result results_links|<div id="links")',
html,
re.DOTALL
)
for block in result_blocks[:num_results]:
# Extract title and URL - look for result__a class
title_match = re.search(r'<a[^>]*class="result__a"[^>]*href="([^"]+)"[^>]*>([^<]+)</a>', block)
# Extract snippet - look for result__snippet class
snippet_match = re.search(r'<a[^>]*class="result__snippet"[^>]*>([^<]+)</a>', block)
if title_match:
url_raw = title_match.group(1)
title = title_match.group(2).strip()
# Decode HTML entities in title
import html as html_module
title = html_module.unescape(title)
snippet = ""
if snippet_match:
snippet = snippet_match.group(1).strip()
snippet = html_module.unescape(snippet)
# Clean up URL (DDG uses redirect links)
if 'uddg=' in url_raw:
# Extract actual URL from redirect
actual_url_match = re.search(r'uddg=([^&]+)', url_raw)
if actual_url_match:
from urllib.parse import unquote
url_raw = unquote(actual_url_match.group(1))
results.append(WebSearchResult(
title=title,
url=url_raw,
snippet=snippet
))
logger.info(f"DuckDuckGo search: found {len(results)} results for '{query}'")
return results
except requests.RequestException as e:
logger.error(f"DuckDuckGo search failed: {e}")
return []
except Exception as e:
logger.error(f"Error parsing DuckDuckGo results: {e}")
return []
class GoogleCustomSearch(WebSearchProvider):
"""Google Custom Search API provider (requires API key)."""
def __init__(self, api_key: str, search_engine_id: str):
"""
Initialize Google Custom Search.
Args:
api_key: Google API key
search_engine_id: Custom Search Engine ID
"""
self.api_key = api_key
self.search_engine_id = search_engine_id
def search(self, query: str, num_results: int = 5) -> List[WebSearchResult]:
"""
Search using Google Custom Search API.
Args:
query: Search query
num_results: Number of results to return
Returns:
List of search results
"""
try:
url = "https://www.googleapis.com/customsearch/v1"
params = {
'key': self.api_key,
'cx': self.search_engine_id,
'q': query,
'num': min(num_results, 10) # Google allows max 10
}
response = requests.get(url, params=params, timeout=10)
response.raise_for_status()
data = response.json()
results = []
for item in data.get('items', []):
results.append(WebSearchResult(
title=item.get('title', ''),
url=item.get('link', ''),
snippet=item.get('snippet', '')
))
logger.info(f"Google Custom Search: found {len(results)} results for '{query}'")
return results
except requests.RequestException as e:
logger.error(f"Google Custom Search failed: {e}")
return []
def perform_web_search(
query: str,
num_results: int = 5,
provider: str = "duckduckgo",
**kwargs
) -> List[WebSearchResult]:
"""
Perform a web search using the specified provider.
Args:
query: Search query
num_results: Number of results to return (default: 5)
provider: Search provider ("duckduckgo" or "google")
**kwargs: Provider-specific arguments (e.g., api_key for Google)
Returns:
List of search results
"""
if provider == "google":
api_key = kwargs.get("google_api_key")
search_engine_id = kwargs.get("google_search_engine_id")
if not api_key or not search_engine_id:
logger.warning("Google search requires api_key and search_engine_id, falling back to DuckDuckGo")
provider = "duckduckgo"
if provider == "google":
search_provider = GoogleCustomSearch(api_key, search_engine_id)
else:
search_provider = DuckDuckGoSearch()
return search_provider.search(query, num_results)
def format_search_results(results: List[WebSearchResult], max_length: int = 2000) -> str:
"""
Format search results for inclusion in AI prompt.
Args:
results: List of search results
max_length: Maximum total length of formatted results
Returns:
Formatted string with search results
"""
if not results:
return "No search results found."
formatted = "**Web Search Results:**\n\n"
for i, result in enumerate(results, 1):
result_text = f"{i}. **{result.title}**\n"
result_text += f" URL: {result.url}\n"
if result.snippet:
result_text += f" {result.snippet}\n"
result_text += "\n"
# Check if adding this result would exceed max_length
if len(formatted) + len(result_text) > max_length:
formatted += f"... ({len(results) - i + 1} more results truncated)\n"
break
formatted += result_text
return formatted.strip()

View File

@@ -4,8 +4,8 @@ build-backend = "setuptools.build_meta"
[project]
name = "oai"
version = "3.0.0-b2" # MUST match oai/__init__.py __version__
description = "OpenRouter AI Chat Client - A feature-rich terminal-based chat application"
version = "3.0.0-b3" # MUST match oai/__init__.py __version__
description = "Open AI Chat Client - Multi-provider terminal chat with MCP support"
readme = "README.md"
license = {text = "MIT"}
authors = [
@@ -39,9 +39,11 @@ classifiers = [
requires-python = ">=3.10"
dependencies = [
"anyio>=4.0.0",
"anthropic>=0.40.0",
"click>=8.0.0",
"httpx>=0.24.0",
"markdown-it-py>=3.0.0",
"openai>=1.59.0",
"openrouter>=0.0.19",
"packaging>=21.0",
"pyperclip>=1.8.0",