fix(providers): route image generation through shared transport (#59729)

* fix(providers): route image generation through shared transport

* fix(providers): use normalized minimax image base url

* fix(providers): fail closed on image private routes

* fix(providers): bound shared HTTP fetches
This commit is contained in:
Vincent Koc
2026-04-03 00:32:37 +09:00
committed by GitHub
parent d2ce3e9acc
commit 0ad2dbd307
9 changed files with 499 additions and 160 deletions

View File

@@ -46,6 +46,8 @@ Docs: https://docs.openclaw.ai
- Providers/streaming headers: centralize default and attribution header merging across OpenAI websocket, embedded-runner, and proxy stream paths so provider-specific headers stay consistent and caller overrides only win where intended. (#59542) Thanks @vincentkoc.
- Providers/Anthropic routing: centralize native-vs-proxy endpoint classification for direct Anthropic `service_tier` handling so spoofed or proxied hosts do not inherit native Anthropic defaults. (#59608) Thanks @vincentkoc.
- Providers/transport policy: centralize request auth, proxy, TLS, and header shaping across shared HTTP, stream, and websocket paths, block insecure TLS/runtime transport overrides, and keep proxy-hop TLS separate from target mTLS settings. (#59682) Thanks @vincentkoc.
- Image generation/providers: route OpenAI, MiniMax, and fal image requests through the shared provider HTTP transport path so custom base URLs, guarded private-network routing, and provider request defaults stay aligned with the rest of provider HTTP. Thanks @vincentkoc.
- Image generation/providers: stop inferring private-network access from configured OpenAI, MiniMax, and fal image base URLs, and cap shared HTTP error-body reads so hostile or misconfigured endpoints fail closed without relaxing SSRF policy or buffering unbounded error payloads. Thanks @vincentkoc.
- Browser/host inspection: keep static Chrome inspection helpers out of the activated browser runtime so `openclaw doctor browser` and related checks do not eagerly load the bundled browser plugin. (#59471) Thanks @vincentkoc.
- Gateway/exec loopback: restore legacy-role fallback for empty paired-device token maps and allow silent local role upgrades so local exec and node clients stop failing with pairing-required errors after `2026.3.31`. (#59092) Thanks @openperf.
- Agents/output sanitization: strip namespaced `antml:thinking` blocks from user-visible text so Anthropic-style internal monologue tags do not leak into replies. (#59550) Thanks @obviyus.

View File

@@ -11,23 +11,14 @@ import {
} from "./image-generation-provider.js";
function expectFalJsonPost(params: { call: number; url: string; body: Record<string, unknown> }) {
expect(fetchWithSsrFGuardMock).toHaveBeenNthCalledWith(
params.call,
expect.objectContaining({
url: params.url,
init: expect.objectContaining({
method: "POST",
headers: expect.objectContaining({
Authorization: "Key fal-test-key",
"Content-Type": "application/json",
}),
}),
auditContext: "fal-image-generate",
}),
);
const request = fetchWithSsrFGuardMock.mock.calls[params.call - 1]?.[0];
expect(request).toBeTruthy();
expect(request?.url).toBe(params.url);
expect(request?.auditContext).toBe("fal-image-generate");
expect(request?.init?.method).toBe("POST");
const headers = new Headers(request?.init?.headers);
expect(headers.get("authorization")).toBe("Key fal-test-key");
expect(headers.get("content-type")).toBe("application/json");
expect(JSON.parse(String(request?.init?.body))).toEqual(params.body);
}
@@ -361,17 +352,13 @@ describe("fal image-generation provider", () => {
);
});
it("allows trusted private relay hosts derived from configured baseUrl", async () => {
it("does not auto-whitelist trusted private relay hosts from a configured baseUrl", async () => {
vi.spyOn(providerAuth, "resolveApiKeyForProvider").mockResolvedValue({
apiKey: "fal-test-key",
source: "env",
mode: "api-key",
});
_setFalFetchGuardForTesting(fetchWithSsrFGuardMock);
const relayPolicy = {
allowPrivateNetwork: true,
hostnameAllowlist: ["relay.internal", "*.relay.internal"],
};
fetchWithSsrFGuardMock
.mockResolvedValueOnce({
response: new Response(
@@ -415,7 +402,7 @@ describe("fal image-generation provider", () => {
expect.objectContaining({
url: "http://relay.internal:8080/fal-ai/flux/dev",
auditContext: "fal-image-generate",
policy: relayPolicy,
policy: undefined,
}),
);
expect(fetchWithSsrFGuardMock).toHaveBeenNthCalledWith(
@@ -423,7 +410,7 @@ describe("fal image-generation provider", () => {
expect.objectContaining({
url: "http://media.relay.internal/files/generated.png",
auditContext: "fal-image-download",
policy: relayPolicy,
policy: undefined,
}),
);
});

View File

@@ -3,6 +3,10 @@ import type {
ImageGenerationProvider,
} from "openclaw/plugin-sdk/image-generation";
import { resolveApiKeyForProvider } from "openclaw/plugin-sdk/provider-auth-runtime";
import {
assertOkOrThrowHttpError,
resolveProviderHttpRequestConfig,
} from "openclaw/plugin-sdk/provider-http";
import {
buildHostnameAllowlistPolicyFromSuffixAllowlist,
fetchWithSsrFGuard,
@@ -81,40 +85,32 @@ function matchesTrustedHostSuffix(hostname: string, trustedSuffix: string): bool
return normalizedHost === normalizedSuffix || normalizedHost.endsWith(`.${normalizedSuffix}`);
}
function resolveFalNetworkPolicy(
cfg: Parameters<typeof resolveApiKeyForProvider>[0]["cfg"],
): FalNetworkPolicy {
const baseUrl = resolveFalBaseUrl(cfg);
const explicitBaseUrl = cfg?.models?.providers?.fal?.baseUrl?.trim();
function resolveFalNetworkPolicy(params: {
baseUrl: string;
allowPrivateNetwork: boolean;
}): FalNetworkPolicy {
let parsedBaseUrl: URL;
try {
parsedBaseUrl = new URL(baseUrl);
parsedBaseUrl = new URL(params.baseUrl);
} catch {
return {};
}
const hostSuffix = parsedBaseUrl.hostname.trim().toLowerCase();
if (!hostSuffix) {
if (!hostSuffix || !params.allowPrivateNetwork) {
return {};
}
const hostPolicy = buildHostnameAllowlistPolicyFromSuffixAllowlist([hostSuffix]);
const privateNetworkPolicy = explicitBaseUrl
? ssrfPolicyFromAllowPrivateNetwork(true)
: undefined;
const privateNetworkPolicy = ssrfPolicyFromAllowPrivateNetwork(true);
const trustedHostPolicy = mergeSsrFPolicies(hostPolicy, privateNetworkPolicy);
return {
apiPolicy: trustedHostPolicy,
trustedDownloadHostSuffix: explicitBaseUrl ? hostSuffix : undefined,
trustedDownloadPolicy: explicitBaseUrl ? trustedHostPolicy : undefined,
trustedDownloadHostSuffix: hostSuffix,
trustedDownloadPolicy: trustedHostPolicy,
};
}
function resolveFalBaseUrl(cfg: Parameters<typeof resolveApiKeyForProvider>[0]["cfg"]): string {
const direct = cfg?.models?.providers?.fal?.baseUrl?.trim();
return (direct || DEFAULT_FAL_BASE_URL).replace(/\/+$/u, "");
}
function ensureFalModelPath(model: string | undefined, hasInputImages: boolean): string {
const trimmed = model?.trim() || DEFAULT_FAL_IMAGE_MODEL;
if (!hasInputImages) {
@@ -341,7 +337,21 @@ export function buildFalImageGenerationProvider(): ImageGenerationProvider {
hasInputImages,
});
const model = ensureFalModelPath(req.model, hasInputImages);
const networkPolicy = resolveFalNetworkPolicy(req.cfg);
const explicitBaseUrl = req.cfg?.models?.providers?.fal?.baseUrl?.trim();
const { baseUrl, allowPrivateNetwork, headers, dispatcherPolicy } =
resolveProviderHttpRequestConfig({
baseUrl: explicitBaseUrl,
defaultBaseUrl: DEFAULT_FAL_BASE_URL,
allowPrivateNetwork: false,
defaultHeaders: {
Authorization: `Key ${auth.apiKey}`,
"Content-Type": "application/json",
},
provider: "fal",
capability: "image",
transport: "http",
});
const networkPolicy = resolveFalNetworkPolicy({ baseUrl, allowPrivateNetwork });
const requestBody: Record<string, unknown> = {
prompt: req.prompt,
num_images: req.count ?? 1,
@@ -358,27 +368,20 @@ export function buildFalImageGenerationProvider(): ImageGenerationProvider {
}
requestBody.image_url = toDataUri(input.buffer, input.mimeType);
}
const { response, release } = await falFetchGuard({
url: `${resolveFalBaseUrl(req.cfg)}/${model}`,
url: `${baseUrl}/${model}`,
init: {
method: "POST",
headers: {
Authorization: `Key ${auth.apiKey}`,
"Content-Type": "application/json",
},
headers,
body: JSON.stringify(requestBody),
},
timeoutMs: req.timeoutMs,
policy: networkPolicy.apiPolicy,
dispatcherPolicy,
auditContext: "fal-image-generate",
});
try {
if (!response.ok) {
const text = await response.text().catch(() => "");
throw new Error(
`fal image generation failed (${response.status}): ${text || response.statusText}`,
);
}
await assertOkOrThrowHttpError(response, "fal image generation failed");
const payload = (await response.json()) as FalImageGenerationResponse;
const images: GeneratedImageAsset[] = [];

View File

@@ -0,0 +1,147 @@
import * as providerAuth from "openclaw/plugin-sdk/provider-auth-runtime";
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
import { buildMinimaxImageGenerationProvider } from "./image-generation-provider.js";
describe("minimax image-generation provider", () => {
beforeEach(() => {
vi.clearAllMocks();
});
afterEach(() => {
vi.restoreAllMocks();
});
it("generates PNG buffers through the shared provider HTTP path", async () => {
vi.spyOn(providerAuth, "resolveApiKeyForProvider").mockResolvedValue({
apiKey: "minimax-test-key",
source: "env",
mode: "api-key",
});
const fetchMock = vi.fn().mockResolvedValue(
new Response(
JSON.stringify({
data: {
image_base64: [Buffer.from("png-data").toString("base64")],
},
base_resp: { status_code: 0 },
}),
{
status: 200,
headers: { "Content-Type": "application/json" },
},
),
);
vi.stubGlobal("fetch", fetchMock);
const provider = buildMinimaxImageGenerationProvider();
const result = await provider.generateImage({
provider: "minimax",
model: "image-01",
prompt: "draw a cat",
cfg: {},
});
expect(fetchMock).toHaveBeenCalledWith(
"https://api.minimax.io/v1/image_generation",
expect.objectContaining({
method: "POST",
body: JSON.stringify({
model: "image-01",
prompt: "draw a cat",
response_format: "base64",
n: 1,
}),
}),
);
const [, init] = fetchMock.mock.calls[0] as [string, RequestInit];
const headers = new Headers(init.headers);
expect(headers.get("authorization")).toBe("Bearer minimax-test-key");
expect(headers.get("content-type")).toBe("application/json");
expect(result).toEqual({
images: [
{
buffer: Buffer.from("png-data"),
mimeType: "image/png",
fileName: "image-1.png",
},
],
model: "image-01",
});
});
it("uses the configured provider base URL origin", async () => {
vi.spyOn(providerAuth, "resolveApiKeyForProvider").mockResolvedValue({
apiKey: "minimax-test-key",
source: "env",
mode: "api-key",
});
const fetchMock = vi.fn().mockResolvedValue(
new Response(
JSON.stringify({
data: {
image_base64: [Buffer.from("png-data").toString("base64")],
},
base_resp: { status_code: 0 },
}),
{
status: 200,
headers: { "Content-Type": "application/json" },
},
),
);
vi.stubGlobal("fetch", fetchMock);
const provider = buildMinimaxImageGenerationProvider();
await provider.generateImage({
provider: "minimax",
model: "image-01",
prompt: "draw a cat",
cfg: {
models: {
providers: {
minimax: {
baseUrl: "https://api.minimax.io/anthropic",
models: [],
},
},
},
},
});
expect(fetchMock).toHaveBeenCalledWith(
"https://api.minimax.io/v1/image_generation",
expect.any(Object),
);
});
it("does not allow private-network routing just because a custom base URL is configured", async () => {
vi.spyOn(providerAuth, "resolveApiKeyForProvider").mockResolvedValue({
apiKey: "minimax-test-key",
source: "env",
mode: "api-key",
});
const fetchMock = vi.fn();
vi.stubGlobal("fetch", fetchMock);
const provider = buildMinimaxImageGenerationProvider();
await expect(
provider.generateImage({
provider: "minimax",
model: "image-01",
prompt: "draw a cat",
cfg: {
models: {
providers: {
minimax: {
baseUrl: "http://127.0.0.1:8080/anthropic",
models: [],
},
},
},
},
}),
).rejects.toThrow("Blocked hostname or private/internal/special-use IP address");
expect(fetchMock).not.toHaveBeenCalled();
});
});

View File

@@ -1,5 +1,10 @@
import type { ImageGenerationProvider } from "openclaw/plugin-sdk/image-generation";
import { resolveApiKeyForProvider } from "openclaw/plugin-sdk/provider-auth-runtime";
import {
assertOkOrThrowHttpError,
postJsonRequest,
resolveProviderHttpRequestConfig,
} from "openclaw/plugin-sdk/provider-http";
const DEFAULT_MINIMAX_IMAGE_BASE_URL = "https://api.minimax.io";
const DEFAULT_MODEL = "image-01";
@@ -83,6 +88,23 @@ function buildMinimaxImageProvider(providerId: string): ImageGenerationProvider
}
const baseUrl = resolveMinimaxImageBaseUrl(req.cfg, providerId);
const {
baseUrl: resolvedBaseUrl,
allowPrivateNetwork,
headers,
dispatcherPolicy,
} = resolveProviderHttpRequestConfig({
baseUrl,
defaultBaseUrl: DEFAULT_MINIMAX_IMAGE_BASE_URL,
allowPrivateNetwork: false,
defaultHeaders: {
Authorization: `Bearer ${auth.apiKey}`,
"Content-Type": "application/json",
},
provider: providerId,
capability: "image",
transport: "http",
});
const body: Record<string, unknown> = {
model: req.model || DEFAULT_MODEL,
@@ -102,67 +124,55 @@ function buildMinimaxImageProvider(providerId: string): ImageGenerationProvider
const dataUrl = `data:${mime};base64,${ref.buffer.toString("base64")}`;
body.subject_reference = [{ type: "character", image_file: dataUrl }];
}
const controller = new AbortController();
const timeoutMs = req.timeoutMs;
const timeout =
typeof timeoutMs === "number" && Number.isFinite(timeoutMs) && timeoutMs > 0
? setTimeout(() => controller.abort(), timeoutMs)
: undefined;
const response = await fetch(`${baseUrl}/v1/image_generation`, {
method: "POST",
headers: {
Authorization: `Bearer ${auth.apiKey}`,
"Content-Type": "application/json",
},
body: JSON.stringify(body),
signal: controller.signal,
}).finally(() => {
clearTimeout(timeout);
const { response, release } = await postJsonRequest({
url: `${resolvedBaseUrl}/v1/image_generation`,
headers,
body,
timeoutMs: req.timeoutMs,
fetchFn: fetch,
allowPrivateNetwork,
dispatcherPolicy,
});
try {
await assertOkOrThrowHttpError(response, "MiniMax image generation failed");
if (!response.ok) {
const text = await response.text().catch(() => "");
throw new Error(
`MiniMax image generation failed (${response.status}): ${text || response.statusText}`,
);
const data = (await response.json()) as MinimaxImageApiResponse;
const baseResp = data.base_resp;
if (baseResp && typeof baseResp.status_code === "number" && baseResp.status_code !== 0) {
const msg = baseResp.status_msg ?? "";
throw new Error(`MiniMax image generation API error (${baseResp.status_code}): ${msg}`);
}
const base64Images = data.data?.image_base64 ?? [];
const failedCount = data.metadata?.failed_count ?? 0;
if (base64Images.length === 0) {
const reason =
failedCount > 0 ? `${failedCount} image(s) failed to generate` : "no images returned";
throw new Error(`MiniMax image generation returned no images: ${reason}`);
}
const images = base64Images
.map((b64, index) => {
if (!b64) {
return null;
}
return {
buffer: Buffer.from(b64, "base64"),
mimeType: DEFAULT_OUTPUT_MIME,
fileName: `image-${index + 1}.png`,
};
})
.filter((entry): entry is NonNullable<typeof entry> => entry !== null);
return {
images,
model: req.model || DEFAULT_MODEL,
};
} finally {
await release();
}
const data = (await response.json()) as MinimaxImageApiResponse;
const baseResp = data.base_resp;
if (baseResp && typeof baseResp.status_code === "number" && baseResp.status_code !== 0) {
const msg = baseResp.status_msg ?? "";
throw new Error(`MiniMax image generation API error (${baseResp.status_code}): ${msg}`);
}
const base64Images = data.data?.image_base64 ?? [];
const failedCount = data.metadata?.failed_count ?? 0;
if (base64Images.length === 0) {
const reason =
failedCount > 0 ? `${failedCount} image(s) failed to generate` : "no images returned";
throw new Error(`MiniMax image generation returned no images: ${reason}`);
}
const images = base64Images
.map((b64, index) => {
if (!b64) {
return null;
}
return {
buffer: Buffer.from(b64, "base64"),
mimeType: DEFAULT_OUTPUT_MIME,
fileName: `image-${index + 1}.png`,
};
})
.filter((entry): entry is NonNullable<typeof entry> => entry !== null);
return {
images,
model: req.model || DEFAULT_MODEL,
};
},
};
}

View File

@@ -1,5 +1,10 @@
import type { ImageGenerationProvider } from "openclaw/plugin-sdk/image-generation";
import { resolveApiKeyForProvider } from "openclaw/plugin-sdk/provider-auth-runtime";
import {
assertOkOrThrowHttpError,
postJsonRequest,
resolveProviderHttpRequestConfig,
} from "openclaw/plugin-sdk/provider-http";
import { OPENAI_DEFAULT_IMAGE_MODEL as DEFAULT_OPENAI_IMAGE_MODEL } from "./default-models.js";
const DEFAULT_OPENAI_IMAGE_BASE_URL = "https://api.openai.com/v1";
@@ -57,57 +62,59 @@ export function buildOpenAIImageGenerationProvider(): ImageGenerationProvider {
if (!auth.apiKey) {
throw new Error("OpenAI API key missing");
}
const { baseUrl, allowPrivateNetwork, headers, dispatcherPolicy } =
resolveProviderHttpRequestConfig({
baseUrl: resolveOpenAIBaseUrl(req.cfg),
defaultBaseUrl: DEFAULT_OPENAI_IMAGE_BASE_URL,
allowPrivateNetwork: false,
defaultHeaders: {
Authorization: `Bearer ${auth.apiKey}`,
"Content-Type": "application/json",
},
provider: "openai",
capability: "image",
transport: "http",
});
const controller = new AbortController();
const timeoutMs = req.timeoutMs;
const timeout =
typeof timeoutMs === "number" && Number.isFinite(timeoutMs) && timeoutMs > 0
? setTimeout(() => controller.abort(), timeoutMs)
: undefined;
const response = await fetch(`${resolveOpenAIBaseUrl(req.cfg)}/images/generations`, {
method: "POST",
headers: {
Authorization: `Bearer ${auth.apiKey}`,
"Content-Type": "application/json",
},
body: JSON.stringify({
const { response, release } = await postJsonRequest({
url: `${baseUrl}/images/generations`,
headers,
body: {
model: req.model || DEFAULT_OPENAI_IMAGE_MODEL,
prompt: req.prompt,
n: req.count ?? 1,
size: req.size ?? DEFAULT_SIZE,
}),
signal: controller.signal,
}).finally(() => {
clearTimeout(timeout);
},
timeoutMs: req.timeoutMs,
fetchFn: fetch,
allowPrivateNetwork,
dispatcherPolicy,
});
try {
await assertOkOrThrowHttpError(response, "OpenAI image generation failed");
if (!response.ok) {
const text = await response.text().catch(() => "");
throw new Error(
`OpenAI image generation failed (${response.status}): ${text || response.statusText}`,
);
const data = (await response.json()) as OpenAIImageApiResponse;
const images = (data.data ?? [])
.map((entry, index) => {
if (!entry.b64_json) {
return null;
}
return {
buffer: Buffer.from(entry.b64_json, "base64"),
mimeType: DEFAULT_OUTPUT_MIME,
fileName: `image-${index + 1}.png`,
...(entry.revised_prompt ? { revisedPrompt: entry.revised_prompt } : {}),
};
})
.filter((entry): entry is NonNullable<typeof entry> => entry !== null);
return {
images,
model: req.model || DEFAULT_OPENAI_IMAGE_MODEL,
};
} finally {
await release();
}
const data = (await response.json()) as OpenAIImageApiResponse;
const images = (data.data ?? [])
.map((entry, index) => {
if (!entry.b64_json) {
return null;
}
return {
buffer: Buffer.from(entry.b64_json, "base64"),
mimeType: DEFAULT_OUTPUT_MIME,
fileName: `image-${index + 1}.png`,
...(entry.revised_prompt ? { revisedPrompt: entry.revised_prompt } : {}),
};
})
.filter((entry): entry is NonNullable<typeof entry> => entry !== null);
return {
images,
model: req.model || DEFAULT_OPENAI_IMAGE_MODEL,
};
},
};
}

View File

@@ -117,6 +117,37 @@ describe("openai plugin", () => {
).rejects.toThrow("does not support reference-image edits");
});
it("does not allow private-network routing just because a custom base URL is configured", async () => {
vi.spyOn(providerAuth, "resolveApiKeyForProvider").mockResolvedValue({
apiKey: "sk-test",
source: "env",
mode: "api-key",
});
const fetchMock = vi.fn();
vi.stubGlobal("fetch", fetchMock);
const provider = buildOpenAIImageGenerationProvider();
await expect(
provider.generateImage({
provider: "openai",
model: "gpt-image-1",
prompt: "draw a cat",
cfg: {
models: {
providers: {
openai: {
baseUrl: "http://127.0.0.1:8080/v1",
models: [],
},
},
},
} satisfies OpenClawConfig,
}),
).rejects.toThrow("Blocked hostname or private/internal/special-use IP address");
expect(fetchMock).not.toHaveBeenCalled();
});
it("bootstraps the env proxy dispatcher before refreshing codex oauth credentials", async () => {
const refreshed = {
access: "next-access",

View File

@@ -1,5 +1,26 @@
import { describe, expect, it } from "vitest";
import { resolveProviderHttpRequestConfig } from "./shared.js";
import { afterEach, describe, expect, it, vi } from "vitest";
const { fetchWithSsrFGuardMock } = vi.hoisted(() => ({
fetchWithSsrFGuardMock: vi.fn(),
}));
vi.mock("../infra/net/fetch-guard.js", async (importOriginal) => {
const actual = await importOriginal<typeof import("../infra/net/fetch-guard.js")>();
return {
...actual,
fetchWithSsrFGuard: fetchWithSsrFGuardMock,
};
});
import {
fetchWithTimeoutGuarded,
readErrorResponse,
resolveProviderHttpRequestConfig,
} from "./shared.js";
afterEach(() => {
vi.clearAllMocks();
});
describe("resolveProviderHttpRequestConfig", () => {
it("preserves explicit caller headers but protects attribution headers", () => {
@@ -108,3 +129,67 @@ describe("resolveProviderHttpRequestConfig", () => {
).toThrow("Missing baseUrl");
});
});
describe("readErrorResponse", () => {
it("caps streamed error bodies instead of buffering the whole response", async () => {
const encoder = new TextEncoder();
let reads = 0;
const response = new Response(
new ReadableStream<Uint8Array>({
pull(controller) {
reads += 1;
controller.enqueue(encoder.encode("a".repeat(2048)));
if (reads >= 10) {
controller.close();
}
},
}),
{
status: 500,
},
);
const detail = await readErrorResponse(response);
expect(detail).toBe(`${"a".repeat(300)}`);
expect(reads).toBe(2);
});
});
describe("fetchWithTimeoutGuarded", () => {
it("applies a default timeout when callers omit one", async () => {
fetchWithSsrFGuardMock.mockResolvedValue({
response: new Response(null, { status: 200 }),
finalUrl: "https://example.com",
release: async () => {},
});
await fetchWithTimeoutGuarded("https://example.com", {}, undefined, fetch);
expect(fetchWithSsrFGuardMock).toHaveBeenCalledWith(
expect.objectContaining({
url: "https://example.com",
timeoutMs: 60_000,
}),
);
});
it("sanitizes auditContext before passing it to the SSRF guard", async () => {
fetchWithSsrFGuardMock.mockResolvedValue({
response: new Response(null, { status: 200 }),
finalUrl: "https://example.com",
release: async () => {},
});
await fetchWithTimeoutGuarded("https://example.com", {}, 5000, fetch, {
auditContext: "provider-http\r\nfal\timage\u001btest",
});
expect(fetchWithSsrFGuardMock).toHaveBeenCalledWith(
expect.objectContaining({
auditContext: "provider-http fal image test",
timeoutMs: 5000,
}),
);
});
});

View File

@@ -16,6 +16,27 @@ export { fetchWithTimeout } from "../utils/fetch-timeout.js";
export { normalizeBaseUrl } from "../agents/provider-request-config.js";
const MAX_ERROR_CHARS = 300;
const MAX_ERROR_RESPONSE_BYTES = 4096;
const DEFAULT_GUARDED_HTTP_TIMEOUT_MS = 60_000;
const MAX_AUDIT_CONTEXT_CHARS = 80;
function resolveGuardedHttpTimeoutMs(timeoutMs: number | undefined): number {
if (typeof timeoutMs !== "number" || !Number.isFinite(timeoutMs) || timeoutMs <= 0) {
return DEFAULT_GUARDED_HTTP_TIMEOUT_MS;
}
return timeoutMs;
}
function sanitizeAuditContext(auditContext: string | undefined): string | undefined {
const cleaned = auditContext
?.replace(/\p{Cc}+/gu, " ")
.replace(/\s+/g, " ")
.trim();
if (!cleaned) {
return undefined;
}
return cleaned.slice(0, MAX_AUDIT_CONTEXT_CHARS);
}
export function resolveProviderHttpRequestConfig(params: {
baseUrl?: string;
@@ -67,24 +88,26 @@ export function resolveProviderHttpRequestConfig(params: {
export async function fetchWithTimeoutGuarded(
url: string,
init: RequestInit,
timeoutMs: number,
timeoutMs: number | undefined,
fetchFn: typeof fetch,
options?: {
ssrfPolicy?: SsrFPolicy;
lookupFn?: LookupFn;
pinDns?: boolean;
dispatcherPolicy?: PinnedDispatcherPolicy;
auditContext?: string;
},
): Promise<GuardedFetchResult> {
return await fetchWithSsrFGuard({
url,
fetchImpl: fetchFn,
init,
timeoutMs,
timeoutMs: resolveGuardedHttpTimeoutMs(timeoutMs),
policy: options?.ssrfPolicy,
lookupFn: options?.lookupFn,
pinDns: options?.pinDns,
dispatcherPolicy: options?.dispatcherPolicy,
auditContext: sanitizeAuditContext(options?.auditContext),
});
}
@@ -92,10 +115,11 @@ export async function postTranscriptionRequest(params: {
url: string;
headers: Headers;
body: BodyInit;
timeoutMs: number;
timeoutMs?: number;
fetchFn: typeof fetch;
allowPrivateNetwork?: boolean;
dispatcherPolicy?: PinnedDispatcherPolicy;
auditContext?: string;
}) {
return fetchWithTimeoutGuarded(
params.url,
@@ -110,6 +134,7 @@ export async function postTranscriptionRequest(params: {
? {
...(params.allowPrivateNetwork ? { ssrfPolicy: { allowPrivateNetwork: true } } : {}),
...(params.dispatcherPolicy ? { dispatcherPolicy: params.dispatcherPolicy } : {}),
...(params.auditContext ? { auditContext: params.auditContext } : {}),
}
: undefined,
);
@@ -119,10 +144,11 @@ export async function postJsonRequest(params: {
url: string;
headers: Headers;
body: unknown;
timeoutMs: number;
timeoutMs?: number;
fetchFn: typeof fetch;
allowPrivateNetwork?: boolean;
dispatcherPolicy?: PinnedDispatcherPolicy;
auditContext?: string;
}) {
return fetchWithTimeoutGuarded(
params.url,
@@ -137,14 +163,49 @@ export async function postJsonRequest(params: {
? {
...(params.allowPrivateNetwork ? { ssrfPolicy: { allowPrivateNetwork: true } } : {}),
...(params.dispatcherPolicy ? { dispatcherPolicy: params.dispatcherPolicy } : {}),
...(params.auditContext ? { auditContext: params.auditContext } : {}),
}
: undefined,
);
}
export async function readErrorResponse(res: Response): Promise<string | undefined> {
let reader: ReadableStreamDefaultReader<Uint8Array> | undefined;
try {
const text = await res.text();
if (!res.body) {
return undefined;
}
reader = res.body.getReader();
const chunks: Uint8Array[] = [];
let total = 0;
let sawBytes = false;
while (total < MAX_ERROR_RESPONSE_BYTES) {
const { done, value } = await reader.read();
if (done) {
break;
}
if (!value || value.length === 0) {
continue;
}
sawBytes = true;
const remaining = MAX_ERROR_RESPONSE_BYTES - total;
const chunk = value.length <= remaining ? value : value.subarray(0, remaining);
chunks.push(chunk);
total += chunk.length;
if (chunk.length < value.length) {
break;
}
}
if (!sawBytes) {
return undefined;
}
const bytes = new Uint8Array(total);
let offset = 0;
for (const chunk of chunks) {
bytes.set(chunk, offset);
offset += chunk.length;
}
const text = new TextDecoder().decode(bytes);
const collapsed = text.replace(/\s+/g, " ").trim();
if (!collapsed) {
return undefined;
@@ -155,6 +216,12 @@ export async function readErrorResponse(res: Response): Promise<string | undefin
return `${collapsed.slice(0, MAX_ERROR_CHARS)}`;
} catch {
return undefined;
} finally {
try {
await reader?.cancel();
} catch {
// Ignore stream-cancel failures while reporting the original HTTP error.
}
}
}