""" main.py — FastAPI application entry point. Provides: - HTML pages: /, /agents, /audit, /settings, /login, /setup, /admin/users - WebSocket: /ws/{session_id} (streaming agent responses) - REST API: /api/* """ from __future__ import annotations import asyncio import hashlib import json import logging import uuid from contextlib import asynccontextmanager from pathlib import Path # Configure logging before anything else imports logging logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)-8s %(name)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) # Make CalDAV tool logs visible at DEBUG level so every step is traceable logging.getLogger("server.tools.caldav_tool").setLevel(logging.DEBUG) from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from .agent.agent import Agent, AgentEvent, ConfirmationRequiredEvent, DoneEvent, ErrorEvent, ImageEvent, TextEvent, ToolDoneEvent, ToolStartEvent from .agent.confirmation import confirmation_manager from .agents.runner import agent_runner from .agents.tasks import cleanup_stale_runs from .auth import SYNTHETIC_API_ADMIN, CurrentUser, create_session_cookie, decode_session_cookie from .brain.database import close_brain_db, init_brain_db from .config import settings from .context_vars import current_user as _current_user_var from .database import close_db, credential_store, init_db from .inbox.listener import inbox_listener from .mcp import create_mcp_app, _session_manager from .telegram.listener import telegram_listener from .tools import build_registry from .users import assign_existing_data_to_admin, create_user, get_user_by_username, user_count from .web.routes import router as api_router BASE_DIR = Path(__file__).parent templates = Jinja2Templates(directory=str(BASE_DIR / "web" / "templates")) templates.env.globals["agent_name"] = settings.agent_name async def _migrate_email_accounts() -> None: """ One-time startup migration: copy old inbox:* / inbox_* credentials into the new email_accounts table as 'trigger' type accounts. Idempotent — guarded by the 'email_accounts_migrated' credential flag. """ if await credential_store.get("email_accounts_migrated") == "1": return from .inbox.accounts import create_account from .inbox.triggers import list_triggers, update_trigger from .database import get_pool logger_main = logging.getLogger(__name__) logger_main.info("[migrate] Running email_accounts one-time migration…") # 1. Global trigger account (inbox:* keys in credential_store) global_host = await credential_store.get("inbox:imap_host") global_user = await credential_store.get("inbox:imap_username") global_pass = await credential_store.get("inbox:imap_password") global_account_id: str | None = None if global_host and global_user and global_pass: _smtp_port_raw = await credential_store.get("inbox:smtp_port") acct = await create_account( label="Global Inbox", account_type="trigger", imap_host=global_host, imap_port=int(await credential_store.get("inbox:imap_port") or "993"), imap_username=global_user, imap_password=global_pass, smtp_host=await credential_store.get("inbox:smtp_host"), smtp_port=int(_smtp_port_raw) if _smtp_port_raw else 465, smtp_username=await credential_store.get("inbox:smtp_username"), smtp_password=await credential_store.get("inbox:smtp_password"), user_id=None, ) global_account_id = str(acct["id"]) logger_main.info("[migrate] Created global trigger account: %s", global_account_id) # 2. Per-user trigger accounts (inbox_imap_host in user_settings) from .database import user_settings_store pool = await get_pool() user_rows = await pool.fetch( "SELECT DISTINCT user_id FROM user_settings WHERE key = 'inbox_imap_host'" ) user_account_map: dict[str, str] = {} # user_id → account_id for row in user_rows: uid = row["user_id"] host = await user_settings_store.get(uid, "inbox_imap_host") uname = await user_settings_store.get(uid, "inbox_imap_username") pw = await user_settings_store.get(uid, "inbox_imap_password") if not (host and uname and pw): continue _u_smtp_port = await user_settings_store.get(uid, "inbox_smtp_port") acct = await create_account( label="My Inbox", account_type="trigger", imap_host=host, imap_port=int(await user_settings_store.get(uid, "inbox_imap_port") or "993"), imap_username=uname, imap_password=pw, smtp_host=await user_settings_store.get(uid, "inbox_smtp_host"), smtp_port=int(_u_smtp_port) if _u_smtp_port else 465, smtp_username=await user_settings_store.get(uid, "inbox_smtp_username"), smtp_password=await user_settings_store.get(uid, "inbox_smtp_password"), user_id=uid, ) user_account_map[uid] = str(acct["id"]) logger_main.info("[migrate] Created trigger account for user %s: %s", uid, acct["id"]) # 3. Update existing email_triggers with account_id all_triggers = await list_triggers(user_id=None) for t in all_triggers: tid = t["id"] t_user_id = t.get("user_id") if t_user_id is None and global_account_id: await update_trigger(tid, account_id=global_account_id) elif t_user_id and t_user_id in user_account_map: await update_trigger(tid, account_id=user_account_map[t_user_id]) await credential_store.set("email_accounts_migrated", "1", "One-time email_accounts migration flag") logger_main.info("[migrate] email_accounts migration complete.") async def _refresh_brand_globals() -> None: """Update brand_name and logo_url Jinja2 globals from credential_store. Call at startup and after branding changes.""" brand_name = await credential_store.get("system:brand_name") or settings.agent_name logo_filename = await credential_store.get("system:brand_logo_filename") if logo_filename and (BASE_DIR / "web" / "static" / logo_filename).exists(): logo_url = f"/static/{logo_filename}" else: logo_url = "/static/logo.png" templates.env.globals["brand_name"] = brand_name templates.env.globals["logo_url"] = logo_url # Cache-busting version: hash of static file contents so it always changes when files change. # Avoids relying on git (not available in Docker container). def _compute_static_version() -> str: static_dir = BASE_DIR / "web" / "static" h = hashlib.md5() for f in sorted(static_dir.glob("*.js")) + sorted(static_dir.glob("*.css")): try: h.update(f.read_bytes()) except OSError: pass return h.hexdigest()[:10] _static_version = _compute_static_version() templates.env.globals["sv"] = _static_version # ── First-run flag ───────────────────────────────────────────────────────────── # Set in lifespan; cleared when /setup creates the first admin. _needs_setup: bool = False # ── Global agent (singleton — shares session history across requests) ───────── _registry = None _agent: Agent | None = None @asynccontextmanager async def lifespan(app: FastAPI): global _registry, _agent, _needs_setup, _trusted_proxy_ips await init_db() await _refresh_brand_globals() await _ensure_session_secret() _needs_setup = await user_count() == 0 global _trusted_proxy_ips _trusted_proxy_ips = await credential_store.get("system:trusted_proxy_ips") or "127.0.0.1" await cleanup_stale_runs() await init_brain_db() _registry = build_registry() from .mcp_client.manager import discover_and_register_mcp_tools await discover_and_register_mcp_tools(_registry) _agent = Agent(registry=_registry) print("[aide] Agent ready.") agent_runner.init(_agent) await agent_runner.start() await _migrate_email_accounts() await inbox_listener.start_all() telegram_listener.start() async with _session_manager.run(): yield inbox_listener.stop_all() telegram_listener.stop() agent_runner.shutdown() await close_brain_db() await close_db() app = FastAPI(title="oAI-Web API", version="0.5", lifespan=lifespan) # ── Custom OpenAPI schema — adds X-API-Key "Authorize" button in Swagger ────── def _custom_openapi(): if app.openapi_schema: return app.openapi_schema from fastapi.openapi.utils import get_openapi schema = get_openapi(title=app.title, version=app.version, routes=app.routes) schema.setdefault("components", {})["securitySchemes"] = { "ApiKeyAuth": {"type": "apiKey", "in": "header", "name": "X-API-Key"} } schema["security"] = [{"ApiKeyAuth": []}] app.openapi_schema = schema return schema app.openapi = _custom_openapi # ── Proxy trust ─────────────────────────────────────────────────────────────── _trusted_proxy_ips: str = "127.0.0.1" class _ProxyTrustMiddleware: """Thin wrapper so trusted IPs are read from DB at startup, not hard-coded.""" def __init__(self, app): from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware self._app = app self._inner: ProxyHeadersMiddleware | None = None async def __call__(self, scope, receive, send): if self._inner is None: from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware self._inner = ProxyHeadersMiddleware(self._app, trusted_hosts=_trusted_proxy_ips) await self._inner(scope, receive, send) app.add_middleware(_ProxyTrustMiddleware) # ── Auth middleware ──────────────────────────────────────────────────────────── # # All routes require authentication. Two accepted paths: # 1. User session cookie (aide_user) — set on login, carries identity. # 2. API key (X-API-Key or Authorization: Bearer) — treated as synthetic admin. # # Exempt paths bypass auth entirely (login, setup, static, health, etc.). # First-run: if no users exist (_needs_setup), all non-exempt paths → /setup. import hashlib as _hashlib import hmac as _hmac import secrets as _secrets import time as _time _USER_COOKIE = "aide_user" _EXEMPT_PATHS = frozenset({"/login", "/login/mfa", "/logout", "/setup", "/health"}) _EXEMPT_PREFIXES = ("/static/", "/brain-mcp/", "/docs", "/redoc", "/openapi.json") _EXEMPT_API_PATHS = frozenset({"/api/settings/api-key"}) async def _ensure_session_secret() -> str: """Return the session HMAC secret, creating it in the credential store if absent.""" secret = await credential_store.get("system:session_secret") if not secret: secret = _secrets.token_hex(32) await credential_store.set("system:session_secret", secret, description="Web UI session token secret (auto-generated)") return secret def _parse_user_cookie(raw_cookie: str) -> str: """Extract aide_user value from raw Cookie header string.""" for part in raw_cookie.split(";"): part = part.strip() if part.startswith(_USER_COOKIE + "="): return part[len(_USER_COOKIE) + 1:] return "" async def _authenticate(headers: dict) -> CurrentUser | None: """Try user session cookie, then API key. Returns CurrentUser or None.""" # Try user session cookie raw_cookie = headers.get(b"cookie", b"").decode() cookie_val = _parse_user_cookie(raw_cookie) if cookie_val: secret = await credential_store.get("system:session_secret") if secret: user = decode_session_cookie(cookie_val, secret) if user: # Verify the user is still active in the DB — catches deactivated accounts # whose session cookies haven't expired yet. from .users import get_user_by_id as _get_user_by_id db_user = await _get_user_by_id(user.id) if db_user and db_user.get("is_active", True): return user # Try API key key_hash = await credential_store.get("system:api_key_hash") if key_hash: provided = ( headers.get(b"x-api-key", b"").decode() or headers.get(b"authorization", b"").decode().removeprefix("Bearer ").strip() ) if provided and _hashlib.sha256(provided.encode()).hexdigest() == key_hash: return SYNTHETIC_API_ADMIN return None class _AuthMiddleware: """Unified authentication middleware. Guards all routes except exempt paths.""" def __init__(self, app): self._app = app async def __call__(self, scope, receive, send): if scope["type"] not in ("http", "websocket"): await self._app(scope, receive, send) return path: str = scope.get("path", "") # Always let exempt paths through if path in _EXEMPT_PATHS or path in _EXEMPT_API_PATHS: await self._app(scope, receive, send) return if any(path.startswith(p) for p in _EXEMPT_PREFIXES): await self._app(scope, receive, send) return # First-run: redirect to /setup if _needs_setup: if scope["type"] == "websocket": await send({"type": "websocket.close", "code": 1008}) return response = RedirectResponse("/setup") await response(scope, receive, send) return # Authenticate headers = dict(scope.get("headers", [])) user = await _authenticate(headers) if user is None: if scope["type"] == "websocket": await send({"type": "websocket.close", "code": 1008}) return is_api = path.startswith("/api/") or path.startswith("/ws/") if is_api: response = JSONResponse({"error": "Authentication required"}, status_code=401) await response(scope, receive, send) return else: next_param = f"?next={path}" if path != "/" else "" response = RedirectResponse(f"/login{next_param}") await response(scope, receive, send) return # Set user on request state (for templates) and ContextVar (for tools/audit) scope.setdefault("state", {})["current_user"] = user token = _current_user_var.set(user) try: await self._app(scope, receive, send) finally: _current_user_var.reset(token) app.add_middleware(_AuthMiddleware) app.mount("/static", StaticFiles(directory=str(BASE_DIR / "web" / "static")), name="static") app.include_router(api_router, prefix="/api") # 2nd Brain MCP server — mounted at /brain-mcp (SSE transport) app.mount("/brain-mcp", create_mcp_app()) # ── Auth helpers ────────────────────────────────────────────────────────────── def _get_current_user(request: Request) -> CurrentUser | None: try: return request.state.current_user except AttributeError: return None def _require_admin(request: Request) -> bool: u = _get_current_user(request) return u is not None and u.is_admin # ── Login rate limiting ─────────────────────────────────────────────────────── from .login_limiter import is_locked as _login_is_locked from .login_limiter import record_failure as _record_login_failure from .login_limiter import clear_failures as _clear_login_failures def _get_client_ip(request: Request) -> str: """Best-effort client IP, respecting X-Forwarded-For if set.""" forwarded = request.headers.get("x-forwarded-for") if forwarded: return forwarded.split(",")[0].strip() return request.client.host if request.client else "unknown" # ── Login / Logout / Setup ──────────────────────────────────────────────────── @app.get("/login", response_class=HTMLResponse) async def login_get(request: Request, next: str = "/", error: str = ""): if _get_current_user(request): return RedirectResponse("/") _ERROR_MESSAGES = { "session_expired": "MFA session expired. Please sign in again.", "too_many_attempts": "Too many incorrect codes. Please sign in again.", } error_msg = _ERROR_MESSAGES.get(error) if error else None return templates.TemplateResponse("login.html", {"request": request, "next": next, "error": error_msg}) @app.post("/login") async def login_post(request: Request): import secrets as _secrets from datetime import datetime, timezone, timedelta from .auth import verify_password form = await request.form() username = str(form.get("username", "")).strip() password = str(form.get("password", "")) raw_next = str(form.get("next", "/")).strip() or "/" # Reject absolute URLs and protocol-relative URLs to prevent open redirect next_url = raw_next if (raw_next.startswith("/") and not raw_next.startswith("//")) else "/" ip = _get_client_ip(request) locked, lock_kind = _login_is_locked(ip) if locked: logger.warning("[login] blocked IP %s (%s)", ip, lock_kind) if lock_kind == "permanent": msg = "This IP has been permanently blocked due to repeated login failures. Contact an administrator." else: msg = "Too many failed attempts. Please try again in 30 minutes." return templates.TemplateResponse("login.html", { "request": request, "next": next_url, "error": msg, }, status_code=429) user = await get_user_by_username(username) if user and user["is_active"] and verify_password(password, user["password_hash"]): _clear_login_failures(ip) # MFA branch: TOTP required if user.get("totp_secret"): token = _secrets.token_hex(32) pool = await _db_pool() now = datetime.now(timezone.utc) expires = now + timedelta(minutes=5) await pool.execute( "INSERT INTO mfa_challenges (token, user_id, next_url, created_at, expires_at) " "VALUES ($1, $2, $3, $4, $5)", token, user["id"], next_url, now, expires, ) response = RedirectResponse(f"/login/mfa", status_code=303) response.set_cookie( "mfa_challenge", token, httponly=True, samesite="lax", max_age=300, path="/login/mfa", ) return response # No MFA — create session directly secret = await _ensure_session_secret() cookie_val = create_session_cookie(user, secret) response = RedirectResponse(next_url, status_code=303) response.set_cookie( _USER_COOKIE, cookie_val, httponly=True, samesite="lax", max_age=2592000, path="/", ) return response _record_login_failure(ip) return templates.TemplateResponse("login.html", { "request": request, "next": next_url, "error": "Invalid username or password.", }, status_code=401) async def _db_pool(): from .database import get_pool return await get_pool() @app.get("/login/mfa", response_class=HTMLResponse) async def login_mfa_get(request: Request): from datetime import datetime, timezone token = request.cookies.get("mfa_challenge", "") pool = await _db_pool() row = await pool.fetchrow( "SELECT user_id, next_url, expires_at FROM mfa_challenges WHERE token = $1", token ) if not row or row["expires_at"] < datetime.now(timezone.utc): return RedirectResponse("/login?error=session_expired", status_code=303) return templates.TemplateResponse("mfa.html", { "request": request, "next": row["next_url"], "error": None, }) @app.post("/login/mfa") async def login_mfa_post(request: Request): from datetime import datetime, timezone from .auth import verify_totp form = await request.form() code = str(form.get("code", "")).strip() token = request.cookies.get("mfa_challenge", "") pool = await _db_pool() row = await pool.fetchrow( "SELECT user_id, next_url, expires_at, attempts FROM mfa_challenges WHERE token = $1", token ) if not row or row["expires_at"] < datetime.now(timezone.utc): return RedirectResponse("/login?error=session_expired", status_code=303) next_url = row["next_url"] or "/" from .users import get_user_by_id user = await get_user_by_id(row["user_id"]) if not user or not user.get("totp_secret"): await pool.execute("DELETE FROM mfa_challenges WHERE token = $1", token) return RedirectResponse("/login", status_code=303) if not verify_totp(user["totp_secret"], code): new_attempts = row["attempts"] + 1 if new_attempts >= 5: await pool.execute("DELETE FROM mfa_challenges WHERE token = $1", token) return RedirectResponse("/login?error=too_many_attempts", status_code=303) await pool.execute( "UPDATE mfa_challenges SET attempts = $1 WHERE token = $2", new_attempts, token ) response = templates.TemplateResponse("mfa.html", { "request": request, "next": next_url, "error": "Invalid code. Try again.", }, status_code=401) return response # Success await pool.execute("DELETE FROM mfa_challenges WHERE token = $1", token) secret = await _ensure_session_secret() cookie_val = create_session_cookie(user, secret) response = RedirectResponse(next_url, status_code=303) response.set_cookie( _USER_COOKIE, cookie_val, httponly=True, samesite="lax", max_age=2592000, path="/", ) response.delete_cookie("mfa_challenge", path="/login/mfa") return response @app.get("/logout") async def logout(request: Request): # Render a tiny page that clears localStorage then redirects to /login. # This prevents the next user on the same browser from restoring the # previous user's conversation via the persisted current_session_id key. response = HTMLResponse(""" Logging out… """) response.delete_cookie(_USER_COOKIE, path="/") return response @app.get("/setup", response_class=HTMLResponse) async def setup_get(request: Request): if not _needs_setup: return RedirectResponse("/") return templates.TemplateResponse("setup.html", {"request": request, "errors": [], "username": ""}) @app.post("/setup") async def setup_post(request: Request): global _needs_setup if not _needs_setup: return RedirectResponse("/", status_code=303) form = await request.form() username = str(form.get("username", "")).strip() password = str(form.get("password", "")) confirm = str(form.get("confirm", "")) email = str(form.get("email", "")).strip().lower() errors = [] if not username: errors.append("Username is required.") if not email or "@" not in email: errors.append("A valid email address is required.") if len(password) < 8: errors.append("Password must be at least 8 characters.") if password != confirm: errors.append("Passwords do not match.") if errors: return templates.TemplateResponse("setup.html", { "request": request, "errors": errors, "username": username, "email": email, }, status_code=400) user = await create_user(username, password, role="admin", email=email) await assign_existing_data_to_admin(user["id"]) _needs_setup = False secret = await _ensure_session_secret() cookie_val = create_session_cookie(user, secret) response = RedirectResponse("/", status_code=303) response.set_cookie(_USER_COOKIE, cookie_val, httponly=True, samesite="lax", max_age=2592000, path="/") return response # ── HTML pages ──────────────────────────────────────────────────────────────── async def _ctx(request: Request, **extra): """Build template context with current_user and active theme CSS injected.""" from .web.themes import get_theme_css, DEFAULT_THEME from .database import user_settings_store user = _get_current_user(request) theme_css = "" needs_personality_setup = False if user: theme_id = await user_settings_store.get(user.id, "theme") or DEFAULT_THEME theme_css = get_theme_css(theme_id) if user.role != "admin": done = await user_settings_store.get(user.id, "personality_setup_done") needs_personality_setup = not done return { "request": request, "current_user": user, "theme_css": theme_css, "needs_personality_setup": needs_personality_setup, **extra, } @app.get("/", response_class=HTMLResponse) async def chat_page(request: Request, session: str = ""): # Allow reopening a saved conversation via /?session= session_id = session.strip() if session.strip() else str(uuid.uuid4()) return templates.TemplateResponse("chat.html", await _ctx(request, session_id=session_id)) @app.get("/chats", response_class=HTMLResponse) async def chats_page(request: Request): return templates.TemplateResponse("chats.html", await _ctx(request)) @app.get("/agents", response_class=HTMLResponse) async def agents_page(request: Request): return templates.TemplateResponse("agents.html", await _ctx(request)) @app.get("/agents/{agent_id}", response_class=HTMLResponse) async def agent_detail_page(request: Request, agent_id: str): return templates.TemplateResponse("agent_detail.html", await _ctx(request, agent_id=agent_id)) @app.get("/models", response_class=HTMLResponse) async def models_page(request: Request): return templates.TemplateResponse("models.html", await _ctx(request)) @app.get("/audit", response_class=HTMLResponse) async def audit_page(request: Request): return templates.TemplateResponse("audit.html", await _ctx(request)) @app.get("/help", response_class=HTMLResponse) async def help_page(request: Request): return templates.TemplateResponse("help.html", await _ctx(request)) @app.get("/files", response_class=HTMLResponse) async def files_page(request: Request): return templates.TemplateResponse("files.html", await _ctx(request)) @app.get("/settings", response_class=HTMLResponse) async def settings_page(request: Request): user = _get_current_user(request) if user is None: return RedirectResponse("/login?next=/settings") ctx = await _ctx(request) if user.is_admin: rows = await credential_store.list_keys() is_paused = await credential_store.get("system:paused") == "1" ctx.update(credential_keys=[r["key"] for r in rows], is_paused=is_paused) return templates.TemplateResponse("settings.html", ctx) @app.get("/admin/users", response_class=HTMLResponse) async def admin_users_page(request: Request): if not _require_admin(request): return RedirectResponse("/") return templates.TemplateResponse("admin_users.html", await _ctx(request)) # ── Kill switch ─────────────────────────────────────────────────────────────── @app.post("/api/pause") async def pause_agent(request: Request): if not _require_admin(request): raise HTTPException(status_code=403, detail="Admin only") await credential_store.set("system:paused", "1", description="Kill switch") return {"status": "paused"} @app.post("/api/resume") async def resume_agent(request: Request): if not _require_admin(request): raise HTTPException(status_code=403, detail="Admin only") await credential_store.delete("system:paused") return {"status": "running"} @app.get("/api/status") async def agent_status(): return { "paused": await credential_store.get("system:paused") == "1", "pending_confirmations": confirmation_manager.list_pending(), } @app.get("/health") async def health(): return {"status": "ok"} # ── WebSocket ───────────────────────────────────────────────────────────────── @app.websocket("/ws/{session_id}") async def websocket_endpoint(websocket: WebSocket, session_id: str): await websocket.accept() _ws_user = getattr(websocket.state, "current_user", None) _ws_is_admin = _ws_user.is_admin if _ws_user else True _ws_user_id = _ws_user.id if _ws_user else None # Send available models immediately on connect (filtered per user's access tier) from .providers.models import get_available_models, get_capability_map try: _models, _default = await get_available_models(user_id=_ws_user_id, is_admin=_ws_is_admin) _caps = await get_capability_map(user_id=_ws_user_id, is_admin=_ws_is_admin) await websocket.send_json({ "type": "models", "models": _models, "default": _default, "capabilities": _caps, }) except WebSocketDisconnect: return # Discover per-user MCP tools (3-E) — discovered once per connection _user_mcp_tools: list = [] if _ws_user_id: try: from .mcp_client.manager import discover_user_mcp_tools _user_mcp_tools = await discover_user_mcp_tools(_ws_user_id) except Exception as _e: logger.warning("Failed to discover user MCP tools: %s", _e) # If this session has existing history (reopened chat), send it to the client try: from .database import get_pool as _get_pool _pool = await _get_pool() # Only restore if this session belongs to the current user (or is unowned) _conv = await _pool.fetchrow( "SELECT messages, title, model FROM conversations WHERE id = $1 AND (user_id = $2 OR user_id IS NULL)", session_id, _ws_user_id, ) if _conv and _conv["messages"]: _msgs = _conv["messages"] if isinstance(_msgs, str): _msgs = json.loads(_msgs) # Build a simplified view: only user + assistant text turns _restore_turns = [] for _m in _msgs: _role = _m.get("role") if _role == "user": _content = _m.get("content", "") if isinstance(_content, list): _text = " ".join(b.get("text", "") for b in _content if b.get("type") == "text") else: _text = str(_content) if _text.strip(): _restore_turns.append({"role": "user", "text": _text.strip()}) elif _role == "assistant": _content = _m.get("content", "") if isinstance(_content, list): _text = " ".join(b.get("text", "") for b in _content if b.get("type") == "text") else: _text = str(_content) if _content else "" if _text.strip(): _restore_turns.append({"role": "assistant", "text": _text.strip()}) if _restore_turns: await websocket.send_json({ "type": "restore", "session_id": session_id, "title": _conv["title"] or "", "model": _conv["model"] or "", "messages": _restore_turns, }) except Exception as _e: logger.warning("Failed to send restore event for session %s: %s", session_id, _e) # Queue for incoming user messages (so receiver and agent run concurrently) msg_queue: asyncio.Queue[dict] = asyncio.Queue() async def receiver(): """Receive messages from client. Confirmations handled immediately.""" try: async for raw in websocket.iter_json(): if raw.get("type") == "confirm": confirmation_manager.respond(session_id, raw.get("approved", False)) elif raw.get("type") == "message": await msg_queue.put(raw) elif raw.get("type") == "clear": if _agent: _agent.clear_history(session_id) except WebSocketDisconnect: await msg_queue.put({"type": "_disconnect"}) async def sender(): """Process queued messages through the agent, stream events back.""" while True: raw = await msg_queue.get() if raw.get("type") == "_disconnect": break content = raw.get("content", "").strip() attachments = raw.get("attachments") or None # list of {media_type, data} if not content and not attachments: continue if _agent is None: await websocket.send_json({"type": "error", "message": "Agent not ready."}) continue model = raw.get("model") or None try: chat_allowed_tools: list[str] | None = None if not _ws_is_admin and _registry is not None: all_names = [t.name for t in _registry.all_tools()] chat_allowed_tools = [t for t in all_names if t != "bash"] stream = await _agent.run( message=content, session_id=session_id, model=model, allowed_tools=chat_allowed_tools, user_id=_ws_user_id, extra_tools=_user_mcp_tools or None, attachments=attachments, ) async for event in stream: payload = _event_to_dict(event) await websocket.send_json(payload) except Exception as e: await websocket.send_json({"type": "error", "message": str(e)}) try: await asyncio.gather(receiver(), sender()) except WebSocketDisconnect: pass def _event_to_dict(event: AgentEvent) -> dict: if isinstance(event, TextEvent): return {"type": "text", "content": event.content} if isinstance(event, ToolStartEvent): return {"type": "tool_start", "call_id": event.call_id, "tool_name": event.tool_name, "arguments": event.arguments} if isinstance(event, ToolDoneEvent): return {"type": "tool_done", "call_id": event.call_id, "tool_name": event.tool_name, "success": event.success, "result": event.result_summary, "confirmed": event.confirmed} if isinstance(event, ConfirmationRequiredEvent): return {"type": "confirmation_required", "call_id": event.call_id, "tool_name": event.tool_name, "arguments": event.arguments, "description": event.description} if isinstance(event, DoneEvent): return {"type": "done", "tool_calls_made": event.tool_calls_made, "usage": {"input": event.usage.input_tokens, "output": event.usage.output_tokens}} if isinstance(event, ImageEvent): return {"type": "image", "data_urls": event.data_urls} if isinstance(event, ErrorEvent): return {"type": "error", "message": event.message} return {"type": "unknown"}