241 lines
7.5 KiB
Python
241 lines
7.5 KiB
Python
"""
|
|
brain/database.py — PostgreSQL + pgvector connection pool and schema.
|
|
|
|
Manages the asyncpg connection pool and initialises the thoughts table +
|
|
match_thoughts function on first startup.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import os
|
|
from typing import Any
|
|
|
|
import asyncpg
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_pool: asyncpg.Pool | None = None
|
|
|
|
# ── Schema ────────────────────────────────────────────────────────────────────
|
|
|
|
_SCHEMA_SQL = """
|
|
-- pgvector extension
|
|
CREATE EXTENSION IF NOT EXISTS vector;
|
|
|
|
-- Main thoughts table
|
|
CREATE TABLE IF NOT EXISTS thoughts (
|
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
|
content TEXT NOT NULL,
|
|
embedding vector(1536),
|
|
metadata JSONB NOT NULL DEFAULT '{}',
|
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
|
);
|
|
|
|
-- IVFFlat index for fast approximate nearest-neighbour search.
|
|
-- Created only if it doesn't exist (pg doesn't support IF NOT EXISTS for indexes).
|
|
DO $$
|
|
BEGIN
|
|
IF NOT EXISTS (
|
|
SELECT 1 FROM pg_indexes
|
|
WHERE tablename = 'thoughts' AND indexname = 'thoughts_embedding_idx'
|
|
) THEN
|
|
CREATE INDEX thoughts_embedding_idx
|
|
ON thoughts USING ivfflat (embedding vector_cosine_ops)
|
|
WITH (lists = 100);
|
|
END IF;
|
|
END$$;
|
|
|
|
-- Semantic similarity search function
|
|
CREATE OR REPLACE FUNCTION match_thoughts(
|
|
query_embedding vector(1536),
|
|
match_threshold FLOAT DEFAULT 0.7,
|
|
match_count INT DEFAULT 10
|
|
)
|
|
RETURNS TABLE (
|
|
id UUID,
|
|
content TEXT,
|
|
metadata JSONB,
|
|
similarity FLOAT,
|
|
created_at TIMESTAMPTZ
|
|
)
|
|
LANGUAGE sql STABLE AS $$
|
|
SELECT
|
|
id,
|
|
content,
|
|
metadata,
|
|
1 - (embedding <=> query_embedding) AS similarity,
|
|
created_at
|
|
FROM thoughts
|
|
WHERE 1 - (embedding <=> query_embedding) > match_threshold
|
|
ORDER BY similarity DESC
|
|
LIMIT match_count;
|
|
$$;
|
|
"""
|
|
|
|
|
|
# ── Pool lifecycle ────────────────────────────────────────────────────────────
|
|
|
|
async def init_brain_db() -> None:
|
|
"""
|
|
Create the connection pool and initialise the schema.
|
|
Called from main.py lifespan. No-ops gracefully if BRAIN_DB_URL is unset.
|
|
"""
|
|
global _pool
|
|
url = os.getenv("BRAIN_DB_URL")
|
|
if not url:
|
|
logger.info("BRAIN_DB_URL not set — 2nd Brain disabled")
|
|
return
|
|
try:
|
|
_pool = await asyncpg.create_pool(url, min_size=1, max_size=5)
|
|
async with _pool.acquire() as conn:
|
|
await conn.execute(_SCHEMA_SQL)
|
|
# Per-user brain namespace (3-G): add user_id column if it doesn't exist yet
|
|
await conn.execute(
|
|
"ALTER TABLE thoughts ADD COLUMN IF NOT EXISTS user_id TEXT"
|
|
)
|
|
logger.info("Brain DB initialised")
|
|
except Exception as e:
|
|
logger.error("Brain DB init failed: %s", e)
|
|
_pool = None
|
|
|
|
|
|
async def close_brain_db() -> None:
|
|
global _pool
|
|
if _pool:
|
|
await _pool.close()
|
|
_pool = None
|
|
|
|
|
|
def get_pool() -> asyncpg.Pool | None:
|
|
return _pool
|
|
|
|
|
|
# ── CRUD helpers ──────────────────────────────────────────────────────────────
|
|
|
|
async def insert_thought(
|
|
content: str,
|
|
embedding: list[float],
|
|
metadata: dict,
|
|
user_id: str | None = None,
|
|
) -> str:
|
|
"""Insert a thought and return its UUID."""
|
|
pool = get_pool()
|
|
if pool is None:
|
|
raise RuntimeError("Brain DB not available")
|
|
async with pool.acquire() as conn:
|
|
row = await conn.fetchrow(
|
|
"""
|
|
INSERT INTO thoughts (content, embedding, metadata, user_id)
|
|
VALUES ($1, $2::vector, $3::jsonb, $4)
|
|
RETURNING id::text
|
|
""",
|
|
content,
|
|
str(embedding),
|
|
__import__("json").dumps(metadata),
|
|
user_id,
|
|
)
|
|
return row["id"]
|
|
|
|
|
|
async def search_thoughts(
|
|
query_embedding: list[float],
|
|
threshold: float = 0.7,
|
|
limit: int = 10,
|
|
user_id: str | None = None,
|
|
) -> list[dict]:
|
|
"""Return thoughts ranked by semantic similarity, scoped to user_id if set."""
|
|
pool = get_pool()
|
|
if pool is None:
|
|
raise RuntimeError("Brain DB not available")
|
|
import json as _json
|
|
async with pool.acquire() as conn:
|
|
rows = await conn.fetch(
|
|
"""
|
|
SELECT mt.id, mt.content, mt.metadata, mt.similarity, mt.created_at
|
|
FROM match_thoughts($1::vector, $2, $3) mt
|
|
JOIN thoughts t ON t.id = mt.id
|
|
WHERE ($4::text IS NULL OR t.user_id = $4::text)
|
|
""",
|
|
str(query_embedding),
|
|
threshold,
|
|
limit,
|
|
user_id,
|
|
)
|
|
return [
|
|
{
|
|
"id": str(r["id"]),
|
|
"content": r["content"],
|
|
"metadata": _json.loads(r["metadata"]) if isinstance(r["metadata"], str) else dict(r["metadata"]),
|
|
"similarity": round(float(r["similarity"]), 4),
|
|
"created_at": r["created_at"].isoformat(),
|
|
}
|
|
for r in rows
|
|
]
|
|
|
|
|
|
async def browse_thoughts(
|
|
limit: int = 20,
|
|
type_filter: str | None = None,
|
|
user_id: str | None = None,
|
|
) -> list[dict]:
|
|
"""Return recent thoughts, optionally filtered by metadata type and user."""
|
|
pool = get_pool()
|
|
if pool is None:
|
|
raise RuntimeError("Brain DB not available")
|
|
async with pool.acquire() as conn:
|
|
rows = await conn.fetch(
|
|
"""
|
|
SELECT id::text, content, metadata, created_at
|
|
FROM thoughts
|
|
WHERE ($1::text IS NULL OR user_id = $1::text)
|
|
AND ($2::text IS NULL OR metadata->>'type' = $2::text)
|
|
ORDER BY created_at DESC
|
|
LIMIT $3
|
|
""",
|
|
user_id,
|
|
type_filter,
|
|
limit,
|
|
)
|
|
import json as _json
|
|
return [
|
|
{
|
|
"id": str(r["id"]),
|
|
"content": r["content"],
|
|
"metadata": _json.loads(r["metadata"]) if isinstance(r["metadata"], str) else dict(r["metadata"]),
|
|
"created_at": r["created_at"].isoformat(),
|
|
}
|
|
for r in rows
|
|
]
|
|
|
|
|
|
async def get_stats(user_id: str | None = None) -> dict:
|
|
"""Return aggregate stats about the thoughts database, scoped to user_id if set."""
|
|
pool = get_pool()
|
|
if pool is None:
|
|
raise RuntimeError("Brain DB not available")
|
|
async with pool.acquire() as conn:
|
|
total = await conn.fetchval(
|
|
"SELECT COUNT(*) FROM thoughts WHERE ($1::text IS NULL OR user_id = $1::text)",
|
|
user_id,
|
|
)
|
|
by_type = await conn.fetch(
|
|
"""
|
|
SELECT metadata->>'type' AS type, COUNT(*) AS count
|
|
FROM thoughts
|
|
WHERE ($1::text IS NULL OR user_id = $1::text)
|
|
GROUP BY metadata->>'type'
|
|
ORDER BY count DESC
|
|
""",
|
|
user_id,
|
|
)
|
|
recent = await conn.fetchval(
|
|
"SELECT created_at FROM thoughts WHERE ($1::text IS NULL OR user_id = $1::text) ORDER BY created_at DESC LIMIT 1",
|
|
user_id,
|
|
)
|
|
return {
|
|
"total": total,
|
|
"by_type": [{"type": r["type"] or "unknown", "count": r["count"]} for r in by_type],
|
|
"most_recent": recent.isoformat() if recent else None,
|
|
}
|