183 lines
6.3 KiB
Python
183 lines
6.3 KiB
Python
"""
|
|
audit.py — Append-only audit log.
|
|
|
|
Every tool call is recorded here BEFORE the result is returned to the agent.
|
|
All methods are async — callers must await them.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from dataclasses import dataclass
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import Any
|
|
|
|
from .database import _jsonify, get_pool
|
|
|
|
|
|
@dataclass
|
|
class AuditEntry:
|
|
id: int
|
|
timestamp: str
|
|
session_id: str | None
|
|
tool_name: str
|
|
arguments: dict | None
|
|
result_summary: str | None
|
|
confirmed: bool
|
|
task_id: str | None
|
|
user_id: str | None = None
|
|
|
|
|
|
class AuditLog:
|
|
"""Write audit records and query them for the UI."""
|
|
|
|
async def record(
|
|
self,
|
|
tool_name: str,
|
|
arguments: dict[str, Any] | None = None,
|
|
result_summary: str | None = None,
|
|
confirmed: bool = False,
|
|
session_id: str | None = None,
|
|
task_id: str | None = None,
|
|
user_id: str | None = None,
|
|
) -> int:
|
|
"""Write a tool-call audit record. Returns the new row ID."""
|
|
if user_id is None:
|
|
from .context_vars import current_user as _cu
|
|
u = _cu.get()
|
|
if u:
|
|
user_id = u.id
|
|
now = datetime.now(timezone.utc).isoformat()
|
|
# Sanitise arguments for JSONB (convert non-serializable values to strings)
|
|
args = _jsonify(arguments) if arguments is not None else None
|
|
pool = await get_pool()
|
|
row_id: int = await pool.fetchval(
|
|
"""
|
|
INSERT INTO audit_log
|
|
(timestamp, session_id, tool_name, arguments, result_summary, confirmed, task_id, user_id)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
|
RETURNING id
|
|
""",
|
|
now, session_id, tool_name, args, result_summary, confirmed, task_id, user_id,
|
|
)
|
|
return row_id
|
|
|
|
async def query(
|
|
self,
|
|
start: str | None = None,
|
|
end: str | None = None,
|
|
tool_name: str | None = None,
|
|
session_id: str | None = None,
|
|
task_id: str | None = None,
|
|
confirmed_only: bool = False,
|
|
user_id: str | None = None,
|
|
limit: int = 50,
|
|
offset: int = 0,
|
|
) -> list[AuditEntry]:
|
|
"""Query the audit log. All filters are optional."""
|
|
clauses: list[str] = []
|
|
params: list[Any] = []
|
|
n = 1
|
|
|
|
if start:
|
|
sv = start if ("+" in start or start.upper().endswith("Z")) else start + "Z"
|
|
clauses.append(f"timestamp::timestamptz >= ${n}::timestamptz"); params.append(sv); n += 1
|
|
if end:
|
|
ev = end if ("+" in end or end.upper().endswith("Z")) else end + "Z"
|
|
clauses.append(f"timestamp::timestamptz <= ${n}::timestamptz"); params.append(ev); n += 1
|
|
if tool_name:
|
|
clauses.append(f"tool_name ILIKE ${n}"); params.append(f"%{tool_name}%"); n += 1
|
|
if session_id:
|
|
clauses.append(f"session_id = ${n}"); params.append(session_id); n += 1
|
|
if task_id:
|
|
clauses.append(f"task_id = ${n}"); params.append(task_id); n += 1
|
|
if confirmed_only:
|
|
clauses.append("confirmed = TRUE")
|
|
if user_id:
|
|
clauses.append(f"user_id = ${n}"); params.append(user_id); n += 1
|
|
|
|
where = ("WHERE " + " AND ".join(clauses)) if clauses else ""
|
|
params.extend([limit, offset])
|
|
|
|
pool = await get_pool()
|
|
rows = await pool.fetch(
|
|
f"""
|
|
SELECT id, timestamp, session_id, tool_name, arguments,
|
|
result_summary, confirmed, task_id, user_id
|
|
FROM audit_log
|
|
{where}
|
|
ORDER BY timestamp::timestamptz DESC
|
|
LIMIT ${n} OFFSET ${n + 1}
|
|
""",
|
|
*params,
|
|
)
|
|
return [
|
|
AuditEntry(
|
|
id=r["id"],
|
|
timestamp=r["timestamp"],
|
|
session_id=r["session_id"],
|
|
tool_name=r["tool_name"],
|
|
arguments=r["arguments"], # asyncpg deserialises JSONB automatically
|
|
result_summary=r["result_summary"],
|
|
confirmed=r["confirmed"],
|
|
task_id=r["task_id"],
|
|
user_id=r["user_id"],
|
|
)
|
|
for r in rows
|
|
]
|
|
|
|
async def count(
|
|
self,
|
|
start: str | None = None,
|
|
end: str | None = None,
|
|
tool_name: str | None = None,
|
|
task_id: str | None = None,
|
|
session_id: str | None = None,
|
|
confirmed_only: bool = False,
|
|
user_id: str | None = None,
|
|
) -> int:
|
|
clauses: list[str] = []
|
|
params: list[Any] = []
|
|
n = 1
|
|
|
|
if start:
|
|
sv = start if ("+" in start or start.upper().endswith("Z")) else start + "Z"
|
|
clauses.append(f"timestamp::timestamptz >= ${n}::timestamptz"); params.append(sv); n += 1
|
|
if end:
|
|
ev = end if ("+" in end or end.upper().endswith("Z")) else end + "Z"
|
|
clauses.append(f"timestamp::timestamptz <= ${n}::timestamptz"); params.append(ev); n += 1
|
|
if tool_name:
|
|
clauses.append(f"tool_name ILIKE ${n}"); params.append(f"%{tool_name}%"); n += 1
|
|
if task_id:
|
|
clauses.append(f"task_id = ${n}"); params.append(task_id); n += 1
|
|
if session_id:
|
|
clauses.append(f"session_id = ${n}"); params.append(session_id); n += 1
|
|
if confirmed_only:
|
|
clauses.append("confirmed = TRUE")
|
|
if user_id:
|
|
clauses.append(f"user_id = ${n}"); params.append(user_id); n += 1
|
|
|
|
where = ("WHERE " + " AND ".join(clauses)) if clauses else ""
|
|
pool = await get_pool()
|
|
return await pool.fetchval(
|
|
f"SELECT COUNT(*) FROM audit_log {where}", *params
|
|
) or 0
|
|
|
|
async def purge(self, older_than_days: int | None = None) -> int:
|
|
"""Delete audit records. older_than_days=None deletes all. Returns row count."""
|
|
pool = await get_pool()
|
|
if older_than_days is not None:
|
|
cutoff = (
|
|
datetime.now(timezone.utc) - timedelta(days=older_than_days)
|
|
).isoformat()
|
|
status = await pool.execute(
|
|
"DELETE FROM audit_log WHERE timestamp < $1", cutoff
|
|
)
|
|
else:
|
|
status = await pool.execute("DELETE FROM audit_log")
|
|
from .database import _rowcount
|
|
return _rowcount(status)
|
|
|
|
|
|
# Module-level singleton
|
|
audit_log = AuditLog()
|