Files
oai-web/server/main.py
2026-04-08 12:43:24 +02:00

899 lines
36 KiB
Python

"""
main.py — FastAPI application entry point.
Provides:
- HTML pages: /, /agents, /audit, /settings, /login, /setup, /admin/users
- WebSocket: /ws/{session_id} (streaming agent responses)
- REST API: /api/*
"""
from __future__ import annotations
import asyncio
import hashlib
import json
import logging
import uuid
from contextlib import asynccontextmanager
from pathlib import Path
# Configure logging before anything else imports logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)-8s %(name)s %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
# Make CalDAV tool logs visible at DEBUG level so every step is traceable
logging.getLogger("server.tools.caldav_tool").setLevel(logging.DEBUG)
from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from .agent.agent import Agent, AgentEvent, ConfirmationRequiredEvent, DoneEvent, ErrorEvent, ImageEvent, TextEvent, ToolDoneEvent, ToolStartEvent
from .agent.confirmation import confirmation_manager
from .agents.runner import agent_runner
from .agents.tasks import cleanup_stale_runs
from .auth import SYNTHETIC_API_ADMIN, CurrentUser, create_session_cookie, decode_session_cookie
from .brain.database import close_brain_db, init_brain_db
from .config import settings
from .context_vars import current_user as _current_user_var
from .database import close_db, credential_store, init_db
from .inbox.listener import inbox_listener
from .mcp import create_mcp_app, _session_manager
from .telegram.listener import telegram_listener
from .tools import build_registry
from .users import assign_existing_data_to_admin, create_user, get_user_by_username, user_count
from .web.routes import router as api_router
BASE_DIR = Path(__file__).parent
templates = Jinja2Templates(directory=str(BASE_DIR / "web" / "templates"))
templates.env.globals["agent_name"] = settings.agent_name
async def _migrate_email_accounts() -> None:
"""
One-time startup migration: copy old inbox:* / inbox_* credentials into the
new email_accounts table as 'trigger' type accounts.
Idempotent — guarded by the 'email_accounts_migrated' credential flag.
"""
if await credential_store.get("email_accounts_migrated") == "1":
return
from .inbox.accounts import create_account
from .inbox.triggers import list_triggers, update_trigger
from .database import get_pool
logger_main = logging.getLogger(__name__)
logger_main.info("[migrate] Running email_accounts one-time migration…")
# 1. Global trigger account (inbox:* keys in credential_store)
global_host = await credential_store.get("inbox:imap_host")
global_user = await credential_store.get("inbox:imap_username")
global_pass = await credential_store.get("inbox:imap_password")
global_account_id: str | None = None
if global_host and global_user and global_pass:
_smtp_port_raw = await credential_store.get("inbox:smtp_port")
acct = await create_account(
label="Global Inbox",
account_type="trigger",
imap_host=global_host,
imap_port=int(await credential_store.get("inbox:imap_port") or "993"),
imap_username=global_user,
imap_password=global_pass,
smtp_host=await credential_store.get("inbox:smtp_host"),
smtp_port=int(_smtp_port_raw) if _smtp_port_raw else 465,
smtp_username=await credential_store.get("inbox:smtp_username"),
smtp_password=await credential_store.get("inbox:smtp_password"),
user_id=None,
)
global_account_id = str(acct["id"])
logger_main.info("[migrate] Created global trigger account: %s", global_account_id)
# 2. Per-user trigger accounts (inbox_imap_host in user_settings)
from .database import user_settings_store
pool = await get_pool()
user_rows = await pool.fetch(
"SELECT DISTINCT user_id FROM user_settings WHERE key = 'inbox_imap_host'"
)
user_account_map: dict[str, str] = {} # user_id → account_id
for row in user_rows:
uid = row["user_id"]
host = await user_settings_store.get(uid, "inbox_imap_host")
uname = await user_settings_store.get(uid, "inbox_imap_username")
pw = await user_settings_store.get(uid, "inbox_imap_password")
if not (host and uname and pw):
continue
_u_smtp_port = await user_settings_store.get(uid, "inbox_smtp_port")
acct = await create_account(
label="My Inbox",
account_type="trigger",
imap_host=host,
imap_port=int(await user_settings_store.get(uid, "inbox_imap_port") or "993"),
imap_username=uname,
imap_password=pw,
smtp_host=await user_settings_store.get(uid, "inbox_smtp_host"),
smtp_port=int(_u_smtp_port) if _u_smtp_port else 465,
smtp_username=await user_settings_store.get(uid, "inbox_smtp_username"),
smtp_password=await user_settings_store.get(uid, "inbox_smtp_password"),
user_id=uid,
)
user_account_map[uid] = str(acct["id"])
logger_main.info("[migrate] Created trigger account for user %s: %s", uid, acct["id"])
# 3. Update existing email_triggers with account_id
all_triggers = await list_triggers(user_id=None)
for t in all_triggers:
tid = t["id"]
t_user_id = t.get("user_id")
if t_user_id is None and global_account_id:
await update_trigger(tid, account_id=global_account_id)
elif t_user_id and t_user_id in user_account_map:
await update_trigger(tid, account_id=user_account_map[t_user_id])
await credential_store.set("email_accounts_migrated", "1", "One-time email_accounts migration flag")
logger_main.info("[migrate] email_accounts migration complete.")
async def _refresh_brand_globals() -> None:
"""Update brand_name and logo_url Jinja2 globals from credential_store. Call at startup and after branding changes."""
brand_name = await credential_store.get("system:brand_name") or settings.agent_name
logo_filename = await credential_store.get("system:brand_logo_filename")
if logo_filename and (BASE_DIR / "web" / "static" / logo_filename).exists():
logo_url = f"/static/{logo_filename}"
else:
logo_url = "/static/logo.png"
templates.env.globals["brand_name"] = brand_name
templates.env.globals["logo_url"] = logo_url
# Cache-busting version: hash of static file contents so it always changes when files change.
# Avoids relying on git (not available in Docker container).
def _compute_static_version() -> str:
static_dir = BASE_DIR / "web" / "static"
h = hashlib.md5()
for f in sorted(static_dir.glob("*.js")) + sorted(static_dir.glob("*.css")):
try:
h.update(f.read_bytes())
except OSError:
pass
return h.hexdigest()[:10]
_static_version = _compute_static_version()
templates.env.globals["sv"] = _static_version
# ── First-run flag ─────────────────────────────────────────────────────────────
# Set in lifespan; cleared when /setup creates the first admin.
_needs_setup: bool = False
# ── Global agent (singleton — shares session history across requests) ─────────
_registry = None
_agent: Agent | None = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global _registry, _agent, _needs_setup, _trusted_proxy_ips
await init_db()
await _refresh_brand_globals()
await _ensure_session_secret()
_needs_setup = await user_count() == 0
global _trusted_proxy_ips
_trusted_proxy_ips = await credential_store.get("system:trusted_proxy_ips") or "127.0.0.1"
await cleanup_stale_runs()
await init_brain_db()
_registry = build_registry()
from .mcp_client.manager import discover_and_register_mcp_tools
await discover_and_register_mcp_tools(_registry)
_agent = Agent(registry=_registry)
print("[aide] Agent ready.")
agent_runner.init(_agent)
await agent_runner.start()
await _migrate_email_accounts()
await inbox_listener.start_all()
telegram_listener.start()
async with _session_manager.run():
yield
inbox_listener.stop_all()
telegram_listener.stop()
agent_runner.shutdown()
await close_brain_db()
await close_db()
app = FastAPI(title="oAI-Web API", version="0.5", lifespan=lifespan)
# ── Custom OpenAPI schema — adds X-API-Key "Authorize" button in Swagger ──────
def _custom_openapi():
if app.openapi_schema:
return app.openapi_schema
from fastapi.openapi.utils import get_openapi
schema = get_openapi(title=app.title, version=app.version, routes=app.routes)
schema.setdefault("components", {})["securitySchemes"] = {
"ApiKeyAuth": {"type": "apiKey", "in": "header", "name": "X-API-Key"}
}
schema["security"] = [{"ApiKeyAuth": []}]
app.openapi_schema = schema
return schema
app.openapi = _custom_openapi
# ── Proxy trust ───────────────────────────────────────────────────────────────
_trusted_proxy_ips: str = "127.0.0.1"
class _ProxyTrustMiddleware:
"""Thin wrapper so trusted IPs are read from DB at startup, not hard-coded."""
def __init__(self, app):
from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware
self._app = app
self._inner: ProxyHeadersMiddleware | None = None
async def __call__(self, scope, receive, send):
if self._inner is None:
from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware
self._inner = ProxyHeadersMiddleware(self._app, trusted_hosts=_trusted_proxy_ips)
await self._inner(scope, receive, send)
app.add_middleware(_ProxyTrustMiddleware)
# ── Auth middleware ────────────────────────────────────────────────────────────
#
# All routes require authentication. Two accepted paths:
# 1. User session cookie (aide_user) — set on login, carries identity.
# 2. API key (X-API-Key or Authorization: Bearer) — treated as synthetic admin.
#
# Exempt paths bypass auth entirely (login, setup, static, health, etc.).
# First-run: if no users exist (_needs_setup), all non-exempt paths → /setup.
import hashlib as _hashlib
import hmac as _hmac
import secrets as _secrets
import time as _time
_USER_COOKIE = "aide_user"
_EXEMPT_PATHS = frozenset({"/login", "/login/mfa", "/logout", "/setup", "/health"})
_EXEMPT_PREFIXES = ("/static/", "/brain-mcp/", "/docs", "/redoc", "/openapi.json")
_EXEMPT_API_PATHS = frozenset({"/api/settings/api-key"})
async def _ensure_session_secret() -> str:
"""Return the session HMAC secret, creating it in the credential store if absent."""
secret = await credential_store.get("system:session_secret")
if not secret:
secret = _secrets.token_hex(32)
await credential_store.set("system:session_secret", secret,
description="Web UI session token secret (auto-generated)")
return secret
def _parse_user_cookie(raw_cookie: str) -> str:
"""Extract aide_user value from raw Cookie header string."""
for part in raw_cookie.split(";"):
part = part.strip()
if part.startswith(_USER_COOKIE + "="):
return part[len(_USER_COOKIE) + 1:]
return ""
async def _authenticate(headers: dict) -> CurrentUser | None:
"""Try user session cookie, then API key. Returns CurrentUser or None."""
# Try user session cookie
raw_cookie = headers.get(b"cookie", b"").decode()
cookie_val = _parse_user_cookie(raw_cookie)
if cookie_val:
secret = await credential_store.get("system:session_secret")
if secret:
user = decode_session_cookie(cookie_val, secret)
if user:
# Verify the user is still active in the DB — catches deactivated accounts
# whose session cookies haven't expired yet.
from .users import get_user_by_id as _get_user_by_id
db_user = await _get_user_by_id(user.id)
if db_user and db_user.get("is_active", True):
return user
# Try API key
key_hash = await credential_store.get("system:api_key_hash")
if key_hash:
provided = (
headers.get(b"x-api-key", b"").decode()
or headers.get(b"authorization", b"").decode().removeprefix("Bearer ").strip()
)
if provided and _hashlib.sha256(provided.encode()).hexdigest() == key_hash:
return SYNTHETIC_API_ADMIN
return None
class _AuthMiddleware:
"""Unified authentication middleware. Guards all routes except exempt paths."""
def __init__(self, app):
self._app = app
async def __call__(self, scope, receive, send):
if scope["type"] not in ("http", "websocket"):
await self._app(scope, receive, send)
return
path: str = scope.get("path", "")
# Always let exempt paths through
if path in _EXEMPT_PATHS or path in _EXEMPT_API_PATHS:
await self._app(scope, receive, send)
return
if any(path.startswith(p) for p in _EXEMPT_PREFIXES):
await self._app(scope, receive, send)
return
# First-run: redirect to /setup
if _needs_setup:
if scope["type"] == "websocket":
await send({"type": "websocket.close", "code": 1008})
return
response = RedirectResponse("/setup")
await response(scope, receive, send)
return
# Authenticate
headers = dict(scope.get("headers", []))
user = await _authenticate(headers)
if user is None:
if scope["type"] == "websocket":
await send({"type": "websocket.close", "code": 1008})
return
is_api = path.startswith("/api/") or path.startswith("/ws/")
if is_api:
response = JSONResponse({"error": "Authentication required"}, status_code=401)
await response(scope, receive, send)
return
else:
next_param = f"?next={path}" if path != "/" else ""
response = RedirectResponse(f"/login{next_param}")
await response(scope, receive, send)
return
# Set user on request state (for templates) and ContextVar (for tools/audit)
scope.setdefault("state", {})["current_user"] = user
token = _current_user_var.set(user)
try:
await self._app(scope, receive, send)
finally:
_current_user_var.reset(token)
app.add_middleware(_AuthMiddleware)
app.mount("/static", StaticFiles(directory=str(BASE_DIR / "web" / "static")), name="static")
app.include_router(api_router, prefix="/api")
# 2nd Brain MCP server — mounted at /brain-mcp (SSE transport)
app.mount("/brain-mcp", create_mcp_app())
# ── Auth helpers ──────────────────────────────────────────────────────────────
def _get_current_user(request: Request) -> CurrentUser | None:
try:
return request.state.current_user
except AttributeError:
return None
def _require_admin(request: Request) -> bool:
u = _get_current_user(request)
return u is not None and u.is_admin
# ── Login rate limiting ───────────────────────────────────────────────────────
from .login_limiter import is_locked as _login_is_locked
from .login_limiter import record_failure as _record_login_failure
from .login_limiter import clear_failures as _clear_login_failures
def _get_client_ip(request: Request) -> str:
"""Best-effort client IP, respecting X-Forwarded-For if set."""
forwarded = request.headers.get("x-forwarded-for")
if forwarded:
return forwarded.split(",")[0].strip()
return request.client.host if request.client else "unknown"
# ── Login / Logout / Setup ────────────────────────────────────────────────────
@app.get("/login", response_class=HTMLResponse)
async def login_get(request: Request, next: str = "/", error: str = ""):
if _get_current_user(request):
return RedirectResponse("/")
_ERROR_MESSAGES = {
"session_expired": "MFA session expired. Please sign in again.",
"too_many_attempts": "Too many incorrect codes. Please sign in again.",
}
error_msg = _ERROR_MESSAGES.get(error) if error else None
return templates.TemplateResponse("login.html", {"request": request, "next": next, "error": error_msg})
@app.post("/login")
async def login_post(request: Request):
import secrets as _secrets
from datetime import datetime, timezone, timedelta
from .auth import verify_password
form = await request.form()
username = str(form.get("username", "")).strip()
password = str(form.get("password", ""))
raw_next = str(form.get("next", "/")).strip() or "/"
# Reject absolute URLs and protocol-relative URLs to prevent open redirect
next_url = raw_next if (raw_next.startswith("/") and not raw_next.startswith("//")) else "/"
ip = _get_client_ip(request)
locked, lock_kind = _login_is_locked(ip)
if locked:
logger.warning("[login] blocked IP %s (%s)", ip, lock_kind)
if lock_kind == "permanent":
msg = "This IP has been permanently blocked due to repeated login failures. Contact an administrator."
else:
msg = "Too many failed attempts. Please try again in 30 minutes."
return templates.TemplateResponse("login.html", {
"request": request,
"next": next_url,
"error": msg,
}, status_code=429)
user = await get_user_by_username(username)
if user and user["is_active"] and verify_password(password, user["password_hash"]):
_clear_login_failures(ip)
# MFA branch: TOTP required
if user.get("totp_secret"):
token = _secrets.token_hex(32)
pool = await _db_pool()
now = datetime.now(timezone.utc)
expires = now + timedelta(minutes=5)
await pool.execute(
"INSERT INTO mfa_challenges (token, user_id, next_url, created_at, expires_at) "
"VALUES ($1, $2, $3, $4, $5)",
token, user["id"], next_url, now, expires,
)
response = RedirectResponse(f"/login/mfa", status_code=303)
response.set_cookie(
"mfa_challenge", token,
httponly=True, samesite="lax", max_age=300, path="/login/mfa",
)
return response
# No MFA — create session directly
secret = await _ensure_session_secret()
cookie_val = create_session_cookie(user, secret)
response = RedirectResponse(next_url, status_code=303)
response.set_cookie(
_USER_COOKIE, cookie_val,
httponly=True, samesite="lax", max_age=2592000, path="/",
)
return response
_record_login_failure(ip)
return templates.TemplateResponse("login.html", {
"request": request,
"next": next_url,
"error": "Invalid username or password.",
}, status_code=401)
async def _db_pool():
from .database import get_pool
return await get_pool()
@app.get("/login/mfa", response_class=HTMLResponse)
async def login_mfa_get(request: Request):
from datetime import datetime, timezone
token = request.cookies.get("mfa_challenge", "")
pool = await _db_pool()
row = await pool.fetchrow(
"SELECT user_id, next_url, expires_at FROM mfa_challenges WHERE token = $1", token
)
if not row or row["expires_at"] < datetime.now(timezone.utc):
return RedirectResponse("/login?error=session_expired", status_code=303)
return templates.TemplateResponse("mfa.html", {
"request": request,
"next": row["next_url"],
"error": None,
})
@app.post("/login/mfa")
async def login_mfa_post(request: Request):
from datetime import datetime, timezone
from .auth import verify_totp
form = await request.form()
code = str(form.get("code", "")).strip()
token = request.cookies.get("mfa_challenge", "")
pool = await _db_pool()
row = await pool.fetchrow(
"SELECT user_id, next_url, expires_at, attempts FROM mfa_challenges WHERE token = $1", token
)
if not row or row["expires_at"] < datetime.now(timezone.utc):
return RedirectResponse("/login?error=session_expired", status_code=303)
next_url = row["next_url"] or "/"
from .users import get_user_by_id
user = await get_user_by_id(row["user_id"])
if not user or not user.get("totp_secret"):
await pool.execute("DELETE FROM mfa_challenges WHERE token = $1", token)
return RedirectResponse("/login", status_code=303)
if not verify_totp(user["totp_secret"], code):
new_attempts = row["attempts"] + 1
if new_attempts >= 5:
await pool.execute("DELETE FROM mfa_challenges WHERE token = $1", token)
return RedirectResponse("/login?error=too_many_attempts", status_code=303)
await pool.execute(
"UPDATE mfa_challenges SET attempts = $1 WHERE token = $2", new_attempts, token
)
response = templates.TemplateResponse("mfa.html", {
"request": request,
"next": next_url,
"error": "Invalid code. Try again.",
}, status_code=401)
return response
# Success
await pool.execute("DELETE FROM mfa_challenges WHERE token = $1", token)
secret = await _ensure_session_secret()
cookie_val = create_session_cookie(user, secret)
response = RedirectResponse(next_url, status_code=303)
response.set_cookie(
_USER_COOKIE, cookie_val,
httponly=True, samesite="lax", max_age=2592000, path="/",
)
response.delete_cookie("mfa_challenge", path="/login/mfa")
return response
@app.get("/logout")
async def logout(request: Request):
# Render a tiny page that clears localStorage then redirects to /login.
# This prevents the next user on the same browser from restoring the
# previous user's conversation via the persisted current_session_id key.
response = HTMLResponse("""<!doctype html>
<html><head><title>Logging out…</title></head><body>
<script>
localStorage.removeItem("current_session_id");
localStorage.removeItem("preferred-model");
window.location.replace("/login");
</script>
</body></html>""")
response.delete_cookie(_USER_COOKIE, path="/")
return response
@app.get("/setup", response_class=HTMLResponse)
async def setup_get(request: Request):
if not _needs_setup:
return RedirectResponse("/")
return templates.TemplateResponse("setup.html", {"request": request, "errors": [], "username": ""})
@app.post("/setup")
async def setup_post(request: Request):
global _needs_setup
if not _needs_setup:
return RedirectResponse("/", status_code=303)
form = await request.form()
username = str(form.get("username", "")).strip()
password = str(form.get("password", ""))
confirm = str(form.get("confirm", ""))
email = str(form.get("email", "")).strip().lower()
errors = []
if not username:
errors.append("Username is required.")
if not email or "@" not in email:
errors.append("A valid email address is required.")
if len(password) < 8:
errors.append("Password must be at least 8 characters.")
if password != confirm:
errors.append("Passwords do not match.")
if errors:
return templates.TemplateResponse("setup.html", {
"request": request,
"errors": errors,
"username": username,
"email": email,
}, status_code=400)
user = await create_user(username, password, role="admin", email=email)
await assign_existing_data_to_admin(user["id"])
_needs_setup = False
secret = await _ensure_session_secret()
cookie_val = create_session_cookie(user, secret)
response = RedirectResponse("/", status_code=303)
response.set_cookie(_USER_COOKIE, cookie_val, httponly=True, samesite="lax", max_age=2592000, path="/")
return response
# ── HTML pages ────────────────────────────────────────────────────────────────
async def _ctx(request: Request, **extra):
"""Build template context with current_user and active theme CSS injected."""
from .web.themes import get_theme_css, DEFAULT_THEME
from .database import user_settings_store
user = _get_current_user(request)
theme_css = ""
needs_personality_setup = False
if user:
theme_id = await user_settings_store.get(user.id, "theme") or DEFAULT_THEME
theme_css = get_theme_css(theme_id)
if user.role != "admin":
done = await user_settings_store.get(user.id, "personality_setup_done")
needs_personality_setup = not done
return {
"request": request,
"current_user": user,
"theme_css": theme_css,
"needs_personality_setup": needs_personality_setup,
**extra,
}
@app.get("/", response_class=HTMLResponse)
async def chat_page(request: Request, session: str = ""):
# Allow reopening a saved conversation via /?session=<id>
session_id = session.strip() if session.strip() else str(uuid.uuid4())
return templates.TemplateResponse("chat.html", await _ctx(request, session_id=session_id))
@app.get("/chats", response_class=HTMLResponse)
async def chats_page(request: Request):
return templates.TemplateResponse("chats.html", await _ctx(request))
@app.get("/agents", response_class=HTMLResponse)
async def agents_page(request: Request):
return templates.TemplateResponse("agents.html", await _ctx(request))
@app.get("/agents/{agent_id}", response_class=HTMLResponse)
async def agent_detail_page(request: Request, agent_id: str):
return templates.TemplateResponse("agent_detail.html", await _ctx(request, agent_id=agent_id))
@app.get("/models", response_class=HTMLResponse)
async def models_page(request: Request):
return templates.TemplateResponse("models.html", await _ctx(request))
@app.get("/audit", response_class=HTMLResponse)
async def audit_page(request: Request):
return templates.TemplateResponse("audit.html", await _ctx(request))
@app.get("/help", response_class=HTMLResponse)
async def help_page(request: Request):
return templates.TemplateResponse("help.html", await _ctx(request))
@app.get("/files", response_class=HTMLResponse)
async def files_page(request: Request):
return templates.TemplateResponse("files.html", await _ctx(request))
@app.get("/settings", response_class=HTMLResponse)
async def settings_page(request: Request):
user = _get_current_user(request)
if user is None:
return RedirectResponse("/login?next=/settings")
ctx = await _ctx(request)
if user.is_admin:
rows = await credential_store.list_keys()
is_paused = await credential_store.get("system:paused") == "1"
ctx.update(credential_keys=[r["key"] for r in rows], is_paused=is_paused)
return templates.TemplateResponse("settings.html", ctx)
@app.get("/admin/users", response_class=HTMLResponse)
async def admin_users_page(request: Request):
if not _require_admin(request):
return RedirectResponse("/")
return templates.TemplateResponse("admin_users.html", await _ctx(request))
# ── Kill switch ───────────────────────────────────────────────────────────────
@app.post("/api/pause")
async def pause_agent(request: Request):
if not _require_admin(request):
raise HTTPException(status_code=403, detail="Admin only")
await credential_store.set("system:paused", "1", description="Kill switch")
return {"status": "paused"}
@app.post("/api/resume")
async def resume_agent(request: Request):
if not _require_admin(request):
raise HTTPException(status_code=403, detail="Admin only")
await credential_store.delete("system:paused")
return {"status": "running"}
@app.get("/api/status")
async def agent_status():
return {
"paused": await credential_store.get("system:paused") == "1",
"pending_confirmations": confirmation_manager.list_pending(),
}
@app.get("/health")
async def health():
return {"status": "ok"}
# ── WebSocket ─────────────────────────────────────────────────────────────────
@app.websocket("/ws/{session_id}")
async def websocket_endpoint(websocket: WebSocket, session_id: str):
await websocket.accept()
_ws_user = getattr(websocket.state, "current_user", None)
_ws_is_admin = _ws_user.is_admin if _ws_user else True
_ws_user_id = _ws_user.id if _ws_user else None
# Send available models immediately on connect (filtered per user's access tier)
from .providers.models import get_available_models, get_capability_map
try:
_models, _default = await get_available_models(user_id=_ws_user_id, is_admin=_ws_is_admin)
_caps = await get_capability_map(user_id=_ws_user_id, is_admin=_ws_is_admin)
await websocket.send_json({
"type": "models",
"models": _models,
"default": _default,
"capabilities": _caps,
})
except WebSocketDisconnect:
return
# Discover per-user MCP tools (3-E) — discovered once per connection
_user_mcp_tools: list = []
if _ws_user_id:
try:
from .mcp_client.manager import discover_user_mcp_tools
_user_mcp_tools = await discover_user_mcp_tools(_ws_user_id)
except Exception as _e:
logger.warning("Failed to discover user MCP tools: %s", _e)
# If this session has existing history (reopened chat), send it to the client
try:
from .database import get_pool as _get_pool
_pool = await _get_pool()
# Only restore if this session belongs to the current user (or is unowned)
_conv = await _pool.fetchrow(
"SELECT messages, title, model FROM conversations WHERE id = $1 AND (user_id = $2 OR user_id IS NULL)",
session_id, _ws_user_id,
)
if _conv and _conv["messages"]:
_msgs = _conv["messages"]
if isinstance(_msgs, str):
_msgs = json.loads(_msgs)
# Build a simplified view: only user + assistant text turns
_restore_turns = []
for _m in _msgs:
_role = _m.get("role")
if _role == "user":
_content = _m.get("content", "")
if isinstance(_content, list):
_text = " ".join(b.get("text", "") for b in _content if b.get("type") == "text")
else:
_text = str(_content)
if _text.strip():
_restore_turns.append({"role": "user", "text": _text.strip()})
elif _role == "assistant":
_content = _m.get("content", "")
if isinstance(_content, list):
_text = " ".join(b.get("text", "") for b in _content if b.get("type") == "text")
else:
_text = str(_content) if _content else ""
if _text.strip():
_restore_turns.append({"role": "assistant", "text": _text.strip()})
if _restore_turns:
await websocket.send_json({
"type": "restore",
"session_id": session_id,
"title": _conv["title"] or "",
"model": _conv["model"] or "",
"messages": _restore_turns,
})
except Exception as _e:
logger.warning("Failed to send restore event for session %s: %s", session_id, _e)
# Queue for incoming user messages (so receiver and agent run concurrently)
msg_queue: asyncio.Queue[dict] = asyncio.Queue()
async def receiver():
"""Receive messages from client. Confirmations handled immediately."""
try:
async for raw in websocket.iter_json():
if raw.get("type") == "confirm":
confirmation_manager.respond(session_id, raw.get("approved", False))
elif raw.get("type") == "message":
await msg_queue.put(raw)
elif raw.get("type") == "clear":
if _agent:
_agent.clear_history(session_id)
except WebSocketDisconnect:
await msg_queue.put({"type": "_disconnect"})
async def sender():
"""Process queued messages through the agent, stream events back."""
while True:
raw = await msg_queue.get()
if raw.get("type") == "_disconnect":
break
content = raw.get("content", "").strip()
attachments = raw.get("attachments") or None # list of {media_type, data}
if not content and not attachments:
continue
if _agent is None:
await websocket.send_json({"type": "error", "message": "Agent not ready."})
continue
model = raw.get("model") or None
try:
chat_allowed_tools: list[str] | None = None
if not _ws_is_admin and _registry is not None:
all_names = [t.name for t in _registry.all_tools()]
chat_allowed_tools = [t for t in all_names if t != "bash"]
stream = await _agent.run(
message=content,
session_id=session_id,
model=model,
allowed_tools=chat_allowed_tools,
user_id=_ws_user_id,
extra_tools=_user_mcp_tools or None,
attachments=attachments,
)
async for event in stream:
payload = _event_to_dict(event)
await websocket.send_json(payload)
except Exception as e:
await websocket.send_json({"type": "error", "message": str(e)})
try:
await asyncio.gather(receiver(), sender())
except WebSocketDisconnect:
pass
def _event_to_dict(event: AgentEvent) -> dict:
if isinstance(event, TextEvent):
return {"type": "text", "content": event.content}
if isinstance(event, ToolStartEvent):
return {"type": "tool_start", "call_id": event.call_id, "tool_name": event.tool_name, "arguments": event.arguments}
if isinstance(event, ToolDoneEvent):
return {"type": "tool_done", "call_id": event.call_id, "tool_name": event.tool_name, "success": event.success, "result": event.result_summary, "confirmed": event.confirmed}
if isinstance(event, ConfirmationRequiredEvent):
return {"type": "confirmation_required", "call_id": event.call_id, "tool_name": event.tool_name, "arguments": event.arguments, "description": event.description}
if isinstance(event, DoneEvent):
return {"type": "done", "tool_calls_made": event.tool_calls_made, "usage": {"input": event.usage.input_tokens, "output": event.usage.output_tokens}}
if isinstance(event, ImageEvent):
return {"type": "image", "data_urls": event.data_urls}
if isinstance(event, ErrorEvent):
return {"type": "error", "message": event.message}
return {"type": "unknown"}