From 0ad2dbd3075437ab873494a8427f75d48418fe39 Mon Sep 17 00:00:00 2001 From: Vincent Koc Date: Fri, 3 Apr 2026 00:32:37 +0900 Subject: [PATCH] fix(providers): route image generation through shared transport (#59729) * fix(providers): route image generation through shared transport * fix(providers): use normalized minimax image base url * fix(providers): fail closed on image private routes * fix(providers): bound shared HTTP fetches --- CHANGELOG.md | 2 + .../fal/image-generation-provider.test.ts | 31 ++-- extensions/fal/image-generation-provider.ts | 63 ++++---- .../minimax/image-generation-provider.test.ts | 147 ++++++++++++++++++ .../minimax/image-generation-provider.ts | 126 ++++++++------- .../openai/image-generation-provider.ts | 93 ++++++----- extensions/openai/index.test.ts | 31 ++++ src/media-understanding/shared.test.ts | 89 ++++++++++- src/media-understanding/shared.ts | 77 ++++++++- 9 files changed, 499 insertions(+), 160 deletions(-) create mode 100644 extensions/minimax/image-generation-provider.test.ts diff --git a/CHANGELOG.md b/CHANGELOG.md index c800b126981..e4c812f739c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -46,6 +46,8 @@ Docs: https://docs.openclaw.ai - Providers/streaming headers: centralize default and attribution header merging across OpenAI websocket, embedded-runner, and proxy stream paths so provider-specific headers stay consistent and caller overrides only win where intended. (#59542) Thanks @vincentkoc. - Providers/Anthropic routing: centralize native-vs-proxy endpoint classification for direct Anthropic `service_tier` handling so spoofed or proxied hosts do not inherit native Anthropic defaults. (#59608) Thanks @vincentkoc. - Providers/transport policy: centralize request auth, proxy, TLS, and header shaping across shared HTTP, stream, and websocket paths, block insecure TLS/runtime transport overrides, and keep proxy-hop TLS separate from target mTLS settings. (#59682) Thanks @vincentkoc. +- Image generation/providers: route OpenAI, MiniMax, and fal image requests through the shared provider HTTP transport path so custom base URLs, guarded private-network routing, and provider request defaults stay aligned with the rest of provider HTTP. Thanks @vincentkoc. +- Image generation/providers: stop inferring private-network access from configured OpenAI, MiniMax, and fal image base URLs, and cap shared HTTP error-body reads so hostile or misconfigured endpoints fail closed without relaxing SSRF policy or buffering unbounded error payloads. Thanks @vincentkoc. - Browser/host inspection: keep static Chrome inspection helpers out of the activated browser runtime so `openclaw doctor browser` and related checks do not eagerly load the bundled browser plugin. (#59471) Thanks @vincentkoc. - Gateway/exec loopback: restore legacy-role fallback for empty paired-device token maps and allow silent local role upgrades so local exec and node clients stop failing with pairing-required errors after `2026.3.31`. (#59092) Thanks @openperf. - Agents/output sanitization: strip namespaced `antml:thinking` blocks from user-visible text so Anthropic-style internal monologue tags do not leak into replies. (#59550) Thanks @obviyus. diff --git a/extensions/fal/image-generation-provider.test.ts b/extensions/fal/image-generation-provider.test.ts index b9d22fd8ac2..a2e510ed49a 100644 --- a/extensions/fal/image-generation-provider.test.ts +++ b/extensions/fal/image-generation-provider.test.ts @@ -11,23 +11,14 @@ import { } from "./image-generation-provider.js"; function expectFalJsonPost(params: { call: number; url: string; body: Record }) { - expect(fetchWithSsrFGuardMock).toHaveBeenNthCalledWith( - params.call, - expect.objectContaining({ - url: params.url, - init: expect.objectContaining({ - method: "POST", - headers: expect.objectContaining({ - Authorization: "Key fal-test-key", - "Content-Type": "application/json", - }), - }), - auditContext: "fal-image-generate", - }), - ); - const request = fetchWithSsrFGuardMock.mock.calls[params.call - 1]?.[0]; expect(request).toBeTruthy(); + expect(request?.url).toBe(params.url); + expect(request?.auditContext).toBe("fal-image-generate"); + expect(request?.init?.method).toBe("POST"); + const headers = new Headers(request?.init?.headers); + expect(headers.get("authorization")).toBe("Key fal-test-key"); + expect(headers.get("content-type")).toBe("application/json"); expect(JSON.parse(String(request?.init?.body))).toEqual(params.body); } @@ -361,17 +352,13 @@ describe("fal image-generation provider", () => { ); }); - it("allows trusted private relay hosts derived from configured baseUrl", async () => { + it("does not auto-whitelist trusted private relay hosts from a configured baseUrl", async () => { vi.spyOn(providerAuth, "resolveApiKeyForProvider").mockResolvedValue({ apiKey: "fal-test-key", source: "env", mode: "api-key", }); _setFalFetchGuardForTesting(fetchWithSsrFGuardMock); - const relayPolicy = { - allowPrivateNetwork: true, - hostnameAllowlist: ["relay.internal", "*.relay.internal"], - }; fetchWithSsrFGuardMock .mockResolvedValueOnce({ response: new Response( @@ -415,7 +402,7 @@ describe("fal image-generation provider", () => { expect.objectContaining({ url: "http://relay.internal:8080/fal-ai/flux/dev", auditContext: "fal-image-generate", - policy: relayPolicy, + policy: undefined, }), ); expect(fetchWithSsrFGuardMock).toHaveBeenNthCalledWith( @@ -423,7 +410,7 @@ describe("fal image-generation provider", () => { expect.objectContaining({ url: "http://media.relay.internal/files/generated.png", auditContext: "fal-image-download", - policy: relayPolicy, + policy: undefined, }), ); }); diff --git a/extensions/fal/image-generation-provider.ts b/extensions/fal/image-generation-provider.ts index 4772adfb863..4b957fbf2af 100644 --- a/extensions/fal/image-generation-provider.ts +++ b/extensions/fal/image-generation-provider.ts @@ -3,6 +3,10 @@ import type { ImageGenerationProvider, } from "openclaw/plugin-sdk/image-generation"; import { resolveApiKeyForProvider } from "openclaw/plugin-sdk/provider-auth-runtime"; +import { + assertOkOrThrowHttpError, + resolveProviderHttpRequestConfig, +} from "openclaw/plugin-sdk/provider-http"; import { buildHostnameAllowlistPolicyFromSuffixAllowlist, fetchWithSsrFGuard, @@ -81,40 +85,32 @@ function matchesTrustedHostSuffix(hostname: string, trustedSuffix: string): bool return normalizedHost === normalizedSuffix || normalizedHost.endsWith(`.${normalizedSuffix}`); } -function resolveFalNetworkPolicy( - cfg: Parameters[0]["cfg"], -): FalNetworkPolicy { - const baseUrl = resolveFalBaseUrl(cfg); - const explicitBaseUrl = cfg?.models?.providers?.fal?.baseUrl?.trim(); +function resolveFalNetworkPolicy(params: { + baseUrl: string; + allowPrivateNetwork: boolean; +}): FalNetworkPolicy { let parsedBaseUrl: URL; try { - parsedBaseUrl = new URL(baseUrl); + parsedBaseUrl = new URL(params.baseUrl); } catch { return {}; } const hostSuffix = parsedBaseUrl.hostname.trim().toLowerCase(); - if (!hostSuffix) { + if (!hostSuffix || !params.allowPrivateNetwork) { return {}; } const hostPolicy = buildHostnameAllowlistPolicyFromSuffixAllowlist([hostSuffix]); - const privateNetworkPolicy = explicitBaseUrl - ? ssrfPolicyFromAllowPrivateNetwork(true) - : undefined; + const privateNetworkPolicy = ssrfPolicyFromAllowPrivateNetwork(true); const trustedHostPolicy = mergeSsrFPolicies(hostPolicy, privateNetworkPolicy); return { apiPolicy: trustedHostPolicy, - trustedDownloadHostSuffix: explicitBaseUrl ? hostSuffix : undefined, - trustedDownloadPolicy: explicitBaseUrl ? trustedHostPolicy : undefined, + trustedDownloadHostSuffix: hostSuffix, + trustedDownloadPolicy: trustedHostPolicy, }; } -function resolveFalBaseUrl(cfg: Parameters[0]["cfg"]): string { - const direct = cfg?.models?.providers?.fal?.baseUrl?.trim(); - return (direct || DEFAULT_FAL_BASE_URL).replace(/\/+$/u, ""); -} - function ensureFalModelPath(model: string | undefined, hasInputImages: boolean): string { const trimmed = model?.trim() || DEFAULT_FAL_IMAGE_MODEL; if (!hasInputImages) { @@ -341,7 +337,21 @@ export function buildFalImageGenerationProvider(): ImageGenerationProvider { hasInputImages, }); const model = ensureFalModelPath(req.model, hasInputImages); - const networkPolicy = resolveFalNetworkPolicy(req.cfg); + const explicitBaseUrl = req.cfg?.models?.providers?.fal?.baseUrl?.trim(); + const { baseUrl, allowPrivateNetwork, headers, dispatcherPolicy } = + resolveProviderHttpRequestConfig({ + baseUrl: explicitBaseUrl, + defaultBaseUrl: DEFAULT_FAL_BASE_URL, + allowPrivateNetwork: false, + defaultHeaders: { + Authorization: `Key ${auth.apiKey}`, + "Content-Type": "application/json", + }, + provider: "fal", + capability: "image", + transport: "http", + }); + const networkPolicy = resolveFalNetworkPolicy({ baseUrl, allowPrivateNetwork }); const requestBody: Record = { prompt: req.prompt, num_images: req.count ?? 1, @@ -358,27 +368,20 @@ export function buildFalImageGenerationProvider(): ImageGenerationProvider { } requestBody.image_url = toDataUri(input.buffer, input.mimeType); } - const { response, release } = await falFetchGuard({ - url: `${resolveFalBaseUrl(req.cfg)}/${model}`, + url: `${baseUrl}/${model}`, init: { method: "POST", - headers: { - Authorization: `Key ${auth.apiKey}`, - "Content-Type": "application/json", - }, + headers, body: JSON.stringify(requestBody), }, + timeoutMs: req.timeoutMs, policy: networkPolicy.apiPolicy, + dispatcherPolicy, auditContext: "fal-image-generate", }); try { - if (!response.ok) { - const text = await response.text().catch(() => ""); - throw new Error( - `fal image generation failed (${response.status}): ${text || response.statusText}`, - ); - } + await assertOkOrThrowHttpError(response, "fal image generation failed"); const payload = (await response.json()) as FalImageGenerationResponse; const images: GeneratedImageAsset[] = []; diff --git a/extensions/minimax/image-generation-provider.test.ts b/extensions/minimax/image-generation-provider.test.ts new file mode 100644 index 00000000000..bb1e5bc1575 --- /dev/null +++ b/extensions/minimax/image-generation-provider.test.ts @@ -0,0 +1,147 @@ +import * as providerAuth from "openclaw/plugin-sdk/provider-auth-runtime"; +import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; +import { buildMinimaxImageGenerationProvider } from "./image-generation-provider.js"; + +describe("minimax image-generation provider", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + it("generates PNG buffers through the shared provider HTTP path", async () => { + vi.spyOn(providerAuth, "resolveApiKeyForProvider").mockResolvedValue({ + apiKey: "minimax-test-key", + source: "env", + mode: "api-key", + }); + const fetchMock = vi.fn().mockResolvedValue( + new Response( + JSON.stringify({ + data: { + image_base64: [Buffer.from("png-data").toString("base64")], + }, + base_resp: { status_code: 0 }, + }), + { + status: 200, + headers: { "Content-Type": "application/json" }, + }, + ), + ); + vi.stubGlobal("fetch", fetchMock); + + const provider = buildMinimaxImageGenerationProvider(); + const result = await provider.generateImage({ + provider: "minimax", + model: "image-01", + prompt: "draw a cat", + cfg: {}, + }); + + expect(fetchMock).toHaveBeenCalledWith( + "https://api.minimax.io/v1/image_generation", + expect.objectContaining({ + method: "POST", + body: JSON.stringify({ + model: "image-01", + prompt: "draw a cat", + response_format: "base64", + n: 1, + }), + }), + ); + const [, init] = fetchMock.mock.calls[0] as [string, RequestInit]; + const headers = new Headers(init.headers); + expect(headers.get("authorization")).toBe("Bearer minimax-test-key"); + expect(headers.get("content-type")).toBe("application/json"); + expect(result).toEqual({ + images: [ + { + buffer: Buffer.from("png-data"), + mimeType: "image/png", + fileName: "image-1.png", + }, + ], + model: "image-01", + }); + }); + + it("uses the configured provider base URL origin", async () => { + vi.spyOn(providerAuth, "resolveApiKeyForProvider").mockResolvedValue({ + apiKey: "minimax-test-key", + source: "env", + mode: "api-key", + }); + const fetchMock = vi.fn().mockResolvedValue( + new Response( + JSON.stringify({ + data: { + image_base64: [Buffer.from("png-data").toString("base64")], + }, + base_resp: { status_code: 0 }, + }), + { + status: 200, + headers: { "Content-Type": "application/json" }, + }, + ), + ); + vi.stubGlobal("fetch", fetchMock); + + const provider = buildMinimaxImageGenerationProvider(); + await provider.generateImage({ + provider: "minimax", + model: "image-01", + prompt: "draw a cat", + cfg: { + models: { + providers: { + minimax: { + baseUrl: "https://api.minimax.io/anthropic", + models: [], + }, + }, + }, + }, + }); + + expect(fetchMock).toHaveBeenCalledWith( + "https://api.minimax.io/v1/image_generation", + expect.any(Object), + ); + }); + + it("does not allow private-network routing just because a custom base URL is configured", async () => { + vi.spyOn(providerAuth, "resolveApiKeyForProvider").mockResolvedValue({ + apiKey: "minimax-test-key", + source: "env", + mode: "api-key", + }); + const fetchMock = vi.fn(); + vi.stubGlobal("fetch", fetchMock); + + const provider = buildMinimaxImageGenerationProvider(); + await expect( + provider.generateImage({ + provider: "minimax", + model: "image-01", + prompt: "draw a cat", + cfg: { + models: { + providers: { + minimax: { + baseUrl: "http://127.0.0.1:8080/anthropic", + models: [], + }, + }, + }, + }, + }), + ).rejects.toThrow("Blocked hostname or private/internal/special-use IP address"); + + expect(fetchMock).not.toHaveBeenCalled(); + }); +}); diff --git a/extensions/minimax/image-generation-provider.ts b/extensions/minimax/image-generation-provider.ts index d42d3088863..416cc3d31dd 100644 --- a/extensions/minimax/image-generation-provider.ts +++ b/extensions/minimax/image-generation-provider.ts @@ -1,5 +1,10 @@ import type { ImageGenerationProvider } from "openclaw/plugin-sdk/image-generation"; import { resolveApiKeyForProvider } from "openclaw/plugin-sdk/provider-auth-runtime"; +import { + assertOkOrThrowHttpError, + postJsonRequest, + resolveProviderHttpRequestConfig, +} from "openclaw/plugin-sdk/provider-http"; const DEFAULT_MINIMAX_IMAGE_BASE_URL = "https://api.minimax.io"; const DEFAULT_MODEL = "image-01"; @@ -83,6 +88,23 @@ function buildMinimaxImageProvider(providerId: string): ImageGenerationProvider } const baseUrl = resolveMinimaxImageBaseUrl(req.cfg, providerId); + const { + baseUrl: resolvedBaseUrl, + allowPrivateNetwork, + headers, + dispatcherPolicy, + } = resolveProviderHttpRequestConfig({ + baseUrl, + defaultBaseUrl: DEFAULT_MINIMAX_IMAGE_BASE_URL, + allowPrivateNetwork: false, + defaultHeaders: { + Authorization: `Bearer ${auth.apiKey}`, + "Content-Type": "application/json", + }, + provider: providerId, + capability: "image", + transport: "http", + }); const body: Record = { model: req.model || DEFAULT_MODEL, @@ -102,67 +124,55 @@ function buildMinimaxImageProvider(providerId: string): ImageGenerationProvider const dataUrl = `data:${mime};base64,${ref.buffer.toString("base64")}`; body.subject_reference = [{ type: "character", image_file: dataUrl }]; } - - const controller = new AbortController(); - const timeoutMs = req.timeoutMs; - const timeout = - typeof timeoutMs === "number" && Number.isFinite(timeoutMs) && timeoutMs > 0 - ? setTimeout(() => controller.abort(), timeoutMs) - : undefined; - - const response = await fetch(`${baseUrl}/v1/image_generation`, { - method: "POST", - headers: { - Authorization: `Bearer ${auth.apiKey}`, - "Content-Type": "application/json", - }, - body: JSON.stringify(body), - signal: controller.signal, - }).finally(() => { - clearTimeout(timeout); + const { response, release } = await postJsonRequest({ + url: `${resolvedBaseUrl}/v1/image_generation`, + headers, + body, + timeoutMs: req.timeoutMs, + fetchFn: fetch, + allowPrivateNetwork, + dispatcherPolicy, }); + try { + await assertOkOrThrowHttpError(response, "MiniMax image generation failed"); - if (!response.ok) { - const text = await response.text().catch(() => ""); - throw new Error( - `MiniMax image generation failed (${response.status}): ${text || response.statusText}`, - ); + const data = (await response.json()) as MinimaxImageApiResponse; + + const baseResp = data.base_resp; + if (baseResp && typeof baseResp.status_code === "number" && baseResp.status_code !== 0) { + const msg = baseResp.status_msg ?? ""; + throw new Error(`MiniMax image generation API error (${baseResp.status_code}): ${msg}`); + } + + const base64Images = data.data?.image_base64 ?? []; + const failedCount = data.metadata?.failed_count ?? 0; + + if (base64Images.length === 0) { + const reason = + failedCount > 0 ? `${failedCount} image(s) failed to generate` : "no images returned"; + throw new Error(`MiniMax image generation returned no images: ${reason}`); + } + + const images = base64Images + .map((b64, index) => { + if (!b64) { + return null; + } + return { + buffer: Buffer.from(b64, "base64"), + mimeType: DEFAULT_OUTPUT_MIME, + fileName: `image-${index + 1}.png`, + }; + }) + .filter((entry): entry is NonNullable => entry !== null); + + return { + images, + model: req.model || DEFAULT_MODEL, + }; + } finally { + await release(); } - - const data = (await response.json()) as MinimaxImageApiResponse; - - const baseResp = data.base_resp; - if (baseResp && typeof baseResp.status_code === "number" && baseResp.status_code !== 0) { - const msg = baseResp.status_msg ?? ""; - throw new Error(`MiniMax image generation API error (${baseResp.status_code}): ${msg}`); - } - - const base64Images = data.data?.image_base64 ?? []; - const failedCount = data.metadata?.failed_count ?? 0; - - if (base64Images.length === 0) { - const reason = - failedCount > 0 ? `${failedCount} image(s) failed to generate` : "no images returned"; - throw new Error(`MiniMax image generation returned no images: ${reason}`); - } - - const images = base64Images - .map((b64, index) => { - if (!b64) { - return null; - } - return { - buffer: Buffer.from(b64, "base64"), - mimeType: DEFAULT_OUTPUT_MIME, - fileName: `image-${index + 1}.png`, - }; - }) - .filter((entry): entry is NonNullable => entry !== null); - - return { - images, - model: req.model || DEFAULT_MODEL, - }; }, }; } diff --git a/extensions/openai/image-generation-provider.ts b/extensions/openai/image-generation-provider.ts index c7338a1be4f..c2acb061776 100644 --- a/extensions/openai/image-generation-provider.ts +++ b/extensions/openai/image-generation-provider.ts @@ -1,5 +1,10 @@ import type { ImageGenerationProvider } from "openclaw/plugin-sdk/image-generation"; import { resolveApiKeyForProvider } from "openclaw/plugin-sdk/provider-auth-runtime"; +import { + assertOkOrThrowHttpError, + postJsonRequest, + resolveProviderHttpRequestConfig, +} from "openclaw/plugin-sdk/provider-http"; import { OPENAI_DEFAULT_IMAGE_MODEL as DEFAULT_OPENAI_IMAGE_MODEL } from "./default-models.js"; const DEFAULT_OPENAI_IMAGE_BASE_URL = "https://api.openai.com/v1"; @@ -57,57 +62,59 @@ export function buildOpenAIImageGenerationProvider(): ImageGenerationProvider { if (!auth.apiKey) { throw new Error("OpenAI API key missing"); } + const { baseUrl, allowPrivateNetwork, headers, dispatcherPolicy } = + resolveProviderHttpRequestConfig({ + baseUrl: resolveOpenAIBaseUrl(req.cfg), + defaultBaseUrl: DEFAULT_OPENAI_IMAGE_BASE_URL, + allowPrivateNetwork: false, + defaultHeaders: { + Authorization: `Bearer ${auth.apiKey}`, + "Content-Type": "application/json", + }, + provider: "openai", + capability: "image", + transport: "http", + }); - const controller = new AbortController(); - const timeoutMs = req.timeoutMs; - const timeout = - typeof timeoutMs === "number" && Number.isFinite(timeoutMs) && timeoutMs > 0 - ? setTimeout(() => controller.abort(), timeoutMs) - : undefined; - - const response = await fetch(`${resolveOpenAIBaseUrl(req.cfg)}/images/generations`, { - method: "POST", - headers: { - Authorization: `Bearer ${auth.apiKey}`, - "Content-Type": "application/json", - }, - body: JSON.stringify({ + const { response, release } = await postJsonRequest({ + url: `${baseUrl}/images/generations`, + headers, + body: { model: req.model || DEFAULT_OPENAI_IMAGE_MODEL, prompt: req.prompt, n: req.count ?? 1, size: req.size ?? DEFAULT_SIZE, - }), - signal: controller.signal, - }).finally(() => { - clearTimeout(timeout); + }, + timeoutMs: req.timeoutMs, + fetchFn: fetch, + allowPrivateNetwork, + dispatcherPolicy, }); + try { + await assertOkOrThrowHttpError(response, "OpenAI image generation failed"); - if (!response.ok) { - const text = await response.text().catch(() => ""); - throw new Error( - `OpenAI image generation failed (${response.status}): ${text || response.statusText}`, - ); + const data = (await response.json()) as OpenAIImageApiResponse; + const images = (data.data ?? []) + .map((entry, index) => { + if (!entry.b64_json) { + return null; + } + return { + buffer: Buffer.from(entry.b64_json, "base64"), + mimeType: DEFAULT_OUTPUT_MIME, + fileName: `image-${index + 1}.png`, + ...(entry.revised_prompt ? { revisedPrompt: entry.revised_prompt } : {}), + }; + }) + .filter((entry): entry is NonNullable => entry !== null); + + return { + images, + model: req.model || DEFAULT_OPENAI_IMAGE_MODEL, + }; + } finally { + await release(); } - - const data = (await response.json()) as OpenAIImageApiResponse; - const images = (data.data ?? []) - .map((entry, index) => { - if (!entry.b64_json) { - return null; - } - return { - buffer: Buffer.from(entry.b64_json, "base64"), - mimeType: DEFAULT_OUTPUT_MIME, - fileName: `image-${index + 1}.png`, - ...(entry.revised_prompt ? { revisedPrompt: entry.revised_prompt } : {}), - }; - }) - .filter((entry): entry is NonNullable => entry !== null); - - return { - images, - model: req.model || DEFAULT_OPENAI_IMAGE_MODEL, - }; }, }; } diff --git a/extensions/openai/index.test.ts b/extensions/openai/index.test.ts index 9c788aab5e2..8534f728447 100644 --- a/extensions/openai/index.test.ts +++ b/extensions/openai/index.test.ts @@ -117,6 +117,37 @@ describe("openai plugin", () => { ).rejects.toThrow("does not support reference-image edits"); }); + it("does not allow private-network routing just because a custom base URL is configured", async () => { + vi.spyOn(providerAuth, "resolveApiKeyForProvider").mockResolvedValue({ + apiKey: "sk-test", + source: "env", + mode: "api-key", + }); + const fetchMock = vi.fn(); + vi.stubGlobal("fetch", fetchMock); + + const provider = buildOpenAIImageGenerationProvider(); + await expect( + provider.generateImage({ + provider: "openai", + model: "gpt-image-1", + prompt: "draw a cat", + cfg: { + models: { + providers: { + openai: { + baseUrl: "http://127.0.0.1:8080/v1", + models: [], + }, + }, + }, + } satisfies OpenClawConfig, + }), + ).rejects.toThrow("Blocked hostname or private/internal/special-use IP address"); + + expect(fetchMock).not.toHaveBeenCalled(); + }); + it("bootstraps the env proxy dispatcher before refreshing codex oauth credentials", async () => { const refreshed = { access: "next-access", diff --git a/src/media-understanding/shared.test.ts b/src/media-understanding/shared.test.ts index c118024d123..d72d69133df 100644 --- a/src/media-understanding/shared.test.ts +++ b/src/media-understanding/shared.test.ts @@ -1,5 +1,26 @@ -import { describe, expect, it } from "vitest"; -import { resolveProviderHttpRequestConfig } from "./shared.js"; +import { afterEach, describe, expect, it, vi } from "vitest"; + +const { fetchWithSsrFGuardMock } = vi.hoisted(() => ({ + fetchWithSsrFGuardMock: vi.fn(), +})); + +vi.mock("../infra/net/fetch-guard.js", async (importOriginal) => { + const actual = await importOriginal(); + return { + ...actual, + fetchWithSsrFGuard: fetchWithSsrFGuardMock, + }; +}); + +import { + fetchWithTimeoutGuarded, + readErrorResponse, + resolveProviderHttpRequestConfig, +} from "./shared.js"; + +afterEach(() => { + vi.clearAllMocks(); +}); describe("resolveProviderHttpRequestConfig", () => { it("preserves explicit caller headers but protects attribution headers", () => { @@ -108,3 +129,67 @@ describe("resolveProviderHttpRequestConfig", () => { ).toThrow("Missing baseUrl"); }); }); + +describe("readErrorResponse", () => { + it("caps streamed error bodies instead of buffering the whole response", async () => { + const encoder = new TextEncoder(); + let reads = 0; + const response = new Response( + new ReadableStream({ + pull(controller) { + reads += 1; + controller.enqueue(encoder.encode("a".repeat(2048))); + if (reads >= 10) { + controller.close(); + } + }, + }), + { + status: 500, + }, + ); + + const detail = await readErrorResponse(response); + + expect(detail).toBe(`${"a".repeat(300)}…`); + expect(reads).toBe(2); + }); +}); + +describe("fetchWithTimeoutGuarded", () => { + it("applies a default timeout when callers omit one", async () => { + fetchWithSsrFGuardMock.mockResolvedValue({ + response: new Response(null, { status: 200 }), + finalUrl: "https://example.com", + release: async () => {}, + }); + + await fetchWithTimeoutGuarded("https://example.com", {}, undefined, fetch); + + expect(fetchWithSsrFGuardMock).toHaveBeenCalledWith( + expect.objectContaining({ + url: "https://example.com", + timeoutMs: 60_000, + }), + ); + }); + + it("sanitizes auditContext before passing it to the SSRF guard", async () => { + fetchWithSsrFGuardMock.mockResolvedValue({ + response: new Response(null, { status: 200 }), + finalUrl: "https://example.com", + release: async () => {}, + }); + + await fetchWithTimeoutGuarded("https://example.com", {}, 5000, fetch, { + auditContext: "provider-http\r\nfal\timage\u001btest", + }); + + expect(fetchWithSsrFGuardMock).toHaveBeenCalledWith( + expect.objectContaining({ + auditContext: "provider-http fal image test", + timeoutMs: 5000, + }), + ); + }); +}); diff --git a/src/media-understanding/shared.ts b/src/media-understanding/shared.ts index ad1bb689722..2de103cb8ec 100644 --- a/src/media-understanding/shared.ts +++ b/src/media-understanding/shared.ts @@ -16,6 +16,27 @@ export { fetchWithTimeout } from "../utils/fetch-timeout.js"; export { normalizeBaseUrl } from "../agents/provider-request-config.js"; const MAX_ERROR_CHARS = 300; +const MAX_ERROR_RESPONSE_BYTES = 4096; +const DEFAULT_GUARDED_HTTP_TIMEOUT_MS = 60_000; +const MAX_AUDIT_CONTEXT_CHARS = 80; + +function resolveGuardedHttpTimeoutMs(timeoutMs: number | undefined): number { + if (typeof timeoutMs !== "number" || !Number.isFinite(timeoutMs) || timeoutMs <= 0) { + return DEFAULT_GUARDED_HTTP_TIMEOUT_MS; + } + return timeoutMs; +} + +function sanitizeAuditContext(auditContext: string | undefined): string | undefined { + const cleaned = auditContext + ?.replace(/\p{Cc}+/gu, " ") + .replace(/\s+/g, " ") + .trim(); + if (!cleaned) { + return undefined; + } + return cleaned.slice(0, MAX_AUDIT_CONTEXT_CHARS); +} export function resolveProviderHttpRequestConfig(params: { baseUrl?: string; @@ -67,24 +88,26 @@ export function resolveProviderHttpRequestConfig(params: { export async function fetchWithTimeoutGuarded( url: string, init: RequestInit, - timeoutMs: number, + timeoutMs: number | undefined, fetchFn: typeof fetch, options?: { ssrfPolicy?: SsrFPolicy; lookupFn?: LookupFn; pinDns?: boolean; dispatcherPolicy?: PinnedDispatcherPolicy; + auditContext?: string; }, ): Promise { return await fetchWithSsrFGuard({ url, fetchImpl: fetchFn, init, - timeoutMs, + timeoutMs: resolveGuardedHttpTimeoutMs(timeoutMs), policy: options?.ssrfPolicy, lookupFn: options?.lookupFn, pinDns: options?.pinDns, dispatcherPolicy: options?.dispatcherPolicy, + auditContext: sanitizeAuditContext(options?.auditContext), }); } @@ -92,10 +115,11 @@ export async function postTranscriptionRequest(params: { url: string; headers: Headers; body: BodyInit; - timeoutMs: number; + timeoutMs?: number; fetchFn: typeof fetch; allowPrivateNetwork?: boolean; dispatcherPolicy?: PinnedDispatcherPolicy; + auditContext?: string; }) { return fetchWithTimeoutGuarded( params.url, @@ -110,6 +134,7 @@ export async function postTranscriptionRequest(params: { ? { ...(params.allowPrivateNetwork ? { ssrfPolicy: { allowPrivateNetwork: true } } : {}), ...(params.dispatcherPolicy ? { dispatcherPolicy: params.dispatcherPolicy } : {}), + ...(params.auditContext ? { auditContext: params.auditContext } : {}), } : undefined, ); @@ -119,10 +144,11 @@ export async function postJsonRequest(params: { url: string; headers: Headers; body: unknown; - timeoutMs: number; + timeoutMs?: number; fetchFn: typeof fetch; allowPrivateNetwork?: boolean; dispatcherPolicy?: PinnedDispatcherPolicy; + auditContext?: string; }) { return fetchWithTimeoutGuarded( params.url, @@ -137,14 +163,49 @@ export async function postJsonRequest(params: { ? { ...(params.allowPrivateNetwork ? { ssrfPolicy: { allowPrivateNetwork: true } } : {}), ...(params.dispatcherPolicy ? { dispatcherPolicy: params.dispatcherPolicy } : {}), + ...(params.auditContext ? { auditContext: params.auditContext } : {}), } : undefined, ); } export async function readErrorResponse(res: Response): Promise { + let reader: ReadableStreamDefaultReader | undefined; try { - const text = await res.text(); + if (!res.body) { + return undefined; + } + reader = res.body.getReader(); + const chunks: Uint8Array[] = []; + let total = 0; + let sawBytes = false; + while (total < MAX_ERROR_RESPONSE_BYTES) { + const { done, value } = await reader.read(); + if (done) { + break; + } + if (!value || value.length === 0) { + continue; + } + sawBytes = true; + const remaining = MAX_ERROR_RESPONSE_BYTES - total; + const chunk = value.length <= remaining ? value : value.subarray(0, remaining); + chunks.push(chunk); + total += chunk.length; + if (chunk.length < value.length) { + break; + } + } + if (!sawBytes) { + return undefined; + } + const bytes = new Uint8Array(total); + let offset = 0; + for (const chunk of chunks) { + bytes.set(chunk, offset); + offset += chunk.length; + } + const text = new TextDecoder().decode(bytes); const collapsed = text.replace(/\s+/g, " ").trim(); if (!collapsed) { return undefined; @@ -155,6 +216,12 @@ export async function readErrorResponse(res: Response): Promise