Files
openclaw/src/agents/tools/image-generate-tool.ts

650 lines
21 KiB
TypeScript

import { Type } from "@sinclair/typebox";
import type { OpenClawConfig } from "../../config/config.js";
import { loadConfig } from "../../config/config.js";
import { parseImageGenerationModelRef } from "../../image-generation/model-ref.js";
import {
generateImage,
listRuntimeImageGenerationProviders,
} from "../../image-generation/runtime.js";
import type {
ImageGenerationProvider,
ImageGenerationResolution,
ImageGenerationSourceImage,
} from "../../image-generation/types.js";
import { getImageMetadata } from "../../media/image-ops.js";
import { saveMediaBuffer } from "../../media/store.js";
import { loadWebMedia } from "../../media/web-media.js";
import { getProviderEnvVars } from "../../secrets/provider-env-vars.js";
import { resolveUserPath } from "../../utils.js";
import { ToolInputError, readNumberParam, readStringParam } from "./common.js";
import { decodeDataUrl } from "./image-tool.helpers.js";
import {
applyImageGenerationModelConfigDefaults,
resolveMediaToolLocalRoots,
} from "./media-tool-shared.js";
import {
buildToolModelConfigFromCandidates,
coerceToolModelConfig,
hasToolModelConfig,
resolveDefaultModelRef,
type ToolModelConfig,
} from "./model-config.helpers.js";
import {
createSandboxBridgeReadFile,
resolveSandboxedBridgeMediaPath,
type AnyAgentTool,
type SandboxFsBridge,
type ToolFsPolicy,
} from "./tool-runtime.helpers.js";
const DEFAULT_COUNT = 1;
const MAX_COUNT = 4;
const MAX_INPUT_IMAGES = 5;
const DEFAULT_RESOLUTION: ImageGenerationResolution = "1K";
const SUPPORTED_ASPECT_RATIOS = new Set([
"1:1",
"2:3",
"3:2",
"3:4",
"4:3",
"4:5",
"5:4",
"9:16",
"16:9",
"21:9",
]);
const ImageGenerateToolSchema = Type.Object({
action: Type.Optional(
Type.String({
description:
'Optional action: "generate" (default) or "list" to inspect available providers/models.',
}),
),
prompt: Type.Optional(Type.String({ description: "Image generation prompt." })),
image: Type.Optional(
Type.String({
description: "Optional reference image path or URL for edit mode.",
}),
),
images: Type.Optional(
Type.Array(Type.String(), {
description: `Optional reference images for edit mode (up to ${MAX_INPUT_IMAGES}).`,
}),
),
model: Type.Optional(
Type.String({ description: "Optional provider/model override, e.g. openai/gpt-image-1." }),
),
filename: Type.Optional(
Type.String({
description:
"Optional output filename hint. OpenClaw preserves the basename and saves under its managed media directory.",
}),
),
size: Type.Optional(
Type.String({
description:
"Optional size hint like 1024x1024, 1536x1024, 1024x1536, 1024x1792, or 1792x1024.",
}),
),
aspectRatio: Type.Optional(
Type.String({
description:
"Optional aspect ratio hint: 1:1, 2:3, 3:2, 3:4, 4:3, 4:5, 5:4, 9:16, 16:9, or 21:9.",
}),
),
resolution: Type.Optional(
Type.String({
description:
"Optional resolution hint: 1K, 2K, or 4K. Useful for Google edit/generation flows.",
}),
),
count: Type.Optional(
Type.Number({
description: `Optional number of images to request (1-${MAX_COUNT}).`,
minimum: 1,
maximum: MAX_COUNT,
}),
),
});
function getImageGenerationProviderAuthEnvVars(providerId: string): string[] {
return getProviderEnvVars(providerId);
}
function resolveImageGenerationModelCandidates(
cfg: OpenClawConfig | undefined,
): Array<string | undefined> {
const providerDefaults = new Map<string, string>();
for (const provider of listRuntimeImageGenerationProviders({ config: cfg })) {
const providerId = provider.id.trim();
const modelId = provider.defaultModel?.trim();
if (!providerId || !modelId || providerDefaults.has(providerId)) {
continue;
}
providerDefaults.set(providerId, `${providerId}/${modelId}`);
}
const orderedProviders = [
resolveDefaultModelRef(cfg).provider,
"openai",
"google",
...providerDefaults.keys(),
];
const orderedRefs: string[] = [];
const seen = new Set<string>();
for (const providerId of orderedProviders) {
const ref = providerDefaults.get(providerId);
if (!ref || seen.has(ref)) {
continue;
}
seen.add(ref);
orderedRefs.push(ref);
}
return orderedRefs;
}
export function resolveImageGenerationModelConfigForTool(params: {
cfg?: OpenClawConfig;
agentDir?: string;
}): ToolModelConfig | null {
const explicit = coerceToolModelConfig(params.cfg?.agents?.defaults?.imageGenerationModel);
if (hasToolModelConfig(explicit)) {
return explicit;
}
return buildToolModelConfigFromCandidates({
explicit,
agentDir: params.agentDir,
candidates: resolveImageGenerationModelCandidates(params.cfg),
});
}
function resolveAction(args: Record<string, unknown>): "generate" | "list" {
const raw = readStringParam(args, "action");
if (!raw) {
return "generate";
}
const normalized = raw.trim().toLowerCase();
if (normalized === "generate" || normalized === "list") {
return normalized;
}
throw new ToolInputError('action must be "generate" or "list"');
}
function resolveRequestedCount(args: Record<string, unknown>): number {
const count = readNumberParam(args, "count", { integer: true });
if (count === undefined) {
return DEFAULT_COUNT;
}
if (count < 1 || count > MAX_COUNT) {
throw new ToolInputError(`count must be between 1 and ${MAX_COUNT}`);
}
return count;
}
function normalizeResolution(raw: string | undefined): ImageGenerationResolution | undefined {
const normalized = raw?.trim().toUpperCase();
if (!normalized) {
return undefined;
}
if (normalized === "1K" || normalized === "2K" || normalized === "4K") {
return normalized;
}
throw new ToolInputError("resolution must be one of 1K, 2K, or 4K");
}
function normalizeAspectRatio(raw: string | undefined): string | undefined {
const normalized = raw?.trim();
if (!normalized) {
return undefined;
}
if (SUPPORTED_ASPECT_RATIOS.has(normalized)) {
return normalized;
}
throw new ToolInputError(
"aspectRatio must be one of 1:1, 2:3, 3:2, 3:4, 4:3, 4:5, 5:4, 9:16, 16:9, or 21:9",
);
}
function normalizeReferenceImages(args: Record<string, unknown>): string[] {
const imageCandidates: string[] = [];
if (typeof args.image === "string") {
imageCandidates.push(args.image);
}
if (Array.isArray(args.images)) {
imageCandidates.push(
...args.images.filter((value): value is string => typeof value === "string"),
);
}
const seen = new Set<string>();
const normalized: string[] = [];
for (const candidate of imageCandidates) {
const trimmed = candidate.trim();
const dedupe = trimmed.startsWith("@") ? trimmed.slice(1).trim() : trimmed;
if (!dedupe || seen.has(dedupe)) {
continue;
}
seen.add(dedupe);
normalized.push(trimmed);
}
if (normalized.length > MAX_INPUT_IMAGES) {
throw new ToolInputError(
`Too many reference images: ${normalized.length} provided, maximum is ${MAX_INPUT_IMAGES}.`,
);
}
return normalized;
}
function resolveSelectedImageGenerationProvider(params: {
config?: OpenClawConfig;
imageGenerationModelConfig: ToolModelConfig;
modelOverride?: string;
}): ImageGenerationProvider | undefined {
const selectedRef =
parseImageGenerationModelRef(params.modelOverride) ??
parseImageGenerationModelRef(params.imageGenerationModelConfig.primary);
if (!selectedRef) {
return undefined;
}
return listRuntimeImageGenerationProviders({ config: params.config }).find(
(provider) =>
provider.id === selectedRef.provider ||
(provider.aliases ?? []).includes(selectedRef.provider),
);
}
function validateImageGenerationCapabilities(params: {
provider: ImageGenerationProvider | undefined;
count: number;
inputImageCount: number;
size?: string;
aspectRatio?: string;
resolution?: ImageGenerationResolution;
}) {
const provider = params.provider;
if (!provider) {
return;
}
const isEdit = params.inputImageCount > 0;
const modeCaps = isEdit ? provider.capabilities.edit : provider.capabilities.generate;
const geometry = provider.capabilities.geometry;
const maxCount = modeCaps.maxCount ?? MAX_COUNT;
if (params.count > maxCount) {
throw new ToolInputError(
`${provider.id} ${isEdit ? "edit" : "generate"} supports at most ${maxCount} output image${maxCount === 1 ? "" : "s"}.`,
);
}
if (isEdit) {
if (!provider.capabilities.edit.enabled) {
throw new ToolInputError(`${provider.id} does not support reference-image edits.`);
}
const maxInputImages = provider.capabilities.edit.maxInputImages ?? MAX_INPUT_IMAGES;
if (params.inputImageCount > maxInputImages) {
throw new ToolInputError(
`${provider.id} edit supports at most ${maxInputImages} reference image${maxInputImages === 1 ? "" : "s"}.`,
);
}
}
if (params.size) {
if (!modeCaps.supportsSize) {
throw new ToolInputError(
`${provider.id} ${isEdit ? "edit" : "generate"} does not support size overrides.`,
);
}
if ((geometry?.sizes?.length ?? 0) > 0 && !geometry?.sizes?.includes(params.size)) {
throw new ToolInputError(
`${provider.id} ${isEdit ? "edit" : "generate"} size must be one of ${geometry?.sizes?.join(", ")}.`,
);
}
}
if (params.aspectRatio) {
if (!modeCaps.supportsAspectRatio) {
throw new ToolInputError(
`${provider.id} ${isEdit ? "edit" : "generate"} does not support aspectRatio overrides.`,
);
}
if (
(geometry?.aspectRatios?.length ?? 0) > 0 &&
!geometry?.aspectRatios?.includes(params.aspectRatio)
) {
throw new ToolInputError(
`${provider.id} ${isEdit ? "edit" : "generate"} aspectRatio must be one of ${geometry?.aspectRatios?.join(", ")}.`,
);
}
}
if (params.resolution) {
if (!modeCaps.supportsResolution) {
throw new ToolInputError(
`${provider.id} ${isEdit ? "edit" : "generate"} does not support resolution overrides.`,
);
}
if (
(geometry?.resolutions?.length ?? 0) > 0 &&
!geometry?.resolutions?.includes(params.resolution)
) {
throw new ToolInputError(
`${provider.id} ${isEdit ? "edit" : "generate"} resolution must be one of ${geometry?.resolutions?.join("/")}.`,
);
}
}
}
type ImageGenerateSandboxConfig = {
root: string;
bridge: SandboxFsBridge;
};
async function loadReferenceImages(params: {
imageInputs: string[];
maxBytes?: number;
workspaceDir?: string;
sandboxConfig: { root: string; bridge: SandboxFsBridge; workspaceOnly: boolean } | null;
}): Promise<
Array<{
sourceImage: ImageGenerationSourceImage;
resolvedImage: string;
rewrittenFrom?: string;
}>
> {
const loaded: Array<{
sourceImage: ImageGenerationSourceImage;
resolvedImage: string;
rewrittenFrom?: string;
}> = [];
for (const imageRawInput of params.imageInputs) {
const trimmed = imageRawInput.trim();
const imageRaw = trimmed.startsWith("@") ? trimmed.slice(1).trim() : trimmed;
if (!imageRaw) {
throw new ToolInputError("image required (empty string in array)");
}
const looksLikeWindowsDrivePath = /^[a-zA-Z]:[\\/]/.test(imageRaw);
const hasScheme = /^[a-z][a-z0-9+.-]*:/i.test(imageRaw);
const isFileUrl = /^file:/i.test(imageRaw);
const isHttpUrl = /^https?:\/\//i.test(imageRaw);
const isDataUrl = /^data:/i.test(imageRaw);
if (hasScheme && !looksLikeWindowsDrivePath && !isFileUrl && !isHttpUrl && !isDataUrl) {
throw new ToolInputError(
`Unsupported image reference: ${imageRawInput}. Use a file path, a file:// URL, a data: URL, or an http(s) URL.`,
);
}
if (params.sandboxConfig && isHttpUrl) {
throw new ToolInputError("Sandboxed image_generate does not allow remote URLs.");
}
const resolvedImage = (() => {
if (params.sandboxConfig) {
return imageRaw;
}
if (imageRaw.startsWith("~")) {
return resolveUserPath(imageRaw);
}
return imageRaw;
})();
const resolvedPathInfo: { resolved: string; rewrittenFrom?: string } = isDataUrl
? { resolved: "" }
: params.sandboxConfig
? await resolveSandboxedBridgeMediaPath({
sandbox: params.sandboxConfig,
mediaPath: resolvedImage,
inboundFallbackDir: "media/inbound",
})
: {
resolved: resolvedImage.startsWith("file://")
? resolvedImage.slice("file://".length)
: resolvedImage,
};
const resolvedPath = isDataUrl ? null : resolvedPathInfo.resolved;
const localRoots = resolveMediaToolLocalRoots(
params.workspaceDir,
{
workspaceOnly: params.sandboxConfig?.workspaceOnly === true,
},
resolvedPath ? [resolvedPath] : undefined,
);
const media = isDataUrl
? decodeDataUrl(resolvedImage)
: params.sandboxConfig
? await loadWebMedia(resolvedPath ?? resolvedImage, {
maxBytes: params.maxBytes,
sandboxValidated: true,
readFile: createSandboxBridgeReadFile({ sandbox: params.sandboxConfig }),
})
: await loadWebMedia(resolvedPath ?? resolvedImage, {
maxBytes: params.maxBytes,
localRoots,
});
if (media.kind !== "image") {
throw new ToolInputError(`Unsupported media type: ${media.kind}`);
}
const mimeType =
("contentType" in media && media.contentType) ||
("mimeType" in media && media.mimeType) ||
"image/png";
loaded.push({
sourceImage: {
buffer: media.buffer,
mimeType,
},
resolvedImage,
...(resolvedPathInfo.rewrittenFrom ? { rewrittenFrom: resolvedPathInfo.rewrittenFrom } : {}),
});
}
return loaded;
}
async function inferResolutionFromInputImages(
images: ImageGenerationSourceImage[],
): Promise<ImageGenerationResolution> {
let maxDimension = 0;
for (const image of images) {
const meta = await getImageMetadata(image.buffer);
const dimension = Math.max(meta?.width ?? 0, meta?.height ?? 0);
maxDimension = Math.max(maxDimension, dimension);
}
if (maxDimension >= 3000) {
return "4K";
}
if (maxDimension >= 1500) {
return "2K";
}
return DEFAULT_RESOLUTION;
}
export function createImageGenerateTool(options?: {
config?: OpenClawConfig;
agentDir?: string;
workspaceDir?: string;
sandbox?: ImageGenerateSandboxConfig;
fsPolicy?: ToolFsPolicy;
}): AnyAgentTool | null {
const cfg = options?.config ?? loadConfig();
const imageGenerationModelConfig = resolveImageGenerationModelConfigForTool({
cfg,
agentDir: options?.agentDir,
});
if (!imageGenerationModelConfig) {
return null;
}
const effectiveCfg =
applyImageGenerationModelConfigDefaults(cfg, imageGenerationModelConfig) ?? cfg;
const sandboxConfig =
options?.sandbox && options.sandbox.root.trim()
? {
root: options.sandbox.root.trim(),
bridge: options.sandbox.bridge,
workspaceOnly: options.fsPolicy?.workspaceOnly === true,
}
: null;
return {
label: "Image Generation",
name: "image_generate",
description:
'Generate new images or edit reference images with the configured or inferred image-generation model. Set agents.defaults.imageGenerationModel.primary to pick a provider/model. If you want openai/*, google/*, fal/*, or another provider, configure that provider auth/API key first. Use action="list" to inspect available providers, models, and auth hints. Generated images are delivered automatically from the tool result as MEDIA paths.',
parameters: ImageGenerateToolSchema,
execute: async (_toolCallId, args) => {
const params = args as Record<string, unknown>;
const action = resolveAction(params);
if (action === "list") {
const providers = listRuntimeImageGenerationProviders({ config: effectiveCfg }).map(
(provider) => ({
id: provider.id,
...(provider.label ? { label: provider.label } : {}),
...(provider.defaultModel ? { defaultModel: provider.defaultModel } : {}),
models: provider.models ?? (provider.defaultModel ? [provider.defaultModel] : []),
authEnvVars: getImageGenerationProviderAuthEnvVars(provider.id),
capabilities: provider.capabilities,
}),
);
const lines = providers.flatMap((provider) => {
const caps: string[] = [];
if (provider.capabilities.edit.enabled) {
const maxRefs = provider.capabilities.edit.maxInputImages;
caps.push(
`editing${typeof maxRefs === "number" ? ` up to ${maxRefs} ref${maxRefs === 1 ? "" : "s"}` : ""}`,
);
}
if ((provider.capabilities.geometry?.resolutions?.length ?? 0) > 0) {
caps.push(`resolutions ${provider.capabilities.geometry?.resolutions?.join("/")}`);
}
if ((provider.capabilities.geometry?.sizes?.length ?? 0) > 0) {
caps.push(`sizes ${provider.capabilities.geometry?.sizes?.join(", ")}`);
}
if ((provider.capabilities.geometry?.aspectRatios?.length ?? 0) > 0) {
caps.push(`aspect ratios ${provider.capabilities.geometry?.aspectRatios?.join(", ")}`);
}
const modelLine =
provider.models.length > 0
? `models: ${provider.models.join(", ")}`
: "models: unknown";
return [
`${provider.id}${provider.defaultModel ? ` (default ${provider.defaultModel})` : ""}`,
` ${modelLine}`,
...(provider.authEnvVars.length > 0
? [` auth: set ${provider.authEnvVars.join(" / ")} to use ${provider.id}/*`]
: []),
...(caps.length > 0 ? [` capabilities: ${caps.join("; ")}`] : []),
];
});
return {
content: [{ type: "text", text: lines.join("\n") }],
details: { providers },
};
}
const prompt = readStringParam(params, "prompt", { required: true });
const imageInputs = normalizeReferenceImages(params);
const model = readStringParam(params, "model");
const filename = readStringParam(params, "filename");
const size = readStringParam(params, "size");
const aspectRatio = normalizeAspectRatio(readStringParam(params, "aspectRatio"));
const explicitResolution = normalizeResolution(readStringParam(params, "resolution"));
const count = resolveRequestedCount(params);
const loadedReferenceImages = await loadReferenceImages({
imageInputs,
workspaceDir: options?.workspaceDir,
sandboxConfig,
});
const inputImages = loadedReferenceImages.map((entry) => entry.sourceImage);
const resolution =
explicitResolution ??
(size
? undefined
: inputImages.length > 0
? await inferResolutionFromInputImages(inputImages)
: undefined);
const selectedProvider = resolveSelectedImageGenerationProvider({
config: effectiveCfg,
imageGenerationModelConfig,
modelOverride: model,
});
validateImageGenerationCapabilities({
provider: selectedProvider,
count,
inputImageCount: inputImages.length,
size,
aspectRatio,
resolution,
});
const result = await generateImage({
cfg: effectiveCfg,
prompt,
agentDir: options?.agentDir,
modelOverride: model,
size,
aspectRatio,
resolution,
count,
inputImages,
});
const savedImages = await Promise.all(
result.images.map((image) =>
saveMediaBuffer(
image.buffer,
image.mimeType,
"tool-image-generation",
undefined,
filename || image.fileName,
),
),
);
const revisedPrompts = result.images
.map((image) => image.revisedPrompt?.trim())
.filter((entry): entry is string => Boolean(entry));
const lines = [
`Generated ${savedImages.length} image${savedImages.length === 1 ? "" : "s"} with ${result.provider}/${result.model}.`,
];
return {
content: [{ type: "text", text: lines.join("\n") }],
details: {
provider: result.provider,
model: result.model,
count: savedImages.length,
media: {
mediaUrls: savedImages.map((image) => image.path),
},
paths: savedImages.map((image) => image.path),
...(imageInputs.length === 1
? {
image: loadedReferenceImages[0]?.resolvedImage,
...(loadedReferenceImages[0]?.rewrittenFrom
? { rewrittenFrom: loadedReferenceImages[0].rewrittenFrom }
: {}),
}
: imageInputs.length > 1
? {
images: loadedReferenceImages.map((entry) => ({
image: entry.resolvedImage,
...(entry.rewrittenFrom ? { rewrittenFrom: entry.rewrittenFrom } : {}),
})),
}
: {}),
...(resolution ? { resolution } : {}),
...(size ? { size } : {}),
...(aspectRatio ? { aspectRatio } : {}),
...(filename ? { filename } : {}),
attempts: result.attempts,
metadata: result.metadata,
...(revisedPrompts.length > 0 ? { revisedPrompts } : {}),
},
};
},
};
}