158 lines
5.8 KiB
Python
158 lines
5.8 KiB
Python
"""
|
|
tools/image_gen_tool.py — AI image generation tool.
|
|
|
|
Calls an image-generation model (via OpenRouter by default) and returns the
|
|
result. If save_path is given the image is written to disk immediately so the
|
|
model doesn't need to handle large base64 blobs.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import base64
|
|
import logging
|
|
|
|
from .base import BaseTool, ToolResult
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Default model — override per call or via credential system:default_image_gen_model
|
|
_DEFAULT_MODEL = "openrouter:openai/gpt-5-image"
|
|
|
|
|
|
class ImageGenTool(BaseTool):
|
|
name = "image_gen"
|
|
description = (
|
|
"Generate an image from a text prompt using an AI image-generation model. "
|
|
"If save_path is provided the image is saved to that path and only the path "
|
|
"is returned (no base64 blob in context). "
|
|
"If save_path is omitted the raw base64 image data is returned so you can "
|
|
"inspect it or pass it to another tool."
|
|
)
|
|
input_schema = {
|
|
"type": "object",
|
|
"properties": {
|
|
"prompt": {
|
|
"type": "string",
|
|
"description": "Detailed description of the image to generate",
|
|
},
|
|
"save_path": {
|
|
"type": "string",
|
|
"description": (
|
|
"Optional absolute file path to save the image to "
|
|
"(e.g. /data/users/rune/stewie.png). "
|
|
"Recommended — avoids returning a large base64 blob."
|
|
),
|
|
},
|
|
"model": {
|
|
"type": "string",
|
|
"description": (
|
|
"Optional image-generation model ID "
|
|
"(e.g. openrouter:openai/gpt-5-image, "
|
|
"openrouter:google/gemini-2.0-flash-exp:free). "
|
|
"Defaults to the system default image model."
|
|
),
|
|
},
|
|
},
|
|
"required": ["prompt"],
|
|
}
|
|
requires_confirmation = False
|
|
allowed_in_scheduled_tasks = True
|
|
|
|
async def execute(
|
|
self,
|
|
prompt: str,
|
|
save_path: str = "",
|
|
model: str = "",
|
|
**kwargs,
|
|
) -> ToolResult:
|
|
# Resolve model: tool arg → credential override → hardcoded default
|
|
if not model:
|
|
from ..database import credential_store
|
|
model = (await credential_store.get("system:default_image_gen_model")) or _DEFAULT_MODEL
|
|
|
|
# Resolve provider + bare model id
|
|
from ..context_vars import current_user as _cu
|
|
user_id = _cu.get().id if _cu.get() else None
|
|
try:
|
|
from ..providers.registry import get_provider_for_model
|
|
provider, bare_model = await get_provider_for_model(model, user_id=user_id)
|
|
except Exception as e:
|
|
return ToolResult(success=False, error=f"Could not resolve image model '{model}': {e}")
|
|
|
|
# Call the model with a simple user message containing the prompt
|
|
try:
|
|
response = await provider.chat_async(
|
|
messages=[{"role": "user", "content": prompt}],
|
|
tools=None,
|
|
system="",
|
|
model=bare_model,
|
|
max_tokens=1024,
|
|
)
|
|
except Exception as e:
|
|
logger.error("[image_gen] Provider call failed: %s", e)
|
|
return ToolResult(success=False, error=f"Image generation failed: {e}")
|
|
|
|
if not response.images:
|
|
msg = response.text or "(no images returned)"
|
|
logger.warning("[image_gen] No images in response. text=%r", msg)
|
|
return ToolResult(
|
|
success=False,
|
|
error=f"Model did not return any images. Response: {msg[:300]}",
|
|
)
|
|
|
|
# Use the first image (most models return exactly one)
|
|
data_url = response.images[0]
|
|
|
|
# Parse the data URL: data:<media_type>;base64,<data>
|
|
media_type = "image/png"
|
|
img_b64 = data_url
|
|
if data_url.startswith("data:"):
|
|
try:
|
|
header, img_b64 = data_url.split(",", 1)
|
|
media_type = header.split(":")[1].split(";")[0]
|
|
except Exception:
|
|
pass
|
|
|
|
img_bytes = base64.b64decode(img_b64)
|
|
|
|
# Save to path if requested
|
|
if save_path:
|
|
from pathlib import Path
|
|
from ..security import assert_path_allowed, SecurityError
|
|
try:
|
|
safe_path = await assert_path_allowed(save_path)
|
|
except SecurityError as e:
|
|
return ToolResult(success=False, error=str(e))
|
|
try:
|
|
safe_path.parent.mkdir(parents=True, exist_ok=True)
|
|
safe_path.write_bytes(img_bytes)
|
|
except PermissionError:
|
|
return ToolResult(success=False, error=f"Permission denied: {safe_path}")
|
|
except Exception as e:
|
|
return ToolResult(success=False, error=f"Save failed: {e}")
|
|
|
|
logger.info("[image_gen] Saved image to %s (%d bytes)", safe_path, len(img_bytes))
|
|
return ToolResult(
|
|
success=True,
|
|
data={
|
|
"saved_to": str(safe_path),
|
|
"size_bytes": len(img_bytes),
|
|
"media_type": media_type,
|
|
"model": f"{model}/{bare_model}".strip("/"),
|
|
"prompt": prompt,
|
|
},
|
|
)
|
|
|
|
# Return base64 data (no save_path given)
|
|
logger.info("[image_gen] Returning image data (%d bytes, %s)", len(img_bytes), media_type)
|
|
return ToolResult(
|
|
success=True,
|
|
data={
|
|
"is_image": True,
|
|
"image_data": img_b64,
|
|
"media_type": media_type,
|
|
"size_bytes": len(img_bytes),
|
|
"model": f"{model}/{bare_model}".strip("/"),
|
|
"prompt": prompt,
|
|
},
|
|
)
|