Files
oai-web/server/tools/image_gen_tool.py
2026-04-08 12:43:24 +02:00

158 lines
5.8 KiB
Python

"""
tools/image_gen_tool.py — AI image generation tool.
Calls an image-generation model (via OpenRouter by default) and returns the
result. If save_path is given the image is written to disk immediately so the
model doesn't need to handle large base64 blobs.
"""
from __future__ import annotations
import base64
import logging
from .base import BaseTool, ToolResult
logger = logging.getLogger(__name__)
# Default model — override per call or via credential system:default_image_gen_model
_DEFAULT_MODEL = "openrouter:openai/gpt-5-image"
class ImageGenTool(BaseTool):
name = "image_gen"
description = (
"Generate an image from a text prompt using an AI image-generation model. "
"If save_path is provided the image is saved to that path and only the path "
"is returned (no base64 blob in context). "
"If save_path is omitted the raw base64 image data is returned so you can "
"inspect it or pass it to another tool."
)
input_schema = {
"type": "object",
"properties": {
"prompt": {
"type": "string",
"description": "Detailed description of the image to generate",
},
"save_path": {
"type": "string",
"description": (
"Optional absolute file path to save the image to "
"(e.g. /data/users/rune/stewie.png). "
"Recommended — avoids returning a large base64 blob."
),
},
"model": {
"type": "string",
"description": (
"Optional image-generation model ID "
"(e.g. openrouter:openai/gpt-5-image, "
"openrouter:google/gemini-2.0-flash-exp:free). "
"Defaults to the system default image model."
),
},
},
"required": ["prompt"],
}
requires_confirmation = False
allowed_in_scheduled_tasks = True
async def execute(
self,
prompt: str,
save_path: str = "",
model: str = "",
**kwargs,
) -> ToolResult:
# Resolve model: tool arg → credential override → hardcoded default
if not model:
from ..database import credential_store
model = (await credential_store.get("system:default_image_gen_model")) or _DEFAULT_MODEL
# Resolve provider + bare model id
from ..context_vars import current_user as _cu
user_id = _cu.get().id if _cu.get() else None
try:
from ..providers.registry import get_provider_for_model
provider, bare_model = await get_provider_for_model(model, user_id=user_id)
except Exception as e:
return ToolResult(success=False, error=f"Could not resolve image model '{model}': {e}")
# Call the model with a simple user message containing the prompt
try:
response = await provider.chat_async(
messages=[{"role": "user", "content": prompt}],
tools=None,
system="",
model=bare_model,
max_tokens=1024,
)
except Exception as e:
logger.error("[image_gen] Provider call failed: %s", e)
return ToolResult(success=False, error=f"Image generation failed: {e}")
if not response.images:
msg = response.text or "(no images returned)"
logger.warning("[image_gen] No images in response. text=%r", msg)
return ToolResult(
success=False,
error=f"Model did not return any images. Response: {msg[:300]}",
)
# Use the first image (most models return exactly one)
data_url = response.images[0]
# Parse the data URL: data:<media_type>;base64,<data>
media_type = "image/png"
img_b64 = data_url
if data_url.startswith("data:"):
try:
header, img_b64 = data_url.split(",", 1)
media_type = header.split(":")[1].split(";")[0]
except Exception:
pass
img_bytes = base64.b64decode(img_b64)
# Save to path if requested
if save_path:
from pathlib import Path
from ..security import assert_path_allowed, SecurityError
try:
safe_path = await assert_path_allowed(save_path)
except SecurityError as e:
return ToolResult(success=False, error=str(e))
try:
safe_path.parent.mkdir(parents=True, exist_ok=True)
safe_path.write_bytes(img_bytes)
except PermissionError:
return ToolResult(success=False, error=f"Permission denied: {safe_path}")
except Exception as e:
return ToolResult(success=False, error=f"Save failed: {e}")
logger.info("[image_gen] Saved image to %s (%d bytes)", safe_path, len(img_bytes))
return ToolResult(
success=True,
data={
"saved_to": str(safe_path),
"size_bytes": len(img_bytes),
"media_type": media_type,
"model": f"{model}/{bare_model}".strip("/"),
"prompt": prompt,
},
)
# Return base64 data (no save_path given)
logger.info("[image_gen] Returning image data (%d bytes, %s)", len(img_bytes), media_type)
return ToolResult(
success=True,
data={
"is_image": True,
"image_data": img_b64,
"media_type": media_type,
"size_bytes": len(img_bytes),
"model": f"{model}/{bare_model}".strip("/"),
"prompt": prompt,
},
)