Files
oai/oai/config/settings.py

573 lines
19 KiB
Python

"""
Settings management for oAI.
This module provides a centralized settings class that handles all application
configuration with type safety, validation, and persistence.
"""
from dataclasses import dataclass, field
from typing import Optional, Dict
from pathlib import Path
import json
from oai.constants import (
DEFAULT_BASE_URL,
DEFAULT_STREAM_ENABLED,
DEFAULT_MAX_TOKENS,
DEFAULT_ONLINE_MODE,
DEFAULT_COST_WARNING_THRESHOLD,
DEFAULT_LOG_MAX_SIZE_MB,
DEFAULT_LOG_BACKUP_COUNT,
DEFAULT_LOG_LEVEL,
DEFAULT_SYSTEM_PROMPT,
VALID_LOG_LEVELS,
DEFAULT_PROVIDER,
OLLAMA_DEFAULT_URL,
)
from oai.config.database import get_database
@dataclass
class Settings:
"""
Application settings with persistence support.
This class provides a clean interface for managing all configuration
options. Settings are automatically loaded from the database on
initialization and can be persisted back.
Attributes:
api_key: Legacy OpenRouter API key (deprecated, use openrouter_api_key)
base_url: API base URL
default_model: Default model ID to use
default_system_prompt: Custom system prompt (None = use hardcoded default, "" = blank)
stream_enabled: Whether to stream responses
max_tokens: Maximum tokens per request
cost_warning_threshold: Alert threshold for message cost
default_online_mode: Whether online mode is enabled by default
log_max_size_mb: Maximum log file size in MB
log_backup_count: Number of log file backups to keep
log_level: Logging level (debug/info/warning/error/critical)
# Provider-specific settings
default_provider: Default AI provider to use
openrouter_api_key: OpenRouter API key
anthropic_api_key: Anthropic API key
openai_api_key: OpenAI API key
ollama_base_url: Ollama server URL
"""
# Legacy field (kept for backward compatibility)
api_key: Optional[str] = None
# Provider configuration
default_provider: str = DEFAULT_PROVIDER
openrouter_api_key: Optional[str] = None
anthropic_api_key: Optional[str] = None
openai_api_key: Optional[str] = None
ollama_base_url: str = OLLAMA_DEFAULT_URL
provider_models: Dict[str, str] = field(default_factory=dict) # provider -> last_model_id
# Web search configuration (for online mode with non-OpenRouter providers)
search_provider: str = "duckduckgo" # "duckduckgo" or "google"
google_api_key: Optional[str] = None
google_search_engine_id: Optional[str] = None
search_num_results: int = 5
# General settings
base_url: str = DEFAULT_BASE_URL
default_model: Optional[str] = None
default_system_prompt: Optional[str] = None
stream_enabled: bool = DEFAULT_STREAM_ENABLED
max_tokens: int = DEFAULT_MAX_TOKENS
cost_warning_threshold: float = DEFAULT_COST_WARNING_THRESHOLD
default_online_mode: bool = DEFAULT_ONLINE_MODE
log_max_size_mb: int = DEFAULT_LOG_MAX_SIZE_MB
log_backup_count: int = DEFAULT_LOG_BACKUP_COUNT
log_level: str = DEFAULT_LOG_LEVEL
@property
def effective_system_prompt(self) -> str:
"""
Get the effective system prompt to use.
Returns:
The custom prompt if set, hardcoded default if None, or blank if explicitly set to ""
"""
if self.default_system_prompt is None:
return DEFAULT_SYSTEM_PROMPT
return self.default_system_prompt
def __post_init__(self):
"""Validate settings after initialization."""
self._validate()
def _validate(self) -> None:
"""Validate all settings values."""
# Validate log level
if self.log_level.lower() not in VALID_LOG_LEVELS:
raise ValueError(
f"Invalid log level: {self.log_level}. "
f"Must be one of: {', '.join(VALID_LOG_LEVELS.keys())}"
)
# Validate numeric bounds
if self.max_tokens < 1:
raise ValueError("max_tokens must be at least 1")
if self.cost_warning_threshold < 0:
raise ValueError("cost_warning_threshold must be non-negative")
if self.log_max_size_mb < 1:
raise ValueError("log_max_size_mb must be at least 1")
if self.log_backup_count < 0:
raise ValueError("log_backup_count must be non-negative")
@classmethod
def load(cls) -> "Settings":
"""
Load settings from the database.
Returns:
Settings instance with values from database
"""
db = get_database()
# Helper to safely parse boolean
def parse_bool(value: Optional[str], default: bool) -> bool:
if value is None:
return default
return value.lower() in ("on", "true", "1", "yes")
# Helper to safely parse int
def parse_int(value: Optional[str], default: int) -> int:
if value is None:
return default
try:
return int(value)
except ValueError:
return default
# Helper to safely parse float
def parse_float(value: Optional[str], default: float) -> float:
if value is None:
return default
try:
return float(value)
except ValueError:
return default
# Get system prompt from DB: None means not set (use default), "" means explicitly blank
system_prompt_value = db.get_config("default_system_prompt")
# Migration: copy legacy api_key to openrouter_api_key if not already set
legacy_api_key = db.get_config("api_key")
openrouter_key = db.get_config("openrouter_api_key")
if legacy_api_key and not openrouter_key:
db.set_config("openrouter_api_key", legacy_api_key)
openrouter_key = legacy_api_key
# Note: We keep the legacy api_key in DB for backward compatibility
# Load provider-model mapping
provider_models_json = db.get_config("provider_models")
provider_models = {}
if provider_models_json:
try:
provider_models = json.loads(provider_models_json)
except json.JSONDecodeError:
provider_models = {}
return cls(
# Legacy field
api_key=legacy_api_key,
# Provider configuration
default_provider=db.get_config("default_provider") or DEFAULT_PROVIDER,
openrouter_api_key=openrouter_key,
anthropic_api_key=db.get_config("anthropic_api_key"),
openai_api_key=db.get_config("openai_api_key"),
ollama_base_url=db.get_config("ollama_base_url") or OLLAMA_DEFAULT_URL,
provider_models=provider_models,
# Web search configuration
search_provider=db.get_config("search_provider") or "duckduckgo",
google_api_key=db.get_config("google_api_key"),
google_search_engine_id=db.get_config("google_search_engine_id"),
search_num_results=parse_int(db.get_config("search_num_results"), 5),
# General settings
base_url=db.get_config("base_url") or DEFAULT_BASE_URL,
default_model=db.get_config("default_model"),
default_system_prompt=system_prompt_value,
stream_enabled=parse_bool(
db.get_config("stream_enabled"),
DEFAULT_STREAM_ENABLED
),
max_tokens=parse_int(
db.get_config("max_token"),
DEFAULT_MAX_TOKENS
),
cost_warning_threshold=parse_float(
db.get_config("cost_warning_threshold"),
DEFAULT_COST_WARNING_THRESHOLD
),
default_online_mode=parse_bool(
db.get_config("default_online_mode"),
DEFAULT_ONLINE_MODE
),
log_max_size_mb=parse_int(
db.get_config("log_max_size_mb"),
DEFAULT_LOG_MAX_SIZE_MB
),
log_backup_count=parse_int(
db.get_config("log_backup_count"),
DEFAULT_LOG_BACKUP_COUNT
),
log_level=db.get_config("log_level") or DEFAULT_LOG_LEVEL,
)
def save(self) -> None:
"""Persist all settings to the database."""
db = get_database()
# Only save API key if it exists
if self.api_key:
db.set_config("api_key", self.api_key)
db.set_config("base_url", self.base_url)
if self.default_model:
db.set_config("default_model", self.default_model)
# Save system prompt: None means not set (don't save), otherwise save the value (even if "")
if self.default_system_prompt is not None:
db.set_config("default_system_prompt", self.default_system_prompt)
db.set_config("stream_enabled", "on" if self.stream_enabled else "off")
db.set_config("max_token", str(self.max_tokens))
db.set_config("cost_warning_threshold", str(self.cost_warning_threshold))
db.set_config("default_online_mode", "on" if self.default_online_mode else "off")
db.set_config("log_max_size_mb", str(self.log_max_size_mb))
db.set_config("log_backup_count", str(self.log_backup_count))
db.set_config("log_level", self.log_level)
def set_api_key(self, api_key: str) -> None:
"""
Set and persist the API key.
Args:
api_key: The new API key
"""
self.api_key = api_key.strip()
get_database().set_config("api_key", self.api_key)
def set_base_url(self, url: str) -> None:
"""
Set and persist the base URL.
Args:
url: The new base URL
"""
self.base_url = url.strip()
get_database().set_config("base_url", self.base_url)
def set_default_model(self, model_id: str) -> None:
"""
Set and persist the default model.
Args:
model_id: The model ID to set as default
"""
self.default_model = model_id
get_database().set_config("default_model", model_id)
def set_default_system_prompt(self, prompt: str) -> None:
"""
Set and persist the default system prompt.
Args:
prompt: The system prompt to use for all new sessions.
Empty string "" means blank prompt (no system message).
"""
self.default_system_prompt = prompt
get_database().set_config("default_system_prompt", prompt)
def clear_default_system_prompt(self) -> None:
"""
Clear the custom system prompt and revert to hardcoded default.
This removes the custom prompt from the database, causing the
application to use the built-in DEFAULT_SYSTEM_PROMPT.
"""
self.default_system_prompt = None
# Remove from database to indicate "not set"
db = get_database()
with db._connection() as conn:
conn.execute("DELETE FROM config WHERE key = ?", ("default_system_prompt",))
conn.commit()
def set_stream_enabled(self, enabled: bool) -> None:
"""
Set and persist the streaming preference.
Args:
enabled: Whether to enable streaming
"""
self.stream_enabled = enabled
get_database().set_config("stream_enabled", "on" if enabled else "off")
def set_max_tokens(self, max_tokens: int) -> None:
"""
Set and persist the maximum tokens.
Args:
max_tokens: Maximum number of tokens
Raises:
ValueError: If max_tokens is less than 1
"""
if max_tokens < 1:
raise ValueError("max_tokens must be at least 1")
self.max_tokens = max_tokens
get_database().set_config("max_token", str(max_tokens))
def set_cost_warning_threshold(self, threshold: float) -> None:
"""
Set and persist the cost warning threshold.
Args:
threshold: Cost threshold in USD
Raises:
ValueError: If threshold is negative
"""
if threshold < 0:
raise ValueError("cost_warning_threshold must be non-negative")
self.cost_warning_threshold = threshold
get_database().set_config("cost_warning_threshold", str(threshold))
def set_default_online_mode(self, enabled: bool) -> None:
"""
Set and persist the default online mode.
Args:
enabled: Whether online mode should be enabled by default
"""
self.default_online_mode = enabled
get_database().set_config("default_online_mode", "on" if enabled else "off")
def set_log_level(self, level: str) -> None:
"""
Set and persist the log level.
Args:
level: The log level (debug/info/warning/error/critical)
Raises:
ValueError: If level is not valid
"""
level_lower = level.lower()
if level_lower not in VALID_LOG_LEVELS:
raise ValueError(
f"Invalid log level: {level}. "
f"Must be one of: {', '.join(VALID_LOG_LEVELS.keys())}"
)
self.log_level = level_lower
get_database().set_config("log_level", level_lower)
def set_log_max_size(self, size_mb: int) -> None:
"""
Set and persist the maximum log file size.
Args:
size_mb: Maximum size in megabytes
Raises:
ValueError: If size_mb is less than 1
"""
if size_mb < 1:
raise ValueError("log_max_size_mb must be at least 1")
# Cap at 100 MB for safety
self.log_max_size_mb = min(size_mb, 100)
get_database().set_config("log_max_size_mb", str(self.log_max_size_mb))
def set_provider_api_key(self, provider: str, api_key: str) -> None:
"""
Set and persist an API key for a specific provider.
Args:
provider: Provider name ("openrouter", "anthropic", "openai")
api_key: The API key to set
Raises:
ValueError: If provider is invalid
"""
provider = provider.lower()
api_key = api_key.strip()
if provider == "openrouter":
self.openrouter_api_key = api_key
get_database().set_config("openrouter_api_key", api_key)
elif provider == "anthropic":
self.anthropic_api_key = api_key
get_database().set_config("anthropic_api_key", api_key)
elif provider == "openai":
self.openai_api_key = api_key
get_database().set_config("openai_api_key", api_key)
else:
raise ValueError(f"Invalid provider: {provider}")
def get_provider_api_key(self, provider: str) -> Optional[str]:
"""
Get the API key for a specific provider.
Args:
provider: Provider name ("openrouter", "anthropic", "openai", "ollama")
Returns:
API key or None if not set
Raises:
ValueError: If provider is invalid
"""
provider = provider.lower()
if provider == "openrouter":
return self.openrouter_api_key
elif provider == "anthropic":
return self.anthropic_api_key
elif provider == "openai":
return self.openai_api_key
elif provider == "ollama":
return "" # Ollama doesn't require an API key
else:
raise ValueError(f"Invalid provider: {provider}")
def set_default_provider(self, provider: str) -> None:
"""
Set and persist the default provider.
Args:
provider: Provider name
Raises:
ValueError: If provider is invalid
"""
from oai.constants import VALID_PROVIDERS
provider = provider.lower()
if provider not in VALID_PROVIDERS:
raise ValueError(
f"Invalid provider: {provider}. "
f"Valid providers: {', '.join(VALID_PROVIDERS)}"
)
self.default_provider = provider
get_database().set_config("default_provider", provider)
def set_ollama_base_url(self, url: str) -> None:
"""
Set and persist the Ollama base URL.
Args:
url: Ollama server URL
"""
self.ollama_base_url = url.strip()
get_database().set_config("ollama_base_url", self.ollama_base_url)
def set_search_provider(self, provider: str) -> None:
"""
Set and persist the web search provider.
Args:
provider: Search provider ("anthropic_native", "duckduckgo", "google")
Raises:
ValueError: If provider is invalid
"""
valid_providers = ["anthropic_native", "duckduckgo", "google"]
provider = provider.lower()
if provider not in valid_providers:
raise ValueError(
f"Invalid search provider: {provider}. "
f"Valid providers: {', '.join(valid_providers)}"
)
self.search_provider = provider
get_database().set_config("search_provider", provider)
def set_google_api_key(self, api_key: str) -> None:
"""
Set and persist the Google API key for Google Custom Search.
Args:
api_key: The Google API key
"""
self.google_api_key = api_key.strip()
get_database().set_config("google_api_key", self.google_api_key)
def set_google_search_engine_id(self, engine_id: str) -> None:
"""
Set and persist the Google Custom Search Engine ID.
Args:
engine_id: The Google Search Engine ID
"""
self.google_search_engine_id = engine_id.strip()
get_database().set_config("google_search_engine_id", self.google_search_engine_id)
def get_provider_model(self, provider: str) -> Optional[str]:
"""
Get the last used model for a provider.
Args:
provider: Provider name
Returns:
Model ID or None if not set
"""
return self.provider_models.get(provider)
def set_provider_model(self, provider: str, model_id: str) -> None:
"""
Set and persist the last used model for a provider.
Args:
provider: Provider name
model_id: Model ID to remember
"""
self.provider_models[provider] = model_id
# Save to database as JSON
get_database().set_config("provider_models", json.dumps(self.provider_models))
# Global settings instance
_settings: Optional[Settings] = None
def get_settings() -> Settings:
"""
Get the global settings instance.
Returns:
The shared Settings instance, loading from database if needed
"""
global _settings
if _settings is None:
_settings = Settings.load()
return _settings
def reload_settings() -> Settings:
"""
Force reload settings from the database.
Returns:
Fresh Settings instance
"""
global _settings
_settings = Settings.load()
return _settings