Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 06a3c898d3 | |||
| ecc2489eef | |||
| 1191fa6d19 | |||
| 6298158d3c |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -45,3 +45,4 @@ b0.sh
|
|||||||
requirements.txt
|
requirements.txt
|
||||||
system_prompt.txt
|
system_prompt.txt
|
||||||
CLAUDE*
|
CLAUDE*
|
||||||
|
SESSION*_COMPLETE.md
|
||||||
|
|||||||
228
README.md
228
README.md
@@ -1,19 +1,25 @@
|
|||||||
# oAI - OpenRouter AI Chat Client
|
# oAI - Open AI Chat Client
|
||||||
|
|
||||||
A powerful, extensible terminal-based chat client for OpenRouter API with **MCP (Model Context Protocol)** support, enabling AI to access local files and query SQLite databases.
|
A powerful, modern **Textual TUI** chat client with **multi-provider support** (OpenRouter, Anthropic, OpenAI, Ollama) and **MCP (Model Context Protocol)** integration, enabling AI to access local files and query SQLite databases.
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
### Core Features
|
### Core Features
|
||||||
- 🤖 Interactive chat with 300+ AI models via OpenRouter
|
- 🖥️ **Modern Textual TUI** with async streaming and beautiful interface
|
||||||
- 🔍 Model selection with search and filtering
|
- 🔄 **Multi-Provider Support** - OpenRouter, Anthropic (Claude), OpenAI (ChatGPT), Ollama (local)
|
||||||
|
- 🤖 Interactive chat with 300+ AI models across providers
|
||||||
|
- 🔍 Model selection with search, filtering, and capability icons
|
||||||
- 💾 Conversation save/load/export (Markdown, JSON, HTML)
|
- 💾 Conversation save/load/export (Markdown, JSON, HTML)
|
||||||
- 📎 File attachments (images, PDFs, code files)
|
- 📎 File attachments (images, PDFs, code files)
|
||||||
- 💰 Real-time cost tracking and credit monitoring
|
- 💰 Real-time cost tracking and credit monitoring (OpenRouter)
|
||||||
- 🎨 Rich terminal UI with syntax highlighting
|
- 🎨 Dark theme with syntax highlighting and Markdown rendering
|
||||||
- 📝 Persistent command history with search (Ctrl+R)
|
- 📝 Command history navigation (Up/Down arrows)
|
||||||
- 🌐 Online mode (web search capabilities)
|
- 🌐 **Universal Online Mode** - Web search for ALL providers:
|
||||||
|
- **Anthropic Native** - Built-in search with automatic citations ($0.01/search)
|
||||||
|
- **DuckDuckGo** - Free web scraping (all providers)
|
||||||
|
- **Google Custom Search** - Premium search option
|
||||||
- 🧠 Conversation memory toggle
|
- 🧠 Conversation memory toggle
|
||||||
|
- ⌨️ Keyboard shortcuts (F1=Help, F2=Models, Ctrl+S=Stats)
|
||||||
|
|
||||||
### MCP Integration
|
### MCP Integration
|
||||||
- 🔧 **File Mode**: AI can read, search, and list local files
|
- 🔧 **File Mode**: AI can read, search, and list local files
|
||||||
@@ -34,30 +40,23 @@ A powerful, extensible terminal-based chat client for OpenRouter API with **MCP
|
|||||||
## Requirements
|
## Requirements
|
||||||
|
|
||||||
- Python 3.10-3.13
|
- Python 3.10-3.13
|
||||||
- OpenRouter API key ([get one here](https://openrouter.ai))
|
- API key for your chosen provider:
|
||||||
|
- **OpenRouter**: [openrouter.ai](https://openrouter.ai)
|
||||||
|
- **Anthropic**: [console.anthropic.com](https://console.anthropic.com)
|
||||||
|
- **OpenAI**: [platform.openai.com](https://platform.openai.com)
|
||||||
|
- **Ollama**: No API key needed (local server)
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
### Option 1: Install from Source (Recommended)
|
### Option 1: Pre-built Binary (macOS/Linux) (Recommended)
|
||||||
|
|
||||||
```bash
|
|
||||||
# Clone the repository
|
|
||||||
git clone https://gitlab.pm/rune/oai.git
|
|
||||||
cd oai
|
|
||||||
|
|
||||||
# Install with pip
|
|
||||||
pip install -e .
|
|
||||||
```
|
|
||||||
|
|
||||||
### Option 2: Pre-built Binary (macOS/Linux)
|
|
||||||
|
|
||||||
Download from [Releases](https://gitlab.pm/rune/oai/releases):
|
Download from [Releases](https://gitlab.pm/rune/oai/releases):
|
||||||
- **macOS (Apple Silicon)**: `oai_v2.1.0_mac_arm64.zip`
|
- **macOS (Apple Silicon)**: `oai_v3.0.0_mac_arm64.zip`
|
||||||
- **Linux (x86_64)**: `oai_v2.1.0_linux_x86_64.zip`
|
- **Linux (x86_64)**: `oai_v3.0.0_linux_x86_64.zip`
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Extract and install
|
# Extract and install
|
||||||
unzip oai_v2.1.0_*.zip
|
unzip oai_v3.0.0_*.zip
|
||||||
mkdir -p ~/.local/bin
|
mkdir -p ~/.local/bin
|
||||||
mv oai ~/.local/bin/
|
mv oai ~/.local/bin/
|
||||||
|
|
||||||
@@ -73,29 +72,130 @@ xattr -cr ~/.local/bin/oai
|
|||||||
export PATH="$HOME/.local/bin:$PATH"
|
export PATH="$HOME/.local/bin:$PATH"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Option 2: Install from Source
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Clone the repository
|
||||||
|
git clone https://gitlab.pm/rune/oai.git
|
||||||
|
cd oai
|
||||||
|
|
||||||
|
# Install with pip
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Start the chat client
|
# Start oAI (launches TUI)
|
||||||
oai chat
|
oai
|
||||||
|
|
||||||
|
# Start with specific provider
|
||||||
|
oai --provider anthropic
|
||||||
|
oai --provider openai
|
||||||
|
oai --provider ollama
|
||||||
|
|
||||||
# Or with options
|
# Or with options
|
||||||
oai chat --model gpt-4o --mcp
|
oai --provider openrouter --model gpt-4o --online --mcp
|
||||||
|
|
||||||
|
# Show version
|
||||||
|
oai version
|
||||||
```
|
```
|
||||||
|
|
||||||
On first run, you'll be prompted for your OpenRouter API key.
|
On first run, you'll be prompted for your API key. Configure additional providers anytime with `/config`.
|
||||||
|
|
||||||
|
### Enable Web Search (All Providers)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Using Anthropic native search (best quality, automatic citations)
|
||||||
|
/config search_provider anthropic_native
|
||||||
|
/online on
|
||||||
|
|
||||||
|
# Using free DuckDuckGo (works with all providers)
|
||||||
|
/config search_provider duckduckgo
|
||||||
|
/online on
|
||||||
|
|
||||||
|
# Using Google Custom Search (requires API key)
|
||||||
|
/config search_provider google
|
||||||
|
/config google_api_key YOUR_KEY
|
||||||
|
/config google_search_engine_id YOUR_ID
|
||||||
|
/online on
|
||||||
|
```
|
||||||
|
|
||||||
### Basic Commands
|
### Basic Commands
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# In the chat interface:
|
# In the TUI interface:
|
||||||
/model # Select AI model
|
/provider # Show current provider or switch
|
||||||
/help # Show all commands
|
/provider anthropic # Switch to Anthropic (Claude)
|
||||||
|
/provider openai # Switch to OpenAI (ChatGPT)
|
||||||
|
/provider ollama # Switch to Ollama (local)
|
||||||
|
/model # Select AI model (or press F2)
|
||||||
|
/online on # Enable web search
|
||||||
|
/help # Show all commands (or press F1)
|
||||||
/mcp on # Enable file/database access
|
/mcp on # Enable file/database access
|
||||||
/stats # View session statistics
|
/stats # View session statistics (or press Ctrl+S)
|
||||||
exit # Quit
|
/config # View configuration settings
|
||||||
|
/credits # Check account credits (shows API balance or console link)
|
||||||
|
Ctrl+Q # Quit
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Web Search
|
||||||
|
|
||||||
|
oAI provides universal web search capabilities for all AI providers with three options:
|
||||||
|
|
||||||
|
### Anthropic Native Search (Recommended for Anthropic)
|
||||||
|
|
||||||
|
Anthropic's built-in web search API with automatic citations:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
/config search_provider anthropic_native
|
||||||
|
/online on
|
||||||
|
|
||||||
|
# Now ask questions requiring current information
|
||||||
|
What are the latest developments in quantum computing?
|
||||||
|
```
|
||||||
|
|
||||||
|
**Features:**
|
||||||
|
- ✅ **Automatic citations** - Claude cites its sources
|
||||||
|
- ✅ **Smart searching** - Claude decides when to search
|
||||||
|
- ✅ **Progressive searches** - Multiple searches for complex queries
|
||||||
|
- ✅ **Best quality** - Professional-grade results
|
||||||
|
|
||||||
|
**Pricing:** $10 per 1,000 searches ($0.01 per search) + token costs
|
||||||
|
|
||||||
|
**Note:** Only works with Anthropic provider (Claude models)
|
||||||
|
|
||||||
|
### DuckDuckGo Search (Default - Free)
|
||||||
|
|
||||||
|
Free web scraping that works with ALL providers:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
/config search_provider duckduckgo # Default
|
||||||
|
/online on
|
||||||
|
|
||||||
|
# Works with Anthropic, OpenAI, Ollama, and OpenRouter
|
||||||
|
```
|
||||||
|
|
||||||
|
**Features:**
|
||||||
|
- ✅ **Free** - No API key or costs
|
||||||
|
- ✅ **Universal** - Works with all providers
|
||||||
|
- ✅ **Privacy-friendly** - Uses DuckDuckGo
|
||||||
|
|
||||||
|
### Google Custom Search (Premium Option)
|
||||||
|
|
||||||
|
Google's Custom Search API for high-quality results:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
/config search_provider google
|
||||||
|
/config google_api_key YOUR_GOOGLE_API_KEY
|
||||||
|
/config google_search_engine_id YOUR_SEARCH_ENGINE_ID
|
||||||
|
/online on
|
||||||
|
```
|
||||||
|
|
||||||
|
Get your API key: [Google Custom Search API](https://developers.google.com/custom-search/v1/overview)
|
||||||
|
|
||||||
## MCP (Model Context Protocol)
|
## MCP (Model Context Protocol)
|
||||||
|
|
||||||
MCP allows the AI to interact with your local files and databases.
|
MCP allows the AI to interact with your local files and databases.
|
||||||
@@ -175,30 +275,39 @@ MCP allows the AI to interact with your local files and databases.
|
|||||||
| Command | Description |
|
| Command | Description |
|
||||||
|---------|-------------|
|
|---------|-------------|
|
||||||
| `/config` | View settings |
|
| `/config` | View settings |
|
||||||
| `/config api` | Set API key |
|
| `/config provider <name>` | Set default provider |
|
||||||
|
| `/config openrouter_api_key` | Set OpenRouter API key |
|
||||||
|
| `/config anthropic_api_key` | Set Anthropic API key |
|
||||||
|
| `/config openai_api_key` | Set OpenAI API key |
|
||||||
|
| `/config ollama_base_url` | Set Ollama server URL |
|
||||||
|
| `/config search_provider <provider>` | Set search provider (anthropic_native/duckduckgo/google) |
|
||||||
|
| `/config google_api_key` | Set Google API key (for Google search) |
|
||||||
|
| `/config online on\|off` | Set default online mode |
|
||||||
| `/config model <id>` | Set default model |
|
| `/config model <id>` | Set default model |
|
||||||
| `/config stream on\|off` | Toggle streaming |
|
| `/config stream on\|off` | Toggle streaming |
|
||||||
| `/stats` | Session statistics |
|
| `/stats` | Session statistics |
|
||||||
| `/credits` | Check credits |
|
| `/credits` | Check credits (OpenRouter) |
|
||||||
|
|
||||||
## CLI Options
|
## CLI Options
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
oai chat [OPTIONS]
|
oai [OPTIONS]
|
||||||
|
|
||||||
Options:
|
Options:
|
||||||
|
-p, --provider TEXT Provider to use (openrouter/anthropic/openai/ollama)
|
||||||
-m, --model TEXT Model ID to use
|
-m, --model TEXT Model ID to use
|
||||||
-s, --system TEXT System prompt
|
-s, --system TEXT System prompt
|
||||||
-o, --online Enable online mode
|
-o, --online Enable online mode (OpenRouter only)
|
||||||
--mcp Enable MCP server
|
--mcp Enable MCP server
|
||||||
|
-v, --version Show version
|
||||||
--help Show help
|
--help Show help
|
||||||
```
|
```
|
||||||
|
|
||||||
Other commands:
|
Commands:
|
||||||
```bash
|
```bash
|
||||||
oai config [setting] [value] # Configure settings
|
oai # Launch TUI (default)
|
||||||
oai version # Show version
|
oai version # Show version information
|
||||||
oai credits # Check credits
|
oai --help # Show help message
|
||||||
```
|
```
|
||||||
|
|
||||||
## Configuration
|
## Configuration
|
||||||
@@ -218,14 +327,18 @@ oai/
|
|||||||
├── oai/
|
├── oai/
|
||||||
│ ├── __init__.py
|
│ ├── __init__.py
|
||||||
│ ├── __main__.py # Entry point for python -m oai
|
│ ├── __main__.py # Entry point for python -m oai
|
||||||
│ ├── cli.py # Main CLI interface
|
│ ├── cli.py # Main CLI entry point
|
||||||
│ ├── constants.py # Configuration constants
|
│ ├── constants.py # Configuration constants
|
||||||
│ ├── commands/ # Slash command handlers
|
│ ├── commands/ # Slash command handlers
|
||||||
│ ├── config/ # Settings and database
|
│ ├── config/ # Settings and database
|
||||||
│ ├── core/ # Chat client and session
|
│ ├── core/ # Chat client and session
|
||||||
│ ├── mcp/ # MCP server and tools
|
│ ├── mcp/ # MCP server and tools
|
||||||
│ ├── providers/ # AI provider abstraction
|
│ ├── providers/ # AI provider abstraction
|
||||||
│ ├── ui/ # Terminal UI utilities
|
│ ├── tui/ # Textual TUI interface
|
||||||
|
│ │ ├── app.py # Main TUI application
|
||||||
|
│ │ ├── widgets/ # Custom widgets
|
||||||
|
│ │ ├── screens/ # Modal screens
|
||||||
|
│ │ └── styles.tcss # TUI styling
|
||||||
│ └── utils/ # Logging, export, etc.
|
│ └── utils/ # Logging, export, etc.
|
||||||
├── pyproject.toml # Package configuration
|
├── pyproject.toml # Package configuration
|
||||||
├── build.sh # Binary build script
|
├── build.sh # Binary build script
|
||||||
@@ -266,7 +379,34 @@ pip install -e . --force-reinstall
|
|||||||
|
|
||||||
## Version History
|
## Version History
|
||||||
|
|
||||||
### v2.1.0 (Current)
|
### v3.0.0-b3 (Current - Beta 3)
|
||||||
|
- 🔄 **Multi-Provider Support** - OpenRouter, Anthropic (Claude), OpenAI (ChatGPT), Ollama (local)
|
||||||
|
- 🔌 **Provider Switching** - Switch between providers mid-session with `/provider`
|
||||||
|
- 🧠 **Provider Model Memory** - Remembers last used model per provider
|
||||||
|
- ⚙️ **Provider Configuration** - Separate API keys for each provider
|
||||||
|
- 🌐 **Universal Web Search** - Three search options for all providers:
|
||||||
|
- **Anthropic Native** - Built-in search with citations ($0.01/search)
|
||||||
|
- **DuckDuckGo** - Free web scraping (default)
|
||||||
|
- **Google Custom Search** - Premium option
|
||||||
|
- 🎨 **Enhanced Header** - Shows current provider and model
|
||||||
|
- 📊 **Updated Config Screen** - Shows provider-specific settings
|
||||||
|
- 🔧 **Command Dropdown** - Added `/provider` commands for discoverability
|
||||||
|
- 🎯 **Auto-Scrolling** - Chat window auto-scrolls during streaming responses
|
||||||
|
- 🎨 **Refined UI** - Thinner scrollbar, cleaner styling
|
||||||
|
- 🔧 **Updated Models** - Claude 4.5 (Sonnet, Haiku, Opus)
|
||||||
|
|
||||||
|
### v3.0.0 (Beta)
|
||||||
|
- 🎨 **Complete migration to Textual TUI** - Modern async terminal interface
|
||||||
|
- 🗑️ **Removed CLI interface** - TUI-only for cleaner codebase (11.6% smaller)
|
||||||
|
- 🖱️ **Modal screens** - Help, stats, config, credits, model selector
|
||||||
|
- ⌨️ **Keyboard shortcuts** - F1 (help), F2 (models), Ctrl+S (stats), etc.
|
||||||
|
- 🎯 **Capability indicators** - Visual icons for model features (vision, tools, online)
|
||||||
|
- 🎨 **Consistent dark theme** - Professional styling throughout
|
||||||
|
- 📊 **Enhanced model selector** - Search, filter, capability columns
|
||||||
|
- 🚀 **Default command** - Just run `oai` to launch TUI
|
||||||
|
- 🧹 **Code cleanup** - Removed 1,300+ lines of CLI code
|
||||||
|
|
||||||
|
### v2.1.0
|
||||||
- 🏗️ Complete codebase refactoring to modular package structure
|
- 🏗️ Complete codebase refactoring to modular package structure
|
||||||
- 🔌 Extensible provider architecture for adding new AI providers
|
- 🔌 Extensible provider architecture for adding new AI providers
|
||||||
- 📦 Proper Python packaging with pyproject.toml
|
- 📦 Proper Python packaging with pyproject.toml
|
||||||
|
|||||||
@@ -1,16 +1,16 @@
|
|||||||
"""
|
"""
|
||||||
oAI - OpenRouter AI Chat Client
|
oAI - Open AI Chat Client
|
||||||
|
|
||||||
A feature-rich terminal-based chat application that provides an interactive CLI
|
A feature-rich terminal-based chat application with multi-provider support
|
||||||
interface to OpenRouter's unified AI API with advanced Model Context Protocol (MCP)
|
(OpenRouter, Anthropic, OpenAI, Ollama) and advanced Model Context Protocol (MCP)
|
||||||
integration for filesystem and database access.
|
integration for filesystem and database access.
|
||||||
|
|
||||||
Author: Rune
|
Author: Rune
|
||||||
License: MIT
|
License: MIT
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__version__ = "2.1.0"
|
__version__ = "3.0.0-b3"
|
||||||
__author__ = "Rune"
|
__author__ = "Rune Olsen"
|
||||||
__license__ = "MIT"
|
__license__ = "MIT"
|
||||||
|
|
||||||
# Lazy imports to avoid circular dependencies and improve startup time
|
# Lazy imports to avoid circular dependencies and improve startup time
|
||||||
|
|||||||
759
oai/cli.py
759
oai/cli.py
@@ -1,55 +1,27 @@
|
|||||||
"""
|
"""
|
||||||
Main CLI entry point for oAI.
|
Main CLI entry point for oAI.
|
||||||
|
|
||||||
This module provides the command-line interface for the oAI chat application.
|
This module provides the command-line interface for the oAI TUI application.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import typer
|
import typer
|
||||||
from prompt_toolkit import PromptSession
|
|
||||||
from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
|
|
||||||
from prompt_toolkit.history import FileHistory
|
|
||||||
from rich.markdown import Markdown
|
|
||||||
from rich.panel import Panel
|
|
||||||
|
|
||||||
from oai import __version__
|
from oai import __version__
|
||||||
from oai.commands import register_all_commands, registry
|
from oai.commands import register_all_commands
|
||||||
from oai.commands.registry import CommandContext, CommandStatus
|
|
||||||
from oai.config.database import Database
|
|
||||||
from oai.config.settings import Settings
|
from oai.config.settings import Settings
|
||||||
from oai.constants import (
|
from oai.constants import APP_URL, APP_VERSION
|
||||||
APP_NAME,
|
|
||||||
APP_URL,
|
|
||||||
APP_VERSION,
|
|
||||||
CONFIG_DIR,
|
|
||||||
HISTORY_FILE,
|
|
||||||
VALID_COMMANDS,
|
|
||||||
)
|
|
||||||
from oai.core.client import AIClient
|
from oai.core.client import AIClient
|
||||||
from oai.core.session import ChatSession
|
from oai.core.session import ChatSession
|
||||||
from oai.mcp.manager import MCPManager
|
from oai.mcp.manager import MCPManager
|
||||||
from oai.providers.base import UsageStats
|
|
||||||
from oai.providers.openrouter import OpenRouterProvider
|
|
||||||
from oai.ui.console import (
|
|
||||||
clear_screen,
|
|
||||||
console,
|
|
||||||
display_panel,
|
|
||||||
print_error,
|
|
||||||
print_info,
|
|
||||||
print_success,
|
|
||||||
print_warning,
|
|
||||||
)
|
|
||||||
from oai.ui.tables import create_model_table, display_paginated_table
|
|
||||||
from oai.utils.logging import LoggingManager, get_logger
|
from oai.utils.logging import LoggingManager, get_logger
|
||||||
|
|
||||||
# Create Typer app
|
# Create Typer app
|
||||||
app = typer.Typer(
|
app = typer.Typer(
|
||||||
name="oai",
|
name="oai",
|
||||||
help=f"oAI - OpenRouter AI Chat Client\n\nVersion: {APP_VERSION}",
|
help=f"oAI - Open AI Chat Client (TUI)\n\nVersion: {APP_VERSION}",
|
||||||
add_completion=False,
|
add_completion=False,
|
||||||
epilog="For more information, visit: " + APP_URL,
|
epilog="For more information, visit: " + APP_URL,
|
||||||
)
|
)
|
||||||
@@ -65,374 +37,151 @@ def main_callback(
|
|||||||
help="Show version information",
|
help="Show version information",
|
||||||
is_flag=True,
|
is_flag=True,
|
||||||
),
|
),
|
||||||
|
model: Optional[str] = typer.Option(
|
||||||
|
None,
|
||||||
|
"--model",
|
||||||
|
"-m",
|
||||||
|
help="Model ID to use",
|
||||||
|
),
|
||||||
|
system: Optional[str] = typer.Option(
|
||||||
|
None,
|
||||||
|
"--system",
|
||||||
|
"-s",
|
||||||
|
help="System prompt",
|
||||||
|
),
|
||||||
|
online: bool = typer.Option(
|
||||||
|
False,
|
||||||
|
"--online",
|
||||||
|
"-o",
|
||||||
|
help="Enable online mode",
|
||||||
|
),
|
||||||
|
mcp: bool = typer.Option(
|
||||||
|
False,
|
||||||
|
"--mcp",
|
||||||
|
help="Enable MCP server",
|
||||||
|
),
|
||||||
|
provider: Optional[str] = typer.Option(
|
||||||
|
None,
|
||||||
|
"--provider",
|
||||||
|
"-p",
|
||||||
|
help="AI provider to use (openrouter, anthropic, openai, ollama)",
|
||||||
|
),
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Main callback to handle global options."""
|
"""Main callback - launches TUI by default."""
|
||||||
# Show version with update check if --version flag
|
|
||||||
if version_flag:
|
if version_flag:
|
||||||
version_info = check_for_updates(APP_VERSION)
|
typer.echo(f"oAI version {APP_VERSION}")
|
||||||
console.print(version_info)
|
|
||||||
raise typer.Exit()
|
raise typer.Exit()
|
||||||
|
|
||||||
# Show version with update check when --help is requested
|
# If no subcommand provided, launch TUI
|
||||||
if "--help" in sys.argv or "-h" in sys.argv:
|
|
||||||
version_info = check_for_updates(APP_VERSION)
|
|
||||||
console.print(f"\n{version_info}\n")
|
|
||||||
|
|
||||||
# Continue to subcommand if provided
|
|
||||||
if ctx.invoked_subcommand is None:
|
if ctx.invoked_subcommand is None:
|
||||||
return
|
_launch_tui(model, system, online, mcp, provider)
|
||||||
|
|
||||||
|
|
||||||
def check_for_updates(current_version: str) -> str:
|
def _launch_tui(
|
||||||
"""Check for available updates."""
|
model: Optional[str] = None,
|
||||||
import requests
|
system: Optional[str] = None,
|
||||||
from packaging import version as pkg_version
|
online: bool = False,
|
||||||
|
mcp: bool = False,
|
||||||
try:
|
provider: Optional[str] = None,
|
||||||
response = requests.get(
|
|
||||||
"https://gitlab.pm/api/v1/repos/rune/oai/releases/latest",
|
|
||||||
headers={"Content-Type": "application/json"},
|
|
||||||
timeout=1.0,
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
version_online = data.get("tag_name", "").lstrip("v")
|
|
||||||
|
|
||||||
if not version_online:
|
|
||||||
return f"[bold green]oAI version {current_version}[/]"
|
|
||||||
|
|
||||||
current = pkg_version.parse(current_version)
|
|
||||||
latest = pkg_version.parse(version_online)
|
|
||||||
|
|
||||||
if latest > current:
|
|
||||||
return (
|
|
||||||
f"[bold green]oAI version {current_version}[/] "
|
|
||||||
f"[bold red](Update available: {current_version} → {version_online})[/]"
|
|
||||||
)
|
|
||||||
return f"[bold green]oAI version {current_version} (up to date)[/]"
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
return f"[bold green]oAI version {current_version}[/]"
|
|
||||||
|
|
||||||
|
|
||||||
def show_welcome(settings: Settings, version_info: str) -> None:
|
|
||||||
"""Display welcome message."""
|
|
||||||
console.print(Panel.fit(
|
|
||||||
f"{version_info}\n\n"
|
|
||||||
"[bold cyan]Commands:[/] /help for commands, /model to select model\n"
|
|
||||||
"[bold cyan]MCP:[/] /mcp on to enable file/database access\n"
|
|
||||||
"[bold cyan]Exit:[/] Type 'exit', 'quit', or 'bye'",
|
|
||||||
title=f"[bold green]Welcome to {APP_NAME}[/]",
|
|
||||||
border_style="green",
|
|
||||||
))
|
|
||||||
|
|
||||||
|
|
||||||
def select_model(client: AIClient, search_term: Optional[str] = None) -> Optional[dict]:
|
|
||||||
"""Display model selection interface."""
|
|
||||||
try:
|
|
||||||
models = client.provider.get_raw_models()
|
|
||||||
if not models:
|
|
||||||
print_error("No models available")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Filter by search term if provided
|
|
||||||
if search_term:
|
|
||||||
search_lower = search_term.lower()
|
|
||||||
models = [m for m in models if search_lower in m.get("id", "").lower()]
|
|
||||||
|
|
||||||
if not models:
|
|
||||||
print_error(f"No models found matching '{search_term}'")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Create and display table
|
|
||||||
table = create_model_table(models)
|
|
||||||
display_paginated_table(
|
|
||||||
table,
|
|
||||||
f"[bold green]Available Models ({len(models)})[/]",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prompt for selection
|
|
||||||
console.print("")
|
|
||||||
try:
|
|
||||||
choice = input("Enter model number (or press Enter to cancel): ").strip()
|
|
||||||
except (EOFError, KeyboardInterrupt):
|
|
||||||
return None
|
|
||||||
|
|
||||||
if not choice:
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
index = int(choice) - 1
|
|
||||||
if 0 <= index < len(models):
|
|
||||||
selected = models[index]
|
|
||||||
print_success(f"Selected model: {selected['id']}")
|
|
||||||
return selected
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
print_error("Invalid selection")
|
|
||||||
return None
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print_error(f"Failed to fetch models: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def run_chat_loop(
|
|
||||||
session: ChatSession,
|
|
||||||
prompt_session: PromptSession,
|
|
||||||
settings: Settings,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run the main chat loop."""
|
"""Launch the Textual TUI interface."""
|
||||||
|
from oai.constants import VALID_PROVIDERS
|
||||||
|
|
||||||
|
# Setup logging
|
||||||
|
logging_manager = LoggingManager()
|
||||||
|
logging_manager.setup()
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
mcp_manager = session.mcp_manager
|
|
||||||
|
|
||||||
while True:
|
# Load settings
|
||||||
|
settings = Settings.load()
|
||||||
|
|
||||||
|
# Determine provider
|
||||||
|
selected_provider = provider or settings.default_provider
|
||||||
|
|
||||||
|
# Validate provider
|
||||||
|
if selected_provider not in VALID_PROVIDERS:
|
||||||
|
typer.echo(f"Error: Invalid provider: {selected_provider}", err=True)
|
||||||
|
typer.echo(f"Valid providers: {', '.join(VALID_PROVIDERS)}", err=True)
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
# Build provider API keys dict
|
||||||
|
provider_api_keys = {
|
||||||
|
"openrouter": settings.openrouter_api_key,
|
||||||
|
"anthropic": settings.anthropic_api_key,
|
||||||
|
"openai": settings.openai_api_key,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Check if provider is configured (except Ollama which doesn't need API key)
|
||||||
|
if selected_provider != "ollama":
|
||||||
|
if not provider_api_keys.get(selected_provider):
|
||||||
|
typer.echo(f"Error: No API key configured for {selected_provider}", err=True)
|
||||||
|
typer.echo(f"Set it with: oai config {selected_provider}_api_key <key>", err=True)
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
# Initialize client
|
||||||
|
try:
|
||||||
|
client = AIClient(
|
||||||
|
provider_name=selected_provider,
|
||||||
|
provider_api_keys=provider_api_keys,
|
||||||
|
ollama_base_url=settings.ollama_base_url,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
typer.echo(f"Error: Failed to initialize client: {e}", err=True)
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
# Register commands
|
||||||
|
register_all_commands()
|
||||||
|
|
||||||
|
# Initialize MCP manager (always create it, even if not enabled)
|
||||||
|
mcp_manager = MCPManager()
|
||||||
|
if mcp:
|
||||||
try:
|
try:
|
||||||
# Build prompt prefix
|
result = mcp_manager.enable()
|
||||||
prefix = "You> "
|
if result["success"]:
|
||||||
if mcp_manager and mcp_manager.enabled:
|
logger.info("MCP server enabled in files mode")
|
||||||
if mcp_manager.mode == "files":
|
|
||||||
if mcp_manager.write_enabled:
|
|
||||||
prefix = "[🔧✍️ MCP: Files+Write] You> "
|
|
||||||
else:
|
|
||||||
prefix = "[🔧 MCP: Files] You> "
|
|
||||||
elif mcp_manager.mode == "database" and mcp_manager.selected_db_index is not None:
|
|
||||||
prefix = f"[🗄️ MCP: DB #{mcp_manager.selected_db_index + 1}] You> "
|
|
||||||
|
|
||||||
# Get user input
|
|
||||||
user_input = prompt_session.prompt(
|
|
||||||
prefix,
|
|
||||||
auto_suggest=AutoSuggestFromHistory(),
|
|
||||||
).strip()
|
|
||||||
|
|
||||||
if not user_input:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Handle escape sequence
|
|
||||||
if user_input.startswith("//"):
|
|
||||||
user_input = user_input[1:]
|
|
||||||
|
|
||||||
# Check for exit
|
|
||||||
if user_input.lower() in ["exit", "quit", "bye"]:
|
|
||||||
console.print(
|
|
||||||
f"\n[bold yellow]Goodbye![/]\n"
|
|
||||||
f"[dim]Session: {session.stats.total_tokens:,} tokens, "
|
|
||||||
f"${session.stats.total_cost:.4f}[/]"
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"Session ended. Messages: {session.stats.message_count}, "
|
|
||||||
f"Tokens: {session.stats.total_tokens}, "
|
|
||||||
f"Cost: ${session.stats.total_cost:.4f}"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Check for unknown commands
|
|
||||||
if user_input.startswith("/"):
|
|
||||||
cmd_word = user_input.split()[0].lower()
|
|
||||||
if not registry.is_command(user_input):
|
|
||||||
# Check if it's a valid command prefix
|
|
||||||
is_valid = any(cmd_word.startswith(cmd) for cmd in VALID_COMMANDS)
|
|
||||||
if not is_valid:
|
|
||||||
print_error(f"Unknown command: {cmd_word}")
|
|
||||||
print_info("Type /help to see available commands.")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Try to execute as command
|
|
||||||
context = session.get_context()
|
|
||||||
result = registry.execute(user_input, context)
|
|
||||||
|
|
||||||
if result:
|
|
||||||
# Update session state from context
|
|
||||||
session.memory_enabled = context.memory_enabled
|
|
||||||
session.memory_start_index = context.memory_start_index
|
|
||||||
session.online_enabled = context.online_enabled
|
|
||||||
session.middle_out_enabled = context.middle_out_enabled
|
|
||||||
session.session_max_token = context.session_max_token
|
|
||||||
session.current_index = context.current_index
|
|
||||||
session.system_prompt = context.session_system_prompt
|
|
||||||
|
|
||||||
if result.status == CommandStatus.EXIT:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Handle special results
|
|
||||||
if result.data:
|
|
||||||
# Retry - resend last prompt
|
|
||||||
if "retry_prompt" in result.data:
|
|
||||||
user_input = result.data["retry_prompt"]
|
|
||||||
# Fall through to send message
|
|
||||||
|
|
||||||
# Paste - send clipboard content
|
|
||||||
elif "paste_prompt" in result.data:
|
|
||||||
user_input = result.data["paste_prompt"]
|
|
||||||
# Fall through to send message
|
|
||||||
|
|
||||||
# Model selection
|
|
||||||
elif "show_model_selector" in result.data:
|
|
||||||
search = result.data.get("search", "")
|
|
||||||
model = select_model(session.client, search if search else None)
|
|
||||||
if model:
|
|
||||||
session.set_model(model)
|
|
||||||
# If this came from /config model, also save as default
|
|
||||||
if result.data.get("set_as_default"):
|
|
||||||
settings.set_default_model(model["id"])
|
|
||||||
print_success(f"Default model set to: {model['id']}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Load conversation
|
|
||||||
elif "load_conversation" in result.data:
|
|
||||||
history = result.data.get("history", [])
|
|
||||||
session.history.clear()
|
|
||||||
from oai.core.session import HistoryEntry
|
|
||||||
for entry in history:
|
|
||||||
session.history.append(HistoryEntry(
|
|
||||||
prompt=entry.get("prompt", ""),
|
|
||||||
response=entry.get("response", ""),
|
|
||||||
prompt_tokens=entry.get("prompt_tokens", 0),
|
|
||||||
completion_tokens=entry.get("completion_tokens", 0),
|
|
||||||
msg_cost=entry.get("msg_cost", 0.0),
|
|
||||||
))
|
|
||||||
session.current_index = len(session.history) - 1
|
|
||||||
continue
|
|
||||||
|
|
||||||
else:
|
|
||||||
# Normal command completed
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
# Command completed with no special data
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Ensure model is selected
|
|
||||||
if not session.selected_model:
|
|
||||||
print_warning("Please select a model first with /model")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Send message
|
|
||||||
stream = settings.stream_enabled
|
|
||||||
if mcp_manager and mcp_manager.enabled:
|
|
||||||
tools = session.get_mcp_tools()
|
|
||||||
if tools:
|
|
||||||
stream = False # Disable streaming with tools
|
|
||||||
|
|
||||||
if stream:
|
|
||||||
console.print(
|
|
||||||
"[bold green]Streaming response...[/] "
|
|
||||||
"[dim](Press Ctrl+C to cancel)[/]"
|
|
||||||
)
|
|
||||||
if session.online_enabled:
|
|
||||||
console.print("[dim cyan]🌐 Online mode active[/]")
|
|
||||||
console.print("")
|
|
||||||
|
|
||||||
try:
|
|
||||||
response_text, usage, response_time = session.send_message(
|
|
||||||
user_input,
|
|
||||||
stream=stream,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
print_error(f"Error: {e}")
|
|
||||||
logger.error(f"Message error: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not response_text:
|
|
||||||
print_error("No response received")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Display non-streaming response
|
|
||||||
if not stream:
|
|
||||||
console.print()
|
|
||||||
display_panel(
|
|
||||||
Markdown(response_text),
|
|
||||||
title="[bold green]AI Response[/]",
|
|
||||||
border_style="green",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Calculate cost and tokens
|
|
||||||
cost = 0.0
|
|
||||||
tokens = 0
|
|
||||||
estimated = False
|
|
||||||
|
|
||||||
if usage and (usage.prompt_tokens > 0 or usage.completion_tokens > 0):
|
|
||||||
tokens = usage.total_tokens
|
|
||||||
if usage.total_cost_usd:
|
|
||||||
cost = usage.total_cost_usd
|
|
||||||
else:
|
|
||||||
cost = session.client.estimate_cost(
|
|
||||||
session.selected_model["id"],
|
|
||||||
usage.prompt_tokens,
|
|
||||||
usage.completion_tokens,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# Estimate tokens when usage not available (streaming fallback)
|
logger.warning(f"MCP: {result.get('error', 'Failed to enable')}")
|
||||||
# Rough estimate: ~4 characters per token for English text
|
except Exception as e:
|
||||||
est_input_tokens = len(user_input) // 4 + 1
|
logger.warning(f"Failed to enable MCP: {e}")
|
||||||
est_output_tokens = len(response_text) // 4 + 1
|
|
||||||
tokens = est_input_tokens + est_output_tokens
|
|
||||||
cost = session.client.estimate_cost(
|
|
||||||
session.selected_model["id"],
|
|
||||||
est_input_tokens,
|
|
||||||
est_output_tokens,
|
|
||||||
)
|
|
||||||
# Create estimated usage for session tracking
|
|
||||||
usage = UsageStats(
|
|
||||||
prompt_tokens=est_input_tokens,
|
|
||||||
completion_tokens=est_output_tokens,
|
|
||||||
total_tokens=tokens,
|
|
||||||
)
|
|
||||||
estimated = True
|
|
||||||
|
|
||||||
# Add to history
|
# Create session with MCP manager
|
||||||
session.add_to_history(user_input, response_text, usage, cost)
|
session = ChatSession(
|
||||||
|
client=client,
|
||||||
|
settings=settings,
|
||||||
|
mcp_manager=mcp_manager,
|
||||||
|
)
|
||||||
|
|
||||||
# Display metrics
|
# Set system prompt if provided
|
||||||
est_marker = "~" if estimated else ""
|
if system:
|
||||||
context_info = ""
|
session.set_system_prompt(system)
|
||||||
if session.memory_enabled:
|
|
||||||
context_count = len(session.history) - session.memory_start_index
|
|
||||||
if context_count > 1:
|
|
||||||
context_info = f", Context: {context_count} msg(s)"
|
|
||||||
else:
|
|
||||||
context_info = ", Memory: OFF"
|
|
||||||
|
|
||||||
online_emoji = " 🌐" if session.online_enabled else ""
|
# Enable online mode if requested
|
||||||
mcp_emoji = ""
|
if online:
|
||||||
if mcp_manager and mcp_manager.enabled:
|
session.online_enabled = True
|
||||||
if mcp_manager.mode == "files":
|
|
||||||
mcp_emoji = " 🔧"
|
|
||||||
elif mcp_manager.mode == "database":
|
|
||||||
mcp_emoji = " 🗄️"
|
|
||||||
|
|
||||||
console.print(
|
# Set model if specified, otherwise use default
|
||||||
f"\n[dim blue]📊 {est_marker}{tokens} tokens | {est_marker}${cost:.4f} | {response_time:.2f}s"
|
if model:
|
||||||
f"{context_info}{online_emoji}{mcp_emoji} | "
|
raw_model = client.get_raw_model(model)
|
||||||
f"Session: {est_marker}{session.stats.total_tokens:,} tokens | "
|
if raw_model:
|
||||||
f"{est_marker}${session.stats.total_cost:.4f}[/]"
|
session.set_model(raw_model)
|
||||||
)
|
else:
|
||||||
|
logger.warning(f"Model '{model}' not found")
|
||||||
|
elif settings.default_model:
|
||||||
|
raw_model = client.get_raw_model(settings.default_model)
|
||||||
|
if raw_model:
|
||||||
|
session.set_model(raw_model)
|
||||||
|
else:
|
||||||
|
logger.warning(f"Default model '{settings.default_model}' not available")
|
||||||
|
|
||||||
# Check warnings
|
# Run Textual app
|
||||||
warnings = session.check_warnings()
|
from oai.tui.app import oAIChatApp
|
||||||
for warning in warnings:
|
|
||||||
print_warning(warning)
|
|
||||||
|
|
||||||
# Offer to copy
|
app_instance = oAIChatApp(session, settings, model)
|
||||||
console.print("")
|
app_instance.run()
|
||||||
try:
|
|
||||||
from oai.ui.prompts import prompt_copy_response
|
|
||||||
prompt_copy_response(response_text)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
console.print("")
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
console.print("\n[dim]Input cancelled[/]")
|
|
||||||
continue
|
|
||||||
except EOFError:
|
|
||||||
console.print("\n[bold yellow]Goodbye![/]")
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def chat(
|
def tui(
|
||||||
model: Optional[str] = typer.Option(
|
model: Optional[str] = typer.Option(
|
||||||
None,
|
None,
|
||||||
"--model",
|
"--model",
|
||||||
@@ -457,261 +206,19 @@ def chat(
|
|||||||
help="Enable MCP server",
|
help="Enable MCP server",
|
||||||
),
|
),
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Start an interactive chat session."""
|
"""Start Textual TUI interface (alias for just running 'oai')."""
|
||||||
# Setup logging
|
_launch_tui(model, system, online, mcp)
|
||||||
logging_manager = LoggingManager()
|
|
||||||
logging_manager.setup()
|
|
||||||
logger = get_logger()
|
|
||||||
|
|
||||||
# Clear screen
|
|
||||||
clear_screen()
|
|
||||||
|
|
||||||
# Load settings
|
|
||||||
settings = Settings.load()
|
|
||||||
|
|
||||||
# Check API key
|
|
||||||
if not settings.api_key:
|
|
||||||
print_error("No API key configured")
|
|
||||||
print_info("Run: oai --config api to set your API key")
|
|
||||||
raise typer.Exit(1)
|
|
||||||
|
|
||||||
# Initialize client
|
|
||||||
try:
|
|
||||||
client = AIClient(
|
|
||||||
api_key=settings.api_key,
|
|
||||||
base_url=settings.base_url,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
print_error(f"Failed to initialize client: {e}")
|
|
||||||
raise typer.Exit(1)
|
|
||||||
|
|
||||||
# Register commands
|
|
||||||
register_all_commands()
|
|
||||||
|
|
||||||
# Check for updates and show welcome
|
|
||||||
version_info = check_for_updates(APP_VERSION)
|
|
||||||
show_welcome(settings, version_info)
|
|
||||||
|
|
||||||
# Initialize MCP manager
|
|
||||||
mcp_manager = MCPManager()
|
|
||||||
if mcp:
|
|
||||||
result = mcp_manager.enable()
|
|
||||||
if result["success"]:
|
|
||||||
print_success("MCP enabled")
|
|
||||||
else:
|
|
||||||
print_warning(f"MCP: {result.get('error', 'Failed to enable')}")
|
|
||||||
|
|
||||||
# Create session
|
|
||||||
session = ChatSession(
|
|
||||||
client=client,
|
|
||||||
settings=settings,
|
|
||||||
mcp_manager=mcp_manager,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set system prompt
|
|
||||||
if system:
|
|
||||||
session.system_prompt = system
|
|
||||||
print_info(f"System prompt: {system}")
|
|
||||||
|
|
||||||
# Set online mode
|
|
||||||
if online:
|
|
||||||
session.online_enabled = True
|
|
||||||
print_info("Online mode enabled")
|
|
||||||
|
|
||||||
# Select model
|
|
||||||
if model:
|
|
||||||
raw_model = client.get_raw_model(model)
|
|
||||||
if raw_model:
|
|
||||||
session.set_model(raw_model)
|
|
||||||
else:
|
|
||||||
print_warning(f"Model '{model}' not found")
|
|
||||||
elif settings.default_model:
|
|
||||||
raw_model = client.get_raw_model(settings.default_model)
|
|
||||||
if raw_model:
|
|
||||||
session.set_model(raw_model)
|
|
||||||
else:
|
|
||||||
print_warning(f"Default model '{settings.default_model}' not available")
|
|
||||||
|
|
||||||
# Setup prompt session
|
|
||||||
HISTORY_FILE.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
prompt_session = PromptSession(
|
|
||||||
history=FileHistory(str(HISTORY_FILE)),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Run chat loop
|
|
||||||
run_chat_loop(session, prompt_session, settings)
|
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
|
||||||
def config(
|
|
||||||
setting: Optional[str] = typer.Argument(
|
|
||||||
None,
|
|
||||||
help="Setting to configure (api, url, model, system, stream, costwarning, maxtoken, online, log, loglevel)",
|
|
||||||
),
|
|
||||||
value: Optional[str] = typer.Argument(
|
|
||||||
None,
|
|
||||||
help="Value to set",
|
|
||||||
),
|
|
||||||
) -> None:
|
|
||||||
"""View or modify configuration settings."""
|
|
||||||
settings = Settings.load()
|
|
||||||
|
|
||||||
if not setting:
|
|
||||||
# Show all settings
|
|
||||||
from rich.table import Table
|
|
||||||
from oai.constants import DEFAULT_SYSTEM_PROMPT
|
|
||||||
|
|
||||||
table = Table("Setting", "Value", show_header=True, header_style="bold magenta")
|
|
||||||
table.add_row("API Key", "***" + settings.api_key[-4:] if settings.api_key else "Not set")
|
|
||||||
table.add_row("Base URL", settings.base_url)
|
|
||||||
table.add_row("Default Model", settings.default_model or "Not set")
|
|
||||||
|
|
||||||
# Show system prompt status
|
|
||||||
if settings.default_system_prompt is None:
|
|
||||||
system_prompt_display = f"[default] {DEFAULT_SYSTEM_PROMPT[:40]}..."
|
|
||||||
elif settings.default_system_prompt == "":
|
|
||||||
system_prompt_display = "[blank]"
|
|
||||||
else:
|
|
||||||
system_prompt_display = settings.default_system_prompt[:50] + "..." if len(settings.default_system_prompt) > 50 else settings.default_system_prompt
|
|
||||||
table.add_row("System Prompt", system_prompt_display)
|
|
||||||
|
|
||||||
table.add_row("Streaming", "on" if settings.stream_enabled else "off")
|
|
||||||
table.add_row("Cost Warning", f"${settings.cost_warning_threshold:.4f}")
|
|
||||||
table.add_row("Max Tokens", str(settings.max_tokens))
|
|
||||||
table.add_row("Default Online", "on" if settings.default_online_mode else "off")
|
|
||||||
table.add_row("Log Level", settings.log_level)
|
|
||||||
|
|
||||||
display_panel(table, title="[bold green]Configuration[/]")
|
|
||||||
return
|
|
||||||
|
|
||||||
setting = setting.lower()
|
|
||||||
|
|
||||||
if setting == "api":
|
|
||||||
if value:
|
|
||||||
settings.set_api_key(value)
|
|
||||||
else:
|
|
||||||
from oai.ui.prompts import prompt_input
|
|
||||||
new_key = prompt_input("Enter API key", password=True)
|
|
||||||
if new_key:
|
|
||||||
settings.set_api_key(new_key)
|
|
||||||
print_success("API key updated")
|
|
||||||
|
|
||||||
elif setting == "url":
|
|
||||||
settings.set_base_url(value or "https://openrouter.ai/api/v1")
|
|
||||||
print_success(f"Base URL set to: {settings.base_url}")
|
|
||||||
|
|
||||||
elif setting == "model":
|
|
||||||
if value:
|
|
||||||
settings.set_default_model(value)
|
|
||||||
print_success(f"Default model set to: {value}")
|
|
||||||
else:
|
|
||||||
print_info(f"Current default model: {settings.default_model or 'Not set'}")
|
|
||||||
|
|
||||||
elif setting == "system":
|
|
||||||
from oai.constants import DEFAULT_SYSTEM_PROMPT
|
|
||||||
|
|
||||||
if value:
|
|
||||||
# Decode escape sequences like \n for newlines
|
|
||||||
value = value.encode().decode('unicode_escape')
|
|
||||||
settings.set_default_system_prompt(value)
|
|
||||||
if value:
|
|
||||||
print_success(f"Default system prompt set to: {value}")
|
|
||||||
else:
|
|
||||||
print_success("Default system prompt set to blank.")
|
|
||||||
else:
|
|
||||||
if settings.default_system_prompt is None:
|
|
||||||
print_info(f"Using hardcoded default: {DEFAULT_SYSTEM_PROMPT[:60]}...")
|
|
||||||
elif settings.default_system_prompt == "":
|
|
||||||
print_info("System prompt: [blank]")
|
|
||||||
else:
|
|
||||||
print_info(f"System prompt: {settings.default_system_prompt}")
|
|
||||||
|
|
||||||
elif setting == "stream":
|
|
||||||
if value and value.lower() in ["on", "off"]:
|
|
||||||
settings.set_stream_enabled(value.lower() == "on")
|
|
||||||
print_success(f"Streaming {'enabled' if settings.stream_enabled else 'disabled'}")
|
|
||||||
else:
|
|
||||||
print_info("Usage: oai config stream [on|off]")
|
|
||||||
|
|
||||||
elif setting == "costwarning":
|
|
||||||
if value:
|
|
||||||
try:
|
|
||||||
threshold = float(value)
|
|
||||||
settings.set_cost_warning_threshold(threshold)
|
|
||||||
print_success(f"Cost warning threshold set to: ${threshold:.4f}")
|
|
||||||
except ValueError:
|
|
||||||
print_error("Please enter a valid number")
|
|
||||||
else:
|
|
||||||
print_info(f"Current threshold: ${settings.cost_warning_threshold:.4f}")
|
|
||||||
|
|
||||||
elif setting == "maxtoken":
|
|
||||||
if value:
|
|
||||||
try:
|
|
||||||
max_tok = int(value)
|
|
||||||
settings.set_max_tokens(max_tok)
|
|
||||||
print_success(f"Max tokens set to: {max_tok}")
|
|
||||||
except ValueError:
|
|
||||||
print_error("Please enter a valid number")
|
|
||||||
else:
|
|
||||||
print_info(f"Current max tokens: {settings.max_tokens}")
|
|
||||||
|
|
||||||
elif setting == "online":
|
|
||||||
if value and value.lower() in ["on", "off"]:
|
|
||||||
settings.set_default_online_mode(value.lower() == "on")
|
|
||||||
print_success(f"Default online mode {'enabled' if settings.default_online_mode else 'disabled'}")
|
|
||||||
else:
|
|
||||||
print_info("Usage: oai config online [on|off]")
|
|
||||||
|
|
||||||
elif setting == "loglevel":
|
|
||||||
valid_levels = ["debug", "info", "warning", "error", "critical"]
|
|
||||||
if value and value.lower() in valid_levels:
|
|
||||||
settings.set_log_level(value.lower())
|
|
||||||
print_success(f"Log level set to: {value.lower()}")
|
|
||||||
else:
|
|
||||||
print_info(f"Valid levels: {', '.join(valid_levels)}")
|
|
||||||
|
|
||||||
else:
|
|
||||||
print_error(f"Unknown setting: {setting}")
|
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def version() -> None:
|
def version() -> None:
|
||||||
"""Show version information."""
|
"""Show version information."""
|
||||||
version_info = check_for_updates(APP_VERSION)
|
typer.echo(f"oAI version {APP_VERSION}")
|
||||||
console.print(version_info)
|
typer.echo(f"Visit {APP_URL} for more information")
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
|
||||||
def credits() -> None:
|
|
||||||
"""Check account credits."""
|
|
||||||
settings = Settings.load()
|
|
||||||
|
|
||||||
if not settings.api_key:
|
|
||||||
print_error("No API key configured")
|
|
||||||
raise typer.Exit(1)
|
|
||||||
|
|
||||||
client = AIClient(api_key=settings.api_key, base_url=settings.base_url)
|
|
||||||
credits_data = client.get_credits()
|
|
||||||
|
|
||||||
if not credits_data:
|
|
||||||
print_error("Failed to fetch credits")
|
|
||||||
raise typer.Exit(1)
|
|
||||||
|
|
||||||
from rich.table import Table
|
|
||||||
|
|
||||||
table = Table("Metric", "Value", show_header=True, header_style="bold magenta")
|
|
||||||
table.add_row("Total Credits", credits_data.get("total_credits_formatted", "N/A"))
|
|
||||||
table.add_row("Used Credits", credits_data.get("used_credits_formatted", "N/A"))
|
|
||||||
table.add_row("Credits Left", credits_data.get("credits_left_formatted", "N/A"))
|
|
||||||
|
|
||||||
display_panel(table, title="[bold green]Account Credits[/]")
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
"""Main entry point."""
|
"""Entry point for the CLI."""
|
||||||
# Default to 'chat' command if no arguments provided
|
|
||||||
if len(sys.argv) == 1:
|
|
||||||
sys.argv.append("chat")
|
|
||||||
app()
|
app()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -80,6 +80,7 @@ class CommandContext:
|
|||||||
online_enabled: Whether online mode is enabled
|
online_enabled: Whether online mode is enabled
|
||||||
session_tokens: Session token counts
|
session_tokens: Session token counts
|
||||||
session_cost: Session cost total
|
session_cost: Session cost total
|
||||||
|
session: Reference to the ChatSession (for provider switching, etc.)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
settings: Optional["Settings"] = None
|
settings: Optional["Settings"] = None
|
||||||
@@ -98,7 +99,9 @@ class CommandContext:
|
|||||||
total_output_tokens: int = 0
|
total_output_tokens: int = 0
|
||||||
total_cost: float = 0.0
|
total_cost: float = 0.0
|
||||||
message_count: int = 0
|
message_count: int = 0
|
||||||
|
is_tui: bool = False # Flag for TUI mode
|
||||||
current_index: int = 0
|
current_index: int = 0
|
||||||
|
session: Optional[Any] = None # Reference to ChatSession (avoid circular import)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -6,8 +6,9 @@ configuration with type safety, validation, and persistence.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional
|
from typing import Optional, Dict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import json
|
||||||
|
|
||||||
from oai.constants import (
|
from oai.constants import (
|
||||||
DEFAULT_BASE_URL,
|
DEFAULT_BASE_URL,
|
||||||
@@ -20,6 +21,8 @@ from oai.constants import (
|
|||||||
DEFAULT_LOG_LEVEL,
|
DEFAULT_LOG_LEVEL,
|
||||||
DEFAULT_SYSTEM_PROMPT,
|
DEFAULT_SYSTEM_PROMPT,
|
||||||
VALID_LOG_LEVELS,
|
VALID_LOG_LEVELS,
|
||||||
|
DEFAULT_PROVIDER,
|
||||||
|
OLLAMA_DEFAULT_URL,
|
||||||
)
|
)
|
||||||
from oai.config.database import get_database
|
from oai.config.database import get_database
|
||||||
|
|
||||||
@@ -34,7 +37,7 @@ class Settings:
|
|||||||
initialization and can be persisted back.
|
initialization and can be persisted back.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
api_key: OpenRouter API key
|
api_key: Legacy OpenRouter API key (deprecated, use openrouter_api_key)
|
||||||
base_url: API base URL
|
base_url: API base URL
|
||||||
default_model: Default model ID to use
|
default_model: Default model ID to use
|
||||||
default_system_prompt: Custom system prompt (None = use hardcoded default, "" = blank)
|
default_system_prompt: Custom system prompt (None = use hardcoded default, "" = blank)
|
||||||
@@ -45,9 +48,33 @@ class Settings:
|
|||||||
log_max_size_mb: Maximum log file size in MB
|
log_max_size_mb: Maximum log file size in MB
|
||||||
log_backup_count: Number of log file backups to keep
|
log_backup_count: Number of log file backups to keep
|
||||||
log_level: Logging level (debug/info/warning/error/critical)
|
log_level: Logging level (debug/info/warning/error/critical)
|
||||||
|
|
||||||
|
# Provider-specific settings
|
||||||
|
default_provider: Default AI provider to use
|
||||||
|
openrouter_api_key: OpenRouter API key
|
||||||
|
anthropic_api_key: Anthropic API key
|
||||||
|
openai_api_key: OpenAI API key
|
||||||
|
ollama_base_url: Ollama server URL
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Legacy field (kept for backward compatibility)
|
||||||
api_key: Optional[str] = None
|
api_key: Optional[str] = None
|
||||||
|
|
||||||
|
# Provider configuration
|
||||||
|
default_provider: str = DEFAULT_PROVIDER
|
||||||
|
openrouter_api_key: Optional[str] = None
|
||||||
|
anthropic_api_key: Optional[str] = None
|
||||||
|
openai_api_key: Optional[str] = None
|
||||||
|
ollama_base_url: str = OLLAMA_DEFAULT_URL
|
||||||
|
provider_models: Dict[str, str] = field(default_factory=dict) # provider -> last_model_id
|
||||||
|
|
||||||
|
# Web search configuration (for online mode with non-OpenRouter providers)
|
||||||
|
search_provider: str = "duckduckgo" # "duckduckgo" or "google"
|
||||||
|
google_api_key: Optional[str] = None
|
||||||
|
google_search_engine_id: Optional[str] = None
|
||||||
|
search_num_results: int = 5
|
||||||
|
|
||||||
|
# General settings
|
||||||
base_url: str = DEFAULT_BASE_URL
|
base_url: str = DEFAULT_BASE_URL
|
||||||
default_model: Optional[str] = None
|
default_model: Optional[str] = None
|
||||||
default_system_prompt: Optional[str] = None
|
default_system_prompt: Optional[str] = None
|
||||||
@@ -134,8 +161,43 @@ class Settings:
|
|||||||
# Get system prompt from DB: None means not set (use default), "" means explicitly blank
|
# Get system prompt from DB: None means not set (use default), "" means explicitly blank
|
||||||
system_prompt_value = db.get_config("default_system_prompt")
|
system_prompt_value = db.get_config("default_system_prompt")
|
||||||
|
|
||||||
|
# Migration: copy legacy api_key to openrouter_api_key if not already set
|
||||||
|
legacy_api_key = db.get_config("api_key")
|
||||||
|
openrouter_key = db.get_config("openrouter_api_key")
|
||||||
|
|
||||||
|
if legacy_api_key and not openrouter_key:
|
||||||
|
db.set_config("openrouter_api_key", legacy_api_key)
|
||||||
|
openrouter_key = legacy_api_key
|
||||||
|
# Note: We keep the legacy api_key in DB for backward compatibility
|
||||||
|
|
||||||
|
# Load provider-model mapping
|
||||||
|
provider_models_json = db.get_config("provider_models")
|
||||||
|
provider_models = {}
|
||||||
|
if provider_models_json:
|
||||||
|
try:
|
||||||
|
provider_models = json.loads(provider_models_json)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
provider_models = {}
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
api_key=db.get_config("api_key"),
|
# Legacy field
|
||||||
|
api_key=legacy_api_key,
|
||||||
|
|
||||||
|
# Provider configuration
|
||||||
|
default_provider=db.get_config("default_provider") or DEFAULT_PROVIDER,
|
||||||
|
openrouter_api_key=openrouter_key,
|
||||||
|
anthropic_api_key=db.get_config("anthropic_api_key"),
|
||||||
|
openai_api_key=db.get_config("openai_api_key"),
|
||||||
|
ollama_base_url=db.get_config("ollama_base_url") or OLLAMA_DEFAULT_URL,
|
||||||
|
provider_models=provider_models,
|
||||||
|
|
||||||
|
# Web search configuration
|
||||||
|
search_provider=db.get_config("search_provider") or "duckduckgo",
|
||||||
|
google_api_key=db.get_config("google_api_key"),
|
||||||
|
google_search_engine_id=db.get_config("google_search_engine_id"),
|
||||||
|
search_num_results=parse_int(db.get_config("search_num_results"), 5),
|
||||||
|
|
||||||
|
# General settings
|
||||||
base_url=db.get_config("base_url") or DEFAULT_BASE_URL,
|
base_url=db.get_config("base_url") or DEFAULT_BASE_URL,
|
||||||
default_model=db.get_config("default_model"),
|
default_model=db.get_config("default_model"),
|
||||||
default_system_prompt=system_prompt_value,
|
default_system_prompt=system_prompt_value,
|
||||||
@@ -331,6 +393,155 @@ class Settings:
|
|||||||
self.log_max_size_mb = min(size_mb, 100)
|
self.log_max_size_mb = min(size_mb, 100)
|
||||||
get_database().set_config("log_max_size_mb", str(self.log_max_size_mb))
|
get_database().set_config("log_max_size_mb", str(self.log_max_size_mb))
|
||||||
|
|
||||||
|
def set_provider_api_key(self, provider: str, api_key: str) -> None:
|
||||||
|
"""
|
||||||
|
Set and persist an API key for a specific provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider: Provider name ("openrouter", "anthropic", "openai")
|
||||||
|
api_key: The API key to set
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If provider is invalid
|
||||||
|
"""
|
||||||
|
provider = provider.lower()
|
||||||
|
api_key = api_key.strip()
|
||||||
|
|
||||||
|
if provider == "openrouter":
|
||||||
|
self.openrouter_api_key = api_key
|
||||||
|
get_database().set_config("openrouter_api_key", api_key)
|
||||||
|
elif provider == "anthropic":
|
||||||
|
self.anthropic_api_key = api_key
|
||||||
|
get_database().set_config("anthropic_api_key", api_key)
|
||||||
|
elif provider == "openai":
|
||||||
|
self.openai_api_key = api_key
|
||||||
|
get_database().set_config("openai_api_key", api_key)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid provider: {provider}")
|
||||||
|
|
||||||
|
def get_provider_api_key(self, provider: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Get the API key for a specific provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider: Provider name ("openrouter", "anthropic", "openai", "ollama")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
API key or None if not set
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If provider is invalid
|
||||||
|
"""
|
||||||
|
provider = provider.lower()
|
||||||
|
|
||||||
|
if provider == "openrouter":
|
||||||
|
return self.openrouter_api_key
|
||||||
|
elif provider == "anthropic":
|
||||||
|
return self.anthropic_api_key
|
||||||
|
elif provider == "openai":
|
||||||
|
return self.openai_api_key
|
||||||
|
elif provider == "ollama":
|
||||||
|
return "" # Ollama doesn't require an API key
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid provider: {provider}")
|
||||||
|
|
||||||
|
def set_default_provider(self, provider: str) -> None:
|
||||||
|
"""
|
||||||
|
Set and persist the default provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider: Provider name
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If provider is invalid
|
||||||
|
"""
|
||||||
|
from oai.constants import VALID_PROVIDERS
|
||||||
|
|
||||||
|
provider = provider.lower()
|
||||||
|
if provider not in VALID_PROVIDERS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid provider: {provider}. "
|
||||||
|
f"Valid providers: {', '.join(VALID_PROVIDERS)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.default_provider = provider
|
||||||
|
get_database().set_config("default_provider", provider)
|
||||||
|
|
||||||
|
def set_ollama_base_url(self, url: str) -> None:
|
||||||
|
"""
|
||||||
|
Set and persist the Ollama base URL.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: Ollama server URL
|
||||||
|
"""
|
||||||
|
self.ollama_base_url = url.strip()
|
||||||
|
get_database().set_config("ollama_base_url", self.ollama_base_url)
|
||||||
|
|
||||||
|
def set_search_provider(self, provider: str) -> None:
|
||||||
|
"""
|
||||||
|
Set and persist the web search provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider: Search provider ("anthropic_native", "duckduckgo", "google")
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If provider is invalid
|
||||||
|
"""
|
||||||
|
valid_providers = ["anthropic_native", "duckduckgo", "google"]
|
||||||
|
provider = provider.lower()
|
||||||
|
if provider not in valid_providers:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid search provider: {provider}. "
|
||||||
|
f"Valid providers: {', '.join(valid_providers)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.search_provider = provider
|
||||||
|
get_database().set_config("search_provider", provider)
|
||||||
|
|
||||||
|
def set_google_api_key(self, api_key: str) -> None:
|
||||||
|
"""
|
||||||
|
Set and persist the Google API key for Google Custom Search.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: The Google API key
|
||||||
|
"""
|
||||||
|
self.google_api_key = api_key.strip()
|
||||||
|
get_database().set_config("google_api_key", self.google_api_key)
|
||||||
|
|
||||||
|
def set_google_search_engine_id(self, engine_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Set and persist the Google Custom Search Engine ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
engine_id: The Google Search Engine ID
|
||||||
|
"""
|
||||||
|
self.google_search_engine_id = engine_id.strip()
|
||||||
|
get_database().set_config("google_search_engine_id", self.google_search_engine_id)
|
||||||
|
|
||||||
|
def get_provider_model(self, provider: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Get the last used model for a provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider: Provider name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Model ID or None if not set
|
||||||
|
"""
|
||||||
|
return self.provider_models.get(provider)
|
||||||
|
|
||||||
|
def set_provider_model(self, provider: str, model_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Set and persist the last used model for a provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider: Provider name
|
||||||
|
model_id: Model ID to remember
|
||||||
|
"""
|
||||||
|
self.provider_models[provider] = model_id
|
||||||
|
# Save to database as JSON
|
||||||
|
get_database().set_config("provider_models", json.dumps(self.provider_models))
|
||||||
|
|
||||||
|
|
||||||
# Global settings instance
|
# Global settings instance
|
||||||
_settings: Optional[Settings] = None
|
_settings: Optional[Settings] = None
|
||||||
|
|||||||
@@ -10,14 +10,17 @@ from pathlib import Path
|
|||||||
from typing import Set, Dict, Any
|
from typing import Set, Dict, Any
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
# Import version from single source of truth
|
||||||
|
from oai import __version__
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# APPLICATION METADATA
|
# APPLICATION METADATA
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
APP_NAME = "oAI"
|
APP_NAME = "oAI"
|
||||||
APP_VERSION = "2.1.0"
|
APP_VERSION = __version__ # Single source of truth in oai/__init__.py
|
||||||
APP_URL = "https://iurl.no/oai"
|
APP_URL = "https://iurl.no/oai"
|
||||||
APP_DESCRIPTION = "OpenRouter AI Chat Client with MCP Integration"
|
APP_DESCRIPTION = "Open AI Chat Client with Multi-Provider Support"
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# FILE PATHS
|
# FILE PATHS
|
||||||
@@ -39,6 +42,26 @@ DEFAULT_STREAM_ENABLED = True
|
|||||||
DEFAULT_MAX_TOKENS = 100_000
|
DEFAULT_MAX_TOKENS = 100_000
|
||||||
DEFAULT_ONLINE_MODE = False
|
DEFAULT_ONLINE_MODE = False
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# PROVIDER CONFIGURATION
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
# Provider names
|
||||||
|
PROVIDER_OPENROUTER = "openrouter"
|
||||||
|
PROVIDER_ANTHROPIC = "anthropic"
|
||||||
|
PROVIDER_OPENAI = "openai"
|
||||||
|
PROVIDER_OLLAMA = "ollama"
|
||||||
|
|
||||||
|
VALID_PROVIDERS = [PROVIDER_OPENROUTER, PROVIDER_ANTHROPIC, PROVIDER_OPENAI, PROVIDER_OLLAMA]
|
||||||
|
|
||||||
|
# Provider base URLs
|
||||||
|
ANTHROPIC_BASE_URL = "https://api.anthropic.com"
|
||||||
|
OPENAI_BASE_URL = "https://api.openai.com/v1"
|
||||||
|
OLLAMA_DEFAULT_URL = "http://localhost:11434"
|
||||||
|
|
||||||
|
# Default provider
|
||||||
|
DEFAULT_PROVIDER = PROVIDER_OPENROUTER
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# DEFAULT SYSTEM PROMPT
|
# DEFAULT SYSTEM PROMPT
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import asyncio
|
|||||||
import json
|
import json
|
||||||
from typing import Any, Callable, Dict, Iterator, List, Optional, Union
|
from typing import Any, Callable, Dict, Iterator, List, Optional, Union
|
||||||
|
|
||||||
from oai.constants import APP_NAME, APP_URL, MODEL_PRICING
|
from oai.constants import APP_NAME, APP_URL, MODEL_PRICING, OLLAMA_DEFAULT_URL
|
||||||
from oai.providers.base import (
|
from oai.providers.base import (
|
||||||
AIProvider,
|
AIProvider,
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
@@ -19,7 +19,6 @@ from oai.providers.base import (
|
|||||||
ToolCall,
|
ToolCall,
|
||||||
UsageStats,
|
UsageStats,
|
||||||
)
|
)
|
||||||
from oai.providers.openrouter import OpenRouterProvider
|
|
||||||
from oai.utils.logging import get_logger
|
from oai.utils.logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
@@ -32,37 +31,66 @@ class AIClient:
|
|||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
provider: The underlying AI provider
|
provider: The underlying AI provider
|
||||||
|
provider_name: Name of the current provider
|
||||||
default_model: Default model ID to use
|
default_model: Default model ID to use
|
||||||
http_headers: Custom HTTP headers for requests
|
http_headers: Custom HTTP headers for requests
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
api_key: str,
|
provider_name: str = "openrouter",
|
||||||
base_url: Optional[str] = None,
|
provider_api_keys: Optional[Dict[str, str]] = None,
|
||||||
provider_class: type = OpenRouterProvider,
|
ollama_base_url: str = OLLAMA_DEFAULT_URL,
|
||||||
app_name: str = APP_NAME,
|
app_name: str = APP_NAME,
|
||||||
app_url: str = APP_URL,
|
app_url: str = APP_URL,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the AI client.
|
Initialize the AI client with specified provider.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
api_key: API key for authentication
|
provider_name: Provider to use ("openrouter", "anthropic", "openai", "ollama")
|
||||||
base_url: Optional custom base URL
|
provider_api_keys: Dict mapping provider names to API keys
|
||||||
provider_class: Provider class to use (default: OpenRouterProvider)
|
ollama_base_url: Base URL for Ollama server
|
||||||
app_name: Application name for headers
|
app_name: Application name for headers
|
||||||
app_url: Application URL for headers
|
app_url: Application URL for headers
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If provider is invalid or not configured
|
||||||
"""
|
"""
|
||||||
self.provider: AIProvider = provider_class(
|
from oai.providers.registry import get_provider_class
|
||||||
api_key=api_key,
|
|
||||||
base_url=base_url,
|
self.provider_name = provider_name
|
||||||
app_name=app_name,
|
self.provider_api_keys = provider_api_keys or {}
|
||||||
app_url=app_url,
|
self.ollama_base_url = ollama_base_url
|
||||||
)
|
self.app_name = app_name
|
||||||
|
self.app_url = app_url
|
||||||
|
|
||||||
|
# Get provider class
|
||||||
|
provider_class = get_provider_class(provider_name)
|
||||||
|
if not provider_class:
|
||||||
|
raise ValueError(f"Unknown provider: {provider_name}")
|
||||||
|
|
||||||
|
# Get API key for this provider
|
||||||
|
api_key = self.provider_api_keys.get(provider_name, "")
|
||||||
|
|
||||||
|
# Initialize provider with appropriate parameters
|
||||||
|
if provider_name == "ollama":
|
||||||
|
self.provider: AIProvider = provider_class(
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=ollama_base_url,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.provider: AIProvider = provider_class(
|
||||||
|
api_key=api_key,
|
||||||
|
app_name=app_name,
|
||||||
|
app_url=app_url,
|
||||||
|
)
|
||||||
|
|
||||||
self.default_model: Optional[str] = None
|
self.default_model: Optional[str] = None
|
||||||
self.logger = get_logger()
|
self.logger = get_logger()
|
||||||
|
|
||||||
|
self.logger.info(f"Initialized {provider_name} provider")
|
||||||
|
|
||||||
def list_models(self, filter_text_only: bool = True) -> List[ModelInfo]:
|
def list_models(self, filter_text_only: bool = True) -> List[ModelInfo]:
|
||||||
"""
|
"""
|
||||||
Get available models.
|
Get available models.
|
||||||
@@ -420,3 +448,49 @@ class AIClient:
|
|||||||
"""Clear the provider's model cache."""
|
"""Clear the provider's model cache."""
|
||||||
if hasattr(self.provider, "clear_cache"):
|
if hasattr(self.provider, "clear_cache"):
|
||||||
self.provider.clear_cache()
|
self.provider.clear_cache()
|
||||||
|
|
||||||
|
def switch_provider(self, provider_name: str, ollama_base_url: Optional[str] = None) -> None:
|
||||||
|
"""
|
||||||
|
Switch to a different provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider_name: Provider to switch to
|
||||||
|
ollama_base_url: Optional Ollama base URL (if switching to Ollama)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If provider is invalid or not configured
|
||||||
|
"""
|
||||||
|
from oai.providers.registry import get_provider_class
|
||||||
|
|
||||||
|
# Get provider class
|
||||||
|
provider_class = get_provider_class(provider_name)
|
||||||
|
if not provider_class:
|
||||||
|
raise ValueError(f"Unknown provider: {provider_name}")
|
||||||
|
|
||||||
|
# Get API key
|
||||||
|
api_key = self.provider_api_keys.get(provider_name, "")
|
||||||
|
|
||||||
|
# Check API key requirement
|
||||||
|
if provider_name != "ollama" and not api_key:
|
||||||
|
raise ValueError(f"No API key configured for {provider_name}")
|
||||||
|
|
||||||
|
# Initialize new provider
|
||||||
|
if provider_name == "ollama":
|
||||||
|
base_url = ollama_base_url or self.ollama_base_url
|
||||||
|
self.provider = provider_class(
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url,
|
||||||
|
)
|
||||||
|
self.ollama_base_url = base_url
|
||||||
|
else:
|
||||||
|
self.provider = provider_class(
|
||||||
|
api_key=api_key,
|
||||||
|
app_name=self.app_name,
|
||||||
|
app_url=self.app_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.provider_name = provider_name
|
||||||
|
self.logger.info(f"Switched to {provider_name} provider")
|
||||||
|
|
||||||
|
# Clear model cache when switching providers
|
||||||
|
self.default_model = None
|
||||||
|
|||||||
@@ -9,10 +9,7 @@ import asyncio
|
|||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple
|
from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Optional, Tuple
|
||||||
|
|
||||||
from rich.live import Live
|
|
||||||
from rich.markdown import Markdown
|
|
||||||
|
|
||||||
from oai.commands.registry import CommandContext, CommandResult, registry
|
from oai.commands.registry import CommandContext, CommandResult, registry
|
||||||
from oai.config.database import Database
|
from oai.config.database import Database
|
||||||
@@ -25,17 +22,10 @@ from oai.constants import (
|
|||||||
from oai.core.client import AIClient
|
from oai.core.client import AIClient
|
||||||
from oai.mcp.manager import MCPManager
|
from oai.mcp.manager import MCPManager
|
||||||
from oai.providers.base import ChatResponse, StreamChunk, UsageStats
|
from oai.providers.base import ChatResponse, StreamChunk, UsageStats
|
||||||
from oai.ui.console import (
|
|
||||||
console,
|
|
||||||
display_markdown,
|
|
||||||
display_panel,
|
|
||||||
print_error,
|
|
||||||
print_info,
|
|
||||||
print_success,
|
|
||||||
print_warning,
|
|
||||||
)
|
|
||||||
from oai.ui.prompts import prompt_copy_response
|
|
||||||
from oai.utils.logging import get_logger
|
from oai.utils.logging import get_logger
|
||||||
|
from oai.utils.web_search import perform_web_search, format_search_results
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -177,6 +167,7 @@ class ChatSession:
|
|||||||
total_cost=self.stats.total_cost,
|
total_cost=self.stats.total_cost,
|
||||||
message_count=self.stats.message_count,
|
message_count=self.stats.message_count,
|
||||||
current_index=self.current_index,
|
current_index=self.current_index,
|
||||||
|
session=self,
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_model(self, model: Dict[str, Any]) -> None:
|
def set_model(self, model: Dict[str, Any]) -> None:
|
||||||
@@ -303,6 +294,44 @@ class ChatSession:
|
|||||||
|
|
||||||
# Build request parameters
|
# Build request parameters
|
||||||
model_id = self.selected_model["id"]
|
model_id = self.selected_model["id"]
|
||||||
|
|
||||||
|
# Handle online mode
|
||||||
|
enable_web_search = False
|
||||||
|
web_search_config = {}
|
||||||
|
|
||||||
|
if self.online_enabled:
|
||||||
|
# OpenRouter handles online mode natively with :online suffix
|
||||||
|
if self.client.provider_name == "openrouter":
|
||||||
|
if hasattr(self.client.provider, "get_effective_model_id"):
|
||||||
|
model_id = self.client.provider.get_effective_model_id(model_id, True)
|
||||||
|
# Anthropic has native web search when search provider is set to anthropic_native
|
||||||
|
elif self.client.provider_name == "anthropic" and self.settings.search_provider == "anthropic_native":
|
||||||
|
enable_web_search = True
|
||||||
|
web_search_config = {
|
||||||
|
"max_uses": self.settings.search_num_results or 5
|
||||||
|
}
|
||||||
|
logger.info("Using Anthropic native web search")
|
||||||
|
else:
|
||||||
|
# For other providers, perform web search and inject results
|
||||||
|
logger.info(f"Performing web search for: {user_input}")
|
||||||
|
search_results = perform_web_search(
|
||||||
|
user_input,
|
||||||
|
num_results=self.settings.search_num_results,
|
||||||
|
provider=self.settings.search_provider,
|
||||||
|
google_api_key=self.settings.google_api_key,
|
||||||
|
google_search_engine_id=self.settings.google_search_engine_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if search_results:
|
||||||
|
# Inject search results into messages
|
||||||
|
formatted_results = format_search_results(search_results)
|
||||||
|
search_context = f"\n\n{formatted_results}\n\nPlease use the above web search results to help answer the user's question."
|
||||||
|
|
||||||
|
# Add search results to the last user message
|
||||||
|
if messages and messages[-1]["role"] == "user":
|
||||||
|
messages[-1]["content"] += search_context
|
||||||
|
|
||||||
|
logger.info(f"Injected {len(search_results)} search results into context")
|
||||||
if self.online_enabled:
|
if self.online_enabled:
|
||||||
if hasattr(self.client.provider, "get_effective_model_id"):
|
if hasattr(self.client.provider, "get_effective_model_id"):
|
||||||
model_id = self.client.provider.get_effective_model_id(model_id, True)
|
model_id = self.client.provider.get_effective_model_id(model_id, True)
|
||||||
@@ -333,6 +362,8 @@ class ChatSession:
|
|||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
transforms=transforms,
|
transforms=transforms,
|
||||||
on_chunk=on_stream_chunk,
|
on_chunk=on_stream_chunk,
|
||||||
|
enable_web_search=enable_web_search,
|
||||||
|
web_search_config=web_search_config,
|
||||||
)
|
)
|
||||||
response_time = time.time() - start_time
|
response_time = time.time() - start_time
|
||||||
return full_text, usage, response_time
|
return full_text, usage, response_time
|
||||||
@@ -396,7 +427,7 @@ class ChatSession:
|
|||||||
if not tool_calls:
|
if not tool_calls:
|
||||||
return response
|
return response
|
||||||
|
|
||||||
console.print(f"\n[dim yellow]🔧 AI requesting {len(tool_calls)} tool call(s)...[/]")
|
# Tool calls requested by AI
|
||||||
|
|
||||||
tool_results = []
|
tool_results = []
|
||||||
for tc in tool_calls:
|
for tc in tool_calls:
|
||||||
@@ -417,15 +448,17 @@ class ChatSession:
|
|||||||
f'{k}="{v}"' if isinstance(v, str) else f"{k}={v}"
|
f'{k}="{v}"' if isinstance(v, str) else f"{k}={v}"
|
||||||
for k, v in args.items()
|
for k, v in args.items()
|
||||||
)
|
)
|
||||||
console.print(f"[dim cyan] → {tc.function.name}({args_display})[/]")
|
# Executing tool: {tc.function.name}
|
||||||
|
|
||||||
# Execute tool
|
# Execute tool
|
||||||
result = asyncio.run(self.execute_tool(tc.function.name, args))
|
result = asyncio.run(self.execute_tool(tc.function.name, args))
|
||||||
|
|
||||||
if "error" in result:
|
if "error" in result:
|
||||||
console.print(f"[dim red] ✗ Error: {result['error']}[/]")
|
# Tool execution error logged
|
||||||
|
pass
|
||||||
else:
|
else:
|
||||||
self._display_tool_success(tc.function.name, result)
|
# Tool execution successful
|
||||||
|
pass
|
||||||
|
|
||||||
tool_results.append({
|
tool_results.append({
|
||||||
"tool_call_id": tc.id,
|
"tool_call_id": tc.id,
|
||||||
@@ -452,38 +485,12 @@ class ChatSession:
|
|||||||
})
|
})
|
||||||
api_messages.extend(tool_results)
|
api_messages.extend(tool_results)
|
||||||
|
|
||||||
console.print("\n[dim cyan]💭 Processing tool results...[/]")
|
# Processing tool results
|
||||||
loop_count += 1
|
loop_count += 1
|
||||||
|
|
||||||
self.logger.warning(f"Reached max tool loops ({max_loops})")
|
self.logger.warning(f"Reached max tool loops ({max_loops})")
|
||||||
console.print(f"[bold yellow]⚠️ Reached maximum tool calls ({max_loops})[/]")
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def _display_tool_success(self, tool_name: str, result: Dict[str, Any]) -> None:
|
|
||||||
"""Display a success message for a tool call."""
|
|
||||||
if tool_name == "search_files":
|
|
||||||
count = result.get("count", 0)
|
|
||||||
console.print(f"[dim green] ✓ Found {count} file(s)[/]")
|
|
||||||
elif tool_name == "read_file":
|
|
||||||
size = result.get("size", 0)
|
|
||||||
truncated = " (truncated)" if result.get("truncated") else ""
|
|
||||||
console.print(f"[dim green] ✓ Read {size} bytes{truncated}[/]")
|
|
||||||
elif tool_name == "list_directory":
|
|
||||||
count = result.get("count", 0)
|
|
||||||
console.print(f"[dim green] ✓ Listed {count} item(s)[/]")
|
|
||||||
elif tool_name == "inspect_database":
|
|
||||||
if "table" in result:
|
|
||||||
console.print(f"[dim green] ✓ Inspected table: {result['table']}[/]")
|
|
||||||
else:
|
|
||||||
console.print(f"[dim green] ✓ Inspected database ({result.get('table_count', 0)} tables)[/]")
|
|
||||||
elif tool_name == "search_database":
|
|
||||||
count = result.get("count", 0)
|
|
||||||
console.print(f"[dim green] ✓ Found {count} match(es)[/]")
|
|
||||||
elif tool_name == "query_database":
|
|
||||||
count = result.get("count", 0)
|
|
||||||
console.print(f"[dim green] ✓ Query returned {count} row(s)[/]")
|
|
||||||
else:
|
|
||||||
console.print("[dim green] ✓ Success[/]")
|
|
||||||
|
|
||||||
def _stream_response(
|
def _stream_response(
|
||||||
self,
|
self,
|
||||||
@@ -492,6 +499,8 @@ class ChatSession:
|
|||||||
max_tokens: Optional[int] = None,
|
max_tokens: Optional[int] = None,
|
||||||
transforms: Optional[List[str]] = None,
|
transforms: Optional[List[str]] = None,
|
||||||
on_chunk: Optional[Callable[[str], None]] = None,
|
on_chunk: Optional[Callable[[str], None]] = None,
|
||||||
|
enable_web_search: bool = False,
|
||||||
|
web_search_config: Optional[Dict[str, Any]] = None,
|
||||||
) -> Tuple[str, Optional[UsageStats]]:
|
) -> Tuple[str, Optional[UsageStats]]:
|
||||||
"""
|
"""
|
||||||
Stream a response with live display.
|
Stream a response with live display.
|
||||||
@@ -502,6 +511,8 @@ class ChatSession:
|
|||||||
max_tokens: Max tokens
|
max_tokens: Max tokens
|
||||||
transforms: Transforms
|
transforms: Transforms
|
||||||
on_chunk: Callback for chunks
|
on_chunk: Callback for chunks
|
||||||
|
enable_web_search: Whether to enable Anthropic native web search
|
||||||
|
web_search_config: Web search configuration
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (full_text, usage)
|
Tuple of (full_text, usage)
|
||||||
@@ -512,6 +523,8 @@ class ChatSession:
|
|||||||
stream=True,
|
stream=True,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
transforms=transforms,
|
transforms=transforms,
|
||||||
|
enable_web_search=enable_web_search,
|
||||||
|
web_search_config=web_search_config or {},
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(response, ChatResponse):
|
if isinstance(response, ChatResponse):
|
||||||
@@ -521,27 +534,339 @@ class ChatSession:
|
|||||||
usage: Optional[UsageStats] = None
|
usage: Optional[UsageStats] = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with Live("", console=console, refresh_per_second=10) as live:
|
for chunk in response:
|
||||||
for chunk in response:
|
if chunk.error:
|
||||||
if chunk.error:
|
self.logger.error(f"Stream error: {chunk.error}")
|
||||||
console.print(f"\n[bold red]Stream error: {chunk.error}[/]")
|
break
|
||||||
break
|
|
||||||
|
|
||||||
if chunk.delta_content:
|
if chunk.delta_content:
|
||||||
full_text += chunk.delta_content
|
full_text += chunk.delta_content
|
||||||
live.update(Markdown(full_text))
|
if on_chunk:
|
||||||
if on_chunk:
|
on_chunk(chunk.delta_content)
|
||||||
on_chunk(chunk.delta_content)
|
|
||||||
|
|
||||||
if chunk.usage:
|
if chunk.usage:
|
||||||
usage = chunk.usage
|
usage = chunk.usage
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
console.print("\n[bold yellow]⚠️ Streaming interrupted[/]")
|
self.logger.info("Streaming interrupted")
|
||||||
return "", None
|
return "", None
|
||||||
|
|
||||||
return full_text, usage
|
return full_text, usage
|
||||||
|
|
||||||
|
# ========== ASYNC METHODS FOR TUI ==========
|
||||||
|
|
||||||
|
async def send_message_async(
|
||||||
|
self,
|
||||||
|
user_input: str,
|
||||||
|
stream: bool = True,
|
||||||
|
) -> AsyncIterator[StreamChunk]:
|
||||||
|
"""
|
||||||
|
Async version of send_message for Textual TUI.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_input: User's input text
|
||||||
|
stream: Whether to stream the response
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
StreamChunk objects for progressive display
|
||||||
|
"""
|
||||||
|
if not self.selected_model:
|
||||||
|
raise ValueError("No model selected")
|
||||||
|
|
||||||
|
messages = self.build_api_messages(user_input)
|
||||||
|
tools = self.get_mcp_tools()
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
# Disable streaming when tools are present
|
||||||
|
stream = False
|
||||||
|
|
||||||
|
# Handle online mode
|
||||||
|
model_id = self.selected_model["id"]
|
||||||
|
enable_web_search = False
|
||||||
|
web_search_config = {}
|
||||||
|
|
||||||
|
if self.online_enabled:
|
||||||
|
# OpenRouter handles online mode natively with :online suffix
|
||||||
|
if self.client.provider_name == "openrouter":
|
||||||
|
if hasattr(self.client.provider, "get_effective_model_id"):
|
||||||
|
model_id = self.client.provider.get_effective_model_id(model_id, True)
|
||||||
|
# Anthropic has native web search when search provider is set to anthropic_native
|
||||||
|
elif self.client.provider_name == "anthropic" and self.settings.search_provider == "anthropic_native":
|
||||||
|
enable_web_search = True
|
||||||
|
web_search_config = {
|
||||||
|
"max_uses": self.settings.search_num_results or 5
|
||||||
|
}
|
||||||
|
logger.info("Using Anthropic native web search")
|
||||||
|
else:
|
||||||
|
# For other providers, perform web search and inject results
|
||||||
|
logger.info(f"Performing web search for: {user_input}")
|
||||||
|
search_results = await asyncio.to_thread(
|
||||||
|
perform_web_search,
|
||||||
|
user_input,
|
||||||
|
num_results=self.settings.search_num_results,
|
||||||
|
provider=self.settings.search_provider,
|
||||||
|
google_api_key=self.settings.google_api_key,
|
||||||
|
google_search_engine_id=self.settings.google_search_engine_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if search_results:
|
||||||
|
# Inject search results into messages
|
||||||
|
formatted_results = format_search_results(search_results)
|
||||||
|
search_context = f"\n\n{formatted_results}\n\nPlease use the above web search results to help answer the user's question."
|
||||||
|
|
||||||
|
# Add search results to the last user message
|
||||||
|
if messages and messages[-1]["role"] == "user":
|
||||||
|
messages[-1]["content"] += search_context
|
||||||
|
|
||||||
|
logger.info(f"Injected {len(search_results)} search results into context")
|
||||||
|
|
||||||
|
transforms = ["middle-out"] if self.middle_out_enabled else None
|
||||||
|
max_tokens = None
|
||||||
|
if self.session_max_token > 0:
|
||||||
|
max_tokens = self.session_max_token
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
# Use async tool handling flow
|
||||||
|
async for chunk in self._send_with_tools_async(
|
||||||
|
messages=messages,
|
||||||
|
model_id=model_id,
|
||||||
|
tools=tools,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
transforms=transforms,
|
||||||
|
):
|
||||||
|
yield chunk
|
||||||
|
elif stream:
|
||||||
|
# Use async streaming flow
|
||||||
|
async for chunk in self._stream_response_async(
|
||||||
|
messages=messages,
|
||||||
|
model_id=model_id,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
transforms=transforms,
|
||||||
|
enable_web_search=enable_web_search,
|
||||||
|
web_search_config=web_search_config,
|
||||||
|
):
|
||||||
|
yield chunk
|
||||||
|
else:
|
||||||
|
# Non-streaming request
|
||||||
|
response = self.client.chat(
|
||||||
|
messages=messages,
|
||||||
|
model=model_id,
|
||||||
|
stream=False,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
transforms=transforms,
|
||||||
|
)
|
||||||
|
if isinstance(response, ChatResponse):
|
||||||
|
# Yield single chunk with complete response
|
||||||
|
chunk = StreamChunk(
|
||||||
|
id="",
|
||||||
|
delta_content=response.content,
|
||||||
|
usage=response.usage,
|
||||||
|
error=None,
|
||||||
|
)
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
async def _send_with_tools_async(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, Any]],
|
||||||
|
model_id: str,
|
||||||
|
tools: List[Dict[str, Any]],
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
transforms: Optional[List[str]] = None,
|
||||||
|
) -> AsyncIterator[StreamChunk]:
|
||||||
|
"""
|
||||||
|
Async version of _send_with_tools for TUI.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: API messages
|
||||||
|
model_id: Model ID
|
||||||
|
tools: Tool definitions
|
||||||
|
max_tokens: Max tokens
|
||||||
|
transforms: Transforms list
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
StreamChunk objects including tool call notifications
|
||||||
|
"""
|
||||||
|
max_loops = 5
|
||||||
|
loop_count = 0
|
||||||
|
api_messages = list(messages)
|
||||||
|
|
||||||
|
while loop_count < max_loops:
|
||||||
|
response = self.client.chat(
|
||||||
|
messages=api_messages,
|
||||||
|
model=model_id,
|
||||||
|
stream=False,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice="auto",
|
||||||
|
transforms=transforms,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(response, ChatResponse):
|
||||||
|
raise ValueError("Expected ChatResponse")
|
||||||
|
|
||||||
|
tool_calls = response.tool_calls
|
||||||
|
if not tool_calls:
|
||||||
|
# Final response, yield it
|
||||||
|
chunk = StreamChunk(
|
||||||
|
id="",
|
||||||
|
delta_content=response.content,
|
||||||
|
usage=response.usage,
|
||||||
|
error=None,
|
||||||
|
)
|
||||||
|
yield chunk
|
||||||
|
return
|
||||||
|
|
||||||
|
# Yield notification about tool calls
|
||||||
|
tool_notification = f"\n🔧 AI requesting {len(tool_calls)} tool call(s)...\n"
|
||||||
|
yield StreamChunk(id="", delta_content=tool_notification, usage=None, error=None)
|
||||||
|
|
||||||
|
tool_results = []
|
||||||
|
for tc in tool_calls:
|
||||||
|
try:
|
||||||
|
args = json.loads(tc.function.arguments)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
self.logger.error(f"Failed to parse tool arguments: {e}")
|
||||||
|
tool_results.append({
|
||||||
|
"tool_call_id": tc.id,
|
||||||
|
"role": "tool",
|
||||||
|
"name": tc.function.name,
|
||||||
|
"content": json.dumps({"error": f"Invalid arguments: {e}"}),
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Yield tool call display
|
||||||
|
args_display = ", ".join(
|
||||||
|
f'{k}="{v}"' if isinstance(v, str) else f"{k}={v}"
|
||||||
|
for k, v in args.items()
|
||||||
|
)
|
||||||
|
tool_display = f" → {tc.function.name}({args_display})\n"
|
||||||
|
yield StreamChunk(id="", delta_content=tool_display, usage=None, error=None)
|
||||||
|
|
||||||
|
# Execute tool (await instead of asyncio.run)
|
||||||
|
result = await self.execute_tool(tc.function.name, args)
|
||||||
|
|
||||||
|
if "error" in result:
|
||||||
|
error_msg = f" ✗ Error: {result['error']}\n"
|
||||||
|
yield StreamChunk(id="", delta_content=error_msg, usage=None, error=None)
|
||||||
|
else:
|
||||||
|
success_msg = self._format_tool_success(tc.function.name, result)
|
||||||
|
yield StreamChunk(id="", delta_content=success_msg, usage=None, error=None)
|
||||||
|
|
||||||
|
tool_results.append({
|
||||||
|
"tool_call_id": tc.id,
|
||||||
|
"role": "tool",
|
||||||
|
"name": tc.function.name,
|
||||||
|
"content": json.dumps(result),
|
||||||
|
})
|
||||||
|
|
||||||
|
# Add assistant message with tool calls
|
||||||
|
api_messages.append({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": response.content,
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": tc.id,
|
||||||
|
"type": tc.type,
|
||||||
|
"function": {
|
||||||
|
"name": tc.function.name,
|
||||||
|
"arguments": tc.function.arguments,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for tc in tool_calls
|
||||||
|
],
|
||||||
|
})
|
||||||
|
|
||||||
|
# Add tool results
|
||||||
|
api_messages.extend(tool_results)
|
||||||
|
loop_count += 1
|
||||||
|
|
||||||
|
# Max loops reached
|
||||||
|
yield StreamChunk(
|
||||||
|
id="",
|
||||||
|
delta_content="\n⚠️ Maximum tool call loops reached\n",
|
||||||
|
usage=None,
|
||||||
|
error="Max loops reached"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _format_tool_success(self, tool_name: str, result: Dict[str, Any]) -> str:
|
||||||
|
"""Format a success message for a tool call."""
|
||||||
|
if tool_name == "search_files":
|
||||||
|
count = result.get("count", 0)
|
||||||
|
return f" ✓ Found {count} file(s)\n"
|
||||||
|
elif tool_name == "read_file":
|
||||||
|
size = result.get("size", 0)
|
||||||
|
truncated = " (truncated)" if result.get("truncated") else ""
|
||||||
|
return f" ✓ Read {size} bytes{truncated}\n"
|
||||||
|
elif tool_name == "list_directory":
|
||||||
|
count = result.get("count", 0)
|
||||||
|
return f" ✓ Listed {count} item(s)\n"
|
||||||
|
elif tool_name == "inspect_database":
|
||||||
|
if "table" in result:
|
||||||
|
return f" ✓ Inspected table: {result['table']}\n"
|
||||||
|
else:
|
||||||
|
return f" ✓ Inspected database ({result.get('table_count', 0)} tables)\n"
|
||||||
|
elif tool_name == "search_database":
|
||||||
|
count = result.get("count", 0)
|
||||||
|
return f" ✓ Found {count} match(es)\n"
|
||||||
|
elif tool_name == "query_database":
|
||||||
|
count = result.get("count", 0)
|
||||||
|
return f" ✓ Query returned {count} row(s)\n"
|
||||||
|
else:
|
||||||
|
return " ✓ Success\n"
|
||||||
|
|
||||||
|
async def _stream_response_async(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, Any]],
|
||||||
|
model_id: str,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
transforms: Optional[List[str]] = None,
|
||||||
|
enable_web_search: bool = False,
|
||||||
|
web_search_config: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> AsyncIterator[StreamChunk]:
|
||||||
|
"""
|
||||||
|
Async version of _stream_response for TUI.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: API messages
|
||||||
|
model_id: Model ID
|
||||||
|
max_tokens: Max tokens
|
||||||
|
transforms: Transforms
|
||||||
|
enable_web_search: Whether to enable Anthropic native web search
|
||||||
|
web_search_config: Web search configuration
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
StreamChunk objects
|
||||||
|
"""
|
||||||
|
response = self.client.chat(
|
||||||
|
messages=messages,
|
||||||
|
model=model_id,
|
||||||
|
stream=True,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
transforms=transforms,
|
||||||
|
enable_web_search=enable_web_search,
|
||||||
|
web_search_config=web_search_config or {},
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(response, ChatResponse):
|
||||||
|
# Non-streaming response
|
||||||
|
chunk = StreamChunk(
|
||||||
|
id="",
|
||||||
|
delta_content=response.content,
|
||||||
|
usage=response.usage,
|
||||||
|
error=None,
|
||||||
|
)
|
||||||
|
yield chunk
|
||||||
|
return
|
||||||
|
|
||||||
|
# Stream the response
|
||||||
|
for chunk in response:
|
||||||
|
if chunk.error:
|
||||||
|
yield StreamChunk(id="", delta_content=None, usage=None, error=chunk.error)
|
||||||
|
break
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
# ========== END ASYNC METHODS ==========
|
||||||
|
|
||||||
def add_to_history(
|
def add_to_history(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
|
|||||||
@@ -16,6 +16,16 @@ from oai.providers.base import (
|
|||||||
ProviderCapabilities,
|
ProviderCapabilities,
|
||||||
)
|
)
|
||||||
from oai.providers.openrouter import OpenRouterProvider
|
from oai.providers.openrouter import OpenRouterProvider
|
||||||
|
from oai.providers.anthropic import AnthropicProvider
|
||||||
|
from oai.providers.openai import OpenAIProvider
|
||||||
|
from oai.providers.ollama import OllamaProvider
|
||||||
|
from oai.providers.registry import register_provider
|
||||||
|
|
||||||
|
# Register all providers
|
||||||
|
register_provider("openrouter", OpenRouterProvider)
|
||||||
|
register_provider("anthropic", AnthropicProvider)
|
||||||
|
register_provider("openai", OpenAIProvider)
|
||||||
|
register_provider("ollama", OllamaProvider)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Base classes and types
|
# Base classes and types
|
||||||
@@ -29,4 +39,7 @@ __all__ = [
|
|||||||
"ProviderCapabilities",
|
"ProviderCapabilities",
|
||||||
# Provider implementations
|
# Provider implementations
|
||||||
"OpenRouterProvider",
|
"OpenRouterProvider",
|
||||||
|
"AnthropicProvider",
|
||||||
|
"OpenAIProvider",
|
||||||
|
"OllamaProvider",
|
||||||
]
|
]
|
||||||
|
|||||||
673
oai/providers/anthropic.py
Normal file
673
oai/providers/anthropic.py
Normal file
@@ -0,0 +1,673 @@
|
|||||||
|
"""
|
||||||
|
Anthropic provider for Claude models.
|
||||||
|
|
||||||
|
This provider connects to Anthropic's API for accessing Claude models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
|
||||||
|
|
||||||
|
import anthropic
|
||||||
|
from anthropic.types import Message, MessageStreamEvent
|
||||||
|
|
||||||
|
from oai.constants import ANTHROPIC_BASE_URL
|
||||||
|
from oai.providers.base import (
|
||||||
|
AIProvider,
|
||||||
|
ChatMessage,
|
||||||
|
ChatResponse,
|
||||||
|
ChatResponseChoice,
|
||||||
|
ModelInfo,
|
||||||
|
ProviderCapabilities,
|
||||||
|
StreamChunk,
|
||||||
|
ToolCall,
|
||||||
|
ToolFunction,
|
||||||
|
UsageStats,
|
||||||
|
)
|
||||||
|
from oai.utils.logging import get_logger
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
# Model name aliases
|
||||||
|
MODEL_ALIASES = {
|
||||||
|
"claude-sonnet": "claude-sonnet-4-5-20250929",
|
||||||
|
"claude-haiku": "claude-haiku-4-5-20251001",
|
||||||
|
"claude-opus": "claude-opus-4-5-20251101",
|
||||||
|
# Legacy aliases
|
||||||
|
"claude-3-haiku": "claude-3-haiku-20240307",
|
||||||
|
"claude-3-7-sonnet": "claude-3-7-sonnet-20250219",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class AnthropicProvider(AIProvider):
|
||||||
|
"""
|
||||||
|
Anthropic API provider.
|
||||||
|
|
||||||
|
Provides access to Claude 3.5 Sonnet, Claude 3 Opus, and other Anthropic models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: str,
|
||||||
|
base_url: Optional[str] = None,
|
||||||
|
app_name: str = "oAI",
|
||||||
|
app_url: str = "",
|
||||||
|
**kwargs: Any,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize Anthropic provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: Anthropic API key
|
||||||
|
base_url: Optional custom base URL
|
||||||
|
app_name: Application name (for headers)
|
||||||
|
app_url: Application URL (for headers)
|
||||||
|
**kwargs: Additional arguments
|
||||||
|
"""
|
||||||
|
super().__init__(api_key, base_url or ANTHROPIC_BASE_URL)
|
||||||
|
self.client = anthropic.Anthropic(api_key=api_key)
|
||||||
|
self.async_client = anthropic.AsyncAnthropic(api_key=api_key)
|
||||||
|
self._models_cache: Optional[List[ModelInfo]] = None
|
||||||
|
|
||||||
|
def _create_web_search_tool(self, config: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Create Anthropic native web search tool definition.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Optional configuration for web search (max_uses, allowed_domains, etc.)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tool definition dict
|
||||||
|
"""
|
||||||
|
tool: Dict[str, Any] = {
|
||||||
|
"type": "web_search_20250305",
|
||||||
|
"name": "web_search",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add optional parameters if provided
|
||||||
|
if "max_uses" in config:
|
||||||
|
tool["max_uses"] = config["max_uses"]
|
||||||
|
else:
|
||||||
|
tool["max_uses"] = 5 # Default
|
||||||
|
|
||||||
|
if "allowed_domains" in config:
|
||||||
|
tool["allowed_domains"] = config["allowed_domains"]
|
||||||
|
|
||||||
|
if "blocked_domains" in config:
|
||||||
|
tool["blocked_domains"] = config["blocked_domains"]
|
||||||
|
|
||||||
|
if "user_location" in config:
|
||||||
|
tool["user_location"] = config["user_location"]
|
||||||
|
|
||||||
|
return tool
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Get provider name."""
|
||||||
|
return "Anthropic"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def capabilities(self) -> ProviderCapabilities:
|
||||||
|
"""Get provider capabilities."""
|
||||||
|
return ProviderCapabilities(
|
||||||
|
streaming=True,
|
||||||
|
tools=True,
|
||||||
|
images=True,
|
||||||
|
online=True, # Web search via DuckDuckGo/Google
|
||||||
|
max_context=200000,
|
||||||
|
)
|
||||||
|
|
||||||
|
def list_models(self, filter_text_only: bool = True) -> List[ModelInfo]:
|
||||||
|
"""
|
||||||
|
List available Anthropic models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filter_text_only: Whether to filter for text models only
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of ModelInfo objects
|
||||||
|
"""
|
||||||
|
if self._models_cache:
|
||||||
|
return self._models_cache
|
||||||
|
|
||||||
|
# Anthropic doesn't have a models list API, so we hardcode the available models
|
||||||
|
models = [
|
||||||
|
# Current Claude 4.5 models
|
||||||
|
ModelInfo(
|
||||||
|
id="claude-sonnet-4-5-20250929",
|
||||||
|
name="Claude Sonnet 4.5",
|
||||||
|
description="Smart model for complex agents and coding (recommended)",
|
||||||
|
context_length=200000,
|
||||||
|
pricing={"input": 3.0, "output": 15.0},
|
||||||
|
supported_parameters=["temperature", "max_tokens", "stream", "tools"],
|
||||||
|
input_modalities=["text", "image"],
|
||||||
|
),
|
||||||
|
ModelInfo(
|
||||||
|
id="claude-haiku-4-5-20251001",
|
||||||
|
name="Claude Haiku 4.5",
|
||||||
|
description="Fastest model with near-frontier intelligence",
|
||||||
|
context_length=200000,
|
||||||
|
pricing={"input": 1.0, "output": 5.0},
|
||||||
|
supported_parameters=["temperature", "max_tokens", "stream", "tools"],
|
||||||
|
input_modalities=["text", "image"],
|
||||||
|
),
|
||||||
|
ModelInfo(
|
||||||
|
id="claude-opus-4-5-20251101",
|
||||||
|
name="Claude Opus 4.5",
|
||||||
|
description="Premium model with maximum intelligence",
|
||||||
|
context_length=200000,
|
||||||
|
pricing={"input": 5.0, "output": 25.0},
|
||||||
|
supported_parameters=["temperature", "max_tokens", "stream", "tools"],
|
||||||
|
input_modalities=["text", "image"],
|
||||||
|
),
|
||||||
|
# Legacy models (still available)
|
||||||
|
ModelInfo(
|
||||||
|
id="claude-3-7-sonnet-20250219",
|
||||||
|
name="Claude Sonnet 3.7",
|
||||||
|
description="Legacy model - recommend migrating to 4.5",
|
||||||
|
context_length=200000,
|
||||||
|
pricing={"input": 3.0, "output": 15.0},
|
||||||
|
supported_parameters=["temperature", "max_tokens", "stream", "tools"],
|
||||||
|
input_modalities=["text", "image"],
|
||||||
|
),
|
||||||
|
ModelInfo(
|
||||||
|
id="claude-3-haiku-20240307",
|
||||||
|
name="Claude 3 Haiku",
|
||||||
|
description="Legacy fast model - recommend migrating to 4.5",
|
||||||
|
context_length=200000,
|
||||||
|
pricing={"input": 0.25, "output": 1.25},
|
||||||
|
supported_parameters=["temperature", "max_tokens", "stream", "tools"],
|
||||||
|
input_modalities=["text", "image"],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
self._models_cache = models
|
||||||
|
logger.info(f"Loaded {len(models)} Anthropic models")
|
||||||
|
return models
|
||||||
|
|
||||||
|
def get_model(self, model_id: str) -> Optional[ModelInfo]:
|
||||||
|
"""
|
||||||
|
Get information about a specific model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: Model identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ModelInfo or None
|
||||||
|
"""
|
||||||
|
# Resolve alias
|
||||||
|
resolved_id = MODEL_ALIASES.get(model_id, model_id)
|
||||||
|
|
||||||
|
models = self.list_models()
|
||||||
|
for model in models:
|
||||||
|
if model.id == resolved_id or model.id == model_id:
|
||||||
|
return model
|
||||||
|
return None
|
||||||
|
|
||||||
|
def chat(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[ChatMessage],
|
||||||
|
stream: bool = False,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
tools: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
tool_choice: Optional[str] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Union[ChatResponse, Iterator[StreamChunk]]:
|
||||||
|
"""
|
||||||
|
Send chat completion request to Anthropic.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model ID
|
||||||
|
messages: Chat messages
|
||||||
|
stream: Whether to stream response
|
||||||
|
max_tokens: Maximum tokens
|
||||||
|
temperature: Sampling temperature
|
||||||
|
tools: Tool definitions
|
||||||
|
tool_choice: Tool selection mode
|
||||||
|
**kwargs: Additional parameters (including enable_web_search)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ChatResponse or Iterator[StreamChunk]
|
||||||
|
"""
|
||||||
|
# Resolve model alias
|
||||||
|
model_id = MODEL_ALIASES.get(model, model)
|
||||||
|
|
||||||
|
# Extract system message (Anthropic requires it separate from messages)
|
||||||
|
system_prompt, anthropic_messages = self._convert_messages(messages)
|
||||||
|
|
||||||
|
# Build request parameters
|
||||||
|
params: Dict[str, Any] = {
|
||||||
|
"model": model_id,
|
||||||
|
"messages": anthropic_messages,
|
||||||
|
"max_tokens": max_tokens or 4096,
|
||||||
|
}
|
||||||
|
|
||||||
|
if system_prompt:
|
||||||
|
params["system"] = system_prompt
|
||||||
|
|
||||||
|
if temperature is not None:
|
||||||
|
params["temperature"] = temperature
|
||||||
|
|
||||||
|
# Prepare tools list
|
||||||
|
tools_list = []
|
||||||
|
|
||||||
|
# Add web search tool if requested via kwargs
|
||||||
|
if kwargs.get("enable_web_search", False):
|
||||||
|
web_search_config = kwargs.get("web_search_config", {})
|
||||||
|
tools_list.append(self._create_web_search_tool(web_search_config))
|
||||||
|
logger.info("Added Anthropic native web search tool")
|
||||||
|
|
||||||
|
# Add user-provided tools
|
||||||
|
if tools:
|
||||||
|
# Convert tools to Anthropic format
|
||||||
|
converted_tools = self._convert_tools(tools)
|
||||||
|
tools_list.extend(converted_tools)
|
||||||
|
|
||||||
|
if tools_list:
|
||||||
|
params["tools"] = tools_list
|
||||||
|
|
||||||
|
if tool_choice and tool_choice != "auto":
|
||||||
|
# Anthropic uses different tool_choice format
|
||||||
|
if tool_choice == "none":
|
||||||
|
pass # Don't include tools
|
||||||
|
elif tool_choice == "required":
|
||||||
|
params["tool_choice"] = {"type": "any"}
|
||||||
|
else:
|
||||||
|
params["tool_choice"] = {"type": "tool", "name": tool_choice}
|
||||||
|
|
||||||
|
logger.debug(f"Anthropic request: model={model_id}, messages={len(anthropic_messages)}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
if stream:
|
||||||
|
return self._stream_chat(params)
|
||||||
|
else:
|
||||||
|
return self._sync_chat(params)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Anthropic request failed: {e}")
|
||||||
|
return ChatResponse(
|
||||||
|
id="error",
|
||||||
|
choices=[
|
||||||
|
ChatResponseChoice(
|
||||||
|
index=0,
|
||||||
|
message=ChatMessage(role="assistant", content=f"Error: {str(e)}"),
|
||||||
|
finish_reason="error",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
def _convert_messages(self, messages: List[ChatMessage]) -> tuple[str, List[Dict[str, Any]]]:
|
||||||
|
"""
|
||||||
|
Convert messages to Anthropic format.
|
||||||
|
|
||||||
|
Anthropic requires system messages to be separate from the conversation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of ChatMessage objects
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (system_prompt, anthropic_messages)
|
||||||
|
"""
|
||||||
|
system_prompt = ""
|
||||||
|
anthropic_messages = []
|
||||||
|
|
||||||
|
for msg in messages:
|
||||||
|
if msg.role == "system":
|
||||||
|
# Accumulate system messages
|
||||||
|
if system_prompt:
|
||||||
|
system_prompt += "\n\n"
|
||||||
|
system_prompt += msg.content or ""
|
||||||
|
else:
|
||||||
|
# Convert to Anthropic format
|
||||||
|
message_dict: Dict[str, Any] = {"role": msg.role}
|
||||||
|
|
||||||
|
# Handle content
|
||||||
|
if msg.content:
|
||||||
|
message_dict["content"] = msg.content
|
||||||
|
|
||||||
|
# Handle tool calls (assistant messages)
|
||||||
|
if msg.tool_calls:
|
||||||
|
# Anthropic format for tool use
|
||||||
|
content_blocks = []
|
||||||
|
if msg.content:
|
||||||
|
content_blocks.append({"type": "text", "text": msg.content})
|
||||||
|
|
||||||
|
for tc in msg.tool_calls:
|
||||||
|
content_blocks.append({
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": tc.id,
|
||||||
|
"name": tc.function.name,
|
||||||
|
"input": json.loads(tc.function.arguments),
|
||||||
|
})
|
||||||
|
|
||||||
|
message_dict["content"] = content_blocks
|
||||||
|
|
||||||
|
# Handle tool results (tool messages)
|
||||||
|
if msg.role == "tool" and msg.tool_call_id:
|
||||||
|
# Convert to Anthropic's tool_result format
|
||||||
|
anthropic_messages.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": [{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": msg.tool_call_id,
|
||||||
|
"content": msg.content or "",
|
||||||
|
}]
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
|
||||||
|
anthropic_messages.append(message_dict)
|
||||||
|
|
||||||
|
return system_prompt, anthropic_messages
|
||||||
|
|
||||||
|
def _convert_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Convert OpenAI-style tools to Anthropic format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tools: OpenAI tool definitions
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Anthropic tool definitions
|
||||||
|
"""
|
||||||
|
anthropic_tools = []
|
||||||
|
|
||||||
|
for tool in tools:
|
||||||
|
if tool.get("type") == "function":
|
||||||
|
func = tool.get("function", {})
|
||||||
|
anthropic_tools.append({
|
||||||
|
"name": func.get("name"),
|
||||||
|
"description": func.get("description", ""),
|
||||||
|
"input_schema": func.get("parameters", {}),
|
||||||
|
})
|
||||||
|
|
||||||
|
return anthropic_tools
|
||||||
|
|
||||||
|
def _sync_chat(self, params: Dict[str, Any]) -> ChatResponse:
|
||||||
|
"""
|
||||||
|
Send synchronous chat request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params: Request parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ChatResponse
|
||||||
|
"""
|
||||||
|
message: Message = self.client.messages.create(**params)
|
||||||
|
|
||||||
|
# Extract content
|
||||||
|
content = ""
|
||||||
|
tool_calls = []
|
||||||
|
|
||||||
|
for block in message.content:
|
||||||
|
if block.type == "text":
|
||||||
|
content += block.text
|
||||||
|
elif block.type == "tool_use":
|
||||||
|
# Convert to ToolCall format
|
||||||
|
tool_calls.append(
|
||||||
|
ToolCall(
|
||||||
|
id=block.id,
|
||||||
|
type="function",
|
||||||
|
function=ToolFunction(
|
||||||
|
name=block.name,
|
||||||
|
arguments=json.dumps(block.input),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build ChatMessage
|
||||||
|
chat_message = ChatMessage(
|
||||||
|
role="assistant",
|
||||||
|
content=content if content else None,
|
||||||
|
tool_calls=tool_calls if tool_calls else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract usage
|
||||||
|
usage = None
|
||||||
|
if message.usage:
|
||||||
|
usage = UsageStats(
|
||||||
|
prompt_tokens=message.usage.input_tokens,
|
||||||
|
completion_tokens=message.usage.output_tokens,
|
||||||
|
total_tokens=message.usage.input_tokens + message.usage.output_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ChatResponse(
|
||||||
|
id=message.id,
|
||||||
|
choices=[
|
||||||
|
ChatResponseChoice(
|
||||||
|
index=0,
|
||||||
|
message=chat_message,
|
||||||
|
finish_reason=message.stop_reason,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
usage=usage,
|
||||||
|
model=message.model,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _stream_chat(self, params: Dict[str, Any]) -> Iterator[StreamChunk]:
|
||||||
|
"""
|
||||||
|
Stream chat response from Anthropic.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params: Request parameters
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
StreamChunk objects
|
||||||
|
"""
|
||||||
|
stream = self.client.messages.stream(**params)
|
||||||
|
|
||||||
|
with stream as event_stream:
|
||||||
|
for event in event_stream:
|
||||||
|
event_data: MessageStreamEvent = event
|
||||||
|
|
||||||
|
# Handle different event types
|
||||||
|
if event_data.type == "content_block_delta":
|
||||||
|
delta = event_data.delta
|
||||||
|
if hasattr(delta, "text"):
|
||||||
|
yield StreamChunk(
|
||||||
|
id="stream",
|
||||||
|
delta_content=delta.text,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif event_data.type == "message_stop":
|
||||||
|
# Final event with usage
|
||||||
|
pass
|
||||||
|
|
||||||
|
elif event_data.type == "message_delta":
|
||||||
|
# Contains stop reason and usage
|
||||||
|
usage = None
|
||||||
|
if hasattr(event_data, "usage"):
|
||||||
|
usage_data = event_data.usage
|
||||||
|
usage = UsageStats(
|
||||||
|
prompt_tokens=0,
|
||||||
|
completion_tokens=usage_data.output_tokens,
|
||||||
|
total_tokens=usage_data.output_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield StreamChunk(
|
||||||
|
id="stream",
|
||||||
|
finish_reason=event_data.delta.stop_reason if hasattr(event_data.delta, "stop_reason") else None,
|
||||||
|
usage=usage,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def chat_async(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[ChatMessage],
|
||||||
|
stream: bool = False,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
tools: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
tool_choice: Optional[str] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Union[ChatResponse, AsyncIterator[StreamChunk]]:
|
||||||
|
"""
|
||||||
|
Send async chat request to Anthropic.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model ID
|
||||||
|
messages: Chat messages
|
||||||
|
stream: Whether to stream
|
||||||
|
max_tokens: Max tokens
|
||||||
|
temperature: Temperature
|
||||||
|
tools: Tool definitions
|
||||||
|
tool_choice: Tool choice
|
||||||
|
**kwargs: Additional args
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ChatResponse or AsyncIterator[StreamChunk]
|
||||||
|
"""
|
||||||
|
# Resolve model alias
|
||||||
|
model_id = MODEL_ALIASES.get(model, model)
|
||||||
|
|
||||||
|
# Convert messages
|
||||||
|
system_prompt, anthropic_messages = self._convert_messages(messages)
|
||||||
|
|
||||||
|
# Build params
|
||||||
|
params: Dict[str, Any] = {
|
||||||
|
"model": model_id,
|
||||||
|
"messages": anthropic_messages,
|
||||||
|
"max_tokens": max_tokens or 4096,
|
||||||
|
}
|
||||||
|
|
||||||
|
if system_prompt:
|
||||||
|
params["system"] = system_prompt
|
||||||
|
if temperature is not None:
|
||||||
|
params["temperature"] = temperature
|
||||||
|
if tools:
|
||||||
|
params["tools"] = self._convert_tools(tools)
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
return self._stream_chat_async(params)
|
||||||
|
else:
|
||||||
|
message = await self.async_client.messages.create(**params)
|
||||||
|
return self._convert_message(message)
|
||||||
|
|
||||||
|
async def _stream_chat_async(self, params: Dict[str, Any]) -> AsyncIterator[StreamChunk]:
|
||||||
|
"""
|
||||||
|
Stream async chat response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params: Request parameters
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
StreamChunk objects
|
||||||
|
"""
|
||||||
|
stream = await self.async_client.messages.stream(**params)
|
||||||
|
|
||||||
|
async with stream as event_stream:
|
||||||
|
async for event in event_stream:
|
||||||
|
if event.type == "content_block_delta":
|
||||||
|
delta = event.delta
|
||||||
|
if hasattr(delta, "text"):
|
||||||
|
yield StreamChunk(
|
||||||
|
id="stream",
|
||||||
|
delta_content=delta.text,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _convert_message(self, message: Message) -> ChatResponse:
|
||||||
|
"""Helper to convert Anthropic message to ChatResponse."""
|
||||||
|
content = ""
|
||||||
|
tool_calls = []
|
||||||
|
|
||||||
|
for block in message.content:
|
||||||
|
if block.type == "text":
|
||||||
|
content += block.text
|
||||||
|
elif block.type == "tool_use":
|
||||||
|
tool_calls.append(
|
||||||
|
ToolCall(
|
||||||
|
id=block.id,
|
||||||
|
type="function",
|
||||||
|
function=ToolFunction(
|
||||||
|
name=block.name,
|
||||||
|
arguments=json.dumps(block.input),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
chat_message = ChatMessage(
|
||||||
|
role="assistant",
|
||||||
|
content=content if content else None,
|
||||||
|
tool_calls=tool_calls if tool_calls else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
usage = None
|
||||||
|
if message.usage:
|
||||||
|
usage = UsageStats(
|
||||||
|
prompt_tokens=message.usage.input_tokens,
|
||||||
|
completion_tokens=message.usage.output_tokens,
|
||||||
|
total_tokens=message.usage.input_tokens + message.usage.output_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ChatResponse(
|
||||||
|
id=message.id,
|
||||||
|
choices=[
|
||||||
|
ChatResponseChoice(
|
||||||
|
index=0,
|
||||||
|
message=chat_message,
|
||||||
|
finish_reason=message.stop_reason,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
usage=usage,
|
||||||
|
model=message.model,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_credits(self) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get account credits from Anthropic.
|
||||||
|
|
||||||
|
Note: Anthropic does not currently provide a public API endpoint
|
||||||
|
for checking account credits/balance. This information is only
|
||||||
|
available through the Anthropic Console web interface.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None (credits API not available)
|
||||||
|
"""
|
||||||
|
# Anthropic doesn't provide a public credits API endpoint
|
||||||
|
# Users must check their balance at console.anthropic.com
|
||||||
|
return None
|
||||||
|
|
||||||
|
def clear_cache(self) -> None:
|
||||||
|
"""Clear model cache."""
|
||||||
|
self._models_cache = None
|
||||||
|
|
||||||
|
def get_raw_models(self) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get raw model data as dictionaries.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of model dictionaries
|
||||||
|
"""
|
||||||
|
models = self.list_models()
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"id": model.id,
|
||||||
|
"name": model.name,
|
||||||
|
"description": model.description,
|
||||||
|
"context_length": model.context_length,
|
||||||
|
"pricing": model.pricing,
|
||||||
|
}
|
||||||
|
for model in models
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_raw_model(self, model_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get raw model data for a specific model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: Model identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Model dictionary or None
|
||||||
|
"""
|
||||||
|
model = self.get_model(model_id)
|
||||||
|
if model:
|
||||||
|
return {
|
||||||
|
"id": model.id,
|
||||||
|
"name": model.name,
|
||||||
|
"description": model.description,
|
||||||
|
"context_length": model.context_length,
|
||||||
|
"pricing": model.pricing,
|
||||||
|
}
|
||||||
|
return None
|
||||||
423
oai/providers/ollama.py
Normal file
423
oai/providers/ollama.py
Normal file
@@ -0,0 +1,423 @@
|
|||||||
|
"""
|
||||||
|
Ollama provider for local AI model serving.
|
||||||
|
|
||||||
|
This provider connects to a local Ollama server for running models
|
||||||
|
locally without API keys or external dependencies.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from oai.constants import OLLAMA_DEFAULT_URL
|
||||||
|
from oai.providers.base import (
|
||||||
|
AIProvider,
|
||||||
|
ChatMessage,
|
||||||
|
ChatResponse,
|
||||||
|
ChatResponseChoice,
|
||||||
|
ModelInfo,
|
||||||
|
ProviderCapabilities,
|
||||||
|
StreamChunk,
|
||||||
|
UsageStats,
|
||||||
|
)
|
||||||
|
from oai.utils.logging import get_logger
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaProvider(AIProvider):
|
||||||
|
"""
|
||||||
|
Ollama local model provider.
|
||||||
|
|
||||||
|
Connects to a local Ollama server for running models locally.
|
||||||
|
No API key required.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: str = "",
|
||||||
|
base_url: Optional[str] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize Ollama provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: Not used (Ollama doesn't require API keys)
|
||||||
|
base_url: Ollama server URL (default: http://localhost:11434)
|
||||||
|
**kwargs: Additional arguments (ignored)
|
||||||
|
"""
|
||||||
|
super().__init__(api_key or "", base_url)
|
||||||
|
self.base_url = base_url or OLLAMA_DEFAULT_URL
|
||||||
|
self._check_server_available()
|
||||||
|
|
||||||
|
def _check_server_available(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check if Ollama server is accessible.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if server is accessible
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response = requests.get(f"{self.base_url}/api/tags", timeout=2)
|
||||||
|
if response.ok:
|
||||||
|
logger.info(f"Ollama server accessible at {self.base_url}")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.warning(f"Ollama server returned status {response.status_code}")
|
||||||
|
return False
|
||||||
|
except requests.RequestException as e:
|
||||||
|
logger.warning(f"Ollama server not accessible at {self.base_url}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Get provider name."""
|
||||||
|
return "Ollama"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def capabilities(self) -> ProviderCapabilities:
|
||||||
|
"""Get provider capabilities."""
|
||||||
|
return ProviderCapabilities(
|
||||||
|
streaming=True,
|
||||||
|
tools=False, # Tool support varies by model
|
||||||
|
images=False, # Image support varies by model
|
||||||
|
online=True, # Web search via DuckDuckGo/Google
|
||||||
|
max_context=8192, # Varies by model
|
||||||
|
)
|
||||||
|
|
||||||
|
def list_models(self, filter_text_only: bool = True) -> List[ModelInfo]:
|
||||||
|
"""
|
||||||
|
List models from local Ollama installation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filter_text_only: Ignored for Ollama
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of available models
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response = requests.get(f"{self.base_url}/api/tags", timeout=5)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
models = []
|
||||||
|
for model_data in data.get("models", []):
|
||||||
|
models.append(self._parse_model(model_data))
|
||||||
|
|
||||||
|
logger.info(f"Found {len(models)} Ollama models")
|
||||||
|
return models
|
||||||
|
|
||||||
|
except requests.RequestException as e:
|
||||||
|
logger.error(f"Failed to list Ollama models: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _parse_model(self, model_data: Dict[str, Any]) -> ModelInfo:
|
||||||
|
"""
|
||||||
|
Parse Ollama model data into ModelInfo.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_data: Raw model data from Ollama API
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ModelInfo object
|
||||||
|
"""
|
||||||
|
model_name = model_data.get("name", "unknown")
|
||||||
|
size_bytes = model_data.get("size", 0)
|
||||||
|
size_gb = size_bytes / (1024 ** 3) if size_bytes else 0
|
||||||
|
|
||||||
|
return ModelInfo(
|
||||||
|
id=model_name,
|
||||||
|
name=model_name,
|
||||||
|
description=f"Size: {size_gb:.1f}GB",
|
||||||
|
context_length=8192, # Default, varies by model
|
||||||
|
pricing={}, # Local models are free
|
||||||
|
supported_parameters=["stream", "temperature", "max_tokens"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_model(self, model_id: str) -> Optional[ModelInfo]:
|
||||||
|
"""
|
||||||
|
Get information about a specific model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: Model identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ModelInfo or None if not found
|
||||||
|
"""
|
||||||
|
models = self.list_models()
|
||||||
|
for model in models:
|
||||||
|
if model.id == model_id:
|
||||||
|
return model
|
||||||
|
return None
|
||||||
|
|
||||||
|
def chat(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[ChatMessage],
|
||||||
|
stream: bool = False,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
tools: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
tool_choice: Optional[str] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Union[ChatResponse, Iterator[StreamChunk]]:
|
||||||
|
"""
|
||||||
|
Send chat request to Ollama.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model name
|
||||||
|
messages: Chat messages
|
||||||
|
stream: Whether to stream response
|
||||||
|
max_tokens: Maximum tokens (Ollama calls this num_predict)
|
||||||
|
temperature: Sampling temperature
|
||||||
|
tools: Not supported
|
||||||
|
tool_choice: Not supported
|
||||||
|
**kwargs: Additional parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ChatResponse or Iterator[StreamChunk]
|
||||||
|
"""
|
||||||
|
# Convert messages to Ollama format
|
||||||
|
ollama_messages = []
|
||||||
|
for msg in messages:
|
||||||
|
ollama_messages.append({
|
||||||
|
"role": msg.role,
|
||||||
|
"content": msg.content or "",
|
||||||
|
})
|
||||||
|
|
||||||
|
# Build request payload
|
||||||
|
payload: Dict[str, Any] = {
|
||||||
|
"model": model,
|
||||||
|
"messages": ollama_messages,
|
||||||
|
"stream": stream,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add optional parameters
|
||||||
|
options = {}
|
||||||
|
if temperature is not None:
|
||||||
|
options["temperature"] = temperature
|
||||||
|
if max_tokens is not None:
|
||||||
|
options["num_predict"] = max_tokens
|
||||||
|
|
||||||
|
if options:
|
||||||
|
payload["options"] = options
|
||||||
|
|
||||||
|
logger.debug(f"Ollama request: model={model}, messages={len(ollama_messages)}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
if stream:
|
||||||
|
return self._stream_chat(payload)
|
||||||
|
else:
|
||||||
|
return self._sync_chat(payload)
|
||||||
|
|
||||||
|
except requests.RequestException as e:
|
||||||
|
logger.error(f"Ollama request failed: {e}")
|
||||||
|
# Return error response
|
||||||
|
return ChatResponse(
|
||||||
|
id="error",
|
||||||
|
choices=[
|
||||||
|
ChatResponseChoice(
|
||||||
|
index=0,
|
||||||
|
message=ChatMessage(
|
||||||
|
role="assistant",
|
||||||
|
content=f"Error: {str(e)}",
|
||||||
|
),
|
||||||
|
finish_reason="error",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
def _sync_chat(self, payload: Dict[str, Any]) -> ChatResponse:
|
||||||
|
"""
|
||||||
|
Send synchronous chat request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
payload: Request payload
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ChatResponse
|
||||||
|
"""
|
||||||
|
response = requests.post(
|
||||||
|
f"{self.base_url}/api/chat",
|
||||||
|
json=payload,
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
# Parse response
|
||||||
|
message_data = data.get("message", {})
|
||||||
|
content = message_data.get("content", "")
|
||||||
|
|
||||||
|
# Extract token usage if available
|
||||||
|
usage = None
|
||||||
|
if "prompt_eval_count" in data or "eval_count" in data:
|
||||||
|
usage = UsageStats(
|
||||||
|
prompt_tokens=data.get("prompt_eval_count", 0),
|
||||||
|
completion_tokens=data.get("eval_count", 0),
|
||||||
|
total_tokens=data.get("prompt_eval_count", 0) + data.get("eval_count", 0),
|
||||||
|
total_cost_usd=0.0, # Local models are free
|
||||||
|
)
|
||||||
|
|
||||||
|
return ChatResponse(
|
||||||
|
id=str(time.time()),
|
||||||
|
choices=[
|
||||||
|
ChatResponseChoice(
|
||||||
|
index=0,
|
||||||
|
message=ChatMessage(role="assistant", content=content),
|
||||||
|
finish_reason="stop",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
usage=usage,
|
||||||
|
model=data.get("model"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _stream_chat(self, payload: Dict[str, Any]) -> Iterator[StreamChunk]:
|
||||||
|
"""
|
||||||
|
Stream chat response from Ollama.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
payload: Request payload
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
StreamChunk objects
|
||||||
|
"""
|
||||||
|
response = requests.post(
|
||||||
|
f"{self.base_url}/api/chat",
|
||||||
|
json=payload,
|
||||||
|
stream=True,
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
total_prompt_tokens = 0
|
||||||
|
total_completion_tokens = 0
|
||||||
|
|
||||||
|
for line in response.iter_lines():
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = json.loads(line)
|
||||||
|
|
||||||
|
# Extract content delta
|
||||||
|
message_data = data.get("message", {})
|
||||||
|
content = message_data.get("content", "")
|
||||||
|
|
||||||
|
# Check if done
|
||||||
|
done = data.get("done", False)
|
||||||
|
finish_reason = "stop" if done else None
|
||||||
|
|
||||||
|
# Extract usage if available
|
||||||
|
usage = None
|
||||||
|
if done and ("prompt_eval_count" in data or "eval_count" in data):
|
||||||
|
total_prompt_tokens = data.get("prompt_eval_count", 0)
|
||||||
|
total_completion_tokens = data.get("eval_count", 0)
|
||||||
|
usage = UsageStats(
|
||||||
|
prompt_tokens=total_prompt_tokens,
|
||||||
|
completion_tokens=total_completion_tokens,
|
||||||
|
total_tokens=total_prompt_tokens + total_completion_tokens,
|
||||||
|
total_cost_usd=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield StreamChunk(
|
||||||
|
id=str(time.time()),
|
||||||
|
delta_content=content if content else None,
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
usage=usage,
|
||||||
|
)
|
||||||
|
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(f"Failed to parse Ollama stream chunk: {e}")
|
||||||
|
yield StreamChunk(
|
||||||
|
id="error",
|
||||||
|
error=f"Parse error: {e}",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def chat_async(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[ChatMessage],
|
||||||
|
stream: bool = False,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
tools: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
tool_choice: Optional[str] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Union[ChatResponse, AsyncIterator[StreamChunk]]:
|
||||||
|
"""
|
||||||
|
Async chat not implemented for Ollama.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model name
|
||||||
|
messages: Chat messages
|
||||||
|
stream: Whether to stream
|
||||||
|
max_tokens: Max tokens
|
||||||
|
temperature: Temperature
|
||||||
|
tools: Tools (not supported)
|
||||||
|
tool_choice: Tool choice (not supported)
|
||||||
|
**kwargs: Additional args
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ChatResponse or AsyncIterator[StreamChunk]
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotImplementedError: Async not implemented
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Async chat not implemented for Ollama provider")
|
||||||
|
|
||||||
|
def get_credits(self) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get account credits.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None (Ollama is local and free)
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
def clear_cache(self) -> None:
|
||||||
|
"""Clear model cache (no-op for Ollama)."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_raw_models(self) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get raw model data as dictionaries.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of model dictionaries
|
||||||
|
"""
|
||||||
|
models = self.list_models()
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"id": model.id,
|
||||||
|
"name": model.name,
|
||||||
|
"description": model.description,
|
||||||
|
"context_length": model.context_length,
|
||||||
|
"pricing": model.pricing,
|
||||||
|
}
|
||||||
|
for model in models
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_raw_model(self, model_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get raw model data for a specific model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: Model identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Model dictionary or None
|
||||||
|
"""
|
||||||
|
model = self.get_model(model_id)
|
||||||
|
if model:
|
||||||
|
return {
|
||||||
|
"id": model.id,
|
||||||
|
"name": model.name,
|
||||||
|
"description": model.description,
|
||||||
|
"context_length": model.context_length,
|
||||||
|
"pricing": model.pricing,
|
||||||
|
}
|
||||||
|
return None
|
||||||
630
oai/providers/openai.py
Normal file
630
oai/providers/openai.py
Normal file
@@ -0,0 +1,630 @@
|
|||||||
|
"""
|
||||||
|
OpenAI provider for GPT models.
|
||||||
|
|
||||||
|
This provider connects to OpenAI's API for accessing GPT-4, GPT-3.5, and other OpenAI models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
|
||||||
|
|
||||||
|
from openai import OpenAI, AsyncOpenAI
|
||||||
|
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||||
|
|
||||||
|
from oai.constants import OPENAI_BASE_URL
|
||||||
|
from oai.providers.base import (
|
||||||
|
AIProvider,
|
||||||
|
ChatMessage,
|
||||||
|
ChatResponse,
|
||||||
|
ChatResponseChoice,
|
||||||
|
ModelInfo,
|
||||||
|
ProviderCapabilities,
|
||||||
|
StreamChunk,
|
||||||
|
ToolCall,
|
||||||
|
ToolFunction,
|
||||||
|
UsageStats,
|
||||||
|
)
|
||||||
|
from oai.utils.logging import get_logger
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
# Model aliases for convenience
|
||||||
|
MODEL_ALIASES = {
|
||||||
|
"gpt-4": "gpt-4-turbo",
|
||||||
|
"gpt-4-turbo": "gpt-4-turbo-2024-04-09",
|
||||||
|
"gpt-4o": "gpt-4o-2024-11-20",
|
||||||
|
"gpt-4o-mini": "gpt-4o-mini-2024-07-18",
|
||||||
|
"gpt-3.5": "gpt-3.5-turbo",
|
||||||
|
"gpt-3.5-turbo": "gpt-3.5-turbo-0125",
|
||||||
|
"o1": "o1-2024-12-17",
|
||||||
|
"o1-mini": "o1-mini-2024-09-12",
|
||||||
|
"o1-preview": "o1-preview-2024-09-12",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIProvider(AIProvider):
|
||||||
|
"""
|
||||||
|
OpenAI API provider.
|
||||||
|
|
||||||
|
Provides access to GPT-4, GPT-3.5, o1, and other OpenAI models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: str,
|
||||||
|
base_url: Optional[str] = None,
|
||||||
|
app_name: str = "oAI",
|
||||||
|
app_url: str = "",
|
||||||
|
**kwargs: Any,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize OpenAI provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: OpenAI API key
|
||||||
|
base_url: Optional custom base URL
|
||||||
|
app_name: Application name (for headers)
|
||||||
|
app_url: Application URL (for headers)
|
||||||
|
**kwargs: Additional arguments
|
||||||
|
"""
|
||||||
|
super().__init__(api_key, base_url or OPENAI_BASE_URL)
|
||||||
|
self.client = OpenAI(api_key=api_key, base_url=self.base_url)
|
||||||
|
self.async_client = AsyncOpenAI(api_key=api_key, base_url=self.base_url)
|
||||||
|
self._models_cache: Optional[List[ModelInfo]] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Get provider name."""
|
||||||
|
return "OpenAI"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def capabilities(self) -> ProviderCapabilities:
|
||||||
|
"""Get provider capabilities."""
|
||||||
|
return ProviderCapabilities(
|
||||||
|
streaming=True,
|
||||||
|
tools=True,
|
||||||
|
images=True,
|
||||||
|
online=True, # Web search via DuckDuckGo/Google
|
||||||
|
max_context=128000,
|
||||||
|
)
|
||||||
|
|
||||||
|
def list_models(self, filter_text_only: bool = True) -> List[ModelInfo]:
|
||||||
|
"""
|
||||||
|
List available OpenAI models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filter_text_only: Whether to filter for text models only
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of ModelInfo objects
|
||||||
|
"""
|
||||||
|
if self._models_cache:
|
||||||
|
return self._models_cache
|
||||||
|
|
||||||
|
try:
|
||||||
|
models_response = self.client.models.list()
|
||||||
|
models = []
|
||||||
|
|
||||||
|
for model in models_response.data:
|
||||||
|
# Filter for chat models
|
||||||
|
if "gpt" in model.id or "o1" in model.id:
|
||||||
|
models.append(self._parse_model(model))
|
||||||
|
|
||||||
|
# Sort by name
|
||||||
|
models.sort(key=lambda m: m.name)
|
||||||
|
self._models_cache = models
|
||||||
|
|
||||||
|
logger.info(f"Loaded {len(models)} OpenAI models")
|
||||||
|
return models
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to list OpenAI models: {e}")
|
||||||
|
return self._get_fallback_models()
|
||||||
|
|
||||||
|
def _get_fallback_models(self) -> List[ModelInfo]:
|
||||||
|
"""
|
||||||
|
Get fallback list of common OpenAI models.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of common models
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
ModelInfo(
|
||||||
|
id="gpt-4o",
|
||||||
|
name="GPT-4o",
|
||||||
|
description="Most capable GPT-4 model",
|
||||||
|
context_length=128000,
|
||||||
|
pricing={"input": 5.0, "output": 15.0},
|
||||||
|
supported_parameters=["temperature", "max_tokens", "stream", "tools"],
|
||||||
|
input_modalities=["text", "image"],
|
||||||
|
),
|
||||||
|
ModelInfo(
|
||||||
|
id="gpt-4o-mini",
|
||||||
|
name="GPT-4o Mini",
|
||||||
|
description="Affordable and fast GPT-4 class model",
|
||||||
|
context_length=128000,
|
||||||
|
pricing={"input": 0.15, "output": 0.6},
|
||||||
|
supported_parameters=["temperature", "max_tokens", "stream", "tools"],
|
||||||
|
input_modalities=["text", "image"],
|
||||||
|
),
|
||||||
|
ModelInfo(
|
||||||
|
id="gpt-4-turbo",
|
||||||
|
name="GPT-4 Turbo",
|
||||||
|
description="GPT-4 Turbo with vision",
|
||||||
|
context_length=128000,
|
||||||
|
pricing={"input": 10.0, "output": 30.0},
|
||||||
|
supported_parameters=["temperature", "max_tokens", "stream", "tools"],
|
||||||
|
input_modalities=["text", "image"],
|
||||||
|
),
|
||||||
|
ModelInfo(
|
||||||
|
id="gpt-3.5-turbo",
|
||||||
|
name="GPT-3.5 Turbo",
|
||||||
|
description="Fast and affordable model",
|
||||||
|
context_length=16384,
|
||||||
|
pricing={"input": 0.5, "output": 1.5},
|
||||||
|
supported_parameters=["temperature", "max_tokens", "stream", "tools"],
|
||||||
|
),
|
||||||
|
ModelInfo(
|
||||||
|
id="o1",
|
||||||
|
name="o1",
|
||||||
|
description="Advanced reasoning model",
|
||||||
|
context_length=200000,
|
||||||
|
pricing={"input": 15.0, "output": 60.0},
|
||||||
|
supported_parameters=["max_tokens"],
|
||||||
|
),
|
||||||
|
ModelInfo(
|
||||||
|
id="o1-mini",
|
||||||
|
name="o1-mini",
|
||||||
|
description="Fast reasoning model",
|
||||||
|
context_length=128000,
|
||||||
|
pricing={"input": 3.0, "output": 12.0},
|
||||||
|
supported_parameters=["max_tokens"],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
def _parse_model(self, model: Any) -> ModelInfo:
|
||||||
|
"""
|
||||||
|
Parse OpenAI model into ModelInfo.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: OpenAI model object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ModelInfo object
|
||||||
|
"""
|
||||||
|
model_id = model.id
|
||||||
|
|
||||||
|
# Determine context length
|
||||||
|
context_length = 8192 # Default
|
||||||
|
if "gpt-4o" in model_id or "gpt-4-turbo" in model_id:
|
||||||
|
context_length = 128000
|
||||||
|
elif "gpt-4" in model_id:
|
||||||
|
context_length = 8192
|
||||||
|
elif "gpt-3.5-turbo" in model_id:
|
||||||
|
context_length = 16384
|
||||||
|
elif "o1" in model_id:
|
||||||
|
context_length = 128000
|
||||||
|
|
||||||
|
# Determine pricing (approximate)
|
||||||
|
pricing = {}
|
||||||
|
if "gpt-4o-mini" in model_id:
|
||||||
|
pricing = {"input": 0.15, "output": 0.6}
|
||||||
|
elif "gpt-4o" in model_id:
|
||||||
|
pricing = {"input": 5.0, "output": 15.0}
|
||||||
|
elif "gpt-4-turbo" in model_id:
|
||||||
|
pricing = {"input": 10.0, "output": 30.0}
|
||||||
|
elif "gpt-4" in model_id:
|
||||||
|
pricing = {"input": 30.0, "output": 60.0}
|
||||||
|
elif "gpt-3.5" in model_id:
|
||||||
|
pricing = {"input": 0.5, "output": 1.5}
|
||||||
|
elif "o1" in model_id and "mini" not in model_id:
|
||||||
|
pricing = {"input": 15.0, "output": 60.0}
|
||||||
|
elif "o1-mini" in model_id:
|
||||||
|
pricing = {"input": 3.0, "output": 12.0}
|
||||||
|
|
||||||
|
return ModelInfo(
|
||||||
|
id=model_id,
|
||||||
|
name=model_id,
|
||||||
|
description="",
|
||||||
|
context_length=context_length,
|
||||||
|
pricing=pricing,
|
||||||
|
supported_parameters=["temperature", "max_tokens", "stream", "tools"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_model(self, model_id: str) -> Optional[ModelInfo]:
|
||||||
|
"""
|
||||||
|
Get information about a specific model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: Model identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ModelInfo or None
|
||||||
|
"""
|
||||||
|
# Resolve alias
|
||||||
|
resolved_id = MODEL_ALIASES.get(model_id, model_id)
|
||||||
|
|
||||||
|
models = self.list_models()
|
||||||
|
for model in models:
|
||||||
|
if model.id == resolved_id or model.id == model_id:
|
||||||
|
return model
|
||||||
|
|
||||||
|
# Try to fetch directly
|
||||||
|
try:
|
||||||
|
model = self.client.models.retrieve(resolved_id)
|
||||||
|
return self._parse_model(model)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def chat(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[ChatMessage],
|
||||||
|
stream: bool = False,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
tools: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
tool_choice: Optional[str] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Union[ChatResponse, Iterator[StreamChunk]]:
|
||||||
|
"""
|
||||||
|
Send chat completion request to OpenAI.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model ID
|
||||||
|
messages: Chat messages
|
||||||
|
stream: Whether to stream response
|
||||||
|
max_tokens: Maximum tokens
|
||||||
|
temperature: Sampling temperature
|
||||||
|
tools: Tool definitions
|
||||||
|
tool_choice: Tool selection mode
|
||||||
|
**kwargs: Additional parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ChatResponse or Iterator[StreamChunk]
|
||||||
|
"""
|
||||||
|
# Resolve model alias
|
||||||
|
model_id = MODEL_ALIASES.get(model, model)
|
||||||
|
|
||||||
|
# Convert messages to OpenAI format
|
||||||
|
openai_messages = []
|
||||||
|
for msg in messages:
|
||||||
|
message_dict = {"role": msg.role, "content": msg.content or ""}
|
||||||
|
|
||||||
|
if msg.tool_calls:
|
||||||
|
message_dict["tool_calls"] = [
|
||||||
|
{
|
||||||
|
"id": tc.id,
|
||||||
|
"type": tc.type,
|
||||||
|
"function": {
|
||||||
|
"name": tc.function.name,
|
||||||
|
"arguments": tc.function.arguments,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for tc in msg.tool_calls
|
||||||
|
]
|
||||||
|
|
||||||
|
if msg.tool_call_id:
|
||||||
|
message_dict["tool_call_id"] = msg.tool_call_id
|
||||||
|
|
||||||
|
openai_messages.append(message_dict)
|
||||||
|
|
||||||
|
# Build request parameters
|
||||||
|
params: Dict[str, Any] = {
|
||||||
|
"model": model_id,
|
||||||
|
"messages": openai_messages,
|
||||||
|
"stream": stream,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add optional parameters
|
||||||
|
if max_tokens is not None:
|
||||||
|
params["max_tokens"] = max_tokens
|
||||||
|
if temperature is not None and "o1" not in model_id:
|
||||||
|
# o1 models don't support temperature
|
||||||
|
params["temperature"] = temperature
|
||||||
|
if tools:
|
||||||
|
params["tools"] = tools
|
||||||
|
if tool_choice:
|
||||||
|
params["tool_choice"] = tool_choice
|
||||||
|
|
||||||
|
logger.debug(f"OpenAI request: model={model_id}, messages={len(openai_messages)}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
if stream:
|
||||||
|
return self._stream_chat(params)
|
||||||
|
else:
|
||||||
|
return self._sync_chat(params)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"OpenAI request failed: {e}")
|
||||||
|
return ChatResponse(
|
||||||
|
id="error",
|
||||||
|
choices=[
|
||||||
|
ChatResponseChoice(
|
||||||
|
index=0,
|
||||||
|
message=ChatMessage(role="assistant", content=f"Error: {str(e)}"),
|
||||||
|
finish_reason="error",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
def _sync_chat(self, params: Dict[str, Any]) -> ChatResponse:
|
||||||
|
"""
|
||||||
|
Send synchronous chat request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params: Request parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ChatResponse
|
||||||
|
"""
|
||||||
|
completion: ChatCompletion = self.client.chat.completions.create(**params)
|
||||||
|
|
||||||
|
# Convert to our format
|
||||||
|
choices = []
|
||||||
|
for choice in completion.choices:
|
||||||
|
# Convert tool calls if present
|
||||||
|
tool_calls = None
|
||||||
|
if choice.message.tool_calls:
|
||||||
|
tool_calls = [
|
||||||
|
ToolCall(
|
||||||
|
id=tc.id,
|
||||||
|
type=tc.type,
|
||||||
|
function=ToolFunction(
|
||||||
|
name=tc.function.name,
|
||||||
|
arguments=tc.function.arguments,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for tc in choice.message.tool_calls
|
||||||
|
]
|
||||||
|
|
||||||
|
choices.append(
|
||||||
|
ChatResponseChoice(
|
||||||
|
index=choice.index,
|
||||||
|
message=ChatMessage(
|
||||||
|
role=choice.message.role,
|
||||||
|
content=choice.message.content,
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
),
|
||||||
|
finish_reason=choice.finish_reason,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert usage
|
||||||
|
usage = None
|
||||||
|
if completion.usage:
|
||||||
|
usage = UsageStats(
|
||||||
|
prompt_tokens=completion.usage.prompt_tokens,
|
||||||
|
completion_tokens=completion.usage.completion_tokens,
|
||||||
|
total_tokens=completion.usage.total_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ChatResponse(
|
||||||
|
id=completion.id,
|
||||||
|
choices=choices,
|
||||||
|
usage=usage,
|
||||||
|
model=completion.model,
|
||||||
|
created=completion.created,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _stream_chat(self, params: Dict[str, Any]) -> Iterator[StreamChunk]:
|
||||||
|
"""
|
||||||
|
Stream chat response from OpenAI.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params: Request parameters
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
StreamChunk objects
|
||||||
|
"""
|
||||||
|
stream = self.client.chat.completions.create(**params)
|
||||||
|
|
||||||
|
for chunk in stream:
|
||||||
|
chunk_data: ChatCompletionChunk = chunk
|
||||||
|
|
||||||
|
if not chunk_data.choices:
|
||||||
|
continue
|
||||||
|
|
||||||
|
choice = chunk_data.choices[0]
|
||||||
|
delta = choice.delta
|
||||||
|
|
||||||
|
# Extract content
|
||||||
|
content = delta.content if delta.content else None
|
||||||
|
|
||||||
|
# Extract finish reason
|
||||||
|
finish_reason = choice.finish_reason
|
||||||
|
|
||||||
|
# Extract usage (usually in last chunk)
|
||||||
|
usage = None
|
||||||
|
if hasattr(chunk_data, "usage") and chunk_data.usage:
|
||||||
|
usage = UsageStats(
|
||||||
|
prompt_tokens=chunk_data.usage.prompt_tokens,
|
||||||
|
completion_tokens=chunk_data.usage.completion_tokens,
|
||||||
|
total_tokens=chunk_data.usage.total_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield StreamChunk(
|
||||||
|
id=chunk_data.id,
|
||||||
|
delta_content=content,
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
usage=usage,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def chat_async(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[ChatMessage],
|
||||||
|
stream: bool = False,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
tools: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
tool_choice: Optional[str] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Union[ChatResponse, AsyncIterator[StreamChunk]]:
|
||||||
|
"""
|
||||||
|
Send async chat request to OpenAI.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model ID
|
||||||
|
messages: Chat messages
|
||||||
|
stream: Whether to stream
|
||||||
|
max_tokens: Max tokens
|
||||||
|
temperature: Temperature
|
||||||
|
tools: Tool definitions
|
||||||
|
tool_choice: Tool choice
|
||||||
|
**kwargs: Additional args
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ChatResponse or AsyncIterator[StreamChunk]
|
||||||
|
"""
|
||||||
|
# Resolve model alias
|
||||||
|
model_id = MODEL_ALIASES.get(model, model)
|
||||||
|
|
||||||
|
# Convert messages
|
||||||
|
openai_messages = [msg.to_dict() for msg in messages]
|
||||||
|
|
||||||
|
# Build params
|
||||||
|
params: Dict[str, Any] = {
|
||||||
|
"model": model_id,
|
||||||
|
"messages": openai_messages,
|
||||||
|
"stream": stream,
|
||||||
|
}
|
||||||
|
|
||||||
|
if max_tokens:
|
||||||
|
params["max_tokens"] = max_tokens
|
||||||
|
if temperature is not None and "o1" not in model_id:
|
||||||
|
params["temperature"] = temperature
|
||||||
|
if tools:
|
||||||
|
params["tools"] = tools
|
||||||
|
if tool_choice:
|
||||||
|
params["tool_choice"] = tool_choice
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
return self._stream_chat_async(params)
|
||||||
|
else:
|
||||||
|
completion = await self.async_client.chat.completions.create(**params)
|
||||||
|
# Convert to ChatResponse (similar to _sync_chat)
|
||||||
|
return self._convert_completion(completion)
|
||||||
|
|
||||||
|
async def _stream_chat_async(self, params: Dict[str, Any]) -> AsyncIterator[StreamChunk]:
|
||||||
|
"""
|
||||||
|
Stream async chat response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params: Request parameters
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
StreamChunk objects
|
||||||
|
"""
|
||||||
|
stream = await self.async_client.chat.completions.create(**params)
|
||||||
|
|
||||||
|
async for chunk in stream:
|
||||||
|
if not chunk.choices:
|
||||||
|
continue
|
||||||
|
|
||||||
|
choice = chunk.choices[0]
|
||||||
|
delta = choice.delta
|
||||||
|
|
||||||
|
yield StreamChunk(
|
||||||
|
id=chunk.id,
|
||||||
|
delta_content=delta.content,
|
||||||
|
finish_reason=choice.finish_reason,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _convert_completion(self, completion: ChatCompletion) -> ChatResponse:
|
||||||
|
"""Helper to convert OpenAI completion to ChatResponse."""
|
||||||
|
choices = []
|
||||||
|
for choice in completion.choices:
|
||||||
|
tool_calls = None
|
||||||
|
if choice.message.tool_calls:
|
||||||
|
tool_calls = [
|
||||||
|
ToolCall(
|
||||||
|
id=tc.id,
|
||||||
|
type=tc.type,
|
||||||
|
function=ToolFunction(
|
||||||
|
name=tc.function.name,
|
||||||
|
arguments=tc.function.arguments,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for tc in choice.message.tool_calls
|
||||||
|
]
|
||||||
|
|
||||||
|
choices.append(
|
||||||
|
ChatResponseChoice(
|
||||||
|
index=choice.index,
|
||||||
|
message=ChatMessage(
|
||||||
|
role=choice.message.role,
|
||||||
|
content=choice.message.content,
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
),
|
||||||
|
finish_reason=choice.finish_reason,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
usage = None
|
||||||
|
if completion.usage:
|
||||||
|
usage = UsageStats(
|
||||||
|
prompt_tokens=completion.usage.prompt_tokens,
|
||||||
|
completion_tokens=completion.usage.completion_tokens,
|
||||||
|
total_tokens=completion.usage.total_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ChatResponse(
|
||||||
|
id=completion.id,
|
||||||
|
choices=choices,
|
||||||
|
usage=usage,
|
||||||
|
model=completion.model,
|
||||||
|
created=completion.created,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_credits(self) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get account credits.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None (OpenAI doesn't provide credit API)
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
def clear_cache(self) -> None:
|
||||||
|
"""Clear model cache."""
|
||||||
|
self._models_cache = None
|
||||||
|
|
||||||
|
def get_raw_models(self) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get raw model data as dictionaries.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of model dictionaries
|
||||||
|
"""
|
||||||
|
models = self.list_models()
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"id": model.id,
|
||||||
|
"name": model.name,
|
||||||
|
"description": model.description,
|
||||||
|
"context_length": model.context_length,
|
||||||
|
"pricing": model.pricing,
|
||||||
|
}
|
||||||
|
for model in models
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_raw_model(self, model_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get raw model data for a specific model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: Model identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Model dictionary or None
|
||||||
|
"""
|
||||||
|
model = self.get_model(model_id)
|
||||||
|
if model:
|
||||||
|
return {
|
||||||
|
"id": model.id,
|
||||||
|
"name": model.name,
|
||||||
|
"description": model.description,
|
||||||
|
"context_length": model.context_length,
|
||||||
|
"pricing": model.pricing,
|
||||||
|
}
|
||||||
|
return None
|
||||||
@@ -269,10 +269,17 @@ class OpenRouterProvider(AIProvider):
|
|||||||
completion_tokens = usage_data.get("output_tokens", 0) or 0
|
completion_tokens = usage_data.get("output_tokens", 0) or 0
|
||||||
|
|
||||||
# Get cost if available
|
# Get cost if available
|
||||||
|
# OpenRouter returns cost in different places:
|
||||||
|
# 1. As 'total_cost_usd' in usage object (rare)
|
||||||
|
# 2. As 'usage' at root level (common - this is the dollar amount)
|
||||||
|
total_cost = None
|
||||||
if hasattr(usage_data, "total_cost_usd"):
|
if hasattr(usage_data, "total_cost_usd"):
|
||||||
total_cost = getattr(usage_data, "total_cost_usd", None)
|
total_cost = getattr(usage_data, "total_cost_usd", None)
|
||||||
|
elif hasattr(usage_data, "usage"):
|
||||||
|
# OpenRouter puts cost as 'usage' field (dollar amount)
|
||||||
|
total_cost = getattr(usage_data, "usage", None)
|
||||||
elif isinstance(usage_data, dict):
|
elif isinstance(usage_data, dict):
|
||||||
total_cost = usage_data.get("total_cost_usd")
|
total_cost = usage_data.get("total_cost_usd") or usage_data.get("usage")
|
||||||
|
|
||||||
return UsageStats(
|
return UsageStats(
|
||||||
prompt_tokens=prompt_tokens,
|
prompt_tokens=prompt_tokens,
|
||||||
|
|||||||
60
oai/providers/registry.py
Normal file
60
oai/providers/registry.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
"""
|
||||||
|
Provider registry for AI model providers.
|
||||||
|
|
||||||
|
This module maintains a central registry of all available AI providers,
|
||||||
|
allowing dynamic provider lookup and registration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Dict, List, Optional, Type
|
||||||
|
|
||||||
|
from oai.providers.base import AIProvider
|
||||||
|
|
||||||
|
# Global provider registry
|
||||||
|
PROVIDER_REGISTRY: Dict[str, Type[AIProvider]] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def register_provider(name: str, provider_class: Type[AIProvider]) -> None:
|
||||||
|
"""
|
||||||
|
Register a provider class with the given name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Provider identifier (e.g., "openrouter", "anthropic")
|
||||||
|
provider_class: The provider class to register
|
||||||
|
"""
|
||||||
|
PROVIDER_REGISTRY[name] = provider_class
|
||||||
|
|
||||||
|
|
||||||
|
def get_provider_class(name: str) -> Optional[Type[AIProvider]]:
|
||||||
|
"""
|
||||||
|
Get a provider class by name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Provider identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Provider class or None if not found
|
||||||
|
"""
|
||||||
|
return PROVIDER_REGISTRY.get(name)
|
||||||
|
|
||||||
|
|
||||||
|
def list_providers() -> List[str]:
|
||||||
|
"""
|
||||||
|
List all registered provider names.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of provider identifiers
|
||||||
|
"""
|
||||||
|
return list(PROVIDER_REGISTRY.keys())
|
||||||
|
|
||||||
|
|
||||||
|
def is_provider_registered(name: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a provider is registered.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Provider identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if provider is registered
|
||||||
|
"""
|
||||||
|
return name in PROVIDER_REGISTRY
|
||||||
5
oai/tui/__init__.py
Normal file
5
oai/tui/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
"""Textual TUI interface for oAI."""
|
||||||
|
|
||||||
|
from oai.tui.app import oAIChatApp
|
||||||
|
|
||||||
|
__all__ = ["oAIChatApp"]
|
||||||
1069
oai/tui/app.py
Normal file
1069
oai/tui/app.py
Normal file
File diff suppressed because it is too large
Load Diff
23
oai/tui/screens/__init__.py
Normal file
23
oai/tui/screens/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
"""TUI screens for oAI."""
|
||||||
|
|
||||||
|
from oai.tui.screens.commands_screen import CommandsScreen
|
||||||
|
from oai.tui.screens.config_screen import ConfigScreen
|
||||||
|
from oai.tui.screens.conversation_selector import ConversationSelectorScreen
|
||||||
|
from oai.tui.screens.credits_screen import CreditsScreen
|
||||||
|
from oai.tui.screens.dialogs import AlertDialog, ConfirmDialog, InputDialog
|
||||||
|
from oai.tui.screens.help_screen import HelpScreen
|
||||||
|
from oai.tui.screens.model_selector import ModelSelectorScreen
|
||||||
|
from oai.tui.screens.stats_screen import StatsScreen
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AlertDialog",
|
||||||
|
"CommandsScreen",
|
||||||
|
"ConfirmDialog",
|
||||||
|
"ConfigScreen",
|
||||||
|
"ConversationSelectorScreen",
|
||||||
|
"CreditsScreen",
|
||||||
|
"InputDialog",
|
||||||
|
"HelpScreen",
|
||||||
|
"ModelSelectorScreen",
|
||||||
|
"StatsScreen",
|
||||||
|
]
|
||||||
172
oai/tui/screens/commands_screen.py
Normal file
172
oai/tui/screens/commands_screen.py
Normal file
@@ -0,0 +1,172 @@
|
|||||||
|
"""Commands reference screen for oAI TUI."""
|
||||||
|
|
||||||
|
from textual.app import ComposeResult
|
||||||
|
from textual.containers import Container, Vertical, VerticalScroll
|
||||||
|
from textual.screen import ModalScreen
|
||||||
|
from textual.widgets import Button, Static
|
||||||
|
|
||||||
|
|
||||||
|
class CommandsScreen(ModalScreen[None]):
|
||||||
|
"""Modal screen showing all available commands."""
|
||||||
|
|
||||||
|
DEFAULT_CSS = """
|
||||||
|
CommandsScreen {
|
||||||
|
align: center middle;
|
||||||
|
}
|
||||||
|
|
||||||
|
CommandsScreen > Container {
|
||||||
|
width: 90;
|
||||||
|
height: 40;
|
||||||
|
background: #1e1e1e;
|
||||||
|
border: solid #555555;
|
||||||
|
}
|
||||||
|
|
||||||
|
CommandsScreen .header {
|
||||||
|
dock: top;
|
||||||
|
width: 100%;
|
||||||
|
height: auto;
|
||||||
|
background: #2d2d2d;
|
||||||
|
color: #cccccc;
|
||||||
|
padding: 0 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
CommandsScreen .content {
|
||||||
|
width: 100%;
|
||||||
|
height: 1fr;
|
||||||
|
background: #1e1e1e;
|
||||||
|
padding: 2;
|
||||||
|
color: #cccccc;
|
||||||
|
overflow-y: auto;
|
||||||
|
scrollbar-background: #1e1e1e;
|
||||||
|
scrollbar-color: #555555;
|
||||||
|
scrollbar-size: 1 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
CommandsScreen .footer {
|
||||||
|
dock: bottom;
|
||||||
|
width: 100%;
|
||||||
|
height: auto;
|
||||||
|
background: #2d2d2d;
|
||||||
|
padding: 1 2;
|
||||||
|
align: center middle;
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def compose(self) -> ComposeResult:
|
||||||
|
"""Compose the commands screen."""
|
||||||
|
with Container():
|
||||||
|
yield Static("[bold]Commands Reference[/]", classes="header")
|
||||||
|
with VerticalScroll(classes="content"):
|
||||||
|
yield Static(self._get_commands_text(), markup=True)
|
||||||
|
with Vertical(classes="footer"):
|
||||||
|
yield Button("Close", id="close-button", variant="primary")
|
||||||
|
|
||||||
|
def _get_commands_text(self) -> str:
|
||||||
|
"""Generate formatted commands text."""
|
||||||
|
return """[bold cyan]General Commands[/]
|
||||||
|
|
||||||
|
[green]/help[/] - Show help screen with keyboard shortcuts
|
||||||
|
[green]/commands[/] - Show this commands reference
|
||||||
|
[green]/model[/] - Open model selector (or press F2)
|
||||||
|
[green]/stats[/] - Show session statistics (or press Ctrl+S)
|
||||||
|
[green]/credits[/] - Check account credits (OpenRouter) or view console link
|
||||||
|
[green]/clear[/] - Clear chat display
|
||||||
|
[green]/reset[/] - Reset conversation history
|
||||||
|
[green]/retry[/] - Retry last prompt
|
||||||
|
[green]/paste[/] - Paste from clipboard
|
||||||
|
|
||||||
|
[bold cyan]Provider Commands[/]
|
||||||
|
|
||||||
|
[green]/provider[/] - Show current provider
|
||||||
|
[green]/provider openrouter[/] - Switch to OpenRouter
|
||||||
|
[green]/provider anthropic[/] - Switch to Anthropic (Claude)
|
||||||
|
[green]/provider openai[/] - Switch to OpenAI (ChatGPT)
|
||||||
|
[green]/provider ollama[/] - Switch to Ollama (local)
|
||||||
|
|
||||||
|
[bold cyan]Online Mode (Web Search)[/]
|
||||||
|
|
||||||
|
[green]/online[/] - Show online mode status
|
||||||
|
[green]/online on[/] - Enable web search
|
||||||
|
[green]/online off[/] - Disable web search
|
||||||
|
|
||||||
|
[dim]Search Providers:[/]
|
||||||
|
• [yellow]anthropic_native[/] - Anthropic's native search with citations ($0.01/search)
|
||||||
|
• [yellow]duckduckgo[/] - Free web scraping (default, works with all providers)
|
||||||
|
• [yellow]google[/] - Google Custom Search (requires API key)
|
||||||
|
|
||||||
|
[bold cyan]Configuration Commands[/]
|
||||||
|
|
||||||
|
[green]/config[/] - View all settings
|
||||||
|
[green]/config provider <name>[/] - Set default provider
|
||||||
|
[green]/config search_provider <provider>[/] - Set search provider (anthropic_native/duckduckgo/google)
|
||||||
|
[green]/config openrouter_api_key <key>[/] - Set OpenRouter API key
|
||||||
|
[green]/config anthropic_api_key <key>[/] - Set Anthropic API key
|
||||||
|
[green]/config openai_api_key <key>[/] - Set OpenAI API key
|
||||||
|
[green]/config ollama_base_url <url>[/] - Set Ollama server URL
|
||||||
|
[green]/config google_api_key <key>[/] - Set Google API key (for Google search)
|
||||||
|
[green]/config google_search_engine_id <id>[/] - Set Google Search Engine ID
|
||||||
|
[green]/config online on|off[/] - Set default online mode
|
||||||
|
[green]/config stream on|off[/] - Toggle streaming
|
||||||
|
[green]/config model <id>[/] - Set default model
|
||||||
|
[green]/config system <prompt>[/] - Set system prompt
|
||||||
|
[green]/config maxtoken <num>[/] - Set token limit
|
||||||
|
|
||||||
|
[bold cyan]Memory & Context[/]
|
||||||
|
|
||||||
|
[green]/memory on[/] - Enable conversation memory
|
||||||
|
[green]/memory off[/] - Disable memory (fresh context each message)
|
||||||
|
|
||||||
|
[bold cyan]Conversation Management[/]
|
||||||
|
|
||||||
|
[green]/save <name>[/] - Save current conversation
|
||||||
|
[green]/load <name>[/] - Load saved conversation
|
||||||
|
[green]/list[/] - List all saved conversations
|
||||||
|
[green]/delete <name>[/] - Delete a conversation
|
||||||
|
[green]/prev[/] - Show previous message
|
||||||
|
[green]/next[/] - Show next message
|
||||||
|
|
||||||
|
[bold cyan]Export Commands[/]
|
||||||
|
|
||||||
|
[green]/export md <file>[/] - Export conversation as Markdown
|
||||||
|
[green]/export json <file>[/] - Export as JSON
|
||||||
|
[green]/export html <file>[/] - Export as HTML
|
||||||
|
|
||||||
|
[bold cyan]MCP (Model Context Protocol)[/]
|
||||||
|
|
||||||
|
[green]/mcp on[/] - Enable MCP file access
|
||||||
|
[green]/mcp off[/] - Disable MCP
|
||||||
|
[green]/mcp status[/] - Show MCP status
|
||||||
|
[green]/mcp add <path>[/] - Add folder for file access
|
||||||
|
[green]/mcp add db <path>[/] - Add SQLite database
|
||||||
|
[green]/mcp remove <path>[/] - Remove folder/database
|
||||||
|
[green]/mcp list[/] - List allowed folders
|
||||||
|
[green]/mcp db list[/] - List added databases
|
||||||
|
[green]/mcp db <n>[/] - Switch to database mode
|
||||||
|
[green]/mcp files[/] - Switch to file mode
|
||||||
|
[green]/mcp write on[/] - Enable write mode (allows file modifications)
|
||||||
|
[green]/mcp write off[/] - Disable write mode
|
||||||
|
|
||||||
|
[bold cyan]System Prompt[/]
|
||||||
|
|
||||||
|
[green]/system <prompt>[/] - Set custom system prompt for session
|
||||||
|
[green]/config system <prompt>[/] - Set default system prompt
|
||||||
|
|
||||||
|
[bold cyan]Keyboard Shortcuts[/]
|
||||||
|
|
||||||
|
• [yellow]F1[/] - Help screen
|
||||||
|
• [yellow]F2[/] - Model selector
|
||||||
|
• [yellow]Ctrl+S[/] - Statistics
|
||||||
|
• [yellow]Ctrl+Q[/] - Quit
|
||||||
|
• [yellow]Ctrl+Y[/] - Copy latest reply in Markdown
|
||||||
|
• [yellow]Up/Down[/] - Command history
|
||||||
|
• [yellow]Tab[/] - Command completion
|
||||||
|
"""
|
||||||
|
|
||||||
|
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||||
|
"""Handle button press."""
|
||||||
|
self.dismiss()
|
||||||
|
|
||||||
|
def on_key(self, event) -> None:
|
||||||
|
"""Handle keyboard shortcuts."""
|
||||||
|
if event.key in ("escape", "enter"):
|
||||||
|
self.dismiss()
|
||||||
163
oai/tui/screens/config_screen.py
Normal file
163
oai/tui/screens/config_screen.py
Normal file
@@ -0,0 +1,163 @@
|
|||||||
|
"""Configuration screen for oAI TUI."""
|
||||||
|
|
||||||
|
from textual.app import ComposeResult
|
||||||
|
from textual.containers import Container, Vertical
|
||||||
|
from textual.screen import ModalScreen
|
||||||
|
from textual.widgets import Button, Static
|
||||||
|
|
||||||
|
from oai.config.settings import Settings
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigScreen(ModalScreen[None]):
|
||||||
|
"""Modal screen displaying configuration settings."""
|
||||||
|
|
||||||
|
DEFAULT_CSS = """
|
||||||
|
ConfigScreen {
|
||||||
|
align: center middle;
|
||||||
|
}
|
||||||
|
|
||||||
|
ConfigScreen > Container {
|
||||||
|
width: 70;
|
||||||
|
height: auto;
|
||||||
|
background: #1e1e1e;
|
||||||
|
border: solid #555555;
|
||||||
|
}
|
||||||
|
|
||||||
|
ConfigScreen .header {
|
||||||
|
dock: top;
|
||||||
|
width: 100%;
|
||||||
|
height: auto;
|
||||||
|
background: #2d2d2d;
|
||||||
|
color: #cccccc;
|
||||||
|
padding: 0 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
ConfigScreen .content {
|
||||||
|
width: 100%;
|
||||||
|
height: auto;
|
||||||
|
background: #1e1e1e;
|
||||||
|
padding: 2;
|
||||||
|
color: #cccccc;
|
||||||
|
}
|
||||||
|
|
||||||
|
ConfigScreen .footer {
|
||||||
|
dock: bottom;
|
||||||
|
width: 100%;
|
||||||
|
height: auto;
|
||||||
|
background: #2d2d2d;
|
||||||
|
padding: 1 2;
|
||||||
|
align: center middle;
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, settings: Settings, session=None):
|
||||||
|
super().__init__()
|
||||||
|
self.settings = settings
|
||||||
|
self.session = session
|
||||||
|
|
||||||
|
def compose(self) -> ComposeResult:
|
||||||
|
"""Compose the screen."""
|
||||||
|
with Container():
|
||||||
|
yield Static("[bold]Configuration[/]", classes="header")
|
||||||
|
with Vertical(classes="content"):
|
||||||
|
yield Static(self._get_config_text(), markup=True)
|
||||||
|
with Vertical(classes="footer"):
|
||||||
|
yield Button("Close", id="close", variant="primary")
|
||||||
|
|
||||||
|
def _get_config_text(self) -> str:
|
||||||
|
"""Generate the configuration text."""
|
||||||
|
from oai.constants import DEFAULT_SYSTEM_PROMPT
|
||||||
|
|
||||||
|
config_lines = ["[bold cyan]═══ CONFIGURATION ═══[/]\n"]
|
||||||
|
|
||||||
|
# Current Session Info
|
||||||
|
if self.session and self.session.client:
|
||||||
|
config_lines.append("[bold yellow]Current Session:[/]")
|
||||||
|
|
||||||
|
provider = self.session.client.provider_name
|
||||||
|
config_lines.append(f"[bold]Provider:[/] [green]{provider}[/]")
|
||||||
|
|
||||||
|
# Provider URL
|
||||||
|
if hasattr(self.session.client.provider, 'base_url'):
|
||||||
|
provider_url = self.session.client.provider.base_url
|
||||||
|
config_lines.append(f"[bold]Provider URL:[/] {provider_url}")
|
||||||
|
|
||||||
|
# API Key status
|
||||||
|
if provider == "ollama":
|
||||||
|
config_lines.append(f"[bold]API Key:[/] [dim]Not required (local)[/]")
|
||||||
|
else:
|
||||||
|
api_key = self.settings.get_provider_api_key(provider)
|
||||||
|
if api_key:
|
||||||
|
key_display = "***" + api_key[-4:] if len(api_key) > 4 else "***"
|
||||||
|
config_lines.append(f"[bold]API Key:[/] [green]{key_display}[/]")
|
||||||
|
else:
|
||||||
|
config_lines.append(f"[bold]API Key:[/] [red]Not set[/]")
|
||||||
|
|
||||||
|
# Current model
|
||||||
|
if self.session.selected_model:
|
||||||
|
model_name = self.session.selected_model.get("name", "Unknown")
|
||||||
|
config_lines.append(f"[bold]Current Model:[/] {model_name}")
|
||||||
|
|
||||||
|
config_lines.append("")
|
||||||
|
|
||||||
|
# Default Settings
|
||||||
|
config_lines.append("[bold yellow]Default Settings:[/]")
|
||||||
|
config_lines.append(f"[bold]Default Provider:[/] {self.settings.default_provider}")
|
||||||
|
|
||||||
|
# Mask helper
|
||||||
|
def mask_key(key):
|
||||||
|
if not key:
|
||||||
|
return "[red]Not set[/]"
|
||||||
|
return "***" + key[-4:] if len(key) > 4 else "***"
|
||||||
|
|
||||||
|
config_lines.append(f"[bold]OpenRouter Key:[/] {mask_key(self.settings.openrouter_api_key)}")
|
||||||
|
config_lines.append(f"[bold]Anthropic Key:[/] {mask_key(self.settings.anthropic_api_key)}")
|
||||||
|
config_lines.append(f"[bold]OpenAI Key:[/] {mask_key(self.settings.openai_api_key)}")
|
||||||
|
config_lines.append(f"[bold]Ollama URL:[/] {self.settings.ollama_base_url}")
|
||||||
|
config_lines.append(f"[bold]Default Model:[/] {self.settings.default_model or 'Not set'}")
|
||||||
|
|
||||||
|
# System prompt display
|
||||||
|
if self.settings.default_system_prompt is None:
|
||||||
|
system_prompt_display = f"[dim][default][/] {DEFAULT_SYSTEM_PROMPT[:40]}..."
|
||||||
|
elif self.settings.default_system_prompt == "":
|
||||||
|
system_prompt_display = "[dim][blank][/]"
|
||||||
|
else:
|
||||||
|
prompt = self.settings.default_system_prompt
|
||||||
|
system_prompt_display = prompt[:50] + "..." if len(prompt) > 50 else prompt
|
||||||
|
|
||||||
|
config_lines.append(f"[bold]System Prompt:[/] {system_prompt_display}")
|
||||||
|
config_lines.append("")
|
||||||
|
|
||||||
|
# Web Search Configuration
|
||||||
|
config_lines.append("[bold yellow]Web Search Configuration:[/]")
|
||||||
|
config_lines.append(f"[bold]Search Provider:[/] {self.settings.search_provider}")
|
||||||
|
|
||||||
|
# Show API key status based on search provider
|
||||||
|
if self.settings.search_provider == "google":
|
||||||
|
config_lines.append(f"[bold]Google API Key:[/] {mask_key(self.settings.google_api_key)}")
|
||||||
|
config_lines.append(f"[bold]Search Engine ID:[/] {mask_key(self.settings.google_search_engine_id)}")
|
||||||
|
elif self.settings.search_provider == "duckduckgo":
|
||||||
|
config_lines.append("[dim] DuckDuckGo requires no configuration (free)[/]")
|
||||||
|
elif self.settings.search_provider == "anthropic_native":
|
||||||
|
config_lines.append("[dim] Uses Anthropic API ($0.01 per search)[/]")
|
||||||
|
|
||||||
|
config_lines.append("")
|
||||||
|
|
||||||
|
# Other settings
|
||||||
|
config_lines.append("[bold]Streaming:[/] " + ("on" if self.settings.stream_enabled else "off"))
|
||||||
|
config_lines.append(f"[bold]Cost Warning:[/] ${self.settings.cost_warning_threshold:.4f}")
|
||||||
|
config_lines.append(f"[bold]Max Tokens:[/] {self.settings.max_tokens}")
|
||||||
|
config_lines.append(f"[bold]Default Online:[/] " + ("on" if self.settings.default_online_mode else "off"))
|
||||||
|
config_lines.append(f"[bold]Log Level:[/] {self.settings.log_level}")
|
||||||
|
config_lines.append("\n[dim]Use /config [setting] [value] to modify settings[/]")
|
||||||
|
|
||||||
|
return "\n".join(config_lines)
|
||||||
|
|
||||||
|
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||||
|
"""Handle button press."""
|
||||||
|
self.dismiss()
|
||||||
|
|
||||||
|
def on_key(self, event) -> None:
|
||||||
|
"""Handle keyboard shortcuts."""
|
||||||
|
if event.key in ("escape", "enter"):
|
||||||
|
self.dismiss()
|
||||||
205
oai/tui/screens/conversation_selector.py
Normal file
205
oai/tui/screens/conversation_selector.py
Normal file
@@ -0,0 +1,205 @@
|
|||||||
|
"""Conversation selector screen for oAI TUI."""
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from textual.app import ComposeResult
|
||||||
|
from textual.containers import Container, Vertical
|
||||||
|
from textual.screen import ModalScreen
|
||||||
|
from textual.widgets import Button, DataTable, Input, Static
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationSelectorScreen(ModalScreen[Optional[dict]]):
|
||||||
|
"""Modal screen for selecting a saved conversation."""
|
||||||
|
|
||||||
|
DEFAULT_CSS = """
|
||||||
|
ConversationSelectorScreen {
|
||||||
|
align: center middle;
|
||||||
|
}
|
||||||
|
|
||||||
|
ConversationSelectorScreen > Container {
|
||||||
|
width: 80%;
|
||||||
|
height: 70%;
|
||||||
|
background: #1e1e1e;
|
||||||
|
border: solid #555555;
|
||||||
|
layout: vertical;
|
||||||
|
}
|
||||||
|
|
||||||
|
ConversationSelectorScreen .header {
|
||||||
|
height: 3;
|
||||||
|
width: 100%;
|
||||||
|
background: #2d2d2d;
|
||||||
|
color: #cccccc;
|
||||||
|
padding: 0 2;
|
||||||
|
content-align: center middle;
|
||||||
|
}
|
||||||
|
|
||||||
|
ConversationSelectorScreen .search-input {
|
||||||
|
height: 3;
|
||||||
|
width: 100%;
|
||||||
|
background: #2a2a2a;
|
||||||
|
border: solid #555555;
|
||||||
|
margin: 0 0 1 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
ConversationSelectorScreen .search-input:focus {
|
||||||
|
border: solid #888888;
|
||||||
|
}
|
||||||
|
|
||||||
|
ConversationSelectorScreen DataTable {
|
||||||
|
height: 1fr;
|
||||||
|
width: 100%;
|
||||||
|
background: #1e1e1e;
|
||||||
|
border: solid #555555;
|
||||||
|
}
|
||||||
|
|
||||||
|
ConversationSelectorScreen .footer {
|
||||||
|
height: 5;
|
||||||
|
width: 100%;
|
||||||
|
background: #2d2d2d;
|
||||||
|
padding: 1 2;
|
||||||
|
align: center middle;
|
||||||
|
}
|
||||||
|
|
||||||
|
ConversationSelectorScreen Button {
|
||||||
|
margin: 0 1;
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, conversations: List[dict]):
|
||||||
|
super().__init__()
|
||||||
|
self.all_conversations = conversations
|
||||||
|
self.filtered_conversations = conversations
|
||||||
|
self.selected_conversation: Optional[dict] = None
|
||||||
|
|
||||||
|
def compose(self) -> ComposeResult:
|
||||||
|
"""Compose the screen."""
|
||||||
|
with Container():
|
||||||
|
yield Static(
|
||||||
|
f"[bold]Load Conversation[/] [dim]({len(self.all_conversations)} saved)[/]",
|
||||||
|
classes="header"
|
||||||
|
)
|
||||||
|
yield Input(placeholder="Search conversations...", id="search-input", classes="search-input")
|
||||||
|
yield DataTable(id="conv-table", cursor_type="row", show_header=True, zebra_stripes=True)
|
||||||
|
with Vertical(classes="footer"):
|
||||||
|
yield Button("Load", id="load", variant="success")
|
||||||
|
yield Button("Cancel", id="cancel", variant="error")
|
||||||
|
|
||||||
|
def on_mount(self) -> None:
|
||||||
|
"""Initialize the table when mounted."""
|
||||||
|
table = self.query_one("#conv-table", DataTable)
|
||||||
|
|
||||||
|
# Add columns
|
||||||
|
table.add_column("#", width=5)
|
||||||
|
table.add_column("Name", width=40)
|
||||||
|
table.add_column("Messages", width=12)
|
||||||
|
table.add_column("Last Saved", width=20)
|
||||||
|
|
||||||
|
# Populate table
|
||||||
|
self._populate_table()
|
||||||
|
|
||||||
|
# Focus table if list is small (fits on screen), otherwise focus search
|
||||||
|
if len(self.all_conversations) <= 10:
|
||||||
|
table.focus()
|
||||||
|
else:
|
||||||
|
search_input = self.query_one("#search-input", Input)
|
||||||
|
search_input.focus()
|
||||||
|
|
||||||
|
def _populate_table(self) -> None:
|
||||||
|
"""Populate the table with conversations."""
|
||||||
|
table = self.query_one("#conv-table", DataTable)
|
||||||
|
table.clear()
|
||||||
|
|
||||||
|
for idx, conv in enumerate(self.filtered_conversations, 1):
|
||||||
|
name = conv.get("name", "Unknown")
|
||||||
|
message_count = str(conv.get("message_count", 0))
|
||||||
|
last_saved = conv.get("last_saved", "Unknown")
|
||||||
|
|
||||||
|
# Format timestamp if it's a full datetime
|
||||||
|
if "T" in last_saved or len(last_saved) > 20:
|
||||||
|
try:
|
||||||
|
# Truncate to just date and time
|
||||||
|
last_saved = last_saved[:19].replace("T", " ")
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
table.add_row(
|
||||||
|
str(idx),
|
||||||
|
name,
|
||||||
|
message_count,
|
||||||
|
last_saved,
|
||||||
|
key=str(idx)
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_input_changed(self, event: Input.Changed) -> None:
|
||||||
|
"""Filter conversations based on search input."""
|
||||||
|
if event.input.id != "search-input":
|
||||||
|
return
|
||||||
|
|
||||||
|
search_term = event.value.lower()
|
||||||
|
|
||||||
|
if not search_term:
|
||||||
|
self.filtered_conversations = self.all_conversations
|
||||||
|
else:
|
||||||
|
self.filtered_conversations = [
|
||||||
|
c for c in self.all_conversations
|
||||||
|
if search_term in c.get("name", "").lower()
|
||||||
|
]
|
||||||
|
|
||||||
|
self._populate_table()
|
||||||
|
|
||||||
|
def on_data_table_row_selected(self, event: DataTable.RowSelected) -> None:
|
||||||
|
"""Handle row selection (click)."""
|
||||||
|
try:
|
||||||
|
row_index = int(event.row_key.value) - 1
|
||||||
|
if 0 <= row_index < len(self.filtered_conversations):
|
||||||
|
self.selected_conversation = self.filtered_conversations[row_index]
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_data_table_row_highlighted(self, event) -> None:
|
||||||
|
"""Handle row highlight (arrow key navigation)."""
|
||||||
|
try:
|
||||||
|
table = self.query_one("#conv-table", DataTable)
|
||||||
|
if table.cursor_row is not None:
|
||||||
|
row_data = table.get_row_at(table.cursor_row)
|
||||||
|
if row_data:
|
||||||
|
row_index = int(row_data[0]) - 1
|
||||||
|
if 0 <= row_index < len(self.filtered_conversations):
|
||||||
|
self.selected_conversation = self.filtered_conversations[row_index]
|
||||||
|
except (ValueError, IndexError, AttributeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||||
|
"""Handle button press."""
|
||||||
|
if event.button.id == "load":
|
||||||
|
if self.selected_conversation:
|
||||||
|
self.dismiss(self.selected_conversation)
|
||||||
|
else:
|
||||||
|
self.dismiss(None)
|
||||||
|
else:
|
||||||
|
self.dismiss(None)
|
||||||
|
|
||||||
|
def on_key(self, event) -> None:
|
||||||
|
"""Handle keyboard shortcuts."""
|
||||||
|
if event.key == "escape":
|
||||||
|
self.dismiss(None)
|
||||||
|
elif event.key == "enter":
|
||||||
|
# If in search input, move to table
|
||||||
|
search_input = self.query_one("#search-input", Input)
|
||||||
|
if search_input.has_focus:
|
||||||
|
table = self.query_one("#conv-table", DataTable)
|
||||||
|
table.focus()
|
||||||
|
# If in table, select current row
|
||||||
|
else:
|
||||||
|
table = self.query_one("#conv-table", DataTable)
|
||||||
|
if table.cursor_row is not None:
|
||||||
|
try:
|
||||||
|
row_data = table.get_row_at(table.cursor_row)
|
||||||
|
if row_data:
|
||||||
|
row_index = int(row_data[0]) - 1
|
||||||
|
if 0 <= row_index < len(self.filtered_conversations):
|
||||||
|
selected = self.filtered_conversations[row_index]
|
||||||
|
self.dismiss(selected)
|
||||||
|
except (ValueError, IndexError, AttributeError):
|
||||||
|
if self.selected_conversation:
|
||||||
|
self.dismiss(self.selected_conversation)
|
||||||
158
oai/tui/screens/credits_screen.py
Normal file
158
oai/tui/screens/credits_screen.py
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
"""Credits screen for oAI TUI."""
|
||||||
|
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
|
||||||
|
from textual.app import ComposeResult
|
||||||
|
from textual.containers import Container, Vertical
|
||||||
|
from textual.screen import ModalScreen
|
||||||
|
from textual.widgets import Button, Static
|
||||||
|
|
||||||
|
from oai.core.client import AIClient
|
||||||
|
|
||||||
|
|
||||||
|
class CreditsScreen(ModalScreen[None]):
|
||||||
|
"""Modal screen displaying account credits."""
|
||||||
|
|
||||||
|
DEFAULT_CSS = """
|
||||||
|
CreditsScreen {
|
||||||
|
align: center middle;
|
||||||
|
}
|
||||||
|
|
||||||
|
CreditsScreen > Container {
|
||||||
|
width: 60;
|
||||||
|
height: auto;
|
||||||
|
background: #1e1e1e;
|
||||||
|
border: solid #555555;
|
||||||
|
}
|
||||||
|
|
||||||
|
CreditsScreen .header {
|
||||||
|
dock: top;
|
||||||
|
width: 100%;
|
||||||
|
height: auto;
|
||||||
|
background: #2d2d2d;
|
||||||
|
color: #cccccc;
|
||||||
|
padding: 0 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
CreditsScreen .content {
|
||||||
|
width: 100%;
|
||||||
|
height: auto;
|
||||||
|
background: #1e1e1e;
|
||||||
|
padding: 2;
|
||||||
|
color: #cccccc;
|
||||||
|
}
|
||||||
|
|
||||||
|
CreditsScreen .footer {
|
||||||
|
dock: bottom;
|
||||||
|
width: 100%;
|
||||||
|
height: auto;
|
||||||
|
background: #2d2d2d;
|
||||||
|
padding: 1 2;
|
||||||
|
align: center middle;
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, client: AIClient):
|
||||||
|
super().__init__()
|
||||||
|
self.client = client
|
||||||
|
self.credits_data: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
def compose(self) -> ComposeResult:
|
||||||
|
"""Compose the screen."""
|
||||||
|
with Container():
|
||||||
|
yield Static("[bold]Account Credits[/]", classes="header")
|
||||||
|
with Vertical(classes="content"):
|
||||||
|
yield Static("[dim]Loading...[/]", id="credits-content", markup=True)
|
||||||
|
with Vertical(classes="footer"):
|
||||||
|
yield Button("Close", id="close", variant="primary")
|
||||||
|
|
||||||
|
def on_mount(self) -> None:
|
||||||
|
"""Fetch credits when mounted."""
|
||||||
|
self.fetch_credits()
|
||||||
|
|
||||||
|
def fetch_credits(self) -> None:
|
||||||
|
"""Fetch and display credits information."""
|
||||||
|
try:
|
||||||
|
self.credits_data = self.client.provider.get_credits()
|
||||||
|
content = self.query_one("#credits-content", Static)
|
||||||
|
content.update(self._get_credits_text())
|
||||||
|
except Exception as e:
|
||||||
|
content = self.query_one("#credits-content", Static)
|
||||||
|
content.update(f"[red]Error fetching credits:[/]\n{str(e)}")
|
||||||
|
|
||||||
|
def _get_credits_text(self) -> str:
|
||||||
|
"""Generate the credits text."""
|
||||||
|
if not self.credits_data:
|
||||||
|
# Provider-specific message when credits aren't available
|
||||||
|
if self.client.provider_name == "anthropic":
|
||||||
|
return """[yellow]Credit information not available via API[/]
|
||||||
|
|
||||||
|
Anthropic does not provide a public API endpoint for checking
|
||||||
|
account credits.
|
||||||
|
|
||||||
|
To view your account balance and usage:
|
||||||
|
[cyan]→ Visit console.anthropic.com[/]
|
||||||
|
[cyan]→ Navigate to Settings → Billing[/]
|
||||||
|
"""
|
||||||
|
elif self.client.provider_name == "openai":
|
||||||
|
return """[yellow]Credit information not available via API[/]
|
||||||
|
|
||||||
|
OpenAI does not provide credit balance through their API.
|
||||||
|
|
||||||
|
To view your account usage:
|
||||||
|
[cyan]→ Visit platform.openai.com[/]
|
||||||
|
[cyan]→ Navigate to Usage[/]
|
||||||
|
"""
|
||||||
|
elif self.client.provider_name == "ollama":
|
||||||
|
return """[yellow]Credit information not applicable[/]
|
||||||
|
|
||||||
|
Ollama runs locally on your machine and does not use credits.
|
||||||
|
"""
|
||||||
|
else:
|
||||||
|
return "[yellow]No credit information available[/]"
|
||||||
|
|
||||||
|
provider_name = self.client.provider_name.upper()
|
||||||
|
total = self.credits_data.get("total_credits")
|
||||||
|
used = self.credits_data.get("used_credits")
|
||||||
|
remaining = self.credits_data.get("credits_left", 0)
|
||||||
|
|
||||||
|
# Determine color based on absolute remaining amount
|
||||||
|
if remaining > 10:
|
||||||
|
remaining_color = "green"
|
||||||
|
elif remaining > 2:
|
||||||
|
remaining_color = "yellow"
|
||||||
|
else:
|
||||||
|
remaining_color = "red"
|
||||||
|
|
||||||
|
lines = [f"[bold cyan]═══ {provider_name} CREDITS ═══[/]\n"]
|
||||||
|
|
||||||
|
# If we have total/used info (OpenRouter)
|
||||||
|
if total is not None and used is not None:
|
||||||
|
percent_used = (used / total) * 100 if total > 0 else 0
|
||||||
|
percent_remaining = (remaining / total) * 100 if total > 0 else 0
|
||||||
|
|
||||||
|
lines.append(f"[bold]Total Credits:[/] ${total:.2f}")
|
||||||
|
lines.append(f"[bold]Used:[/] ${used:.2f} [dim]({percent_used:.1f}%)[/]")
|
||||||
|
lines.append(f"[bold]Remaining:[/] [{remaining_color}]${remaining:.2f}[/] [dim]({percent_remaining:.1f}%)[/]")
|
||||||
|
else:
|
||||||
|
# Anthropic - only shows balance
|
||||||
|
lines.append(f"[bold]Current Balance:[/] [{remaining_color}]${remaining:.2f}[/]")
|
||||||
|
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
# Provider-specific footer
|
||||||
|
if self.client.provider_name == "openrouter":
|
||||||
|
lines.append("[dim]Visit openrouter.ai to add more credits[/]")
|
||||||
|
elif self.client.provider_name == "anthropic":
|
||||||
|
lines.append("[dim]Visit console.anthropic.com to manage billing[/]")
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||||
|
"""Handle button press."""
|
||||||
|
self.dismiss()
|
||||||
|
|
||||||
|
def on_key(self, event) -> None:
|
||||||
|
"""Handle keyboard shortcuts."""
|
||||||
|
if event.key in ("escape", "enter"):
|
||||||
|
self.dismiss()
|
||||||
236
oai/tui/screens/dialogs.py
Normal file
236
oai/tui/screens/dialogs.py
Normal file
@@ -0,0 +1,236 @@
|
|||||||
|
"""Modal dialog screens for oAI TUI."""
|
||||||
|
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
from textual.app import ComposeResult
|
||||||
|
from textual.containers import Container, Horizontal, Vertical
|
||||||
|
from textual.screen import ModalScreen
|
||||||
|
from textual.widgets import Button, Input, Label, Static
|
||||||
|
|
||||||
|
|
||||||
|
class ConfirmDialog(ModalScreen[bool]):
|
||||||
|
"""A confirmation dialog with Yes/No buttons."""
|
||||||
|
|
||||||
|
DEFAULT_CSS = """
|
||||||
|
ConfirmDialog {
|
||||||
|
align: center middle;
|
||||||
|
}
|
||||||
|
|
||||||
|
ConfirmDialog > Container {
|
||||||
|
width: 60;
|
||||||
|
height: auto;
|
||||||
|
background: #2d2d2d;
|
||||||
|
border: solid #555555;
|
||||||
|
padding: 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
ConfirmDialog Label {
|
||||||
|
width: 100%;
|
||||||
|
content-align: center middle;
|
||||||
|
margin-bottom: 2;
|
||||||
|
color: #cccccc;
|
||||||
|
}
|
||||||
|
|
||||||
|
ConfirmDialog Horizontal {
|
||||||
|
width: 100%;
|
||||||
|
height: auto;
|
||||||
|
align: center middle;
|
||||||
|
}
|
||||||
|
|
||||||
|
ConfirmDialog Button {
|
||||||
|
margin: 0 1;
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
title: str = "Confirm",
|
||||||
|
yes_label: str = "Yes",
|
||||||
|
no_label: str = "No",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.message = message
|
||||||
|
self.title = title
|
||||||
|
self.yes_label = yes_label
|
||||||
|
self.no_label = no_label
|
||||||
|
|
||||||
|
def compose(self) -> ComposeResult:
|
||||||
|
"""Compose the dialog."""
|
||||||
|
with Container():
|
||||||
|
yield Static(f"[bold]{self.title}[/]", classes="dialog-title")
|
||||||
|
yield Label(self.message)
|
||||||
|
with Horizontal():
|
||||||
|
yield Button(self.yes_label, id="yes", variant="success")
|
||||||
|
yield Button(self.no_label, id="no", variant="error")
|
||||||
|
|
||||||
|
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||||
|
"""Handle button press."""
|
||||||
|
if event.button.id == "yes":
|
||||||
|
self.dismiss(True)
|
||||||
|
else:
|
||||||
|
self.dismiss(False)
|
||||||
|
|
||||||
|
def on_key(self, event) -> None:
|
||||||
|
"""Handle keyboard shortcuts."""
|
||||||
|
if event.key == "escape":
|
||||||
|
self.dismiss(False)
|
||||||
|
elif event.key == "enter":
|
||||||
|
self.dismiss(True)
|
||||||
|
|
||||||
|
|
||||||
|
class InputDialog(ModalScreen[Optional[str]]):
|
||||||
|
"""An input dialog for text entry."""
|
||||||
|
|
||||||
|
DEFAULT_CSS = """
|
||||||
|
InputDialog {
|
||||||
|
align: center middle;
|
||||||
|
}
|
||||||
|
|
||||||
|
InputDialog > Container {
|
||||||
|
width: 70;
|
||||||
|
height: auto;
|
||||||
|
background: #2d2d2d;
|
||||||
|
border: solid #555555;
|
||||||
|
padding: 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
InputDialog Label {
|
||||||
|
width: 100%;
|
||||||
|
margin-bottom: 1;
|
||||||
|
color: #cccccc;
|
||||||
|
}
|
||||||
|
|
||||||
|
InputDialog Input {
|
||||||
|
width: 100%;
|
||||||
|
margin-bottom: 2;
|
||||||
|
background: #3a3a3a;
|
||||||
|
border: solid #555555;
|
||||||
|
}
|
||||||
|
|
||||||
|
InputDialog Input:focus {
|
||||||
|
border: solid #888888;
|
||||||
|
}
|
||||||
|
|
||||||
|
InputDialog Horizontal {
|
||||||
|
width: 100%;
|
||||||
|
height: auto;
|
||||||
|
align: center middle;
|
||||||
|
}
|
||||||
|
|
||||||
|
InputDialog Button {
|
||||||
|
margin: 0 1;
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
title: str = "Input",
|
||||||
|
default: str = "",
|
||||||
|
placeholder: str = "",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.message = message
|
||||||
|
self.title = title
|
||||||
|
self.default = default
|
||||||
|
self.placeholder = placeholder
|
||||||
|
|
||||||
|
def compose(self) -> ComposeResult:
|
||||||
|
"""Compose the dialog."""
|
||||||
|
with Container():
|
||||||
|
yield Static(f"[bold]{self.title}[/]", classes="dialog-title")
|
||||||
|
yield Label(self.message)
|
||||||
|
yield Input(
|
||||||
|
value=self.default,
|
||||||
|
placeholder=self.placeholder,
|
||||||
|
id="input-field"
|
||||||
|
)
|
||||||
|
with Horizontal():
|
||||||
|
yield Button("OK", id="ok", variant="primary")
|
||||||
|
yield Button("Cancel", id="cancel")
|
||||||
|
|
||||||
|
def on_mount(self) -> None:
|
||||||
|
"""Focus the input field when mounted."""
|
||||||
|
input_field = self.query_one("#input-field", Input)
|
||||||
|
input_field.focus()
|
||||||
|
|
||||||
|
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||||
|
"""Handle button press."""
|
||||||
|
if event.button.id == "ok":
|
||||||
|
input_field = self.query_one("#input-field", Input)
|
||||||
|
self.dismiss(input_field.value)
|
||||||
|
else:
|
||||||
|
self.dismiss(None)
|
||||||
|
|
||||||
|
def on_input_submitted(self, event: Input.Submitted) -> None:
|
||||||
|
"""Handle Enter key in input field."""
|
||||||
|
self.dismiss(event.value)
|
||||||
|
|
||||||
|
def on_key(self, event) -> None:
|
||||||
|
"""Handle keyboard shortcuts."""
|
||||||
|
if event.key == "escape":
|
||||||
|
self.dismiss(None)
|
||||||
|
|
||||||
|
|
||||||
|
class AlertDialog(ModalScreen[None]):
|
||||||
|
"""A simple alert/message dialog."""
|
||||||
|
|
||||||
|
DEFAULT_CSS = """
|
||||||
|
AlertDialog {
|
||||||
|
align: center middle;
|
||||||
|
}
|
||||||
|
|
||||||
|
AlertDialog > Container {
|
||||||
|
width: 60;
|
||||||
|
height: auto;
|
||||||
|
background: #2d2d2d;
|
||||||
|
border: solid #555555;
|
||||||
|
padding: 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
AlertDialog Label {
|
||||||
|
width: 100%;
|
||||||
|
content-align: center middle;
|
||||||
|
margin-bottom: 2;
|
||||||
|
color: #cccccc;
|
||||||
|
}
|
||||||
|
|
||||||
|
AlertDialog Horizontal {
|
||||||
|
width: 100%;
|
||||||
|
height: auto;
|
||||||
|
align: center middle;
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, message: str, title: str = "Alert", variant: str = "default"):
|
||||||
|
super().__init__()
|
||||||
|
self.message = message
|
||||||
|
self.title = title
|
||||||
|
self.variant = variant
|
||||||
|
|
||||||
|
def compose(self) -> ComposeResult:
|
||||||
|
"""Compose the dialog."""
|
||||||
|
# Choose color based on variant (using design system)
|
||||||
|
color = "$primary"
|
||||||
|
if self.variant == "error":
|
||||||
|
color = "$error"
|
||||||
|
elif self.variant == "success":
|
||||||
|
color = "$success"
|
||||||
|
elif self.variant == "warning":
|
||||||
|
color = "$warning"
|
||||||
|
|
||||||
|
with Container():
|
||||||
|
yield Static(f"[bold {color}]{self.title}[/]", classes="dialog-title")
|
||||||
|
yield Label(self.message)
|
||||||
|
with Horizontal():
|
||||||
|
yield Button("OK", id="ok", variant="primary")
|
||||||
|
|
||||||
|
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||||
|
"""Handle button press."""
|
||||||
|
self.dismiss()
|
||||||
|
|
||||||
|
def on_key(self, event) -> None:
|
||||||
|
"""Handle keyboard shortcuts."""
|
||||||
|
if event.key in ("escape", "enter"):
|
||||||
|
self.dismiss()
|
||||||
140
oai/tui/screens/help_screen.py
Normal file
140
oai/tui/screens/help_screen.py
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
"""Help screen for oAI TUI."""
|
||||||
|
|
||||||
|
from textual.app import ComposeResult
|
||||||
|
from textual.containers import Container, Vertical
|
||||||
|
from textual.screen import ModalScreen
|
||||||
|
from textual.widgets import Button, Static
|
||||||
|
|
||||||
|
|
||||||
|
class HelpScreen(ModalScreen[None]):
|
||||||
|
"""Modal screen displaying help and commands."""
|
||||||
|
|
||||||
|
DEFAULT_CSS = """
|
||||||
|
HelpScreen {
|
||||||
|
align: center middle;
|
||||||
|
}
|
||||||
|
|
||||||
|
HelpScreen > Container {
|
||||||
|
width: 90%;
|
||||||
|
height: 85%;
|
||||||
|
background: #1e1e1e;
|
||||||
|
border: solid #555555;
|
||||||
|
}
|
||||||
|
|
||||||
|
HelpScreen .header {
|
||||||
|
dock: top;
|
||||||
|
width: 100%;
|
||||||
|
height: auto;
|
||||||
|
background: #2d2d2d;
|
||||||
|
color: #cccccc;
|
||||||
|
padding: 0 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
HelpScreen .content {
|
||||||
|
height: 1fr;
|
||||||
|
background: #1e1e1e;
|
||||||
|
padding: 2;
|
||||||
|
overflow-y: auto;
|
||||||
|
color: #cccccc;
|
||||||
|
}
|
||||||
|
|
||||||
|
HelpScreen .footer {
|
||||||
|
dock: bottom;
|
||||||
|
width: 100%;
|
||||||
|
height: auto;
|
||||||
|
background: #2d2d2d;
|
||||||
|
padding: 1 2;
|
||||||
|
align: center middle;
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def compose(self) -> ComposeResult:
|
||||||
|
"""Compose the screen."""
|
||||||
|
with Container():
|
||||||
|
yield Static("[bold]oAI Help & Commands[/]", classes="header")
|
||||||
|
with Vertical(classes="content"):
|
||||||
|
yield Static(self._get_help_text(), markup=True)
|
||||||
|
with Vertical(classes="footer"):
|
||||||
|
yield Button("Close", id="close", variant="primary")
|
||||||
|
|
||||||
|
def _get_help_text(self) -> str:
|
||||||
|
"""Generate the help text."""
|
||||||
|
return """
|
||||||
|
[bold cyan]═══ KEYBOARD SHORTCUTS ═══[/]
|
||||||
|
[bold]F1[/] Show this help (Ctrl+H may not work)
|
||||||
|
[bold]F2[/] Open model selector (Ctrl+M may not work)
|
||||||
|
[bold]F3[/] Copy last AI response to clipboard
|
||||||
|
[bold]Ctrl+S[/] Show session statistics
|
||||||
|
[bold]Ctrl+L[/] Clear chat display
|
||||||
|
[bold]Ctrl+P[/] Show previous message
|
||||||
|
[bold]Ctrl+N[/] Show next message
|
||||||
|
[bold]Ctrl+Y[/] Copy last AI response (alternative to F3)
|
||||||
|
[bold]Ctrl+Q[/] Quit application
|
||||||
|
[bold]Up/Down[/] Navigate input history
|
||||||
|
[bold]ESC[/] Close dialogs
|
||||||
|
[dim]Note: Some Ctrl keys may be captured by your terminal[/]
|
||||||
|
|
||||||
|
[bold cyan]═══ SLASH COMMANDS ═══[/]
|
||||||
|
[bold yellow]Session Control:[/]
|
||||||
|
/reset Clear conversation history (with confirmation)
|
||||||
|
/clear Clear the chat display
|
||||||
|
/memory on/off Toggle conversation memory
|
||||||
|
/online on/off Toggle online search mode
|
||||||
|
/exit, /quit, /bye Exit the application
|
||||||
|
|
||||||
|
[bold yellow]Model & Configuration:[/]
|
||||||
|
/model [search] Open model selector with optional search
|
||||||
|
/config View configuration settings
|
||||||
|
/config api Set API key (prompts for input)
|
||||||
|
/config stream on Enable streaming responses
|
||||||
|
/system [prompt] Set session system prompt
|
||||||
|
/maxtoken [n] Set session token limit
|
||||||
|
|
||||||
|
[bold yellow]Conversation Management:[/]
|
||||||
|
/save [name] Save current conversation
|
||||||
|
/load [name] Load saved conversation (shows picker if no name)
|
||||||
|
/list List all saved conversations
|
||||||
|
/delete <name> Delete a saved conversation
|
||||||
|
|
||||||
|
[bold yellow]Export:[/]
|
||||||
|
/export md [file] Export as Markdown
|
||||||
|
/export json [file] Export as JSON
|
||||||
|
/export html [file] Export as HTML
|
||||||
|
|
||||||
|
[bold yellow]History Navigation:[/]
|
||||||
|
/prev Show previous message in history
|
||||||
|
/next Show next message in history
|
||||||
|
|
||||||
|
[bold yellow]MCP (Model Context Protocol):[/]
|
||||||
|
/mcp on Enable MCP file access
|
||||||
|
/mcp off Disable MCP
|
||||||
|
/mcp status Show MCP status
|
||||||
|
/mcp add <path> Add folder for file access
|
||||||
|
/mcp list List registered folders
|
||||||
|
/mcp write Toggle write permissions
|
||||||
|
|
||||||
|
[bold yellow]Information & Utilities:[/]
|
||||||
|
/help Show this help screen
|
||||||
|
/stats Show session statistics
|
||||||
|
/credits Check account credits
|
||||||
|
/retry Retry last prompt
|
||||||
|
/paste Paste from clipboard and send
|
||||||
|
|
||||||
|
[bold cyan]═══ TIPS ═══[/]
|
||||||
|
• Type [bold]/[/] to see command suggestions with [bold]Tab[/] to autocomplete
|
||||||
|
• Use [bold]Up/Down arrows[/] to navigate your input history
|
||||||
|
• Type [bold]//[/] at start to escape commands (sends /help as literal message)
|
||||||
|
• All messages support [bold]Markdown formatting[/] with syntax highlighting
|
||||||
|
• Responses stream in real-time for better interactivity
|
||||||
|
• Enable MCP to let AI access your local files and databases
|
||||||
|
• Use [bold]F1[/] or [bold]F2[/] if Ctrl shortcuts don't work in your terminal
|
||||||
|
"""
|
||||||
|
|
||||||
|
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||||
|
"""Handle button press."""
|
||||||
|
self.dismiss()
|
||||||
|
|
||||||
|
def on_key(self, event) -> None:
|
||||||
|
"""Handle keyboard shortcuts."""
|
||||||
|
if event.key in ("escape", "enter"):
|
||||||
|
self.dismiss()
|
||||||
254
oai/tui/screens/model_selector.py
Normal file
254
oai/tui/screens/model_selector.py
Normal file
@@ -0,0 +1,254 @@
|
|||||||
|
"""Model selector screen for oAI TUI."""
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from textual.app import ComposeResult
|
||||||
|
from textual.containers import Container, Vertical
|
||||||
|
from textual.screen import ModalScreen
|
||||||
|
from textual.widgets import Button, DataTable, Input, Label, Static
|
||||||
|
|
||||||
|
|
||||||
|
class ModelSelectorScreen(ModalScreen[Optional[dict]]):
|
||||||
|
"""Modal screen for selecting an AI model."""
|
||||||
|
|
||||||
|
DEFAULT_CSS = """
|
||||||
|
ModelSelectorScreen {
|
||||||
|
align: center middle;
|
||||||
|
}
|
||||||
|
|
||||||
|
ModelSelectorScreen > Container {
|
||||||
|
width: 90%;
|
||||||
|
height: 85%;
|
||||||
|
background: #1e1e1e;
|
||||||
|
border: solid #555555;
|
||||||
|
layout: vertical;
|
||||||
|
}
|
||||||
|
|
||||||
|
ModelSelectorScreen .header {
|
||||||
|
height: 3;
|
||||||
|
width: 100%;
|
||||||
|
background: #2d2d2d;
|
||||||
|
color: #cccccc;
|
||||||
|
padding: 0 2;
|
||||||
|
content-align: center middle;
|
||||||
|
}
|
||||||
|
|
||||||
|
ModelSelectorScreen .search-input {
|
||||||
|
height: 3;
|
||||||
|
width: 100%;
|
||||||
|
background: #2a2a2a;
|
||||||
|
border: solid #555555;
|
||||||
|
margin: 0 0 1 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
ModelSelectorScreen .search-input:focus {
|
||||||
|
border: solid #888888;
|
||||||
|
}
|
||||||
|
|
||||||
|
ModelSelectorScreen DataTable {
|
||||||
|
height: 1fr;
|
||||||
|
width: 100%;
|
||||||
|
background: #1e1e1e;
|
||||||
|
border: solid #555555;
|
||||||
|
}
|
||||||
|
|
||||||
|
ModelSelectorScreen .footer {
|
||||||
|
height: 5;
|
||||||
|
width: 100%;
|
||||||
|
background: #2d2d2d;
|
||||||
|
padding: 1 2;
|
||||||
|
align: center middle;
|
||||||
|
}
|
||||||
|
|
||||||
|
ModelSelectorScreen Button {
|
||||||
|
margin: 0 1;
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, models: List[dict], current_model: Optional[str] = None):
|
||||||
|
super().__init__()
|
||||||
|
self.all_models = models
|
||||||
|
self.filtered_models = models
|
||||||
|
self.current_model = current_model
|
||||||
|
self.selected_model: Optional[dict] = None
|
||||||
|
|
||||||
|
def compose(self) -> ComposeResult:
|
||||||
|
"""Compose the screen."""
|
||||||
|
with Container():
|
||||||
|
yield Static(
|
||||||
|
f"[bold]Select Model[/] [dim]({len(self.all_models)} available)[/]",
|
||||||
|
classes="header"
|
||||||
|
)
|
||||||
|
yield Input(placeholder="Search to filter models...", id="search-input", classes="search-input")
|
||||||
|
yield DataTable(id="model-table", cursor_type="row", show_header=True, zebra_stripes=True)
|
||||||
|
with Vertical(classes="footer"):
|
||||||
|
yield Button("Select", id="select", variant="success")
|
||||||
|
yield Button("Cancel", id="cancel", variant="error")
|
||||||
|
|
||||||
|
def on_mount(self) -> None:
|
||||||
|
"""Initialize the table when mounted."""
|
||||||
|
table = self.query_one("#model-table", DataTable)
|
||||||
|
|
||||||
|
# Add columns
|
||||||
|
table.add_column("#", width=5)
|
||||||
|
table.add_column("Model ID", width=35)
|
||||||
|
table.add_column("Name", width=30)
|
||||||
|
table.add_column("Context", width=10)
|
||||||
|
table.add_column("Price", width=12)
|
||||||
|
table.add_column("Img", width=4)
|
||||||
|
table.add_column("Tools", width=6)
|
||||||
|
table.add_column("Online", width=7)
|
||||||
|
|
||||||
|
# Populate table
|
||||||
|
self._populate_table()
|
||||||
|
|
||||||
|
# Focus table if list is small (fits on screen), otherwise focus search
|
||||||
|
if len(self.filtered_models) <= 20:
|
||||||
|
table.focus()
|
||||||
|
else:
|
||||||
|
search_input = self.query_one("#search-input", Input)
|
||||||
|
search_input.focus()
|
||||||
|
|
||||||
|
def _populate_table(self) -> None:
|
||||||
|
"""Populate the table with models."""
|
||||||
|
table = self.query_one("#model-table", DataTable)
|
||||||
|
table.clear()
|
||||||
|
|
||||||
|
rows_added = 0
|
||||||
|
for idx, model in enumerate(self.filtered_models, 1):
|
||||||
|
try:
|
||||||
|
model_id = model.get("id", "")
|
||||||
|
name = model.get("name", "")
|
||||||
|
context = str(model.get("context_length", "N/A"))
|
||||||
|
|
||||||
|
# Format pricing
|
||||||
|
pricing = model.get("pricing", {})
|
||||||
|
prompt_price = pricing.get("prompt", "0")
|
||||||
|
completion_price = pricing.get("completion", "0")
|
||||||
|
|
||||||
|
# Convert to numbers and format
|
||||||
|
try:
|
||||||
|
prompt = float(prompt_price) * 1000000 # Convert to per 1M tokens
|
||||||
|
completion = float(completion_price) * 1000000
|
||||||
|
if prompt == 0 and completion == 0:
|
||||||
|
price = "Free"
|
||||||
|
else:
|
||||||
|
price = f"${prompt:.2f}/${completion:.2f}"
|
||||||
|
except:
|
||||||
|
price = "N/A"
|
||||||
|
|
||||||
|
# Check capabilities
|
||||||
|
architecture = model.get("architecture", {})
|
||||||
|
modality = architecture.get("modality", "")
|
||||||
|
supported_params = model.get("supported_parameters", [])
|
||||||
|
|
||||||
|
# Vision support: check if modality contains "image"
|
||||||
|
supports_vision = "image" in modality
|
||||||
|
|
||||||
|
# Tool support: check if "tools" or "tool_choice" in supported_parameters
|
||||||
|
supports_tools = "tools" in supported_params or "tool_choice" in supported_params
|
||||||
|
|
||||||
|
# Online support: check if model can use :online suffix (most models can)
|
||||||
|
# Models that already have :online in their ID support it
|
||||||
|
supports_online = ":online" in model_id or model_id not in ["openrouter/free"]
|
||||||
|
|
||||||
|
# Format capability indicators
|
||||||
|
img_indicator = "✓" if supports_vision else "-"
|
||||||
|
tools_indicator = "✓" if supports_tools else "-"
|
||||||
|
web_indicator = "✓" if supports_online else "-"
|
||||||
|
|
||||||
|
# Add row
|
||||||
|
table.add_row(
|
||||||
|
str(idx),
|
||||||
|
model_id,
|
||||||
|
name,
|
||||||
|
context,
|
||||||
|
price,
|
||||||
|
img_indicator,
|
||||||
|
tools_indicator,
|
||||||
|
web_indicator,
|
||||||
|
key=str(idx)
|
||||||
|
)
|
||||||
|
rows_added += 1
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
# Silently skip rows that fail to add
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_input_changed(self, event: Input.Changed) -> None:
|
||||||
|
"""Filter models based on search input."""
|
||||||
|
if event.input.id != "search-input":
|
||||||
|
return
|
||||||
|
|
||||||
|
search_term = event.value.lower()
|
||||||
|
|
||||||
|
if not search_term:
|
||||||
|
self.filtered_models = self.all_models
|
||||||
|
else:
|
||||||
|
self.filtered_models = [
|
||||||
|
m for m in self.all_models
|
||||||
|
if search_term in m.get("id", "").lower()
|
||||||
|
or search_term in m.get("name", "").lower()
|
||||||
|
]
|
||||||
|
|
||||||
|
self._populate_table()
|
||||||
|
|
||||||
|
def on_data_table_row_selected(self, event: DataTable.RowSelected) -> None:
|
||||||
|
"""Handle row selection (click or arrow navigation)."""
|
||||||
|
try:
|
||||||
|
row_index = int(event.row_key.value) - 1
|
||||||
|
if 0 <= row_index < len(self.filtered_models):
|
||||||
|
self.selected_model = self.filtered_models[row_index]
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_data_table_row_highlighted(self, event) -> None:
|
||||||
|
"""Handle row highlight (arrow key navigation)."""
|
||||||
|
try:
|
||||||
|
table = self.query_one("#model-table", DataTable)
|
||||||
|
if table.cursor_row is not None:
|
||||||
|
row_data = table.get_row_at(table.cursor_row)
|
||||||
|
if row_data:
|
||||||
|
row_index = int(row_data[0]) - 1
|
||||||
|
if 0 <= row_index < len(self.filtered_models):
|
||||||
|
self.selected_model = self.filtered_models[row_index]
|
||||||
|
except (ValueError, IndexError, AttributeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||||
|
"""Handle button press."""
|
||||||
|
if event.button.id == "select":
|
||||||
|
if self.selected_model:
|
||||||
|
self.dismiss(self.selected_model)
|
||||||
|
else:
|
||||||
|
# No selection, dismiss without result
|
||||||
|
self.dismiss(None)
|
||||||
|
else:
|
||||||
|
self.dismiss(None)
|
||||||
|
|
||||||
|
def on_key(self, event) -> None:
|
||||||
|
"""Handle keyboard shortcuts."""
|
||||||
|
if event.key == "escape":
|
||||||
|
self.dismiss(None)
|
||||||
|
elif event.key == "enter":
|
||||||
|
# If in search input, move to table
|
||||||
|
search_input = self.query_one("#search-input", Input)
|
||||||
|
if search_input.has_focus:
|
||||||
|
table = self.query_one("#model-table", DataTable)
|
||||||
|
table.focus()
|
||||||
|
# If in table or anywhere else, select current row
|
||||||
|
else:
|
||||||
|
table = self.query_one("#model-table", DataTable)
|
||||||
|
# Get the currently highlighted row
|
||||||
|
if table.cursor_row is not None:
|
||||||
|
try:
|
||||||
|
row_key = table.get_row_at(table.cursor_row)
|
||||||
|
if row_key:
|
||||||
|
row_index = int(row_key[0]) - 1
|
||||||
|
if 0 <= row_index < len(self.filtered_models):
|
||||||
|
selected = self.filtered_models[row_index]
|
||||||
|
self.dismiss(selected)
|
||||||
|
except (ValueError, IndexError, AttributeError):
|
||||||
|
# Fall back to previously selected model
|
||||||
|
if self.selected_model:
|
||||||
|
self.dismiss(self.selected_model)
|
||||||
129
oai/tui/screens/stats_screen.py
Normal file
129
oai/tui/screens/stats_screen.py
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
"""Statistics screen for oAI TUI."""
|
||||||
|
|
||||||
|
from textual.app import ComposeResult
|
||||||
|
from textual.containers import Container, Vertical
|
||||||
|
from textual.screen import ModalScreen
|
||||||
|
from textual.widgets import Button, Static
|
||||||
|
|
||||||
|
from oai.core.session import ChatSession
|
||||||
|
|
||||||
|
|
||||||
|
class StatsScreen(ModalScreen[None]):
|
||||||
|
"""Modal screen displaying session statistics."""
|
||||||
|
|
||||||
|
DEFAULT_CSS = """
|
||||||
|
StatsScreen {
|
||||||
|
align: center middle;
|
||||||
|
}
|
||||||
|
|
||||||
|
StatsScreen > Container {
|
||||||
|
width: 70;
|
||||||
|
height: auto;
|
||||||
|
background: #1e1e1e;
|
||||||
|
border: solid #555555;
|
||||||
|
}
|
||||||
|
|
||||||
|
StatsScreen .header {
|
||||||
|
dock: top;
|
||||||
|
width: 100%;
|
||||||
|
height: auto;
|
||||||
|
background: #2d2d2d;
|
||||||
|
color: #cccccc;
|
||||||
|
padding: 0 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
StatsScreen .content {
|
||||||
|
width: 100%;
|
||||||
|
height: auto;
|
||||||
|
background: #1e1e1e;
|
||||||
|
padding: 2;
|
||||||
|
color: #cccccc;
|
||||||
|
}
|
||||||
|
|
||||||
|
StatsScreen .footer {
|
||||||
|
dock: bottom;
|
||||||
|
width: 100%;
|
||||||
|
height: auto;
|
||||||
|
background: #2d2d2d;
|
||||||
|
padding: 1 2;
|
||||||
|
align: center middle;
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, session: ChatSession):
|
||||||
|
super().__init__()
|
||||||
|
self.session = session
|
||||||
|
|
||||||
|
def compose(self) -> ComposeResult:
|
||||||
|
"""Compose the screen."""
|
||||||
|
with Container():
|
||||||
|
yield Static("[bold]Session Statistics[/]", classes="header")
|
||||||
|
with Vertical(classes="content"):
|
||||||
|
yield Static(self._get_stats_text(), markup=True)
|
||||||
|
with Vertical(classes="footer"):
|
||||||
|
yield Button("Close", id="close", variant="primary")
|
||||||
|
|
||||||
|
def _get_stats_text(self) -> str:
|
||||||
|
"""Generate the statistics text."""
|
||||||
|
stats = self.session.stats
|
||||||
|
|
||||||
|
# Calculate averages
|
||||||
|
avg_input = stats.total_input_tokens // stats.message_count if stats.message_count > 0 else 0
|
||||||
|
avg_output = stats.total_output_tokens // stats.message_count if stats.message_count > 0 else 0
|
||||||
|
avg_cost = stats.total_cost / stats.message_count if stats.message_count > 0 else 0
|
||||||
|
|
||||||
|
# Get model info
|
||||||
|
model_name = "None"
|
||||||
|
model_context = "N/A"
|
||||||
|
if self.session.selected_model:
|
||||||
|
model_name = self.session.selected_model.get("name", "Unknown")
|
||||||
|
model_context = str(self.session.selected_model.get("context_length", "N/A"))
|
||||||
|
|
||||||
|
# MCP status
|
||||||
|
mcp_status = "Disabled"
|
||||||
|
if self.session.mcp_manager and self.session.mcp_manager.enabled:
|
||||||
|
mode = self.session.mcp_manager.mode
|
||||||
|
if mode == "files":
|
||||||
|
write = " (Write)" if self.session.mcp_manager.write_enabled else ""
|
||||||
|
mcp_status = f"Enabled - Files{write}"
|
||||||
|
elif mode == "database":
|
||||||
|
db_idx = self.session.mcp_manager.selected_db_index
|
||||||
|
if db_idx is not None:
|
||||||
|
db_name = self.session.mcp_manager.databases[db_idx]["name"]
|
||||||
|
mcp_status = f"Enabled - Database ({db_name})"
|
||||||
|
|
||||||
|
return f"""
|
||||||
|
[bold cyan]═══ SESSION INFO ═══[/]
|
||||||
|
[bold]Messages:[/] {stats.message_count}
|
||||||
|
[bold]Current Model:[/] {model_name}
|
||||||
|
[bold]Context Length:[/] {model_context}
|
||||||
|
[bold]Memory:[/] {"Enabled" if self.session.memory_enabled else "Disabled"}
|
||||||
|
[bold]Online Mode:[/] {"Enabled" if self.session.online_enabled else "Disabled"}
|
||||||
|
[bold]MCP:[/] {mcp_status}
|
||||||
|
|
||||||
|
[bold cyan]═══ TOKEN USAGE ═══[/]
|
||||||
|
[bold]Input Tokens:[/] {stats.total_input_tokens:,}
|
||||||
|
[bold]Output Tokens:[/] {stats.total_output_tokens:,}
|
||||||
|
[bold]Total Tokens:[/] {stats.total_tokens:,}
|
||||||
|
|
||||||
|
[bold]Avg Input/Msg:[/] {avg_input:,}
|
||||||
|
[bold]Avg Output/Msg:[/] {avg_output:,}
|
||||||
|
|
||||||
|
[bold cyan]═══ COSTS ═══[/]
|
||||||
|
[bold]Total Cost:[/] ${stats.total_cost:.6f}
|
||||||
|
[bold]Avg Cost/Msg:[/] ${avg_cost:.6f}
|
||||||
|
|
||||||
|
[bold cyan]═══ HISTORY ═══[/]
|
||||||
|
[bold]History Size:[/] {len(self.session.history)} entries
|
||||||
|
[bold]Current Index:[/] {self.session.current_index + 1 if self.session.history else 0}
|
||||||
|
[bold]Memory Start:[/] {self.session.memory_start_index + 1}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||||
|
"""Handle button press."""
|
||||||
|
self.dismiss()
|
||||||
|
|
||||||
|
def on_key(self, event) -> None:
|
||||||
|
"""Handle keyboard shortcuts."""
|
||||||
|
if event.key in ("escape", "enter"):
|
||||||
|
self.dismiss()
|
||||||
174
oai/tui/styles.tcss
Normal file
174
oai/tui/styles.tcss
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
/* Textual CSS for oAI TUI - Using Textual Design System */
|
||||||
|
|
||||||
|
Screen {
|
||||||
|
background: $background;
|
||||||
|
overflow: hidden;
|
||||||
|
}
|
||||||
|
|
||||||
|
Header {
|
||||||
|
dock: top;
|
||||||
|
height: auto;
|
||||||
|
background: #2d2d2d;
|
||||||
|
color: #cccccc;
|
||||||
|
padding: 0 1;
|
||||||
|
border-bottom: solid #555555;
|
||||||
|
}
|
||||||
|
|
||||||
|
ChatDisplay {
|
||||||
|
background: $background;
|
||||||
|
border: none;
|
||||||
|
padding: 1 0 1 1; /* top right bottom left - no right padding for scrollbar */
|
||||||
|
scrollbar-background: $background;
|
||||||
|
scrollbar-color: $primary;
|
||||||
|
scrollbar-size: 1 1;
|
||||||
|
scrollbar-gutter: stable; /* Reserve space for scrollbar */
|
||||||
|
overflow-y: auto;
|
||||||
|
}
|
||||||
|
|
||||||
|
UserMessageWidget {
|
||||||
|
margin: 0 0 1 0;
|
||||||
|
padding: 1;
|
||||||
|
background: $surface;
|
||||||
|
border-left: thick $success;
|
||||||
|
height: auto;
|
||||||
|
}
|
||||||
|
|
||||||
|
SystemMessageWidget {
|
||||||
|
margin: 0 0 1 0;
|
||||||
|
padding: 1;
|
||||||
|
background: #2a2a2a;
|
||||||
|
border-left: thick #888888;
|
||||||
|
height: auto;
|
||||||
|
color: #cccccc;
|
||||||
|
}
|
||||||
|
|
||||||
|
AssistantMessageWidget {
|
||||||
|
margin: 0 0 1 0;
|
||||||
|
padding: 1;
|
||||||
|
background: $background;
|
||||||
|
border-left: thick $accent;
|
||||||
|
height: auto;
|
||||||
|
}
|
||||||
|
|
||||||
|
#assistant-label {
|
||||||
|
margin-bottom: 1;
|
||||||
|
color: #cccccc;
|
||||||
|
}
|
||||||
|
|
||||||
|
#assistant-content {
|
||||||
|
height: auto;
|
||||||
|
max-height: 100%;
|
||||||
|
color: #cccccc;
|
||||||
|
link-color: #888888;
|
||||||
|
link-style: none;
|
||||||
|
border: none;
|
||||||
|
scrollbar-background: transparent;
|
||||||
|
scrollbar-color: #555555;
|
||||||
|
}
|
||||||
|
|
||||||
|
InputBar {
|
||||||
|
dock: bottom;
|
||||||
|
height: auto;
|
||||||
|
background: #2d2d2d;
|
||||||
|
align: center middle;
|
||||||
|
border-top: solid #555555;
|
||||||
|
padding: 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
#input-prefix {
|
||||||
|
width: auto;
|
||||||
|
padding: 0 1;
|
||||||
|
content-align: center middle;
|
||||||
|
color: #888888;
|
||||||
|
}
|
||||||
|
|
||||||
|
#input-prefix.prefix-hidden {
|
||||||
|
display: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
#chat-input {
|
||||||
|
width: 85%;
|
||||||
|
height: 5;
|
||||||
|
min-height: 5;
|
||||||
|
background: #3a3a3a;
|
||||||
|
border: none;
|
||||||
|
padding: 1 2;
|
||||||
|
color: #ffffff;
|
||||||
|
content-align: left top;
|
||||||
|
}
|
||||||
|
|
||||||
|
#chat-input:focus {
|
||||||
|
background: #404040;
|
||||||
|
}
|
||||||
|
|
||||||
|
#command-dropdown {
|
||||||
|
display: none;
|
||||||
|
dock: bottom;
|
||||||
|
offset-y: -5;
|
||||||
|
offset-x: 7.5%;
|
||||||
|
height: auto;
|
||||||
|
max-height: 12;
|
||||||
|
width: 85%;
|
||||||
|
background: #2d2d2d;
|
||||||
|
border: solid #555555;
|
||||||
|
padding: 0;
|
||||||
|
layer: overlay;
|
||||||
|
}
|
||||||
|
|
||||||
|
#command-dropdown.visible {
|
||||||
|
display: block;
|
||||||
|
}
|
||||||
|
|
||||||
|
#command-dropdown #command-list {
|
||||||
|
background: #2d2d2d;
|
||||||
|
border: none;
|
||||||
|
scrollbar-background: #2d2d2d;
|
||||||
|
scrollbar-color: #555555;
|
||||||
|
}
|
||||||
|
|
||||||
|
Footer {
|
||||||
|
dock: bottom;
|
||||||
|
height: auto;
|
||||||
|
background: #252525;
|
||||||
|
color: #888888;
|
||||||
|
padding: 0 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Button styles */
|
||||||
|
Button {
|
||||||
|
height: 3;
|
||||||
|
min-width: 10;
|
||||||
|
background: #3a3a3a;
|
||||||
|
color: #cccccc;
|
||||||
|
border: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
Button:hover {
|
||||||
|
background: #4a4a4a;
|
||||||
|
}
|
||||||
|
|
||||||
|
Button:focus {
|
||||||
|
background: #505050;
|
||||||
|
}
|
||||||
|
|
||||||
|
Button.-primary {
|
||||||
|
background: #3a3a3a;
|
||||||
|
}
|
||||||
|
|
||||||
|
Button.-success {
|
||||||
|
background: #2d5016;
|
||||||
|
color: #90ee90;
|
||||||
|
}
|
||||||
|
|
||||||
|
Button.-success:hover {
|
||||||
|
background: #3a6b1e;
|
||||||
|
}
|
||||||
|
|
||||||
|
Button.-error {
|
||||||
|
background: #5a1a1a;
|
||||||
|
color: #ff6b6b;
|
||||||
|
}
|
||||||
|
|
||||||
|
Button.-error:hover {
|
||||||
|
background: #6e2222;
|
||||||
|
}
|
||||||
17
oai/tui/widgets/__init__.py
Normal file
17
oai/tui/widgets/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
"""TUI widgets for oAI."""
|
||||||
|
|
||||||
|
from oai.tui.widgets.chat_display import ChatDisplay
|
||||||
|
from oai.tui.widgets.footer import Footer
|
||||||
|
from oai.tui.widgets.header import Header
|
||||||
|
from oai.tui.widgets.input_bar import InputBar
|
||||||
|
from oai.tui.widgets.message import AssistantMessageWidget, SystemMessageWidget, UserMessageWidget
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ChatDisplay",
|
||||||
|
"Footer",
|
||||||
|
"Header",
|
||||||
|
"InputBar",
|
||||||
|
"UserMessageWidget",
|
||||||
|
"SystemMessageWidget",
|
||||||
|
"AssistantMessageWidget",
|
||||||
|
]
|
||||||
21
oai/tui/widgets/chat_display.py
Normal file
21
oai/tui/widgets/chat_display.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
"""Chat display widget for oAI TUI."""
|
||||||
|
|
||||||
|
from textual.containers import ScrollableContainer
|
||||||
|
from textual.widgets import Static
|
||||||
|
|
||||||
|
|
||||||
|
class ChatDisplay(ScrollableContainer):
|
||||||
|
"""Scrollable container for chat messages."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(id="chat-display")
|
||||||
|
|
||||||
|
async def add_message(self, widget: Static) -> None:
|
||||||
|
"""Add a message widget to the display."""
|
||||||
|
await self.mount(widget)
|
||||||
|
self.scroll_end(animate=False)
|
||||||
|
|
||||||
|
def clear_messages(self) -> None:
|
||||||
|
"""Clear all messages from the display."""
|
||||||
|
for child in list(self.children):
|
||||||
|
child.remove()
|
||||||
214
oai/tui/widgets/command_dropdown.py
Normal file
214
oai/tui/widgets/command_dropdown.py
Normal file
@@ -0,0 +1,214 @@
|
|||||||
|
"""Command dropdown menu for TUI input."""
|
||||||
|
|
||||||
|
from textual.app import ComposeResult
|
||||||
|
from textual.containers import VerticalScroll
|
||||||
|
from textual.widget import Widget
|
||||||
|
from textual.widgets import Label, OptionList
|
||||||
|
from textual.widgets.option_list import Option
|
||||||
|
|
||||||
|
from oai.commands import registry
|
||||||
|
|
||||||
|
|
||||||
|
class CommandDropdown(VerticalScroll):
|
||||||
|
"""Dropdown menu showing available commands."""
|
||||||
|
|
||||||
|
DEFAULT_CSS = """
|
||||||
|
CommandDropdown {
|
||||||
|
display: none;
|
||||||
|
height: auto;
|
||||||
|
max-height: 12;
|
||||||
|
width: 80;
|
||||||
|
background: #2d2d2d;
|
||||||
|
border: solid #555555;
|
||||||
|
padding: 0;
|
||||||
|
layer: overlay;
|
||||||
|
}
|
||||||
|
|
||||||
|
CommandDropdown.visible {
|
||||||
|
display: block;
|
||||||
|
}
|
||||||
|
|
||||||
|
CommandDropdown OptionList {
|
||||||
|
height: auto;
|
||||||
|
max-height: 12;
|
||||||
|
background: #2d2d2d;
|
||||||
|
border: none;
|
||||||
|
padding: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
CommandDropdown OptionList > .option-list--option {
|
||||||
|
padding: 0 2;
|
||||||
|
color: #cccccc;
|
||||||
|
background: transparent;
|
||||||
|
}
|
||||||
|
|
||||||
|
CommandDropdown OptionList > .option-list--option-highlighted {
|
||||||
|
background: #3e3e3e;
|
||||||
|
color: #ffffff;
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize the command dropdown."""
|
||||||
|
super().__init__(id="command-dropdown")
|
||||||
|
self._all_commands = []
|
||||||
|
self._load_commands()
|
||||||
|
|
||||||
|
def _load_commands(self) -> None:
|
||||||
|
"""Load all available commands."""
|
||||||
|
# Get base commands with descriptions
|
||||||
|
base_commands = [
|
||||||
|
("/help", "Show help screen"),
|
||||||
|
("/commands", "Show all commands"),
|
||||||
|
("/model", "Select AI model"),
|
||||||
|
("/provider", "Switch AI provider"),
|
||||||
|
("/provider openrouter", "Switch to OpenRouter"),
|
||||||
|
("/provider anthropic", "Switch to Anthropic (Claude)"),
|
||||||
|
("/provider openai", "Switch to OpenAI (GPT)"),
|
||||||
|
("/provider ollama", "Switch to Ollama (local)"),
|
||||||
|
("/stats", "Show session statistics"),
|
||||||
|
("/credits", "Check account credits or view console link"),
|
||||||
|
("/clear", "Clear chat display"),
|
||||||
|
("/reset", "Reset conversation history"),
|
||||||
|
("/memory on", "Enable conversation memory"),
|
||||||
|
("/memory off", "Disable memory"),
|
||||||
|
("/online on", "Enable online search"),
|
||||||
|
("/online off", "Disable online search"),
|
||||||
|
("/save", "Save current conversation"),
|
||||||
|
("/load", "Load saved conversation"),
|
||||||
|
("/list", "List saved conversations"),
|
||||||
|
("/delete", "Delete a conversation"),
|
||||||
|
("/export md", "Export as Markdown"),
|
||||||
|
("/export json", "Export as JSON"),
|
||||||
|
("/export html", "Export as HTML"),
|
||||||
|
("/prev", "Show previous message"),
|
||||||
|
("/next", "Show next message"),
|
||||||
|
("/config", "View configuration"),
|
||||||
|
("/config provider", "Set default provider"),
|
||||||
|
("/config search_provider", "Set search provider (anthropic_native/duckduckgo/google)"),
|
||||||
|
("/config openrouter_api_key", "Set OpenRouter API key"),
|
||||||
|
("/config anthropic_api_key", "Set Anthropic API key"),
|
||||||
|
("/config openai_api_key", "Set OpenAI API key"),
|
||||||
|
("/config ollama_base_url", "Set Ollama server URL"),
|
||||||
|
("/config google_api_key", "Set Google API key"),
|
||||||
|
("/config google_search_engine_id", "Set Google Search Engine ID"),
|
||||||
|
("/config online", "Set default online mode (on/off)"),
|
||||||
|
("/config stream", "Toggle streaming (on/off)"),
|
||||||
|
("/config model", "Set default model"),
|
||||||
|
("/config system", "Set system prompt"),
|
||||||
|
("/config maxtoken", "Set token limit"),
|
||||||
|
("/system", "Set system prompt"),
|
||||||
|
("/maxtoken", "Set token limit"),
|
||||||
|
("/retry", "Retry last prompt"),
|
||||||
|
("/paste", "Paste from clipboard"),
|
||||||
|
("/mcp on", "Enable MCP file access"),
|
||||||
|
("/mcp off", "Disable MCP"),
|
||||||
|
("/mcp status", "Show MCP status"),
|
||||||
|
("/mcp add", "Add folder/database"),
|
||||||
|
("/mcp remove", "Remove folder/database"),
|
||||||
|
("/mcp list", "List folders"),
|
||||||
|
("/mcp write on", "Enable write mode"),
|
||||||
|
("/mcp write off", "Disable write mode"),
|
||||||
|
("/mcp files", "Switch to file mode"),
|
||||||
|
("/mcp db list", "List databases"),
|
||||||
|
]
|
||||||
|
|
||||||
|
self._all_commands = base_commands
|
||||||
|
|
||||||
|
def compose(self) -> ComposeResult:
|
||||||
|
"""Compose the dropdown."""
|
||||||
|
yield OptionList(id="command-list")
|
||||||
|
|
||||||
|
def show_commands(self, filter_text: str = "") -> None:
|
||||||
|
"""Show commands matching the filter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filter_text: Text to filter commands by
|
||||||
|
"""
|
||||||
|
option_list = self.query_one("#command-list", OptionList)
|
||||||
|
option_list.clear_options()
|
||||||
|
|
||||||
|
if not filter_text.startswith("/"):
|
||||||
|
self.remove_class("visible")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Remove the leading slash for filtering
|
||||||
|
filter_without_slash = filter_text[1:].lower()
|
||||||
|
|
||||||
|
# Filter commands
|
||||||
|
if filter_without_slash:
|
||||||
|
matching = []
|
||||||
|
|
||||||
|
# Check if user typed a parent command (e.g., "/config")
|
||||||
|
# If so, show all sub-commands that start with it
|
||||||
|
parent_command = "/" + filter_without_slash
|
||||||
|
has_subcommands = any(
|
||||||
|
cmd.startswith(parent_command + " ")
|
||||||
|
for cmd, _ in self._all_commands
|
||||||
|
)
|
||||||
|
|
||||||
|
if has_subcommands and parent_command in [cmd for cmd, _ in self._all_commands]:
|
||||||
|
# Show the parent command and all its sub-commands
|
||||||
|
matching = [
|
||||||
|
(cmd, desc) for cmd, desc in self._all_commands
|
||||||
|
if cmd == parent_command or cmd.startswith(parent_command + " ")
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
# Regular filtering - show if filter text is contained anywhere in the command
|
||||||
|
matching = [
|
||||||
|
(cmd, desc) for cmd, desc in self._all_commands
|
||||||
|
if filter_without_slash in cmd[1:].lower() # Skip the / in command for matching
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
# Show all commands when just "/" is typed
|
||||||
|
matching = self._all_commands
|
||||||
|
|
||||||
|
if not matching:
|
||||||
|
self.remove_class("visible")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Add options - limit to 15 results (increased from 10 for sub-commands)
|
||||||
|
for cmd, desc in matching[:15]:
|
||||||
|
# Format: command in white, description in gray, separated by spaces
|
||||||
|
label = f"{cmd} [dim]{desc}[/]" if desc else cmd
|
||||||
|
option_list.add_option(Option(label, id=cmd))
|
||||||
|
|
||||||
|
self.add_class("visible")
|
||||||
|
|
||||||
|
# Auto-select first option
|
||||||
|
if len(option_list._options) > 0:
|
||||||
|
option_list.highlighted = 0
|
||||||
|
|
||||||
|
def hide(self) -> None:
|
||||||
|
"""Hide the dropdown."""
|
||||||
|
self.remove_class("visible")
|
||||||
|
|
||||||
|
def get_selected_command(self) -> str | None:
|
||||||
|
"""Get the currently selected command.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Selected command text or None
|
||||||
|
"""
|
||||||
|
option_list = self.query_one("#command-list", OptionList)
|
||||||
|
if option_list.highlighted is not None:
|
||||||
|
option = option_list.get_option_at_index(option_list.highlighted)
|
||||||
|
return option.id
|
||||||
|
return None
|
||||||
|
|
||||||
|
def move_selection_up(self) -> None:
|
||||||
|
"""Move selection up in the list."""
|
||||||
|
option_list = self.query_one("#command-list", OptionList)
|
||||||
|
if option_list.option_count > 0:
|
||||||
|
if option_list.highlighted is None:
|
||||||
|
option_list.highlighted = option_list.option_count - 1
|
||||||
|
elif option_list.highlighted > 0:
|
||||||
|
option_list.highlighted -= 1
|
||||||
|
|
||||||
|
def move_selection_down(self) -> None:
|
||||||
|
"""Move selection down in the list."""
|
||||||
|
option_list = self.query_one("#command-list", OptionList)
|
||||||
|
if option_list.option_count > 0:
|
||||||
|
if option_list.highlighted is None:
|
||||||
|
option_list.highlighted = 0
|
||||||
|
elif option_list.highlighted < option_list.option_count - 1:
|
||||||
|
option_list.highlighted += 1
|
||||||
58
oai/tui/widgets/command_suggester.py
Normal file
58
oai/tui/widgets/command_suggester.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
"""Command suggester for TUI input."""
|
||||||
|
|
||||||
|
from typing import Iterable
|
||||||
|
|
||||||
|
from textual.suggester import Suggester
|
||||||
|
|
||||||
|
from oai.commands import registry
|
||||||
|
|
||||||
|
|
||||||
|
class CommandSuggester(Suggester):
|
||||||
|
"""Suggester that provides command completions."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize the command suggester."""
|
||||||
|
super().__init__(use_cache=False, case_sensitive=False)
|
||||||
|
# Get all command names from registry
|
||||||
|
self._commands = []
|
||||||
|
self._update_commands()
|
||||||
|
|
||||||
|
def _update_commands(self) -> None:
|
||||||
|
"""Update the list of available commands."""
|
||||||
|
# Get all registered command names
|
||||||
|
command_names = registry.get_all_names()
|
||||||
|
# Add common MCP subcommands for better UX
|
||||||
|
mcp_subcommands = [
|
||||||
|
"/mcp on",
|
||||||
|
"/mcp off",
|
||||||
|
"/mcp status",
|
||||||
|
"/mcp add",
|
||||||
|
"/mcp remove",
|
||||||
|
"/mcp list",
|
||||||
|
"/mcp write on",
|
||||||
|
"/mcp write off",
|
||||||
|
"/mcp files",
|
||||||
|
"/mcp db list",
|
||||||
|
]
|
||||||
|
self._commands = command_names + mcp_subcommands
|
||||||
|
|
||||||
|
async def get_suggestion(self, value: str) -> str | None:
|
||||||
|
"""Get a command suggestion based on the current input.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: Current input value
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Suggested completion or None
|
||||||
|
"""
|
||||||
|
if not value or not value.startswith("/"):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Find the first command that starts with the input
|
||||||
|
value_lower = value.lower()
|
||||||
|
for cmd in self._commands:
|
||||||
|
if cmd.lower().startswith(value_lower) and cmd.lower() != value_lower:
|
||||||
|
# Return the rest of the command (after what's already typed)
|
||||||
|
return cmd[len(value):]
|
||||||
|
|
||||||
|
return None
|
||||||
39
oai/tui/widgets/footer.py
Normal file
39
oai/tui/widgets/footer.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
"""Footer widget for oAI TUI."""
|
||||||
|
|
||||||
|
from textual.app import ComposeResult
|
||||||
|
from textual.widgets import Static
|
||||||
|
|
||||||
|
|
||||||
|
class Footer(Static):
|
||||||
|
"""Footer displaying session metrics."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.tokens_in = 0
|
||||||
|
self.tokens_out = 0
|
||||||
|
self.cost = 0.0
|
||||||
|
self.messages = 0
|
||||||
|
|
||||||
|
def compose(self) -> ComposeResult:
|
||||||
|
"""Compose the footer."""
|
||||||
|
yield Static(self._format_footer(), id="footer-content")
|
||||||
|
|
||||||
|
def _format_footer(self) -> str:
|
||||||
|
"""Format the footer text."""
|
||||||
|
return (
|
||||||
|
f"[dim]Messages: {self.messages} | "
|
||||||
|
f"Tokens: {self.tokens_in + self.tokens_out:,} "
|
||||||
|
f"({self.tokens_in:,} in, {self.tokens_out:,} out) | "
|
||||||
|
f"Cost: ${self.cost:.4f}[/]"
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_stats(
|
||||||
|
self, tokens_in: int, tokens_out: int, cost: float, messages: int
|
||||||
|
) -> None:
|
||||||
|
"""Update the displayed statistics."""
|
||||||
|
self.tokens_in = tokens_in
|
||||||
|
self.tokens_out = tokens_out
|
||||||
|
self.cost = cost
|
||||||
|
self.messages = messages
|
||||||
|
content = self.query_one("#footer-content", Static)
|
||||||
|
content.update(self._format_footer())
|
||||||
83
oai/tui/widgets/header.py
Normal file
83
oai/tui/widgets/header.py
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
"""Header widget for oAI TUI."""
|
||||||
|
|
||||||
|
from textual.app import ComposeResult
|
||||||
|
from textual.widgets import Static
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
|
||||||
|
|
||||||
|
class Header(Static):
|
||||||
|
"""Header displaying app title, version, current model, and capabilities."""
|
||||||
|
|
||||||
|
def __init__(self, version: str = "3.0.1", model: str = "", model_info: Optional[Dict[str, Any]] = None, provider: str = ""):
|
||||||
|
super().__init__()
|
||||||
|
self.version = version
|
||||||
|
self.model = model
|
||||||
|
self.model_info = model_info or {}
|
||||||
|
self.provider = provider
|
||||||
|
|
||||||
|
def compose(self) -> ComposeResult:
|
||||||
|
"""Compose the header."""
|
||||||
|
yield Static(self._format_header(), id="header-content")
|
||||||
|
|
||||||
|
def _format_capabilities(self) -> str:
|
||||||
|
"""Format capability icons based on model info."""
|
||||||
|
if not self.model_info:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
icons = []
|
||||||
|
|
||||||
|
# Check vision support
|
||||||
|
architecture = self.model_info.get("architecture", {})
|
||||||
|
modality = architecture.get("modality", "")
|
||||||
|
if "image" in modality:
|
||||||
|
icons.append("[bold cyan]👁️[/]") # Bright if supported
|
||||||
|
else:
|
||||||
|
icons.append("[dim]👁️[/]") # Dim if not supported
|
||||||
|
|
||||||
|
# Check tool support
|
||||||
|
supported_params = self.model_info.get("supported_parameters", [])
|
||||||
|
if "tools" in supported_params or "tool_choice" in supported_params:
|
||||||
|
icons.append("[bold cyan]🔧[/]")
|
||||||
|
else:
|
||||||
|
icons.append("[dim]🔧[/]")
|
||||||
|
|
||||||
|
# Check online support (most models support :online suffix)
|
||||||
|
model_id = self.model_info.get("id", "")
|
||||||
|
if ":online" in model_id or model_id not in ["openrouter/free"]:
|
||||||
|
icons.append("[bold cyan]🌐[/]")
|
||||||
|
else:
|
||||||
|
icons.append("[dim]🌐[/]")
|
||||||
|
|
||||||
|
return " ".join(icons) if icons else ""
|
||||||
|
|
||||||
|
def _format_header(self) -> str:
|
||||||
|
"""Format the header text."""
|
||||||
|
# Show provider : model format
|
||||||
|
if self.provider and self.model:
|
||||||
|
provider_model = f"[bold cyan]{self.provider}[/] [dim]:[/] [bold]{self.model}[/]"
|
||||||
|
elif self.provider:
|
||||||
|
provider_model = f"[bold cyan]{self.provider}[/]"
|
||||||
|
elif self.model:
|
||||||
|
provider_model = f"[bold]{self.model}[/]"
|
||||||
|
else:
|
||||||
|
provider_model = ""
|
||||||
|
|
||||||
|
capabilities = self._format_capabilities()
|
||||||
|
capabilities_text = f" {capabilities}" if capabilities else ""
|
||||||
|
|
||||||
|
# Format: oAI v{version} | provider : model capabilities
|
||||||
|
version_text = f"[bold cyan]oAI[/] [dim]v{self.version}[/]"
|
||||||
|
if provider_model:
|
||||||
|
return f"{version_text} [dim]|[/] {provider_model}{capabilities_text}"
|
||||||
|
else:
|
||||||
|
return version_text
|
||||||
|
|
||||||
|
def update_model(self, model: str, model_info: Optional[Dict[str, Any]] = None, provider: Optional[str] = None) -> None:
|
||||||
|
"""Update the displayed model and capabilities."""
|
||||||
|
self.model = model
|
||||||
|
if model_info:
|
||||||
|
self.model_info = model_info
|
||||||
|
if provider is not None:
|
||||||
|
self.provider = provider
|
||||||
|
content = self.query_one("#header-content", Static)
|
||||||
|
content.update(self._format_header())
|
||||||
49
oai/tui/widgets/input_bar.py
Normal file
49
oai/tui/widgets/input_bar.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
"""Input bar widget for oAI TUI."""
|
||||||
|
|
||||||
|
from textual.app import ComposeResult
|
||||||
|
from textual.containers import Horizontal
|
||||||
|
from textual.widgets import Input, Static
|
||||||
|
|
||||||
|
|
||||||
|
class InputBar(Horizontal):
|
||||||
|
"""Input bar with prompt prefix and text input."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(id="input-bar")
|
||||||
|
self.mcp_status = ""
|
||||||
|
self.online_mode = False
|
||||||
|
|
||||||
|
def compose(self) -> ComposeResult:
|
||||||
|
"""Compose the input bar."""
|
||||||
|
yield Static(self._format_prefix(), id="input-prefix", classes="prefix-hidden" if not (self.mcp_status or self.online_mode) else "")
|
||||||
|
yield Input(
|
||||||
|
placeholder="Type a message or /command...",
|
||||||
|
id="chat-input"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _format_prefix(self) -> str:
|
||||||
|
"""Format the input prefix with status indicators."""
|
||||||
|
indicators = []
|
||||||
|
if self.mcp_status:
|
||||||
|
indicators.append(f"[cyan]{self.mcp_status}[/]")
|
||||||
|
if self.online_mode:
|
||||||
|
indicators.append("[green]🌐[/]")
|
||||||
|
|
||||||
|
prefix = " ".join(indicators) + " " if indicators else ""
|
||||||
|
return f"{prefix}[bold]>[/]"
|
||||||
|
|
||||||
|
def update_mcp_status(self, status: str) -> None:
|
||||||
|
"""Update MCP status indicator."""
|
||||||
|
self.mcp_status = status
|
||||||
|
prefix = self.query_one("#input-prefix", Static)
|
||||||
|
prefix.update(self._format_prefix())
|
||||||
|
|
||||||
|
def update_online_mode(self, online: bool) -> None:
|
||||||
|
"""Update online mode indicator."""
|
||||||
|
self.online_mode = online
|
||||||
|
prefix = self.query_one("#input-prefix", Static)
|
||||||
|
prefix.update(self._format_prefix())
|
||||||
|
|
||||||
|
def get_input(self) -> Input:
|
||||||
|
"""Get the input widget."""
|
||||||
|
return self.query_one("#chat-input", Input)
|
||||||
98
oai/tui/widgets/message.py
Normal file
98
oai/tui/widgets/message.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
"""Message widgets for oAI TUI."""
|
||||||
|
|
||||||
|
from typing import Any, AsyncIterator, Tuple
|
||||||
|
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.markdown import Markdown
|
||||||
|
from rich.style import Style
|
||||||
|
from rich.theme import Theme
|
||||||
|
from textual.app import ComposeResult
|
||||||
|
from textual.widgets import RichLog, Static
|
||||||
|
|
||||||
|
# Custom theme for Markdown rendering - neutral colors matching the dark theme
|
||||||
|
MARKDOWN_THEME = Theme({
|
||||||
|
"markdown.text": Style(color="#cccccc"),
|
||||||
|
"markdown.paragraph": Style(color="#cccccc"),
|
||||||
|
"markdown.code": Style(color="#e0e0e0", bgcolor="#2a2a2a"),
|
||||||
|
"markdown.code_block": Style(color="#e0e0e0", bgcolor="#2a2a2a"),
|
||||||
|
"markdown.heading": Style(color="#ffffff", bold=True),
|
||||||
|
"markdown.h1": Style(color="#ffffff", bold=True),
|
||||||
|
"markdown.h2": Style(color="#eeeeee", bold=True),
|
||||||
|
"markdown.h3": Style(color="#dddddd", bold=True),
|
||||||
|
"markdown.link": Style(color="#aaaaaa", underline=False),
|
||||||
|
"markdown.link_url": Style(color="#888888"),
|
||||||
|
"markdown.emphasis": Style(color="#cccccc", italic=True),
|
||||||
|
"markdown.strong": Style(color="#ffffff", bold=True),
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
class UserMessageWidget(Static):
|
||||||
|
"""Widget for displaying user messages."""
|
||||||
|
|
||||||
|
def __init__(self, content: str):
|
||||||
|
super().__init__()
|
||||||
|
self.content = content
|
||||||
|
|
||||||
|
def compose(self) -> ComposeResult:
|
||||||
|
"""Compose the user message."""
|
||||||
|
yield Static(f"[bold green]You:[/] {self.content}")
|
||||||
|
|
||||||
|
|
||||||
|
class SystemMessageWidget(Static):
|
||||||
|
"""Widget for displaying system/info messages without 'You:' prefix."""
|
||||||
|
|
||||||
|
def __init__(self, content: str):
|
||||||
|
super().__init__()
|
||||||
|
self.content = content
|
||||||
|
|
||||||
|
def compose(self) -> ComposeResult:
|
||||||
|
"""Compose the system message."""
|
||||||
|
yield Static(self.content)
|
||||||
|
|
||||||
|
|
||||||
|
class AssistantMessageWidget(Static):
|
||||||
|
"""Widget for displaying assistant responses with streaming support."""
|
||||||
|
|
||||||
|
def __init__(self, model_name: str = "Assistant", chat_display=None):
|
||||||
|
super().__init__()
|
||||||
|
self.model_name = model_name
|
||||||
|
self.full_text = ""
|
||||||
|
self.chat_display = chat_display
|
||||||
|
|
||||||
|
def compose(self) -> ComposeResult:
|
||||||
|
"""Compose the assistant message."""
|
||||||
|
yield Static(f"[bold]{self.model_name}:[/]", id="assistant-label")
|
||||||
|
yield RichLog(id="assistant-content", highlight=True, markup=True, wrap=True)
|
||||||
|
|
||||||
|
async def stream_response(self, response_iterator: AsyncIterator) -> Tuple[str, Any]:
|
||||||
|
"""Stream tokens progressively and return final text and usage."""
|
||||||
|
log = self.query_one("#assistant-content", RichLog)
|
||||||
|
self.full_text = ""
|
||||||
|
usage = None
|
||||||
|
|
||||||
|
async for chunk in response_iterator:
|
||||||
|
if hasattr(chunk, "delta_content") and chunk.delta_content:
|
||||||
|
self.full_text += chunk.delta_content
|
||||||
|
log.clear()
|
||||||
|
# Use neutral code theme for syntax highlighting
|
||||||
|
md = Markdown(self.full_text, code_theme="github-dark", inline_code_theme="github-dark")
|
||||||
|
log.write(md)
|
||||||
|
|
||||||
|
# Auto-scroll to keep the latest content visible
|
||||||
|
# Use call_after_refresh to ensure scroll happens after layout update
|
||||||
|
if self.chat_display:
|
||||||
|
self.chat_display.call_after_refresh(self.chat_display.scroll_end, animate=False)
|
||||||
|
|
||||||
|
if hasattr(chunk, "usage") and chunk.usage:
|
||||||
|
usage = chunk.usage
|
||||||
|
|
||||||
|
return self.full_text, usage
|
||||||
|
|
||||||
|
def set_content(self, content: str) -> None:
|
||||||
|
"""Set the complete content (non-streaming)."""
|
||||||
|
self.full_text = content
|
||||||
|
log = self.query_one("#assistant-content", RichLog)
|
||||||
|
log.clear()
|
||||||
|
# Use neutral code theme for syntax highlighting
|
||||||
|
md = Markdown(content, code_theme="github-dark", inline_code_theme="github-dark")
|
||||||
|
log.write(md)
|
||||||
@@ -1,51 +0,0 @@
|
|||||||
"""
|
|
||||||
UI utilities for oAI.
|
|
||||||
|
|
||||||
This module provides rich terminal UI components and display helpers
|
|
||||||
for the chat application.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from oai.ui.console import (
|
|
||||||
console,
|
|
||||||
clear_screen,
|
|
||||||
display_panel,
|
|
||||||
display_table,
|
|
||||||
display_markdown,
|
|
||||||
print_error,
|
|
||||||
print_warning,
|
|
||||||
print_success,
|
|
||||||
print_info,
|
|
||||||
)
|
|
||||||
from oai.ui.tables import (
|
|
||||||
create_model_table,
|
|
||||||
create_stats_table,
|
|
||||||
create_help_table,
|
|
||||||
display_paginated_table,
|
|
||||||
)
|
|
||||||
from oai.ui.prompts import (
|
|
||||||
prompt_confirm,
|
|
||||||
prompt_choice,
|
|
||||||
prompt_input,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
# Console utilities
|
|
||||||
"console",
|
|
||||||
"clear_screen",
|
|
||||||
"display_panel",
|
|
||||||
"display_table",
|
|
||||||
"display_markdown",
|
|
||||||
"print_error",
|
|
||||||
"print_warning",
|
|
||||||
"print_success",
|
|
||||||
"print_info",
|
|
||||||
# Table utilities
|
|
||||||
"create_model_table",
|
|
||||||
"create_stats_table",
|
|
||||||
"create_help_table",
|
|
||||||
"display_paginated_table",
|
|
||||||
# Prompt utilities
|
|
||||||
"prompt_confirm",
|
|
||||||
"prompt_choice",
|
|
||||||
"prompt_input",
|
|
||||||
]
|
|
||||||
@@ -1,242 +0,0 @@
|
|||||||
"""
|
|
||||||
Console utilities for oAI.
|
|
||||||
|
|
||||||
This module provides the Rich console instance and common display functions
|
|
||||||
for formatted terminal output.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
from rich.console import Console
|
|
||||||
from rich.markdown import Markdown
|
|
||||||
from rich.panel import Panel
|
|
||||||
from rich.table import Table
|
|
||||||
from rich.text import Text
|
|
||||||
|
|
||||||
# Global console instance for the application
|
|
||||||
console = Console()
|
|
||||||
|
|
||||||
|
|
||||||
def clear_screen() -> None:
|
|
||||||
"""
|
|
||||||
Clear the terminal screen.
|
|
||||||
|
|
||||||
Uses ANSI escape codes for fast clearing, with a fallback
|
|
||||||
for terminals that don't support them.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
print("\033[H\033[J", end="", flush=True)
|
|
||||||
except Exception:
|
|
||||||
# Fallback: print many newlines
|
|
||||||
print("\n" * 100)
|
|
||||||
|
|
||||||
|
|
||||||
def display_panel(
|
|
||||||
content: Any,
|
|
||||||
title: Optional[str] = None,
|
|
||||||
subtitle: Optional[str] = None,
|
|
||||||
border_style: str = "green",
|
|
||||||
title_align: str = "left",
|
|
||||||
subtitle_align: str = "right",
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Display content in a bordered panel.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
content: Content to display (string, Table, or Markdown)
|
|
||||||
title: Optional panel title
|
|
||||||
subtitle: Optional panel subtitle
|
|
||||||
border_style: Border color/style
|
|
||||||
title_align: Title alignment ("left", "center", "right")
|
|
||||||
subtitle_align: Subtitle alignment
|
|
||||||
"""
|
|
||||||
panel = Panel(
|
|
||||||
content,
|
|
||||||
title=title,
|
|
||||||
subtitle=subtitle,
|
|
||||||
border_style=border_style,
|
|
||||||
title_align=title_align,
|
|
||||||
subtitle_align=subtitle_align,
|
|
||||||
)
|
|
||||||
console.print(panel)
|
|
||||||
|
|
||||||
|
|
||||||
def display_table(
|
|
||||||
table: Table,
|
|
||||||
title: Optional[str] = None,
|
|
||||||
subtitle: Optional[str] = None,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Display a table with optional title panel.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
table: Rich Table to display
|
|
||||||
title: Optional panel title
|
|
||||||
subtitle: Optional panel subtitle
|
|
||||||
"""
|
|
||||||
if title:
|
|
||||||
display_panel(table, title=title, subtitle=subtitle)
|
|
||||||
else:
|
|
||||||
console.print(table)
|
|
||||||
|
|
||||||
|
|
||||||
def display_markdown(
|
|
||||||
content: str,
|
|
||||||
panel: bool = False,
|
|
||||||
title: Optional[str] = None,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Display markdown-formatted content.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
content: Markdown text to display
|
|
||||||
panel: Whether to wrap in a panel
|
|
||||||
title: Optional panel title (if panel=True)
|
|
||||||
"""
|
|
||||||
md = Markdown(content)
|
|
||||||
if panel:
|
|
||||||
display_panel(md, title=title)
|
|
||||||
else:
|
|
||||||
console.print(md)
|
|
||||||
|
|
||||||
|
|
||||||
def print_error(message: str, prefix: str = "Error:") -> None:
|
|
||||||
"""
|
|
||||||
Print an error message in red.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message: Error message to display
|
|
||||||
prefix: Prefix before the message (default: "Error:")
|
|
||||||
"""
|
|
||||||
console.print(f"[bold red]{prefix}[/] {message}")
|
|
||||||
|
|
||||||
|
|
||||||
def print_warning(message: str, prefix: str = "Warning:") -> None:
|
|
||||||
"""
|
|
||||||
Print a warning message in yellow.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message: Warning message to display
|
|
||||||
prefix: Prefix before the message (default: "Warning:")
|
|
||||||
"""
|
|
||||||
console.print(f"[bold yellow]{prefix}[/] {message}")
|
|
||||||
|
|
||||||
|
|
||||||
def print_success(message: str, prefix: str = "✓") -> None:
|
|
||||||
"""
|
|
||||||
Print a success message in green.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message: Success message to display
|
|
||||||
prefix: Prefix before the message (default: "✓")
|
|
||||||
"""
|
|
||||||
console.print(f"[bold green]{prefix}[/] {message}")
|
|
||||||
|
|
||||||
|
|
||||||
def print_info(message: str, dim: bool = False) -> None:
|
|
||||||
"""
|
|
||||||
Print an informational message in cyan.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message: Info message to display
|
|
||||||
dim: Whether to dim the message
|
|
||||||
"""
|
|
||||||
if dim:
|
|
||||||
console.print(f"[dim cyan]{message}[/]")
|
|
||||||
else:
|
|
||||||
console.print(f"[bold cyan]{message}[/]")
|
|
||||||
|
|
||||||
|
|
||||||
def print_metrics(
|
|
||||||
tokens: int,
|
|
||||||
cost: float,
|
|
||||||
time_seconds: float,
|
|
||||||
context_info: str = "",
|
|
||||||
online: bool = False,
|
|
||||||
mcp_mode: Optional[str] = None,
|
|
||||||
tool_loops: int = 0,
|
|
||||||
session_tokens: int = 0,
|
|
||||||
session_cost: float = 0.0,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Print formatted metrics for a response.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tokens: Total tokens used
|
|
||||||
cost: Cost in USD
|
|
||||||
time_seconds: Response time
|
|
||||||
context_info: Context information string
|
|
||||||
online: Whether online mode is active
|
|
||||||
mcp_mode: MCP mode ("files", "database", or None)
|
|
||||||
tool_loops: Number of tool call loops
|
|
||||||
session_tokens: Total session tokens
|
|
||||||
session_cost: Total session cost
|
|
||||||
"""
|
|
||||||
parts = [
|
|
||||||
f"📊 Metrics: {tokens} tokens",
|
|
||||||
f"${cost:.4f}",
|
|
||||||
f"{time_seconds:.2f}s",
|
|
||||||
]
|
|
||||||
|
|
||||||
if context_info:
|
|
||||||
parts.append(context_info)
|
|
||||||
|
|
||||||
if online:
|
|
||||||
parts.append("🌐")
|
|
||||||
|
|
||||||
if mcp_mode == "files":
|
|
||||||
parts.append("🔧")
|
|
||||||
elif mcp_mode == "database":
|
|
||||||
parts.append("🗄️")
|
|
||||||
|
|
||||||
if tool_loops > 0:
|
|
||||||
parts.append(f"({tool_loops} tool loop(s))")
|
|
||||||
|
|
||||||
parts.append(f"Session: {session_tokens} tokens")
|
|
||||||
parts.append(f"${session_cost:.4f}")
|
|
||||||
|
|
||||||
console.print(f"\n[dim blue]{' | '.join(parts)}[/]")
|
|
||||||
|
|
||||||
|
|
||||||
def format_size(size_bytes: int) -> str:
|
|
||||||
"""
|
|
||||||
Format a size in bytes to a human-readable string.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
size_bytes: Size in bytes
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Formatted size string (e.g., "1.5 MB")
|
|
||||||
"""
|
|
||||||
for unit in ["B", "KB", "MB", "GB", "TB"]:
|
|
||||||
if abs(size_bytes) < 1024.0:
|
|
||||||
return f"{size_bytes:.1f} {unit}"
|
|
||||||
size_bytes /= 1024.0
|
|
||||||
return f"{size_bytes:.1f} PB"
|
|
||||||
|
|
||||||
|
|
||||||
def format_tokens(tokens: int) -> str:
|
|
||||||
"""
|
|
||||||
Format token count with thousands separators.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tokens: Number of tokens
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Formatted token string (e.g., "1,234,567")
|
|
||||||
"""
|
|
||||||
return f"{tokens:,}"
|
|
||||||
|
|
||||||
|
|
||||||
def format_cost(cost: float, precision: int = 4) -> str:
|
|
||||||
"""
|
|
||||||
Format cost in USD.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cost: Cost in dollars
|
|
||||||
precision: Decimal places
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Formatted cost string (e.g., "$0.0123")
|
|
||||||
"""
|
|
||||||
return f"${cost:.{precision}f}"
|
|
||||||
@@ -1,274 +0,0 @@
|
|||||||
"""
|
|
||||||
Prompt utilities for oAI.
|
|
||||||
|
|
||||||
This module provides functions for gathering user input
|
|
||||||
through confirmations, choices, and text prompts.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import List, Optional, TypeVar
|
|
||||||
|
|
||||||
import typer
|
|
||||||
|
|
||||||
from oai.ui.console import console
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
|
|
||||||
def prompt_confirm(
|
|
||||||
message: str,
|
|
||||||
default: bool = False,
|
|
||||||
abort: bool = False,
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
Prompt the user for a yes/no confirmation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message: The question to ask
|
|
||||||
default: Default value if user presses Enter
|
|
||||||
abort: Whether to abort on "no" response
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if user confirms, False otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
return typer.confirm(message, default=default, abort=abort)
|
|
||||||
except (EOFError, KeyboardInterrupt):
|
|
||||||
console.print("\n[yellow]Cancelled[/]")
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def prompt_choice(
|
|
||||||
message: str,
|
|
||||||
choices: List[str],
|
|
||||||
default: Optional[str] = None,
|
|
||||||
) -> Optional[str]:
|
|
||||||
"""
|
|
||||||
Prompt the user to select from a list of choices.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message: The question to ask
|
|
||||||
choices: List of valid choices
|
|
||||||
default: Default choice if user presses Enter
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Selected choice or None if cancelled
|
|
||||||
"""
|
|
||||||
# Display choices
|
|
||||||
console.print(f"\n[bold cyan]{message}[/]")
|
|
||||||
for i, choice in enumerate(choices, 1):
|
|
||||||
default_marker = " [default]" if choice == default else ""
|
|
||||||
console.print(f" {i}. {choice}{default_marker}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = input("\nEnter number or value: ").strip()
|
|
||||||
|
|
||||||
if not response and default:
|
|
||||||
return default
|
|
||||||
|
|
||||||
# Try as number first
|
|
||||||
try:
|
|
||||||
index = int(response) - 1
|
|
||||||
if 0 <= index < len(choices):
|
|
||||||
return choices[index]
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Try as exact match
|
|
||||||
if response in choices:
|
|
||||||
return response
|
|
||||||
|
|
||||||
# Try case-insensitive match
|
|
||||||
response_lower = response.lower()
|
|
||||||
for choice in choices:
|
|
||||||
if choice.lower() == response_lower:
|
|
||||||
return choice
|
|
||||||
|
|
||||||
console.print(f"[red]Invalid choice: {response}[/]")
|
|
||||||
return None
|
|
||||||
|
|
||||||
except (EOFError, KeyboardInterrupt):
|
|
||||||
console.print("\n[yellow]Cancelled[/]")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def prompt_input(
|
|
||||||
message: str,
|
|
||||||
default: Optional[str] = None,
|
|
||||||
password: bool = False,
|
|
||||||
required: bool = False,
|
|
||||||
) -> Optional[str]:
|
|
||||||
"""
|
|
||||||
Prompt the user for text input.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message: The prompt message
|
|
||||||
default: Default value if user presses Enter
|
|
||||||
password: Whether to hide input (for sensitive data)
|
|
||||||
required: Whether input is required (loops until provided)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
User input or default, None if cancelled
|
|
||||||
"""
|
|
||||||
prompt_text = message
|
|
||||||
if default:
|
|
||||||
prompt_text += f" [{default}]"
|
|
||||||
prompt_text += ": "
|
|
||||||
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
if password:
|
|
||||||
import getpass
|
|
||||||
|
|
||||||
response = getpass.getpass(prompt_text)
|
|
||||||
else:
|
|
||||||
response = input(prompt_text).strip()
|
|
||||||
|
|
||||||
if not response:
|
|
||||||
if default:
|
|
||||||
return default
|
|
||||||
if required:
|
|
||||||
console.print("[yellow]Input required[/]")
|
|
||||||
continue
|
|
||||||
return None
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
except (EOFError, KeyboardInterrupt):
|
|
||||||
console.print("\n[yellow]Cancelled[/]")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def prompt_number(
|
|
||||||
message: str,
|
|
||||||
min_value: Optional[int] = None,
|
|
||||||
max_value: Optional[int] = None,
|
|
||||||
default: Optional[int] = None,
|
|
||||||
) -> Optional[int]:
|
|
||||||
"""
|
|
||||||
Prompt the user for a numeric input.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message: The prompt message
|
|
||||||
min_value: Minimum allowed value
|
|
||||||
max_value: Maximum allowed value
|
|
||||||
default: Default value if user presses Enter
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Integer value or None if cancelled
|
|
||||||
"""
|
|
||||||
prompt_text = message
|
|
||||||
if default is not None:
|
|
||||||
prompt_text += f" [{default}]"
|
|
||||||
prompt_text += ": "
|
|
||||||
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
response = input(prompt_text).strip()
|
|
||||||
|
|
||||||
if not response:
|
|
||||||
if default is not None:
|
|
||||||
return default
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
value = int(response)
|
|
||||||
except ValueError:
|
|
||||||
console.print("[red]Please enter a valid number[/]")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if min_value is not None and value < min_value:
|
|
||||||
console.print(f"[red]Value must be at least {min_value}[/]")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if max_value is not None and value > max_value:
|
|
||||||
console.print(f"[red]Value must be at most {max_value}[/]")
|
|
||||||
continue
|
|
||||||
|
|
||||||
return value
|
|
||||||
|
|
||||||
except (EOFError, KeyboardInterrupt):
|
|
||||||
console.print("\n[yellow]Cancelled[/]")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def prompt_selection(
|
|
||||||
items: List[T],
|
|
||||||
message: str = "Select an item",
|
|
||||||
display_func: Optional[callable] = None,
|
|
||||||
allow_cancel: bool = True,
|
|
||||||
) -> Optional[T]:
|
|
||||||
"""
|
|
||||||
Prompt the user to select an item from a list.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
items: List of items to choose from
|
|
||||||
message: The selection prompt
|
|
||||||
display_func: Function to convert item to display string
|
|
||||||
allow_cancel: Whether to allow cancellation
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Selected item or None if cancelled
|
|
||||||
"""
|
|
||||||
if not items:
|
|
||||||
console.print("[yellow]No items to select[/]")
|
|
||||||
return None
|
|
||||||
|
|
||||||
display = display_func or str
|
|
||||||
|
|
||||||
console.print(f"\n[bold cyan]{message}[/]")
|
|
||||||
for i, item in enumerate(items, 1):
|
|
||||||
console.print(f" {i}. {display(item)}")
|
|
||||||
|
|
||||||
if allow_cancel:
|
|
||||||
console.print(f" 0. Cancel")
|
|
||||||
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
response = input("\nEnter number: ").strip()
|
|
||||||
|
|
||||||
try:
|
|
||||||
index = int(response)
|
|
||||||
except ValueError:
|
|
||||||
console.print("[red]Please enter a valid number[/]")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if allow_cancel and index == 0:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if 1 <= index <= len(items):
|
|
||||||
return items[index - 1]
|
|
||||||
|
|
||||||
console.print(f"[red]Please enter a number between 1 and {len(items)}[/]")
|
|
||||||
|
|
||||||
except (EOFError, KeyboardInterrupt):
|
|
||||||
console.print("\n[yellow]Cancelled[/]")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def prompt_copy_response(response: str) -> bool:
|
|
||||||
"""
|
|
||||||
Prompt user to copy a response to clipboard.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
response: The response text
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if copied, False otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
copy_choice = input("💾 Type 'c' to copy response, or press Enter to continue: ").strip().lower()
|
|
||||||
if copy_choice == "c":
|
|
||||||
try:
|
|
||||||
import pyperclip
|
|
||||||
|
|
||||||
pyperclip.copy(response)
|
|
||||||
console.print("[bold green]✅ Response copied to clipboard![/]")
|
|
||||||
return True
|
|
||||||
except ImportError:
|
|
||||||
console.print("[yellow]pyperclip not installed - cannot copy to clipboard[/]")
|
|
||||||
except Exception as e:
|
|
||||||
console.print(f"[red]Failed to copy: {e}[/]")
|
|
||||||
except (EOFError, KeyboardInterrupt):
|
|
||||||
pass
|
|
||||||
|
|
||||||
return False
|
|
||||||
373
oai/ui/tables.py
373
oai/ui/tables.py
@@ -1,373 +0,0 @@
|
|||||||
"""
|
|
||||||
Table utilities for oAI.
|
|
||||||
|
|
||||||
This module provides functions for creating and displaying
|
|
||||||
formatted tables with pagination support.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
from rich.panel import Panel
|
|
||||||
from rich.table import Table
|
|
||||||
|
|
||||||
from oai.ui.console import clear_screen, console
|
|
||||||
|
|
||||||
|
|
||||||
def create_model_table(
|
|
||||||
models: List[Dict[str, Any]],
|
|
||||||
show_capabilities: bool = True,
|
|
||||||
) -> Table:
|
|
||||||
"""
|
|
||||||
Create a table displaying available AI models.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
models: List of model dictionaries
|
|
||||||
show_capabilities: Whether to show capability columns
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Rich Table with model information
|
|
||||||
"""
|
|
||||||
if show_capabilities:
|
|
||||||
table = Table(
|
|
||||||
"No.",
|
|
||||||
"Model ID",
|
|
||||||
"Context",
|
|
||||||
"Image",
|
|
||||||
"Online",
|
|
||||||
"Tools",
|
|
||||||
show_header=True,
|
|
||||||
header_style="bold magenta",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
table = Table(
|
|
||||||
"No.",
|
|
||||||
"Model ID",
|
|
||||||
"Context",
|
|
||||||
show_header=True,
|
|
||||||
header_style="bold magenta",
|
|
||||||
)
|
|
||||||
|
|
||||||
for i, model in enumerate(models, 1):
|
|
||||||
model_id = model.get("id", "Unknown")
|
|
||||||
context = model.get("context_length", 0)
|
|
||||||
context_str = f"{context:,}" if context else "-"
|
|
||||||
|
|
||||||
if show_capabilities:
|
|
||||||
# Get modalities and parameters
|
|
||||||
architecture = model.get("architecture", {})
|
|
||||||
input_modalities = architecture.get("input_modalities", [])
|
|
||||||
supported_params = model.get("supported_parameters", [])
|
|
||||||
|
|
||||||
has_image = "✓" if "image" in input_modalities else "-"
|
|
||||||
has_online = "✓" if "tools" in supported_params else "-"
|
|
||||||
has_tools = "✓" if "tools" in supported_params or "functions" in supported_params else "-"
|
|
||||||
|
|
||||||
table.add_row(
|
|
||||||
str(i),
|
|
||||||
model_id,
|
|
||||||
context_str,
|
|
||||||
has_image,
|
|
||||||
has_online,
|
|
||||||
has_tools,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
table.add_row(str(i), model_id, context_str)
|
|
||||||
|
|
||||||
return table
|
|
||||||
|
|
||||||
|
|
||||||
def create_stats_table(stats: Dict[str, Any]) -> Table:
|
|
||||||
"""
|
|
||||||
Create a table displaying session statistics.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
stats: Dictionary with statistics data
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Rich Table with stats
|
|
||||||
"""
|
|
||||||
table = Table(
|
|
||||||
"Metric",
|
|
||||||
"Value",
|
|
||||||
show_header=True,
|
|
||||||
header_style="bold magenta",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Token stats
|
|
||||||
if "input_tokens" in stats:
|
|
||||||
table.add_row("Input Tokens", f"{stats['input_tokens']:,}")
|
|
||||||
if "output_tokens" in stats:
|
|
||||||
table.add_row("Output Tokens", f"{stats['output_tokens']:,}")
|
|
||||||
if "total_tokens" in stats:
|
|
||||||
table.add_row("Total Tokens", f"{stats['total_tokens']:,}")
|
|
||||||
|
|
||||||
# Cost stats
|
|
||||||
if "total_cost" in stats:
|
|
||||||
table.add_row("Total Cost", f"${stats['total_cost']:.4f}")
|
|
||||||
if "avg_cost" in stats:
|
|
||||||
table.add_row("Avg Cost/Message", f"${stats['avg_cost']:.4f}")
|
|
||||||
|
|
||||||
# Message stats
|
|
||||||
if "message_count" in stats:
|
|
||||||
table.add_row("Messages", str(stats["message_count"]))
|
|
||||||
|
|
||||||
# Credits
|
|
||||||
if "credits_left" in stats:
|
|
||||||
table.add_row("Credits Left", stats["credits_left"])
|
|
||||||
|
|
||||||
return table
|
|
||||||
|
|
||||||
|
|
||||||
def create_help_table(commands: Dict[str, Dict[str, str]]) -> Table:
|
|
||||||
"""
|
|
||||||
Create a help table for commands.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
commands: Dictionary of command info
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Rich Table with command help
|
|
||||||
"""
|
|
||||||
table = Table(
|
|
||||||
"Command",
|
|
||||||
"Description",
|
|
||||||
"Example",
|
|
||||||
show_header=True,
|
|
||||||
header_style="bold magenta",
|
|
||||||
show_lines=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
for cmd, info in commands.items():
|
|
||||||
description = info.get("description", "")
|
|
||||||
example = info.get("example", "")
|
|
||||||
table.add_row(cmd, description, example)
|
|
||||||
|
|
||||||
return table
|
|
||||||
|
|
||||||
|
|
||||||
def create_folder_table(
|
|
||||||
folders: List[Dict[str, Any]],
|
|
||||||
gitignore_info: str = "",
|
|
||||||
) -> Table:
|
|
||||||
"""
|
|
||||||
Create a table for MCP folder listing.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
folders: List of folder dictionaries
|
|
||||||
gitignore_info: Optional gitignore status info
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Rich Table with folder information
|
|
||||||
"""
|
|
||||||
table = Table(
|
|
||||||
"No.",
|
|
||||||
"Path",
|
|
||||||
"Files",
|
|
||||||
"Size",
|
|
||||||
show_header=True,
|
|
||||||
header_style="bold magenta",
|
|
||||||
)
|
|
||||||
|
|
||||||
for folder in folders:
|
|
||||||
number = str(folder.get("number", ""))
|
|
||||||
path = folder.get("path", "")
|
|
||||||
|
|
||||||
if folder.get("exists", True):
|
|
||||||
files = f"📁 {folder.get('file_count', 0)}"
|
|
||||||
size = f"{folder.get('size_mb', 0):.1f} MB"
|
|
||||||
else:
|
|
||||||
files = "[red]Not found[/red]"
|
|
||||||
size = "-"
|
|
||||||
|
|
||||||
table.add_row(number, path, files, size)
|
|
||||||
|
|
||||||
return table
|
|
||||||
|
|
||||||
|
|
||||||
def create_database_table(databases: List[Dict[str, Any]]) -> Table:
|
|
||||||
"""
|
|
||||||
Create a table for MCP database listing.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
databases: List of database dictionaries
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Rich Table with database information
|
|
||||||
"""
|
|
||||||
table = Table(
|
|
||||||
"No.",
|
|
||||||
"Name",
|
|
||||||
"Tables",
|
|
||||||
"Size",
|
|
||||||
"Status",
|
|
||||||
show_header=True,
|
|
||||||
header_style="bold magenta",
|
|
||||||
)
|
|
||||||
|
|
||||||
for db in databases:
|
|
||||||
number = str(db.get("number", ""))
|
|
||||||
name = db.get("name", "")
|
|
||||||
table_count = f"{db.get('table_count', 0)} tables"
|
|
||||||
size = f"{db.get('size_mb', 0):.1f} MB"
|
|
||||||
|
|
||||||
if db.get("warning"):
|
|
||||||
status = f"[red]{db['warning']}[/red]"
|
|
||||||
else:
|
|
||||||
status = "[green]✓[/green]"
|
|
||||||
|
|
||||||
table.add_row(number, name, table_count, size, status)
|
|
||||||
|
|
||||||
return table
|
|
||||||
|
|
||||||
|
|
||||||
def display_paginated_table(
|
|
||||||
table: Table,
|
|
||||||
title: str,
|
|
||||||
terminal_height: Optional[int] = None,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Display a table with pagination for large datasets.
|
|
||||||
|
|
||||||
Allows navigating through pages with keyboard input.
|
|
||||||
Press SPACE for next page, any other key to exit.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
table: Rich Table to display
|
|
||||||
title: Title for the table
|
|
||||||
terminal_height: Override terminal height (auto-detected if None)
|
|
||||||
"""
|
|
||||||
# Get terminal dimensions
|
|
||||||
try:
|
|
||||||
term_height = terminal_height or os.get_terminal_size().lines - 8
|
|
||||||
except OSError:
|
|
||||||
term_height = 20
|
|
||||||
|
|
||||||
# Render table to segments
|
|
||||||
from rich.segment import Segment
|
|
||||||
|
|
||||||
segments = list(console.render(table))
|
|
||||||
|
|
||||||
# Group segments into lines
|
|
||||||
current_line_segments: List[Segment] = []
|
|
||||||
all_lines: List[List[Segment]] = []
|
|
||||||
|
|
||||||
for segment in segments:
|
|
||||||
if segment.text == "\n":
|
|
||||||
all_lines.append(current_line_segments)
|
|
||||||
current_line_segments = []
|
|
||||||
else:
|
|
||||||
current_line_segments.append(segment)
|
|
||||||
|
|
||||||
if current_line_segments:
|
|
||||||
all_lines.append(current_line_segments)
|
|
||||||
|
|
||||||
total_lines = len(all_lines)
|
|
||||||
|
|
||||||
# If table fits in one screen, just display it
|
|
||||||
if total_lines <= term_height:
|
|
||||||
console.print(Panel(table, title=title, title_align="left"))
|
|
||||||
return
|
|
||||||
|
|
||||||
# Extract header and footer lines
|
|
||||||
header_lines: List[List[Segment]] = []
|
|
||||||
data_lines: List[List[Segment]] = []
|
|
||||||
footer_line: List[Segment] = []
|
|
||||||
|
|
||||||
# Find header end (line after the header text with border)
|
|
||||||
header_end_index = 0
|
|
||||||
found_header_text = False
|
|
||||||
|
|
||||||
for i, line_segments in enumerate(all_lines):
|
|
||||||
has_header_style = any(
|
|
||||||
seg.style and ("bold" in str(seg.style) or "magenta" in str(seg.style))
|
|
||||||
for seg in line_segments
|
|
||||||
)
|
|
||||||
|
|
||||||
if has_header_style:
|
|
||||||
found_header_text = True
|
|
||||||
|
|
||||||
if found_header_text and i > 0:
|
|
||||||
line_text = "".join(seg.text for seg in line_segments)
|
|
||||||
if any(char in line_text for char in ["─", "━", "┼", "╪", "┤", "├"]):
|
|
||||||
header_end_index = i
|
|
||||||
break
|
|
||||||
|
|
||||||
# Extract footer (bottom border)
|
|
||||||
if all_lines:
|
|
||||||
last_line_text = "".join(seg.text for seg in all_lines[-1])
|
|
||||||
if any(char in last_line_text for char in ["─", "━", "┴", "╧", "┘", "└"]):
|
|
||||||
footer_line = all_lines[-1]
|
|
||||||
all_lines = all_lines[:-1]
|
|
||||||
|
|
||||||
# Split into header and data
|
|
||||||
if header_end_index > 0:
|
|
||||||
header_lines = all_lines[: header_end_index + 1]
|
|
||||||
data_lines = all_lines[header_end_index + 1 :]
|
|
||||||
else:
|
|
||||||
header_lines = all_lines[: min(3, len(all_lines))]
|
|
||||||
data_lines = all_lines[min(3, len(all_lines)) :]
|
|
||||||
|
|
||||||
lines_per_page = term_height - len(header_lines)
|
|
||||||
current_line = 0
|
|
||||||
page_number = 1
|
|
||||||
|
|
||||||
# Paginate
|
|
||||||
while current_line < len(data_lines):
|
|
||||||
clear_screen()
|
|
||||||
console.print(f"[bold cyan]{title} (Page {page_number})[/]")
|
|
||||||
|
|
||||||
# Print header
|
|
||||||
for line_segments in header_lines:
|
|
||||||
for segment in line_segments:
|
|
||||||
console.print(segment.text, style=segment.style, end="")
|
|
||||||
console.print()
|
|
||||||
|
|
||||||
# Print data rows for this page
|
|
||||||
end_line = min(current_line + lines_per_page, len(data_lines))
|
|
||||||
for line_segments in data_lines[current_line:end_line]:
|
|
||||||
for segment in line_segments:
|
|
||||||
console.print(segment.text, style=segment.style, end="")
|
|
||||||
console.print()
|
|
||||||
|
|
||||||
# Print footer
|
|
||||||
if footer_line:
|
|
||||||
for segment in footer_line:
|
|
||||||
console.print(segment.text, style=segment.style, end="")
|
|
||||||
console.print()
|
|
||||||
|
|
||||||
current_line = end_line
|
|
||||||
page_number += 1
|
|
||||||
|
|
||||||
# Prompt for next page
|
|
||||||
if current_line < len(data_lines):
|
|
||||||
console.print(
|
|
||||||
f"\n[dim yellow]--- Press SPACE for next page, "
|
|
||||||
f"or any other key to finish (Page {page_number - 1}, "
|
|
||||||
f"showing {end_line}/{len(data_lines)} data rows) ---[/dim yellow]"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
import termios
|
|
||||||
import tty
|
|
||||||
|
|
||||||
fd = sys.stdin.fileno()
|
|
||||||
old_settings = termios.tcgetattr(fd)
|
|
||||||
|
|
||||||
try:
|
|
||||||
tty.setraw(fd)
|
|
||||||
char = sys.stdin.read(1)
|
|
||||||
if char != " ":
|
|
||||||
break
|
|
||||||
finally:
|
|
||||||
termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
|
|
||||||
|
|
||||||
except (ImportError, OSError, AttributeError):
|
|
||||||
# Fallback for non-Unix systems
|
|
||||||
try:
|
|
||||||
user_input = input()
|
|
||||||
if user_input.strip():
|
|
||||||
break
|
|
||||||
except (EOFError, KeyboardInterrupt):
|
|
||||||
break
|
|
||||||
247
oai/utils/web_search.py
Normal file
247
oai/utils/web_search.py
Normal file
@@ -0,0 +1,247 @@
|
|||||||
|
"""
|
||||||
|
Web search utilities for oAI.
|
||||||
|
|
||||||
|
Provides web search capabilities for all providers (not just OpenRouter).
|
||||||
|
Uses DuckDuckGo by default (no API key needed).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
from urllib.parse import quote_plus
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from oai.utils.logging import get_logger
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class WebSearchResult:
|
||||||
|
"""Container for a single search result."""
|
||||||
|
|
||||||
|
def __init__(self, title: str, url: str, snippet: str):
|
||||||
|
self.title = title
|
||||||
|
self.url = url
|
||||||
|
self.snippet = snippet
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"WebSearchResult(title='{self.title}', url='{self.url}')"
|
||||||
|
|
||||||
|
|
||||||
|
class WebSearchProvider:
|
||||||
|
"""Base class for web search providers."""
|
||||||
|
|
||||||
|
def search(self, query: str, num_results: int = 5) -> List[WebSearchResult]:
|
||||||
|
"""
|
||||||
|
Perform a web search.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Search query
|
||||||
|
num_results: Number of results to return
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of search results
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class DuckDuckGoSearch(WebSearchProvider):
|
||||||
|
"""DuckDuckGo search provider (no API key needed)."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.session = requests.Session()
|
||||||
|
self.session.headers.update({
|
||||||
|
'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36'
|
||||||
|
})
|
||||||
|
|
||||||
|
def search(self, query: str, num_results: int = 5) -> List[WebSearchResult]:
|
||||||
|
"""
|
||||||
|
Search using DuckDuckGo HTML interface.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Search query
|
||||||
|
num_results: Number of results to return (default: 5)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of search results
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Use DuckDuckGo HTML search
|
||||||
|
url = f"https://html.duckduckgo.com/html/?q={quote_plus(query)}"
|
||||||
|
response = self.session.get(url, timeout=10)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
results = []
|
||||||
|
html = response.text
|
||||||
|
|
||||||
|
# Parse results using regex (simple HTML parsing)
|
||||||
|
# Find all result blocks - they end at next result or end of results section
|
||||||
|
result_blocks = re.findall(
|
||||||
|
r'<div class="result results_links.*?(?=<div class="result results_links|<div id="links")',
|
||||||
|
html,
|
||||||
|
re.DOTALL
|
||||||
|
)
|
||||||
|
|
||||||
|
for block in result_blocks[:num_results]:
|
||||||
|
# Extract title and URL - look for result__a class
|
||||||
|
title_match = re.search(r'<a[^>]*class="result__a"[^>]*href="([^"]+)"[^>]*>([^<]+)</a>', block)
|
||||||
|
# Extract snippet - look for result__snippet class
|
||||||
|
snippet_match = re.search(r'<a[^>]*class="result__snippet"[^>]*>([^<]+)</a>', block)
|
||||||
|
|
||||||
|
if title_match:
|
||||||
|
url_raw = title_match.group(1)
|
||||||
|
title = title_match.group(2).strip()
|
||||||
|
|
||||||
|
# Decode HTML entities in title
|
||||||
|
import html as html_module
|
||||||
|
title = html_module.unescape(title)
|
||||||
|
|
||||||
|
snippet = ""
|
||||||
|
if snippet_match:
|
||||||
|
snippet = snippet_match.group(1).strip()
|
||||||
|
snippet = html_module.unescape(snippet)
|
||||||
|
|
||||||
|
# Clean up URL (DDG uses redirect links)
|
||||||
|
if 'uddg=' in url_raw:
|
||||||
|
# Extract actual URL from redirect
|
||||||
|
actual_url_match = re.search(r'uddg=([^&]+)', url_raw)
|
||||||
|
if actual_url_match:
|
||||||
|
from urllib.parse import unquote
|
||||||
|
url_raw = unquote(actual_url_match.group(1))
|
||||||
|
|
||||||
|
results.append(WebSearchResult(
|
||||||
|
title=title,
|
||||||
|
url=url_raw,
|
||||||
|
snippet=snippet
|
||||||
|
))
|
||||||
|
|
||||||
|
logger.info(f"DuckDuckGo search: found {len(results)} results for '{query}'")
|
||||||
|
return results
|
||||||
|
|
||||||
|
except requests.RequestException as e:
|
||||||
|
logger.error(f"DuckDuckGo search failed: {e}")
|
||||||
|
return []
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error parsing DuckDuckGo results: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleCustomSearch(WebSearchProvider):
|
||||||
|
"""Google Custom Search API provider (requires API key)."""
|
||||||
|
|
||||||
|
def __init__(self, api_key: str, search_engine_id: str):
|
||||||
|
"""
|
||||||
|
Initialize Google Custom Search.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: Google API key
|
||||||
|
search_engine_id: Custom Search Engine ID
|
||||||
|
"""
|
||||||
|
self.api_key = api_key
|
||||||
|
self.search_engine_id = search_engine_id
|
||||||
|
|
||||||
|
def search(self, query: str, num_results: int = 5) -> List[WebSearchResult]:
|
||||||
|
"""
|
||||||
|
Search using Google Custom Search API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Search query
|
||||||
|
num_results: Number of results to return
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of search results
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
url = "https://www.googleapis.com/customsearch/v1"
|
||||||
|
params = {
|
||||||
|
'key': self.api_key,
|
||||||
|
'cx': self.search_engine_id,
|
||||||
|
'q': query,
|
||||||
|
'num': min(num_results, 10) # Google allows max 10
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.get(url, params=params, timeout=10)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for item in data.get('items', []):
|
||||||
|
results.append(WebSearchResult(
|
||||||
|
title=item.get('title', ''),
|
||||||
|
url=item.get('link', ''),
|
||||||
|
snippet=item.get('snippet', '')
|
||||||
|
))
|
||||||
|
|
||||||
|
logger.info(f"Google Custom Search: found {len(results)} results for '{query}'")
|
||||||
|
return results
|
||||||
|
|
||||||
|
except requests.RequestException as e:
|
||||||
|
logger.error(f"Google Custom Search failed: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def perform_web_search(
|
||||||
|
query: str,
|
||||||
|
num_results: int = 5,
|
||||||
|
provider: str = "duckduckgo",
|
||||||
|
**kwargs
|
||||||
|
) -> List[WebSearchResult]:
|
||||||
|
"""
|
||||||
|
Perform a web search using the specified provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Search query
|
||||||
|
num_results: Number of results to return (default: 5)
|
||||||
|
provider: Search provider ("duckduckgo" or "google")
|
||||||
|
**kwargs: Provider-specific arguments (e.g., api_key for Google)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of search results
|
||||||
|
"""
|
||||||
|
if provider == "google":
|
||||||
|
api_key = kwargs.get("google_api_key")
|
||||||
|
search_engine_id = kwargs.get("google_search_engine_id")
|
||||||
|
if not api_key or not search_engine_id:
|
||||||
|
logger.warning("Google search requires api_key and search_engine_id, falling back to DuckDuckGo")
|
||||||
|
provider = "duckduckgo"
|
||||||
|
|
||||||
|
if provider == "google":
|
||||||
|
search_provider = GoogleCustomSearch(api_key, search_engine_id)
|
||||||
|
else:
|
||||||
|
search_provider = DuckDuckGoSearch()
|
||||||
|
|
||||||
|
return search_provider.search(query, num_results)
|
||||||
|
|
||||||
|
|
||||||
|
def format_search_results(results: List[WebSearchResult], max_length: int = 2000) -> str:
|
||||||
|
"""
|
||||||
|
Format search results for inclusion in AI prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results: List of search results
|
||||||
|
max_length: Maximum total length of formatted results
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted string with search results
|
||||||
|
"""
|
||||||
|
if not results:
|
||||||
|
return "No search results found."
|
||||||
|
|
||||||
|
formatted = "**Web Search Results:**\n\n"
|
||||||
|
|
||||||
|
for i, result in enumerate(results, 1):
|
||||||
|
result_text = f"{i}. **{result.title}**\n"
|
||||||
|
result_text += f" URL: {result.url}\n"
|
||||||
|
if result.snippet:
|
||||||
|
result_text += f" {result.snippet}\n"
|
||||||
|
result_text += "\n"
|
||||||
|
|
||||||
|
# Check if adding this result would exceed max_length
|
||||||
|
if len(formatted) + len(result_text) > max_length:
|
||||||
|
formatted += f"... ({len(results) - i + 1} more results truncated)\n"
|
||||||
|
break
|
||||||
|
|
||||||
|
formatted += result_text
|
||||||
|
|
||||||
|
return formatted.strip()
|
||||||
@@ -4,8 +4,8 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "oai"
|
name = "oai"
|
||||||
version = "2.1.0"
|
version = "3.0.0-b3" # MUST match oai/__init__.py __version__
|
||||||
description = "OpenRouter AI Chat Client - A feature-rich terminal-based chat application"
|
description = "Open AI Chat Client - Multi-provider terminal chat with MCP support"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = {text = "MIT"}
|
license = {text = "MIT"}
|
||||||
authors = [
|
authors = [
|
||||||
@@ -39,15 +39,17 @@ classifiers = [
|
|||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyio>=4.0.0",
|
"anyio>=4.0.0",
|
||||||
|
"anthropic>=0.40.0",
|
||||||
"click>=8.0.0",
|
"click>=8.0.0",
|
||||||
"httpx>=0.24.0",
|
"httpx>=0.24.0",
|
||||||
"markdown-it-py>=3.0.0",
|
"markdown-it-py>=3.0.0",
|
||||||
|
"openai>=1.59.0",
|
||||||
"openrouter>=0.0.19",
|
"openrouter>=0.0.19",
|
||||||
"packaging>=21.0",
|
"packaging>=21.0",
|
||||||
"prompt-toolkit>=3.0.0",
|
|
||||||
"pyperclip>=1.8.0",
|
"pyperclip>=1.8.0",
|
||||||
"requests>=2.28.0",
|
"requests>=2.28.0",
|
||||||
"rich>=13.0.0",
|
"rich>=13.0.0",
|
||||||
|
"textual>=0.50.0",
|
||||||
"typer>=0.9.0",
|
"typer>=0.9.0",
|
||||||
"mcp>=1.0.0",
|
"mcp>=1.0.0",
|
||||||
]
|
]
|
||||||
@@ -73,7 +75,7 @@ Documentation = "https://iurl.no/oai"
|
|||||||
oai = "oai.cli:main"
|
oai = "oai.cli:main"
|
||||||
|
|
||||||
[tool.setuptools]
|
[tool.setuptools]
|
||||||
packages = ["oai", "oai.commands", "oai.config", "oai.core", "oai.mcp", "oai.providers", "oai.ui", "oai.utils"]
|
packages = ["oai", "oai.commands", "oai.config", "oai.core", "oai.mcp", "oai.providers", "oai.tui", "oai.tui.widgets", "oai.tui.screens", "oai.utils"]
|
||||||
|
|
||||||
[tool.setuptools.package-data]
|
[tool.setuptools.package-data]
|
||||||
oai = ["py.typed"]
|
oai = ["py.typed"]
|
||||||
|
|||||||
Reference in New Issue
Block a user