Files
oai-web/server/audit.py
2026-04-08 12:43:24 +02:00

183 lines
6.3 KiB
Python

"""
audit.py — Append-only audit log.
Every tool call is recorded here BEFORE the result is returned to the agent.
All methods are async — callers must await them.
"""
from __future__ import annotations
import json
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import Any
from .database import _jsonify, get_pool
@dataclass
class AuditEntry:
id: int
timestamp: str
session_id: str | None
tool_name: str
arguments: dict | None
result_summary: str | None
confirmed: bool
task_id: str | None
user_id: str | None = None
class AuditLog:
"""Write audit records and query them for the UI."""
async def record(
self,
tool_name: str,
arguments: dict[str, Any] | None = None,
result_summary: str | None = None,
confirmed: bool = False,
session_id: str | None = None,
task_id: str | None = None,
user_id: str | None = None,
) -> int:
"""Write a tool-call audit record. Returns the new row ID."""
if user_id is None:
from .context_vars import current_user as _cu
u = _cu.get()
if u:
user_id = u.id
now = datetime.now(timezone.utc).isoformat()
# Sanitise arguments for JSONB (convert non-serializable values to strings)
args = _jsonify(arguments) if arguments is not None else None
pool = await get_pool()
row_id: int = await pool.fetchval(
"""
INSERT INTO audit_log
(timestamp, session_id, tool_name, arguments, result_summary, confirmed, task_id, user_id)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
RETURNING id
""",
now, session_id, tool_name, args, result_summary, confirmed, task_id, user_id,
)
return row_id
async def query(
self,
start: str | None = None,
end: str | None = None,
tool_name: str | None = None,
session_id: str | None = None,
task_id: str | None = None,
confirmed_only: bool = False,
user_id: str | None = None,
limit: int = 50,
offset: int = 0,
) -> list[AuditEntry]:
"""Query the audit log. All filters are optional."""
clauses: list[str] = []
params: list[Any] = []
n = 1
if start:
sv = start if ("+" in start or start.upper().endswith("Z")) else start + "Z"
clauses.append(f"timestamp::timestamptz >= ${n}::timestamptz"); params.append(sv); n += 1
if end:
ev = end if ("+" in end or end.upper().endswith("Z")) else end + "Z"
clauses.append(f"timestamp::timestamptz <= ${n}::timestamptz"); params.append(ev); n += 1
if tool_name:
clauses.append(f"tool_name ILIKE ${n}"); params.append(f"%{tool_name}%"); n += 1
if session_id:
clauses.append(f"session_id = ${n}"); params.append(session_id); n += 1
if task_id:
clauses.append(f"task_id = ${n}"); params.append(task_id); n += 1
if confirmed_only:
clauses.append("confirmed = TRUE")
if user_id:
clauses.append(f"user_id = ${n}"); params.append(user_id); n += 1
where = ("WHERE " + " AND ".join(clauses)) if clauses else ""
params.extend([limit, offset])
pool = await get_pool()
rows = await pool.fetch(
f"""
SELECT id, timestamp, session_id, tool_name, arguments,
result_summary, confirmed, task_id, user_id
FROM audit_log
{where}
ORDER BY timestamp::timestamptz DESC
LIMIT ${n} OFFSET ${n + 1}
""",
*params,
)
return [
AuditEntry(
id=r["id"],
timestamp=r["timestamp"],
session_id=r["session_id"],
tool_name=r["tool_name"],
arguments=r["arguments"], # asyncpg deserialises JSONB automatically
result_summary=r["result_summary"],
confirmed=r["confirmed"],
task_id=r["task_id"],
user_id=r["user_id"],
)
for r in rows
]
async def count(
self,
start: str | None = None,
end: str | None = None,
tool_name: str | None = None,
task_id: str | None = None,
session_id: str | None = None,
confirmed_only: bool = False,
user_id: str | None = None,
) -> int:
clauses: list[str] = []
params: list[Any] = []
n = 1
if start:
sv = start if ("+" in start or start.upper().endswith("Z")) else start + "Z"
clauses.append(f"timestamp::timestamptz >= ${n}::timestamptz"); params.append(sv); n += 1
if end:
ev = end if ("+" in end or end.upper().endswith("Z")) else end + "Z"
clauses.append(f"timestamp::timestamptz <= ${n}::timestamptz"); params.append(ev); n += 1
if tool_name:
clauses.append(f"tool_name ILIKE ${n}"); params.append(f"%{tool_name}%"); n += 1
if task_id:
clauses.append(f"task_id = ${n}"); params.append(task_id); n += 1
if session_id:
clauses.append(f"session_id = ${n}"); params.append(session_id); n += 1
if confirmed_only:
clauses.append("confirmed = TRUE")
if user_id:
clauses.append(f"user_id = ${n}"); params.append(user_id); n += 1
where = ("WHERE " + " AND ".join(clauses)) if clauses else ""
pool = await get_pool()
return await pool.fetchval(
f"SELECT COUNT(*) FROM audit_log {where}", *params
) or 0
async def purge(self, older_than_days: int | None = None) -> int:
"""Delete audit records. older_than_days=None deletes all. Returns row count."""
pool = await get_pool()
if older_than_days is not None:
cutoff = (
datetime.now(timezone.utc) - timedelta(days=older_than_days)
).isoformat()
status = await pool.execute(
"DELETE FROM audit_log WHERE timestamp < $1", cutoff
)
else:
status = await pool.execute("DELETE FROM audit_log")
from .database import _rowcount
return _rowcount(status)
# Module-level singleton
audit_log = AuditLog()