1082 lines
40 KiB
Python
1082 lines
40 KiB
Python
"""Main Textual TUI application for oAI."""
|
||
|
||
import asyncio
|
||
import platform
|
||
from pathlib import Path
|
||
from typing import Optional
|
||
|
||
import pyperclip
|
||
from textual.app import App, ComposeResult
|
||
from textual.widgets import Input
|
||
|
||
from oai import __version__
|
||
from oai.commands.registry import CommandStatus, registry
|
||
from oai.config.settings import Settings
|
||
from oai.core.client import AIClient
|
||
from oai.core.session import ChatSession
|
||
from oai.tui.screens import (
|
||
AlertDialog,
|
||
CommandsScreen,
|
||
ConfirmDialog,
|
||
ConfigScreen,
|
||
ConversationSelectorScreen,
|
||
CreditsScreen,
|
||
HelpScreen,
|
||
InputDialog,
|
||
ModelInfoScreen,
|
||
ModelSelectorScreen,
|
||
StatsScreen,
|
||
)
|
||
from oai.tui.widgets import (
|
||
AssistantMessageWidget,
|
||
ChatDisplay,
|
||
Footer,
|
||
Header,
|
||
InputBar,
|
||
SystemMessageWidget,
|
||
UserMessageWidget,
|
||
)
|
||
from oai.tui.widgets.command_dropdown import CommandDropdown
|
||
|
||
|
||
class oAIChatApp(App):
|
||
"""Textual TUI for oAI chat."""
|
||
|
||
CSS_PATH = Path(__file__).parent / "styles.tcss"
|
||
TITLE = "oAI Chat"
|
||
|
||
BINDINGS = [
|
||
("ctrl+q", "quit", "Quit"),
|
||
("ctrl+m", "show_model_selector", "Model"),
|
||
("ctrl+h", "show_help", "Help"),
|
||
("ctrl+l", "clear_chat", "Clear"),
|
||
("ctrl+s", "show_stats", "Stats"),
|
||
]
|
||
|
||
def __init__(
|
||
self,
|
||
session: ChatSession,
|
||
settings: Settings,
|
||
model: Optional[str] = None,
|
||
):
|
||
super().__init__()
|
||
self.session = session
|
||
self.settings = settings
|
||
self.initial_model = model
|
||
self.input_history: list[str] = []
|
||
self.history_index: int = -1
|
||
self._navigating_history: bool = False
|
||
|
||
def compose(self) -> ComposeResult:
|
||
"""Compose the TUI layout."""
|
||
model_name = self.session.selected_model.get("name", "") if self.session.selected_model else ""
|
||
model_info = self.session.selected_model if self.session.selected_model else None
|
||
provider_name = self.session.client.provider_name if self.session.client else ""
|
||
yield Header(version=__version__, model=model_name, model_info=model_info, provider=provider_name)
|
||
yield ChatDisplay()
|
||
yield InputBar()
|
||
yield CommandDropdown()
|
||
yield Footer()
|
||
|
||
def on_mount(self) -> None:
|
||
"""Handle app mount."""
|
||
# Focus the input
|
||
input_bar = self.query_one(InputBar)
|
||
chat_input = input_bar.get_input()
|
||
chat_input.focus()
|
||
|
||
# Update MCP status if enabled
|
||
if self.session.mcp_manager and self.session.mcp_manager.enabled:
|
||
mode = self.session.mcp_manager.mode
|
||
if mode == "files":
|
||
input_bar.update_mcp_status("📁 Files")
|
||
elif mode == "database":
|
||
input_bar.update_mcp_status("🗄️ DB")
|
||
|
||
# Update online mode
|
||
if self.session.online_enabled:
|
||
input_bar.update_online_mode(True)
|
||
|
||
def on_key(self, event) -> None:
|
||
"""Handle global keyboard shortcuts."""
|
||
# Debug: Show what key was pressed
|
||
# self.notify(f"Key pressed: {event.key}", severity="information")
|
||
|
||
# Don't handle keys if a modal screen is open (let the modal handle them)
|
||
if len(self.screen_stack) > 1:
|
||
return
|
||
|
||
# Handle input history navigation (Up/Down arrows)
|
||
input_bar = self.query_one(InputBar)
|
||
chat_input = input_bar.get_input()
|
||
dropdown = self.query_one(CommandDropdown)
|
||
|
||
# Check if dropdown is visible
|
||
dropdown_visible = dropdown.has_class("visible")
|
||
|
||
if chat_input.has_focus:
|
||
# If dropdown is visible, handle dropdown navigation and selection
|
||
if dropdown_visible:
|
||
if event.key == "up":
|
||
event.prevent_default()
|
||
dropdown.move_selection_up()
|
||
return
|
||
elif event.key == "down":
|
||
event.prevent_default()
|
||
dropdown.move_selection_down()
|
||
return
|
||
elif event.key == "tab":
|
||
# Tab accepts the selected command and adds space for arguments
|
||
event.prevent_default()
|
||
selected = dropdown.get_selected_command()
|
||
if selected:
|
||
chat_input.value = selected + " "
|
||
chat_input.cursor_position = len(chat_input.value)
|
||
dropdown.hide()
|
||
return
|
||
elif event.key == "enter":
|
||
# Enter accepts the selected command
|
||
# If command needs more input, add space; otherwise submit
|
||
event.prevent_default()
|
||
selected = dropdown.get_selected_command()
|
||
if selected:
|
||
# Commands that require additional arguments
|
||
needs_args = [
|
||
"/mcp add",
|
||
"/mcp remove",
|
||
"/config api",
|
||
"/delete",
|
||
"/system",
|
||
"/maxtoken",
|
||
]
|
||
|
||
# Check if this command needs more input
|
||
needs_input = any(selected.startswith(cmd) for cmd in needs_args)
|
||
|
||
if needs_input:
|
||
# Add space and wait for user to type more
|
||
chat_input.value = selected + " "
|
||
chat_input.cursor_position = len(chat_input.value)
|
||
dropdown.hide()
|
||
else:
|
||
# Command is complete, submit it directly
|
||
dropdown.hide()
|
||
chat_input.value = "" # Clear immediately
|
||
# Process the command directly
|
||
async def submit_command():
|
||
await self._process_submitted_input(selected)
|
||
self.call_later(submit_command)
|
||
return
|
||
elif event.key == "escape":
|
||
# Escape closes dropdown
|
||
event.prevent_default()
|
||
dropdown.hide()
|
||
return
|
||
# Otherwise, arrow keys navigate history
|
||
elif event.key == "up":
|
||
event.prevent_default()
|
||
self._navigate_history_backward(chat_input)
|
||
return
|
||
elif event.key == "down":
|
||
event.prevent_default()
|
||
self._navigate_history_forward(chat_input)
|
||
return
|
||
|
||
# Handle Ctrl shortcuts that should work globally
|
||
if event.key == "ctrl+q":
|
||
event.prevent_default()
|
||
self.exit()
|
||
elif event.key in ("ctrl+h", "f1"): # F1 as alternative for help
|
||
event.prevent_default()
|
||
self.notify("Opening help...", severity="information")
|
||
self.action_show_help()
|
||
elif event.key in ("ctrl+m", "f2"): # F2 as alternative for model
|
||
event.prevent_default()
|
||
self.notify("Opening model selector...", severity="information")
|
||
self.action_show_model_selector()
|
||
elif event.key == "ctrl+s":
|
||
event.prevent_default()
|
||
self.action_show_stats()
|
||
elif event.key == "ctrl+l":
|
||
event.prevent_default()
|
||
self.action_clear_chat()
|
||
elif event.key == "ctrl+p":
|
||
event.prevent_default()
|
||
self.call_later(self._handle_prev_command)
|
||
elif event.key == "ctrl+n":
|
||
event.prevent_default()
|
||
self.call_later(self._handle_next_command)
|
||
elif event.key in ("f3", "ctrl+y"):
|
||
# F3 or Ctrl+Y to copy last AI response
|
||
event.prevent_default()
|
||
self.action_copy_last_response()
|
||
|
||
def on_input_changed(self, event: Input.Changed) -> None:
|
||
"""Handle input value changes to show/hide command dropdown."""
|
||
if event.input.id != "chat-input":
|
||
return
|
||
|
||
# Don't show dropdown when navigating history
|
||
if self._navigating_history:
|
||
return
|
||
|
||
dropdown = self.query_one(CommandDropdown)
|
||
value = event.value
|
||
|
||
# Show dropdown if input starts with /
|
||
if value.startswith("/") and not value.startswith("//"):
|
||
dropdown.show_commands(value)
|
||
else:
|
||
dropdown.hide()
|
||
|
||
async def on_input_submitted(self, event: Input.Submitted) -> None:
|
||
"""Handle input submission."""
|
||
user_input = event.value.strip()
|
||
if not user_input:
|
||
return
|
||
|
||
# Clear input field immediately
|
||
event.input.value = ""
|
||
|
||
# Process the input (async, will wait for AI response)
|
||
await self._process_submitted_input(user_input)
|
||
|
||
async def _process_submitted_input(self, user_input: str) -> None:
|
||
"""Process submitted input (command or message).
|
||
|
||
Args:
|
||
user_input: The input text to process
|
||
"""
|
||
if not user_input:
|
||
return
|
||
|
||
# Hide command dropdown
|
||
dropdown = self.query_one(CommandDropdown)
|
||
dropdown.hide()
|
||
|
||
# Add to history
|
||
self.input_history.append(user_input)
|
||
self.history_index = -1
|
||
|
||
# Always show what the user typed
|
||
chat_display = self.query_one(ChatDisplay)
|
||
user_widget = UserMessageWidget(user_input)
|
||
await chat_display.add_message(user_widget)
|
||
|
||
# Check if it's a command
|
||
if user_input.startswith("/"):
|
||
await self.handle_command(user_input)
|
||
else:
|
||
await self.handle_message(user_input)
|
||
|
||
async def handle_command(self, command_text: str) -> None:
|
||
"""Handle a slash command."""
|
||
# Remove leading slash and check for exit commands
|
||
cmd_word = command_text.lstrip("/").lower().split()[0] if command_text else ""
|
||
|
||
if cmd_word in ["exit", "quit", "bye"]:
|
||
self.exit()
|
||
return
|
||
|
||
# Handle special TUI commands that need screens
|
||
if cmd_word == "help":
|
||
await self.push_screen(HelpScreen())
|
||
return
|
||
|
||
if cmd_word == "commands":
|
||
await self.push_screen(CommandsScreen())
|
||
return
|
||
|
||
if cmd_word == "stats":
|
||
await self.push_screen(StatsScreen(self.session))
|
||
return
|
||
|
||
if cmd_word == "config":
|
||
# Check if there are any arguments
|
||
args = command_text.split(maxsplit=1)
|
||
if len(args) == 1: # No arguments, just "/config"
|
||
await self.push_screen(ConfigScreen(self.settings, self.session))
|
||
return
|
||
# If there are arguments, fall through to normal command handler
|
||
|
||
if cmd_word == "credits":
|
||
await self.push_screen(CreditsScreen(self.session.client))
|
||
return
|
||
|
||
if cmd_word == "info":
|
||
# Show model info modal
|
||
if self.session.selected_model:
|
||
provider_name = self.session.client.provider_name
|
||
await self.push_screen(ModelInfoScreen(self.session.selected_model, provider_name))
|
||
else:
|
||
chat_display = self.query_one(ChatDisplay)
|
||
error_widget = SystemMessageWidget("❌ No model selected. Use /model to select a model.")
|
||
await chat_display.add_message(error_widget)
|
||
return
|
||
|
||
if cmd_word == "clear":
|
||
chat_display = self.query_one(ChatDisplay)
|
||
chat_display.clear_messages()
|
||
return
|
||
|
||
if cmd_word == "reset":
|
||
def handle_reset_confirmation(confirmed: bool) -> None:
|
||
if confirmed:
|
||
self.session.history.clear()
|
||
self.session.current_index = -1
|
||
chat_display = self.query_one(ChatDisplay)
|
||
chat_display.clear_messages()
|
||
self._update_footer()
|
||
|
||
self.push_screen(
|
||
ConfirmDialog(
|
||
"Reset conversation and clear all history?",
|
||
"Reset Conversation"
|
||
),
|
||
callback=handle_reset_confirmation
|
||
)
|
||
return
|
||
|
||
if cmd_word == "memory":
|
||
# Toggle memory
|
||
args = command_text.split(maxsplit=1)
|
||
if len(args) > 1:
|
||
state = args[1].lower()
|
||
if state == "on":
|
||
self.session.memory_enabled = True
|
||
status = "enabled"
|
||
elif state == "off":
|
||
self.session.memory_enabled = False
|
||
status = "disabled"
|
||
else:
|
||
status = "on" if self.session.memory_enabled else "off"
|
||
else:
|
||
# Toggle
|
||
self.session.memory_enabled = not self.session.memory_enabled
|
||
status = "enabled" if self.session.memory_enabled else "disabled"
|
||
|
||
chat_display = self.query_one(ChatDisplay)
|
||
info_widget = UserMessageWidget(f"✓ Memory {status}")
|
||
await chat_display.add_message(info_widget)
|
||
return
|
||
|
||
if cmd_word == "online":
|
||
# Toggle online mode
|
||
args = command_text.split(maxsplit=1)
|
||
if len(args) > 1:
|
||
state = args[1].lower()
|
||
if state == "on":
|
||
self.session.online_enabled = True
|
||
status = "enabled"
|
||
elif state == "off":
|
||
self.session.online_enabled = False
|
||
status = "disabled"
|
||
else:
|
||
status = "on" if self.session.online_enabled else "off"
|
||
else:
|
||
# Toggle
|
||
self.session.online_enabled = not self.session.online_enabled
|
||
status = "enabled" if self.session.online_enabled else "disabled"
|
||
|
||
input_bar = self.query_one(InputBar)
|
||
input_bar.update_online_mode(self.session.online_enabled)
|
||
|
||
chat_display = self.query_one(ChatDisplay)
|
||
info_widget = UserMessageWidget(f"✓ Online mode {status}")
|
||
await chat_display.add_message(info_widget)
|
||
return
|
||
|
||
if cmd_word == "save":
|
||
await self._handle_save_command(command_text)
|
||
return
|
||
|
||
if cmd_word == "load":
|
||
await self._handle_load_command(command_text)
|
||
return
|
||
|
||
if cmd_word == "list":
|
||
await self._handle_list_command()
|
||
return
|
||
|
||
if cmd_word == "export":
|
||
await self._handle_export_command(command_text)
|
||
return
|
||
|
||
if cmd_word == "credits":
|
||
await self._handle_credits_command()
|
||
return
|
||
|
||
if cmd_word == "prev":
|
||
await self._handle_prev_command()
|
||
return
|
||
|
||
if cmd_word == "next":
|
||
await self._handle_next_command()
|
||
return
|
||
|
||
if cmd_word == "config":
|
||
await self._handle_config_command(command_text)
|
||
return
|
||
|
||
# Create command context
|
||
context = self.session.get_context()
|
||
context.is_tui = True # Flag for TUI mode
|
||
|
||
# Execute command
|
||
result = registry.execute(command_text, context)
|
||
|
||
# Check if command was found
|
||
if result is None:
|
||
chat_display = self.query_one(ChatDisplay)
|
||
error_widget = SystemMessageWidget(f"❌ Unknown command: {command_text}")
|
||
await chat_display.add_message(error_widget)
|
||
return
|
||
|
||
# Update session state from context
|
||
self.session.memory_enabled = context.memory_enabled
|
||
self.session.memory_start_index = context.memory_start_index
|
||
self.session.online_enabled = context.online_enabled
|
||
self.session.middle_out_enabled = context.middle_out_enabled
|
||
self.session.session_max_token = context.session_max_token
|
||
self.session.current_index = context.current_index
|
||
self.session.system_prompt = context.session_system_prompt
|
||
|
||
# Handle result
|
||
if result.status == CommandStatus.EXIT:
|
||
self.exit()
|
||
elif result.status == CommandStatus.ERROR:
|
||
# Display error in chat
|
||
chat_display = self.query_one(ChatDisplay)
|
||
error_widget = SystemMessageWidget(f"❌ {result.message}")
|
||
await chat_display.add_message(error_widget)
|
||
elif result.message:
|
||
# Display success message
|
||
chat_display = self.query_one(ChatDisplay)
|
||
info_widget = SystemMessageWidget(f"ℹ️ {result.message}")
|
||
await chat_display.add_message(info_widget)
|
||
|
||
# Handle special result data
|
||
if result.data:
|
||
await self._handle_command_data(result.data)
|
||
|
||
# Update footer stats
|
||
self._update_footer()
|
||
|
||
# Update header if model changed
|
||
if self.session.selected_model:
|
||
header = self.query_one(Header)
|
||
provider_name = self.session.client.provider_name if self.session.client else ""
|
||
header.update_model(
|
||
self.session.selected_model.get("name", ""),
|
||
self.session.selected_model,
|
||
provider_name
|
||
)
|
||
|
||
# Update MCP status indicator in input bar
|
||
input_bar = self.query_one(InputBar)
|
||
if self.session.mcp_manager and self.session.mcp_manager.enabled:
|
||
mode = self.session.mcp_manager.mode
|
||
if mode == "files":
|
||
input_bar.update_mcp_status("📁 Files")
|
||
elif mode == "database":
|
||
input_bar.update_mcp_status("🗄️ DB")
|
||
else:
|
||
input_bar.update_mcp_status("")
|
||
|
||
# Update online mode indicator
|
||
input_bar.update_online_mode(self.session.online_enabled)
|
||
|
||
async def handle_message(self, user_input: str) -> None:
|
||
"""Handle a chat message (user message already added by caller)."""
|
||
chat_display = self.query_one(ChatDisplay)
|
||
|
||
# Create assistant message widget with loading indicator
|
||
model_name = self.session.selected_model.get("name", "Assistant") if self.session.selected_model else "Assistant"
|
||
assistant_widget = AssistantMessageWidget(model_name, chat_display=chat_display)
|
||
await chat_display.add_message(assistant_widget)
|
||
|
||
# Show loading indicator immediately
|
||
assistant_widget.set_content("_Thinking..._")
|
||
|
||
try:
|
||
# Stream response
|
||
response_iterator = self.session.send_message_async(
|
||
user_input,
|
||
stream=self.settings.stream_enabled,
|
||
)
|
||
|
||
# Stream and collect response
|
||
full_text, usage = await assistant_widget.stream_response(response_iterator)
|
||
|
||
# Add to history if we got a response
|
||
if full_text:
|
||
# Extract cost from usage or calculate from pricing
|
||
cost = 0.0
|
||
if usage and hasattr(usage, 'total_cost_usd') and usage.total_cost_usd:
|
||
cost = usage.total_cost_usd
|
||
self.notify(f"Cost from API: ${cost:.6f}", severity="information")
|
||
elif usage and self.session.selected_model:
|
||
# Calculate cost from model pricing
|
||
pricing = self.session.selected_model.get("pricing", {})
|
||
prompt_cost = float(pricing.get("prompt", 0))
|
||
completion_cost = float(pricing.get("completion", 0))
|
||
|
||
# Prices are per token, convert to dollars
|
||
prompt_total = usage.prompt_tokens * prompt_cost
|
||
completion_total = usage.completion_tokens * completion_cost
|
||
cost = prompt_total + completion_total
|
||
|
||
if cost > 0:
|
||
self.notify(f"Cost calculated: ${cost:.6f}", severity="information")
|
||
|
||
self.session.add_to_history(
|
||
prompt=user_input,
|
||
response=full_text,
|
||
usage=usage,
|
||
cost=cost,
|
||
)
|
||
|
||
# Update footer
|
||
self._update_footer()
|
||
|
||
except Exception as e:
|
||
assistant_widget.set_content(f"❌ Error: {str(e)}")
|
||
|
||
def _update_footer(self) -> None:
|
||
"""Update footer statistics."""
|
||
footer = self.query_one(Footer)
|
||
footer.update_stats(
|
||
tokens_in=self.session.stats.total_input_tokens,
|
||
tokens_out=self.session.stats.total_output_tokens,
|
||
cost=self.session.stats.total_cost,
|
||
messages=len(self.session.history),
|
||
)
|
||
|
||
async def _handle_save_command(self, command_text: str) -> None:
|
||
"""Handle /save command."""
|
||
from oai.config.database import Database
|
||
|
||
# Get name from command or ask user
|
||
args = command_text.split(maxsplit=1)
|
||
name = args[1] if len(args) > 1 else None
|
||
|
||
if not name:
|
||
# Prompt for name
|
||
def handle_save_input(input_name: Optional[str]) -> None:
|
||
if input_name:
|
||
self._save_conversation(input_name)
|
||
|
||
self.push_screen(
|
||
InputDialog(
|
||
"Enter a name for this conversation:",
|
||
"Save Conversation",
|
||
placeholder="conversation-name"
|
||
),
|
||
callback=handle_save_input
|
||
)
|
||
else:
|
||
self._save_conversation(name)
|
||
|
||
def _save_conversation(self, name: str) -> None:
|
||
"""Actually save the conversation."""
|
||
from oai.config.database import Database
|
||
|
||
if not self.session.history:
|
||
self.notify("No conversation to save", severity="warning")
|
||
return
|
||
|
||
try:
|
||
db = Database()
|
||
db.save_conversation(name, self.session.history)
|
||
self.notify(f"✓ Saved as '{name}'", severity="information")
|
||
except Exception as e:
|
||
self.notify(f"Error saving: {e}", severity="error")
|
||
|
||
async def _handle_load_command(self, command_text: str) -> None:
|
||
"""Handle /load command."""
|
||
from oai.config.database import Database
|
||
|
||
# Get name from command or show list
|
||
args = command_text.split(maxsplit=1)
|
||
name = args[1] if len(args) > 1 else None
|
||
|
||
db = Database()
|
||
conversations = db.list_conversations()
|
||
|
||
if not conversations:
|
||
self.notify("No saved conversations", severity="warning")
|
||
return
|
||
|
||
if name:
|
||
# Load by name or number
|
||
if name.isdigit():
|
||
index = int(name) - 1
|
||
if 0 <= index < len(conversations):
|
||
name = conversations[index]["name"]
|
||
else:
|
||
self.notify(f"Invalid number: {name}", severity="error")
|
||
return
|
||
|
||
self._load_conversation(name)
|
||
else:
|
||
# Always show selector popup for better UX
|
||
def handle_selection(selected: Optional[dict]) -> None:
|
||
if selected:
|
||
self._load_conversation(selected["name"])
|
||
|
||
self.push_screen(
|
||
ConversationSelectorScreen(conversations),
|
||
callback=handle_selection
|
||
)
|
||
|
||
def _load_conversation(self, name: str) -> None:
|
||
"""Actually load a conversation."""
|
||
from oai.config.database import Database
|
||
|
||
try:
|
||
db = Database()
|
||
history = db.load_conversation(name)
|
||
|
||
if history:
|
||
self.session.history = history
|
||
self.session.current_index = len(history) - 1
|
||
self._update_footer()
|
||
self.notify(f"✓ Loaded '{name}' ({len(history)} messages)", severity="information")
|
||
else:
|
||
self.notify(f"Conversation '{name}' not found", severity="error")
|
||
except Exception as e:
|
||
self.notify(f"Error loading: {e}", severity="error")
|
||
|
||
async def _handle_list_command(self) -> None:
|
||
"""Handle /list command - show conversation selector."""
|
||
from oai.config.database import Database
|
||
|
||
db = Database()
|
||
conversations = db.list_conversations()
|
||
|
||
if not conversations:
|
||
self.notify("No saved conversations", severity="information")
|
||
return
|
||
|
||
# Show selector popup
|
||
def handle_selection(selected: Optional[dict]) -> None:
|
||
if selected:
|
||
self._load_conversation(selected["name"])
|
||
|
||
self.push_screen(
|
||
ConversationSelectorScreen(conversations),
|
||
callback=handle_selection
|
||
)
|
||
|
||
async def _handle_export_command(self, command_text: str) -> None:
|
||
"""Handle /export command."""
|
||
from oai.utils.export import export_as_html, export_as_json, export_as_markdown
|
||
from pathlib import Path
|
||
|
||
parts = command_text.split(maxsplit=2)
|
||
|
||
if len(parts) < 2:
|
||
self.notify("Usage: /export <md|json|html> [filename]", severity="warning")
|
||
return
|
||
|
||
fmt = parts[1].lower()
|
||
filename = parts[2] if len(parts) > 2 else None
|
||
|
||
if fmt not in ["md", "json", "html"]:
|
||
self.notify(f"Unknown format: {fmt}. Use md, json, or html", severity="error")
|
||
return
|
||
|
||
if not self.session.history:
|
||
self.notify("No conversation to export", severity="warning")
|
||
return
|
||
|
||
# Generate default filename if not provided
|
||
if not filename:
|
||
import time
|
||
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||
extensions = {"md": "md", "json": "json", "html": "html"}
|
||
filename = f"conversation_{timestamp}.{extensions[fmt]}"
|
||
|
||
try:
|
||
if fmt == "md":
|
||
content = export_as_markdown(self.session.history)
|
||
elif fmt == "json":
|
||
content = export_as_json(self.session.history)
|
||
else: # html
|
||
content = export_as_html(self.session.history)
|
||
|
||
# Write to file
|
||
Path(filename).write_text(content, encoding="utf-8")
|
||
self.notify(f"✓ Exported to {filename}", severity="information")
|
||
|
||
except Exception as e:
|
||
self.notify(f"Export failed: {e}", severity="error")
|
||
|
||
async def _handle_credits_command(self) -> None:
|
||
"""Handle /credits command."""
|
||
try:
|
||
credits = self.session.client.provider.get_credits()
|
||
if not credits:
|
||
self.notify("Failed to fetch credits", severity="error")
|
||
return
|
||
|
||
# Show in chat
|
||
chat_display = self.query_one(ChatDisplay)
|
||
|
||
msg = "**💰 Account Credits**\n\n"
|
||
msg += f"**Total Credits:** {credits.get('total_credits_formatted', 'N/A')}\n"
|
||
msg += f"**Used Credits:** {credits.get('used_credits_formatted', 'N/A')}\n"
|
||
msg += f"**Credits Left:** {credits.get('credits_left_formatted', 'N/A')}\n"
|
||
|
||
# Check for low credits warning
|
||
credits_left = credits.get('credits_left', 0)
|
||
if credits_left < 1.0:
|
||
msg += f"\n⚠️ **Low credits!** Consider adding more."
|
||
|
||
info_widget = UserMessageWidget(msg)
|
||
await chat_display.add_message(info_widget)
|
||
|
||
except Exception as e:
|
||
self.notify(f"Error fetching credits: {e}", severity="error")
|
||
|
||
async def _handle_prev_command(self) -> None:
|
||
"""Handle /prev command - show previous message."""
|
||
if not self.session.history:
|
||
self.notify("No history available", severity="warning")
|
||
return
|
||
|
||
if self.session.current_index > 0:
|
||
self.session.current_index -= 1
|
||
entry = self.session.history[self.session.current_index]
|
||
|
||
chat_display = self.query_one(ChatDisplay)
|
||
msg = f"**[Message {self.session.current_index + 1}/{len(self.session.history)}]**\n\n"
|
||
msg += f"**You:** {entry.prompt}\n\n**Assistant:** {entry.response}"
|
||
|
||
info_widget = UserMessageWidget(msg)
|
||
await chat_display.add_message(info_widget)
|
||
else:
|
||
self.notify("Already at first message", severity="information")
|
||
|
||
async def _handle_next_command(self) -> None:
|
||
"""Handle /next command - show next message."""
|
||
if not self.session.history:
|
||
self.notify("No history available", severity="warning")
|
||
return
|
||
|
||
if self.session.current_index < len(self.session.history) - 1:
|
||
self.session.current_index += 1
|
||
entry = self.session.history[self.session.current_index]
|
||
|
||
chat_display = self.query_one(ChatDisplay)
|
||
msg = f"**[Message {self.session.current_index + 1}/{len(self.session.history)}]**\n\n"
|
||
msg += f"**You:** {entry.prompt}\n\n**Assistant:** {entry.response}"
|
||
|
||
info_widget = UserMessageWidget(msg)
|
||
await chat_display.add_message(info_widget)
|
||
else:
|
||
self.notify("Already at last message", severity="information")
|
||
|
||
async def _handle_config_command(self, command_text: str) -> None:
|
||
"""Handle /config command with TUI dialogs."""
|
||
parts = command_text.split(maxsplit=2)
|
||
|
||
if len(parts) == 1:
|
||
# /config with no args - show all settings
|
||
from oai.constants import DEFAULT_SYSTEM_PROMPT
|
||
settings = self.settings
|
||
|
||
msg = "**Configuration**\n\n"
|
||
msg += f"**API Key:** {'***' + settings.api_key[-4:] if settings.api_key else 'Not set'}\n"
|
||
msg += f"**Base URL:** {settings.base_url}\n"
|
||
msg += f"**Default Model:** {settings.default_model or 'Not set'}\n"
|
||
|
||
# System prompt
|
||
if settings.default_system_prompt is None:
|
||
system_display = f"[default] {DEFAULT_SYSTEM_PROMPT[:40]}..."
|
||
elif settings.default_system_prompt == "":
|
||
system_display = "[blank]"
|
||
else:
|
||
system_display = settings.default_system_prompt[:50] + "..." if len(settings.default_system_prompt) > 50 else settings.default_system_prompt
|
||
msg += f"**System Prompt:** {system_display}\n"
|
||
|
||
msg += f"**Streaming:** {'on' if settings.stream_enabled else 'off'}\n"
|
||
msg += f"**Cost Warning:** ${settings.cost_warning_threshold:.4f}\n"
|
||
msg += f"**Max Tokens:** {settings.max_tokens}\n"
|
||
msg += f"**Default Online:** {'on' if settings.default_online_mode else 'off'}\n"
|
||
msg += f"**Log Level:** {settings.log_level}"
|
||
|
||
chat_display = self.query_one(ChatDisplay)
|
||
info_widget = UserMessageWidget(msg)
|
||
await chat_display.add_message(info_widget)
|
||
return
|
||
|
||
setting = parts[1].lower()
|
||
|
||
if setting == "api":
|
||
# Show input dialog for API key
|
||
def handle_api_key(api_key: Optional[str]) -> None:
|
||
if api_key:
|
||
self.settings.set_api_key(api_key)
|
||
async def show_success():
|
||
chat_display = self.query_one(ChatDisplay)
|
||
info_widget = UserMessageWidget("✓ API key updated")
|
||
await chat_display.add_message(info_widget)
|
||
self.call_later(show_success)
|
||
|
||
self.push_screen(
|
||
InputDialog(
|
||
"Enter your OpenRouter API key:",
|
||
"Configure API Key",
|
||
placeholder="sk-or-...",
|
||
),
|
||
callback=handle_api_key
|
||
)
|
||
return
|
||
|
||
# For other config commands, use the generic handler
|
||
context = self.session.get_context()
|
||
context.is_tui = True
|
||
result = registry.execute(command_text, context)
|
||
|
||
if result and result.message:
|
||
chat_display = self.query_one(ChatDisplay)
|
||
info_widget = UserMessageWidget(result.message)
|
||
await chat_display.add_message(info_widget)
|
||
|
||
# Handle special command data (e.g., show_model_selector)
|
||
if result and result.data:
|
||
await self._handle_command_data(result.data)
|
||
|
||
async def _handle_command_data(self, data: dict) -> None:
|
||
"""Handle special command result data."""
|
||
# Model selection
|
||
if "show_model_selector" in data:
|
||
search = data.get("search", "")
|
||
set_as_default = data.get("set_as_default", False)
|
||
self._show_model_selector(search, set_as_default)
|
||
|
||
# Retry prompt
|
||
elif "retry_prompt" in data:
|
||
await self.handle_message(data["retry_prompt"])
|
||
|
||
# Paste prompt
|
||
elif "paste_prompt" in data:
|
||
await self.handle_message(data["paste_prompt"])
|
||
|
||
def _show_model_selector(self, search: str = "", set_as_default: bool = False) -> None:
|
||
"""Show the model selector screen."""
|
||
def handle_model_selection(selected: Optional[dict]) -> None:
|
||
"""Handle the model selection result."""
|
||
if selected:
|
||
self.session.set_model(selected)
|
||
header = self.query_one(Header)
|
||
provider_name = self.session.client.provider_name if self.session.client else ""
|
||
header.update_model(selected.get("name", ""), selected, provider_name)
|
||
|
||
# Save this model as the last used for this provider
|
||
model_id = selected.get("id")
|
||
if model_id and provider_name:
|
||
self.settings.set_provider_model(provider_name, model_id)
|
||
|
||
# Save as default if requested
|
||
if set_as_default:
|
||
self.settings.set_default_model(selected["id"])
|
||
|
||
# Show confirmation in chat
|
||
async def add_confirmation():
|
||
chat_display = self.query_one(ChatDisplay)
|
||
if set_as_default:
|
||
info_widget = UserMessageWidget(f"✓ Default model set to: {selected['id']}")
|
||
else:
|
||
info_widget = UserMessageWidget(f"✓ Model changed to: {selected['id']}")
|
||
await chat_display.add_message(info_widget)
|
||
|
||
self.call_later(add_confirmation)
|
||
|
||
try:
|
||
# Get all models
|
||
models = self.session.client.provider.get_raw_models()
|
||
|
||
if not models:
|
||
self.push_screen(
|
||
AlertDialog("No models available", "Error", variant="error")
|
||
)
|
||
return
|
||
|
||
# Filter by search if provided
|
||
if search:
|
||
search_lower = search.lower()
|
||
models = [
|
||
m for m in models
|
||
if search_lower in m.get("id", "").lower()
|
||
or search_lower in m.get("name", "").lower()
|
||
]
|
||
|
||
if not models:
|
||
self.push_screen(
|
||
AlertDialog(f"No models found matching '{search}'", "Error", variant="error")
|
||
)
|
||
return
|
||
|
||
# Get current model ID
|
||
current_model = None
|
||
if self.session.selected_model:
|
||
current_model = self.session.selected_model.get("id")
|
||
|
||
# Show selector with callback
|
||
self.push_screen(
|
||
ModelSelectorScreen(models, current_model),
|
||
callback=handle_model_selection
|
||
)
|
||
|
||
except Exception as e:
|
||
self.push_screen(
|
||
AlertDialog(f"Error loading models: {str(e)}", "Error", variant="error")
|
||
)
|
||
|
||
async def show_confirm(self, message: str, title: str = "Confirm") -> bool:
|
||
"""Show a confirmation dialog and return the result."""
|
||
result = await self.push_screen_wait(ConfirmDialog(message, title))
|
||
return result
|
||
|
||
async def show_input(
|
||
self,
|
||
message: str,
|
||
title: str = "Input",
|
||
default: str = "",
|
||
placeholder: str = ""
|
||
) -> Optional[str]:
|
||
"""Show an input dialog and return the entered text."""
|
||
result = await self.push_screen_wait(
|
||
InputDialog(message, title, default, placeholder)
|
||
)
|
||
return result
|
||
|
||
async def show_alert(self, message: str, title: str = "Alert", variant: str = "default") -> None:
|
||
"""Show an alert dialog."""
|
||
await self.push_screen_wait(AlertDialog(message, title, variant))
|
||
|
||
def action_quit(self) -> None:
|
||
"""Quit the application."""
|
||
self.exit()
|
||
|
||
def action_show_model_selector(self) -> None:
|
||
"""Action to show model selector."""
|
||
self._show_model_selector()
|
||
|
||
def action_show_help(self) -> None:
|
||
"""Action to show help screen."""
|
||
try:
|
||
self.push_screen(HelpScreen())
|
||
except Exception as e:
|
||
self.notify(f"Error showing help: {e}", severity="error")
|
||
|
||
def action_show_stats(self) -> None:
|
||
"""Action to show statistics screen."""
|
||
try:
|
||
self.push_screen(StatsScreen(self.session))
|
||
except Exception as e:
|
||
self.notify(f"Error showing stats: {e}", severity="error")
|
||
|
||
def action_clear_chat(self) -> None:
|
||
"""Action to clear chat display."""
|
||
def handle_confirmation(confirmed: bool) -> None:
|
||
if confirmed:
|
||
chat_display = self.query_one(ChatDisplay)
|
||
chat_display.clear_messages()
|
||
|
||
self.push_screen(
|
||
ConfirmDialog(
|
||
"Clear chat display? (History will be preserved)",
|
||
"Clear Chat"
|
||
),
|
||
callback=handle_confirmation
|
||
)
|
||
|
||
def _navigate_history_backward(self, input_widget: Input) -> None:
|
||
"""Navigate backward through input history (Up arrow)."""
|
||
if not self.input_history:
|
||
return
|
||
|
||
# Set flag to prevent dropdown from showing
|
||
self._navigating_history = True
|
||
|
||
# If we're at the start of history, go to the last item
|
||
if self.history_index == -1:
|
||
self.history_index = len(self.input_history) - 1
|
||
# Otherwise, go back one (unless we're already at the oldest)
|
||
elif self.history_index > 0:
|
||
self.history_index -= 1
|
||
|
||
# Update input with history item
|
||
if 0 <= self.history_index < len(self.input_history):
|
||
input_widget.value = self.input_history[self.history_index]
|
||
# Move cursor to end
|
||
input_widget.cursor_position = len(input_widget.value)
|
||
|
||
# Clear flag after a short delay
|
||
self.set_timer(0.1, lambda: setattr(self, "_navigating_history", False))
|
||
|
||
def _navigate_history_forward(self, input_widget: Input) -> None:
|
||
"""Navigate forward through input history (Down arrow)."""
|
||
if not self.input_history or self.history_index == -1:
|
||
return
|
||
|
||
# Set flag to prevent dropdown from showing
|
||
self._navigating_history = True
|
||
|
||
# Move forward in history
|
||
if self.history_index < len(self.input_history) - 1:
|
||
self.history_index += 1
|
||
input_widget.value = self.input_history[self.history_index]
|
||
input_widget.cursor_position = len(input_widget.value)
|
||
else:
|
||
# At the newest item, clear the input
|
||
self.history_index = -1
|
||
input_widget.value = ""
|
||
|
||
# Clear flag after a short delay
|
||
self.set_timer(0.1, lambda: setattr(self, "_navigating_history", False))
|
||
|
||
async def _handle_prev_command(self) -> None:
|
||
"""Handle Ctrl+P to show previous message."""
|
||
await self.handle_command("/prev")
|
||
|
||
async def _handle_next_command(self) -> None:
|
||
"""Handle Ctrl+N to show next message."""
|
||
await self.handle_command("/next")
|
||
|
||
def action_copy_last_response(self) -> None:
|
||
"""Copy the last AI response to clipboard."""
|
||
try:
|
||
chat_display = self.query_one(ChatDisplay)
|
||
|
||
# Find the last AssistantMessageWidget
|
||
assistant_widgets = [
|
||
child for child in chat_display.children
|
||
if isinstance(child, AssistantMessageWidget)
|
||
]
|
||
|
||
if not assistant_widgets:
|
||
self.notify("No AI responses to copy", severity="warning")
|
||
return
|
||
|
||
# Get the last assistant message
|
||
last_assistant = assistant_widgets[-1]
|
||
text = last_assistant.full_text
|
||
|
||
if not text:
|
||
self.notify("Last response is empty", severity="warning")
|
||
return
|
||
|
||
# Copy to clipboard
|
||
pyperclip.copy(text)
|
||
|
||
# Show success notification
|
||
preview = text[:50] + "..." if len(text) > 50 else text
|
||
self.notify(f"✓ Copied: {preview}", severity="information")
|
||
|
||
except Exception as e:
|
||
self.notify(f"Copy failed: {e}", severity="error")
|