4 Commits
v2.1 ... 3.0

42 changed files with 6884 additions and 1937 deletions

1
.gitignore vendored
View File

@@ -45,3 +45,4 @@ b0.sh
requirements.txt requirements.txt
system_prompt.txt system_prompt.txt
CLAUDE* CLAUDE*
SESSION*_COMPLETE.md

228
README.md
View File

@@ -1,19 +1,25 @@
# oAI - OpenRouter AI Chat Client # oAI - Open AI Chat Client
A powerful, extensible terminal-based chat client for OpenRouter API with **MCP (Model Context Protocol)** support, enabling AI to access local files and query SQLite databases. A powerful, modern **Textual TUI** chat client with **multi-provider support** (OpenRouter, Anthropic, OpenAI, Ollama) and **MCP (Model Context Protocol)** integration, enabling AI to access local files and query SQLite databases.
## Features ## Features
### Core Features ### Core Features
- 🤖 Interactive chat with 300+ AI models via OpenRouter - 🖥️ **Modern Textual TUI** with async streaming and beautiful interface
- 🔍 Model selection with search and filtering - 🔄 **Multi-Provider Support** - OpenRouter, Anthropic (Claude), OpenAI (ChatGPT), Ollama (local)
- 🤖 Interactive chat with 300+ AI models across providers
- 🔍 Model selection with search, filtering, and capability icons
- 💾 Conversation save/load/export (Markdown, JSON, HTML) - 💾 Conversation save/load/export (Markdown, JSON, HTML)
- 📎 File attachments (images, PDFs, code files) - 📎 File attachments (images, PDFs, code files)
- 💰 Real-time cost tracking and credit monitoring - 💰 Real-time cost tracking and credit monitoring (OpenRouter)
- 🎨 Rich terminal UI with syntax highlighting - 🎨 Dark theme with syntax highlighting and Markdown rendering
- 📝 Persistent command history with search (Ctrl+R) - 📝 Command history navigation (Up/Down arrows)
- 🌐 Online mode (web search capabilities) - 🌐 **Universal Online Mode** - Web search for ALL providers:
- **Anthropic Native** - Built-in search with automatic citations ($0.01/search)
- **DuckDuckGo** - Free web scraping (all providers)
- **Google Custom Search** - Premium search option
- 🧠 Conversation memory toggle - 🧠 Conversation memory toggle
- ⌨️ Keyboard shortcuts (F1=Help, F2=Models, Ctrl+S=Stats)
### MCP Integration ### MCP Integration
- 🔧 **File Mode**: AI can read, search, and list local files - 🔧 **File Mode**: AI can read, search, and list local files
@@ -34,30 +40,23 @@ A powerful, extensible terminal-based chat client for OpenRouter API with **MCP
## Requirements ## Requirements
- Python 3.10-3.13 - Python 3.10-3.13
- OpenRouter API key ([get one here](https://openrouter.ai)) - API key for your chosen provider:
- **OpenRouter**: [openrouter.ai](https://openrouter.ai)
- **Anthropic**: [console.anthropic.com](https://console.anthropic.com)
- **OpenAI**: [platform.openai.com](https://platform.openai.com)
- **Ollama**: No API key needed (local server)
## Installation ## Installation
### Option 1: Install from Source (Recommended) ### Option 1: Pre-built Binary (macOS/Linux) (Recommended)
```bash
# Clone the repository
git clone https://gitlab.pm/rune/oai.git
cd oai
# Install with pip
pip install -e .
```
### Option 2: Pre-built Binary (macOS/Linux)
Download from [Releases](https://gitlab.pm/rune/oai/releases): Download from [Releases](https://gitlab.pm/rune/oai/releases):
- **macOS (Apple Silicon)**: `oai_v2.1.0_mac_arm64.zip` - **macOS (Apple Silicon)**: `oai_v3.0.0_mac_arm64.zip`
- **Linux (x86_64)**: `oai_v2.1.0_linux_x86_64.zip` - **Linux (x86_64)**: `oai_v3.0.0_linux_x86_64.zip`
```bash ```bash
# Extract and install # Extract and install
unzip oai_v2.1.0_*.zip unzip oai_v3.0.0_*.zip
mkdir -p ~/.local/bin mkdir -p ~/.local/bin
mv oai ~/.local/bin/ mv oai ~/.local/bin/
@@ -73,29 +72,130 @@ xattr -cr ~/.local/bin/oai
export PATH="$HOME/.local/bin:$PATH" export PATH="$HOME/.local/bin:$PATH"
``` ```
### Option 2: Install from Source
```bash
# Clone the repository
git clone https://gitlab.pm/rune/oai.git
cd oai
# Install with pip
pip install -e .
```
## Quick Start ## Quick Start
```bash ```bash
# Start the chat client # Start oAI (launches TUI)
oai chat oai
# Start with specific provider
oai --provider anthropic
oai --provider openai
oai --provider ollama
# Or with options # Or with options
oai chat --model gpt-4o --mcp oai --provider openrouter --model gpt-4o --online --mcp
# Show version
oai version
``` ```
On first run, you'll be prompted for your OpenRouter API key. On first run, you'll be prompted for your API key. Configure additional providers anytime with `/config`.
### Enable Web Search (All Providers)
```bash
# Using Anthropic native search (best quality, automatic citations)
/config search_provider anthropic_native
/online on
# Using free DuckDuckGo (works with all providers)
/config search_provider duckduckgo
/online on
# Using Google Custom Search (requires API key)
/config search_provider google
/config google_api_key YOUR_KEY
/config google_search_engine_id YOUR_ID
/online on
```
### Basic Commands ### Basic Commands
```bash ```bash
# In the chat interface: # In the TUI interface:
/model # Select AI model /provider # Show current provider or switch
/help # Show all commands /provider anthropic # Switch to Anthropic (Claude)
/provider openai # Switch to OpenAI (ChatGPT)
/provider ollama # Switch to Ollama (local)
/model # Select AI model (or press F2)
/online on # Enable web search
/help # Show all commands (or press F1)
/mcp on # Enable file/database access /mcp on # Enable file/database access
/stats # View session statistics /stats # View session statistics (or press Ctrl+S)
exit # Quit /config # View configuration settings
/credits # Check account credits (shows API balance or console link)
Ctrl+Q # Quit
``` ```
## Web Search
oAI provides universal web search capabilities for all AI providers with three options:
### Anthropic Native Search (Recommended for Anthropic)
Anthropic's built-in web search API with automatic citations:
```bash
/config search_provider anthropic_native
/online on
# Now ask questions requiring current information
What are the latest developments in quantum computing?
```
**Features:**
-**Automatic citations** - Claude cites its sources
-**Smart searching** - Claude decides when to search
-**Progressive searches** - Multiple searches for complex queries
-**Best quality** - Professional-grade results
**Pricing:** $10 per 1,000 searches ($0.01 per search) + token costs
**Note:** Only works with Anthropic provider (Claude models)
### DuckDuckGo Search (Default - Free)
Free web scraping that works with ALL providers:
```bash
/config search_provider duckduckgo # Default
/online on
# Works with Anthropic, OpenAI, Ollama, and OpenRouter
```
**Features:**
-**Free** - No API key or costs
-**Universal** - Works with all providers
-**Privacy-friendly** - Uses DuckDuckGo
### Google Custom Search (Premium Option)
Google's Custom Search API for high-quality results:
```bash
/config search_provider google
/config google_api_key YOUR_GOOGLE_API_KEY
/config google_search_engine_id YOUR_SEARCH_ENGINE_ID
/online on
```
Get your API key: [Google Custom Search API](https://developers.google.com/custom-search/v1/overview)
## MCP (Model Context Protocol) ## MCP (Model Context Protocol)
MCP allows the AI to interact with your local files and databases. MCP allows the AI to interact with your local files and databases.
@@ -175,30 +275,39 @@ MCP allows the AI to interact with your local files and databases.
| Command | Description | | Command | Description |
|---------|-------------| |---------|-------------|
| `/config` | View settings | | `/config` | View settings |
| `/config api` | Set API key | | `/config provider <name>` | Set default provider |
| `/config openrouter_api_key` | Set OpenRouter API key |
| `/config anthropic_api_key` | Set Anthropic API key |
| `/config openai_api_key` | Set OpenAI API key |
| `/config ollama_base_url` | Set Ollama server URL |
| `/config search_provider <provider>` | Set search provider (anthropic_native/duckduckgo/google) |
| `/config google_api_key` | Set Google API key (for Google search) |
| `/config online on\|off` | Set default online mode |
| `/config model <id>` | Set default model | | `/config model <id>` | Set default model |
| `/config stream on\|off` | Toggle streaming | | `/config stream on\|off` | Toggle streaming |
| `/stats` | Session statistics | | `/stats` | Session statistics |
| `/credits` | Check credits | | `/credits` | Check credits (OpenRouter) |
## CLI Options ## CLI Options
```bash ```bash
oai chat [OPTIONS] oai [OPTIONS]
Options: Options:
-p, --provider TEXT Provider to use (openrouter/anthropic/openai/ollama)
-m, --model TEXT Model ID to use -m, --model TEXT Model ID to use
-s, --system TEXT System prompt -s, --system TEXT System prompt
-o, --online Enable online mode -o, --online Enable online mode (OpenRouter only)
--mcp Enable MCP server --mcp Enable MCP server
-v, --version Show version
--help Show help --help Show help
``` ```
Other commands: Commands:
```bash ```bash
oai config [setting] [value] # Configure settings oai # Launch TUI (default)
oai version # Show version oai version # Show version information
oai credits # Check credits oai --help # Show help message
``` ```
## Configuration ## Configuration
@@ -218,14 +327,18 @@ oai/
├── oai/ ├── oai/
│ ├── __init__.py │ ├── __init__.py
│ ├── __main__.py # Entry point for python -m oai │ ├── __main__.py # Entry point for python -m oai
│ ├── cli.py # Main CLI interface │ ├── cli.py # Main CLI entry point
│ ├── constants.py # Configuration constants │ ├── constants.py # Configuration constants
│ ├── commands/ # Slash command handlers │ ├── commands/ # Slash command handlers
│ ├── config/ # Settings and database │ ├── config/ # Settings and database
│ ├── core/ # Chat client and session │ ├── core/ # Chat client and session
│ ├── mcp/ # MCP server and tools │ ├── mcp/ # MCP server and tools
│ ├── providers/ # AI provider abstraction │ ├── providers/ # AI provider abstraction
│ ├── ui/ # Terminal UI utilities │ ├── tui/ # Textual TUI interface
│ │ ├── app.py # Main TUI application
│ │ ├── widgets/ # Custom widgets
│ │ ├── screens/ # Modal screens
│ │ └── styles.tcss # TUI styling
│ └── utils/ # Logging, export, etc. │ └── utils/ # Logging, export, etc.
├── pyproject.toml # Package configuration ├── pyproject.toml # Package configuration
├── build.sh # Binary build script ├── build.sh # Binary build script
@@ -266,7 +379,34 @@ pip install -e . --force-reinstall
## Version History ## Version History
### v2.1.0 (Current) ### v3.0.0-b3 (Current - Beta 3)
- 🔄 **Multi-Provider Support** - OpenRouter, Anthropic (Claude), OpenAI (ChatGPT), Ollama (local)
- 🔌 **Provider Switching** - Switch between providers mid-session with `/provider`
- 🧠 **Provider Model Memory** - Remembers last used model per provider
- ⚙️ **Provider Configuration** - Separate API keys for each provider
- 🌐 **Universal Web Search** - Three search options for all providers:
- **Anthropic Native** - Built-in search with citations ($0.01/search)
- **DuckDuckGo** - Free web scraping (default)
- **Google Custom Search** - Premium option
- 🎨 **Enhanced Header** - Shows current provider and model
- 📊 **Updated Config Screen** - Shows provider-specific settings
- 🔧 **Command Dropdown** - Added `/provider` commands for discoverability
- 🎯 **Auto-Scrolling** - Chat window auto-scrolls during streaming responses
- 🎨 **Refined UI** - Thinner scrollbar, cleaner styling
- 🔧 **Updated Models** - Claude 4.5 (Sonnet, Haiku, Opus)
### v3.0.0 (Beta)
- 🎨 **Complete migration to Textual TUI** - Modern async terminal interface
- 🗑️ **Removed CLI interface** - TUI-only for cleaner codebase (11.6% smaller)
- 🖱️ **Modal screens** - Help, stats, config, credits, model selector
- ⌨️ **Keyboard shortcuts** - F1 (help), F2 (models), Ctrl+S (stats), etc.
- 🎯 **Capability indicators** - Visual icons for model features (vision, tools, online)
- 🎨 **Consistent dark theme** - Professional styling throughout
- 📊 **Enhanced model selector** - Search, filter, capability columns
- 🚀 **Default command** - Just run `oai` to launch TUI
- 🧹 **Code cleanup** - Removed 1,300+ lines of CLI code
### v2.1.0
- 🏗️ Complete codebase refactoring to modular package structure - 🏗️ Complete codebase refactoring to modular package structure
- 🔌 Extensible provider architecture for adding new AI providers - 🔌 Extensible provider architecture for adding new AI providers
- 📦 Proper Python packaging with pyproject.toml - 📦 Proper Python packaging with pyproject.toml

View File

@@ -1,16 +1,16 @@
""" """
oAI - OpenRouter AI Chat Client oAI - Open AI Chat Client
A feature-rich terminal-based chat application that provides an interactive CLI A feature-rich terminal-based chat application with multi-provider support
interface to OpenRouter's unified AI API with advanced Model Context Protocol (MCP) (OpenRouter, Anthropic, OpenAI, Ollama) and advanced Model Context Protocol (MCP)
integration for filesystem and database access. integration for filesystem and database access.
Author: Rune Author: Rune
License: MIT License: MIT
""" """
__version__ = "2.1.0" __version__ = "3.0.0-b3"
__author__ = "Rune" __author__ = "Rune Olsen"
__license__ = "MIT" __license__ = "MIT"
# Lazy imports to avoid circular dependencies and improve startup time # Lazy imports to avoid circular dependencies and improve startup time

View File

@@ -1,55 +1,27 @@
""" """
Main CLI entry point for oAI. Main CLI entry point for oAI.
This module provides the command-line interface for the oAI chat application. This module provides the command-line interface for the oAI TUI application.
""" """
import os
import sys import sys
from pathlib import Path
from typing import Optional from typing import Optional
import typer import typer
from prompt_toolkit import PromptSession
from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
from prompt_toolkit.history import FileHistory
from rich.markdown import Markdown
from rich.panel import Panel
from oai import __version__ from oai import __version__
from oai.commands import register_all_commands, registry from oai.commands import register_all_commands
from oai.commands.registry import CommandContext, CommandStatus
from oai.config.database import Database
from oai.config.settings import Settings from oai.config.settings import Settings
from oai.constants import ( from oai.constants import APP_URL, APP_VERSION
APP_NAME,
APP_URL,
APP_VERSION,
CONFIG_DIR,
HISTORY_FILE,
VALID_COMMANDS,
)
from oai.core.client import AIClient from oai.core.client import AIClient
from oai.core.session import ChatSession from oai.core.session import ChatSession
from oai.mcp.manager import MCPManager from oai.mcp.manager import MCPManager
from oai.providers.base import UsageStats
from oai.providers.openrouter import OpenRouterProvider
from oai.ui.console import (
clear_screen,
console,
display_panel,
print_error,
print_info,
print_success,
print_warning,
)
from oai.ui.tables import create_model_table, display_paginated_table
from oai.utils.logging import LoggingManager, get_logger from oai.utils.logging import LoggingManager, get_logger
# Create Typer app # Create Typer app
app = typer.Typer( app = typer.Typer(
name="oai", name="oai",
help=f"oAI - OpenRouter AI Chat Client\n\nVersion: {APP_VERSION}", help=f"oAI - Open AI Chat Client (TUI)\n\nVersion: {APP_VERSION}",
add_completion=False, add_completion=False,
epilog="For more information, visit: " + APP_URL, epilog="For more information, visit: " + APP_URL,
) )
@@ -65,374 +37,151 @@ def main_callback(
help="Show version information", help="Show version information",
is_flag=True, is_flag=True,
), ),
model: Optional[str] = typer.Option(
None,
"--model",
"-m",
help="Model ID to use",
),
system: Optional[str] = typer.Option(
None,
"--system",
"-s",
help="System prompt",
),
online: bool = typer.Option(
False,
"--online",
"-o",
help="Enable online mode",
),
mcp: bool = typer.Option(
False,
"--mcp",
help="Enable MCP server",
),
provider: Optional[str] = typer.Option(
None,
"--provider",
"-p",
help="AI provider to use (openrouter, anthropic, openai, ollama)",
),
) -> None: ) -> None:
"""Main callback to handle global options.""" """Main callback - launches TUI by default."""
# Show version with update check if --version flag
if version_flag: if version_flag:
version_info = check_for_updates(APP_VERSION) typer.echo(f"oAI version {APP_VERSION}")
console.print(version_info)
raise typer.Exit() raise typer.Exit()
# Show version with update check when --help is requested # If no subcommand provided, launch TUI
if "--help" in sys.argv or "-h" in sys.argv:
version_info = check_for_updates(APP_VERSION)
console.print(f"\n{version_info}\n")
# Continue to subcommand if provided
if ctx.invoked_subcommand is None: if ctx.invoked_subcommand is None:
return _launch_tui(model, system, online, mcp, provider)
def check_for_updates(current_version: str) -> str: def _launch_tui(
"""Check for available updates.""" model: Optional[str] = None,
import requests system: Optional[str] = None,
from packaging import version as pkg_version online: bool = False,
mcp: bool = False,
try: provider: Optional[str] = None,
response = requests.get(
"https://gitlab.pm/api/v1/repos/rune/oai/releases/latest",
headers={"Content-Type": "application/json"},
timeout=1.0,
)
response.raise_for_status()
data = response.json()
version_online = data.get("tag_name", "").lstrip("v")
if not version_online:
return f"[bold green]oAI version {current_version}[/]"
current = pkg_version.parse(current_version)
latest = pkg_version.parse(version_online)
if latest > current:
return (
f"[bold green]oAI version {current_version}[/] "
f"[bold red](Update available: {current_version}{version_online})[/]"
)
return f"[bold green]oAI version {current_version} (up to date)[/]"
except Exception:
return f"[bold green]oAI version {current_version}[/]"
def show_welcome(settings: Settings, version_info: str) -> None:
"""Display welcome message."""
console.print(Panel.fit(
f"{version_info}\n\n"
"[bold cyan]Commands:[/] /help for commands, /model to select model\n"
"[bold cyan]MCP:[/] /mcp on to enable file/database access\n"
"[bold cyan]Exit:[/] Type 'exit', 'quit', or 'bye'",
title=f"[bold green]Welcome to {APP_NAME}[/]",
border_style="green",
))
def select_model(client: AIClient, search_term: Optional[str] = None) -> Optional[dict]:
"""Display model selection interface."""
try:
models = client.provider.get_raw_models()
if not models:
print_error("No models available")
return None
# Filter by search term if provided
if search_term:
search_lower = search_term.lower()
models = [m for m in models if search_lower in m.get("id", "").lower()]
if not models:
print_error(f"No models found matching '{search_term}'")
return None
# Create and display table
table = create_model_table(models)
display_paginated_table(
table,
f"[bold green]Available Models ({len(models)})[/]",
)
# Prompt for selection
console.print("")
try:
choice = input("Enter model number (or press Enter to cancel): ").strip()
except (EOFError, KeyboardInterrupt):
return None
if not choice:
return None
try:
index = int(choice) - 1
if 0 <= index < len(models):
selected = models[index]
print_success(f"Selected model: {selected['id']}")
return selected
except ValueError:
pass
print_error("Invalid selection")
return None
except Exception as e:
print_error(f"Failed to fetch models: {e}")
return None
def run_chat_loop(
session: ChatSession,
prompt_session: PromptSession,
settings: Settings,
) -> None: ) -> None:
"""Run the main chat loop.""" """Launch the Textual TUI interface."""
from oai.constants import VALID_PROVIDERS
# Setup logging
logging_manager = LoggingManager()
logging_manager.setup()
logger = get_logger() logger = get_logger()
mcp_manager = session.mcp_manager
while True: # Load settings
settings = Settings.load()
# Determine provider
selected_provider = provider or settings.default_provider
# Validate provider
if selected_provider not in VALID_PROVIDERS:
typer.echo(f"Error: Invalid provider: {selected_provider}", err=True)
typer.echo(f"Valid providers: {', '.join(VALID_PROVIDERS)}", err=True)
raise typer.Exit(1)
# Build provider API keys dict
provider_api_keys = {
"openrouter": settings.openrouter_api_key,
"anthropic": settings.anthropic_api_key,
"openai": settings.openai_api_key,
}
# Check if provider is configured (except Ollama which doesn't need API key)
if selected_provider != "ollama":
if not provider_api_keys.get(selected_provider):
typer.echo(f"Error: No API key configured for {selected_provider}", err=True)
typer.echo(f"Set it with: oai config {selected_provider}_api_key <key>", err=True)
raise typer.Exit(1)
# Initialize client
try:
client = AIClient(
provider_name=selected_provider,
provider_api_keys=provider_api_keys,
ollama_base_url=settings.ollama_base_url,
)
except Exception as e:
typer.echo(f"Error: Failed to initialize client: {e}", err=True)
raise typer.Exit(1)
# Register commands
register_all_commands()
# Initialize MCP manager (always create it, even if not enabled)
mcp_manager = MCPManager()
if mcp:
try: try:
# Build prompt prefix result = mcp_manager.enable()
prefix = "You> " if result["success"]:
if mcp_manager and mcp_manager.enabled: logger.info("MCP server enabled in files mode")
if mcp_manager.mode == "files":
if mcp_manager.write_enabled:
prefix = "[🔧✍️ MCP: Files+Write] You> "
else:
prefix = "[🔧 MCP: Files] You> "
elif mcp_manager.mode == "database" and mcp_manager.selected_db_index is not None:
prefix = f"[🗄️ MCP: DB #{mcp_manager.selected_db_index + 1}] You> "
# Get user input
user_input = prompt_session.prompt(
prefix,
auto_suggest=AutoSuggestFromHistory(),
).strip()
if not user_input:
continue
# Handle escape sequence
if user_input.startswith("//"):
user_input = user_input[1:]
# Check for exit
if user_input.lower() in ["exit", "quit", "bye"]:
console.print(
f"\n[bold yellow]Goodbye![/]\n"
f"[dim]Session: {session.stats.total_tokens:,} tokens, "
f"${session.stats.total_cost:.4f}[/]"
)
logger.info(
f"Session ended. Messages: {session.stats.message_count}, "
f"Tokens: {session.stats.total_tokens}, "
f"Cost: ${session.stats.total_cost:.4f}"
)
return
# Check for unknown commands
if user_input.startswith("/"):
cmd_word = user_input.split()[0].lower()
if not registry.is_command(user_input):
# Check if it's a valid command prefix
is_valid = any(cmd_word.startswith(cmd) for cmd in VALID_COMMANDS)
if not is_valid:
print_error(f"Unknown command: {cmd_word}")
print_info("Type /help to see available commands.")
continue
# Try to execute as command
context = session.get_context()
result = registry.execute(user_input, context)
if result:
# Update session state from context
session.memory_enabled = context.memory_enabled
session.memory_start_index = context.memory_start_index
session.online_enabled = context.online_enabled
session.middle_out_enabled = context.middle_out_enabled
session.session_max_token = context.session_max_token
session.current_index = context.current_index
session.system_prompt = context.session_system_prompt
if result.status == CommandStatus.EXIT:
return
# Handle special results
if result.data:
# Retry - resend last prompt
if "retry_prompt" in result.data:
user_input = result.data["retry_prompt"]
# Fall through to send message
# Paste - send clipboard content
elif "paste_prompt" in result.data:
user_input = result.data["paste_prompt"]
# Fall through to send message
# Model selection
elif "show_model_selector" in result.data:
search = result.data.get("search", "")
model = select_model(session.client, search if search else None)
if model:
session.set_model(model)
# If this came from /config model, also save as default
if result.data.get("set_as_default"):
settings.set_default_model(model["id"])
print_success(f"Default model set to: {model['id']}")
continue
# Load conversation
elif "load_conversation" in result.data:
history = result.data.get("history", [])
session.history.clear()
from oai.core.session import HistoryEntry
for entry in history:
session.history.append(HistoryEntry(
prompt=entry.get("prompt", ""),
response=entry.get("response", ""),
prompt_tokens=entry.get("prompt_tokens", 0),
completion_tokens=entry.get("completion_tokens", 0),
msg_cost=entry.get("msg_cost", 0.0),
))
session.current_index = len(session.history) - 1
continue
else:
# Normal command completed
continue
else:
# Command completed with no special data
continue
# Ensure model is selected
if not session.selected_model:
print_warning("Please select a model first with /model")
continue
# Send message
stream = settings.stream_enabled
if mcp_manager and mcp_manager.enabled:
tools = session.get_mcp_tools()
if tools:
stream = False # Disable streaming with tools
if stream:
console.print(
"[bold green]Streaming response...[/] "
"[dim](Press Ctrl+C to cancel)[/]"
)
if session.online_enabled:
console.print("[dim cyan]🌐 Online mode active[/]")
console.print("")
try:
response_text, usage, response_time = session.send_message(
user_input,
stream=stream,
)
except Exception as e:
print_error(f"Error: {e}")
logger.error(f"Message error: {e}")
continue
if not response_text:
print_error("No response received")
continue
# Display non-streaming response
if not stream:
console.print()
display_panel(
Markdown(response_text),
title="[bold green]AI Response[/]",
border_style="green",
)
# Calculate cost and tokens
cost = 0.0
tokens = 0
estimated = False
if usage and (usage.prompt_tokens > 0 or usage.completion_tokens > 0):
tokens = usage.total_tokens
if usage.total_cost_usd:
cost = usage.total_cost_usd
else:
cost = session.client.estimate_cost(
session.selected_model["id"],
usage.prompt_tokens,
usage.completion_tokens,
)
else: else:
# Estimate tokens when usage not available (streaming fallback) logger.warning(f"MCP: {result.get('error', 'Failed to enable')}")
# Rough estimate: ~4 characters per token for English text except Exception as e:
est_input_tokens = len(user_input) // 4 + 1 logger.warning(f"Failed to enable MCP: {e}")
est_output_tokens = len(response_text) // 4 + 1
tokens = est_input_tokens + est_output_tokens
cost = session.client.estimate_cost(
session.selected_model["id"],
est_input_tokens,
est_output_tokens,
)
# Create estimated usage for session tracking
usage = UsageStats(
prompt_tokens=est_input_tokens,
completion_tokens=est_output_tokens,
total_tokens=tokens,
)
estimated = True
# Add to history # Create session with MCP manager
session.add_to_history(user_input, response_text, usage, cost) session = ChatSession(
client=client,
settings=settings,
mcp_manager=mcp_manager,
)
# Display metrics # Set system prompt if provided
est_marker = "~" if estimated else "" if system:
context_info = "" session.set_system_prompt(system)
if session.memory_enabled:
context_count = len(session.history) - session.memory_start_index
if context_count > 1:
context_info = f", Context: {context_count} msg(s)"
else:
context_info = ", Memory: OFF"
online_emoji = " 🌐" if session.online_enabled else "" # Enable online mode if requested
mcp_emoji = "" if online:
if mcp_manager and mcp_manager.enabled: session.online_enabled = True
if mcp_manager.mode == "files":
mcp_emoji = " 🔧"
elif mcp_manager.mode == "database":
mcp_emoji = " 🗄️"
console.print( # Set model if specified, otherwise use default
f"\n[dim blue]📊 {est_marker}{tokens} tokens | {est_marker}${cost:.4f} | {response_time:.2f}s" if model:
f"{context_info}{online_emoji}{mcp_emoji} | " raw_model = client.get_raw_model(model)
f"Session: {est_marker}{session.stats.total_tokens:,} tokens | " if raw_model:
f"{est_marker}${session.stats.total_cost:.4f}[/]" session.set_model(raw_model)
) else:
logger.warning(f"Model '{model}' not found")
elif settings.default_model:
raw_model = client.get_raw_model(settings.default_model)
if raw_model:
session.set_model(raw_model)
else:
logger.warning(f"Default model '{settings.default_model}' not available")
# Check warnings # Run Textual app
warnings = session.check_warnings() from oai.tui.app import oAIChatApp
for warning in warnings:
print_warning(warning)
# Offer to copy app_instance = oAIChatApp(session, settings, model)
console.print("") app_instance.run()
try:
from oai.ui.prompts import prompt_copy_response
prompt_copy_response(response_text)
except Exception:
pass
console.print("")
except KeyboardInterrupt:
console.print("\n[dim]Input cancelled[/]")
continue
except EOFError:
console.print("\n[bold yellow]Goodbye![/]")
return
@app.command() @app.command()
def chat( def tui(
model: Optional[str] = typer.Option( model: Optional[str] = typer.Option(
None, None,
"--model", "--model",
@@ -457,261 +206,19 @@ def chat(
help="Enable MCP server", help="Enable MCP server",
), ),
) -> None: ) -> None:
"""Start an interactive chat session.""" """Start Textual TUI interface (alias for just running 'oai')."""
# Setup logging _launch_tui(model, system, online, mcp)
logging_manager = LoggingManager()
logging_manager.setup()
logger = get_logger()
# Clear screen
clear_screen()
# Load settings
settings = Settings.load()
# Check API key
if not settings.api_key:
print_error("No API key configured")
print_info("Run: oai --config api to set your API key")
raise typer.Exit(1)
# Initialize client
try:
client = AIClient(
api_key=settings.api_key,
base_url=settings.base_url,
)
except Exception as e:
print_error(f"Failed to initialize client: {e}")
raise typer.Exit(1)
# Register commands
register_all_commands()
# Check for updates and show welcome
version_info = check_for_updates(APP_VERSION)
show_welcome(settings, version_info)
# Initialize MCP manager
mcp_manager = MCPManager()
if mcp:
result = mcp_manager.enable()
if result["success"]:
print_success("MCP enabled")
else:
print_warning(f"MCP: {result.get('error', 'Failed to enable')}")
# Create session
session = ChatSession(
client=client,
settings=settings,
mcp_manager=mcp_manager,
)
# Set system prompt
if system:
session.system_prompt = system
print_info(f"System prompt: {system}")
# Set online mode
if online:
session.online_enabled = True
print_info("Online mode enabled")
# Select model
if model:
raw_model = client.get_raw_model(model)
if raw_model:
session.set_model(raw_model)
else:
print_warning(f"Model '{model}' not found")
elif settings.default_model:
raw_model = client.get_raw_model(settings.default_model)
if raw_model:
session.set_model(raw_model)
else:
print_warning(f"Default model '{settings.default_model}' not available")
# Setup prompt session
HISTORY_FILE.parent.mkdir(parents=True, exist_ok=True)
prompt_session = PromptSession(
history=FileHistory(str(HISTORY_FILE)),
)
# Run chat loop
run_chat_loop(session, prompt_session, settings)
@app.command()
def config(
setting: Optional[str] = typer.Argument(
None,
help="Setting to configure (api, url, model, system, stream, costwarning, maxtoken, online, log, loglevel)",
),
value: Optional[str] = typer.Argument(
None,
help="Value to set",
),
) -> None:
"""View or modify configuration settings."""
settings = Settings.load()
if not setting:
# Show all settings
from rich.table import Table
from oai.constants import DEFAULT_SYSTEM_PROMPT
table = Table("Setting", "Value", show_header=True, header_style="bold magenta")
table.add_row("API Key", "***" + settings.api_key[-4:] if settings.api_key else "Not set")
table.add_row("Base URL", settings.base_url)
table.add_row("Default Model", settings.default_model or "Not set")
# Show system prompt status
if settings.default_system_prompt is None:
system_prompt_display = f"[default] {DEFAULT_SYSTEM_PROMPT[:40]}..."
elif settings.default_system_prompt == "":
system_prompt_display = "[blank]"
else:
system_prompt_display = settings.default_system_prompt[:50] + "..." if len(settings.default_system_prompt) > 50 else settings.default_system_prompt
table.add_row("System Prompt", system_prompt_display)
table.add_row("Streaming", "on" if settings.stream_enabled else "off")
table.add_row("Cost Warning", f"${settings.cost_warning_threshold:.4f}")
table.add_row("Max Tokens", str(settings.max_tokens))
table.add_row("Default Online", "on" if settings.default_online_mode else "off")
table.add_row("Log Level", settings.log_level)
display_panel(table, title="[bold green]Configuration[/]")
return
setting = setting.lower()
if setting == "api":
if value:
settings.set_api_key(value)
else:
from oai.ui.prompts import prompt_input
new_key = prompt_input("Enter API key", password=True)
if new_key:
settings.set_api_key(new_key)
print_success("API key updated")
elif setting == "url":
settings.set_base_url(value or "https://openrouter.ai/api/v1")
print_success(f"Base URL set to: {settings.base_url}")
elif setting == "model":
if value:
settings.set_default_model(value)
print_success(f"Default model set to: {value}")
else:
print_info(f"Current default model: {settings.default_model or 'Not set'}")
elif setting == "system":
from oai.constants import DEFAULT_SYSTEM_PROMPT
if value:
# Decode escape sequences like \n for newlines
value = value.encode().decode('unicode_escape')
settings.set_default_system_prompt(value)
if value:
print_success(f"Default system prompt set to: {value}")
else:
print_success("Default system prompt set to blank.")
else:
if settings.default_system_prompt is None:
print_info(f"Using hardcoded default: {DEFAULT_SYSTEM_PROMPT[:60]}...")
elif settings.default_system_prompt == "":
print_info("System prompt: [blank]")
else:
print_info(f"System prompt: {settings.default_system_prompt}")
elif setting == "stream":
if value and value.lower() in ["on", "off"]:
settings.set_stream_enabled(value.lower() == "on")
print_success(f"Streaming {'enabled' if settings.stream_enabled else 'disabled'}")
else:
print_info("Usage: oai config stream [on|off]")
elif setting == "costwarning":
if value:
try:
threshold = float(value)
settings.set_cost_warning_threshold(threshold)
print_success(f"Cost warning threshold set to: ${threshold:.4f}")
except ValueError:
print_error("Please enter a valid number")
else:
print_info(f"Current threshold: ${settings.cost_warning_threshold:.4f}")
elif setting == "maxtoken":
if value:
try:
max_tok = int(value)
settings.set_max_tokens(max_tok)
print_success(f"Max tokens set to: {max_tok}")
except ValueError:
print_error("Please enter a valid number")
else:
print_info(f"Current max tokens: {settings.max_tokens}")
elif setting == "online":
if value and value.lower() in ["on", "off"]:
settings.set_default_online_mode(value.lower() == "on")
print_success(f"Default online mode {'enabled' if settings.default_online_mode else 'disabled'}")
else:
print_info("Usage: oai config online [on|off]")
elif setting == "loglevel":
valid_levels = ["debug", "info", "warning", "error", "critical"]
if value and value.lower() in valid_levels:
settings.set_log_level(value.lower())
print_success(f"Log level set to: {value.lower()}")
else:
print_info(f"Valid levels: {', '.join(valid_levels)}")
else:
print_error(f"Unknown setting: {setting}")
@app.command() @app.command()
def version() -> None: def version() -> None:
"""Show version information.""" """Show version information."""
version_info = check_for_updates(APP_VERSION) typer.echo(f"oAI version {APP_VERSION}")
console.print(version_info) typer.echo(f"Visit {APP_URL} for more information")
@app.command()
def credits() -> None:
"""Check account credits."""
settings = Settings.load()
if not settings.api_key:
print_error("No API key configured")
raise typer.Exit(1)
client = AIClient(api_key=settings.api_key, base_url=settings.base_url)
credits_data = client.get_credits()
if not credits_data:
print_error("Failed to fetch credits")
raise typer.Exit(1)
from rich.table import Table
table = Table("Metric", "Value", show_header=True, header_style="bold magenta")
table.add_row("Total Credits", credits_data.get("total_credits_formatted", "N/A"))
table.add_row("Used Credits", credits_data.get("used_credits_formatted", "N/A"))
table.add_row("Credits Left", credits_data.get("credits_left_formatted", "N/A"))
display_panel(table, title="[bold green]Account Credits[/]")
def main() -> None: def main() -> None:
"""Main entry point.""" """Entry point for the CLI."""
# Default to 'chat' command if no arguments provided
if len(sys.argv) == 1:
sys.argv.append("chat")
app() app()

File diff suppressed because it is too large Load Diff

View File

@@ -80,6 +80,7 @@ class CommandContext:
online_enabled: Whether online mode is enabled online_enabled: Whether online mode is enabled
session_tokens: Session token counts session_tokens: Session token counts
session_cost: Session cost total session_cost: Session cost total
session: Reference to the ChatSession (for provider switching, etc.)
""" """
settings: Optional["Settings"] = None settings: Optional["Settings"] = None
@@ -98,7 +99,9 @@ class CommandContext:
total_output_tokens: int = 0 total_output_tokens: int = 0
total_cost: float = 0.0 total_cost: float = 0.0
message_count: int = 0 message_count: int = 0
is_tui: bool = False # Flag for TUI mode
current_index: int = 0 current_index: int = 0
session: Optional[Any] = None # Reference to ChatSession (avoid circular import)
@dataclass @dataclass

View File

@@ -6,8 +6,9 @@ configuration with type safety, validation, and persistence.
""" """
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional from typing import Optional, Dict
from pathlib import Path from pathlib import Path
import json
from oai.constants import ( from oai.constants import (
DEFAULT_BASE_URL, DEFAULT_BASE_URL,
@@ -20,6 +21,8 @@ from oai.constants import (
DEFAULT_LOG_LEVEL, DEFAULT_LOG_LEVEL,
DEFAULT_SYSTEM_PROMPT, DEFAULT_SYSTEM_PROMPT,
VALID_LOG_LEVELS, VALID_LOG_LEVELS,
DEFAULT_PROVIDER,
OLLAMA_DEFAULT_URL,
) )
from oai.config.database import get_database from oai.config.database import get_database
@@ -34,7 +37,7 @@ class Settings:
initialization and can be persisted back. initialization and can be persisted back.
Attributes: Attributes:
api_key: OpenRouter API key api_key: Legacy OpenRouter API key (deprecated, use openrouter_api_key)
base_url: API base URL base_url: API base URL
default_model: Default model ID to use default_model: Default model ID to use
default_system_prompt: Custom system prompt (None = use hardcoded default, "" = blank) default_system_prompt: Custom system prompt (None = use hardcoded default, "" = blank)
@@ -45,9 +48,33 @@ class Settings:
log_max_size_mb: Maximum log file size in MB log_max_size_mb: Maximum log file size in MB
log_backup_count: Number of log file backups to keep log_backup_count: Number of log file backups to keep
log_level: Logging level (debug/info/warning/error/critical) log_level: Logging level (debug/info/warning/error/critical)
# Provider-specific settings
default_provider: Default AI provider to use
openrouter_api_key: OpenRouter API key
anthropic_api_key: Anthropic API key
openai_api_key: OpenAI API key
ollama_base_url: Ollama server URL
""" """
# Legacy field (kept for backward compatibility)
api_key: Optional[str] = None api_key: Optional[str] = None
# Provider configuration
default_provider: str = DEFAULT_PROVIDER
openrouter_api_key: Optional[str] = None
anthropic_api_key: Optional[str] = None
openai_api_key: Optional[str] = None
ollama_base_url: str = OLLAMA_DEFAULT_URL
provider_models: Dict[str, str] = field(default_factory=dict) # provider -> last_model_id
# Web search configuration (for online mode with non-OpenRouter providers)
search_provider: str = "duckduckgo" # "duckduckgo" or "google"
google_api_key: Optional[str] = None
google_search_engine_id: Optional[str] = None
search_num_results: int = 5
# General settings
base_url: str = DEFAULT_BASE_URL base_url: str = DEFAULT_BASE_URL
default_model: Optional[str] = None default_model: Optional[str] = None
default_system_prompt: Optional[str] = None default_system_prompt: Optional[str] = None
@@ -134,8 +161,43 @@ class Settings:
# Get system prompt from DB: None means not set (use default), "" means explicitly blank # Get system prompt from DB: None means not set (use default), "" means explicitly blank
system_prompt_value = db.get_config("default_system_prompt") system_prompt_value = db.get_config("default_system_prompt")
# Migration: copy legacy api_key to openrouter_api_key if not already set
legacy_api_key = db.get_config("api_key")
openrouter_key = db.get_config("openrouter_api_key")
if legacy_api_key and not openrouter_key:
db.set_config("openrouter_api_key", legacy_api_key)
openrouter_key = legacy_api_key
# Note: We keep the legacy api_key in DB for backward compatibility
# Load provider-model mapping
provider_models_json = db.get_config("provider_models")
provider_models = {}
if provider_models_json:
try:
provider_models = json.loads(provider_models_json)
except json.JSONDecodeError:
provider_models = {}
return cls( return cls(
api_key=db.get_config("api_key"), # Legacy field
api_key=legacy_api_key,
# Provider configuration
default_provider=db.get_config("default_provider") or DEFAULT_PROVIDER,
openrouter_api_key=openrouter_key,
anthropic_api_key=db.get_config("anthropic_api_key"),
openai_api_key=db.get_config("openai_api_key"),
ollama_base_url=db.get_config("ollama_base_url") or OLLAMA_DEFAULT_URL,
provider_models=provider_models,
# Web search configuration
search_provider=db.get_config("search_provider") or "duckduckgo",
google_api_key=db.get_config("google_api_key"),
google_search_engine_id=db.get_config("google_search_engine_id"),
search_num_results=parse_int(db.get_config("search_num_results"), 5),
# General settings
base_url=db.get_config("base_url") or DEFAULT_BASE_URL, base_url=db.get_config("base_url") or DEFAULT_BASE_URL,
default_model=db.get_config("default_model"), default_model=db.get_config("default_model"),
default_system_prompt=system_prompt_value, default_system_prompt=system_prompt_value,
@@ -331,6 +393,155 @@ class Settings:
self.log_max_size_mb = min(size_mb, 100) self.log_max_size_mb = min(size_mb, 100)
get_database().set_config("log_max_size_mb", str(self.log_max_size_mb)) get_database().set_config("log_max_size_mb", str(self.log_max_size_mb))
def set_provider_api_key(self, provider: str, api_key: str) -> None:
"""
Set and persist an API key for a specific provider.
Args:
provider: Provider name ("openrouter", "anthropic", "openai")
api_key: The API key to set
Raises:
ValueError: If provider is invalid
"""
provider = provider.lower()
api_key = api_key.strip()
if provider == "openrouter":
self.openrouter_api_key = api_key
get_database().set_config("openrouter_api_key", api_key)
elif provider == "anthropic":
self.anthropic_api_key = api_key
get_database().set_config("anthropic_api_key", api_key)
elif provider == "openai":
self.openai_api_key = api_key
get_database().set_config("openai_api_key", api_key)
else:
raise ValueError(f"Invalid provider: {provider}")
def get_provider_api_key(self, provider: str) -> Optional[str]:
"""
Get the API key for a specific provider.
Args:
provider: Provider name ("openrouter", "anthropic", "openai", "ollama")
Returns:
API key or None if not set
Raises:
ValueError: If provider is invalid
"""
provider = provider.lower()
if provider == "openrouter":
return self.openrouter_api_key
elif provider == "anthropic":
return self.anthropic_api_key
elif provider == "openai":
return self.openai_api_key
elif provider == "ollama":
return "" # Ollama doesn't require an API key
else:
raise ValueError(f"Invalid provider: {provider}")
def set_default_provider(self, provider: str) -> None:
"""
Set and persist the default provider.
Args:
provider: Provider name
Raises:
ValueError: If provider is invalid
"""
from oai.constants import VALID_PROVIDERS
provider = provider.lower()
if provider not in VALID_PROVIDERS:
raise ValueError(
f"Invalid provider: {provider}. "
f"Valid providers: {', '.join(VALID_PROVIDERS)}"
)
self.default_provider = provider
get_database().set_config("default_provider", provider)
def set_ollama_base_url(self, url: str) -> None:
"""
Set and persist the Ollama base URL.
Args:
url: Ollama server URL
"""
self.ollama_base_url = url.strip()
get_database().set_config("ollama_base_url", self.ollama_base_url)
def set_search_provider(self, provider: str) -> None:
"""
Set and persist the web search provider.
Args:
provider: Search provider ("anthropic_native", "duckduckgo", "google")
Raises:
ValueError: If provider is invalid
"""
valid_providers = ["anthropic_native", "duckduckgo", "google"]
provider = provider.lower()
if provider not in valid_providers:
raise ValueError(
f"Invalid search provider: {provider}. "
f"Valid providers: {', '.join(valid_providers)}"
)
self.search_provider = provider
get_database().set_config("search_provider", provider)
def set_google_api_key(self, api_key: str) -> None:
"""
Set and persist the Google API key for Google Custom Search.
Args:
api_key: The Google API key
"""
self.google_api_key = api_key.strip()
get_database().set_config("google_api_key", self.google_api_key)
def set_google_search_engine_id(self, engine_id: str) -> None:
"""
Set and persist the Google Custom Search Engine ID.
Args:
engine_id: The Google Search Engine ID
"""
self.google_search_engine_id = engine_id.strip()
get_database().set_config("google_search_engine_id", self.google_search_engine_id)
def get_provider_model(self, provider: str) -> Optional[str]:
"""
Get the last used model for a provider.
Args:
provider: Provider name
Returns:
Model ID or None if not set
"""
return self.provider_models.get(provider)
def set_provider_model(self, provider: str, model_id: str) -> None:
"""
Set and persist the last used model for a provider.
Args:
provider: Provider name
model_id: Model ID to remember
"""
self.provider_models[provider] = model_id
# Save to database as JSON
get_database().set_config("provider_models", json.dumps(self.provider_models))
# Global settings instance # Global settings instance
_settings: Optional[Settings] = None _settings: Optional[Settings] = None

View File

@@ -10,14 +10,17 @@ from pathlib import Path
from typing import Set, Dict, Any from typing import Set, Dict, Any
import logging import logging
# Import version from single source of truth
from oai import __version__
# ============================================================================= # =============================================================================
# APPLICATION METADATA # APPLICATION METADATA
# ============================================================================= # =============================================================================
APP_NAME = "oAI" APP_NAME = "oAI"
APP_VERSION = "2.1.0" APP_VERSION = __version__ # Single source of truth in oai/__init__.py
APP_URL = "https://iurl.no/oai" APP_URL = "https://iurl.no/oai"
APP_DESCRIPTION = "OpenRouter AI Chat Client with MCP Integration" APP_DESCRIPTION = "Open AI Chat Client with Multi-Provider Support"
# ============================================================================= # =============================================================================
# FILE PATHS # FILE PATHS
@@ -39,6 +42,26 @@ DEFAULT_STREAM_ENABLED = True
DEFAULT_MAX_TOKENS = 100_000 DEFAULT_MAX_TOKENS = 100_000
DEFAULT_ONLINE_MODE = False DEFAULT_ONLINE_MODE = False
# =============================================================================
# PROVIDER CONFIGURATION
# =============================================================================
# Provider names
PROVIDER_OPENROUTER = "openrouter"
PROVIDER_ANTHROPIC = "anthropic"
PROVIDER_OPENAI = "openai"
PROVIDER_OLLAMA = "ollama"
VALID_PROVIDERS = [PROVIDER_OPENROUTER, PROVIDER_ANTHROPIC, PROVIDER_OPENAI, PROVIDER_OLLAMA]
# Provider base URLs
ANTHROPIC_BASE_URL = "https://api.anthropic.com"
OPENAI_BASE_URL = "https://api.openai.com/v1"
OLLAMA_DEFAULT_URL = "http://localhost:11434"
# Default provider
DEFAULT_PROVIDER = PROVIDER_OPENROUTER
# ============================================================================= # =============================================================================
# DEFAULT SYSTEM PROMPT # DEFAULT SYSTEM PROMPT
# ============================================================================= # =============================================================================

View File

@@ -9,7 +9,7 @@ import asyncio
import json import json
from typing import Any, Callable, Dict, Iterator, List, Optional, Union from typing import Any, Callable, Dict, Iterator, List, Optional, Union
from oai.constants import APP_NAME, APP_URL, MODEL_PRICING from oai.constants import APP_NAME, APP_URL, MODEL_PRICING, OLLAMA_DEFAULT_URL
from oai.providers.base import ( from oai.providers.base import (
AIProvider, AIProvider,
ChatMessage, ChatMessage,
@@ -19,7 +19,6 @@ from oai.providers.base import (
ToolCall, ToolCall,
UsageStats, UsageStats,
) )
from oai.providers.openrouter import OpenRouterProvider
from oai.utils.logging import get_logger from oai.utils.logging import get_logger
@@ -32,37 +31,66 @@ class AIClient:
Attributes: Attributes:
provider: The underlying AI provider provider: The underlying AI provider
provider_name: Name of the current provider
default_model: Default model ID to use default_model: Default model ID to use
http_headers: Custom HTTP headers for requests http_headers: Custom HTTP headers for requests
""" """
def __init__( def __init__(
self, self,
api_key: str, provider_name: str = "openrouter",
base_url: Optional[str] = None, provider_api_keys: Optional[Dict[str, str]] = None,
provider_class: type = OpenRouterProvider, ollama_base_url: str = OLLAMA_DEFAULT_URL,
app_name: str = APP_NAME, app_name: str = APP_NAME,
app_url: str = APP_URL, app_url: str = APP_URL,
): ):
""" """
Initialize the AI client. Initialize the AI client with specified provider.
Args: Args:
api_key: API key for authentication provider_name: Provider to use ("openrouter", "anthropic", "openai", "ollama")
base_url: Optional custom base URL provider_api_keys: Dict mapping provider names to API keys
provider_class: Provider class to use (default: OpenRouterProvider) ollama_base_url: Base URL for Ollama server
app_name: Application name for headers app_name: Application name for headers
app_url: Application URL for headers app_url: Application URL for headers
Raises:
ValueError: If provider is invalid or not configured
""" """
self.provider: AIProvider = provider_class( from oai.providers.registry import get_provider_class
api_key=api_key,
base_url=base_url, self.provider_name = provider_name
app_name=app_name, self.provider_api_keys = provider_api_keys or {}
app_url=app_url, self.ollama_base_url = ollama_base_url
) self.app_name = app_name
self.app_url = app_url
# Get provider class
provider_class = get_provider_class(provider_name)
if not provider_class:
raise ValueError(f"Unknown provider: {provider_name}")
# Get API key for this provider
api_key = self.provider_api_keys.get(provider_name, "")
# Initialize provider with appropriate parameters
if provider_name == "ollama":
self.provider: AIProvider = provider_class(
api_key=api_key,
base_url=ollama_base_url,
)
else:
self.provider: AIProvider = provider_class(
api_key=api_key,
app_name=app_name,
app_url=app_url,
)
self.default_model: Optional[str] = None self.default_model: Optional[str] = None
self.logger = get_logger() self.logger = get_logger()
self.logger.info(f"Initialized {provider_name} provider")
def list_models(self, filter_text_only: bool = True) -> List[ModelInfo]: def list_models(self, filter_text_only: bool = True) -> List[ModelInfo]:
""" """
Get available models. Get available models.
@@ -420,3 +448,49 @@ class AIClient:
"""Clear the provider's model cache.""" """Clear the provider's model cache."""
if hasattr(self.provider, "clear_cache"): if hasattr(self.provider, "clear_cache"):
self.provider.clear_cache() self.provider.clear_cache()
def switch_provider(self, provider_name: str, ollama_base_url: Optional[str] = None) -> None:
"""
Switch to a different provider.
Args:
provider_name: Provider to switch to
ollama_base_url: Optional Ollama base URL (if switching to Ollama)
Raises:
ValueError: If provider is invalid or not configured
"""
from oai.providers.registry import get_provider_class
# Get provider class
provider_class = get_provider_class(provider_name)
if not provider_class:
raise ValueError(f"Unknown provider: {provider_name}")
# Get API key
api_key = self.provider_api_keys.get(provider_name, "")
# Check API key requirement
if provider_name != "ollama" and not api_key:
raise ValueError(f"No API key configured for {provider_name}")
# Initialize new provider
if provider_name == "ollama":
base_url = ollama_base_url or self.ollama_base_url
self.provider = provider_class(
api_key=api_key,
base_url=base_url,
)
self.ollama_base_url = base_url
else:
self.provider = provider_class(
api_key=api_key,
app_name=self.app_name,
app_url=self.app_url,
)
self.provider_name = provider_name
self.logger.info(f"Switched to {provider_name} provider")
# Clear model cache when switching providers
self.default_model = None

View File

@@ -9,10 +9,7 @@ import asyncio
import json import json
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Optional, Tuple
from rich.live import Live
from rich.markdown import Markdown
from oai.commands.registry import CommandContext, CommandResult, registry from oai.commands.registry import CommandContext, CommandResult, registry
from oai.config.database import Database from oai.config.database import Database
@@ -25,17 +22,10 @@ from oai.constants import (
from oai.core.client import AIClient from oai.core.client import AIClient
from oai.mcp.manager import MCPManager from oai.mcp.manager import MCPManager
from oai.providers.base import ChatResponse, StreamChunk, UsageStats from oai.providers.base import ChatResponse, StreamChunk, UsageStats
from oai.ui.console import (
console,
display_markdown,
display_panel,
print_error,
print_info,
print_success,
print_warning,
)
from oai.ui.prompts import prompt_copy_response
from oai.utils.logging import get_logger from oai.utils.logging import get_logger
from oai.utils.web_search import perform_web_search, format_search_results
logger = get_logger()
@dataclass @dataclass
@@ -177,6 +167,7 @@ class ChatSession:
total_cost=self.stats.total_cost, total_cost=self.stats.total_cost,
message_count=self.stats.message_count, message_count=self.stats.message_count,
current_index=self.current_index, current_index=self.current_index,
session=self,
) )
def set_model(self, model: Dict[str, Any]) -> None: def set_model(self, model: Dict[str, Any]) -> None:
@@ -303,6 +294,44 @@ class ChatSession:
# Build request parameters # Build request parameters
model_id = self.selected_model["id"] model_id = self.selected_model["id"]
# Handle online mode
enable_web_search = False
web_search_config = {}
if self.online_enabled:
# OpenRouter handles online mode natively with :online suffix
if self.client.provider_name == "openrouter":
if hasattr(self.client.provider, "get_effective_model_id"):
model_id = self.client.provider.get_effective_model_id(model_id, True)
# Anthropic has native web search when search provider is set to anthropic_native
elif self.client.provider_name == "anthropic" and self.settings.search_provider == "anthropic_native":
enable_web_search = True
web_search_config = {
"max_uses": self.settings.search_num_results or 5
}
logger.info("Using Anthropic native web search")
else:
# For other providers, perform web search and inject results
logger.info(f"Performing web search for: {user_input}")
search_results = perform_web_search(
user_input,
num_results=self.settings.search_num_results,
provider=self.settings.search_provider,
google_api_key=self.settings.google_api_key,
google_search_engine_id=self.settings.google_search_engine_id
)
if search_results:
# Inject search results into messages
formatted_results = format_search_results(search_results)
search_context = f"\n\n{formatted_results}\n\nPlease use the above web search results to help answer the user's question."
# Add search results to the last user message
if messages and messages[-1]["role"] == "user":
messages[-1]["content"] += search_context
logger.info(f"Injected {len(search_results)} search results into context")
if self.online_enabled: if self.online_enabled:
if hasattr(self.client.provider, "get_effective_model_id"): if hasattr(self.client.provider, "get_effective_model_id"):
model_id = self.client.provider.get_effective_model_id(model_id, True) model_id = self.client.provider.get_effective_model_id(model_id, True)
@@ -333,6 +362,8 @@ class ChatSession:
max_tokens=max_tokens, max_tokens=max_tokens,
transforms=transforms, transforms=transforms,
on_chunk=on_stream_chunk, on_chunk=on_stream_chunk,
enable_web_search=enable_web_search,
web_search_config=web_search_config,
) )
response_time = time.time() - start_time response_time = time.time() - start_time
return full_text, usage, response_time return full_text, usage, response_time
@@ -396,7 +427,7 @@ class ChatSession:
if not tool_calls: if not tool_calls:
return response return response
console.print(f"\n[dim yellow]🔧 AI requesting {len(tool_calls)} tool call(s)...[/]") # Tool calls requested by AI
tool_results = [] tool_results = []
for tc in tool_calls: for tc in tool_calls:
@@ -417,15 +448,17 @@ class ChatSession:
f'{k}="{v}"' if isinstance(v, str) else f"{k}={v}" f'{k}="{v}"' if isinstance(v, str) else f"{k}={v}"
for k, v in args.items() for k, v in args.items()
) )
console.print(f"[dim cyan] → {tc.function.name}({args_display})[/]") # Executing tool: {tc.function.name}
# Execute tool # Execute tool
result = asyncio.run(self.execute_tool(tc.function.name, args)) result = asyncio.run(self.execute_tool(tc.function.name, args))
if "error" in result: if "error" in result:
console.print(f"[dim red] ✗ Error: {result['error']}[/]") # Tool execution error logged
pass
else: else:
self._display_tool_success(tc.function.name, result) # Tool execution successful
pass
tool_results.append({ tool_results.append({
"tool_call_id": tc.id, "tool_call_id": tc.id,
@@ -452,38 +485,12 @@ class ChatSession:
}) })
api_messages.extend(tool_results) api_messages.extend(tool_results)
console.print("\n[dim cyan]💭 Processing tool results...[/]") # Processing tool results
loop_count += 1 loop_count += 1
self.logger.warning(f"Reached max tool loops ({max_loops})") self.logger.warning(f"Reached max tool loops ({max_loops})")
console.print(f"[bold yellow]⚠️ Reached maximum tool calls ({max_loops})[/]")
return response return response
def _display_tool_success(self, tool_name: str, result: Dict[str, Any]) -> None:
"""Display a success message for a tool call."""
if tool_name == "search_files":
count = result.get("count", 0)
console.print(f"[dim green] ✓ Found {count} file(s)[/]")
elif tool_name == "read_file":
size = result.get("size", 0)
truncated = " (truncated)" if result.get("truncated") else ""
console.print(f"[dim green] ✓ Read {size} bytes{truncated}[/]")
elif tool_name == "list_directory":
count = result.get("count", 0)
console.print(f"[dim green] ✓ Listed {count} item(s)[/]")
elif tool_name == "inspect_database":
if "table" in result:
console.print(f"[dim green] ✓ Inspected table: {result['table']}[/]")
else:
console.print(f"[dim green] ✓ Inspected database ({result.get('table_count', 0)} tables)[/]")
elif tool_name == "search_database":
count = result.get("count", 0)
console.print(f"[dim green] ✓ Found {count} match(es)[/]")
elif tool_name == "query_database":
count = result.get("count", 0)
console.print(f"[dim green] ✓ Query returned {count} row(s)[/]")
else:
console.print("[dim green] ✓ Success[/]")
def _stream_response( def _stream_response(
self, self,
@@ -492,6 +499,8 @@ class ChatSession:
max_tokens: Optional[int] = None, max_tokens: Optional[int] = None,
transforms: Optional[List[str]] = None, transforms: Optional[List[str]] = None,
on_chunk: Optional[Callable[[str], None]] = None, on_chunk: Optional[Callable[[str], None]] = None,
enable_web_search: bool = False,
web_search_config: Optional[Dict[str, Any]] = None,
) -> Tuple[str, Optional[UsageStats]]: ) -> Tuple[str, Optional[UsageStats]]:
""" """
Stream a response with live display. Stream a response with live display.
@@ -502,6 +511,8 @@ class ChatSession:
max_tokens: Max tokens max_tokens: Max tokens
transforms: Transforms transforms: Transforms
on_chunk: Callback for chunks on_chunk: Callback for chunks
enable_web_search: Whether to enable Anthropic native web search
web_search_config: Web search configuration
Returns: Returns:
Tuple of (full_text, usage) Tuple of (full_text, usage)
@@ -512,6 +523,8 @@ class ChatSession:
stream=True, stream=True,
max_tokens=max_tokens, max_tokens=max_tokens,
transforms=transforms, transforms=transforms,
enable_web_search=enable_web_search,
web_search_config=web_search_config or {},
) )
if isinstance(response, ChatResponse): if isinstance(response, ChatResponse):
@@ -521,27 +534,339 @@ class ChatSession:
usage: Optional[UsageStats] = None usage: Optional[UsageStats] = None
try: try:
with Live("", console=console, refresh_per_second=10) as live: for chunk in response:
for chunk in response: if chunk.error:
if chunk.error: self.logger.error(f"Stream error: {chunk.error}")
console.print(f"\n[bold red]Stream error: {chunk.error}[/]") break
break
if chunk.delta_content: if chunk.delta_content:
full_text += chunk.delta_content full_text += chunk.delta_content
live.update(Markdown(full_text)) if on_chunk:
if on_chunk: on_chunk(chunk.delta_content)
on_chunk(chunk.delta_content)
if chunk.usage: if chunk.usage:
usage = chunk.usage usage = chunk.usage
except KeyboardInterrupt: except KeyboardInterrupt:
console.print("\n[bold yellow]⚠️ Streaming interrupted[/]") self.logger.info("Streaming interrupted")
return "", None return "", None
return full_text, usage return full_text, usage
# ========== ASYNC METHODS FOR TUI ==========
async def send_message_async(
self,
user_input: str,
stream: bool = True,
) -> AsyncIterator[StreamChunk]:
"""
Async version of send_message for Textual TUI.
Args:
user_input: User's input text
stream: Whether to stream the response
Yields:
StreamChunk objects for progressive display
"""
if not self.selected_model:
raise ValueError("No model selected")
messages = self.build_api_messages(user_input)
tools = self.get_mcp_tools()
if tools:
# Disable streaming when tools are present
stream = False
# Handle online mode
model_id = self.selected_model["id"]
enable_web_search = False
web_search_config = {}
if self.online_enabled:
# OpenRouter handles online mode natively with :online suffix
if self.client.provider_name == "openrouter":
if hasattr(self.client.provider, "get_effective_model_id"):
model_id = self.client.provider.get_effective_model_id(model_id, True)
# Anthropic has native web search when search provider is set to anthropic_native
elif self.client.provider_name == "anthropic" and self.settings.search_provider == "anthropic_native":
enable_web_search = True
web_search_config = {
"max_uses": self.settings.search_num_results or 5
}
logger.info("Using Anthropic native web search")
else:
# For other providers, perform web search and inject results
logger.info(f"Performing web search for: {user_input}")
search_results = await asyncio.to_thread(
perform_web_search,
user_input,
num_results=self.settings.search_num_results,
provider=self.settings.search_provider,
google_api_key=self.settings.google_api_key,
google_search_engine_id=self.settings.google_search_engine_id
)
if search_results:
# Inject search results into messages
formatted_results = format_search_results(search_results)
search_context = f"\n\n{formatted_results}\n\nPlease use the above web search results to help answer the user's question."
# Add search results to the last user message
if messages and messages[-1]["role"] == "user":
messages[-1]["content"] += search_context
logger.info(f"Injected {len(search_results)} search results into context")
transforms = ["middle-out"] if self.middle_out_enabled else None
max_tokens = None
if self.session_max_token > 0:
max_tokens = self.session_max_token
if tools:
# Use async tool handling flow
async for chunk in self._send_with_tools_async(
messages=messages,
model_id=model_id,
tools=tools,
max_tokens=max_tokens,
transforms=transforms,
):
yield chunk
elif stream:
# Use async streaming flow
async for chunk in self._stream_response_async(
messages=messages,
model_id=model_id,
max_tokens=max_tokens,
transforms=transforms,
enable_web_search=enable_web_search,
web_search_config=web_search_config,
):
yield chunk
else:
# Non-streaming request
response = self.client.chat(
messages=messages,
model=model_id,
stream=False,
max_tokens=max_tokens,
transforms=transforms,
)
if isinstance(response, ChatResponse):
# Yield single chunk with complete response
chunk = StreamChunk(
id="",
delta_content=response.content,
usage=response.usage,
error=None,
)
yield chunk
async def _send_with_tools_async(
self,
messages: List[Dict[str, Any]],
model_id: str,
tools: List[Dict[str, Any]],
max_tokens: Optional[int] = None,
transforms: Optional[List[str]] = None,
) -> AsyncIterator[StreamChunk]:
"""
Async version of _send_with_tools for TUI.
Args:
messages: API messages
model_id: Model ID
tools: Tool definitions
max_tokens: Max tokens
transforms: Transforms list
Yields:
StreamChunk objects including tool call notifications
"""
max_loops = 5
loop_count = 0
api_messages = list(messages)
while loop_count < max_loops:
response = self.client.chat(
messages=api_messages,
model=model_id,
stream=False,
max_tokens=max_tokens,
tools=tools,
tool_choice="auto",
transforms=transforms,
)
if not isinstance(response, ChatResponse):
raise ValueError("Expected ChatResponse")
tool_calls = response.tool_calls
if not tool_calls:
# Final response, yield it
chunk = StreamChunk(
id="",
delta_content=response.content,
usage=response.usage,
error=None,
)
yield chunk
return
# Yield notification about tool calls
tool_notification = f"\n🔧 AI requesting {len(tool_calls)} tool call(s)...\n"
yield StreamChunk(id="", delta_content=tool_notification, usage=None, error=None)
tool_results = []
for tc in tool_calls:
try:
args = json.loads(tc.function.arguments)
except json.JSONDecodeError as e:
self.logger.error(f"Failed to parse tool arguments: {e}")
tool_results.append({
"tool_call_id": tc.id,
"role": "tool",
"name": tc.function.name,
"content": json.dumps({"error": f"Invalid arguments: {e}"}),
})
continue
# Yield tool call display
args_display = ", ".join(
f'{k}="{v}"' if isinstance(v, str) else f"{k}={v}"
for k, v in args.items()
)
tool_display = f"{tc.function.name}({args_display})\n"
yield StreamChunk(id="", delta_content=tool_display, usage=None, error=None)
# Execute tool (await instead of asyncio.run)
result = await self.execute_tool(tc.function.name, args)
if "error" in result:
error_msg = f" ✗ Error: {result['error']}\n"
yield StreamChunk(id="", delta_content=error_msg, usage=None, error=None)
else:
success_msg = self._format_tool_success(tc.function.name, result)
yield StreamChunk(id="", delta_content=success_msg, usage=None, error=None)
tool_results.append({
"tool_call_id": tc.id,
"role": "tool",
"name": tc.function.name,
"content": json.dumps(result),
})
# Add assistant message with tool calls
api_messages.append({
"role": "assistant",
"content": response.content,
"tool_calls": [
{
"id": tc.id,
"type": tc.type,
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments,
},
}
for tc in tool_calls
],
})
# Add tool results
api_messages.extend(tool_results)
loop_count += 1
# Max loops reached
yield StreamChunk(
id="",
delta_content="\n⚠️ Maximum tool call loops reached\n",
usage=None,
error="Max loops reached"
)
def _format_tool_success(self, tool_name: str, result: Dict[str, Any]) -> str:
"""Format a success message for a tool call."""
if tool_name == "search_files":
count = result.get("count", 0)
return f" ✓ Found {count} file(s)\n"
elif tool_name == "read_file":
size = result.get("size", 0)
truncated = " (truncated)" if result.get("truncated") else ""
return f" ✓ Read {size} bytes{truncated}\n"
elif tool_name == "list_directory":
count = result.get("count", 0)
return f" ✓ Listed {count} item(s)\n"
elif tool_name == "inspect_database":
if "table" in result:
return f" ✓ Inspected table: {result['table']}\n"
else:
return f" ✓ Inspected database ({result.get('table_count', 0)} tables)\n"
elif tool_name == "search_database":
count = result.get("count", 0)
return f" ✓ Found {count} match(es)\n"
elif tool_name == "query_database":
count = result.get("count", 0)
return f" ✓ Query returned {count} row(s)\n"
else:
return " ✓ Success\n"
async def _stream_response_async(
self,
messages: List[Dict[str, Any]],
model_id: str,
max_tokens: Optional[int] = None,
transforms: Optional[List[str]] = None,
enable_web_search: bool = False,
web_search_config: Optional[Dict[str, Any]] = None,
) -> AsyncIterator[StreamChunk]:
"""
Async version of _stream_response for TUI.
Args:
messages: API messages
model_id: Model ID
max_tokens: Max tokens
transforms: Transforms
enable_web_search: Whether to enable Anthropic native web search
web_search_config: Web search configuration
Yields:
StreamChunk objects
"""
response = self.client.chat(
messages=messages,
model=model_id,
stream=True,
max_tokens=max_tokens,
transforms=transforms,
enable_web_search=enable_web_search,
web_search_config=web_search_config or {},
)
if isinstance(response, ChatResponse):
# Non-streaming response
chunk = StreamChunk(
id="",
delta_content=response.content,
usage=response.usage,
error=None,
)
yield chunk
return
# Stream the response
for chunk in response:
if chunk.error:
yield StreamChunk(id="", delta_content=None, usage=None, error=chunk.error)
break
yield chunk
# ========== END ASYNC METHODS ==========
def add_to_history( def add_to_history(
self, self,
prompt: str, prompt: str,

View File

@@ -16,6 +16,16 @@ from oai.providers.base import (
ProviderCapabilities, ProviderCapabilities,
) )
from oai.providers.openrouter import OpenRouterProvider from oai.providers.openrouter import OpenRouterProvider
from oai.providers.anthropic import AnthropicProvider
from oai.providers.openai import OpenAIProvider
from oai.providers.ollama import OllamaProvider
from oai.providers.registry import register_provider
# Register all providers
register_provider("openrouter", OpenRouterProvider)
register_provider("anthropic", AnthropicProvider)
register_provider("openai", OpenAIProvider)
register_provider("ollama", OllamaProvider)
__all__ = [ __all__ = [
# Base classes and types # Base classes and types
@@ -29,4 +39,7 @@ __all__ = [
"ProviderCapabilities", "ProviderCapabilities",
# Provider implementations # Provider implementations
"OpenRouterProvider", "OpenRouterProvider",
"AnthropicProvider",
"OpenAIProvider",
"OllamaProvider",
] ]

673
oai/providers/anthropic.py Normal file
View File

@@ -0,0 +1,673 @@
"""
Anthropic provider for Claude models.
This provider connects to Anthropic's API for accessing Claude models.
"""
import json
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
import anthropic
from anthropic.types import Message, MessageStreamEvent
from oai.constants import ANTHROPIC_BASE_URL
from oai.providers.base import (
AIProvider,
ChatMessage,
ChatResponse,
ChatResponseChoice,
ModelInfo,
ProviderCapabilities,
StreamChunk,
ToolCall,
ToolFunction,
UsageStats,
)
from oai.utils.logging import get_logger
logger = get_logger()
# Model name aliases
MODEL_ALIASES = {
"claude-sonnet": "claude-sonnet-4-5-20250929",
"claude-haiku": "claude-haiku-4-5-20251001",
"claude-opus": "claude-opus-4-5-20251101",
# Legacy aliases
"claude-3-haiku": "claude-3-haiku-20240307",
"claude-3-7-sonnet": "claude-3-7-sonnet-20250219",
}
class AnthropicProvider(AIProvider):
"""
Anthropic API provider.
Provides access to Claude 3.5 Sonnet, Claude 3 Opus, and other Anthropic models.
"""
def __init__(
self,
api_key: str,
base_url: Optional[str] = None,
app_name: str = "oAI",
app_url: str = "",
**kwargs: Any,
):
"""
Initialize Anthropic provider.
Args:
api_key: Anthropic API key
base_url: Optional custom base URL
app_name: Application name (for headers)
app_url: Application URL (for headers)
**kwargs: Additional arguments
"""
super().__init__(api_key, base_url or ANTHROPIC_BASE_URL)
self.client = anthropic.Anthropic(api_key=api_key)
self.async_client = anthropic.AsyncAnthropic(api_key=api_key)
self._models_cache: Optional[List[ModelInfo]] = None
def _create_web_search_tool(self, config: Dict[str, Any]) -> Dict[str, Any]:
"""
Create Anthropic native web search tool definition.
Args:
config: Optional configuration for web search (max_uses, allowed_domains, etc.)
Returns:
Tool definition dict
"""
tool: Dict[str, Any] = {
"type": "web_search_20250305",
"name": "web_search",
}
# Add optional parameters if provided
if "max_uses" in config:
tool["max_uses"] = config["max_uses"]
else:
tool["max_uses"] = 5 # Default
if "allowed_domains" in config:
tool["allowed_domains"] = config["allowed_domains"]
if "blocked_domains" in config:
tool["blocked_domains"] = config["blocked_domains"]
if "user_location" in config:
tool["user_location"] = config["user_location"]
return tool
@property
def name(self) -> str:
"""Get provider name."""
return "Anthropic"
@property
def capabilities(self) -> ProviderCapabilities:
"""Get provider capabilities."""
return ProviderCapabilities(
streaming=True,
tools=True,
images=True,
online=True, # Web search via DuckDuckGo/Google
max_context=200000,
)
def list_models(self, filter_text_only: bool = True) -> List[ModelInfo]:
"""
List available Anthropic models.
Args:
filter_text_only: Whether to filter for text models only
Returns:
List of ModelInfo objects
"""
if self._models_cache:
return self._models_cache
# Anthropic doesn't have a models list API, so we hardcode the available models
models = [
# Current Claude 4.5 models
ModelInfo(
id="claude-sonnet-4-5-20250929",
name="Claude Sonnet 4.5",
description="Smart model for complex agents and coding (recommended)",
context_length=200000,
pricing={"input": 3.0, "output": 15.0},
supported_parameters=["temperature", "max_tokens", "stream", "tools"],
input_modalities=["text", "image"],
),
ModelInfo(
id="claude-haiku-4-5-20251001",
name="Claude Haiku 4.5",
description="Fastest model with near-frontier intelligence",
context_length=200000,
pricing={"input": 1.0, "output": 5.0},
supported_parameters=["temperature", "max_tokens", "stream", "tools"],
input_modalities=["text", "image"],
),
ModelInfo(
id="claude-opus-4-5-20251101",
name="Claude Opus 4.5",
description="Premium model with maximum intelligence",
context_length=200000,
pricing={"input": 5.0, "output": 25.0},
supported_parameters=["temperature", "max_tokens", "stream", "tools"],
input_modalities=["text", "image"],
),
# Legacy models (still available)
ModelInfo(
id="claude-3-7-sonnet-20250219",
name="Claude Sonnet 3.7",
description="Legacy model - recommend migrating to 4.5",
context_length=200000,
pricing={"input": 3.0, "output": 15.0},
supported_parameters=["temperature", "max_tokens", "stream", "tools"],
input_modalities=["text", "image"],
),
ModelInfo(
id="claude-3-haiku-20240307",
name="Claude 3 Haiku",
description="Legacy fast model - recommend migrating to 4.5",
context_length=200000,
pricing={"input": 0.25, "output": 1.25},
supported_parameters=["temperature", "max_tokens", "stream", "tools"],
input_modalities=["text", "image"],
),
]
self._models_cache = models
logger.info(f"Loaded {len(models)} Anthropic models")
return models
def get_model(self, model_id: str) -> Optional[ModelInfo]:
"""
Get information about a specific model.
Args:
model_id: Model identifier
Returns:
ModelInfo or None
"""
# Resolve alias
resolved_id = MODEL_ALIASES.get(model_id, model_id)
models = self.list_models()
for model in models:
if model.id == resolved_id or model.id == model_id:
return model
return None
def chat(
self,
model: str,
messages: List[ChatMessage],
stream: bool = False,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
tools: Optional[List[Dict[str, Any]]] = None,
tool_choice: Optional[str] = None,
**kwargs: Any,
) -> Union[ChatResponse, Iterator[StreamChunk]]:
"""
Send chat completion request to Anthropic.
Args:
model: Model ID
messages: Chat messages
stream: Whether to stream response
max_tokens: Maximum tokens
temperature: Sampling temperature
tools: Tool definitions
tool_choice: Tool selection mode
**kwargs: Additional parameters (including enable_web_search)
Returns:
ChatResponse or Iterator[StreamChunk]
"""
# Resolve model alias
model_id = MODEL_ALIASES.get(model, model)
# Extract system message (Anthropic requires it separate from messages)
system_prompt, anthropic_messages = self._convert_messages(messages)
# Build request parameters
params: Dict[str, Any] = {
"model": model_id,
"messages": anthropic_messages,
"max_tokens": max_tokens or 4096,
}
if system_prompt:
params["system"] = system_prompt
if temperature is not None:
params["temperature"] = temperature
# Prepare tools list
tools_list = []
# Add web search tool if requested via kwargs
if kwargs.get("enable_web_search", False):
web_search_config = kwargs.get("web_search_config", {})
tools_list.append(self._create_web_search_tool(web_search_config))
logger.info("Added Anthropic native web search tool")
# Add user-provided tools
if tools:
# Convert tools to Anthropic format
converted_tools = self._convert_tools(tools)
tools_list.extend(converted_tools)
if tools_list:
params["tools"] = tools_list
if tool_choice and tool_choice != "auto":
# Anthropic uses different tool_choice format
if tool_choice == "none":
pass # Don't include tools
elif tool_choice == "required":
params["tool_choice"] = {"type": "any"}
else:
params["tool_choice"] = {"type": "tool", "name": tool_choice}
logger.debug(f"Anthropic request: model={model_id}, messages={len(anthropic_messages)}")
try:
if stream:
return self._stream_chat(params)
else:
return self._sync_chat(params)
except Exception as e:
logger.error(f"Anthropic request failed: {e}")
return ChatResponse(
id="error",
choices=[
ChatResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=f"Error: {str(e)}"),
finish_reason="error",
)
],
)
def _convert_messages(self, messages: List[ChatMessage]) -> tuple[str, List[Dict[str, Any]]]:
"""
Convert messages to Anthropic format.
Anthropic requires system messages to be separate from the conversation.
Args:
messages: List of ChatMessage objects
Returns:
Tuple of (system_prompt, anthropic_messages)
"""
system_prompt = ""
anthropic_messages = []
for msg in messages:
if msg.role == "system":
# Accumulate system messages
if system_prompt:
system_prompt += "\n\n"
system_prompt += msg.content or ""
else:
# Convert to Anthropic format
message_dict: Dict[str, Any] = {"role": msg.role}
# Handle content
if msg.content:
message_dict["content"] = msg.content
# Handle tool calls (assistant messages)
if msg.tool_calls:
# Anthropic format for tool use
content_blocks = []
if msg.content:
content_blocks.append({"type": "text", "text": msg.content})
for tc in msg.tool_calls:
content_blocks.append({
"type": "tool_use",
"id": tc.id,
"name": tc.function.name,
"input": json.loads(tc.function.arguments),
})
message_dict["content"] = content_blocks
# Handle tool results (tool messages)
if msg.role == "tool" and msg.tool_call_id:
# Convert to Anthropic's tool_result format
anthropic_messages.append({
"role": "user",
"content": [{
"type": "tool_result",
"tool_use_id": msg.tool_call_id,
"content": msg.content or "",
}]
})
continue
anthropic_messages.append(message_dict)
return system_prompt, anthropic_messages
def _convert_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Convert OpenAI-style tools to Anthropic format.
Args:
tools: OpenAI tool definitions
Returns:
Anthropic tool definitions
"""
anthropic_tools = []
for tool in tools:
if tool.get("type") == "function":
func = tool.get("function", {})
anthropic_tools.append({
"name": func.get("name"),
"description": func.get("description", ""),
"input_schema": func.get("parameters", {}),
})
return anthropic_tools
def _sync_chat(self, params: Dict[str, Any]) -> ChatResponse:
"""
Send synchronous chat request.
Args:
params: Request parameters
Returns:
ChatResponse
"""
message: Message = self.client.messages.create(**params)
# Extract content
content = ""
tool_calls = []
for block in message.content:
if block.type == "text":
content += block.text
elif block.type == "tool_use":
# Convert to ToolCall format
tool_calls.append(
ToolCall(
id=block.id,
type="function",
function=ToolFunction(
name=block.name,
arguments=json.dumps(block.input),
),
)
)
# Build ChatMessage
chat_message = ChatMessage(
role="assistant",
content=content if content else None,
tool_calls=tool_calls if tool_calls else None,
)
# Extract usage
usage = None
if message.usage:
usage = UsageStats(
prompt_tokens=message.usage.input_tokens,
completion_tokens=message.usage.output_tokens,
total_tokens=message.usage.input_tokens + message.usage.output_tokens,
)
return ChatResponse(
id=message.id,
choices=[
ChatResponseChoice(
index=0,
message=chat_message,
finish_reason=message.stop_reason,
)
],
usage=usage,
model=message.model,
)
def _stream_chat(self, params: Dict[str, Any]) -> Iterator[StreamChunk]:
"""
Stream chat response from Anthropic.
Args:
params: Request parameters
Yields:
StreamChunk objects
"""
stream = self.client.messages.stream(**params)
with stream as event_stream:
for event in event_stream:
event_data: MessageStreamEvent = event
# Handle different event types
if event_data.type == "content_block_delta":
delta = event_data.delta
if hasattr(delta, "text"):
yield StreamChunk(
id="stream",
delta_content=delta.text,
)
elif event_data.type == "message_stop":
# Final event with usage
pass
elif event_data.type == "message_delta":
# Contains stop reason and usage
usage = None
if hasattr(event_data, "usage"):
usage_data = event_data.usage
usage = UsageStats(
prompt_tokens=0,
completion_tokens=usage_data.output_tokens,
total_tokens=usage_data.output_tokens,
)
yield StreamChunk(
id="stream",
finish_reason=event_data.delta.stop_reason if hasattr(event_data.delta, "stop_reason") else None,
usage=usage,
)
async def chat_async(
self,
model: str,
messages: List[ChatMessage],
stream: bool = False,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
tools: Optional[List[Dict[str, Any]]] = None,
tool_choice: Optional[str] = None,
**kwargs: Any,
) -> Union[ChatResponse, AsyncIterator[StreamChunk]]:
"""
Send async chat request to Anthropic.
Args:
model: Model ID
messages: Chat messages
stream: Whether to stream
max_tokens: Max tokens
temperature: Temperature
tools: Tool definitions
tool_choice: Tool choice
**kwargs: Additional args
Returns:
ChatResponse or AsyncIterator[StreamChunk]
"""
# Resolve model alias
model_id = MODEL_ALIASES.get(model, model)
# Convert messages
system_prompt, anthropic_messages = self._convert_messages(messages)
# Build params
params: Dict[str, Any] = {
"model": model_id,
"messages": anthropic_messages,
"max_tokens": max_tokens or 4096,
}
if system_prompt:
params["system"] = system_prompt
if temperature is not None:
params["temperature"] = temperature
if tools:
params["tools"] = self._convert_tools(tools)
if stream:
return self._stream_chat_async(params)
else:
message = await self.async_client.messages.create(**params)
return self._convert_message(message)
async def _stream_chat_async(self, params: Dict[str, Any]) -> AsyncIterator[StreamChunk]:
"""
Stream async chat response.
Args:
params: Request parameters
Yields:
StreamChunk objects
"""
stream = await self.async_client.messages.stream(**params)
async with stream as event_stream:
async for event in event_stream:
if event.type == "content_block_delta":
delta = event.delta
if hasattr(delta, "text"):
yield StreamChunk(
id="stream",
delta_content=delta.text,
)
def _convert_message(self, message: Message) -> ChatResponse:
"""Helper to convert Anthropic message to ChatResponse."""
content = ""
tool_calls = []
for block in message.content:
if block.type == "text":
content += block.text
elif block.type == "tool_use":
tool_calls.append(
ToolCall(
id=block.id,
type="function",
function=ToolFunction(
name=block.name,
arguments=json.dumps(block.input),
),
)
)
chat_message = ChatMessage(
role="assistant",
content=content if content else None,
tool_calls=tool_calls if tool_calls else None,
)
usage = None
if message.usage:
usage = UsageStats(
prompt_tokens=message.usage.input_tokens,
completion_tokens=message.usage.output_tokens,
total_tokens=message.usage.input_tokens + message.usage.output_tokens,
)
return ChatResponse(
id=message.id,
choices=[
ChatResponseChoice(
index=0,
message=chat_message,
finish_reason=message.stop_reason,
)
],
usage=usage,
model=message.model,
)
def get_credits(self) -> Optional[Dict[str, Any]]:
"""
Get account credits from Anthropic.
Note: Anthropic does not currently provide a public API endpoint
for checking account credits/balance. This information is only
available through the Anthropic Console web interface.
Returns:
None (credits API not available)
"""
# Anthropic doesn't provide a public credits API endpoint
# Users must check their balance at console.anthropic.com
return None
def clear_cache(self) -> None:
"""Clear model cache."""
self._models_cache = None
def get_raw_models(self) -> List[Dict[str, Any]]:
"""
Get raw model data as dictionaries.
Returns:
List of model dictionaries
"""
models = self.list_models()
return [
{
"id": model.id,
"name": model.name,
"description": model.description,
"context_length": model.context_length,
"pricing": model.pricing,
}
for model in models
]
def get_raw_model(self, model_id: str) -> Optional[Dict[str, Any]]:
"""
Get raw model data for a specific model.
Args:
model_id: Model identifier
Returns:
Model dictionary or None
"""
model = self.get_model(model_id)
if model:
return {
"id": model.id,
"name": model.name,
"description": model.description,
"context_length": model.context_length,
"pricing": model.pricing,
}
return None

423
oai/providers/ollama.py Normal file
View File

@@ -0,0 +1,423 @@
"""
Ollama provider for local AI model serving.
This provider connects to a local Ollama server for running models
locally without API keys or external dependencies.
"""
import json
import time
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
import requests
from oai.constants import OLLAMA_DEFAULT_URL
from oai.providers.base import (
AIProvider,
ChatMessage,
ChatResponse,
ChatResponseChoice,
ModelInfo,
ProviderCapabilities,
StreamChunk,
UsageStats,
)
from oai.utils.logging import get_logger
logger = get_logger()
class OllamaProvider(AIProvider):
"""
Ollama local model provider.
Connects to a local Ollama server for running models locally.
No API key required.
"""
def __init__(
self,
api_key: str = "",
base_url: Optional[str] = None,
**kwargs: Any,
):
"""
Initialize Ollama provider.
Args:
api_key: Not used (Ollama doesn't require API keys)
base_url: Ollama server URL (default: http://localhost:11434)
**kwargs: Additional arguments (ignored)
"""
super().__init__(api_key or "", base_url)
self.base_url = base_url or OLLAMA_DEFAULT_URL
self._check_server_available()
def _check_server_available(self) -> bool:
"""
Check if Ollama server is accessible.
Returns:
True if server is accessible
"""
try:
response = requests.get(f"{self.base_url}/api/tags", timeout=2)
if response.ok:
logger.info(f"Ollama server accessible at {self.base_url}")
return True
else:
logger.warning(f"Ollama server returned status {response.status_code}")
return False
except requests.RequestException as e:
logger.warning(f"Ollama server not accessible at {self.base_url}: {e}")
return False
@property
def name(self) -> str:
"""Get provider name."""
return "Ollama"
@property
def capabilities(self) -> ProviderCapabilities:
"""Get provider capabilities."""
return ProviderCapabilities(
streaming=True,
tools=False, # Tool support varies by model
images=False, # Image support varies by model
online=True, # Web search via DuckDuckGo/Google
max_context=8192, # Varies by model
)
def list_models(self, filter_text_only: bool = True) -> List[ModelInfo]:
"""
List models from local Ollama installation.
Args:
filter_text_only: Ignored for Ollama
Returns:
List of available models
"""
try:
response = requests.get(f"{self.base_url}/api/tags", timeout=5)
response.raise_for_status()
data = response.json()
models = []
for model_data in data.get("models", []):
models.append(self._parse_model(model_data))
logger.info(f"Found {len(models)} Ollama models")
return models
except requests.RequestException as e:
logger.error(f"Failed to list Ollama models: {e}")
return []
def _parse_model(self, model_data: Dict[str, Any]) -> ModelInfo:
"""
Parse Ollama model data into ModelInfo.
Args:
model_data: Raw model data from Ollama API
Returns:
ModelInfo object
"""
model_name = model_data.get("name", "unknown")
size_bytes = model_data.get("size", 0)
size_gb = size_bytes / (1024 ** 3) if size_bytes else 0
return ModelInfo(
id=model_name,
name=model_name,
description=f"Size: {size_gb:.1f}GB",
context_length=8192, # Default, varies by model
pricing={}, # Local models are free
supported_parameters=["stream", "temperature", "max_tokens"],
)
def get_model(self, model_id: str) -> Optional[ModelInfo]:
"""
Get information about a specific model.
Args:
model_id: Model identifier
Returns:
ModelInfo or None if not found
"""
models = self.list_models()
for model in models:
if model.id == model_id:
return model
return None
def chat(
self,
model: str,
messages: List[ChatMessage],
stream: bool = False,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
tools: Optional[List[Dict[str, Any]]] = None,
tool_choice: Optional[str] = None,
**kwargs: Any,
) -> Union[ChatResponse, Iterator[StreamChunk]]:
"""
Send chat request to Ollama.
Args:
model: Model name
messages: Chat messages
stream: Whether to stream response
max_tokens: Maximum tokens (Ollama calls this num_predict)
temperature: Sampling temperature
tools: Not supported
tool_choice: Not supported
**kwargs: Additional parameters
Returns:
ChatResponse or Iterator[StreamChunk]
"""
# Convert messages to Ollama format
ollama_messages = []
for msg in messages:
ollama_messages.append({
"role": msg.role,
"content": msg.content or "",
})
# Build request payload
payload: Dict[str, Any] = {
"model": model,
"messages": ollama_messages,
"stream": stream,
}
# Add optional parameters
options = {}
if temperature is not None:
options["temperature"] = temperature
if max_tokens is not None:
options["num_predict"] = max_tokens
if options:
payload["options"] = options
logger.debug(f"Ollama request: model={model}, messages={len(ollama_messages)}")
try:
if stream:
return self._stream_chat(payload)
else:
return self._sync_chat(payload)
except requests.RequestException as e:
logger.error(f"Ollama request failed: {e}")
# Return error response
return ChatResponse(
id="error",
choices=[
ChatResponseChoice(
index=0,
message=ChatMessage(
role="assistant",
content=f"Error: {str(e)}",
),
finish_reason="error",
)
],
)
def _sync_chat(self, payload: Dict[str, Any]) -> ChatResponse:
"""
Send synchronous chat request.
Args:
payload: Request payload
Returns:
ChatResponse
"""
response = requests.post(
f"{self.base_url}/api/chat",
json=payload,
timeout=120,
)
response.raise_for_status()
data = response.json()
# Parse response
message_data = data.get("message", {})
content = message_data.get("content", "")
# Extract token usage if available
usage = None
if "prompt_eval_count" in data or "eval_count" in data:
usage = UsageStats(
prompt_tokens=data.get("prompt_eval_count", 0),
completion_tokens=data.get("eval_count", 0),
total_tokens=data.get("prompt_eval_count", 0) + data.get("eval_count", 0),
total_cost_usd=0.0, # Local models are free
)
return ChatResponse(
id=str(time.time()),
choices=[
ChatResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=content),
finish_reason="stop",
)
],
usage=usage,
model=data.get("model"),
)
def _stream_chat(self, payload: Dict[str, Any]) -> Iterator[StreamChunk]:
"""
Stream chat response from Ollama.
Args:
payload: Request payload
Yields:
StreamChunk objects
"""
response = requests.post(
f"{self.base_url}/api/chat",
json=payload,
stream=True,
timeout=120,
)
response.raise_for_status()
total_prompt_tokens = 0
total_completion_tokens = 0
for line in response.iter_lines():
if not line:
continue
try:
data = json.loads(line)
# Extract content delta
message_data = data.get("message", {})
content = message_data.get("content", "")
# Check if done
done = data.get("done", False)
finish_reason = "stop" if done else None
# Extract usage if available
usage = None
if done and ("prompt_eval_count" in data or "eval_count" in data):
total_prompt_tokens = data.get("prompt_eval_count", 0)
total_completion_tokens = data.get("eval_count", 0)
usage = UsageStats(
prompt_tokens=total_prompt_tokens,
completion_tokens=total_completion_tokens,
total_tokens=total_prompt_tokens + total_completion_tokens,
total_cost_usd=0.0,
)
yield StreamChunk(
id=str(time.time()),
delta_content=content if content else None,
finish_reason=finish_reason,
usage=usage,
)
except json.JSONDecodeError as e:
logger.error(f"Failed to parse Ollama stream chunk: {e}")
yield StreamChunk(
id="error",
error=f"Parse error: {e}",
)
async def chat_async(
self,
model: str,
messages: List[ChatMessage],
stream: bool = False,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
tools: Optional[List[Dict[str, Any]]] = None,
tool_choice: Optional[str] = None,
**kwargs: Any,
) -> Union[ChatResponse, AsyncIterator[StreamChunk]]:
"""
Async chat not implemented for Ollama.
Args:
model: Model name
messages: Chat messages
stream: Whether to stream
max_tokens: Max tokens
temperature: Temperature
tools: Tools (not supported)
tool_choice: Tool choice (not supported)
**kwargs: Additional args
Returns:
ChatResponse or AsyncIterator[StreamChunk]
Raises:
NotImplementedError: Async not implemented
"""
raise NotImplementedError("Async chat not implemented for Ollama provider")
def get_credits(self) -> Optional[Dict[str, Any]]:
"""
Get account credits.
Returns:
None (Ollama is local and free)
"""
return None
def clear_cache(self) -> None:
"""Clear model cache (no-op for Ollama)."""
pass
def get_raw_models(self) -> List[Dict[str, Any]]:
"""
Get raw model data as dictionaries.
Returns:
List of model dictionaries
"""
models = self.list_models()
return [
{
"id": model.id,
"name": model.name,
"description": model.description,
"context_length": model.context_length,
"pricing": model.pricing,
}
for model in models
]
def get_raw_model(self, model_id: str) -> Optional[Dict[str, Any]]:
"""
Get raw model data for a specific model.
Args:
model_id: Model identifier
Returns:
Model dictionary or None
"""
model = self.get_model(model_id)
if model:
return {
"id": model.id,
"name": model.name,
"description": model.description,
"context_length": model.context_length,
"pricing": model.pricing,
}
return None

630
oai/providers/openai.py Normal file
View File

@@ -0,0 +1,630 @@
"""
OpenAI provider for GPT models.
This provider connects to OpenAI's API for accessing GPT-4, GPT-3.5, and other OpenAI models.
"""
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
from openai import OpenAI, AsyncOpenAI
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from oai.constants import OPENAI_BASE_URL
from oai.providers.base import (
AIProvider,
ChatMessage,
ChatResponse,
ChatResponseChoice,
ModelInfo,
ProviderCapabilities,
StreamChunk,
ToolCall,
ToolFunction,
UsageStats,
)
from oai.utils.logging import get_logger
logger = get_logger()
# Model aliases for convenience
MODEL_ALIASES = {
"gpt-4": "gpt-4-turbo",
"gpt-4-turbo": "gpt-4-turbo-2024-04-09",
"gpt-4o": "gpt-4o-2024-11-20",
"gpt-4o-mini": "gpt-4o-mini-2024-07-18",
"gpt-3.5": "gpt-3.5-turbo",
"gpt-3.5-turbo": "gpt-3.5-turbo-0125",
"o1": "o1-2024-12-17",
"o1-mini": "o1-mini-2024-09-12",
"o1-preview": "o1-preview-2024-09-12",
}
class OpenAIProvider(AIProvider):
"""
OpenAI API provider.
Provides access to GPT-4, GPT-3.5, o1, and other OpenAI models.
"""
def __init__(
self,
api_key: str,
base_url: Optional[str] = None,
app_name: str = "oAI",
app_url: str = "",
**kwargs: Any,
):
"""
Initialize OpenAI provider.
Args:
api_key: OpenAI API key
base_url: Optional custom base URL
app_name: Application name (for headers)
app_url: Application URL (for headers)
**kwargs: Additional arguments
"""
super().__init__(api_key, base_url or OPENAI_BASE_URL)
self.client = OpenAI(api_key=api_key, base_url=self.base_url)
self.async_client = AsyncOpenAI(api_key=api_key, base_url=self.base_url)
self._models_cache: Optional[List[ModelInfo]] = None
@property
def name(self) -> str:
"""Get provider name."""
return "OpenAI"
@property
def capabilities(self) -> ProviderCapabilities:
"""Get provider capabilities."""
return ProviderCapabilities(
streaming=True,
tools=True,
images=True,
online=True, # Web search via DuckDuckGo/Google
max_context=128000,
)
def list_models(self, filter_text_only: bool = True) -> List[ModelInfo]:
"""
List available OpenAI models.
Args:
filter_text_only: Whether to filter for text models only
Returns:
List of ModelInfo objects
"""
if self._models_cache:
return self._models_cache
try:
models_response = self.client.models.list()
models = []
for model in models_response.data:
# Filter for chat models
if "gpt" in model.id or "o1" in model.id:
models.append(self._parse_model(model))
# Sort by name
models.sort(key=lambda m: m.name)
self._models_cache = models
logger.info(f"Loaded {len(models)} OpenAI models")
return models
except Exception as e:
logger.error(f"Failed to list OpenAI models: {e}")
return self._get_fallback_models()
def _get_fallback_models(self) -> List[ModelInfo]:
"""
Get fallback list of common OpenAI models.
Returns:
List of common models
"""
return [
ModelInfo(
id="gpt-4o",
name="GPT-4o",
description="Most capable GPT-4 model",
context_length=128000,
pricing={"input": 5.0, "output": 15.0},
supported_parameters=["temperature", "max_tokens", "stream", "tools"],
input_modalities=["text", "image"],
),
ModelInfo(
id="gpt-4o-mini",
name="GPT-4o Mini",
description="Affordable and fast GPT-4 class model",
context_length=128000,
pricing={"input": 0.15, "output": 0.6},
supported_parameters=["temperature", "max_tokens", "stream", "tools"],
input_modalities=["text", "image"],
),
ModelInfo(
id="gpt-4-turbo",
name="GPT-4 Turbo",
description="GPT-4 Turbo with vision",
context_length=128000,
pricing={"input": 10.0, "output": 30.0},
supported_parameters=["temperature", "max_tokens", "stream", "tools"],
input_modalities=["text", "image"],
),
ModelInfo(
id="gpt-3.5-turbo",
name="GPT-3.5 Turbo",
description="Fast and affordable model",
context_length=16384,
pricing={"input": 0.5, "output": 1.5},
supported_parameters=["temperature", "max_tokens", "stream", "tools"],
),
ModelInfo(
id="o1",
name="o1",
description="Advanced reasoning model",
context_length=200000,
pricing={"input": 15.0, "output": 60.0},
supported_parameters=["max_tokens"],
),
ModelInfo(
id="o1-mini",
name="o1-mini",
description="Fast reasoning model",
context_length=128000,
pricing={"input": 3.0, "output": 12.0},
supported_parameters=["max_tokens"],
),
]
def _parse_model(self, model: Any) -> ModelInfo:
"""
Parse OpenAI model into ModelInfo.
Args:
model: OpenAI model object
Returns:
ModelInfo object
"""
model_id = model.id
# Determine context length
context_length = 8192 # Default
if "gpt-4o" in model_id or "gpt-4-turbo" in model_id:
context_length = 128000
elif "gpt-4" in model_id:
context_length = 8192
elif "gpt-3.5-turbo" in model_id:
context_length = 16384
elif "o1" in model_id:
context_length = 128000
# Determine pricing (approximate)
pricing = {}
if "gpt-4o-mini" in model_id:
pricing = {"input": 0.15, "output": 0.6}
elif "gpt-4o" in model_id:
pricing = {"input": 5.0, "output": 15.0}
elif "gpt-4-turbo" in model_id:
pricing = {"input": 10.0, "output": 30.0}
elif "gpt-4" in model_id:
pricing = {"input": 30.0, "output": 60.0}
elif "gpt-3.5" in model_id:
pricing = {"input": 0.5, "output": 1.5}
elif "o1" in model_id and "mini" not in model_id:
pricing = {"input": 15.0, "output": 60.0}
elif "o1-mini" in model_id:
pricing = {"input": 3.0, "output": 12.0}
return ModelInfo(
id=model_id,
name=model_id,
description="",
context_length=context_length,
pricing=pricing,
supported_parameters=["temperature", "max_tokens", "stream", "tools"],
)
def get_model(self, model_id: str) -> Optional[ModelInfo]:
"""
Get information about a specific model.
Args:
model_id: Model identifier
Returns:
ModelInfo or None
"""
# Resolve alias
resolved_id = MODEL_ALIASES.get(model_id, model_id)
models = self.list_models()
for model in models:
if model.id == resolved_id or model.id == model_id:
return model
# Try to fetch directly
try:
model = self.client.models.retrieve(resolved_id)
return self._parse_model(model)
except Exception:
return None
def chat(
self,
model: str,
messages: List[ChatMessage],
stream: bool = False,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
tools: Optional[List[Dict[str, Any]]] = None,
tool_choice: Optional[str] = None,
**kwargs: Any,
) -> Union[ChatResponse, Iterator[StreamChunk]]:
"""
Send chat completion request to OpenAI.
Args:
model: Model ID
messages: Chat messages
stream: Whether to stream response
max_tokens: Maximum tokens
temperature: Sampling temperature
tools: Tool definitions
tool_choice: Tool selection mode
**kwargs: Additional parameters
Returns:
ChatResponse or Iterator[StreamChunk]
"""
# Resolve model alias
model_id = MODEL_ALIASES.get(model, model)
# Convert messages to OpenAI format
openai_messages = []
for msg in messages:
message_dict = {"role": msg.role, "content": msg.content or ""}
if msg.tool_calls:
message_dict["tool_calls"] = [
{
"id": tc.id,
"type": tc.type,
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments,
},
}
for tc in msg.tool_calls
]
if msg.tool_call_id:
message_dict["tool_call_id"] = msg.tool_call_id
openai_messages.append(message_dict)
# Build request parameters
params: Dict[str, Any] = {
"model": model_id,
"messages": openai_messages,
"stream": stream,
}
# Add optional parameters
if max_tokens is not None:
params["max_tokens"] = max_tokens
if temperature is not None and "o1" not in model_id:
# o1 models don't support temperature
params["temperature"] = temperature
if tools:
params["tools"] = tools
if tool_choice:
params["tool_choice"] = tool_choice
logger.debug(f"OpenAI request: model={model_id}, messages={len(openai_messages)}")
try:
if stream:
return self._stream_chat(params)
else:
return self._sync_chat(params)
except Exception as e:
logger.error(f"OpenAI request failed: {e}")
return ChatResponse(
id="error",
choices=[
ChatResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=f"Error: {str(e)}"),
finish_reason="error",
)
],
)
def _sync_chat(self, params: Dict[str, Any]) -> ChatResponse:
"""
Send synchronous chat request.
Args:
params: Request parameters
Returns:
ChatResponse
"""
completion: ChatCompletion = self.client.chat.completions.create(**params)
# Convert to our format
choices = []
for choice in completion.choices:
# Convert tool calls if present
tool_calls = None
if choice.message.tool_calls:
tool_calls = [
ToolCall(
id=tc.id,
type=tc.type,
function=ToolFunction(
name=tc.function.name,
arguments=tc.function.arguments,
),
)
for tc in choice.message.tool_calls
]
choices.append(
ChatResponseChoice(
index=choice.index,
message=ChatMessage(
role=choice.message.role,
content=choice.message.content,
tool_calls=tool_calls,
),
finish_reason=choice.finish_reason,
)
)
# Convert usage
usage = None
if completion.usage:
usage = UsageStats(
prompt_tokens=completion.usage.prompt_tokens,
completion_tokens=completion.usage.completion_tokens,
total_tokens=completion.usage.total_tokens,
)
return ChatResponse(
id=completion.id,
choices=choices,
usage=usage,
model=completion.model,
created=completion.created,
)
def _stream_chat(self, params: Dict[str, Any]) -> Iterator[StreamChunk]:
"""
Stream chat response from OpenAI.
Args:
params: Request parameters
Yields:
StreamChunk objects
"""
stream = self.client.chat.completions.create(**params)
for chunk in stream:
chunk_data: ChatCompletionChunk = chunk
if not chunk_data.choices:
continue
choice = chunk_data.choices[0]
delta = choice.delta
# Extract content
content = delta.content if delta.content else None
# Extract finish reason
finish_reason = choice.finish_reason
# Extract usage (usually in last chunk)
usage = None
if hasattr(chunk_data, "usage") and chunk_data.usage:
usage = UsageStats(
prompt_tokens=chunk_data.usage.prompt_tokens,
completion_tokens=chunk_data.usage.completion_tokens,
total_tokens=chunk_data.usage.total_tokens,
)
yield StreamChunk(
id=chunk_data.id,
delta_content=content,
finish_reason=finish_reason,
usage=usage,
)
async def chat_async(
self,
model: str,
messages: List[ChatMessage],
stream: bool = False,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
tools: Optional[List[Dict[str, Any]]] = None,
tool_choice: Optional[str] = None,
**kwargs: Any,
) -> Union[ChatResponse, AsyncIterator[StreamChunk]]:
"""
Send async chat request to OpenAI.
Args:
model: Model ID
messages: Chat messages
stream: Whether to stream
max_tokens: Max tokens
temperature: Temperature
tools: Tool definitions
tool_choice: Tool choice
**kwargs: Additional args
Returns:
ChatResponse or AsyncIterator[StreamChunk]
"""
# Resolve model alias
model_id = MODEL_ALIASES.get(model, model)
# Convert messages
openai_messages = [msg.to_dict() for msg in messages]
# Build params
params: Dict[str, Any] = {
"model": model_id,
"messages": openai_messages,
"stream": stream,
}
if max_tokens:
params["max_tokens"] = max_tokens
if temperature is not None and "o1" not in model_id:
params["temperature"] = temperature
if tools:
params["tools"] = tools
if tool_choice:
params["tool_choice"] = tool_choice
if stream:
return self._stream_chat_async(params)
else:
completion = await self.async_client.chat.completions.create(**params)
# Convert to ChatResponse (similar to _sync_chat)
return self._convert_completion(completion)
async def _stream_chat_async(self, params: Dict[str, Any]) -> AsyncIterator[StreamChunk]:
"""
Stream async chat response.
Args:
params: Request parameters
Yields:
StreamChunk objects
"""
stream = await self.async_client.chat.completions.create(**params)
async for chunk in stream:
if not chunk.choices:
continue
choice = chunk.choices[0]
delta = choice.delta
yield StreamChunk(
id=chunk.id,
delta_content=delta.content,
finish_reason=choice.finish_reason,
)
def _convert_completion(self, completion: ChatCompletion) -> ChatResponse:
"""Helper to convert OpenAI completion to ChatResponse."""
choices = []
for choice in completion.choices:
tool_calls = None
if choice.message.tool_calls:
tool_calls = [
ToolCall(
id=tc.id,
type=tc.type,
function=ToolFunction(
name=tc.function.name,
arguments=tc.function.arguments,
),
)
for tc in choice.message.tool_calls
]
choices.append(
ChatResponseChoice(
index=choice.index,
message=ChatMessage(
role=choice.message.role,
content=choice.message.content,
tool_calls=tool_calls,
),
finish_reason=choice.finish_reason,
)
)
usage = None
if completion.usage:
usage = UsageStats(
prompt_tokens=completion.usage.prompt_tokens,
completion_tokens=completion.usage.completion_tokens,
total_tokens=completion.usage.total_tokens,
)
return ChatResponse(
id=completion.id,
choices=choices,
usage=usage,
model=completion.model,
created=completion.created,
)
def get_credits(self) -> Optional[Dict[str, Any]]:
"""
Get account credits.
Returns:
None (OpenAI doesn't provide credit API)
"""
return None
def clear_cache(self) -> None:
"""Clear model cache."""
self._models_cache = None
def get_raw_models(self) -> List[Dict[str, Any]]:
"""
Get raw model data as dictionaries.
Returns:
List of model dictionaries
"""
models = self.list_models()
return [
{
"id": model.id,
"name": model.name,
"description": model.description,
"context_length": model.context_length,
"pricing": model.pricing,
}
for model in models
]
def get_raw_model(self, model_id: str) -> Optional[Dict[str, Any]]:
"""
Get raw model data for a specific model.
Args:
model_id: Model identifier
Returns:
Model dictionary or None
"""
model = self.get_model(model_id)
if model:
return {
"id": model.id,
"name": model.name,
"description": model.description,
"context_length": model.context_length,
"pricing": model.pricing,
}
return None

View File

@@ -269,10 +269,17 @@ class OpenRouterProvider(AIProvider):
completion_tokens = usage_data.get("output_tokens", 0) or 0 completion_tokens = usage_data.get("output_tokens", 0) or 0
# Get cost if available # Get cost if available
# OpenRouter returns cost in different places:
# 1. As 'total_cost_usd' in usage object (rare)
# 2. As 'usage' at root level (common - this is the dollar amount)
total_cost = None
if hasattr(usage_data, "total_cost_usd"): if hasattr(usage_data, "total_cost_usd"):
total_cost = getattr(usage_data, "total_cost_usd", None) total_cost = getattr(usage_data, "total_cost_usd", None)
elif hasattr(usage_data, "usage"):
# OpenRouter puts cost as 'usage' field (dollar amount)
total_cost = getattr(usage_data, "usage", None)
elif isinstance(usage_data, dict): elif isinstance(usage_data, dict):
total_cost = usage_data.get("total_cost_usd") total_cost = usage_data.get("total_cost_usd") or usage_data.get("usage")
return UsageStats( return UsageStats(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,

60
oai/providers/registry.py Normal file
View File

@@ -0,0 +1,60 @@
"""
Provider registry for AI model providers.
This module maintains a central registry of all available AI providers,
allowing dynamic provider lookup and registration.
"""
from typing import Dict, List, Optional, Type
from oai.providers.base import AIProvider
# Global provider registry
PROVIDER_REGISTRY: Dict[str, Type[AIProvider]] = {}
def register_provider(name: str, provider_class: Type[AIProvider]) -> None:
"""
Register a provider class with the given name.
Args:
name: Provider identifier (e.g., "openrouter", "anthropic")
provider_class: The provider class to register
"""
PROVIDER_REGISTRY[name] = provider_class
def get_provider_class(name: str) -> Optional[Type[AIProvider]]:
"""
Get a provider class by name.
Args:
name: Provider identifier
Returns:
Provider class or None if not found
"""
return PROVIDER_REGISTRY.get(name)
def list_providers() -> List[str]:
"""
List all registered provider names.
Returns:
List of provider identifiers
"""
return list(PROVIDER_REGISTRY.keys())
def is_provider_registered(name: str) -> bool:
"""
Check if a provider is registered.
Args:
name: Provider identifier
Returns:
True if provider is registered
"""
return name in PROVIDER_REGISTRY

5
oai/tui/__init__.py Normal file
View File

@@ -0,0 +1,5 @@
"""Textual TUI interface for oAI."""
from oai.tui.app import oAIChatApp
__all__ = ["oAIChatApp"]

1069
oai/tui/app.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,23 @@
"""TUI screens for oAI."""
from oai.tui.screens.commands_screen import CommandsScreen
from oai.tui.screens.config_screen import ConfigScreen
from oai.tui.screens.conversation_selector import ConversationSelectorScreen
from oai.tui.screens.credits_screen import CreditsScreen
from oai.tui.screens.dialogs import AlertDialog, ConfirmDialog, InputDialog
from oai.tui.screens.help_screen import HelpScreen
from oai.tui.screens.model_selector import ModelSelectorScreen
from oai.tui.screens.stats_screen import StatsScreen
__all__ = [
"AlertDialog",
"CommandsScreen",
"ConfirmDialog",
"ConfigScreen",
"ConversationSelectorScreen",
"CreditsScreen",
"InputDialog",
"HelpScreen",
"ModelSelectorScreen",
"StatsScreen",
]

View File

@@ -0,0 +1,172 @@
"""Commands reference screen for oAI TUI."""
from textual.app import ComposeResult
from textual.containers import Container, Vertical, VerticalScroll
from textual.screen import ModalScreen
from textual.widgets import Button, Static
class CommandsScreen(ModalScreen[None]):
"""Modal screen showing all available commands."""
DEFAULT_CSS = """
CommandsScreen {
align: center middle;
}
CommandsScreen > Container {
width: 90;
height: 40;
background: #1e1e1e;
border: solid #555555;
}
CommandsScreen .header {
dock: top;
width: 100%;
height: auto;
background: #2d2d2d;
color: #cccccc;
padding: 0 2;
}
CommandsScreen .content {
width: 100%;
height: 1fr;
background: #1e1e1e;
padding: 2;
color: #cccccc;
overflow-y: auto;
scrollbar-background: #1e1e1e;
scrollbar-color: #555555;
scrollbar-size: 1 1;
}
CommandsScreen .footer {
dock: bottom;
width: 100%;
height: auto;
background: #2d2d2d;
padding: 1 2;
align: center middle;
}
"""
def compose(self) -> ComposeResult:
"""Compose the commands screen."""
with Container():
yield Static("[bold]Commands Reference[/]", classes="header")
with VerticalScroll(classes="content"):
yield Static(self._get_commands_text(), markup=True)
with Vertical(classes="footer"):
yield Button("Close", id="close-button", variant="primary")
def _get_commands_text(self) -> str:
"""Generate formatted commands text."""
return """[bold cyan]General Commands[/]
[green]/help[/] - Show help screen with keyboard shortcuts
[green]/commands[/] - Show this commands reference
[green]/model[/] - Open model selector (or press F2)
[green]/stats[/] - Show session statistics (or press Ctrl+S)
[green]/credits[/] - Check account credits (OpenRouter) or view console link
[green]/clear[/] - Clear chat display
[green]/reset[/] - Reset conversation history
[green]/retry[/] - Retry last prompt
[green]/paste[/] - Paste from clipboard
[bold cyan]Provider Commands[/]
[green]/provider[/] - Show current provider
[green]/provider openrouter[/] - Switch to OpenRouter
[green]/provider anthropic[/] - Switch to Anthropic (Claude)
[green]/provider openai[/] - Switch to OpenAI (ChatGPT)
[green]/provider ollama[/] - Switch to Ollama (local)
[bold cyan]Online Mode (Web Search)[/]
[green]/online[/] - Show online mode status
[green]/online on[/] - Enable web search
[green]/online off[/] - Disable web search
[dim]Search Providers:[/]
• [yellow]anthropic_native[/] - Anthropic's native search with citations ($0.01/search)
• [yellow]duckduckgo[/] - Free web scraping (default, works with all providers)
• [yellow]google[/] - Google Custom Search (requires API key)
[bold cyan]Configuration Commands[/]
[green]/config[/] - View all settings
[green]/config provider <name>[/] - Set default provider
[green]/config search_provider <provider>[/] - Set search provider (anthropic_native/duckduckgo/google)
[green]/config openrouter_api_key <key>[/] - Set OpenRouter API key
[green]/config anthropic_api_key <key>[/] - Set Anthropic API key
[green]/config openai_api_key <key>[/] - Set OpenAI API key
[green]/config ollama_base_url <url>[/] - Set Ollama server URL
[green]/config google_api_key <key>[/] - Set Google API key (for Google search)
[green]/config google_search_engine_id <id>[/] - Set Google Search Engine ID
[green]/config online on|off[/] - Set default online mode
[green]/config stream on|off[/] - Toggle streaming
[green]/config model <id>[/] - Set default model
[green]/config system <prompt>[/] - Set system prompt
[green]/config maxtoken <num>[/] - Set token limit
[bold cyan]Memory & Context[/]
[green]/memory on[/] - Enable conversation memory
[green]/memory off[/] - Disable memory (fresh context each message)
[bold cyan]Conversation Management[/]
[green]/save <name>[/] - Save current conversation
[green]/load <name>[/] - Load saved conversation
[green]/list[/] - List all saved conversations
[green]/delete <name>[/] - Delete a conversation
[green]/prev[/] - Show previous message
[green]/next[/] - Show next message
[bold cyan]Export Commands[/]
[green]/export md <file>[/] - Export conversation as Markdown
[green]/export json <file>[/] - Export as JSON
[green]/export html <file>[/] - Export as HTML
[bold cyan]MCP (Model Context Protocol)[/]
[green]/mcp on[/] - Enable MCP file access
[green]/mcp off[/] - Disable MCP
[green]/mcp status[/] - Show MCP status
[green]/mcp add <path>[/] - Add folder for file access
[green]/mcp add db <path>[/] - Add SQLite database
[green]/mcp remove <path>[/] - Remove folder/database
[green]/mcp list[/] - List allowed folders
[green]/mcp db list[/] - List added databases
[green]/mcp db <n>[/] - Switch to database mode
[green]/mcp files[/] - Switch to file mode
[green]/mcp write on[/] - Enable write mode (allows file modifications)
[green]/mcp write off[/] - Disable write mode
[bold cyan]System Prompt[/]
[green]/system <prompt>[/] - Set custom system prompt for session
[green]/config system <prompt>[/] - Set default system prompt
[bold cyan]Keyboard Shortcuts[/]
• [yellow]F1[/] - Help screen
• [yellow]F2[/] - Model selector
• [yellow]Ctrl+S[/] - Statistics
• [yellow]Ctrl+Q[/] - Quit
• [yellow]Ctrl+Y[/] - Copy latest reply in Markdown
• [yellow]Up/Down[/] - Command history
• [yellow]Tab[/] - Command completion
"""
def on_button_pressed(self, event: Button.Pressed) -> None:
"""Handle button press."""
self.dismiss()
def on_key(self, event) -> None:
"""Handle keyboard shortcuts."""
if event.key in ("escape", "enter"):
self.dismiss()

View File

@@ -0,0 +1,163 @@
"""Configuration screen for oAI TUI."""
from textual.app import ComposeResult
from textual.containers import Container, Vertical
from textual.screen import ModalScreen
from textual.widgets import Button, Static
from oai.config.settings import Settings
class ConfigScreen(ModalScreen[None]):
"""Modal screen displaying configuration settings."""
DEFAULT_CSS = """
ConfigScreen {
align: center middle;
}
ConfigScreen > Container {
width: 70;
height: auto;
background: #1e1e1e;
border: solid #555555;
}
ConfigScreen .header {
dock: top;
width: 100%;
height: auto;
background: #2d2d2d;
color: #cccccc;
padding: 0 2;
}
ConfigScreen .content {
width: 100%;
height: auto;
background: #1e1e1e;
padding: 2;
color: #cccccc;
}
ConfigScreen .footer {
dock: bottom;
width: 100%;
height: auto;
background: #2d2d2d;
padding: 1 2;
align: center middle;
}
"""
def __init__(self, settings: Settings, session=None):
super().__init__()
self.settings = settings
self.session = session
def compose(self) -> ComposeResult:
"""Compose the screen."""
with Container():
yield Static("[bold]Configuration[/]", classes="header")
with Vertical(classes="content"):
yield Static(self._get_config_text(), markup=True)
with Vertical(classes="footer"):
yield Button("Close", id="close", variant="primary")
def _get_config_text(self) -> str:
"""Generate the configuration text."""
from oai.constants import DEFAULT_SYSTEM_PROMPT
config_lines = ["[bold cyan]═══ CONFIGURATION ═══[/]\n"]
# Current Session Info
if self.session and self.session.client:
config_lines.append("[bold yellow]Current Session:[/]")
provider = self.session.client.provider_name
config_lines.append(f"[bold]Provider:[/] [green]{provider}[/]")
# Provider URL
if hasattr(self.session.client.provider, 'base_url'):
provider_url = self.session.client.provider.base_url
config_lines.append(f"[bold]Provider URL:[/] {provider_url}")
# API Key status
if provider == "ollama":
config_lines.append(f"[bold]API Key:[/] [dim]Not required (local)[/]")
else:
api_key = self.settings.get_provider_api_key(provider)
if api_key:
key_display = "***" + api_key[-4:] if len(api_key) > 4 else "***"
config_lines.append(f"[bold]API Key:[/] [green]{key_display}[/]")
else:
config_lines.append(f"[bold]API Key:[/] [red]Not set[/]")
# Current model
if self.session.selected_model:
model_name = self.session.selected_model.get("name", "Unknown")
config_lines.append(f"[bold]Current Model:[/] {model_name}")
config_lines.append("")
# Default Settings
config_lines.append("[bold yellow]Default Settings:[/]")
config_lines.append(f"[bold]Default Provider:[/] {self.settings.default_provider}")
# Mask helper
def mask_key(key):
if not key:
return "[red]Not set[/]"
return "***" + key[-4:] if len(key) > 4 else "***"
config_lines.append(f"[bold]OpenRouter Key:[/] {mask_key(self.settings.openrouter_api_key)}")
config_lines.append(f"[bold]Anthropic Key:[/] {mask_key(self.settings.anthropic_api_key)}")
config_lines.append(f"[bold]OpenAI Key:[/] {mask_key(self.settings.openai_api_key)}")
config_lines.append(f"[bold]Ollama URL:[/] {self.settings.ollama_base_url}")
config_lines.append(f"[bold]Default Model:[/] {self.settings.default_model or 'Not set'}")
# System prompt display
if self.settings.default_system_prompt is None:
system_prompt_display = f"[dim][default][/] {DEFAULT_SYSTEM_PROMPT[:40]}..."
elif self.settings.default_system_prompt == "":
system_prompt_display = "[dim][blank][/]"
else:
prompt = self.settings.default_system_prompt
system_prompt_display = prompt[:50] + "..." if len(prompt) > 50 else prompt
config_lines.append(f"[bold]System Prompt:[/] {system_prompt_display}")
config_lines.append("")
# Web Search Configuration
config_lines.append("[bold yellow]Web Search Configuration:[/]")
config_lines.append(f"[bold]Search Provider:[/] {self.settings.search_provider}")
# Show API key status based on search provider
if self.settings.search_provider == "google":
config_lines.append(f"[bold]Google API Key:[/] {mask_key(self.settings.google_api_key)}")
config_lines.append(f"[bold]Search Engine ID:[/] {mask_key(self.settings.google_search_engine_id)}")
elif self.settings.search_provider == "duckduckgo":
config_lines.append("[dim] DuckDuckGo requires no configuration (free)[/]")
elif self.settings.search_provider == "anthropic_native":
config_lines.append("[dim] Uses Anthropic API ($0.01 per search)[/]")
config_lines.append("")
# Other settings
config_lines.append("[bold]Streaming:[/] " + ("on" if self.settings.stream_enabled else "off"))
config_lines.append(f"[bold]Cost Warning:[/] ${self.settings.cost_warning_threshold:.4f}")
config_lines.append(f"[bold]Max Tokens:[/] {self.settings.max_tokens}")
config_lines.append(f"[bold]Default Online:[/] " + ("on" if self.settings.default_online_mode else "off"))
config_lines.append(f"[bold]Log Level:[/] {self.settings.log_level}")
config_lines.append("\n[dim]Use /config [setting] [value] to modify settings[/]")
return "\n".join(config_lines)
def on_button_pressed(self, event: Button.Pressed) -> None:
"""Handle button press."""
self.dismiss()
def on_key(self, event) -> None:
"""Handle keyboard shortcuts."""
if event.key in ("escape", "enter"):
self.dismiss()

View File

@@ -0,0 +1,205 @@
"""Conversation selector screen for oAI TUI."""
from typing import List, Optional
from textual.app import ComposeResult
from textual.containers import Container, Vertical
from textual.screen import ModalScreen
from textual.widgets import Button, DataTable, Input, Static
class ConversationSelectorScreen(ModalScreen[Optional[dict]]):
"""Modal screen for selecting a saved conversation."""
DEFAULT_CSS = """
ConversationSelectorScreen {
align: center middle;
}
ConversationSelectorScreen > Container {
width: 80%;
height: 70%;
background: #1e1e1e;
border: solid #555555;
layout: vertical;
}
ConversationSelectorScreen .header {
height: 3;
width: 100%;
background: #2d2d2d;
color: #cccccc;
padding: 0 2;
content-align: center middle;
}
ConversationSelectorScreen .search-input {
height: 3;
width: 100%;
background: #2a2a2a;
border: solid #555555;
margin: 0 0 1 0;
}
ConversationSelectorScreen .search-input:focus {
border: solid #888888;
}
ConversationSelectorScreen DataTable {
height: 1fr;
width: 100%;
background: #1e1e1e;
border: solid #555555;
}
ConversationSelectorScreen .footer {
height: 5;
width: 100%;
background: #2d2d2d;
padding: 1 2;
align: center middle;
}
ConversationSelectorScreen Button {
margin: 0 1;
}
"""
def __init__(self, conversations: List[dict]):
super().__init__()
self.all_conversations = conversations
self.filtered_conversations = conversations
self.selected_conversation: Optional[dict] = None
def compose(self) -> ComposeResult:
"""Compose the screen."""
with Container():
yield Static(
f"[bold]Load Conversation[/] [dim]({len(self.all_conversations)} saved)[/]",
classes="header"
)
yield Input(placeholder="Search conversations...", id="search-input", classes="search-input")
yield DataTable(id="conv-table", cursor_type="row", show_header=True, zebra_stripes=True)
with Vertical(classes="footer"):
yield Button("Load", id="load", variant="success")
yield Button("Cancel", id="cancel", variant="error")
def on_mount(self) -> None:
"""Initialize the table when mounted."""
table = self.query_one("#conv-table", DataTable)
# Add columns
table.add_column("#", width=5)
table.add_column("Name", width=40)
table.add_column("Messages", width=12)
table.add_column("Last Saved", width=20)
# Populate table
self._populate_table()
# Focus table if list is small (fits on screen), otherwise focus search
if len(self.all_conversations) <= 10:
table.focus()
else:
search_input = self.query_one("#search-input", Input)
search_input.focus()
def _populate_table(self) -> None:
"""Populate the table with conversations."""
table = self.query_one("#conv-table", DataTable)
table.clear()
for idx, conv in enumerate(self.filtered_conversations, 1):
name = conv.get("name", "Unknown")
message_count = str(conv.get("message_count", 0))
last_saved = conv.get("last_saved", "Unknown")
# Format timestamp if it's a full datetime
if "T" in last_saved or len(last_saved) > 20:
try:
# Truncate to just date and time
last_saved = last_saved[:19].replace("T", " ")
except:
pass
table.add_row(
str(idx),
name,
message_count,
last_saved,
key=str(idx)
)
def on_input_changed(self, event: Input.Changed) -> None:
"""Filter conversations based on search input."""
if event.input.id != "search-input":
return
search_term = event.value.lower()
if not search_term:
self.filtered_conversations = self.all_conversations
else:
self.filtered_conversations = [
c for c in self.all_conversations
if search_term in c.get("name", "").lower()
]
self._populate_table()
def on_data_table_row_selected(self, event: DataTable.RowSelected) -> None:
"""Handle row selection (click)."""
try:
row_index = int(event.row_key.value) - 1
if 0 <= row_index < len(self.filtered_conversations):
self.selected_conversation = self.filtered_conversations[row_index]
except (ValueError, IndexError):
pass
def on_data_table_row_highlighted(self, event) -> None:
"""Handle row highlight (arrow key navigation)."""
try:
table = self.query_one("#conv-table", DataTable)
if table.cursor_row is not None:
row_data = table.get_row_at(table.cursor_row)
if row_data:
row_index = int(row_data[0]) - 1
if 0 <= row_index < len(self.filtered_conversations):
self.selected_conversation = self.filtered_conversations[row_index]
except (ValueError, IndexError, AttributeError):
pass
def on_button_pressed(self, event: Button.Pressed) -> None:
"""Handle button press."""
if event.button.id == "load":
if self.selected_conversation:
self.dismiss(self.selected_conversation)
else:
self.dismiss(None)
else:
self.dismiss(None)
def on_key(self, event) -> None:
"""Handle keyboard shortcuts."""
if event.key == "escape":
self.dismiss(None)
elif event.key == "enter":
# If in search input, move to table
search_input = self.query_one("#search-input", Input)
if search_input.has_focus:
table = self.query_one("#conv-table", DataTable)
table.focus()
# If in table, select current row
else:
table = self.query_one("#conv-table", DataTable)
if table.cursor_row is not None:
try:
row_data = table.get_row_at(table.cursor_row)
if row_data:
row_index = int(row_data[0]) - 1
if 0 <= row_index < len(self.filtered_conversations):
selected = self.filtered_conversations[row_index]
self.dismiss(selected)
except (ValueError, IndexError, AttributeError):
if self.selected_conversation:
self.dismiss(self.selected_conversation)

View File

@@ -0,0 +1,158 @@
"""Credits screen for oAI TUI."""
from typing import Optional, Dict, Any
from textual.app import ComposeResult
from textual.containers import Container, Vertical
from textual.screen import ModalScreen
from textual.widgets import Button, Static
from oai.core.client import AIClient
class CreditsScreen(ModalScreen[None]):
"""Modal screen displaying account credits."""
DEFAULT_CSS = """
CreditsScreen {
align: center middle;
}
CreditsScreen > Container {
width: 60;
height: auto;
background: #1e1e1e;
border: solid #555555;
}
CreditsScreen .header {
dock: top;
width: 100%;
height: auto;
background: #2d2d2d;
color: #cccccc;
padding: 0 2;
}
CreditsScreen .content {
width: 100%;
height: auto;
background: #1e1e1e;
padding: 2;
color: #cccccc;
}
CreditsScreen .footer {
dock: bottom;
width: 100%;
height: auto;
background: #2d2d2d;
padding: 1 2;
align: center middle;
}
"""
def __init__(self, client: AIClient):
super().__init__()
self.client = client
self.credits_data: Optional[Dict[str, Any]] = None
def compose(self) -> ComposeResult:
"""Compose the screen."""
with Container():
yield Static("[bold]Account Credits[/]", classes="header")
with Vertical(classes="content"):
yield Static("[dim]Loading...[/]", id="credits-content", markup=True)
with Vertical(classes="footer"):
yield Button("Close", id="close", variant="primary")
def on_mount(self) -> None:
"""Fetch credits when mounted."""
self.fetch_credits()
def fetch_credits(self) -> None:
"""Fetch and display credits information."""
try:
self.credits_data = self.client.provider.get_credits()
content = self.query_one("#credits-content", Static)
content.update(self._get_credits_text())
except Exception as e:
content = self.query_one("#credits-content", Static)
content.update(f"[red]Error fetching credits:[/]\n{str(e)}")
def _get_credits_text(self) -> str:
"""Generate the credits text."""
if not self.credits_data:
# Provider-specific message when credits aren't available
if self.client.provider_name == "anthropic":
return """[yellow]Credit information not available via API[/]
Anthropic does not provide a public API endpoint for checking
account credits.
To view your account balance and usage:
[cyan]→ Visit console.anthropic.com[/]
[cyan]→ Navigate to Settings → Billing[/]
"""
elif self.client.provider_name == "openai":
return """[yellow]Credit information not available via API[/]
OpenAI does not provide credit balance through their API.
To view your account usage:
[cyan]→ Visit platform.openai.com[/]
[cyan]→ Navigate to Usage[/]
"""
elif self.client.provider_name == "ollama":
return """[yellow]Credit information not applicable[/]
Ollama runs locally on your machine and does not use credits.
"""
else:
return "[yellow]No credit information available[/]"
provider_name = self.client.provider_name.upper()
total = self.credits_data.get("total_credits")
used = self.credits_data.get("used_credits")
remaining = self.credits_data.get("credits_left", 0)
# Determine color based on absolute remaining amount
if remaining > 10:
remaining_color = "green"
elif remaining > 2:
remaining_color = "yellow"
else:
remaining_color = "red"
lines = [f"[bold cyan]═══ {provider_name} CREDITS ═══[/]\n"]
# If we have total/used info (OpenRouter)
if total is not None and used is not None:
percent_used = (used / total) * 100 if total > 0 else 0
percent_remaining = (remaining / total) * 100 if total > 0 else 0
lines.append(f"[bold]Total Credits:[/] ${total:.2f}")
lines.append(f"[bold]Used:[/] ${used:.2f} [dim]({percent_used:.1f}%)[/]")
lines.append(f"[bold]Remaining:[/] [{remaining_color}]${remaining:.2f}[/] [dim]({percent_remaining:.1f}%)[/]")
else:
# Anthropic - only shows balance
lines.append(f"[bold]Current Balance:[/] [{remaining_color}]${remaining:.2f}[/]")
lines.append("")
# Provider-specific footer
if self.client.provider_name == "openrouter":
lines.append("[dim]Visit openrouter.ai to add more credits[/]")
elif self.client.provider_name == "anthropic":
lines.append("[dim]Visit console.anthropic.com to manage billing[/]")
return "\n".join(lines)
def on_button_pressed(self, event: Button.Pressed) -> None:
"""Handle button press."""
self.dismiss()
def on_key(self, event) -> None:
"""Handle keyboard shortcuts."""
if event.key in ("escape", "enter"):
self.dismiss()

236
oai/tui/screens/dialogs.py Normal file
View File

@@ -0,0 +1,236 @@
"""Modal dialog screens for oAI TUI."""
from typing import Callable, Optional
from textual.app import ComposeResult
from textual.containers import Container, Horizontal, Vertical
from textual.screen import ModalScreen
from textual.widgets import Button, Input, Label, Static
class ConfirmDialog(ModalScreen[bool]):
"""A confirmation dialog with Yes/No buttons."""
DEFAULT_CSS = """
ConfirmDialog {
align: center middle;
}
ConfirmDialog > Container {
width: 60;
height: auto;
background: #2d2d2d;
border: solid #555555;
padding: 2;
}
ConfirmDialog Label {
width: 100%;
content-align: center middle;
margin-bottom: 2;
color: #cccccc;
}
ConfirmDialog Horizontal {
width: 100%;
height: auto;
align: center middle;
}
ConfirmDialog Button {
margin: 0 1;
}
"""
def __init__(
self,
message: str,
title: str = "Confirm",
yes_label: str = "Yes",
no_label: str = "No",
):
super().__init__()
self.message = message
self.title = title
self.yes_label = yes_label
self.no_label = no_label
def compose(self) -> ComposeResult:
"""Compose the dialog."""
with Container():
yield Static(f"[bold]{self.title}[/]", classes="dialog-title")
yield Label(self.message)
with Horizontal():
yield Button(self.yes_label, id="yes", variant="success")
yield Button(self.no_label, id="no", variant="error")
def on_button_pressed(self, event: Button.Pressed) -> None:
"""Handle button press."""
if event.button.id == "yes":
self.dismiss(True)
else:
self.dismiss(False)
def on_key(self, event) -> None:
"""Handle keyboard shortcuts."""
if event.key == "escape":
self.dismiss(False)
elif event.key == "enter":
self.dismiss(True)
class InputDialog(ModalScreen[Optional[str]]):
"""An input dialog for text entry."""
DEFAULT_CSS = """
InputDialog {
align: center middle;
}
InputDialog > Container {
width: 70;
height: auto;
background: #2d2d2d;
border: solid #555555;
padding: 2;
}
InputDialog Label {
width: 100%;
margin-bottom: 1;
color: #cccccc;
}
InputDialog Input {
width: 100%;
margin-bottom: 2;
background: #3a3a3a;
border: solid #555555;
}
InputDialog Input:focus {
border: solid #888888;
}
InputDialog Horizontal {
width: 100%;
height: auto;
align: center middle;
}
InputDialog Button {
margin: 0 1;
}
"""
def __init__(
self,
message: str,
title: str = "Input",
default: str = "",
placeholder: str = "",
):
super().__init__()
self.message = message
self.title = title
self.default = default
self.placeholder = placeholder
def compose(self) -> ComposeResult:
"""Compose the dialog."""
with Container():
yield Static(f"[bold]{self.title}[/]", classes="dialog-title")
yield Label(self.message)
yield Input(
value=self.default,
placeholder=self.placeholder,
id="input-field"
)
with Horizontal():
yield Button("OK", id="ok", variant="primary")
yield Button("Cancel", id="cancel")
def on_mount(self) -> None:
"""Focus the input field when mounted."""
input_field = self.query_one("#input-field", Input)
input_field.focus()
def on_button_pressed(self, event: Button.Pressed) -> None:
"""Handle button press."""
if event.button.id == "ok":
input_field = self.query_one("#input-field", Input)
self.dismiss(input_field.value)
else:
self.dismiss(None)
def on_input_submitted(self, event: Input.Submitted) -> None:
"""Handle Enter key in input field."""
self.dismiss(event.value)
def on_key(self, event) -> None:
"""Handle keyboard shortcuts."""
if event.key == "escape":
self.dismiss(None)
class AlertDialog(ModalScreen[None]):
"""A simple alert/message dialog."""
DEFAULT_CSS = """
AlertDialog {
align: center middle;
}
AlertDialog > Container {
width: 60;
height: auto;
background: #2d2d2d;
border: solid #555555;
padding: 2;
}
AlertDialog Label {
width: 100%;
content-align: center middle;
margin-bottom: 2;
color: #cccccc;
}
AlertDialog Horizontal {
width: 100%;
height: auto;
align: center middle;
}
"""
def __init__(self, message: str, title: str = "Alert", variant: str = "default"):
super().__init__()
self.message = message
self.title = title
self.variant = variant
def compose(self) -> ComposeResult:
"""Compose the dialog."""
# Choose color based on variant (using design system)
color = "$primary"
if self.variant == "error":
color = "$error"
elif self.variant == "success":
color = "$success"
elif self.variant == "warning":
color = "$warning"
with Container():
yield Static(f"[bold {color}]{self.title}[/]", classes="dialog-title")
yield Label(self.message)
with Horizontal():
yield Button("OK", id="ok", variant="primary")
def on_button_pressed(self, event: Button.Pressed) -> None:
"""Handle button press."""
self.dismiss()
def on_key(self, event) -> None:
"""Handle keyboard shortcuts."""
if event.key in ("escape", "enter"):
self.dismiss()

View File

@@ -0,0 +1,140 @@
"""Help screen for oAI TUI."""
from textual.app import ComposeResult
from textual.containers import Container, Vertical
from textual.screen import ModalScreen
from textual.widgets import Button, Static
class HelpScreen(ModalScreen[None]):
"""Modal screen displaying help and commands."""
DEFAULT_CSS = """
HelpScreen {
align: center middle;
}
HelpScreen > Container {
width: 90%;
height: 85%;
background: #1e1e1e;
border: solid #555555;
}
HelpScreen .header {
dock: top;
width: 100%;
height: auto;
background: #2d2d2d;
color: #cccccc;
padding: 0 2;
}
HelpScreen .content {
height: 1fr;
background: #1e1e1e;
padding: 2;
overflow-y: auto;
color: #cccccc;
}
HelpScreen .footer {
dock: bottom;
width: 100%;
height: auto;
background: #2d2d2d;
padding: 1 2;
align: center middle;
}
"""
def compose(self) -> ComposeResult:
"""Compose the screen."""
with Container():
yield Static("[bold]oAI Help & Commands[/]", classes="header")
with Vertical(classes="content"):
yield Static(self._get_help_text(), markup=True)
with Vertical(classes="footer"):
yield Button("Close", id="close", variant="primary")
def _get_help_text(self) -> str:
"""Generate the help text."""
return """
[bold cyan]═══ KEYBOARD SHORTCUTS ═══[/]
[bold]F1[/] Show this help (Ctrl+H may not work)
[bold]F2[/] Open model selector (Ctrl+M may not work)
[bold]F3[/] Copy last AI response to clipboard
[bold]Ctrl+S[/] Show session statistics
[bold]Ctrl+L[/] Clear chat display
[bold]Ctrl+P[/] Show previous message
[bold]Ctrl+N[/] Show next message
[bold]Ctrl+Y[/] Copy last AI response (alternative to F3)
[bold]Ctrl+Q[/] Quit application
[bold]Up/Down[/] Navigate input history
[bold]ESC[/] Close dialogs
[dim]Note: Some Ctrl keys may be captured by your terminal[/]
[bold cyan]═══ SLASH COMMANDS ═══[/]
[bold yellow]Session Control:[/]
/reset Clear conversation history (with confirmation)
/clear Clear the chat display
/memory on/off Toggle conversation memory
/online on/off Toggle online search mode
/exit, /quit, /bye Exit the application
[bold yellow]Model & Configuration:[/]
/model [search] Open model selector with optional search
/config View configuration settings
/config api Set API key (prompts for input)
/config stream on Enable streaming responses
/system [prompt] Set session system prompt
/maxtoken [n] Set session token limit
[bold yellow]Conversation Management:[/]
/save [name] Save current conversation
/load [name] Load saved conversation (shows picker if no name)
/list List all saved conversations
/delete <name> Delete a saved conversation
[bold yellow]Export:[/]
/export md [file] Export as Markdown
/export json [file] Export as JSON
/export html [file] Export as HTML
[bold yellow]History Navigation:[/]
/prev Show previous message in history
/next Show next message in history
[bold yellow]MCP (Model Context Protocol):[/]
/mcp on Enable MCP file access
/mcp off Disable MCP
/mcp status Show MCP status
/mcp add <path> Add folder for file access
/mcp list List registered folders
/mcp write Toggle write permissions
[bold yellow]Information & Utilities:[/]
/help Show this help screen
/stats Show session statistics
/credits Check account credits
/retry Retry last prompt
/paste Paste from clipboard and send
[bold cyan]═══ TIPS ═══[/]
• Type [bold]/[/] to see command suggestions with [bold]Tab[/] to autocomplete
• Use [bold]Up/Down arrows[/] to navigate your input history
• Type [bold]//[/] at start to escape commands (sends /help as literal message)
• All messages support [bold]Markdown formatting[/] with syntax highlighting
• Responses stream in real-time for better interactivity
• Enable MCP to let AI access your local files and databases
• Use [bold]F1[/] or [bold]F2[/] if Ctrl shortcuts don't work in your terminal
"""
def on_button_pressed(self, event: Button.Pressed) -> None:
"""Handle button press."""
self.dismiss()
def on_key(self, event) -> None:
"""Handle keyboard shortcuts."""
if event.key in ("escape", "enter"):
self.dismiss()

View File

@@ -0,0 +1,254 @@
"""Model selector screen for oAI TUI."""
from typing import List, Optional
from textual.app import ComposeResult
from textual.containers import Container, Vertical
from textual.screen import ModalScreen
from textual.widgets import Button, DataTable, Input, Label, Static
class ModelSelectorScreen(ModalScreen[Optional[dict]]):
"""Modal screen for selecting an AI model."""
DEFAULT_CSS = """
ModelSelectorScreen {
align: center middle;
}
ModelSelectorScreen > Container {
width: 90%;
height: 85%;
background: #1e1e1e;
border: solid #555555;
layout: vertical;
}
ModelSelectorScreen .header {
height: 3;
width: 100%;
background: #2d2d2d;
color: #cccccc;
padding: 0 2;
content-align: center middle;
}
ModelSelectorScreen .search-input {
height: 3;
width: 100%;
background: #2a2a2a;
border: solid #555555;
margin: 0 0 1 0;
}
ModelSelectorScreen .search-input:focus {
border: solid #888888;
}
ModelSelectorScreen DataTable {
height: 1fr;
width: 100%;
background: #1e1e1e;
border: solid #555555;
}
ModelSelectorScreen .footer {
height: 5;
width: 100%;
background: #2d2d2d;
padding: 1 2;
align: center middle;
}
ModelSelectorScreen Button {
margin: 0 1;
}
"""
def __init__(self, models: List[dict], current_model: Optional[str] = None):
super().__init__()
self.all_models = models
self.filtered_models = models
self.current_model = current_model
self.selected_model: Optional[dict] = None
def compose(self) -> ComposeResult:
"""Compose the screen."""
with Container():
yield Static(
f"[bold]Select Model[/] [dim]({len(self.all_models)} available)[/]",
classes="header"
)
yield Input(placeholder="Search to filter models...", id="search-input", classes="search-input")
yield DataTable(id="model-table", cursor_type="row", show_header=True, zebra_stripes=True)
with Vertical(classes="footer"):
yield Button("Select", id="select", variant="success")
yield Button("Cancel", id="cancel", variant="error")
def on_mount(self) -> None:
"""Initialize the table when mounted."""
table = self.query_one("#model-table", DataTable)
# Add columns
table.add_column("#", width=5)
table.add_column("Model ID", width=35)
table.add_column("Name", width=30)
table.add_column("Context", width=10)
table.add_column("Price", width=12)
table.add_column("Img", width=4)
table.add_column("Tools", width=6)
table.add_column("Online", width=7)
# Populate table
self._populate_table()
# Focus table if list is small (fits on screen), otherwise focus search
if len(self.filtered_models) <= 20:
table.focus()
else:
search_input = self.query_one("#search-input", Input)
search_input.focus()
def _populate_table(self) -> None:
"""Populate the table with models."""
table = self.query_one("#model-table", DataTable)
table.clear()
rows_added = 0
for idx, model in enumerate(self.filtered_models, 1):
try:
model_id = model.get("id", "")
name = model.get("name", "")
context = str(model.get("context_length", "N/A"))
# Format pricing
pricing = model.get("pricing", {})
prompt_price = pricing.get("prompt", "0")
completion_price = pricing.get("completion", "0")
# Convert to numbers and format
try:
prompt = float(prompt_price) * 1000000 # Convert to per 1M tokens
completion = float(completion_price) * 1000000
if prompt == 0 and completion == 0:
price = "Free"
else:
price = f"${prompt:.2f}/${completion:.2f}"
except:
price = "N/A"
# Check capabilities
architecture = model.get("architecture", {})
modality = architecture.get("modality", "")
supported_params = model.get("supported_parameters", [])
# Vision support: check if modality contains "image"
supports_vision = "image" in modality
# Tool support: check if "tools" or "tool_choice" in supported_parameters
supports_tools = "tools" in supported_params or "tool_choice" in supported_params
# Online support: check if model can use :online suffix (most models can)
# Models that already have :online in their ID support it
supports_online = ":online" in model_id or model_id not in ["openrouter/free"]
# Format capability indicators
img_indicator = "" if supports_vision else "-"
tools_indicator = "" if supports_tools else "-"
web_indicator = "" if supports_online else "-"
# Add row
table.add_row(
str(idx),
model_id,
name,
context,
price,
img_indicator,
tools_indicator,
web_indicator,
key=str(idx)
)
rows_added += 1
except Exception:
# Silently skip rows that fail to add
pass
def on_input_changed(self, event: Input.Changed) -> None:
"""Filter models based on search input."""
if event.input.id != "search-input":
return
search_term = event.value.lower()
if not search_term:
self.filtered_models = self.all_models
else:
self.filtered_models = [
m for m in self.all_models
if search_term in m.get("id", "").lower()
or search_term in m.get("name", "").lower()
]
self._populate_table()
def on_data_table_row_selected(self, event: DataTable.RowSelected) -> None:
"""Handle row selection (click or arrow navigation)."""
try:
row_index = int(event.row_key.value) - 1
if 0 <= row_index < len(self.filtered_models):
self.selected_model = self.filtered_models[row_index]
except (ValueError, IndexError):
pass
def on_data_table_row_highlighted(self, event) -> None:
"""Handle row highlight (arrow key navigation)."""
try:
table = self.query_one("#model-table", DataTable)
if table.cursor_row is not None:
row_data = table.get_row_at(table.cursor_row)
if row_data:
row_index = int(row_data[0]) - 1
if 0 <= row_index < len(self.filtered_models):
self.selected_model = self.filtered_models[row_index]
except (ValueError, IndexError, AttributeError):
pass
def on_button_pressed(self, event: Button.Pressed) -> None:
"""Handle button press."""
if event.button.id == "select":
if self.selected_model:
self.dismiss(self.selected_model)
else:
# No selection, dismiss without result
self.dismiss(None)
else:
self.dismiss(None)
def on_key(self, event) -> None:
"""Handle keyboard shortcuts."""
if event.key == "escape":
self.dismiss(None)
elif event.key == "enter":
# If in search input, move to table
search_input = self.query_one("#search-input", Input)
if search_input.has_focus:
table = self.query_one("#model-table", DataTable)
table.focus()
# If in table or anywhere else, select current row
else:
table = self.query_one("#model-table", DataTable)
# Get the currently highlighted row
if table.cursor_row is not None:
try:
row_key = table.get_row_at(table.cursor_row)
if row_key:
row_index = int(row_key[0]) - 1
if 0 <= row_index < len(self.filtered_models):
selected = self.filtered_models[row_index]
self.dismiss(selected)
except (ValueError, IndexError, AttributeError):
# Fall back to previously selected model
if self.selected_model:
self.dismiss(self.selected_model)

View File

@@ -0,0 +1,129 @@
"""Statistics screen for oAI TUI."""
from textual.app import ComposeResult
from textual.containers import Container, Vertical
from textual.screen import ModalScreen
from textual.widgets import Button, Static
from oai.core.session import ChatSession
class StatsScreen(ModalScreen[None]):
"""Modal screen displaying session statistics."""
DEFAULT_CSS = """
StatsScreen {
align: center middle;
}
StatsScreen > Container {
width: 70;
height: auto;
background: #1e1e1e;
border: solid #555555;
}
StatsScreen .header {
dock: top;
width: 100%;
height: auto;
background: #2d2d2d;
color: #cccccc;
padding: 0 2;
}
StatsScreen .content {
width: 100%;
height: auto;
background: #1e1e1e;
padding: 2;
color: #cccccc;
}
StatsScreen .footer {
dock: bottom;
width: 100%;
height: auto;
background: #2d2d2d;
padding: 1 2;
align: center middle;
}
"""
def __init__(self, session: ChatSession):
super().__init__()
self.session = session
def compose(self) -> ComposeResult:
"""Compose the screen."""
with Container():
yield Static("[bold]Session Statistics[/]", classes="header")
with Vertical(classes="content"):
yield Static(self._get_stats_text(), markup=True)
with Vertical(classes="footer"):
yield Button("Close", id="close", variant="primary")
def _get_stats_text(self) -> str:
"""Generate the statistics text."""
stats = self.session.stats
# Calculate averages
avg_input = stats.total_input_tokens // stats.message_count if stats.message_count > 0 else 0
avg_output = stats.total_output_tokens // stats.message_count if stats.message_count > 0 else 0
avg_cost = stats.total_cost / stats.message_count if stats.message_count > 0 else 0
# Get model info
model_name = "None"
model_context = "N/A"
if self.session.selected_model:
model_name = self.session.selected_model.get("name", "Unknown")
model_context = str(self.session.selected_model.get("context_length", "N/A"))
# MCP status
mcp_status = "Disabled"
if self.session.mcp_manager and self.session.mcp_manager.enabled:
mode = self.session.mcp_manager.mode
if mode == "files":
write = " (Write)" if self.session.mcp_manager.write_enabled else ""
mcp_status = f"Enabled - Files{write}"
elif mode == "database":
db_idx = self.session.mcp_manager.selected_db_index
if db_idx is not None:
db_name = self.session.mcp_manager.databases[db_idx]["name"]
mcp_status = f"Enabled - Database ({db_name})"
return f"""
[bold cyan]═══ SESSION INFO ═══[/]
[bold]Messages:[/] {stats.message_count}
[bold]Current Model:[/] {model_name}
[bold]Context Length:[/] {model_context}
[bold]Memory:[/] {"Enabled" if self.session.memory_enabled else "Disabled"}
[bold]Online Mode:[/] {"Enabled" if self.session.online_enabled else "Disabled"}
[bold]MCP:[/] {mcp_status}
[bold cyan]═══ TOKEN USAGE ═══[/]
[bold]Input Tokens:[/] {stats.total_input_tokens:,}
[bold]Output Tokens:[/] {stats.total_output_tokens:,}
[bold]Total Tokens:[/] {stats.total_tokens:,}
[bold]Avg Input/Msg:[/] {avg_input:,}
[bold]Avg Output/Msg:[/] {avg_output:,}
[bold cyan]═══ COSTS ═══[/]
[bold]Total Cost:[/] ${stats.total_cost:.6f}
[bold]Avg Cost/Msg:[/] ${avg_cost:.6f}
[bold cyan]═══ HISTORY ═══[/]
[bold]History Size:[/] {len(self.session.history)} entries
[bold]Current Index:[/] {self.session.current_index + 1 if self.session.history else 0}
[bold]Memory Start:[/] {self.session.memory_start_index + 1}
"""
def on_button_pressed(self, event: Button.Pressed) -> None:
"""Handle button press."""
self.dismiss()
def on_key(self, event) -> None:
"""Handle keyboard shortcuts."""
if event.key in ("escape", "enter"):
self.dismiss()

174
oai/tui/styles.tcss Normal file
View File

@@ -0,0 +1,174 @@
/* Textual CSS for oAI TUI - Using Textual Design System */
Screen {
background: $background;
overflow: hidden;
}
Header {
dock: top;
height: auto;
background: #2d2d2d;
color: #cccccc;
padding: 0 1;
border-bottom: solid #555555;
}
ChatDisplay {
background: $background;
border: none;
padding: 1 0 1 1; /* top right bottom left - no right padding for scrollbar */
scrollbar-background: $background;
scrollbar-color: $primary;
scrollbar-size: 1 1;
scrollbar-gutter: stable; /* Reserve space for scrollbar */
overflow-y: auto;
}
UserMessageWidget {
margin: 0 0 1 0;
padding: 1;
background: $surface;
border-left: thick $success;
height: auto;
}
SystemMessageWidget {
margin: 0 0 1 0;
padding: 1;
background: #2a2a2a;
border-left: thick #888888;
height: auto;
color: #cccccc;
}
AssistantMessageWidget {
margin: 0 0 1 0;
padding: 1;
background: $background;
border-left: thick $accent;
height: auto;
}
#assistant-label {
margin-bottom: 1;
color: #cccccc;
}
#assistant-content {
height: auto;
max-height: 100%;
color: #cccccc;
link-color: #888888;
link-style: none;
border: none;
scrollbar-background: transparent;
scrollbar-color: #555555;
}
InputBar {
dock: bottom;
height: auto;
background: #2d2d2d;
align: center middle;
border-top: solid #555555;
padding: 1;
}
#input-prefix {
width: auto;
padding: 0 1;
content-align: center middle;
color: #888888;
}
#input-prefix.prefix-hidden {
display: none;
}
#chat-input {
width: 85%;
height: 5;
min-height: 5;
background: #3a3a3a;
border: none;
padding: 1 2;
color: #ffffff;
content-align: left top;
}
#chat-input:focus {
background: #404040;
}
#command-dropdown {
display: none;
dock: bottom;
offset-y: -5;
offset-x: 7.5%;
height: auto;
max-height: 12;
width: 85%;
background: #2d2d2d;
border: solid #555555;
padding: 0;
layer: overlay;
}
#command-dropdown.visible {
display: block;
}
#command-dropdown #command-list {
background: #2d2d2d;
border: none;
scrollbar-background: #2d2d2d;
scrollbar-color: #555555;
}
Footer {
dock: bottom;
height: auto;
background: #252525;
color: #888888;
padding: 0 1;
}
/* Button styles */
Button {
height: 3;
min-width: 10;
background: #3a3a3a;
color: #cccccc;
border: none;
}
Button:hover {
background: #4a4a4a;
}
Button:focus {
background: #505050;
}
Button.-primary {
background: #3a3a3a;
}
Button.-success {
background: #2d5016;
color: #90ee90;
}
Button.-success:hover {
background: #3a6b1e;
}
Button.-error {
background: #5a1a1a;
color: #ff6b6b;
}
Button.-error:hover {
background: #6e2222;
}

View File

@@ -0,0 +1,17 @@
"""TUI widgets for oAI."""
from oai.tui.widgets.chat_display import ChatDisplay
from oai.tui.widgets.footer import Footer
from oai.tui.widgets.header import Header
from oai.tui.widgets.input_bar import InputBar
from oai.tui.widgets.message import AssistantMessageWidget, SystemMessageWidget, UserMessageWidget
__all__ = [
"ChatDisplay",
"Footer",
"Header",
"InputBar",
"UserMessageWidget",
"SystemMessageWidget",
"AssistantMessageWidget",
]

View File

@@ -0,0 +1,21 @@
"""Chat display widget for oAI TUI."""
from textual.containers import ScrollableContainer
from textual.widgets import Static
class ChatDisplay(ScrollableContainer):
"""Scrollable container for chat messages."""
def __init__(self):
super().__init__(id="chat-display")
async def add_message(self, widget: Static) -> None:
"""Add a message widget to the display."""
await self.mount(widget)
self.scroll_end(animate=False)
def clear_messages(self) -> None:
"""Clear all messages from the display."""
for child in list(self.children):
child.remove()

View File

@@ -0,0 +1,214 @@
"""Command dropdown menu for TUI input."""
from textual.app import ComposeResult
from textual.containers import VerticalScroll
from textual.widget import Widget
from textual.widgets import Label, OptionList
from textual.widgets.option_list import Option
from oai.commands import registry
class CommandDropdown(VerticalScroll):
"""Dropdown menu showing available commands."""
DEFAULT_CSS = """
CommandDropdown {
display: none;
height: auto;
max-height: 12;
width: 80;
background: #2d2d2d;
border: solid #555555;
padding: 0;
layer: overlay;
}
CommandDropdown.visible {
display: block;
}
CommandDropdown OptionList {
height: auto;
max-height: 12;
background: #2d2d2d;
border: none;
padding: 0;
}
CommandDropdown OptionList > .option-list--option {
padding: 0 2;
color: #cccccc;
background: transparent;
}
CommandDropdown OptionList > .option-list--option-highlighted {
background: #3e3e3e;
color: #ffffff;
}
"""
def __init__(self):
"""Initialize the command dropdown."""
super().__init__(id="command-dropdown")
self._all_commands = []
self._load_commands()
def _load_commands(self) -> None:
"""Load all available commands."""
# Get base commands with descriptions
base_commands = [
("/help", "Show help screen"),
("/commands", "Show all commands"),
("/model", "Select AI model"),
("/provider", "Switch AI provider"),
("/provider openrouter", "Switch to OpenRouter"),
("/provider anthropic", "Switch to Anthropic (Claude)"),
("/provider openai", "Switch to OpenAI (GPT)"),
("/provider ollama", "Switch to Ollama (local)"),
("/stats", "Show session statistics"),
("/credits", "Check account credits or view console link"),
("/clear", "Clear chat display"),
("/reset", "Reset conversation history"),
("/memory on", "Enable conversation memory"),
("/memory off", "Disable memory"),
("/online on", "Enable online search"),
("/online off", "Disable online search"),
("/save", "Save current conversation"),
("/load", "Load saved conversation"),
("/list", "List saved conversations"),
("/delete", "Delete a conversation"),
("/export md", "Export as Markdown"),
("/export json", "Export as JSON"),
("/export html", "Export as HTML"),
("/prev", "Show previous message"),
("/next", "Show next message"),
("/config", "View configuration"),
("/config provider", "Set default provider"),
("/config search_provider", "Set search provider (anthropic_native/duckduckgo/google)"),
("/config openrouter_api_key", "Set OpenRouter API key"),
("/config anthropic_api_key", "Set Anthropic API key"),
("/config openai_api_key", "Set OpenAI API key"),
("/config ollama_base_url", "Set Ollama server URL"),
("/config google_api_key", "Set Google API key"),
("/config google_search_engine_id", "Set Google Search Engine ID"),
("/config online", "Set default online mode (on/off)"),
("/config stream", "Toggle streaming (on/off)"),
("/config model", "Set default model"),
("/config system", "Set system prompt"),
("/config maxtoken", "Set token limit"),
("/system", "Set system prompt"),
("/maxtoken", "Set token limit"),
("/retry", "Retry last prompt"),
("/paste", "Paste from clipboard"),
("/mcp on", "Enable MCP file access"),
("/mcp off", "Disable MCP"),
("/mcp status", "Show MCP status"),
("/mcp add", "Add folder/database"),
("/mcp remove", "Remove folder/database"),
("/mcp list", "List folders"),
("/mcp write on", "Enable write mode"),
("/mcp write off", "Disable write mode"),
("/mcp files", "Switch to file mode"),
("/mcp db list", "List databases"),
]
self._all_commands = base_commands
def compose(self) -> ComposeResult:
"""Compose the dropdown."""
yield OptionList(id="command-list")
def show_commands(self, filter_text: str = "") -> None:
"""Show commands matching the filter.
Args:
filter_text: Text to filter commands by
"""
option_list = self.query_one("#command-list", OptionList)
option_list.clear_options()
if not filter_text.startswith("/"):
self.remove_class("visible")
return
# Remove the leading slash for filtering
filter_without_slash = filter_text[1:].lower()
# Filter commands
if filter_without_slash:
matching = []
# Check if user typed a parent command (e.g., "/config")
# If so, show all sub-commands that start with it
parent_command = "/" + filter_without_slash
has_subcommands = any(
cmd.startswith(parent_command + " ")
for cmd, _ in self._all_commands
)
if has_subcommands and parent_command in [cmd for cmd, _ in self._all_commands]:
# Show the parent command and all its sub-commands
matching = [
(cmd, desc) for cmd, desc in self._all_commands
if cmd == parent_command or cmd.startswith(parent_command + " ")
]
else:
# Regular filtering - show if filter text is contained anywhere in the command
matching = [
(cmd, desc) for cmd, desc in self._all_commands
if filter_without_slash in cmd[1:].lower() # Skip the / in command for matching
]
else:
# Show all commands when just "/" is typed
matching = self._all_commands
if not matching:
self.remove_class("visible")
return
# Add options - limit to 15 results (increased from 10 for sub-commands)
for cmd, desc in matching[:15]:
# Format: command in white, description in gray, separated by spaces
label = f"{cmd} [dim]{desc}[/]" if desc else cmd
option_list.add_option(Option(label, id=cmd))
self.add_class("visible")
# Auto-select first option
if len(option_list._options) > 0:
option_list.highlighted = 0
def hide(self) -> None:
"""Hide the dropdown."""
self.remove_class("visible")
def get_selected_command(self) -> str | None:
"""Get the currently selected command.
Returns:
Selected command text or None
"""
option_list = self.query_one("#command-list", OptionList)
if option_list.highlighted is not None:
option = option_list.get_option_at_index(option_list.highlighted)
return option.id
return None
def move_selection_up(self) -> None:
"""Move selection up in the list."""
option_list = self.query_one("#command-list", OptionList)
if option_list.option_count > 0:
if option_list.highlighted is None:
option_list.highlighted = option_list.option_count - 1
elif option_list.highlighted > 0:
option_list.highlighted -= 1
def move_selection_down(self) -> None:
"""Move selection down in the list."""
option_list = self.query_one("#command-list", OptionList)
if option_list.option_count > 0:
if option_list.highlighted is None:
option_list.highlighted = 0
elif option_list.highlighted < option_list.option_count - 1:
option_list.highlighted += 1

View File

@@ -0,0 +1,58 @@
"""Command suggester for TUI input."""
from typing import Iterable
from textual.suggester import Suggester
from oai.commands import registry
class CommandSuggester(Suggester):
"""Suggester that provides command completions."""
def __init__(self):
"""Initialize the command suggester."""
super().__init__(use_cache=False, case_sensitive=False)
# Get all command names from registry
self._commands = []
self._update_commands()
def _update_commands(self) -> None:
"""Update the list of available commands."""
# Get all registered command names
command_names = registry.get_all_names()
# Add common MCP subcommands for better UX
mcp_subcommands = [
"/mcp on",
"/mcp off",
"/mcp status",
"/mcp add",
"/mcp remove",
"/mcp list",
"/mcp write on",
"/mcp write off",
"/mcp files",
"/mcp db list",
]
self._commands = command_names + mcp_subcommands
async def get_suggestion(self, value: str) -> str | None:
"""Get a command suggestion based on the current input.
Args:
value: Current input value
Returns:
Suggested completion or None
"""
if not value or not value.startswith("/"):
return None
# Find the first command that starts with the input
value_lower = value.lower()
for cmd in self._commands:
if cmd.lower().startswith(value_lower) and cmd.lower() != value_lower:
# Return the rest of the command (after what's already typed)
return cmd[len(value):]
return None

39
oai/tui/widgets/footer.py Normal file
View File

@@ -0,0 +1,39 @@
"""Footer widget for oAI TUI."""
from textual.app import ComposeResult
from textual.widgets import Static
class Footer(Static):
"""Footer displaying session metrics."""
def __init__(self):
super().__init__()
self.tokens_in = 0
self.tokens_out = 0
self.cost = 0.0
self.messages = 0
def compose(self) -> ComposeResult:
"""Compose the footer."""
yield Static(self._format_footer(), id="footer-content")
def _format_footer(self) -> str:
"""Format the footer text."""
return (
f"[dim]Messages: {self.messages} | "
f"Tokens: {self.tokens_in + self.tokens_out:,} "
f"({self.tokens_in:,} in, {self.tokens_out:,} out) | "
f"Cost: ${self.cost:.4f}[/]"
)
def update_stats(
self, tokens_in: int, tokens_out: int, cost: float, messages: int
) -> None:
"""Update the displayed statistics."""
self.tokens_in = tokens_in
self.tokens_out = tokens_out
self.cost = cost
self.messages = messages
content = self.query_one("#footer-content", Static)
content.update(self._format_footer())

83
oai/tui/widgets/header.py Normal file
View File

@@ -0,0 +1,83 @@
"""Header widget for oAI TUI."""
from textual.app import ComposeResult
from textual.widgets import Static
from typing import Optional, Dict, Any
class Header(Static):
"""Header displaying app title, version, current model, and capabilities."""
def __init__(self, version: str = "3.0.1", model: str = "", model_info: Optional[Dict[str, Any]] = None, provider: str = ""):
super().__init__()
self.version = version
self.model = model
self.model_info = model_info or {}
self.provider = provider
def compose(self) -> ComposeResult:
"""Compose the header."""
yield Static(self._format_header(), id="header-content")
def _format_capabilities(self) -> str:
"""Format capability icons based on model info."""
if not self.model_info:
return ""
icons = []
# Check vision support
architecture = self.model_info.get("architecture", {})
modality = architecture.get("modality", "")
if "image" in modality:
icons.append("[bold cyan]👁️[/]") # Bright if supported
else:
icons.append("[dim]👁️[/]") # Dim if not supported
# Check tool support
supported_params = self.model_info.get("supported_parameters", [])
if "tools" in supported_params or "tool_choice" in supported_params:
icons.append("[bold cyan]🔧[/]")
else:
icons.append("[dim]🔧[/]")
# Check online support (most models support :online suffix)
model_id = self.model_info.get("id", "")
if ":online" in model_id or model_id not in ["openrouter/free"]:
icons.append("[bold cyan]🌐[/]")
else:
icons.append("[dim]🌐[/]")
return " ".join(icons) if icons else ""
def _format_header(self) -> str:
"""Format the header text."""
# Show provider : model format
if self.provider and self.model:
provider_model = f"[bold cyan]{self.provider}[/] [dim]:[/] [bold]{self.model}[/]"
elif self.provider:
provider_model = f"[bold cyan]{self.provider}[/]"
elif self.model:
provider_model = f"[bold]{self.model}[/]"
else:
provider_model = ""
capabilities = self._format_capabilities()
capabilities_text = f" {capabilities}" if capabilities else ""
# Format: oAI v{version} | provider : model capabilities
version_text = f"[bold cyan]oAI[/] [dim]v{self.version}[/]"
if provider_model:
return f"{version_text} [dim]|[/] {provider_model}{capabilities_text}"
else:
return version_text
def update_model(self, model: str, model_info: Optional[Dict[str, Any]] = None, provider: Optional[str] = None) -> None:
"""Update the displayed model and capabilities."""
self.model = model
if model_info:
self.model_info = model_info
if provider is not None:
self.provider = provider
content = self.query_one("#header-content", Static)
content.update(self._format_header())

View File

@@ -0,0 +1,49 @@
"""Input bar widget for oAI TUI."""
from textual.app import ComposeResult
from textual.containers import Horizontal
from textual.widgets import Input, Static
class InputBar(Horizontal):
"""Input bar with prompt prefix and text input."""
def __init__(self):
super().__init__(id="input-bar")
self.mcp_status = ""
self.online_mode = False
def compose(self) -> ComposeResult:
"""Compose the input bar."""
yield Static(self._format_prefix(), id="input-prefix", classes="prefix-hidden" if not (self.mcp_status or self.online_mode) else "")
yield Input(
placeholder="Type a message or /command...",
id="chat-input"
)
def _format_prefix(self) -> str:
"""Format the input prefix with status indicators."""
indicators = []
if self.mcp_status:
indicators.append(f"[cyan]{self.mcp_status}[/]")
if self.online_mode:
indicators.append("[green]🌐[/]")
prefix = " ".join(indicators) + " " if indicators else ""
return f"{prefix}[bold]>[/]"
def update_mcp_status(self, status: str) -> None:
"""Update MCP status indicator."""
self.mcp_status = status
prefix = self.query_one("#input-prefix", Static)
prefix.update(self._format_prefix())
def update_online_mode(self, online: bool) -> None:
"""Update online mode indicator."""
self.online_mode = online
prefix = self.query_one("#input-prefix", Static)
prefix.update(self._format_prefix())
def get_input(self) -> Input:
"""Get the input widget."""
return self.query_one("#chat-input", Input)

View File

@@ -0,0 +1,98 @@
"""Message widgets for oAI TUI."""
from typing import Any, AsyncIterator, Tuple
from rich.console import Console
from rich.markdown import Markdown
from rich.style import Style
from rich.theme import Theme
from textual.app import ComposeResult
from textual.widgets import RichLog, Static
# Custom theme for Markdown rendering - neutral colors matching the dark theme
MARKDOWN_THEME = Theme({
"markdown.text": Style(color="#cccccc"),
"markdown.paragraph": Style(color="#cccccc"),
"markdown.code": Style(color="#e0e0e0", bgcolor="#2a2a2a"),
"markdown.code_block": Style(color="#e0e0e0", bgcolor="#2a2a2a"),
"markdown.heading": Style(color="#ffffff", bold=True),
"markdown.h1": Style(color="#ffffff", bold=True),
"markdown.h2": Style(color="#eeeeee", bold=True),
"markdown.h3": Style(color="#dddddd", bold=True),
"markdown.link": Style(color="#aaaaaa", underline=False),
"markdown.link_url": Style(color="#888888"),
"markdown.emphasis": Style(color="#cccccc", italic=True),
"markdown.strong": Style(color="#ffffff", bold=True),
})
class UserMessageWidget(Static):
"""Widget for displaying user messages."""
def __init__(self, content: str):
super().__init__()
self.content = content
def compose(self) -> ComposeResult:
"""Compose the user message."""
yield Static(f"[bold green]You:[/] {self.content}")
class SystemMessageWidget(Static):
"""Widget for displaying system/info messages without 'You:' prefix."""
def __init__(self, content: str):
super().__init__()
self.content = content
def compose(self) -> ComposeResult:
"""Compose the system message."""
yield Static(self.content)
class AssistantMessageWidget(Static):
"""Widget for displaying assistant responses with streaming support."""
def __init__(self, model_name: str = "Assistant", chat_display=None):
super().__init__()
self.model_name = model_name
self.full_text = ""
self.chat_display = chat_display
def compose(self) -> ComposeResult:
"""Compose the assistant message."""
yield Static(f"[bold]{self.model_name}:[/]", id="assistant-label")
yield RichLog(id="assistant-content", highlight=True, markup=True, wrap=True)
async def stream_response(self, response_iterator: AsyncIterator) -> Tuple[str, Any]:
"""Stream tokens progressively and return final text and usage."""
log = self.query_one("#assistant-content", RichLog)
self.full_text = ""
usage = None
async for chunk in response_iterator:
if hasattr(chunk, "delta_content") and chunk.delta_content:
self.full_text += chunk.delta_content
log.clear()
# Use neutral code theme for syntax highlighting
md = Markdown(self.full_text, code_theme="github-dark", inline_code_theme="github-dark")
log.write(md)
# Auto-scroll to keep the latest content visible
# Use call_after_refresh to ensure scroll happens after layout update
if self.chat_display:
self.chat_display.call_after_refresh(self.chat_display.scroll_end, animate=False)
if hasattr(chunk, "usage") and chunk.usage:
usage = chunk.usage
return self.full_text, usage
def set_content(self, content: str) -> None:
"""Set the complete content (non-streaming)."""
self.full_text = content
log = self.query_one("#assistant-content", RichLog)
log.clear()
# Use neutral code theme for syntax highlighting
md = Markdown(content, code_theme="github-dark", inline_code_theme="github-dark")
log.write(md)

View File

@@ -1,51 +0,0 @@
"""
UI utilities for oAI.
This module provides rich terminal UI components and display helpers
for the chat application.
"""
from oai.ui.console import (
console,
clear_screen,
display_panel,
display_table,
display_markdown,
print_error,
print_warning,
print_success,
print_info,
)
from oai.ui.tables import (
create_model_table,
create_stats_table,
create_help_table,
display_paginated_table,
)
from oai.ui.prompts import (
prompt_confirm,
prompt_choice,
prompt_input,
)
__all__ = [
# Console utilities
"console",
"clear_screen",
"display_panel",
"display_table",
"display_markdown",
"print_error",
"print_warning",
"print_success",
"print_info",
# Table utilities
"create_model_table",
"create_stats_table",
"create_help_table",
"display_paginated_table",
# Prompt utilities
"prompt_confirm",
"prompt_choice",
"prompt_input",
]

View File

@@ -1,242 +0,0 @@
"""
Console utilities for oAI.
This module provides the Rich console instance and common display functions
for formatted terminal output.
"""
from typing import Any, Optional
from rich.console import Console
from rich.markdown import Markdown
from rich.panel import Panel
from rich.table import Table
from rich.text import Text
# Global console instance for the application
console = Console()
def clear_screen() -> None:
"""
Clear the terminal screen.
Uses ANSI escape codes for fast clearing, with a fallback
for terminals that don't support them.
"""
try:
print("\033[H\033[J", end="", flush=True)
except Exception:
# Fallback: print many newlines
print("\n" * 100)
def display_panel(
content: Any,
title: Optional[str] = None,
subtitle: Optional[str] = None,
border_style: str = "green",
title_align: str = "left",
subtitle_align: str = "right",
) -> None:
"""
Display content in a bordered panel.
Args:
content: Content to display (string, Table, or Markdown)
title: Optional panel title
subtitle: Optional panel subtitle
border_style: Border color/style
title_align: Title alignment ("left", "center", "right")
subtitle_align: Subtitle alignment
"""
panel = Panel(
content,
title=title,
subtitle=subtitle,
border_style=border_style,
title_align=title_align,
subtitle_align=subtitle_align,
)
console.print(panel)
def display_table(
table: Table,
title: Optional[str] = None,
subtitle: Optional[str] = None,
) -> None:
"""
Display a table with optional title panel.
Args:
table: Rich Table to display
title: Optional panel title
subtitle: Optional panel subtitle
"""
if title:
display_panel(table, title=title, subtitle=subtitle)
else:
console.print(table)
def display_markdown(
content: str,
panel: bool = False,
title: Optional[str] = None,
) -> None:
"""
Display markdown-formatted content.
Args:
content: Markdown text to display
panel: Whether to wrap in a panel
title: Optional panel title (if panel=True)
"""
md = Markdown(content)
if panel:
display_panel(md, title=title)
else:
console.print(md)
def print_error(message: str, prefix: str = "Error:") -> None:
"""
Print an error message in red.
Args:
message: Error message to display
prefix: Prefix before the message (default: "Error:")
"""
console.print(f"[bold red]{prefix}[/] {message}")
def print_warning(message: str, prefix: str = "Warning:") -> None:
"""
Print a warning message in yellow.
Args:
message: Warning message to display
prefix: Prefix before the message (default: "Warning:")
"""
console.print(f"[bold yellow]{prefix}[/] {message}")
def print_success(message: str, prefix: str = "") -> None:
"""
Print a success message in green.
Args:
message: Success message to display
prefix: Prefix before the message (default: "")
"""
console.print(f"[bold green]{prefix}[/] {message}")
def print_info(message: str, dim: bool = False) -> None:
"""
Print an informational message in cyan.
Args:
message: Info message to display
dim: Whether to dim the message
"""
if dim:
console.print(f"[dim cyan]{message}[/]")
else:
console.print(f"[bold cyan]{message}[/]")
def print_metrics(
tokens: int,
cost: float,
time_seconds: float,
context_info: str = "",
online: bool = False,
mcp_mode: Optional[str] = None,
tool_loops: int = 0,
session_tokens: int = 0,
session_cost: float = 0.0,
) -> None:
"""
Print formatted metrics for a response.
Args:
tokens: Total tokens used
cost: Cost in USD
time_seconds: Response time
context_info: Context information string
online: Whether online mode is active
mcp_mode: MCP mode ("files", "database", or None)
tool_loops: Number of tool call loops
session_tokens: Total session tokens
session_cost: Total session cost
"""
parts = [
f"📊 Metrics: {tokens} tokens",
f"${cost:.4f}",
f"{time_seconds:.2f}s",
]
if context_info:
parts.append(context_info)
if online:
parts.append("🌐")
if mcp_mode == "files":
parts.append("🔧")
elif mcp_mode == "database":
parts.append("🗄️")
if tool_loops > 0:
parts.append(f"({tool_loops} tool loop(s))")
parts.append(f"Session: {session_tokens} tokens")
parts.append(f"${session_cost:.4f}")
console.print(f"\n[dim blue]{' | '.join(parts)}[/]")
def format_size(size_bytes: int) -> str:
"""
Format a size in bytes to a human-readable string.
Args:
size_bytes: Size in bytes
Returns:
Formatted size string (e.g., "1.5 MB")
"""
for unit in ["B", "KB", "MB", "GB", "TB"]:
if abs(size_bytes) < 1024.0:
return f"{size_bytes:.1f} {unit}"
size_bytes /= 1024.0
return f"{size_bytes:.1f} PB"
def format_tokens(tokens: int) -> str:
"""
Format token count with thousands separators.
Args:
tokens: Number of tokens
Returns:
Formatted token string (e.g., "1,234,567")
"""
return f"{tokens:,}"
def format_cost(cost: float, precision: int = 4) -> str:
"""
Format cost in USD.
Args:
cost: Cost in dollars
precision: Decimal places
Returns:
Formatted cost string (e.g., "$0.0123")
"""
return f"${cost:.{precision}f}"

View File

@@ -1,274 +0,0 @@
"""
Prompt utilities for oAI.
This module provides functions for gathering user input
through confirmations, choices, and text prompts.
"""
from typing import List, Optional, TypeVar
import typer
from oai.ui.console import console
T = TypeVar("T")
def prompt_confirm(
message: str,
default: bool = False,
abort: bool = False,
) -> bool:
"""
Prompt the user for a yes/no confirmation.
Args:
message: The question to ask
default: Default value if user presses Enter
abort: Whether to abort on "no" response
Returns:
True if user confirms, False otherwise
"""
try:
return typer.confirm(message, default=default, abort=abort)
except (EOFError, KeyboardInterrupt):
console.print("\n[yellow]Cancelled[/]")
return False
def prompt_choice(
message: str,
choices: List[str],
default: Optional[str] = None,
) -> Optional[str]:
"""
Prompt the user to select from a list of choices.
Args:
message: The question to ask
choices: List of valid choices
default: Default choice if user presses Enter
Returns:
Selected choice or None if cancelled
"""
# Display choices
console.print(f"\n[bold cyan]{message}[/]")
for i, choice in enumerate(choices, 1):
default_marker = " [default]" if choice == default else ""
console.print(f" {i}. {choice}{default_marker}")
try:
response = input("\nEnter number or value: ").strip()
if not response and default:
return default
# Try as number first
try:
index = int(response) - 1
if 0 <= index < len(choices):
return choices[index]
except ValueError:
pass
# Try as exact match
if response in choices:
return response
# Try case-insensitive match
response_lower = response.lower()
for choice in choices:
if choice.lower() == response_lower:
return choice
console.print(f"[red]Invalid choice: {response}[/]")
return None
except (EOFError, KeyboardInterrupt):
console.print("\n[yellow]Cancelled[/]")
return None
def prompt_input(
message: str,
default: Optional[str] = None,
password: bool = False,
required: bool = False,
) -> Optional[str]:
"""
Prompt the user for text input.
Args:
message: The prompt message
default: Default value if user presses Enter
password: Whether to hide input (for sensitive data)
required: Whether input is required (loops until provided)
Returns:
User input or default, None if cancelled
"""
prompt_text = message
if default:
prompt_text += f" [{default}]"
prompt_text += ": "
try:
while True:
if password:
import getpass
response = getpass.getpass(prompt_text)
else:
response = input(prompt_text).strip()
if not response:
if default:
return default
if required:
console.print("[yellow]Input required[/]")
continue
return None
return response
except (EOFError, KeyboardInterrupt):
console.print("\n[yellow]Cancelled[/]")
return None
def prompt_number(
message: str,
min_value: Optional[int] = None,
max_value: Optional[int] = None,
default: Optional[int] = None,
) -> Optional[int]:
"""
Prompt the user for a numeric input.
Args:
message: The prompt message
min_value: Minimum allowed value
max_value: Maximum allowed value
default: Default value if user presses Enter
Returns:
Integer value or None if cancelled
"""
prompt_text = message
if default is not None:
prompt_text += f" [{default}]"
prompt_text += ": "
try:
while True:
response = input(prompt_text).strip()
if not response:
if default is not None:
return default
return None
try:
value = int(response)
except ValueError:
console.print("[red]Please enter a valid number[/]")
continue
if min_value is not None and value < min_value:
console.print(f"[red]Value must be at least {min_value}[/]")
continue
if max_value is not None and value > max_value:
console.print(f"[red]Value must be at most {max_value}[/]")
continue
return value
except (EOFError, KeyboardInterrupt):
console.print("\n[yellow]Cancelled[/]")
return None
def prompt_selection(
items: List[T],
message: str = "Select an item",
display_func: Optional[callable] = None,
allow_cancel: bool = True,
) -> Optional[T]:
"""
Prompt the user to select an item from a list.
Args:
items: List of items to choose from
message: The selection prompt
display_func: Function to convert item to display string
allow_cancel: Whether to allow cancellation
Returns:
Selected item or None if cancelled
"""
if not items:
console.print("[yellow]No items to select[/]")
return None
display = display_func or str
console.print(f"\n[bold cyan]{message}[/]")
for i, item in enumerate(items, 1):
console.print(f" {i}. {display(item)}")
if allow_cancel:
console.print(f" 0. Cancel")
try:
while True:
response = input("\nEnter number: ").strip()
try:
index = int(response)
except ValueError:
console.print("[red]Please enter a valid number[/]")
continue
if allow_cancel and index == 0:
return None
if 1 <= index <= len(items):
return items[index - 1]
console.print(f"[red]Please enter a number between 1 and {len(items)}[/]")
except (EOFError, KeyboardInterrupt):
console.print("\n[yellow]Cancelled[/]")
return None
def prompt_copy_response(response: str) -> bool:
"""
Prompt user to copy a response to clipboard.
Args:
response: The response text
Returns:
True if copied, False otherwise
"""
try:
copy_choice = input("💾 Type 'c' to copy response, or press Enter to continue: ").strip().lower()
if copy_choice == "c":
try:
import pyperclip
pyperclip.copy(response)
console.print("[bold green]✅ Response copied to clipboard![/]")
return True
except ImportError:
console.print("[yellow]pyperclip not installed - cannot copy to clipboard[/]")
except Exception as e:
console.print(f"[red]Failed to copy: {e}[/]")
except (EOFError, KeyboardInterrupt):
pass
return False

View File

@@ -1,373 +0,0 @@
"""
Table utilities for oAI.
This module provides functions for creating and displaying
formatted tables with pagination support.
"""
import os
import sys
from typing import Any, Dict, List, Optional
from rich.panel import Panel
from rich.table import Table
from oai.ui.console import clear_screen, console
def create_model_table(
models: List[Dict[str, Any]],
show_capabilities: bool = True,
) -> Table:
"""
Create a table displaying available AI models.
Args:
models: List of model dictionaries
show_capabilities: Whether to show capability columns
Returns:
Rich Table with model information
"""
if show_capabilities:
table = Table(
"No.",
"Model ID",
"Context",
"Image",
"Online",
"Tools",
show_header=True,
header_style="bold magenta",
)
else:
table = Table(
"No.",
"Model ID",
"Context",
show_header=True,
header_style="bold magenta",
)
for i, model in enumerate(models, 1):
model_id = model.get("id", "Unknown")
context = model.get("context_length", 0)
context_str = f"{context:,}" if context else "-"
if show_capabilities:
# Get modalities and parameters
architecture = model.get("architecture", {})
input_modalities = architecture.get("input_modalities", [])
supported_params = model.get("supported_parameters", [])
has_image = "" if "image" in input_modalities else "-"
has_online = "" if "tools" in supported_params else "-"
has_tools = "" if "tools" in supported_params or "functions" in supported_params else "-"
table.add_row(
str(i),
model_id,
context_str,
has_image,
has_online,
has_tools,
)
else:
table.add_row(str(i), model_id, context_str)
return table
def create_stats_table(stats: Dict[str, Any]) -> Table:
"""
Create a table displaying session statistics.
Args:
stats: Dictionary with statistics data
Returns:
Rich Table with stats
"""
table = Table(
"Metric",
"Value",
show_header=True,
header_style="bold magenta",
)
# Token stats
if "input_tokens" in stats:
table.add_row("Input Tokens", f"{stats['input_tokens']:,}")
if "output_tokens" in stats:
table.add_row("Output Tokens", f"{stats['output_tokens']:,}")
if "total_tokens" in stats:
table.add_row("Total Tokens", f"{stats['total_tokens']:,}")
# Cost stats
if "total_cost" in stats:
table.add_row("Total Cost", f"${stats['total_cost']:.4f}")
if "avg_cost" in stats:
table.add_row("Avg Cost/Message", f"${stats['avg_cost']:.4f}")
# Message stats
if "message_count" in stats:
table.add_row("Messages", str(stats["message_count"]))
# Credits
if "credits_left" in stats:
table.add_row("Credits Left", stats["credits_left"])
return table
def create_help_table(commands: Dict[str, Dict[str, str]]) -> Table:
"""
Create a help table for commands.
Args:
commands: Dictionary of command info
Returns:
Rich Table with command help
"""
table = Table(
"Command",
"Description",
"Example",
show_header=True,
header_style="bold magenta",
show_lines=False,
)
for cmd, info in commands.items():
description = info.get("description", "")
example = info.get("example", "")
table.add_row(cmd, description, example)
return table
def create_folder_table(
folders: List[Dict[str, Any]],
gitignore_info: str = "",
) -> Table:
"""
Create a table for MCP folder listing.
Args:
folders: List of folder dictionaries
gitignore_info: Optional gitignore status info
Returns:
Rich Table with folder information
"""
table = Table(
"No.",
"Path",
"Files",
"Size",
show_header=True,
header_style="bold magenta",
)
for folder in folders:
number = str(folder.get("number", ""))
path = folder.get("path", "")
if folder.get("exists", True):
files = f"📁 {folder.get('file_count', 0)}"
size = f"{folder.get('size_mb', 0):.1f} MB"
else:
files = "[red]Not found[/red]"
size = "-"
table.add_row(number, path, files, size)
return table
def create_database_table(databases: List[Dict[str, Any]]) -> Table:
"""
Create a table for MCP database listing.
Args:
databases: List of database dictionaries
Returns:
Rich Table with database information
"""
table = Table(
"No.",
"Name",
"Tables",
"Size",
"Status",
show_header=True,
header_style="bold magenta",
)
for db in databases:
number = str(db.get("number", ""))
name = db.get("name", "")
table_count = f"{db.get('table_count', 0)} tables"
size = f"{db.get('size_mb', 0):.1f} MB"
if db.get("warning"):
status = f"[red]{db['warning']}[/red]"
else:
status = "[green]✓[/green]"
table.add_row(number, name, table_count, size, status)
return table
def display_paginated_table(
table: Table,
title: str,
terminal_height: Optional[int] = None,
) -> None:
"""
Display a table with pagination for large datasets.
Allows navigating through pages with keyboard input.
Press SPACE for next page, any other key to exit.
Args:
table: Rich Table to display
title: Title for the table
terminal_height: Override terminal height (auto-detected if None)
"""
# Get terminal dimensions
try:
term_height = terminal_height or os.get_terminal_size().lines - 8
except OSError:
term_height = 20
# Render table to segments
from rich.segment import Segment
segments = list(console.render(table))
# Group segments into lines
current_line_segments: List[Segment] = []
all_lines: List[List[Segment]] = []
for segment in segments:
if segment.text == "\n":
all_lines.append(current_line_segments)
current_line_segments = []
else:
current_line_segments.append(segment)
if current_line_segments:
all_lines.append(current_line_segments)
total_lines = len(all_lines)
# If table fits in one screen, just display it
if total_lines <= term_height:
console.print(Panel(table, title=title, title_align="left"))
return
# Extract header and footer lines
header_lines: List[List[Segment]] = []
data_lines: List[List[Segment]] = []
footer_line: List[Segment] = []
# Find header end (line after the header text with border)
header_end_index = 0
found_header_text = False
for i, line_segments in enumerate(all_lines):
has_header_style = any(
seg.style and ("bold" in str(seg.style) or "magenta" in str(seg.style))
for seg in line_segments
)
if has_header_style:
found_header_text = True
if found_header_text and i > 0:
line_text = "".join(seg.text for seg in line_segments)
if any(char in line_text for char in ["", "", "", "", "", ""]):
header_end_index = i
break
# Extract footer (bottom border)
if all_lines:
last_line_text = "".join(seg.text for seg in all_lines[-1])
if any(char in last_line_text for char in ["", "", "", "", "", ""]):
footer_line = all_lines[-1]
all_lines = all_lines[:-1]
# Split into header and data
if header_end_index > 0:
header_lines = all_lines[: header_end_index + 1]
data_lines = all_lines[header_end_index + 1 :]
else:
header_lines = all_lines[: min(3, len(all_lines))]
data_lines = all_lines[min(3, len(all_lines)) :]
lines_per_page = term_height - len(header_lines)
current_line = 0
page_number = 1
# Paginate
while current_line < len(data_lines):
clear_screen()
console.print(f"[bold cyan]{title} (Page {page_number})[/]")
# Print header
for line_segments in header_lines:
for segment in line_segments:
console.print(segment.text, style=segment.style, end="")
console.print()
# Print data rows for this page
end_line = min(current_line + lines_per_page, len(data_lines))
for line_segments in data_lines[current_line:end_line]:
for segment in line_segments:
console.print(segment.text, style=segment.style, end="")
console.print()
# Print footer
if footer_line:
for segment in footer_line:
console.print(segment.text, style=segment.style, end="")
console.print()
current_line = end_line
page_number += 1
# Prompt for next page
if current_line < len(data_lines):
console.print(
f"\n[dim yellow]--- Press SPACE for next page, "
f"or any other key to finish (Page {page_number - 1}, "
f"showing {end_line}/{len(data_lines)} data rows) ---[/dim yellow]"
)
try:
import termios
import tty
fd = sys.stdin.fileno()
old_settings = termios.tcgetattr(fd)
try:
tty.setraw(fd)
char = sys.stdin.read(1)
if char != " ":
break
finally:
termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
except (ImportError, OSError, AttributeError):
# Fallback for non-Unix systems
try:
user_input = input()
if user_input.strip():
break
except (EOFError, KeyboardInterrupt):
break

247
oai/utils/web_search.py Normal file
View File

@@ -0,0 +1,247 @@
"""
Web search utilities for oAI.
Provides web search capabilities for all providers (not just OpenRouter).
Uses DuckDuckGo by default (no API key needed).
"""
import json
import re
from typing import Dict, List, Optional
from urllib.parse import quote_plus
import requests
from oai.utils.logging import get_logger
logger = get_logger()
class WebSearchResult:
"""Container for a single search result."""
def __init__(self, title: str, url: str, snippet: str):
self.title = title
self.url = url
self.snippet = snippet
def __repr__(self) -> str:
return f"WebSearchResult(title='{self.title}', url='{self.url}')"
class WebSearchProvider:
"""Base class for web search providers."""
def search(self, query: str, num_results: int = 5) -> List[WebSearchResult]:
"""
Perform a web search.
Args:
query: Search query
num_results: Number of results to return
Returns:
List of search results
"""
raise NotImplementedError
class DuckDuckGoSearch(WebSearchProvider):
"""DuckDuckGo search provider (no API key needed)."""
def __init__(self):
self.session = requests.Session()
self.session.headers.update({
'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36'
})
def search(self, query: str, num_results: int = 5) -> List[WebSearchResult]:
"""
Search using DuckDuckGo HTML interface.
Args:
query: Search query
num_results: Number of results to return (default: 5)
Returns:
List of search results
"""
try:
# Use DuckDuckGo HTML search
url = f"https://html.duckduckgo.com/html/?q={quote_plus(query)}"
response = self.session.get(url, timeout=10)
response.raise_for_status()
results = []
html = response.text
# Parse results using regex (simple HTML parsing)
# Find all result blocks - they end at next result or end of results section
result_blocks = re.findall(
r'<div class="result results_links.*?(?=<div class="result results_links|<div id="links")',
html,
re.DOTALL
)
for block in result_blocks[:num_results]:
# Extract title and URL - look for result__a class
title_match = re.search(r'<a[^>]*class="result__a"[^>]*href="([^"]+)"[^>]*>([^<]+)</a>', block)
# Extract snippet - look for result__snippet class
snippet_match = re.search(r'<a[^>]*class="result__snippet"[^>]*>([^<]+)</a>', block)
if title_match:
url_raw = title_match.group(1)
title = title_match.group(2).strip()
# Decode HTML entities in title
import html as html_module
title = html_module.unescape(title)
snippet = ""
if snippet_match:
snippet = snippet_match.group(1).strip()
snippet = html_module.unescape(snippet)
# Clean up URL (DDG uses redirect links)
if 'uddg=' in url_raw:
# Extract actual URL from redirect
actual_url_match = re.search(r'uddg=([^&]+)', url_raw)
if actual_url_match:
from urllib.parse import unquote
url_raw = unquote(actual_url_match.group(1))
results.append(WebSearchResult(
title=title,
url=url_raw,
snippet=snippet
))
logger.info(f"DuckDuckGo search: found {len(results)} results for '{query}'")
return results
except requests.RequestException as e:
logger.error(f"DuckDuckGo search failed: {e}")
return []
except Exception as e:
logger.error(f"Error parsing DuckDuckGo results: {e}")
return []
class GoogleCustomSearch(WebSearchProvider):
"""Google Custom Search API provider (requires API key)."""
def __init__(self, api_key: str, search_engine_id: str):
"""
Initialize Google Custom Search.
Args:
api_key: Google API key
search_engine_id: Custom Search Engine ID
"""
self.api_key = api_key
self.search_engine_id = search_engine_id
def search(self, query: str, num_results: int = 5) -> List[WebSearchResult]:
"""
Search using Google Custom Search API.
Args:
query: Search query
num_results: Number of results to return
Returns:
List of search results
"""
try:
url = "https://www.googleapis.com/customsearch/v1"
params = {
'key': self.api_key,
'cx': self.search_engine_id,
'q': query,
'num': min(num_results, 10) # Google allows max 10
}
response = requests.get(url, params=params, timeout=10)
response.raise_for_status()
data = response.json()
results = []
for item in data.get('items', []):
results.append(WebSearchResult(
title=item.get('title', ''),
url=item.get('link', ''),
snippet=item.get('snippet', '')
))
logger.info(f"Google Custom Search: found {len(results)} results for '{query}'")
return results
except requests.RequestException as e:
logger.error(f"Google Custom Search failed: {e}")
return []
def perform_web_search(
query: str,
num_results: int = 5,
provider: str = "duckduckgo",
**kwargs
) -> List[WebSearchResult]:
"""
Perform a web search using the specified provider.
Args:
query: Search query
num_results: Number of results to return (default: 5)
provider: Search provider ("duckduckgo" or "google")
**kwargs: Provider-specific arguments (e.g., api_key for Google)
Returns:
List of search results
"""
if provider == "google":
api_key = kwargs.get("google_api_key")
search_engine_id = kwargs.get("google_search_engine_id")
if not api_key or not search_engine_id:
logger.warning("Google search requires api_key and search_engine_id, falling back to DuckDuckGo")
provider = "duckduckgo"
if provider == "google":
search_provider = GoogleCustomSearch(api_key, search_engine_id)
else:
search_provider = DuckDuckGoSearch()
return search_provider.search(query, num_results)
def format_search_results(results: List[WebSearchResult], max_length: int = 2000) -> str:
"""
Format search results for inclusion in AI prompt.
Args:
results: List of search results
max_length: Maximum total length of formatted results
Returns:
Formatted string with search results
"""
if not results:
return "No search results found."
formatted = "**Web Search Results:**\n\n"
for i, result in enumerate(results, 1):
result_text = f"{i}. **{result.title}**\n"
result_text += f" URL: {result.url}\n"
if result.snippet:
result_text += f" {result.snippet}\n"
result_text += "\n"
# Check if adding this result would exceed max_length
if len(formatted) + len(result_text) > max_length:
formatted += f"... ({len(results) - i + 1} more results truncated)\n"
break
formatted += result_text
return formatted.strip()

View File

@@ -4,8 +4,8 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "oai" name = "oai"
version = "2.1.0" version = "3.0.0-b3" # MUST match oai/__init__.py __version__
description = "OpenRouter AI Chat Client - A feature-rich terminal-based chat application" description = "Open AI Chat Client - Multi-provider terminal chat with MCP support"
readme = "README.md" readme = "README.md"
license = {text = "MIT"} license = {text = "MIT"}
authors = [ authors = [
@@ -39,15 +39,17 @@ classifiers = [
requires-python = ">=3.10" requires-python = ">=3.10"
dependencies = [ dependencies = [
"anyio>=4.0.0", "anyio>=4.0.0",
"anthropic>=0.40.0",
"click>=8.0.0", "click>=8.0.0",
"httpx>=0.24.0", "httpx>=0.24.0",
"markdown-it-py>=3.0.0", "markdown-it-py>=3.0.0",
"openai>=1.59.0",
"openrouter>=0.0.19", "openrouter>=0.0.19",
"packaging>=21.0", "packaging>=21.0",
"prompt-toolkit>=3.0.0",
"pyperclip>=1.8.0", "pyperclip>=1.8.0",
"requests>=2.28.0", "requests>=2.28.0",
"rich>=13.0.0", "rich>=13.0.0",
"textual>=0.50.0",
"typer>=0.9.0", "typer>=0.9.0",
"mcp>=1.0.0", "mcp>=1.0.0",
] ]
@@ -73,7 +75,7 @@ Documentation = "https://iurl.no/oai"
oai = "oai.cli:main" oai = "oai.cli:main"
[tool.setuptools] [tool.setuptools]
packages = ["oai", "oai.commands", "oai.config", "oai.core", "oai.mcp", "oai.providers", "oai.ui", "oai.utils"] packages = ["oai", "oai.commands", "oai.config", "oai.core", "oai.mcp", "oai.providers", "oai.tui", "oai.tui.widgets", "oai.tui.screens", "oai.utils"]
[tool.setuptools.package-data] [tool.setuptools.package-data]
oai = ["py.typed"] oai = ["py.typed"]