226 lines
7.0 KiB
Python
226 lines
7.0 KiB
Python
"""
|
|
agents/tasks.py — Agent and agent run CRUD operations (async).
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import uuid
|
|
from datetime import datetime, timezone
|
|
from typing import Any
|
|
|
|
from ..database import _rowcount, get_pool
|
|
|
|
|
|
def _now() -> str:
|
|
return datetime.now(timezone.utc).isoformat()
|
|
|
|
|
|
def _agent_row(row) -> dict:
|
|
"""Convert asyncpg Record to a plain dict, normalising JSONB fields."""
|
|
d = dict(row)
|
|
# allowed_tools: JSONB column, but SQLite-migrated rows may have stored a
|
|
# JSON string instead of a JSON array — asyncpg then returns a str.
|
|
at = d.get("allowed_tools")
|
|
if isinstance(at, str):
|
|
try:
|
|
d["allowed_tools"] = json.loads(at)
|
|
except (json.JSONDecodeError, ValueError):
|
|
d["allowed_tools"] = None
|
|
return d
|
|
|
|
|
|
# ── Agents ────────────────────────────────────────────────────────────────────
|
|
|
|
async def create_agent(
|
|
name: str,
|
|
prompt: str,
|
|
model: str,
|
|
description: str = "",
|
|
can_create_subagents: bool = False,
|
|
allowed_tools: list[str] | None = None,
|
|
schedule: str | None = None,
|
|
enabled: bool = True,
|
|
parent_agent_id: str | None = None,
|
|
created_by: str = "user",
|
|
max_tool_calls: int | None = None,
|
|
prompt_mode: str = "combined",
|
|
owner_user_id: str | None = None,
|
|
) -> dict:
|
|
agent_id = str(uuid.uuid4())
|
|
now = _now()
|
|
pool = await get_pool()
|
|
await pool.execute(
|
|
"""
|
|
INSERT INTO agents
|
|
(id, name, description, prompt, model, can_create_subagents,
|
|
allowed_tools, schedule, enabled, parent_agent_id, created_by,
|
|
created_at, updated_at, max_tool_calls, prompt_mode, owner_user_id)
|
|
VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16)
|
|
""",
|
|
agent_id, name, description, prompt, model,
|
|
can_create_subagents,
|
|
allowed_tools, # JSONB — pass list directly
|
|
schedule, enabled,
|
|
parent_agent_id, created_by, now, now,
|
|
max_tool_calls, prompt_mode, owner_user_id,
|
|
)
|
|
return await get_agent(agent_id)
|
|
|
|
|
|
async def list_agents(
|
|
include_subagents: bool = True,
|
|
owner_user_id: str | None = None,
|
|
) -> list[dict]:
|
|
pool = await get_pool()
|
|
clauses: list[str] = []
|
|
params: list[Any] = []
|
|
n = 1
|
|
|
|
if not include_subagents:
|
|
clauses.append("parent_agent_id IS NULL")
|
|
if owner_user_id is not None:
|
|
clauses.append(f"owner_user_id = ${n}"); params.append(owner_user_id); n += 1
|
|
|
|
where = ("WHERE " + " AND ".join(clauses)) if clauses else ""
|
|
rows = await pool.fetch(
|
|
f"""
|
|
SELECT a.*,
|
|
(SELECT started_at FROM agent_runs
|
|
WHERE agent_id = a.id
|
|
ORDER BY started_at DESC LIMIT 1) AS last_run_at
|
|
FROM agents a {where} ORDER BY a.created_at DESC
|
|
""",
|
|
*params,
|
|
)
|
|
return [_agent_row(r) for r in rows]
|
|
|
|
|
|
async def get_agent(agent_id: str) -> dict | None:
|
|
pool = await get_pool()
|
|
row = await pool.fetchrow("SELECT * FROM agents WHERE id = $1", agent_id)
|
|
return _agent_row(row) if row else None
|
|
|
|
|
|
async def update_agent(agent_id: str, **fields) -> dict | None:
|
|
if not await get_agent(agent_id):
|
|
return None
|
|
now = _now()
|
|
fields["updated_at"] = now
|
|
|
|
# No bool→int conversion needed — PostgreSQL BOOLEAN accepts Python bool directly
|
|
# No json.dumps needed — JSONB accepts Python list directly
|
|
|
|
set_parts = []
|
|
values: list[Any] = []
|
|
for i, (k, v) in enumerate(fields.items(), start=1):
|
|
set_parts.append(f"{k} = ${i}")
|
|
values.append(v)
|
|
|
|
id_param = len(fields) + 1
|
|
values.append(agent_id)
|
|
|
|
pool = await get_pool()
|
|
await pool.execute(
|
|
f"UPDATE agents SET {', '.join(set_parts)} WHERE id = ${id_param}", *values
|
|
)
|
|
return await get_agent(agent_id)
|
|
|
|
|
|
async def delete_agent(agent_id: str) -> bool:
|
|
pool = await get_pool()
|
|
async with pool.acquire() as conn:
|
|
async with conn.transaction():
|
|
await conn.execute("DELETE FROM agent_runs WHERE agent_id = $1", agent_id)
|
|
await conn.execute(
|
|
"UPDATE agents SET parent_agent_id = NULL WHERE parent_agent_id = $1", agent_id
|
|
)
|
|
await conn.execute(
|
|
"UPDATE scheduled_tasks SET agent_id = NULL WHERE agent_id = $1", agent_id
|
|
)
|
|
status = await conn.execute("DELETE FROM agents WHERE id = $1", agent_id)
|
|
return _rowcount(status) > 0
|
|
|
|
|
|
# ── Agent runs ────────────────────────────────────────────────────────────────
|
|
|
|
async def create_run(agent_id: str) -> dict:
|
|
run_id = str(uuid.uuid4())
|
|
now = _now()
|
|
pool = await get_pool()
|
|
await pool.execute(
|
|
"INSERT INTO agent_runs (id, agent_id, started_at, status) VALUES ($1, $2, $3, 'running')",
|
|
run_id, agent_id, now,
|
|
)
|
|
return await get_run(run_id)
|
|
|
|
|
|
async def finish_run(
|
|
run_id: str,
|
|
status: str,
|
|
input_tokens: int = 0,
|
|
output_tokens: int = 0,
|
|
result: str | None = None,
|
|
error: str | None = None,
|
|
) -> dict | None:
|
|
now = _now()
|
|
pool = await get_pool()
|
|
await pool.execute(
|
|
"""
|
|
UPDATE agent_runs
|
|
SET ended_at = $1, status = $2, input_tokens = $3,
|
|
output_tokens = $4, result = $5, error = $6
|
|
WHERE id = $7
|
|
""",
|
|
now, status, input_tokens, output_tokens, result, error, run_id,
|
|
)
|
|
return await get_run(run_id)
|
|
|
|
|
|
async def get_run(run_id: str) -> dict | None:
|
|
pool = await get_pool()
|
|
row = await pool.fetchrow("SELECT * FROM agent_runs WHERE id = $1", run_id)
|
|
return dict(row) if row else None
|
|
|
|
|
|
async def cleanup_stale_runs() -> int:
|
|
"""Mark any runs still in 'running' state as 'error' (interrupted by restart)."""
|
|
now = _now()
|
|
pool = await get_pool()
|
|
status = await pool.execute(
|
|
"""
|
|
UPDATE agent_runs
|
|
SET status = 'error', ended_at = $1, error = 'Interrupted by server restart'
|
|
WHERE status = 'running'
|
|
""",
|
|
now,
|
|
)
|
|
return _rowcount(status)
|
|
|
|
|
|
async def list_runs(
|
|
agent_id: str | None = None,
|
|
since: str | None = None,
|
|
status: str | None = None,
|
|
limit: int = 200,
|
|
) -> list[dict]:
|
|
clauses: list[str] = []
|
|
params: list[Any] = []
|
|
n = 1
|
|
|
|
if agent_id:
|
|
clauses.append(f"agent_id = ${n}"); params.append(agent_id); n += 1
|
|
if since:
|
|
clauses.append(f"started_at >= ${n}"); params.append(since); n += 1
|
|
if status:
|
|
clauses.append(f"status = ${n}"); params.append(status); n += 1
|
|
|
|
where = f"WHERE {' AND '.join(clauses)}" if clauses else ""
|
|
params.append(limit)
|
|
|
|
pool = await get_pool()
|
|
rows = await pool.fetch(
|
|
f"SELECT * FROM agent_runs {where} ORDER BY started_at DESC LIMIT ${n}",
|
|
*params,
|
|
)
|
|
return [dict(r) for r in rows]
|