diff --git a/extensions/lmstudio/src/stream.ts b/extensions/lmstudio/src/stream.ts index ae2567cfb1e..e0c70b7ca94 100644 --- a/extensions/lmstudio/src/stream.ts +++ b/extensions/lmstudio/src/stream.ts @@ -2,7 +2,7 @@ import type { StreamFn } from "@mariozechner/pi-agent-core"; import { streamSimple } from "@mariozechner/pi-ai"; import { createSubsystemLogger } from "openclaw/plugin-sdk/logging-core"; import type { ProviderWrapStreamFnContext } from "openclaw/plugin-sdk/plugin-entry"; -import type { SsrFPolicy } from "openclaw/plugin-sdk/ssrf-runtime"; +import { ssrfPolicyFromHttpBaseUrlAllowedHostname } from "openclaw/plugin-sdk/ssrf-runtime"; import { LMSTUDIO_PROVIDER_ID } from "./defaults.js"; import { ensureLmstudioModelLoaded } from "./models.fetch.js"; import { resolveLmstudioInferenceBase } from "./models.js"; @@ -120,22 +120,6 @@ function createPreloadKey(params: { return `${params.baseUrl}::${params.modelKey}::${params.requestedContextLength ?? "default"}`; } -function buildLmstudioPreloadSsrFPolicy(baseUrl: string): SsrFPolicy | undefined { - const trimmed = baseUrl.trim(); - if (!trimmed) { - return undefined; - } - try { - const parsed = new URL(trimmed); - if (parsed.protocol !== "http:" && parsed.protocol !== "https:") { - return undefined; - } - return { allowedHostnames: [parsed.hostname] }; - } catch { - return undefined; - } -} - async function ensureLmstudioModelLoadedBestEffort(params: { baseUrl: string; modelKey: string; @@ -167,7 +151,7 @@ async function ensureLmstudioModelLoadedBestEffort(params: { baseUrl: params.baseUrl, apiKey: runtimeApiKey ?? configuredApiKey, headers, - ssrfPolicy: buildLmstudioPreloadSsrFPolicy(params.baseUrl), + ssrfPolicy: ssrfPolicyFromHttpBaseUrlAllowedHostname(params.baseUrl), modelKey: params.modelKey, requestedContextLength: params.requestedContextLength, }); diff --git a/extensions/ollama/src/embedding-provider.test.ts b/extensions/ollama/src/embedding-provider.test.ts index cb3c1bbac39..a85f1b7d273 100644 --- a/extensions/ollama/src/embedding-provider.test.ts +++ b/extensions/ollama/src/embedding-provider.test.ts @@ -11,6 +11,10 @@ const { fetchWithSsrFGuardMock } = vi.hoisted(() => ({ vi.mock("openclaw/plugin-sdk/ssrf-runtime", () => ({ fetchWithSsrFGuard: fetchWithSsrFGuardMock, formatErrorMessage: (error: unknown) => (error instanceof Error ? error.message : String(error)), + ssrfPolicyFromHttpBaseUrlAllowedHostname: (baseUrl: string) => { + const parsed = new URL(baseUrl); + return { allowedHostnames: [parsed.hostname] }; + }, })); let createOllamaEmbeddingProvider: typeof import("./embedding-provider.js").createOllamaEmbeddingProvider; diff --git a/extensions/ollama/src/embedding-provider.ts b/extensions/ollama/src/embedding-provider.ts index 7c6bef7607a..5ea33222c6d 100644 --- a/extensions/ollama/src/embedding-provider.ts +++ b/extensions/ollama/src/embedding-provider.ts @@ -8,6 +8,7 @@ import { import { fetchWithSsrFGuard, formatErrorMessage, + ssrfPolicyFromHttpBaseUrlAllowedHostname, type SsrFPolicy, } from "openclaw/plugin-sdk/ssrf-runtime"; import { resolveOllamaApiBase } from "./provider-models.js"; @@ -57,22 +58,6 @@ function sanitizeAndNormalizeEmbedding(vec: number[]): number[] { return sanitized.map((value) => value / magnitude); } -function buildRemoteBaseUrlPolicy(baseUrl: string): SsrFPolicy | undefined { - const trimmed = baseUrl.trim(); - if (!trimmed) { - return undefined; - } - try { - const parsed = new URL(trimmed); - if (parsed.protocol !== "http:" && parsed.protocol !== "https:") { - return undefined; - } - return { allowedHostnames: [parsed.hostname] }; - } catch { - return undefined; - } -} - async function withRemoteHttpResponse(params: { url: string; init?: RequestInit; @@ -149,7 +134,7 @@ function resolveOllamaEmbeddingClient( return { baseUrl, headers, - ssrfPolicy: buildRemoteBaseUrlPolicy(baseUrl), + ssrfPolicy: ssrfPolicyFromHttpBaseUrlAllowedHostname(baseUrl), model, }; } diff --git a/src/infra/net/ssrf.test.ts b/src/infra/net/ssrf.test.ts index 2d0f4a7527a..e5e97ba4a3d 100644 --- a/src/infra/net/ssrf.test.ts +++ b/src/infra/net/ssrf.test.ts @@ -1,6 +1,10 @@ import { describe, expect, it } from "vitest"; import { blockedIpv6MulticastLiterals } from "../../shared/net/ip-test-fixtures.js"; -import { isBlockedHostnameOrIp, isPrivateIpAddress } from "./ssrf.js"; +import { + isBlockedHostnameOrIp, + isPrivateIpAddress, + ssrfPolicyFromHttpBaseUrlAllowedHostname, +} from "./ssrf.js"; const privateIpCases = [ "198.18.0.1", @@ -106,6 +110,20 @@ describe("ssrf ip classification", () => { }); }); +describe("ssrfPolicyFromHttpBaseUrlAllowedHostname", () => { + it("builds an allowed-hostname policy from HTTP base URLs", () => { + expect(ssrfPolicyFromHttpBaseUrlAllowedHostname(" https://api.example.com/v1 ")).toEqual({ + allowedHostnames: ["api.example.com"], + }); + }); + + it("ignores empty, invalid, and non-HTTP URLs", () => { + expect(ssrfPolicyFromHttpBaseUrlAllowedHostname("")).toBeUndefined(); + expect(ssrfPolicyFromHttpBaseUrlAllowedHostname("not-a-url")).toBeUndefined(); + expect(ssrfPolicyFromHttpBaseUrlAllowedHostname("ftp://api.example.com")).toBeUndefined(); + }); +}); + describe("isBlockedHostnameOrIp", () => { it.each([ "localhost.localdomain", diff --git a/src/infra/net/ssrf.ts b/src/infra/net/ssrf.ts index d8acd365530..0df8fb0096e 100644 --- a/src/infra/net/ssrf.ts +++ b/src/infra/net/ssrf.ts @@ -44,6 +44,22 @@ export type SsrFPolicy = { hostnameAllowlist?: string[]; }; +export function ssrfPolicyFromHttpBaseUrlAllowedHostname(baseUrl: string): SsrFPolicy | undefined { + const trimmed = baseUrl.trim(); + if (!trimmed) { + return undefined; + } + try { + const parsed = new URL(trimmed); + if (parsed.protocol !== "http:" && parsed.protocol !== "https:") { + return undefined; + } + return { allowedHostnames: [parsed.hostname] }; + } catch { + return undefined; + } +} + const BLOCKED_HOSTNAMES = new Set([ "localhost", "localhost.localdomain", diff --git a/src/memory-host-sdk/host/remote-http.ts b/src/memory-host-sdk/host/remote-http.ts index e28d7cf10c4..7385125beda 100644 --- a/src/memory-host-sdk/host/remote-http.ts +++ b/src/memory-host-sdk/host/remote-http.ts @@ -1,23 +1,7 @@ import { fetchWithSsrFGuard } from "../../infra/net/fetch-guard.js"; -import type { SsrFPolicy } from "../../infra/net/ssrf.js"; +import { ssrfPolicyFromHttpBaseUrlAllowedHostname, type SsrFPolicy } from "../../infra/net/ssrf.js"; -export function buildRemoteBaseUrlPolicy(baseUrl: string): SsrFPolicy | undefined { - const trimmed = baseUrl.trim(); - if (!trimmed) { - return undefined; - } - try { - const parsed = new URL(trimmed); - if (parsed.protocol !== "http:" && parsed.protocol !== "https:") { - return undefined; - } - // Keep policy tied to the configured host so private operator endpoints - // continue to work, while cross-host redirects stay blocked. - return { allowedHostnames: [parsed.hostname] }; - } catch { - return undefined; - } -} +export const buildRemoteBaseUrlPolicy = ssrfPolicyFromHttpBaseUrlAllowedHostname; export async function withRemoteHttpResponse(params: { url: string; diff --git a/src/plugin-sdk/ssrf-runtime.ts b/src/plugin-sdk/ssrf-runtime.ts index 1295d3a60a7..119c6b82410 100644 --- a/src/plugin-sdk/ssrf-runtime.ts +++ b/src/plugin-sdk/ssrf-runtime.ts @@ -7,6 +7,7 @@ export { isBlockedHostnameOrIp, resolvePinnedHostname, resolvePinnedHostnameWithPolicy, + ssrfPolicyFromHttpBaseUrlAllowedHostname, type LookupFn, type SsrFPolicy, } from "../infra/net/ssrf.js";