Lot's of changes. None breaking. v3.0.0-b3
This commit is contained in:
141
README.md
141
README.md
@@ -1,19 +1,23 @@
|
||||
# oAI - OpenRouter AI Chat Client
|
||||
# oAI - Open AI Chat Client
|
||||
|
||||
A powerful, modern **Textual TUI** chat client for OpenRouter API with **MCP (Model Context Protocol)** support, enabling AI to access local files and query SQLite databases.
|
||||
A powerful, modern **Textual TUI** chat client with **multi-provider support** (OpenRouter, Anthropic, OpenAI, Ollama) and **MCP (Model Context Protocol)** integration, enabling AI to access local files and query SQLite databases.
|
||||
|
||||
## Features
|
||||
|
||||
### Core Features
|
||||
- 🖥️ **Modern Textual TUI** with async streaming and beautiful interface
|
||||
- 🤖 Interactive chat with 300+ AI models via OpenRouter
|
||||
- 🔄 **Multi-Provider Support** - OpenRouter, Anthropic (Claude), OpenAI (ChatGPT), Ollama (local)
|
||||
- 🤖 Interactive chat with 300+ AI models across providers
|
||||
- 🔍 Model selection with search, filtering, and capability icons
|
||||
- 💾 Conversation save/load/export (Markdown, JSON, HTML)
|
||||
- 📎 File attachments (images, PDFs, code files)
|
||||
- 💰 Real-time cost tracking and credit monitoring
|
||||
- 💰 Real-time cost tracking and credit monitoring (OpenRouter)
|
||||
- 🎨 Dark theme with syntax highlighting and Markdown rendering
|
||||
- 📝 Command history navigation (Up/Down arrows)
|
||||
- 🌐 Online mode (web search capabilities)
|
||||
- 🌐 **Universal Online Mode** - Web search for ALL providers:
|
||||
- **Anthropic Native** - Built-in search with automatic citations ($0.01/search)
|
||||
- **DuckDuckGo** - Free web scraping (all providers)
|
||||
- **Google Custom Search** - Premium search option
|
||||
- 🧠 Conversation memory toggle
|
||||
- ⌨️ Keyboard shortcuts (F1=Help, F2=Models, Ctrl+S=Stats)
|
||||
|
||||
@@ -36,7 +40,11 @@ A powerful, modern **Textual TUI** chat client for OpenRouter API with **MCP (Mo
|
||||
## Requirements
|
||||
|
||||
- Python 3.10-3.13
|
||||
- OpenRouter API key ([get one here](https://openrouter.ai))
|
||||
- API key for your chosen provider:
|
||||
- **OpenRouter**: [openrouter.ai](https://openrouter.ai)
|
||||
- **Anthropic**: [console.anthropic.com](https://console.anthropic.com)
|
||||
- **OpenAI**: [platform.openai.com](https://platform.openai.com)
|
||||
- **Ollama**: No API key needed (local server)
|
||||
|
||||
## Installation
|
||||
|
||||
@@ -83,28 +91,111 @@ pip install -e .
|
||||
# Start oAI (launches TUI)
|
||||
oai
|
||||
|
||||
# Start with specific provider
|
||||
oai --provider anthropic
|
||||
oai --provider openai
|
||||
oai --provider ollama
|
||||
|
||||
# Or with options
|
||||
oai --model gpt-4o --online --mcp
|
||||
oai --provider openrouter --model gpt-4o --online --mcp
|
||||
|
||||
# Show version
|
||||
oai version
|
||||
```
|
||||
|
||||
On first run, you'll be prompted for your OpenRouter API key.
|
||||
On first run, you'll be prompted for your API key. Configure additional providers anytime with `/config`.
|
||||
|
||||
### Enable Web Search (All Providers)
|
||||
|
||||
```bash
|
||||
# Using Anthropic native search (best quality, automatic citations)
|
||||
/config search_provider anthropic_native
|
||||
/online on
|
||||
|
||||
# Using free DuckDuckGo (works with all providers)
|
||||
/config search_provider duckduckgo
|
||||
/online on
|
||||
|
||||
# Using Google Custom Search (requires API key)
|
||||
/config search_provider google
|
||||
/config google_api_key YOUR_KEY
|
||||
/config google_search_engine_id YOUR_ID
|
||||
/online on
|
||||
```
|
||||
|
||||
### Basic Commands
|
||||
|
||||
```bash
|
||||
# In the TUI interface:
|
||||
/provider # Show current provider or switch
|
||||
/provider anthropic # Switch to Anthropic (Claude)
|
||||
/provider openai # Switch to OpenAI (ChatGPT)
|
||||
/provider ollama # Switch to Ollama (local)
|
||||
/model # Select AI model (or press F2)
|
||||
/online on # Enable web search
|
||||
/help # Show all commands (or press F1)
|
||||
/mcp on # Enable file/database access
|
||||
/stats # View session statistics (or press Ctrl+S)
|
||||
/config # View configuration settings
|
||||
/credits # Check account credits
|
||||
/credits # Check account credits (shows API balance or console link)
|
||||
Ctrl+Q # Quit
|
||||
```
|
||||
|
||||
## Web Search
|
||||
|
||||
oAI provides universal web search capabilities for all AI providers with three options:
|
||||
|
||||
### Anthropic Native Search (Recommended for Anthropic)
|
||||
|
||||
Anthropic's built-in web search API with automatic citations:
|
||||
|
||||
```bash
|
||||
/config search_provider anthropic_native
|
||||
/online on
|
||||
|
||||
# Now ask questions requiring current information
|
||||
What are the latest developments in quantum computing?
|
||||
```
|
||||
|
||||
**Features:**
|
||||
- ✅ **Automatic citations** - Claude cites its sources
|
||||
- ✅ **Smart searching** - Claude decides when to search
|
||||
- ✅ **Progressive searches** - Multiple searches for complex queries
|
||||
- ✅ **Best quality** - Professional-grade results
|
||||
|
||||
**Pricing:** $10 per 1,000 searches ($0.01 per search) + token costs
|
||||
|
||||
**Note:** Only works with Anthropic provider (Claude models)
|
||||
|
||||
### DuckDuckGo Search (Default - Free)
|
||||
|
||||
Free web scraping that works with ALL providers:
|
||||
|
||||
```bash
|
||||
/config search_provider duckduckgo # Default
|
||||
/online on
|
||||
|
||||
# Works with Anthropic, OpenAI, Ollama, and OpenRouter
|
||||
```
|
||||
|
||||
**Features:**
|
||||
- ✅ **Free** - No API key or costs
|
||||
- ✅ **Universal** - Works with all providers
|
||||
- ✅ **Privacy-friendly** - Uses DuckDuckGo
|
||||
|
||||
### Google Custom Search (Premium Option)
|
||||
|
||||
Google's Custom Search API for high-quality results:
|
||||
|
||||
```bash
|
||||
/config search_provider google
|
||||
/config google_api_key YOUR_GOOGLE_API_KEY
|
||||
/config google_search_engine_id YOUR_SEARCH_ENGINE_ID
|
||||
/online on
|
||||
```
|
||||
|
||||
Get your API key: [Google Custom Search API](https://developers.google.com/custom-search/v1/overview)
|
||||
|
||||
## MCP (Model Context Protocol)
|
||||
|
||||
MCP allows the AI to interact with your local files and databases.
|
||||
@@ -184,11 +275,18 @@ MCP allows the AI to interact with your local files and databases.
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `/config` | View settings |
|
||||
| `/config api` | Set API key |
|
||||
| `/config provider <name>` | Set default provider |
|
||||
| `/config openrouter_api_key` | Set OpenRouter API key |
|
||||
| `/config anthropic_api_key` | Set Anthropic API key |
|
||||
| `/config openai_api_key` | Set OpenAI API key |
|
||||
| `/config ollama_base_url` | Set Ollama server URL |
|
||||
| `/config search_provider <provider>` | Set search provider (anthropic_native/duckduckgo/google) |
|
||||
| `/config google_api_key` | Set Google API key (for Google search) |
|
||||
| `/config online on\|off` | Set default online mode |
|
||||
| `/config model <id>` | Set default model |
|
||||
| `/config stream on\|off` | Toggle streaming |
|
||||
| `/stats` | Session statistics |
|
||||
| `/credits` | Check credits |
|
||||
| `/credits` | Check credits (OpenRouter) |
|
||||
|
||||
## CLI Options
|
||||
|
||||
@@ -196,9 +294,10 @@ MCP allows the AI to interact with your local files and databases.
|
||||
oai [OPTIONS]
|
||||
|
||||
Options:
|
||||
-p, --provider TEXT Provider to use (openrouter/anthropic/openai/ollama)
|
||||
-m, --model TEXT Model ID to use
|
||||
-s, --system TEXT System prompt
|
||||
-o, --online Enable online mode
|
||||
-o, --online Enable online mode (OpenRouter only)
|
||||
--mcp Enable MCP server
|
||||
-v, --version Show version
|
||||
--help Show help
|
||||
@@ -280,7 +379,23 @@ pip install -e . --force-reinstall
|
||||
|
||||
## Version History
|
||||
|
||||
### v3.0.0 (Current)
|
||||
### v3.0.0-b3 (Current - Beta 3)
|
||||
- 🔄 **Multi-Provider Support** - OpenRouter, Anthropic (Claude), OpenAI (ChatGPT), Ollama (local)
|
||||
- 🔌 **Provider Switching** - Switch between providers mid-session with `/provider`
|
||||
- 🧠 **Provider Model Memory** - Remembers last used model per provider
|
||||
- ⚙️ **Provider Configuration** - Separate API keys for each provider
|
||||
- 🌐 **Universal Web Search** - Three search options for all providers:
|
||||
- **Anthropic Native** - Built-in search with citations ($0.01/search)
|
||||
- **DuckDuckGo** - Free web scraping (default)
|
||||
- **Google Custom Search** - Premium option
|
||||
- 🎨 **Enhanced Header** - Shows current provider and model
|
||||
- 📊 **Updated Config Screen** - Shows provider-specific settings
|
||||
- 🔧 **Command Dropdown** - Added `/provider` commands for discoverability
|
||||
- 🎯 **Auto-Scrolling** - Chat window auto-scrolls during streaming responses
|
||||
- 🎨 **Refined UI** - Thinner scrollbar, cleaner styling
|
||||
- 🔧 **Updated Models** - Claude 4.5 (Sonnet, Haiku, Opus)
|
||||
|
||||
### v3.0.0 (Beta)
|
||||
- 🎨 **Complete migration to Textual TUI** - Modern async terminal interface
|
||||
- 🗑️ **Removed CLI interface** - TUI-only for cleaner codebase (11.6% smaller)
|
||||
- 🖱️ **Modal screens** - Help, stats, config, credits, model selector
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
"""
|
||||
oAI - OpenRouter AI Chat Client
|
||||
oAI - Open AI Chat Client
|
||||
|
||||
A feature-rich terminal-based chat application that provides an interactive CLI
|
||||
interface to OpenRouter's unified AI API with advanced Model Context Protocol (MCP)
|
||||
A feature-rich terminal-based chat application with multi-provider support
|
||||
(OpenRouter, Anthropic, OpenAI, Ollama) and advanced Model Context Protocol (MCP)
|
||||
integration for filesystem and database access.
|
||||
|
||||
Author: Rune
|
||||
License: MIT
|
||||
"""
|
||||
|
||||
__version__ = "3.0.0-b2"
|
||||
__author__ = "Rune"
|
||||
__version__ = "3.0.0-b3"
|
||||
__author__ = "Rune Olsen"
|
||||
__license__ = "MIT"
|
||||
|
||||
# Lazy imports to avoid circular dependencies and improve startup time
|
||||
|
||||
43
oai/cli.py
43
oai/cli.py
@@ -21,7 +21,7 @@ from oai.utils.logging import LoggingManager, get_logger
|
||||
# Create Typer app
|
||||
app = typer.Typer(
|
||||
name="oai",
|
||||
help=f"oAI - OpenRouter AI Chat Client (TUI)\n\nVersion: {APP_VERSION}",
|
||||
help=f"oAI - Open AI Chat Client (TUI)\n\nVersion: {APP_VERSION}",
|
||||
add_completion=False,
|
||||
epilog="For more information, visit: " + APP_URL,
|
||||
)
|
||||
@@ -60,6 +60,12 @@ def main_callback(
|
||||
"--mcp",
|
||||
help="Enable MCP server",
|
||||
),
|
||||
provider: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--provider",
|
||||
"-p",
|
||||
help="AI provider to use (openrouter, anthropic, openai, ollama)",
|
||||
),
|
||||
) -> None:
|
||||
"""Main callback - launches TUI by default."""
|
||||
if version_flag:
|
||||
@@ -68,7 +74,7 @@ def main_callback(
|
||||
|
||||
# If no subcommand provided, launch TUI
|
||||
if ctx.invoked_subcommand is None:
|
||||
_launch_tui(model, system, online, mcp)
|
||||
_launch_tui(model, system, online, mcp, provider)
|
||||
|
||||
|
||||
def _launch_tui(
|
||||
@@ -76,8 +82,11 @@ def _launch_tui(
|
||||
system: Optional[str] = None,
|
||||
online: bool = False,
|
||||
mcp: bool = False,
|
||||
provider: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Launch the Textual TUI interface."""
|
||||
from oai.constants import VALID_PROVIDERS
|
||||
|
||||
# Setup logging
|
||||
logging_manager = LoggingManager()
|
||||
logging_manager.setup()
|
||||
@@ -86,17 +95,35 @@ def _launch_tui(
|
||||
# Load settings
|
||||
settings = Settings.load()
|
||||
|
||||
# Check API key
|
||||
if not settings.api_key:
|
||||
typer.echo("Error: No API key configured", err=True)
|
||||
typer.echo("Run: oai config api to set your API key", err=True)
|
||||
# Determine provider
|
||||
selected_provider = provider or settings.default_provider
|
||||
|
||||
# Validate provider
|
||||
if selected_provider not in VALID_PROVIDERS:
|
||||
typer.echo(f"Error: Invalid provider: {selected_provider}", err=True)
|
||||
typer.echo(f"Valid providers: {', '.join(VALID_PROVIDERS)}", err=True)
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Build provider API keys dict
|
||||
provider_api_keys = {
|
||||
"openrouter": settings.openrouter_api_key,
|
||||
"anthropic": settings.anthropic_api_key,
|
||||
"openai": settings.openai_api_key,
|
||||
}
|
||||
|
||||
# Check if provider is configured (except Ollama which doesn't need API key)
|
||||
if selected_provider != "ollama":
|
||||
if not provider_api_keys.get(selected_provider):
|
||||
typer.echo(f"Error: No API key configured for {selected_provider}", err=True)
|
||||
typer.echo(f"Set it with: oai config {selected_provider}_api_key <key>", err=True)
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Initialize client
|
||||
try:
|
||||
client = AIClient(
|
||||
api_key=settings.api_key,
|
||||
base_url=settings.base_url,
|
||||
provider_name=selected_provider,
|
||||
provider_api_keys=provider_api_keys,
|
||||
ollama_base_url=settings.ollama_base_url,
|
||||
)
|
||||
except Exception as e:
|
||||
typer.echo(f"Error: Failed to initialize client: {e}", err=True)
|
||||
|
||||
@@ -275,7 +275,7 @@ class OnlineCommand(Command):
|
||||
("Enable web search", "/online on"),
|
||||
("Disable web search", "/online off"),
|
||||
],
|
||||
notes="Not all models support online mode.",
|
||||
notes="OpenRouter: Native :online suffix. Other providers: DuckDuckGo search results injected into context.",
|
||||
)
|
||||
|
||||
def execute(self, args: str, context: CommandContext) -> CommandResult:
|
||||
@@ -820,7 +820,56 @@ class ConfigCommand(Command):
|
||||
from oai.constants import DEFAULT_SYSTEM_PROMPT
|
||||
|
||||
table = Table("Setting", "Value", show_header=True, header_style="bold magenta")
|
||||
table.add_row("API Key", "***" + settings.api_key[-4:] if settings.api_key else "Not set")
|
||||
|
||||
# Current Session Provider Info (if available)
|
||||
if context.session and context.session.client:
|
||||
client = context.session.client
|
||||
current_provider = client.provider_name
|
||||
|
||||
table.add_row(
|
||||
"[bold cyan]Current Provider[/]",
|
||||
f"[bold green]{current_provider}[/]"
|
||||
)
|
||||
|
||||
# Show current provider's base URL
|
||||
if hasattr(client.provider, 'base_url'):
|
||||
provider_url = client.provider.base_url
|
||||
table.add_row(" Provider URL", provider_url)
|
||||
|
||||
# Show API key status for current provider
|
||||
if current_provider == "ollama":
|
||||
table.add_row(" API Key", "[dim]Not required (local)[/]")
|
||||
else:
|
||||
current_key = settings.get_provider_api_key(current_provider)
|
||||
if current_key:
|
||||
masked = "***" + current_key[-4:] if len(current_key) > 4 else "***"
|
||||
table.add_row(" API Key", f"[green]{masked}[/]")
|
||||
else:
|
||||
table.add_row(" API Key", "[red]Not set[/]")
|
||||
|
||||
# Show current model
|
||||
if context.session.selected_model:
|
||||
model_name = context.session.selected_model.get("name", "Unknown")
|
||||
table.add_row(" Current Model", model_name)
|
||||
|
||||
# Add separator
|
||||
table.add_row("", "")
|
||||
|
||||
# Provider configuration settings
|
||||
table.add_row("[bold]Default Provider[/]", settings.default_provider)
|
||||
|
||||
# Provider API keys (masked)
|
||||
def mask_key(key: Optional[str]) -> str:
|
||||
if not key:
|
||||
return "[red]Not set[/]"
|
||||
return "***" + key[-4:] if len(key) > 4 else "***"
|
||||
|
||||
table.add_row("OpenRouter API Key", mask_key(settings.openrouter_api_key))
|
||||
table.add_row("Anthropic API Key", mask_key(settings.anthropic_api_key))
|
||||
table.add_row("OpenAI API Key", mask_key(settings.openai_api_key))
|
||||
table.add_row("Ollama Base URL", settings.ollama_base_url)
|
||||
|
||||
# Legacy/general settings
|
||||
table.add_row("Base URL", settings.base_url)
|
||||
table.add_row("Default Model", settings.default_model or "Not set")
|
||||
|
||||
@@ -845,7 +894,62 @@ class ConfigCommand(Command):
|
||||
setting = parts[0].lower()
|
||||
value = parts[1] if len(parts) > 1 else None
|
||||
|
||||
if setting == "api":
|
||||
# Provider selection
|
||||
if setting == "provider":
|
||||
from oai.constants import VALID_PROVIDERS
|
||||
|
||||
if value:
|
||||
if value.lower() in VALID_PROVIDERS:
|
||||
settings.set_default_provider(value.lower())
|
||||
return CommandResult.success()
|
||||
else:
|
||||
return CommandResult.error(f"Invalid provider. Valid: {', '.join(VALID_PROVIDERS)}")
|
||||
else:
|
||||
return CommandResult.error("Usage: /config provider <provider_name>")
|
||||
|
||||
# Provider-specific API keys
|
||||
elif setting in ["openrouter_api_key", "anthropic_api_key", "openai_api_key"]:
|
||||
if value:
|
||||
provider = setting.replace("_api_key", "")
|
||||
settings.set_provider_api_key(provider, value)
|
||||
return CommandResult.success()
|
||||
else:
|
||||
return CommandResult.success(data={"show_api_key_input": True, "provider": setting})
|
||||
|
||||
elif setting == "ollama_base_url":
|
||||
if value:
|
||||
settings.set_ollama_base_url(value)
|
||||
return CommandResult.success()
|
||||
else:
|
||||
return CommandResult.error("Usage: /config ollama_base_url <url>")
|
||||
|
||||
# Web search settings
|
||||
elif setting == "search_provider":
|
||||
if value:
|
||||
valid_providers = ["anthropic_native", "duckduckgo", "google"]
|
||||
if value.lower() in valid_providers:
|
||||
settings.set_search_provider(value.lower())
|
||||
return CommandResult.success()
|
||||
else:
|
||||
return CommandResult.error(f"Invalid search provider. Valid: {', '.join(valid_providers)}")
|
||||
else:
|
||||
return CommandResult.error("Usage: /config search_provider <provider>")
|
||||
|
||||
elif setting == "google_api_key":
|
||||
if value:
|
||||
settings.set_google_api_key(value)
|
||||
return CommandResult.success()
|
||||
else:
|
||||
return CommandResult.error("Usage: /config google_api_key <key>")
|
||||
|
||||
elif setting == "google_search_engine_id":
|
||||
if value:
|
||||
settings.set_google_search_engine_id(value)
|
||||
return CommandResult.success()
|
||||
else:
|
||||
return CommandResult.error("Usage: /config google_search_engine_id <id>")
|
||||
|
||||
elif setting == "api":
|
||||
if value:
|
||||
settings.set_api_key(value)
|
||||
else:
|
||||
@@ -1388,6 +1492,104 @@ class MCPCommand(Command):
|
||||
message = f"Unknown MCP command: {cmd}\nAvailable: on, off, status, add, remove, list, db, files, write, gitignore"
|
||||
return CommandResult.error(message)
|
||||
|
||||
|
||||
class ProviderCommand(Command):
|
||||
"""Switch or display current provider."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "/provider"
|
||||
|
||||
@property
|
||||
def help(self) -> CommandHelp:
|
||||
return CommandHelp(
|
||||
description="Switch AI provider or show current provider.",
|
||||
usage="/provider [provider_name]",
|
||||
examples=[
|
||||
("Show current provider", "/provider"),
|
||||
("Switch to Anthropic", "/provider anthropic"),
|
||||
("Switch to OpenAI", "/provider openai"),
|
||||
("Switch to Ollama", "/provider ollama"),
|
||||
("Switch to OpenRouter", "/provider openrouter"),
|
||||
],
|
||||
)
|
||||
|
||||
def execute(self, args: str, context: CommandContext) -> CommandResult:
|
||||
from oai.constants import VALID_PROVIDERS
|
||||
|
||||
provider_name = args.strip().lower()
|
||||
|
||||
if not context.session:
|
||||
return CommandResult.error("No active session")
|
||||
|
||||
if not provider_name:
|
||||
# Show current provider and capabilities
|
||||
current = context.session.client.provider_name
|
||||
capabilities = context.session.client.provider.capabilities
|
||||
|
||||
message = f"\n[bold]Current Provider:[/] {current}\n"
|
||||
message += f" Streaming: {capabilities.streaming}\n"
|
||||
message += f" Tools: {capabilities.tools}\n"
|
||||
message += f" Images: {capabilities.images}\n"
|
||||
message += f" Online: {capabilities.online}"
|
||||
|
||||
return CommandResult.success(message=message)
|
||||
|
||||
# Validate provider
|
||||
if provider_name not in VALID_PROVIDERS:
|
||||
return CommandResult.error(
|
||||
f"Invalid provider: {provider_name}\nValid: {', '.join(VALID_PROVIDERS)}"
|
||||
)
|
||||
|
||||
# Check API key (except for Ollama)
|
||||
if provider_name != "ollama":
|
||||
key = context.settings.get_provider_api_key(provider_name)
|
||||
if not key:
|
||||
return CommandResult.error(
|
||||
f"No API key for {provider_name}. Set with /config {provider_name}_api_key <key>"
|
||||
)
|
||||
|
||||
# Switch provider
|
||||
try:
|
||||
context.session.client.switch_provider(provider_name)
|
||||
|
||||
# Try to restore last used model for this provider
|
||||
last_model_id = context.settings.get_provider_model(provider_name)
|
||||
restored_model = False
|
||||
|
||||
if last_model_id:
|
||||
# Try to get the model from the new provider
|
||||
raw_model = context.session.client.get_raw_model(last_model_id)
|
||||
if raw_model:
|
||||
context.session.set_model(raw_model)
|
||||
restored_model = True
|
||||
message = f"Switched to {provider_name} provider (model: {raw_model.get('name', last_model_id)})"
|
||||
return CommandResult.success(message=message)
|
||||
|
||||
# If no stored model or model not found, select a default
|
||||
if not restored_model:
|
||||
models = context.session.client.list_models()
|
||||
if models:
|
||||
# Select a sensible default (first model)
|
||||
default_model = models[0]
|
||||
raw_model = context.session.client.get_raw_model(default_model.id)
|
||||
if raw_model:
|
||||
context.session.set_model(raw_model)
|
||||
# Save this as the default for this provider
|
||||
context.settings.set_provider_model(provider_name, default_model.id)
|
||||
message = f"Switched to {provider_name} provider (auto-selected: {default_model.name})"
|
||||
return CommandResult.success(message=message)
|
||||
else:
|
||||
# No models available, clear selection
|
||||
context.session.selected_model = None
|
||||
message = f"Switched to {provider_name} provider (no models available)"
|
||||
return CommandResult.success(message=message)
|
||||
|
||||
return CommandResult.success(message=f"Switched to {provider_name} provider")
|
||||
except Exception as e:
|
||||
return CommandResult.error(f"Failed to switch provider: {e}")
|
||||
|
||||
|
||||
class PasteCommand(Command):
|
||||
"""Paste from clipboard."""
|
||||
|
||||
@@ -1468,6 +1670,7 @@ def register_all_commands() -> None:
|
||||
DeleteCommand(),
|
||||
InfoCommand(),
|
||||
MCPCommand(),
|
||||
ProviderCommand(),
|
||||
PasteCommand(),
|
||||
ModelCommand(),
|
||||
]
|
||||
|
||||
@@ -80,6 +80,7 @@ class CommandContext:
|
||||
online_enabled: Whether online mode is enabled
|
||||
session_tokens: Session token counts
|
||||
session_cost: Session cost total
|
||||
session: Reference to the ChatSession (for provider switching, etc.)
|
||||
"""
|
||||
|
||||
settings: Optional["Settings"] = None
|
||||
@@ -100,6 +101,7 @@ class CommandContext:
|
||||
message_count: int = 0
|
||||
is_tui: bool = False # Flag for TUI mode
|
||||
current_index: int = 0
|
||||
session: Optional[Any] = None # Reference to ChatSession (avoid circular import)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -6,8 +6,9 @@ configuration with type safety, validation, and persistence.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
from typing import Optional, Dict
|
||||
from pathlib import Path
|
||||
import json
|
||||
|
||||
from oai.constants import (
|
||||
DEFAULT_BASE_URL,
|
||||
@@ -20,6 +21,8 @@ from oai.constants import (
|
||||
DEFAULT_LOG_LEVEL,
|
||||
DEFAULT_SYSTEM_PROMPT,
|
||||
VALID_LOG_LEVELS,
|
||||
DEFAULT_PROVIDER,
|
||||
OLLAMA_DEFAULT_URL,
|
||||
)
|
||||
from oai.config.database import get_database
|
||||
|
||||
@@ -34,7 +37,7 @@ class Settings:
|
||||
initialization and can be persisted back.
|
||||
|
||||
Attributes:
|
||||
api_key: OpenRouter API key
|
||||
api_key: Legacy OpenRouter API key (deprecated, use openrouter_api_key)
|
||||
base_url: API base URL
|
||||
default_model: Default model ID to use
|
||||
default_system_prompt: Custom system prompt (None = use hardcoded default, "" = blank)
|
||||
@@ -45,9 +48,33 @@ class Settings:
|
||||
log_max_size_mb: Maximum log file size in MB
|
||||
log_backup_count: Number of log file backups to keep
|
||||
log_level: Logging level (debug/info/warning/error/critical)
|
||||
|
||||
# Provider-specific settings
|
||||
default_provider: Default AI provider to use
|
||||
openrouter_api_key: OpenRouter API key
|
||||
anthropic_api_key: Anthropic API key
|
||||
openai_api_key: OpenAI API key
|
||||
ollama_base_url: Ollama server URL
|
||||
"""
|
||||
|
||||
# Legacy field (kept for backward compatibility)
|
||||
api_key: Optional[str] = None
|
||||
|
||||
# Provider configuration
|
||||
default_provider: str = DEFAULT_PROVIDER
|
||||
openrouter_api_key: Optional[str] = None
|
||||
anthropic_api_key: Optional[str] = None
|
||||
openai_api_key: Optional[str] = None
|
||||
ollama_base_url: str = OLLAMA_DEFAULT_URL
|
||||
provider_models: Dict[str, str] = field(default_factory=dict) # provider -> last_model_id
|
||||
|
||||
# Web search configuration (for online mode with non-OpenRouter providers)
|
||||
search_provider: str = "duckduckgo" # "duckduckgo" or "google"
|
||||
google_api_key: Optional[str] = None
|
||||
google_search_engine_id: Optional[str] = None
|
||||
search_num_results: int = 5
|
||||
|
||||
# General settings
|
||||
base_url: str = DEFAULT_BASE_URL
|
||||
default_model: Optional[str] = None
|
||||
default_system_prompt: Optional[str] = None
|
||||
@@ -134,8 +161,43 @@ class Settings:
|
||||
# Get system prompt from DB: None means not set (use default), "" means explicitly blank
|
||||
system_prompt_value = db.get_config("default_system_prompt")
|
||||
|
||||
# Migration: copy legacy api_key to openrouter_api_key if not already set
|
||||
legacy_api_key = db.get_config("api_key")
|
||||
openrouter_key = db.get_config("openrouter_api_key")
|
||||
|
||||
if legacy_api_key and not openrouter_key:
|
||||
db.set_config("openrouter_api_key", legacy_api_key)
|
||||
openrouter_key = legacy_api_key
|
||||
# Note: We keep the legacy api_key in DB for backward compatibility
|
||||
|
||||
# Load provider-model mapping
|
||||
provider_models_json = db.get_config("provider_models")
|
||||
provider_models = {}
|
||||
if provider_models_json:
|
||||
try:
|
||||
provider_models = json.loads(provider_models_json)
|
||||
except json.JSONDecodeError:
|
||||
provider_models = {}
|
||||
|
||||
return cls(
|
||||
api_key=db.get_config("api_key"),
|
||||
# Legacy field
|
||||
api_key=legacy_api_key,
|
||||
|
||||
# Provider configuration
|
||||
default_provider=db.get_config("default_provider") or DEFAULT_PROVIDER,
|
||||
openrouter_api_key=openrouter_key,
|
||||
anthropic_api_key=db.get_config("anthropic_api_key"),
|
||||
openai_api_key=db.get_config("openai_api_key"),
|
||||
ollama_base_url=db.get_config("ollama_base_url") or OLLAMA_DEFAULT_URL,
|
||||
provider_models=provider_models,
|
||||
|
||||
# Web search configuration
|
||||
search_provider=db.get_config("search_provider") or "duckduckgo",
|
||||
google_api_key=db.get_config("google_api_key"),
|
||||
google_search_engine_id=db.get_config("google_search_engine_id"),
|
||||
search_num_results=parse_int(db.get_config("search_num_results"), 5),
|
||||
|
||||
# General settings
|
||||
base_url=db.get_config("base_url") or DEFAULT_BASE_URL,
|
||||
default_model=db.get_config("default_model"),
|
||||
default_system_prompt=system_prompt_value,
|
||||
@@ -331,6 +393,155 @@ class Settings:
|
||||
self.log_max_size_mb = min(size_mb, 100)
|
||||
get_database().set_config("log_max_size_mb", str(self.log_max_size_mb))
|
||||
|
||||
def set_provider_api_key(self, provider: str, api_key: str) -> None:
|
||||
"""
|
||||
Set and persist an API key for a specific provider.
|
||||
|
||||
Args:
|
||||
provider: Provider name ("openrouter", "anthropic", "openai")
|
||||
api_key: The API key to set
|
||||
|
||||
Raises:
|
||||
ValueError: If provider is invalid
|
||||
"""
|
||||
provider = provider.lower()
|
||||
api_key = api_key.strip()
|
||||
|
||||
if provider == "openrouter":
|
||||
self.openrouter_api_key = api_key
|
||||
get_database().set_config("openrouter_api_key", api_key)
|
||||
elif provider == "anthropic":
|
||||
self.anthropic_api_key = api_key
|
||||
get_database().set_config("anthropic_api_key", api_key)
|
||||
elif provider == "openai":
|
||||
self.openai_api_key = api_key
|
||||
get_database().set_config("openai_api_key", api_key)
|
||||
else:
|
||||
raise ValueError(f"Invalid provider: {provider}")
|
||||
|
||||
def get_provider_api_key(self, provider: str) -> Optional[str]:
|
||||
"""
|
||||
Get the API key for a specific provider.
|
||||
|
||||
Args:
|
||||
provider: Provider name ("openrouter", "anthropic", "openai", "ollama")
|
||||
|
||||
Returns:
|
||||
API key or None if not set
|
||||
|
||||
Raises:
|
||||
ValueError: If provider is invalid
|
||||
"""
|
||||
provider = provider.lower()
|
||||
|
||||
if provider == "openrouter":
|
||||
return self.openrouter_api_key
|
||||
elif provider == "anthropic":
|
||||
return self.anthropic_api_key
|
||||
elif provider == "openai":
|
||||
return self.openai_api_key
|
||||
elif provider == "ollama":
|
||||
return "" # Ollama doesn't require an API key
|
||||
else:
|
||||
raise ValueError(f"Invalid provider: {provider}")
|
||||
|
||||
def set_default_provider(self, provider: str) -> None:
|
||||
"""
|
||||
Set and persist the default provider.
|
||||
|
||||
Args:
|
||||
provider: Provider name
|
||||
|
||||
Raises:
|
||||
ValueError: If provider is invalid
|
||||
"""
|
||||
from oai.constants import VALID_PROVIDERS
|
||||
|
||||
provider = provider.lower()
|
||||
if provider not in VALID_PROVIDERS:
|
||||
raise ValueError(
|
||||
f"Invalid provider: {provider}. "
|
||||
f"Valid providers: {', '.join(VALID_PROVIDERS)}"
|
||||
)
|
||||
|
||||
self.default_provider = provider
|
||||
get_database().set_config("default_provider", provider)
|
||||
|
||||
def set_ollama_base_url(self, url: str) -> None:
|
||||
"""
|
||||
Set and persist the Ollama base URL.
|
||||
|
||||
Args:
|
||||
url: Ollama server URL
|
||||
"""
|
||||
self.ollama_base_url = url.strip()
|
||||
get_database().set_config("ollama_base_url", self.ollama_base_url)
|
||||
|
||||
def set_search_provider(self, provider: str) -> None:
|
||||
"""
|
||||
Set and persist the web search provider.
|
||||
|
||||
Args:
|
||||
provider: Search provider ("anthropic_native", "duckduckgo", "google")
|
||||
|
||||
Raises:
|
||||
ValueError: If provider is invalid
|
||||
"""
|
||||
valid_providers = ["anthropic_native", "duckduckgo", "google"]
|
||||
provider = provider.lower()
|
||||
if provider not in valid_providers:
|
||||
raise ValueError(
|
||||
f"Invalid search provider: {provider}. "
|
||||
f"Valid providers: {', '.join(valid_providers)}"
|
||||
)
|
||||
|
||||
self.search_provider = provider
|
||||
get_database().set_config("search_provider", provider)
|
||||
|
||||
def set_google_api_key(self, api_key: str) -> None:
|
||||
"""
|
||||
Set and persist the Google API key for Google Custom Search.
|
||||
|
||||
Args:
|
||||
api_key: The Google API key
|
||||
"""
|
||||
self.google_api_key = api_key.strip()
|
||||
get_database().set_config("google_api_key", self.google_api_key)
|
||||
|
||||
def set_google_search_engine_id(self, engine_id: str) -> None:
|
||||
"""
|
||||
Set and persist the Google Custom Search Engine ID.
|
||||
|
||||
Args:
|
||||
engine_id: The Google Search Engine ID
|
||||
"""
|
||||
self.google_search_engine_id = engine_id.strip()
|
||||
get_database().set_config("google_search_engine_id", self.google_search_engine_id)
|
||||
|
||||
def get_provider_model(self, provider: str) -> Optional[str]:
|
||||
"""
|
||||
Get the last used model for a provider.
|
||||
|
||||
Args:
|
||||
provider: Provider name
|
||||
|
||||
Returns:
|
||||
Model ID or None if not set
|
||||
"""
|
||||
return self.provider_models.get(provider)
|
||||
|
||||
def set_provider_model(self, provider: str, model_id: str) -> None:
|
||||
"""
|
||||
Set and persist the last used model for a provider.
|
||||
|
||||
Args:
|
||||
provider: Provider name
|
||||
model_id: Model ID to remember
|
||||
"""
|
||||
self.provider_models[provider] = model_id
|
||||
# Save to database as JSON
|
||||
get_database().set_config("provider_models", json.dumps(self.provider_models))
|
||||
|
||||
|
||||
# Global settings instance
|
||||
_settings: Optional[Settings] = None
|
||||
|
||||
@@ -20,7 +20,7 @@ from oai import __version__
|
||||
APP_NAME = "oAI"
|
||||
APP_VERSION = __version__ # Single source of truth in oai/__init__.py
|
||||
APP_URL = "https://iurl.no/oai"
|
||||
APP_DESCRIPTION = "OpenRouter AI Chat Client with MCP Integration"
|
||||
APP_DESCRIPTION = "Open AI Chat Client with Multi-Provider Support"
|
||||
|
||||
# =============================================================================
|
||||
# FILE PATHS
|
||||
@@ -42,6 +42,26 @@ DEFAULT_STREAM_ENABLED = True
|
||||
DEFAULT_MAX_TOKENS = 100_000
|
||||
DEFAULT_ONLINE_MODE = False
|
||||
|
||||
# =============================================================================
|
||||
# PROVIDER CONFIGURATION
|
||||
# =============================================================================
|
||||
|
||||
# Provider names
|
||||
PROVIDER_OPENROUTER = "openrouter"
|
||||
PROVIDER_ANTHROPIC = "anthropic"
|
||||
PROVIDER_OPENAI = "openai"
|
||||
PROVIDER_OLLAMA = "ollama"
|
||||
|
||||
VALID_PROVIDERS = [PROVIDER_OPENROUTER, PROVIDER_ANTHROPIC, PROVIDER_OPENAI, PROVIDER_OLLAMA]
|
||||
|
||||
# Provider base URLs
|
||||
ANTHROPIC_BASE_URL = "https://api.anthropic.com"
|
||||
OPENAI_BASE_URL = "https://api.openai.com/v1"
|
||||
OLLAMA_DEFAULT_URL = "http://localhost:11434"
|
||||
|
||||
# Default provider
|
||||
DEFAULT_PROVIDER = PROVIDER_OPENROUTER
|
||||
|
||||
# =============================================================================
|
||||
# DEFAULT SYSTEM PROMPT
|
||||
# =============================================================================
|
||||
|
||||
@@ -9,7 +9,7 @@ import asyncio
|
||||
import json
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional, Union
|
||||
|
||||
from oai.constants import APP_NAME, APP_URL, MODEL_PRICING
|
||||
from oai.constants import APP_NAME, APP_URL, MODEL_PRICING, OLLAMA_DEFAULT_URL
|
||||
from oai.providers.base import (
|
||||
AIProvider,
|
||||
ChatMessage,
|
||||
@@ -19,7 +19,6 @@ from oai.providers.base import (
|
||||
ToolCall,
|
||||
UsageStats,
|
||||
)
|
||||
from oai.providers.openrouter import OpenRouterProvider
|
||||
from oai.utils.logging import get_logger
|
||||
|
||||
|
||||
@@ -32,37 +31,66 @@ class AIClient:
|
||||
|
||||
Attributes:
|
||||
provider: The underlying AI provider
|
||||
provider_name: Name of the current provider
|
||||
default_model: Default model ID to use
|
||||
http_headers: Custom HTTP headers for requests
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: Optional[str] = None,
|
||||
provider_class: type = OpenRouterProvider,
|
||||
provider_name: str = "openrouter",
|
||||
provider_api_keys: Optional[Dict[str, str]] = None,
|
||||
ollama_base_url: str = OLLAMA_DEFAULT_URL,
|
||||
app_name: str = APP_NAME,
|
||||
app_url: str = APP_URL,
|
||||
):
|
||||
"""
|
||||
Initialize the AI client.
|
||||
Initialize the AI client with specified provider.
|
||||
|
||||
Args:
|
||||
api_key: API key for authentication
|
||||
base_url: Optional custom base URL
|
||||
provider_class: Provider class to use (default: OpenRouterProvider)
|
||||
provider_name: Provider to use ("openrouter", "anthropic", "openai", "ollama")
|
||||
provider_api_keys: Dict mapping provider names to API keys
|
||||
ollama_base_url: Base URL for Ollama server
|
||||
app_name: Application name for headers
|
||||
app_url: Application URL for headers
|
||||
|
||||
Raises:
|
||||
ValueError: If provider is invalid or not configured
|
||||
"""
|
||||
self.provider: AIProvider = provider_class(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
app_name=app_name,
|
||||
app_url=app_url,
|
||||
)
|
||||
from oai.providers.registry import get_provider_class
|
||||
|
||||
self.provider_name = provider_name
|
||||
self.provider_api_keys = provider_api_keys or {}
|
||||
self.ollama_base_url = ollama_base_url
|
||||
self.app_name = app_name
|
||||
self.app_url = app_url
|
||||
|
||||
# Get provider class
|
||||
provider_class = get_provider_class(provider_name)
|
||||
if not provider_class:
|
||||
raise ValueError(f"Unknown provider: {provider_name}")
|
||||
|
||||
# Get API key for this provider
|
||||
api_key = self.provider_api_keys.get(provider_name, "")
|
||||
|
||||
# Initialize provider with appropriate parameters
|
||||
if provider_name == "ollama":
|
||||
self.provider: AIProvider = provider_class(
|
||||
api_key=api_key,
|
||||
base_url=ollama_base_url,
|
||||
)
|
||||
else:
|
||||
self.provider: AIProvider = provider_class(
|
||||
api_key=api_key,
|
||||
app_name=app_name,
|
||||
app_url=app_url,
|
||||
)
|
||||
|
||||
self.default_model: Optional[str] = None
|
||||
self.logger = get_logger()
|
||||
|
||||
self.logger.info(f"Initialized {provider_name} provider")
|
||||
|
||||
def list_models(self, filter_text_only: bool = True) -> List[ModelInfo]:
|
||||
"""
|
||||
Get available models.
|
||||
@@ -420,3 +448,49 @@ class AIClient:
|
||||
"""Clear the provider's model cache."""
|
||||
if hasattr(self.provider, "clear_cache"):
|
||||
self.provider.clear_cache()
|
||||
|
||||
def switch_provider(self, provider_name: str, ollama_base_url: Optional[str] = None) -> None:
|
||||
"""
|
||||
Switch to a different provider.
|
||||
|
||||
Args:
|
||||
provider_name: Provider to switch to
|
||||
ollama_base_url: Optional Ollama base URL (if switching to Ollama)
|
||||
|
||||
Raises:
|
||||
ValueError: If provider is invalid or not configured
|
||||
"""
|
||||
from oai.providers.registry import get_provider_class
|
||||
|
||||
# Get provider class
|
||||
provider_class = get_provider_class(provider_name)
|
||||
if not provider_class:
|
||||
raise ValueError(f"Unknown provider: {provider_name}")
|
||||
|
||||
# Get API key
|
||||
api_key = self.provider_api_keys.get(provider_name, "")
|
||||
|
||||
# Check API key requirement
|
||||
if provider_name != "ollama" and not api_key:
|
||||
raise ValueError(f"No API key configured for {provider_name}")
|
||||
|
||||
# Initialize new provider
|
||||
if provider_name == "ollama":
|
||||
base_url = ollama_base_url or self.ollama_base_url
|
||||
self.provider = provider_class(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
)
|
||||
self.ollama_base_url = base_url
|
||||
else:
|
||||
self.provider = provider_class(
|
||||
api_key=api_key,
|
||||
app_name=self.app_name,
|
||||
app_url=self.app_url,
|
||||
)
|
||||
|
||||
self.provider_name = provider_name
|
||||
self.logger.info(f"Switched to {provider_name} provider")
|
||||
|
||||
# Clear model cache when switching providers
|
||||
self.default_model = None
|
||||
|
||||
@@ -23,6 +23,9 @@ from oai.core.client import AIClient
|
||||
from oai.mcp.manager import MCPManager
|
||||
from oai.providers.base import ChatResponse, StreamChunk, UsageStats
|
||||
from oai.utils.logging import get_logger
|
||||
from oai.utils.web_search import perform_web_search, format_search_results
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -164,6 +167,7 @@ class ChatSession:
|
||||
total_cost=self.stats.total_cost,
|
||||
message_count=self.stats.message_count,
|
||||
current_index=self.current_index,
|
||||
session=self,
|
||||
)
|
||||
|
||||
def set_model(self, model: Dict[str, Any]) -> None:
|
||||
@@ -290,6 +294,44 @@ class ChatSession:
|
||||
|
||||
# Build request parameters
|
||||
model_id = self.selected_model["id"]
|
||||
|
||||
# Handle online mode
|
||||
enable_web_search = False
|
||||
web_search_config = {}
|
||||
|
||||
if self.online_enabled:
|
||||
# OpenRouter handles online mode natively with :online suffix
|
||||
if self.client.provider_name == "openrouter":
|
||||
if hasattr(self.client.provider, "get_effective_model_id"):
|
||||
model_id = self.client.provider.get_effective_model_id(model_id, True)
|
||||
# Anthropic has native web search when search provider is set to anthropic_native
|
||||
elif self.client.provider_name == "anthropic" and self.settings.search_provider == "anthropic_native":
|
||||
enable_web_search = True
|
||||
web_search_config = {
|
||||
"max_uses": self.settings.search_num_results or 5
|
||||
}
|
||||
logger.info("Using Anthropic native web search")
|
||||
else:
|
||||
# For other providers, perform web search and inject results
|
||||
logger.info(f"Performing web search for: {user_input}")
|
||||
search_results = perform_web_search(
|
||||
user_input,
|
||||
num_results=self.settings.search_num_results,
|
||||
provider=self.settings.search_provider,
|
||||
google_api_key=self.settings.google_api_key,
|
||||
google_search_engine_id=self.settings.google_search_engine_id
|
||||
)
|
||||
|
||||
if search_results:
|
||||
# Inject search results into messages
|
||||
formatted_results = format_search_results(search_results)
|
||||
search_context = f"\n\n{formatted_results}\n\nPlease use the above web search results to help answer the user's question."
|
||||
|
||||
# Add search results to the last user message
|
||||
if messages and messages[-1]["role"] == "user":
|
||||
messages[-1]["content"] += search_context
|
||||
|
||||
logger.info(f"Injected {len(search_results)} search results into context")
|
||||
if self.online_enabled:
|
||||
if hasattr(self.client.provider, "get_effective_model_id"):
|
||||
model_id = self.client.provider.get_effective_model_id(model_id, True)
|
||||
@@ -320,6 +362,8 @@ class ChatSession:
|
||||
max_tokens=max_tokens,
|
||||
transforms=transforms,
|
||||
on_chunk=on_stream_chunk,
|
||||
enable_web_search=enable_web_search,
|
||||
web_search_config=web_search_config,
|
||||
)
|
||||
response_time = time.time() - start_time
|
||||
return full_text, usage, response_time
|
||||
@@ -455,6 +499,8 @@ class ChatSession:
|
||||
max_tokens: Optional[int] = None,
|
||||
transforms: Optional[List[str]] = None,
|
||||
on_chunk: Optional[Callable[[str], None]] = None,
|
||||
enable_web_search: bool = False,
|
||||
web_search_config: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[str, Optional[UsageStats]]:
|
||||
"""
|
||||
Stream a response with live display.
|
||||
@@ -465,6 +511,8 @@ class ChatSession:
|
||||
max_tokens: Max tokens
|
||||
transforms: Transforms
|
||||
on_chunk: Callback for chunks
|
||||
enable_web_search: Whether to enable Anthropic native web search
|
||||
web_search_config: Web search configuration
|
||||
|
||||
Returns:
|
||||
Tuple of (full_text, usage)
|
||||
@@ -475,6 +523,8 @@ class ChatSession:
|
||||
stream=True,
|
||||
max_tokens=max_tokens,
|
||||
transforms=transforms,
|
||||
enable_web_search=enable_web_search,
|
||||
web_search_config=web_search_config or {},
|
||||
)
|
||||
|
||||
if isinstance(response, ChatResponse):
|
||||
@@ -530,10 +580,45 @@ class ChatSession:
|
||||
# Disable streaming when tools are present
|
||||
stream = False
|
||||
|
||||
# Handle online mode
|
||||
model_id = self.selected_model["id"]
|
||||
enable_web_search = False
|
||||
web_search_config = {}
|
||||
|
||||
if self.online_enabled:
|
||||
if hasattr(self.client.provider, "get_effective_model_id"):
|
||||
model_id = self.client.provider.get_effective_model_id(model_id, True)
|
||||
# OpenRouter handles online mode natively with :online suffix
|
||||
if self.client.provider_name == "openrouter":
|
||||
if hasattr(self.client.provider, "get_effective_model_id"):
|
||||
model_id = self.client.provider.get_effective_model_id(model_id, True)
|
||||
# Anthropic has native web search when search provider is set to anthropic_native
|
||||
elif self.client.provider_name == "anthropic" and self.settings.search_provider == "anthropic_native":
|
||||
enable_web_search = True
|
||||
web_search_config = {
|
||||
"max_uses": self.settings.search_num_results or 5
|
||||
}
|
||||
logger.info("Using Anthropic native web search")
|
||||
else:
|
||||
# For other providers, perform web search and inject results
|
||||
logger.info(f"Performing web search for: {user_input}")
|
||||
search_results = await asyncio.to_thread(
|
||||
perform_web_search,
|
||||
user_input,
|
||||
num_results=self.settings.search_num_results,
|
||||
provider=self.settings.search_provider,
|
||||
google_api_key=self.settings.google_api_key,
|
||||
google_search_engine_id=self.settings.google_search_engine_id
|
||||
)
|
||||
|
||||
if search_results:
|
||||
# Inject search results into messages
|
||||
formatted_results = format_search_results(search_results)
|
||||
search_context = f"\n\n{formatted_results}\n\nPlease use the above web search results to help answer the user's question."
|
||||
|
||||
# Add search results to the last user message
|
||||
if messages and messages[-1]["role"] == "user":
|
||||
messages[-1]["content"] += search_context
|
||||
|
||||
logger.info(f"Injected {len(search_results)} search results into context")
|
||||
|
||||
transforms = ["middle-out"] if self.middle_out_enabled else None
|
||||
max_tokens = None
|
||||
@@ -557,6 +642,8 @@ class ChatSession:
|
||||
model_id=model_id,
|
||||
max_tokens=max_tokens,
|
||||
transforms=transforms,
|
||||
enable_web_search=enable_web_search,
|
||||
web_search_config=web_search_config,
|
||||
):
|
||||
yield chunk
|
||||
else:
|
||||
@@ -733,6 +820,8 @@ class ChatSession:
|
||||
model_id: str,
|
||||
max_tokens: Optional[int] = None,
|
||||
transforms: Optional[List[str]] = None,
|
||||
enable_web_search: bool = False,
|
||||
web_search_config: Optional[Dict[str, Any]] = None,
|
||||
) -> AsyncIterator[StreamChunk]:
|
||||
"""
|
||||
Async version of _stream_response for TUI.
|
||||
@@ -742,6 +831,8 @@ class ChatSession:
|
||||
model_id: Model ID
|
||||
max_tokens: Max tokens
|
||||
transforms: Transforms
|
||||
enable_web_search: Whether to enable Anthropic native web search
|
||||
web_search_config: Web search configuration
|
||||
|
||||
Yields:
|
||||
StreamChunk objects
|
||||
@@ -752,6 +843,8 @@ class ChatSession:
|
||||
stream=True,
|
||||
max_tokens=max_tokens,
|
||||
transforms=transforms,
|
||||
enable_web_search=enable_web_search,
|
||||
web_search_config=web_search_config or {},
|
||||
)
|
||||
|
||||
if isinstance(response, ChatResponse):
|
||||
|
||||
@@ -16,6 +16,16 @@ from oai.providers.base import (
|
||||
ProviderCapabilities,
|
||||
)
|
||||
from oai.providers.openrouter import OpenRouterProvider
|
||||
from oai.providers.anthropic import AnthropicProvider
|
||||
from oai.providers.openai import OpenAIProvider
|
||||
from oai.providers.ollama import OllamaProvider
|
||||
from oai.providers.registry import register_provider
|
||||
|
||||
# Register all providers
|
||||
register_provider("openrouter", OpenRouterProvider)
|
||||
register_provider("anthropic", AnthropicProvider)
|
||||
register_provider("openai", OpenAIProvider)
|
||||
register_provider("ollama", OllamaProvider)
|
||||
|
||||
__all__ = [
|
||||
# Base classes and types
|
||||
@@ -29,4 +39,7 @@ __all__ = [
|
||||
"ProviderCapabilities",
|
||||
# Provider implementations
|
||||
"OpenRouterProvider",
|
||||
"AnthropicProvider",
|
||||
"OpenAIProvider",
|
||||
"OllamaProvider",
|
||||
]
|
||||
|
||||
673
oai/providers/anthropic.py
Normal file
673
oai/providers/anthropic.py
Normal file
@@ -0,0 +1,673 @@
|
||||
"""
|
||||
Anthropic provider for Claude models.
|
||||
|
||||
This provider connects to Anthropic's API for accessing Claude models.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
|
||||
|
||||
import anthropic
|
||||
from anthropic.types import Message, MessageStreamEvent
|
||||
|
||||
from oai.constants import ANTHROPIC_BASE_URL
|
||||
from oai.providers.base import (
|
||||
AIProvider,
|
||||
ChatMessage,
|
||||
ChatResponse,
|
||||
ChatResponseChoice,
|
||||
ModelInfo,
|
||||
ProviderCapabilities,
|
||||
StreamChunk,
|
||||
ToolCall,
|
||||
ToolFunction,
|
||||
UsageStats,
|
||||
)
|
||||
from oai.utils.logging import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
# Model name aliases
|
||||
MODEL_ALIASES = {
|
||||
"claude-sonnet": "claude-sonnet-4-5-20250929",
|
||||
"claude-haiku": "claude-haiku-4-5-20251001",
|
||||
"claude-opus": "claude-opus-4-5-20251101",
|
||||
# Legacy aliases
|
||||
"claude-3-haiku": "claude-3-haiku-20240307",
|
||||
"claude-3-7-sonnet": "claude-3-7-sonnet-20250219",
|
||||
}
|
||||
|
||||
|
||||
class AnthropicProvider(AIProvider):
|
||||
"""
|
||||
Anthropic API provider.
|
||||
|
||||
Provides access to Claude 3.5 Sonnet, Claude 3 Opus, and other Anthropic models.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: Optional[str] = None,
|
||||
app_name: str = "oAI",
|
||||
app_url: str = "",
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""
|
||||
Initialize Anthropic provider.
|
||||
|
||||
Args:
|
||||
api_key: Anthropic API key
|
||||
base_url: Optional custom base URL
|
||||
app_name: Application name (for headers)
|
||||
app_url: Application URL (for headers)
|
||||
**kwargs: Additional arguments
|
||||
"""
|
||||
super().__init__(api_key, base_url or ANTHROPIC_BASE_URL)
|
||||
self.client = anthropic.Anthropic(api_key=api_key)
|
||||
self.async_client = anthropic.AsyncAnthropic(api_key=api_key)
|
||||
self._models_cache: Optional[List[ModelInfo]] = None
|
||||
|
||||
def _create_web_search_tool(self, config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Create Anthropic native web search tool definition.
|
||||
|
||||
Args:
|
||||
config: Optional configuration for web search (max_uses, allowed_domains, etc.)
|
||||
|
||||
Returns:
|
||||
Tool definition dict
|
||||
"""
|
||||
tool: Dict[str, Any] = {
|
||||
"type": "web_search_20250305",
|
||||
"name": "web_search",
|
||||
}
|
||||
|
||||
# Add optional parameters if provided
|
||||
if "max_uses" in config:
|
||||
tool["max_uses"] = config["max_uses"]
|
||||
else:
|
||||
tool["max_uses"] = 5 # Default
|
||||
|
||||
if "allowed_domains" in config:
|
||||
tool["allowed_domains"] = config["allowed_domains"]
|
||||
|
||||
if "blocked_domains" in config:
|
||||
tool["blocked_domains"] = config["blocked_domains"]
|
||||
|
||||
if "user_location" in config:
|
||||
tool["user_location"] = config["user_location"]
|
||||
|
||||
return tool
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Get provider name."""
|
||||
return "Anthropic"
|
||||
|
||||
@property
|
||||
def capabilities(self) -> ProviderCapabilities:
|
||||
"""Get provider capabilities."""
|
||||
return ProviderCapabilities(
|
||||
streaming=True,
|
||||
tools=True,
|
||||
images=True,
|
||||
online=True, # Web search via DuckDuckGo/Google
|
||||
max_context=200000,
|
||||
)
|
||||
|
||||
def list_models(self, filter_text_only: bool = True) -> List[ModelInfo]:
|
||||
"""
|
||||
List available Anthropic models.
|
||||
|
||||
Args:
|
||||
filter_text_only: Whether to filter for text models only
|
||||
|
||||
Returns:
|
||||
List of ModelInfo objects
|
||||
"""
|
||||
if self._models_cache:
|
||||
return self._models_cache
|
||||
|
||||
# Anthropic doesn't have a models list API, so we hardcode the available models
|
||||
models = [
|
||||
# Current Claude 4.5 models
|
||||
ModelInfo(
|
||||
id="claude-sonnet-4-5-20250929",
|
||||
name="Claude Sonnet 4.5",
|
||||
description="Smart model for complex agents and coding (recommended)",
|
||||
context_length=200000,
|
||||
pricing={"input": 3.0, "output": 15.0},
|
||||
supported_parameters=["temperature", "max_tokens", "stream", "tools"],
|
||||
input_modalities=["text", "image"],
|
||||
),
|
||||
ModelInfo(
|
||||
id="claude-haiku-4-5-20251001",
|
||||
name="Claude Haiku 4.5",
|
||||
description="Fastest model with near-frontier intelligence",
|
||||
context_length=200000,
|
||||
pricing={"input": 1.0, "output": 5.0},
|
||||
supported_parameters=["temperature", "max_tokens", "stream", "tools"],
|
||||
input_modalities=["text", "image"],
|
||||
),
|
||||
ModelInfo(
|
||||
id="claude-opus-4-5-20251101",
|
||||
name="Claude Opus 4.5",
|
||||
description="Premium model with maximum intelligence",
|
||||
context_length=200000,
|
||||
pricing={"input": 5.0, "output": 25.0},
|
||||
supported_parameters=["temperature", "max_tokens", "stream", "tools"],
|
||||
input_modalities=["text", "image"],
|
||||
),
|
||||
# Legacy models (still available)
|
||||
ModelInfo(
|
||||
id="claude-3-7-sonnet-20250219",
|
||||
name="Claude Sonnet 3.7",
|
||||
description="Legacy model - recommend migrating to 4.5",
|
||||
context_length=200000,
|
||||
pricing={"input": 3.0, "output": 15.0},
|
||||
supported_parameters=["temperature", "max_tokens", "stream", "tools"],
|
||||
input_modalities=["text", "image"],
|
||||
),
|
||||
ModelInfo(
|
||||
id="claude-3-haiku-20240307",
|
||||
name="Claude 3 Haiku",
|
||||
description="Legacy fast model - recommend migrating to 4.5",
|
||||
context_length=200000,
|
||||
pricing={"input": 0.25, "output": 1.25},
|
||||
supported_parameters=["temperature", "max_tokens", "stream", "tools"],
|
||||
input_modalities=["text", "image"],
|
||||
),
|
||||
]
|
||||
|
||||
self._models_cache = models
|
||||
logger.info(f"Loaded {len(models)} Anthropic models")
|
||||
return models
|
||||
|
||||
def get_model(self, model_id: str) -> Optional[ModelInfo]:
|
||||
"""
|
||||
Get information about a specific model.
|
||||
|
||||
Args:
|
||||
model_id: Model identifier
|
||||
|
||||
Returns:
|
||||
ModelInfo or None
|
||||
"""
|
||||
# Resolve alias
|
||||
resolved_id = MODEL_ALIASES.get(model_id, model_id)
|
||||
|
||||
models = self.list_models()
|
||||
for model in models:
|
||||
if model.id == resolved_id or model.id == model_id:
|
||||
return model
|
||||
return None
|
||||
|
||||
def chat(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[ChatMessage],
|
||||
stream: bool = False,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[ChatResponse, Iterator[StreamChunk]]:
|
||||
"""
|
||||
Send chat completion request to Anthropic.
|
||||
|
||||
Args:
|
||||
model: Model ID
|
||||
messages: Chat messages
|
||||
stream: Whether to stream response
|
||||
max_tokens: Maximum tokens
|
||||
temperature: Sampling temperature
|
||||
tools: Tool definitions
|
||||
tool_choice: Tool selection mode
|
||||
**kwargs: Additional parameters (including enable_web_search)
|
||||
|
||||
Returns:
|
||||
ChatResponse or Iterator[StreamChunk]
|
||||
"""
|
||||
# Resolve model alias
|
||||
model_id = MODEL_ALIASES.get(model, model)
|
||||
|
||||
# Extract system message (Anthropic requires it separate from messages)
|
||||
system_prompt, anthropic_messages = self._convert_messages(messages)
|
||||
|
||||
# Build request parameters
|
||||
params: Dict[str, Any] = {
|
||||
"model": model_id,
|
||||
"messages": anthropic_messages,
|
||||
"max_tokens": max_tokens or 4096,
|
||||
}
|
||||
|
||||
if system_prompt:
|
||||
params["system"] = system_prompt
|
||||
|
||||
if temperature is not None:
|
||||
params["temperature"] = temperature
|
||||
|
||||
# Prepare tools list
|
||||
tools_list = []
|
||||
|
||||
# Add web search tool if requested via kwargs
|
||||
if kwargs.get("enable_web_search", False):
|
||||
web_search_config = kwargs.get("web_search_config", {})
|
||||
tools_list.append(self._create_web_search_tool(web_search_config))
|
||||
logger.info("Added Anthropic native web search tool")
|
||||
|
||||
# Add user-provided tools
|
||||
if tools:
|
||||
# Convert tools to Anthropic format
|
||||
converted_tools = self._convert_tools(tools)
|
||||
tools_list.extend(converted_tools)
|
||||
|
||||
if tools_list:
|
||||
params["tools"] = tools_list
|
||||
|
||||
if tool_choice and tool_choice != "auto":
|
||||
# Anthropic uses different tool_choice format
|
||||
if tool_choice == "none":
|
||||
pass # Don't include tools
|
||||
elif tool_choice == "required":
|
||||
params["tool_choice"] = {"type": "any"}
|
||||
else:
|
||||
params["tool_choice"] = {"type": "tool", "name": tool_choice}
|
||||
|
||||
logger.debug(f"Anthropic request: model={model_id}, messages={len(anthropic_messages)}")
|
||||
|
||||
try:
|
||||
if stream:
|
||||
return self._stream_chat(params)
|
||||
else:
|
||||
return self._sync_chat(params)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Anthropic request failed: {e}")
|
||||
return ChatResponse(
|
||||
id="error",
|
||||
choices=[
|
||||
ChatResponseChoice(
|
||||
index=0,
|
||||
message=ChatMessage(role="assistant", content=f"Error: {str(e)}"),
|
||||
finish_reason="error",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
def _convert_messages(self, messages: List[ChatMessage]) -> tuple[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Convert messages to Anthropic format.
|
||||
|
||||
Anthropic requires system messages to be separate from the conversation.
|
||||
|
||||
Args:
|
||||
messages: List of ChatMessage objects
|
||||
|
||||
Returns:
|
||||
Tuple of (system_prompt, anthropic_messages)
|
||||
"""
|
||||
system_prompt = ""
|
||||
anthropic_messages = []
|
||||
|
||||
for msg in messages:
|
||||
if msg.role == "system":
|
||||
# Accumulate system messages
|
||||
if system_prompt:
|
||||
system_prompt += "\n\n"
|
||||
system_prompt += msg.content or ""
|
||||
else:
|
||||
# Convert to Anthropic format
|
||||
message_dict: Dict[str, Any] = {"role": msg.role}
|
||||
|
||||
# Handle content
|
||||
if msg.content:
|
||||
message_dict["content"] = msg.content
|
||||
|
||||
# Handle tool calls (assistant messages)
|
||||
if msg.tool_calls:
|
||||
# Anthropic format for tool use
|
||||
content_blocks = []
|
||||
if msg.content:
|
||||
content_blocks.append({"type": "text", "text": msg.content})
|
||||
|
||||
for tc in msg.tool_calls:
|
||||
content_blocks.append({
|
||||
"type": "tool_use",
|
||||
"id": tc.id,
|
||||
"name": tc.function.name,
|
||||
"input": json.loads(tc.function.arguments),
|
||||
})
|
||||
|
||||
message_dict["content"] = content_blocks
|
||||
|
||||
# Handle tool results (tool messages)
|
||||
if msg.role == "tool" and msg.tool_call_id:
|
||||
# Convert to Anthropic's tool_result format
|
||||
anthropic_messages.append({
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": msg.tool_call_id,
|
||||
"content": msg.content or "",
|
||||
}]
|
||||
})
|
||||
continue
|
||||
|
||||
anthropic_messages.append(message_dict)
|
||||
|
||||
return system_prompt, anthropic_messages
|
||||
|
||||
def _convert_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Convert OpenAI-style tools to Anthropic format.
|
||||
|
||||
Args:
|
||||
tools: OpenAI tool definitions
|
||||
|
||||
Returns:
|
||||
Anthropic tool definitions
|
||||
"""
|
||||
anthropic_tools = []
|
||||
|
||||
for tool in tools:
|
||||
if tool.get("type") == "function":
|
||||
func = tool.get("function", {})
|
||||
anthropic_tools.append({
|
||||
"name": func.get("name"),
|
||||
"description": func.get("description", ""),
|
||||
"input_schema": func.get("parameters", {}),
|
||||
})
|
||||
|
||||
return anthropic_tools
|
||||
|
||||
def _sync_chat(self, params: Dict[str, Any]) -> ChatResponse:
|
||||
"""
|
||||
Send synchronous chat request.
|
||||
|
||||
Args:
|
||||
params: Request parameters
|
||||
|
||||
Returns:
|
||||
ChatResponse
|
||||
"""
|
||||
message: Message = self.client.messages.create(**params)
|
||||
|
||||
# Extract content
|
||||
content = ""
|
||||
tool_calls = []
|
||||
|
||||
for block in message.content:
|
||||
if block.type == "text":
|
||||
content += block.text
|
||||
elif block.type == "tool_use":
|
||||
# Convert to ToolCall format
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
id=block.id,
|
||||
type="function",
|
||||
function=ToolFunction(
|
||||
name=block.name,
|
||||
arguments=json.dumps(block.input),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Build ChatMessage
|
||||
chat_message = ChatMessage(
|
||||
role="assistant",
|
||||
content=content if content else None,
|
||||
tool_calls=tool_calls if tool_calls else None,
|
||||
)
|
||||
|
||||
# Extract usage
|
||||
usage = None
|
||||
if message.usage:
|
||||
usage = UsageStats(
|
||||
prompt_tokens=message.usage.input_tokens,
|
||||
completion_tokens=message.usage.output_tokens,
|
||||
total_tokens=message.usage.input_tokens + message.usage.output_tokens,
|
||||
)
|
||||
|
||||
return ChatResponse(
|
||||
id=message.id,
|
||||
choices=[
|
||||
ChatResponseChoice(
|
||||
index=0,
|
||||
message=chat_message,
|
||||
finish_reason=message.stop_reason,
|
||||
)
|
||||
],
|
||||
usage=usage,
|
||||
model=message.model,
|
||||
)
|
||||
|
||||
def _stream_chat(self, params: Dict[str, Any]) -> Iterator[StreamChunk]:
|
||||
"""
|
||||
Stream chat response from Anthropic.
|
||||
|
||||
Args:
|
||||
params: Request parameters
|
||||
|
||||
Yields:
|
||||
StreamChunk objects
|
||||
"""
|
||||
stream = self.client.messages.stream(**params)
|
||||
|
||||
with stream as event_stream:
|
||||
for event in event_stream:
|
||||
event_data: MessageStreamEvent = event
|
||||
|
||||
# Handle different event types
|
||||
if event_data.type == "content_block_delta":
|
||||
delta = event_data.delta
|
||||
if hasattr(delta, "text"):
|
||||
yield StreamChunk(
|
||||
id="stream",
|
||||
delta_content=delta.text,
|
||||
)
|
||||
|
||||
elif event_data.type == "message_stop":
|
||||
# Final event with usage
|
||||
pass
|
||||
|
||||
elif event_data.type == "message_delta":
|
||||
# Contains stop reason and usage
|
||||
usage = None
|
||||
if hasattr(event_data, "usage"):
|
||||
usage_data = event_data.usage
|
||||
usage = UsageStats(
|
||||
prompt_tokens=0,
|
||||
completion_tokens=usage_data.output_tokens,
|
||||
total_tokens=usage_data.output_tokens,
|
||||
)
|
||||
|
||||
yield StreamChunk(
|
||||
id="stream",
|
||||
finish_reason=event_data.delta.stop_reason if hasattr(event_data.delta, "stop_reason") else None,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
async def chat_async(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[ChatMessage],
|
||||
stream: bool = False,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[ChatResponse, AsyncIterator[StreamChunk]]:
|
||||
"""
|
||||
Send async chat request to Anthropic.
|
||||
|
||||
Args:
|
||||
model: Model ID
|
||||
messages: Chat messages
|
||||
stream: Whether to stream
|
||||
max_tokens: Max tokens
|
||||
temperature: Temperature
|
||||
tools: Tool definitions
|
||||
tool_choice: Tool choice
|
||||
**kwargs: Additional args
|
||||
|
||||
Returns:
|
||||
ChatResponse or AsyncIterator[StreamChunk]
|
||||
"""
|
||||
# Resolve model alias
|
||||
model_id = MODEL_ALIASES.get(model, model)
|
||||
|
||||
# Convert messages
|
||||
system_prompt, anthropic_messages = self._convert_messages(messages)
|
||||
|
||||
# Build params
|
||||
params: Dict[str, Any] = {
|
||||
"model": model_id,
|
||||
"messages": anthropic_messages,
|
||||
"max_tokens": max_tokens or 4096,
|
||||
}
|
||||
|
||||
if system_prompt:
|
||||
params["system"] = system_prompt
|
||||
if temperature is not None:
|
||||
params["temperature"] = temperature
|
||||
if tools:
|
||||
params["tools"] = self._convert_tools(tools)
|
||||
|
||||
if stream:
|
||||
return self._stream_chat_async(params)
|
||||
else:
|
||||
message = await self.async_client.messages.create(**params)
|
||||
return self._convert_message(message)
|
||||
|
||||
async def _stream_chat_async(self, params: Dict[str, Any]) -> AsyncIterator[StreamChunk]:
|
||||
"""
|
||||
Stream async chat response.
|
||||
|
||||
Args:
|
||||
params: Request parameters
|
||||
|
||||
Yields:
|
||||
StreamChunk objects
|
||||
"""
|
||||
stream = await self.async_client.messages.stream(**params)
|
||||
|
||||
async with stream as event_stream:
|
||||
async for event in event_stream:
|
||||
if event.type == "content_block_delta":
|
||||
delta = event.delta
|
||||
if hasattr(delta, "text"):
|
||||
yield StreamChunk(
|
||||
id="stream",
|
||||
delta_content=delta.text,
|
||||
)
|
||||
|
||||
def _convert_message(self, message: Message) -> ChatResponse:
|
||||
"""Helper to convert Anthropic message to ChatResponse."""
|
||||
content = ""
|
||||
tool_calls = []
|
||||
|
||||
for block in message.content:
|
||||
if block.type == "text":
|
||||
content += block.text
|
||||
elif block.type == "tool_use":
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
id=block.id,
|
||||
type="function",
|
||||
function=ToolFunction(
|
||||
name=block.name,
|
||||
arguments=json.dumps(block.input),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
chat_message = ChatMessage(
|
||||
role="assistant",
|
||||
content=content if content else None,
|
||||
tool_calls=tool_calls if tool_calls else None,
|
||||
)
|
||||
|
||||
usage = None
|
||||
if message.usage:
|
||||
usage = UsageStats(
|
||||
prompt_tokens=message.usage.input_tokens,
|
||||
completion_tokens=message.usage.output_tokens,
|
||||
total_tokens=message.usage.input_tokens + message.usage.output_tokens,
|
||||
)
|
||||
|
||||
return ChatResponse(
|
||||
id=message.id,
|
||||
choices=[
|
||||
ChatResponseChoice(
|
||||
index=0,
|
||||
message=chat_message,
|
||||
finish_reason=message.stop_reason,
|
||||
)
|
||||
],
|
||||
usage=usage,
|
||||
model=message.model,
|
||||
)
|
||||
|
||||
def get_credits(self) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get account credits from Anthropic.
|
||||
|
||||
Note: Anthropic does not currently provide a public API endpoint
|
||||
for checking account credits/balance. This information is only
|
||||
available through the Anthropic Console web interface.
|
||||
|
||||
Returns:
|
||||
None (credits API not available)
|
||||
"""
|
||||
# Anthropic doesn't provide a public credits API endpoint
|
||||
# Users must check their balance at console.anthropic.com
|
||||
return None
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear model cache."""
|
||||
self._models_cache = None
|
||||
|
||||
def get_raw_models(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get raw model data as dictionaries.
|
||||
|
||||
Returns:
|
||||
List of model dictionaries
|
||||
"""
|
||||
models = self.list_models()
|
||||
return [
|
||||
{
|
||||
"id": model.id,
|
||||
"name": model.name,
|
||||
"description": model.description,
|
||||
"context_length": model.context_length,
|
||||
"pricing": model.pricing,
|
||||
}
|
||||
for model in models
|
||||
]
|
||||
|
||||
def get_raw_model(self, model_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get raw model data for a specific model.
|
||||
|
||||
Args:
|
||||
model_id: Model identifier
|
||||
|
||||
Returns:
|
||||
Model dictionary or None
|
||||
"""
|
||||
model = self.get_model(model_id)
|
||||
if model:
|
||||
return {
|
||||
"id": model.id,
|
||||
"name": model.name,
|
||||
"description": model.description,
|
||||
"context_length": model.context_length,
|
||||
"pricing": model.pricing,
|
||||
}
|
||||
return None
|
||||
423
oai/providers/ollama.py
Normal file
423
oai/providers/ollama.py
Normal file
@@ -0,0 +1,423 @@
|
||||
"""
|
||||
Ollama provider for local AI model serving.
|
||||
|
||||
This provider connects to a local Ollama server for running models
|
||||
locally without API keys or external dependencies.
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
|
||||
|
||||
import requests
|
||||
|
||||
from oai.constants import OLLAMA_DEFAULT_URL
|
||||
from oai.providers.base import (
|
||||
AIProvider,
|
||||
ChatMessage,
|
||||
ChatResponse,
|
||||
ChatResponseChoice,
|
||||
ModelInfo,
|
||||
ProviderCapabilities,
|
||||
StreamChunk,
|
||||
UsageStats,
|
||||
)
|
||||
from oai.utils.logging import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class OllamaProvider(AIProvider):
|
||||
"""
|
||||
Ollama local model provider.
|
||||
|
||||
Connects to a local Ollama server for running models locally.
|
||||
No API key required.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str = "",
|
||||
base_url: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""
|
||||
Initialize Ollama provider.
|
||||
|
||||
Args:
|
||||
api_key: Not used (Ollama doesn't require API keys)
|
||||
base_url: Ollama server URL (default: http://localhost:11434)
|
||||
**kwargs: Additional arguments (ignored)
|
||||
"""
|
||||
super().__init__(api_key or "", base_url)
|
||||
self.base_url = base_url or OLLAMA_DEFAULT_URL
|
||||
self._check_server_available()
|
||||
|
||||
def _check_server_available(self) -> bool:
|
||||
"""
|
||||
Check if Ollama server is accessible.
|
||||
|
||||
Returns:
|
||||
True if server is accessible
|
||||
"""
|
||||
try:
|
||||
response = requests.get(f"{self.base_url}/api/tags", timeout=2)
|
||||
if response.ok:
|
||||
logger.info(f"Ollama server accessible at {self.base_url}")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"Ollama server returned status {response.status_code}")
|
||||
return False
|
||||
except requests.RequestException as e:
|
||||
logger.warning(f"Ollama server not accessible at {self.base_url}: {e}")
|
||||
return False
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Get provider name."""
|
||||
return "Ollama"
|
||||
|
||||
@property
|
||||
def capabilities(self) -> ProviderCapabilities:
|
||||
"""Get provider capabilities."""
|
||||
return ProviderCapabilities(
|
||||
streaming=True,
|
||||
tools=False, # Tool support varies by model
|
||||
images=False, # Image support varies by model
|
||||
online=True, # Web search via DuckDuckGo/Google
|
||||
max_context=8192, # Varies by model
|
||||
)
|
||||
|
||||
def list_models(self, filter_text_only: bool = True) -> List[ModelInfo]:
|
||||
"""
|
||||
List models from local Ollama installation.
|
||||
|
||||
Args:
|
||||
filter_text_only: Ignored for Ollama
|
||||
|
||||
Returns:
|
||||
List of available models
|
||||
"""
|
||||
try:
|
||||
response = requests.get(f"{self.base_url}/api/tags", timeout=5)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
models = []
|
||||
for model_data in data.get("models", []):
|
||||
models.append(self._parse_model(model_data))
|
||||
|
||||
logger.info(f"Found {len(models)} Ollama models")
|
||||
return models
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Failed to list Ollama models: {e}")
|
||||
return []
|
||||
|
||||
def _parse_model(self, model_data: Dict[str, Any]) -> ModelInfo:
|
||||
"""
|
||||
Parse Ollama model data into ModelInfo.
|
||||
|
||||
Args:
|
||||
model_data: Raw model data from Ollama API
|
||||
|
||||
Returns:
|
||||
ModelInfo object
|
||||
"""
|
||||
model_name = model_data.get("name", "unknown")
|
||||
size_bytes = model_data.get("size", 0)
|
||||
size_gb = size_bytes / (1024 ** 3) if size_bytes else 0
|
||||
|
||||
return ModelInfo(
|
||||
id=model_name,
|
||||
name=model_name,
|
||||
description=f"Size: {size_gb:.1f}GB",
|
||||
context_length=8192, # Default, varies by model
|
||||
pricing={}, # Local models are free
|
||||
supported_parameters=["stream", "temperature", "max_tokens"],
|
||||
)
|
||||
|
||||
def get_model(self, model_id: str) -> Optional[ModelInfo]:
|
||||
"""
|
||||
Get information about a specific model.
|
||||
|
||||
Args:
|
||||
model_id: Model identifier
|
||||
|
||||
Returns:
|
||||
ModelInfo or None if not found
|
||||
"""
|
||||
models = self.list_models()
|
||||
for model in models:
|
||||
if model.id == model_id:
|
||||
return model
|
||||
return None
|
||||
|
||||
def chat(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[ChatMessage],
|
||||
stream: bool = False,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[ChatResponse, Iterator[StreamChunk]]:
|
||||
"""
|
||||
Send chat request to Ollama.
|
||||
|
||||
Args:
|
||||
model: Model name
|
||||
messages: Chat messages
|
||||
stream: Whether to stream response
|
||||
max_tokens: Maximum tokens (Ollama calls this num_predict)
|
||||
temperature: Sampling temperature
|
||||
tools: Not supported
|
||||
tool_choice: Not supported
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
ChatResponse or Iterator[StreamChunk]
|
||||
"""
|
||||
# Convert messages to Ollama format
|
||||
ollama_messages = []
|
||||
for msg in messages:
|
||||
ollama_messages.append({
|
||||
"role": msg.role,
|
||||
"content": msg.content or "",
|
||||
})
|
||||
|
||||
# Build request payload
|
||||
payload: Dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": ollama_messages,
|
||||
"stream": stream,
|
||||
}
|
||||
|
||||
# Add optional parameters
|
||||
options = {}
|
||||
if temperature is not None:
|
||||
options["temperature"] = temperature
|
||||
if max_tokens is not None:
|
||||
options["num_predict"] = max_tokens
|
||||
|
||||
if options:
|
||||
payload["options"] = options
|
||||
|
||||
logger.debug(f"Ollama request: model={model}, messages={len(ollama_messages)}")
|
||||
|
||||
try:
|
||||
if stream:
|
||||
return self._stream_chat(payload)
|
||||
else:
|
||||
return self._sync_chat(payload)
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Ollama request failed: {e}")
|
||||
# Return error response
|
||||
return ChatResponse(
|
||||
id="error",
|
||||
choices=[
|
||||
ChatResponseChoice(
|
||||
index=0,
|
||||
message=ChatMessage(
|
||||
role="assistant",
|
||||
content=f"Error: {str(e)}",
|
||||
),
|
||||
finish_reason="error",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
def _sync_chat(self, payload: Dict[str, Any]) -> ChatResponse:
|
||||
"""
|
||||
Send synchronous chat request.
|
||||
|
||||
Args:
|
||||
payload: Request payload
|
||||
|
||||
Returns:
|
||||
ChatResponse
|
||||
"""
|
||||
response = requests.post(
|
||||
f"{self.base_url}/api/chat",
|
||||
json=payload,
|
||||
timeout=120,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
# Parse response
|
||||
message_data = data.get("message", {})
|
||||
content = message_data.get("content", "")
|
||||
|
||||
# Extract token usage if available
|
||||
usage = None
|
||||
if "prompt_eval_count" in data or "eval_count" in data:
|
||||
usage = UsageStats(
|
||||
prompt_tokens=data.get("prompt_eval_count", 0),
|
||||
completion_tokens=data.get("eval_count", 0),
|
||||
total_tokens=data.get("prompt_eval_count", 0) + data.get("eval_count", 0),
|
||||
total_cost_usd=0.0, # Local models are free
|
||||
)
|
||||
|
||||
return ChatResponse(
|
||||
id=str(time.time()),
|
||||
choices=[
|
||||
ChatResponseChoice(
|
||||
index=0,
|
||||
message=ChatMessage(role="assistant", content=content),
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
usage=usage,
|
||||
model=data.get("model"),
|
||||
)
|
||||
|
||||
def _stream_chat(self, payload: Dict[str, Any]) -> Iterator[StreamChunk]:
|
||||
"""
|
||||
Stream chat response from Ollama.
|
||||
|
||||
Args:
|
||||
payload: Request payload
|
||||
|
||||
Yields:
|
||||
StreamChunk objects
|
||||
"""
|
||||
response = requests.post(
|
||||
f"{self.base_url}/api/chat",
|
||||
json=payload,
|
||||
stream=True,
|
||||
timeout=120,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
total_prompt_tokens = 0
|
||||
total_completion_tokens = 0
|
||||
|
||||
for line in response.iter_lines():
|
||||
if not line:
|
||||
continue
|
||||
|
||||
try:
|
||||
data = json.loads(line)
|
||||
|
||||
# Extract content delta
|
||||
message_data = data.get("message", {})
|
||||
content = message_data.get("content", "")
|
||||
|
||||
# Check if done
|
||||
done = data.get("done", False)
|
||||
finish_reason = "stop" if done else None
|
||||
|
||||
# Extract usage if available
|
||||
usage = None
|
||||
if done and ("prompt_eval_count" in data or "eval_count" in data):
|
||||
total_prompt_tokens = data.get("prompt_eval_count", 0)
|
||||
total_completion_tokens = data.get("eval_count", 0)
|
||||
usage = UsageStats(
|
||||
prompt_tokens=total_prompt_tokens,
|
||||
completion_tokens=total_completion_tokens,
|
||||
total_tokens=total_prompt_tokens + total_completion_tokens,
|
||||
total_cost_usd=0.0,
|
||||
)
|
||||
|
||||
yield StreamChunk(
|
||||
id=str(time.time()),
|
||||
delta_content=content if content else None,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to parse Ollama stream chunk: {e}")
|
||||
yield StreamChunk(
|
||||
id="error",
|
||||
error=f"Parse error: {e}",
|
||||
)
|
||||
|
||||
async def chat_async(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[ChatMessage],
|
||||
stream: bool = False,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[ChatResponse, AsyncIterator[StreamChunk]]:
|
||||
"""
|
||||
Async chat not implemented for Ollama.
|
||||
|
||||
Args:
|
||||
model: Model name
|
||||
messages: Chat messages
|
||||
stream: Whether to stream
|
||||
max_tokens: Max tokens
|
||||
temperature: Temperature
|
||||
tools: Tools (not supported)
|
||||
tool_choice: Tool choice (not supported)
|
||||
**kwargs: Additional args
|
||||
|
||||
Returns:
|
||||
ChatResponse or AsyncIterator[StreamChunk]
|
||||
|
||||
Raises:
|
||||
NotImplementedError: Async not implemented
|
||||
"""
|
||||
raise NotImplementedError("Async chat not implemented for Ollama provider")
|
||||
|
||||
def get_credits(self) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get account credits.
|
||||
|
||||
Returns:
|
||||
None (Ollama is local and free)
|
||||
"""
|
||||
return None
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear model cache (no-op for Ollama)."""
|
||||
pass
|
||||
|
||||
def get_raw_models(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get raw model data as dictionaries.
|
||||
|
||||
Returns:
|
||||
List of model dictionaries
|
||||
"""
|
||||
models = self.list_models()
|
||||
return [
|
||||
{
|
||||
"id": model.id,
|
||||
"name": model.name,
|
||||
"description": model.description,
|
||||
"context_length": model.context_length,
|
||||
"pricing": model.pricing,
|
||||
}
|
||||
for model in models
|
||||
]
|
||||
|
||||
def get_raw_model(self, model_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get raw model data for a specific model.
|
||||
|
||||
Args:
|
||||
model_id: Model identifier
|
||||
|
||||
Returns:
|
||||
Model dictionary or None
|
||||
"""
|
||||
model = self.get_model(model_id)
|
||||
if model:
|
||||
return {
|
||||
"id": model.id,
|
||||
"name": model.name,
|
||||
"description": model.description,
|
||||
"context_length": model.context_length,
|
||||
"pricing": model.pricing,
|
||||
}
|
||||
return None
|
||||
630
oai/providers/openai.py
Normal file
630
oai/providers/openai.py
Normal file
@@ -0,0 +1,630 @@
|
||||
"""
|
||||
OpenAI provider for GPT models.
|
||||
|
||||
This provider connects to OpenAI's API for accessing GPT-4, GPT-3.5, and other OpenAI models.
|
||||
"""
|
||||
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
|
||||
|
||||
from openai import OpenAI, AsyncOpenAI
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
|
||||
from oai.constants import OPENAI_BASE_URL
|
||||
from oai.providers.base import (
|
||||
AIProvider,
|
||||
ChatMessage,
|
||||
ChatResponse,
|
||||
ChatResponseChoice,
|
||||
ModelInfo,
|
||||
ProviderCapabilities,
|
||||
StreamChunk,
|
||||
ToolCall,
|
||||
ToolFunction,
|
||||
UsageStats,
|
||||
)
|
||||
from oai.utils.logging import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
# Model aliases for convenience
|
||||
MODEL_ALIASES = {
|
||||
"gpt-4": "gpt-4-turbo",
|
||||
"gpt-4-turbo": "gpt-4-turbo-2024-04-09",
|
||||
"gpt-4o": "gpt-4o-2024-11-20",
|
||||
"gpt-4o-mini": "gpt-4o-mini-2024-07-18",
|
||||
"gpt-3.5": "gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo": "gpt-3.5-turbo-0125",
|
||||
"o1": "o1-2024-12-17",
|
||||
"o1-mini": "o1-mini-2024-09-12",
|
||||
"o1-preview": "o1-preview-2024-09-12",
|
||||
}
|
||||
|
||||
|
||||
class OpenAIProvider(AIProvider):
|
||||
"""
|
||||
OpenAI API provider.
|
||||
|
||||
Provides access to GPT-4, GPT-3.5, o1, and other OpenAI models.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: Optional[str] = None,
|
||||
app_name: str = "oAI",
|
||||
app_url: str = "",
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""
|
||||
Initialize OpenAI provider.
|
||||
|
||||
Args:
|
||||
api_key: OpenAI API key
|
||||
base_url: Optional custom base URL
|
||||
app_name: Application name (for headers)
|
||||
app_url: Application URL (for headers)
|
||||
**kwargs: Additional arguments
|
||||
"""
|
||||
super().__init__(api_key, base_url or OPENAI_BASE_URL)
|
||||
self.client = OpenAI(api_key=api_key, base_url=self.base_url)
|
||||
self.async_client = AsyncOpenAI(api_key=api_key, base_url=self.base_url)
|
||||
self._models_cache: Optional[List[ModelInfo]] = None
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Get provider name."""
|
||||
return "OpenAI"
|
||||
|
||||
@property
|
||||
def capabilities(self) -> ProviderCapabilities:
|
||||
"""Get provider capabilities."""
|
||||
return ProviderCapabilities(
|
||||
streaming=True,
|
||||
tools=True,
|
||||
images=True,
|
||||
online=True, # Web search via DuckDuckGo/Google
|
||||
max_context=128000,
|
||||
)
|
||||
|
||||
def list_models(self, filter_text_only: bool = True) -> List[ModelInfo]:
|
||||
"""
|
||||
List available OpenAI models.
|
||||
|
||||
Args:
|
||||
filter_text_only: Whether to filter for text models only
|
||||
|
||||
Returns:
|
||||
List of ModelInfo objects
|
||||
"""
|
||||
if self._models_cache:
|
||||
return self._models_cache
|
||||
|
||||
try:
|
||||
models_response = self.client.models.list()
|
||||
models = []
|
||||
|
||||
for model in models_response.data:
|
||||
# Filter for chat models
|
||||
if "gpt" in model.id or "o1" in model.id:
|
||||
models.append(self._parse_model(model))
|
||||
|
||||
# Sort by name
|
||||
models.sort(key=lambda m: m.name)
|
||||
self._models_cache = models
|
||||
|
||||
logger.info(f"Loaded {len(models)} OpenAI models")
|
||||
return models
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list OpenAI models: {e}")
|
||||
return self._get_fallback_models()
|
||||
|
||||
def _get_fallback_models(self) -> List[ModelInfo]:
|
||||
"""
|
||||
Get fallback list of common OpenAI models.
|
||||
|
||||
Returns:
|
||||
List of common models
|
||||
"""
|
||||
return [
|
||||
ModelInfo(
|
||||
id="gpt-4o",
|
||||
name="GPT-4o",
|
||||
description="Most capable GPT-4 model",
|
||||
context_length=128000,
|
||||
pricing={"input": 5.0, "output": 15.0},
|
||||
supported_parameters=["temperature", "max_tokens", "stream", "tools"],
|
||||
input_modalities=["text", "image"],
|
||||
),
|
||||
ModelInfo(
|
||||
id="gpt-4o-mini",
|
||||
name="GPT-4o Mini",
|
||||
description="Affordable and fast GPT-4 class model",
|
||||
context_length=128000,
|
||||
pricing={"input": 0.15, "output": 0.6},
|
||||
supported_parameters=["temperature", "max_tokens", "stream", "tools"],
|
||||
input_modalities=["text", "image"],
|
||||
),
|
||||
ModelInfo(
|
||||
id="gpt-4-turbo",
|
||||
name="GPT-4 Turbo",
|
||||
description="GPT-4 Turbo with vision",
|
||||
context_length=128000,
|
||||
pricing={"input": 10.0, "output": 30.0},
|
||||
supported_parameters=["temperature", "max_tokens", "stream", "tools"],
|
||||
input_modalities=["text", "image"],
|
||||
),
|
||||
ModelInfo(
|
||||
id="gpt-3.5-turbo",
|
||||
name="GPT-3.5 Turbo",
|
||||
description="Fast and affordable model",
|
||||
context_length=16384,
|
||||
pricing={"input": 0.5, "output": 1.5},
|
||||
supported_parameters=["temperature", "max_tokens", "stream", "tools"],
|
||||
),
|
||||
ModelInfo(
|
||||
id="o1",
|
||||
name="o1",
|
||||
description="Advanced reasoning model",
|
||||
context_length=200000,
|
||||
pricing={"input": 15.0, "output": 60.0},
|
||||
supported_parameters=["max_tokens"],
|
||||
),
|
||||
ModelInfo(
|
||||
id="o1-mini",
|
||||
name="o1-mini",
|
||||
description="Fast reasoning model",
|
||||
context_length=128000,
|
||||
pricing={"input": 3.0, "output": 12.0},
|
||||
supported_parameters=["max_tokens"],
|
||||
),
|
||||
]
|
||||
|
||||
def _parse_model(self, model: Any) -> ModelInfo:
|
||||
"""
|
||||
Parse OpenAI model into ModelInfo.
|
||||
|
||||
Args:
|
||||
model: OpenAI model object
|
||||
|
||||
Returns:
|
||||
ModelInfo object
|
||||
"""
|
||||
model_id = model.id
|
||||
|
||||
# Determine context length
|
||||
context_length = 8192 # Default
|
||||
if "gpt-4o" in model_id or "gpt-4-turbo" in model_id:
|
||||
context_length = 128000
|
||||
elif "gpt-4" in model_id:
|
||||
context_length = 8192
|
||||
elif "gpt-3.5-turbo" in model_id:
|
||||
context_length = 16384
|
||||
elif "o1" in model_id:
|
||||
context_length = 128000
|
||||
|
||||
# Determine pricing (approximate)
|
||||
pricing = {}
|
||||
if "gpt-4o-mini" in model_id:
|
||||
pricing = {"input": 0.15, "output": 0.6}
|
||||
elif "gpt-4o" in model_id:
|
||||
pricing = {"input": 5.0, "output": 15.0}
|
||||
elif "gpt-4-turbo" in model_id:
|
||||
pricing = {"input": 10.0, "output": 30.0}
|
||||
elif "gpt-4" in model_id:
|
||||
pricing = {"input": 30.0, "output": 60.0}
|
||||
elif "gpt-3.5" in model_id:
|
||||
pricing = {"input": 0.5, "output": 1.5}
|
||||
elif "o1" in model_id and "mini" not in model_id:
|
||||
pricing = {"input": 15.0, "output": 60.0}
|
||||
elif "o1-mini" in model_id:
|
||||
pricing = {"input": 3.0, "output": 12.0}
|
||||
|
||||
return ModelInfo(
|
||||
id=model_id,
|
||||
name=model_id,
|
||||
description="",
|
||||
context_length=context_length,
|
||||
pricing=pricing,
|
||||
supported_parameters=["temperature", "max_tokens", "stream", "tools"],
|
||||
)
|
||||
|
||||
def get_model(self, model_id: str) -> Optional[ModelInfo]:
|
||||
"""
|
||||
Get information about a specific model.
|
||||
|
||||
Args:
|
||||
model_id: Model identifier
|
||||
|
||||
Returns:
|
||||
ModelInfo or None
|
||||
"""
|
||||
# Resolve alias
|
||||
resolved_id = MODEL_ALIASES.get(model_id, model_id)
|
||||
|
||||
models = self.list_models()
|
||||
for model in models:
|
||||
if model.id == resolved_id or model.id == model_id:
|
||||
return model
|
||||
|
||||
# Try to fetch directly
|
||||
try:
|
||||
model = self.client.models.retrieve(resolved_id)
|
||||
return self._parse_model(model)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def chat(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[ChatMessage],
|
||||
stream: bool = False,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[ChatResponse, Iterator[StreamChunk]]:
|
||||
"""
|
||||
Send chat completion request to OpenAI.
|
||||
|
||||
Args:
|
||||
model: Model ID
|
||||
messages: Chat messages
|
||||
stream: Whether to stream response
|
||||
max_tokens: Maximum tokens
|
||||
temperature: Sampling temperature
|
||||
tools: Tool definitions
|
||||
tool_choice: Tool selection mode
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
ChatResponse or Iterator[StreamChunk]
|
||||
"""
|
||||
# Resolve model alias
|
||||
model_id = MODEL_ALIASES.get(model, model)
|
||||
|
||||
# Convert messages to OpenAI format
|
||||
openai_messages = []
|
||||
for msg in messages:
|
||||
message_dict = {"role": msg.role, "content": msg.content or ""}
|
||||
|
||||
if msg.tool_calls:
|
||||
message_dict["tool_calls"] = [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": tc.type,
|
||||
"function": {
|
||||
"name": tc.function.name,
|
||||
"arguments": tc.function.arguments,
|
||||
},
|
||||
}
|
||||
for tc in msg.tool_calls
|
||||
]
|
||||
|
||||
if msg.tool_call_id:
|
||||
message_dict["tool_call_id"] = msg.tool_call_id
|
||||
|
||||
openai_messages.append(message_dict)
|
||||
|
||||
# Build request parameters
|
||||
params: Dict[str, Any] = {
|
||||
"model": model_id,
|
||||
"messages": openai_messages,
|
||||
"stream": stream,
|
||||
}
|
||||
|
||||
# Add optional parameters
|
||||
if max_tokens is not None:
|
||||
params["max_tokens"] = max_tokens
|
||||
if temperature is not None and "o1" not in model_id:
|
||||
# o1 models don't support temperature
|
||||
params["temperature"] = temperature
|
||||
if tools:
|
||||
params["tools"] = tools
|
||||
if tool_choice:
|
||||
params["tool_choice"] = tool_choice
|
||||
|
||||
logger.debug(f"OpenAI request: model={model_id}, messages={len(openai_messages)}")
|
||||
|
||||
try:
|
||||
if stream:
|
||||
return self._stream_chat(params)
|
||||
else:
|
||||
return self._sync_chat(params)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI request failed: {e}")
|
||||
return ChatResponse(
|
||||
id="error",
|
||||
choices=[
|
||||
ChatResponseChoice(
|
||||
index=0,
|
||||
message=ChatMessage(role="assistant", content=f"Error: {str(e)}"),
|
||||
finish_reason="error",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
def _sync_chat(self, params: Dict[str, Any]) -> ChatResponse:
|
||||
"""
|
||||
Send synchronous chat request.
|
||||
|
||||
Args:
|
||||
params: Request parameters
|
||||
|
||||
Returns:
|
||||
ChatResponse
|
||||
"""
|
||||
completion: ChatCompletion = self.client.chat.completions.create(**params)
|
||||
|
||||
# Convert to our format
|
||||
choices = []
|
||||
for choice in completion.choices:
|
||||
# Convert tool calls if present
|
||||
tool_calls = None
|
||||
if choice.message.tool_calls:
|
||||
tool_calls = [
|
||||
ToolCall(
|
||||
id=tc.id,
|
||||
type=tc.type,
|
||||
function=ToolFunction(
|
||||
name=tc.function.name,
|
||||
arguments=tc.function.arguments,
|
||||
),
|
||||
)
|
||||
for tc in choice.message.tool_calls
|
||||
]
|
||||
|
||||
choices.append(
|
||||
ChatResponseChoice(
|
||||
index=choice.index,
|
||||
message=ChatMessage(
|
||||
role=choice.message.role,
|
||||
content=choice.message.content,
|
||||
tool_calls=tool_calls,
|
||||
),
|
||||
finish_reason=choice.finish_reason,
|
||||
)
|
||||
)
|
||||
|
||||
# Convert usage
|
||||
usage = None
|
||||
if completion.usage:
|
||||
usage = UsageStats(
|
||||
prompt_tokens=completion.usage.prompt_tokens,
|
||||
completion_tokens=completion.usage.completion_tokens,
|
||||
total_tokens=completion.usage.total_tokens,
|
||||
)
|
||||
|
||||
return ChatResponse(
|
||||
id=completion.id,
|
||||
choices=choices,
|
||||
usage=usage,
|
||||
model=completion.model,
|
||||
created=completion.created,
|
||||
)
|
||||
|
||||
def _stream_chat(self, params: Dict[str, Any]) -> Iterator[StreamChunk]:
|
||||
"""
|
||||
Stream chat response from OpenAI.
|
||||
|
||||
Args:
|
||||
params: Request parameters
|
||||
|
||||
Yields:
|
||||
StreamChunk objects
|
||||
"""
|
||||
stream = self.client.chat.completions.create(**params)
|
||||
|
||||
for chunk in stream:
|
||||
chunk_data: ChatCompletionChunk = chunk
|
||||
|
||||
if not chunk_data.choices:
|
||||
continue
|
||||
|
||||
choice = chunk_data.choices[0]
|
||||
delta = choice.delta
|
||||
|
||||
# Extract content
|
||||
content = delta.content if delta.content else None
|
||||
|
||||
# Extract finish reason
|
||||
finish_reason = choice.finish_reason
|
||||
|
||||
# Extract usage (usually in last chunk)
|
||||
usage = None
|
||||
if hasattr(chunk_data, "usage") and chunk_data.usage:
|
||||
usage = UsageStats(
|
||||
prompt_tokens=chunk_data.usage.prompt_tokens,
|
||||
completion_tokens=chunk_data.usage.completion_tokens,
|
||||
total_tokens=chunk_data.usage.total_tokens,
|
||||
)
|
||||
|
||||
yield StreamChunk(
|
||||
id=chunk_data.id,
|
||||
delta_content=content,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
async def chat_async(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[ChatMessage],
|
||||
stream: bool = False,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[ChatResponse, AsyncIterator[StreamChunk]]:
|
||||
"""
|
||||
Send async chat request to OpenAI.
|
||||
|
||||
Args:
|
||||
model: Model ID
|
||||
messages: Chat messages
|
||||
stream: Whether to stream
|
||||
max_tokens: Max tokens
|
||||
temperature: Temperature
|
||||
tools: Tool definitions
|
||||
tool_choice: Tool choice
|
||||
**kwargs: Additional args
|
||||
|
||||
Returns:
|
||||
ChatResponse or AsyncIterator[StreamChunk]
|
||||
"""
|
||||
# Resolve model alias
|
||||
model_id = MODEL_ALIASES.get(model, model)
|
||||
|
||||
# Convert messages
|
||||
openai_messages = [msg.to_dict() for msg in messages]
|
||||
|
||||
# Build params
|
||||
params: Dict[str, Any] = {
|
||||
"model": model_id,
|
||||
"messages": openai_messages,
|
||||
"stream": stream,
|
||||
}
|
||||
|
||||
if max_tokens:
|
||||
params["max_tokens"] = max_tokens
|
||||
if temperature is not None and "o1" not in model_id:
|
||||
params["temperature"] = temperature
|
||||
if tools:
|
||||
params["tools"] = tools
|
||||
if tool_choice:
|
||||
params["tool_choice"] = tool_choice
|
||||
|
||||
if stream:
|
||||
return self._stream_chat_async(params)
|
||||
else:
|
||||
completion = await self.async_client.chat.completions.create(**params)
|
||||
# Convert to ChatResponse (similar to _sync_chat)
|
||||
return self._convert_completion(completion)
|
||||
|
||||
async def _stream_chat_async(self, params: Dict[str, Any]) -> AsyncIterator[StreamChunk]:
|
||||
"""
|
||||
Stream async chat response.
|
||||
|
||||
Args:
|
||||
params: Request parameters
|
||||
|
||||
Yields:
|
||||
StreamChunk objects
|
||||
"""
|
||||
stream = await self.async_client.chat.completions.create(**params)
|
||||
|
||||
async for chunk in stream:
|
||||
if not chunk.choices:
|
||||
continue
|
||||
|
||||
choice = chunk.choices[0]
|
||||
delta = choice.delta
|
||||
|
||||
yield StreamChunk(
|
||||
id=chunk.id,
|
||||
delta_content=delta.content,
|
||||
finish_reason=choice.finish_reason,
|
||||
)
|
||||
|
||||
def _convert_completion(self, completion: ChatCompletion) -> ChatResponse:
|
||||
"""Helper to convert OpenAI completion to ChatResponse."""
|
||||
choices = []
|
||||
for choice in completion.choices:
|
||||
tool_calls = None
|
||||
if choice.message.tool_calls:
|
||||
tool_calls = [
|
||||
ToolCall(
|
||||
id=tc.id,
|
||||
type=tc.type,
|
||||
function=ToolFunction(
|
||||
name=tc.function.name,
|
||||
arguments=tc.function.arguments,
|
||||
),
|
||||
)
|
||||
for tc in choice.message.tool_calls
|
||||
]
|
||||
|
||||
choices.append(
|
||||
ChatResponseChoice(
|
||||
index=choice.index,
|
||||
message=ChatMessage(
|
||||
role=choice.message.role,
|
||||
content=choice.message.content,
|
||||
tool_calls=tool_calls,
|
||||
),
|
||||
finish_reason=choice.finish_reason,
|
||||
)
|
||||
)
|
||||
|
||||
usage = None
|
||||
if completion.usage:
|
||||
usage = UsageStats(
|
||||
prompt_tokens=completion.usage.prompt_tokens,
|
||||
completion_tokens=completion.usage.completion_tokens,
|
||||
total_tokens=completion.usage.total_tokens,
|
||||
)
|
||||
|
||||
return ChatResponse(
|
||||
id=completion.id,
|
||||
choices=choices,
|
||||
usage=usage,
|
||||
model=completion.model,
|
||||
created=completion.created,
|
||||
)
|
||||
|
||||
def get_credits(self) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get account credits.
|
||||
|
||||
Returns:
|
||||
None (OpenAI doesn't provide credit API)
|
||||
"""
|
||||
return None
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear model cache."""
|
||||
self._models_cache = None
|
||||
|
||||
def get_raw_models(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get raw model data as dictionaries.
|
||||
|
||||
Returns:
|
||||
List of model dictionaries
|
||||
"""
|
||||
models = self.list_models()
|
||||
return [
|
||||
{
|
||||
"id": model.id,
|
||||
"name": model.name,
|
||||
"description": model.description,
|
||||
"context_length": model.context_length,
|
||||
"pricing": model.pricing,
|
||||
}
|
||||
for model in models
|
||||
]
|
||||
|
||||
def get_raw_model(self, model_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get raw model data for a specific model.
|
||||
|
||||
Args:
|
||||
model_id: Model identifier
|
||||
|
||||
Returns:
|
||||
Model dictionary or None
|
||||
"""
|
||||
model = self.get_model(model_id)
|
||||
if model:
|
||||
return {
|
||||
"id": model.id,
|
||||
"name": model.name,
|
||||
"description": model.description,
|
||||
"context_length": model.context_length,
|
||||
"pricing": model.pricing,
|
||||
}
|
||||
return None
|
||||
60
oai/providers/registry.py
Normal file
60
oai/providers/registry.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""
|
||||
Provider registry for AI model providers.
|
||||
|
||||
This module maintains a central registry of all available AI providers,
|
||||
allowing dynamic provider lookup and registration.
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Type
|
||||
|
||||
from oai.providers.base import AIProvider
|
||||
|
||||
# Global provider registry
|
||||
PROVIDER_REGISTRY: Dict[str, Type[AIProvider]] = {}
|
||||
|
||||
|
||||
def register_provider(name: str, provider_class: Type[AIProvider]) -> None:
|
||||
"""
|
||||
Register a provider class with the given name.
|
||||
|
||||
Args:
|
||||
name: Provider identifier (e.g., "openrouter", "anthropic")
|
||||
provider_class: The provider class to register
|
||||
"""
|
||||
PROVIDER_REGISTRY[name] = provider_class
|
||||
|
||||
|
||||
def get_provider_class(name: str) -> Optional[Type[AIProvider]]:
|
||||
"""
|
||||
Get a provider class by name.
|
||||
|
||||
Args:
|
||||
name: Provider identifier
|
||||
|
||||
Returns:
|
||||
Provider class or None if not found
|
||||
"""
|
||||
return PROVIDER_REGISTRY.get(name)
|
||||
|
||||
|
||||
def list_providers() -> List[str]:
|
||||
"""
|
||||
List all registered provider names.
|
||||
|
||||
Returns:
|
||||
List of provider identifiers
|
||||
"""
|
||||
return list(PROVIDER_REGISTRY.keys())
|
||||
|
||||
|
||||
def is_provider_registered(name: str) -> bool:
|
||||
"""
|
||||
Check if a provider is registered.
|
||||
|
||||
Args:
|
||||
name: Provider identifier
|
||||
|
||||
Returns:
|
||||
True if provider is registered
|
||||
"""
|
||||
return name in PROVIDER_REGISTRY
|
||||
@@ -16,6 +16,7 @@ from oai.core.client import AIClient
|
||||
from oai.core.session import ChatSession
|
||||
from oai.tui.screens import (
|
||||
AlertDialog,
|
||||
CommandsScreen,
|
||||
ConfirmDialog,
|
||||
ConfigScreen,
|
||||
ConversationSelectorScreen,
|
||||
@@ -69,7 +70,8 @@ class oAIChatApp(App):
|
||||
"""Compose the TUI layout."""
|
||||
model_name = self.session.selected_model.get("name", "") if self.session.selected_model else ""
|
||||
model_info = self.session.selected_model if self.session.selected_model else None
|
||||
yield Header(version=__version__, model=model_name, model_info=model_info)
|
||||
provider_name = self.session.client.provider_name if self.session.client else ""
|
||||
yield Header(version=__version__, model=model_name, model_info=model_info, provider=provider_name)
|
||||
yield ChatDisplay()
|
||||
yield InputBar()
|
||||
yield CommandDropdown()
|
||||
@@ -158,10 +160,10 @@ class oAIChatApp(App):
|
||||
else:
|
||||
# Command is complete, submit it directly
|
||||
dropdown.hide()
|
||||
chat_input.value = "" # Clear immediately
|
||||
# Process the command directly
|
||||
async def submit_command():
|
||||
await self._process_submitted_input(selected)
|
||||
chat_input.value = ""
|
||||
self.call_later(submit_command)
|
||||
return
|
||||
elif event.key == "escape":
|
||||
@@ -232,12 +234,12 @@ class oAIChatApp(App):
|
||||
if not user_input:
|
||||
return
|
||||
|
||||
# Process the input
|
||||
await self._process_submitted_input(user_input)
|
||||
|
||||
# Clear input field
|
||||
# Clear input field immediately
|
||||
event.input.value = ""
|
||||
|
||||
# Process the input (async, will wait for AI response)
|
||||
await self._process_submitted_input(user_input)
|
||||
|
||||
async def _process_submitted_input(self, user_input: str) -> None:
|
||||
"""Process submitted input (command or message).
|
||||
|
||||
@@ -280,6 +282,10 @@ class oAIChatApp(App):
|
||||
await self.push_screen(HelpScreen())
|
||||
return
|
||||
|
||||
if cmd_word == "commands":
|
||||
await self.push_screen(CommandsScreen())
|
||||
return
|
||||
|
||||
if cmd_word == "stats":
|
||||
await self.push_screen(StatsScreen(self.session))
|
||||
return
|
||||
@@ -288,7 +294,7 @@ class oAIChatApp(App):
|
||||
# Check if there are any arguments
|
||||
args = command_text.split(maxsplit=1)
|
||||
if len(args) == 1: # No arguments, just "/config"
|
||||
await self.push_screen(ConfigScreen(self.settings))
|
||||
await self.push_screen(ConfigScreen(self.settings, self.session))
|
||||
return
|
||||
# If there are arguments, fall through to normal command handler
|
||||
|
||||
@@ -447,9 +453,11 @@ class oAIChatApp(App):
|
||||
# Update header if model changed
|
||||
if self.session.selected_model:
|
||||
header = self.query_one(Header)
|
||||
provider_name = self.session.client.provider_name if self.session.client else ""
|
||||
header.update_model(
|
||||
self.session.selected_model.get("name", ""),
|
||||
self.session.selected_model
|
||||
self.session.selected_model,
|
||||
provider_name
|
||||
)
|
||||
|
||||
# Update MCP status indicator in input bar
|
||||
@@ -472,7 +480,7 @@ class oAIChatApp(App):
|
||||
|
||||
# Create assistant message widget with loading indicator
|
||||
model_name = self.session.selected_model.get("name", "Assistant") if self.session.selected_model else "Assistant"
|
||||
assistant_widget = AssistantMessageWidget(model_name)
|
||||
assistant_widget = AssistantMessageWidget(model_name, chat_display=chat_display)
|
||||
await chat_display.add_message(assistant_widget)
|
||||
|
||||
# Show loading indicator immediately
|
||||
@@ -851,7 +859,13 @@ class oAIChatApp(App):
|
||||
if selected:
|
||||
self.session.set_model(selected)
|
||||
header = self.query_one(Header)
|
||||
header.update_model(selected.get("name", ""), selected)
|
||||
provider_name = self.session.client.provider_name if self.session.client else ""
|
||||
header.update_model(selected.get("name", ""), selected, provider_name)
|
||||
|
||||
# Save this model as the last used for this provider
|
||||
model_id = selected.get("id")
|
||||
if model_id and provider_name:
|
||||
self.settings.set_provider_model(provider_name, model_id)
|
||||
|
||||
# Save as default if requested
|
||||
if set_as_default:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""TUI screens for oAI."""
|
||||
|
||||
from oai.tui.screens.commands_screen import CommandsScreen
|
||||
from oai.tui.screens.config_screen import ConfigScreen
|
||||
from oai.tui.screens.conversation_selector import ConversationSelectorScreen
|
||||
from oai.tui.screens.credits_screen import CreditsScreen
|
||||
@@ -10,6 +11,7 @@ from oai.tui.screens.stats_screen import StatsScreen
|
||||
|
||||
__all__ = [
|
||||
"AlertDialog",
|
||||
"CommandsScreen",
|
||||
"ConfirmDialog",
|
||||
"ConfigScreen",
|
||||
"ConversationSelectorScreen",
|
||||
|
||||
172
oai/tui/screens/commands_screen.py
Normal file
172
oai/tui/screens/commands_screen.py
Normal file
@@ -0,0 +1,172 @@
|
||||
"""Commands reference screen for oAI TUI."""
|
||||
|
||||
from textual.app import ComposeResult
|
||||
from textual.containers import Container, Vertical, VerticalScroll
|
||||
from textual.screen import ModalScreen
|
||||
from textual.widgets import Button, Static
|
||||
|
||||
|
||||
class CommandsScreen(ModalScreen[None]):
|
||||
"""Modal screen showing all available commands."""
|
||||
|
||||
DEFAULT_CSS = """
|
||||
CommandsScreen {
|
||||
align: center middle;
|
||||
}
|
||||
|
||||
CommandsScreen > Container {
|
||||
width: 90;
|
||||
height: 40;
|
||||
background: #1e1e1e;
|
||||
border: solid #555555;
|
||||
}
|
||||
|
||||
CommandsScreen .header {
|
||||
dock: top;
|
||||
width: 100%;
|
||||
height: auto;
|
||||
background: #2d2d2d;
|
||||
color: #cccccc;
|
||||
padding: 0 2;
|
||||
}
|
||||
|
||||
CommandsScreen .content {
|
||||
width: 100%;
|
||||
height: 1fr;
|
||||
background: #1e1e1e;
|
||||
padding: 2;
|
||||
color: #cccccc;
|
||||
overflow-y: auto;
|
||||
scrollbar-background: #1e1e1e;
|
||||
scrollbar-color: #555555;
|
||||
scrollbar-size: 1 1;
|
||||
}
|
||||
|
||||
CommandsScreen .footer {
|
||||
dock: bottom;
|
||||
width: 100%;
|
||||
height: auto;
|
||||
background: #2d2d2d;
|
||||
padding: 1 2;
|
||||
align: center middle;
|
||||
}
|
||||
"""
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
"""Compose the commands screen."""
|
||||
with Container():
|
||||
yield Static("[bold]Commands Reference[/]", classes="header")
|
||||
with VerticalScroll(classes="content"):
|
||||
yield Static(self._get_commands_text(), markup=True)
|
||||
with Vertical(classes="footer"):
|
||||
yield Button("Close", id="close-button", variant="primary")
|
||||
|
||||
def _get_commands_text(self) -> str:
|
||||
"""Generate formatted commands text."""
|
||||
return """[bold cyan]General Commands[/]
|
||||
|
||||
[green]/help[/] - Show help screen with keyboard shortcuts
|
||||
[green]/commands[/] - Show this commands reference
|
||||
[green]/model[/] - Open model selector (or press F2)
|
||||
[green]/stats[/] - Show session statistics (or press Ctrl+S)
|
||||
[green]/credits[/] - Check account credits (OpenRouter) or view console link
|
||||
[green]/clear[/] - Clear chat display
|
||||
[green]/reset[/] - Reset conversation history
|
||||
[green]/retry[/] - Retry last prompt
|
||||
[green]/paste[/] - Paste from clipboard
|
||||
|
||||
[bold cyan]Provider Commands[/]
|
||||
|
||||
[green]/provider[/] - Show current provider
|
||||
[green]/provider openrouter[/] - Switch to OpenRouter
|
||||
[green]/provider anthropic[/] - Switch to Anthropic (Claude)
|
||||
[green]/provider openai[/] - Switch to OpenAI (ChatGPT)
|
||||
[green]/provider ollama[/] - Switch to Ollama (local)
|
||||
|
||||
[bold cyan]Online Mode (Web Search)[/]
|
||||
|
||||
[green]/online[/] - Show online mode status
|
||||
[green]/online on[/] - Enable web search
|
||||
[green]/online off[/] - Disable web search
|
||||
|
||||
[dim]Search Providers:[/]
|
||||
• [yellow]anthropic_native[/] - Anthropic's native search with citations ($0.01/search)
|
||||
• [yellow]duckduckgo[/] - Free web scraping (default, works with all providers)
|
||||
• [yellow]google[/] - Google Custom Search (requires API key)
|
||||
|
||||
[bold cyan]Configuration Commands[/]
|
||||
|
||||
[green]/config[/] - View all settings
|
||||
[green]/config provider <name>[/] - Set default provider
|
||||
[green]/config search_provider <provider>[/] - Set search provider (anthropic_native/duckduckgo/google)
|
||||
[green]/config openrouter_api_key <key>[/] - Set OpenRouter API key
|
||||
[green]/config anthropic_api_key <key>[/] - Set Anthropic API key
|
||||
[green]/config openai_api_key <key>[/] - Set OpenAI API key
|
||||
[green]/config ollama_base_url <url>[/] - Set Ollama server URL
|
||||
[green]/config google_api_key <key>[/] - Set Google API key (for Google search)
|
||||
[green]/config google_search_engine_id <id>[/] - Set Google Search Engine ID
|
||||
[green]/config online on|off[/] - Set default online mode
|
||||
[green]/config stream on|off[/] - Toggle streaming
|
||||
[green]/config model <id>[/] - Set default model
|
||||
[green]/config system <prompt>[/] - Set system prompt
|
||||
[green]/config maxtoken <num>[/] - Set token limit
|
||||
|
||||
[bold cyan]Memory & Context[/]
|
||||
|
||||
[green]/memory on[/] - Enable conversation memory
|
||||
[green]/memory off[/] - Disable memory (fresh context each message)
|
||||
|
||||
[bold cyan]Conversation Management[/]
|
||||
|
||||
[green]/save <name>[/] - Save current conversation
|
||||
[green]/load <name>[/] - Load saved conversation
|
||||
[green]/list[/] - List all saved conversations
|
||||
[green]/delete <name>[/] - Delete a conversation
|
||||
[green]/prev[/] - Show previous message
|
||||
[green]/next[/] - Show next message
|
||||
|
||||
[bold cyan]Export Commands[/]
|
||||
|
||||
[green]/export md <file>[/] - Export conversation as Markdown
|
||||
[green]/export json <file>[/] - Export as JSON
|
||||
[green]/export html <file>[/] - Export as HTML
|
||||
|
||||
[bold cyan]MCP (Model Context Protocol)[/]
|
||||
|
||||
[green]/mcp on[/] - Enable MCP file access
|
||||
[green]/mcp off[/] - Disable MCP
|
||||
[green]/mcp status[/] - Show MCP status
|
||||
[green]/mcp add <path>[/] - Add folder for file access
|
||||
[green]/mcp add db <path>[/] - Add SQLite database
|
||||
[green]/mcp remove <path>[/] - Remove folder/database
|
||||
[green]/mcp list[/] - List allowed folders
|
||||
[green]/mcp db list[/] - List added databases
|
||||
[green]/mcp db <n>[/] - Switch to database mode
|
||||
[green]/mcp files[/] - Switch to file mode
|
||||
[green]/mcp write on[/] - Enable write mode (allows file modifications)
|
||||
[green]/mcp write off[/] - Disable write mode
|
||||
|
||||
[bold cyan]System Prompt[/]
|
||||
|
||||
[green]/system <prompt>[/] - Set custom system prompt for session
|
||||
[green]/config system <prompt>[/] - Set default system prompt
|
||||
|
||||
[bold cyan]Keyboard Shortcuts[/]
|
||||
|
||||
• [yellow]F1[/] - Help screen
|
||||
• [yellow]F2[/] - Model selector
|
||||
• [yellow]Ctrl+S[/] - Statistics
|
||||
• [yellow]Ctrl+Q[/] - Quit
|
||||
• [yellow]Ctrl+Y[/] - Copy latest reply in Markdown
|
||||
• [yellow]Up/Down[/] - Command history
|
||||
• [yellow]Tab[/] - Command completion
|
||||
"""
|
||||
|
||||
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||
"""Handle button press."""
|
||||
self.dismiss()
|
||||
|
||||
def on_key(self, event) -> None:
|
||||
"""Handle keyboard shortcuts."""
|
||||
if event.key in ("escape", "enter"):
|
||||
self.dismiss()
|
||||
@@ -50,9 +50,10 @@ class ConfigScreen(ModalScreen[None]):
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, settings: Settings):
|
||||
def __init__(self, settings: Settings, session=None):
|
||||
super().__init__()
|
||||
self.settings = settings
|
||||
self.session = session
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
"""Compose the screen."""
|
||||
@@ -67,35 +68,90 @@ class ConfigScreen(ModalScreen[None]):
|
||||
"""Generate the configuration text."""
|
||||
from oai.constants import DEFAULT_SYSTEM_PROMPT
|
||||
|
||||
# API Key display
|
||||
api_key_display = "***" + self.settings.api_key[-4:] if self.settings.api_key else "Not set"
|
||||
config_lines = ["[bold cyan]═══ CONFIGURATION ═══[/]\n"]
|
||||
|
||||
# Current Session Info
|
||||
if self.session and self.session.client:
|
||||
config_lines.append("[bold yellow]Current Session:[/]")
|
||||
|
||||
provider = self.session.client.provider_name
|
||||
config_lines.append(f"[bold]Provider:[/] [green]{provider}[/]")
|
||||
|
||||
# Provider URL
|
||||
if hasattr(self.session.client.provider, 'base_url'):
|
||||
provider_url = self.session.client.provider.base_url
|
||||
config_lines.append(f"[bold]Provider URL:[/] {provider_url}")
|
||||
|
||||
# API Key status
|
||||
if provider == "ollama":
|
||||
config_lines.append(f"[bold]API Key:[/] [dim]Not required (local)[/]")
|
||||
else:
|
||||
api_key = self.settings.get_provider_api_key(provider)
|
||||
if api_key:
|
||||
key_display = "***" + api_key[-4:] if len(api_key) > 4 else "***"
|
||||
config_lines.append(f"[bold]API Key:[/] [green]{key_display}[/]")
|
||||
else:
|
||||
config_lines.append(f"[bold]API Key:[/] [red]Not set[/]")
|
||||
|
||||
# Current model
|
||||
if self.session.selected_model:
|
||||
model_name = self.session.selected_model.get("name", "Unknown")
|
||||
config_lines.append(f"[bold]Current Model:[/] {model_name}")
|
||||
|
||||
config_lines.append("")
|
||||
|
||||
# Default Settings
|
||||
config_lines.append("[bold yellow]Default Settings:[/]")
|
||||
config_lines.append(f"[bold]Default Provider:[/] {self.settings.default_provider}")
|
||||
|
||||
# Mask helper
|
||||
def mask_key(key):
|
||||
if not key:
|
||||
return "[red]Not set[/]"
|
||||
return "***" + key[-4:] if len(key) > 4 else "***"
|
||||
|
||||
config_lines.append(f"[bold]OpenRouter Key:[/] {mask_key(self.settings.openrouter_api_key)}")
|
||||
config_lines.append(f"[bold]Anthropic Key:[/] {mask_key(self.settings.anthropic_api_key)}")
|
||||
config_lines.append(f"[bold]OpenAI Key:[/] {mask_key(self.settings.openai_api_key)}")
|
||||
config_lines.append(f"[bold]Ollama URL:[/] {self.settings.ollama_base_url}")
|
||||
config_lines.append(f"[bold]Default Model:[/] {self.settings.default_model or 'Not set'}")
|
||||
|
||||
# System prompt display
|
||||
if self.settings.default_system_prompt is None:
|
||||
system_prompt_display = f"[default] {DEFAULT_SYSTEM_PROMPT[:40]}..."
|
||||
system_prompt_display = f"[dim][default][/] {DEFAULT_SYSTEM_PROMPT[:40]}..."
|
||||
elif self.settings.default_system_prompt == "":
|
||||
system_prompt_display = "[blank]"
|
||||
system_prompt_display = "[dim][blank][/]"
|
||||
else:
|
||||
prompt = self.settings.default_system_prompt
|
||||
system_prompt_display = prompt[:50] + "..." if len(prompt) > 50 else prompt
|
||||
|
||||
return f"""
|
||||
[bold cyan]═══ CONFIGURATION ═══[/]
|
||||
config_lines.append(f"[bold]System Prompt:[/] {system_prompt_display}")
|
||||
config_lines.append("")
|
||||
|
||||
[bold]API Key:[/] {api_key_display}
|
||||
[bold]Base URL:[/] {self.settings.base_url}
|
||||
[bold]Default Model:[/] {self.settings.default_model or "Not set"}
|
||||
# Web Search Configuration
|
||||
config_lines.append("[bold yellow]Web Search Configuration:[/]")
|
||||
config_lines.append(f"[bold]Search Provider:[/] {self.settings.search_provider}")
|
||||
|
||||
[bold]System Prompt:[/] {system_prompt_display}
|
||||
# Show API key status based on search provider
|
||||
if self.settings.search_provider == "google":
|
||||
config_lines.append(f"[bold]Google API Key:[/] {mask_key(self.settings.google_api_key)}")
|
||||
config_lines.append(f"[bold]Search Engine ID:[/] {mask_key(self.settings.google_search_engine_id)}")
|
||||
elif self.settings.search_provider == "duckduckgo":
|
||||
config_lines.append("[dim] DuckDuckGo requires no configuration (free)[/]")
|
||||
elif self.settings.search_provider == "anthropic_native":
|
||||
config_lines.append("[dim] Uses Anthropic API ($0.01 per search)[/]")
|
||||
|
||||
[bold]Streaming:[/] {"on" if self.settings.stream_enabled else "off"}
|
||||
[bold]Cost Warning:[/] ${self.settings.cost_warning_threshold:.4f}
|
||||
[bold]Max Tokens:[/] {self.settings.max_tokens}
|
||||
[bold]Default Online:[/] {"on" if self.settings.default_online_mode else "off"}
|
||||
[bold]Log Level:[/] {self.settings.log_level}
|
||||
config_lines.append("")
|
||||
|
||||
[dim]Use /config [setting] [value] to modify settings[/]
|
||||
"""
|
||||
# Other settings
|
||||
config_lines.append("[bold]Streaming:[/] " + ("on" if self.settings.stream_enabled else "off"))
|
||||
config_lines.append(f"[bold]Cost Warning:[/] ${self.settings.cost_warning_threshold:.4f}")
|
||||
config_lines.append(f"[bold]Max Tokens:[/] {self.settings.max_tokens}")
|
||||
config_lines.append(f"[bold]Default Online:[/] " + ("on" if self.settings.default_online_mode else "off"))
|
||||
config_lines.append(f"[bold]Log Level:[/] {self.settings.log_level}")
|
||||
config_lines.append("\n[dim]Use /config [setting] [value] to modify settings[/]")
|
||||
|
||||
return "\n".join(config_lines)
|
||||
|
||||
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||
"""Handle button press."""
|
||||
|
||||
@@ -83,37 +83,70 @@ class CreditsScreen(ModalScreen[None]):
|
||||
def _get_credits_text(self) -> str:
|
||||
"""Generate the credits text."""
|
||||
if not self.credits_data:
|
||||
return "[yellow]No credit information available[/]"
|
||||
# Provider-specific message when credits aren't available
|
||||
if self.client.provider_name == "anthropic":
|
||||
return """[yellow]Credit information not available via API[/]
|
||||
|
||||
total = self.credits_data.get("total_credits", 0)
|
||||
used = self.credits_data.get("used_credits", 0)
|
||||
Anthropic does not provide a public API endpoint for checking
|
||||
account credits.
|
||||
|
||||
To view your account balance and usage:
|
||||
[cyan]→ Visit console.anthropic.com[/]
|
||||
[cyan]→ Navigate to Settings → Billing[/]
|
||||
"""
|
||||
elif self.client.provider_name == "openai":
|
||||
return """[yellow]Credit information not available via API[/]
|
||||
|
||||
OpenAI does not provide credit balance through their API.
|
||||
|
||||
To view your account usage:
|
||||
[cyan]→ Visit platform.openai.com[/]
|
||||
[cyan]→ Navigate to Usage[/]
|
||||
"""
|
||||
elif self.client.provider_name == "ollama":
|
||||
return """[yellow]Credit information not applicable[/]
|
||||
|
||||
Ollama runs locally on your machine and does not use credits.
|
||||
"""
|
||||
else:
|
||||
return "[yellow]No credit information available[/]"
|
||||
|
||||
provider_name = self.client.provider_name.upper()
|
||||
total = self.credits_data.get("total_credits")
|
||||
used = self.credits_data.get("used_credits")
|
||||
remaining = self.credits_data.get("credits_left", 0)
|
||||
|
||||
# Calculate percentage used
|
||||
if total > 0:
|
||||
percent_used = (used / total) * 100
|
||||
percent_remaining = (remaining / total) * 100
|
||||
else:
|
||||
percent_used = 0
|
||||
percent_remaining = 0
|
||||
|
||||
# Color code based on remaining credits
|
||||
if percent_remaining > 50:
|
||||
# Determine color based on absolute remaining amount
|
||||
if remaining > 10:
|
||||
remaining_color = "green"
|
||||
elif percent_remaining > 20:
|
||||
elif remaining > 2:
|
||||
remaining_color = "yellow"
|
||||
else:
|
||||
remaining_color = "red"
|
||||
|
||||
return f"""
|
||||
[bold cyan]═══ OPENROUTER CREDITS ═══[/]
|
||||
lines = [f"[bold cyan]═══ {provider_name} CREDITS ═══[/]\n"]
|
||||
|
||||
[bold]Total Credits:[/] ${total:.2f}
|
||||
[bold]Used:[/] ${used:.2f} [dim]({percent_used:.1f}%)[/]
|
||||
[bold]Remaining:[/] [{remaining_color}]${remaining:.2f}[/] [dim]({percent_remaining:.1f}%)[/]
|
||||
# If we have total/used info (OpenRouter)
|
||||
if total is not None and used is not None:
|
||||
percent_used = (used / total) * 100 if total > 0 else 0
|
||||
percent_remaining = (remaining / total) * 100 if total > 0 else 0
|
||||
|
||||
[dim]Visit openrouter.ai to add more credits[/]
|
||||
"""
|
||||
lines.append(f"[bold]Total Credits:[/] ${total:.2f}")
|
||||
lines.append(f"[bold]Used:[/] ${used:.2f} [dim]({percent_used:.1f}%)[/]")
|
||||
lines.append(f"[bold]Remaining:[/] [{remaining_color}]${remaining:.2f}[/] [dim]({percent_remaining:.1f}%)[/]")
|
||||
else:
|
||||
# Anthropic - only shows balance
|
||||
lines.append(f"[bold]Current Balance:[/] [{remaining_color}]${remaining:.2f}[/]")
|
||||
|
||||
lines.append("")
|
||||
|
||||
# Provider-specific footer
|
||||
if self.client.provider_name == "openrouter":
|
||||
lines.append("[dim]Visit openrouter.ai to add more credits[/]")
|
||||
elif self.client.provider_name == "anthropic":
|
||||
lines.append("[dim]Visit console.anthropic.com to manage billing[/]")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||
"""Handle button press."""
|
||||
|
||||
@@ -17,9 +17,11 @@ Header {
|
||||
ChatDisplay {
|
||||
background: $background;
|
||||
border: none;
|
||||
padding: 1;
|
||||
padding: 1 0 1 1; /* top right bottom left - no right padding for scrollbar */
|
||||
scrollbar-background: $background;
|
||||
scrollbar-color: $primary;
|
||||
scrollbar-size: 1 1;
|
||||
scrollbar-gutter: stable; /* Reserve space for scrollbar */
|
||||
overflow-y: auto;
|
||||
}
|
||||
|
||||
@@ -43,7 +45,7 @@ SystemMessageWidget {
|
||||
AssistantMessageWidget {
|
||||
margin: 0 0 1 0;
|
||||
padding: 1;
|
||||
background: $panel;
|
||||
background: $background;
|
||||
border-left: thick $accent;
|
||||
height: auto;
|
||||
}
|
||||
@@ -59,6 +61,9 @@ AssistantMessageWidget {
|
||||
color: #cccccc;
|
||||
link-color: #888888;
|
||||
link-style: none;
|
||||
border: none;
|
||||
scrollbar-background: transparent;
|
||||
scrollbar-color: #555555;
|
||||
}
|
||||
|
||||
InputBar {
|
||||
|
||||
@@ -59,9 +59,15 @@ class CommandDropdown(VerticalScroll):
|
||||
# Get base commands with descriptions
|
||||
base_commands = [
|
||||
("/help", "Show help screen"),
|
||||
("/commands", "Show all commands"),
|
||||
("/model", "Select AI model"),
|
||||
("/provider", "Switch AI provider"),
|
||||
("/provider openrouter", "Switch to OpenRouter"),
|
||||
("/provider anthropic", "Switch to Anthropic (Claude)"),
|
||||
("/provider openai", "Switch to OpenAI (GPT)"),
|
||||
("/provider ollama", "Switch to Ollama (local)"),
|
||||
("/stats", "Show session statistics"),
|
||||
("/credits", "Check account credits"),
|
||||
("/credits", "Check account credits or view console link"),
|
||||
("/clear", "Clear chat display"),
|
||||
("/reset", "Reset conversation history"),
|
||||
("/memory on", "Enable conversation memory"),
|
||||
@@ -78,7 +84,19 @@ class CommandDropdown(VerticalScroll):
|
||||
("/prev", "Show previous message"),
|
||||
("/next", "Show next message"),
|
||||
("/config", "View configuration"),
|
||||
("/config api", "Set API key"),
|
||||
("/config provider", "Set default provider"),
|
||||
("/config search_provider", "Set search provider (anthropic_native/duckduckgo/google)"),
|
||||
("/config openrouter_api_key", "Set OpenRouter API key"),
|
||||
("/config anthropic_api_key", "Set Anthropic API key"),
|
||||
("/config openai_api_key", "Set OpenAI API key"),
|
||||
("/config ollama_base_url", "Set Ollama server URL"),
|
||||
("/config google_api_key", "Set Google API key"),
|
||||
("/config google_search_engine_id", "Set Google Search Engine ID"),
|
||||
("/config online", "Set default online mode (on/off)"),
|
||||
("/config stream", "Toggle streaming (on/off)"),
|
||||
("/config model", "Set default model"),
|
||||
("/config system", "Set system prompt"),
|
||||
("/config maxtoken", "Set token limit"),
|
||||
("/system", "Set system prompt"),
|
||||
("/maxtoken", "Set token limit"),
|
||||
("/retry", "Retry last prompt"),
|
||||
@@ -117,12 +135,30 @@ class CommandDropdown(VerticalScroll):
|
||||
# Remove the leading slash for filtering
|
||||
filter_without_slash = filter_text[1:].lower()
|
||||
|
||||
# Filter commands - show if filter text is contained anywhere in the command
|
||||
# Filter commands
|
||||
if filter_without_slash:
|
||||
matching = [
|
||||
(cmd, desc) for cmd, desc in self._all_commands
|
||||
if filter_without_slash in cmd[1:].lower() # Skip the / in command for matching
|
||||
]
|
||||
matching = []
|
||||
|
||||
# Check if user typed a parent command (e.g., "/config")
|
||||
# If so, show all sub-commands that start with it
|
||||
parent_command = "/" + filter_without_slash
|
||||
has_subcommands = any(
|
||||
cmd.startswith(parent_command + " ")
|
||||
for cmd, _ in self._all_commands
|
||||
)
|
||||
|
||||
if has_subcommands and parent_command in [cmd for cmd, _ in self._all_commands]:
|
||||
# Show the parent command and all its sub-commands
|
||||
matching = [
|
||||
(cmd, desc) for cmd, desc in self._all_commands
|
||||
if cmd == parent_command or cmd.startswith(parent_command + " ")
|
||||
]
|
||||
else:
|
||||
# Regular filtering - show if filter text is contained anywhere in the command
|
||||
matching = [
|
||||
(cmd, desc) for cmd, desc in self._all_commands
|
||||
if filter_without_slash in cmd[1:].lower() # Skip the / in command for matching
|
||||
]
|
||||
else:
|
||||
# Show all commands when just "/" is typed
|
||||
matching = self._all_commands
|
||||
@@ -131,8 +167,8 @@ class CommandDropdown(VerticalScroll):
|
||||
self.remove_class("visible")
|
||||
return
|
||||
|
||||
# Add options - limit to 10 results
|
||||
for cmd, desc in matching[:10]:
|
||||
# Add options - limit to 15 results (increased from 10 for sub-commands)
|
||||
for cmd, desc in matching[:15]:
|
||||
# Format: command in white, description in gray, separated by spaces
|
||||
label = f"{cmd} [dim]{desc}[/]" if desc else cmd
|
||||
option_list.add_option(Option(label, id=cmd))
|
||||
|
||||
@@ -8,11 +8,12 @@ from typing import Optional, Dict, Any
|
||||
class Header(Static):
|
||||
"""Header displaying app title, version, current model, and capabilities."""
|
||||
|
||||
def __init__(self, version: str = "3.0.1", model: str = "", model_info: Optional[Dict[str, Any]] = None):
|
||||
def __init__(self, version: str = "3.0.1", model: str = "", model_info: Optional[Dict[str, Any]] = None, provider: str = ""):
|
||||
super().__init__()
|
||||
self.version = version
|
||||
self.model = model
|
||||
self.model_info = model_info or {}
|
||||
self.provider = provider
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
"""Compose the header."""
|
||||
@@ -51,15 +52,32 @@ class Header(Static):
|
||||
|
||||
def _format_header(self) -> str:
|
||||
"""Format the header text."""
|
||||
model_text = f" | {self.model}" if self.model else ""
|
||||
# Show provider : model format
|
||||
if self.provider and self.model:
|
||||
provider_model = f"[bold cyan]{self.provider}[/] [dim]:[/] [bold]{self.model}[/]"
|
||||
elif self.provider:
|
||||
provider_model = f"[bold cyan]{self.provider}[/]"
|
||||
elif self.model:
|
||||
provider_model = f"[bold]{self.model}[/]"
|
||||
else:
|
||||
provider_model = ""
|
||||
|
||||
capabilities = self._format_capabilities()
|
||||
capabilities_text = f" {capabilities}" if capabilities else ""
|
||||
return f"[bold cyan]oAI[/] [dim]v{self.version}[/]{model_text}{capabilities_text}"
|
||||
|
||||
def update_model(self, model: str, model_info: Optional[Dict[str, Any]] = None) -> None:
|
||||
# Format: oAI v{version} | provider : model capabilities
|
||||
version_text = f"[bold cyan]oAI[/] [dim]v{self.version}[/]"
|
||||
if provider_model:
|
||||
return f"{version_text} [dim]|[/] {provider_model}{capabilities_text}"
|
||||
else:
|
||||
return version_text
|
||||
|
||||
def update_model(self, model: str, model_info: Optional[Dict[str, Any]] = None, provider: Optional[str] = None) -> None:
|
||||
"""Update the displayed model and capabilities."""
|
||||
self.model = model
|
||||
if model_info:
|
||||
self.model_info = model_info
|
||||
if provider is not None:
|
||||
self.provider = provider
|
||||
content = self.query_one("#header-content", Static)
|
||||
content.update(self._format_header())
|
||||
|
||||
@@ -53,10 +53,11 @@ class SystemMessageWidget(Static):
|
||||
class AssistantMessageWidget(Static):
|
||||
"""Widget for displaying assistant responses with streaming support."""
|
||||
|
||||
def __init__(self, model_name: str = "Assistant"):
|
||||
def __init__(self, model_name: str = "Assistant", chat_display=None):
|
||||
super().__init__()
|
||||
self.model_name = model_name
|
||||
self.full_text = ""
|
||||
self.chat_display = chat_display
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
"""Compose the assistant message."""
|
||||
@@ -77,6 +78,11 @@ class AssistantMessageWidget(Static):
|
||||
md = Markdown(self.full_text, code_theme="github-dark", inline_code_theme="github-dark")
|
||||
log.write(md)
|
||||
|
||||
# Auto-scroll to keep the latest content visible
|
||||
# Use call_after_refresh to ensure scroll happens after layout update
|
||||
if self.chat_display:
|
||||
self.chat_display.call_after_refresh(self.chat_display.scroll_end, animate=False)
|
||||
|
||||
if hasattr(chunk, "usage") and chunk.usage:
|
||||
usage = chunk.usage
|
||||
|
||||
|
||||
247
oai/utils/web_search.py
Normal file
247
oai/utils/web_search.py
Normal file
@@ -0,0 +1,247 @@
|
||||
"""
|
||||
Web search utilities for oAI.
|
||||
|
||||
Provides web search capabilities for all providers (not just OpenRouter).
|
||||
Uses DuckDuckGo by default (no API key needed).
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Dict, List, Optional
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
import requests
|
||||
|
||||
from oai.utils.logging import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class WebSearchResult:
|
||||
"""Container for a single search result."""
|
||||
|
||||
def __init__(self, title: str, url: str, snippet: str):
|
||||
self.title = title
|
||||
self.url = url
|
||||
self.snippet = snippet
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"WebSearchResult(title='{self.title}', url='{self.url}')"
|
||||
|
||||
|
||||
class WebSearchProvider:
|
||||
"""Base class for web search providers."""
|
||||
|
||||
def search(self, query: str, num_results: int = 5) -> List[WebSearchResult]:
|
||||
"""
|
||||
Perform a web search.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
num_results: Number of results to return
|
||||
|
||||
Returns:
|
||||
List of search results
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class DuckDuckGoSearch(WebSearchProvider):
|
||||
"""DuckDuckGo search provider (no API key needed)."""
|
||||
|
||||
def __init__(self):
|
||||
self.session = requests.Session()
|
||||
self.session.headers.update({
|
||||
'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36'
|
||||
})
|
||||
|
||||
def search(self, query: str, num_results: int = 5) -> List[WebSearchResult]:
|
||||
"""
|
||||
Search using DuckDuckGo HTML interface.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
num_results: Number of results to return (default: 5)
|
||||
|
||||
Returns:
|
||||
List of search results
|
||||
"""
|
||||
try:
|
||||
# Use DuckDuckGo HTML search
|
||||
url = f"https://html.duckduckgo.com/html/?q={quote_plus(query)}"
|
||||
response = self.session.get(url, timeout=10)
|
||||
response.raise_for_status()
|
||||
|
||||
results = []
|
||||
html = response.text
|
||||
|
||||
# Parse results using regex (simple HTML parsing)
|
||||
# Find all result blocks - they end at next result or end of results section
|
||||
result_blocks = re.findall(
|
||||
r'<div class="result results_links.*?(?=<div class="result results_links|<div id="links")',
|
||||
html,
|
||||
re.DOTALL
|
||||
)
|
||||
|
||||
for block in result_blocks[:num_results]:
|
||||
# Extract title and URL - look for result__a class
|
||||
title_match = re.search(r'<a[^>]*class="result__a"[^>]*href="([^"]+)"[^>]*>([^<]+)</a>', block)
|
||||
# Extract snippet - look for result__snippet class
|
||||
snippet_match = re.search(r'<a[^>]*class="result__snippet"[^>]*>([^<]+)</a>', block)
|
||||
|
||||
if title_match:
|
||||
url_raw = title_match.group(1)
|
||||
title = title_match.group(2).strip()
|
||||
|
||||
# Decode HTML entities in title
|
||||
import html as html_module
|
||||
title = html_module.unescape(title)
|
||||
|
||||
snippet = ""
|
||||
if snippet_match:
|
||||
snippet = snippet_match.group(1).strip()
|
||||
snippet = html_module.unescape(snippet)
|
||||
|
||||
# Clean up URL (DDG uses redirect links)
|
||||
if 'uddg=' in url_raw:
|
||||
# Extract actual URL from redirect
|
||||
actual_url_match = re.search(r'uddg=([^&]+)', url_raw)
|
||||
if actual_url_match:
|
||||
from urllib.parse import unquote
|
||||
url_raw = unquote(actual_url_match.group(1))
|
||||
|
||||
results.append(WebSearchResult(
|
||||
title=title,
|
||||
url=url_raw,
|
||||
snippet=snippet
|
||||
))
|
||||
|
||||
logger.info(f"DuckDuckGo search: found {len(results)} results for '{query}'")
|
||||
return results
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"DuckDuckGo search failed: {e}")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing DuckDuckGo results: {e}")
|
||||
return []
|
||||
|
||||
|
||||
class GoogleCustomSearch(WebSearchProvider):
|
||||
"""Google Custom Search API provider (requires API key)."""
|
||||
|
||||
def __init__(self, api_key: str, search_engine_id: str):
|
||||
"""
|
||||
Initialize Google Custom Search.
|
||||
|
||||
Args:
|
||||
api_key: Google API key
|
||||
search_engine_id: Custom Search Engine ID
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.search_engine_id = search_engine_id
|
||||
|
||||
def search(self, query: str, num_results: int = 5) -> List[WebSearchResult]:
|
||||
"""
|
||||
Search using Google Custom Search API.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
num_results: Number of results to return
|
||||
|
||||
Returns:
|
||||
List of search results
|
||||
"""
|
||||
try:
|
||||
url = "https://www.googleapis.com/customsearch/v1"
|
||||
params = {
|
||||
'key': self.api_key,
|
||||
'cx': self.search_engine_id,
|
||||
'q': query,
|
||||
'num': min(num_results, 10) # Google allows max 10
|
||||
}
|
||||
|
||||
response = requests.get(url, params=params, timeout=10)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
results = []
|
||||
for item in data.get('items', []):
|
||||
results.append(WebSearchResult(
|
||||
title=item.get('title', ''),
|
||||
url=item.get('link', ''),
|
||||
snippet=item.get('snippet', '')
|
||||
))
|
||||
|
||||
logger.info(f"Google Custom Search: found {len(results)} results for '{query}'")
|
||||
return results
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Google Custom Search failed: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def perform_web_search(
|
||||
query: str,
|
||||
num_results: int = 5,
|
||||
provider: str = "duckduckgo",
|
||||
**kwargs
|
||||
) -> List[WebSearchResult]:
|
||||
"""
|
||||
Perform a web search using the specified provider.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
num_results: Number of results to return (default: 5)
|
||||
provider: Search provider ("duckduckgo" or "google")
|
||||
**kwargs: Provider-specific arguments (e.g., api_key for Google)
|
||||
|
||||
Returns:
|
||||
List of search results
|
||||
"""
|
||||
if provider == "google":
|
||||
api_key = kwargs.get("google_api_key")
|
||||
search_engine_id = kwargs.get("google_search_engine_id")
|
||||
if not api_key or not search_engine_id:
|
||||
logger.warning("Google search requires api_key and search_engine_id, falling back to DuckDuckGo")
|
||||
provider = "duckduckgo"
|
||||
|
||||
if provider == "google":
|
||||
search_provider = GoogleCustomSearch(api_key, search_engine_id)
|
||||
else:
|
||||
search_provider = DuckDuckGoSearch()
|
||||
|
||||
return search_provider.search(query, num_results)
|
||||
|
||||
|
||||
def format_search_results(results: List[WebSearchResult], max_length: int = 2000) -> str:
|
||||
"""
|
||||
Format search results for inclusion in AI prompt.
|
||||
|
||||
Args:
|
||||
results: List of search results
|
||||
max_length: Maximum total length of formatted results
|
||||
|
||||
Returns:
|
||||
Formatted string with search results
|
||||
"""
|
||||
if not results:
|
||||
return "No search results found."
|
||||
|
||||
formatted = "**Web Search Results:**\n\n"
|
||||
|
||||
for i, result in enumerate(results, 1):
|
||||
result_text = f"{i}. **{result.title}**\n"
|
||||
result_text += f" URL: {result.url}\n"
|
||||
if result.snippet:
|
||||
result_text += f" {result.snippet}\n"
|
||||
result_text += "\n"
|
||||
|
||||
# Check if adding this result would exceed max_length
|
||||
if len(formatted) + len(result_text) > max_length:
|
||||
formatted += f"... ({len(results) - i + 1} more results truncated)\n"
|
||||
break
|
||||
|
||||
formatted += result_text
|
||||
|
||||
return formatted.strip()
|
||||
@@ -4,8 +4,8 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "oai"
|
||||
version = "3.0.0-b2" # MUST match oai/__init__.py __version__
|
||||
description = "OpenRouter AI Chat Client - A feature-rich terminal-based chat application"
|
||||
version = "3.0.0-b3" # MUST match oai/__init__.py __version__
|
||||
description = "Open AI Chat Client - Multi-provider terminal chat with MCP support"
|
||||
readme = "README.md"
|
||||
license = {text = "MIT"}
|
||||
authors = [
|
||||
@@ -39,9 +39,11 @@ classifiers = [
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"anyio>=4.0.0",
|
||||
"anthropic>=0.40.0",
|
||||
"click>=8.0.0",
|
||||
"httpx>=0.24.0",
|
||||
"markdown-it-py>=3.0.0",
|
||||
"openai>=1.59.0",
|
||||
"openrouter>=0.0.19",
|
||||
"packaging>=21.0",
|
||||
"pyperclip>=1.8.0",
|
||||
|
||||
Reference in New Issue
Block a user