Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 2e7c49bf68 |
7
.gitignore
vendored
7
.gitignore
vendored
@@ -23,9 +23,6 @@ Pipfile.lock # Consider if you want to include or exclude
|
||||
*~.nib
|
||||
*~.xib
|
||||
|
||||
# Claude Code local settings
|
||||
.claude/
|
||||
|
||||
# Added by author
|
||||
*.zip
|
||||
.note
|
||||
@@ -42,7 +39,3 @@ b0.sh
|
||||
*.old
|
||||
*.sh
|
||||
*.back
|
||||
requirements.txt
|
||||
system_prompt.txt
|
||||
CLAUDE*
|
||||
SESSION*_COMPLETE.md
|
||||
|
||||
633
README.md
633
README.md
@@ -1,326 +1,529 @@
|
||||
# oAI - OpenRouter AI Chat Client
|
||||
# oAI - OpenRouter AI Chat
|
||||
|
||||
A powerful, modern **Textual TUI** chat client for OpenRouter API with **MCP (Model Context Protocol)** support, enabling AI to access local files and query SQLite databases.
|
||||
A powerful terminal-based chat interface for OpenRouter API with **MCP (Model Context Protocol)** support, enabling AI agents to access local files and query SQLite databases directly.
|
||||
|
||||
## Description
|
||||
|
||||
oAI is a feature-rich command-line chat application that provides an interactive interface to OpenRouter's AI models. It now includes **MCP integration** for local file system access and read-only database querying, allowing AI to help with code analysis, data exploration, and more.
|
||||
|
||||
## Features
|
||||
|
||||
### Core Features
|
||||
- 🖥️ **Modern Textual TUI** with async streaming and beautiful interface
|
||||
- 🤖 Interactive chat with 300+ AI models via OpenRouter
|
||||
- 🔍 Model selection with search, filtering, and capability icons
|
||||
- 🔍 Model selection with search and capability filtering
|
||||
- 💾 Conversation save/load/export (Markdown, JSON, HTML)
|
||||
- 📎 File attachments (images, PDFs, code files)
|
||||
- 💰 Real-time cost tracking and credit monitoring
|
||||
- 🎨 Dark theme with syntax highlighting and Markdown rendering
|
||||
- 📝 Command history navigation (Up/Down arrows)
|
||||
- 📎 File attachment support (images, PDFs, code files)
|
||||
- 💰 Session cost tracking and credit monitoring
|
||||
- 🎨 Rich terminal formatting with syntax highlighting
|
||||
- 📝 Persistent command history with search (Ctrl+R)
|
||||
- ⚙️ Configurable system prompts and token limits
|
||||
- 🗄️ SQLite-based configuration and conversation storage
|
||||
- 🌐 Online mode (web search capabilities)
|
||||
- 🧠 Conversation memory toggle
|
||||
- ⌨️ Keyboard shortcuts (F1=Help, F2=Models, Ctrl+S=Stats)
|
||||
- 🧠 Conversation memory toggle (save costs with stateless mode)
|
||||
|
||||
### MCP Integration
|
||||
- 🔧 **File Mode**: AI can read, search, and list local files
|
||||
### NEW: MCP (Model Context Protocol) v2.1.0-beta
|
||||
- 🔧 **File Mode**: AI can read, search, and list your local files
|
||||
- Automatic .gitignore filtering
|
||||
- Virtual environment exclusion
|
||||
- Virtual environment exclusion (venv, node_modules, etc.)
|
||||
- Supports code files, text, JSON, YAML, and more
|
||||
- Large file handling (auto-truncates >50KB)
|
||||
|
||||
- 🗄️ **Database Mode**: AI can query your SQLite databases
|
||||
- Read-only access (no data modification possible)
|
||||
- Schema inspection (tables, columns, indexes)
|
||||
- Full-text search across all tables
|
||||
- SQL query execution (SELECT, JOINs, CTEs, subqueries)
|
||||
- Query validation and timeout protection
|
||||
- Result limiting (max 1000 rows)
|
||||
|
||||
- ✍️ **Write Mode**: AI can modify files with permission
|
||||
- Create, edit, delete files
|
||||
- Move, copy, organize files
|
||||
- Always requires explicit opt-in
|
||||
|
||||
- 🗄️ **Database Mode**: AI can query SQLite databases
|
||||
- Read-only access (safe)
|
||||
- Schema inspection
|
||||
- Full SQL query support
|
||||
- 🔒 **Security Features**:
|
||||
- Explicit folder/database approval required
|
||||
- System directory blocking
|
||||
- Read-only database access
|
||||
- SQL injection protection
|
||||
- Query timeout (5 seconds)
|
||||
|
||||
## Requirements
|
||||
|
||||
- Python 3.10-3.13
|
||||
- OpenRouter API key ([get one here](https://openrouter.ai))
|
||||
- Python 3.10-3.13 (3.14 not supported yet)
|
||||
- OpenRouter API key (get one at https://openrouter.ai)
|
||||
- Function-calling model required for MCP features (GPT-4, Claude, etc.)
|
||||
|
||||
## Screenshot
|
||||
|
||||
[<img src="https://gitlab.pm/rune/oai/raw/branch/main/images/screenshot_01.png">](https://gitlab.pm/rune/oai/src/branch/main/README.md)
|
||||
|
||||
*Screenshot from version 1.0 - MCP interface shows mode indicators like `[🔧 MCP: Files]` or `[🗄️ MCP: DB #1]`*
|
||||
|
||||
## Installation
|
||||
|
||||
### Option 1: Pre-built Binary (macOS/Linux) (Recommended)
|
||||
### Option 1: From Source (Recommended for Development)
|
||||
|
||||
Download from [Releases](https://gitlab.pm/rune/oai/releases):
|
||||
- **macOS (Apple Silicon)**: `oai_v3.0.0_mac_arm64.zip`
|
||||
- **Linux (x86_64)**: `oai_v3.0.0_linux_x86_64.zip`
|
||||
#### 1. Install Dependencies
|
||||
|
||||
```bash
|
||||
# Extract and install
|
||||
unzip oai_v3.0.0_*.zip
|
||||
mkdir -p ~/.local/bin
|
||||
mv oai ~/.local/bin/
|
||||
|
||||
# macOS only: Remove quarantine and approve
|
||||
xattr -cr ~/.local/bin/oai
|
||||
# Then right-click oai in Finder → Open With → Terminal → Click "Open"
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### Add to PATH
|
||||
#### 2. Make Executable
|
||||
|
||||
```bash
|
||||
# Add to ~/.zshrc or ~/.bashrc
|
||||
chmod +x oai.py
|
||||
```
|
||||
|
||||
#### 3. Copy to PATH
|
||||
|
||||
```bash
|
||||
# Option 1: System-wide (requires sudo)
|
||||
sudo cp oai.py /usr/local/bin/oai
|
||||
|
||||
# Option 2: User-local (recommended)
|
||||
mkdir -p ~/.local/bin
|
||||
cp oai.py ~/.local/bin/oai
|
||||
|
||||
# Add to PATH if needed (add to ~/.bashrc or ~/.zshrc)
|
||||
export PATH="$HOME/.local/bin:$PATH"
|
||||
```
|
||||
|
||||
|
||||
### Option 2: Install from Source
|
||||
#### 4. Verify Installation
|
||||
|
||||
```bash
|
||||
# Clone the repository
|
||||
git clone https://gitlab.pm/rune/oai.git
|
||||
cd oai
|
||||
|
||||
# Install with pip
|
||||
pip install -e .
|
||||
oai --version
|
||||
```
|
||||
|
||||
### Option 2: Pre-built Binaries
|
||||
|
||||
Download platform-specific binaries:
|
||||
- **macOS (Apple Silicon)**: `oai_vx.x.x_mac_arm64.zip`
|
||||
- **Linux (x86_64)**: `oai_vx.x.x-linux-x86_64.zip`
|
||||
|
||||
```bash
|
||||
# Extract and install
|
||||
unzip oai_vx.x.x_mac_arm64.zip # or `oai_vx.x.x-linux-x86_64.zip`
|
||||
chmod +x oai
|
||||
mkdir -p ~/.local/bin
|
||||
mv oai ~/.local/bin/
|
||||
```
|
||||
|
||||
### Option 3: Build Your Own Binary
|
||||
|
||||
```bash
|
||||
# Install build dependencies
|
||||
pip install -r requirements.txt
|
||||
pip install nuitka ordered-set zstandard
|
||||
|
||||
# Run build script
|
||||
chmod +x build.sh
|
||||
./build.sh
|
||||
|
||||
# Binary will be in dist/oai
|
||||
cp dist/oai ~/.local/bin/
|
||||
```
|
||||
|
||||
### Alternative: Shell Alias
|
||||
|
||||
```bash
|
||||
# Add to ~/.bashrc or ~/.zshrc
|
||||
alias oai='python3 /path/to/oai.py'
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
### First Run Setup
|
||||
|
||||
```bash
|
||||
# Start oAI (launches TUI)
|
||||
oai
|
||||
```
|
||||
|
||||
On first run, you'll be prompted to enter your OpenRouter API key.
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```bash
|
||||
# Start chatting
|
||||
oai
|
||||
|
||||
# Or with options
|
||||
oai --model gpt-4o --online --mcp
|
||||
# Select a model
|
||||
You> /model
|
||||
|
||||
# Show version
|
||||
oai version
|
||||
# Enable MCP for file access
|
||||
You> /mcp enable
|
||||
You> /mcp add ~/Documents
|
||||
|
||||
# Ask AI to help with files
|
||||
[🔧 MCP: Files] You> List all Python files in Documents
|
||||
[🔧 MCP: Files] You> Read and explain main.py
|
||||
|
||||
# Switch to database mode
|
||||
You> /mcp add db ~/myapp/data.db
|
||||
You> /mcp db 1
|
||||
[🗄️ MCP: DB #1] You> Show me all tables
|
||||
[🗄️ MCP: DB #1] You> Find all users created this month
|
||||
```
|
||||
|
||||
On first run, you'll be prompted for your OpenRouter API key.
|
||||
## MCP Guide
|
||||
|
||||
### Basic Commands
|
||||
### File Mode (Default)
|
||||
|
||||
**Setup:**
|
||||
```bash
|
||||
# In the TUI interface:
|
||||
/model # Select AI model (or press F2)
|
||||
/help # Show all commands (or press F1)
|
||||
/mcp on # Enable file/database access
|
||||
/stats # View session statistics (or press Ctrl+S)
|
||||
/config # View configuration settings
|
||||
/credits # Check account credits
|
||||
Ctrl+Q # Quit
|
||||
/mcp enable # Start MCP server
|
||||
/mcp add ~/Projects # Grant access to folder
|
||||
/mcp add ~/Documents # Add another folder
|
||||
/mcp list # View all allowed folders
|
||||
```
|
||||
|
||||
## MCP (Model Context Protocol)
|
||||
|
||||
MCP allows the AI to interact with your local files and databases.
|
||||
|
||||
### File Access
|
||||
|
||||
```bash
|
||||
/mcp on # Enable MCP
|
||||
/mcp add ~/Projects # Grant access to folder
|
||||
/mcp list # View allowed folders
|
||||
|
||||
# Now ask the AI:
|
||||
**Natural Language Usage:**
|
||||
```
|
||||
"List all Python files in Projects"
|
||||
"Read and explain main.py"
|
||||
"Read and explain config.yaml"
|
||||
"Search for files containing 'TODO'"
|
||||
"What's in my Documents folder?"
|
||||
```
|
||||
|
||||
### Write Mode
|
||||
**Available Tools:**
|
||||
- `read_file` - Read complete file contents
|
||||
- `list_directory` - List files/folders (recursive optional)
|
||||
- `search_files` - Search by name or content
|
||||
|
||||
```bash
|
||||
/mcp write on # Enable file modifications
|
||||
|
||||
# AI can now:
|
||||
"Create a new file called utils.py"
|
||||
"Edit config.json and update the API URL"
|
||||
"Delete the old backup files" # Always asks for confirmation
|
||||
```
|
||||
**Features:**
|
||||
- ✅ Automatic .gitignore filtering
|
||||
- ✅ Skips virtual environments (venv, node_modules)
|
||||
- ✅ Handles large files (auto-truncates >50KB)
|
||||
- ✅ Cross-platform (macOS, Linux, Windows via WSL)
|
||||
|
||||
### Database Mode
|
||||
|
||||
**Setup:**
|
||||
```bash
|
||||
/mcp add db ~/app/data.db # Add database
|
||||
/mcp db 1 # Switch to database mode
|
||||
/mcp add db ~/app/database.db # Add SQLite database
|
||||
/mcp db list # View all databases
|
||||
/mcp db 1 # Switch to database #1
|
||||
```
|
||||
|
||||
# Ask the AI:
|
||||
"Show all tables"
|
||||
"Find users created this month"
|
||||
"What's the schema for the orders table?"
|
||||
**Natural Language Usage:**
|
||||
```
|
||||
"Show me all tables in this database"
|
||||
"Find records mentioning 'error'"
|
||||
"How many users registered last week?"
|
||||
"Get the schema for the orders table"
|
||||
"Show me the 10 most recent transactions"
|
||||
```
|
||||
|
||||
**Available Tools:**
|
||||
- `inspect_database` - View schema, tables, columns, indexes
|
||||
- `search_database` - Full-text search across tables
|
||||
- `query_database` - Execute read-only SQL queries
|
||||
|
||||
**Supported Queries:**
|
||||
- ✅ SELECT statements
|
||||
- ✅ JOINs (INNER, LEFT, RIGHT, FULL)
|
||||
- ✅ Subqueries
|
||||
- ✅ CTEs (Common Table Expressions)
|
||||
- ✅ Aggregations (COUNT, SUM, AVG, etc.)
|
||||
- ✅ WHERE, GROUP BY, HAVING, ORDER BY, LIMIT
|
||||
- ❌ INSERT/UPDATE/DELETE (blocked for safety)
|
||||
|
||||
### Mode Management
|
||||
|
||||
```bash
|
||||
/mcp status # Show current mode, stats, folders/databases
|
||||
/mcp files # Switch to file mode
|
||||
/mcp db <number> # Switch to database mode
|
||||
/mcp gitignore on # Enable .gitignore filtering (default)
|
||||
/mcp remove 2 # Remove folder/database by number
|
||||
```
|
||||
|
||||
## Command Reference
|
||||
|
||||
### Chat Commands
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `/help [cmd]` | Show help |
|
||||
| `/model [search]` | Select model |
|
||||
| `/info [model]` | Model details |
|
||||
| `/memory on\|off` | Toggle context |
|
||||
| `/online on\|off` | Toggle web search |
|
||||
| `/retry` | Resend last message |
|
||||
| `/clear` | Clear screen |
|
||||
### Session Commands
|
||||
```
|
||||
/help [command] Show help menu or detailed command help
|
||||
/help mcp Comprehensive MCP guide
|
||||
/clear or /cl Clear terminal screen (or Ctrl+L)
|
||||
/memory on|off Toggle conversation memory (save costs)
|
||||
/online on|off Enable/disable web search
|
||||
/paste [prompt] Paste clipboard content
|
||||
/retry Resend last prompt
|
||||
/reset Clear history and system prompt
|
||||
/prev View previous response
|
||||
/next View next response
|
||||
```
|
||||
|
||||
### MCP Commands
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `/mcp on\|off` | Enable/disable MCP |
|
||||
| `/mcp status` | Show MCP status |
|
||||
| `/mcp add <path>` | Add folder |
|
||||
| `/mcp add db <path>` | Add database |
|
||||
| `/mcp list` | List folders |
|
||||
| `/mcp db list` | List databases |
|
||||
| `/mcp db <n>` | Switch to database |
|
||||
| `/mcp files` | Switch to file mode |
|
||||
| `/mcp write on\|off` | Toggle write mode |
|
||||
```
|
||||
/mcp enable Start MCP server
|
||||
/mcp disable Stop MCP server
|
||||
/mcp status Show comprehensive status
|
||||
/mcp add <folder> Add folder for file access
|
||||
/mcp add db <path> Add SQLite database
|
||||
/mcp list List all folders
|
||||
/mcp db list List all databases
|
||||
/mcp db <number> Switch to database mode
|
||||
/mcp files Switch to file mode
|
||||
/mcp remove <num> Remove folder/database
|
||||
/mcp gitignore on Enable .gitignore filtering
|
||||
```
|
||||
|
||||
### Conversation Commands
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `/save <name>` | Save conversation |
|
||||
| `/load <name>` | Load conversation |
|
||||
| `/list` | List saved conversations |
|
||||
| `/delete <name>` | Delete conversation |
|
||||
| `/export md\|json\|html <file>` | Export |
|
||||
### Model Commands
|
||||
```
|
||||
/model [search] Select/change AI model
|
||||
/info [model_id] Show model details (pricing, capabilities)
|
||||
```
|
||||
|
||||
### Configuration
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `/config` | View settings |
|
||||
| `/config api` | Set API key |
|
||||
| `/config model <id>` | Set default model |
|
||||
| `/config stream on\|off` | Toggle streaming |
|
||||
| `/stats` | Session statistics |
|
||||
| `/credits` | Check credits |
|
||||
```
|
||||
/config View all settings
|
||||
/config api Set API key
|
||||
/config model Set default model
|
||||
/config online Set default online mode (on|off)
|
||||
/config stream Enable/disable streaming (on|off)
|
||||
/config maxtoken Set max token limit
|
||||
/config costwarning Set cost warning threshold ($)
|
||||
/config loglevel Set log level (debug/info/warning/error)
|
||||
/config log Set log file size (MB)
|
||||
```
|
||||
|
||||
## CLI Options
|
||||
### Conversation Management
|
||||
```
|
||||
/save <name> Save conversation
|
||||
/load <name|num> Load saved conversation
|
||||
/delete <name|num> Delete conversation
|
||||
/list List saved conversations
|
||||
/export md|json|html <file> Export conversation
|
||||
```
|
||||
|
||||
### Token & System
|
||||
```
|
||||
/maxtoken [value] Set session token limit
|
||||
/system [prompt] Set system prompt (use 'clear' to reset)
|
||||
/middleout on|off Enable prompt compression
|
||||
```
|
||||
|
||||
### Monitoring
|
||||
```
|
||||
/stats View session statistics
|
||||
/credits Check OpenRouter credits
|
||||
```
|
||||
|
||||
### File Attachments
|
||||
```
|
||||
@/path/to/file Attach file (images, PDFs, code)
|
||||
|
||||
Examples:
|
||||
Debug @script.py
|
||||
Analyze @data.json
|
||||
Review @screenshot.png
|
||||
```
|
||||
|
||||
## Configuration Options
|
||||
|
||||
All configuration stored in `~/.config/oai/`:
|
||||
|
||||
### Files
|
||||
- `oai_config.db` - SQLite database (settings, conversations, MCP config)
|
||||
- `oai.log` - Application logs (rotating, configurable size)
|
||||
- `history.txt` - Command history (searchable with Ctrl+R)
|
||||
|
||||
### Key Settings
|
||||
- **API Key**: OpenRouter authentication
|
||||
- **Default Model**: Auto-select on startup
|
||||
- **Streaming**: Real-time response display
|
||||
- **Max Tokens**: Global and session limits
|
||||
- **Cost Warning**: Alert threshold for expensive requests
|
||||
- **Online Mode**: Default web search setting
|
||||
- **Log Level**: debug/info/warning/error/critical
|
||||
- **Log Size**: Rotating file size in MB
|
||||
|
||||
## Supported File Types
|
||||
|
||||
### Code Files
|
||||
`.py, .js, .ts, .cs, .java, .c, .cpp, .h, .hpp, .rb, .ruby, .php, .swift, .kt, .kts, .go, .sh, .bat, .ps1, .R, .scala, .pl, .lua, .dart, .elm`
|
||||
|
||||
### Data Files
|
||||
`.json, .yaml, .yml, .xml, .csv, .txt, .md`
|
||||
|
||||
### Images
|
||||
All standard formats: PNG, JPEG, JPG, GIF, WEBP, BMP
|
||||
|
||||
### Documents
|
||||
PDF (models with document support)
|
||||
|
||||
### Size Limits
|
||||
- Images: 10 MB max
|
||||
- Code/Text: Auto-truncates files >50KB
|
||||
- Binary data: Displayed as `<binary: X bytes>`
|
||||
|
||||
## MCP Security
|
||||
|
||||
### Access Control
|
||||
- ✅ Explicit folder/database approval required
|
||||
- ✅ System directories blocked automatically
|
||||
- ✅ User confirmation for each addition
|
||||
- ✅ .gitignore patterns respected (file mode)
|
||||
|
||||
### Database Safety
|
||||
- ✅ Read-only mode (cannot modify data)
|
||||
- ✅ SQL query validation (blocks INSERT/UPDATE/DELETE)
|
||||
- ✅ Query timeout (5 seconds max)
|
||||
- ✅ Result limits (1000 rows max)
|
||||
- ✅ Database opened in `mode=ro`
|
||||
|
||||
### File System Safety
|
||||
- ✅ Read-only access (no write/delete)
|
||||
- ✅ Virtual environment exclusion
|
||||
- ✅ Build artifact filtering
|
||||
- ✅ Maximum file size (10 MB)
|
||||
|
||||
## Tips & Tricks
|
||||
|
||||
### Command History
|
||||
- **↑/↓ arrows**: Navigate previous commands
|
||||
- **Ctrl+R**: Search command history
|
||||
- **Auto-complete**: Start typing `/` for command suggestions
|
||||
|
||||
### Cost Optimization
|
||||
```bash
|
||||
oai [OPTIONS]
|
||||
|
||||
Options:
|
||||
-m, --model TEXT Model ID to use
|
||||
-s, --system TEXT System prompt
|
||||
-o, --online Enable online mode
|
||||
--mcp Enable MCP server
|
||||
-v, --version Show version
|
||||
--help Show help
|
||||
/memory off # Disable context (stateless mode)
|
||||
/maxtoken 1000 # Limit response length
|
||||
/config costwarning 0.01 # Set alert threshold
|
||||
```
|
||||
|
||||
Commands:
|
||||
### MCP Best Practices
|
||||
```bash
|
||||
oai # Launch TUI (default)
|
||||
oai version # Show version information
|
||||
oai --help # Show help message
|
||||
# Check status frequently
|
||||
/mcp status
|
||||
|
||||
# Use specific paths to reduce search time
|
||||
"List Python files in Projects/app/" # Better than
|
||||
"List all Python files" # Slower
|
||||
|
||||
# Database queries - be specific
|
||||
"SELECT * FROM users LIMIT 10" # Good
|
||||
"SELECT * FROM users" # May hit row limit
|
||||
```
|
||||
|
||||
## Configuration
|
||||
### Debugging
|
||||
```bash
|
||||
# Enable debug logging
|
||||
/config loglevel debug
|
||||
|
||||
Configuration is stored in `~/.config/oai/`:
|
||||
# Check log file
|
||||
tail -f ~/.config/oai/oai.log
|
||||
|
||||
| File | Purpose |
|
||||
|------|---------|
|
||||
| `oai_config.db` | Settings, conversations, MCP config |
|
||||
| `oai.log` | Application logs |
|
||||
| `history.txt` | Command history |
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
oai/
|
||||
├── oai/
|
||||
│ ├── __init__.py
|
||||
│ ├── __main__.py # Entry point for python -m oai
|
||||
│ ├── cli.py # Main CLI entry point
|
||||
│ ├── constants.py # Configuration constants
|
||||
│ ├── commands/ # Slash command handlers
|
||||
│ ├── config/ # Settings and database
|
||||
│ ├── core/ # Chat client and session
|
||||
│ ├── mcp/ # MCP server and tools
|
||||
│ ├── providers/ # AI provider abstraction
|
||||
│ ├── tui/ # Textual TUI interface
|
||||
│ │ ├── app.py # Main TUI application
|
||||
│ │ ├── widgets/ # Custom widgets
|
||||
│ │ ├── screens/ # Modal screens
|
||||
│ │ └── styles.tcss # TUI styling
|
||||
│ └── utils/ # Logging, export, etc.
|
||||
├── pyproject.toml # Package configuration
|
||||
├── build.sh # Binary build script
|
||||
└── README.md
|
||||
# View MCP statistics
|
||||
/mcp status # Shows tool call counts
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### macOS Binary Issues
|
||||
|
||||
```bash
|
||||
# Remove quarantine attribute
|
||||
xattr -cr ~/.local/bin/oai
|
||||
|
||||
# Then in Finder: right-click oai → Open With → Terminal → Click "Open"
|
||||
# After this, oai works from any terminal
|
||||
```
|
||||
|
||||
### MCP Not Working
|
||||
|
||||
```bash
|
||||
# Check if model supports function calling
|
||||
# 1. Check if MCP is installed
|
||||
python3 -c "import mcp; print('MCP OK')"
|
||||
|
||||
# 2. Verify model supports function calling
|
||||
/info # Look for "tools" in supported parameters
|
||||
|
||||
# Check MCP status
|
||||
# 3. Check MCP status
|
||||
/mcp status
|
||||
|
||||
# View logs
|
||||
tail -f ~/.config/oai/oai.log
|
||||
# 4. Review logs
|
||||
tail ~/.config/oai/oai.log
|
||||
```
|
||||
|
||||
### Import Errors
|
||||
|
||||
```bash
|
||||
# Reinstall package
|
||||
pip install -e . --force-reinstall
|
||||
# Reinstall dependencies
|
||||
pip install --force-reinstall -r requirements.txt
|
||||
```
|
||||
|
||||
### Binary Issues (macOS)
|
||||
```bash
|
||||
# Remove quarantine
|
||||
xattr -cr ~/.local/bin/oai
|
||||
|
||||
# Check security settings
|
||||
# System Settings > Privacy & Security > "Allow Anyway"
|
||||
```
|
||||
|
||||
### Database Errors
|
||||
```bash
|
||||
# Verify it's a valid SQLite database
|
||||
sqlite3 database.db ".tables"
|
||||
|
||||
# Check file permissions
|
||||
ls -la database.db
|
||||
```
|
||||
|
||||
## Version History
|
||||
|
||||
### v3.0.0 (Current)
|
||||
- 🎨 **Complete migration to Textual TUI** - Modern async terminal interface
|
||||
- 🗑️ **Removed CLI interface** - TUI-only for cleaner codebase (11.6% smaller)
|
||||
- 🖱️ **Modal screens** - Help, stats, config, credits, model selector
|
||||
- ⌨️ **Keyboard shortcuts** - F1 (help), F2 (models), Ctrl+S (stats), etc.
|
||||
- 🎯 **Capability indicators** - Visual icons for model features (vision, tools, online)
|
||||
- 🎨 **Consistent dark theme** - Professional styling throughout
|
||||
- 📊 **Enhanced model selector** - Search, filter, capability columns
|
||||
- 🚀 **Default command** - Just run `oai` to launch TUI
|
||||
- 🧹 **Code cleanup** - Removed 1,300+ lines of CLI code
|
||||
### v2.1.0-beta (Current)
|
||||
- ✨ **NEW**: MCP (Model Context Protocol) integration
|
||||
- ✨ **NEW**: File system access (read, search, list)
|
||||
- ✨ **NEW**: SQLite database querying (read-only)
|
||||
- ✨ **NEW**: Dual mode support (Files & Database)
|
||||
- ✨ **NEW**: .gitignore filtering
|
||||
- ✨ **NEW**: Binary data handling in databases
|
||||
- ✨ **NEW**: Mode indicators in prompt
|
||||
- ✨ **NEW**: Comprehensive `/help mcp` guide
|
||||
- 🔧 Improved error handling for tool calls
|
||||
- 🔧 Enhanced logging for MCP operations
|
||||
- 🔧 Statistics tracking for tool usage
|
||||
|
||||
### v2.1.0
|
||||
- 🏗️ Complete codebase refactoring to modular package structure
|
||||
- 🔌 Extensible provider architecture for adding new AI providers
|
||||
- 📦 Proper Python packaging with pyproject.toml
|
||||
- ✨ MCP integration (file access, write mode, database queries)
|
||||
- 🔧 Command registry pattern for slash commands
|
||||
- 📊 Improved cost tracking and session statistics
|
||||
|
||||
### v1.9.x
|
||||
- Single-file implementation
|
||||
- Core chat functionality
|
||||
- File attachments
|
||||
### v1.9.6
|
||||
- Base version with core chat functionality
|
||||
- Conversation management
|
||||
- File attachments
|
||||
- Cost tracking
|
||||
- Export capabilities
|
||||
|
||||
## License
|
||||
|
||||
MIT License - See [LICENSE](LICENSE) for details.
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2024-2025 Rune Olsen
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
Full license: https://opensource.org/licenses/MIT
|
||||
|
||||
## Author
|
||||
|
||||
**Rune Olsen**
|
||||
|
||||
- Blog: https://blog.rune.pm
|
||||
- Project: https://iurl.no/oai
|
||||
- Repository: https://gitlab.pm/rune/oai
|
||||
|
||||
## Contributing
|
||||
|
||||
Contributions welcome! Please:
|
||||
1. Fork the repository
|
||||
2. Create a feature branch
|
||||
3. Submit a pull request
|
||||
3. Submit a pull request with detailed description
|
||||
|
||||
## Acknowledgments
|
||||
|
||||
- OpenRouter team for the unified AI API
|
||||
- Rich library for beautiful terminal output
|
||||
- MCP community for the protocol specification
|
||||
|
||||
---
|
||||
|
||||
**⭐ Star this project if you find it useful!**
|
||||
**Star ⭐ this project if you find it useful!**
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
"""
|
||||
oAI - OpenRouter AI Chat Client
|
||||
|
||||
A feature-rich terminal-based chat application that provides an interactive CLI
|
||||
interface to OpenRouter's unified AI API with advanced Model Context Protocol (MCP)
|
||||
integration for filesystem and database access.
|
||||
|
||||
Author: Rune
|
||||
License: MIT
|
||||
"""
|
||||
|
||||
__version__ = "3.0.0-b2"
|
||||
__author__ = "Rune"
|
||||
__license__ = "MIT"
|
||||
|
||||
# Lazy imports to avoid circular dependencies and improve startup time
|
||||
# Full imports are available via submodules:
|
||||
# from oai.config import Settings, Database
|
||||
# from oai.providers import OpenRouterProvider, AIProvider
|
||||
# from oai.mcp import MCPManager
|
||||
|
||||
__all__ = [
|
||||
"__version__",
|
||||
"__author__",
|
||||
"__license__",
|
||||
]
|
||||
@@ -1,8 +0,0 @@
|
||||
"""
|
||||
Entry point for running oAI as a module: python -m oai
|
||||
"""
|
||||
|
||||
from oai.cli import main
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
199
oai/cli.py
199
oai/cli.py
@@ -1,199 +0,0 @@
|
||||
"""
|
||||
Main CLI entry point for oAI.
|
||||
|
||||
This module provides the command-line interface for the oAI TUI application.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
import typer
|
||||
|
||||
from oai import __version__
|
||||
from oai.commands import register_all_commands
|
||||
from oai.config.settings import Settings
|
||||
from oai.constants import APP_URL, APP_VERSION
|
||||
from oai.core.client import AIClient
|
||||
from oai.core.session import ChatSession
|
||||
from oai.mcp.manager import MCPManager
|
||||
from oai.utils.logging import LoggingManager, get_logger
|
||||
|
||||
# Create Typer app
|
||||
app = typer.Typer(
|
||||
name="oai",
|
||||
help=f"oAI - OpenRouter AI Chat Client (TUI)\n\nVersion: {APP_VERSION}",
|
||||
add_completion=False,
|
||||
epilog="For more information, visit: " + APP_URL,
|
||||
)
|
||||
|
||||
|
||||
@app.callback(invoke_without_command=True)
|
||||
def main_callback(
|
||||
ctx: typer.Context,
|
||||
version_flag: bool = typer.Option(
|
||||
False,
|
||||
"--version",
|
||||
"-v",
|
||||
help="Show version information",
|
||||
is_flag=True,
|
||||
),
|
||||
model: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--model",
|
||||
"-m",
|
||||
help="Model ID to use",
|
||||
),
|
||||
system: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--system",
|
||||
"-s",
|
||||
help="System prompt",
|
||||
),
|
||||
online: bool = typer.Option(
|
||||
False,
|
||||
"--online",
|
||||
"-o",
|
||||
help="Enable online mode",
|
||||
),
|
||||
mcp: bool = typer.Option(
|
||||
False,
|
||||
"--mcp",
|
||||
help="Enable MCP server",
|
||||
),
|
||||
) -> None:
|
||||
"""Main callback - launches TUI by default."""
|
||||
if version_flag:
|
||||
typer.echo(f"oAI version {APP_VERSION}")
|
||||
raise typer.Exit()
|
||||
|
||||
# If no subcommand provided, launch TUI
|
||||
if ctx.invoked_subcommand is None:
|
||||
_launch_tui(model, system, online, mcp)
|
||||
|
||||
|
||||
def _launch_tui(
|
||||
model: Optional[str] = None,
|
||||
system: Optional[str] = None,
|
||||
online: bool = False,
|
||||
mcp: bool = False,
|
||||
) -> None:
|
||||
"""Launch the Textual TUI interface."""
|
||||
# Setup logging
|
||||
logging_manager = LoggingManager()
|
||||
logging_manager.setup()
|
||||
logger = get_logger()
|
||||
|
||||
# Load settings
|
||||
settings = Settings.load()
|
||||
|
||||
# Check API key
|
||||
if not settings.api_key:
|
||||
typer.echo("Error: No API key configured", err=True)
|
||||
typer.echo("Run: oai config api to set your API key", err=True)
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Initialize client
|
||||
try:
|
||||
client = AIClient(
|
||||
api_key=settings.api_key,
|
||||
base_url=settings.base_url,
|
||||
)
|
||||
except Exception as e:
|
||||
typer.echo(f"Error: Failed to initialize client: {e}", err=True)
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Register commands
|
||||
register_all_commands()
|
||||
|
||||
# Initialize MCP manager (always create it, even if not enabled)
|
||||
mcp_manager = MCPManager()
|
||||
if mcp:
|
||||
try:
|
||||
result = mcp_manager.enable()
|
||||
if result["success"]:
|
||||
logger.info("MCP server enabled in files mode")
|
||||
else:
|
||||
logger.warning(f"MCP: {result.get('error', 'Failed to enable')}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to enable MCP: {e}")
|
||||
|
||||
# Create session with MCP manager
|
||||
session = ChatSession(
|
||||
client=client,
|
||||
settings=settings,
|
||||
mcp_manager=mcp_manager,
|
||||
)
|
||||
|
||||
# Set system prompt if provided
|
||||
if system:
|
||||
session.set_system_prompt(system)
|
||||
|
||||
# Enable online mode if requested
|
||||
if online:
|
||||
session.online_enabled = True
|
||||
|
||||
# Set model if specified, otherwise use default
|
||||
if model:
|
||||
raw_model = client.get_raw_model(model)
|
||||
if raw_model:
|
||||
session.set_model(raw_model)
|
||||
else:
|
||||
logger.warning(f"Model '{model}' not found")
|
||||
elif settings.default_model:
|
||||
raw_model = client.get_raw_model(settings.default_model)
|
||||
if raw_model:
|
||||
session.set_model(raw_model)
|
||||
else:
|
||||
logger.warning(f"Default model '{settings.default_model}' not available")
|
||||
|
||||
# Run Textual app
|
||||
from oai.tui.app import oAIChatApp
|
||||
|
||||
app_instance = oAIChatApp(session, settings, model)
|
||||
app_instance.run()
|
||||
|
||||
|
||||
@app.command()
|
||||
def tui(
|
||||
model: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--model",
|
||||
"-m",
|
||||
help="Model ID to use",
|
||||
),
|
||||
system: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--system",
|
||||
"-s",
|
||||
help="System prompt",
|
||||
),
|
||||
online: bool = typer.Option(
|
||||
False,
|
||||
"--online",
|
||||
"-o",
|
||||
help="Enable online mode",
|
||||
),
|
||||
mcp: bool = typer.Option(
|
||||
False,
|
||||
"--mcp",
|
||||
help="Enable MCP server",
|
||||
),
|
||||
) -> None:
|
||||
"""Start Textual TUI interface (alias for just running 'oai')."""
|
||||
_launch_tui(model, system, online, mcp)
|
||||
|
||||
|
||||
@app.command()
|
||||
def version() -> None:
|
||||
"""Show version information."""
|
||||
typer.echo(f"oAI version {APP_VERSION}")
|
||||
typer.echo(f"Visit {APP_URL} for more information")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Entry point for the CLI."""
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,24 +0,0 @@
|
||||
"""
|
||||
Command system for oAI.
|
||||
|
||||
This module provides a command registry and handler system
|
||||
for processing slash commands in the chat interface.
|
||||
"""
|
||||
|
||||
from oai.commands.registry import (
|
||||
Command,
|
||||
CommandRegistry,
|
||||
CommandContext,
|
||||
CommandResult,
|
||||
registry,
|
||||
)
|
||||
from oai.commands.handlers import register_all_commands
|
||||
|
||||
__all__ = [
|
||||
"Command",
|
||||
"CommandRegistry",
|
||||
"CommandContext",
|
||||
"CommandResult",
|
||||
"registry",
|
||||
"register_all_commands",
|
||||
]
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,382 +0,0 @@
|
||||
"""
|
||||
Command registry for oAI.
|
||||
|
||||
This module defines the command system infrastructure including
|
||||
the Command base class, CommandContext for state, and CommandRegistry
|
||||
for managing available commands.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING
|
||||
|
||||
from oai.utils.logging import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from oai.config.settings import Settings
|
||||
from oai.providers.base import AIProvider, ModelInfo
|
||||
from oai.mcp.manager import MCPManager
|
||||
|
||||
|
||||
class CommandStatus(str, Enum):
|
||||
"""Status of command execution."""
|
||||
|
||||
SUCCESS = "success"
|
||||
ERROR = "error"
|
||||
CONTINUE = "continue" # Continue to next handler
|
||||
EXIT = "exit" # Exit the application
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommandResult:
|
||||
"""
|
||||
Result of a command execution.
|
||||
|
||||
Attributes:
|
||||
status: Execution status
|
||||
message: Optional message to display
|
||||
data: Optional data payload
|
||||
should_continue: Whether to continue the main loop
|
||||
"""
|
||||
|
||||
status: CommandStatus = CommandStatus.SUCCESS
|
||||
message: Optional[str] = None
|
||||
data: Optional[Any] = None
|
||||
should_continue: bool = True
|
||||
|
||||
@classmethod
|
||||
def success(cls, message: Optional[str] = None, data: Any = None) -> "CommandResult":
|
||||
"""Create a success result."""
|
||||
return cls(status=CommandStatus.SUCCESS, message=message, data=data)
|
||||
|
||||
@classmethod
|
||||
def error(cls, message: str) -> "CommandResult":
|
||||
"""Create an error result."""
|
||||
return cls(status=CommandStatus.ERROR, message=message)
|
||||
|
||||
@classmethod
|
||||
def exit(cls, message: Optional[str] = None) -> "CommandResult":
|
||||
"""Create an exit result."""
|
||||
return cls(status=CommandStatus.EXIT, message=message, should_continue=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommandContext:
|
||||
"""
|
||||
Context object providing state to command handlers.
|
||||
|
||||
Contains all the session state needed by commands including
|
||||
settings, provider, conversation history, and MCP manager.
|
||||
|
||||
Attributes:
|
||||
settings: Application settings
|
||||
provider: AI provider instance
|
||||
mcp_manager: MCP manager instance
|
||||
selected_model: Currently selected model
|
||||
session_history: Conversation history
|
||||
session_system_prompt: Current system prompt
|
||||
memory_enabled: Whether memory is enabled
|
||||
online_enabled: Whether online mode is enabled
|
||||
session_tokens: Session token counts
|
||||
session_cost: Session cost total
|
||||
"""
|
||||
|
||||
settings: Optional["Settings"] = None
|
||||
provider: Optional["AIProvider"] = None
|
||||
mcp_manager: Optional["MCPManager"] = None
|
||||
selected_model: Optional["ModelInfo"] = None
|
||||
selected_model_raw: Optional[Dict[str, Any]] = None
|
||||
session_history: List[Dict[str, Any]] = field(default_factory=list)
|
||||
session_system_prompt: str = ""
|
||||
memory_enabled: bool = True
|
||||
memory_start_index: int = 0
|
||||
online_enabled: bool = False
|
||||
middle_out_enabled: bool = False
|
||||
session_max_token: int = 0
|
||||
total_input_tokens: int = 0
|
||||
total_output_tokens: int = 0
|
||||
total_cost: float = 0.0
|
||||
message_count: int = 0
|
||||
is_tui: bool = False # Flag for TUI mode
|
||||
current_index: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommandHelp:
|
||||
"""
|
||||
Help information for a command.
|
||||
|
||||
Attributes:
|
||||
description: Brief description
|
||||
usage: Usage syntax
|
||||
examples: List of (description, example) tuples
|
||||
notes: Additional notes
|
||||
aliases: Command aliases
|
||||
"""
|
||||
|
||||
description: str
|
||||
usage: str = ""
|
||||
examples: List[tuple] = field(default_factory=list)
|
||||
notes: str = ""
|
||||
aliases: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
class Command(ABC):
|
||||
"""
|
||||
Abstract base class for all commands.
|
||||
|
||||
Commands implement the execute method to handle their logic.
|
||||
They can also provide help information and aliases.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Get the primary command name (e.g., '/help')."""
|
||||
pass
|
||||
|
||||
@property
|
||||
def aliases(self) -> List[str]:
|
||||
"""Get command aliases (e.g., ['/h'] for help)."""
|
||||
return []
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def help(self) -> CommandHelp:
|
||||
"""Get command help information."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def execute(self, args: str, context: CommandContext) -> CommandResult:
|
||||
"""
|
||||
Execute the command.
|
||||
|
||||
Args:
|
||||
args: Arguments passed to the command
|
||||
context: Command execution context
|
||||
|
||||
Returns:
|
||||
CommandResult indicating success/failure
|
||||
"""
|
||||
pass
|
||||
|
||||
def matches(self, input_text: str) -> bool:
|
||||
"""
|
||||
Check if this command matches the input.
|
||||
|
||||
Args:
|
||||
input_text: User input text
|
||||
|
||||
Returns:
|
||||
True if this command should handle the input
|
||||
"""
|
||||
input_lower = input_text.lower()
|
||||
cmd_word = input_lower.split()[0] if input_lower.split() else ""
|
||||
|
||||
# Check primary name
|
||||
if cmd_word == self.name.lower():
|
||||
return True
|
||||
|
||||
# Check aliases
|
||||
for alias in self.aliases:
|
||||
if cmd_word == alias.lower():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_args(self, input_text: str) -> str:
|
||||
"""
|
||||
Extract arguments from the input text.
|
||||
|
||||
Args:
|
||||
input_text: Full user input
|
||||
|
||||
Returns:
|
||||
Arguments portion of the input
|
||||
"""
|
||||
parts = input_text.split(maxsplit=1)
|
||||
return parts[1] if len(parts) > 1 else ""
|
||||
|
||||
|
||||
class CommandRegistry:
|
||||
"""
|
||||
Registry for managing available commands.
|
||||
|
||||
Provides registration, lookup, and execution of commands.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize an empty command registry."""
|
||||
self._commands: Dict[str, Command] = {}
|
||||
self._aliases: Dict[str, str] = {}
|
||||
self.logger = get_logger()
|
||||
|
||||
def register(self, command: Command) -> None:
|
||||
"""
|
||||
Register a command.
|
||||
|
||||
Args:
|
||||
command: Command instance to register
|
||||
|
||||
Raises:
|
||||
ValueError: If command name already registered
|
||||
"""
|
||||
name = command.name.lower()
|
||||
|
||||
if name in self._commands:
|
||||
raise ValueError(f"Command '{name}' already registered")
|
||||
|
||||
self._commands[name] = command
|
||||
|
||||
# Register aliases
|
||||
for alias in command.aliases:
|
||||
alias_lower = alias.lower()
|
||||
if alias_lower in self._aliases:
|
||||
self.logger.warning(
|
||||
f"Alias '{alias}' already registered, overwriting"
|
||||
)
|
||||
self._aliases[alias_lower] = name
|
||||
|
||||
self.logger.debug(f"Registered command: {name}")
|
||||
|
||||
def register_function(
|
||||
self,
|
||||
name: str,
|
||||
handler: Callable[[str, CommandContext], CommandResult],
|
||||
description: str,
|
||||
usage: str = "",
|
||||
aliases: Optional[List[str]] = None,
|
||||
examples: Optional[List[tuple]] = None,
|
||||
notes: str = "",
|
||||
) -> None:
|
||||
"""
|
||||
Register a function-based command.
|
||||
|
||||
Convenience method for simple commands that don't need
|
||||
a full Command class.
|
||||
|
||||
Args:
|
||||
name: Command name (e.g., '/help')
|
||||
handler: Function to execute
|
||||
description: Help description
|
||||
usage: Usage syntax
|
||||
aliases: Command aliases
|
||||
examples: Example usages
|
||||
notes: Additional notes
|
||||
"""
|
||||
aliases = aliases or []
|
||||
examples = examples or []
|
||||
|
||||
class FunctionCommand(Command):
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return name
|
||||
|
||||
@property
|
||||
def aliases(self) -> List[str]:
|
||||
return aliases
|
||||
|
||||
@property
|
||||
def help(self) -> CommandHelp:
|
||||
return CommandHelp(
|
||||
description=description,
|
||||
usage=usage,
|
||||
examples=examples,
|
||||
notes=notes,
|
||||
aliases=aliases,
|
||||
)
|
||||
|
||||
def execute(self, args: str, context: CommandContext) -> CommandResult:
|
||||
return handler(args, context)
|
||||
|
||||
self.register(FunctionCommand())
|
||||
|
||||
def get(self, name: str) -> Optional[Command]:
|
||||
"""
|
||||
Get a command by name or alias.
|
||||
|
||||
Args:
|
||||
name: Command name or alias
|
||||
|
||||
Returns:
|
||||
Command instance or None if not found
|
||||
"""
|
||||
name_lower = name.lower()
|
||||
|
||||
# Check direct match
|
||||
if name_lower in self._commands:
|
||||
return self._commands[name_lower]
|
||||
|
||||
# Check aliases
|
||||
if name_lower in self._aliases:
|
||||
return self._commands[self._aliases[name_lower]]
|
||||
|
||||
return None
|
||||
|
||||
def find(self, input_text: str) -> Optional[Command]:
|
||||
"""
|
||||
Find a command that matches the input.
|
||||
|
||||
Args:
|
||||
input_text: User input text
|
||||
|
||||
Returns:
|
||||
Matching Command or None
|
||||
"""
|
||||
cmd_word = input_text.lower().split()[0] if input_text.split() else ""
|
||||
return self.get(cmd_word)
|
||||
|
||||
def execute(self, input_text: str, context: CommandContext) -> Optional[CommandResult]:
|
||||
"""
|
||||
Execute a command matching the input.
|
||||
|
||||
Args:
|
||||
input_text: User input text
|
||||
context: Execution context
|
||||
|
||||
Returns:
|
||||
CommandResult or None if no matching command
|
||||
"""
|
||||
command = self.find(input_text)
|
||||
if command:
|
||||
args = command.get_args(input_text)
|
||||
self.logger.debug(f"Executing command: {command.name} with args: {args}")
|
||||
return command.execute(args, context)
|
||||
return None
|
||||
|
||||
def is_command(self, input_text: str) -> bool:
|
||||
"""
|
||||
Check if input is a valid command.
|
||||
|
||||
Args:
|
||||
input_text: User input text
|
||||
|
||||
Returns:
|
||||
True if input matches a registered command
|
||||
"""
|
||||
return self.find(input_text) is not None
|
||||
|
||||
def list_commands(self) -> List[Command]:
|
||||
"""
|
||||
Get all registered commands.
|
||||
|
||||
Returns:
|
||||
List of Command instances
|
||||
"""
|
||||
return list(self._commands.values())
|
||||
|
||||
def get_all_names(self) -> List[str]:
|
||||
"""
|
||||
Get all command names and aliases.
|
||||
|
||||
Returns:
|
||||
List of command names including aliases
|
||||
"""
|
||||
names = list(self._commands.keys())
|
||||
names.extend(self._aliases.keys())
|
||||
return sorted(set(names))
|
||||
|
||||
|
||||
# Global registry instance
|
||||
registry = CommandRegistry()
|
||||
@@ -1,11 +0,0 @@
|
||||
"""
|
||||
Configuration management for oAI.
|
||||
|
||||
This package handles all configuration persistence, settings management,
|
||||
and database operations for the application.
|
||||
"""
|
||||
|
||||
from oai.config.settings import Settings
|
||||
from oai.config.database import Database
|
||||
|
||||
__all__ = ["Settings", "Database"]
|
||||
@@ -1,472 +0,0 @@
|
||||
"""
|
||||
Database persistence layer for oAI.
|
||||
|
||||
This module provides a clean abstraction for SQLite operations including
|
||||
configuration storage, conversation persistence, and MCP statistics tracking.
|
||||
All database operations are centralized here for maintainability.
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
import json
|
||||
import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Dict, Any
|
||||
from contextlib import contextmanager
|
||||
|
||||
from oai.constants import DATABASE_FILE, CONFIG_DIR
|
||||
|
||||
|
||||
class Database:
|
||||
"""
|
||||
SQLite database manager for oAI.
|
||||
|
||||
Handles all database operations including:
|
||||
- Configuration key-value storage
|
||||
- Conversation session persistence
|
||||
- MCP configuration and statistics
|
||||
- Database registrations for MCP
|
||||
|
||||
Uses context managers for safe connection handling and supports
|
||||
automatic table creation on first use.
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: Optional[Path] = None):
|
||||
"""
|
||||
Initialize the database manager.
|
||||
|
||||
Args:
|
||||
db_path: Optional custom database path. Defaults to standard location.
|
||||
"""
|
||||
self.db_path = db_path or DATABASE_FILE
|
||||
self._ensure_directories()
|
||||
self._ensure_tables()
|
||||
|
||||
def _ensure_directories(self) -> None:
|
||||
"""Ensure the configuration directory exists."""
|
||||
CONFIG_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@contextmanager
|
||||
def _connection(self):
|
||||
"""
|
||||
Context manager for database connections.
|
||||
|
||||
Yields:
|
||||
sqlite3.Connection: Active database connection
|
||||
|
||||
Example:
|
||||
with self._connection() as conn:
|
||||
conn.execute("SELECT * FROM config")
|
||||
"""
|
||||
conn = sqlite3.connect(str(self.db_path))
|
||||
try:
|
||||
yield conn
|
||||
conn.commit()
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def _ensure_tables(self) -> None:
|
||||
"""Create all required tables if they don't exist."""
|
||||
with self._connection() as conn:
|
||||
# Main configuration table
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS config (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT NOT NULL
|
||||
)
|
||||
""")
|
||||
|
||||
# Conversation sessions table
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS conversation_sessions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT NOT NULL,
|
||||
timestamp TEXT NOT NULL,
|
||||
data TEXT NOT NULL
|
||||
)
|
||||
""")
|
||||
|
||||
# MCP configuration table
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS mcp_config (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT NOT NULL
|
||||
)
|
||||
""")
|
||||
|
||||
# MCP statistics table
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS mcp_stats (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
timestamp TEXT NOT NULL,
|
||||
tool_name TEXT NOT NULL,
|
||||
folder TEXT,
|
||||
success INTEGER NOT NULL,
|
||||
error_message TEXT
|
||||
)
|
||||
""")
|
||||
|
||||
# MCP databases table
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS mcp_databases (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
path TEXT NOT NULL UNIQUE,
|
||||
name TEXT NOT NULL,
|
||||
size INTEGER,
|
||||
tables TEXT,
|
||||
added_timestamp TEXT NOT NULL
|
||||
)
|
||||
""")
|
||||
|
||||
# =========================================================================
|
||||
# CONFIGURATION METHODS
|
||||
# =========================================================================
|
||||
|
||||
def get_config(self, key: str) -> Optional[str]:
|
||||
"""
|
||||
Retrieve a configuration value by key.
|
||||
|
||||
Args:
|
||||
key: The configuration key to retrieve
|
||||
|
||||
Returns:
|
||||
The configuration value, or None if not found
|
||||
"""
|
||||
with self._connection() as conn:
|
||||
cursor = conn.execute(
|
||||
"SELECT value FROM config WHERE key = ?",
|
||||
(key,)
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
return result[0] if result else None
|
||||
|
||||
def set_config(self, key: str, value: str) -> None:
|
||||
"""
|
||||
Set a configuration value.
|
||||
|
||||
Args:
|
||||
key: The configuration key
|
||||
value: The value to store
|
||||
"""
|
||||
with self._connection() as conn:
|
||||
conn.execute(
|
||||
"INSERT OR REPLACE INTO config (key, value) VALUES (?, ?)",
|
||||
(key, value)
|
||||
)
|
||||
|
||||
def delete_config(self, key: str) -> bool:
|
||||
"""
|
||||
Delete a configuration value.
|
||||
|
||||
Args:
|
||||
key: The configuration key to delete
|
||||
|
||||
Returns:
|
||||
True if a row was deleted, False otherwise
|
||||
"""
|
||||
with self._connection() as conn:
|
||||
cursor = conn.execute(
|
||||
"DELETE FROM config WHERE key = ?",
|
||||
(key,)
|
||||
)
|
||||
return cursor.rowcount > 0
|
||||
|
||||
def get_all_config(self) -> Dict[str, str]:
|
||||
"""
|
||||
Retrieve all configuration values.
|
||||
|
||||
Returns:
|
||||
Dictionary of all key-value pairs
|
||||
"""
|
||||
with self._connection() as conn:
|
||||
cursor = conn.execute("SELECT key, value FROM config")
|
||||
return dict(cursor.fetchall())
|
||||
|
||||
# =========================================================================
|
||||
# MCP CONFIGURATION METHODS
|
||||
# =========================================================================
|
||||
|
||||
def get_mcp_config(self, key: str) -> Optional[str]:
|
||||
"""
|
||||
Retrieve an MCP configuration value.
|
||||
|
||||
Args:
|
||||
key: The MCP configuration key
|
||||
|
||||
Returns:
|
||||
The configuration value, or None if not found
|
||||
"""
|
||||
with self._connection() as conn:
|
||||
cursor = conn.execute(
|
||||
"SELECT value FROM mcp_config WHERE key = ?",
|
||||
(key,)
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
return result[0] if result else None
|
||||
|
||||
def set_mcp_config(self, key: str, value: str) -> None:
|
||||
"""
|
||||
Set an MCP configuration value.
|
||||
|
||||
Args:
|
||||
key: The MCP configuration key
|
||||
value: The value to store
|
||||
"""
|
||||
with self._connection() as conn:
|
||||
conn.execute(
|
||||
"INSERT OR REPLACE INTO mcp_config (key, value) VALUES (?, ?)",
|
||||
(key, value)
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# MCP STATISTICS METHODS
|
||||
# =========================================================================
|
||||
|
||||
def log_mcp_stat(
|
||||
self,
|
||||
tool_name: str,
|
||||
folder: Optional[str],
|
||||
success: bool,
|
||||
error_message: Optional[str] = None
|
||||
) -> None:
|
||||
"""
|
||||
Log an MCP tool usage event.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the MCP tool that was called
|
||||
folder: The folder path involved (if any)
|
||||
success: Whether the call succeeded
|
||||
error_message: Error message if the call failed
|
||||
"""
|
||||
timestamp = datetime.datetime.now().isoformat()
|
||||
with self._connection() as conn:
|
||||
conn.execute(
|
||||
"""INSERT INTO mcp_stats
|
||||
(timestamp, tool_name, folder, success, error_message)
|
||||
VALUES (?, ?, ?, ?, ?)""",
|
||||
(timestamp, tool_name, folder, 1 if success else 0, error_message)
|
||||
)
|
||||
|
||||
def get_mcp_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get aggregated MCP usage statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary containing usage statistics:
|
||||
- total_calls: Total number of tool calls
|
||||
- reads: Number of file reads
|
||||
- lists: Number of directory listings
|
||||
- searches: Number of file searches
|
||||
- db_inspects: Number of database inspections
|
||||
- db_searches: Number of database searches
|
||||
- db_queries: Number of database queries
|
||||
- last_used: Timestamp of last usage
|
||||
"""
|
||||
with self._connection() as conn:
|
||||
cursor = conn.execute("""
|
||||
SELECT
|
||||
COUNT(*) as total_calls,
|
||||
SUM(CASE WHEN tool_name = 'read_file' THEN 1 ELSE 0 END) as reads,
|
||||
SUM(CASE WHEN tool_name = 'list_directory' THEN 1 ELSE 0 END) as lists,
|
||||
SUM(CASE WHEN tool_name = 'search_files' THEN 1 ELSE 0 END) as searches,
|
||||
SUM(CASE WHEN tool_name = 'inspect_database' THEN 1 ELSE 0 END) as db_inspects,
|
||||
SUM(CASE WHEN tool_name = 'search_database' THEN 1 ELSE 0 END) as db_searches,
|
||||
SUM(CASE WHEN tool_name = 'query_database' THEN 1 ELSE 0 END) as db_queries,
|
||||
MAX(timestamp) as last_used
|
||||
FROM mcp_stats
|
||||
""")
|
||||
row = cursor.fetchone()
|
||||
return {
|
||||
"total_calls": row[0] or 0,
|
||||
"reads": row[1] or 0,
|
||||
"lists": row[2] or 0,
|
||||
"searches": row[3] or 0,
|
||||
"db_inspects": row[4] or 0,
|
||||
"db_searches": row[5] or 0,
|
||||
"db_queries": row[6] or 0,
|
||||
"last_used": row[7],
|
||||
}
|
||||
|
||||
# =========================================================================
|
||||
# MCP DATABASE REGISTRY METHODS
|
||||
# =========================================================================
|
||||
|
||||
def add_mcp_database(self, db_info: Dict[str, Any]) -> int:
|
||||
"""
|
||||
Register a database for MCP access.
|
||||
|
||||
Args:
|
||||
db_info: Dictionary containing:
|
||||
- path: Database file path
|
||||
- name: Display name
|
||||
- size: File size in bytes
|
||||
- tables: List of table names
|
||||
- added: Timestamp when added
|
||||
|
||||
Returns:
|
||||
The database ID
|
||||
"""
|
||||
with self._connection() as conn:
|
||||
conn.execute(
|
||||
"""INSERT INTO mcp_databases
|
||||
(path, name, size, tables, added_timestamp)
|
||||
VALUES (?, ?, ?, ?, ?)""",
|
||||
(
|
||||
db_info["path"],
|
||||
db_info["name"],
|
||||
db_info["size"],
|
||||
json.dumps(db_info["tables"]),
|
||||
db_info["added"]
|
||||
)
|
||||
)
|
||||
cursor = conn.execute(
|
||||
"SELECT id FROM mcp_databases WHERE path = ?",
|
||||
(db_info["path"],)
|
||||
)
|
||||
return cursor.fetchone()[0]
|
||||
|
||||
def remove_mcp_database(self, db_path: str) -> bool:
|
||||
"""
|
||||
Remove a database from the MCP registry.
|
||||
|
||||
Args:
|
||||
db_path: Path to the database file
|
||||
|
||||
Returns:
|
||||
True if a row was deleted, False otherwise
|
||||
"""
|
||||
with self._connection() as conn:
|
||||
cursor = conn.execute(
|
||||
"DELETE FROM mcp_databases WHERE path = ?",
|
||||
(db_path,)
|
||||
)
|
||||
return cursor.rowcount > 0
|
||||
|
||||
def get_mcp_databases(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Retrieve all registered MCP databases.
|
||||
|
||||
Returns:
|
||||
List of database information dictionaries
|
||||
"""
|
||||
with self._connection() as conn:
|
||||
cursor = conn.execute(
|
||||
"""SELECT id, path, name, size, tables, added_timestamp
|
||||
FROM mcp_databases ORDER BY id"""
|
||||
)
|
||||
databases = []
|
||||
for row in cursor.fetchall():
|
||||
tables_list = json.loads(row[4]) if row[4] else []
|
||||
databases.append({
|
||||
"id": row[0],
|
||||
"path": row[1],
|
||||
"name": row[2],
|
||||
"size": row[3],
|
||||
"tables": tables_list,
|
||||
"added": row[5],
|
||||
})
|
||||
return databases
|
||||
|
||||
# =========================================================================
|
||||
# CONVERSATION METHODS
|
||||
# =========================================================================
|
||||
|
||||
def save_conversation(self, name: str, data: List[Dict[str, str]]) -> None:
|
||||
"""
|
||||
Save a conversation session.
|
||||
|
||||
Args:
|
||||
name: Name/identifier for the conversation
|
||||
data: List of message dictionaries with 'prompt' and 'response' keys
|
||||
"""
|
||||
timestamp = datetime.datetime.now().isoformat()
|
||||
data_json = json.dumps(data)
|
||||
with self._connection() as conn:
|
||||
conn.execute(
|
||||
"""INSERT INTO conversation_sessions
|
||||
(name, timestamp, data) VALUES (?, ?, ?)""",
|
||||
(name, timestamp, data_json)
|
||||
)
|
||||
|
||||
def load_conversation(self, name: str) -> Optional[List[Dict[str, str]]]:
|
||||
"""
|
||||
Load a conversation by name.
|
||||
|
||||
Args:
|
||||
name: Name of the conversation to load
|
||||
|
||||
Returns:
|
||||
List of messages, or None if not found
|
||||
"""
|
||||
with self._connection() as conn:
|
||||
cursor = conn.execute(
|
||||
"""SELECT data FROM conversation_sessions
|
||||
WHERE name = ?
|
||||
ORDER BY timestamp DESC LIMIT 1""",
|
||||
(name,)
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
if result:
|
||||
return json.loads(result[0])
|
||||
return None
|
||||
|
||||
def delete_conversation(self, name: str) -> int:
|
||||
"""
|
||||
Delete a conversation by name.
|
||||
|
||||
Args:
|
||||
name: Name of the conversation to delete
|
||||
|
||||
Returns:
|
||||
Number of rows deleted
|
||||
"""
|
||||
with self._connection() as conn:
|
||||
cursor = conn.execute(
|
||||
"DELETE FROM conversation_sessions WHERE name = ?",
|
||||
(name,)
|
||||
)
|
||||
return cursor.rowcount
|
||||
|
||||
def list_conversations(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
List all saved conversations.
|
||||
|
||||
Returns:
|
||||
List of conversation summaries with name, timestamp, and message_count
|
||||
"""
|
||||
with self._connection() as conn:
|
||||
cursor = conn.execute("""
|
||||
SELECT name, MAX(timestamp) as last_saved, data
|
||||
FROM conversation_sessions
|
||||
GROUP BY name
|
||||
ORDER BY last_saved DESC
|
||||
""")
|
||||
conversations = []
|
||||
for row in cursor.fetchall():
|
||||
name, timestamp, data_json = row
|
||||
data = json.loads(data_json)
|
||||
conversations.append({
|
||||
"name": name,
|
||||
"timestamp": timestamp,
|
||||
"message_count": len(data),
|
||||
})
|
||||
return conversations
|
||||
|
||||
|
||||
# Global database instance for convenience
|
||||
_db: Optional[Database] = None
|
||||
|
||||
|
||||
def get_database() -> Database:
|
||||
"""
|
||||
Get the global database instance.
|
||||
|
||||
Returns:
|
||||
The shared Database instance
|
||||
"""
|
||||
global _db
|
||||
if _db is None:
|
||||
_db = Database()
|
||||
return _db
|
||||
@@ -1,361 +0,0 @@
|
||||
"""
|
||||
Settings management for oAI.
|
||||
|
||||
This module provides a centralized settings class that handles all application
|
||||
configuration with type safety, validation, and persistence.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
|
||||
from oai.constants import (
|
||||
DEFAULT_BASE_URL,
|
||||
DEFAULT_STREAM_ENABLED,
|
||||
DEFAULT_MAX_TOKENS,
|
||||
DEFAULT_ONLINE_MODE,
|
||||
DEFAULT_COST_WARNING_THRESHOLD,
|
||||
DEFAULT_LOG_MAX_SIZE_MB,
|
||||
DEFAULT_LOG_BACKUP_COUNT,
|
||||
DEFAULT_LOG_LEVEL,
|
||||
DEFAULT_SYSTEM_PROMPT,
|
||||
VALID_LOG_LEVELS,
|
||||
)
|
||||
from oai.config.database import get_database
|
||||
|
||||
|
||||
@dataclass
|
||||
class Settings:
|
||||
"""
|
||||
Application settings with persistence support.
|
||||
|
||||
This class provides a clean interface for managing all configuration
|
||||
options. Settings are automatically loaded from the database on
|
||||
initialization and can be persisted back.
|
||||
|
||||
Attributes:
|
||||
api_key: OpenRouter API key
|
||||
base_url: API base URL
|
||||
default_model: Default model ID to use
|
||||
default_system_prompt: Custom system prompt (None = use hardcoded default, "" = blank)
|
||||
stream_enabled: Whether to stream responses
|
||||
max_tokens: Maximum tokens per request
|
||||
cost_warning_threshold: Alert threshold for message cost
|
||||
default_online_mode: Whether online mode is enabled by default
|
||||
log_max_size_mb: Maximum log file size in MB
|
||||
log_backup_count: Number of log file backups to keep
|
||||
log_level: Logging level (debug/info/warning/error/critical)
|
||||
"""
|
||||
|
||||
api_key: Optional[str] = None
|
||||
base_url: str = DEFAULT_BASE_URL
|
||||
default_model: Optional[str] = None
|
||||
default_system_prompt: Optional[str] = None
|
||||
stream_enabled: bool = DEFAULT_STREAM_ENABLED
|
||||
max_tokens: int = DEFAULT_MAX_TOKENS
|
||||
cost_warning_threshold: float = DEFAULT_COST_WARNING_THRESHOLD
|
||||
default_online_mode: bool = DEFAULT_ONLINE_MODE
|
||||
log_max_size_mb: int = DEFAULT_LOG_MAX_SIZE_MB
|
||||
log_backup_count: int = DEFAULT_LOG_BACKUP_COUNT
|
||||
log_level: str = DEFAULT_LOG_LEVEL
|
||||
|
||||
@property
|
||||
def effective_system_prompt(self) -> str:
|
||||
"""
|
||||
Get the effective system prompt to use.
|
||||
|
||||
Returns:
|
||||
The custom prompt if set, hardcoded default if None, or blank if explicitly set to ""
|
||||
"""
|
||||
if self.default_system_prompt is None:
|
||||
return DEFAULT_SYSTEM_PROMPT
|
||||
return self.default_system_prompt
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate settings after initialization."""
|
||||
self._validate()
|
||||
|
||||
def _validate(self) -> None:
|
||||
"""Validate all settings values."""
|
||||
# Validate log level
|
||||
if self.log_level.lower() not in VALID_LOG_LEVELS:
|
||||
raise ValueError(
|
||||
f"Invalid log level: {self.log_level}. "
|
||||
f"Must be one of: {', '.join(VALID_LOG_LEVELS.keys())}"
|
||||
)
|
||||
|
||||
# Validate numeric bounds
|
||||
if self.max_tokens < 1:
|
||||
raise ValueError("max_tokens must be at least 1")
|
||||
|
||||
if self.cost_warning_threshold < 0:
|
||||
raise ValueError("cost_warning_threshold must be non-negative")
|
||||
|
||||
if self.log_max_size_mb < 1:
|
||||
raise ValueError("log_max_size_mb must be at least 1")
|
||||
|
||||
if self.log_backup_count < 0:
|
||||
raise ValueError("log_backup_count must be non-negative")
|
||||
|
||||
@classmethod
|
||||
def load(cls) -> "Settings":
|
||||
"""
|
||||
Load settings from the database.
|
||||
|
||||
Returns:
|
||||
Settings instance with values from database
|
||||
"""
|
||||
db = get_database()
|
||||
|
||||
# Helper to safely parse boolean
|
||||
def parse_bool(value: Optional[str], default: bool) -> bool:
|
||||
if value is None:
|
||||
return default
|
||||
return value.lower() in ("on", "true", "1", "yes")
|
||||
|
||||
# Helper to safely parse int
|
||||
def parse_int(value: Optional[str], default: int) -> int:
|
||||
if value is None:
|
||||
return default
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
return default
|
||||
|
||||
# Helper to safely parse float
|
||||
def parse_float(value: Optional[str], default: float) -> float:
|
||||
if value is None:
|
||||
return default
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
return default
|
||||
|
||||
# Get system prompt from DB: None means not set (use default), "" means explicitly blank
|
||||
system_prompt_value = db.get_config("default_system_prompt")
|
||||
|
||||
return cls(
|
||||
api_key=db.get_config("api_key"),
|
||||
base_url=db.get_config("base_url") or DEFAULT_BASE_URL,
|
||||
default_model=db.get_config("default_model"),
|
||||
default_system_prompt=system_prompt_value,
|
||||
stream_enabled=parse_bool(
|
||||
db.get_config("stream_enabled"),
|
||||
DEFAULT_STREAM_ENABLED
|
||||
),
|
||||
max_tokens=parse_int(
|
||||
db.get_config("max_token"),
|
||||
DEFAULT_MAX_TOKENS
|
||||
),
|
||||
cost_warning_threshold=parse_float(
|
||||
db.get_config("cost_warning_threshold"),
|
||||
DEFAULT_COST_WARNING_THRESHOLD
|
||||
),
|
||||
default_online_mode=parse_bool(
|
||||
db.get_config("default_online_mode"),
|
||||
DEFAULT_ONLINE_MODE
|
||||
),
|
||||
log_max_size_mb=parse_int(
|
||||
db.get_config("log_max_size_mb"),
|
||||
DEFAULT_LOG_MAX_SIZE_MB
|
||||
),
|
||||
log_backup_count=parse_int(
|
||||
db.get_config("log_backup_count"),
|
||||
DEFAULT_LOG_BACKUP_COUNT
|
||||
),
|
||||
log_level=db.get_config("log_level") or DEFAULT_LOG_LEVEL,
|
||||
)
|
||||
|
||||
def save(self) -> None:
|
||||
"""Persist all settings to the database."""
|
||||
db = get_database()
|
||||
|
||||
# Only save API key if it exists
|
||||
if self.api_key:
|
||||
db.set_config("api_key", self.api_key)
|
||||
|
||||
db.set_config("base_url", self.base_url)
|
||||
|
||||
if self.default_model:
|
||||
db.set_config("default_model", self.default_model)
|
||||
|
||||
# Save system prompt: None means not set (don't save), otherwise save the value (even if "")
|
||||
if self.default_system_prompt is not None:
|
||||
db.set_config("default_system_prompt", self.default_system_prompt)
|
||||
|
||||
db.set_config("stream_enabled", "on" if self.stream_enabled else "off")
|
||||
db.set_config("max_token", str(self.max_tokens))
|
||||
db.set_config("cost_warning_threshold", str(self.cost_warning_threshold))
|
||||
db.set_config("default_online_mode", "on" if self.default_online_mode else "off")
|
||||
db.set_config("log_max_size_mb", str(self.log_max_size_mb))
|
||||
db.set_config("log_backup_count", str(self.log_backup_count))
|
||||
db.set_config("log_level", self.log_level)
|
||||
|
||||
def set_api_key(self, api_key: str) -> None:
|
||||
"""
|
||||
Set and persist the API key.
|
||||
|
||||
Args:
|
||||
api_key: The new API key
|
||||
"""
|
||||
self.api_key = api_key.strip()
|
||||
get_database().set_config("api_key", self.api_key)
|
||||
|
||||
def set_base_url(self, url: str) -> None:
|
||||
"""
|
||||
Set and persist the base URL.
|
||||
|
||||
Args:
|
||||
url: The new base URL
|
||||
"""
|
||||
self.base_url = url.strip()
|
||||
get_database().set_config("base_url", self.base_url)
|
||||
|
||||
def set_default_model(self, model_id: str) -> None:
|
||||
"""
|
||||
Set and persist the default model.
|
||||
|
||||
Args:
|
||||
model_id: The model ID to set as default
|
||||
"""
|
||||
self.default_model = model_id
|
||||
get_database().set_config("default_model", model_id)
|
||||
|
||||
def set_default_system_prompt(self, prompt: str) -> None:
|
||||
"""
|
||||
Set and persist the default system prompt.
|
||||
|
||||
Args:
|
||||
prompt: The system prompt to use for all new sessions.
|
||||
Empty string "" means blank prompt (no system message).
|
||||
"""
|
||||
self.default_system_prompt = prompt
|
||||
get_database().set_config("default_system_prompt", prompt)
|
||||
|
||||
def clear_default_system_prompt(self) -> None:
|
||||
"""
|
||||
Clear the custom system prompt and revert to hardcoded default.
|
||||
|
||||
This removes the custom prompt from the database, causing the
|
||||
application to use the built-in DEFAULT_SYSTEM_PROMPT.
|
||||
"""
|
||||
self.default_system_prompt = None
|
||||
# Remove from database to indicate "not set"
|
||||
db = get_database()
|
||||
with db._connection() as conn:
|
||||
conn.execute("DELETE FROM config WHERE key = ?", ("default_system_prompt",))
|
||||
conn.commit()
|
||||
|
||||
def set_stream_enabled(self, enabled: bool) -> None:
|
||||
"""
|
||||
Set and persist the streaming preference.
|
||||
|
||||
Args:
|
||||
enabled: Whether to enable streaming
|
||||
"""
|
||||
self.stream_enabled = enabled
|
||||
get_database().set_config("stream_enabled", "on" if enabled else "off")
|
||||
|
||||
def set_max_tokens(self, max_tokens: int) -> None:
|
||||
"""
|
||||
Set and persist the maximum tokens.
|
||||
|
||||
Args:
|
||||
max_tokens: Maximum number of tokens
|
||||
|
||||
Raises:
|
||||
ValueError: If max_tokens is less than 1
|
||||
"""
|
||||
if max_tokens < 1:
|
||||
raise ValueError("max_tokens must be at least 1")
|
||||
self.max_tokens = max_tokens
|
||||
get_database().set_config("max_token", str(max_tokens))
|
||||
|
||||
def set_cost_warning_threshold(self, threshold: float) -> None:
|
||||
"""
|
||||
Set and persist the cost warning threshold.
|
||||
|
||||
Args:
|
||||
threshold: Cost threshold in USD
|
||||
|
||||
Raises:
|
||||
ValueError: If threshold is negative
|
||||
"""
|
||||
if threshold < 0:
|
||||
raise ValueError("cost_warning_threshold must be non-negative")
|
||||
self.cost_warning_threshold = threshold
|
||||
get_database().set_config("cost_warning_threshold", str(threshold))
|
||||
|
||||
def set_default_online_mode(self, enabled: bool) -> None:
|
||||
"""
|
||||
Set and persist the default online mode.
|
||||
|
||||
Args:
|
||||
enabled: Whether online mode should be enabled by default
|
||||
"""
|
||||
self.default_online_mode = enabled
|
||||
get_database().set_config("default_online_mode", "on" if enabled else "off")
|
||||
|
||||
def set_log_level(self, level: str) -> None:
|
||||
"""
|
||||
Set and persist the log level.
|
||||
|
||||
Args:
|
||||
level: The log level (debug/info/warning/error/critical)
|
||||
|
||||
Raises:
|
||||
ValueError: If level is not valid
|
||||
"""
|
||||
level_lower = level.lower()
|
||||
if level_lower not in VALID_LOG_LEVELS:
|
||||
raise ValueError(
|
||||
f"Invalid log level: {level}. "
|
||||
f"Must be one of: {', '.join(VALID_LOG_LEVELS.keys())}"
|
||||
)
|
||||
self.log_level = level_lower
|
||||
get_database().set_config("log_level", level_lower)
|
||||
|
||||
def set_log_max_size(self, size_mb: int) -> None:
|
||||
"""
|
||||
Set and persist the maximum log file size.
|
||||
|
||||
Args:
|
||||
size_mb: Maximum size in megabytes
|
||||
|
||||
Raises:
|
||||
ValueError: If size_mb is less than 1
|
||||
"""
|
||||
if size_mb < 1:
|
||||
raise ValueError("log_max_size_mb must be at least 1")
|
||||
# Cap at 100 MB for safety
|
||||
self.log_max_size_mb = min(size_mb, 100)
|
||||
get_database().set_config("log_max_size_mb", str(self.log_max_size_mb))
|
||||
|
||||
|
||||
# Global settings instance
|
||||
_settings: Optional[Settings] = None
|
||||
|
||||
|
||||
def get_settings() -> Settings:
|
||||
"""
|
||||
Get the global settings instance.
|
||||
|
||||
Returns:
|
||||
The shared Settings instance, loading from database if needed
|
||||
"""
|
||||
global _settings
|
||||
if _settings is None:
|
||||
_settings = Settings.load()
|
||||
return _settings
|
||||
|
||||
|
||||
def reload_settings() -> Settings:
|
||||
"""
|
||||
Force reload settings from the database.
|
||||
|
||||
Returns:
|
||||
Fresh Settings instance
|
||||
"""
|
||||
global _settings
|
||||
_settings = Settings.load()
|
||||
return _settings
|
||||
451
oai/constants.py
451
oai/constants.py
@@ -1,451 +0,0 @@
|
||||
"""
|
||||
Application-wide constants for oAI.
|
||||
|
||||
This module contains all configuration constants, default values, and static
|
||||
definitions used throughout the application. Centralizing these values makes
|
||||
the codebase easier to maintain and configure.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Set, Dict, Any
|
||||
import logging
|
||||
|
||||
# Import version from single source of truth
|
||||
from oai import __version__
|
||||
|
||||
# =============================================================================
|
||||
# APPLICATION METADATA
|
||||
# =============================================================================
|
||||
|
||||
APP_NAME = "oAI"
|
||||
APP_VERSION = __version__ # Single source of truth in oai/__init__.py
|
||||
APP_URL = "https://iurl.no/oai"
|
||||
APP_DESCRIPTION = "OpenRouter AI Chat Client with MCP Integration"
|
||||
|
||||
# =============================================================================
|
||||
# FILE PATHS
|
||||
# =============================================================================
|
||||
|
||||
HOME_DIR = Path.home()
|
||||
CONFIG_DIR = HOME_DIR / ".config" / "oai"
|
||||
CACHE_DIR = HOME_DIR / ".cache" / "oai"
|
||||
HISTORY_FILE = CONFIG_DIR / "history.txt"
|
||||
DATABASE_FILE = CONFIG_DIR / "oai_config.db"
|
||||
LOG_FILE = CONFIG_DIR / "oai.log"
|
||||
|
||||
# =============================================================================
|
||||
# API CONFIGURATION
|
||||
# =============================================================================
|
||||
|
||||
DEFAULT_BASE_URL = "https://openrouter.ai/api/v1"
|
||||
DEFAULT_STREAM_ENABLED = True
|
||||
DEFAULT_MAX_TOKENS = 100_000
|
||||
DEFAULT_ONLINE_MODE = False
|
||||
|
||||
# =============================================================================
|
||||
# DEFAULT SYSTEM PROMPT
|
||||
# =============================================================================
|
||||
|
||||
DEFAULT_SYSTEM_PROMPT = (
|
||||
"You are a knowledgeable and helpful AI assistant. Provide clear, accurate, "
|
||||
"and well-structured responses. Be concise yet thorough. When uncertain about "
|
||||
"something, acknowledge your limitations. For technical topics, include relevant "
|
||||
"details and examples when helpful."
|
||||
)
|
||||
|
||||
# =============================================================================
|
||||
# PRICING DEFAULTS (per million tokens)
|
||||
# =============================================================================
|
||||
|
||||
DEFAULT_INPUT_PRICE = 3.0
|
||||
DEFAULT_OUTPUT_PRICE = 15.0
|
||||
|
||||
MODEL_PRICING: Dict[str, float] = {
|
||||
"input": DEFAULT_INPUT_PRICE,
|
||||
"output": DEFAULT_OUTPUT_PRICE,
|
||||
}
|
||||
|
||||
# =============================================================================
|
||||
# CREDIT ALERTS
|
||||
# =============================================================================
|
||||
|
||||
LOW_CREDIT_RATIO = 0.1 # Alert when credits < 10% of total
|
||||
LOW_CREDIT_AMOUNT = 1.0 # Alert when credits < $1.00
|
||||
DEFAULT_COST_WARNING_THRESHOLD = 0.01 # Alert when single message cost exceeds this
|
||||
COST_WARNING_THRESHOLD = DEFAULT_COST_WARNING_THRESHOLD # Alias for convenience
|
||||
|
||||
# =============================================================================
|
||||
# LOGGING CONFIGURATION
|
||||
# =============================================================================
|
||||
|
||||
DEFAULT_LOG_MAX_SIZE_MB = 10
|
||||
DEFAULT_LOG_BACKUP_COUNT = 2
|
||||
DEFAULT_LOG_LEVEL = "info"
|
||||
|
||||
VALID_LOG_LEVELS: Dict[str, int] = {
|
||||
"debug": logging.DEBUG,
|
||||
"info": logging.INFO,
|
||||
"warning": logging.WARNING,
|
||||
"error": logging.ERROR,
|
||||
"critical": logging.CRITICAL,
|
||||
}
|
||||
|
||||
# =============================================================================
|
||||
# FILE HANDLING
|
||||
# =============================================================================
|
||||
|
||||
# Maximum file size for reading (10 MB)
|
||||
MAX_FILE_SIZE = 10 * 1024 * 1024
|
||||
|
||||
# Content truncation threshold (50 KB)
|
||||
CONTENT_TRUNCATION_THRESHOLD = 50 * 1024
|
||||
|
||||
# Maximum items in directory listing
|
||||
MAX_LIST_ITEMS = 1000
|
||||
|
||||
# Supported code file extensions for syntax highlighting
|
||||
SUPPORTED_CODE_EXTENSIONS: Set[str] = {
|
||||
".py", ".js", ".ts", ".cs", ".java", ".c", ".cpp", ".h", ".hpp",
|
||||
".rb", ".ruby", ".php", ".swift", ".kt", ".kts", ".go",
|
||||
".sh", ".bat", ".ps1", ".r", ".scala", ".pl", ".lua", ".dart",
|
||||
".elm", ".xml", ".json", ".yaml", ".yml", ".md", ".txt",
|
||||
}
|
||||
|
||||
# All allowed file extensions for attachment
|
||||
ALLOWED_FILE_EXTENSIONS: Set[str] = {
|
||||
# Code files
|
||||
".py", ".js", ".ts", ".jsx", ".tsx", ".vue", ".java", ".c", ".cpp", ".cc", ".cxx",
|
||||
".h", ".hpp", ".hxx", ".rb", ".go", ".rs", ".swift", ".kt", ".kts", ".php",
|
||||
".sh", ".bash", ".zsh", ".fish", ".bat", ".cmd", ".ps1",
|
||||
# Data files
|
||||
".json", ".csv", ".yaml", ".yml", ".toml", ".xml", ".sql", ".db", ".sqlite", ".sqlite3",
|
||||
# Documents
|
||||
".txt", ".md", ".log", ".conf", ".cfg", ".ini", ".env", ".properties",
|
||||
# Images
|
||||
".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp", ".svg", ".ico",
|
||||
# Archives
|
||||
".zip", ".tar", ".gz", ".bz2", ".7z", ".rar", ".xz",
|
||||
# Config files
|
||||
".lock", ".gitignore", ".dockerignore", ".editorconfig", ".eslintrc",
|
||||
".prettierrc", ".babelrc", ".nvmrc", ".npmrc",
|
||||
# Binary/Compiled
|
||||
".pyc", ".pyo", ".pyd", ".so", ".dll", ".dylib", ".exe", ".app",
|
||||
".dmg", ".pkg", ".deb", ".rpm", ".apk", ".ipa",
|
||||
# ML/AI
|
||||
".pkl", ".pickle", ".joblib", ".npy", ".npz", ".safetensors", ".onnx",
|
||||
".pt", ".pth", ".ckpt", ".pb", ".tflite", ".mlmodel", ".coreml", ".rknn",
|
||||
# Data formats
|
||||
".wasm", ".proto", ".graphql", ".graphqls", ".grpc", ".avro", ".parquet",
|
||||
".orc", ".feather", ".arrow", ".hdf5", ".h5", ".mat", ".rdata", ".rds",
|
||||
# Other
|
||||
".pdf", ".class", ".jar", ".war",
|
||||
}
|
||||
|
||||
# =============================================================================
|
||||
# SECURITY CONFIGURATION
|
||||
# =============================================================================
|
||||
|
||||
# System directories that should never be accessed
|
||||
SYSTEM_DIRS_BLACKLIST: Set[str] = {
|
||||
# macOS
|
||||
"/System", "/Library", "/private", "/usr", "/bin", "/sbin",
|
||||
# Linux
|
||||
"/boot", "/dev", "/proc", "/sys", "/root",
|
||||
# Windows
|
||||
"C:\\Windows", "C:\\Program Files", "C:\\Program Files (x86)",
|
||||
}
|
||||
|
||||
# Directories to skip during file operations
|
||||
SKIP_DIRECTORIES: Set[str] = {
|
||||
# Python virtual environments
|
||||
".venv", "venv", "env", "virtualenv",
|
||||
"site-packages", "dist-packages",
|
||||
# Python caches
|
||||
"__pycache__", ".pytest_cache", ".mypy_cache",
|
||||
# JavaScript/Node
|
||||
"node_modules",
|
||||
# Version control
|
||||
".git", ".svn",
|
||||
# IDEs
|
||||
".idea", ".vscode",
|
||||
# Build directories
|
||||
"build", "dist", "eggs", ".eggs",
|
||||
}
|
||||
|
||||
# =============================================================================
|
||||
# DATABASE QUERIES - SQL SAFETY
|
||||
# =============================================================================
|
||||
|
||||
# Maximum query execution timeout (seconds)
|
||||
MAX_QUERY_TIMEOUT = 5
|
||||
|
||||
# Maximum rows returned from queries
|
||||
MAX_QUERY_RESULTS = 1000
|
||||
|
||||
# Default rows per query
|
||||
DEFAULT_QUERY_LIMIT = 100
|
||||
|
||||
# Keywords that are blocked in database queries
|
||||
DANGEROUS_SQL_KEYWORDS: Set[str] = {
|
||||
"INSERT", "UPDATE", "DELETE", "DROP", "CREATE",
|
||||
"ALTER", "TRUNCATE", "REPLACE", "ATTACH", "DETACH",
|
||||
"PRAGMA", "VACUUM", "REINDEX",
|
||||
}
|
||||
|
||||
# =============================================================================
|
||||
# MCP CONFIGURATION
|
||||
# =============================================================================
|
||||
|
||||
# Maximum tool call iterations per request
|
||||
MAX_TOOL_LOOPS = 5
|
||||
|
||||
# =============================================================================
|
||||
# VALID COMMANDS
|
||||
# =============================================================================
|
||||
|
||||
VALID_COMMANDS: Set[str] = {
|
||||
"/retry", "/online", "/memory", "/paste", "/export", "/save", "/load",
|
||||
"/delete", "/list", "/prev", "/next", "/stats", "/middleout", "/reset",
|
||||
"/info", "/model", "/maxtoken", "/system", "/config", "/credits", "/clear",
|
||||
"/cl", "/help", "/mcp",
|
||||
}
|
||||
|
||||
# =============================================================================
|
||||
# COMMAND HELP DATABASE
|
||||
# =============================================================================
|
||||
|
||||
COMMAND_HELP: Dict[str, Dict[str, Any]] = {
|
||||
"/clear": {
|
||||
"aliases": ["/cl"],
|
||||
"description": "Clear the terminal screen for a clean interface.",
|
||||
"usage": "/clear\n/cl",
|
||||
"examples": [
|
||||
("Clear screen", "/clear"),
|
||||
("Using short alias", "/cl"),
|
||||
],
|
||||
"notes": "You can also use the keyboard shortcut Ctrl+L.",
|
||||
},
|
||||
"/help": {
|
||||
"description": "Display help information for commands.",
|
||||
"usage": "/help [command|topic]",
|
||||
"examples": [
|
||||
("Show all commands", "/help"),
|
||||
("Get help for a specific command", "/help /model"),
|
||||
("Get detailed MCP help", "/help mcp"),
|
||||
],
|
||||
"notes": "Use /help without arguments to see the full command list.",
|
||||
},
|
||||
"mcp": {
|
||||
"description": "Complete guide to MCP (Model Context Protocol).",
|
||||
"usage": "See detailed examples below",
|
||||
"examples": [],
|
||||
"notes": """
|
||||
MCP (Model Context Protocol) gives your AI assistant direct access to:
|
||||
• Local files and folders (read, search, list)
|
||||
• SQLite databases (inspect, search, query)
|
||||
|
||||
FILE MODE (default):
|
||||
/mcp on Start MCP server
|
||||
/mcp add ~/Documents Grant access to folder
|
||||
/mcp list View all allowed folders
|
||||
|
||||
DATABASE MODE:
|
||||
/mcp add db ~/app/data.db Add specific database
|
||||
/mcp db list View all databases
|
||||
/mcp db 1 Work with database #1
|
||||
/mcp files Switch back to file mode
|
||||
|
||||
WRITE MODE (optional):
|
||||
/mcp write on Enable file modifications
|
||||
/mcp write off Disable write mode (back to read-only)
|
||||
|
||||
For command-specific help: /help /mcp
|
||||
""",
|
||||
},
|
||||
"/mcp": {
|
||||
"description": "Manage MCP for local file access and SQLite database querying.",
|
||||
"usage": "/mcp <command> [args]",
|
||||
"examples": [
|
||||
("Enable MCP server", "/mcp on"),
|
||||
("Disable MCP server", "/mcp off"),
|
||||
("Show MCP status", "/mcp status"),
|
||||
("", ""),
|
||||
("━━━ FILE MODE ━━━", ""),
|
||||
("Add folder for file access", "/mcp add ~/Documents"),
|
||||
("Remove folder", "/mcp remove ~/Desktop"),
|
||||
("List allowed folders", "/mcp list"),
|
||||
("Enable write mode", "/mcp write on"),
|
||||
("", ""),
|
||||
("━━━ DATABASE MODE ━━━", ""),
|
||||
("Add SQLite database", "/mcp add db ~/app/data.db"),
|
||||
("List all databases", "/mcp db list"),
|
||||
("Switch to database #1", "/mcp db 1"),
|
||||
("Switch back to file mode", "/mcp files"),
|
||||
],
|
||||
"notes": "MCP allows AI to read local files and query SQLite databases.",
|
||||
},
|
||||
"/memory": {
|
||||
"description": "Toggle conversation memory.",
|
||||
"usage": "/memory [on|off]",
|
||||
"examples": [
|
||||
("Check current memory status", "/memory"),
|
||||
("Enable conversation memory", "/memory on"),
|
||||
("Disable memory (save costs)", "/memory off"),
|
||||
],
|
||||
"notes": "Memory is ON by default. Disabling saves tokens.",
|
||||
},
|
||||
"/online": {
|
||||
"description": "Enable or disable online mode (web search).",
|
||||
"usage": "/online [on|off]",
|
||||
"examples": [
|
||||
("Check online mode status", "/online"),
|
||||
("Enable web search", "/online on"),
|
||||
("Disable web search", "/online off"),
|
||||
],
|
||||
"notes": "Not all models support online mode.",
|
||||
},
|
||||
"/paste": {
|
||||
"description": "Paste plain text from clipboard and send to the AI.",
|
||||
"usage": "/paste [prompt]",
|
||||
"examples": [
|
||||
("Paste clipboard content", "/paste"),
|
||||
("Paste with a question", "/paste Explain this code"),
|
||||
],
|
||||
"notes": "Only plain text is supported.",
|
||||
},
|
||||
"/retry": {
|
||||
"description": "Resend the last prompt from conversation history.",
|
||||
"usage": "/retry",
|
||||
"examples": [("Retry last message", "/retry")],
|
||||
"notes": "Requires at least one message in history.",
|
||||
},
|
||||
"/next": {
|
||||
"description": "View the next response in conversation history.",
|
||||
"usage": "/next",
|
||||
"examples": [("Navigate to next response", "/next")],
|
||||
"notes": "Use /prev to go backward.",
|
||||
},
|
||||
"/prev": {
|
||||
"description": "View the previous response in conversation history.",
|
||||
"usage": "/prev",
|
||||
"examples": [("Navigate to previous response", "/prev")],
|
||||
"notes": "Use /next to go forward.",
|
||||
},
|
||||
"/reset": {
|
||||
"description": "Clear conversation history and reset system prompt.",
|
||||
"usage": "/reset",
|
||||
"examples": [("Reset conversation", "/reset")],
|
||||
"notes": "Requires confirmation.",
|
||||
},
|
||||
"/info": {
|
||||
"description": "Display detailed information about a model.",
|
||||
"usage": "/info [model_id]",
|
||||
"examples": [
|
||||
("Show current model info", "/info"),
|
||||
("Show specific model info", "/info gpt-4o"),
|
||||
],
|
||||
"notes": "Shows pricing, capabilities, and context length.",
|
||||
},
|
||||
"/model": {
|
||||
"description": "Select or change the AI model.",
|
||||
"usage": "/model [search_term]",
|
||||
"examples": [
|
||||
("List all models", "/model"),
|
||||
("Search for GPT models", "/model gpt"),
|
||||
("Search for Claude models", "/model claude"),
|
||||
],
|
||||
"notes": "Models are numbered for easy selection.",
|
||||
},
|
||||
"/config": {
|
||||
"description": "View or modify application configuration.",
|
||||
"usage": "/config [setting] [value]",
|
||||
"examples": [
|
||||
("View all settings", "/config"),
|
||||
("Set API key", "/config api"),
|
||||
("Set default model", "/config model"),
|
||||
("Set system prompt", "/config system You are a helpful assistant"),
|
||||
("Enable streaming", "/config stream on"),
|
||||
],
|
||||
"notes": "Available: api, url, model, system, stream, costwarning, maxtoken, online, loglevel.",
|
||||
},
|
||||
"/maxtoken": {
|
||||
"description": "Set a temporary session token limit.",
|
||||
"usage": "/maxtoken [value]",
|
||||
"examples": [
|
||||
("View current session limit", "/maxtoken"),
|
||||
("Set session limit to 2000", "/maxtoken 2000"),
|
||||
],
|
||||
"notes": "Cannot exceed stored max token limit.",
|
||||
},
|
||||
"/system": {
|
||||
"description": "Set or clear the session-level system prompt.",
|
||||
"usage": "/system [prompt|clear|default <prompt>]",
|
||||
"examples": [
|
||||
("View current system prompt", "/system"),
|
||||
("Set as Python expert", "/system You are a Python expert"),
|
||||
("Multiline with newlines", r"/system You are an expert.\nBe clear and concise."),
|
||||
("Save as default", "/system default You are a helpful assistant"),
|
||||
("Revert to default", "/system clear"),
|
||||
("Blank prompt", '/system ""'),
|
||||
],
|
||||
"notes": r"Use \n for newlines. /system clear reverts to hardcoded default.",
|
||||
},
|
||||
"/save": {
|
||||
"description": "Save the current conversation history.",
|
||||
"usage": "/save <name>",
|
||||
"examples": [("Save conversation", "/save my_chat")],
|
||||
"notes": "Saved conversations can be loaded later with /load.",
|
||||
},
|
||||
"/load": {
|
||||
"description": "Load a saved conversation.",
|
||||
"usage": "/load <name|number>",
|
||||
"examples": [
|
||||
("Load by name", "/load my_chat"),
|
||||
("Load by number from /list", "/load 3"),
|
||||
],
|
||||
"notes": "Use /list to see numbered conversations.",
|
||||
},
|
||||
"/delete": {
|
||||
"description": "Delete a saved conversation.",
|
||||
"usage": "/delete <name|number>",
|
||||
"examples": [("Delete by name", "/delete my_chat")],
|
||||
"notes": "Requires confirmation. Cannot be undone.",
|
||||
},
|
||||
"/list": {
|
||||
"description": "List all saved conversations.",
|
||||
"usage": "/list",
|
||||
"examples": [("Show saved conversations", "/list")],
|
||||
"notes": "Conversations are numbered for use with /load and /delete.",
|
||||
},
|
||||
"/export": {
|
||||
"description": "Export the current conversation to a file.",
|
||||
"usage": "/export <format> <filename>",
|
||||
"examples": [
|
||||
("Export as Markdown", "/export md notes.md"),
|
||||
("Export as JSON", "/export json conversation.json"),
|
||||
("Export as HTML", "/export html report.html"),
|
||||
],
|
||||
"notes": "Available formats: md, json, html.",
|
||||
},
|
||||
"/stats": {
|
||||
"description": "Display session statistics.",
|
||||
"usage": "/stats",
|
||||
"examples": [("View session statistics", "/stats")],
|
||||
"notes": "Shows tokens, costs, and credits.",
|
||||
},
|
||||
"/credits": {
|
||||
"description": "Display your OpenRouter account credits.",
|
||||
"usage": "/credits",
|
||||
"examples": [("Check credits", "/credits")],
|
||||
"notes": "Shows total, used, and remaining credits.",
|
||||
},
|
||||
"/middleout": {
|
||||
"description": "Toggle middle-out transform for long prompts.",
|
||||
"usage": "/middleout [on|off]",
|
||||
"examples": [
|
||||
("Check status", "/middleout"),
|
||||
("Enable compression", "/middleout on"),
|
||||
],
|
||||
"notes": "Compresses prompts exceeding context size.",
|
||||
},
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
"""
|
||||
Core functionality for oAI.
|
||||
|
||||
This module provides the main session management and AI client
|
||||
classes that power the chat application.
|
||||
"""
|
||||
|
||||
from oai.core.session import ChatSession
|
||||
from oai.core.client import AIClient
|
||||
|
||||
__all__ = [
|
||||
"ChatSession",
|
||||
"AIClient",
|
||||
]
|
||||
@@ -1,422 +0,0 @@
|
||||
"""
|
||||
AI Client for oAI.
|
||||
|
||||
This module provides a high-level client for interacting with AI models
|
||||
through the provider abstraction layer.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional, Union
|
||||
|
||||
from oai.constants import APP_NAME, APP_URL, MODEL_PRICING
|
||||
from oai.providers.base import (
|
||||
AIProvider,
|
||||
ChatMessage,
|
||||
ChatResponse,
|
||||
ModelInfo,
|
||||
StreamChunk,
|
||||
ToolCall,
|
||||
UsageStats,
|
||||
)
|
||||
from oai.providers.openrouter import OpenRouterProvider
|
||||
from oai.utils.logging import get_logger
|
||||
|
||||
|
||||
class AIClient:
|
||||
"""
|
||||
High-level AI client for chat interactions.
|
||||
|
||||
Provides a simplified interface for sending chat requests,
|
||||
handling streaming, and managing tool calls.
|
||||
|
||||
Attributes:
|
||||
provider: The underlying AI provider
|
||||
default_model: Default model ID to use
|
||||
http_headers: Custom HTTP headers for requests
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: Optional[str] = None,
|
||||
provider_class: type = OpenRouterProvider,
|
||||
app_name: str = APP_NAME,
|
||||
app_url: str = APP_URL,
|
||||
):
|
||||
"""
|
||||
Initialize the AI client.
|
||||
|
||||
Args:
|
||||
api_key: API key for authentication
|
||||
base_url: Optional custom base URL
|
||||
provider_class: Provider class to use (default: OpenRouterProvider)
|
||||
app_name: Application name for headers
|
||||
app_url: Application URL for headers
|
||||
"""
|
||||
self.provider: AIProvider = provider_class(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
app_name=app_name,
|
||||
app_url=app_url,
|
||||
)
|
||||
self.default_model: Optional[str] = None
|
||||
self.logger = get_logger()
|
||||
|
||||
def list_models(self, filter_text_only: bool = True) -> List[ModelInfo]:
|
||||
"""
|
||||
Get available models.
|
||||
|
||||
Args:
|
||||
filter_text_only: Whether to exclude video-only models
|
||||
|
||||
Returns:
|
||||
List of ModelInfo objects
|
||||
"""
|
||||
return self.provider.list_models(filter_text_only=filter_text_only)
|
||||
|
||||
def get_model(self, model_id: str) -> Optional[ModelInfo]:
|
||||
"""
|
||||
Get information about a specific model.
|
||||
|
||||
Args:
|
||||
model_id: Model identifier
|
||||
|
||||
Returns:
|
||||
ModelInfo or None if not found
|
||||
"""
|
||||
return self.provider.get_model(model_id)
|
||||
|
||||
def get_raw_model(self, model_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get raw model data for provider-specific fields.
|
||||
|
||||
Args:
|
||||
model_id: Model identifier
|
||||
|
||||
Returns:
|
||||
Raw model dictionary or None
|
||||
"""
|
||||
if hasattr(self.provider, "get_raw_model"):
|
||||
return self.provider.get_raw_model(model_id)
|
||||
return None
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
model: Optional[str] = None,
|
||||
stream: bool = False,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
online: bool = False,
|
||||
transforms: Optional[List[str]] = None,
|
||||
) -> Union[ChatResponse, Iterator[StreamChunk]]:
|
||||
"""
|
||||
Send a chat request.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries
|
||||
model: Model ID (uses default if not specified)
|
||||
stream: Whether to stream the response
|
||||
max_tokens: Maximum tokens in response
|
||||
temperature: Sampling temperature
|
||||
tools: Tool definitions for function calling
|
||||
tool_choice: Tool selection mode
|
||||
system_prompt: System prompt to prepend
|
||||
online: Whether to enable online mode
|
||||
transforms: List of transforms (e.g., ["middle-out"])
|
||||
|
||||
Returns:
|
||||
ChatResponse for non-streaming, Iterator[StreamChunk] for streaming
|
||||
|
||||
Raises:
|
||||
ValueError: If no model specified and no default set
|
||||
"""
|
||||
model_id = model or self.default_model
|
||||
if not model_id:
|
||||
raise ValueError("No model specified and no default set")
|
||||
|
||||
# Apply online mode suffix
|
||||
if online and hasattr(self.provider, "get_effective_model_id"):
|
||||
model_id = self.provider.get_effective_model_id(model_id, True)
|
||||
|
||||
# Convert dict messages to ChatMessage objects
|
||||
chat_messages = []
|
||||
|
||||
# Add system prompt if provided
|
||||
if system_prompt:
|
||||
chat_messages.append(ChatMessage(role="system", content=system_prompt))
|
||||
|
||||
# Convert message dicts
|
||||
for msg in messages:
|
||||
# Convert tool_calls dicts to ToolCall objects if present
|
||||
tool_calls_data = msg.get("tool_calls")
|
||||
tool_calls_obj = None
|
||||
if tool_calls_data:
|
||||
from oai.providers.base import ToolCall, ToolFunction
|
||||
tool_calls_obj = []
|
||||
for tc in tool_calls_data:
|
||||
# Handle both ToolCall objects and dicts
|
||||
if isinstance(tc, ToolCall):
|
||||
tool_calls_obj.append(tc)
|
||||
elif isinstance(tc, dict):
|
||||
func_data = tc.get("function", {})
|
||||
tool_calls_obj.append(
|
||||
ToolCall(
|
||||
id=tc.get("id", ""),
|
||||
type=tc.get("type", "function"),
|
||||
function=ToolFunction(
|
||||
name=func_data.get("name", ""),
|
||||
arguments=func_data.get("arguments", "{}"),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
chat_messages.append(
|
||||
ChatMessage(
|
||||
role=msg.get("role", "user"),
|
||||
content=msg.get("content"),
|
||||
tool_calls=tool_calls_obj,
|
||||
tool_call_id=msg.get("tool_call_id"),
|
||||
)
|
||||
)
|
||||
|
||||
self.logger.debug(
|
||||
f"Sending chat request: model={model_id}, "
|
||||
f"messages={len(chat_messages)}, stream={stream}"
|
||||
)
|
||||
|
||||
return self.provider.chat(
|
||||
model=model_id,
|
||||
messages=chat_messages,
|
||||
stream=stream,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
transforms=transforms,
|
||||
)
|
||||
|
||||
def chat_with_tools(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
tools: List[Dict[str, Any]],
|
||||
tool_executor: Callable[[str, Dict[str, Any]], Dict[str, Any]],
|
||||
model: Optional[str] = None,
|
||||
max_loops: int = 5,
|
||||
max_tokens: Optional[int] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
on_tool_call: Optional[Callable[[ToolCall], None]] = None,
|
||||
on_tool_result: Optional[Callable[[str, Dict[str, Any]], None]] = None,
|
||||
) -> ChatResponse:
|
||||
"""
|
||||
Send a chat request with automatic tool call handling.
|
||||
|
||||
Executes tool calls returned by the model and continues
|
||||
the conversation until no more tool calls are requested.
|
||||
|
||||
Args:
|
||||
messages: Initial messages
|
||||
tools: Tool definitions
|
||||
tool_executor: Function to execute tool calls
|
||||
model: Model ID
|
||||
max_loops: Maximum tool call iterations
|
||||
max_tokens: Maximum response tokens
|
||||
system_prompt: System prompt
|
||||
on_tool_call: Callback when tool is called
|
||||
on_tool_result: Callback when tool returns result
|
||||
|
||||
Returns:
|
||||
Final ChatResponse after all tool calls complete
|
||||
"""
|
||||
model_id = model or self.default_model
|
||||
if not model_id:
|
||||
raise ValueError("No model specified and no default set")
|
||||
|
||||
# Build initial messages
|
||||
chat_messages = []
|
||||
if system_prompt:
|
||||
chat_messages.append({"role": "system", "content": system_prompt})
|
||||
chat_messages.extend(messages)
|
||||
|
||||
loop_count = 0
|
||||
current_response: Optional[ChatResponse] = None
|
||||
|
||||
while loop_count < max_loops:
|
||||
# Send request
|
||||
response = self.chat(
|
||||
messages=chat_messages,
|
||||
model=model_id,
|
||||
stream=False,
|
||||
max_tokens=max_tokens,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
)
|
||||
|
||||
if not isinstance(response, ChatResponse):
|
||||
raise ValueError("Expected non-streaming response")
|
||||
|
||||
current_response = response
|
||||
|
||||
# Check for tool calls
|
||||
tool_calls = response.tool_calls
|
||||
if not tool_calls:
|
||||
break
|
||||
|
||||
self.logger.info(f"Model requested {len(tool_calls)} tool call(s)")
|
||||
|
||||
# Process each tool call
|
||||
tool_results = []
|
||||
for tc in tool_calls:
|
||||
if on_tool_call:
|
||||
on_tool_call(tc)
|
||||
|
||||
try:
|
||||
args = json.loads(tc.function.arguments)
|
||||
except json.JSONDecodeError as e:
|
||||
self.logger.error(f"Failed to parse tool arguments: {e}")
|
||||
result = {"error": f"Invalid arguments: {e}"}
|
||||
else:
|
||||
result = tool_executor(tc.function.name, args)
|
||||
|
||||
if on_tool_result:
|
||||
on_tool_result(tc.function.name, result)
|
||||
|
||||
tool_results.append({
|
||||
"tool_call_id": tc.id,
|
||||
"role": "tool",
|
||||
"name": tc.function.name,
|
||||
"content": json.dumps(result),
|
||||
})
|
||||
|
||||
# Add assistant message with tool calls
|
||||
assistant_msg = {
|
||||
"role": "assistant",
|
||||
"content": response.content,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": tc.type,
|
||||
"function": {
|
||||
"name": tc.function.name,
|
||||
"arguments": tc.function.arguments,
|
||||
},
|
||||
}
|
||||
for tc in tool_calls
|
||||
],
|
||||
}
|
||||
chat_messages.append(assistant_msg)
|
||||
chat_messages.extend(tool_results)
|
||||
|
||||
loop_count += 1
|
||||
|
||||
if loop_count >= max_loops:
|
||||
self.logger.warning(f"Reached max tool call loops ({max_loops})")
|
||||
|
||||
return current_response
|
||||
|
||||
def stream_chat(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
model: Optional[str] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
online: bool = False,
|
||||
on_chunk: Optional[Callable[[StreamChunk], None]] = None,
|
||||
) -> tuple[str, Optional[UsageStats]]:
|
||||
"""
|
||||
Stream a chat response and collect the full text.
|
||||
|
||||
Args:
|
||||
messages: Chat messages
|
||||
model: Model ID
|
||||
max_tokens: Maximum tokens
|
||||
system_prompt: System prompt
|
||||
online: Online mode
|
||||
on_chunk: Optional callback for each chunk
|
||||
|
||||
Returns:
|
||||
Tuple of (full_response_text, usage_stats)
|
||||
"""
|
||||
response = self.chat(
|
||||
messages=messages,
|
||||
model=model,
|
||||
stream=True,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_prompt,
|
||||
online=online,
|
||||
)
|
||||
|
||||
if isinstance(response, ChatResponse):
|
||||
# Not actually streaming
|
||||
return response.content or "", response.usage
|
||||
|
||||
full_text = ""
|
||||
usage: Optional[UsageStats] = None
|
||||
|
||||
for chunk in response:
|
||||
if chunk.error:
|
||||
self.logger.error(f"Stream error: {chunk.error}")
|
||||
break
|
||||
|
||||
if chunk.delta_content:
|
||||
full_text += chunk.delta_content
|
||||
if on_chunk:
|
||||
on_chunk(chunk)
|
||||
|
||||
if chunk.usage:
|
||||
usage = chunk.usage
|
||||
|
||||
return full_text, usage
|
||||
|
||||
def get_credits(self) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get account credit information.
|
||||
|
||||
Returns:
|
||||
Credit info dict or None if unavailable
|
||||
"""
|
||||
return self.provider.get_credits()
|
||||
|
||||
def estimate_cost(
|
||||
self,
|
||||
model_id: str,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
) -> float:
|
||||
"""
|
||||
Estimate cost for a completion.
|
||||
|
||||
Args:
|
||||
model_id: Model ID
|
||||
input_tokens: Number of input tokens
|
||||
output_tokens: Number of output tokens
|
||||
|
||||
Returns:
|
||||
Estimated cost in USD
|
||||
"""
|
||||
if hasattr(self.provider, "estimate_cost"):
|
||||
return self.provider.estimate_cost(model_id, input_tokens, output_tokens)
|
||||
|
||||
# Fallback to default pricing
|
||||
input_cost = MODEL_PRICING["input"] * input_tokens / 1_000_000
|
||||
output_cost = MODEL_PRICING["output"] * output_tokens / 1_000_000
|
||||
return input_cost + output_cost
|
||||
|
||||
def set_default_model(self, model_id: str) -> None:
|
||||
"""
|
||||
Set the default model.
|
||||
|
||||
Args:
|
||||
model_id: Model ID to use as default
|
||||
"""
|
||||
self.default_model = model_id
|
||||
self.logger.info(f"Default model set to: {model_id}")
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear the provider's model cache."""
|
||||
if hasattr(self.provider, "clear_cache"):
|
||||
self.provider.clear_cache()
|
||||
@@ -1,891 +0,0 @@
|
||||
"""
|
||||
Chat session management for oAI.
|
||||
|
||||
This module provides the ChatSession class that manages an interactive
|
||||
chat session including history, state, and message handling.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Optional, Tuple
|
||||
|
||||
from oai.commands.registry import CommandContext, CommandResult, registry
|
||||
from oai.config.database import Database
|
||||
from oai.config.settings import Settings
|
||||
from oai.constants import (
|
||||
COST_WARNING_THRESHOLD,
|
||||
LOW_CREDIT_AMOUNT,
|
||||
LOW_CREDIT_RATIO,
|
||||
)
|
||||
from oai.core.client import AIClient
|
||||
from oai.mcp.manager import MCPManager
|
||||
from oai.providers.base import ChatResponse, StreamChunk, UsageStats
|
||||
from oai.utils.logging import get_logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionStats:
|
||||
"""
|
||||
Statistics for the current session.
|
||||
|
||||
Tracks tokens, costs, and message counts.
|
||||
"""
|
||||
|
||||
total_input_tokens: int = 0
|
||||
total_output_tokens: int = 0
|
||||
total_cost: float = 0.0
|
||||
message_count: int = 0
|
||||
|
||||
@property
|
||||
def total_tokens(self) -> int:
|
||||
"""Get total token count."""
|
||||
return self.total_input_tokens + self.total_output_tokens
|
||||
|
||||
def add_usage(self, usage: Optional[UsageStats], cost: float = 0.0) -> None:
|
||||
"""
|
||||
Add usage stats from a response.
|
||||
|
||||
Args:
|
||||
usage: Usage statistics
|
||||
cost: Cost if not in usage
|
||||
"""
|
||||
if usage:
|
||||
self.total_input_tokens += usage.prompt_tokens
|
||||
self.total_output_tokens += usage.completion_tokens
|
||||
if usage.total_cost_usd:
|
||||
self.total_cost += usage.total_cost_usd
|
||||
else:
|
||||
self.total_cost += cost
|
||||
else:
|
||||
self.total_cost += cost
|
||||
self.message_count += 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class HistoryEntry:
|
||||
"""
|
||||
A single entry in the conversation history.
|
||||
|
||||
Stores the user prompt, assistant response, and metrics.
|
||||
"""
|
||||
|
||||
prompt: str
|
||||
response: str
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
msg_cost: float = 0.0
|
||||
timestamp: Optional[float] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary format."""
|
||||
return {
|
||||
"prompt": self.prompt,
|
||||
"response": self.response,
|
||||
"prompt_tokens": self.prompt_tokens,
|
||||
"completion_tokens": self.completion_tokens,
|
||||
"msg_cost": self.msg_cost,
|
||||
}
|
||||
|
||||
|
||||
class ChatSession:
|
||||
"""
|
||||
Manages an interactive chat session.
|
||||
|
||||
Handles conversation history, state management, command processing,
|
||||
and communication with the AI client.
|
||||
|
||||
Attributes:
|
||||
client: AI client for API requests
|
||||
settings: Application settings
|
||||
mcp_manager: MCP manager for file/database access
|
||||
history: Conversation history
|
||||
stats: Session statistics
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: AIClient,
|
||||
settings: Settings,
|
||||
mcp_manager: Optional[MCPManager] = None,
|
||||
):
|
||||
"""
|
||||
Initialize a chat session.
|
||||
|
||||
Args:
|
||||
client: AI client instance
|
||||
settings: Application settings
|
||||
mcp_manager: Optional MCP manager
|
||||
"""
|
||||
self.client = client
|
||||
self.settings = settings
|
||||
self.mcp_manager = mcp_manager
|
||||
self.db = Database()
|
||||
|
||||
self.history: List[HistoryEntry] = []
|
||||
self.stats = SessionStats()
|
||||
|
||||
# Session state
|
||||
self.system_prompt: str = settings.effective_system_prompt
|
||||
self.memory_enabled: bool = True
|
||||
self.memory_start_index: int = 0
|
||||
self.online_enabled: bool = settings.default_online_mode
|
||||
self.middle_out_enabled: bool = False
|
||||
self.session_max_token: int = 0
|
||||
self.current_index: int = 0
|
||||
|
||||
# Selected model
|
||||
self.selected_model: Optional[Dict[str, Any]] = None
|
||||
|
||||
self.logger = get_logger()
|
||||
|
||||
def get_context(self) -> CommandContext:
|
||||
"""
|
||||
Get the current command context.
|
||||
|
||||
Returns:
|
||||
CommandContext with current session state
|
||||
"""
|
||||
return CommandContext(
|
||||
settings=self.settings,
|
||||
provider=self.client.provider,
|
||||
mcp_manager=self.mcp_manager,
|
||||
selected_model_raw=self.selected_model,
|
||||
session_history=[e.to_dict() for e in self.history],
|
||||
session_system_prompt=self.system_prompt,
|
||||
memory_enabled=self.memory_enabled,
|
||||
memory_start_index=self.memory_start_index,
|
||||
online_enabled=self.online_enabled,
|
||||
middle_out_enabled=self.middle_out_enabled,
|
||||
session_max_token=self.session_max_token,
|
||||
total_input_tokens=self.stats.total_input_tokens,
|
||||
total_output_tokens=self.stats.total_output_tokens,
|
||||
total_cost=self.stats.total_cost,
|
||||
message_count=self.stats.message_count,
|
||||
current_index=self.current_index,
|
||||
)
|
||||
|
||||
def set_model(self, model: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Set the selected model.
|
||||
|
||||
Args:
|
||||
model: Raw model dictionary
|
||||
"""
|
||||
self.selected_model = model
|
||||
self.client.set_default_model(model["id"])
|
||||
self.logger.info(f"Model selected: {model['id']}")
|
||||
|
||||
def build_api_messages(self, user_input: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Build the messages array for an API request.
|
||||
|
||||
Includes system prompt, history (if memory enabled), and current input.
|
||||
|
||||
Args:
|
||||
user_input: Current user input
|
||||
|
||||
Returns:
|
||||
List of message dictionaries
|
||||
"""
|
||||
messages = []
|
||||
|
||||
# Add system prompt
|
||||
if self.system_prompt:
|
||||
messages.append({"role": "system", "content": self.system_prompt})
|
||||
|
||||
# Add database context if in database mode
|
||||
if self.mcp_manager and self.mcp_manager.enabled:
|
||||
if self.mcp_manager.mode == "database" and self.mcp_manager.selected_db_index is not None:
|
||||
db = self.mcp_manager.databases[self.mcp_manager.selected_db_index]
|
||||
db_context = (
|
||||
f"You are connected to SQLite database: {db['name']}\n"
|
||||
f"Available tables: {', '.join(db['tables'])}\n\n"
|
||||
"Use inspect_database, search_database, or query_database tools. "
|
||||
"All queries are read-only."
|
||||
)
|
||||
messages.append({"role": "system", "content": db_context})
|
||||
|
||||
# Add history if memory enabled
|
||||
if self.memory_enabled:
|
||||
for i in range(self.memory_start_index, len(self.history)):
|
||||
entry = self.history[i]
|
||||
messages.append({"role": "user", "content": entry.prompt})
|
||||
messages.append({"role": "assistant", "content": entry.response})
|
||||
|
||||
# Add current message
|
||||
messages.append({"role": "user", "content": user_input})
|
||||
|
||||
return messages
|
||||
|
||||
def get_mcp_tools(self) -> Optional[List[Dict[str, Any]]]:
|
||||
"""
|
||||
Get MCP tool definitions if available.
|
||||
|
||||
Returns:
|
||||
List of tool schemas or None
|
||||
"""
|
||||
if not self.mcp_manager or not self.mcp_manager.enabled:
|
||||
return None
|
||||
|
||||
if not self.selected_model:
|
||||
return None
|
||||
|
||||
# Check if model supports tools
|
||||
supported_params = self.selected_model.get("supported_parameters", [])
|
||||
if "tools" not in supported_params and "functions" not in supported_params:
|
||||
return None
|
||||
|
||||
return self.mcp_manager.get_tools_schema()
|
||||
|
||||
async def execute_tool(
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_args: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute an MCP tool.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool
|
||||
tool_args: Tool arguments
|
||||
|
||||
Returns:
|
||||
Tool execution result
|
||||
"""
|
||||
if not self.mcp_manager:
|
||||
return {"error": "MCP not available"}
|
||||
|
||||
return await self.mcp_manager.call_tool(tool_name, **tool_args)
|
||||
|
||||
def send_message(
|
||||
self,
|
||||
user_input: str,
|
||||
stream: bool = True,
|
||||
on_stream_chunk: Optional[Callable[[str], None]] = None,
|
||||
) -> Tuple[str, Optional[UsageStats], float]:
|
||||
"""
|
||||
Send a message and get a response.
|
||||
|
||||
Args:
|
||||
user_input: User's input text
|
||||
stream: Whether to stream the response
|
||||
on_stream_chunk: Callback for stream chunks
|
||||
|
||||
Returns:
|
||||
Tuple of (response_text, usage_stats, response_time)
|
||||
"""
|
||||
if not self.selected_model:
|
||||
raise ValueError("No model selected")
|
||||
|
||||
start_time = time.time()
|
||||
messages = self.build_api_messages(user_input)
|
||||
|
||||
# Get MCP tools
|
||||
tools = self.get_mcp_tools()
|
||||
if tools:
|
||||
# Disable streaming when tools are present
|
||||
stream = False
|
||||
|
||||
# Build request parameters
|
||||
model_id = self.selected_model["id"]
|
||||
if self.online_enabled:
|
||||
if hasattr(self.client.provider, "get_effective_model_id"):
|
||||
model_id = self.client.provider.get_effective_model_id(model_id, True)
|
||||
|
||||
transforms = ["middle-out"] if self.middle_out_enabled else None
|
||||
|
||||
max_tokens = None
|
||||
if self.session_max_token > 0:
|
||||
max_tokens = self.session_max_token
|
||||
|
||||
if tools:
|
||||
# Use tool handling flow
|
||||
response = self._send_with_tools(
|
||||
messages=messages,
|
||||
model_id=model_id,
|
||||
tools=tools,
|
||||
max_tokens=max_tokens,
|
||||
transforms=transforms,
|
||||
)
|
||||
response_time = time.time() - start_time
|
||||
return response.content or "", response.usage, response_time
|
||||
|
||||
elif stream:
|
||||
# Use streaming flow
|
||||
full_text, usage = self._stream_response(
|
||||
messages=messages,
|
||||
model_id=model_id,
|
||||
max_tokens=max_tokens,
|
||||
transforms=transforms,
|
||||
on_chunk=on_stream_chunk,
|
||||
)
|
||||
response_time = time.time() - start_time
|
||||
return full_text, usage, response_time
|
||||
|
||||
else:
|
||||
# Non-streaming request
|
||||
response = self.client.chat(
|
||||
messages=messages,
|
||||
model=model_id,
|
||||
stream=False,
|
||||
max_tokens=max_tokens,
|
||||
transforms=transforms,
|
||||
)
|
||||
response_time = time.time() - start_time
|
||||
|
||||
if isinstance(response, ChatResponse):
|
||||
return response.content or "", response.usage, response_time
|
||||
else:
|
||||
return "", None, response_time
|
||||
|
||||
def _send_with_tools(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
model_id: str,
|
||||
tools: List[Dict[str, Any]],
|
||||
max_tokens: Optional[int] = None,
|
||||
transforms: Optional[List[str]] = None,
|
||||
) -> ChatResponse:
|
||||
"""
|
||||
Send a request with tool call handling.
|
||||
|
||||
Args:
|
||||
messages: API messages
|
||||
model_id: Model ID
|
||||
tools: Tool definitions
|
||||
max_tokens: Max tokens
|
||||
transforms: Transforms list
|
||||
|
||||
Returns:
|
||||
Final ChatResponse
|
||||
"""
|
||||
max_loops = 5
|
||||
loop_count = 0
|
||||
api_messages = list(messages)
|
||||
|
||||
while loop_count < max_loops:
|
||||
response = self.client.chat(
|
||||
messages=api_messages,
|
||||
model=model_id,
|
||||
stream=False,
|
||||
max_tokens=max_tokens,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
transforms=transforms,
|
||||
)
|
||||
|
||||
if not isinstance(response, ChatResponse):
|
||||
raise ValueError("Expected ChatResponse")
|
||||
|
||||
tool_calls = response.tool_calls
|
||||
if not tool_calls:
|
||||
return response
|
||||
|
||||
# Tool calls requested by AI
|
||||
|
||||
tool_results = []
|
||||
for tc in tool_calls:
|
||||
try:
|
||||
args = json.loads(tc.function.arguments)
|
||||
except json.JSONDecodeError as e:
|
||||
self.logger.error(f"Failed to parse tool arguments: {e}")
|
||||
tool_results.append({
|
||||
"tool_call_id": tc.id,
|
||||
"role": "tool",
|
||||
"name": tc.function.name,
|
||||
"content": json.dumps({"error": f"Invalid arguments: {e}"}),
|
||||
})
|
||||
continue
|
||||
|
||||
# Display tool call
|
||||
args_display = ", ".join(
|
||||
f'{k}="{v}"' if isinstance(v, str) else f"{k}={v}"
|
||||
for k, v in args.items()
|
||||
)
|
||||
# Executing tool: {tc.function.name}
|
||||
|
||||
# Execute tool
|
||||
result = asyncio.run(self.execute_tool(tc.function.name, args))
|
||||
|
||||
if "error" in result:
|
||||
# Tool execution error logged
|
||||
pass
|
||||
else:
|
||||
# Tool execution successful
|
||||
pass
|
||||
|
||||
tool_results.append({
|
||||
"tool_call_id": tc.id,
|
||||
"role": "tool",
|
||||
"name": tc.function.name,
|
||||
"content": json.dumps(result),
|
||||
})
|
||||
|
||||
# Add assistant message with tool calls
|
||||
api_messages.append({
|
||||
"role": "assistant",
|
||||
"content": response.content,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": tc.type,
|
||||
"function": {
|
||||
"name": tc.function.name,
|
||||
"arguments": tc.function.arguments,
|
||||
},
|
||||
}
|
||||
for tc in tool_calls
|
||||
],
|
||||
})
|
||||
api_messages.extend(tool_results)
|
||||
|
||||
# Processing tool results
|
||||
loop_count += 1
|
||||
|
||||
self.logger.warning(f"Reached max tool loops ({max_loops})")
|
||||
return response
|
||||
|
||||
|
||||
def _stream_response(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
model_id: str,
|
||||
max_tokens: Optional[int] = None,
|
||||
transforms: Optional[List[str]] = None,
|
||||
on_chunk: Optional[Callable[[str], None]] = None,
|
||||
) -> Tuple[str, Optional[UsageStats]]:
|
||||
"""
|
||||
Stream a response with live display.
|
||||
|
||||
Args:
|
||||
messages: API messages
|
||||
model_id: Model ID
|
||||
max_tokens: Max tokens
|
||||
transforms: Transforms
|
||||
on_chunk: Callback for chunks
|
||||
|
||||
Returns:
|
||||
Tuple of (full_text, usage)
|
||||
"""
|
||||
response = self.client.chat(
|
||||
messages=messages,
|
||||
model=model_id,
|
||||
stream=True,
|
||||
max_tokens=max_tokens,
|
||||
transforms=transforms,
|
||||
)
|
||||
|
||||
if isinstance(response, ChatResponse):
|
||||
return response.content or "", response.usage
|
||||
|
||||
full_text = ""
|
||||
usage: Optional[UsageStats] = None
|
||||
|
||||
try:
|
||||
for chunk in response:
|
||||
if chunk.error:
|
||||
self.logger.error(f"Stream error: {chunk.error}")
|
||||
break
|
||||
|
||||
if chunk.delta_content:
|
||||
full_text += chunk.delta_content
|
||||
if on_chunk:
|
||||
on_chunk(chunk.delta_content)
|
||||
|
||||
if chunk.usage:
|
||||
usage = chunk.usage
|
||||
|
||||
except KeyboardInterrupt:
|
||||
self.logger.info("Streaming interrupted")
|
||||
return "", None
|
||||
|
||||
return full_text, usage
|
||||
|
||||
# ========== ASYNC METHODS FOR TUI ==========
|
||||
|
||||
async def send_message_async(
|
||||
self,
|
||||
user_input: str,
|
||||
stream: bool = True,
|
||||
) -> AsyncIterator[StreamChunk]:
|
||||
"""
|
||||
Async version of send_message for Textual TUI.
|
||||
|
||||
Args:
|
||||
user_input: User's input text
|
||||
stream: Whether to stream the response
|
||||
|
||||
Yields:
|
||||
StreamChunk objects for progressive display
|
||||
"""
|
||||
if not self.selected_model:
|
||||
raise ValueError("No model selected")
|
||||
|
||||
messages = self.build_api_messages(user_input)
|
||||
tools = self.get_mcp_tools()
|
||||
|
||||
if tools:
|
||||
# Disable streaming when tools are present
|
||||
stream = False
|
||||
|
||||
model_id = self.selected_model["id"]
|
||||
if self.online_enabled:
|
||||
if hasattr(self.client.provider, "get_effective_model_id"):
|
||||
model_id = self.client.provider.get_effective_model_id(model_id, True)
|
||||
|
||||
transforms = ["middle-out"] if self.middle_out_enabled else None
|
||||
max_tokens = None
|
||||
if self.session_max_token > 0:
|
||||
max_tokens = self.session_max_token
|
||||
|
||||
if tools:
|
||||
# Use async tool handling flow
|
||||
async for chunk in self._send_with_tools_async(
|
||||
messages=messages,
|
||||
model_id=model_id,
|
||||
tools=tools,
|
||||
max_tokens=max_tokens,
|
||||
transforms=transforms,
|
||||
):
|
||||
yield chunk
|
||||
elif stream:
|
||||
# Use async streaming flow
|
||||
async for chunk in self._stream_response_async(
|
||||
messages=messages,
|
||||
model_id=model_id,
|
||||
max_tokens=max_tokens,
|
||||
transforms=transforms,
|
||||
):
|
||||
yield chunk
|
||||
else:
|
||||
# Non-streaming request
|
||||
response = self.client.chat(
|
||||
messages=messages,
|
||||
model=model_id,
|
||||
stream=False,
|
||||
max_tokens=max_tokens,
|
||||
transforms=transforms,
|
||||
)
|
||||
if isinstance(response, ChatResponse):
|
||||
# Yield single chunk with complete response
|
||||
chunk = StreamChunk(
|
||||
id="",
|
||||
delta_content=response.content,
|
||||
usage=response.usage,
|
||||
error=None,
|
||||
)
|
||||
yield chunk
|
||||
|
||||
async def _send_with_tools_async(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
model_id: str,
|
||||
tools: List[Dict[str, Any]],
|
||||
max_tokens: Optional[int] = None,
|
||||
transforms: Optional[List[str]] = None,
|
||||
) -> AsyncIterator[StreamChunk]:
|
||||
"""
|
||||
Async version of _send_with_tools for TUI.
|
||||
|
||||
Args:
|
||||
messages: API messages
|
||||
model_id: Model ID
|
||||
tools: Tool definitions
|
||||
max_tokens: Max tokens
|
||||
transforms: Transforms list
|
||||
|
||||
Yields:
|
||||
StreamChunk objects including tool call notifications
|
||||
"""
|
||||
max_loops = 5
|
||||
loop_count = 0
|
||||
api_messages = list(messages)
|
||||
|
||||
while loop_count < max_loops:
|
||||
response = self.client.chat(
|
||||
messages=api_messages,
|
||||
model=model_id,
|
||||
stream=False,
|
||||
max_tokens=max_tokens,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
transforms=transforms,
|
||||
)
|
||||
|
||||
if not isinstance(response, ChatResponse):
|
||||
raise ValueError("Expected ChatResponse")
|
||||
|
||||
tool_calls = response.tool_calls
|
||||
if not tool_calls:
|
||||
# Final response, yield it
|
||||
chunk = StreamChunk(
|
||||
id="",
|
||||
delta_content=response.content,
|
||||
usage=response.usage,
|
||||
error=None,
|
||||
)
|
||||
yield chunk
|
||||
return
|
||||
|
||||
# Yield notification about tool calls
|
||||
tool_notification = f"\n🔧 AI requesting {len(tool_calls)} tool call(s)...\n"
|
||||
yield StreamChunk(id="", delta_content=tool_notification, usage=None, error=None)
|
||||
|
||||
tool_results = []
|
||||
for tc in tool_calls:
|
||||
try:
|
||||
args = json.loads(tc.function.arguments)
|
||||
except json.JSONDecodeError as e:
|
||||
self.logger.error(f"Failed to parse tool arguments: {e}")
|
||||
tool_results.append({
|
||||
"tool_call_id": tc.id,
|
||||
"role": "tool",
|
||||
"name": tc.function.name,
|
||||
"content": json.dumps({"error": f"Invalid arguments: {e}"}),
|
||||
})
|
||||
continue
|
||||
|
||||
# Yield tool call display
|
||||
args_display = ", ".join(
|
||||
f'{k}="{v}"' if isinstance(v, str) else f"{k}={v}"
|
||||
for k, v in args.items()
|
||||
)
|
||||
tool_display = f" → {tc.function.name}({args_display})\n"
|
||||
yield StreamChunk(id="", delta_content=tool_display, usage=None, error=None)
|
||||
|
||||
# Execute tool (await instead of asyncio.run)
|
||||
result = await self.execute_tool(tc.function.name, args)
|
||||
|
||||
if "error" in result:
|
||||
error_msg = f" ✗ Error: {result['error']}\n"
|
||||
yield StreamChunk(id="", delta_content=error_msg, usage=None, error=None)
|
||||
else:
|
||||
success_msg = self._format_tool_success(tc.function.name, result)
|
||||
yield StreamChunk(id="", delta_content=success_msg, usage=None, error=None)
|
||||
|
||||
tool_results.append({
|
||||
"tool_call_id": tc.id,
|
||||
"role": "tool",
|
||||
"name": tc.function.name,
|
||||
"content": json.dumps(result),
|
||||
})
|
||||
|
||||
# Add assistant message with tool calls
|
||||
api_messages.append({
|
||||
"role": "assistant",
|
||||
"content": response.content,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": tc.type,
|
||||
"function": {
|
||||
"name": tc.function.name,
|
||||
"arguments": tc.function.arguments,
|
||||
},
|
||||
}
|
||||
for tc in tool_calls
|
||||
],
|
||||
})
|
||||
|
||||
# Add tool results
|
||||
api_messages.extend(tool_results)
|
||||
loop_count += 1
|
||||
|
||||
# Max loops reached
|
||||
yield StreamChunk(
|
||||
id="",
|
||||
delta_content="\n⚠️ Maximum tool call loops reached\n",
|
||||
usage=None,
|
||||
error="Max loops reached"
|
||||
)
|
||||
|
||||
def _format_tool_success(self, tool_name: str, result: Dict[str, Any]) -> str:
|
||||
"""Format a success message for a tool call."""
|
||||
if tool_name == "search_files":
|
||||
count = result.get("count", 0)
|
||||
return f" ✓ Found {count} file(s)\n"
|
||||
elif tool_name == "read_file":
|
||||
size = result.get("size", 0)
|
||||
truncated = " (truncated)" if result.get("truncated") else ""
|
||||
return f" ✓ Read {size} bytes{truncated}\n"
|
||||
elif tool_name == "list_directory":
|
||||
count = result.get("count", 0)
|
||||
return f" ✓ Listed {count} item(s)\n"
|
||||
elif tool_name == "inspect_database":
|
||||
if "table" in result:
|
||||
return f" ✓ Inspected table: {result['table']}\n"
|
||||
else:
|
||||
return f" ✓ Inspected database ({result.get('table_count', 0)} tables)\n"
|
||||
elif tool_name == "search_database":
|
||||
count = result.get("count", 0)
|
||||
return f" ✓ Found {count} match(es)\n"
|
||||
elif tool_name == "query_database":
|
||||
count = result.get("count", 0)
|
||||
return f" ✓ Query returned {count} row(s)\n"
|
||||
else:
|
||||
return " ✓ Success\n"
|
||||
|
||||
async def _stream_response_async(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
model_id: str,
|
||||
max_tokens: Optional[int] = None,
|
||||
transforms: Optional[List[str]] = None,
|
||||
) -> AsyncIterator[StreamChunk]:
|
||||
"""
|
||||
Async version of _stream_response for TUI.
|
||||
|
||||
Args:
|
||||
messages: API messages
|
||||
model_id: Model ID
|
||||
max_tokens: Max tokens
|
||||
transforms: Transforms
|
||||
|
||||
Yields:
|
||||
StreamChunk objects
|
||||
"""
|
||||
response = self.client.chat(
|
||||
messages=messages,
|
||||
model=model_id,
|
||||
stream=True,
|
||||
max_tokens=max_tokens,
|
||||
transforms=transforms,
|
||||
)
|
||||
|
||||
if isinstance(response, ChatResponse):
|
||||
# Non-streaming response
|
||||
chunk = StreamChunk(
|
||||
id="",
|
||||
delta_content=response.content,
|
||||
usage=response.usage,
|
||||
error=None,
|
||||
)
|
||||
yield chunk
|
||||
return
|
||||
|
||||
# Stream the response
|
||||
for chunk in response:
|
||||
if chunk.error:
|
||||
yield StreamChunk(id="", delta_content=None, usage=None, error=chunk.error)
|
||||
break
|
||||
yield chunk
|
||||
|
||||
# ========== END ASYNC METHODS ==========
|
||||
|
||||
def add_to_history(
|
||||
self,
|
||||
prompt: str,
|
||||
response: str,
|
||||
usage: Optional[UsageStats] = None,
|
||||
cost: float = 0.0,
|
||||
) -> None:
|
||||
"""
|
||||
Add an exchange to the history.
|
||||
|
||||
Args:
|
||||
prompt: User prompt
|
||||
response: Assistant response
|
||||
usage: Usage statistics
|
||||
cost: Cost if not in usage
|
||||
"""
|
||||
entry = HistoryEntry(
|
||||
prompt=prompt,
|
||||
response=response,
|
||||
prompt_tokens=usage.prompt_tokens if usage else 0,
|
||||
completion_tokens=usage.completion_tokens if usage else 0,
|
||||
msg_cost=usage.total_cost_usd if usage and usage.total_cost_usd else cost,
|
||||
timestamp=time.time(),
|
||||
)
|
||||
self.history.append(entry)
|
||||
self.current_index = len(self.history) - 1
|
||||
self.stats.add_usage(usage, cost)
|
||||
|
||||
def save_conversation(self, name: str) -> bool:
|
||||
"""
|
||||
Save the current conversation.
|
||||
|
||||
Args:
|
||||
name: Name for the saved conversation
|
||||
|
||||
Returns:
|
||||
True if saved successfully
|
||||
"""
|
||||
if not self.history:
|
||||
return False
|
||||
|
||||
data = [e.to_dict() for e in self.history]
|
||||
self.db.save_conversation(name, data)
|
||||
self.logger.info(f"Saved conversation: {name}")
|
||||
return True
|
||||
|
||||
def load_conversation(self, name: str) -> bool:
|
||||
"""
|
||||
Load a saved conversation.
|
||||
|
||||
Args:
|
||||
name: Name of the conversation to load
|
||||
|
||||
Returns:
|
||||
True if loaded successfully
|
||||
"""
|
||||
data = self.db.load_conversation(name)
|
||||
if not data:
|
||||
return False
|
||||
|
||||
self.history.clear()
|
||||
for entry_dict in data:
|
||||
self.history.append(HistoryEntry(
|
||||
prompt=entry_dict.get("prompt", ""),
|
||||
response=entry_dict.get("response", ""),
|
||||
prompt_tokens=entry_dict.get("prompt_tokens", 0),
|
||||
completion_tokens=entry_dict.get("completion_tokens", 0),
|
||||
msg_cost=entry_dict.get("msg_cost", 0.0),
|
||||
))
|
||||
|
||||
self.current_index = len(self.history) - 1
|
||||
self.memory_start_index = 0
|
||||
self.stats = SessionStats() # Reset stats for loaded conversation
|
||||
self.logger.info(f"Loaded conversation: {name}")
|
||||
return True
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the session state."""
|
||||
self.history.clear()
|
||||
self.stats = SessionStats()
|
||||
self.system_prompt = ""
|
||||
self.memory_start_index = 0
|
||||
self.current_index = 0
|
||||
self.logger.info("Session reset")
|
||||
|
||||
def check_warnings(self) -> List[str]:
|
||||
"""
|
||||
Check for cost and credit warnings.
|
||||
|
||||
Returns:
|
||||
List of warning messages
|
||||
"""
|
||||
warnings = []
|
||||
|
||||
# Check last message cost
|
||||
if self.history:
|
||||
last_cost = self.history[-1].msg_cost
|
||||
threshold = self.settings.cost_warning_threshold
|
||||
if last_cost > threshold:
|
||||
warnings.append(
|
||||
f"High cost: ${last_cost:.4f} exceeds threshold ${threshold:.4f}"
|
||||
)
|
||||
|
||||
# Check credits
|
||||
credits = self.client.get_credits()
|
||||
if credits:
|
||||
left = credits.get("credits_left", 0)
|
||||
total = credits.get("total_credits", 0)
|
||||
|
||||
if left < LOW_CREDIT_AMOUNT:
|
||||
warnings.append(f"Low credits: ${left:.2f} remaining!")
|
||||
elif total > 0 and left < total * LOW_CREDIT_RATIO:
|
||||
warnings.append(f"Credits low: less than 10% remaining (${left:.2f})")
|
||||
|
||||
return warnings
|
||||
@@ -1,28 +0,0 @@
|
||||
"""
|
||||
Model Context Protocol (MCP) integration for oAI.
|
||||
|
||||
This package provides filesystem and database access capabilities
|
||||
through the MCP standard, allowing AI models to interact with
|
||||
local files and SQLite databases safely.
|
||||
|
||||
Key components:
|
||||
- MCPManager: High-level manager for MCP operations
|
||||
- MCPFilesystemServer: Filesystem and database access implementation
|
||||
- GitignoreParser: Pattern matching for .gitignore support
|
||||
- SQLiteQueryValidator: Query safety validation
|
||||
- CrossPlatformMCPConfig: OS-specific configuration
|
||||
"""
|
||||
|
||||
from oai.mcp.manager import MCPManager
|
||||
from oai.mcp.server import MCPFilesystemServer
|
||||
from oai.mcp.gitignore import GitignoreParser
|
||||
from oai.mcp.validators import SQLiteQueryValidator
|
||||
from oai.mcp.platform import CrossPlatformMCPConfig
|
||||
|
||||
__all__ = [
|
||||
"MCPManager",
|
||||
"MCPFilesystemServer",
|
||||
"GitignoreParser",
|
||||
"SQLiteQueryValidator",
|
||||
"CrossPlatformMCPConfig",
|
||||
]
|
||||
@@ -1,166 +0,0 @@
|
||||
"""
|
||||
Gitignore pattern parsing for oAI MCP.
|
||||
|
||||
This module implements .gitignore pattern matching to filter files
|
||||
during MCP filesystem operations.
|
||||
"""
|
||||
|
||||
import fnmatch
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
from oai.utils.logging import get_logger
|
||||
|
||||
|
||||
class GitignoreParser:
|
||||
"""
|
||||
Parse and apply .gitignore patterns.
|
||||
|
||||
Supports standard gitignore syntax including:
|
||||
- Wildcards (*) and double wildcards (**)
|
||||
- Directory-only patterns (ending with /)
|
||||
- Negation patterns (starting with !)
|
||||
- Comments (lines starting with #)
|
||||
|
||||
Patterns are applied relative to the directory containing
|
||||
the .gitignore file.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize an empty pattern collection."""
|
||||
# List of (pattern, is_negation, source_dir)
|
||||
self.patterns: List[Tuple[str, bool, Path]] = []
|
||||
|
||||
def add_gitignore(self, gitignore_path: Path) -> None:
|
||||
"""
|
||||
Parse and add patterns from a .gitignore file.
|
||||
|
||||
Args:
|
||||
gitignore_path: Path to the .gitignore file
|
||||
"""
|
||||
logger = get_logger()
|
||||
|
||||
if not gitignore_path.exists():
|
||||
return
|
||||
|
||||
try:
|
||||
source_dir = gitignore_path.parent
|
||||
|
||||
with open(gitignore_path, "r", encoding="utf-8") as f:
|
||||
for line_num, line in enumerate(f, 1):
|
||||
line = line.rstrip("\n\r")
|
||||
|
||||
# Skip empty lines and comments
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
|
||||
# Check for negation pattern
|
||||
is_negation = line.startswith("!")
|
||||
if is_negation:
|
||||
line = line[1:]
|
||||
|
||||
# Remove leading slash (make relative to gitignore location)
|
||||
if line.startswith("/"):
|
||||
line = line[1:]
|
||||
|
||||
self.patterns.append((line, is_negation, source_dir))
|
||||
|
||||
logger.debug(
|
||||
f"Loaded {len(self.patterns)} patterns from {gitignore_path}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error reading {gitignore_path}: {e}")
|
||||
|
||||
def should_ignore(self, path: Path) -> bool:
|
||||
"""
|
||||
Check if a path should be ignored based on gitignore patterns.
|
||||
|
||||
Patterns are evaluated in order, with later patterns overriding
|
||||
earlier ones. Negation patterns (starting with !) un-ignore
|
||||
previously matched paths.
|
||||
|
||||
Args:
|
||||
path: Path to check
|
||||
|
||||
Returns:
|
||||
True if the path should be ignored
|
||||
"""
|
||||
if not self.patterns:
|
||||
return False
|
||||
|
||||
ignored = False
|
||||
|
||||
for pattern, is_negation, source_dir in self.patterns:
|
||||
# Only apply pattern if path is under the source directory
|
||||
try:
|
||||
rel_path = path.relative_to(source_dir)
|
||||
except ValueError:
|
||||
# Path is not relative to this gitignore's directory
|
||||
continue
|
||||
|
||||
rel_path_str = str(rel_path)
|
||||
|
||||
# Check if pattern matches
|
||||
if self._match_pattern(pattern, rel_path_str, path.is_dir()):
|
||||
if is_negation:
|
||||
ignored = False # Negation patterns un-ignore
|
||||
else:
|
||||
ignored = True
|
||||
|
||||
return ignored
|
||||
|
||||
def _match_pattern(self, pattern: str, path: str, is_dir: bool) -> bool:
|
||||
"""
|
||||
Match a gitignore pattern against a path.
|
||||
|
||||
Args:
|
||||
pattern: The gitignore pattern
|
||||
path: The relative path string to match
|
||||
is_dir: Whether the path is a directory
|
||||
|
||||
Returns:
|
||||
True if the pattern matches
|
||||
"""
|
||||
# Directory-only pattern (ends with /)
|
||||
if pattern.endswith("/"):
|
||||
if not is_dir:
|
||||
return False
|
||||
pattern = pattern[:-1]
|
||||
|
||||
# Handle ** patterns (matches any number of directories)
|
||||
if "**" in pattern:
|
||||
pattern_parts = pattern.split("**")
|
||||
if len(pattern_parts) == 2:
|
||||
prefix, suffix = pattern_parts
|
||||
|
||||
# Match if path starts with prefix and ends with suffix
|
||||
if prefix:
|
||||
if not path.startswith(prefix.rstrip("/")):
|
||||
return False
|
||||
if suffix:
|
||||
suffix = suffix.lstrip("/")
|
||||
if not (path.endswith(suffix) or f"/{suffix}" in path):
|
||||
return False
|
||||
return True
|
||||
|
||||
# Direct match using fnmatch
|
||||
if fnmatch.fnmatch(path, pattern):
|
||||
return True
|
||||
|
||||
# Match as subdirectory pattern (pattern without / matches in any directory)
|
||||
if "/" not in pattern:
|
||||
parts = path.split("/")
|
||||
if any(fnmatch.fnmatch(part, pattern) for part in parts):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all loaded patterns."""
|
||||
self.patterns = []
|
||||
|
||||
@property
|
||||
def pattern_count(self) -> int:
|
||||
"""Get the number of loaded patterns."""
|
||||
return len(self.patterns)
|
||||
1365
oai/mcp/manager.py
1365
oai/mcp/manager.py
File diff suppressed because it is too large
Load Diff
@@ -1,228 +0,0 @@
|
||||
"""
|
||||
Cross-platform MCP configuration for oAI.
|
||||
|
||||
This module handles OS-specific configuration, path handling,
|
||||
and security checks for the MCP filesystem server.
|
||||
"""
|
||||
|
||||
import os
|
||||
import platform
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
from oai.constants import SYSTEM_DIRS_BLACKLIST
|
||||
from oai.utils.logging import get_logger
|
||||
|
||||
|
||||
class CrossPlatformMCPConfig:
|
||||
"""
|
||||
Handle OS-specific MCP configuration.
|
||||
|
||||
Provides methods for path normalization, security validation,
|
||||
and OS-specific default directories.
|
||||
|
||||
Attributes:
|
||||
system: Operating system name
|
||||
is_macos: Whether running on macOS
|
||||
is_linux: Whether running on Linux
|
||||
is_windows: Whether running on Windows
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize platform detection."""
|
||||
self.system = platform.system()
|
||||
self.is_macos = self.system == "Darwin"
|
||||
self.is_linux = self.system == "Linux"
|
||||
self.is_windows = self.system == "Windows"
|
||||
|
||||
logger = get_logger()
|
||||
logger.info(f"Detected OS: {self.system}")
|
||||
|
||||
def get_default_allowed_dirs(self) -> List[Path]:
|
||||
"""
|
||||
Get safe default directories for the current OS.
|
||||
|
||||
Returns:
|
||||
List of default directories that are safe to access
|
||||
"""
|
||||
home = Path.home()
|
||||
|
||||
if self.is_macos:
|
||||
return [
|
||||
home / "Documents",
|
||||
home / "Desktop",
|
||||
home / "Downloads",
|
||||
]
|
||||
|
||||
elif self.is_linux:
|
||||
dirs = [home / "Documents"]
|
||||
|
||||
# Try to get XDG directories
|
||||
try:
|
||||
for xdg_dir in ["DOCUMENTS", "DESKTOP", "DOWNLOAD"]:
|
||||
result = subprocess.run(
|
||||
["xdg-user-dir", xdg_dir],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=1
|
||||
)
|
||||
if result.returncode == 0:
|
||||
dir_path = Path(result.stdout.strip())
|
||||
if dir_path.exists():
|
||||
dirs.append(dir_path)
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError):
|
||||
# Fallback to standard locations
|
||||
dirs.extend([
|
||||
home / "Desktop",
|
||||
home / "Downloads",
|
||||
])
|
||||
|
||||
return list(set(dirs))
|
||||
|
||||
elif self.is_windows:
|
||||
return [
|
||||
home / "Documents",
|
||||
home / "Desktop",
|
||||
home / "Downloads",
|
||||
]
|
||||
|
||||
# Fallback for unknown OS
|
||||
return [home]
|
||||
|
||||
def get_python_command(self) -> str:
|
||||
"""
|
||||
Get the Python executable path.
|
||||
|
||||
Returns:
|
||||
Path to the Python executable
|
||||
"""
|
||||
import sys
|
||||
return sys.executable
|
||||
|
||||
def get_filesystem_warning(self) -> str:
|
||||
"""
|
||||
Get OS-specific security warning message.
|
||||
|
||||
Returns:
|
||||
Warning message for the current OS
|
||||
"""
|
||||
if self.is_macos:
|
||||
return """
|
||||
Note: macOS Security
|
||||
The Filesystem MCP server needs access to your selected folder.
|
||||
You may see a security prompt - click 'Allow' to proceed.
|
||||
(System Settings > Privacy & Security > Files and Folders)
|
||||
"""
|
||||
elif self.is_linux:
|
||||
return """
|
||||
Note: Linux Security
|
||||
The Filesystem MCP server will access your selected folder.
|
||||
Ensure oAI has appropriate file permissions.
|
||||
"""
|
||||
elif self.is_windows:
|
||||
return """
|
||||
Note: Windows Security
|
||||
The Filesystem MCP server will access your selected folder.
|
||||
You may need to grant file access permissions.
|
||||
"""
|
||||
return ""
|
||||
|
||||
def normalize_path(self, path: str) -> Path:
|
||||
"""
|
||||
Normalize a path for the current OS.
|
||||
|
||||
Expands user directory (~) and resolves to absolute path.
|
||||
|
||||
Args:
|
||||
path: Path string to normalize
|
||||
|
||||
Returns:
|
||||
Normalized absolute Path
|
||||
"""
|
||||
return Path(os.path.expanduser(path)).resolve()
|
||||
|
||||
def is_system_directory(self, path: Path) -> bool:
|
||||
"""
|
||||
Check if a path is a protected system directory.
|
||||
|
||||
Args:
|
||||
path: Path to check
|
||||
|
||||
Returns:
|
||||
True if the path is a system directory
|
||||
"""
|
||||
path_str = str(path)
|
||||
for blocked in SYSTEM_DIRS_BLACKLIST:
|
||||
if path_str.startswith(blocked):
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_safe_path(self, requested_path: Path, allowed_dirs: List[Path]) -> bool:
|
||||
"""
|
||||
Check if a path is within allowed directories.
|
||||
|
||||
Args:
|
||||
requested_path: Path being requested
|
||||
allowed_dirs: List of allowed parent directories
|
||||
|
||||
Returns:
|
||||
True if the path is within an allowed directory
|
||||
"""
|
||||
try:
|
||||
requested = requested_path.resolve()
|
||||
|
||||
for allowed in allowed_dirs:
|
||||
try:
|
||||
allowed_resolved = allowed.resolve()
|
||||
requested.relative_to(allowed_resolved)
|
||||
return True
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def get_folder_stats(self, folder: Path) -> Dict[str, Any]:
|
||||
"""
|
||||
Get statistics for a folder.
|
||||
|
||||
Args:
|
||||
folder: Path to the folder
|
||||
|
||||
Returns:
|
||||
Dictionary with folder statistics:
|
||||
- exists: Whether the folder exists
|
||||
- file_count: Number of files (if exists)
|
||||
- total_size: Total size in bytes (if exists)
|
||||
- size_mb: Size in megabytes (if exists)
|
||||
- error: Error message (if any)
|
||||
"""
|
||||
logger = get_logger()
|
||||
|
||||
try:
|
||||
if not folder.exists() or not folder.is_dir():
|
||||
return {"exists": False}
|
||||
|
||||
file_count = 0
|
||||
total_size = 0
|
||||
|
||||
for item in folder.rglob("*"):
|
||||
if item.is_file():
|
||||
file_count += 1
|
||||
try:
|
||||
total_size += item.stat().st_size
|
||||
except (OSError, PermissionError):
|
||||
pass
|
||||
|
||||
return {
|
||||
"exists": True,
|
||||
"file_count": file_count,
|
||||
"total_size": total_size,
|
||||
"size_mb": total_size / (1024 * 1024),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting folder stats for {folder}: {e}")
|
||||
return {"exists": False, "error": str(e)}
|
||||
1368
oai/mcp/server.py
1368
oai/mcp/server.py
File diff suppressed because it is too large
Load Diff
@@ -1,123 +0,0 @@
|
||||
"""
|
||||
Query validation for oAI MCP database operations.
|
||||
|
||||
This module provides safety validation for SQL queries to ensure
|
||||
only read-only operations are executed.
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Tuple
|
||||
|
||||
from oai.constants import DANGEROUS_SQL_KEYWORDS
|
||||
|
||||
|
||||
class SQLiteQueryValidator:
|
||||
"""
|
||||
Validate SQLite queries for read-only safety.
|
||||
|
||||
Ensures that only SELECT queries (including CTEs) are allowed
|
||||
and blocks potentially dangerous operations like INSERT, UPDATE,
|
||||
DELETE, DROP, etc.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def is_safe_query(query: str) -> Tuple[bool, str]:
|
||||
"""
|
||||
Validate that a query is a safe read-only SELECT.
|
||||
|
||||
The validation:
|
||||
1. Checks that query starts with SELECT or WITH
|
||||
2. Strips string literals before checking for dangerous keywords
|
||||
3. Blocks any dangerous keywords outside of string literals
|
||||
|
||||
Args:
|
||||
query: SQL query string to validate
|
||||
|
||||
Returns:
|
||||
Tuple of (is_safe, error_message)
|
||||
- is_safe: True if the query is safe to execute
|
||||
- error_message: Description of why query is unsafe (empty if safe)
|
||||
|
||||
Examples:
|
||||
>>> SQLiteQueryValidator.is_safe_query("SELECT * FROM users")
|
||||
(True, "")
|
||||
>>> SQLiteQueryValidator.is_safe_query("DELETE FROM users")
|
||||
(False, "Only SELECT queries are allowed...")
|
||||
>>> SQLiteQueryValidator.is_safe_query("SELECT 'DELETE' FROM users")
|
||||
(True, "") # 'DELETE' is inside a string literal
|
||||
"""
|
||||
query_upper = query.strip().upper()
|
||||
|
||||
# Must start with SELECT or WITH (for CTEs)
|
||||
if not (query_upper.startswith("SELECT") or query_upper.startswith("WITH")):
|
||||
return False, "Only SELECT queries are allowed (including WITH/CTE)"
|
||||
|
||||
# Remove string literals before checking for dangerous keywords
|
||||
# This prevents false positives when keywords appear in data
|
||||
query_no_strings = re.sub(r"'[^']*'", "", query_upper)
|
||||
query_no_strings = re.sub(r'"[^"]*"', "", query_no_strings)
|
||||
|
||||
# Check for dangerous keywords outside of quotes
|
||||
for keyword in DANGEROUS_SQL_KEYWORDS:
|
||||
if re.search(r"\b" + keyword + r"\b", query_no_strings):
|
||||
return False, f"Keyword '{keyword}' not allowed in read-only mode"
|
||||
|
||||
return True, ""
|
||||
|
||||
@staticmethod
|
||||
def sanitize_table_name(table_name: str) -> str:
|
||||
"""
|
||||
Sanitize a table name to prevent SQL injection.
|
||||
|
||||
Only allows alphanumeric characters and underscores.
|
||||
|
||||
Args:
|
||||
table_name: Table name to sanitize
|
||||
|
||||
Returns:
|
||||
Sanitized table name
|
||||
|
||||
Raises:
|
||||
ValueError: If table name contains invalid characters
|
||||
"""
|
||||
# Remove any characters that aren't alphanumeric or underscore
|
||||
sanitized = re.sub(r"[^\w]", "", table_name)
|
||||
|
||||
if not sanitized:
|
||||
raise ValueError("Table name cannot be empty after sanitization")
|
||||
|
||||
if sanitized != table_name:
|
||||
raise ValueError(
|
||||
f"Table name contains invalid characters: {table_name}"
|
||||
)
|
||||
|
||||
return sanitized
|
||||
|
||||
@staticmethod
|
||||
def sanitize_column_name(column_name: str) -> str:
|
||||
"""
|
||||
Sanitize a column name to prevent SQL injection.
|
||||
|
||||
Only allows alphanumeric characters and underscores.
|
||||
|
||||
Args:
|
||||
column_name: Column name to sanitize
|
||||
|
||||
Returns:
|
||||
Sanitized column name
|
||||
|
||||
Raises:
|
||||
ValueError: If column name contains invalid characters
|
||||
"""
|
||||
# Remove any characters that aren't alphanumeric or underscore
|
||||
sanitized = re.sub(r"[^\w]", "", column_name)
|
||||
|
||||
if not sanitized:
|
||||
raise ValueError("Column name cannot be empty after sanitization")
|
||||
|
||||
if sanitized != column_name:
|
||||
raise ValueError(
|
||||
f"Column name contains invalid characters: {column_name}"
|
||||
)
|
||||
|
||||
return sanitized
|
||||
@@ -1,32 +0,0 @@
|
||||
"""
|
||||
Provider abstraction for oAI.
|
||||
|
||||
This module provides a unified interface for AI model providers,
|
||||
enabling easy extension to support additional providers beyond OpenRouter.
|
||||
"""
|
||||
|
||||
from oai.providers.base import (
|
||||
AIProvider,
|
||||
ChatMessage,
|
||||
ChatResponse,
|
||||
ToolCall,
|
||||
ToolFunction,
|
||||
UsageStats,
|
||||
ModelInfo,
|
||||
ProviderCapabilities,
|
||||
)
|
||||
from oai.providers.openrouter import OpenRouterProvider
|
||||
|
||||
__all__ = [
|
||||
# Base classes and types
|
||||
"AIProvider",
|
||||
"ChatMessage",
|
||||
"ChatResponse",
|
||||
"ToolCall",
|
||||
"ToolFunction",
|
||||
"UsageStats",
|
||||
"ModelInfo",
|
||||
"ProviderCapabilities",
|
||||
# Provider implementations
|
||||
"OpenRouterProvider",
|
||||
]
|
||||
@@ -1,413 +0,0 @@
|
||||
"""
|
||||
Abstract base provider for AI model integration.
|
||||
|
||||
This module defines the interface that all AI providers must implement,
|
||||
along with common data structures for requests and responses.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
|
||||
|
||||
class MessageRole(str, Enum):
|
||||
"""Message roles in a conversation."""
|
||||
|
||||
SYSTEM = "system"
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
TOOL = "tool"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolFunction:
|
||||
"""
|
||||
Represents a function within a tool call.
|
||||
|
||||
Attributes:
|
||||
name: The function name
|
||||
arguments: JSON string of function arguments
|
||||
"""
|
||||
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCall:
|
||||
"""
|
||||
Represents a tool/function call requested by the model.
|
||||
|
||||
Attributes:
|
||||
id: Unique identifier for this tool call
|
||||
type: Type of tool call (usually "function")
|
||||
function: The function being called
|
||||
"""
|
||||
|
||||
id: str
|
||||
type: str
|
||||
function: ToolFunction
|
||||
|
||||
|
||||
@dataclass
|
||||
class UsageStats:
|
||||
"""
|
||||
Token usage statistics from an API response.
|
||||
|
||||
Attributes:
|
||||
prompt_tokens: Number of tokens in the prompt
|
||||
completion_tokens: Number of tokens in the completion
|
||||
total_tokens: Total tokens used
|
||||
total_cost_usd: Cost in USD (if available from API)
|
||||
"""
|
||||
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
total_cost_usd: Optional[float] = None
|
||||
|
||||
@property
|
||||
def input_tokens(self) -> int:
|
||||
"""Alias for prompt_tokens."""
|
||||
return self.prompt_tokens
|
||||
|
||||
@property
|
||||
def output_tokens(self) -> int:
|
||||
"""Alias for completion_tokens."""
|
||||
return self.completion_tokens
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatMessage:
|
||||
"""
|
||||
A single message in a chat conversation.
|
||||
|
||||
Attributes:
|
||||
role: The role of the message sender
|
||||
content: Message content (text or structured content blocks)
|
||||
name: Optional name for the sender
|
||||
tool_calls: List of tool calls (for assistant messages)
|
||||
tool_call_id: Tool call ID this message responds to (for tool messages)
|
||||
"""
|
||||
|
||||
role: str
|
||||
content: Union[str, List[Dict[str, Any]], None] = None
|
||||
name: Optional[str] = None
|
||||
tool_calls: Optional[List[ToolCall]] = None
|
||||
tool_call_id: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary format for API requests."""
|
||||
result: Dict[str, Any] = {"role": self.role}
|
||||
|
||||
if self.content is not None:
|
||||
result["content"] = self.content
|
||||
|
||||
if self.name:
|
||||
result["name"] = self.name
|
||||
|
||||
if self.tool_calls:
|
||||
result["tool_calls"] = [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": tc.type,
|
||||
"function": {
|
||||
"name": tc.function.name,
|
||||
"arguments": tc.function.arguments,
|
||||
},
|
||||
}
|
||||
for tc in self.tool_calls
|
||||
]
|
||||
|
||||
if self.tool_call_id:
|
||||
result["tool_call_id"] = self.tool_call_id
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatResponseChoice:
|
||||
"""
|
||||
A single choice in a chat response.
|
||||
|
||||
Attributes:
|
||||
index: Index of this choice
|
||||
message: The response message
|
||||
finish_reason: Why the response ended
|
||||
"""
|
||||
|
||||
index: int
|
||||
message: ChatMessage
|
||||
finish_reason: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatResponse:
|
||||
"""
|
||||
Response from a chat completion request.
|
||||
|
||||
Attributes:
|
||||
id: Unique identifier for this response
|
||||
choices: List of response choices
|
||||
usage: Token usage statistics
|
||||
model: Model that generated this response
|
||||
created: Unix timestamp of creation
|
||||
"""
|
||||
|
||||
id: str
|
||||
choices: List[ChatResponseChoice]
|
||||
usage: Optional[UsageStats] = None
|
||||
model: Optional[str] = None
|
||||
created: Optional[int] = None
|
||||
|
||||
@property
|
||||
def message(self) -> Optional[ChatMessage]:
|
||||
"""Get the first choice's message."""
|
||||
if self.choices:
|
||||
return self.choices[0].message
|
||||
return None
|
||||
|
||||
@property
|
||||
def content(self) -> Optional[str]:
|
||||
"""Get the text content of the first choice."""
|
||||
msg = self.message
|
||||
if msg and isinstance(msg.content, str):
|
||||
return msg.content
|
||||
return None
|
||||
|
||||
@property
|
||||
def tool_calls(self) -> Optional[List[ToolCall]]:
|
||||
"""Get tool calls from the first choice."""
|
||||
msg = self.message
|
||||
if msg:
|
||||
return msg.tool_calls
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamChunk:
|
||||
"""
|
||||
A single chunk from a streaming response.
|
||||
|
||||
Attributes:
|
||||
id: Response ID
|
||||
delta_content: New content in this chunk
|
||||
finish_reason: Finish reason (if this is the last chunk)
|
||||
usage: Usage stats (usually in the last chunk)
|
||||
error: Error message if something went wrong
|
||||
"""
|
||||
|
||||
id: str
|
||||
delta_content: Optional[str] = None
|
||||
finish_reason: Optional[str] = None
|
||||
usage: Optional[UsageStats] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelInfo:
|
||||
"""
|
||||
Information about an AI model.
|
||||
|
||||
Attributes:
|
||||
id: Unique model identifier
|
||||
name: Display name
|
||||
description: Model description
|
||||
context_length: Maximum context window size
|
||||
pricing: Pricing info (input/output per million tokens)
|
||||
supported_parameters: List of supported API parameters
|
||||
input_modalities: Supported input types (text, image, etc.)
|
||||
output_modalities: Supported output types
|
||||
"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
description: str = ""
|
||||
context_length: int = 0
|
||||
pricing: Dict[str, float] = field(default_factory=dict)
|
||||
supported_parameters: List[str] = field(default_factory=list)
|
||||
input_modalities: List[str] = field(default_factory=lambda: ["text"])
|
||||
output_modalities: List[str] = field(default_factory=lambda: ["text"])
|
||||
|
||||
def supports_images(self) -> bool:
|
||||
"""Check if model supports image input."""
|
||||
return "image" in self.input_modalities
|
||||
|
||||
def supports_tools(self) -> bool:
|
||||
"""Check if model supports function calling/tools."""
|
||||
return "tools" in self.supported_parameters or "functions" in self.supported_parameters
|
||||
|
||||
def supports_streaming(self) -> bool:
|
||||
"""Check if model supports streaming responses."""
|
||||
return "stream" in self.supported_parameters
|
||||
|
||||
def supports_online(self) -> bool:
|
||||
"""Check if model supports web search (online mode)."""
|
||||
return self.supports_tools()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProviderCapabilities:
|
||||
"""
|
||||
Capabilities supported by a provider.
|
||||
|
||||
Attributes:
|
||||
streaming: Provider supports streaming responses
|
||||
tools: Provider supports function calling
|
||||
images: Provider supports image inputs
|
||||
online: Provider supports web search
|
||||
max_context: Maximum context length across all models
|
||||
"""
|
||||
|
||||
streaming: bool = True
|
||||
tools: bool = True
|
||||
images: bool = True
|
||||
online: bool = False
|
||||
max_context: int = 128000
|
||||
|
||||
|
||||
class AIProvider(ABC):
|
||||
"""
|
||||
Abstract base class for AI model providers.
|
||||
|
||||
All provider implementations must inherit from this class
|
||||
and implement the required abstract methods.
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: str, base_url: Optional[str] = None):
|
||||
"""
|
||||
Initialize the provider.
|
||||
|
||||
Args:
|
||||
api_key: API key for authentication
|
||||
base_url: Optional custom base URL for the API
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Get the provider name."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def capabilities(self) -> ProviderCapabilities:
|
||||
"""Get provider capabilities."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_models(self) -> List[ModelInfo]:
|
||||
"""
|
||||
Fetch available models from the provider.
|
||||
|
||||
Returns:
|
||||
List of available models with their info
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_model(self, model_id: str) -> Optional[ModelInfo]:
|
||||
"""
|
||||
Get information about a specific model.
|
||||
|
||||
Args:
|
||||
model_id: The model identifier
|
||||
|
||||
Returns:
|
||||
Model information or None if not found
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def chat(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[ChatMessage],
|
||||
stream: bool = False,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[ChatResponse, Iterator[StreamChunk]]:
|
||||
"""
|
||||
Send a chat completion request.
|
||||
|
||||
Args:
|
||||
model: Model ID to use
|
||||
messages: List of chat messages
|
||||
stream: Whether to stream the response
|
||||
max_tokens: Maximum tokens in response
|
||||
temperature: Sampling temperature
|
||||
tools: List of tool definitions for function calling
|
||||
tool_choice: How to handle tool selection ("auto", "none", etc.)
|
||||
**kwargs: Additional provider-specific parameters
|
||||
|
||||
Returns:
|
||||
ChatResponse for non-streaming, Iterator[StreamChunk] for streaming
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def chat_async(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[ChatMessage],
|
||||
stream: bool = False,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[ChatResponse, AsyncIterator[StreamChunk]]:
|
||||
"""
|
||||
Send an async chat completion request.
|
||||
|
||||
Args:
|
||||
model: Model ID to use
|
||||
messages: List of chat messages
|
||||
stream: Whether to stream the response
|
||||
max_tokens: Maximum tokens in response
|
||||
temperature: Sampling temperature
|
||||
tools: List of tool definitions for function calling
|
||||
tool_choice: How to handle tool selection
|
||||
**kwargs: Additional provider-specific parameters
|
||||
|
||||
Returns:
|
||||
ChatResponse for non-streaming, AsyncIterator[StreamChunk] for streaming
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_credits(self) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get account credit/balance information.
|
||||
|
||||
Returns:
|
||||
Dict with credit info or None if not supported
|
||||
"""
|
||||
pass
|
||||
|
||||
def validate_api_key(self) -> bool:
|
||||
"""
|
||||
Validate that the API key is valid.
|
||||
|
||||
Returns:
|
||||
True if API key is valid
|
||||
"""
|
||||
try:
|
||||
self.list_models()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
@@ -1,630 +0,0 @@
|
||||
"""
|
||||
OpenRouter provider implementation.
|
||||
|
||||
This module implements the AIProvider interface for OpenRouter,
|
||||
supporting chat completions, streaming, and function calling.
|
||||
"""
|
||||
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
|
||||
|
||||
import requests
|
||||
from openrouter import OpenRouter
|
||||
|
||||
from oai.constants import APP_NAME, APP_URL, DEFAULT_BASE_URL
|
||||
from oai.providers.base import (
|
||||
AIProvider,
|
||||
ChatMessage,
|
||||
ChatResponse,
|
||||
ChatResponseChoice,
|
||||
ModelInfo,
|
||||
ProviderCapabilities,
|
||||
StreamChunk,
|
||||
ToolCall,
|
||||
ToolFunction,
|
||||
UsageStats,
|
||||
)
|
||||
from oai.utils.logging import get_logger
|
||||
|
||||
|
||||
class OpenRouterProvider(AIProvider):
|
||||
"""
|
||||
OpenRouter API provider implementation.
|
||||
|
||||
Provides access to multiple AI models through OpenRouter's unified API,
|
||||
supporting chat completions, streaming responses, and function calling.
|
||||
|
||||
Attributes:
|
||||
client: The underlying OpenRouter client
|
||||
_models_cache: Cached list of available models
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: Optional[str] = None,
|
||||
app_name: str = APP_NAME,
|
||||
app_url: str = APP_URL,
|
||||
):
|
||||
"""
|
||||
Initialize the OpenRouter provider.
|
||||
|
||||
Args:
|
||||
api_key: OpenRouter API key
|
||||
base_url: Optional custom base URL
|
||||
app_name: Application name for API headers
|
||||
app_url: Application URL for API headers
|
||||
"""
|
||||
super().__init__(api_key, base_url or DEFAULT_BASE_URL)
|
||||
self.app_name = app_name
|
||||
self.app_url = app_url
|
||||
self.client = OpenRouter(api_key=api_key)
|
||||
self._models_cache: Optional[List[ModelInfo]] = None
|
||||
self._raw_models_cache: Optional[List[Dict[str, Any]]] = None
|
||||
|
||||
self.logger = get_logger()
|
||||
self.logger.info(f"OpenRouter provider initialized with base URL: {self.base_url}")
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Get the provider name."""
|
||||
return "OpenRouter"
|
||||
|
||||
@property
|
||||
def capabilities(self) -> ProviderCapabilities:
|
||||
"""Get provider capabilities."""
|
||||
return ProviderCapabilities(
|
||||
streaming=True,
|
||||
tools=True,
|
||||
images=True,
|
||||
online=True,
|
||||
max_context=2000000, # Claude models support up to 200k
|
||||
)
|
||||
|
||||
def _get_headers(self) -> Dict[str, str]:
|
||||
"""Get standard HTTP headers for API requests."""
|
||||
headers = {
|
||||
"HTTP-Referer": self.app_url,
|
||||
"X-Title": self.app_name,
|
||||
}
|
||||
if self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
return headers
|
||||
|
||||
def _parse_model(self, model_data: Dict[str, Any]) -> ModelInfo:
|
||||
"""
|
||||
Parse raw model data into ModelInfo.
|
||||
|
||||
Args:
|
||||
model_data: Raw model data from API
|
||||
|
||||
Returns:
|
||||
Parsed ModelInfo object
|
||||
"""
|
||||
architecture = model_data.get("architecture", {})
|
||||
pricing_data = model_data.get("pricing", {})
|
||||
|
||||
# Parse pricing (convert from string to float if needed)
|
||||
pricing = {}
|
||||
for key in ["prompt", "completion"]:
|
||||
value = pricing_data.get(key)
|
||||
if value is not None:
|
||||
try:
|
||||
# Convert from per-token to per-million-tokens
|
||||
pricing[key] = float(value) * 1_000_000
|
||||
except (ValueError, TypeError):
|
||||
pricing[key] = 0.0
|
||||
|
||||
return ModelInfo(
|
||||
id=model_data.get("id", ""),
|
||||
name=model_data.get("name", model_data.get("id", "")),
|
||||
description=model_data.get("description", ""),
|
||||
context_length=model_data.get("context_length", 0),
|
||||
pricing=pricing,
|
||||
supported_parameters=model_data.get("supported_parameters", []),
|
||||
input_modalities=architecture.get("input_modalities", ["text"]),
|
||||
output_modalities=architecture.get("output_modalities", ["text"]),
|
||||
)
|
||||
|
||||
def list_models(self, filter_text_only: bool = True) -> List[ModelInfo]:
|
||||
"""
|
||||
Fetch available models from OpenRouter.
|
||||
|
||||
Args:
|
||||
filter_text_only: If True, exclude video-only models
|
||||
|
||||
Returns:
|
||||
List of available models
|
||||
|
||||
Raises:
|
||||
Exception: If API request fails
|
||||
"""
|
||||
if self._models_cache is not None:
|
||||
return self._models_cache
|
||||
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{self.base_url}/models",
|
||||
headers=self._get_headers(),
|
||||
timeout=10,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
raw_models = response.json().get("data", [])
|
||||
self._raw_models_cache = raw_models
|
||||
|
||||
models = []
|
||||
for model_data in raw_models:
|
||||
# Optionally filter out video-only models
|
||||
if filter_text_only:
|
||||
modalities = model_data.get("modalities", [])
|
||||
if modalities and "video" in modalities and "text" not in modalities:
|
||||
continue
|
||||
|
||||
models.append(self._parse_model(model_data))
|
||||
|
||||
self._models_cache = models
|
||||
self.logger.info(f"Fetched {len(models)} models from OpenRouter")
|
||||
return models
|
||||
|
||||
except requests.RequestException as e:
|
||||
self.logger.error(f"Failed to fetch models: {e}")
|
||||
raise
|
||||
|
||||
def get_raw_models(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get raw model data as returned by the API.
|
||||
|
||||
Useful for accessing provider-specific fields not in ModelInfo.
|
||||
|
||||
Returns:
|
||||
List of raw model dictionaries
|
||||
"""
|
||||
if self._raw_models_cache is None:
|
||||
self.list_models()
|
||||
return self._raw_models_cache or []
|
||||
|
||||
def get_model(self, model_id: str) -> Optional[ModelInfo]:
|
||||
"""
|
||||
Get information about a specific model.
|
||||
|
||||
Args:
|
||||
model_id: The model identifier
|
||||
|
||||
Returns:
|
||||
Model information or None if not found
|
||||
"""
|
||||
models = self.list_models()
|
||||
for model in models:
|
||||
if model.id == model_id:
|
||||
return model
|
||||
return None
|
||||
|
||||
def get_raw_model(self, model_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get raw model data for a specific model.
|
||||
|
||||
Args:
|
||||
model_id: The model identifier
|
||||
|
||||
Returns:
|
||||
Raw model dictionary or None if not found
|
||||
"""
|
||||
raw_models = self.get_raw_models()
|
||||
for model in raw_models:
|
||||
if model.get("id") == model_id:
|
||||
return model
|
||||
return None
|
||||
|
||||
def _convert_messages(self, messages: List[ChatMessage]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Convert ChatMessage objects to API format.
|
||||
|
||||
Args:
|
||||
messages: List of ChatMessage objects
|
||||
|
||||
Returns:
|
||||
List of message dictionaries for the API
|
||||
"""
|
||||
return [msg.to_dict() for msg in messages]
|
||||
|
||||
def _parse_usage(self, usage_data: Any) -> Optional[UsageStats]:
|
||||
"""
|
||||
Parse usage data from API response.
|
||||
|
||||
Args:
|
||||
usage_data: Raw usage data from API
|
||||
|
||||
Returns:
|
||||
Parsed UsageStats or None
|
||||
"""
|
||||
if not usage_data:
|
||||
return None
|
||||
|
||||
# Handle both attribute and dict access
|
||||
prompt_tokens = 0
|
||||
completion_tokens = 0
|
||||
total_cost = None
|
||||
|
||||
if hasattr(usage_data, "prompt_tokens"):
|
||||
prompt_tokens = getattr(usage_data, "prompt_tokens", 0) or 0
|
||||
elif isinstance(usage_data, dict):
|
||||
prompt_tokens = usage_data.get("prompt_tokens", 0) or 0
|
||||
|
||||
if hasattr(usage_data, "completion_tokens"):
|
||||
completion_tokens = getattr(usage_data, "completion_tokens", 0) or 0
|
||||
elif isinstance(usage_data, dict):
|
||||
completion_tokens = usage_data.get("completion_tokens", 0) or 0
|
||||
|
||||
# Try alternative naming (input_tokens/output_tokens)
|
||||
if prompt_tokens == 0:
|
||||
if hasattr(usage_data, "input_tokens"):
|
||||
prompt_tokens = getattr(usage_data, "input_tokens", 0) or 0
|
||||
elif isinstance(usage_data, dict):
|
||||
prompt_tokens = usage_data.get("input_tokens", 0) or 0
|
||||
|
||||
if completion_tokens == 0:
|
||||
if hasattr(usage_data, "output_tokens"):
|
||||
completion_tokens = getattr(usage_data, "output_tokens", 0) or 0
|
||||
elif isinstance(usage_data, dict):
|
||||
completion_tokens = usage_data.get("output_tokens", 0) or 0
|
||||
|
||||
# Get cost if available
|
||||
# OpenRouter returns cost in different places:
|
||||
# 1. As 'total_cost_usd' in usage object (rare)
|
||||
# 2. As 'usage' at root level (common - this is the dollar amount)
|
||||
total_cost = None
|
||||
if hasattr(usage_data, "total_cost_usd"):
|
||||
total_cost = getattr(usage_data, "total_cost_usd", None)
|
||||
elif hasattr(usage_data, "usage"):
|
||||
# OpenRouter puts cost as 'usage' field (dollar amount)
|
||||
total_cost = getattr(usage_data, "usage", None)
|
||||
elif isinstance(usage_data, dict):
|
||||
total_cost = usage_data.get("total_cost_usd") or usage_data.get("usage")
|
||||
|
||||
return UsageStats(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
total_cost_usd=float(total_cost) if total_cost else None,
|
||||
)
|
||||
|
||||
def _parse_tool_calls(self, tool_calls_data: Any) -> Optional[List[ToolCall]]:
|
||||
"""
|
||||
Parse tool calls from API response.
|
||||
|
||||
Args:
|
||||
tool_calls_data: Raw tool calls data
|
||||
|
||||
Returns:
|
||||
List of ToolCall objects or None
|
||||
"""
|
||||
if not tool_calls_data:
|
||||
return None
|
||||
|
||||
tool_calls = []
|
||||
for tc in tool_calls_data:
|
||||
# Handle both attribute and dict access
|
||||
if hasattr(tc, "id"):
|
||||
tc_id = tc.id
|
||||
tc_type = getattr(tc, "type", "function")
|
||||
func = tc.function
|
||||
func_name = func.name
|
||||
func_args = func.arguments
|
||||
else:
|
||||
tc_id = tc.get("id", "")
|
||||
tc_type = tc.get("type", "function")
|
||||
func = tc.get("function", {})
|
||||
func_name = func.get("name", "")
|
||||
func_args = func.get("arguments", "{}")
|
||||
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
id=tc_id,
|
||||
type=tc_type,
|
||||
function=ToolFunction(name=func_name, arguments=func_args),
|
||||
)
|
||||
)
|
||||
|
||||
return tool_calls if tool_calls else None
|
||||
|
||||
def _parse_response(self, response: Any) -> ChatResponse:
|
||||
"""
|
||||
Parse API response into ChatResponse.
|
||||
|
||||
Args:
|
||||
response: Raw API response
|
||||
|
||||
Returns:
|
||||
Parsed ChatResponse
|
||||
"""
|
||||
choices = []
|
||||
for choice in response.choices:
|
||||
msg = choice.message
|
||||
message = ChatMessage(
|
||||
role=msg.role if hasattr(msg, "role") else "assistant",
|
||||
content=msg.content if hasattr(msg, "content") else None,
|
||||
tool_calls=self._parse_tool_calls(
|
||||
getattr(msg, "tool_calls", None)
|
||||
),
|
||||
)
|
||||
choices.append(
|
||||
ChatResponseChoice(
|
||||
index=choice.index if hasattr(choice, "index") else 0,
|
||||
message=message,
|
||||
finish_reason=getattr(choice, "finish_reason", None),
|
||||
)
|
||||
)
|
||||
|
||||
return ChatResponse(
|
||||
id=response.id if hasattr(response, "id") else "",
|
||||
choices=choices,
|
||||
usage=self._parse_usage(getattr(response, "usage", None)),
|
||||
model=getattr(response, "model", None),
|
||||
created=getattr(response, "created", None),
|
||||
)
|
||||
|
||||
def chat(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[ChatMessage],
|
||||
stream: bool = False,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
transforms: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[ChatResponse, Iterator[StreamChunk]]:
|
||||
"""
|
||||
Send a chat completion request to OpenRouter.
|
||||
|
||||
Args:
|
||||
model: Model ID to use
|
||||
messages: List of chat messages
|
||||
stream: Whether to stream the response
|
||||
max_tokens: Maximum tokens in response
|
||||
temperature: Sampling temperature (0-2)
|
||||
tools: List of tool definitions for function calling
|
||||
tool_choice: How to handle tool selection ("auto", "none", etc.)
|
||||
transforms: List of transforms (e.g., ["middle-out"])
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
ChatResponse for non-streaming, Iterator[StreamChunk] for streaming
|
||||
"""
|
||||
# Build request parameters
|
||||
params: Dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": self._convert_messages(messages),
|
||||
"stream": stream,
|
||||
"http_headers": self._get_headers(),
|
||||
}
|
||||
|
||||
# Request usage stats in streaming responses
|
||||
if stream:
|
||||
params["stream_options"] = {"include_usage": True}
|
||||
|
||||
if max_tokens is not None:
|
||||
params["max_tokens"] = max_tokens
|
||||
|
||||
if temperature is not None:
|
||||
params["temperature"] = temperature
|
||||
|
||||
if tools:
|
||||
params["tools"] = tools
|
||||
params["tool_choice"] = tool_choice or "auto"
|
||||
|
||||
if transforms:
|
||||
params["transforms"] = transforms
|
||||
|
||||
# Add any additional parameters
|
||||
params.update(kwargs)
|
||||
|
||||
self.logger.debug(f"Sending chat request to model {model}")
|
||||
|
||||
try:
|
||||
response = self.client.chat.send(**params)
|
||||
|
||||
if stream:
|
||||
return self._stream_response(response)
|
||||
else:
|
||||
return self._parse_response(response)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Chat request failed: {e}")
|
||||
raise
|
||||
|
||||
def _stream_response(self, response: Any) -> Iterator[StreamChunk]:
|
||||
"""
|
||||
Process a streaming response.
|
||||
|
||||
Args:
|
||||
response: Streaming response from API
|
||||
|
||||
Yields:
|
||||
StreamChunk objects
|
||||
"""
|
||||
last_usage = None
|
||||
|
||||
try:
|
||||
for chunk in response:
|
||||
# Check for errors
|
||||
if hasattr(chunk, "error") and chunk.error:
|
||||
yield StreamChunk(
|
||||
id=getattr(chunk, "id", ""),
|
||||
error=chunk.error.message if hasattr(chunk.error, "message") else str(chunk.error),
|
||||
)
|
||||
return
|
||||
|
||||
# Extract delta content
|
||||
delta_content = None
|
||||
finish_reason = None
|
||||
|
||||
if hasattr(chunk, "choices") and chunk.choices:
|
||||
choice = chunk.choices[0]
|
||||
if hasattr(choice, "delta"):
|
||||
delta = choice.delta
|
||||
if hasattr(delta, "content") and delta.content:
|
||||
delta_content = delta.content
|
||||
finish_reason = getattr(choice, "finish_reason", None)
|
||||
|
||||
# Track usage from last chunk
|
||||
if hasattr(chunk, "usage") and chunk.usage:
|
||||
last_usage = self._parse_usage(chunk.usage)
|
||||
|
||||
yield StreamChunk(
|
||||
id=getattr(chunk, "id", ""),
|
||||
delta_content=delta_content,
|
||||
finish_reason=finish_reason,
|
||||
usage=last_usage if finish_reason else None,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Stream error: {e}")
|
||||
yield StreamChunk(id="", error=str(e))
|
||||
|
||||
async def chat_async(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[ChatMessage],
|
||||
stream: bool = False,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[ChatResponse, AsyncIterator[StreamChunk]]:
|
||||
"""
|
||||
Send an async chat completion request.
|
||||
|
||||
Note: Currently wraps the sync implementation.
|
||||
TODO: Implement true async support when OpenRouter SDK supports it.
|
||||
|
||||
Args:
|
||||
model: Model ID to use
|
||||
messages: List of chat messages
|
||||
stream: Whether to stream the response
|
||||
max_tokens: Maximum tokens in response
|
||||
temperature: Sampling temperature
|
||||
tools: List of tool definitions
|
||||
tool_choice: Tool selection mode
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
ChatResponse for non-streaming, AsyncIterator for streaming
|
||||
"""
|
||||
# For now, use sync implementation
|
||||
# TODO: Add true async when SDK supports it
|
||||
result = self.chat(
|
||||
model=model,
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if stream and isinstance(result, Iterator):
|
||||
# Convert sync iterator to async
|
||||
async def async_iter() -> AsyncIterator[StreamChunk]:
|
||||
for chunk in result:
|
||||
yield chunk
|
||||
|
||||
return async_iter()
|
||||
|
||||
return result
|
||||
|
||||
def get_credits(self) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get OpenRouter account credit information.
|
||||
|
||||
Returns:
|
||||
Dict with credit info:
|
||||
- total_credits: Total credits purchased
|
||||
- used_credits: Credits used
|
||||
- credits_left: Remaining credits
|
||||
|
||||
Raises:
|
||||
Exception: If API request fails
|
||||
"""
|
||||
if not self.api_key:
|
||||
return None
|
||||
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{self.base_url}/credits",
|
||||
headers=self._get_headers(),
|
||||
timeout=10,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json().get("data", {})
|
||||
total_credits = float(data.get("total_credits", 0))
|
||||
total_usage = float(data.get("total_usage", 0))
|
||||
credits_left = total_credits - total_usage
|
||||
|
||||
return {
|
||||
"total_credits": total_credits,
|
||||
"used_credits": total_usage,
|
||||
"credits_left": credits_left,
|
||||
"total_credits_formatted": f"${total_credits:.2f}",
|
||||
"used_credits_formatted": f"${total_usage:.2f}",
|
||||
"credits_left_formatted": f"${credits_left:.2f}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to fetch credits: {e}")
|
||||
return None
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear the models cache to force a refresh."""
|
||||
self._models_cache = None
|
||||
self._raw_models_cache = None
|
||||
self.logger.debug("Models cache cleared")
|
||||
|
||||
def get_effective_model_id(self, model_id: str, online_enabled: bool) -> str:
|
||||
"""
|
||||
Get the effective model ID with online suffix if needed.
|
||||
|
||||
Args:
|
||||
model_id: Base model ID
|
||||
online_enabled: Whether online mode is enabled
|
||||
|
||||
Returns:
|
||||
Model ID with :online suffix if applicable
|
||||
"""
|
||||
if online_enabled and not model_id.endswith(":online"):
|
||||
return f"{model_id}:online"
|
||||
return model_id
|
||||
|
||||
def estimate_cost(
|
||||
self,
|
||||
model_id: str,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
) -> float:
|
||||
"""
|
||||
Estimate the cost for a completion.
|
||||
|
||||
Args:
|
||||
model_id: Model ID
|
||||
input_tokens: Number of input tokens
|
||||
output_tokens: Number of output tokens
|
||||
|
||||
Returns:
|
||||
Estimated cost in USD
|
||||
"""
|
||||
model = self.get_model(model_id)
|
||||
if model and model.pricing:
|
||||
input_cost = model.pricing.get("prompt", 0) * input_tokens / 1_000_000
|
||||
output_cost = model.pricing.get("completion", 0) * output_tokens / 1_000_000
|
||||
return input_cost + output_cost
|
||||
|
||||
# Fallback to default pricing if model not found
|
||||
from oai.constants import MODEL_PRICING
|
||||
|
||||
input_cost = MODEL_PRICING["input"] * input_tokens / 1_000_000
|
||||
output_cost = MODEL_PRICING["output"] * output_tokens / 1_000_000
|
||||
return input_cost + output_cost
|
||||
@@ -1,2 +0,0 @@
|
||||
# Marker file for PEP 561
|
||||
# This package supports type checking
|
||||
@@ -1,5 +0,0 @@
|
||||
"""Textual TUI interface for oAI."""
|
||||
|
||||
from oai.tui.app import oAIChatApp
|
||||
|
||||
__all__ = ["oAIChatApp"]
|
||||
1055
oai/tui/app.py
1055
oai/tui/app.py
File diff suppressed because it is too large
Load Diff
@@ -1,21 +0,0 @@
|
||||
"""TUI screens for oAI."""
|
||||
|
||||
from oai.tui.screens.config_screen import ConfigScreen
|
||||
from oai.tui.screens.conversation_selector import ConversationSelectorScreen
|
||||
from oai.tui.screens.credits_screen import CreditsScreen
|
||||
from oai.tui.screens.dialogs import AlertDialog, ConfirmDialog, InputDialog
|
||||
from oai.tui.screens.help_screen import HelpScreen
|
||||
from oai.tui.screens.model_selector import ModelSelectorScreen
|
||||
from oai.tui.screens.stats_screen import StatsScreen
|
||||
|
||||
__all__ = [
|
||||
"AlertDialog",
|
||||
"ConfirmDialog",
|
||||
"ConfigScreen",
|
||||
"ConversationSelectorScreen",
|
||||
"CreditsScreen",
|
||||
"InputDialog",
|
||||
"HelpScreen",
|
||||
"ModelSelectorScreen",
|
||||
"StatsScreen",
|
||||
]
|
||||
@@ -1,107 +0,0 @@
|
||||
"""Configuration screen for oAI TUI."""
|
||||
|
||||
from textual.app import ComposeResult
|
||||
from textual.containers import Container, Vertical
|
||||
from textual.screen import ModalScreen
|
||||
from textual.widgets import Button, Static
|
||||
|
||||
from oai.config.settings import Settings
|
||||
|
||||
|
||||
class ConfigScreen(ModalScreen[None]):
|
||||
"""Modal screen displaying configuration settings."""
|
||||
|
||||
DEFAULT_CSS = """
|
||||
ConfigScreen {
|
||||
align: center middle;
|
||||
}
|
||||
|
||||
ConfigScreen > Container {
|
||||
width: 70;
|
||||
height: auto;
|
||||
background: #1e1e1e;
|
||||
border: solid #555555;
|
||||
}
|
||||
|
||||
ConfigScreen .header {
|
||||
dock: top;
|
||||
width: 100%;
|
||||
height: auto;
|
||||
background: #2d2d2d;
|
||||
color: #cccccc;
|
||||
padding: 0 2;
|
||||
}
|
||||
|
||||
ConfigScreen .content {
|
||||
width: 100%;
|
||||
height: auto;
|
||||
background: #1e1e1e;
|
||||
padding: 2;
|
||||
color: #cccccc;
|
||||
}
|
||||
|
||||
ConfigScreen .footer {
|
||||
dock: bottom;
|
||||
width: 100%;
|
||||
height: auto;
|
||||
background: #2d2d2d;
|
||||
padding: 1 2;
|
||||
align: center middle;
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, settings: Settings):
|
||||
super().__init__()
|
||||
self.settings = settings
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
"""Compose the screen."""
|
||||
with Container():
|
||||
yield Static("[bold]Configuration[/]", classes="header")
|
||||
with Vertical(classes="content"):
|
||||
yield Static(self._get_config_text(), markup=True)
|
||||
with Vertical(classes="footer"):
|
||||
yield Button("Close", id="close", variant="primary")
|
||||
|
||||
def _get_config_text(self) -> str:
|
||||
"""Generate the configuration text."""
|
||||
from oai.constants import DEFAULT_SYSTEM_PROMPT
|
||||
|
||||
# API Key display
|
||||
api_key_display = "***" + self.settings.api_key[-4:] if self.settings.api_key else "Not set"
|
||||
|
||||
# System prompt display
|
||||
if self.settings.default_system_prompt is None:
|
||||
system_prompt_display = f"[default] {DEFAULT_SYSTEM_PROMPT[:40]}..."
|
||||
elif self.settings.default_system_prompt == "":
|
||||
system_prompt_display = "[blank]"
|
||||
else:
|
||||
prompt = self.settings.default_system_prompt
|
||||
system_prompt_display = prompt[:50] + "..." if len(prompt) > 50 else prompt
|
||||
|
||||
return f"""
|
||||
[bold cyan]═══ CONFIGURATION ═══[/]
|
||||
|
||||
[bold]API Key:[/] {api_key_display}
|
||||
[bold]Base URL:[/] {self.settings.base_url}
|
||||
[bold]Default Model:[/] {self.settings.default_model or "Not set"}
|
||||
|
||||
[bold]System Prompt:[/] {system_prompt_display}
|
||||
|
||||
[bold]Streaming:[/] {"on" if self.settings.stream_enabled else "off"}
|
||||
[bold]Cost Warning:[/] ${self.settings.cost_warning_threshold:.4f}
|
||||
[bold]Max Tokens:[/] {self.settings.max_tokens}
|
||||
[bold]Default Online:[/] {"on" if self.settings.default_online_mode else "off"}
|
||||
[bold]Log Level:[/] {self.settings.log_level}
|
||||
|
||||
[dim]Use /config [setting] [value] to modify settings[/]
|
||||
"""
|
||||
|
||||
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||
"""Handle button press."""
|
||||
self.dismiss()
|
||||
|
||||
def on_key(self, event) -> None:
|
||||
"""Handle keyboard shortcuts."""
|
||||
if event.key in ("escape", "enter"):
|
||||
self.dismiss()
|
||||
@@ -1,205 +0,0 @@
|
||||
"""Conversation selector screen for oAI TUI."""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from textual.app import ComposeResult
|
||||
from textual.containers import Container, Vertical
|
||||
from textual.screen import ModalScreen
|
||||
from textual.widgets import Button, DataTable, Input, Static
|
||||
|
||||
|
||||
class ConversationSelectorScreen(ModalScreen[Optional[dict]]):
|
||||
"""Modal screen for selecting a saved conversation."""
|
||||
|
||||
DEFAULT_CSS = """
|
||||
ConversationSelectorScreen {
|
||||
align: center middle;
|
||||
}
|
||||
|
||||
ConversationSelectorScreen > Container {
|
||||
width: 80%;
|
||||
height: 70%;
|
||||
background: #1e1e1e;
|
||||
border: solid #555555;
|
||||
layout: vertical;
|
||||
}
|
||||
|
||||
ConversationSelectorScreen .header {
|
||||
height: 3;
|
||||
width: 100%;
|
||||
background: #2d2d2d;
|
||||
color: #cccccc;
|
||||
padding: 0 2;
|
||||
content-align: center middle;
|
||||
}
|
||||
|
||||
ConversationSelectorScreen .search-input {
|
||||
height: 3;
|
||||
width: 100%;
|
||||
background: #2a2a2a;
|
||||
border: solid #555555;
|
||||
margin: 0 0 1 0;
|
||||
}
|
||||
|
||||
ConversationSelectorScreen .search-input:focus {
|
||||
border: solid #888888;
|
||||
}
|
||||
|
||||
ConversationSelectorScreen DataTable {
|
||||
height: 1fr;
|
||||
width: 100%;
|
||||
background: #1e1e1e;
|
||||
border: solid #555555;
|
||||
}
|
||||
|
||||
ConversationSelectorScreen .footer {
|
||||
height: 5;
|
||||
width: 100%;
|
||||
background: #2d2d2d;
|
||||
padding: 1 2;
|
||||
align: center middle;
|
||||
}
|
||||
|
||||
ConversationSelectorScreen Button {
|
||||
margin: 0 1;
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, conversations: List[dict]):
|
||||
super().__init__()
|
||||
self.all_conversations = conversations
|
||||
self.filtered_conversations = conversations
|
||||
self.selected_conversation: Optional[dict] = None
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
"""Compose the screen."""
|
||||
with Container():
|
||||
yield Static(
|
||||
f"[bold]Load Conversation[/] [dim]({len(self.all_conversations)} saved)[/]",
|
||||
classes="header"
|
||||
)
|
||||
yield Input(placeholder="Search conversations...", id="search-input", classes="search-input")
|
||||
yield DataTable(id="conv-table", cursor_type="row", show_header=True, zebra_stripes=True)
|
||||
with Vertical(classes="footer"):
|
||||
yield Button("Load", id="load", variant="success")
|
||||
yield Button("Cancel", id="cancel", variant="error")
|
||||
|
||||
def on_mount(self) -> None:
|
||||
"""Initialize the table when mounted."""
|
||||
table = self.query_one("#conv-table", DataTable)
|
||||
|
||||
# Add columns
|
||||
table.add_column("#", width=5)
|
||||
table.add_column("Name", width=40)
|
||||
table.add_column("Messages", width=12)
|
||||
table.add_column("Last Saved", width=20)
|
||||
|
||||
# Populate table
|
||||
self._populate_table()
|
||||
|
||||
# Focus table if list is small (fits on screen), otherwise focus search
|
||||
if len(self.all_conversations) <= 10:
|
||||
table.focus()
|
||||
else:
|
||||
search_input = self.query_one("#search-input", Input)
|
||||
search_input.focus()
|
||||
|
||||
def _populate_table(self) -> None:
|
||||
"""Populate the table with conversations."""
|
||||
table = self.query_one("#conv-table", DataTable)
|
||||
table.clear()
|
||||
|
||||
for idx, conv in enumerate(self.filtered_conversations, 1):
|
||||
name = conv.get("name", "Unknown")
|
||||
message_count = str(conv.get("message_count", 0))
|
||||
last_saved = conv.get("last_saved", "Unknown")
|
||||
|
||||
# Format timestamp if it's a full datetime
|
||||
if "T" in last_saved or len(last_saved) > 20:
|
||||
try:
|
||||
# Truncate to just date and time
|
||||
last_saved = last_saved[:19].replace("T", " ")
|
||||
except:
|
||||
pass
|
||||
|
||||
table.add_row(
|
||||
str(idx),
|
||||
name,
|
||||
message_count,
|
||||
last_saved,
|
||||
key=str(idx)
|
||||
)
|
||||
|
||||
def on_input_changed(self, event: Input.Changed) -> None:
|
||||
"""Filter conversations based on search input."""
|
||||
if event.input.id != "search-input":
|
||||
return
|
||||
|
||||
search_term = event.value.lower()
|
||||
|
||||
if not search_term:
|
||||
self.filtered_conversations = self.all_conversations
|
||||
else:
|
||||
self.filtered_conversations = [
|
||||
c for c in self.all_conversations
|
||||
if search_term in c.get("name", "").lower()
|
||||
]
|
||||
|
||||
self._populate_table()
|
||||
|
||||
def on_data_table_row_selected(self, event: DataTable.RowSelected) -> None:
|
||||
"""Handle row selection (click)."""
|
||||
try:
|
||||
row_index = int(event.row_key.value) - 1
|
||||
if 0 <= row_index < len(self.filtered_conversations):
|
||||
self.selected_conversation = self.filtered_conversations[row_index]
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
|
||||
def on_data_table_row_highlighted(self, event) -> None:
|
||||
"""Handle row highlight (arrow key navigation)."""
|
||||
try:
|
||||
table = self.query_one("#conv-table", DataTable)
|
||||
if table.cursor_row is not None:
|
||||
row_data = table.get_row_at(table.cursor_row)
|
||||
if row_data:
|
||||
row_index = int(row_data[0]) - 1
|
||||
if 0 <= row_index < len(self.filtered_conversations):
|
||||
self.selected_conversation = self.filtered_conversations[row_index]
|
||||
except (ValueError, IndexError, AttributeError):
|
||||
pass
|
||||
|
||||
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||
"""Handle button press."""
|
||||
if event.button.id == "load":
|
||||
if self.selected_conversation:
|
||||
self.dismiss(self.selected_conversation)
|
||||
else:
|
||||
self.dismiss(None)
|
||||
else:
|
||||
self.dismiss(None)
|
||||
|
||||
def on_key(self, event) -> None:
|
||||
"""Handle keyboard shortcuts."""
|
||||
if event.key == "escape":
|
||||
self.dismiss(None)
|
||||
elif event.key == "enter":
|
||||
# If in search input, move to table
|
||||
search_input = self.query_one("#search-input", Input)
|
||||
if search_input.has_focus:
|
||||
table = self.query_one("#conv-table", DataTable)
|
||||
table.focus()
|
||||
# If in table, select current row
|
||||
else:
|
||||
table = self.query_one("#conv-table", DataTable)
|
||||
if table.cursor_row is not None:
|
||||
try:
|
||||
row_data = table.get_row_at(table.cursor_row)
|
||||
if row_data:
|
||||
row_index = int(row_data[0]) - 1
|
||||
if 0 <= row_index < len(self.filtered_conversations):
|
||||
selected = self.filtered_conversations[row_index]
|
||||
self.dismiss(selected)
|
||||
except (ValueError, IndexError, AttributeError):
|
||||
if self.selected_conversation:
|
||||
self.dismiss(self.selected_conversation)
|
||||
@@ -1,125 +0,0 @@
|
||||
"""Credits screen for oAI TUI."""
|
||||
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from textual.app import ComposeResult
|
||||
from textual.containers import Container, Vertical
|
||||
from textual.screen import ModalScreen
|
||||
from textual.widgets import Button, Static
|
||||
|
||||
from oai.core.client import AIClient
|
||||
|
||||
|
||||
class CreditsScreen(ModalScreen[None]):
|
||||
"""Modal screen displaying account credits."""
|
||||
|
||||
DEFAULT_CSS = """
|
||||
CreditsScreen {
|
||||
align: center middle;
|
||||
}
|
||||
|
||||
CreditsScreen > Container {
|
||||
width: 60;
|
||||
height: auto;
|
||||
background: #1e1e1e;
|
||||
border: solid #555555;
|
||||
}
|
||||
|
||||
CreditsScreen .header {
|
||||
dock: top;
|
||||
width: 100%;
|
||||
height: auto;
|
||||
background: #2d2d2d;
|
||||
color: #cccccc;
|
||||
padding: 0 2;
|
||||
}
|
||||
|
||||
CreditsScreen .content {
|
||||
width: 100%;
|
||||
height: auto;
|
||||
background: #1e1e1e;
|
||||
padding: 2;
|
||||
color: #cccccc;
|
||||
}
|
||||
|
||||
CreditsScreen .footer {
|
||||
dock: bottom;
|
||||
width: 100%;
|
||||
height: auto;
|
||||
background: #2d2d2d;
|
||||
padding: 1 2;
|
||||
align: center middle;
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, client: AIClient):
|
||||
super().__init__()
|
||||
self.client = client
|
||||
self.credits_data: Optional[Dict[str, Any]] = None
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
"""Compose the screen."""
|
||||
with Container():
|
||||
yield Static("[bold]Account Credits[/]", classes="header")
|
||||
with Vertical(classes="content"):
|
||||
yield Static("[dim]Loading...[/]", id="credits-content", markup=True)
|
||||
with Vertical(classes="footer"):
|
||||
yield Button("Close", id="close", variant="primary")
|
||||
|
||||
def on_mount(self) -> None:
|
||||
"""Fetch credits when mounted."""
|
||||
self.fetch_credits()
|
||||
|
||||
def fetch_credits(self) -> None:
|
||||
"""Fetch and display credits information."""
|
||||
try:
|
||||
self.credits_data = self.client.provider.get_credits()
|
||||
content = self.query_one("#credits-content", Static)
|
||||
content.update(self._get_credits_text())
|
||||
except Exception as e:
|
||||
content = self.query_one("#credits-content", Static)
|
||||
content.update(f"[red]Error fetching credits:[/]\n{str(e)}")
|
||||
|
||||
def _get_credits_text(self) -> str:
|
||||
"""Generate the credits text."""
|
||||
if not self.credits_data:
|
||||
return "[yellow]No credit information available[/]"
|
||||
|
||||
total = self.credits_data.get("total_credits", 0)
|
||||
used = self.credits_data.get("used_credits", 0)
|
||||
remaining = self.credits_data.get("credits_left", 0)
|
||||
|
||||
# Calculate percentage used
|
||||
if total > 0:
|
||||
percent_used = (used / total) * 100
|
||||
percent_remaining = (remaining / total) * 100
|
||||
else:
|
||||
percent_used = 0
|
||||
percent_remaining = 0
|
||||
|
||||
# Color code based on remaining credits
|
||||
if percent_remaining > 50:
|
||||
remaining_color = "green"
|
||||
elif percent_remaining > 20:
|
||||
remaining_color = "yellow"
|
||||
else:
|
||||
remaining_color = "red"
|
||||
|
||||
return f"""
|
||||
[bold cyan]═══ OPENROUTER CREDITS ═══[/]
|
||||
|
||||
[bold]Total Credits:[/] ${total:.2f}
|
||||
[bold]Used:[/] ${used:.2f} [dim]({percent_used:.1f}%)[/]
|
||||
[bold]Remaining:[/] [{remaining_color}]${remaining:.2f}[/] [dim]({percent_remaining:.1f}%)[/]
|
||||
|
||||
[dim]Visit openrouter.ai to add more credits[/]
|
||||
"""
|
||||
|
||||
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||
"""Handle button press."""
|
||||
self.dismiss()
|
||||
|
||||
def on_key(self, event) -> None:
|
||||
"""Handle keyboard shortcuts."""
|
||||
if event.key in ("escape", "enter"):
|
||||
self.dismiss()
|
||||
@@ -1,236 +0,0 @@
|
||||
"""Modal dialog screens for oAI TUI."""
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
from textual.app import ComposeResult
|
||||
from textual.containers import Container, Horizontal, Vertical
|
||||
from textual.screen import ModalScreen
|
||||
from textual.widgets import Button, Input, Label, Static
|
||||
|
||||
|
||||
class ConfirmDialog(ModalScreen[bool]):
|
||||
"""A confirmation dialog with Yes/No buttons."""
|
||||
|
||||
DEFAULT_CSS = """
|
||||
ConfirmDialog {
|
||||
align: center middle;
|
||||
}
|
||||
|
||||
ConfirmDialog > Container {
|
||||
width: 60;
|
||||
height: auto;
|
||||
background: #2d2d2d;
|
||||
border: solid #555555;
|
||||
padding: 2;
|
||||
}
|
||||
|
||||
ConfirmDialog Label {
|
||||
width: 100%;
|
||||
content-align: center middle;
|
||||
margin-bottom: 2;
|
||||
color: #cccccc;
|
||||
}
|
||||
|
||||
ConfirmDialog Horizontal {
|
||||
width: 100%;
|
||||
height: auto;
|
||||
align: center middle;
|
||||
}
|
||||
|
||||
ConfirmDialog Button {
|
||||
margin: 0 1;
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
title: str = "Confirm",
|
||||
yes_label: str = "Yes",
|
||||
no_label: str = "No",
|
||||
):
|
||||
super().__init__()
|
||||
self.message = message
|
||||
self.title = title
|
||||
self.yes_label = yes_label
|
||||
self.no_label = no_label
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
"""Compose the dialog."""
|
||||
with Container():
|
||||
yield Static(f"[bold]{self.title}[/]", classes="dialog-title")
|
||||
yield Label(self.message)
|
||||
with Horizontal():
|
||||
yield Button(self.yes_label, id="yes", variant="success")
|
||||
yield Button(self.no_label, id="no", variant="error")
|
||||
|
||||
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||
"""Handle button press."""
|
||||
if event.button.id == "yes":
|
||||
self.dismiss(True)
|
||||
else:
|
||||
self.dismiss(False)
|
||||
|
||||
def on_key(self, event) -> None:
|
||||
"""Handle keyboard shortcuts."""
|
||||
if event.key == "escape":
|
||||
self.dismiss(False)
|
||||
elif event.key == "enter":
|
||||
self.dismiss(True)
|
||||
|
||||
|
||||
class InputDialog(ModalScreen[Optional[str]]):
|
||||
"""An input dialog for text entry."""
|
||||
|
||||
DEFAULT_CSS = """
|
||||
InputDialog {
|
||||
align: center middle;
|
||||
}
|
||||
|
||||
InputDialog > Container {
|
||||
width: 70;
|
||||
height: auto;
|
||||
background: #2d2d2d;
|
||||
border: solid #555555;
|
||||
padding: 2;
|
||||
}
|
||||
|
||||
InputDialog Label {
|
||||
width: 100%;
|
||||
margin-bottom: 1;
|
||||
color: #cccccc;
|
||||
}
|
||||
|
||||
InputDialog Input {
|
||||
width: 100%;
|
||||
margin-bottom: 2;
|
||||
background: #3a3a3a;
|
||||
border: solid #555555;
|
||||
}
|
||||
|
||||
InputDialog Input:focus {
|
||||
border: solid #888888;
|
||||
}
|
||||
|
||||
InputDialog Horizontal {
|
||||
width: 100%;
|
||||
height: auto;
|
||||
align: center middle;
|
||||
}
|
||||
|
||||
InputDialog Button {
|
||||
margin: 0 1;
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
title: str = "Input",
|
||||
default: str = "",
|
||||
placeholder: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.message = message
|
||||
self.title = title
|
||||
self.default = default
|
||||
self.placeholder = placeholder
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
"""Compose the dialog."""
|
||||
with Container():
|
||||
yield Static(f"[bold]{self.title}[/]", classes="dialog-title")
|
||||
yield Label(self.message)
|
||||
yield Input(
|
||||
value=self.default,
|
||||
placeholder=self.placeholder,
|
||||
id="input-field"
|
||||
)
|
||||
with Horizontal():
|
||||
yield Button("OK", id="ok", variant="primary")
|
||||
yield Button("Cancel", id="cancel")
|
||||
|
||||
def on_mount(self) -> None:
|
||||
"""Focus the input field when mounted."""
|
||||
input_field = self.query_one("#input-field", Input)
|
||||
input_field.focus()
|
||||
|
||||
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||
"""Handle button press."""
|
||||
if event.button.id == "ok":
|
||||
input_field = self.query_one("#input-field", Input)
|
||||
self.dismiss(input_field.value)
|
||||
else:
|
||||
self.dismiss(None)
|
||||
|
||||
def on_input_submitted(self, event: Input.Submitted) -> None:
|
||||
"""Handle Enter key in input field."""
|
||||
self.dismiss(event.value)
|
||||
|
||||
def on_key(self, event) -> None:
|
||||
"""Handle keyboard shortcuts."""
|
||||
if event.key == "escape":
|
||||
self.dismiss(None)
|
||||
|
||||
|
||||
class AlertDialog(ModalScreen[None]):
|
||||
"""A simple alert/message dialog."""
|
||||
|
||||
DEFAULT_CSS = """
|
||||
AlertDialog {
|
||||
align: center middle;
|
||||
}
|
||||
|
||||
AlertDialog > Container {
|
||||
width: 60;
|
||||
height: auto;
|
||||
background: #2d2d2d;
|
||||
border: solid #555555;
|
||||
padding: 2;
|
||||
}
|
||||
|
||||
AlertDialog Label {
|
||||
width: 100%;
|
||||
content-align: center middle;
|
||||
margin-bottom: 2;
|
||||
color: #cccccc;
|
||||
}
|
||||
|
||||
AlertDialog Horizontal {
|
||||
width: 100%;
|
||||
height: auto;
|
||||
align: center middle;
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, message: str, title: str = "Alert", variant: str = "default"):
|
||||
super().__init__()
|
||||
self.message = message
|
||||
self.title = title
|
||||
self.variant = variant
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
"""Compose the dialog."""
|
||||
# Choose color based on variant (using design system)
|
||||
color = "$primary"
|
||||
if self.variant == "error":
|
||||
color = "$error"
|
||||
elif self.variant == "success":
|
||||
color = "$success"
|
||||
elif self.variant == "warning":
|
||||
color = "$warning"
|
||||
|
||||
with Container():
|
||||
yield Static(f"[bold {color}]{self.title}[/]", classes="dialog-title")
|
||||
yield Label(self.message)
|
||||
with Horizontal():
|
||||
yield Button("OK", id="ok", variant="primary")
|
||||
|
||||
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||
"""Handle button press."""
|
||||
self.dismiss()
|
||||
|
||||
def on_key(self, event) -> None:
|
||||
"""Handle keyboard shortcuts."""
|
||||
if event.key in ("escape", "enter"):
|
||||
self.dismiss()
|
||||
@@ -1,140 +0,0 @@
|
||||
"""Help screen for oAI TUI."""
|
||||
|
||||
from textual.app import ComposeResult
|
||||
from textual.containers import Container, Vertical
|
||||
from textual.screen import ModalScreen
|
||||
from textual.widgets import Button, Static
|
||||
|
||||
|
||||
class HelpScreen(ModalScreen[None]):
|
||||
"""Modal screen displaying help and commands."""
|
||||
|
||||
DEFAULT_CSS = """
|
||||
HelpScreen {
|
||||
align: center middle;
|
||||
}
|
||||
|
||||
HelpScreen > Container {
|
||||
width: 90%;
|
||||
height: 85%;
|
||||
background: #1e1e1e;
|
||||
border: solid #555555;
|
||||
}
|
||||
|
||||
HelpScreen .header {
|
||||
dock: top;
|
||||
width: 100%;
|
||||
height: auto;
|
||||
background: #2d2d2d;
|
||||
color: #cccccc;
|
||||
padding: 0 2;
|
||||
}
|
||||
|
||||
HelpScreen .content {
|
||||
height: 1fr;
|
||||
background: #1e1e1e;
|
||||
padding: 2;
|
||||
overflow-y: auto;
|
||||
color: #cccccc;
|
||||
}
|
||||
|
||||
HelpScreen .footer {
|
||||
dock: bottom;
|
||||
width: 100%;
|
||||
height: auto;
|
||||
background: #2d2d2d;
|
||||
padding: 1 2;
|
||||
align: center middle;
|
||||
}
|
||||
"""
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
"""Compose the screen."""
|
||||
with Container():
|
||||
yield Static("[bold]oAI Help & Commands[/]", classes="header")
|
||||
with Vertical(classes="content"):
|
||||
yield Static(self._get_help_text(), markup=True)
|
||||
with Vertical(classes="footer"):
|
||||
yield Button("Close", id="close", variant="primary")
|
||||
|
||||
def _get_help_text(self) -> str:
|
||||
"""Generate the help text."""
|
||||
return """
|
||||
[bold cyan]═══ KEYBOARD SHORTCUTS ═══[/]
|
||||
[bold]F1[/] Show this help (Ctrl+H may not work)
|
||||
[bold]F2[/] Open model selector (Ctrl+M may not work)
|
||||
[bold]F3[/] Copy last AI response to clipboard
|
||||
[bold]Ctrl+S[/] Show session statistics
|
||||
[bold]Ctrl+L[/] Clear chat display
|
||||
[bold]Ctrl+P[/] Show previous message
|
||||
[bold]Ctrl+N[/] Show next message
|
||||
[bold]Ctrl+Y[/] Copy last AI response (alternative to F3)
|
||||
[bold]Ctrl+Q[/] Quit application
|
||||
[bold]Up/Down[/] Navigate input history
|
||||
[bold]ESC[/] Close dialogs
|
||||
[dim]Note: Some Ctrl keys may be captured by your terminal[/]
|
||||
|
||||
[bold cyan]═══ SLASH COMMANDS ═══[/]
|
||||
[bold yellow]Session Control:[/]
|
||||
/reset Clear conversation history (with confirmation)
|
||||
/clear Clear the chat display
|
||||
/memory on/off Toggle conversation memory
|
||||
/online on/off Toggle online search mode
|
||||
/exit, /quit, /bye Exit the application
|
||||
|
||||
[bold yellow]Model & Configuration:[/]
|
||||
/model [search] Open model selector with optional search
|
||||
/config View configuration settings
|
||||
/config api Set API key (prompts for input)
|
||||
/config stream on Enable streaming responses
|
||||
/system [prompt] Set session system prompt
|
||||
/maxtoken [n] Set session token limit
|
||||
|
||||
[bold yellow]Conversation Management:[/]
|
||||
/save [name] Save current conversation
|
||||
/load [name] Load saved conversation (shows picker if no name)
|
||||
/list List all saved conversations
|
||||
/delete <name> Delete a saved conversation
|
||||
|
||||
[bold yellow]Export:[/]
|
||||
/export md [file] Export as Markdown
|
||||
/export json [file] Export as JSON
|
||||
/export html [file] Export as HTML
|
||||
|
||||
[bold yellow]History Navigation:[/]
|
||||
/prev Show previous message in history
|
||||
/next Show next message in history
|
||||
|
||||
[bold yellow]MCP (Model Context Protocol):[/]
|
||||
/mcp on Enable MCP file access
|
||||
/mcp off Disable MCP
|
||||
/mcp status Show MCP status
|
||||
/mcp add <path> Add folder for file access
|
||||
/mcp list List registered folders
|
||||
/mcp write Toggle write permissions
|
||||
|
||||
[bold yellow]Information & Utilities:[/]
|
||||
/help Show this help screen
|
||||
/stats Show session statistics
|
||||
/credits Check account credits
|
||||
/retry Retry last prompt
|
||||
/paste Paste from clipboard and send
|
||||
|
||||
[bold cyan]═══ TIPS ═══[/]
|
||||
• Type [bold]/[/] to see command suggestions with [bold]Tab[/] to autocomplete
|
||||
• Use [bold]Up/Down arrows[/] to navigate your input history
|
||||
• Type [bold]//[/] at start to escape commands (sends /help as literal message)
|
||||
• All messages support [bold]Markdown formatting[/] with syntax highlighting
|
||||
• Responses stream in real-time for better interactivity
|
||||
• Enable MCP to let AI access your local files and databases
|
||||
• Use [bold]F1[/] or [bold]F2[/] if Ctrl shortcuts don't work in your terminal
|
||||
"""
|
||||
|
||||
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||
"""Handle button press."""
|
||||
self.dismiss()
|
||||
|
||||
def on_key(self, event) -> None:
|
||||
"""Handle keyboard shortcuts."""
|
||||
if event.key in ("escape", "enter"):
|
||||
self.dismiss()
|
||||
@@ -1,254 +0,0 @@
|
||||
"""Model selector screen for oAI TUI."""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from textual.app import ComposeResult
|
||||
from textual.containers import Container, Vertical
|
||||
from textual.screen import ModalScreen
|
||||
from textual.widgets import Button, DataTable, Input, Label, Static
|
||||
|
||||
|
||||
class ModelSelectorScreen(ModalScreen[Optional[dict]]):
|
||||
"""Modal screen for selecting an AI model."""
|
||||
|
||||
DEFAULT_CSS = """
|
||||
ModelSelectorScreen {
|
||||
align: center middle;
|
||||
}
|
||||
|
||||
ModelSelectorScreen > Container {
|
||||
width: 90%;
|
||||
height: 85%;
|
||||
background: #1e1e1e;
|
||||
border: solid #555555;
|
||||
layout: vertical;
|
||||
}
|
||||
|
||||
ModelSelectorScreen .header {
|
||||
height: 3;
|
||||
width: 100%;
|
||||
background: #2d2d2d;
|
||||
color: #cccccc;
|
||||
padding: 0 2;
|
||||
content-align: center middle;
|
||||
}
|
||||
|
||||
ModelSelectorScreen .search-input {
|
||||
height: 3;
|
||||
width: 100%;
|
||||
background: #2a2a2a;
|
||||
border: solid #555555;
|
||||
margin: 0 0 1 0;
|
||||
}
|
||||
|
||||
ModelSelectorScreen .search-input:focus {
|
||||
border: solid #888888;
|
||||
}
|
||||
|
||||
ModelSelectorScreen DataTable {
|
||||
height: 1fr;
|
||||
width: 100%;
|
||||
background: #1e1e1e;
|
||||
border: solid #555555;
|
||||
}
|
||||
|
||||
ModelSelectorScreen .footer {
|
||||
height: 5;
|
||||
width: 100%;
|
||||
background: #2d2d2d;
|
||||
padding: 1 2;
|
||||
align: center middle;
|
||||
}
|
||||
|
||||
ModelSelectorScreen Button {
|
||||
margin: 0 1;
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, models: List[dict], current_model: Optional[str] = None):
|
||||
super().__init__()
|
||||
self.all_models = models
|
||||
self.filtered_models = models
|
||||
self.current_model = current_model
|
||||
self.selected_model: Optional[dict] = None
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
"""Compose the screen."""
|
||||
with Container():
|
||||
yield Static(
|
||||
f"[bold]Select Model[/] [dim]({len(self.all_models)} available)[/]",
|
||||
classes="header"
|
||||
)
|
||||
yield Input(placeholder="Search to filter models...", id="search-input", classes="search-input")
|
||||
yield DataTable(id="model-table", cursor_type="row", show_header=True, zebra_stripes=True)
|
||||
with Vertical(classes="footer"):
|
||||
yield Button("Select", id="select", variant="success")
|
||||
yield Button("Cancel", id="cancel", variant="error")
|
||||
|
||||
def on_mount(self) -> None:
|
||||
"""Initialize the table when mounted."""
|
||||
table = self.query_one("#model-table", DataTable)
|
||||
|
||||
# Add columns
|
||||
table.add_column("#", width=5)
|
||||
table.add_column("Model ID", width=35)
|
||||
table.add_column("Name", width=30)
|
||||
table.add_column("Context", width=10)
|
||||
table.add_column("Price", width=12)
|
||||
table.add_column("Img", width=4)
|
||||
table.add_column("Tools", width=6)
|
||||
table.add_column("Online", width=7)
|
||||
|
||||
# Populate table
|
||||
self._populate_table()
|
||||
|
||||
# Focus table if list is small (fits on screen), otherwise focus search
|
||||
if len(self.filtered_models) <= 20:
|
||||
table.focus()
|
||||
else:
|
||||
search_input = self.query_one("#search-input", Input)
|
||||
search_input.focus()
|
||||
|
||||
def _populate_table(self) -> None:
|
||||
"""Populate the table with models."""
|
||||
table = self.query_one("#model-table", DataTable)
|
||||
table.clear()
|
||||
|
||||
rows_added = 0
|
||||
for idx, model in enumerate(self.filtered_models, 1):
|
||||
try:
|
||||
model_id = model.get("id", "")
|
||||
name = model.get("name", "")
|
||||
context = str(model.get("context_length", "N/A"))
|
||||
|
||||
# Format pricing
|
||||
pricing = model.get("pricing", {})
|
||||
prompt_price = pricing.get("prompt", "0")
|
||||
completion_price = pricing.get("completion", "0")
|
||||
|
||||
# Convert to numbers and format
|
||||
try:
|
||||
prompt = float(prompt_price) * 1000000 # Convert to per 1M tokens
|
||||
completion = float(completion_price) * 1000000
|
||||
if prompt == 0 and completion == 0:
|
||||
price = "Free"
|
||||
else:
|
||||
price = f"${prompt:.2f}/${completion:.2f}"
|
||||
except:
|
||||
price = "N/A"
|
||||
|
||||
# Check capabilities
|
||||
architecture = model.get("architecture", {})
|
||||
modality = architecture.get("modality", "")
|
||||
supported_params = model.get("supported_parameters", [])
|
||||
|
||||
# Vision support: check if modality contains "image"
|
||||
supports_vision = "image" in modality
|
||||
|
||||
# Tool support: check if "tools" or "tool_choice" in supported_parameters
|
||||
supports_tools = "tools" in supported_params or "tool_choice" in supported_params
|
||||
|
||||
# Online support: check if model can use :online suffix (most models can)
|
||||
# Models that already have :online in their ID support it
|
||||
supports_online = ":online" in model_id or model_id not in ["openrouter/free"]
|
||||
|
||||
# Format capability indicators
|
||||
img_indicator = "✓" if supports_vision else "-"
|
||||
tools_indicator = "✓" if supports_tools else "-"
|
||||
web_indicator = "✓" if supports_online else "-"
|
||||
|
||||
# Add row
|
||||
table.add_row(
|
||||
str(idx),
|
||||
model_id,
|
||||
name,
|
||||
context,
|
||||
price,
|
||||
img_indicator,
|
||||
tools_indicator,
|
||||
web_indicator,
|
||||
key=str(idx)
|
||||
)
|
||||
rows_added += 1
|
||||
|
||||
except Exception:
|
||||
# Silently skip rows that fail to add
|
||||
pass
|
||||
|
||||
def on_input_changed(self, event: Input.Changed) -> None:
|
||||
"""Filter models based on search input."""
|
||||
if event.input.id != "search-input":
|
||||
return
|
||||
|
||||
search_term = event.value.lower()
|
||||
|
||||
if not search_term:
|
||||
self.filtered_models = self.all_models
|
||||
else:
|
||||
self.filtered_models = [
|
||||
m for m in self.all_models
|
||||
if search_term in m.get("id", "").lower()
|
||||
or search_term in m.get("name", "").lower()
|
||||
]
|
||||
|
||||
self._populate_table()
|
||||
|
||||
def on_data_table_row_selected(self, event: DataTable.RowSelected) -> None:
|
||||
"""Handle row selection (click or arrow navigation)."""
|
||||
try:
|
||||
row_index = int(event.row_key.value) - 1
|
||||
if 0 <= row_index < len(self.filtered_models):
|
||||
self.selected_model = self.filtered_models[row_index]
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
|
||||
def on_data_table_row_highlighted(self, event) -> None:
|
||||
"""Handle row highlight (arrow key navigation)."""
|
||||
try:
|
||||
table = self.query_one("#model-table", DataTable)
|
||||
if table.cursor_row is not None:
|
||||
row_data = table.get_row_at(table.cursor_row)
|
||||
if row_data:
|
||||
row_index = int(row_data[0]) - 1
|
||||
if 0 <= row_index < len(self.filtered_models):
|
||||
self.selected_model = self.filtered_models[row_index]
|
||||
except (ValueError, IndexError, AttributeError):
|
||||
pass
|
||||
|
||||
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||
"""Handle button press."""
|
||||
if event.button.id == "select":
|
||||
if self.selected_model:
|
||||
self.dismiss(self.selected_model)
|
||||
else:
|
||||
# No selection, dismiss without result
|
||||
self.dismiss(None)
|
||||
else:
|
||||
self.dismiss(None)
|
||||
|
||||
def on_key(self, event) -> None:
|
||||
"""Handle keyboard shortcuts."""
|
||||
if event.key == "escape":
|
||||
self.dismiss(None)
|
||||
elif event.key == "enter":
|
||||
# If in search input, move to table
|
||||
search_input = self.query_one("#search-input", Input)
|
||||
if search_input.has_focus:
|
||||
table = self.query_one("#model-table", DataTable)
|
||||
table.focus()
|
||||
# If in table or anywhere else, select current row
|
||||
else:
|
||||
table = self.query_one("#model-table", DataTable)
|
||||
# Get the currently highlighted row
|
||||
if table.cursor_row is not None:
|
||||
try:
|
||||
row_key = table.get_row_at(table.cursor_row)
|
||||
if row_key:
|
||||
row_index = int(row_key[0]) - 1
|
||||
if 0 <= row_index < len(self.filtered_models):
|
||||
selected = self.filtered_models[row_index]
|
||||
self.dismiss(selected)
|
||||
except (ValueError, IndexError, AttributeError):
|
||||
# Fall back to previously selected model
|
||||
if self.selected_model:
|
||||
self.dismiss(self.selected_model)
|
||||
@@ -1,129 +0,0 @@
|
||||
"""Statistics screen for oAI TUI."""
|
||||
|
||||
from textual.app import ComposeResult
|
||||
from textual.containers import Container, Vertical
|
||||
from textual.screen import ModalScreen
|
||||
from textual.widgets import Button, Static
|
||||
|
||||
from oai.core.session import ChatSession
|
||||
|
||||
|
||||
class StatsScreen(ModalScreen[None]):
|
||||
"""Modal screen displaying session statistics."""
|
||||
|
||||
DEFAULT_CSS = """
|
||||
StatsScreen {
|
||||
align: center middle;
|
||||
}
|
||||
|
||||
StatsScreen > Container {
|
||||
width: 70;
|
||||
height: auto;
|
||||
background: #1e1e1e;
|
||||
border: solid #555555;
|
||||
}
|
||||
|
||||
StatsScreen .header {
|
||||
dock: top;
|
||||
width: 100%;
|
||||
height: auto;
|
||||
background: #2d2d2d;
|
||||
color: #cccccc;
|
||||
padding: 0 2;
|
||||
}
|
||||
|
||||
StatsScreen .content {
|
||||
width: 100%;
|
||||
height: auto;
|
||||
background: #1e1e1e;
|
||||
padding: 2;
|
||||
color: #cccccc;
|
||||
}
|
||||
|
||||
StatsScreen .footer {
|
||||
dock: bottom;
|
||||
width: 100%;
|
||||
height: auto;
|
||||
background: #2d2d2d;
|
||||
padding: 1 2;
|
||||
align: center middle;
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, session: ChatSession):
|
||||
super().__init__()
|
||||
self.session = session
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
"""Compose the screen."""
|
||||
with Container():
|
||||
yield Static("[bold]Session Statistics[/]", classes="header")
|
||||
with Vertical(classes="content"):
|
||||
yield Static(self._get_stats_text(), markup=True)
|
||||
with Vertical(classes="footer"):
|
||||
yield Button("Close", id="close", variant="primary")
|
||||
|
||||
def _get_stats_text(self) -> str:
|
||||
"""Generate the statistics text."""
|
||||
stats = self.session.stats
|
||||
|
||||
# Calculate averages
|
||||
avg_input = stats.total_input_tokens // stats.message_count if stats.message_count > 0 else 0
|
||||
avg_output = stats.total_output_tokens // stats.message_count if stats.message_count > 0 else 0
|
||||
avg_cost = stats.total_cost / stats.message_count if stats.message_count > 0 else 0
|
||||
|
||||
# Get model info
|
||||
model_name = "None"
|
||||
model_context = "N/A"
|
||||
if self.session.selected_model:
|
||||
model_name = self.session.selected_model.get("name", "Unknown")
|
||||
model_context = str(self.session.selected_model.get("context_length", "N/A"))
|
||||
|
||||
# MCP status
|
||||
mcp_status = "Disabled"
|
||||
if self.session.mcp_manager and self.session.mcp_manager.enabled:
|
||||
mode = self.session.mcp_manager.mode
|
||||
if mode == "files":
|
||||
write = " (Write)" if self.session.mcp_manager.write_enabled else ""
|
||||
mcp_status = f"Enabled - Files{write}"
|
||||
elif mode == "database":
|
||||
db_idx = self.session.mcp_manager.selected_db_index
|
||||
if db_idx is not None:
|
||||
db_name = self.session.mcp_manager.databases[db_idx]["name"]
|
||||
mcp_status = f"Enabled - Database ({db_name})"
|
||||
|
||||
return f"""
|
||||
[bold cyan]═══ SESSION INFO ═══[/]
|
||||
[bold]Messages:[/] {stats.message_count}
|
||||
[bold]Current Model:[/] {model_name}
|
||||
[bold]Context Length:[/] {model_context}
|
||||
[bold]Memory:[/] {"Enabled" if self.session.memory_enabled else "Disabled"}
|
||||
[bold]Online Mode:[/] {"Enabled" if self.session.online_enabled else "Disabled"}
|
||||
[bold]MCP:[/] {mcp_status}
|
||||
|
||||
[bold cyan]═══ TOKEN USAGE ═══[/]
|
||||
[bold]Input Tokens:[/] {stats.total_input_tokens:,}
|
||||
[bold]Output Tokens:[/] {stats.total_output_tokens:,}
|
||||
[bold]Total Tokens:[/] {stats.total_tokens:,}
|
||||
|
||||
[bold]Avg Input/Msg:[/] {avg_input:,}
|
||||
[bold]Avg Output/Msg:[/] {avg_output:,}
|
||||
|
||||
[bold cyan]═══ COSTS ═══[/]
|
||||
[bold]Total Cost:[/] ${stats.total_cost:.6f}
|
||||
[bold]Avg Cost/Msg:[/] ${avg_cost:.6f}
|
||||
|
||||
[bold cyan]═══ HISTORY ═══[/]
|
||||
[bold]History Size:[/] {len(self.session.history)} entries
|
||||
[bold]Current Index:[/] {self.session.current_index + 1 if self.session.history else 0}
|
||||
[bold]Memory Start:[/] {self.session.memory_start_index + 1}
|
||||
"""
|
||||
|
||||
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||
"""Handle button press."""
|
||||
self.dismiss()
|
||||
|
||||
def on_key(self, event) -> None:
|
||||
"""Handle keyboard shortcuts."""
|
||||
if event.key in ("escape", "enter"):
|
||||
self.dismiss()
|
||||
@@ -1,169 +0,0 @@
|
||||
/* Textual CSS for oAI TUI - Using Textual Design System */
|
||||
|
||||
Screen {
|
||||
background: $background;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
Header {
|
||||
dock: top;
|
||||
height: auto;
|
||||
background: #2d2d2d;
|
||||
color: #cccccc;
|
||||
padding: 0 1;
|
||||
border-bottom: solid #555555;
|
||||
}
|
||||
|
||||
ChatDisplay {
|
||||
background: $background;
|
||||
border: none;
|
||||
padding: 1;
|
||||
scrollbar-background: $background;
|
||||
scrollbar-color: $primary;
|
||||
overflow-y: auto;
|
||||
}
|
||||
|
||||
UserMessageWidget {
|
||||
margin: 0 0 1 0;
|
||||
padding: 1;
|
||||
background: $surface;
|
||||
border-left: thick $success;
|
||||
height: auto;
|
||||
}
|
||||
|
||||
SystemMessageWidget {
|
||||
margin: 0 0 1 0;
|
||||
padding: 1;
|
||||
background: #2a2a2a;
|
||||
border-left: thick #888888;
|
||||
height: auto;
|
||||
color: #cccccc;
|
||||
}
|
||||
|
||||
AssistantMessageWidget {
|
||||
margin: 0 0 1 0;
|
||||
padding: 1;
|
||||
background: $panel;
|
||||
border-left: thick $accent;
|
||||
height: auto;
|
||||
}
|
||||
|
||||
#assistant-label {
|
||||
margin-bottom: 1;
|
||||
color: #cccccc;
|
||||
}
|
||||
|
||||
#assistant-content {
|
||||
height: auto;
|
||||
max-height: 100%;
|
||||
color: #cccccc;
|
||||
link-color: #888888;
|
||||
link-style: none;
|
||||
}
|
||||
|
||||
InputBar {
|
||||
dock: bottom;
|
||||
height: auto;
|
||||
background: #2d2d2d;
|
||||
align: center middle;
|
||||
border-top: solid #555555;
|
||||
padding: 1;
|
||||
}
|
||||
|
||||
#input-prefix {
|
||||
width: auto;
|
||||
padding: 0 1;
|
||||
content-align: center middle;
|
||||
color: #888888;
|
||||
}
|
||||
|
||||
#input-prefix.prefix-hidden {
|
||||
display: none;
|
||||
}
|
||||
|
||||
#chat-input {
|
||||
width: 85%;
|
||||
height: 5;
|
||||
min-height: 5;
|
||||
background: #3a3a3a;
|
||||
border: none;
|
||||
padding: 1 2;
|
||||
color: #ffffff;
|
||||
content-align: left top;
|
||||
}
|
||||
|
||||
#chat-input:focus {
|
||||
background: #404040;
|
||||
}
|
||||
|
||||
#command-dropdown {
|
||||
display: none;
|
||||
dock: bottom;
|
||||
offset-y: -5;
|
||||
offset-x: 7.5%;
|
||||
height: auto;
|
||||
max-height: 12;
|
||||
width: 85%;
|
||||
background: #2d2d2d;
|
||||
border: solid #555555;
|
||||
padding: 0;
|
||||
layer: overlay;
|
||||
}
|
||||
|
||||
#command-dropdown.visible {
|
||||
display: block;
|
||||
}
|
||||
|
||||
#command-dropdown #command-list {
|
||||
background: #2d2d2d;
|
||||
border: none;
|
||||
scrollbar-background: #2d2d2d;
|
||||
scrollbar-color: #555555;
|
||||
}
|
||||
|
||||
Footer {
|
||||
dock: bottom;
|
||||
height: auto;
|
||||
background: #252525;
|
||||
color: #888888;
|
||||
padding: 0 1;
|
||||
}
|
||||
|
||||
/* Button styles */
|
||||
Button {
|
||||
height: 3;
|
||||
min-width: 10;
|
||||
background: #3a3a3a;
|
||||
color: #cccccc;
|
||||
border: none;
|
||||
}
|
||||
|
||||
Button:hover {
|
||||
background: #4a4a4a;
|
||||
}
|
||||
|
||||
Button:focus {
|
||||
background: #505050;
|
||||
}
|
||||
|
||||
Button.-primary {
|
||||
background: #3a3a3a;
|
||||
}
|
||||
|
||||
Button.-success {
|
||||
background: #2d5016;
|
||||
color: #90ee90;
|
||||
}
|
||||
|
||||
Button.-success:hover {
|
||||
background: #3a6b1e;
|
||||
}
|
||||
|
||||
Button.-error {
|
||||
background: #5a1a1a;
|
||||
color: #ff6b6b;
|
||||
}
|
||||
|
||||
Button.-error:hover {
|
||||
background: #6e2222;
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
"""TUI widgets for oAI."""
|
||||
|
||||
from oai.tui.widgets.chat_display import ChatDisplay
|
||||
from oai.tui.widgets.footer import Footer
|
||||
from oai.tui.widgets.header import Header
|
||||
from oai.tui.widgets.input_bar import InputBar
|
||||
from oai.tui.widgets.message import AssistantMessageWidget, SystemMessageWidget, UserMessageWidget
|
||||
|
||||
__all__ = [
|
||||
"ChatDisplay",
|
||||
"Footer",
|
||||
"Header",
|
||||
"InputBar",
|
||||
"UserMessageWidget",
|
||||
"SystemMessageWidget",
|
||||
"AssistantMessageWidget",
|
||||
]
|
||||
@@ -1,21 +0,0 @@
|
||||
"""Chat display widget for oAI TUI."""
|
||||
|
||||
from textual.containers import ScrollableContainer
|
||||
from textual.widgets import Static
|
||||
|
||||
|
||||
class ChatDisplay(ScrollableContainer):
|
||||
"""Scrollable container for chat messages."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(id="chat-display")
|
||||
|
||||
async def add_message(self, widget: Static) -> None:
|
||||
"""Add a message widget to the display."""
|
||||
await self.mount(widget)
|
||||
self.scroll_end(animate=False)
|
||||
|
||||
def clear_messages(self) -> None:
|
||||
"""Clear all messages from the display."""
|
||||
for child in list(self.children):
|
||||
child.remove()
|
||||
@@ -1,178 +0,0 @@
|
||||
"""Command dropdown menu for TUI input."""
|
||||
|
||||
from textual.app import ComposeResult
|
||||
from textual.containers import VerticalScroll
|
||||
from textual.widget import Widget
|
||||
from textual.widgets import Label, OptionList
|
||||
from textual.widgets.option_list import Option
|
||||
|
||||
from oai.commands import registry
|
||||
|
||||
|
||||
class CommandDropdown(VerticalScroll):
|
||||
"""Dropdown menu showing available commands."""
|
||||
|
||||
DEFAULT_CSS = """
|
||||
CommandDropdown {
|
||||
display: none;
|
||||
height: auto;
|
||||
max-height: 12;
|
||||
width: 80;
|
||||
background: #2d2d2d;
|
||||
border: solid #555555;
|
||||
padding: 0;
|
||||
layer: overlay;
|
||||
}
|
||||
|
||||
CommandDropdown.visible {
|
||||
display: block;
|
||||
}
|
||||
|
||||
CommandDropdown OptionList {
|
||||
height: auto;
|
||||
max-height: 12;
|
||||
background: #2d2d2d;
|
||||
border: none;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
CommandDropdown OptionList > .option-list--option {
|
||||
padding: 0 2;
|
||||
color: #cccccc;
|
||||
background: transparent;
|
||||
}
|
||||
|
||||
CommandDropdown OptionList > .option-list--option-highlighted {
|
||||
background: #3e3e3e;
|
||||
color: #ffffff;
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the command dropdown."""
|
||||
super().__init__(id="command-dropdown")
|
||||
self._all_commands = []
|
||||
self._load_commands()
|
||||
|
||||
def _load_commands(self) -> None:
|
||||
"""Load all available commands."""
|
||||
# Get base commands with descriptions
|
||||
base_commands = [
|
||||
("/help", "Show help screen"),
|
||||
("/model", "Select AI model"),
|
||||
("/stats", "Show session statistics"),
|
||||
("/credits", "Check account credits"),
|
||||
("/clear", "Clear chat display"),
|
||||
("/reset", "Reset conversation history"),
|
||||
("/memory on", "Enable conversation memory"),
|
||||
("/memory off", "Disable memory"),
|
||||
("/online on", "Enable online search"),
|
||||
("/online off", "Disable online search"),
|
||||
("/save", "Save current conversation"),
|
||||
("/load", "Load saved conversation"),
|
||||
("/list", "List saved conversations"),
|
||||
("/delete", "Delete a conversation"),
|
||||
("/export md", "Export as Markdown"),
|
||||
("/export json", "Export as JSON"),
|
||||
("/export html", "Export as HTML"),
|
||||
("/prev", "Show previous message"),
|
||||
("/next", "Show next message"),
|
||||
("/config", "View configuration"),
|
||||
("/config api", "Set API key"),
|
||||
("/system", "Set system prompt"),
|
||||
("/maxtoken", "Set token limit"),
|
||||
("/retry", "Retry last prompt"),
|
||||
("/paste", "Paste from clipboard"),
|
||||
("/mcp on", "Enable MCP file access"),
|
||||
("/mcp off", "Disable MCP"),
|
||||
("/mcp status", "Show MCP status"),
|
||||
("/mcp add", "Add folder/database"),
|
||||
("/mcp remove", "Remove folder/database"),
|
||||
("/mcp list", "List folders"),
|
||||
("/mcp write on", "Enable write mode"),
|
||||
("/mcp write off", "Disable write mode"),
|
||||
("/mcp files", "Switch to file mode"),
|
||||
("/mcp db list", "List databases"),
|
||||
]
|
||||
|
||||
self._all_commands = base_commands
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
"""Compose the dropdown."""
|
||||
yield OptionList(id="command-list")
|
||||
|
||||
def show_commands(self, filter_text: str = "") -> None:
|
||||
"""Show commands matching the filter.
|
||||
|
||||
Args:
|
||||
filter_text: Text to filter commands by
|
||||
"""
|
||||
option_list = self.query_one("#command-list", OptionList)
|
||||
option_list.clear_options()
|
||||
|
||||
if not filter_text.startswith("/"):
|
||||
self.remove_class("visible")
|
||||
return
|
||||
|
||||
# Remove the leading slash for filtering
|
||||
filter_without_slash = filter_text[1:].lower()
|
||||
|
||||
# Filter commands - show if filter text is contained anywhere in the command
|
||||
if filter_without_slash:
|
||||
matching = [
|
||||
(cmd, desc) for cmd, desc in self._all_commands
|
||||
if filter_without_slash in cmd[1:].lower() # Skip the / in command for matching
|
||||
]
|
||||
else:
|
||||
# Show all commands when just "/" is typed
|
||||
matching = self._all_commands
|
||||
|
||||
if not matching:
|
||||
self.remove_class("visible")
|
||||
return
|
||||
|
||||
# Add options - limit to 10 results
|
||||
for cmd, desc in matching[:10]:
|
||||
# Format: command in white, description in gray, separated by spaces
|
||||
label = f"{cmd} [dim]{desc}[/]" if desc else cmd
|
||||
option_list.add_option(Option(label, id=cmd))
|
||||
|
||||
self.add_class("visible")
|
||||
|
||||
# Auto-select first option
|
||||
if len(option_list._options) > 0:
|
||||
option_list.highlighted = 0
|
||||
|
||||
def hide(self) -> None:
|
||||
"""Hide the dropdown."""
|
||||
self.remove_class("visible")
|
||||
|
||||
def get_selected_command(self) -> str | None:
|
||||
"""Get the currently selected command.
|
||||
|
||||
Returns:
|
||||
Selected command text or None
|
||||
"""
|
||||
option_list = self.query_one("#command-list", OptionList)
|
||||
if option_list.highlighted is not None:
|
||||
option = option_list.get_option_at_index(option_list.highlighted)
|
||||
return option.id
|
||||
return None
|
||||
|
||||
def move_selection_up(self) -> None:
|
||||
"""Move selection up in the list."""
|
||||
option_list = self.query_one("#command-list", OptionList)
|
||||
if option_list.option_count > 0:
|
||||
if option_list.highlighted is None:
|
||||
option_list.highlighted = option_list.option_count - 1
|
||||
elif option_list.highlighted > 0:
|
||||
option_list.highlighted -= 1
|
||||
|
||||
def move_selection_down(self) -> None:
|
||||
"""Move selection down in the list."""
|
||||
option_list = self.query_one("#command-list", OptionList)
|
||||
if option_list.option_count > 0:
|
||||
if option_list.highlighted is None:
|
||||
option_list.highlighted = 0
|
||||
elif option_list.highlighted < option_list.option_count - 1:
|
||||
option_list.highlighted += 1
|
||||
@@ -1,58 +0,0 @@
|
||||
"""Command suggester for TUI input."""
|
||||
|
||||
from typing import Iterable
|
||||
|
||||
from textual.suggester import Suggester
|
||||
|
||||
from oai.commands import registry
|
||||
|
||||
|
||||
class CommandSuggester(Suggester):
|
||||
"""Suggester that provides command completions."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the command suggester."""
|
||||
super().__init__(use_cache=False, case_sensitive=False)
|
||||
# Get all command names from registry
|
||||
self._commands = []
|
||||
self._update_commands()
|
||||
|
||||
def _update_commands(self) -> None:
|
||||
"""Update the list of available commands."""
|
||||
# Get all registered command names
|
||||
command_names = registry.get_all_names()
|
||||
# Add common MCP subcommands for better UX
|
||||
mcp_subcommands = [
|
||||
"/mcp on",
|
||||
"/mcp off",
|
||||
"/mcp status",
|
||||
"/mcp add",
|
||||
"/mcp remove",
|
||||
"/mcp list",
|
||||
"/mcp write on",
|
||||
"/mcp write off",
|
||||
"/mcp files",
|
||||
"/mcp db list",
|
||||
]
|
||||
self._commands = command_names + mcp_subcommands
|
||||
|
||||
async def get_suggestion(self, value: str) -> str | None:
|
||||
"""Get a command suggestion based on the current input.
|
||||
|
||||
Args:
|
||||
value: Current input value
|
||||
|
||||
Returns:
|
||||
Suggested completion or None
|
||||
"""
|
||||
if not value or not value.startswith("/"):
|
||||
return None
|
||||
|
||||
# Find the first command that starts with the input
|
||||
value_lower = value.lower()
|
||||
for cmd in self._commands:
|
||||
if cmd.lower().startswith(value_lower) and cmd.lower() != value_lower:
|
||||
# Return the rest of the command (after what's already typed)
|
||||
return cmd[len(value):]
|
||||
|
||||
return None
|
||||
@@ -1,39 +0,0 @@
|
||||
"""Footer widget for oAI TUI."""
|
||||
|
||||
from textual.app import ComposeResult
|
||||
from textual.widgets import Static
|
||||
|
||||
|
||||
class Footer(Static):
|
||||
"""Footer displaying session metrics."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.tokens_in = 0
|
||||
self.tokens_out = 0
|
||||
self.cost = 0.0
|
||||
self.messages = 0
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
"""Compose the footer."""
|
||||
yield Static(self._format_footer(), id="footer-content")
|
||||
|
||||
def _format_footer(self) -> str:
|
||||
"""Format the footer text."""
|
||||
return (
|
||||
f"[dim]Messages: {self.messages} | "
|
||||
f"Tokens: {self.tokens_in + self.tokens_out:,} "
|
||||
f"({self.tokens_in:,} in, {self.tokens_out:,} out) | "
|
||||
f"Cost: ${self.cost:.4f}[/]"
|
||||
)
|
||||
|
||||
def update_stats(
|
||||
self, tokens_in: int, tokens_out: int, cost: float, messages: int
|
||||
) -> None:
|
||||
"""Update the displayed statistics."""
|
||||
self.tokens_in = tokens_in
|
||||
self.tokens_out = tokens_out
|
||||
self.cost = cost
|
||||
self.messages = messages
|
||||
content = self.query_one("#footer-content", Static)
|
||||
content.update(self._format_footer())
|
||||
@@ -1,65 +0,0 @@
|
||||
"""Header widget for oAI TUI."""
|
||||
|
||||
from textual.app import ComposeResult
|
||||
from textual.widgets import Static
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
|
||||
class Header(Static):
|
||||
"""Header displaying app title, version, current model, and capabilities."""
|
||||
|
||||
def __init__(self, version: str = "3.0.1", model: str = "", model_info: Optional[Dict[str, Any]] = None):
|
||||
super().__init__()
|
||||
self.version = version
|
||||
self.model = model
|
||||
self.model_info = model_info or {}
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
"""Compose the header."""
|
||||
yield Static(self._format_header(), id="header-content")
|
||||
|
||||
def _format_capabilities(self) -> str:
|
||||
"""Format capability icons based on model info."""
|
||||
if not self.model_info:
|
||||
return ""
|
||||
|
||||
icons = []
|
||||
|
||||
# Check vision support
|
||||
architecture = self.model_info.get("architecture", {})
|
||||
modality = architecture.get("modality", "")
|
||||
if "image" in modality:
|
||||
icons.append("[bold cyan]👁️[/]") # Bright if supported
|
||||
else:
|
||||
icons.append("[dim]👁️[/]") # Dim if not supported
|
||||
|
||||
# Check tool support
|
||||
supported_params = self.model_info.get("supported_parameters", [])
|
||||
if "tools" in supported_params or "tool_choice" in supported_params:
|
||||
icons.append("[bold cyan]🔧[/]")
|
||||
else:
|
||||
icons.append("[dim]🔧[/]")
|
||||
|
||||
# Check online support (most models support :online suffix)
|
||||
model_id = self.model_info.get("id", "")
|
||||
if ":online" in model_id or model_id not in ["openrouter/free"]:
|
||||
icons.append("[bold cyan]🌐[/]")
|
||||
else:
|
||||
icons.append("[dim]🌐[/]")
|
||||
|
||||
return " ".join(icons) if icons else ""
|
||||
|
||||
def _format_header(self) -> str:
|
||||
"""Format the header text."""
|
||||
model_text = f" | {self.model}" if self.model else ""
|
||||
capabilities = self._format_capabilities()
|
||||
capabilities_text = f" {capabilities}" if capabilities else ""
|
||||
return f"[bold cyan]oAI[/] [dim]v{self.version}[/]{model_text}{capabilities_text}"
|
||||
|
||||
def update_model(self, model: str, model_info: Optional[Dict[str, Any]] = None) -> None:
|
||||
"""Update the displayed model and capabilities."""
|
||||
self.model = model
|
||||
if model_info:
|
||||
self.model_info = model_info
|
||||
content = self.query_one("#header-content", Static)
|
||||
content.update(self._format_header())
|
||||
@@ -1,49 +0,0 @@
|
||||
"""Input bar widget for oAI TUI."""
|
||||
|
||||
from textual.app import ComposeResult
|
||||
from textual.containers import Horizontal
|
||||
from textual.widgets import Input, Static
|
||||
|
||||
|
||||
class InputBar(Horizontal):
|
||||
"""Input bar with prompt prefix and text input."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(id="input-bar")
|
||||
self.mcp_status = ""
|
||||
self.online_mode = False
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
"""Compose the input bar."""
|
||||
yield Static(self._format_prefix(), id="input-prefix", classes="prefix-hidden" if not (self.mcp_status or self.online_mode) else "")
|
||||
yield Input(
|
||||
placeholder="Type a message or /command...",
|
||||
id="chat-input"
|
||||
)
|
||||
|
||||
def _format_prefix(self) -> str:
|
||||
"""Format the input prefix with status indicators."""
|
||||
indicators = []
|
||||
if self.mcp_status:
|
||||
indicators.append(f"[cyan]{self.mcp_status}[/]")
|
||||
if self.online_mode:
|
||||
indicators.append("[green]🌐[/]")
|
||||
|
||||
prefix = " ".join(indicators) + " " if indicators else ""
|
||||
return f"{prefix}[bold]>[/]"
|
||||
|
||||
def update_mcp_status(self, status: str) -> None:
|
||||
"""Update MCP status indicator."""
|
||||
self.mcp_status = status
|
||||
prefix = self.query_one("#input-prefix", Static)
|
||||
prefix.update(self._format_prefix())
|
||||
|
||||
def update_online_mode(self, online: bool) -> None:
|
||||
"""Update online mode indicator."""
|
||||
self.online_mode = online
|
||||
prefix = self.query_one("#input-prefix", Static)
|
||||
prefix.update(self._format_prefix())
|
||||
|
||||
def get_input(self) -> Input:
|
||||
"""Get the input widget."""
|
||||
return self.query_one("#chat-input", Input)
|
||||
@@ -1,92 +0,0 @@
|
||||
"""Message widgets for oAI TUI."""
|
||||
|
||||
from typing import Any, AsyncIterator, Tuple
|
||||
|
||||
from rich.console import Console
|
||||
from rich.markdown import Markdown
|
||||
from rich.style import Style
|
||||
from rich.theme import Theme
|
||||
from textual.app import ComposeResult
|
||||
from textual.widgets import RichLog, Static
|
||||
|
||||
# Custom theme for Markdown rendering - neutral colors matching the dark theme
|
||||
MARKDOWN_THEME = Theme({
|
||||
"markdown.text": Style(color="#cccccc"),
|
||||
"markdown.paragraph": Style(color="#cccccc"),
|
||||
"markdown.code": Style(color="#e0e0e0", bgcolor="#2a2a2a"),
|
||||
"markdown.code_block": Style(color="#e0e0e0", bgcolor="#2a2a2a"),
|
||||
"markdown.heading": Style(color="#ffffff", bold=True),
|
||||
"markdown.h1": Style(color="#ffffff", bold=True),
|
||||
"markdown.h2": Style(color="#eeeeee", bold=True),
|
||||
"markdown.h3": Style(color="#dddddd", bold=True),
|
||||
"markdown.link": Style(color="#aaaaaa", underline=False),
|
||||
"markdown.link_url": Style(color="#888888"),
|
||||
"markdown.emphasis": Style(color="#cccccc", italic=True),
|
||||
"markdown.strong": Style(color="#ffffff", bold=True),
|
||||
})
|
||||
|
||||
|
||||
class UserMessageWidget(Static):
|
||||
"""Widget for displaying user messages."""
|
||||
|
||||
def __init__(self, content: str):
|
||||
super().__init__()
|
||||
self.content = content
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
"""Compose the user message."""
|
||||
yield Static(f"[bold green]You:[/] {self.content}")
|
||||
|
||||
|
||||
class SystemMessageWidget(Static):
|
||||
"""Widget for displaying system/info messages without 'You:' prefix."""
|
||||
|
||||
def __init__(self, content: str):
|
||||
super().__init__()
|
||||
self.content = content
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
"""Compose the system message."""
|
||||
yield Static(self.content)
|
||||
|
||||
|
||||
class AssistantMessageWidget(Static):
|
||||
"""Widget for displaying assistant responses with streaming support."""
|
||||
|
||||
def __init__(self, model_name: str = "Assistant"):
|
||||
super().__init__()
|
||||
self.model_name = model_name
|
||||
self.full_text = ""
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
"""Compose the assistant message."""
|
||||
yield Static(f"[bold]{self.model_name}:[/]", id="assistant-label")
|
||||
yield RichLog(id="assistant-content", highlight=True, markup=True, wrap=True)
|
||||
|
||||
async def stream_response(self, response_iterator: AsyncIterator) -> Tuple[str, Any]:
|
||||
"""Stream tokens progressively and return final text and usage."""
|
||||
log = self.query_one("#assistant-content", RichLog)
|
||||
self.full_text = ""
|
||||
usage = None
|
||||
|
||||
async for chunk in response_iterator:
|
||||
if hasattr(chunk, "delta_content") and chunk.delta_content:
|
||||
self.full_text += chunk.delta_content
|
||||
log.clear()
|
||||
# Use neutral code theme for syntax highlighting
|
||||
md = Markdown(self.full_text, code_theme="github-dark", inline_code_theme="github-dark")
|
||||
log.write(md)
|
||||
|
||||
if hasattr(chunk, "usage") and chunk.usage:
|
||||
usage = chunk.usage
|
||||
|
||||
return self.full_text, usage
|
||||
|
||||
def set_content(self, content: str) -> None:
|
||||
"""Set the complete content (non-streaming)."""
|
||||
self.full_text = content
|
||||
log = self.query_one("#assistant-content", RichLog)
|
||||
log.clear()
|
||||
# Use neutral code theme for syntax highlighting
|
||||
md = Markdown(content, code_theme="github-dark", inline_code_theme="github-dark")
|
||||
log.write(md)
|
||||
@@ -1,20 +0,0 @@
|
||||
"""
|
||||
Utility modules for oAI.
|
||||
|
||||
This package provides common utilities used throughout the application
|
||||
including logging, file handling, and export functionality.
|
||||
"""
|
||||
|
||||
from oai.utils.logging import setup_logging, get_logger
|
||||
from oai.utils.files import read_file_safe, is_binary_file
|
||||
from oai.utils.export import export_as_markdown, export_as_json, export_as_html
|
||||
|
||||
__all__ = [
|
||||
"setup_logging",
|
||||
"get_logger",
|
||||
"read_file_safe",
|
||||
"is_binary_file",
|
||||
"export_as_markdown",
|
||||
"export_as_json",
|
||||
"export_as_html",
|
||||
]
|
||||
@@ -1,248 +0,0 @@
|
||||
"""
|
||||
Export utilities for oAI.
|
||||
|
||||
This module provides functions for exporting conversation history
|
||||
in various formats including Markdown, JSON, and HTML.
|
||||
"""
|
||||
|
||||
import json
|
||||
import datetime
|
||||
from typing import List, Dict
|
||||
from html import escape as html_escape
|
||||
|
||||
from oai.constants import APP_VERSION, APP_URL
|
||||
|
||||
|
||||
def export_as_markdown(
|
||||
session_history: List[Dict[str, str]],
|
||||
session_system_prompt: str = ""
|
||||
) -> str:
|
||||
"""
|
||||
Export conversation history as Markdown.
|
||||
|
||||
Args:
|
||||
session_history: List of message dictionaries with 'prompt' and 'response'
|
||||
session_system_prompt: Optional system prompt to include
|
||||
|
||||
Returns:
|
||||
Markdown formatted string
|
||||
"""
|
||||
lines = ["# Conversation Export", ""]
|
||||
|
||||
if session_system_prompt:
|
||||
lines.extend([f"**System Prompt:** {session_system_prompt}", ""])
|
||||
|
||||
lines.append(f"**Export Date:** {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
lines.append(f"**Messages:** {len(session_history)}")
|
||||
lines.append("")
|
||||
lines.append("---")
|
||||
lines.append("")
|
||||
|
||||
for i, entry in enumerate(session_history, 1):
|
||||
lines.append(f"## Message {i}")
|
||||
lines.append("")
|
||||
lines.append("**User:**")
|
||||
lines.append("")
|
||||
lines.append(entry.get("prompt", ""))
|
||||
lines.append("")
|
||||
lines.append("**Assistant:**")
|
||||
lines.append("")
|
||||
lines.append(entry.get("response", ""))
|
||||
lines.append("")
|
||||
lines.append("---")
|
||||
lines.append("")
|
||||
|
||||
lines.append(f"*Exported from oAI v{APP_VERSION} - {APP_URL}*")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def export_as_json(
|
||||
session_history: List[Dict[str, str]],
|
||||
session_system_prompt: str = ""
|
||||
) -> str:
|
||||
"""
|
||||
Export conversation history as JSON.
|
||||
|
||||
Args:
|
||||
session_history: List of message dictionaries
|
||||
session_system_prompt: Optional system prompt to include
|
||||
|
||||
Returns:
|
||||
JSON formatted string
|
||||
"""
|
||||
export_data = {
|
||||
"export_date": datetime.datetime.now().isoformat(),
|
||||
"app_version": APP_VERSION,
|
||||
"system_prompt": session_system_prompt,
|
||||
"message_count": len(session_history),
|
||||
"messages": [
|
||||
{
|
||||
"index": i + 1,
|
||||
"prompt": entry.get("prompt", ""),
|
||||
"response": entry.get("response", ""),
|
||||
"prompt_tokens": entry.get("prompt_tokens", 0),
|
||||
"completion_tokens": entry.get("completion_tokens", 0),
|
||||
"cost": entry.get("msg_cost", 0.0),
|
||||
}
|
||||
for i, entry in enumerate(session_history)
|
||||
],
|
||||
"totals": {
|
||||
"prompt_tokens": sum(e.get("prompt_tokens", 0) for e in session_history),
|
||||
"completion_tokens": sum(e.get("completion_tokens", 0) for e in session_history),
|
||||
"total_cost": sum(e.get("msg_cost", 0.0) for e in session_history),
|
||||
}
|
||||
}
|
||||
return json.dumps(export_data, indent=2, ensure_ascii=False)
|
||||
|
||||
|
||||
def export_as_html(
|
||||
session_history: List[Dict[str, str]],
|
||||
session_system_prompt: str = ""
|
||||
) -> str:
|
||||
"""
|
||||
Export conversation history as styled HTML.
|
||||
|
||||
Args:
|
||||
session_history: List of message dictionaries
|
||||
session_system_prompt: Optional system prompt to include
|
||||
|
||||
Returns:
|
||||
HTML formatted string with embedded CSS
|
||||
"""
|
||||
html_parts = [
|
||||
"<!DOCTYPE html>",
|
||||
"<html>",
|
||||
"<head>",
|
||||
" <meta charset='UTF-8'>",
|
||||
" <meta name='viewport' content='width=device-width, initial-scale=1.0'>",
|
||||
" <title>Conversation Export - oAI</title>",
|
||||
" <style>",
|
||||
" * { box-sizing: border-box; }",
|
||||
" body {",
|
||||
" font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;",
|
||||
" max-width: 900px;",
|
||||
" margin: 40px auto;",
|
||||
" padding: 20px;",
|
||||
" background: #f5f5f5;",
|
||||
" color: #333;",
|
||||
" }",
|
||||
" .header {",
|
||||
" background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);",
|
||||
" color: white;",
|
||||
" padding: 30px;",
|
||||
" border-radius: 10px;",
|
||||
" margin-bottom: 30px;",
|
||||
" box-shadow: 0 4px 6px rgba(0,0,0,0.1);",
|
||||
" }",
|
||||
" .header h1 {",
|
||||
" margin: 0 0 10px 0;",
|
||||
" font-size: 2em;",
|
||||
" }",
|
||||
" .export-info {",
|
||||
" opacity: 0.9;",
|
||||
" font-size: 0.95em;",
|
||||
" margin: 5px 0;",
|
||||
" }",
|
||||
" .system-prompt {",
|
||||
" background: #fff3cd;",
|
||||
" padding: 20px;",
|
||||
" border-radius: 8px;",
|
||||
" margin-bottom: 25px;",
|
||||
" border-left: 5px solid #ffc107;",
|
||||
" box-shadow: 0 2px 4px rgba(0,0,0,0.05);",
|
||||
" }",
|
||||
" .system-prompt strong {",
|
||||
" color: #856404;",
|
||||
" display: block;",
|
||||
" margin-bottom: 10px;",
|
||||
" font-size: 1.1em;",
|
||||
" }",
|
||||
" .message-container { margin-bottom: 20px; }",
|
||||
" .message {",
|
||||
" background: white;",
|
||||
" padding: 20px;",
|
||||
" border-radius: 8px;",
|
||||
" box-shadow: 0 2px 4px rgba(0,0,0,0.08);",
|
||||
" margin-bottom: 12px;",
|
||||
" }",
|
||||
" .user-message { border-left: 5px solid #10b981; }",
|
||||
" .assistant-message { border-left: 5px solid #3b82f6; }",
|
||||
" .role {",
|
||||
" font-weight: bold;",
|
||||
" margin-bottom: 12px;",
|
||||
" font-size: 1.05em;",
|
||||
" text-transform: uppercase;",
|
||||
" letter-spacing: 0.5px;",
|
||||
" }",
|
||||
" .user-role { color: #10b981; }",
|
||||
" .assistant-role { color: #3b82f6; }",
|
||||
" .content {",
|
||||
" line-height: 1.8;",
|
||||
" white-space: pre-wrap;",
|
||||
" color: #333;",
|
||||
" }",
|
||||
" .message-number {",
|
||||
" color: #6b7280;",
|
||||
" font-size: 0.85em;",
|
||||
" margin-bottom: 15px;",
|
||||
" font-weight: 600;",
|
||||
" }",
|
||||
" .footer {",
|
||||
" text-align: center;",
|
||||
" margin-top: 40px;",
|
||||
" padding: 20px;",
|
||||
" color: #6b7280;",
|
||||
" font-size: 0.9em;",
|
||||
" }",
|
||||
" .footer a { color: #667eea; text-decoration: none; }",
|
||||
" .footer a:hover { text-decoration: underline; }",
|
||||
" @media print {",
|
||||
" body { background: white; }",
|
||||
" .message { break-inside: avoid; }",
|
||||
" }",
|
||||
" </style>",
|
||||
"</head>",
|
||||
"<body>",
|
||||
" <div class='header'>",
|
||||
" <h1>Conversation Export</h1>",
|
||||
f" <div class='export-info'>Exported: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}</div>",
|
||||
f" <div class='export-info'>Total Messages: {len(session_history)}</div>",
|
||||
" </div>",
|
||||
]
|
||||
|
||||
if session_system_prompt:
|
||||
html_parts.extend([
|
||||
" <div class='system-prompt'>",
|
||||
" <strong>System Prompt</strong>",
|
||||
f" <div>{html_escape(session_system_prompt)}</div>",
|
||||
" </div>",
|
||||
])
|
||||
|
||||
for i, entry in enumerate(session_history, 1):
|
||||
prompt = html_escape(entry.get("prompt", ""))
|
||||
response = html_escape(entry.get("response", ""))
|
||||
|
||||
html_parts.extend([
|
||||
" <div class='message-container'>",
|
||||
f" <div class='message-number'>Message {i} of {len(session_history)}</div>",
|
||||
" <div class='message user-message'>",
|
||||
" <div class='role user-role'>User</div>",
|
||||
f" <div class='content'>{prompt}</div>",
|
||||
" </div>",
|
||||
" <div class='message assistant-message'>",
|
||||
" <div class='role assistant-role'>Assistant</div>",
|
||||
f" <div class='content'>{response}</div>",
|
||||
" </div>",
|
||||
" </div>",
|
||||
])
|
||||
|
||||
html_parts.extend([
|
||||
" <div class='footer'>",
|
||||
f" <p>Generated by oAI v{APP_VERSION} • <a href='{APP_URL}'>{APP_URL}</a></p>",
|
||||
" </div>",
|
||||
"</body>",
|
||||
"</html>",
|
||||
])
|
||||
|
||||
return "\n".join(html_parts)
|
||||
@@ -1,323 +0,0 @@
|
||||
"""
|
||||
File handling utilities for oAI.
|
||||
|
||||
This module provides safe file reading, type detection, and other
|
||||
file-related operations used throughout the application.
|
||||
"""
|
||||
|
||||
import os
|
||||
import mimetypes
|
||||
import base64
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any, Tuple
|
||||
|
||||
from oai.constants import (
|
||||
MAX_FILE_SIZE,
|
||||
CONTENT_TRUNCATION_THRESHOLD,
|
||||
SUPPORTED_CODE_EXTENSIONS,
|
||||
ALLOWED_FILE_EXTENSIONS,
|
||||
)
|
||||
from oai.utils.logging import get_logger
|
||||
|
||||
|
||||
def is_binary_file(file_path: Path) -> bool:
|
||||
"""
|
||||
Check if a file appears to be binary.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to check
|
||||
|
||||
Returns:
|
||||
True if the file appears to be binary, False otherwise
|
||||
"""
|
||||
try:
|
||||
with open(file_path, "rb") as f:
|
||||
# Read first 8KB to check for binary content
|
||||
chunk = f.read(8192)
|
||||
# Check for null bytes (common in binary files)
|
||||
if b"\x00" in chunk:
|
||||
return True
|
||||
# Try to decode as UTF-8
|
||||
try:
|
||||
chunk.decode("utf-8")
|
||||
return False
|
||||
except UnicodeDecodeError:
|
||||
return True
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
|
||||
def get_file_type(file_path: Path) -> Tuple[Optional[str], str]:
|
||||
"""
|
||||
Determine the MIME type and category of a file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file
|
||||
|
||||
Returns:
|
||||
Tuple of (mime_type, category) where category is one of:
|
||||
'image', 'pdf', 'code', 'text', 'binary', 'unknown'
|
||||
"""
|
||||
mime_type, _ = mimetypes.guess_type(str(file_path))
|
||||
ext = file_path.suffix.lower()
|
||||
|
||||
if mime_type and mime_type.startswith("image/"):
|
||||
return mime_type, "image"
|
||||
elif mime_type == "application/pdf" or ext == ".pdf":
|
||||
return mime_type or "application/pdf", "pdf"
|
||||
elif ext in SUPPORTED_CODE_EXTENSIONS:
|
||||
return mime_type or "text/plain", "code"
|
||||
elif mime_type and mime_type.startswith("text/"):
|
||||
return mime_type, "text"
|
||||
elif is_binary_file(file_path):
|
||||
return mime_type, "binary"
|
||||
else:
|
||||
return mime_type, "unknown"
|
||||
|
||||
|
||||
def read_file_safe(
|
||||
file_path: Path,
|
||||
max_size: int = MAX_FILE_SIZE,
|
||||
truncate_threshold: int = CONTENT_TRUNCATION_THRESHOLD
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Safely read a file with size limits and truncation support.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to read
|
||||
max_size: Maximum file size to read (bytes)
|
||||
truncate_threshold: Threshold for truncating large files
|
||||
|
||||
Returns:
|
||||
Dictionary containing:
|
||||
- content: File content (text or base64)
|
||||
- size: File size in bytes
|
||||
- truncated: Whether content was truncated
|
||||
- encoding: 'text', 'base64', or None on error
|
||||
- error: Error message if reading failed
|
||||
"""
|
||||
logger = get_logger()
|
||||
|
||||
try:
|
||||
path = Path(file_path).resolve()
|
||||
|
||||
if not path.exists():
|
||||
return {
|
||||
"content": None,
|
||||
"size": 0,
|
||||
"truncated": False,
|
||||
"encoding": None,
|
||||
"error": f"File not found: {path}"
|
||||
}
|
||||
|
||||
if not path.is_file():
|
||||
return {
|
||||
"content": None,
|
||||
"size": 0,
|
||||
"truncated": False,
|
||||
"encoding": None,
|
||||
"error": f"Not a file: {path}"
|
||||
}
|
||||
|
||||
file_size = path.stat().st_size
|
||||
|
||||
if file_size > max_size:
|
||||
return {
|
||||
"content": None,
|
||||
"size": file_size,
|
||||
"truncated": False,
|
||||
"encoding": None,
|
||||
"error": f"File too large: {file_size / (1024*1024):.1f}MB (max: {max_size / (1024*1024):.0f}MB)"
|
||||
}
|
||||
|
||||
# Try to read as text first
|
||||
try:
|
||||
content = path.read_text(encoding="utf-8")
|
||||
|
||||
# Check if truncation is needed
|
||||
if file_size > truncate_threshold:
|
||||
lines = content.split("\n")
|
||||
total_lines = len(lines)
|
||||
|
||||
# Keep first 500 lines and last 100 lines
|
||||
head_lines = 500
|
||||
tail_lines = 100
|
||||
|
||||
if total_lines > (head_lines + tail_lines):
|
||||
truncated_content = (
|
||||
"\n".join(lines[:head_lines]) +
|
||||
f"\n\n... [TRUNCATED: {total_lines - head_lines - tail_lines} lines omitted] ...\n\n" +
|
||||
"\n".join(lines[-tail_lines:])
|
||||
)
|
||||
logger.info(f"Read file (truncated): {path} ({file_size} bytes, {total_lines} lines)")
|
||||
return {
|
||||
"content": truncated_content,
|
||||
"size": file_size,
|
||||
"truncated": True,
|
||||
"total_lines": total_lines,
|
||||
"lines_shown": head_lines + tail_lines,
|
||||
"encoding": "text",
|
||||
"error": None
|
||||
}
|
||||
|
||||
logger.info(f"Read file: {path} ({file_size} bytes)")
|
||||
return {
|
||||
"content": content,
|
||||
"size": file_size,
|
||||
"truncated": False,
|
||||
"encoding": "text",
|
||||
"error": None
|
||||
}
|
||||
|
||||
except UnicodeDecodeError:
|
||||
# File is binary, return base64 encoded
|
||||
with open(path, "rb") as f:
|
||||
binary_data = f.read()
|
||||
b64_content = base64.b64encode(binary_data).decode("utf-8")
|
||||
logger.info(f"Read binary file: {path} ({file_size} bytes)")
|
||||
return {
|
||||
"content": b64_content,
|
||||
"size": file_size,
|
||||
"truncated": False,
|
||||
"encoding": "base64",
|
||||
"error": None
|
||||
}
|
||||
|
||||
except PermissionError as e:
|
||||
return {
|
||||
"content": None,
|
||||
"size": 0,
|
||||
"truncated": False,
|
||||
"encoding": None,
|
||||
"error": f"Permission denied: {e}"
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading file {file_path}: {e}")
|
||||
return {
|
||||
"content": None,
|
||||
"size": 0,
|
||||
"truncated": False,
|
||||
"encoding": None,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
def get_file_extension(file_path: Path) -> str:
|
||||
"""
|
||||
Get the lowercase file extension.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file
|
||||
|
||||
Returns:
|
||||
Lowercase extension including the dot (e.g., '.py')
|
||||
"""
|
||||
return file_path.suffix.lower()
|
||||
|
||||
|
||||
def is_allowed_extension(file_path: Path) -> bool:
|
||||
"""
|
||||
Check if a file has an allowed extension for attachment.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file
|
||||
|
||||
Returns:
|
||||
True if the extension is allowed, False otherwise
|
||||
"""
|
||||
return get_file_extension(file_path) in ALLOWED_FILE_EXTENSIONS
|
||||
|
||||
|
||||
def format_file_size(size_bytes: int) -> str:
|
||||
"""
|
||||
Format a file size in human-readable format.
|
||||
|
||||
Args:
|
||||
size_bytes: Size in bytes
|
||||
|
||||
Returns:
|
||||
Formatted string (e.g., '1.5 MB', '512 KB')
|
||||
"""
|
||||
for unit in ["B", "KB", "MB", "GB", "TB"]:
|
||||
if abs(size_bytes) < 1024:
|
||||
return f"{size_bytes:.1f} {unit}"
|
||||
size_bytes /= 1024
|
||||
return f"{size_bytes:.1f} PB"
|
||||
|
||||
|
||||
def prepare_file_attachment(
|
||||
file_path: Path,
|
||||
model_capabilities: Dict[str, Any]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Prepare a file for attachment to an API request.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file
|
||||
model_capabilities: Model capability information
|
||||
|
||||
Returns:
|
||||
Content block dictionary for the API, or None if unsupported
|
||||
"""
|
||||
logger = get_logger()
|
||||
path = Path(file_path).resolve()
|
||||
|
||||
if not path.exists():
|
||||
logger.warning(f"File not found: {path}")
|
||||
return None
|
||||
|
||||
mime_type, category = get_file_type(path)
|
||||
file_size = path.stat().st_size
|
||||
|
||||
if file_size > MAX_FILE_SIZE:
|
||||
logger.warning(f"File too large: {path} ({format_file_size(file_size)})")
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(path, "rb") as f:
|
||||
file_data = f.read()
|
||||
|
||||
if category == "image":
|
||||
# Check if model supports images
|
||||
input_modalities = model_capabilities.get("architecture", {}).get("input_modalities", [])
|
||||
if "image" not in input_modalities:
|
||||
logger.warning(f"Model does not support images")
|
||||
return None
|
||||
|
||||
b64_data = base64.b64encode(file_data).decode("utf-8")
|
||||
return {
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:{mime_type};base64,{b64_data}"}
|
||||
}
|
||||
|
||||
elif category == "pdf":
|
||||
# Check if model supports PDFs
|
||||
input_modalities = model_capabilities.get("architecture", {}).get("input_modalities", [])
|
||||
supports_pdf = any(mod in input_modalities for mod in ["document", "pdf", "file"])
|
||||
if not supports_pdf:
|
||||
logger.warning(f"Model does not support PDFs")
|
||||
return None
|
||||
|
||||
b64_data = base64.b64encode(file_data).decode("utf-8")
|
||||
return {
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:application/pdf;base64,{b64_data}"}
|
||||
}
|
||||
|
||||
elif category in ("code", "text"):
|
||||
text_content = file_data.decode("utf-8")
|
||||
return {
|
||||
"type": "text",
|
||||
"text": f"File: {path.name}\n\n{text_content}"
|
||||
}
|
||||
|
||||
else:
|
||||
logger.warning(f"Unsupported file type: {category} ({mime_type})")
|
||||
return None
|
||||
|
||||
except UnicodeDecodeError:
|
||||
logger.error(f"Cannot decode file as UTF-8: {path}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error preparing file attachment {path}: {e}")
|
||||
return None
|
||||
@@ -1,297 +0,0 @@
|
||||
"""
|
||||
Logging configuration for oAI.
|
||||
|
||||
This module provides centralized logging setup with Rich formatting,
|
||||
file rotation, and configurable log levels.
|
||||
"""
|
||||
|
||||
import io
|
||||
import os
|
||||
import glob
|
||||
import logging
|
||||
import datetime
|
||||
import shutil
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from rich.console import Console
|
||||
from rich.logging import RichHandler
|
||||
|
||||
from oai.constants import (
|
||||
LOG_FILE,
|
||||
CONFIG_DIR,
|
||||
DEFAULT_LOG_MAX_SIZE_MB,
|
||||
DEFAULT_LOG_BACKUP_COUNT,
|
||||
DEFAULT_LOG_LEVEL,
|
||||
VALID_LOG_LEVELS,
|
||||
)
|
||||
|
||||
|
||||
class RotatingRichHandler(RotatingFileHandler):
|
||||
"""
|
||||
Custom log handler combining file rotation with Rich formatting.
|
||||
|
||||
This handler writes Rich-formatted log output to a rotating file,
|
||||
providing colored and formatted logs even in file output while
|
||||
managing file size and backups automatically.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Initialize the handler with Rich console for formatting."""
|
||||
super().__init__(*args, **kwargs)
|
||||
# Create an internal console for Rich formatting
|
||||
self.rich_console = Console(
|
||||
file=io.StringIO(),
|
||||
width=120,
|
||||
force_terminal=False
|
||||
)
|
||||
self.rich_handler = RichHandler(
|
||||
console=self.rich_console,
|
||||
show_time=True,
|
||||
show_path=True,
|
||||
rich_tracebacks=True,
|
||||
tracebacks_suppress=["requests", "openrouter", "urllib3", "httpx", "openai"]
|
||||
)
|
||||
|
||||
def emit(self, record: logging.LogRecord) -> None:
|
||||
"""
|
||||
Emit a log record with Rich formatting.
|
||||
|
||||
Args:
|
||||
record: The log record to emit
|
||||
"""
|
||||
try:
|
||||
# Format with Rich
|
||||
self.rich_handler.emit(record)
|
||||
output = self.rich_console.file.getvalue()
|
||||
self.rich_console.file.seek(0)
|
||||
self.rich_console.file.truncate(0)
|
||||
|
||||
if output:
|
||||
self.stream.write(output)
|
||||
self.flush()
|
||||
except Exception:
|
||||
self.handleError(record)
|
||||
|
||||
|
||||
class LoggingManager:
|
||||
"""
|
||||
Manages application logging configuration.
|
||||
|
||||
Provides methods to setup, configure, and manage logging with
|
||||
support for runtime reconfiguration and level changes.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the logging manager."""
|
||||
self.handler: Optional[RotatingRichHandler] = None
|
||||
self.app_logger: Optional[logging.Logger] = None
|
||||
self.max_size_mb: int = DEFAULT_LOG_MAX_SIZE_MB
|
||||
self.backup_count: int = DEFAULT_LOG_BACKUP_COUNT
|
||||
self.log_level: str = DEFAULT_LOG_LEVEL
|
||||
|
||||
def setup(
|
||||
self,
|
||||
max_size_mb: Optional[int] = None,
|
||||
backup_count: Optional[int] = None,
|
||||
log_level: Optional[str] = None
|
||||
) -> logging.Logger:
|
||||
"""
|
||||
Setup or reconfigure logging.
|
||||
|
||||
Args:
|
||||
max_size_mb: Maximum log file size in MB
|
||||
backup_count: Number of backup files to keep
|
||||
log_level: Logging level string
|
||||
|
||||
Returns:
|
||||
The configured application logger
|
||||
"""
|
||||
# Update configuration if provided
|
||||
if max_size_mb is not None:
|
||||
self.max_size_mb = max_size_mb
|
||||
if backup_count is not None:
|
||||
self.backup_count = backup_count
|
||||
if log_level is not None:
|
||||
self.log_level = log_level
|
||||
|
||||
# Ensure config directory exists
|
||||
CONFIG_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Get root logger
|
||||
root_logger = logging.getLogger()
|
||||
|
||||
# Remove existing handler if present
|
||||
if self.handler is not None:
|
||||
root_logger.removeHandler(self.handler)
|
||||
try:
|
||||
self.handler.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Check if log needs manual rotation
|
||||
self._check_rotation()
|
||||
|
||||
# Create new handler
|
||||
max_bytes = self.max_size_mb * 1024 * 1024
|
||||
self.handler = RotatingRichHandler(
|
||||
filename=str(LOG_FILE),
|
||||
maxBytes=max_bytes,
|
||||
backupCount=self.backup_count,
|
||||
encoding="utf-8"
|
||||
)
|
||||
|
||||
self.handler.setLevel(logging.NOTSET)
|
||||
root_logger.setLevel(logging.WARNING)
|
||||
root_logger.addHandler(self.handler)
|
||||
|
||||
# Suppress noisy third-party loggers
|
||||
for logger_name in [
|
||||
"asyncio", "urllib3", "requests", "httpx",
|
||||
"httpcore", "openai", "openrouter"
|
||||
]:
|
||||
logging.getLogger(logger_name).setLevel(logging.WARNING)
|
||||
|
||||
# Configure application logger
|
||||
self.app_logger = logging.getLogger("oai_app")
|
||||
level = VALID_LOG_LEVELS.get(self.log_level.lower(), logging.INFO)
|
||||
self.app_logger.setLevel(level)
|
||||
self.app_logger.propagate = True
|
||||
|
||||
return self.app_logger
|
||||
|
||||
def _check_rotation(self) -> None:
|
||||
"""Check if log file needs rotation and perform if necessary."""
|
||||
if not LOG_FILE.exists():
|
||||
return
|
||||
|
||||
current_size = LOG_FILE.stat().st_size
|
||||
max_bytes = self.max_size_mb * 1024 * 1024
|
||||
|
||||
if current_size >= max_bytes:
|
||||
# Perform manual rotation
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
backup_file = f"{LOG_FILE}.{timestamp}"
|
||||
|
||||
try:
|
||||
shutil.move(str(LOG_FILE), backup_file)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Clean old backups
|
||||
self._cleanup_old_backups()
|
||||
|
||||
def _cleanup_old_backups(self) -> None:
|
||||
"""Remove old backup files exceeding the backup count."""
|
||||
log_dir = LOG_FILE.parent
|
||||
backup_pattern = f"{LOG_FILE.name}.*"
|
||||
backups = sorted(glob.glob(str(log_dir / backup_pattern)))
|
||||
|
||||
while len(backups) > self.backup_count:
|
||||
oldest = backups.pop(0)
|
||||
try:
|
||||
os.remove(oldest)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def set_level(self, level: str) -> bool:
|
||||
"""
|
||||
Set the application log level.
|
||||
|
||||
Args:
|
||||
level: Log level string (debug/info/warning/error/critical)
|
||||
|
||||
Returns:
|
||||
True if level was set successfully, False otherwise
|
||||
"""
|
||||
level_lower = level.lower()
|
||||
if level_lower not in VALID_LOG_LEVELS:
|
||||
return False
|
||||
|
||||
self.log_level = level_lower
|
||||
if self.app_logger:
|
||||
self.app_logger.setLevel(VALID_LOG_LEVELS[level_lower])
|
||||
|
||||
return True
|
||||
|
||||
def get_logger(self) -> logging.Logger:
|
||||
"""
|
||||
Get the application logger, initializing if necessary.
|
||||
|
||||
Returns:
|
||||
The application logger
|
||||
"""
|
||||
if self.app_logger is None:
|
||||
self.setup()
|
||||
return self.app_logger
|
||||
|
||||
|
||||
# Global logging manager instance
|
||||
_logging_manager = LoggingManager()
|
||||
|
||||
|
||||
def setup_logging(
|
||||
max_size_mb: Optional[int] = None,
|
||||
backup_count: Optional[int] = None,
|
||||
log_level: Optional[str] = None
|
||||
) -> logging.Logger:
|
||||
"""
|
||||
Setup application logging.
|
||||
|
||||
This is the main entry point for configuring logging. Call this
|
||||
early in application startup.
|
||||
|
||||
Args:
|
||||
max_size_mb: Maximum log file size in MB
|
||||
backup_count: Number of backup files to keep
|
||||
log_level: Logging level string
|
||||
|
||||
Returns:
|
||||
The configured application logger
|
||||
"""
|
||||
return _logging_manager.setup(max_size_mb, backup_count, log_level)
|
||||
|
||||
|
||||
def get_logger() -> logging.Logger:
|
||||
"""
|
||||
Get the application logger.
|
||||
|
||||
Returns:
|
||||
The application logger instance
|
||||
"""
|
||||
return _logging_manager.get_logger()
|
||||
|
||||
|
||||
def set_log_level(level: str) -> bool:
|
||||
"""
|
||||
Set the application log level.
|
||||
|
||||
Args:
|
||||
level: Log level string
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
return _logging_manager.set_level(level)
|
||||
|
||||
|
||||
def reload_logging(
|
||||
max_size_mb: Optional[int] = None,
|
||||
backup_count: Optional[int] = None,
|
||||
log_level: Optional[str] = None
|
||||
) -> logging.Logger:
|
||||
"""
|
||||
Reload logging configuration.
|
||||
|
||||
Useful when settings change at runtime.
|
||||
|
||||
Args:
|
||||
max_size_mb: New maximum log file size
|
||||
backup_count: New backup count
|
||||
log_level: New log level
|
||||
|
||||
Returns:
|
||||
The reconfigured logger
|
||||
"""
|
||||
return _logging_manager.setup(max_size_mb, backup_count, log_level)
|
||||
134
pyproject.toml
134
pyproject.toml
@@ -1,134 +0,0 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=61.0", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "oai"
|
||||
version = "3.0.0-b2" # MUST match oai/__init__.py __version__
|
||||
description = "OpenRouter AI Chat Client - A feature-rich terminal-based chat application"
|
||||
readme = "README.md"
|
||||
license = {text = "MIT"}
|
||||
authors = [
|
||||
{name = "Rune", email = "rune@example.com"}
|
||||
]
|
||||
maintainers = [
|
||||
{name = "Rune", email = "rune@example.com"}
|
||||
]
|
||||
keywords = [
|
||||
"ai",
|
||||
"chat",
|
||||
"openrouter",
|
||||
"cli",
|
||||
"terminal",
|
||||
"mcp",
|
||||
"llm",
|
||||
]
|
||||
classifiers = [
|
||||
"Development Status :: 4 - Beta",
|
||||
"Environment :: Console",
|
||||
"Intended Audience :: Developers",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Operating System :: OS Independent",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
"Topic :: Utilities",
|
||||
]
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"anyio>=4.0.0",
|
||||
"click>=8.0.0",
|
||||
"httpx>=0.24.0",
|
||||
"markdown-it-py>=3.0.0",
|
||||
"openrouter>=0.0.19",
|
||||
"packaging>=21.0",
|
||||
"pyperclip>=1.8.0",
|
||||
"requests>=2.28.0",
|
||||
"rich>=13.0.0",
|
||||
"textual>=0.50.0",
|
||||
"typer>=0.9.0",
|
||||
"mcp>=1.0.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=7.0.0",
|
||||
"pytest-asyncio>=0.21.0",
|
||||
"pytest-cov>=4.0.0",
|
||||
"black>=23.0.0",
|
||||
"isort>=5.12.0",
|
||||
"mypy>=1.0.0",
|
||||
"ruff>=0.1.0",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://iurl.no/oai"
|
||||
Repository = "https://gitlab.pm/rune/oai"
|
||||
Documentation = "https://iurl.no/oai"
|
||||
"Bug Tracker" = "https://gitlab.pm/rune/oai/issues"
|
||||
|
||||
[project.scripts]
|
||||
oai = "oai.cli:main"
|
||||
|
||||
[tool.setuptools]
|
||||
packages = ["oai", "oai.commands", "oai.config", "oai.core", "oai.mcp", "oai.providers", "oai.tui", "oai.tui.widgets", "oai.tui.screens", "oai.utils"]
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
oai = ["py.typed"]
|
||||
|
||||
[tool.black]
|
||||
line-length = 100
|
||||
target-version = ["py310", "py311", "py312"]
|
||||
include = '\.pyi?$'
|
||||
exclude = '''
|
||||
/(
|
||||
\.git
|
||||
| \.mypy_cache
|
||||
| \.pytest_cache
|
||||
| \.venv
|
||||
| build
|
||||
| dist
|
||||
)/
|
||||
'''
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
line_length = 100
|
||||
skip_gitignore = true
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.10"
|
||||
warn_return_any = true
|
||||
warn_unused_configs = true
|
||||
ignore_missing_imports = true
|
||||
exclude = [
|
||||
"build",
|
||||
"dist",
|
||||
".venv",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 100
|
||||
target-version = "py310"
|
||||
select = [
|
||||
"E", # pycodestyle errors
|
||||
"W", # pycodestyle warnings
|
||||
"F", # Pyflakes
|
||||
"I", # isort
|
||||
"B", # flake8-bugbear
|
||||
"C4", # flake8-comprehensions
|
||||
"UP", # pyupgrade
|
||||
]
|
||||
ignore = [
|
||||
"E501", # line too long (handled by black)
|
||||
"B008", # do not perform function calls in argument defaults
|
||||
"C901", # too complex
|
||||
]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
python_files = ["test_*.py"]
|
||||
asyncio_mode = "auto"
|
||||
addopts = "-v --tb=short"
|
||||
Reference in New Issue
Block a user