""" audit.py — Append-only audit log. Every tool call is recorded here BEFORE the result is returned to the agent. All methods are async — callers must await them. """ from __future__ import annotations import json from dataclasses import dataclass from datetime import datetime, timedelta, timezone from typing import Any from .database import _jsonify, get_pool @dataclass class AuditEntry: id: int timestamp: str session_id: str | None tool_name: str arguments: dict | None result_summary: str | None confirmed: bool task_id: str | None user_id: str | None = None class AuditLog: """Write audit records and query them for the UI.""" async def record( self, tool_name: str, arguments: dict[str, Any] | None = None, result_summary: str | None = None, confirmed: bool = False, session_id: str | None = None, task_id: str | None = None, user_id: str | None = None, ) -> int: """Write a tool-call audit record. Returns the new row ID.""" if user_id is None: from .context_vars import current_user as _cu u = _cu.get() if u: user_id = u.id now = datetime.now(timezone.utc).isoformat() # Sanitise arguments for JSONB (convert non-serializable values to strings) args = _jsonify(arguments) if arguments is not None else None pool = await get_pool() row_id: int = await pool.fetchval( """ INSERT INTO audit_log (timestamp, session_id, tool_name, arguments, result_summary, confirmed, task_id, user_id) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) RETURNING id """, now, session_id, tool_name, args, result_summary, confirmed, task_id, user_id, ) return row_id async def query( self, start: str | None = None, end: str | None = None, tool_name: str | None = None, session_id: str | None = None, task_id: str | None = None, confirmed_only: bool = False, user_id: str | None = None, limit: int = 50, offset: int = 0, ) -> list[AuditEntry]: """Query the audit log. All filters are optional.""" clauses: list[str] = [] params: list[Any] = [] n = 1 if start: sv = start if ("+" in start or start.upper().endswith("Z")) else start + "Z" clauses.append(f"timestamp::timestamptz >= ${n}::timestamptz"); params.append(sv); n += 1 if end: ev = end if ("+" in end or end.upper().endswith("Z")) else end + "Z" clauses.append(f"timestamp::timestamptz <= ${n}::timestamptz"); params.append(ev); n += 1 if tool_name: clauses.append(f"tool_name ILIKE ${n}"); params.append(f"%{tool_name}%"); n += 1 if session_id: clauses.append(f"session_id = ${n}"); params.append(session_id); n += 1 if task_id: clauses.append(f"task_id = ${n}"); params.append(task_id); n += 1 if confirmed_only: clauses.append("confirmed = TRUE") if user_id: clauses.append(f"user_id = ${n}"); params.append(user_id); n += 1 where = ("WHERE " + " AND ".join(clauses)) if clauses else "" params.extend([limit, offset]) pool = await get_pool() rows = await pool.fetch( f""" SELECT id, timestamp, session_id, tool_name, arguments, result_summary, confirmed, task_id, user_id FROM audit_log {where} ORDER BY timestamp::timestamptz DESC LIMIT ${n} OFFSET ${n + 1} """, *params, ) return [ AuditEntry( id=r["id"], timestamp=r["timestamp"], session_id=r["session_id"], tool_name=r["tool_name"], arguments=r["arguments"], # asyncpg deserialises JSONB automatically result_summary=r["result_summary"], confirmed=r["confirmed"], task_id=r["task_id"], user_id=r["user_id"], ) for r in rows ] async def count( self, start: str | None = None, end: str | None = None, tool_name: str | None = None, task_id: str | None = None, session_id: str | None = None, confirmed_only: bool = False, user_id: str | None = None, ) -> int: clauses: list[str] = [] params: list[Any] = [] n = 1 if start: sv = start if ("+" in start or start.upper().endswith("Z")) else start + "Z" clauses.append(f"timestamp::timestamptz >= ${n}::timestamptz"); params.append(sv); n += 1 if end: ev = end if ("+" in end or end.upper().endswith("Z")) else end + "Z" clauses.append(f"timestamp::timestamptz <= ${n}::timestamptz"); params.append(ev); n += 1 if tool_name: clauses.append(f"tool_name ILIKE ${n}"); params.append(f"%{tool_name}%"); n += 1 if task_id: clauses.append(f"task_id = ${n}"); params.append(task_id); n += 1 if session_id: clauses.append(f"session_id = ${n}"); params.append(session_id); n += 1 if confirmed_only: clauses.append("confirmed = TRUE") if user_id: clauses.append(f"user_id = ${n}"); params.append(user_id); n += 1 where = ("WHERE " + " AND ".join(clauses)) if clauses else "" pool = await get_pool() return await pool.fetchval( f"SELECT COUNT(*) FROM audit_log {where}", *params ) or 0 async def purge(self, older_than_days: int | None = None) -> int: """Delete audit records. older_than_days=None deletes all. Returns row count.""" pool = await get_pool() if older_than_days is not None: cutoff = ( datetime.now(timezone.utc) - timedelta(days=older_than_days) ).isoformat() status = await pool.execute( "DELETE FROM audit_log WHERE timestamp < $1", cutoff ) else: status = await pool.execute("DELETE FROM audit_log") from .database import _rowcount return _rowcount(status) # Module-level singleton audit_log = AuditLog()