Initial commit
This commit is contained in:
157
server/tools/image_gen_tool.py
Normal file
157
server/tools/image_gen_tool.py
Normal file
@@ -0,0 +1,157 @@
|
||||
"""
|
||||
tools/image_gen_tool.py — AI image generation tool.
|
||||
|
||||
Calls an image-generation model (via OpenRouter by default) and returns the
|
||||
result. If save_path is given the image is written to disk immediately so the
|
||||
model doesn't need to handle large base64 blobs.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import logging
|
||||
|
||||
from .base import BaseTool, ToolResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default model — override per call or via credential system:default_image_gen_model
|
||||
_DEFAULT_MODEL = "openrouter:openai/gpt-5-image"
|
||||
|
||||
|
||||
class ImageGenTool(BaseTool):
|
||||
name = "image_gen"
|
||||
description = (
|
||||
"Generate an image from a text prompt using an AI image-generation model. "
|
||||
"If save_path is provided the image is saved to that path and only the path "
|
||||
"is returned (no base64 blob in context). "
|
||||
"If save_path is omitted the raw base64 image data is returned so you can "
|
||||
"inspect it or pass it to another tool."
|
||||
)
|
||||
input_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"description": "Detailed description of the image to generate",
|
||||
},
|
||||
"save_path": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Optional absolute file path to save the image to "
|
||||
"(e.g. /data/users/rune/stewie.png). "
|
||||
"Recommended — avoids returning a large base64 blob."
|
||||
),
|
||||
},
|
||||
"model": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Optional image-generation model ID "
|
||||
"(e.g. openrouter:openai/gpt-5-image, "
|
||||
"openrouter:google/gemini-2.0-flash-exp:free). "
|
||||
"Defaults to the system default image model."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["prompt"],
|
||||
}
|
||||
requires_confirmation = False
|
||||
allowed_in_scheduled_tasks = True
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
prompt: str,
|
||||
save_path: str = "",
|
||||
model: str = "",
|
||||
**kwargs,
|
||||
) -> ToolResult:
|
||||
# Resolve model: tool arg → credential override → hardcoded default
|
||||
if not model:
|
||||
from ..database import credential_store
|
||||
model = (await credential_store.get("system:default_image_gen_model")) or _DEFAULT_MODEL
|
||||
|
||||
# Resolve provider + bare model id
|
||||
from ..context_vars import current_user as _cu
|
||||
user_id = _cu.get().id if _cu.get() else None
|
||||
try:
|
||||
from ..providers.registry import get_provider_for_model
|
||||
provider, bare_model = await get_provider_for_model(model, user_id=user_id)
|
||||
except Exception as e:
|
||||
return ToolResult(success=False, error=f"Could not resolve image model '{model}': {e}")
|
||||
|
||||
# Call the model with a simple user message containing the prompt
|
||||
try:
|
||||
response = await provider.chat_async(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
tools=None,
|
||||
system="",
|
||||
model=bare_model,
|
||||
max_tokens=1024,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("[image_gen] Provider call failed: %s", e)
|
||||
return ToolResult(success=False, error=f"Image generation failed: {e}")
|
||||
|
||||
if not response.images:
|
||||
msg = response.text or "(no images returned)"
|
||||
logger.warning("[image_gen] No images in response. text=%r", msg)
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error=f"Model did not return any images. Response: {msg[:300]}",
|
||||
)
|
||||
|
||||
# Use the first image (most models return exactly one)
|
||||
data_url = response.images[0]
|
||||
|
||||
# Parse the data URL: data:<media_type>;base64,<data>
|
||||
media_type = "image/png"
|
||||
img_b64 = data_url
|
||||
if data_url.startswith("data:"):
|
||||
try:
|
||||
header, img_b64 = data_url.split(",", 1)
|
||||
media_type = header.split(":")[1].split(";")[0]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
img_bytes = base64.b64decode(img_b64)
|
||||
|
||||
# Save to path if requested
|
||||
if save_path:
|
||||
from pathlib import Path
|
||||
from ..security import assert_path_allowed, SecurityError
|
||||
try:
|
||||
safe_path = await assert_path_allowed(save_path)
|
||||
except SecurityError as e:
|
||||
return ToolResult(success=False, error=str(e))
|
||||
try:
|
||||
safe_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
safe_path.write_bytes(img_bytes)
|
||||
except PermissionError:
|
||||
return ToolResult(success=False, error=f"Permission denied: {safe_path}")
|
||||
except Exception as e:
|
||||
return ToolResult(success=False, error=f"Save failed: {e}")
|
||||
|
||||
logger.info("[image_gen] Saved image to %s (%d bytes)", safe_path, len(img_bytes))
|
||||
return ToolResult(
|
||||
success=True,
|
||||
data={
|
||||
"saved_to": str(safe_path),
|
||||
"size_bytes": len(img_bytes),
|
||||
"media_type": media_type,
|
||||
"model": f"{model}/{bare_model}".strip("/"),
|
||||
"prompt": prompt,
|
||||
},
|
||||
)
|
||||
|
||||
# Return base64 data (no save_path given)
|
||||
logger.info("[image_gen] Returning image data (%d bytes, %s)", len(img_bytes), media_type)
|
||||
return ToolResult(
|
||||
success=True,
|
||||
data={
|
||||
"is_image": True,
|
||||
"image_data": img_b64,
|
||||
"media_type": media_type,
|
||||
"size_bytes": len(img_bytes),
|
||||
"model": f"{model}/{bare_model}".strip("/"),
|
||||
"prompt": prompt,
|
||||
},
|
||||
)
|
||||
Reference in New Issue
Block a user