Files
oai/oai/core/session.py
2026-02-06 09:48:37 +01:00

1016 lines
34 KiB
Python

"""
Chat session management for oAI.
This module provides the ChatSession class that manages an interactive
chat session including history, state, and message handling.
"""
import asyncio
import json
import time
from dataclasses import dataclass, field
from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Optional, Tuple
from oai.commands.registry import CommandContext, CommandResult, registry
from oai.config.database import Database
from oai.config.settings import Settings
from oai.constants import (
COST_WARNING_THRESHOLD,
LOW_CREDIT_AMOUNT,
LOW_CREDIT_RATIO,
)
from oai.core.client import AIClient
from oai.mcp.manager import MCPManager
from oai.providers.base import ChatResponse, StreamChunk, UsageStats
from oai.utils.logging import get_logger
from oai.utils.web_search import perform_web_search, format_search_results
logger = get_logger()
@dataclass
class SessionStats:
"""
Statistics for the current session.
Tracks tokens, costs, and message counts.
"""
total_input_tokens: int = 0
total_output_tokens: int = 0
total_cost: float = 0.0
message_count: int = 0
@property
def total_tokens(self) -> int:
"""Get total token count."""
return self.total_input_tokens + self.total_output_tokens
def add_usage(self, usage: Optional[UsageStats], cost: float = 0.0) -> None:
"""
Add usage stats from a response.
Args:
usage: Usage statistics
cost: Cost if not in usage
"""
if usage:
self.total_input_tokens += usage.prompt_tokens
self.total_output_tokens += usage.completion_tokens
if usage.total_cost_usd:
self.total_cost += usage.total_cost_usd
else:
self.total_cost += cost
else:
self.total_cost += cost
self.message_count += 1
@dataclass
class HistoryEntry:
"""
A single entry in the conversation history.
Stores the user prompt, assistant response, and metrics.
"""
prompt: str
response: str
prompt_tokens: int = 0
completion_tokens: int = 0
msg_cost: float = 0.0
timestamp: Optional[float] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary format."""
return {
"prompt": self.prompt,
"response": self.response,
"prompt_tokens": self.prompt_tokens,
"completion_tokens": self.completion_tokens,
"msg_cost": self.msg_cost,
}
class ChatSession:
"""
Manages an interactive chat session.
Handles conversation history, state management, command processing,
and communication with the AI client.
Attributes:
client: AI client for API requests
settings: Application settings
mcp_manager: MCP manager for file/database access
history: Conversation history
stats: Session statistics
"""
def __init__(
self,
client: AIClient,
settings: Settings,
mcp_manager: Optional[MCPManager] = None,
):
"""
Initialize a chat session.
Args:
client: AI client instance
settings: Application settings
mcp_manager: Optional MCP manager
"""
self.client = client
self.settings = settings
self.mcp_manager = mcp_manager
self.db = Database()
self.history: List[HistoryEntry] = []
self.stats = SessionStats()
# Session state
self.system_prompt: str = settings.effective_system_prompt
self.memory_enabled: bool = True
self.memory_start_index: int = 0
self.online_enabled: bool = settings.default_online_mode
self.middle_out_enabled: bool = False
self.session_max_token: int = 0
self.current_index: int = 0
# Selected model
self.selected_model: Optional[Dict[str, Any]] = None
self.logger = get_logger()
def get_context(self) -> CommandContext:
"""
Get the current command context.
Returns:
CommandContext with current session state
"""
return CommandContext(
settings=self.settings,
provider=self.client.provider,
mcp_manager=self.mcp_manager,
selected_model_raw=self.selected_model,
session_history=[e.to_dict() for e in self.history],
session_system_prompt=self.system_prompt,
memory_enabled=self.memory_enabled,
memory_start_index=self.memory_start_index,
online_enabled=self.online_enabled,
middle_out_enabled=self.middle_out_enabled,
session_max_token=self.session_max_token,
total_input_tokens=self.stats.total_input_tokens,
total_output_tokens=self.stats.total_output_tokens,
total_cost=self.stats.total_cost,
message_count=self.stats.message_count,
current_index=self.current_index,
session=self,
)
def set_model(self, model: Dict[str, Any]) -> None:
"""
Set the selected model.
Args:
model: Raw model dictionary
"""
self.selected_model = model
self.client.set_default_model(model["id"])
self.logger.info(f"Model selected: {model['id']}")
def build_api_messages(self, user_input: str) -> List[Dict[str, Any]]:
"""
Build the messages array for an API request.
Includes system prompt, history (if memory enabled), and current input.
Args:
user_input: Current user input
Returns:
List of message dictionaries
"""
messages = []
# Add system prompt
if self.system_prompt:
messages.append({"role": "system", "content": self.system_prompt})
# Add database context if in database mode
if self.mcp_manager and self.mcp_manager.enabled:
if self.mcp_manager.mode == "database" and self.mcp_manager.selected_db_index is not None:
db = self.mcp_manager.databases[self.mcp_manager.selected_db_index]
db_context = (
f"You are connected to SQLite database: {db['name']}\n"
f"Available tables: {', '.join(db['tables'])}\n\n"
"Use inspect_database, search_database, or query_database tools. "
"All queries are read-only."
)
messages.append({"role": "system", "content": db_context})
# Add history if memory enabled
if self.memory_enabled:
for i in range(self.memory_start_index, len(self.history)):
entry = self.history[i]
messages.append({"role": "user", "content": entry.prompt})
messages.append({"role": "assistant", "content": entry.response})
# Add current message
messages.append({"role": "user", "content": user_input})
return messages
def get_mcp_tools(self) -> Optional[List[Dict[str, Any]]]:
"""
Get MCP tool definitions if available.
Returns:
List of tool schemas or None
"""
if not self.mcp_manager or not self.mcp_manager.enabled:
return None
if not self.selected_model:
return None
# Check if model supports tools
supported_params = self.selected_model.get("supported_parameters", [])
if "tools" not in supported_params and "functions" not in supported_params:
return None
return self.mcp_manager.get_tools_schema()
async def execute_tool(
self,
tool_name: str,
tool_args: Dict[str, Any],
) -> Dict[str, Any]:
"""
Execute an MCP tool.
Args:
tool_name: Name of the tool
tool_args: Tool arguments
Returns:
Tool execution result
"""
if not self.mcp_manager:
return {"error": "MCP not available"}
return await self.mcp_manager.call_tool(tool_name, **tool_args)
def send_message(
self,
user_input: str,
stream: bool = True,
on_stream_chunk: Optional[Callable[[str], None]] = None,
) -> Tuple[str, Optional[UsageStats], float]:
"""
Send a message and get a response.
Args:
user_input: User's input text
stream: Whether to stream the response
on_stream_chunk: Callback for stream chunks
Returns:
Tuple of (response_text, usage_stats, response_time)
"""
if not self.selected_model:
raise ValueError("No model selected")
start_time = time.time()
messages = self.build_api_messages(user_input)
# Get MCP tools
tools = self.get_mcp_tools()
if tools:
# Disable streaming when tools are present
stream = False
# Build request parameters
model_id = self.selected_model["id"]
# Handle online mode
enable_web_search = False
web_search_config = {}
if self.online_enabled:
# OpenRouter handles online mode natively with :online suffix
if self.client.provider_name == "openrouter":
if hasattr(self.client.provider, "get_effective_model_id"):
model_id = self.client.provider.get_effective_model_id(model_id, True)
# Anthropic has native web search when search provider is set to anthropic_native
elif self.client.provider_name == "anthropic" and self.settings.search_provider == "anthropic_native":
enable_web_search = True
web_search_config = {
"max_uses": self.settings.search_num_results or 5
}
logger.info("Using Anthropic native web search")
else:
# For other providers, perform web search and inject results
logger.info(f"Performing web search for: {user_input}")
search_results = perform_web_search(
user_input,
num_results=self.settings.search_num_results,
provider=self.settings.search_provider,
google_api_key=self.settings.google_api_key,
google_search_engine_id=self.settings.google_search_engine_id
)
if search_results:
# Inject search results into messages
formatted_results = format_search_results(search_results)
search_context = f"\n\n{formatted_results}\n\nPlease use the above web search results to help answer the user's question."
# Add search results to the last user message
if messages and messages[-1]["role"] == "user":
messages[-1]["content"] += search_context
logger.info(f"Injected {len(search_results)} search results into context")
if self.online_enabled:
if hasattr(self.client.provider, "get_effective_model_id"):
model_id = self.client.provider.get_effective_model_id(model_id, True)
transforms = ["middle-out"] if self.middle_out_enabled else None
max_tokens = None
if self.session_max_token > 0:
max_tokens = self.session_max_token
if tools:
# Use tool handling flow
response = self._send_with_tools(
messages=messages,
model_id=model_id,
tools=tools,
max_tokens=max_tokens,
transforms=transforms,
)
response_time = time.time() - start_time
return response.content or "", response.usage, response_time
elif stream:
# Use streaming flow
full_text, usage = self._stream_response(
messages=messages,
model_id=model_id,
max_tokens=max_tokens,
transforms=transforms,
on_chunk=on_stream_chunk,
enable_web_search=enable_web_search,
web_search_config=web_search_config,
)
response_time = time.time() - start_time
return full_text, usage, response_time
else:
# Non-streaming request
response = self.client.chat(
messages=messages,
model=model_id,
stream=False,
max_tokens=max_tokens,
transforms=transforms,
)
response_time = time.time() - start_time
if isinstance(response, ChatResponse):
return response.content or "", response.usage, response_time
else:
return "", None, response_time
def _send_with_tools(
self,
messages: List[Dict[str, Any]],
model_id: str,
tools: List[Dict[str, Any]],
max_tokens: Optional[int] = None,
transforms: Optional[List[str]] = None,
) -> ChatResponse:
"""
Send a request with tool call handling.
Args:
messages: API messages
model_id: Model ID
tools: Tool definitions
max_tokens: Max tokens
transforms: Transforms list
Returns:
Final ChatResponse
"""
max_loops = 5
loop_count = 0
api_messages = list(messages)
while loop_count < max_loops:
response = self.client.chat(
messages=api_messages,
model=model_id,
stream=False,
max_tokens=max_tokens,
tools=tools,
tool_choice="auto",
transforms=transforms,
)
if not isinstance(response, ChatResponse):
raise ValueError("Expected ChatResponse")
tool_calls = response.tool_calls
if not tool_calls:
return response
# Tool calls requested by AI
tool_results = []
for tc in tool_calls:
try:
args = json.loads(tc.function.arguments)
except json.JSONDecodeError as e:
self.logger.error(f"Failed to parse tool arguments: {e}")
tool_results.append({
"tool_call_id": tc.id,
"role": "tool",
"name": tc.function.name,
"content": json.dumps({"error": f"Invalid arguments: {e}"}),
})
continue
# Display tool call
args_display = ", ".join(
f'{k}="{v}"' if isinstance(v, str) else f"{k}={v}"
for k, v in args.items()
)
# Executing tool: {tc.function.name}
# Execute tool
result = asyncio.run(self.execute_tool(tc.function.name, args))
if "error" in result:
# Tool execution error logged
pass
else:
# Tool execution successful
pass
tool_results.append({
"tool_call_id": tc.id,
"role": "tool",
"name": tc.function.name,
"content": json.dumps(result),
})
# Add assistant message with tool calls
api_messages.append({
"role": "assistant",
"content": response.content,
"tool_calls": [
{
"id": tc.id,
"type": tc.type,
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments,
},
}
for tc in tool_calls
],
})
api_messages.extend(tool_results)
# Processing tool results
loop_count += 1
self.logger.warning(f"Reached max tool loops ({max_loops})")
return response
def _stream_response(
self,
messages: List[Dict[str, Any]],
model_id: str,
max_tokens: Optional[int] = None,
transforms: Optional[List[str]] = None,
on_chunk: Optional[Callable[[str], None]] = None,
enable_web_search: bool = False,
web_search_config: Optional[Dict[str, Any]] = None,
) -> Tuple[str, Optional[UsageStats]]:
"""
Stream a response with live display.
Args:
messages: API messages
model_id: Model ID
max_tokens: Max tokens
transforms: Transforms
on_chunk: Callback for chunks
enable_web_search: Whether to enable Anthropic native web search
web_search_config: Web search configuration
Returns:
Tuple of (full_text, usage)
"""
# Build chat parameters
chat_params = {
"messages": messages,
"model": model_id,
"stream": True,
"max_tokens": max_tokens,
"transforms": transforms,
}
# Only pass web search params to Anthropic provider
if self.client.provider_name == "anthropic":
chat_params["enable_web_search"] = enable_web_search
chat_params["web_search_config"] = web_search_config or {}
response = self.client.chat(**chat_params)
if isinstance(response, ChatResponse):
return response.content or "", response.usage
full_text = ""
usage: Optional[UsageStats] = None
try:
for chunk in response:
if chunk.error:
self.logger.error(f"Stream error: {chunk.error}")
break
if chunk.delta_content:
full_text += chunk.delta_content
if on_chunk:
on_chunk(chunk.delta_content)
if chunk.usage:
usage = chunk.usage
except KeyboardInterrupt:
self.logger.info("Streaming interrupted")
return "", None
return full_text, usage
# ========== ASYNC METHODS FOR TUI ==========
async def send_message_async(
self,
user_input: str,
stream: bool = True,
) -> AsyncIterator[StreamChunk]:
"""
Async version of send_message for Textual TUI.
Args:
user_input: User's input text
stream: Whether to stream the response
Yields:
StreamChunk objects for progressive display
"""
if not self.selected_model:
raise ValueError("No model selected")
messages = self.build_api_messages(user_input)
tools = self.get_mcp_tools()
if tools:
# Disable streaming when tools are present
stream = False
# Handle online mode
model_id = self.selected_model["id"]
enable_web_search = False
web_search_config = {}
if self.online_enabled:
# OpenRouter handles online mode natively with :online suffix
if self.client.provider_name == "openrouter":
if hasattr(self.client.provider, "get_effective_model_id"):
model_id = self.client.provider.get_effective_model_id(model_id, True)
# Anthropic has native web search when search provider is set to anthropic_native
elif self.client.provider_name == "anthropic" and self.settings.search_provider == "anthropic_native":
enable_web_search = True
web_search_config = {
"max_uses": self.settings.search_num_results or 5
}
logger.info("Using Anthropic native web search")
else:
# For other providers, perform web search and inject results
logger.info(f"Performing web search for: {user_input}")
search_results = await asyncio.to_thread(
perform_web_search,
user_input,
num_results=self.settings.search_num_results,
provider=self.settings.search_provider,
google_api_key=self.settings.google_api_key,
google_search_engine_id=self.settings.google_search_engine_id
)
if search_results:
# Inject search results into messages
formatted_results = format_search_results(search_results)
search_context = f"\n\n{formatted_results}\n\nPlease use the above web search results to help answer the user's question."
# Add search results to the last user message
if messages and messages[-1]["role"] == "user":
messages[-1]["content"] += search_context
logger.info(f"Injected {len(search_results)} search results into context")
transforms = ["middle-out"] if self.middle_out_enabled else None
max_tokens = None
if self.session_max_token > 0:
max_tokens = self.session_max_token
if tools:
# Use async tool handling flow
async for chunk in self._send_with_tools_async(
messages=messages,
model_id=model_id,
tools=tools,
max_tokens=max_tokens,
transforms=transforms,
):
yield chunk
elif stream:
# Use async streaming flow
async for chunk in self._stream_response_async(
messages=messages,
model_id=model_id,
max_tokens=max_tokens,
transforms=transforms,
enable_web_search=enable_web_search,
web_search_config=web_search_config,
):
yield chunk
else:
# Non-streaming request - run in thread to avoid blocking event loop
response = await asyncio.to_thread(
self.client.chat,
messages=messages,
model=model_id,
stream=False,
max_tokens=max_tokens,
transforms=transforms,
)
if isinstance(response, ChatResponse):
# Yield single chunk with complete response
chunk = StreamChunk(
id="",
delta_content=response.content,
usage=response.usage,
error=None,
)
yield chunk
async def _send_with_tools_async(
self,
messages: List[Dict[str, Any]],
model_id: str,
tools: List[Dict[str, Any]],
max_tokens: Optional[int] = None,
transforms: Optional[List[str]] = None,
) -> AsyncIterator[StreamChunk]:
"""
Async version of _send_with_tools for TUI.
Args:
messages: API messages
model_id: Model ID
tools: Tool definitions
max_tokens: Max tokens
transforms: Transforms list
Yields:
StreamChunk objects including tool call notifications
"""
max_loops = 5
loop_count = 0
api_messages = list(messages)
while loop_count < max_loops:
# Run in thread to avoid blocking event loop
response = await asyncio.to_thread(
self.client.chat,
messages=api_messages,
model=model_id,
stream=False,
max_tokens=max_tokens,
tools=tools,
tool_choice="auto",
transforms=transforms,
)
if not isinstance(response, ChatResponse):
raise ValueError("Expected ChatResponse")
tool_calls = response.tool_calls
if not tool_calls:
# Final response, yield it
chunk = StreamChunk(
id="",
delta_content=response.content,
usage=response.usage,
error=None,
)
yield chunk
return
# Yield notification about tool calls
tool_notification = f"\n🔧 AI requesting {len(tool_calls)} tool call(s)...\n"
yield StreamChunk(id="", delta_content=tool_notification, usage=None, error=None)
tool_results = []
for tc in tool_calls:
try:
args = json.loads(tc.function.arguments)
except json.JSONDecodeError as e:
self.logger.error(f"Failed to parse tool arguments: {e}")
tool_results.append({
"tool_call_id": tc.id,
"role": "tool",
"name": tc.function.name,
"content": json.dumps({"error": f"Invalid arguments: {e}"}),
})
continue
# Yield tool call display
args_display = ", ".join(
f'{k}="{v}"' if isinstance(v, str) else f"{k}={v}"
for k, v in args.items()
)
tool_display = f"{tc.function.name}({args_display})\n"
yield StreamChunk(id="", delta_content=tool_display, usage=None, error=None)
# Execute tool (await instead of asyncio.run)
result = await self.execute_tool(tc.function.name, args)
if "error" in result:
error_msg = f" ✗ Error: {result['error']}\n"
yield StreamChunk(id="", delta_content=error_msg, usage=None, error=None)
else:
success_msg = self._format_tool_success(tc.function.name, result)
yield StreamChunk(id="", delta_content=success_msg, usage=None, error=None)
tool_results.append({
"tool_call_id": tc.id,
"role": "tool",
"name": tc.function.name,
"content": json.dumps(result),
})
# Add assistant message with tool calls
api_messages.append({
"role": "assistant",
"content": response.content,
"tool_calls": [
{
"id": tc.id,
"type": tc.type,
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments,
},
}
for tc in tool_calls
],
})
# Add tool results
api_messages.extend(tool_results)
loop_count += 1
# Max loops reached
yield StreamChunk(
id="",
delta_content="\n⚠️ Maximum tool call loops reached\n",
usage=None,
error="Max loops reached"
)
def _format_tool_success(self, tool_name: str, result: Dict[str, Any]) -> str:
"""Format a success message for a tool call."""
if tool_name == "search_files":
count = result.get("count", 0)
return f" ✓ Found {count} file(s)\n"
elif tool_name == "read_file":
size = result.get("size", 0)
truncated = " (truncated)" if result.get("truncated") else ""
return f" ✓ Read {size} bytes{truncated}\n"
elif tool_name == "list_directory":
count = result.get("count", 0)
return f" ✓ Listed {count} item(s)\n"
elif tool_name == "inspect_database":
if "table" in result:
return f" ✓ Inspected table: {result['table']}\n"
else:
return f" ✓ Inspected database ({result.get('table_count', 0)} tables)\n"
elif tool_name == "search_database":
count = result.get("count", 0)
return f" ✓ Found {count} match(es)\n"
elif tool_name == "query_database":
count = result.get("count", 0)
return f" ✓ Query returned {count} row(s)\n"
else:
return " ✓ Success\n"
async def _stream_response_async(
self,
messages: List[Dict[str, Any]],
model_id: str,
max_tokens: Optional[int] = None,
transforms: Optional[List[str]] = None,
enable_web_search: bool = False,
web_search_config: Optional[Dict[str, Any]] = None,
) -> AsyncIterator[StreamChunk]:
"""
Async version of _stream_response for TUI.
Args:
messages: API messages
model_id: Model ID
max_tokens: Max tokens
transforms: Transforms
enable_web_search: Whether to enable Anthropic native web search
web_search_config: Web search configuration
Yields:
StreamChunk objects
"""
# Build chat parameters
chat_params = {
"messages": messages,
"model": model_id,
"stream": True,
"max_tokens": max_tokens,
"transforms": transforms,
}
# Only pass web search params to Anthropic provider
if self.client.provider_name == "anthropic":
chat_params["enable_web_search"] = enable_web_search
chat_params["web_search_config"] = web_search_config or {}
# For streaming, call directly (generator yields control naturally)
# For non-streaming, we'll detect it and run in thread
if chat_params.get("stream", True):
# Streaming - call directly, iteration will yield control
response = self.client.chat(**chat_params)
if isinstance(response, ChatResponse):
# Provider returned non-streaming despite stream=True
chunk = StreamChunk(
id="",
delta_content=response.content,
usage=response.usage,
error=None,
)
yield chunk
return
# Stream the response - yield control between chunks
for chunk in response:
await asyncio.sleep(0) # Yield control to event loop
if chunk.error:
yield StreamChunk(id="", delta_content=None, usage=None, error=chunk.error)
break
yield chunk
else:
# Non-streaming - run in thread to avoid blocking
response = await asyncio.to_thread(self.client.chat, **chat_params)
if isinstance(response, ChatResponse):
chunk = StreamChunk(
id="",
delta_content=response.content,
usage=response.usage,
error=None,
)
yield chunk
# ========== END ASYNC METHODS ==========
def add_to_history(
self,
prompt: str,
response: str,
usage: Optional[UsageStats] = None,
cost: float = 0.0,
) -> None:
"""
Add an exchange to the history.
Args:
prompt: User prompt
response: Assistant response
usage: Usage statistics
cost: Cost if not in usage
"""
entry = HistoryEntry(
prompt=prompt,
response=response,
prompt_tokens=usage.prompt_tokens if usage else 0,
completion_tokens=usage.completion_tokens if usage else 0,
msg_cost=usage.total_cost_usd if usage and usage.total_cost_usd else cost,
timestamp=time.time(),
)
self.history.append(entry)
self.current_index = len(self.history) - 1
self.stats.add_usage(usage, cost)
def save_conversation(self, name: str) -> bool:
"""
Save the current conversation.
Args:
name: Name for the saved conversation
Returns:
True if saved successfully
"""
if not self.history:
return False
data = [e.to_dict() for e in self.history]
self.db.save_conversation(name, data)
self.logger.info(f"Saved conversation: {name}")
return True
def load_conversation(self, name: str) -> bool:
"""
Load a saved conversation.
Args:
name: Name of the conversation to load
Returns:
True if loaded successfully
"""
data = self.db.load_conversation(name)
if not data:
return False
self.history.clear()
for entry_dict in data:
self.history.append(HistoryEntry(
prompt=entry_dict.get("prompt", ""),
response=entry_dict.get("response", ""),
prompt_tokens=entry_dict.get("prompt_tokens", 0),
completion_tokens=entry_dict.get("completion_tokens", 0),
msg_cost=entry_dict.get("msg_cost", 0.0),
))
self.current_index = len(self.history) - 1
self.memory_start_index = 0
self.stats = SessionStats() # Reset stats for loaded conversation
self.logger.info(f"Loaded conversation: {name}")
return True
def reset(self) -> None:
"""Reset the session state."""
self.history.clear()
self.stats = SessionStats()
self.system_prompt = ""
self.memory_start_index = 0
self.current_index = 0
self.logger.info("Session reset")
def check_warnings(self) -> List[str]:
"""
Check for cost and credit warnings.
Returns:
List of warning messages
"""
warnings = []
# Check last message cost
if self.history:
last_cost = self.history[-1].msg_cost
threshold = self.settings.cost_warning_threshold
if last_cost > threshold:
warnings.append(
f"High cost: ${last_cost:.4f} exceeds threshold ${threshold:.4f}"
)
# Check credits
credits = self.client.get_credits()
if credits:
left = credits.get("credits_left", 0)
total = credits.get("total_credits", 0)
if left < LOW_CREDIT_AMOUNT:
warnings.append(f"Low credits: ${left:.2f} remaining!")
elif total > 0 and left < total * LOW_CREDIT_RATIO:
warnings.append(f"Credits low: less than 10% remaining (${left:.2f})")
return warnings