refactor: share ssrf base url policy

This commit is contained in:
Peter Steinberger
2026-04-20 23:13:27 +01:00
parent 85450b3da9
commit da5a6b68bd
7 changed files with 46 additions and 54 deletions

View File

@@ -2,7 +2,7 @@ import type { StreamFn } from "@mariozechner/pi-agent-core";
import { streamSimple } from "@mariozechner/pi-ai";
import { createSubsystemLogger } from "openclaw/plugin-sdk/logging-core";
import type { ProviderWrapStreamFnContext } from "openclaw/plugin-sdk/plugin-entry";
import type { SsrFPolicy } from "openclaw/plugin-sdk/ssrf-runtime";
import { ssrfPolicyFromHttpBaseUrlAllowedHostname } from "openclaw/plugin-sdk/ssrf-runtime";
import { LMSTUDIO_PROVIDER_ID } from "./defaults.js";
import { ensureLmstudioModelLoaded } from "./models.fetch.js";
import { resolveLmstudioInferenceBase } from "./models.js";
@@ -120,22 +120,6 @@ function createPreloadKey(params: {
return `${params.baseUrl}::${params.modelKey}::${params.requestedContextLength ?? "default"}`;
}
function buildLmstudioPreloadSsrFPolicy(baseUrl: string): SsrFPolicy | undefined {
const trimmed = baseUrl.trim();
if (!trimmed) {
return undefined;
}
try {
const parsed = new URL(trimmed);
if (parsed.protocol !== "http:" && parsed.protocol !== "https:") {
return undefined;
}
return { allowedHostnames: [parsed.hostname] };
} catch {
return undefined;
}
}
async function ensureLmstudioModelLoadedBestEffort(params: {
baseUrl: string;
modelKey: string;
@@ -167,7 +151,7 @@ async function ensureLmstudioModelLoadedBestEffort(params: {
baseUrl: params.baseUrl,
apiKey: runtimeApiKey ?? configuredApiKey,
headers,
ssrfPolicy: buildLmstudioPreloadSsrFPolicy(params.baseUrl),
ssrfPolicy: ssrfPolicyFromHttpBaseUrlAllowedHostname(params.baseUrl),
modelKey: params.modelKey,
requestedContextLength: params.requestedContextLength,
});

View File

@@ -11,6 +11,10 @@ const { fetchWithSsrFGuardMock } = vi.hoisted(() => ({
vi.mock("openclaw/plugin-sdk/ssrf-runtime", () => ({
fetchWithSsrFGuard: fetchWithSsrFGuardMock,
formatErrorMessage: (error: unknown) => (error instanceof Error ? error.message : String(error)),
ssrfPolicyFromHttpBaseUrlAllowedHostname: (baseUrl: string) => {
const parsed = new URL(baseUrl);
return { allowedHostnames: [parsed.hostname] };
},
}));
let createOllamaEmbeddingProvider: typeof import("./embedding-provider.js").createOllamaEmbeddingProvider;

View File

@@ -8,6 +8,7 @@ import {
import {
fetchWithSsrFGuard,
formatErrorMessage,
ssrfPolicyFromHttpBaseUrlAllowedHostname,
type SsrFPolicy,
} from "openclaw/plugin-sdk/ssrf-runtime";
import { resolveOllamaApiBase } from "./provider-models.js";
@@ -57,22 +58,6 @@ function sanitizeAndNormalizeEmbedding(vec: number[]): number[] {
return sanitized.map((value) => value / magnitude);
}
function buildRemoteBaseUrlPolicy(baseUrl: string): SsrFPolicy | undefined {
const trimmed = baseUrl.trim();
if (!trimmed) {
return undefined;
}
try {
const parsed = new URL(trimmed);
if (parsed.protocol !== "http:" && parsed.protocol !== "https:") {
return undefined;
}
return { allowedHostnames: [parsed.hostname] };
} catch {
return undefined;
}
}
async function withRemoteHttpResponse<T>(params: {
url: string;
init?: RequestInit;
@@ -149,7 +134,7 @@ function resolveOllamaEmbeddingClient(
return {
baseUrl,
headers,
ssrfPolicy: buildRemoteBaseUrlPolicy(baseUrl),
ssrfPolicy: ssrfPolicyFromHttpBaseUrlAllowedHostname(baseUrl),
model,
};
}

View File

@@ -1,6 +1,10 @@
import { describe, expect, it } from "vitest";
import { blockedIpv6MulticastLiterals } from "../../shared/net/ip-test-fixtures.js";
import { isBlockedHostnameOrIp, isPrivateIpAddress } from "./ssrf.js";
import {
isBlockedHostnameOrIp,
isPrivateIpAddress,
ssrfPolicyFromHttpBaseUrlAllowedHostname,
} from "./ssrf.js";
const privateIpCases = [
"198.18.0.1",
@@ -106,6 +110,20 @@ describe("ssrf ip classification", () => {
});
});
describe("ssrfPolicyFromHttpBaseUrlAllowedHostname", () => {
it("builds an allowed-hostname policy from HTTP base URLs", () => {
expect(ssrfPolicyFromHttpBaseUrlAllowedHostname(" https://api.example.com/v1 ")).toEqual({
allowedHostnames: ["api.example.com"],
});
});
it("ignores empty, invalid, and non-HTTP URLs", () => {
expect(ssrfPolicyFromHttpBaseUrlAllowedHostname("")).toBeUndefined();
expect(ssrfPolicyFromHttpBaseUrlAllowedHostname("not-a-url")).toBeUndefined();
expect(ssrfPolicyFromHttpBaseUrlAllowedHostname("ftp://api.example.com")).toBeUndefined();
});
});
describe("isBlockedHostnameOrIp", () => {
it.each([
"localhost.localdomain",

View File

@@ -44,6 +44,22 @@ export type SsrFPolicy = {
hostnameAllowlist?: string[];
};
export function ssrfPolicyFromHttpBaseUrlAllowedHostname(baseUrl: string): SsrFPolicy | undefined {
const trimmed = baseUrl.trim();
if (!trimmed) {
return undefined;
}
try {
const parsed = new URL(trimmed);
if (parsed.protocol !== "http:" && parsed.protocol !== "https:") {
return undefined;
}
return { allowedHostnames: [parsed.hostname] };
} catch {
return undefined;
}
}
const BLOCKED_HOSTNAMES = new Set([
"localhost",
"localhost.localdomain",

View File

@@ -1,23 +1,7 @@
import { fetchWithSsrFGuard } from "../../infra/net/fetch-guard.js";
import type { SsrFPolicy } from "../../infra/net/ssrf.js";
import { ssrfPolicyFromHttpBaseUrlAllowedHostname, type SsrFPolicy } from "../../infra/net/ssrf.js";
export function buildRemoteBaseUrlPolicy(baseUrl: string): SsrFPolicy | undefined {
const trimmed = baseUrl.trim();
if (!trimmed) {
return undefined;
}
try {
const parsed = new URL(trimmed);
if (parsed.protocol !== "http:" && parsed.protocol !== "https:") {
return undefined;
}
// Keep policy tied to the configured host so private operator endpoints
// continue to work, while cross-host redirects stay blocked.
return { allowedHostnames: [parsed.hostname] };
} catch {
return undefined;
}
}
export const buildRemoteBaseUrlPolicy = ssrfPolicyFromHttpBaseUrlAllowedHostname;
export async function withRemoteHttpResponse<T>(params: {
url: string;

View File

@@ -7,6 +7,7 @@ export {
isBlockedHostnameOrIp,
resolvePinnedHostname,
resolvePinnedHostnameWithPolicy,
ssrfPolicyFromHttpBaseUrlAllowedHostname,
type LookupFn,
type SsrFPolicy,
} from "../infra/net/ssrf.js";