259 lines
9.5 KiB
Python
259 lines
9.5 KiB
Python
"""
|
|
users.py — User CRUD operations (async, PostgreSQL).
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import uuid
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
|
|
from .auth import hash_password
|
|
from .database import _rowcount, get_pool
|
|
|
|
_PROJECT_ROOT = Path(__file__).parent.parent
|
|
|
|
# Sections in USER.md that contain personal identity info — cleared for new users
|
|
_PERSONAL_SECTIONS = {"identity", "people", "context and background", "hobbies and interests"}
|
|
|
|
_SECTION_PLACEHOLDERS = {
|
|
"identity": (
|
|
"<!-- Fill in your details -->\n"
|
|
"- **Name**: \n"
|
|
"- **Location**: \n"
|
|
"- **Timezone**: \n"
|
|
),
|
|
"people": "<!-- Add people important to you and their relationship to you -->\n",
|
|
"context and background": "<!-- Describe your role, profession, and relevant background -->\n",
|
|
"hobbies and interests": "<!-- Add your hobbies and interests -->\n",
|
|
}
|
|
|
|
|
|
def _make_user_template(user_md_content: str) -> str:
|
|
"""Return USER.md with personal sections blanked out, ready to seed new users."""
|
|
import re
|
|
|
|
# Split into sections on ## headings; keep the heading as part of each chunk
|
|
parts = re.split(r"(?=^## )", user_md_content, flags=re.MULTILINE)
|
|
result = []
|
|
for part in parts:
|
|
heading_match = re.match(r"^## (.+)", part)
|
|
if heading_match:
|
|
section_key = heading_match.group(1).strip().lower()
|
|
if section_key in _PERSONAL_SECTIONS:
|
|
placeholder = _SECTION_PLACEHOLDERS.get(section_key, "<!-- Add your details here -->\n")
|
|
result.append(f"## {heading_match.group(1).strip()}\n\n{placeholder}\n")
|
|
continue
|
|
result.append(part)
|
|
return "".join(result).strip()
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _now() -> str:
|
|
return datetime.now(timezone.utc).isoformat()
|
|
|
|
|
|
async def _provision_user_folder(username: str, soul_content: str = "", user_content: str = "") -> None:
|
|
"""Create {users_base_folder}/{username}/ if the admin has configured a base folder.
|
|
Also writes soul.md and user.md there as editable copies.
|
|
"""
|
|
from .database import credential_store
|
|
base = await credential_store.get("system:users_base_folder")
|
|
if not base:
|
|
return
|
|
user_folder = Path(base.rstrip("/")) / username
|
|
try:
|
|
user_folder.mkdir(parents=True, exist_ok=True)
|
|
logger.info("[users] Provisioned user folder: %s", user_folder)
|
|
if soul_content:
|
|
(user_folder / "soul.md").write_text(soul_content, encoding="utf-8")
|
|
if user_content:
|
|
(user_folder / "user.md").write_text(user_content, encoding="utf-8")
|
|
except Exception as e:
|
|
logger.warning("[users] Could not provision user folder %s: %s", user_folder, e)
|
|
|
|
|
|
async def get_user_folder(user_id: str) -> str | None:
|
|
"""Return the provisioned folder path for a user, or None if not configured."""
|
|
from .database import credential_store
|
|
base = await credential_store.get("system:users_base_folder")
|
|
if not base:
|
|
return None
|
|
user = await get_user_by_id(user_id)
|
|
if not user:
|
|
return None
|
|
import os
|
|
return os.path.join(base.rstrip("/"), user["username"])
|
|
|
|
|
|
async def _seed_user_personality(user_id: str, username: str, role: str) -> None:
|
|
"""Seed per-user personality settings from global files on account creation."""
|
|
from .database import user_settings_store
|
|
|
|
# Always seed SOUL.md content so the user can customise it
|
|
try:
|
|
soul_content = (_PROJECT_ROOT / "SOUL.md").read_text(encoding="utf-8").strip()
|
|
except FileNotFoundError:
|
|
soul_content = ""
|
|
|
|
if soul_content:
|
|
await user_settings_store.set(user_id, "personality_soul", soul_content)
|
|
|
|
try:
|
|
raw_user_md = (_PROJECT_ROOT / "USER.md").read_text(encoding="utf-8")
|
|
except FileNotFoundError:
|
|
raw_user_md = ""
|
|
|
|
if role == "admin":
|
|
# Admin gets the real USER.md verbatim; suppress nag
|
|
user_content = raw_user_md.strip()
|
|
await user_settings_store.set(user_id, "personality_setup_done", "1")
|
|
else:
|
|
# Non-admin: seed from USER.md with personal sections blanked → nag shown
|
|
user_content = _make_user_template(raw_user_md) if raw_user_md else ""
|
|
|
|
if user_content:
|
|
await user_settings_store.set(user_id, "personality_user", user_content)
|
|
|
|
await _provision_user_folder(username, soul_content, user_content)
|
|
|
|
|
|
async def _sync_email_whitelist(pool, old_email: str | None, new_email: str | None) -> None:
|
|
"""Keep email_whitelist in sync with user email changes."""
|
|
if new_email:
|
|
await pool.execute(
|
|
"INSERT INTO email_whitelist (email, daily_limit, created_at) VALUES ($1, 0, $2) "
|
|
"ON CONFLICT (email) DO NOTHING",
|
|
new_email, _now(),
|
|
)
|
|
if old_email and old_email != new_email:
|
|
# Remove old email only if no other user still has it
|
|
still_used = await pool.fetchval(
|
|
"SELECT EXISTS(SELECT 1 FROM users WHERE email = $1)", old_email
|
|
)
|
|
if not still_used:
|
|
await pool.execute("DELETE FROM email_whitelist WHERE email = $1", old_email)
|
|
|
|
|
|
async def create_user(username: str, password: str, role: str = "user", email: str = "") -> dict:
|
|
user_id = str(uuid.uuid4())
|
|
now = _now()
|
|
pw_hash = hash_password(password)
|
|
pool = await get_pool()
|
|
await pool.execute(
|
|
"""
|
|
INSERT INTO users (id, username, password_hash, role, is_active, email, created_at, updated_at)
|
|
VALUES ($1, $2, $3, $4, TRUE, $5, $6, $6)
|
|
""",
|
|
user_id, username, pw_hash, role, email or None, now,
|
|
)
|
|
await _sync_email_whitelist(pool, None, email or None)
|
|
await _seed_user_personality(user_id, username, role)
|
|
return await get_user_by_id(user_id)
|
|
|
|
|
|
async def get_user_by_id(user_id: str) -> dict | None:
|
|
pool = await get_pool()
|
|
row = await pool.fetchrow(
|
|
"SELECT id, username, email, display_name, role, is_active, totp_secret, created_at, updated_at FROM users WHERE id = $1",
|
|
user_id,
|
|
)
|
|
return dict(row) if row else None
|
|
|
|
|
|
async def get_user_by_username(username: str) -> dict | None:
|
|
"""Returns dict including password_hash and totp_secret for verification."""
|
|
pool = await get_pool()
|
|
row = await pool.fetchrow(
|
|
"""
|
|
SELECT id, username, email, display_name, password_hash, role, is_active, totp_secret, created_at, updated_at
|
|
FROM users WHERE username = $1
|
|
""",
|
|
username,
|
|
)
|
|
return dict(row) if row else None
|
|
|
|
|
|
async def list_users() -> list[dict]:
|
|
pool = await get_pool()
|
|
rows = await pool.fetch(
|
|
"SELECT id, username, email, display_name, role, is_active, totp_secret, created_at, updated_at FROM users ORDER BY created_at ASC"
|
|
)
|
|
result = []
|
|
for r in rows:
|
|
d = dict(r)
|
|
d["mfa_enabled"] = d.pop("totp_secret") is not None
|
|
result.append(d)
|
|
return result
|
|
|
|
|
|
async def update_user(user_id: str, **fields) -> bool:
|
|
"""Update user fields. Pass password= to re-hash; role=, is_active= to change those."""
|
|
new_email = fields.get("email") or None
|
|
if "email" in fields:
|
|
# Fetch old email before update to detect changes
|
|
pool = await get_pool()
|
|
old_row = await pool.fetchrow("SELECT email FROM users WHERE id = $1", user_id)
|
|
old_email = old_row["email"] if old_row else None
|
|
else:
|
|
old_email = None
|
|
|
|
if "password" in fields:
|
|
fields["password_hash"] = hash_password(fields.pop("password"))
|
|
fields["updated_at"] = _now()
|
|
|
|
set_parts = []
|
|
values: list = []
|
|
for i, (k, v) in enumerate(fields.items(), start=1):
|
|
set_parts.append(f"{k} = ${i}")
|
|
values.append(v)
|
|
values.append(user_id)
|
|
|
|
pool = await get_pool()
|
|
status = await pool.execute(
|
|
f"UPDATE users SET {', '.join(set_parts)} WHERE id = ${len(values)}",
|
|
*values,
|
|
)
|
|
updated = _rowcount(status) > 0
|
|
if updated and "email" in fields:
|
|
await _sync_email_whitelist(pool, old_email, new_email)
|
|
return updated
|
|
|
|
|
|
async def delete_user(user_id: str) -> bool:
|
|
pool = await get_pool()
|
|
# Fetch email before delete for whitelist cleanup
|
|
old_row = await pool.fetchrow("SELECT email FROM users WHERE id = $1", user_id)
|
|
old_email = old_row["email"] if old_row else None
|
|
# Nullify FK references before deleting
|
|
async with pool.acquire() as conn:
|
|
async with conn.transaction():
|
|
await conn.execute("UPDATE agents SET owner_user_id = NULL WHERE owner_user_id = $1", user_id)
|
|
await conn.execute("UPDATE conversations SET user_id = NULL WHERE user_id = $1", user_id)
|
|
await conn.execute("UPDATE audit_log SET user_id = NULL WHERE user_id = $1", user_id)
|
|
status = await conn.execute("DELETE FROM users WHERE id = $1", user_id)
|
|
if _rowcount(status) > 0:
|
|
await _sync_email_whitelist(pool, old_email, None)
|
|
return True
|
|
return False
|
|
|
|
|
|
async def user_count() -> int:
|
|
pool = await get_pool()
|
|
return await pool.fetchval("SELECT COUNT(*) FROM users") or 0
|
|
|
|
|
|
async def assign_existing_data_to_admin(admin_id: str) -> None:
|
|
"""Assign all existing NULL-owner data to the first admin. Called after /setup."""
|
|
pool = await get_pool()
|
|
await pool.execute(
|
|
"UPDATE agents SET owner_user_id = $1 WHERE owner_user_id IS NULL", admin_id
|
|
)
|
|
await pool.execute(
|
|
"UPDATE conversations SET user_id = $1 WHERE user_id IS NULL", admin_id
|
|
)
|
|
await pool.execute(
|
|
"UPDATE audit_log SET user_id = $1 WHERE user_id IS NULL", admin_id
|
|
)
|