mirror of
https://github.com/openclaw/openclaw.git
synced 2026-05-06 06:00:43 +00:00
refactor: share ssrf base url policy
This commit is contained in:
@@ -2,7 +2,7 @@ import type { StreamFn } from "@mariozechner/pi-agent-core";
|
||||
import { streamSimple } from "@mariozechner/pi-ai";
|
||||
import { createSubsystemLogger } from "openclaw/plugin-sdk/logging-core";
|
||||
import type { ProviderWrapStreamFnContext } from "openclaw/plugin-sdk/plugin-entry";
|
||||
import type { SsrFPolicy } from "openclaw/plugin-sdk/ssrf-runtime";
|
||||
import { ssrfPolicyFromHttpBaseUrlAllowedHostname } from "openclaw/plugin-sdk/ssrf-runtime";
|
||||
import { LMSTUDIO_PROVIDER_ID } from "./defaults.js";
|
||||
import { ensureLmstudioModelLoaded } from "./models.fetch.js";
|
||||
import { resolveLmstudioInferenceBase } from "./models.js";
|
||||
@@ -120,22 +120,6 @@ function createPreloadKey(params: {
|
||||
return `${params.baseUrl}::${params.modelKey}::${params.requestedContextLength ?? "default"}`;
|
||||
}
|
||||
|
||||
function buildLmstudioPreloadSsrFPolicy(baseUrl: string): SsrFPolicy | undefined {
|
||||
const trimmed = baseUrl.trim();
|
||||
if (!trimmed) {
|
||||
return undefined;
|
||||
}
|
||||
try {
|
||||
const parsed = new URL(trimmed);
|
||||
if (parsed.protocol !== "http:" && parsed.protocol !== "https:") {
|
||||
return undefined;
|
||||
}
|
||||
return { allowedHostnames: [parsed.hostname] };
|
||||
} catch {
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
async function ensureLmstudioModelLoadedBestEffort(params: {
|
||||
baseUrl: string;
|
||||
modelKey: string;
|
||||
@@ -167,7 +151,7 @@ async function ensureLmstudioModelLoadedBestEffort(params: {
|
||||
baseUrl: params.baseUrl,
|
||||
apiKey: runtimeApiKey ?? configuredApiKey,
|
||||
headers,
|
||||
ssrfPolicy: buildLmstudioPreloadSsrFPolicy(params.baseUrl),
|
||||
ssrfPolicy: ssrfPolicyFromHttpBaseUrlAllowedHostname(params.baseUrl),
|
||||
modelKey: params.modelKey,
|
||||
requestedContextLength: params.requestedContextLength,
|
||||
});
|
||||
|
||||
@@ -11,6 +11,10 @@ const { fetchWithSsrFGuardMock } = vi.hoisted(() => ({
|
||||
vi.mock("openclaw/plugin-sdk/ssrf-runtime", () => ({
|
||||
fetchWithSsrFGuard: fetchWithSsrFGuardMock,
|
||||
formatErrorMessage: (error: unknown) => (error instanceof Error ? error.message : String(error)),
|
||||
ssrfPolicyFromHttpBaseUrlAllowedHostname: (baseUrl: string) => {
|
||||
const parsed = new URL(baseUrl);
|
||||
return { allowedHostnames: [parsed.hostname] };
|
||||
},
|
||||
}));
|
||||
|
||||
let createOllamaEmbeddingProvider: typeof import("./embedding-provider.js").createOllamaEmbeddingProvider;
|
||||
|
||||
@@ -8,6 +8,7 @@ import {
|
||||
import {
|
||||
fetchWithSsrFGuard,
|
||||
formatErrorMessage,
|
||||
ssrfPolicyFromHttpBaseUrlAllowedHostname,
|
||||
type SsrFPolicy,
|
||||
} from "openclaw/plugin-sdk/ssrf-runtime";
|
||||
import { resolveOllamaApiBase } from "./provider-models.js";
|
||||
@@ -57,22 +58,6 @@ function sanitizeAndNormalizeEmbedding(vec: number[]): number[] {
|
||||
return sanitized.map((value) => value / magnitude);
|
||||
}
|
||||
|
||||
function buildRemoteBaseUrlPolicy(baseUrl: string): SsrFPolicy | undefined {
|
||||
const trimmed = baseUrl.trim();
|
||||
if (!trimmed) {
|
||||
return undefined;
|
||||
}
|
||||
try {
|
||||
const parsed = new URL(trimmed);
|
||||
if (parsed.protocol !== "http:" && parsed.protocol !== "https:") {
|
||||
return undefined;
|
||||
}
|
||||
return { allowedHostnames: [parsed.hostname] };
|
||||
} catch {
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
async function withRemoteHttpResponse<T>(params: {
|
||||
url: string;
|
||||
init?: RequestInit;
|
||||
@@ -149,7 +134,7 @@ function resolveOllamaEmbeddingClient(
|
||||
return {
|
||||
baseUrl,
|
||||
headers,
|
||||
ssrfPolicy: buildRemoteBaseUrlPolicy(baseUrl),
|
||||
ssrfPolicy: ssrfPolicyFromHttpBaseUrlAllowedHostname(baseUrl),
|
||||
model,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { blockedIpv6MulticastLiterals } from "../../shared/net/ip-test-fixtures.js";
|
||||
import { isBlockedHostnameOrIp, isPrivateIpAddress } from "./ssrf.js";
|
||||
import {
|
||||
isBlockedHostnameOrIp,
|
||||
isPrivateIpAddress,
|
||||
ssrfPolicyFromHttpBaseUrlAllowedHostname,
|
||||
} from "./ssrf.js";
|
||||
|
||||
const privateIpCases = [
|
||||
"198.18.0.1",
|
||||
@@ -106,6 +110,20 @@ describe("ssrf ip classification", () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe("ssrfPolicyFromHttpBaseUrlAllowedHostname", () => {
|
||||
it("builds an allowed-hostname policy from HTTP base URLs", () => {
|
||||
expect(ssrfPolicyFromHttpBaseUrlAllowedHostname(" https://api.example.com/v1 ")).toEqual({
|
||||
allowedHostnames: ["api.example.com"],
|
||||
});
|
||||
});
|
||||
|
||||
it("ignores empty, invalid, and non-HTTP URLs", () => {
|
||||
expect(ssrfPolicyFromHttpBaseUrlAllowedHostname("")).toBeUndefined();
|
||||
expect(ssrfPolicyFromHttpBaseUrlAllowedHostname("not-a-url")).toBeUndefined();
|
||||
expect(ssrfPolicyFromHttpBaseUrlAllowedHostname("ftp://api.example.com")).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe("isBlockedHostnameOrIp", () => {
|
||||
it.each([
|
||||
"localhost.localdomain",
|
||||
|
||||
@@ -44,6 +44,22 @@ export type SsrFPolicy = {
|
||||
hostnameAllowlist?: string[];
|
||||
};
|
||||
|
||||
export function ssrfPolicyFromHttpBaseUrlAllowedHostname(baseUrl: string): SsrFPolicy | undefined {
|
||||
const trimmed = baseUrl.trim();
|
||||
if (!trimmed) {
|
||||
return undefined;
|
||||
}
|
||||
try {
|
||||
const parsed = new URL(trimmed);
|
||||
if (parsed.protocol !== "http:" && parsed.protocol !== "https:") {
|
||||
return undefined;
|
||||
}
|
||||
return { allowedHostnames: [parsed.hostname] };
|
||||
} catch {
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
const BLOCKED_HOSTNAMES = new Set([
|
||||
"localhost",
|
||||
"localhost.localdomain",
|
||||
|
||||
@@ -1,23 +1,7 @@
|
||||
import { fetchWithSsrFGuard } from "../../infra/net/fetch-guard.js";
|
||||
import type { SsrFPolicy } from "../../infra/net/ssrf.js";
|
||||
import { ssrfPolicyFromHttpBaseUrlAllowedHostname, type SsrFPolicy } from "../../infra/net/ssrf.js";
|
||||
|
||||
export function buildRemoteBaseUrlPolicy(baseUrl: string): SsrFPolicy | undefined {
|
||||
const trimmed = baseUrl.trim();
|
||||
if (!trimmed) {
|
||||
return undefined;
|
||||
}
|
||||
try {
|
||||
const parsed = new URL(trimmed);
|
||||
if (parsed.protocol !== "http:" && parsed.protocol !== "https:") {
|
||||
return undefined;
|
||||
}
|
||||
// Keep policy tied to the configured host so private operator endpoints
|
||||
// continue to work, while cross-host redirects stay blocked.
|
||||
return { allowedHostnames: [parsed.hostname] };
|
||||
} catch {
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
export const buildRemoteBaseUrlPolicy = ssrfPolicyFromHttpBaseUrlAllowedHostname;
|
||||
|
||||
export async function withRemoteHttpResponse<T>(params: {
|
||||
url: string;
|
||||
|
||||
@@ -7,6 +7,7 @@ export {
|
||||
isBlockedHostnameOrIp,
|
||||
resolvePinnedHostname,
|
||||
resolvePinnedHostnameWithPolicy,
|
||||
ssrfPolicyFromHttpBaseUrlAllowedHostname,
|
||||
type LookupFn,
|
||||
type SsrFPolicy,
|
||||
} from "../infra/net/ssrf.js";
|
||||
|
||||
Reference in New Issue
Block a user