refactor: move media generation runtimes into core

This commit is contained in:
Peter Steinberger
2026-04-05 15:13:08 +01:00
parent 5da21bc2f7
commit 9f2b760d33
20 changed files with 1062 additions and 179 deletions

View File

@@ -1,37 +1,111 @@
import { afterEach, describe, expect, it, vi } from "vitest";
import { beforeEach, describe, expect, it, vi } from "vitest";
import type { OpenClawConfig } from "../config/config.js";
import type { ImageGenerationProvider } from "../image-generation/types.js";
import {
generateImage,
listRuntimeImageGenerationProviders,
type GenerateImageRuntimeResult,
} from "./runtime.js";
import { generateImage, listRuntimeImageGenerationProviders } from "./runtime.js";
import type { ImageGenerationProvider } from "./types.js";
const mocks = vi.hoisted(() => ({
generateImage: vi.fn<typeof generateImage>(),
listRuntimeImageGenerationProviders: vi.fn<typeof listRuntimeImageGenerationProviders>(),
const mocks = vi.hoisted(() => {
const debug = vi.fn();
return {
createSubsystemLogger: vi.fn(() => ({ debug })),
describeFailoverError: vi.fn(),
getImageGenerationProvider: vi.fn<
(providerId: string, config?: OpenClawConfig) => ImageGenerationProvider | undefined
>(() => undefined),
getProviderEnvVars: vi.fn<(providerId: string) => string[]>(() => []),
isFailoverError: vi.fn<(err: unknown) => boolean>(() => false),
listImageGenerationProviders: vi.fn<(config?: OpenClawConfig) => ImageGenerationProvider[]>(
() => [],
),
parseImageGenerationModelRef: vi.fn<
(raw?: string) => { provider: string; model: string } | undefined
>((raw?: string) => {
const trimmed = raw?.trim();
if (!trimmed) {
return undefined;
}
const slash = trimmed.indexOf("/");
if (slash <= 0 || slash === trimmed.length - 1) {
return undefined;
}
return {
provider: trimmed.slice(0, slash),
model: trimmed.slice(slash + 1),
};
}),
resolveAgentModelFallbackValues: vi.fn<(value: unknown) => string[]>(() => []),
resolveAgentModelPrimaryValue: vi.fn<(value: unknown) => string | undefined>(() => undefined),
debug,
};
});
vi.mock("../agents/failover-error.js", () => ({
describeFailoverError: mocks.describeFailoverError,
isFailoverError: mocks.isFailoverError,
}));
vi.mock("../config/model-input.js", () => ({
resolveAgentModelFallbackValues: mocks.resolveAgentModelFallbackValues,
resolveAgentModelPrimaryValue: mocks.resolveAgentModelPrimaryValue,
}));
vi.mock("../logging/subsystem.js", () => ({
createSubsystemLogger: mocks.createSubsystemLogger,
}));
vi.mock("../secrets/provider-env-vars.js", () => ({
getProviderEnvVars: mocks.getProviderEnvVars,
}));
vi.mock("./model-ref.js", () => ({
parseImageGenerationModelRef: mocks.parseImageGenerationModelRef,
}));
vi.mock("./provider-registry.js", () => ({
getImageGenerationProvider: mocks.getImageGenerationProvider,
listImageGenerationProviders: mocks.listImageGenerationProviders,
}));
vi.mock("../../extensions/image-generation-core/runtime-api.js", () => ({
generateImage: mocks.generateImage,
listRuntimeImageGenerationProviders: mocks.listRuntimeImageGenerationProviders,
}));
describe("image-generation runtime facade", () => {
afterEach(() => {
mocks.generateImage.mockReset();
mocks.listRuntimeImageGenerationProviders.mockReset();
describe("image-generation runtime", () => {
beforeEach(() => {
mocks.createSubsystemLogger.mockClear();
mocks.describeFailoverError.mockReset();
mocks.getImageGenerationProvider.mockReset();
mocks.getProviderEnvVars.mockReset();
mocks.getProviderEnvVars.mockReturnValue([]);
mocks.isFailoverError.mockReset();
mocks.isFailoverError.mockReturnValue(false);
mocks.listImageGenerationProviders.mockReset();
mocks.listImageGenerationProviders.mockReturnValue([]);
mocks.parseImageGenerationModelRef.mockClear();
mocks.resolveAgentModelFallbackValues.mockReset();
mocks.resolveAgentModelFallbackValues.mockReturnValue([]);
mocks.resolveAgentModelPrimaryValue.mockReset();
mocks.resolveAgentModelPrimaryValue.mockReturnValue(undefined);
mocks.debug.mockReset();
});
it("delegates image generation to the image runtime", async () => {
const result: GenerateImageRuntimeResult = {
images: [{ buffer: Buffer.from("png-bytes"), mimeType: "image/png", fileName: "sample.png" }],
provider: "image-plugin",
model: "img-v1",
attempts: [],
it("generates images through the active image-generation provider", async () => {
const authStore = { version: 1, profiles: {} } as const;
let seenAuthStore: unknown;
mocks.resolveAgentModelPrimaryValue.mockReturnValue("image-plugin/img-v1");
const provider: ImageGenerationProvider = {
id: "image-plugin",
capabilities: {
generate: {},
edit: { enabled: false },
},
async generateImage(req: { authStore?: unknown }) {
seenAuthStore = req.authStore;
return {
images: [
{
buffer: Buffer.from("png-bytes"),
mimeType: "image/png",
fileName: "sample.png",
},
],
model: "img-v1",
};
},
};
mocks.generateImage.mockResolvedValue(result);
const params = {
mocks.getImageGenerationProvider.mockReturnValue(provider);
const result = await generateImage({
cfg: {
agents: {
defaults: {
@@ -41,19 +115,58 @@ describe("image-generation runtime facade", () => {
} as OpenClawConfig,
prompt: "draw a cat",
agentDir: "/tmp/agent",
authStore: { version: 1, profiles: {} },
};
authStore,
});
await expect(generateImage(params)).resolves.toBe(result);
expect(mocks.generateImage).toHaveBeenCalledWith(params);
expect(result.provider).toBe("image-plugin");
expect(result.model).toBe("img-v1");
expect(result.attempts).toEqual([]);
expect(seenAuthStore).toEqual(authStore);
expect(result.images).toEqual([
{
buffer: Buffer.from("png-bytes"),
mimeType: "image/png",
fileName: "sample.png",
},
]);
});
it("delegates provider listing to the image runtime", () => {
it("lists runtime image-generation providers through the provider registry", () => {
const providers: ImageGenerationProvider[] = [
{
id: "image-plugin",
defaultModel: "img-v1",
models: ["img-v1", "img-v2"],
capabilities: {
generate: {
supportsResolution: true,
},
edit: {
enabled: true,
maxInputImages: 3,
},
geometry: {
resolutions: ["1K", "2K"],
},
},
generateImage: async () => ({
images: [{ buffer: Buffer.from("png-bytes"), mimeType: "image/png" }],
}),
},
];
mocks.listImageGenerationProviders.mockReturnValue(providers);
expect(listRuntimeImageGenerationProviders({ config: {} as OpenClawConfig })).toEqual(
providers,
);
expect(mocks.listImageGenerationProviders).toHaveBeenCalledWith({} as OpenClawConfig);
});
it("builds a generic config hint without hardcoded provider ids", async () => {
mocks.listImageGenerationProviders.mockReturnValue([
{
id: "vision-one",
defaultModel: "paint-v1",
capabilities: {
generate: {},
edit: { enabled: false },
@@ -62,11 +175,35 @@ describe("image-generation runtime facade", () => {
images: [{ buffer: Buffer.from("png-bytes"), mimeType: "image/png" }],
}),
},
];
mocks.listRuntimeImageGenerationProviders.mockReturnValue(providers);
const params = { config: {} as OpenClawConfig };
{
id: "vision-two",
defaultModel: "paint-v2",
capabilities: {
generate: {},
edit: { enabled: false },
},
generateImage: async () => ({
images: [{ buffer: Buffer.from("png-bytes"), mimeType: "image/png" }],
}),
},
]);
mocks.getProviderEnvVars.mockImplementation((providerId: string) => {
if (providerId === "vision-one") {
return ["VISION_ONE_API_KEY"];
}
if (providerId === "vision-two") {
return ["VISION_TWO_API_KEY"];
}
return [];
});
expect(listRuntimeImageGenerationProviders(params)).toBe(providers);
expect(mocks.listRuntimeImageGenerationProviders).toHaveBeenCalledWith(params);
const promise = generateImage({ cfg: {} as OpenClawConfig, prompt: "draw a cat" });
await expect(promise).rejects.toThrow("No image-generation model configured.");
await expect(promise).rejects.toThrow(
'Set agents.defaults.imageGenerationModel.primary to a provider/model like "vision-one/paint-v1".',
);
await expect(promise).rejects.toThrow("vision-one: VISION_ONE_API_KEY");
await expect(promise).rejects.toThrow("vision-two: VISION_TWO_API_KEY");
});
});

View File

@@ -1,6 +1,186 @@
export {
generateImage,
listRuntimeImageGenerationProviders,
type GenerateImageParams,
type GenerateImageRuntimeResult,
} from "../../extensions/image-generation-core/runtime-api.js";
import type { AuthProfileStore } from "../agents/auth-profiles.js";
import { describeFailoverError, isFailoverError } from "../agents/failover-error.js";
import type { FallbackAttempt } from "../agents/model-fallback.types.js";
import type { OpenClawConfig } from "../config/config.js";
import {
resolveAgentModelFallbackValues,
resolveAgentModelPrimaryValue,
} from "../config/model-input.js";
import { createSubsystemLogger } from "../logging/subsystem.js";
import { getProviderEnvVars } from "../secrets/provider-env-vars.js";
import { parseImageGenerationModelRef } from "./model-ref.js";
import { getImageGenerationProvider, listImageGenerationProviders } from "./provider-registry.js";
import type {
GeneratedImageAsset,
ImageGenerationResolution,
ImageGenerationResult,
ImageGenerationSourceImage,
} from "./types.js";
const log = createSubsystemLogger("image-generation");
export type GenerateImageParams = {
cfg: OpenClawConfig;
prompt: string;
agentDir?: string;
authStore?: AuthProfileStore;
modelOverride?: string;
count?: number;
size?: string;
aspectRatio?: string;
resolution?: ImageGenerationResolution;
inputImages?: ImageGenerationSourceImage[];
};
export type GenerateImageRuntimeResult = {
images: GeneratedImageAsset[];
provider: string;
model: string;
attempts: FallbackAttempt[];
metadata?: Record<string, unknown>;
};
function resolveImageGenerationCandidates(params: {
cfg: OpenClawConfig;
modelOverride?: string;
}): Array<{ provider: string; model: string }> {
const candidates: Array<{ provider: string; model: string }> = [];
const seen = new Set<string>();
const add = (raw: string | undefined) => {
const parsed = parseImageGenerationModelRef(raw);
if (!parsed) {
return;
}
const key = `${parsed.provider}/${parsed.model}`;
if (seen.has(key)) {
return;
}
seen.add(key);
candidates.push(parsed);
};
add(params.modelOverride);
add(resolveAgentModelPrimaryValue(params.cfg.agents?.defaults?.imageGenerationModel));
for (const fallback of resolveAgentModelFallbackValues(
params.cfg.agents?.defaults?.imageGenerationModel,
)) {
add(fallback);
}
return candidates;
}
function throwImageGenerationFailure(params: {
attempts: FallbackAttempt[];
lastError: unknown;
}): never {
if (params.attempts.length <= 1 && params.lastError) {
throw params.lastError;
}
const summary =
params.attempts.length > 0
? params.attempts
.map((attempt) => `${attempt.provider}/${attempt.model}: ${attempt.error}`)
.join(" | ")
: "unknown";
throw new Error(`All image generation models failed (${params.attempts.length}): ${summary}`, {
cause: params.lastError instanceof Error ? params.lastError : undefined,
});
}
function buildNoImageGenerationModelConfiguredMessage(cfg: OpenClawConfig): string {
const providers = listImageGenerationProviders(cfg);
const sampleModel = providers.find(
(provider) => provider.id.trim().length > 0 && provider.defaultModel?.trim(),
);
const sampleRef = sampleModel
? `${sampleModel.id}/${sampleModel.defaultModel}`
: "<provider>/<model>";
const authHints = providers
.flatMap((provider) => {
const envVars = getProviderEnvVars(provider.id);
if (envVars.length === 0) {
return [];
}
return [`${provider.id}: ${envVars.join(" / ")}`];
})
.slice(0, 3);
return [
`No image-generation model configured. Set agents.defaults.imageGenerationModel.primary to a provider/model like "${sampleRef}".`,
authHints.length > 0
? `If you want a specific provider, also configure that provider's auth/API key first (${authHints.join("; ")}).`
: "If you want a specific provider, also configure that provider's auth/API key first.",
].join(" ");
}
export function listRuntimeImageGenerationProviders(params?: { config?: OpenClawConfig }) {
return listImageGenerationProviders(params?.config);
}
export async function generateImage(
params: GenerateImageParams,
): Promise<GenerateImageRuntimeResult> {
const candidates = resolveImageGenerationCandidates({
cfg: params.cfg,
modelOverride: params.modelOverride,
});
if (candidates.length === 0) {
throw new Error(buildNoImageGenerationModelConfiguredMessage(params.cfg));
}
const attempts: FallbackAttempt[] = [];
let lastError: unknown;
for (const candidate of candidates) {
const provider = getImageGenerationProvider(candidate.provider, params.cfg);
if (!provider) {
const error = `No image-generation provider registered for ${candidate.provider}`;
attempts.push({
provider: candidate.provider,
model: candidate.model,
error,
});
lastError = new Error(error);
continue;
}
try {
const result: ImageGenerationResult = await provider.generateImage({
provider: candidate.provider,
model: candidate.model,
prompt: params.prompt,
cfg: params.cfg,
agentDir: params.agentDir,
authStore: params.authStore,
count: params.count,
size: params.size,
aspectRatio: params.aspectRatio,
resolution: params.resolution,
inputImages: params.inputImages,
});
if (!Array.isArray(result.images) || result.images.length === 0) {
throw new Error("Image generation provider returned no images.");
}
return {
images: result.images,
provider: candidate.provider,
model: result.model ?? candidate.model,
attempts,
metadata: result.metadata,
};
} catch (err) {
lastError = err;
const described = isFailoverError(err) ? describeFailoverError(err) : undefined;
attempts.push({
provider: candidate.provider,
model: candidate.model,
error: described?.message ?? (err instanceof Error ? err.message : String(err)),
reason: described?.reason,
status: described?.status,
code: described?.code,
});
log.debug(`image-generation candidate failed: ${candidate.provider}/${candidate.model}`);
}
}
throwImageGenerationFailure({ attempts, lastError });
}

View File

@@ -18,6 +18,11 @@ export type ImageGenerationSourceImage = {
metadata?: Record<string, unknown>;
};
export type ImageGenerationProviderConfiguredContext = {
cfg?: OpenClawConfig;
agentDir?: string;
};
export type ImageGenerationRequest = {
provider: string;
model: string;
@@ -70,5 +75,6 @@ export type ImageGenerationProvider = {
defaultModel?: string;
models?: string[];
capabilities: ImageGenerationProviderCapabilities;
isConfigured?: (ctx: ImageGenerationProviderConfiguredContext) => boolean;
generateImage: (req: ImageGenerationRequest) => Promise<ImageGenerationResult>;
};