899 lines
36 KiB
Python
899 lines
36 KiB
Python
"""
|
|
main.py — FastAPI application entry point.
|
|
|
|
Provides:
|
|
- HTML pages: /, /agents, /audit, /settings, /login, /setup, /admin/users
|
|
- WebSocket: /ws/{session_id} (streaming agent responses)
|
|
- REST API: /api/*
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
import uuid
|
|
from contextlib import asynccontextmanager
|
|
from pathlib import Path
|
|
|
|
# Configure logging before anything else imports logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s %(levelname)-8s %(name)s %(message)s",
|
|
datefmt="%Y-%m-%d %H:%M:%S",
|
|
)
|
|
# Make CalDAV tool logs visible at DEBUG level so every step is traceable
|
|
logging.getLogger("server.tools.caldav_tool").setLevel(logging.DEBUG)
|
|
|
|
from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect
|
|
from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
|
|
from fastapi.staticfiles import StaticFiles
|
|
from fastapi.templating import Jinja2Templates
|
|
|
|
|
|
from .agent.agent import Agent, AgentEvent, ConfirmationRequiredEvent, DoneEvent, ErrorEvent, ImageEvent, TextEvent, ToolDoneEvent, ToolStartEvent
|
|
from .agent.confirmation import confirmation_manager
|
|
from .agents.runner import agent_runner
|
|
from .agents.tasks import cleanup_stale_runs
|
|
from .auth import SYNTHETIC_API_ADMIN, CurrentUser, create_session_cookie, decode_session_cookie
|
|
from .brain.database import close_brain_db, init_brain_db
|
|
from .config import settings
|
|
from .context_vars import current_user as _current_user_var
|
|
from .database import close_db, credential_store, init_db
|
|
from .inbox.listener import inbox_listener
|
|
from .mcp import create_mcp_app, _session_manager
|
|
from .telegram.listener import telegram_listener
|
|
from .tools import build_registry
|
|
from .users import assign_existing_data_to_admin, create_user, get_user_by_username, user_count
|
|
from .web.routes import router as api_router
|
|
|
|
BASE_DIR = Path(__file__).parent
|
|
templates = Jinja2Templates(directory=str(BASE_DIR / "web" / "templates"))
|
|
templates.env.globals["agent_name"] = settings.agent_name
|
|
|
|
|
|
async def _migrate_email_accounts() -> None:
|
|
"""
|
|
One-time startup migration: copy old inbox:* / inbox_* credentials into the
|
|
new email_accounts table as 'trigger' type accounts.
|
|
Idempotent — guarded by the 'email_accounts_migrated' credential flag.
|
|
"""
|
|
if await credential_store.get("email_accounts_migrated") == "1":
|
|
return
|
|
|
|
from .inbox.accounts import create_account
|
|
from .inbox.triggers import list_triggers, update_trigger
|
|
from .database import get_pool
|
|
|
|
logger_main = logging.getLogger(__name__)
|
|
logger_main.info("[migrate] Running email_accounts one-time migration…")
|
|
|
|
# 1. Global trigger account (inbox:* keys in credential_store)
|
|
global_host = await credential_store.get("inbox:imap_host")
|
|
global_user = await credential_store.get("inbox:imap_username")
|
|
global_pass = await credential_store.get("inbox:imap_password")
|
|
|
|
global_account_id: str | None = None
|
|
if global_host and global_user and global_pass:
|
|
_smtp_port_raw = await credential_store.get("inbox:smtp_port")
|
|
acct = await create_account(
|
|
label="Global Inbox",
|
|
account_type="trigger",
|
|
imap_host=global_host,
|
|
imap_port=int(await credential_store.get("inbox:imap_port") or "993"),
|
|
imap_username=global_user,
|
|
imap_password=global_pass,
|
|
smtp_host=await credential_store.get("inbox:smtp_host"),
|
|
smtp_port=int(_smtp_port_raw) if _smtp_port_raw else 465,
|
|
smtp_username=await credential_store.get("inbox:smtp_username"),
|
|
smtp_password=await credential_store.get("inbox:smtp_password"),
|
|
user_id=None,
|
|
)
|
|
global_account_id = str(acct["id"])
|
|
logger_main.info("[migrate] Created global trigger account: %s", global_account_id)
|
|
|
|
# 2. Per-user trigger accounts (inbox_imap_host in user_settings)
|
|
from .database import user_settings_store
|
|
pool = await get_pool()
|
|
user_rows = await pool.fetch(
|
|
"SELECT DISTINCT user_id FROM user_settings WHERE key = 'inbox_imap_host'"
|
|
)
|
|
user_account_map: dict[str, str] = {} # user_id → account_id
|
|
for row in user_rows:
|
|
uid = row["user_id"]
|
|
host = await user_settings_store.get(uid, "inbox_imap_host")
|
|
uname = await user_settings_store.get(uid, "inbox_imap_username")
|
|
pw = await user_settings_store.get(uid, "inbox_imap_password")
|
|
if not (host and uname and pw):
|
|
continue
|
|
_u_smtp_port = await user_settings_store.get(uid, "inbox_smtp_port")
|
|
acct = await create_account(
|
|
label="My Inbox",
|
|
account_type="trigger",
|
|
imap_host=host,
|
|
imap_port=int(await user_settings_store.get(uid, "inbox_imap_port") or "993"),
|
|
imap_username=uname,
|
|
imap_password=pw,
|
|
smtp_host=await user_settings_store.get(uid, "inbox_smtp_host"),
|
|
smtp_port=int(_u_smtp_port) if _u_smtp_port else 465,
|
|
smtp_username=await user_settings_store.get(uid, "inbox_smtp_username"),
|
|
smtp_password=await user_settings_store.get(uid, "inbox_smtp_password"),
|
|
user_id=uid,
|
|
)
|
|
user_account_map[uid] = str(acct["id"])
|
|
logger_main.info("[migrate] Created trigger account for user %s: %s", uid, acct["id"])
|
|
|
|
# 3. Update existing email_triggers with account_id
|
|
all_triggers = await list_triggers(user_id=None)
|
|
for t in all_triggers:
|
|
tid = t["id"]
|
|
t_user_id = t.get("user_id")
|
|
if t_user_id is None and global_account_id:
|
|
await update_trigger(tid, account_id=global_account_id)
|
|
elif t_user_id and t_user_id in user_account_map:
|
|
await update_trigger(tid, account_id=user_account_map[t_user_id])
|
|
|
|
await credential_store.set("email_accounts_migrated", "1", "One-time email_accounts migration flag")
|
|
logger_main.info("[migrate] email_accounts migration complete.")
|
|
|
|
|
|
async def _refresh_brand_globals() -> None:
|
|
"""Update brand_name and logo_url Jinja2 globals from credential_store. Call at startup and after branding changes."""
|
|
brand_name = await credential_store.get("system:brand_name") or settings.agent_name
|
|
logo_filename = await credential_store.get("system:brand_logo_filename")
|
|
if logo_filename and (BASE_DIR / "web" / "static" / logo_filename).exists():
|
|
logo_url = f"/static/{logo_filename}"
|
|
else:
|
|
logo_url = "/static/logo.png"
|
|
templates.env.globals["brand_name"] = brand_name
|
|
templates.env.globals["logo_url"] = logo_url
|
|
|
|
# Cache-busting version: hash of static file contents so it always changes when files change.
|
|
# Avoids relying on git (not available in Docker container).
|
|
def _compute_static_version() -> str:
|
|
static_dir = BASE_DIR / "web" / "static"
|
|
h = hashlib.md5()
|
|
for f in sorted(static_dir.glob("*.js")) + sorted(static_dir.glob("*.css")):
|
|
try:
|
|
h.update(f.read_bytes())
|
|
except OSError:
|
|
pass
|
|
return h.hexdigest()[:10]
|
|
|
|
_static_version = _compute_static_version()
|
|
templates.env.globals["sv"] = _static_version
|
|
|
|
# ── First-run flag ─────────────────────────────────────────────────────────────
|
|
# Set in lifespan; cleared when /setup creates the first admin.
|
|
_needs_setup: bool = False
|
|
|
|
# ── Global agent (singleton — shares session history across requests) ─────────
|
|
_registry = None
|
|
_agent: Agent | None = None
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
global _registry, _agent, _needs_setup, _trusted_proxy_ips
|
|
await init_db()
|
|
await _refresh_brand_globals()
|
|
await _ensure_session_secret()
|
|
_needs_setup = await user_count() == 0
|
|
global _trusted_proxy_ips
|
|
_trusted_proxy_ips = await credential_store.get("system:trusted_proxy_ips") or "127.0.0.1"
|
|
await cleanup_stale_runs()
|
|
await init_brain_db()
|
|
_registry = build_registry()
|
|
from .mcp_client.manager import discover_and_register_mcp_tools
|
|
await discover_and_register_mcp_tools(_registry)
|
|
_agent = Agent(registry=_registry)
|
|
print("[aide] Agent ready.")
|
|
agent_runner.init(_agent)
|
|
await agent_runner.start()
|
|
await _migrate_email_accounts()
|
|
await inbox_listener.start_all()
|
|
telegram_listener.start()
|
|
async with _session_manager.run():
|
|
yield
|
|
inbox_listener.stop_all()
|
|
telegram_listener.stop()
|
|
agent_runner.shutdown()
|
|
await close_brain_db()
|
|
await close_db()
|
|
|
|
|
|
app = FastAPI(title="oAI-Web API", version="0.5", lifespan=lifespan)
|
|
|
|
|
|
# ── Custom OpenAPI schema — adds X-API-Key "Authorize" button in Swagger ──────
|
|
|
|
def _custom_openapi():
|
|
if app.openapi_schema:
|
|
return app.openapi_schema
|
|
from fastapi.openapi.utils import get_openapi
|
|
schema = get_openapi(title=app.title, version=app.version, routes=app.routes)
|
|
schema.setdefault("components", {})["securitySchemes"] = {
|
|
"ApiKeyAuth": {"type": "apiKey", "in": "header", "name": "X-API-Key"}
|
|
}
|
|
schema["security"] = [{"ApiKeyAuth": []}]
|
|
app.openapi_schema = schema
|
|
return schema
|
|
|
|
app.openapi = _custom_openapi
|
|
|
|
# ── Proxy trust ───────────────────────────────────────────────────────────────
|
|
|
|
_trusted_proxy_ips: str = "127.0.0.1"
|
|
|
|
|
|
class _ProxyTrustMiddleware:
|
|
"""Thin wrapper so trusted IPs are read from DB at startup, not hard-coded."""
|
|
|
|
def __init__(self, app):
|
|
from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware
|
|
self._app = app
|
|
self._inner: ProxyHeadersMiddleware | None = None
|
|
|
|
async def __call__(self, scope, receive, send):
|
|
if self._inner is None:
|
|
from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware
|
|
self._inner = ProxyHeadersMiddleware(self._app, trusted_hosts=_trusted_proxy_ips)
|
|
await self._inner(scope, receive, send)
|
|
|
|
|
|
app.add_middleware(_ProxyTrustMiddleware)
|
|
|
|
|
|
# ── Auth middleware ────────────────────────────────────────────────────────────
|
|
#
|
|
# All routes require authentication. Two accepted paths:
|
|
# 1. User session cookie (aide_user) — set on login, carries identity.
|
|
# 2. API key (X-API-Key or Authorization: Bearer) — treated as synthetic admin.
|
|
#
|
|
# Exempt paths bypass auth entirely (login, setup, static, health, etc.).
|
|
# First-run: if no users exist (_needs_setup), all non-exempt paths → /setup.
|
|
|
|
import hashlib as _hashlib
|
|
import hmac as _hmac
|
|
import secrets as _secrets
|
|
import time as _time
|
|
|
|
_USER_COOKIE = "aide_user"
|
|
_EXEMPT_PATHS = frozenset({"/login", "/login/mfa", "/logout", "/setup", "/health"})
|
|
_EXEMPT_PREFIXES = ("/static/", "/brain-mcp/", "/docs", "/redoc", "/openapi.json")
|
|
_EXEMPT_API_PATHS = frozenset({"/api/settings/api-key"})
|
|
|
|
|
|
async def _ensure_session_secret() -> str:
|
|
"""Return the session HMAC secret, creating it in the credential store if absent."""
|
|
secret = await credential_store.get("system:session_secret")
|
|
if not secret:
|
|
secret = _secrets.token_hex(32)
|
|
await credential_store.set("system:session_secret", secret,
|
|
description="Web UI session token secret (auto-generated)")
|
|
return secret
|
|
|
|
|
|
def _parse_user_cookie(raw_cookie: str) -> str:
|
|
"""Extract aide_user value from raw Cookie header string."""
|
|
for part in raw_cookie.split(";"):
|
|
part = part.strip()
|
|
if part.startswith(_USER_COOKIE + "="):
|
|
return part[len(_USER_COOKIE) + 1:]
|
|
return ""
|
|
|
|
|
|
async def _authenticate(headers: dict) -> CurrentUser | None:
|
|
"""Try user session cookie, then API key. Returns CurrentUser or None."""
|
|
# Try user session cookie
|
|
raw_cookie = headers.get(b"cookie", b"").decode()
|
|
cookie_val = _parse_user_cookie(raw_cookie)
|
|
if cookie_val:
|
|
secret = await credential_store.get("system:session_secret")
|
|
if secret:
|
|
user = decode_session_cookie(cookie_val, secret)
|
|
if user:
|
|
# Verify the user is still active in the DB — catches deactivated accounts
|
|
# whose session cookies haven't expired yet.
|
|
from .users import get_user_by_id as _get_user_by_id
|
|
db_user = await _get_user_by_id(user.id)
|
|
if db_user and db_user.get("is_active", True):
|
|
return user
|
|
|
|
# Try API key
|
|
key_hash = await credential_store.get("system:api_key_hash")
|
|
if key_hash:
|
|
provided = (
|
|
headers.get(b"x-api-key", b"").decode()
|
|
or headers.get(b"authorization", b"").decode().removeprefix("Bearer ").strip()
|
|
)
|
|
if provided and _hashlib.sha256(provided.encode()).hexdigest() == key_hash:
|
|
return SYNTHETIC_API_ADMIN
|
|
|
|
return None
|
|
|
|
|
|
class _AuthMiddleware:
|
|
"""Unified authentication middleware. Guards all routes except exempt paths."""
|
|
|
|
def __init__(self, app):
|
|
self._app = app
|
|
|
|
async def __call__(self, scope, receive, send):
|
|
if scope["type"] not in ("http", "websocket"):
|
|
await self._app(scope, receive, send)
|
|
return
|
|
|
|
path: str = scope.get("path", "")
|
|
|
|
# Always let exempt paths through
|
|
if path in _EXEMPT_PATHS or path in _EXEMPT_API_PATHS:
|
|
await self._app(scope, receive, send)
|
|
return
|
|
if any(path.startswith(p) for p in _EXEMPT_PREFIXES):
|
|
await self._app(scope, receive, send)
|
|
return
|
|
|
|
# First-run: redirect to /setup
|
|
if _needs_setup:
|
|
if scope["type"] == "websocket":
|
|
await send({"type": "websocket.close", "code": 1008})
|
|
return
|
|
response = RedirectResponse("/setup")
|
|
await response(scope, receive, send)
|
|
return
|
|
|
|
# Authenticate
|
|
headers = dict(scope.get("headers", []))
|
|
user = await _authenticate(headers)
|
|
|
|
if user is None:
|
|
if scope["type"] == "websocket":
|
|
await send({"type": "websocket.close", "code": 1008})
|
|
return
|
|
is_api = path.startswith("/api/") or path.startswith("/ws/")
|
|
if is_api:
|
|
response = JSONResponse({"error": "Authentication required"}, status_code=401)
|
|
await response(scope, receive, send)
|
|
return
|
|
else:
|
|
next_param = f"?next={path}" if path != "/" else ""
|
|
response = RedirectResponse(f"/login{next_param}")
|
|
await response(scope, receive, send)
|
|
return
|
|
|
|
# Set user on request state (for templates) and ContextVar (for tools/audit)
|
|
scope.setdefault("state", {})["current_user"] = user
|
|
token = _current_user_var.set(user)
|
|
try:
|
|
await self._app(scope, receive, send)
|
|
finally:
|
|
_current_user_var.reset(token)
|
|
|
|
|
|
app.add_middleware(_AuthMiddleware)
|
|
|
|
app.mount("/static", StaticFiles(directory=str(BASE_DIR / "web" / "static")), name="static")
|
|
app.include_router(api_router, prefix="/api")
|
|
|
|
# 2nd Brain MCP server — mounted at /brain-mcp (SSE transport)
|
|
app.mount("/brain-mcp", create_mcp_app())
|
|
|
|
|
|
|
|
# ── Auth helpers ──────────────────────────────────────────────────────────────
|
|
|
|
def _get_current_user(request: Request) -> CurrentUser | None:
|
|
try:
|
|
return request.state.current_user
|
|
except AttributeError:
|
|
return None
|
|
|
|
|
|
def _require_admin(request: Request) -> bool:
|
|
u = _get_current_user(request)
|
|
return u is not None and u.is_admin
|
|
|
|
|
|
# ── Login rate limiting ───────────────────────────────────────────────────────
|
|
|
|
from .login_limiter import is_locked as _login_is_locked
|
|
from .login_limiter import record_failure as _record_login_failure
|
|
from .login_limiter import clear_failures as _clear_login_failures
|
|
|
|
|
|
def _get_client_ip(request: Request) -> str:
|
|
"""Best-effort client IP, respecting X-Forwarded-For if set."""
|
|
forwarded = request.headers.get("x-forwarded-for")
|
|
if forwarded:
|
|
return forwarded.split(",")[0].strip()
|
|
return request.client.host if request.client else "unknown"
|
|
|
|
|
|
# ── Login / Logout / Setup ────────────────────────────────────────────────────
|
|
|
|
@app.get("/login", response_class=HTMLResponse)
|
|
async def login_get(request: Request, next: str = "/", error: str = ""):
|
|
if _get_current_user(request):
|
|
return RedirectResponse("/")
|
|
_ERROR_MESSAGES = {
|
|
"session_expired": "MFA session expired. Please sign in again.",
|
|
"too_many_attempts": "Too many incorrect codes. Please sign in again.",
|
|
}
|
|
error_msg = _ERROR_MESSAGES.get(error) if error else None
|
|
return templates.TemplateResponse("login.html", {"request": request, "next": next, "error": error_msg})
|
|
|
|
|
|
@app.post("/login")
|
|
async def login_post(request: Request):
|
|
import secrets as _secrets
|
|
from datetime import datetime, timezone, timedelta
|
|
from .auth import verify_password
|
|
form = await request.form()
|
|
username = str(form.get("username", "")).strip()
|
|
password = str(form.get("password", ""))
|
|
raw_next = str(form.get("next", "/")).strip() or "/"
|
|
# Reject absolute URLs and protocol-relative URLs to prevent open redirect
|
|
next_url = raw_next if (raw_next.startswith("/") and not raw_next.startswith("//")) else "/"
|
|
|
|
ip = _get_client_ip(request)
|
|
locked, lock_kind = _login_is_locked(ip)
|
|
if locked:
|
|
logger.warning("[login] blocked IP %s (%s)", ip, lock_kind)
|
|
if lock_kind == "permanent":
|
|
msg = "This IP has been permanently blocked due to repeated login failures. Contact an administrator."
|
|
else:
|
|
msg = "Too many failed attempts. Please try again in 30 minutes."
|
|
return templates.TemplateResponse("login.html", {
|
|
"request": request,
|
|
"next": next_url,
|
|
"error": msg,
|
|
}, status_code=429)
|
|
|
|
user = await get_user_by_username(username)
|
|
if user and user["is_active"] and verify_password(password, user["password_hash"]):
|
|
_clear_login_failures(ip)
|
|
# MFA branch: TOTP required
|
|
if user.get("totp_secret"):
|
|
token = _secrets.token_hex(32)
|
|
pool = await _db_pool()
|
|
now = datetime.now(timezone.utc)
|
|
expires = now + timedelta(minutes=5)
|
|
await pool.execute(
|
|
"INSERT INTO mfa_challenges (token, user_id, next_url, created_at, expires_at) "
|
|
"VALUES ($1, $2, $3, $4, $5)",
|
|
token, user["id"], next_url, now, expires,
|
|
)
|
|
response = RedirectResponse(f"/login/mfa", status_code=303)
|
|
response.set_cookie(
|
|
"mfa_challenge", token,
|
|
httponly=True, samesite="lax", max_age=300, path="/login/mfa",
|
|
)
|
|
return response
|
|
|
|
# No MFA — create session directly
|
|
secret = await _ensure_session_secret()
|
|
cookie_val = create_session_cookie(user, secret)
|
|
response = RedirectResponse(next_url, status_code=303)
|
|
response.set_cookie(
|
|
_USER_COOKIE, cookie_val,
|
|
httponly=True, samesite="lax", max_age=2592000, path="/",
|
|
)
|
|
return response
|
|
|
|
_record_login_failure(ip)
|
|
return templates.TemplateResponse("login.html", {
|
|
"request": request,
|
|
"next": next_url,
|
|
"error": "Invalid username or password.",
|
|
}, status_code=401)
|
|
|
|
|
|
async def _db_pool():
|
|
from .database import get_pool
|
|
return await get_pool()
|
|
|
|
|
|
@app.get("/login/mfa", response_class=HTMLResponse)
|
|
async def login_mfa_get(request: Request):
|
|
from datetime import datetime, timezone
|
|
token = request.cookies.get("mfa_challenge", "")
|
|
pool = await _db_pool()
|
|
row = await pool.fetchrow(
|
|
"SELECT user_id, next_url, expires_at FROM mfa_challenges WHERE token = $1", token
|
|
)
|
|
if not row or row["expires_at"] < datetime.now(timezone.utc):
|
|
return RedirectResponse("/login?error=session_expired", status_code=303)
|
|
return templates.TemplateResponse("mfa.html", {
|
|
"request": request,
|
|
"next": row["next_url"],
|
|
"error": None,
|
|
})
|
|
|
|
|
|
@app.post("/login/mfa")
|
|
async def login_mfa_post(request: Request):
|
|
from datetime import datetime, timezone
|
|
from .auth import verify_totp
|
|
form = await request.form()
|
|
code = str(form.get("code", "")).strip()
|
|
token = request.cookies.get("mfa_challenge", "")
|
|
pool = await _db_pool()
|
|
|
|
row = await pool.fetchrow(
|
|
"SELECT user_id, next_url, expires_at, attempts FROM mfa_challenges WHERE token = $1", token
|
|
)
|
|
if not row or row["expires_at"] < datetime.now(timezone.utc):
|
|
return RedirectResponse("/login?error=session_expired", status_code=303)
|
|
|
|
next_url = row["next_url"] or "/"
|
|
from .users import get_user_by_id
|
|
user = await get_user_by_id(row["user_id"])
|
|
if not user or not user.get("totp_secret"):
|
|
await pool.execute("DELETE FROM mfa_challenges WHERE token = $1", token)
|
|
return RedirectResponse("/login", status_code=303)
|
|
|
|
if not verify_totp(user["totp_secret"], code):
|
|
new_attempts = row["attempts"] + 1
|
|
if new_attempts >= 5:
|
|
await pool.execute("DELETE FROM mfa_challenges WHERE token = $1", token)
|
|
return RedirectResponse("/login?error=too_many_attempts", status_code=303)
|
|
await pool.execute(
|
|
"UPDATE mfa_challenges SET attempts = $1 WHERE token = $2", new_attempts, token
|
|
)
|
|
response = templates.TemplateResponse("mfa.html", {
|
|
"request": request,
|
|
"next": next_url,
|
|
"error": "Invalid code. Try again.",
|
|
}, status_code=401)
|
|
return response
|
|
|
|
# Success
|
|
await pool.execute("DELETE FROM mfa_challenges WHERE token = $1", token)
|
|
secret = await _ensure_session_secret()
|
|
cookie_val = create_session_cookie(user, secret)
|
|
response = RedirectResponse(next_url, status_code=303)
|
|
response.set_cookie(
|
|
_USER_COOKIE, cookie_val,
|
|
httponly=True, samesite="lax", max_age=2592000, path="/",
|
|
)
|
|
response.delete_cookie("mfa_challenge", path="/login/mfa")
|
|
return response
|
|
|
|
|
|
@app.get("/logout")
|
|
async def logout(request: Request):
|
|
# Render a tiny page that clears localStorage then redirects to /login.
|
|
# This prevents the next user on the same browser from restoring the
|
|
# previous user's conversation via the persisted current_session_id key.
|
|
response = HTMLResponse("""<!doctype html>
|
|
<html><head><title>Logging out…</title></head><body>
|
|
<script>
|
|
localStorage.removeItem("current_session_id");
|
|
localStorage.removeItem("preferred-model");
|
|
window.location.replace("/login");
|
|
</script>
|
|
</body></html>""")
|
|
response.delete_cookie(_USER_COOKIE, path="/")
|
|
return response
|
|
|
|
|
|
@app.get("/setup", response_class=HTMLResponse)
|
|
async def setup_get(request: Request):
|
|
if not _needs_setup:
|
|
return RedirectResponse("/")
|
|
return templates.TemplateResponse("setup.html", {"request": request, "errors": [], "username": ""})
|
|
|
|
|
|
@app.post("/setup")
|
|
async def setup_post(request: Request):
|
|
global _needs_setup
|
|
if not _needs_setup:
|
|
return RedirectResponse("/", status_code=303)
|
|
|
|
form = await request.form()
|
|
username = str(form.get("username", "")).strip()
|
|
password = str(form.get("password", ""))
|
|
confirm = str(form.get("confirm", ""))
|
|
email = str(form.get("email", "")).strip().lower()
|
|
|
|
errors = []
|
|
if not username:
|
|
errors.append("Username is required.")
|
|
if not email or "@" not in email:
|
|
errors.append("A valid email address is required.")
|
|
if len(password) < 8:
|
|
errors.append("Password must be at least 8 characters.")
|
|
if password != confirm:
|
|
errors.append("Passwords do not match.")
|
|
|
|
if errors:
|
|
return templates.TemplateResponse("setup.html", {
|
|
"request": request,
|
|
"errors": errors,
|
|
"username": username,
|
|
"email": email,
|
|
}, status_code=400)
|
|
|
|
user = await create_user(username, password, role="admin", email=email)
|
|
await assign_existing_data_to_admin(user["id"])
|
|
_needs_setup = False
|
|
|
|
secret = await _ensure_session_secret()
|
|
cookie_val = create_session_cookie(user, secret)
|
|
response = RedirectResponse("/", status_code=303)
|
|
response.set_cookie(_USER_COOKIE, cookie_val, httponly=True, samesite="lax", max_age=2592000, path="/")
|
|
return response
|
|
|
|
|
|
# ── HTML pages ────────────────────────────────────────────────────────────────
|
|
|
|
async def _ctx(request: Request, **extra):
|
|
"""Build template context with current_user and active theme CSS injected."""
|
|
from .web.themes import get_theme_css, DEFAULT_THEME
|
|
from .database import user_settings_store
|
|
user = _get_current_user(request)
|
|
theme_css = ""
|
|
needs_personality_setup = False
|
|
if user:
|
|
theme_id = await user_settings_store.get(user.id, "theme") or DEFAULT_THEME
|
|
theme_css = get_theme_css(theme_id)
|
|
if user.role != "admin":
|
|
done = await user_settings_store.get(user.id, "personality_setup_done")
|
|
needs_personality_setup = not done
|
|
return {
|
|
"request": request,
|
|
"current_user": user,
|
|
"theme_css": theme_css,
|
|
"needs_personality_setup": needs_personality_setup,
|
|
**extra,
|
|
}
|
|
|
|
|
|
@app.get("/", response_class=HTMLResponse)
|
|
async def chat_page(request: Request, session: str = ""):
|
|
# Allow reopening a saved conversation via /?session=<id>
|
|
session_id = session.strip() if session.strip() else str(uuid.uuid4())
|
|
return templates.TemplateResponse("chat.html", await _ctx(request, session_id=session_id))
|
|
|
|
|
|
@app.get("/chats", response_class=HTMLResponse)
|
|
async def chats_page(request: Request):
|
|
return templates.TemplateResponse("chats.html", await _ctx(request))
|
|
|
|
|
|
@app.get("/agents", response_class=HTMLResponse)
|
|
async def agents_page(request: Request):
|
|
return templates.TemplateResponse("agents.html", await _ctx(request))
|
|
|
|
|
|
@app.get("/agents/{agent_id}", response_class=HTMLResponse)
|
|
async def agent_detail_page(request: Request, agent_id: str):
|
|
return templates.TemplateResponse("agent_detail.html", await _ctx(request, agent_id=agent_id))
|
|
|
|
|
|
@app.get("/models", response_class=HTMLResponse)
|
|
async def models_page(request: Request):
|
|
return templates.TemplateResponse("models.html", await _ctx(request))
|
|
|
|
|
|
@app.get("/audit", response_class=HTMLResponse)
|
|
async def audit_page(request: Request):
|
|
return templates.TemplateResponse("audit.html", await _ctx(request))
|
|
|
|
|
|
@app.get("/help", response_class=HTMLResponse)
|
|
async def help_page(request: Request):
|
|
return templates.TemplateResponse("help.html", await _ctx(request))
|
|
|
|
|
|
@app.get("/files", response_class=HTMLResponse)
|
|
async def files_page(request: Request):
|
|
return templates.TemplateResponse("files.html", await _ctx(request))
|
|
|
|
|
|
@app.get("/settings", response_class=HTMLResponse)
|
|
async def settings_page(request: Request):
|
|
user = _get_current_user(request)
|
|
if user is None:
|
|
return RedirectResponse("/login?next=/settings")
|
|
ctx = await _ctx(request)
|
|
if user.is_admin:
|
|
rows = await credential_store.list_keys()
|
|
is_paused = await credential_store.get("system:paused") == "1"
|
|
ctx.update(credential_keys=[r["key"] for r in rows], is_paused=is_paused)
|
|
return templates.TemplateResponse("settings.html", ctx)
|
|
|
|
|
|
@app.get("/admin/users", response_class=HTMLResponse)
|
|
async def admin_users_page(request: Request):
|
|
if not _require_admin(request):
|
|
return RedirectResponse("/")
|
|
return templates.TemplateResponse("admin_users.html", await _ctx(request))
|
|
|
|
|
|
# ── Kill switch ───────────────────────────────────────────────────────────────
|
|
|
|
@app.post("/api/pause")
|
|
async def pause_agent(request: Request):
|
|
if not _require_admin(request):
|
|
raise HTTPException(status_code=403, detail="Admin only")
|
|
await credential_store.set("system:paused", "1", description="Kill switch")
|
|
return {"status": "paused"}
|
|
|
|
|
|
@app.post("/api/resume")
|
|
async def resume_agent(request: Request):
|
|
if not _require_admin(request):
|
|
raise HTTPException(status_code=403, detail="Admin only")
|
|
await credential_store.delete("system:paused")
|
|
return {"status": "running"}
|
|
|
|
|
|
@app.get("/api/status")
|
|
async def agent_status():
|
|
return {
|
|
"paused": await credential_store.get("system:paused") == "1",
|
|
"pending_confirmations": confirmation_manager.list_pending(),
|
|
}
|
|
|
|
|
|
@app.get("/health")
|
|
async def health():
|
|
return {"status": "ok"}
|
|
|
|
|
|
# ── WebSocket ─────────────────────────────────────────────────────────────────
|
|
|
|
@app.websocket("/ws/{session_id}")
|
|
async def websocket_endpoint(websocket: WebSocket, session_id: str):
|
|
await websocket.accept()
|
|
_ws_user = getattr(websocket.state, "current_user", None)
|
|
_ws_is_admin = _ws_user.is_admin if _ws_user else True
|
|
_ws_user_id = _ws_user.id if _ws_user else None
|
|
|
|
# Send available models immediately on connect (filtered per user's access tier)
|
|
from .providers.models import get_available_models, get_capability_map
|
|
try:
|
|
_models, _default = await get_available_models(user_id=_ws_user_id, is_admin=_ws_is_admin)
|
|
_caps = await get_capability_map(user_id=_ws_user_id, is_admin=_ws_is_admin)
|
|
await websocket.send_json({
|
|
"type": "models",
|
|
"models": _models,
|
|
"default": _default,
|
|
"capabilities": _caps,
|
|
})
|
|
except WebSocketDisconnect:
|
|
return
|
|
|
|
# Discover per-user MCP tools (3-E) — discovered once per connection
|
|
_user_mcp_tools: list = []
|
|
if _ws_user_id:
|
|
try:
|
|
from .mcp_client.manager import discover_user_mcp_tools
|
|
_user_mcp_tools = await discover_user_mcp_tools(_ws_user_id)
|
|
except Exception as _e:
|
|
logger.warning("Failed to discover user MCP tools: %s", _e)
|
|
|
|
# If this session has existing history (reopened chat), send it to the client
|
|
try:
|
|
from .database import get_pool as _get_pool
|
|
_pool = await _get_pool()
|
|
# Only restore if this session belongs to the current user (or is unowned)
|
|
_conv = await _pool.fetchrow(
|
|
"SELECT messages, title, model FROM conversations WHERE id = $1 AND (user_id = $2 OR user_id IS NULL)",
|
|
session_id, _ws_user_id,
|
|
)
|
|
if _conv and _conv["messages"]:
|
|
_msgs = _conv["messages"]
|
|
if isinstance(_msgs, str):
|
|
_msgs = json.loads(_msgs)
|
|
# Build a simplified view: only user + assistant text turns
|
|
_restore_turns = []
|
|
for _m in _msgs:
|
|
_role = _m.get("role")
|
|
if _role == "user":
|
|
_content = _m.get("content", "")
|
|
if isinstance(_content, list):
|
|
_text = " ".join(b.get("text", "") for b in _content if b.get("type") == "text")
|
|
else:
|
|
_text = str(_content)
|
|
if _text.strip():
|
|
_restore_turns.append({"role": "user", "text": _text.strip()})
|
|
elif _role == "assistant":
|
|
_content = _m.get("content", "")
|
|
if isinstance(_content, list):
|
|
_text = " ".join(b.get("text", "") for b in _content if b.get("type") == "text")
|
|
else:
|
|
_text = str(_content) if _content else ""
|
|
if _text.strip():
|
|
_restore_turns.append({"role": "assistant", "text": _text.strip()})
|
|
if _restore_turns:
|
|
await websocket.send_json({
|
|
"type": "restore",
|
|
"session_id": session_id,
|
|
"title": _conv["title"] or "",
|
|
"model": _conv["model"] or "",
|
|
"messages": _restore_turns,
|
|
})
|
|
except Exception as _e:
|
|
logger.warning("Failed to send restore event for session %s: %s", session_id, _e)
|
|
|
|
# Queue for incoming user messages (so receiver and agent run concurrently)
|
|
msg_queue: asyncio.Queue[dict] = asyncio.Queue()
|
|
|
|
async def receiver():
|
|
"""Receive messages from client. Confirmations handled immediately."""
|
|
try:
|
|
async for raw in websocket.iter_json():
|
|
if raw.get("type") == "confirm":
|
|
confirmation_manager.respond(session_id, raw.get("approved", False))
|
|
elif raw.get("type") == "message":
|
|
await msg_queue.put(raw)
|
|
elif raw.get("type") == "clear":
|
|
if _agent:
|
|
_agent.clear_history(session_id)
|
|
except WebSocketDisconnect:
|
|
await msg_queue.put({"type": "_disconnect"})
|
|
|
|
async def sender():
|
|
"""Process queued messages through the agent, stream events back."""
|
|
while True:
|
|
raw = await msg_queue.get()
|
|
if raw.get("type") == "_disconnect":
|
|
break
|
|
|
|
content = raw.get("content", "").strip()
|
|
attachments = raw.get("attachments") or None # list of {media_type, data}
|
|
if not content and not attachments:
|
|
continue
|
|
|
|
if _agent is None:
|
|
await websocket.send_json({"type": "error", "message": "Agent not ready."})
|
|
continue
|
|
|
|
model = raw.get("model") or None
|
|
|
|
try:
|
|
chat_allowed_tools: list[str] | None = None
|
|
if not _ws_is_admin and _registry is not None:
|
|
all_names = [t.name for t in _registry.all_tools()]
|
|
chat_allowed_tools = [t for t in all_names if t != "bash"]
|
|
stream = await _agent.run(
|
|
message=content,
|
|
session_id=session_id,
|
|
model=model,
|
|
allowed_tools=chat_allowed_tools,
|
|
user_id=_ws_user_id,
|
|
extra_tools=_user_mcp_tools or None,
|
|
attachments=attachments,
|
|
)
|
|
async for event in stream:
|
|
payload = _event_to_dict(event)
|
|
await websocket.send_json(payload)
|
|
except Exception as e:
|
|
await websocket.send_json({"type": "error", "message": str(e)})
|
|
|
|
try:
|
|
await asyncio.gather(receiver(), sender())
|
|
except WebSocketDisconnect:
|
|
pass
|
|
|
|
|
|
def _event_to_dict(event: AgentEvent) -> dict:
|
|
if isinstance(event, TextEvent):
|
|
return {"type": "text", "content": event.content}
|
|
if isinstance(event, ToolStartEvent):
|
|
return {"type": "tool_start", "call_id": event.call_id, "tool_name": event.tool_name, "arguments": event.arguments}
|
|
if isinstance(event, ToolDoneEvent):
|
|
return {"type": "tool_done", "call_id": event.call_id, "tool_name": event.tool_name, "success": event.success, "result": event.result_summary, "confirmed": event.confirmed}
|
|
if isinstance(event, ConfirmationRequiredEvent):
|
|
return {"type": "confirmation_required", "call_id": event.call_id, "tool_name": event.tool_name, "arguments": event.arguments, "description": event.description}
|
|
if isinstance(event, DoneEvent):
|
|
return {"type": "done", "tool_calls_made": event.tool_calls_made, "usage": {"input": event.usage.input_tokens, "output": event.usage.output_tokens}}
|
|
if isinstance(event, ImageEvent):
|
|
return {"type": "image", "data_urls": event.data_urls}
|
|
if isinstance(event, ErrorEvent):
|
|
return {"type": "error", "message": event.message}
|
|
return {"type": "unknown"}
|