1016 lines
34 KiB
Python
1016 lines
34 KiB
Python
"""
|
|
Chat session management for oAI.
|
|
|
|
This module provides the ChatSession class that manages an interactive
|
|
chat session including history, state, and message handling.
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Optional, Tuple
|
|
|
|
from oai.commands.registry import CommandContext, CommandResult, registry
|
|
from oai.config.database import Database
|
|
from oai.config.settings import Settings
|
|
from oai.constants import (
|
|
COST_WARNING_THRESHOLD,
|
|
LOW_CREDIT_AMOUNT,
|
|
LOW_CREDIT_RATIO,
|
|
)
|
|
from oai.core.client import AIClient
|
|
from oai.mcp.manager import MCPManager
|
|
from oai.providers.base import ChatResponse, StreamChunk, UsageStats
|
|
from oai.utils.logging import get_logger
|
|
from oai.utils.web_search import perform_web_search, format_search_results
|
|
|
|
logger = get_logger()
|
|
|
|
|
|
@dataclass
|
|
class SessionStats:
|
|
"""
|
|
Statistics for the current session.
|
|
|
|
Tracks tokens, costs, and message counts.
|
|
"""
|
|
|
|
total_input_tokens: int = 0
|
|
total_output_tokens: int = 0
|
|
total_cost: float = 0.0
|
|
message_count: int = 0
|
|
|
|
@property
|
|
def total_tokens(self) -> int:
|
|
"""Get total token count."""
|
|
return self.total_input_tokens + self.total_output_tokens
|
|
|
|
def add_usage(self, usage: Optional[UsageStats], cost: float = 0.0) -> None:
|
|
"""
|
|
Add usage stats from a response.
|
|
|
|
Args:
|
|
usage: Usage statistics
|
|
cost: Cost if not in usage
|
|
"""
|
|
if usage:
|
|
self.total_input_tokens += usage.prompt_tokens
|
|
self.total_output_tokens += usage.completion_tokens
|
|
if usage.total_cost_usd:
|
|
self.total_cost += usage.total_cost_usd
|
|
else:
|
|
self.total_cost += cost
|
|
else:
|
|
self.total_cost += cost
|
|
self.message_count += 1
|
|
|
|
|
|
@dataclass
|
|
class HistoryEntry:
|
|
"""
|
|
A single entry in the conversation history.
|
|
|
|
Stores the user prompt, assistant response, and metrics.
|
|
"""
|
|
|
|
prompt: str
|
|
response: str
|
|
prompt_tokens: int = 0
|
|
completion_tokens: int = 0
|
|
msg_cost: float = 0.0
|
|
timestamp: Optional[float] = None
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""Convert to dictionary format."""
|
|
return {
|
|
"prompt": self.prompt,
|
|
"response": self.response,
|
|
"prompt_tokens": self.prompt_tokens,
|
|
"completion_tokens": self.completion_tokens,
|
|
"msg_cost": self.msg_cost,
|
|
}
|
|
|
|
|
|
class ChatSession:
|
|
"""
|
|
Manages an interactive chat session.
|
|
|
|
Handles conversation history, state management, command processing,
|
|
and communication with the AI client.
|
|
|
|
Attributes:
|
|
client: AI client for API requests
|
|
settings: Application settings
|
|
mcp_manager: MCP manager for file/database access
|
|
history: Conversation history
|
|
stats: Session statistics
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
client: AIClient,
|
|
settings: Settings,
|
|
mcp_manager: Optional[MCPManager] = None,
|
|
):
|
|
"""
|
|
Initialize a chat session.
|
|
|
|
Args:
|
|
client: AI client instance
|
|
settings: Application settings
|
|
mcp_manager: Optional MCP manager
|
|
"""
|
|
self.client = client
|
|
self.settings = settings
|
|
self.mcp_manager = mcp_manager
|
|
self.db = Database()
|
|
|
|
self.history: List[HistoryEntry] = []
|
|
self.stats = SessionStats()
|
|
|
|
# Session state
|
|
self.system_prompt: str = settings.effective_system_prompt
|
|
self.memory_enabled: bool = True
|
|
self.memory_start_index: int = 0
|
|
self.online_enabled: bool = settings.default_online_mode
|
|
self.middle_out_enabled: bool = False
|
|
self.session_max_token: int = 0
|
|
self.current_index: int = 0
|
|
|
|
# Selected model
|
|
self.selected_model: Optional[Dict[str, Any]] = None
|
|
|
|
self.logger = get_logger()
|
|
|
|
def get_context(self) -> CommandContext:
|
|
"""
|
|
Get the current command context.
|
|
|
|
Returns:
|
|
CommandContext with current session state
|
|
"""
|
|
return CommandContext(
|
|
settings=self.settings,
|
|
provider=self.client.provider,
|
|
mcp_manager=self.mcp_manager,
|
|
selected_model_raw=self.selected_model,
|
|
session_history=[e.to_dict() for e in self.history],
|
|
session_system_prompt=self.system_prompt,
|
|
memory_enabled=self.memory_enabled,
|
|
memory_start_index=self.memory_start_index,
|
|
online_enabled=self.online_enabled,
|
|
middle_out_enabled=self.middle_out_enabled,
|
|
session_max_token=self.session_max_token,
|
|
total_input_tokens=self.stats.total_input_tokens,
|
|
total_output_tokens=self.stats.total_output_tokens,
|
|
total_cost=self.stats.total_cost,
|
|
message_count=self.stats.message_count,
|
|
current_index=self.current_index,
|
|
session=self,
|
|
)
|
|
|
|
def set_model(self, model: Dict[str, Any]) -> None:
|
|
"""
|
|
Set the selected model.
|
|
|
|
Args:
|
|
model: Raw model dictionary
|
|
"""
|
|
self.selected_model = model
|
|
self.client.set_default_model(model["id"])
|
|
self.logger.info(f"Model selected: {model['id']}")
|
|
|
|
def build_api_messages(self, user_input: str) -> List[Dict[str, Any]]:
|
|
"""
|
|
Build the messages array for an API request.
|
|
|
|
Includes system prompt, history (if memory enabled), and current input.
|
|
|
|
Args:
|
|
user_input: Current user input
|
|
|
|
Returns:
|
|
List of message dictionaries
|
|
"""
|
|
messages = []
|
|
|
|
# Add system prompt
|
|
if self.system_prompt:
|
|
messages.append({"role": "system", "content": self.system_prompt})
|
|
|
|
# Add database context if in database mode
|
|
if self.mcp_manager and self.mcp_manager.enabled:
|
|
if self.mcp_manager.mode == "database" and self.mcp_manager.selected_db_index is not None:
|
|
db = self.mcp_manager.databases[self.mcp_manager.selected_db_index]
|
|
db_context = (
|
|
f"You are connected to SQLite database: {db['name']}\n"
|
|
f"Available tables: {', '.join(db['tables'])}\n\n"
|
|
"Use inspect_database, search_database, or query_database tools. "
|
|
"All queries are read-only."
|
|
)
|
|
messages.append({"role": "system", "content": db_context})
|
|
|
|
# Add history if memory enabled
|
|
if self.memory_enabled:
|
|
for i in range(self.memory_start_index, len(self.history)):
|
|
entry = self.history[i]
|
|
messages.append({"role": "user", "content": entry.prompt})
|
|
messages.append({"role": "assistant", "content": entry.response})
|
|
|
|
# Add current message
|
|
messages.append({"role": "user", "content": user_input})
|
|
|
|
return messages
|
|
|
|
def get_mcp_tools(self) -> Optional[List[Dict[str, Any]]]:
|
|
"""
|
|
Get MCP tool definitions if available.
|
|
|
|
Returns:
|
|
List of tool schemas or None
|
|
"""
|
|
if not self.mcp_manager or not self.mcp_manager.enabled:
|
|
return None
|
|
|
|
if not self.selected_model:
|
|
return None
|
|
|
|
# Check if model supports tools
|
|
supported_params = self.selected_model.get("supported_parameters", [])
|
|
if "tools" not in supported_params and "functions" not in supported_params:
|
|
return None
|
|
|
|
return self.mcp_manager.get_tools_schema()
|
|
|
|
async def execute_tool(
|
|
self,
|
|
tool_name: str,
|
|
tool_args: Dict[str, Any],
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Execute an MCP tool.
|
|
|
|
Args:
|
|
tool_name: Name of the tool
|
|
tool_args: Tool arguments
|
|
|
|
Returns:
|
|
Tool execution result
|
|
"""
|
|
if not self.mcp_manager:
|
|
return {"error": "MCP not available"}
|
|
|
|
return await self.mcp_manager.call_tool(tool_name, **tool_args)
|
|
|
|
def send_message(
|
|
self,
|
|
user_input: str,
|
|
stream: bool = True,
|
|
on_stream_chunk: Optional[Callable[[str], None]] = None,
|
|
) -> Tuple[str, Optional[UsageStats], float]:
|
|
"""
|
|
Send a message and get a response.
|
|
|
|
Args:
|
|
user_input: User's input text
|
|
stream: Whether to stream the response
|
|
on_stream_chunk: Callback for stream chunks
|
|
|
|
Returns:
|
|
Tuple of (response_text, usage_stats, response_time)
|
|
"""
|
|
if not self.selected_model:
|
|
raise ValueError("No model selected")
|
|
|
|
start_time = time.time()
|
|
messages = self.build_api_messages(user_input)
|
|
|
|
# Get MCP tools
|
|
tools = self.get_mcp_tools()
|
|
if tools:
|
|
# Disable streaming when tools are present
|
|
stream = False
|
|
|
|
# Build request parameters
|
|
model_id = self.selected_model["id"]
|
|
|
|
# Handle online mode
|
|
enable_web_search = False
|
|
web_search_config = {}
|
|
|
|
if self.online_enabled:
|
|
# OpenRouter handles online mode natively with :online suffix
|
|
if self.client.provider_name == "openrouter":
|
|
if hasattr(self.client.provider, "get_effective_model_id"):
|
|
model_id = self.client.provider.get_effective_model_id(model_id, True)
|
|
# Anthropic has native web search when search provider is set to anthropic_native
|
|
elif self.client.provider_name == "anthropic" and self.settings.search_provider == "anthropic_native":
|
|
enable_web_search = True
|
|
web_search_config = {
|
|
"max_uses": self.settings.search_num_results or 5
|
|
}
|
|
logger.info("Using Anthropic native web search")
|
|
else:
|
|
# For other providers, perform web search and inject results
|
|
logger.info(f"Performing web search for: {user_input}")
|
|
search_results = perform_web_search(
|
|
user_input,
|
|
num_results=self.settings.search_num_results,
|
|
provider=self.settings.search_provider,
|
|
google_api_key=self.settings.google_api_key,
|
|
google_search_engine_id=self.settings.google_search_engine_id
|
|
)
|
|
|
|
if search_results:
|
|
# Inject search results into messages
|
|
formatted_results = format_search_results(search_results)
|
|
search_context = f"\n\n{formatted_results}\n\nPlease use the above web search results to help answer the user's question."
|
|
|
|
# Add search results to the last user message
|
|
if messages and messages[-1]["role"] == "user":
|
|
messages[-1]["content"] += search_context
|
|
|
|
logger.info(f"Injected {len(search_results)} search results into context")
|
|
if self.online_enabled:
|
|
if hasattr(self.client.provider, "get_effective_model_id"):
|
|
model_id = self.client.provider.get_effective_model_id(model_id, True)
|
|
|
|
transforms = ["middle-out"] if self.middle_out_enabled else None
|
|
|
|
max_tokens = None
|
|
if self.session_max_token > 0:
|
|
max_tokens = self.session_max_token
|
|
|
|
if tools:
|
|
# Use tool handling flow
|
|
response = self._send_with_tools(
|
|
messages=messages,
|
|
model_id=model_id,
|
|
tools=tools,
|
|
max_tokens=max_tokens,
|
|
transforms=transforms,
|
|
)
|
|
response_time = time.time() - start_time
|
|
return response.content or "", response.usage, response_time
|
|
|
|
elif stream:
|
|
# Use streaming flow
|
|
full_text, usage = self._stream_response(
|
|
messages=messages,
|
|
model_id=model_id,
|
|
max_tokens=max_tokens,
|
|
transforms=transforms,
|
|
on_chunk=on_stream_chunk,
|
|
enable_web_search=enable_web_search,
|
|
web_search_config=web_search_config,
|
|
)
|
|
response_time = time.time() - start_time
|
|
return full_text, usage, response_time
|
|
|
|
else:
|
|
# Non-streaming request
|
|
response = self.client.chat(
|
|
messages=messages,
|
|
model=model_id,
|
|
stream=False,
|
|
max_tokens=max_tokens,
|
|
transforms=transforms,
|
|
)
|
|
response_time = time.time() - start_time
|
|
|
|
if isinstance(response, ChatResponse):
|
|
return response.content or "", response.usage, response_time
|
|
else:
|
|
return "", None, response_time
|
|
|
|
def _send_with_tools(
|
|
self,
|
|
messages: List[Dict[str, Any]],
|
|
model_id: str,
|
|
tools: List[Dict[str, Any]],
|
|
max_tokens: Optional[int] = None,
|
|
transforms: Optional[List[str]] = None,
|
|
) -> ChatResponse:
|
|
"""
|
|
Send a request with tool call handling.
|
|
|
|
Args:
|
|
messages: API messages
|
|
model_id: Model ID
|
|
tools: Tool definitions
|
|
max_tokens: Max tokens
|
|
transforms: Transforms list
|
|
|
|
Returns:
|
|
Final ChatResponse
|
|
"""
|
|
max_loops = 5
|
|
loop_count = 0
|
|
api_messages = list(messages)
|
|
|
|
while loop_count < max_loops:
|
|
response = self.client.chat(
|
|
messages=api_messages,
|
|
model=model_id,
|
|
stream=False,
|
|
max_tokens=max_tokens,
|
|
tools=tools,
|
|
tool_choice="auto",
|
|
transforms=transforms,
|
|
)
|
|
|
|
if not isinstance(response, ChatResponse):
|
|
raise ValueError("Expected ChatResponse")
|
|
|
|
tool_calls = response.tool_calls
|
|
if not tool_calls:
|
|
return response
|
|
|
|
# Tool calls requested by AI
|
|
|
|
tool_results = []
|
|
for tc in tool_calls:
|
|
try:
|
|
args = json.loads(tc.function.arguments)
|
|
except json.JSONDecodeError as e:
|
|
self.logger.error(f"Failed to parse tool arguments: {e}")
|
|
tool_results.append({
|
|
"tool_call_id": tc.id,
|
|
"role": "tool",
|
|
"name": tc.function.name,
|
|
"content": json.dumps({"error": f"Invalid arguments: {e}"}),
|
|
})
|
|
continue
|
|
|
|
# Display tool call
|
|
args_display = ", ".join(
|
|
f'{k}="{v}"' if isinstance(v, str) else f"{k}={v}"
|
|
for k, v in args.items()
|
|
)
|
|
# Executing tool: {tc.function.name}
|
|
|
|
# Execute tool
|
|
result = asyncio.run(self.execute_tool(tc.function.name, args))
|
|
|
|
if "error" in result:
|
|
# Tool execution error logged
|
|
pass
|
|
else:
|
|
# Tool execution successful
|
|
pass
|
|
|
|
tool_results.append({
|
|
"tool_call_id": tc.id,
|
|
"role": "tool",
|
|
"name": tc.function.name,
|
|
"content": json.dumps(result),
|
|
})
|
|
|
|
# Add assistant message with tool calls
|
|
api_messages.append({
|
|
"role": "assistant",
|
|
"content": response.content,
|
|
"tool_calls": [
|
|
{
|
|
"id": tc.id,
|
|
"type": tc.type,
|
|
"function": {
|
|
"name": tc.function.name,
|
|
"arguments": tc.function.arguments,
|
|
},
|
|
}
|
|
for tc in tool_calls
|
|
],
|
|
})
|
|
api_messages.extend(tool_results)
|
|
|
|
# Processing tool results
|
|
loop_count += 1
|
|
|
|
self.logger.warning(f"Reached max tool loops ({max_loops})")
|
|
return response
|
|
|
|
|
|
def _stream_response(
|
|
self,
|
|
messages: List[Dict[str, Any]],
|
|
model_id: str,
|
|
max_tokens: Optional[int] = None,
|
|
transforms: Optional[List[str]] = None,
|
|
on_chunk: Optional[Callable[[str], None]] = None,
|
|
enable_web_search: bool = False,
|
|
web_search_config: Optional[Dict[str, Any]] = None,
|
|
) -> Tuple[str, Optional[UsageStats]]:
|
|
"""
|
|
Stream a response with live display.
|
|
|
|
Args:
|
|
messages: API messages
|
|
model_id: Model ID
|
|
max_tokens: Max tokens
|
|
transforms: Transforms
|
|
on_chunk: Callback for chunks
|
|
enable_web_search: Whether to enable Anthropic native web search
|
|
web_search_config: Web search configuration
|
|
|
|
Returns:
|
|
Tuple of (full_text, usage)
|
|
"""
|
|
# Build chat parameters
|
|
chat_params = {
|
|
"messages": messages,
|
|
"model": model_id,
|
|
"stream": True,
|
|
"max_tokens": max_tokens,
|
|
"transforms": transforms,
|
|
}
|
|
|
|
# Only pass web search params to Anthropic provider
|
|
if self.client.provider_name == "anthropic":
|
|
chat_params["enable_web_search"] = enable_web_search
|
|
chat_params["web_search_config"] = web_search_config or {}
|
|
|
|
response = self.client.chat(**chat_params)
|
|
|
|
if isinstance(response, ChatResponse):
|
|
return response.content or "", response.usage
|
|
|
|
full_text = ""
|
|
usage: Optional[UsageStats] = None
|
|
|
|
try:
|
|
for chunk in response:
|
|
if chunk.error:
|
|
self.logger.error(f"Stream error: {chunk.error}")
|
|
break
|
|
|
|
if chunk.delta_content:
|
|
full_text += chunk.delta_content
|
|
if on_chunk:
|
|
on_chunk(chunk.delta_content)
|
|
|
|
if chunk.usage:
|
|
usage = chunk.usage
|
|
|
|
except KeyboardInterrupt:
|
|
self.logger.info("Streaming interrupted")
|
|
return "", None
|
|
|
|
return full_text, usage
|
|
|
|
# ========== ASYNC METHODS FOR TUI ==========
|
|
|
|
async def send_message_async(
|
|
self,
|
|
user_input: str,
|
|
stream: bool = True,
|
|
) -> AsyncIterator[StreamChunk]:
|
|
"""
|
|
Async version of send_message for Textual TUI.
|
|
|
|
Args:
|
|
user_input: User's input text
|
|
stream: Whether to stream the response
|
|
|
|
Yields:
|
|
StreamChunk objects for progressive display
|
|
"""
|
|
if not self.selected_model:
|
|
raise ValueError("No model selected")
|
|
|
|
messages = self.build_api_messages(user_input)
|
|
tools = self.get_mcp_tools()
|
|
|
|
if tools:
|
|
# Disable streaming when tools are present
|
|
stream = False
|
|
|
|
# Handle online mode
|
|
model_id = self.selected_model["id"]
|
|
enable_web_search = False
|
|
web_search_config = {}
|
|
|
|
if self.online_enabled:
|
|
# OpenRouter handles online mode natively with :online suffix
|
|
if self.client.provider_name == "openrouter":
|
|
if hasattr(self.client.provider, "get_effective_model_id"):
|
|
model_id = self.client.provider.get_effective_model_id(model_id, True)
|
|
# Anthropic has native web search when search provider is set to anthropic_native
|
|
elif self.client.provider_name == "anthropic" and self.settings.search_provider == "anthropic_native":
|
|
enable_web_search = True
|
|
web_search_config = {
|
|
"max_uses": self.settings.search_num_results or 5
|
|
}
|
|
logger.info("Using Anthropic native web search")
|
|
else:
|
|
# For other providers, perform web search and inject results
|
|
logger.info(f"Performing web search for: {user_input}")
|
|
search_results = await asyncio.to_thread(
|
|
perform_web_search,
|
|
user_input,
|
|
num_results=self.settings.search_num_results,
|
|
provider=self.settings.search_provider,
|
|
google_api_key=self.settings.google_api_key,
|
|
google_search_engine_id=self.settings.google_search_engine_id
|
|
)
|
|
|
|
if search_results:
|
|
# Inject search results into messages
|
|
formatted_results = format_search_results(search_results)
|
|
search_context = f"\n\n{formatted_results}\n\nPlease use the above web search results to help answer the user's question."
|
|
|
|
# Add search results to the last user message
|
|
if messages and messages[-1]["role"] == "user":
|
|
messages[-1]["content"] += search_context
|
|
|
|
logger.info(f"Injected {len(search_results)} search results into context")
|
|
|
|
transforms = ["middle-out"] if self.middle_out_enabled else None
|
|
max_tokens = None
|
|
if self.session_max_token > 0:
|
|
max_tokens = self.session_max_token
|
|
|
|
if tools:
|
|
# Use async tool handling flow
|
|
async for chunk in self._send_with_tools_async(
|
|
messages=messages,
|
|
model_id=model_id,
|
|
tools=tools,
|
|
max_tokens=max_tokens,
|
|
transforms=transforms,
|
|
):
|
|
yield chunk
|
|
elif stream:
|
|
# Use async streaming flow
|
|
async for chunk in self._stream_response_async(
|
|
messages=messages,
|
|
model_id=model_id,
|
|
max_tokens=max_tokens,
|
|
transforms=transforms,
|
|
enable_web_search=enable_web_search,
|
|
web_search_config=web_search_config,
|
|
):
|
|
yield chunk
|
|
else:
|
|
# Non-streaming request - run in thread to avoid blocking event loop
|
|
response = await asyncio.to_thread(
|
|
self.client.chat,
|
|
messages=messages,
|
|
model=model_id,
|
|
stream=False,
|
|
max_tokens=max_tokens,
|
|
transforms=transforms,
|
|
)
|
|
if isinstance(response, ChatResponse):
|
|
# Yield single chunk with complete response
|
|
chunk = StreamChunk(
|
|
id="",
|
|
delta_content=response.content,
|
|
usage=response.usage,
|
|
error=None,
|
|
)
|
|
yield chunk
|
|
|
|
async def _send_with_tools_async(
|
|
self,
|
|
messages: List[Dict[str, Any]],
|
|
model_id: str,
|
|
tools: List[Dict[str, Any]],
|
|
max_tokens: Optional[int] = None,
|
|
transforms: Optional[List[str]] = None,
|
|
) -> AsyncIterator[StreamChunk]:
|
|
"""
|
|
Async version of _send_with_tools for TUI.
|
|
|
|
Args:
|
|
messages: API messages
|
|
model_id: Model ID
|
|
tools: Tool definitions
|
|
max_tokens: Max tokens
|
|
transforms: Transforms list
|
|
|
|
Yields:
|
|
StreamChunk objects including tool call notifications
|
|
"""
|
|
max_loops = 5
|
|
loop_count = 0
|
|
api_messages = list(messages)
|
|
|
|
while loop_count < max_loops:
|
|
# Run in thread to avoid blocking event loop
|
|
response = await asyncio.to_thread(
|
|
self.client.chat,
|
|
messages=api_messages,
|
|
model=model_id,
|
|
stream=False,
|
|
max_tokens=max_tokens,
|
|
tools=tools,
|
|
tool_choice="auto",
|
|
transforms=transforms,
|
|
)
|
|
|
|
if not isinstance(response, ChatResponse):
|
|
raise ValueError("Expected ChatResponse")
|
|
|
|
tool_calls = response.tool_calls
|
|
if not tool_calls:
|
|
# Final response, yield it
|
|
chunk = StreamChunk(
|
|
id="",
|
|
delta_content=response.content,
|
|
usage=response.usage,
|
|
error=None,
|
|
)
|
|
yield chunk
|
|
return
|
|
|
|
# Yield notification about tool calls
|
|
tool_notification = f"\n🔧 AI requesting {len(tool_calls)} tool call(s)...\n"
|
|
yield StreamChunk(id="", delta_content=tool_notification, usage=None, error=None)
|
|
|
|
tool_results = []
|
|
for tc in tool_calls:
|
|
try:
|
|
args = json.loads(tc.function.arguments)
|
|
except json.JSONDecodeError as e:
|
|
self.logger.error(f"Failed to parse tool arguments: {e}")
|
|
tool_results.append({
|
|
"tool_call_id": tc.id,
|
|
"role": "tool",
|
|
"name": tc.function.name,
|
|
"content": json.dumps({"error": f"Invalid arguments: {e}"}),
|
|
})
|
|
continue
|
|
|
|
# Yield tool call display
|
|
args_display = ", ".join(
|
|
f'{k}="{v}"' if isinstance(v, str) else f"{k}={v}"
|
|
for k, v in args.items()
|
|
)
|
|
tool_display = f" → {tc.function.name}({args_display})\n"
|
|
yield StreamChunk(id="", delta_content=tool_display, usage=None, error=None)
|
|
|
|
# Execute tool (await instead of asyncio.run)
|
|
result = await self.execute_tool(tc.function.name, args)
|
|
|
|
if "error" in result:
|
|
error_msg = f" ✗ Error: {result['error']}\n"
|
|
yield StreamChunk(id="", delta_content=error_msg, usage=None, error=None)
|
|
else:
|
|
success_msg = self._format_tool_success(tc.function.name, result)
|
|
yield StreamChunk(id="", delta_content=success_msg, usage=None, error=None)
|
|
|
|
tool_results.append({
|
|
"tool_call_id": tc.id,
|
|
"role": "tool",
|
|
"name": tc.function.name,
|
|
"content": json.dumps(result),
|
|
})
|
|
|
|
# Add assistant message with tool calls
|
|
api_messages.append({
|
|
"role": "assistant",
|
|
"content": response.content,
|
|
"tool_calls": [
|
|
{
|
|
"id": tc.id,
|
|
"type": tc.type,
|
|
"function": {
|
|
"name": tc.function.name,
|
|
"arguments": tc.function.arguments,
|
|
},
|
|
}
|
|
for tc in tool_calls
|
|
],
|
|
})
|
|
|
|
# Add tool results
|
|
api_messages.extend(tool_results)
|
|
loop_count += 1
|
|
|
|
# Max loops reached
|
|
yield StreamChunk(
|
|
id="",
|
|
delta_content="\n⚠️ Maximum tool call loops reached\n",
|
|
usage=None,
|
|
error="Max loops reached"
|
|
)
|
|
|
|
def _format_tool_success(self, tool_name: str, result: Dict[str, Any]) -> str:
|
|
"""Format a success message for a tool call."""
|
|
if tool_name == "search_files":
|
|
count = result.get("count", 0)
|
|
return f" ✓ Found {count} file(s)\n"
|
|
elif tool_name == "read_file":
|
|
size = result.get("size", 0)
|
|
truncated = " (truncated)" if result.get("truncated") else ""
|
|
return f" ✓ Read {size} bytes{truncated}\n"
|
|
elif tool_name == "list_directory":
|
|
count = result.get("count", 0)
|
|
return f" ✓ Listed {count} item(s)\n"
|
|
elif tool_name == "inspect_database":
|
|
if "table" in result:
|
|
return f" ✓ Inspected table: {result['table']}\n"
|
|
else:
|
|
return f" ✓ Inspected database ({result.get('table_count', 0)} tables)\n"
|
|
elif tool_name == "search_database":
|
|
count = result.get("count", 0)
|
|
return f" ✓ Found {count} match(es)\n"
|
|
elif tool_name == "query_database":
|
|
count = result.get("count", 0)
|
|
return f" ✓ Query returned {count} row(s)\n"
|
|
else:
|
|
return " ✓ Success\n"
|
|
|
|
async def _stream_response_async(
|
|
self,
|
|
messages: List[Dict[str, Any]],
|
|
model_id: str,
|
|
max_tokens: Optional[int] = None,
|
|
transforms: Optional[List[str]] = None,
|
|
enable_web_search: bool = False,
|
|
web_search_config: Optional[Dict[str, Any]] = None,
|
|
) -> AsyncIterator[StreamChunk]:
|
|
"""
|
|
Async version of _stream_response for TUI.
|
|
|
|
Args:
|
|
messages: API messages
|
|
model_id: Model ID
|
|
max_tokens: Max tokens
|
|
transforms: Transforms
|
|
enable_web_search: Whether to enable Anthropic native web search
|
|
web_search_config: Web search configuration
|
|
|
|
Yields:
|
|
StreamChunk objects
|
|
"""
|
|
# Build chat parameters
|
|
chat_params = {
|
|
"messages": messages,
|
|
"model": model_id,
|
|
"stream": True,
|
|
"max_tokens": max_tokens,
|
|
"transforms": transforms,
|
|
}
|
|
|
|
# Only pass web search params to Anthropic provider
|
|
if self.client.provider_name == "anthropic":
|
|
chat_params["enable_web_search"] = enable_web_search
|
|
chat_params["web_search_config"] = web_search_config or {}
|
|
|
|
# For streaming, call directly (generator yields control naturally)
|
|
# For non-streaming, we'll detect it and run in thread
|
|
if chat_params.get("stream", True):
|
|
# Streaming - call directly, iteration will yield control
|
|
response = self.client.chat(**chat_params)
|
|
|
|
if isinstance(response, ChatResponse):
|
|
# Provider returned non-streaming despite stream=True
|
|
chunk = StreamChunk(
|
|
id="",
|
|
delta_content=response.content,
|
|
usage=response.usage,
|
|
error=None,
|
|
)
|
|
yield chunk
|
|
return
|
|
|
|
# Stream the response - yield control between chunks
|
|
for chunk in response:
|
|
await asyncio.sleep(0) # Yield control to event loop
|
|
if chunk.error:
|
|
yield StreamChunk(id="", delta_content=None, usage=None, error=chunk.error)
|
|
break
|
|
yield chunk
|
|
else:
|
|
# Non-streaming - run in thread to avoid blocking
|
|
response = await asyncio.to_thread(self.client.chat, **chat_params)
|
|
if isinstance(response, ChatResponse):
|
|
chunk = StreamChunk(
|
|
id="",
|
|
delta_content=response.content,
|
|
usage=response.usage,
|
|
error=None,
|
|
)
|
|
yield chunk
|
|
|
|
# ========== END ASYNC METHODS ==========
|
|
|
|
def add_to_history(
|
|
self,
|
|
prompt: str,
|
|
response: str,
|
|
usage: Optional[UsageStats] = None,
|
|
cost: float = 0.0,
|
|
) -> None:
|
|
"""
|
|
Add an exchange to the history.
|
|
|
|
Args:
|
|
prompt: User prompt
|
|
response: Assistant response
|
|
usage: Usage statistics
|
|
cost: Cost if not in usage
|
|
"""
|
|
entry = HistoryEntry(
|
|
prompt=prompt,
|
|
response=response,
|
|
prompt_tokens=usage.prompt_tokens if usage else 0,
|
|
completion_tokens=usage.completion_tokens if usage else 0,
|
|
msg_cost=usage.total_cost_usd if usage and usage.total_cost_usd else cost,
|
|
timestamp=time.time(),
|
|
)
|
|
self.history.append(entry)
|
|
self.current_index = len(self.history) - 1
|
|
self.stats.add_usage(usage, cost)
|
|
|
|
def save_conversation(self, name: str) -> bool:
|
|
"""
|
|
Save the current conversation.
|
|
|
|
Args:
|
|
name: Name for the saved conversation
|
|
|
|
Returns:
|
|
True if saved successfully
|
|
"""
|
|
if not self.history:
|
|
return False
|
|
|
|
data = [e.to_dict() for e in self.history]
|
|
self.db.save_conversation(name, data)
|
|
self.logger.info(f"Saved conversation: {name}")
|
|
return True
|
|
|
|
def load_conversation(self, name: str) -> bool:
|
|
"""
|
|
Load a saved conversation.
|
|
|
|
Args:
|
|
name: Name of the conversation to load
|
|
|
|
Returns:
|
|
True if loaded successfully
|
|
"""
|
|
data = self.db.load_conversation(name)
|
|
if not data:
|
|
return False
|
|
|
|
self.history.clear()
|
|
for entry_dict in data:
|
|
self.history.append(HistoryEntry(
|
|
prompt=entry_dict.get("prompt", ""),
|
|
response=entry_dict.get("response", ""),
|
|
prompt_tokens=entry_dict.get("prompt_tokens", 0),
|
|
completion_tokens=entry_dict.get("completion_tokens", 0),
|
|
msg_cost=entry_dict.get("msg_cost", 0.0),
|
|
))
|
|
|
|
self.current_index = len(self.history) - 1
|
|
self.memory_start_index = 0
|
|
self.stats = SessionStats() # Reset stats for loaded conversation
|
|
self.logger.info(f"Loaded conversation: {name}")
|
|
return True
|
|
|
|
def reset(self) -> None:
|
|
"""Reset the session state."""
|
|
self.history.clear()
|
|
self.stats = SessionStats()
|
|
self.system_prompt = ""
|
|
self.memory_start_index = 0
|
|
self.current_index = 0
|
|
self.logger.info("Session reset")
|
|
|
|
def check_warnings(self) -> List[str]:
|
|
"""
|
|
Check for cost and credit warnings.
|
|
|
|
Returns:
|
|
List of warning messages
|
|
"""
|
|
warnings = []
|
|
|
|
# Check last message cost
|
|
if self.history:
|
|
last_cost = self.history[-1].msg_cost
|
|
threshold = self.settings.cost_warning_threshold
|
|
if last_cost > threshold:
|
|
warnings.append(
|
|
f"High cost: ${last_cost:.4f} exceeds threshold ${threshold:.4f}"
|
|
)
|
|
|
|
# Check credits
|
|
credits = self.client.get_credits()
|
|
if credits:
|
|
left = credits.get("credits_left", 0)
|
|
total = credits.get("total_credits", 0)
|
|
|
|
if left < LOW_CREDIT_AMOUNT:
|
|
warnings.append(f"Low credits: ${left:.2f} remaining!")
|
|
elif total > 0 and left < total * LOW_CREDIT_RATIO:
|
|
warnings.append(f"Credits low: less than 10% remaining (${left:.2f})")
|
|
|
|
return warnings
|