mirror of
https://github.com/openclaw/openclaw.git
synced 2026-05-06 18:20:44 +00:00
refactor: share OpenAI-compatible image provider
This commit is contained in:
@@ -7,11 +7,17 @@ const {
|
||||
postMultipartRequestMock,
|
||||
resolveApiKeyForProviderMock,
|
||||
resolveProviderHttpRequestConfigMock,
|
||||
createProviderOperationDeadlineMock,
|
||||
resolveProviderOperationTimeoutMsMock,
|
||||
} = vi.hoisted(() => ({
|
||||
assertOkOrThrowHttpErrorMock: vi.fn(async () => {}),
|
||||
postJsonRequestMock: vi.fn(),
|
||||
postMultipartRequestMock: vi.fn(),
|
||||
resolveApiKeyForProviderMock: vi.fn(async () => ({ apiKey: "deepinfra-key" })),
|
||||
createProviderOperationDeadlineMock: vi.fn((params: Record<string, unknown>) => params),
|
||||
resolveProviderOperationTimeoutMsMock: vi.fn(
|
||||
(params: Record<string, unknown>) => params.defaultTimeoutMs,
|
||||
),
|
||||
resolveProviderHttpRequestConfigMock: vi.fn((params: Record<string, unknown>) => ({
|
||||
baseUrl: params.baseUrl ?? params.defaultBaseUrl ?? "https://api.deepinfra.com/v1/openai",
|
||||
allowPrivateNetwork: false,
|
||||
@@ -26,9 +32,11 @@ vi.mock("openclaw/plugin-sdk/provider-auth-runtime", () => ({
|
||||
|
||||
vi.mock("openclaw/plugin-sdk/provider-http", () => ({
|
||||
assertOkOrThrowHttpError: assertOkOrThrowHttpErrorMock,
|
||||
createProviderOperationDeadline: createProviderOperationDeadlineMock,
|
||||
postJsonRequest: postJsonRequestMock,
|
||||
postMultipartRequest: postMultipartRequestMock,
|
||||
resolveProviderHttpRequestConfig: resolveProviderHttpRequestConfigMock,
|
||||
resolveProviderOperationTimeoutMs: resolveProviderOperationTimeoutMsMock,
|
||||
sanitizeConfiguredModelProviderRequest: vi.fn((request) => request),
|
||||
}));
|
||||
|
||||
|
||||
@@ -1,18 +1,8 @@
|
||||
import type { OpenClawConfig } from "openclaw/plugin-sdk/config-types";
|
||||
import type { ImageGenerationProvider } from "openclaw/plugin-sdk/image-generation";
|
||||
import {
|
||||
createOpenAiCompatibleImageGenerationProvider,
|
||||
imageSourceUploadFileName,
|
||||
parseOpenAiCompatibleImageResponse,
|
||||
type ImageGenerationProvider,
|
||||
} from "openclaw/plugin-sdk/image-generation";
|
||||
import { isProviderApiKeyConfigured } from "openclaw/plugin-sdk/provider-auth";
|
||||
import { resolveApiKeyForProvider } from "openclaw/plugin-sdk/provider-auth-runtime";
|
||||
import {
|
||||
assertOkOrThrowHttpError,
|
||||
postJsonRequest,
|
||||
postMultipartRequest,
|
||||
resolveProviderHttpRequestConfig,
|
||||
sanitizeConfiguredModelProviderRequest,
|
||||
} from "openclaw/plugin-sdk/provider-http";
|
||||
import { normalizeOptionalString } from "openclaw/plugin-sdk/text-runtime";
|
||||
import {
|
||||
DEEPINFRA_BASE_URL,
|
||||
@@ -26,35 +16,12 @@ import {
|
||||
const DEEPINFRA_IMAGE_SIZES = ["512x512", "1024x1024", "1024x1792", "1792x1024"] as const;
|
||||
const MAX_DEEPINFRA_INPUT_IMAGES = 1;
|
||||
|
||||
type DeepInfraProviderConfig = NonNullable<
|
||||
NonNullable<OpenClawConfig["models"]>["providers"]
|
||||
>[string];
|
||||
|
||||
type DeepInfraImageApiResponse = {
|
||||
data?: Array<{
|
||||
b64_json?: string;
|
||||
revised_prompt?: string;
|
||||
url?: string;
|
||||
}>;
|
||||
};
|
||||
|
||||
function resolveDeepInfraProviderConfig(
|
||||
cfg: OpenClawConfig | undefined,
|
||||
): DeepInfraProviderConfig | undefined {
|
||||
return cfg?.models?.providers?.deepinfra;
|
||||
}
|
||||
|
||||
export function buildDeepInfraImageGenerationProvider(): ImageGenerationProvider {
|
||||
return {
|
||||
return createOpenAiCompatibleImageGenerationProvider({
|
||||
id: "deepinfra",
|
||||
label: "DeepInfra",
|
||||
defaultModel: DEFAULT_DEEPINFRA_IMAGE_MODEL,
|
||||
models: [...DEEPINFRA_IMAGE_MODELS],
|
||||
isConfigured: ({ agentDir }) =>
|
||||
isProviderApiKeyConfigured({
|
||||
provider: "deepinfra",
|
||||
agentDir,
|
||||
}),
|
||||
capabilities: {
|
||||
generate: {
|
||||
maxCount: 4,
|
||||
@@ -74,111 +41,49 @@ export function buildDeepInfraImageGenerationProvider(): ImageGenerationProvider
|
||||
sizes: [...DEEPINFRA_IMAGE_SIZES],
|
||||
},
|
||||
},
|
||||
async generateImage(req) {
|
||||
const inputImages = req.inputImages ?? [];
|
||||
const isEdit = inputImages.length > 0;
|
||||
if (inputImages.length > MAX_DEEPINFRA_INPUT_IMAGES) {
|
||||
throw new Error("DeepInfra image editing supports one reference image.");
|
||||
defaultBaseUrl: DEEPINFRA_BASE_URL,
|
||||
normalizeModel: normalizeDeepInfraModelRef,
|
||||
resolveBaseUrl: ({ providerConfig }) =>
|
||||
normalizeDeepInfraBaseUrl(providerConfig?.baseUrl, DEEPINFRA_BASE_URL),
|
||||
resolveAllowPrivateNetwork: () => false,
|
||||
useConfiguredRequest: true,
|
||||
resolveCount: ({ req, mode }) => (mode === "edit" ? 1 : (req.count ?? 1)),
|
||||
buildGenerateRequest: ({ req, model, count }) => ({
|
||||
kind: "json",
|
||||
body: {
|
||||
model,
|
||||
prompt: req.prompt,
|
||||
n: count,
|
||||
size: normalizeOptionalString(req.size) ?? DEFAULT_DEEPINFRA_IMAGE_SIZE,
|
||||
response_format: "b64_json",
|
||||
},
|
||||
}),
|
||||
buildEditRequest: ({ req, inputImages, model, count }) => {
|
||||
const image = inputImages[0];
|
||||
if (!image) {
|
||||
throw new Error("DeepInfra image edit missing reference image.");
|
||||
}
|
||||
const auth = await resolveApiKeyForProvider({
|
||||
provider: "deepinfra",
|
||||
cfg: req.cfg,
|
||||
agentDir: req.agentDir,
|
||||
store: req.authStore,
|
||||
});
|
||||
if (!auth.apiKey) {
|
||||
throw new Error("DeepInfra API key missing");
|
||||
}
|
||||
|
||||
const providerConfig = resolveDeepInfraProviderConfig(req.cfg);
|
||||
const resolvedBaseUrl = normalizeDeepInfraBaseUrl(
|
||||
providerConfig?.baseUrl,
|
||||
DEEPINFRA_BASE_URL,
|
||||
const form = new FormData();
|
||||
form.set("model", model);
|
||||
form.set("prompt", req.prompt);
|
||||
form.set("n", String(count));
|
||||
form.set("size", normalizeOptionalString(req.size) ?? DEFAULT_DEEPINFRA_IMAGE_SIZE);
|
||||
form.set("response_format", "b64_json");
|
||||
const mimeType = normalizeOptionalString(image.mimeType) ?? "image/png";
|
||||
form.append(
|
||||
"image",
|
||||
new Blob([new Uint8Array(image.buffer)], { type: mimeType }),
|
||||
imageSourceUploadFileName({ image, index: 0 }),
|
||||
);
|
||||
const { baseUrl, allowPrivateNetwork, headers, dispatcherPolicy } =
|
||||
resolveProviderHttpRequestConfig({
|
||||
baseUrl: resolvedBaseUrl,
|
||||
defaultBaseUrl: DEEPINFRA_BASE_URL,
|
||||
allowPrivateNetwork: false,
|
||||
request: sanitizeConfiguredModelProviderRequest(providerConfig?.request),
|
||||
defaultHeaders: {
|
||||
Authorization: `Bearer ${auth.apiKey}`,
|
||||
},
|
||||
provider: "deepinfra",
|
||||
capability: "image",
|
||||
transport: "http",
|
||||
});
|
||||
|
||||
const model = normalizeDeepInfraModelRef(req.model, DEFAULT_DEEPINFRA_IMAGE_MODEL);
|
||||
const count = isEdit ? 1 : (req.count ?? 1);
|
||||
const size = normalizeOptionalString(req.size) ?? DEFAULT_DEEPINFRA_IMAGE_SIZE;
|
||||
const endpoint = isEdit ? "images/edits" : "images/generations";
|
||||
const request = isEdit
|
||||
? (() => {
|
||||
const form = new FormData();
|
||||
form.set("model", model);
|
||||
form.set("prompt", req.prompt);
|
||||
form.set("n", String(count));
|
||||
form.set("size", size);
|
||||
form.set("response_format", "b64_json");
|
||||
const image = inputImages[0];
|
||||
if (!image) {
|
||||
throw new Error("DeepInfra image edit missing reference image.");
|
||||
}
|
||||
const mimeType = normalizeOptionalString(image.mimeType) ?? "image/png";
|
||||
form.append(
|
||||
"image",
|
||||
new Blob([new Uint8Array(image.buffer)], { type: mimeType }),
|
||||
imageSourceUploadFileName({ image, index: 0 }),
|
||||
);
|
||||
const multipartHeaders = new Headers(headers);
|
||||
multipartHeaders.delete("Content-Type");
|
||||
return postMultipartRequest({
|
||||
url: `${baseUrl}/${endpoint}`,
|
||||
headers: multipartHeaders,
|
||||
body: form,
|
||||
timeoutMs: req.timeoutMs,
|
||||
fetchFn: fetch,
|
||||
allowPrivateNetwork,
|
||||
dispatcherPolicy,
|
||||
});
|
||||
})()
|
||||
: postJsonRequest({
|
||||
url: `${baseUrl}/${endpoint}`,
|
||||
headers: new Headers({
|
||||
...Object.fromEntries(headers.entries()),
|
||||
"Content-Type": "application/json",
|
||||
}),
|
||||
body: {
|
||||
model,
|
||||
prompt: req.prompt,
|
||||
n: count,
|
||||
size,
|
||||
response_format: "b64_json",
|
||||
},
|
||||
timeoutMs: req.timeoutMs,
|
||||
fetchFn: fetch,
|
||||
allowPrivateNetwork,
|
||||
dispatcherPolicy,
|
||||
});
|
||||
|
||||
const { response, release } = await request;
|
||||
try {
|
||||
await assertOkOrThrowHttpError(
|
||||
response,
|
||||
isEdit ? "DeepInfra image edit failed" : "DeepInfra image generation failed",
|
||||
);
|
||||
const images = parseOpenAiCompatibleImageResponse(
|
||||
(await response.json()) as DeepInfraImageApiResponse,
|
||||
{ defaultMimeType: "image/jpeg", sniffMimeType: true },
|
||||
);
|
||||
if (images.length === 0) {
|
||||
throw new Error("DeepInfra image response did not include generated image data");
|
||||
}
|
||||
return { images, model };
|
||||
} finally {
|
||||
await release();
|
||||
}
|
||||
return { kind: "multipart", form };
|
||||
},
|
||||
};
|
||||
response: { defaultMimeType: "image/jpeg", sniffMimeType: true },
|
||||
tooManyInputImagesError: "DeepInfra image editing supports one reference image.",
|
||||
missingApiKeyError: "DeepInfra API key missing",
|
||||
emptyResponseError: "DeepInfra image response did not include generated image data",
|
||||
failureLabels: {
|
||||
generate: "DeepInfra image generation failed",
|
||||
edit: "DeepInfra image edit failed",
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
@@ -4,19 +4,27 @@ import { buildLitellmImageGenerationProvider } from "./image-generation-provider
|
||||
const {
|
||||
resolveApiKeyForProviderMock,
|
||||
postJsonRequestMock,
|
||||
postMultipartRequestMock,
|
||||
assertOkOrThrowHttpErrorMock,
|
||||
createProviderOperationDeadlineMock,
|
||||
resolveProviderHttpRequestConfigMock,
|
||||
resolveProviderOperationTimeoutMsMock,
|
||||
sanitizeConfiguredModelProviderRequestMock,
|
||||
} = vi.hoisted(() => ({
|
||||
resolveApiKeyForProviderMock: vi.fn(async () => ({ apiKey: "litellm-key" })),
|
||||
postJsonRequestMock: vi.fn(),
|
||||
postMultipartRequestMock: vi.fn(),
|
||||
assertOkOrThrowHttpErrorMock: vi.fn(async () => {}),
|
||||
createProviderOperationDeadlineMock: vi.fn((params: Record<string, unknown>) => params),
|
||||
resolveProviderHttpRequestConfigMock: vi.fn((params) => ({
|
||||
baseUrl: params.baseUrl ?? params.defaultBaseUrl,
|
||||
allowPrivateNetwork: Boolean(params.allowPrivateNetwork ?? params.request?.allowPrivateNetwork),
|
||||
headers: new Headers(params.defaultHeaders),
|
||||
dispatcherPolicy: undefined as unknown,
|
||||
})),
|
||||
resolveProviderOperationTimeoutMsMock: vi.fn(
|
||||
(params: Record<string, unknown>) => params.defaultTimeoutMs,
|
||||
),
|
||||
sanitizeConfiguredModelProviderRequestMock: vi.fn((request) => request),
|
||||
}));
|
||||
|
||||
@@ -26,8 +34,11 @@ vi.mock("openclaw/plugin-sdk/provider-auth-runtime", () => ({
|
||||
|
||||
vi.mock("openclaw/plugin-sdk/provider-http", () => ({
|
||||
assertOkOrThrowHttpError: assertOkOrThrowHttpErrorMock,
|
||||
createProviderOperationDeadline: createProviderOperationDeadlineMock,
|
||||
postJsonRequest: postJsonRequestMock,
|
||||
postMultipartRequest: postMultipartRequestMock,
|
||||
resolveProviderHttpRequestConfig: resolveProviderHttpRequestConfigMock,
|
||||
resolveProviderOperationTimeoutMs: resolveProviderOperationTimeoutMsMock,
|
||||
sanitizeConfiguredModelProviderRequest: sanitizeConfiguredModelProviderRequestMock,
|
||||
}));
|
||||
|
||||
|
||||
@@ -1,17 +1,10 @@
|
||||
import type { OpenClawConfig } from "openclaw/plugin-sdk/config-types";
|
||||
import type { ImageGenerationProvider } from "openclaw/plugin-sdk/image-generation";
|
||||
import {
|
||||
parseOpenAiCompatibleImageResponse,
|
||||
createOpenAiCompatibleImageGenerationProvider,
|
||||
type ImageGenerationProvider,
|
||||
type ImageGenerationSourceImage,
|
||||
toImageDataUrl,
|
||||
} from "openclaw/plugin-sdk/image-generation";
|
||||
import { isProviderApiKeyConfigured } from "openclaw/plugin-sdk/provider-auth";
|
||||
import { resolveApiKeyForProvider } from "openclaw/plugin-sdk/provider-auth-runtime";
|
||||
import {
|
||||
assertOkOrThrowHttpError,
|
||||
postJsonRequest,
|
||||
resolveProviderHttpRequestConfig,
|
||||
sanitizeConfiguredModelProviderRequest,
|
||||
} from "openclaw/plugin-sdk/provider-http";
|
||||
import { normalizeOptionalString } from "openclaw/plugin-sdk/text-runtime";
|
||||
import { LITELLM_BASE_URL } from "./onboard.js";
|
||||
|
||||
@@ -46,6 +39,10 @@ function resolveConfiguredLitellmBaseUrl(cfg: OpenClawConfig | undefined): strin
|
||||
return normalizeOptionalString(resolveLitellmProviderConfig(cfg)?.baseUrl) ?? LITELLM_BASE_URL;
|
||||
}
|
||||
|
||||
function imageToDataUrl(image: ImageGenerationSourceImage): string {
|
||||
return toImageDataUrl({ buffer: image.buffer, mimeType: image.mimeType });
|
||||
}
|
||||
|
||||
// LiteLLM's default proxy is loopback. Auto-enable private-network access only
|
||||
// for loopback-style hosts; LAN/custom private endpoints should use the
|
||||
// explicit models.providers.litellm.request.allowPrivateNetwork opt-in.
|
||||
@@ -85,24 +82,12 @@ function shouldAutoAllowPrivateLitellmEndpoint(baseUrl: string): boolean {
|
||||
}
|
||||
}
|
||||
|
||||
type LitellmImageApiResponse = {
|
||||
data?: Array<{
|
||||
b64_json?: string;
|
||||
revised_prompt?: string;
|
||||
}>;
|
||||
};
|
||||
|
||||
export function buildLitellmImageGenerationProvider(): ImageGenerationProvider {
|
||||
return {
|
||||
return createOpenAiCompatibleImageGenerationProvider({
|
||||
id: "litellm",
|
||||
label: "LiteLLM",
|
||||
defaultModel: DEFAULT_LITELLM_IMAGE_MODEL,
|
||||
models: [DEFAULT_LITELLM_IMAGE_MODEL],
|
||||
isConfigured: ({ agentDir }) =>
|
||||
isProviderApiKeyConfigured({
|
||||
provider: "litellm",
|
||||
agentDir,
|
||||
}),
|
||||
capabilities: {
|
||||
generate: {
|
||||
maxCount: 4,
|
||||
@@ -122,84 +107,36 @@ export function buildLitellmImageGenerationProvider(): ImageGenerationProvider {
|
||||
sizes: [...LITELLM_SUPPORTED_SIZES],
|
||||
},
|
||||
},
|
||||
async generateImage(req) {
|
||||
const inputImages = req.inputImages ?? [];
|
||||
const isEdit = inputImages.length > 0;
|
||||
const auth = await resolveApiKeyForProvider({
|
||||
provider: "litellm",
|
||||
cfg: req.cfg,
|
||||
agentDir: req.agentDir,
|
||||
store: req.authStore,
|
||||
});
|
||||
if (!auth.apiKey) {
|
||||
throw new Error("LiteLLM API key missing");
|
||||
}
|
||||
const providerConfig = resolveLitellmProviderConfig(req.cfg);
|
||||
const resolvedBaseUrl = resolveConfiguredLitellmBaseUrl(req.cfg);
|
||||
const { baseUrl, allowPrivateNetwork, headers, dispatcherPolicy } =
|
||||
resolveProviderHttpRequestConfig({
|
||||
baseUrl: resolvedBaseUrl,
|
||||
defaultBaseUrl: LITELLM_BASE_URL,
|
||||
allowPrivateNetwork: shouldAutoAllowPrivateLitellmEndpoint(resolvedBaseUrl)
|
||||
? true
|
||||
: undefined,
|
||||
request: sanitizeConfiguredModelProviderRequest(providerConfig?.request),
|
||||
defaultHeaders: {
|
||||
Authorization: `Bearer ${auth.apiKey}`,
|
||||
},
|
||||
provider: "litellm",
|
||||
capability: "image",
|
||||
transport: "http",
|
||||
});
|
||||
|
||||
const model = req.model || DEFAULT_LITELLM_IMAGE_MODEL;
|
||||
const count = req.count ?? 1;
|
||||
const size = req.size ?? DEFAULT_SIZE;
|
||||
|
||||
const jsonHeaders = new Headers(headers);
|
||||
jsonHeaders.set("Content-Type", "application/json");
|
||||
const endpoint = isEdit ? "images/edits" : "images/generations";
|
||||
const body = isEdit
|
||||
? {
|
||||
model,
|
||||
prompt: req.prompt,
|
||||
n: count,
|
||||
size,
|
||||
images: inputImages.map((image) => ({
|
||||
image_url: toImageDataUrl(image),
|
||||
})),
|
||||
}
|
||||
: {
|
||||
model,
|
||||
prompt: req.prompt,
|
||||
n: count,
|
||||
size,
|
||||
};
|
||||
const { response, release } = await postJsonRequest({
|
||||
url: `${baseUrl}/${endpoint}`,
|
||||
headers: jsonHeaders,
|
||||
body,
|
||||
timeoutMs: req.timeoutMs,
|
||||
fetchFn: fetch,
|
||||
allowPrivateNetwork,
|
||||
dispatcherPolicy,
|
||||
});
|
||||
try {
|
||||
await assertOkOrThrowHttpError(
|
||||
response,
|
||||
isEdit ? "LiteLLM image edit failed" : "LiteLLM image generation failed",
|
||||
);
|
||||
|
||||
const data = (await response.json()) as LitellmImageApiResponse;
|
||||
const images = parseOpenAiCompatibleImageResponse(data);
|
||||
|
||||
return {
|
||||
images,
|
||||
model,
|
||||
};
|
||||
} finally {
|
||||
await release();
|
||||
}
|
||||
defaultBaseUrl: LITELLM_BASE_URL,
|
||||
resolveBaseUrl: ({ req }) => resolveConfiguredLitellmBaseUrl(req.cfg),
|
||||
resolveAllowPrivateNetwork: ({ baseUrl }) =>
|
||||
shouldAutoAllowPrivateLitellmEndpoint(baseUrl) ? true : undefined,
|
||||
useConfiguredRequest: true,
|
||||
buildGenerateRequest: ({ req, model, count }) => ({
|
||||
kind: "json",
|
||||
body: {
|
||||
model,
|
||||
prompt: req.prompt,
|
||||
n: count,
|
||||
size: req.size ?? DEFAULT_SIZE,
|
||||
},
|
||||
}),
|
||||
buildEditRequest: ({ req, inputImages, model, count }) => ({
|
||||
kind: "json",
|
||||
body: {
|
||||
model,
|
||||
prompt: req.prompt,
|
||||
n: count,
|
||||
size: req.size ?? DEFAULT_SIZE,
|
||||
images: inputImages.map((image) => ({
|
||||
image_url: imageToDataUrl(image),
|
||||
})),
|
||||
},
|
||||
}),
|
||||
missingApiKeyError: "LiteLLM API key missing",
|
||||
failureLabels: {
|
||||
generate: "LiteLLM image generation failed",
|
||||
edit: "LiteLLM image edit failed",
|
||||
},
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
@@ -4,7 +4,10 @@ import type {
|
||||
ImageGenerationOutputFormat,
|
||||
ImageGenerationProvider,
|
||||
ImageGenerationResult,
|
||||
ImageGenerationSourceImage,
|
||||
} from "openclaw/plugin-sdk/image-generation";
|
||||
import {
|
||||
parseOpenAiCompatibleImageResponse,
|
||||
toImageDataUrl,
|
||||
} from "openclaw/plugin-sdk/image-generation";
|
||||
import { createSubsystemLogger } from "openclaw/plugin-sdk/logging-core";
|
||||
import { resolveClosestSize } from "openclaw/plugin-sdk/media-generation-runtime";
|
||||
@@ -388,11 +391,6 @@ function inferImageUploadFileName(params: {
|
||||
return `image-${params.index + 1}.${ext}`;
|
||||
}
|
||||
|
||||
function toOpenAIDataUrl(image: ImageGenerationSourceImage): string {
|
||||
const mimeType = image.mimeType?.trim() || DEFAULT_OUTPUT_MIME;
|
||||
return `data:${mimeType};base64,${Buffer.from(image.buffer).toString("base64")}`;
|
||||
}
|
||||
|
||||
async function readResponseBodyText(response: Response): Promise<string> {
|
||||
if (!response.body) {
|
||||
const text = await response.text();
|
||||
@@ -643,7 +641,7 @@ async function generateOpenAICodexImage(params: {
|
||||
{ type: "input_text", text: req.prompt },
|
||||
...inputImages.map((image) => ({
|
||||
type: "input_image",
|
||||
image_url: toOpenAIDataUrl(image),
|
||||
image_url: toImageDataUrl({ buffer: image.buffer, mimeType: image.mimeType }),
|
||||
detail: "auto",
|
||||
})),
|
||||
];
|
||||
@@ -876,21 +874,13 @@ export function buildOpenAIImageGenerationProvider(): ImageGenerationProvider {
|
||||
|
||||
const data = (await response.json()) as OpenAIImageApiResponse;
|
||||
const output = resolveOutputMime(req.outputFormat);
|
||||
const images = (data.data ?? [])
|
||||
.map((entry, index) => {
|
||||
if (!entry.b64_json) {
|
||||
return null;
|
||||
}
|
||||
return Object.assign(
|
||||
{
|
||||
buffer: Buffer.from(entry.b64_json, `base64`),
|
||||
mimeType: output.mimeType,
|
||||
fileName: `image-${index + 1}.${output.extension}`,
|
||||
},
|
||||
entry.revised_prompt ? { revisedPrompt: entry.revised_prompt } : {},
|
||||
);
|
||||
})
|
||||
.filter((entry): entry is NonNullable<typeof entry> => entry !== null);
|
||||
const images = parseOpenAiCompatibleImageResponse(data, {
|
||||
defaultMimeType: output.mimeType,
|
||||
}).map((image, index) =>
|
||||
Object.assign(image, {
|
||||
fileName: `image-${index + 1}.${output.extension}`,
|
||||
}),
|
||||
);
|
||||
|
||||
return {
|
||||
images,
|
||||
|
||||
@@ -4,13 +4,16 @@ import { buildXaiImageGenerationProvider } from "./image-generation-provider.js"
|
||||
const {
|
||||
resolveApiKeyForProviderMock,
|
||||
postJsonRequestMock,
|
||||
postMultipartRequestMock,
|
||||
assertOkOrThrowHttpErrorMock,
|
||||
resolveProviderHttpRequestConfigMock,
|
||||
createProviderOperationDeadlineMock,
|
||||
resolveProviderOperationTimeoutMsMock,
|
||||
sanitizeConfiguredModelProviderRequestMock,
|
||||
} = vi.hoisted(() => ({
|
||||
resolveApiKeyForProviderMock: vi.fn(async () => ({ apiKey: "xai-key" })),
|
||||
postJsonRequestMock: vi.fn(),
|
||||
postMultipartRequestMock: vi.fn(),
|
||||
assertOkOrThrowHttpErrorMock: vi.fn(async () => {}),
|
||||
resolveProviderHttpRequestConfigMock: vi.fn((params: Record<string, unknown>) => ({
|
||||
baseUrl: params.baseUrl ?? params.defaultBaseUrl ?? "https://api.x.ai/v1",
|
||||
@@ -25,6 +28,7 @@ const {
|
||||
resolveProviderOperationTimeoutMsMock: vi.fn(
|
||||
(params: Record<string, unknown>) => params.defaultTimeoutMs ?? 60000,
|
||||
),
|
||||
sanitizeConfiguredModelProviderRequestMock: vi.fn((request) => request),
|
||||
}));
|
||||
|
||||
vi.mock("openclaw/plugin-sdk/provider-auth-runtime", () => ({
|
||||
@@ -35,8 +39,10 @@ vi.mock("openclaw/plugin-sdk/provider-http", () => ({
|
||||
assertOkOrThrowHttpError: assertOkOrThrowHttpErrorMock,
|
||||
createProviderOperationDeadline: createProviderOperationDeadlineMock,
|
||||
postJsonRequest: postJsonRequestMock,
|
||||
postMultipartRequest: postMultipartRequestMock,
|
||||
resolveProviderHttpRequestConfig: resolveProviderHttpRequestConfigMock,
|
||||
resolveProviderOperationTimeoutMs: resolveProviderOperationTimeoutMsMock,
|
||||
sanitizeConfiguredModelProviderRequest: sanitizeConfiguredModelProviderRequestMock,
|
||||
}));
|
||||
|
||||
vi.mock("openclaw/plugin-sdk/text-runtime", () => ({
|
||||
@@ -54,6 +60,7 @@ describe("xai image generation provider", () => {
|
||||
resolveProviderHttpRequestConfigMock.mockClear();
|
||||
createProviderOperationDeadlineMock.mockClear();
|
||||
resolveProviderOperationTimeoutMsMock.mockClear();
|
||||
sanitizeConfiguredModelProviderRequestMock.mockClear();
|
||||
});
|
||||
|
||||
it("builds provider with correct models, default, and capabilities", () => {
|
||||
@@ -174,4 +181,50 @@ describe("xai image generation provider", () => {
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it("uses the plural xAI images payload for multiple edit inputs", async () => {
|
||||
postJsonRequestMock.mockResolvedValue({
|
||||
response: {
|
||||
json: async () => ({
|
||||
data: [
|
||||
{
|
||||
b64_json: Buffer.from("edited").toString("base64"),
|
||||
mime_type: "image/png",
|
||||
},
|
||||
],
|
||||
}),
|
||||
},
|
||||
release: vi.fn(async () => {}),
|
||||
});
|
||||
|
||||
const provider = buildXaiImageGenerationProvider();
|
||||
await provider.generateImage({
|
||||
provider: "xai",
|
||||
model: "grok-imagine-image",
|
||||
prompt: "Combine the references",
|
||||
inputImages: [
|
||||
{ buffer: Buffer.from("first"), mimeType: "image/png" },
|
||||
{ buffer: Buffer.from("second"), mimeType: "image/jpeg" },
|
||||
],
|
||||
cfg: {},
|
||||
} as any);
|
||||
|
||||
expect(postJsonRequestMock).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
url: expect.stringContaining("/images/edits"),
|
||||
body: expect.objectContaining({
|
||||
images: [
|
||||
{
|
||||
url: expect.stringContaining("data:image/png;base64,"),
|
||||
type: "image_url",
|
||||
},
|
||||
{
|
||||
url: expect.stringContaining("data:image/jpeg;base64,"),
|
||||
type: "image_url",
|
||||
},
|
||||
],
|
||||
}),
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,21 +1,12 @@
|
||||
import type {
|
||||
ImageGenerationProvider,
|
||||
ImageGenerationRequest,
|
||||
ImageGenerationResult,
|
||||
ImageGenerationSourceImage,
|
||||
} from "openclaw/plugin-sdk/image-generation";
|
||||
import {
|
||||
parseOpenAiCompatibleImageResponse,
|
||||
createOpenAiCompatibleImageGenerationProvider,
|
||||
toImageDataUrl,
|
||||
} from "openclaw/plugin-sdk/image-generation";
|
||||
import { isProviderApiKeyConfigured } from "openclaw/plugin-sdk/provider-auth";
|
||||
import { resolveApiKeyForProvider } from "openclaw/plugin-sdk/provider-auth-runtime";
|
||||
import {
|
||||
assertOkOrThrowHttpError,
|
||||
createProviderOperationDeadline,
|
||||
postJsonRequest,
|
||||
resolveProviderHttpRequestConfig,
|
||||
resolveProviderOperationTimeoutMs,
|
||||
} from "openclaw/plugin-sdk/provider-http";
|
||||
import {
|
||||
normalizeOptionalLowercaseString,
|
||||
normalizeOptionalString,
|
||||
@@ -26,16 +17,8 @@ const DEFAULT_TIMEOUT_MS = 60_000;
|
||||
|
||||
const XAI_SUPPORTED_ASPECT_RATIOS = ["1:1", "16:9", "9:16", "4:3", "3:4", "2:3", "3:2"] as const;
|
||||
|
||||
type XaiImageApiResponse = {
|
||||
data?: Array<{
|
||||
b64_json?: string;
|
||||
mime_type?: string;
|
||||
revised_prompt?: string;
|
||||
}>;
|
||||
};
|
||||
|
||||
function resolveImageForEdit(
|
||||
input: { url?: string; buffer?: Buffer; mimeType?: string } | undefined,
|
||||
input: (ImageGenerationSourceImage & { url?: string }) | undefined,
|
||||
): string {
|
||||
if (!input) {
|
||||
throw new Error("xAI image edit requires an input image.");
|
||||
@@ -50,44 +33,42 @@ function resolveImageForEdit(
|
||||
return toImageDataUrl({ buffer: input.buffer, mimeType: input.mimeType });
|
||||
}
|
||||
|
||||
function isEdit(req: ImageGenerationRequest): boolean {
|
||||
return (req.inputImages?.length ?? 0) > 0;
|
||||
}
|
||||
|
||||
function resolveXaiImageBaseUrl(req: ImageGenerationRequest): string {
|
||||
return normalizeOptionalString(req.cfg?.models?.providers?.xai?.baseUrl) ?? XAI_BASE_URL;
|
||||
}
|
||||
|
||||
function buildBody(req: ImageGenerationRequest, edit: boolean): Record<string, unknown> {
|
||||
const model = normalizeOptionalString(req.model) ?? XAI_DEFAULT_IMAGE_MODEL;
|
||||
const count = req.count ?? 1;
|
||||
function buildBody(params: {
|
||||
req: ImageGenerationRequest;
|
||||
inputImages: ImageGenerationSourceImage[];
|
||||
model: string;
|
||||
count: number;
|
||||
}): Record<string, unknown> {
|
||||
const body: Record<string, unknown> = {
|
||||
model,
|
||||
prompt: req.prompt,
|
||||
n: Math.min(count, 4),
|
||||
model: params.model,
|
||||
prompt: params.req.prompt,
|
||||
n: Math.min(params.count, 4),
|
||||
response_format: "b64_json" as const,
|
||||
};
|
||||
|
||||
const aspect = normalizeOptionalString(req.aspectRatio);
|
||||
const aspect = normalizeOptionalString(params.req.aspectRatio);
|
||||
if (aspect && (XAI_SUPPORTED_ASPECT_RATIOS as readonly string[]).includes(aspect)) {
|
||||
body.aspect_ratio = aspect;
|
||||
}
|
||||
|
||||
const resolution = normalizeOptionalLowercaseString(req.resolution);
|
||||
const resolution = normalizeOptionalLowercaseString(params.req.resolution);
|
||||
if (resolution) {
|
||||
body.resolution = resolution;
|
||||
}
|
||||
|
||||
if (edit) {
|
||||
const inputImages = req.inputImages ?? [];
|
||||
if (inputImages.length > 1) {
|
||||
body.images = inputImages.map((input) => ({
|
||||
if (params.inputImages.length > 0) {
|
||||
if (params.inputImages.length > 1) {
|
||||
body.images = params.inputImages.map((input) => ({
|
||||
url: resolveImageForEdit(input),
|
||||
type: "image_url",
|
||||
}));
|
||||
} else {
|
||||
body.image = {
|
||||
url: resolveImageForEdit(inputImages[0]),
|
||||
url: resolveImageForEdit(params.inputImages[0]),
|
||||
type: "image_url",
|
||||
};
|
||||
}
|
||||
@@ -97,16 +78,11 @@ function buildBody(req: ImageGenerationRequest, edit: boolean): Record<string, u
|
||||
}
|
||||
|
||||
export function buildXaiImageGenerationProvider(): ImageGenerationProvider {
|
||||
return {
|
||||
return createOpenAiCompatibleImageGenerationProvider({
|
||||
id: "xai",
|
||||
label: "xAI",
|
||||
defaultModel: XAI_DEFAULT_IMAGE_MODEL,
|
||||
models: [...XAI_IMAGE_MODELS],
|
||||
isConfigured: ({ agentDir }) =>
|
||||
isProviderApiKeyConfigured({
|
||||
provider: "xai",
|
||||
agentDir,
|
||||
}),
|
||||
capabilities: {
|
||||
generate: {
|
||||
maxCount: 4,
|
||||
@@ -127,72 +103,22 @@ export function buildXaiImageGenerationProvider(): ImageGenerationProvider {
|
||||
resolutions: ["1K", "2K"],
|
||||
},
|
||||
},
|
||||
async generateImage(req: ImageGenerationRequest): Promise<ImageGenerationResult> {
|
||||
const edit = isEdit(req);
|
||||
const auth = await resolveApiKeyForProvider({
|
||||
provider: "xai",
|
||||
cfg: req.cfg,
|
||||
agentDir: req.agentDir,
|
||||
store: req.authStore,
|
||||
});
|
||||
if (!auth.apiKey) {
|
||||
throw new Error("xAI API key missing");
|
||||
}
|
||||
|
||||
const fetchFn = fetch;
|
||||
const deadline = createProviderOperationDeadline({
|
||||
timeoutMs: req.timeoutMs,
|
||||
label: edit ? "xAI image edit" : "xAI image generation",
|
||||
});
|
||||
const {
|
||||
baseUrl: resolvedBaseUrl,
|
||||
allowPrivateNetwork,
|
||||
headers,
|
||||
dispatcherPolicy,
|
||||
} = resolveProviderHttpRequestConfig({
|
||||
baseUrl: resolveXaiImageBaseUrl(req),
|
||||
defaultBaseUrl: XAI_BASE_URL,
|
||||
allowPrivateNetwork: false,
|
||||
defaultHeaders: {
|
||||
Authorization: `Bearer ${auth.apiKey}`,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
provider: "xai",
|
||||
capability: "image",
|
||||
transport: "http",
|
||||
});
|
||||
|
||||
const body = buildBody(req, edit);
|
||||
const endpoint = edit ? "/images/edits" : "/images/generations";
|
||||
const { response, release } = await postJsonRequest({
|
||||
url: `${resolvedBaseUrl}${endpoint}`,
|
||||
headers,
|
||||
body,
|
||||
timeoutMs: resolveProviderOperationTimeoutMs({
|
||||
deadline,
|
||||
defaultTimeoutMs: DEFAULT_TIMEOUT_MS,
|
||||
}),
|
||||
fetchFn,
|
||||
allowPrivateNetwork,
|
||||
dispatcherPolicy,
|
||||
});
|
||||
|
||||
try {
|
||||
await assertOkOrThrowHttpError(
|
||||
response,
|
||||
edit ? "xAI image edit failed" : "xAI image generation failed",
|
||||
);
|
||||
|
||||
const payload = (await response.json()) as XaiImageApiResponse;
|
||||
const images = parseOpenAiCompatibleImageResponse(payload);
|
||||
|
||||
return {
|
||||
images,
|
||||
model: normalizeOptionalString(req.model) ?? XAI_DEFAULT_IMAGE_MODEL,
|
||||
};
|
||||
} finally {
|
||||
await release();
|
||||
}
|
||||
defaultBaseUrl: XAI_BASE_URL,
|
||||
resolveBaseUrl: ({ req }) => resolveXaiImageBaseUrl(req),
|
||||
resolveAllowPrivateNetwork: () => false,
|
||||
defaultTimeoutMs: DEFAULT_TIMEOUT_MS,
|
||||
buildGenerateRequest: ({ req, inputImages, model, count }) => ({
|
||||
kind: "json",
|
||||
body: buildBody({ req, inputImages, model, count }),
|
||||
}),
|
||||
buildEditRequest: ({ req, inputImages, model, count }) => ({
|
||||
kind: "json",
|
||||
body: buildBody({ req, inputImages, model, count }),
|
||||
}),
|
||||
missingApiKeyError: "xAI API key missing",
|
||||
failureLabels: {
|
||||
generate: "xAI image generation failed",
|
||||
edit: "xAI image edit failed",
|
||||
},
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user