From 06a3c898d389e05c9062c5dc2c9308c5dfc346ac Mon Sep 17 00:00:00 2001 From: Rune Olsen Date: Thu, 5 Feb 2026 11:21:22 +0100 Subject: [PATCH] Lot's of changes. None breaking. v3.0.0-b3 --- README.md | 141 +++++- oai/__init__.py | 10 +- oai/cli.py | 43 +- oai/commands/handlers.py | 209 ++++++++- oai/commands/registry.py | 2 + oai/config/settings.py | 217 ++++++++- oai/constants.py | 22 +- oai/core/client.py | 104 ++++- oai/core/session.py | 97 +++- oai/providers/__init__.py | 13 + oai/providers/anthropic.py | 673 ++++++++++++++++++++++++++++ oai/providers/ollama.py | 423 +++++++++++++++++ oai/providers/openai.py | 630 ++++++++++++++++++++++++++ oai/providers/registry.py | 60 +++ oai/tui/app.py | 34 +- oai/tui/screens/__init__.py | 2 + oai/tui/screens/commands_screen.py | 172 +++++++ oai/tui/screens/config_screen.py | 92 +++- oai/tui/screens/credits_screen.py | 75 +++- oai/tui/styles.tcss | 9 +- oai/tui/widgets/command_dropdown.py | 54 ++- oai/tui/widgets/header.py | 26 +- oai/tui/widgets/message.py | 8 +- oai/utils/web_search.py | 247 ++++++++++ pyproject.toml | 6 +- 25 files changed, 3252 insertions(+), 117 deletions(-) create mode 100644 oai/providers/anthropic.py create mode 100644 oai/providers/ollama.py create mode 100644 oai/providers/openai.py create mode 100644 oai/providers/registry.py create mode 100644 oai/tui/screens/commands_screen.py create mode 100644 oai/utils/web_search.py diff --git a/README.md b/README.md index 86e5cd5..e8076f5 100644 --- a/README.md +++ b/README.md @@ -1,19 +1,23 @@ -# oAI - OpenRouter AI Chat Client +# oAI - Open AI Chat Client -A powerful, modern **Textual TUI** chat client for OpenRouter API with **MCP (Model Context Protocol)** support, enabling AI to access local files and query SQLite databases. +A powerful, modern **Textual TUI** chat client with **multi-provider support** (OpenRouter, Anthropic, OpenAI, Ollama) and **MCP (Model Context Protocol)** integration, enabling AI to access local files and query SQLite databases. ## Features ### Core Features - 🖥️ **Modern Textual TUI** with async streaming and beautiful interface -- 🤖 Interactive chat with 300+ AI models via OpenRouter +- 🔄 **Multi-Provider Support** - OpenRouter, Anthropic (Claude), OpenAI (ChatGPT), Ollama (local) +- 🤖 Interactive chat with 300+ AI models across providers - 🔍 Model selection with search, filtering, and capability icons - 💾 Conversation save/load/export (Markdown, JSON, HTML) - 📎 File attachments (images, PDFs, code files) -- 💰 Real-time cost tracking and credit monitoring +- 💰 Real-time cost tracking and credit monitoring (OpenRouter) - 🎨 Dark theme with syntax highlighting and Markdown rendering - 📝 Command history navigation (Up/Down arrows) -- 🌐 Online mode (web search capabilities) +- 🌐 **Universal Online Mode** - Web search for ALL providers: + - **Anthropic Native** - Built-in search with automatic citations ($0.01/search) + - **DuckDuckGo** - Free web scraping (all providers) + - **Google Custom Search** - Premium search option - 🧠 Conversation memory toggle - ⌨️ Keyboard shortcuts (F1=Help, F2=Models, Ctrl+S=Stats) @@ -36,7 +40,11 @@ A powerful, modern **Textual TUI** chat client for OpenRouter API with **MCP (Mo ## Requirements - Python 3.10-3.13 -- OpenRouter API key ([get one here](https://openrouter.ai)) +- API key for your chosen provider: + - **OpenRouter**: [openrouter.ai](https://openrouter.ai) + - **Anthropic**: [console.anthropic.com](https://console.anthropic.com) + - **OpenAI**: [platform.openai.com](https://platform.openai.com) + - **Ollama**: No API key needed (local server) ## Installation @@ -83,28 +91,111 @@ pip install -e . # Start oAI (launches TUI) oai +# Start with specific provider +oai --provider anthropic +oai --provider openai +oai --provider ollama + # Or with options -oai --model gpt-4o --online --mcp +oai --provider openrouter --model gpt-4o --online --mcp # Show version oai version ``` -On first run, you'll be prompted for your OpenRouter API key. +On first run, you'll be prompted for your API key. Configure additional providers anytime with `/config`. + +### Enable Web Search (All Providers) + +```bash +# Using Anthropic native search (best quality, automatic citations) +/config search_provider anthropic_native +/online on + +# Using free DuckDuckGo (works with all providers) +/config search_provider duckduckgo +/online on + +# Using Google Custom Search (requires API key) +/config search_provider google +/config google_api_key YOUR_KEY +/config google_search_engine_id YOUR_ID +/online on +``` ### Basic Commands ```bash # In the TUI interface: +/provider # Show current provider or switch +/provider anthropic # Switch to Anthropic (Claude) +/provider openai # Switch to OpenAI (ChatGPT) +/provider ollama # Switch to Ollama (local) /model # Select AI model (or press F2) +/online on # Enable web search /help # Show all commands (or press F1) /mcp on # Enable file/database access /stats # View session statistics (or press Ctrl+S) /config # View configuration settings -/credits # Check account credits +/credits # Check account credits (shows API balance or console link) Ctrl+Q # Quit ``` +## Web Search + +oAI provides universal web search capabilities for all AI providers with three options: + +### Anthropic Native Search (Recommended for Anthropic) + +Anthropic's built-in web search API with automatic citations: + +```bash +/config search_provider anthropic_native +/online on + +# Now ask questions requiring current information +What are the latest developments in quantum computing? +``` + +**Features:** +- ✅ **Automatic citations** - Claude cites its sources +- ✅ **Smart searching** - Claude decides when to search +- ✅ **Progressive searches** - Multiple searches for complex queries +- ✅ **Best quality** - Professional-grade results + +**Pricing:** $10 per 1,000 searches ($0.01 per search) + token costs + +**Note:** Only works with Anthropic provider (Claude models) + +### DuckDuckGo Search (Default - Free) + +Free web scraping that works with ALL providers: + +```bash +/config search_provider duckduckgo # Default +/online on + +# Works with Anthropic, OpenAI, Ollama, and OpenRouter +``` + +**Features:** +- ✅ **Free** - No API key or costs +- ✅ **Universal** - Works with all providers +- ✅ **Privacy-friendly** - Uses DuckDuckGo + +### Google Custom Search (Premium Option) + +Google's Custom Search API for high-quality results: + +```bash +/config search_provider google +/config google_api_key YOUR_GOOGLE_API_KEY +/config google_search_engine_id YOUR_SEARCH_ENGINE_ID +/online on +``` + +Get your API key: [Google Custom Search API](https://developers.google.com/custom-search/v1/overview) + ## MCP (Model Context Protocol) MCP allows the AI to interact with your local files and databases. @@ -184,11 +275,18 @@ MCP allows the AI to interact with your local files and databases. | Command | Description | |---------|-------------| | `/config` | View settings | -| `/config api` | Set API key | +| `/config provider ` | Set default provider | +| `/config openrouter_api_key` | Set OpenRouter API key | +| `/config anthropic_api_key` | Set Anthropic API key | +| `/config openai_api_key` | Set OpenAI API key | +| `/config ollama_base_url` | Set Ollama server URL | +| `/config search_provider ` | Set search provider (anthropic_native/duckduckgo/google) | +| `/config google_api_key` | Set Google API key (for Google search) | +| `/config online on\|off` | Set default online mode | | `/config model ` | Set default model | | `/config stream on\|off` | Toggle streaming | | `/stats` | Session statistics | -| `/credits` | Check credits | +| `/credits` | Check credits (OpenRouter) | ## CLI Options @@ -196,9 +294,10 @@ MCP allows the AI to interact with your local files and databases. oai [OPTIONS] Options: + -p, --provider TEXT Provider to use (openrouter/anthropic/openai/ollama) -m, --model TEXT Model ID to use -s, --system TEXT System prompt - -o, --online Enable online mode + -o, --online Enable online mode (OpenRouter only) --mcp Enable MCP server -v, --version Show version --help Show help @@ -280,7 +379,23 @@ pip install -e . --force-reinstall ## Version History -### v3.0.0 (Current) +### v3.0.0-b3 (Current - Beta 3) +- 🔄 **Multi-Provider Support** - OpenRouter, Anthropic (Claude), OpenAI (ChatGPT), Ollama (local) +- 🔌 **Provider Switching** - Switch between providers mid-session with `/provider` +- 🧠 **Provider Model Memory** - Remembers last used model per provider +- ⚙️ **Provider Configuration** - Separate API keys for each provider +- 🌐 **Universal Web Search** - Three search options for all providers: + - **Anthropic Native** - Built-in search with citations ($0.01/search) + - **DuckDuckGo** - Free web scraping (default) + - **Google Custom Search** - Premium option +- 🎨 **Enhanced Header** - Shows current provider and model +- 📊 **Updated Config Screen** - Shows provider-specific settings +- 🔧 **Command Dropdown** - Added `/provider` commands for discoverability +- 🎯 **Auto-Scrolling** - Chat window auto-scrolls during streaming responses +- 🎨 **Refined UI** - Thinner scrollbar, cleaner styling +- 🔧 **Updated Models** - Claude 4.5 (Sonnet, Haiku, Opus) + +### v3.0.0 (Beta) - 🎨 **Complete migration to Textual TUI** - Modern async terminal interface - 🗑️ **Removed CLI interface** - TUI-only for cleaner codebase (11.6% smaller) - 🖱️ **Modal screens** - Help, stats, config, credits, model selector diff --git a/oai/__init__.py b/oai/__init__.py index b6f4403..ffbf92c 100644 --- a/oai/__init__.py +++ b/oai/__init__.py @@ -1,16 +1,16 @@ """ -oAI - OpenRouter AI Chat Client +oAI - Open AI Chat Client -A feature-rich terminal-based chat application that provides an interactive CLI -interface to OpenRouter's unified AI API with advanced Model Context Protocol (MCP) +A feature-rich terminal-based chat application with multi-provider support +(OpenRouter, Anthropic, OpenAI, Ollama) and advanced Model Context Protocol (MCP) integration for filesystem and database access. Author: Rune License: MIT """ -__version__ = "3.0.0-b2" -__author__ = "Rune" +__version__ = "3.0.0-b3" +__author__ = "Rune Olsen" __license__ = "MIT" # Lazy imports to avoid circular dependencies and improve startup time diff --git a/oai/cli.py b/oai/cli.py index 702e0c2..71faae4 100644 --- a/oai/cli.py +++ b/oai/cli.py @@ -21,7 +21,7 @@ from oai.utils.logging import LoggingManager, get_logger # Create Typer app app = typer.Typer( name="oai", - help=f"oAI - OpenRouter AI Chat Client (TUI)\n\nVersion: {APP_VERSION}", + help=f"oAI - Open AI Chat Client (TUI)\n\nVersion: {APP_VERSION}", add_completion=False, epilog="For more information, visit: " + APP_URL, ) @@ -60,6 +60,12 @@ def main_callback( "--mcp", help="Enable MCP server", ), + provider: Optional[str] = typer.Option( + None, + "--provider", + "-p", + help="AI provider to use (openrouter, anthropic, openai, ollama)", + ), ) -> None: """Main callback - launches TUI by default.""" if version_flag: @@ -68,7 +74,7 @@ def main_callback( # If no subcommand provided, launch TUI if ctx.invoked_subcommand is None: - _launch_tui(model, system, online, mcp) + _launch_tui(model, system, online, mcp, provider) def _launch_tui( @@ -76,8 +82,11 @@ def _launch_tui( system: Optional[str] = None, online: bool = False, mcp: bool = False, + provider: Optional[str] = None, ) -> None: """Launch the Textual TUI interface.""" + from oai.constants import VALID_PROVIDERS + # Setup logging logging_manager = LoggingManager() logging_manager.setup() @@ -86,17 +95,35 @@ def _launch_tui( # Load settings settings = Settings.load() - # Check API key - if not settings.api_key: - typer.echo("Error: No API key configured", err=True) - typer.echo("Run: oai config api to set your API key", err=True) + # Determine provider + selected_provider = provider or settings.default_provider + + # Validate provider + if selected_provider not in VALID_PROVIDERS: + typer.echo(f"Error: Invalid provider: {selected_provider}", err=True) + typer.echo(f"Valid providers: {', '.join(VALID_PROVIDERS)}", err=True) raise typer.Exit(1) + # Build provider API keys dict + provider_api_keys = { + "openrouter": settings.openrouter_api_key, + "anthropic": settings.anthropic_api_key, + "openai": settings.openai_api_key, + } + + # Check if provider is configured (except Ollama which doesn't need API key) + if selected_provider != "ollama": + if not provider_api_keys.get(selected_provider): + typer.echo(f"Error: No API key configured for {selected_provider}", err=True) + typer.echo(f"Set it with: oai config {selected_provider}_api_key ", err=True) + raise typer.Exit(1) + # Initialize client try: client = AIClient( - api_key=settings.api_key, - base_url=settings.base_url, + provider_name=selected_provider, + provider_api_keys=provider_api_keys, + ollama_base_url=settings.ollama_base_url, ) except Exception as e: typer.echo(f"Error: Failed to initialize client: {e}", err=True) diff --git a/oai/commands/handlers.py b/oai/commands/handlers.py index e9fb451..b413488 100644 --- a/oai/commands/handlers.py +++ b/oai/commands/handlers.py @@ -275,7 +275,7 @@ class OnlineCommand(Command): ("Enable web search", "/online on"), ("Disable web search", "/online off"), ], - notes="Not all models support online mode.", + notes="OpenRouter: Native :online suffix. Other providers: DuckDuckGo search results injected into context.", ) def execute(self, args: str, context: CommandContext) -> CommandResult: @@ -820,7 +820,56 @@ class ConfigCommand(Command): from oai.constants import DEFAULT_SYSTEM_PROMPT table = Table("Setting", "Value", show_header=True, header_style="bold magenta") - table.add_row("API Key", "***" + settings.api_key[-4:] if settings.api_key else "Not set") + + # Current Session Provider Info (if available) + if context.session and context.session.client: + client = context.session.client + current_provider = client.provider_name + + table.add_row( + "[bold cyan]Current Provider[/]", + f"[bold green]{current_provider}[/]" + ) + + # Show current provider's base URL + if hasattr(client.provider, 'base_url'): + provider_url = client.provider.base_url + table.add_row(" Provider URL", provider_url) + + # Show API key status for current provider + if current_provider == "ollama": + table.add_row(" API Key", "[dim]Not required (local)[/]") + else: + current_key = settings.get_provider_api_key(current_provider) + if current_key: + masked = "***" + current_key[-4:] if len(current_key) > 4 else "***" + table.add_row(" API Key", f"[green]{masked}[/]") + else: + table.add_row(" API Key", "[red]Not set[/]") + + # Show current model + if context.session.selected_model: + model_name = context.session.selected_model.get("name", "Unknown") + table.add_row(" Current Model", model_name) + + # Add separator + table.add_row("", "") + + # Provider configuration settings + table.add_row("[bold]Default Provider[/]", settings.default_provider) + + # Provider API keys (masked) + def mask_key(key: Optional[str]) -> str: + if not key: + return "[red]Not set[/]" + return "***" + key[-4:] if len(key) > 4 else "***" + + table.add_row("OpenRouter API Key", mask_key(settings.openrouter_api_key)) + table.add_row("Anthropic API Key", mask_key(settings.anthropic_api_key)) + table.add_row("OpenAI API Key", mask_key(settings.openai_api_key)) + table.add_row("Ollama Base URL", settings.ollama_base_url) + + # Legacy/general settings table.add_row("Base URL", settings.base_url) table.add_row("Default Model", settings.default_model or "Not set") @@ -845,7 +894,62 @@ class ConfigCommand(Command): setting = parts[0].lower() value = parts[1] if len(parts) > 1 else None - if setting == "api": + # Provider selection + if setting == "provider": + from oai.constants import VALID_PROVIDERS + + if value: + if value.lower() in VALID_PROVIDERS: + settings.set_default_provider(value.lower()) + return CommandResult.success() + else: + return CommandResult.error(f"Invalid provider. Valid: {', '.join(VALID_PROVIDERS)}") + else: + return CommandResult.error("Usage: /config provider ") + + # Provider-specific API keys + elif setting in ["openrouter_api_key", "anthropic_api_key", "openai_api_key"]: + if value: + provider = setting.replace("_api_key", "") + settings.set_provider_api_key(provider, value) + return CommandResult.success() + else: + return CommandResult.success(data={"show_api_key_input": True, "provider": setting}) + + elif setting == "ollama_base_url": + if value: + settings.set_ollama_base_url(value) + return CommandResult.success() + else: + return CommandResult.error("Usage: /config ollama_base_url ") + + # Web search settings + elif setting == "search_provider": + if value: + valid_providers = ["anthropic_native", "duckduckgo", "google"] + if value.lower() in valid_providers: + settings.set_search_provider(value.lower()) + return CommandResult.success() + else: + return CommandResult.error(f"Invalid search provider. Valid: {', '.join(valid_providers)}") + else: + return CommandResult.error("Usage: /config search_provider ") + + elif setting == "google_api_key": + if value: + settings.set_google_api_key(value) + return CommandResult.success() + else: + return CommandResult.error("Usage: /config google_api_key ") + + elif setting == "google_search_engine_id": + if value: + settings.set_google_search_engine_id(value) + return CommandResult.success() + else: + return CommandResult.error("Usage: /config google_search_engine_id ") + + elif setting == "api": if value: settings.set_api_key(value) else: @@ -1388,6 +1492,104 @@ class MCPCommand(Command): message = f"Unknown MCP command: {cmd}\nAvailable: on, off, status, add, remove, list, db, files, write, gitignore" return CommandResult.error(message) + +class ProviderCommand(Command): + """Switch or display current provider.""" + + @property + def name(self) -> str: + return "/provider" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Switch AI provider or show current provider.", + usage="/provider [provider_name]", + examples=[ + ("Show current provider", "/provider"), + ("Switch to Anthropic", "/provider anthropic"), + ("Switch to OpenAI", "/provider openai"), + ("Switch to Ollama", "/provider ollama"), + ("Switch to OpenRouter", "/provider openrouter"), + ], + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + from oai.constants import VALID_PROVIDERS + + provider_name = args.strip().lower() + + if not context.session: + return CommandResult.error("No active session") + + if not provider_name: + # Show current provider and capabilities + current = context.session.client.provider_name + capabilities = context.session.client.provider.capabilities + + message = f"\n[bold]Current Provider:[/] {current}\n" + message += f" Streaming: {capabilities.streaming}\n" + message += f" Tools: {capabilities.tools}\n" + message += f" Images: {capabilities.images}\n" + message += f" Online: {capabilities.online}" + + return CommandResult.success(message=message) + + # Validate provider + if provider_name not in VALID_PROVIDERS: + return CommandResult.error( + f"Invalid provider: {provider_name}\nValid: {', '.join(VALID_PROVIDERS)}" + ) + + # Check API key (except for Ollama) + if provider_name != "ollama": + key = context.settings.get_provider_api_key(provider_name) + if not key: + return CommandResult.error( + f"No API key for {provider_name}. Set with /config {provider_name}_api_key " + ) + + # Switch provider + try: + context.session.client.switch_provider(provider_name) + + # Try to restore last used model for this provider + last_model_id = context.settings.get_provider_model(provider_name) + restored_model = False + + if last_model_id: + # Try to get the model from the new provider + raw_model = context.session.client.get_raw_model(last_model_id) + if raw_model: + context.session.set_model(raw_model) + restored_model = True + message = f"Switched to {provider_name} provider (model: {raw_model.get('name', last_model_id)})" + return CommandResult.success(message=message) + + # If no stored model or model not found, select a default + if not restored_model: + models = context.session.client.list_models() + if models: + # Select a sensible default (first model) + default_model = models[0] + raw_model = context.session.client.get_raw_model(default_model.id) + if raw_model: + context.session.set_model(raw_model) + # Save this as the default for this provider + context.settings.set_provider_model(provider_name, default_model.id) + message = f"Switched to {provider_name} provider (auto-selected: {default_model.name})" + return CommandResult.success(message=message) + else: + # No models available, clear selection + context.session.selected_model = None + message = f"Switched to {provider_name} provider (no models available)" + return CommandResult.success(message=message) + + return CommandResult.success(message=f"Switched to {provider_name} provider") + except Exception as e: + return CommandResult.error(f"Failed to switch provider: {e}") + + class PasteCommand(Command): """Paste from clipboard.""" @@ -1468,6 +1670,7 @@ def register_all_commands() -> None: DeleteCommand(), InfoCommand(), MCPCommand(), + ProviderCommand(), PasteCommand(), ModelCommand(), ] diff --git a/oai/commands/registry.py b/oai/commands/registry.py index 097e2d5..39b2279 100644 --- a/oai/commands/registry.py +++ b/oai/commands/registry.py @@ -80,6 +80,7 @@ class CommandContext: online_enabled: Whether online mode is enabled session_tokens: Session token counts session_cost: Session cost total + session: Reference to the ChatSession (for provider switching, etc.) """ settings: Optional["Settings"] = None @@ -100,6 +101,7 @@ class CommandContext: message_count: int = 0 is_tui: bool = False # Flag for TUI mode current_index: int = 0 + session: Optional[Any] = None # Reference to ChatSession (avoid circular import) @dataclass diff --git a/oai/config/settings.py b/oai/config/settings.py index 907a75a..16723f8 100644 --- a/oai/config/settings.py +++ b/oai/config/settings.py @@ -6,8 +6,9 @@ configuration with type safety, validation, and persistence. """ from dataclasses import dataclass, field -from typing import Optional +from typing import Optional, Dict from pathlib import Path +import json from oai.constants import ( DEFAULT_BASE_URL, @@ -20,6 +21,8 @@ from oai.constants import ( DEFAULT_LOG_LEVEL, DEFAULT_SYSTEM_PROMPT, VALID_LOG_LEVELS, + DEFAULT_PROVIDER, + OLLAMA_DEFAULT_URL, ) from oai.config.database import get_database @@ -34,7 +37,7 @@ class Settings: initialization and can be persisted back. Attributes: - api_key: OpenRouter API key + api_key: Legacy OpenRouter API key (deprecated, use openrouter_api_key) base_url: API base URL default_model: Default model ID to use default_system_prompt: Custom system prompt (None = use hardcoded default, "" = blank) @@ -45,9 +48,33 @@ class Settings: log_max_size_mb: Maximum log file size in MB log_backup_count: Number of log file backups to keep log_level: Logging level (debug/info/warning/error/critical) + + # Provider-specific settings + default_provider: Default AI provider to use + openrouter_api_key: OpenRouter API key + anthropic_api_key: Anthropic API key + openai_api_key: OpenAI API key + ollama_base_url: Ollama server URL """ + # Legacy field (kept for backward compatibility) api_key: Optional[str] = None + + # Provider configuration + default_provider: str = DEFAULT_PROVIDER + openrouter_api_key: Optional[str] = None + anthropic_api_key: Optional[str] = None + openai_api_key: Optional[str] = None + ollama_base_url: str = OLLAMA_DEFAULT_URL + provider_models: Dict[str, str] = field(default_factory=dict) # provider -> last_model_id + + # Web search configuration (for online mode with non-OpenRouter providers) + search_provider: str = "duckduckgo" # "duckduckgo" or "google" + google_api_key: Optional[str] = None + google_search_engine_id: Optional[str] = None + search_num_results: int = 5 + + # General settings base_url: str = DEFAULT_BASE_URL default_model: Optional[str] = None default_system_prompt: Optional[str] = None @@ -134,8 +161,43 @@ class Settings: # Get system prompt from DB: None means not set (use default), "" means explicitly blank system_prompt_value = db.get_config("default_system_prompt") + # Migration: copy legacy api_key to openrouter_api_key if not already set + legacy_api_key = db.get_config("api_key") + openrouter_key = db.get_config("openrouter_api_key") + + if legacy_api_key and not openrouter_key: + db.set_config("openrouter_api_key", legacy_api_key) + openrouter_key = legacy_api_key + # Note: We keep the legacy api_key in DB for backward compatibility + + # Load provider-model mapping + provider_models_json = db.get_config("provider_models") + provider_models = {} + if provider_models_json: + try: + provider_models = json.loads(provider_models_json) + except json.JSONDecodeError: + provider_models = {} + return cls( - api_key=db.get_config("api_key"), + # Legacy field + api_key=legacy_api_key, + + # Provider configuration + default_provider=db.get_config("default_provider") or DEFAULT_PROVIDER, + openrouter_api_key=openrouter_key, + anthropic_api_key=db.get_config("anthropic_api_key"), + openai_api_key=db.get_config("openai_api_key"), + ollama_base_url=db.get_config("ollama_base_url") or OLLAMA_DEFAULT_URL, + provider_models=provider_models, + + # Web search configuration + search_provider=db.get_config("search_provider") or "duckduckgo", + google_api_key=db.get_config("google_api_key"), + google_search_engine_id=db.get_config("google_search_engine_id"), + search_num_results=parse_int(db.get_config("search_num_results"), 5), + + # General settings base_url=db.get_config("base_url") or DEFAULT_BASE_URL, default_model=db.get_config("default_model"), default_system_prompt=system_prompt_value, @@ -331,6 +393,155 @@ class Settings: self.log_max_size_mb = min(size_mb, 100) get_database().set_config("log_max_size_mb", str(self.log_max_size_mb)) + def set_provider_api_key(self, provider: str, api_key: str) -> None: + """ + Set and persist an API key for a specific provider. + + Args: + provider: Provider name ("openrouter", "anthropic", "openai") + api_key: The API key to set + + Raises: + ValueError: If provider is invalid + """ + provider = provider.lower() + api_key = api_key.strip() + + if provider == "openrouter": + self.openrouter_api_key = api_key + get_database().set_config("openrouter_api_key", api_key) + elif provider == "anthropic": + self.anthropic_api_key = api_key + get_database().set_config("anthropic_api_key", api_key) + elif provider == "openai": + self.openai_api_key = api_key + get_database().set_config("openai_api_key", api_key) + else: + raise ValueError(f"Invalid provider: {provider}") + + def get_provider_api_key(self, provider: str) -> Optional[str]: + """ + Get the API key for a specific provider. + + Args: + provider: Provider name ("openrouter", "anthropic", "openai", "ollama") + + Returns: + API key or None if not set + + Raises: + ValueError: If provider is invalid + """ + provider = provider.lower() + + if provider == "openrouter": + return self.openrouter_api_key + elif provider == "anthropic": + return self.anthropic_api_key + elif provider == "openai": + return self.openai_api_key + elif provider == "ollama": + return "" # Ollama doesn't require an API key + else: + raise ValueError(f"Invalid provider: {provider}") + + def set_default_provider(self, provider: str) -> None: + """ + Set and persist the default provider. + + Args: + provider: Provider name + + Raises: + ValueError: If provider is invalid + """ + from oai.constants import VALID_PROVIDERS + + provider = provider.lower() + if provider not in VALID_PROVIDERS: + raise ValueError( + f"Invalid provider: {provider}. " + f"Valid providers: {', '.join(VALID_PROVIDERS)}" + ) + + self.default_provider = provider + get_database().set_config("default_provider", provider) + + def set_ollama_base_url(self, url: str) -> None: + """ + Set and persist the Ollama base URL. + + Args: + url: Ollama server URL + """ + self.ollama_base_url = url.strip() + get_database().set_config("ollama_base_url", self.ollama_base_url) + + def set_search_provider(self, provider: str) -> None: + """ + Set and persist the web search provider. + + Args: + provider: Search provider ("anthropic_native", "duckduckgo", "google") + + Raises: + ValueError: If provider is invalid + """ + valid_providers = ["anthropic_native", "duckduckgo", "google"] + provider = provider.lower() + if provider not in valid_providers: + raise ValueError( + f"Invalid search provider: {provider}. " + f"Valid providers: {', '.join(valid_providers)}" + ) + + self.search_provider = provider + get_database().set_config("search_provider", provider) + + def set_google_api_key(self, api_key: str) -> None: + """ + Set and persist the Google API key for Google Custom Search. + + Args: + api_key: The Google API key + """ + self.google_api_key = api_key.strip() + get_database().set_config("google_api_key", self.google_api_key) + + def set_google_search_engine_id(self, engine_id: str) -> None: + """ + Set and persist the Google Custom Search Engine ID. + + Args: + engine_id: The Google Search Engine ID + """ + self.google_search_engine_id = engine_id.strip() + get_database().set_config("google_search_engine_id", self.google_search_engine_id) + + def get_provider_model(self, provider: str) -> Optional[str]: + """ + Get the last used model for a provider. + + Args: + provider: Provider name + + Returns: + Model ID or None if not set + """ + return self.provider_models.get(provider) + + def set_provider_model(self, provider: str, model_id: str) -> None: + """ + Set and persist the last used model for a provider. + + Args: + provider: Provider name + model_id: Model ID to remember + """ + self.provider_models[provider] = model_id + # Save to database as JSON + get_database().set_config("provider_models", json.dumps(self.provider_models)) + # Global settings instance _settings: Optional[Settings] = None diff --git a/oai/constants.py b/oai/constants.py index f218680..172746d 100644 --- a/oai/constants.py +++ b/oai/constants.py @@ -20,7 +20,7 @@ from oai import __version__ APP_NAME = "oAI" APP_VERSION = __version__ # Single source of truth in oai/__init__.py APP_URL = "https://iurl.no/oai" -APP_DESCRIPTION = "OpenRouter AI Chat Client with MCP Integration" +APP_DESCRIPTION = "Open AI Chat Client with Multi-Provider Support" # ============================================================================= # FILE PATHS @@ -42,6 +42,26 @@ DEFAULT_STREAM_ENABLED = True DEFAULT_MAX_TOKENS = 100_000 DEFAULT_ONLINE_MODE = False +# ============================================================================= +# PROVIDER CONFIGURATION +# ============================================================================= + +# Provider names +PROVIDER_OPENROUTER = "openrouter" +PROVIDER_ANTHROPIC = "anthropic" +PROVIDER_OPENAI = "openai" +PROVIDER_OLLAMA = "ollama" + +VALID_PROVIDERS = [PROVIDER_OPENROUTER, PROVIDER_ANTHROPIC, PROVIDER_OPENAI, PROVIDER_OLLAMA] + +# Provider base URLs +ANTHROPIC_BASE_URL = "https://api.anthropic.com" +OPENAI_BASE_URL = "https://api.openai.com/v1" +OLLAMA_DEFAULT_URL = "http://localhost:11434" + +# Default provider +DEFAULT_PROVIDER = PROVIDER_OPENROUTER + # ============================================================================= # DEFAULT SYSTEM PROMPT # ============================================================================= diff --git a/oai/core/client.py b/oai/core/client.py index 5bcdbec..7eae911 100644 --- a/oai/core/client.py +++ b/oai/core/client.py @@ -9,7 +9,7 @@ import asyncio import json from typing import Any, Callable, Dict, Iterator, List, Optional, Union -from oai.constants import APP_NAME, APP_URL, MODEL_PRICING +from oai.constants import APP_NAME, APP_URL, MODEL_PRICING, OLLAMA_DEFAULT_URL from oai.providers.base import ( AIProvider, ChatMessage, @@ -19,7 +19,6 @@ from oai.providers.base import ( ToolCall, UsageStats, ) -from oai.providers.openrouter import OpenRouterProvider from oai.utils.logging import get_logger @@ -32,37 +31,66 @@ class AIClient: Attributes: provider: The underlying AI provider + provider_name: Name of the current provider default_model: Default model ID to use http_headers: Custom HTTP headers for requests """ def __init__( self, - api_key: str, - base_url: Optional[str] = None, - provider_class: type = OpenRouterProvider, + provider_name: str = "openrouter", + provider_api_keys: Optional[Dict[str, str]] = None, + ollama_base_url: str = OLLAMA_DEFAULT_URL, app_name: str = APP_NAME, app_url: str = APP_URL, ): """ - Initialize the AI client. + Initialize the AI client with specified provider. Args: - api_key: API key for authentication - base_url: Optional custom base URL - provider_class: Provider class to use (default: OpenRouterProvider) + provider_name: Provider to use ("openrouter", "anthropic", "openai", "ollama") + provider_api_keys: Dict mapping provider names to API keys + ollama_base_url: Base URL for Ollama server app_name: Application name for headers app_url: Application URL for headers + + Raises: + ValueError: If provider is invalid or not configured """ - self.provider: AIProvider = provider_class( - api_key=api_key, - base_url=base_url, - app_name=app_name, - app_url=app_url, - ) + from oai.providers.registry import get_provider_class + + self.provider_name = provider_name + self.provider_api_keys = provider_api_keys or {} + self.ollama_base_url = ollama_base_url + self.app_name = app_name + self.app_url = app_url + + # Get provider class + provider_class = get_provider_class(provider_name) + if not provider_class: + raise ValueError(f"Unknown provider: {provider_name}") + + # Get API key for this provider + api_key = self.provider_api_keys.get(provider_name, "") + + # Initialize provider with appropriate parameters + if provider_name == "ollama": + self.provider: AIProvider = provider_class( + api_key=api_key, + base_url=ollama_base_url, + ) + else: + self.provider: AIProvider = provider_class( + api_key=api_key, + app_name=app_name, + app_url=app_url, + ) + self.default_model: Optional[str] = None self.logger = get_logger() + self.logger.info(f"Initialized {provider_name} provider") + def list_models(self, filter_text_only: bool = True) -> List[ModelInfo]: """ Get available models. @@ -420,3 +448,49 @@ class AIClient: """Clear the provider's model cache.""" if hasattr(self.provider, "clear_cache"): self.provider.clear_cache() + + def switch_provider(self, provider_name: str, ollama_base_url: Optional[str] = None) -> None: + """ + Switch to a different provider. + + Args: + provider_name: Provider to switch to + ollama_base_url: Optional Ollama base URL (if switching to Ollama) + + Raises: + ValueError: If provider is invalid or not configured + """ + from oai.providers.registry import get_provider_class + + # Get provider class + provider_class = get_provider_class(provider_name) + if not provider_class: + raise ValueError(f"Unknown provider: {provider_name}") + + # Get API key + api_key = self.provider_api_keys.get(provider_name, "") + + # Check API key requirement + if provider_name != "ollama" and not api_key: + raise ValueError(f"No API key configured for {provider_name}") + + # Initialize new provider + if provider_name == "ollama": + base_url = ollama_base_url or self.ollama_base_url + self.provider = provider_class( + api_key=api_key, + base_url=base_url, + ) + self.ollama_base_url = base_url + else: + self.provider = provider_class( + api_key=api_key, + app_name=self.app_name, + app_url=self.app_url, + ) + + self.provider_name = provider_name + self.logger.info(f"Switched to {provider_name} provider") + + # Clear model cache when switching providers + self.default_model = None diff --git a/oai/core/session.py b/oai/core/session.py index aa3f381..536cf9b 100644 --- a/oai/core/session.py +++ b/oai/core/session.py @@ -23,6 +23,9 @@ from oai.core.client import AIClient from oai.mcp.manager import MCPManager from oai.providers.base import ChatResponse, StreamChunk, UsageStats from oai.utils.logging import get_logger +from oai.utils.web_search import perform_web_search, format_search_results + +logger = get_logger() @dataclass @@ -164,6 +167,7 @@ class ChatSession: total_cost=self.stats.total_cost, message_count=self.stats.message_count, current_index=self.current_index, + session=self, ) def set_model(self, model: Dict[str, Any]) -> None: @@ -290,6 +294,44 @@ class ChatSession: # Build request parameters model_id = self.selected_model["id"] + + # Handle online mode + enable_web_search = False + web_search_config = {} + + if self.online_enabled: + # OpenRouter handles online mode natively with :online suffix + if self.client.provider_name == "openrouter": + if hasattr(self.client.provider, "get_effective_model_id"): + model_id = self.client.provider.get_effective_model_id(model_id, True) + # Anthropic has native web search when search provider is set to anthropic_native + elif self.client.provider_name == "anthropic" and self.settings.search_provider == "anthropic_native": + enable_web_search = True + web_search_config = { + "max_uses": self.settings.search_num_results or 5 + } + logger.info("Using Anthropic native web search") + else: + # For other providers, perform web search and inject results + logger.info(f"Performing web search for: {user_input}") + search_results = perform_web_search( + user_input, + num_results=self.settings.search_num_results, + provider=self.settings.search_provider, + google_api_key=self.settings.google_api_key, + google_search_engine_id=self.settings.google_search_engine_id + ) + + if search_results: + # Inject search results into messages + formatted_results = format_search_results(search_results) + search_context = f"\n\n{formatted_results}\n\nPlease use the above web search results to help answer the user's question." + + # Add search results to the last user message + if messages and messages[-1]["role"] == "user": + messages[-1]["content"] += search_context + + logger.info(f"Injected {len(search_results)} search results into context") if self.online_enabled: if hasattr(self.client.provider, "get_effective_model_id"): model_id = self.client.provider.get_effective_model_id(model_id, True) @@ -320,6 +362,8 @@ class ChatSession: max_tokens=max_tokens, transforms=transforms, on_chunk=on_stream_chunk, + enable_web_search=enable_web_search, + web_search_config=web_search_config, ) response_time = time.time() - start_time return full_text, usage, response_time @@ -455,6 +499,8 @@ class ChatSession: max_tokens: Optional[int] = None, transforms: Optional[List[str]] = None, on_chunk: Optional[Callable[[str], None]] = None, + enable_web_search: bool = False, + web_search_config: Optional[Dict[str, Any]] = None, ) -> Tuple[str, Optional[UsageStats]]: """ Stream a response with live display. @@ -465,6 +511,8 @@ class ChatSession: max_tokens: Max tokens transforms: Transforms on_chunk: Callback for chunks + enable_web_search: Whether to enable Anthropic native web search + web_search_config: Web search configuration Returns: Tuple of (full_text, usage) @@ -475,6 +523,8 @@ class ChatSession: stream=True, max_tokens=max_tokens, transforms=transforms, + enable_web_search=enable_web_search, + web_search_config=web_search_config or {}, ) if isinstance(response, ChatResponse): @@ -530,10 +580,45 @@ class ChatSession: # Disable streaming when tools are present stream = False + # Handle online mode model_id = self.selected_model["id"] + enable_web_search = False + web_search_config = {} + if self.online_enabled: - if hasattr(self.client.provider, "get_effective_model_id"): - model_id = self.client.provider.get_effective_model_id(model_id, True) + # OpenRouter handles online mode natively with :online suffix + if self.client.provider_name == "openrouter": + if hasattr(self.client.provider, "get_effective_model_id"): + model_id = self.client.provider.get_effective_model_id(model_id, True) + # Anthropic has native web search when search provider is set to anthropic_native + elif self.client.provider_name == "anthropic" and self.settings.search_provider == "anthropic_native": + enable_web_search = True + web_search_config = { + "max_uses": self.settings.search_num_results or 5 + } + logger.info("Using Anthropic native web search") + else: + # For other providers, perform web search and inject results + logger.info(f"Performing web search for: {user_input}") + search_results = await asyncio.to_thread( + perform_web_search, + user_input, + num_results=self.settings.search_num_results, + provider=self.settings.search_provider, + google_api_key=self.settings.google_api_key, + google_search_engine_id=self.settings.google_search_engine_id + ) + + if search_results: + # Inject search results into messages + formatted_results = format_search_results(search_results) + search_context = f"\n\n{formatted_results}\n\nPlease use the above web search results to help answer the user's question." + + # Add search results to the last user message + if messages and messages[-1]["role"] == "user": + messages[-1]["content"] += search_context + + logger.info(f"Injected {len(search_results)} search results into context") transforms = ["middle-out"] if self.middle_out_enabled else None max_tokens = None @@ -557,6 +642,8 @@ class ChatSession: model_id=model_id, max_tokens=max_tokens, transforms=transforms, + enable_web_search=enable_web_search, + web_search_config=web_search_config, ): yield chunk else: @@ -733,6 +820,8 @@ class ChatSession: model_id: str, max_tokens: Optional[int] = None, transforms: Optional[List[str]] = None, + enable_web_search: bool = False, + web_search_config: Optional[Dict[str, Any]] = None, ) -> AsyncIterator[StreamChunk]: """ Async version of _stream_response for TUI. @@ -742,6 +831,8 @@ class ChatSession: model_id: Model ID max_tokens: Max tokens transforms: Transforms + enable_web_search: Whether to enable Anthropic native web search + web_search_config: Web search configuration Yields: StreamChunk objects @@ -752,6 +843,8 @@ class ChatSession: stream=True, max_tokens=max_tokens, transforms=transforms, + enable_web_search=enable_web_search, + web_search_config=web_search_config or {}, ) if isinstance(response, ChatResponse): diff --git a/oai/providers/__init__.py b/oai/providers/__init__.py index 93df1e5..ce43819 100644 --- a/oai/providers/__init__.py +++ b/oai/providers/__init__.py @@ -16,6 +16,16 @@ from oai.providers.base import ( ProviderCapabilities, ) from oai.providers.openrouter import OpenRouterProvider +from oai.providers.anthropic import AnthropicProvider +from oai.providers.openai import OpenAIProvider +from oai.providers.ollama import OllamaProvider +from oai.providers.registry import register_provider + +# Register all providers +register_provider("openrouter", OpenRouterProvider) +register_provider("anthropic", AnthropicProvider) +register_provider("openai", OpenAIProvider) +register_provider("ollama", OllamaProvider) __all__ = [ # Base classes and types @@ -29,4 +39,7 @@ __all__ = [ "ProviderCapabilities", # Provider implementations "OpenRouterProvider", + "AnthropicProvider", + "OpenAIProvider", + "OllamaProvider", ] diff --git a/oai/providers/anthropic.py b/oai/providers/anthropic.py new file mode 100644 index 0000000..af5542a --- /dev/null +++ b/oai/providers/anthropic.py @@ -0,0 +1,673 @@ +""" +Anthropic provider for Claude models. + +This provider connects to Anthropic's API for accessing Claude models. +""" + +import json +from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union + +import anthropic +from anthropic.types import Message, MessageStreamEvent + +from oai.constants import ANTHROPIC_BASE_URL +from oai.providers.base import ( + AIProvider, + ChatMessage, + ChatResponse, + ChatResponseChoice, + ModelInfo, + ProviderCapabilities, + StreamChunk, + ToolCall, + ToolFunction, + UsageStats, +) +from oai.utils.logging import get_logger + +logger = get_logger() + + +# Model name aliases +MODEL_ALIASES = { + "claude-sonnet": "claude-sonnet-4-5-20250929", + "claude-haiku": "claude-haiku-4-5-20251001", + "claude-opus": "claude-opus-4-5-20251101", + # Legacy aliases + "claude-3-haiku": "claude-3-haiku-20240307", + "claude-3-7-sonnet": "claude-3-7-sonnet-20250219", +} + + +class AnthropicProvider(AIProvider): + """ + Anthropic API provider. + + Provides access to Claude 3.5 Sonnet, Claude 3 Opus, and other Anthropic models. + """ + + def __init__( + self, + api_key: str, + base_url: Optional[str] = None, + app_name: str = "oAI", + app_url: str = "", + **kwargs: Any, + ): + """ + Initialize Anthropic provider. + + Args: + api_key: Anthropic API key + base_url: Optional custom base URL + app_name: Application name (for headers) + app_url: Application URL (for headers) + **kwargs: Additional arguments + """ + super().__init__(api_key, base_url or ANTHROPIC_BASE_URL) + self.client = anthropic.Anthropic(api_key=api_key) + self.async_client = anthropic.AsyncAnthropic(api_key=api_key) + self._models_cache: Optional[List[ModelInfo]] = None + + def _create_web_search_tool(self, config: Dict[str, Any]) -> Dict[str, Any]: + """ + Create Anthropic native web search tool definition. + + Args: + config: Optional configuration for web search (max_uses, allowed_domains, etc.) + + Returns: + Tool definition dict + """ + tool: Dict[str, Any] = { + "type": "web_search_20250305", + "name": "web_search", + } + + # Add optional parameters if provided + if "max_uses" in config: + tool["max_uses"] = config["max_uses"] + else: + tool["max_uses"] = 5 # Default + + if "allowed_domains" in config: + tool["allowed_domains"] = config["allowed_domains"] + + if "blocked_domains" in config: + tool["blocked_domains"] = config["blocked_domains"] + + if "user_location" in config: + tool["user_location"] = config["user_location"] + + return tool + + @property + def name(self) -> str: + """Get provider name.""" + return "Anthropic" + + @property + def capabilities(self) -> ProviderCapabilities: + """Get provider capabilities.""" + return ProviderCapabilities( + streaming=True, + tools=True, + images=True, + online=True, # Web search via DuckDuckGo/Google + max_context=200000, + ) + + def list_models(self, filter_text_only: bool = True) -> List[ModelInfo]: + """ + List available Anthropic models. + + Args: + filter_text_only: Whether to filter for text models only + + Returns: + List of ModelInfo objects + """ + if self._models_cache: + return self._models_cache + + # Anthropic doesn't have a models list API, so we hardcode the available models + models = [ + # Current Claude 4.5 models + ModelInfo( + id="claude-sonnet-4-5-20250929", + name="Claude Sonnet 4.5", + description="Smart model for complex agents and coding (recommended)", + context_length=200000, + pricing={"input": 3.0, "output": 15.0}, + supported_parameters=["temperature", "max_tokens", "stream", "tools"], + input_modalities=["text", "image"], + ), + ModelInfo( + id="claude-haiku-4-5-20251001", + name="Claude Haiku 4.5", + description="Fastest model with near-frontier intelligence", + context_length=200000, + pricing={"input": 1.0, "output": 5.0}, + supported_parameters=["temperature", "max_tokens", "stream", "tools"], + input_modalities=["text", "image"], + ), + ModelInfo( + id="claude-opus-4-5-20251101", + name="Claude Opus 4.5", + description="Premium model with maximum intelligence", + context_length=200000, + pricing={"input": 5.0, "output": 25.0}, + supported_parameters=["temperature", "max_tokens", "stream", "tools"], + input_modalities=["text", "image"], + ), + # Legacy models (still available) + ModelInfo( + id="claude-3-7-sonnet-20250219", + name="Claude Sonnet 3.7", + description="Legacy model - recommend migrating to 4.5", + context_length=200000, + pricing={"input": 3.0, "output": 15.0}, + supported_parameters=["temperature", "max_tokens", "stream", "tools"], + input_modalities=["text", "image"], + ), + ModelInfo( + id="claude-3-haiku-20240307", + name="Claude 3 Haiku", + description="Legacy fast model - recommend migrating to 4.5", + context_length=200000, + pricing={"input": 0.25, "output": 1.25}, + supported_parameters=["temperature", "max_tokens", "stream", "tools"], + input_modalities=["text", "image"], + ), + ] + + self._models_cache = models + logger.info(f"Loaded {len(models)} Anthropic models") + return models + + def get_model(self, model_id: str) -> Optional[ModelInfo]: + """ + Get information about a specific model. + + Args: + model_id: Model identifier + + Returns: + ModelInfo or None + """ + # Resolve alias + resolved_id = MODEL_ALIASES.get(model_id, model_id) + + models = self.list_models() + for model in models: + if model.id == resolved_id or model.id == model_id: + return model + return None + + def chat( + self, + model: str, + messages: List[ChatMessage], + stream: bool = False, + max_tokens: Optional[int] = None, + temperature: Optional[float] = None, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[str] = None, + **kwargs: Any, + ) -> Union[ChatResponse, Iterator[StreamChunk]]: + """ + Send chat completion request to Anthropic. + + Args: + model: Model ID + messages: Chat messages + stream: Whether to stream response + max_tokens: Maximum tokens + temperature: Sampling temperature + tools: Tool definitions + tool_choice: Tool selection mode + **kwargs: Additional parameters (including enable_web_search) + + Returns: + ChatResponse or Iterator[StreamChunk] + """ + # Resolve model alias + model_id = MODEL_ALIASES.get(model, model) + + # Extract system message (Anthropic requires it separate from messages) + system_prompt, anthropic_messages = self._convert_messages(messages) + + # Build request parameters + params: Dict[str, Any] = { + "model": model_id, + "messages": anthropic_messages, + "max_tokens": max_tokens or 4096, + } + + if system_prompt: + params["system"] = system_prompt + + if temperature is not None: + params["temperature"] = temperature + + # Prepare tools list + tools_list = [] + + # Add web search tool if requested via kwargs + if kwargs.get("enable_web_search", False): + web_search_config = kwargs.get("web_search_config", {}) + tools_list.append(self._create_web_search_tool(web_search_config)) + logger.info("Added Anthropic native web search tool") + + # Add user-provided tools + if tools: + # Convert tools to Anthropic format + converted_tools = self._convert_tools(tools) + tools_list.extend(converted_tools) + + if tools_list: + params["tools"] = tools_list + + if tool_choice and tool_choice != "auto": + # Anthropic uses different tool_choice format + if tool_choice == "none": + pass # Don't include tools + elif tool_choice == "required": + params["tool_choice"] = {"type": "any"} + else: + params["tool_choice"] = {"type": "tool", "name": tool_choice} + + logger.debug(f"Anthropic request: model={model_id}, messages={len(anthropic_messages)}") + + try: + if stream: + return self._stream_chat(params) + else: + return self._sync_chat(params) + + except Exception as e: + logger.error(f"Anthropic request failed: {e}") + return ChatResponse( + id="error", + choices=[ + ChatResponseChoice( + index=0, + message=ChatMessage(role="assistant", content=f"Error: {str(e)}"), + finish_reason="error", + ) + ], + ) + + def _convert_messages(self, messages: List[ChatMessage]) -> tuple[str, List[Dict[str, Any]]]: + """ + Convert messages to Anthropic format. + + Anthropic requires system messages to be separate from the conversation. + + Args: + messages: List of ChatMessage objects + + Returns: + Tuple of (system_prompt, anthropic_messages) + """ + system_prompt = "" + anthropic_messages = [] + + for msg in messages: + if msg.role == "system": + # Accumulate system messages + if system_prompt: + system_prompt += "\n\n" + system_prompt += msg.content or "" + else: + # Convert to Anthropic format + message_dict: Dict[str, Any] = {"role": msg.role} + + # Handle content + if msg.content: + message_dict["content"] = msg.content + + # Handle tool calls (assistant messages) + if msg.tool_calls: + # Anthropic format for tool use + content_blocks = [] + if msg.content: + content_blocks.append({"type": "text", "text": msg.content}) + + for tc in msg.tool_calls: + content_blocks.append({ + "type": "tool_use", + "id": tc.id, + "name": tc.function.name, + "input": json.loads(tc.function.arguments), + }) + + message_dict["content"] = content_blocks + + # Handle tool results (tool messages) + if msg.role == "tool" and msg.tool_call_id: + # Convert to Anthropic's tool_result format + anthropic_messages.append({ + "role": "user", + "content": [{ + "type": "tool_result", + "tool_use_id": msg.tool_call_id, + "content": msg.content or "", + }] + }) + continue + + anthropic_messages.append(message_dict) + + return system_prompt, anthropic_messages + + def _convert_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Convert OpenAI-style tools to Anthropic format. + + Args: + tools: OpenAI tool definitions + + Returns: + Anthropic tool definitions + """ + anthropic_tools = [] + + for tool in tools: + if tool.get("type") == "function": + func = tool.get("function", {}) + anthropic_tools.append({ + "name": func.get("name"), + "description": func.get("description", ""), + "input_schema": func.get("parameters", {}), + }) + + return anthropic_tools + + def _sync_chat(self, params: Dict[str, Any]) -> ChatResponse: + """ + Send synchronous chat request. + + Args: + params: Request parameters + + Returns: + ChatResponse + """ + message: Message = self.client.messages.create(**params) + + # Extract content + content = "" + tool_calls = [] + + for block in message.content: + if block.type == "text": + content += block.text + elif block.type == "tool_use": + # Convert to ToolCall format + tool_calls.append( + ToolCall( + id=block.id, + type="function", + function=ToolFunction( + name=block.name, + arguments=json.dumps(block.input), + ), + ) + ) + + # Build ChatMessage + chat_message = ChatMessage( + role="assistant", + content=content if content else None, + tool_calls=tool_calls if tool_calls else None, + ) + + # Extract usage + usage = None + if message.usage: + usage = UsageStats( + prompt_tokens=message.usage.input_tokens, + completion_tokens=message.usage.output_tokens, + total_tokens=message.usage.input_tokens + message.usage.output_tokens, + ) + + return ChatResponse( + id=message.id, + choices=[ + ChatResponseChoice( + index=0, + message=chat_message, + finish_reason=message.stop_reason, + ) + ], + usage=usage, + model=message.model, + ) + + def _stream_chat(self, params: Dict[str, Any]) -> Iterator[StreamChunk]: + """ + Stream chat response from Anthropic. + + Args: + params: Request parameters + + Yields: + StreamChunk objects + """ + stream = self.client.messages.stream(**params) + + with stream as event_stream: + for event in event_stream: + event_data: MessageStreamEvent = event + + # Handle different event types + if event_data.type == "content_block_delta": + delta = event_data.delta + if hasattr(delta, "text"): + yield StreamChunk( + id="stream", + delta_content=delta.text, + ) + + elif event_data.type == "message_stop": + # Final event with usage + pass + + elif event_data.type == "message_delta": + # Contains stop reason and usage + usage = None + if hasattr(event_data, "usage"): + usage_data = event_data.usage + usage = UsageStats( + prompt_tokens=0, + completion_tokens=usage_data.output_tokens, + total_tokens=usage_data.output_tokens, + ) + + yield StreamChunk( + id="stream", + finish_reason=event_data.delta.stop_reason if hasattr(event_data.delta, "stop_reason") else None, + usage=usage, + ) + + async def chat_async( + self, + model: str, + messages: List[ChatMessage], + stream: bool = False, + max_tokens: Optional[int] = None, + temperature: Optional[float] = None, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[str] = None, + **kwargs: Any, + ) -> Union[ChatResponse, AsyncIterator[StreamChunk]]: + """ + Send async chat request to Anthropic. + + Args: + model: Model ID + messages: Chat messages + stream: Whether to stream + max_tokens: Max tokens + temperature: Temperature + tools: Tool definitions + tool_choice: Tool choice + **kwargs: Additional args + + Returns: + ChatResponse or AsyncIterator[StreamChunk] + """ + # Resolve model alias + model_id = MODEL_ALIASES.get(model, model) + + # Convert messages + system_prompt, anthropic_messages = self._convert_messages(messages) + + # Build params + params: Dict[str, Any] = { + "model": model_id, + "messages": anthropic_messages, + "max_tokens": max_tokens or 4096, + } + + if system_prompt: + params["system"] = system_prompt + if temperature is not None: + params["temperature"] = temperature + if tools: + params["tools"] = self._convert_tools(tools) + + if stream: + return self._stream_chat_async(params) + else: + message = await self.async_client.messages.create(**params) + return self._convert_message(message) + + async def _stream_chat_async(self, params: Dict[str, Any]) -> AsyncIterator[StreamChunk]: + """ + Stream async chat response. + + Args: + params: Request parameters + + Yields: + StreamChunk objects + """ + stream = await self.async_client.messages.stream(**params) + + async with stream as event_stream: + async for event in event_stream: + if event.type == "content_block_delta": + delta = event.delta + if hasattr(delta, "text"): + yield StreamChunk( + id="stream", + delta_content=delta.text, + ) + + def _convert_message(self, message: Message) -> ChatResponse: + """Helper to convert Anthropic message to ChatResponse.""" + content = "" + tool_calls = [] + + for block in message.content: + if block.type == "text": + content += block.text + elif block.type == "tool_use": + tool_calls.append( + ToolCall( + id=block.id, + type="function", + function=ToolFunction( + name=block.name, + arguments=json.dumps(block.input), + ), + ) + ) + + chat_message = ChatMessage( + role="assistant", + content=content if content else None, + tool_calls=tool_calls if tool_calls else None, + ) + + usage = None + if message.usage: + usage = UsageStats( + prompt_tokens=message.usage.input_tokens, + completion_tokens=message.usage.output_tokens, + total_tokens=message.usage.input_tokens + message.usage.output_tokens, + ) + + return ChatResponse( + id=message.id, + choices=[ + ChatResponseChoice( + index=0, + message=chat_message, + finish_reason=message.stop_reason, + ) + ], + usage=usage, + model=message.model, + ) + + def get_credits(self) -> Optional[Dict[str, Any]]: + """ + Get account credits from Anthropic. + + Note: Anthropic does not currently provide a public API endpoint + for checking account credits/balance. This information is only + available through the Anthropic Console web interface. + + Returns: + None (credits API not available) + """ + # Anthropic doesn't provide a public credits API endpoint + # Users must check their balance at console.anthropic.com + return None + + def clear_cache(self) -> None: + """Clear model cache.""" + self._models_cache = None + + def get_raw_models(self) -> List[Dict[str, Any]]: + """ + Get raw model data as dictionaries. + + Returns: + List of model dictionaries + """ + models = self.list_models() + return [ + { + "id": model.id, + "name": model.name, + "description": model.description, + "context_length": model.context_length, + "pricing": model.pricing, + } + for model in models + ] + + def get_raw_model(self, model_id: str) -> Optional[Dict[str, Any]]: + """ + Get raw model data for a specific model. + + Args: + model_id: Model identifier + + Returns: + Model dictionary or None + """ + model = self.get_model(model_id) + if model: + return { + "id": model.id, + "name": model.name, + "description": model.description, + "context_length": model.context_length, + "pricing": model.pricing, + } + return None diff --git a/oai/providers/ollama.py b/oai/providers/ollama.py new file mode 100644 index 0000000..9d0de6a --- /dev/null +++ b/oai/providers/ollama.py @@ -0,0 +1,423 @@ +""" +Ollama provider for local AI model serving. + +This provider connects to a local Ollama server for running models +locally without API keys or external dependencies. +""" + +import json +import time +from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union + +import requests + +from oai.constants import OLLAMA_DEFAULT_URL +from oai.providers.base import ( + AIProvider, + ChatMessage, + ChatResponse, + ChatResponseChoice, + ModelInfo, + ProviderCapabilities, + StreamChunk, + UsageStats, +) +from oai.utils.logging import get_logger + +logger = get_logger() + + +class OllamaProvider(AIProvider): + """ + Ollama local model provider. + + Connects to a local Ollama server for running models locally. + No API key required. + """ + + def __init__( + self, + api_key: str = "", + base_url: Optional[str] = None, + **kwargs: Any, + ): + """ + Initialize Ollama provider. + + Args: + api_key: Not used (Ollama doesn't require API keys) + base_url: Ollama server URL (default: http://localhost:11434) + **kwargs: Additional arguments (ignored) + """ + super().__init__(api_key or "", base_url) + self.base_url = base_url or OLLAMA_DEFAULT_URL + self._check_server_available() + + def _check_server_available(self) -> bool: + """ + Check if Ollama server is accessible. + + Returns: + True if server is accessible + """ + try: + response = requests.get(f"{self.base_url}/api/tags", timeout=2) + if response.ok: + logger.info(f"Ollama server accessible at {self.base_url}") + return True + else: + logger.warning(f"Ollama server returned status {response.status_code}") + return False + except requests.RequestException as e: + logger.warning(f"Ollama server not accessible at {self.base_url}: {e}") + return False + + @property + def name(self) -> str: + """Get provider name.""" + return "Ollama" + + @property + def capabilities(self) -> ProviderCapabilities: + """Get provider capabilities.""" + return ProviderCapabilities( + streaming=True, + tools=False, # Tool support varies by model + images=False, # Image support varies by model + online=True, # Web search via DuckDuckGo/Google + max_context=8192, # Varies by model + ) + + def list_models(self, filter_text_only: bool = True) -> List[ModelInfo]: + """ + List models from local Ollama installation. + + Args: + filter_text_only: Ignored for Ollama + + Returns: + List of available models + """ + try: + response = requests.get(f"{self.base_url}/api/tags", timeout=5) + response.raise_for_status() + data = response.json() + + models = [] + for model_data in data.get("models", []): + models.append(self._parse_model(model_data)) + + logger.info(f"Found {len(models)} Ollama models") + return models + + except requests.RequestException as e: + logger.error(f"Failed to list Ollama models: {e}") + return [] + + def _parse_model(self, model_data: Dict[str, Any]) -> ModelInfo: + """ + Parse Ollama model data into ModelInfo. + + Args: + model_data: Raw model data from Ollama API + + Returns: + ModelInfo object + """ + model_name = model_data.get("name", "unknown") + size_bytes = model_data.get("size", 0) + size_gb = size_bytes / (1024 ** 3) if size_bytes else 0 + + return ModelInfo( + id=model_name, + name=model_name, + description=f"Size: {size_gb:.1f}GB", + context_length=8192, # Default, varies by model + pricing={}, # Local models are free + supported_parameters=["stream", "temperature", "max_tokens"], + ) + + def get_model(self, model_id: str) -> Optional[ModelInfo]: + """ + Get information about a specific model. + + Args: + model_id: Model identifier + + Returns: + ModelInfo or None if not found + """ + models = self.list_models() + for model in models: + if model.id == model_id: + return model + return None + + def chat( + self, + model: str, + messages: List[ChatMessage], + stream: bool = False, + max_tokens: Optional[int] = None, + temperature: Optional[float] = None, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[str] = None, + **kwargs: Any, + ) -> Union[ChatResponse, Iterator[StreamChunk]]: + """ + Send chat request to Ollama. + + Args: + model: Model name + messages: Chat messages + stream: Whether to stream response + max_tokens: Maximum tokens (Ollama calls this num_predict) + temperature: Sampling temperature + tools: Not supported + tool_choice: Not supported + **kwargs: Additional parameters + + Returns: + ChatResponse or Iterator[StreamChunk] + """ + # Convert messages to Ollama format + ollama_messages = [] + for msg in messages: + ollama_messages.append({ + "role": msg.role, + "content": msg.content or "", + }) + + # Build request payload + payload: Dict[str, Any] = { + "model": model, + "messages": ollama_messages, + "stream": stream, + } + + # Add optional parameters + options = {} + if temperature is not None: + options["temperature"] = temperature + if max_tokens is not None: + options["num_predict"] = max_tokens + + if options: + payload["options"] = options + + logger.debug(f"Ollama request: model={model}, messages={len(ollama_messages)}") + + try: + if stream: + return self._stream_chat(payload) + else: + return self._sync_chat(payload) + + except requests.RequestException as e: + logger.error(f"Ollama request failed: {e}") + # Return error response + return ChatResponse( + id="error", + choices=[ + ChatResponseChoice( + index=0, + message=ChatMessage( + role="assistant", + content=f"Error: {str(e)}", + ), + finish_reason="error", + ) + ], + ) + + def _sync_chat(self, payload: Dict[str, Any]) -> ChatResponse: + """ + Send synchronous chat request. + + Args: + payload: Request payload + + Returns: + ChatResponse + """ + response = requests.post( + f"{self.base_url}/api/chat", + json=payload, + timeout=120, + ) + response.raise_for_status() + data = response.json() + + # Parse response + message_data = data.get("message", {}) + content = message_data.get("content", "") + + # Extract token usage if available + usage = None + if "prompt_eval_count" in data or "eval_count" in data: + usage = UsageStats( + prompt_tokens=data.get("prompt_eval_count", 0), + completion_tokens=data.get("eval_count", 0), + total_tokens=data.get("prompt_eval_count", 0) + data.get("eval_count", 0), + total_cost_usd=0.0, # Local models are free + ) + + return ChatResponse( + id=str(time.time()), + choices=[ + ChatResponseChoice( + index=0, + message=ChatMessage(role="assistant", content=content), + finish_reason="stop", + ) + ], + usage=usage, + model=data.get("model"), + ) + + def _stream_chat(self, payload: Dict[str, Any]) -> Iterator[StreamChunk]: + """ + Stream chat response from Ollama. + + Args: + payload: Request payload + + Yields: + StreamChunk objects + """ + response = requests.post( + f"{self.base_url}/api/chat", + json=payload, + stream=True, + timeout=120, + ) + response.raise_for_status() + + total_prompt_tokens = 0 + total_completion_tokens = 0 + + for line in response.iter_lines(): + if not line: + continue + + try: + data = json.loads(line) + + # Extract content delta + message_data = data.get("message", {}) + content = message_data.get("content", "") + + # Check if done + done = data.get("done", False) + finish_reason = "stop" if done else None + + # Extract usage if available + usage = None + if done and ("prompt_eval_count" in data or "eval_count" in data): + total_prompt_tokens = data.get("prompt_eval_count", 0) + total_completion_tokens = data.get("eval_count", 0) + usage = UsageStats( + prompt_tokens=total_prompt_tokens, + completion_tokens=total_completion_tokens, + total_tokens=total_prompt_tokens + total_completion_tokens, + total_cost_usd=0.0, + ) + + yield StreamChunk( + id=str(time.time()), + delta_content=content if content else None, + finish_reason=finish_reason, + usage=usage, + ) + + except json.JSONDecodeError as e: + logger.error(f"Failed to parse Ollama stream chunk: {e}") + yield StreamChunk( + id="error", + error=f"Parse error: {e}", + ) + + async def chat_async( + self, + model: str, + messages: List[ChatMessage], + stream: bool = False, + max_tokens: Optional[int] = None, + temperature: Optional[float] = None, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[str] = None, + **kwargs: Any, + ) -> Union[ChatResponse, AsyncIterator[StreamChunk]]: + """ + Async chat not implemented for Ollama. + + Args: + model: Model name + messages: Chat messages + stream: Whether to stream + max_tokens: Max tokens + temperature: Temperature + tools: Tools (not supported) + tool_choice: Tool choice (not supported) + **kwargs: Additional args + + Returns: + ChatResponse or AsyncIterator[StreamChunk] + + Raises: + NotImplementedError: Async not implemented + """ + raise NotImplementedError("Async chat not implemented for Ollama provider") + + def get_credits(self) -> Optional[Dict[str, Any]]: + """ + Get account credits. + + Returns: + None (Ollama is local and free) + """ + return None + + def clear_cache(self) -> None: + """Clear model cache (no-op for Ollama).""" + pass + + def get_raw_models(self) -> List[Dict[str, Any]]: + """ + Get raw model data as dictionaries. + + Returns: + List of model dictionaries + """ + models = self.list_models() + return [ + { + "id": model.id, + "name": model.name, + "description": model.description, + "context_length": model.context_length, + "pricing": model.pricing, + } + for model in models + ] + + def get_raw_model(self, model_id: str) -> Optional[Dict[str, Any]]: + """ + Get raw model data for a specific model. + + Args: + model_id: Model identifier + + Returns: + Model dictionary or None + """ + model = self.get_model(model_id) + if model: + return { + "id": model.id, + "name": model.name, + "description": model.description, + "context_length": model.context_length, + "pricing": model.pricing, + } + return None diff --git a/oai/providers/openai.py b/oai/providers/openai.py new file mode 100644 index 0000000..0ec2139 --- /dev/null +++ b/oai/providers/openai.py @@ -0,0 +1,630 @@ +""" +OpenAI provider for GPT models. + +This provider connects to OpenAI's API for accessing GPT-4, GPT-3.5, and other OpenAI models. +""" + +from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union + +from openai import OpenAI, AsyncOpenAI +from openai.types.chat import ChatCompletion, ChatCompletionChunk + +from oai.constants import OPENAI_BASE_URL +from oai.providers.base import ( + AIProvider, + ChatMessage, + ChatResponse, + ChatResponseChoice, + ModelInfo, + ProviderCapabilities, + StreamChunk, + ToolCall, + ToolFunction, + UsageStats, +) +from oai.utils.logging import get_logger + +logger = get_logger() + + +# Model aliases for convenience +MODEL_ALIASES = { + "gpt-4": "gpt-4-turbo", + "gpt-4-turbo": "gpt-4-turbo-2024-04-09", + "gpt-4o": "gpt-4o-2024-11-20", + "gpt-4o-mini": "gpt-4o-mini-2024-07-18", + "gpt-3.5": "gpt-3.5-turbo", + "gpt-3.5-turbo": "gpt-3.5-turbo-0125", + "o1": "o1-2024-12-17", + "o1-mini": "o1-mini-2024-09-12", + "o1-preview": "o1-preview-2024-09-12", +} + + +class OpenAIProvider(AIProvider): + """ + OpenAI API provider. + + Provides access to GPT-4, GPT-3.5, o1, and other OpenAI models. + """ + + def __init__( + self, + api_key: str, + base_url: Optional[str] = None, + app_name: str = "oAI", + app_url: str = "", + **kwargs: Any, + ): + """ + Initialize OpenAI provider. + + Args: + api_key: OpenAI API key + base_url: Optional custom base URL + app_name: Application name (for headers) + app_url: Application URL (for headers) + **kwargs: Additional arguments + """ + super().__init__(api_key, base_url or OPENAI_BASE_URL) + self.client = OpenAI(api_key=api_key, base_url=self.base_url) + self.async_client = AsyncOpenAI(api_key=api_key, base_url=self.base_url) + self._models_cache: Optional[List[ModelInfo]] = None + + @property + def name(self) -> str: + """Get provider name.""" + return "OpenAI" + + @property + def capabilities(self) -> ProviderCapabilities: + """Get provider capabilities.""" + return ProviderCapabilities( + streaming=True, + tools=True, + images=True, + online=True, # Web search via DuckDuckGo/Google + max_context=128000, + ) + + def list_models(self, filter_text_only: bool = True) -> List[ModelInfo]: + """ + List available OpenAI models. + + Args: + filter_text_only: Whether to filter for text models only + + Returns: + List of ModelInfo objects + """ + if self._models_cache: + return self._models_cache + + try: + models_response = self.client.models.list() + models = [] + + for model in models_response.data: + # Filter for chat models + if "gpt" in model.id or "o1" in model.id: + models.append(self._parse_model(model)) + + # Sort by name + models.sort(key=lambda m: m.name) + self._models_cache = models + + logger.info(f"Loaded {len(models)} OpenAI models") + return models + + except Exception as e: + logger.error(f"Failed to list OpenAI models: {e}") + return self._get_fallback_models() + + def _get_fallback_models(self) -> List[ModelInfo]: + """ + Get fallback list of common OpenAI models. + + Returns: + List of common models + """ + return [ + ModelInfo( + id="gpt-4o", + name="GPT-4o", + description="Most capable GPT-4 model", + context_length=128000, + pricing={"input": 5.0, "output": 15.0}, + supported_parameters=["temperature", "max_tokens", "stream", "tools"], + input_modalities=["text", "image"], + ), + ModelInfo( + id="gpt-4o-mini", + name="GPT-4o Mini", + description="Affordable and fast GPT-4 class model", + context_length=128000, + pricing={"input": 0.15, "output": 0.6}, + supported_parameters=["temperature", "max_tokens", "stream", "tools"], + input_modalities=["text", "image"], + ), + ModelInfo( + id="gpt-4-turbo", + name="GPT-4 Turbo", + description="GPT-4 Turbo with vision", + context_length=128000, + pricing={"input": 10.0, "output": 30.0}, + supported_parameters=["temperature", "max_tokens", "stream", "tools"], + input_modalities=["text", "image"], + ), + ModelInfo( + id="gpt-3.5-turbo", + name="GPT-3.5 Turbo", + description="Fast and affordable model", + context_length=16384, + pricing={"input": 0.5, "output": 1.5}, + supported_parameters=["temperature", "max_tokens", "stream", "tools"], + ), + ModelInfo( + id="o1", + name="o1", + description="Advanced reasoning model", + context_length=200000, + pricing={"input": 15.0, "output": 60.0}, + supported_parameters=["max_tokens"], + ), + ModelInfo( + id="o1-mini", + name="o1-mini", + description="Fast reasoning model", + context_length=128000, + pricing={"input": 3.0, "output": 12.0}, + supported_parameters=["max_tokens"], + ), + ] + + def _parse_model(self, model: Any) -> ModelInfo: + """ + Parse OpenAI model into ModelInfo. + + Args: + model: OpenAI model object + + Returns: + ModelInfo object + """ + model_id = model.id + + # Determine context length + context_length = 8192 # Default + if "gpt-4o" in model_id or "gpt-4-turbo" in model_id: + context_length = 128000 + elif "gpt-4" in model_id: + context_length = 8192 + elif "gpt-3.5-turbo" in model_id: + context_length = 16384 + elif "o1" in model_id: + context_length = 128000 + + # Determine pricing (approximate) + pricing = {} + if "gpt-4o-mini" in model_id: + pricing = {"input": 0.15, "output": 0.6} + elif "gpt-4o" in model_id: + pricing = {"input": 5.0, "output": 15.0} + elif "gpt-4-turbo" in model_id: + pricing = {"input": 10.0, "output": 30.0} + elif "gpt-4" in model_id: + pricing = {"input": 30.0, "output": 60.0} + elif "gpt-3.5" in model_id: + pricing = {"input": 0.5, "output": 1.5} + elif "o1" in model_id and "mini" not in model_id: + pricing = {"input": 15.0, "output": 60.0} + elif "o1-mini" in model_id: + pricing = {"input": 3.0, "output": 12.0} + + return ModelInfo( + id=model_id, + name=model_id, + description="", + context_length=context_length, + pricing=pricing, + supported_parameters=["temperature", "max_tokens", "stream", "tools"], + ) + + def get_model(self, model_id: str) -> Optional[ModelInfo]: + """ + Get information about a specific model. + + Args: + model_id: Model identifier + + Returns: + ModelInfo or None + """ + # Resolve alias + resolved_id = MODEL_ALIASES.get(model_id, model_id) + + models = self.list_models() + for model in models: + if model.id == resolved_id or model.id == model_id: + return model + + # Try to fetch directly + try: + model = self.client.models.retrieve(resolved_id) + return self._parse_model(model) + except Exception: + return None + + def chat( + self, + model: str, + messages: List[ChatMessage], + stream: bool = False, + max_tokens: Optional[int] = None, + temperature: Optional[float] = None, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[str] = None, + **kwargs: Any, + ) -> Union[ChatResponse, Iterator[StreamChunk]]: + """ + Send chat completion request to OpenAI. + + Args: + model: Model ID + messages: Chat messages + stream: Whether to stream response + max_tokens: Maximum tokens + temperature: Sampling temperature + tools: Tool definitions + tool_choice: Tool selection mode + **kwargs: Additional parameters + + Returns: + ChatResponse or Iterator[StreamChunk] + """ + # Resolve model alias + model_id = MODEL_ALIASES.get(model, model) + + # Convert messages to OpenAI format + openai_messages = [] + for msg in messages: + message_dict = {"role": msg.role, "content": msg.content or ""} + + if msg.tool_calls: + message_dict["tool_calls"] = [ + { + "id": tc.id, + "type": tc.type, + "function": { + "name": tc.function.name, + "arguments": tc.function.arguments, + }, + } + for tc in msg.tool_calls + ] + + if msg.tool_call_id: + message_dict["tool_call_id"] = msg.tool_call_id + + openai_messages.append(message_dict) + + # Build request parameters + params: Dict[str, Any] = { + "model": model_id, + "messages": openai_messages, + "stream": stream, + } + + # Add optional parameters + if max_tokens is not None: + params["max_tokens"] = max_tokens + if temperature is not None and "o1" not in model_id: + # o1 models don't support temperature + params["temperature"] = temperature + if tools: + params["tools"] = tools + if tool_choice: + params["tool_choice"] = tool_choice + + logger.debug(f"OpenAI request: model={model_id}, messages={len(openai_messages)}") + + try: + if stream: + return self._stream_chat(params) + else: + return self._sync_chat(params) + + except Exception as e: + logger.error(f"OpenAI request failed: {e}") + return ChatResponse( + id="error", + choices=[ + ChatResponseChoice( + index=0, + message=ChatMessage(role="assistant", content=f"Error: {str(e)}"), + finish_reason="error", + ) + ], + ) + + def _sync_chat(self, params: Dict[str, Any]) -> ChatResponse: + """ + Send synchronous chat request. + + Args: + params: Request parameters + + Returns: + ChatResponse + """ + completion: ChatCompletion = self.client.chat.completions.create(**params) + + # Convert to our format + choices = [] + for choice in completion.choices: + # Convert tool calls if present + tool_calls = None + if choice.message.tool_calls: + tool_calls = [ + ToolCall( + id=tc.id, + type=tc.type, + function=ToolFunction( + name=tc.function.name, + arguments=tc.function.arguments, + ), + ) + for tc in choice.message.tool_calls + ] + + choices.append( + ChatResponseChoice( + index=choice.index, + message=ChatMessage( + role=choice.message.role, + content=choice.message.content, + tool_calls=tool_calls, + ), + finish_reason=choice.finish_reason, + ) + ) + + # Convert usage + usage = None + if completion.usage: + usage = UsageStats( + prompt_tokens=completion.usage.prompt_tokens, + completion_tokens=completion.usage.completion_tokens, + total_tokens=completion.usage.total_tokens, + ) + + return ChatResponse( + id=completion.id, + choices=choices, + usage=usage, + model=completion.model, + created=completion.created, + ) + + def _stream_chat(self, params: Dict[str, Any]) -> Iterator[StreamChunk]: + """ + Stream chat response from OpenAI. + + Args: + params: Request parameters + + Yields: + StreamChunk objects + """ + stream = self.client.chat.completions.create(**params) + + for chunk in stream: + chunk_data: ChatCompletionChunk = chunk + + if not chunk_data.choices: + continue + + choice = chunk_data.choices[0] + delta = choice.delta + + # Extract content + content = delta.content if delta.content else None + + # Extract finish reason + finish_reason = choice.finish_reason + + # Extract usage (usually in last chunk) + usage = None + if hasattr(chunk_data, "usage") and chunk_data.usage: + usage = UsageStats( + prompt_tokens=chunk_data.usage.prompt_tokens, + completion_tokens=chunk_data.usage.completion_tokens, + total_tokens=chunk_data.usage.total_tokens, + ) + + yield StreamChunk( + id=chunk_data.id, + delta_content=content, + finish_reason=finish_reason, + usage=usage, + ) + + async def chat_async( + self, + model: str, + messages: List[ChatMessage], + stream: bool = False, + max_tokens: Optional[int] = None, + temperature: Optional[float] = None, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[str] = None, + **kwargs: Any, + ) -> Union[ChatResponse, AsyncIterator[StreamChunk]]: + """ + Send async chat request to OpenAI. + + Args: + model: Model ID + messages: Chat messages + stream: Whether to stream + max_tokens: Max tokens + temperature: Temperature + tools: Tool definitions + tool_choice: Tool choice + **kwargs: Additional args + + Returns: + ChatResponse or AsyncIterator[StreamChunk] + """ + # Resolve model alias + model_id = MODEL_ALIASES.get(model, model) + + # Convert messages + openai_messages = [msg.to_dict() for msg in messages] + + # Build params + params: Dict[str, Any] = { + "model": model_id, + "messages": openai_messages, + "stream": stream, + } + + if max_tokens: + params["max_tokens"] = max_tokens + if temperature is not None and "o1" not in model_id: + params["temperature"] = temperature + if tools: + params["tools"] = tools + if tool_choice: + params["tool_choice"] = tool_choice + + if stream: + return self._stream_chat_async(params) + else: + completion = await self.async_client.chat.completions.create(**params) + # Convert to ChatResponse (similar to _sync_chat) + return self._convert_completion(completion) + + async def _stream_chat_async(self, params: Dict[str, Any]) -> AsyncIterator[StreamChunk]: + """ + Stream async chat response. + + Args: + params: Request parameters + + Yields: + StreamChunk objects + """ + stream = await self.async_client.chat.completions.create(**params) + + async for chunk in stream: + if not chunk.choices: + continue + + choice = chunk.choices[0] + delta = choice.delta + + yield StreamChunk( + id=chunk.id, + delta_content=delta.content, + finish_reason=choice.finish_reason, + ) + + def _convert_completion(self, completion: ChatCompletion) -> ChatResponse: + """Helper to convert OpenAI completion to ChatResponse.""" + choices = [] + for choice in completion.choices: + tool_calls = None + if choice.message.tool_calls: + tool_calls = [ + ToolCall( + id=tc.id, + type=tc.type, + function=ToolFunction( + name=tc.function.name, + arguments=tc.function.arguments, + ), + ) + for tc in choice.message.tool_calls + ] + + choices.append( + ChatResponseChoice( + index=choice.index, + message=ChatMessage( + role=choice.message.role, + content=choice.message.content, + tool_calls=tool_calls, + ), + finish_reason=choice.finish_reason, + ) + ) + + usage = None + if completion.usage: + usage = UsageStats( + prompt_tokens=completion.usage.prompt_tokens, + completion_tokens=completion.usage.completion_tokens, + total_tokens=completion.usage.total_tokens, + ) + + return ChatResponse( + id=completion.id, + choices=choices, + usage=usage, + model=completion.model, + created=completion.created, + ) + + def get_credits(self) -> Optional[Dict[str, Any]]: + """ + Get account credits. + + Returns: + None (OpenAI doesn't provide credit API) + """ + return None + + def clear_cache(self) -> None: + """Clear model cache.""" + self._models_cache = None + + def get_raw_models(self) -> List[Dict[str, Any]]: + """ + Get raw model data as dictionaries. + + Returns: + List of model dictionaries + """ + models = self.list_models() + return [ + { + "id": model.id, + "name": model.name, + "description": model.description, + "context_length": model.context_length, + "pricing": model.pricing, + } + for model in models + ] + + def get_raw_model(self, model_id: str) -> Optional[Dict[str, Any]]: + """ + Get raw model data for a specific model. + + Args: + model_id: Model identifier + + Returns: + Model dictionary or None + """ + model = self.get_model(model_id) + if model: + return { + "id": model.id, + "name": model.name, + "description": model.description, + "context_length": model.context_length, + "pricing": model.pricing, + } + return None diff --git a/oai/providers/registry.py b/oai/providers/registry.py new file mode 100644 index 0000000..4fa4b21 --- /dev/null +++ b/oai/providers/registry.py @@ -0,0 +1,60 @@ +""" +Provider registry for AI model providers. + +This module maintains a central registry of all available AI providers, +allowing dynamic provider lookup and registration. +""" + +from typing import Dict, List, Optional, Type + +from oai.providers.base import AIProvider + +# Global provider registry +PROVIDER_REGISTRY: Dict[str, Type[AIProvider]] = {} + + +def register_provider(name: str, provider_class: Type[AIProvider]) -> None: + """ + Register a provider class with the given name. + + Args: + name: Provider identifier (e.g., "openrouter", "anthropic") + provider_class: The provider class to register + """ + PROVIDER_REGISTRY[name] = provider_class + + +def get_provider_class(name: str) -> Optional[Type[AIProvider]]: + """ + Get a provider class by name. + + Args: + name: Provider identifier + + Returns: + Provider class or None if not found + """ + return PROVIDER_REGISTRY.get(name) + + +def list_providers() -> List[str]: + """ + List all registered provider names. + + Returns: + List of provider identifiers + """ + return list(PROVIDER_REGISTRY.keys()) + + +def is_provider_registered(name: str) -> bool: + """ + Check if a provider is registered. + + Args: + name: Provider identifier + + Returns: + True if provider is registered + """ + return name in PROVIDER_REGISTRY diff --git a/oai/tui/app.py b/oai/tui/app.py index 29b0bbb..76e1810 100644 --- a/oai/tui/app.py +++ b/oai/tui/app.py @@ -16,6 +16,7 @@ from oai.core.client import AIClient from oai.core.session import ChatSession from oai.tui.screens import ( AlertDialog, + CommandsScreen, ConfirmDialog, ConfigScreen, ConversationSelectorScreen, @@ -69,7 +70,8 @@ class oAIChatApp(App): """Compose the TUI layout.""" model_name = self.session.selected_model.get("name", "") if self.session.selected_model else "" model_info = self.session.selected_model if self.session.selected_model else None - yield Header(version=__version__, model=model_name, model_info=model_info) + provider_name = self.session.client.provider_name if self.session.client else "" + yield Header(version=__version__, model=model_name, model_info=model_info, provider=provider_name) yield ChatDisplay() yield InputBar() yield CommandDropdown() @@ -158,10 +160,10 @@ class oAIChatApp(App): else: # Command is complete, submit it directly dropdown.hide() + chat_input.value = "" # Clear immediately # Process the command directly async def submit_command(): await self._process_submitted_input(selected) - chat_input.value = "" self.call_later(submit_command) return elif event.key == "escape": @@ -232,12 +234,12 @@ class oAIChatApp(App): if not user_input: return - # Process the input - await self._process_submitted_input(user_input) - - # Clear input field + # Clear input field immediately event.input.value = "" + # Process the input (async, will wait for AI response) + await self._process_submitted_input(user_input) + async def _process_submitted_input(self, user_input: str) -> None: """Process submitted input (command or message). @@ -280,6 +282,10 @@ class oAIChatApp(App): await self.push_screen(HelpScreen()) return + if cmd_word == "commands": + await self.push_screen(CommandsScreen()) + return + if cmd_word == "stats": await self.push_screen(StatsScreen(self.session)) return @@ -288,7 +294,7 @@ class oAIChatApp(App): # Check if there are any arguments args = command_text.split(maxsplit=1) if len(args) == 1: # No arguments, just "/config" - await self.push_screen(ConfigScreen(self.settings)) + await self.push_screen(ConfigScreen(self.settings, self.session)) return # If there are arguments, fall through to normal command handler @@ -447,9 +453,11 @@ class oAIChatApp(App): # Update header if model changed if self.session.selected_model: header = self.query_one(Header) + provider_name = self.session.client.provider_name if self.session.client else "" header.update_model( self.session.selected_model.get("name", ""), - self.session.selected_model + self.session.selected_model, + provider_name ) # Update MCP status indicator in input bar @@ -472,7 +480,7 @@ class oAIChatApp(App): # Create assistant message widget with loading indicator model_name = self.session.selected_model.get("name", "Assistant") if self.session.selected_model else "Assistant" - assistant_widget = AssistantMessageWidget(model_name) + assistant_widget = AssistantMessageWidget(model_name, chat_display=chat_display) await chat_display.add_message(assistant_widget) # Show loading indicator immediately @@ -851,7 +859,13 @@ class oAIChatApp(App): if selected: self.session.set_model(selected) header = self.query_one(Header) - header.update_model(selected.get("name", ""), selected) + provider_name = self.session.client.provider_name if self.session.client else "" + header.update_model(selected.get("name", ""), selected, provider_name) + + # Save this model as the last used for this provider + model_id = selected.get("id") + if model_id and provider_name: + self.settings.set_provider_model(provider_name, model_id) # Save as default if requested if set_as_default: diff --git a/oai/tui/screens/__init__.py b/oai/tui/screens/__init__.py index 9b5efbe..7c8e981 100644 --- a/oai/tui/screens/__init__.py +++ b/oai/tui/screens/__init__.py @@ -1,5 +1,6 @@ """TUI screens for oAI.""" +from oai.tui.screens.commands_screen import CommandsScreen from oai.tui.screens.config_screen import ConfigScreen from oai.tui.screens.conversation_selector import ConversationSelectorScreen from oai.tui.screens.credits_screen import CreditsScreen @@ -10,6 +11,7 @@ from oai.tui.screens.stats_screen import StatsScreen __all__ = [ "AlertDialog", + "CommandsScreen", "ConfirmDialog", "ConfigScreen", "ConversationSelectorScreen", diff --git a/oai/tui/screens/commands_screen.py b/oai/tui/screens/commands_screen.py new file mode 100644 index 0000000..2790e58 --- /dev/null +++ b/oai/tui/screens/commands_screen.py @@ -0,0 +1,172 @@ +"""Commands reference screen for oAI TUI.""" + +from textual.app import ComposeResult +from textual.containers import Container, Vertical, VerticalScroll +from textual.screen import ModalScreen +from textual.widgets import Button, Static + + +class CommandsScreen(ModalScreen[None]): + """Modal screen showing all available commands.""" + + DEFAULT_CSS = """ + CommandsScreen { + align: center middle; + } + + CommandsScreen > Container { + width: 90; + height: 40; + background: #1e1e1e; + border: solid #555555; + } + + CommandsScreen .header { + dock: top; + width: 100%; + height: auto; + background: #2d2d2d; + color: #cccccc; + padding: 0 2; + } + + CommandsScreen .content { + width: 100%; + height: 1fr; + background: #1e1e1e; + padding: 2; + color: #cccccc; + overflow-y: auto; + scrollbar-background: #1e1e1e; + scrollbar-color: #555555; + scrollbar-size: 1 1; + } + + CommandsScreen .footer { + dock: bottom; + width: 100%; + height: auto; + background: #2d2d2d; + padding: 1 2; + align: center middle; + } + """ + + def compose(self) -> ComposeResult: + """Compose the commands screen.""" + with Container(): + yield Static("[bold]Commands Reference[/]", classes="header") + with VerticalScroll(classes="content"): + yield Static(self._get_commands_text(), markup=True) + with Vertical(classes="footer"): + yield Button("Close", id="close-button", variant="primary") + + def _get_commands_text(self) -> str: + """Generate formatted commands text.""" + return """[bold cyan]General Commands[/] + +[green]/help[/] - Show help screen with keyboard shortcuts +[green]/commands[/] - Show this commands reference +[green]/model[/] - Open model selector (or press F2) +[green]/stats[/] - Show session statistics (or press Ctrl+S) +[green]/credits[/] - Check account credits (OpenRouter) or view console link +[green]/clear[/] - Clear chat display +[green]/reset[/] - Reset conversation history +[green]/retry[/] - Retry last prompt +[green]/paste[/] - Paste from clipboard + +[bold cyan]Provider Commands[/] + +[green]/provider[/] - Show current provider +[green]/provider openrouter[/] - Switch to OpenRouter +[green]/provider anthropic[/] - Switch to Anthropic (Claude) +[green]/provider openai[/] - Switch to OpenAI (ChatGPT) +[green]/provider ollama[/] - Switch to Ollama (local) + +[bold cyan]Online Mode (Web Search)[/] + +[green]/online[/] - Show online mode status +[green]/online on[/] - Enable web search +[green]/online off[/] - Disable web search + +[dim]Search Providers:[/] +• [yellow]anthropic_native[/] - Anthropic's native search with citations ($0.01/search) +• [yellow]duckduckgo[/] - Free web scraping (default, works with all providers) +• [yellow]google[/] - Google Custom Search (requires API key) + +[bold cyan]Configuration Commands[/] + +[green]/config[/] - View all settings +[green]/config provider [/] - Set default provider +[green]/config search_provider [/] - Set search provider (anthropic_native/duckduckgo/google) +[green]/config openrouter_api_key [/] - Set OpenRouter API key +[green]/config anthropic_api_key [/] - Set Anthropic API key +[green]/config openai_api_key [/] - Set OpenAI API key +[green]/config ollama_base_url [/] - Set Ollama server URL +[green]/config google_api_key [/] - Set Google API key (for Google search) +[green]/config google_search_engine_id [/] - Set Google Search Engine ID +[green]/config online on|off[/] - Set default online mode +[green]/config stream on|off[/] - Toggle streaming +[green]/config model [/] - Set default model +[green]/config system [/] - Set system prompt +[green]/config maxtoken [/] - Set token limit + +[bold cyan]Memory & Context[/] + +[green]/memory on[/] - Enable conversation memory +[green]/memory off[/] - Disable memory (fresh context each message) + +[bold cyan]Conversation Management[/] + +[green]/save [/] - Save current conversation +[green]/load [/] - Load saved conversation +[green]/list[/] - List all saved conversations +[green]/delete [/] - Delete a conversation +[green]/prev[/] - Show previous message +[green]/next[/] - Show next message + +[bold cyan]Export Commands[/] + +[green]/export md [/] - Export conversation as Markdown +[green]/export json [/] - Export as JSON +[green]/export html [/] - Export as HTML + +[bold cyan]MCP (Model Context Protocol)[/] + +[green]/mcp on[/] - Enable MCP file access +[green]/mcp off[/] - Disable MCP +[green]/mcp status[/] - Show MCP status +[green]/mcp add [/] - Add folder for file access +[green]/mcp add db [/] - Add SQLite database +[green]/mcp remove [/] - Remove folder/database +[green]/mcp list[/] - List allowed folders +[green]/mcp db list[/] - List added databases +[green]/mcp db [/] - Switch to database mode +[green]/mcp files[/] - Switch to file mode +[green]/mcp write on[/] - Enable write mode (allows file modifications) +[green]/mcp write off[/] - Disable write mode + +[bold cyan]System Prompt[/] + +[green]/system [/] - Set custom system prompt for session +[green]/config system [/] - Set default system prompt + +[bold cyan]Keyboard Shortcuts[/] + +• [yellow]F1[/] - Help screen +• [yellow]F2[/] - Model selector +• [yellow]Ctrl+S[/] - Statistics +• [yellow]Ctrl+Q[/] - Quit +• [yellow]Ctrl+Y[/] - Copy latest reply in Markdown +• [yellow]Up/Down[/] - Command history +• [yellow]Tab[/] - Command completion +""" + + def on_button_pressed(self, event: Button.Pressed) -> None: + """Handle button press.""" + self.dismiss() + + def on_key(self, event) -> None: + """Handle keyboard shortcuts.""" + if event.key in ("escape", "enter"): + self.dismiss() diff --git a/oai/tui/screens/config_screen.py b/oai/tui/screens/config_screen.py index 042c4b6..d5d197d 100644 --- a/oai/tui/screens/config_screen.py +++ b/oai/tui/screens/config_screen.py @@ -50,9 +50,10 @@ class ConfigScreen(ModalScreen[None]): } """ - def __init__(self, settings: Settings): + def __init__(self, settings: Settings, session=None): super().__init__() self.settings = settings + self.session = session def compose(self) -> ComposeResult: """Compose the screen.""" @@ -67,35 +68,90 @@ class ConfigScreen(ModalScreen[None]): """Generate the configuration text.""" from oai.constants import DEFAULT_SYSTEM_PROMPT - # API Key display - api_key_display = "***" + self.settings.api_key[-4:] if self.settings.api_key else "Not set" + config_lines = ["[bold cyan]═══ CONFIGURATION ═══[/]\n"] + + # Current Session Info + if self.session and self.session.client: + config_lines.append("[bold yellow]Current Session:[/]") + + provider = self.session.client.provider_name + config_lines.append(f"[bold]Provider:[/] [green]{provider}[/]") + + # Provider URL + if hasattr(self.session.client.provider, 'base_url'): + provider_url = self.session.client.provider.base_url + config_lines.append(f"[bold]Provider URL:[/] {provider_url}") + + # API Key status + if provider == "ollama": + config_lines.append(f"[bold]API Key:[/] [dim]Not required (local)[/]") + else: + api_key = self.settings.get_provider_api_key(provider) + if api_key: + key_display = "***" + api_key[-4:] if len(api_key) > 4 else "***" + config_lines.append(f"[bold]API Key:[/] [green]{key_display}[/]") + else: + config_lines.append(f"[bold]API Key:[/] [red]Not set[/]") + + # Current model + if self.session.selected_model: + model_name = self.session.selected_model.get("name", "Unknown") + config_lines.append(f"[bold]Current Model:[/] {model_name}") + + config_lines.append("") + + # Default Settings + config_lines.append("[bold yellow]Default Settings:[/]") + config_lines.append(f"[bold]Default Provider:[/] {self.settings.default_provider}") + + # Mask helper + def mask_key(key): + if not key: + return "[red]Not set[/]" + return "***" + key[-4:] if len(key) > 4 else "***" + + config_lines.append(f"[bold]OpenRouter Key:[/] {mask_key(self.settings.openrouter_api_key)}") + config_lines.append(f"[bold]Anthropic Key:[/] {mask_key(self.settings.anthropic_api_key)}") + config_lines.append(f"[bold]OpenAI Key:[/] {mask_key(self.settings.openai_api_key)}") + config_lines.append(f"[bold]Ollama URL:[/] {self.settings.ollama_base_url}") + config_lines.append(f"[bold]Default Model:[/] {self.settings.default_model or 'Not set'}") # System prompt display if self.settings.default_system_prompt is None: - system_prompt_display = f"[default] {DEFAULT_SYSTEM_PROMPT[:40]}..." + system_prompt_display = f"[dim][default][/] {DEFAULT_SYSTEM_PROMPT[:40]}..." elif self.settings.default_system_prompt == "": - system_prompt_display = "[blank]" + system_prompt_display = "[dim][blank][/]" else: prompt = self.settings.default_system_prompt system_prompt_display = prompt[:50] + "..." if len(prompt) > 50 else prompt - return f""" -[bold cyan]═══ CONFIGURATION ═══[/] + config_lines.append(f"[bold]System Prompt:[/] {system_prompt_display}") + config_lines.append("") -[bold]API Key:[/] {api_key_display} -[bold]Base URL:[/] {self.settings.base_url} -[bold]Default Model:[/] {self.settings.default_model or "Not set"} + # Web Search Configuration + config_lines.append("[bold yellow]Web Search Configuration:[/]") + config_lines.append(f"[bold]Search Provider:[/] {self.settings.search_provider}") -[bold]System Prompt:[/] {system_prompt_display} + # Show API key status based on search provider + if self.settings.search_provider == "google": + config_lines.append(f"[bold]Google API Key:[/] {mask_key(self.settings.google_api_key)}") + config_lines.append(f"[bold]Search Engine ID:[/] {mask_key(self.settings.google_search_engine_id)}") + elif self.settings.search_provider == "duckduckgo": + config_lines.append("[dim] DuckDuckGo requires no configuration (free)[/]") + elif self.settings.search_provider == "anthropic_native": + config_lines.append("[dim] Uses Anthropic API ($0.01 per search)[/]") -[bold]Streaming:[/] {"on" if self.settings.stream_enabled else "off"} -[bold]Cost Warning:[/] ${self.settings.cost_warning_threshold:.4f} -[bold]Max Tokens:[/] {self.settings.max_tokens} -[bold]Default Online:[/] {"on" if self.settings.default_online_mode else "off"} -[bold]Log Level:[/] {self.settings.log_level} + config_lines.append("") -[dim]Use /config [setting] [value] to modify settings[/] -""" + # Other settings + config_lines.append("[bold]Streaming:[/] " + ("on" if self.settings.stream_enabled else "off")) + config_lines.append(f"[bold]Cost Warning:[/] ${self.settings.cost_warning_threshold:.4f}") + config_lines.append(f"[bold]Max Tokens:[/] {self.settings.max_tokens}") + config_lines.append(f"[bold]Default Online:[/] " + ("on" if self.settings.default_online_mode else "off")) + config_lines.append(f"[bold]Log Level:[/] {self.settings.log_level}") + config_lines.append("\n[dim]Use /config [setting] [value] to modify settings[/]") + + return "\n".join(config_lines) def on_button_pressed(self, event: Button.Pressed) -> None: """Handle button press.""" diff --git a/oai/tui/screens/credits_screen.py b/oai/tui/screens/credits_screen.py index f3bed9c..862e0c3 100644 --- a/oai/tui/screens/credits_screen.py +++ b/oai/tui/screens/credits_screen.py @@ -83,37 +83,70 @@ class CreditsScreen(ModalScreen[None]): def _get_credits_text(self) -> str: """Generate the credits text.""" if not self.credits_data: - return "[yellow]No credit information available[/]" + # Provider-specific message when credits aren't available + if self.client.provider_name == "anthropic": + return """[yellow]Credit information not available via API[/] - total = self.credits_data.get("total_credits", 0) - used = self.credits_data.get("used_credits", 0) +Anthropic does not provide a public API endpoint for checking +account credits. + +To view your account balance and usage: +[cyan]→ Visit console.anthropic.com[/] +[cyan]→ Navigate to Settings → Billing[/] +""" + elif self.client.provider_name == "openai": + return """[yellow]Credit information not available via API[/] + +OpenAI does not provide credit balance through their API. + +To view your account usage: +[cyan]→ Visit platform.openai.com[/] +[cyan]→ Navigate to Usage[/] +""" + elif self.client.provider_name == "ollama": + return """[yellow]Credit information not applicable[/] + +Ollama runs locally on your machine and does not use credits. +""" + else: + return "[yellow]No credit information available[/]" + + provider_name = self.client.provider_name.upper() + total = self.credits_data.get("total_credits") + used = self.credits_data.get("used_credits") remaining = self.credits_data.get("credits_left", 0) - # Calculate percentage used - if total > 0: - percent_used = (used / total) * 100 - percent_remaining = (remaining / total) * 100 - else: - percent_used = 0 - percent_remaining = 0 - - # Color code based on remaining credits - if percent_remaining > 50: + # Determine color based on absolute remaining amount + if remaining > 10: remaining_color = "green" - elif percent_remaining > 20: + elif remaining > 2: remaining_color = "yellow" else: remaining_color = "red" - return f""" -[bold cyan]═══ OPENROUTER CREDITS ═══[/] + lines = [f"[bold cyan]═══ {provider_name} CREDITS ═══[/]\n"] -[bold]Total Credits:[/] ${total:.2f} -[bold]Used:[/] ${used:.2f} [dim]({percent_used:.1f}%)[/] -[bold]Remaining:[/] [{remaining_color}]${remaining:.2f}[/] [dim]({percent_remaining:.1f}%)[/] + # If we have total/used info (OpenRouter) + if total is not None and used is not None: + percent_used = (used / total) * 100 if total > 0 else 0 + percent_remaining = (remaining / total) * 100 if total > 0 else 0 -[dim]Visit openrouter.ai to add more credits[/] -""" + lines.append(f"[bold]Total Credits:[/] ${total:.2f}") + lines.append(f"[bold]Used:[/] ${used:.2f} [dim]({percent_used:.1f}%)[/]") + lines.append(f"[bold]Remaining:[/] [{remaining_color}]${remaining:.2f}[/] [dim]({percent_remaining:.1f}%)[/]") + else: + # Anthropic - only shows balance + lines.append(f"[bold]Current Balance:[/] [{remaining_color}]${remaining:.2f}[/]") + + lines.append("") + + # Provider-specific footer + if self.client.provider_name == "openrouter": + lines.append("[dim]Visit openrouter.ai to add more credits[/]") + elif self.client.provider_name == "anthropic": + lines.append("[dim]Visit console.anthropic.com to manage billing[/]") + + return "\n".join(lines) def on_button_pressed(self, event: Button.Pressed) -> None: """Handle button press.""" diff --git a/oai/tui/styles.tcss b/oai/tui/styles.tcss index 7f04a9c..a854872 100644 --- a/oai/tui/styles.tcss +++ b/oai/tui/styles.tcss @@ -17,9 +17,11 @@ Header { ChatDisplay { background: $background; border: none; - padding: 1; + padding: 1 0 1 1; /* top right bottom left - no right padding for scrollbar */ scrollbar-background: $background; scrollbar-color: $primary; + scrollbar-size: 1 1; + scrollbar-gutter: stable; /* Reserve space for scrollbar */ overflow-y: auto; } @@ -43,7 +45,7 @@ SystemMessageWidget { AssistantMessageWidget { margin: 0 0 1 0; padding: 1; - background: $panel; + background: $background; border-left: thick $accent; height: auto; } @@ -59,6 +61,9 @@ AssistantMessageWidget { color: #cccccc; link-color: #888888; link-style: none; + border: none; + scrollbar-background: transparent; + scrollbar-color: #555555; } InputBar { diff --git a/oai/tui/widgets/command_dropdown.py b/oai/tui/widgets/command_dropdown.py index 938d3c8..9f76da9 100644 --- a/oai/tui/widgets/command_dropdown.py +++ b/oai/tui/widgets/command_dropdown.py @@ -59,9 +59,15 @@ class CommandDropdown(VerticalScroll): # Get base commands with descriptions base_commands = [ ("/help", "Show help screen"), + ("/commands", "Show all commands"), ("/model", "Select AI model"), + ("/provider", "Switch AI provider"), + ("/provider openrouter", "Switch to OpenRouter"), + ("/provider anthropic", "Switch to Anthropic (Claude)"), + ("/provider openai", "Switch to OpenAI (GPT)"), + ("/provider ollama", "Switch to Ollama (local)"), ("/stats", "Show session statistics"), - ("/credits", "Check account credits"), + ("/credits", "Check account credits or view console link"), ("/clear", "Clear chat display"), ("/reset", "Reset conversation history"), ("/memory on", "Enable conversation memory"), @@ -78,7 +84,19 @@ class CommandDropdown(VerticalScroll): ("/prev", "Show previous message"), ("/next", "Show next message"), ("/config", "View configuration"), - ("/config api", "Set API key"), + ("/config provider", "Set default provider"), + ("/config search_provider", "Set search provider (anthropic_native/duckduckgo/google)"), + ("/config openrouter_api_key", "Set OpenRouter API key"), + ("/config anthropic_api_key", "Set Anthropic API key"), + ("/config openai_api_key", "Set OpenAI API key"), + ("/config ollama_base_url", "Set Ollama server URL"), + ("/config google_api_key", "Set Google API key"), + ("/config google_search_engine_id", "Set Google Search Engine ID"), + ("/config online", "Set default online mode (on/off)"), + ("/config stream", "Toggle streaming (on/off)"), + ("/config model", "Set default model"), + ("/config system", "Set system prompt"), + ("/config maxtoken", "Set token limit"), ("/system", "Set system prompt"), ("/maxtoken", "Set token limit"), ("/retry", "Retry last prompt"), @@ -117,12 +135,30 @@ class CommandDropdown(VerticalScroll): # Remove the leading slash for filtering filter_without_slash = filter_text[1:].lower() - # Filter commands - show if filter text is contained anywhere in the command + # Filter commands if filter_without_slash: - matching = [ - (cmd, desc) for cmd, desc in self._all_commands - if filter_without_slash in cmd[1:].lower() # Skip the / in command for matching - ] + matching = [] + + # Check if user typed a parent command (e.g., "/config") + # If so, show all sub-commands that start with it + parent_command = "/" + filter_without_slash + has_subcommands = any( + cmd.startswith(parent_command + " ") + for cmd, _ in self._all_commands + ) + + if has_subcommands and parent_command in [cmd for cmd, _ in self._all_commands]: + # Show the parent command and all its sub-commands + matching = [ + (cmd, desc) for cmd, desc in self._all_commands + if cmd == parent_command or cmd.startswith(parent_command + " ") + ] + else: + # Regular filtering - show if filter text is contained anywhere in the command + matching = [ + (cmd, desc) for cmd, desc in self._all_commands + if filter_without_slash in cmd[1:].lower() # Skip the / in command for matching + ] else: # Show all commands when just "/" is typed matching = self._all_commands @@ -131,8 +167,8 @@ class CommandDropdown(VerticalScroll): self.remove_class("visible") return - # Add options - limit to 10 results - for cmd, desc in matching[:10]: + # Add options - limit to 15 results (increased from 10 for sub-commands) + for cmd, desc in matching[:15]: # Format: command in white, description in gray, separated by spaces label = f"{cmd} [dim]{desc}[/]" if desc else cmd option_list.add_option(Option(label, id=cmd)) diff --git a/oai/tui/widgets/header.py b/oai/tui/widgets/header.py index cfbd09a..ef09a68 100644 --- a/oai/tui/widgets/header.py +++ b/oai/tui/widgets/header.py @@ -8,11 +8,12 @@ from typing import Optional, Dict, Any class Header(Static): """Header displaying app title, version, current model, and capabilities.""" - def __init__(self, version: str = "3.0.1", model: str = "", model_info: Optional[Dict[str, Any]] = None): + def __init__(self, version: str = "3.0.1", model: str = "", model_info: Optional[Dict[str, Any]] = None, provider: str = ""): super().__init__() self.version = version self.model = model self.model_info = model_info or {} + self.provider = provider def compose(self) -> ComposeResult: """Compose the header.""" @@ -51,15 +52,32 @@ class Header(Static): def _format_header(self) -> str: """Format the header text.""" - model_text = f" | {self.model}" if self.model else "" + # Show provider : model format + if self.provider and self.model: + provider_model = f"[bold cyan]{self.provider}[/] [dim]:[/] [bold]{self.model}[/]" + elif self.provider: + provider_model = f"[bold cyan]{self.provider}[/]" + elif self.model: + provider_model = f"[bold]{self.model}[/]" + else: + provider_model = "" + capabilities = self._format_capabilities() capabilities_text = f" {capabilities}" if capabilities else "" - return f"[bold cyan]oAI[/] [dim]v{self.version}[/]{model_text}{capabilities_text}" - def update_model(self, model: str, model_info: Optional[Dict[str, Any]] = None) -> None: + # Format: oAI v{version} | provider : model capabilities + version_text = f"[bold cyan]oAI[/] [dim]v{self.version}[/]" + if provider_model: + return f"{version_text} [dim]|[/] {provider_model}{capabilities_text}" + else: + return version_text + + def update_model(self, model: str, model_info: Optional[Dict[str, Any]] = None, provider: Optional[str] = None) -> None: """Update the displayed model and capabilities.""" self.model = model if model_info: self.model_info = model_info + if provider is not None: + self.provider = provider content = self.query_one("#header-content", Static) content.update(self._format_header()) diff --git a/oai/tui/widgets/message.py b/oai/tui/widgets/message.py index 98d6226..727aeb8 100644 --- a/oai/tui/widgets/message.py +++ b/oai/tui/widgets/message.py @@ -53,10 +53,11 @@ class SystemMessageWidget(Static): class AssistantMessageWidget(Static): """Widget for displaying assistant responses with streaming support.""" - def __init__(self, model_name: str = "Assistant"): + def __init__(self, model_name: str = "Assistant", chat_display=None): super().__init__() self.model_name = model_name self.full_text = "" + self.chat_display = chat_display def compose(self) -> ComposeResult: """Compose the assistant message.""" @@ -77,6 +78,11 @@ class AssistantMessageWidget(Static): md = Markdown(self.full_text, code_theme="github-dark", inline_code_theme="github-dark") log.write(md) + # Auto-scroll to keep the latest content visible + # Use call_after_refresh to ensure scroll happens after layout update + if self.chat_display: + self.chat_display.call_after_refresh(self.chat_display.scroll_end, animate=False) + if hasattr(chunk, "usage") and chunk.usage: usage = chunk.usage diff --git a/oai/utils/web_search.py b/oai/utils/web_search.py new file mode 100644 index 0000000..f6f49fb --- /dev/null +++ b/oai/utils/web_search.py @@ -0,0 +1,247 @@ +""" +Web search utilities for oAI. + +Provides web search capabilities for all providers (not just OpenRouter). +Uses DuckDuckGo by default (no API key needed). +""" + +import json +import re +from typing import Dict, List, Optional +from urllib.parse import quote_plus + +import requests + +from oai.utils.logging import get_logger + +logger = get_logger() + + +class WebSearchResult: + """Container for a single search result.""" + + def __init__(self, title: str, url: str, snippet: str): + self.title = title + self.url = url + self.snippet = snippet + + def __repr__(self) -> str: + return f"WebSearchResult(title='{self.title}', url='{self.url}')" + + +class WebSearchProvider: + """Base class for web search providers.""" + + def search(self, query: str, num_results: int = 5) -> List[WebSearchResult]: + """ + Perform a web search. + + Args: + query: Search query + num_results: Number of results to return + + Returns: + List of search results + """ + raise NotImplementedError + + +class DuckDuckGoSearch(WebSearchProvider): + """DuckDuckGo search provider (no API key needed).""" + + def __init__(self): + self.session = requests.Session() + self.session.headers.update({ + 'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36' + }) + + def search(self, query: str, num_results: int = 5) -> List[WebSearchResult]: + """ + Search using DuckDuckGo HTML interface. + + Args: + query: Search query + num_results: Number of results to return (default: 5) + + Returns: + List of search results + """ + try: + # Use DuckDuckGo HTML search + url = f"https://html.duckduckgo.com/html/?q={quote_plus(query)}" + response = self.session.get(url, timeout=10) + response.raise_for_status() + + results = [] + html = response.text + + # Parse results using regex (simple HTML parsing) + # Find all result blocks - they end at next result or end of results section + result_blocks = re.findall( + r'