mirror of
https://github.com/openclaw/openclaw.git
synced 2026-03-23 16:01:17 +00:00
test: harden no-isolate timer and undici seams
This commit is contained in:
@@ -5,19 +5,24 @@ import fs from "node:fs/promises";
|
||||
import net from "node:net";
|
||||
import os from "node:os";
|
||||
import path from "node:path";
|
||||
import { setTimeout as nativeSleep } from "node:timers/promises";
|
||||
import { afterAll, afterEach, beforeAll, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { resolveConfigPath, resolveGatewayLockDir, resolveStateDir } from "../config/paths.js";
|
||||
import { resolveConfigPath, resolveStateDir } from "../config/paths.js";
|
||||
import { acquireGatewayLock, GatewayLockError, type GatewayLockOptions } from "./gateway-lock.js";
|
||||
|
||||
let fixtureRoot = "";
|
||||
let fixtureCount = 0;
|
||||
const realNow = Date.now.bind(Date);
|
||||
|
||||
function resolveTestLockDir() {
|
||||
return path.join(fixtureRoot, "__locks");
|
||||
}
|
||||
|
||||
async function makeEnv() {
|
||||
const dir = path.join(fixtureRoot, `case-${fixtureCount++}`);
|
||||
await fs.mkdir(dir, { recursive: true });
|
||||
const configPath = path.join(dir, "openclaw.json");
|
||||
await fs.writeFile(configPath, "{}", "utf8");
|
||||
await fs.mkdir(resolveGatewayLockDir(), { recursive: true });
|
||||
return {
|
||||
...process.env,
|
||||
OPENCLAW_STATE_DIR: dir,
|
||||
@@ -34,6 +39,11 @@ async function acquireForTest(
|
||||
allowInTests: true,
|
||||
timeoutMs: 30,
|
||||
pollIntervalMs: 2,
|
||||
now: realNow,
|
||||
sleep: async (ms) => {
|
||||
await nativeSleep(ms);
|
||||
},
|
||||
lockDir: resolveTestLockDir(),
|
||||
...opts,
|
||||
});
|
||||
}
|
||||
@@ -42,7 +52,7 @@ function resolveLockPath(env: NodeJS.ProcessEnv) {
|
||||
const stateDir = resolveStateDir(env);
|
||||
const configPath = resolveConfigPath(env, stateDir);
|
||||
const hash = createHash("sha256").update(configPath).digest("hex").slice(0, 8);
|
||||
const lockDir = resolveGatewayLockDir();
|
||||
const lockDir = resolveTestLockDir();
|
||||
return { lockPath: path.join(lockDir, `gateway.${hash}.lock`), configPath };
|
||||
}
|
||||
|
||||
@@ -145,6 +155,7 @@ describe("gateway lock", () => {
|
||||
beforeEach(() => {
|
||||
// Other suites occasionally leave global spies behind (Date.now, setTimeout, etc.).
|
||||
// This test relies on fake timers advancing Date.now and setTimeout deterministically.
|
||||
vi.useRealTimers();
|
||||
vi.restoreAllMocks();
|
||||
vi.unstubAllGlobals();
|
||||
});
|
||||
@@ -155,6 +166,7 @@ describe("gateway lock", () => {
|
||||
|
||||
afterEach(() => {
|
||||
vi.useRealTimers();
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
it("blocks concurrent acquisition until release", async () => {
|
||||
@@ -174,8 +186,6 @@ describe("gateway lock", () => {
|
||||
});
|
||||
|
||||
it("treats recycled linux pid as stale when start time mismatches", async () => {
|
||||
vi.useFakeTimers();
|
||||
vi.setSystemTime(new Date("2026-02-06T10:05:00.000Z"));
|
||||
const env = await makeEnv();
|
||||
const { lockPath, configPath } = resolveLockPath(env);
|
||||
const payload = createLockPayload({ configPath, startTime: 111 });
|
||||
@@ -274,6 +284,7 @@ describe("gateway lock", () => {
|
||||
const env = await makeEnv();
|
||||
const lock = await acquireGatewayLock({
|
||||
env: { ...env, OPENCLAW_ALLOW_MULTI_GATEWAY: "1", VITEST: "" },
|
||||
lockDir: resolveTestLockDir(),
|
||||
});
|
||||
expect(lock).toBeNull();
|
||||
});
|
||||
@@ -282,6 +293,7 @@ describe("gateway lock", () => {
|
||||
const env = await makeEnv();
|
||||
const lock = await acquireGatewayLock({
|
||||
env: { ...env, VITEST: "1" },
|
||||
lockDir: resolveTestLockDir(),
|
||||
});
|
||||
expect(lock).toBeNull();
|
||||
});
|
||||
|
||||
@@ -33,6 +33,9 @@ export type GatewayLockOptions = {
|
||||
allowInTests?: boolean;
|
||||
platform?: NodeJS.Platform;
|
||||
port?: number;
|
||||
now?: () => number;
|
||||
sleep?: (ms: number) => Promise<void>;
|
||||
lockDir?: string;
|
||||
};
|
||||
|
||||
export class GatewayLockError extends Error {
|
||||
@@ -161,11 +164,10 @@ async function readLockPayload(lockPath: string): Promise<LockPayload | null> {
|
||||
}
|
||||
}
|
||||
|
||||
function resolveGatewayLockPath(env: NodeJS.ProcessEnv) {
|
||||
function resolveGatewayLockPath(env: NodeJS.ProcessEnv, lockDir = resolveGatewayLockDir()) {
|
||||
const stateDir = resolveStateDir(env);
|
||||
const configPath = resolveConfigPath(env, stateDir);
|
||||
const hash = createHash("sha256").update(configPath).digest("hex").slice(0, 8);
|
||||
const lockDir = resolveGatewayLockDir();
|
||||
const lockPath = path.join(lockDir, `gateway.${hash}.lock`);
|
||||
return { lockPath, configPath };
|
||||
}
|
||||
@@ -187,19 +189,22 @@ export async function acquireGatewayLock(
|
||||
const staleMs = opts.staleMs ?? DEFAULT_STALE_MS;
|
||||
const platform = opts.platform ?? process.platform;
|
||||
const port = opts.port;
|
||||
const { lockPath, configPath } = resolveGatewayLockPath(env);
|
||||
const now = opts.now ?? Date.now;
|
||||
const sleep =
|
||||
opts.sleep ?? (async (ms: number) => await new Promise((resolve) => setTimeout(resolve, ms)));
|
||||
const { lockPath, configPath } = resolveGatewayLockPath(env, opts.lockDir);
|
||||
await fs.mkdir(path.dirname(lockPath), { recursive: true });
|
||||
|
||||
const startedAt = Date.now();
|
||||
const startedAt = now();
|
||||
let lastPayload: LockPayload | null = null;
|
||||
|
||||
while (Date.now() - startedAt < timeoutMs) {
|
||||
while (now() - startedAt < timeoutMs) {
|
||||
try {
|
||||
const handle = await fs.open(lockPath, "wx");
|
||||
const startTime = platform === "linux" ? readLinuxStartTime(process.pid) : null;
|
||||
const payload: LockPayload = {
|
||||
pid: process.pid,
|
||||
createdAt: new Date().toISOString(),
|
||||
createdAt: new Date(now()).toISOString(),
|
||||
configPath,
|
||||
};
|
||||
if (typeof startTime === "number" && Number.isFinite(startTime)) {
|
||||
@@ -233,12 +238,12 @@ export async function acquireGatewayLock(
|
||||
let stale = false;
|
||||
if (lastPayload?.createdAt) {
|
||||
const createdAt = Date.parse(lastPayload.createdAt);
|
||||
stale = Number.isFinite(createdAt) ? Date.now() - createdAt > staleMs : false;
|
||||
stale = Number.isFinite(createdAt) ? now() - createdAt > staleMs : false;
|
||||
}
|
||||
if (!stale) {
|
||||
try {
|
||||
const st = await fs.stat(lockPath);
|
||||
stale = Date.now() - st.mtimeMs > staleMs;
|
||||
stale = now() - st.mtimeMs > staleMs;
|
||||
} catch {
|
||||
// On Windows or locked filesystems we may be unable to stat the
|
||||
// lock file even though the existing gateway is still healthy.
|
||||
@@ -253,7 +258,7 @@ export async function acquireGatewayLock(
|
||||
}
|
||||
}
|
||||
|
||||
await new Promise((r) => setTimeout(r, pollIntervalMs));
|
||||
await sleep(pollIntervalMs);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,11 +1,21 @@
|
||||
import { ProxyAgent } from "undici";
|
||||
import { afterEach, describe, expect, it, vi } from "vitest";
|
||||
import { createRequire } from "node:module";
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
|
||||
const TEST_GAXIOS_CONSTRUCTOR_OVERRIDE = "__OPENCLAW_TEST_GAXIOS_CONSTRUCTOR__";
|
||||
type FetchLike = (input: RequestInfo | URL, init?: RequestInit) => Promise<Response>;
|
||||
let ProxyAgent: typeof import("undici").ProxyAgent;
|
||||
|
||||
beforeEach(async () => {
|
||||
vi.useRealTimers();
|
||||
vi.doUnmock("undici");
|
||||
vi.resetModules();
|
||||
const require = createRequire(import.meta.url);
|
||||
({ ProxyAgent } = require("undici") as typeof import("undici"));
|
||||
});
|
||||
|
||||
describe("gaxios fetch compat", () => {
|
||||
afterEach(() => {
|
||||
vi.doUnmock("undici");
|
||||
Reflect.deleteProperty(globalThis as object, TEST_GAXIOS_CONSTRUCTOR_OVERRIDE);
|
||||
vi.resetModules();
|
||||
vi.restoreAllMocks();
|
||||
|
||||
@@ -2,7 +2,6 @@ import { createRequire } from "node:module";
|
||||
import type { ConnectionOptions } from "node:tls";
|
||||
import { pathToFileURL } from "node:url";
|
||||
import type { Dispatcher } from "undici";
|
||||
import { Agent as UndiciAgent, ProxyAgent } from "undici";
|
||||
|
||||
type ProxyRule = RegExp | URL | string;
|
||||
type TlsCert = ConnectionOptions["cert"];
|
||||
@@ -40,6 +39,11 @@ const TEST_GAXIOS_CONSTRUCTOR_OVERRIDE = "__OPENCLAW_TEST_GAXIOS_CONSTRUCTOR__";
|
||||
|
||||
let installState: "not-installed" | "installing" | "shimmed" | "installed" = "not-installed";
|
||||
|
||||
type UndiciRuntimeDeps = {
|
||||
UndiciAgent: typeof import("undici").Agent;
|
||||
ProxyAgent: typeof import("undici").ProxyAgent;
|
||||
};
|
||||
|
||||
function isRecord(value: unknown): value is Record<string, unknown> {
|
||||
return typeof value === "object" && value !== null;
|
||||
}
|
||||
@@ -140,6 +144,15 @@ function resolveProxyUri(init: GaxiosFetchRequestInit, url: URL): string | undef
|
||||
return urlMayUseProxy(url, init.noProxy) ? envProxy : undefined;
|
||||
}
|
||||
|
||||
function loadUndiciRuntimeDeps(): UndiciRuntimeDeps {
|
||||
const require = createRequire(import.meta.url);
|
||||
const undici = require("undici") as typeof import("undici");
|
||||
return {
|
||||
ProxyAgent: undici.ProxyAgent,
|
||||
UndiciAgent: undici.Agent,
|
||||
};
|
||||
}
|
||||
|
||||
function buildDispatcher(init: GaxiosFetchRequestInit, url: URL): Dispatcher | undefined {
|
||||
if (init.dispatcher) {
|
||||
return init.dispatcher;
|
||||
@@ -154,6 +167,7 @@ function buildDispatcher(init: GaxiosFetchRequestInit, url: URL): Dispatcher | u
|
||||
const proxyUri =
|
||||
resolveProxyUri(init, url) ?? (hasProxyAgentShape(agent) ? String(agent.proxy) : undefined);
|
||||
if (proxyUri) {
|
||||
const { ProxyAgent } = loadUndiciRuntimeDeps();
|
||||
return new ProxyAgent({
|
||||
requestTls: cert !== undefined || key !== undefined ? { cert, key } : undefined,
|
||||
uri: proxyUri,
|
||||
@@ -161,6 +175,7 @@ function buildDispatcher(init: GaxiosFetchRequestInit, url: URL): Dispatcher | u
|
||||
}
|
||||
|
||||
if (cert !== undefined || key !== undefined) {
|
||||
const { UndiciAgent } = loadUndiciRuntimeDeps();
|
||||
return new UndiciAgent({
|
||||
connect: { cert, key },
|
||||
});
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { EventEmitter } from "node:events";
|
||||
import type { IncomingMessage } from "node:http";
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { createMockServerResponse } from "../test-utils/mock-http-response.js";
|
||||
import {
|
||||
installRequestBodyLimitGuard,
|
||||
@@ -104,6 +104,10 @@ function createMockRequest(params: {
|
||||
}
|
||||
|
||||
describe("http body limits", () => {
|
||||
beforeEach(() => {
|
||||
vi.useRealTimers();
|
||||
});
|
||||
|
||||
it("reads body within max bytes", async () => {
|
||||
const req = createMockRequest({ chunks: ['{"ok":true}'] });
|
||||
await expect(readRequestBodyWithLimit(req, { maxBytes: 1024 })).resolves.toBe('{"ok":true}');
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import type { IncomingMessage, ServerResponse } from "node:http";
|
||||
import { clearTimeout as clearNodeTimeout, setTimeout as setNodeTimeout } from "node:timers";
|
||||
|
||||
export const DEFAULT_WEBHOOK_MAX_BODY_BYTES = 1024 * 1024;
|
||||
export const DEFAULT_WEBHOOK_BODY_TIMEOUT_MS = 30_000;
|
||||
@@ -147,7 +148,7 @@ export async function readRequestBodyWithLimit(
|
||||
req.removeListener("end", onEnd);
|
||||
req.removeListener("error", onError);
|
||||
req.removeListener("close", onClose);
|
||||
clearTimeout(timer);
|
||||
clearNodeTimeout(timer);
|
||||
};
|
||||
|
||||
const finish = (cb: () => void) => {
|
||||
@@ -163,7 +164,7 @@ export async function readRequestBodyWithLimit(
|
||||
finish(() => reject(error));
|
||||
};
|
||||
|
||||
const timer = setTimeout(() => {
|
||||
const timer = setNodeTimeout(() => {
|
||||
const error = new RequestBodyLimitError({ code: "REQUEST_BODY_TIMEOUT" });
|
||||
if (!req.destroyed) {
|
||||
req.destroy();
|
||||
@@ -289,7 +290,7 @@ export function installRequestBodyLimitGuard(
|
||||
req.removeListener("end", onEnd);
|
||||
req.removeListener("close", onClose);
|
||||
req.removeListener("error", onError);
|
||||
clearTimeout(timer);
|
||||
clearNodeTimeout(timer);
|
||||
};
|
||||
|
||||
const finish = () => {
|
||||
@@ -356,7 +357,7 @@ export function installRequestBodyLimitGuard(
|
||||
finish();
|
||||
};
|
||||
|
||||
const timer = setTimeout(() => {
|
||||
const timer = setNodeTimeout(() => {
|
||||
trip(new RequestBodyLimitError({ code: "REQUEST_BODY_TIMEOUT" }));
|
||||
}, timeoutMs);
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import fs from "node:fs/promises";
|
||||
import os from "node:os";
|
||||
import path from "node:path";
|
||||
import { setTimeout as sleep } from "node:timers/promises";
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { createAsyncLock, readJsonFile, writeJsonAtomic, writeTextAtomic } from "./json-files.js";
|
||||
|
||||
@@ -50,7 +51,7 @@ describe("json file helpers", () => {
|
||||
|
||||
const first = withLock(async () => {
|
||||
events.push("first:start");
|
||||
await new Promise((resolve) => setTimeout(resolve, 20));
|
||||
await sleep(20);
|
||||
events.push("first:end");
|
||||
throw new Error("boom");
|
||||
});
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import net from "node:net";
|
||||
import { clearTimeout as clearNodeTimeout, setTimeout as setNodeTimeout } from "node:timers";
|
||||
|
||||
export async function requestJsonlSocket<T>(params: {
|
||||
socketPath: string;
|
||||
@@ -25,7 +26,7 @@ export async function requestJsonlSocket<T>(params: {
|
||||
resolve(value);
|
||||
};
|
||||
|
||||
const timer = setTimeout(() => finish(null), timeoutMs);
|
||||
const timer = setNodeTimeout(() => finish(null), timeoutMs);
|
||||
|
||||
client.on("error", () => finish(null));
|
||||
client.connect(socketPath, () => {
|
||||
@@ -47,7 +48,7 @@ export async function requestJsonlSocket<T>(params: {
|
||||
if (result === undefined) {
|
||||
continue;
|
||||
}
|
||||
clearTimeout(timer);
|
||||
clearNodeTimeout(timer);
|
||||
finish(result);
|
||||
return;
|
||||
} catch {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { EnvHttpProxyAgent, type Dispatcher } from "undici";
|
||||
import type { Dispatcher } from "undici";
|
||||
import { logWarn } from "../../logger.js";
|
||||
import { bindAbortRelay } from "../../utils/fetch-timeout.js";
|
||||
import { hasProxyEnvConfigured } from "./proxy-env.js";
|
||||
@@ -11,6 +11,7 @@ import {
|
||||
SsrFBlockedError,
|
||||
type SsrFPolicy,
|
||||
} from "./ssrf.js";
|
||||
import { loadUndiciRuntimeDeps } from "./undici-runtime.js";
|
||||
|
||||
type FetchLike = (input: RequestInfo | URL, init?: RequestInit) => Promise<Response>;
|
||||
|
||||
@@ -196,6 +197,7 @@ export async function fetchWithSsrFGuard(params: GuardedFetchOptions): Promise<G
|
||||
const canUseTrustedEnvProxy =
|
||||
mode === GUARDED_FETCH_MODE.TRUSTED_ENV_PROXY && hasProxyEnvConfigured();
|
||||
if (canUseTrustedEnvProxy) {
|
||||
const { EnvHttpProxyAgent } = loadUndiciRuntimeDeps();
|
||||
dispatcher = new EnvHttpProxyAgent();
|
||||
} else if (params.pinDns !== false) {
|
||||
dispatcher = createPinnedDispatcher(pinned, params.dispatcherPolicy, params.policy);
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { TEST_UNDICI_RUNTIME_DEPS_KEY } from "./undici-runtime.js";
|
||||
|
||||
const { agentCtor, envHttpProxyAgentCtor, proxyAgentCtor } = vi.hoisted(() => ({
|
||||
agentCtor: vi.fn(function MockAgent(this: { options: unknown }, options: unknown) {
|
||||
@@ -15,21 +16,27 @@ const { agentCtor, envHttpProxyAgentCtor, proxyAgentCtor } = vi.hoisted(() => ({
|
||||
}),
|
||||
}));
|
||||
|
||||
vi.mock("undici", () => ({
|
||||
Agent: agentCtor,
|
||||
EnvHttpProxyAgent: envHttpProxyAgentCtor,
|
||||
ProxyAgent: proxyAgentCtor,
|
||||
}));
|
||||
|
||||
import type { PinnedHostname } from "./ssrf.js";
|
||||
|
||||
let createPinnedDispatcher: typeof import("./ssrf.js").createPinnedDispatcher;
|
||||
|
||||
beforeEach(async () => {
|
||||
vi.resetModules();
|
||||
agentCtor.mockClear();
|
||||
envHttpProxyAgentCtor.mockClear();
|
||||
proxyAgentCtor.mockClear();
|
||||
(globalThis as Record<string, unknown>)[TEST_UNDICI_RUNTIME_DEPS_KEY] = {
|
||||
Agent: agentCtor,
|
||||
EnvHttpProxyAgent: envHttpProxyAgentCtor,
|
||||
ProxyAgent: proxyAgentCtor,
|
||||
};
|
||||
({ createPinnedDispatcher } = await import("./ssrf.js"));
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
Reflect.deleteProperty(globalThis as object, TEST_UNDICI_RUNTIME_DEPS_KEY);
|
||||
});
|
||||
|
||||
describe("createPinnedDispatcher", () => {
|
||||
it("uses pinned lookup without overriding global family policy", () => {
|
||||
const lookup = vi.fn() as unknown as PinnedHostname["lookup"];
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { lookup as dnsLookupCb, type LookupAddress } from "node:dns";
|
||||
import { lookup as dnsLookup } from "node:dns/promises";
|
||||
import { Agent, EnvHttpProxyAgent, ProxyAgent, type Dispatcher } from "undici";
|
||||
import type { Dispatcher } from "undici";
|
||||
import {
|
||||
extractEmbeddedIpv4FromIpv6,
|
||||
isBlockedSpecialUseIpv4Address,
|
||||
@@ -13,6 +13,7 @@ import {
|
||||
parseLooseIpAddress,
|
||||
} from "../../shared/net/ip.js";
|
||||
import { normalizeHostname } from "./hostname.js";
|
||||
import { loadUndiciRuntimeDeps } from "./undici-runtime.js";
|
||||
|
||||
type LookupCallback = (
|
||||
err: NodeJS.ErrnoException | null,
|
||||
@@ -400,6 +401,7 @@ export function createPinnedDispatcher(
|
||||
policy?: PinnedDispatcherPolicy,
|
||||
ssrfPolicy?: SsrFPolicy,
|
||||
): Dispatcher {
|
||||
const { Agent, EnvHttpProxyAgent, ProxyAgent } = loadUndiciRuntimeDeps();
|
||||
const lookup = resolvePinnedDispatcherLookup(pinned, policy?.pinnedHostname, ssrfPolicy);
|
||||
|
||||
if (!policy || policy.mode === "direct") {
|
||||
|
||||
34
src/infra/net/undici-runtime.ts
Normal file
34
src/infra/net/undici-runtime.ts
Normal file
@@ -0,0 +1,34 @@
|
||||
import { createRequire } from "node:module";
|
||||
|
||||
export const TEST_UNDICI_RUNTIME_DEPS_KEY = "__OPENCLAW_TEST_UNDICI_RUNTIME_DEPS__";
|
||||
|
||||
export type UndiciRuntimeDeps = {
|
||||
Agent: typeof import("undici").Agent;
|
||||
EnvHttpProxyAgent: typeof import("undici").EnvHttpProxyAgent;
|
||||
ProxyAgent: typeof import("undici").ProxyAgent;
|
||||
};
|
||||
|
||||
function isUndiciRuntimeDeps(value: unknown): value is UndiciRuntimeDeps {
|
||||
return (
|
||||
typeof value === "object" &&
|
||||
value !== null &&
|
||||
typeof (value as UndiciRuntimeDeps).Agent === "function" &&
|
||||
typeof (value as UndiciRuntimeDeps).EnvHttpProxyAgent === "function" &&
|
||||
typeof (value as UndiciRuntimeDeps).ProxyAgent === "function"
|
||||
);
|
||||
}
|
||||
|
||||
export function loadUndiciRuntimeDeps(): UndiciRuntimeDeps {
|
||||
const override = (globalThis as Record<string, unknown>)[TEST_UNDICI_RUNTIME_DEPS_KEY];
|
||||
if (isUndiciRuntimeDeps(override)) {
|
||||
return override;
|
||||
}
|
||||
|
||||
const require = createRequire(import.meta.url);
|
||||
const undici = require("undici") as typeof import("undici");
|
||||
return {
|
||||
Agent: undici.Agent,
|
||||
EnvHttpProxyAgent: undici.EnvHttpProxyAgent,
|
||||
ProxyAgent: undici.ProxyAgent,
|
||||
};
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import {
|
||||
shouldRetryTelegramTransportFallback,
|
||||
} from "../../extensions/telegram/src/fetch.js";
|
||||
import { makeProxyFetch } from "../../extensions/telegram/src/proxy.js";
|
||||
import { TEST_UNDICI_RUNTIME_DEPS_KEY } from "../infra/net/undici-runtime.js";
|
||||
import { fetchRemoteMedia } from "./fetch.js";
|
||||
|
||||
const undiciMocks = vi.hoisted(() => {
|
||||
@@ -35,6 +36,11 @@ describe("fetchRemoteMedia telegram network policy", () => {
|
||||
undiciMocks.agentCtor.mockClear();
|
||||
undiciMocks.envHttpProxyAgentCtor.mockClear();
|
||||
undiciMocks.proxyAgentCtor.mockClear();
|
||||
(globalThis as Record<string, unknown>)[TEST_UNDICI_RUNTIME_DEPS_KEY] = {
|
||||
Agent: undiciMocks.agentCtor,
|
||||
EnvHttpProxyAgent: undiciMocks.envHttpProxyAgentCtor,
|
||||
ProxyAgent: undiciMocks.proxyAgentCtor,
|
||||
};
|
||||
});
|
||||
|
||||
function createTelegramFetchFailedError(code: string): Error {
|
||||
@@ -44,6 +50,7 @@ describe("fetchRemoteMedia telegram network policy", () => {
|
||||
}
|
||||
|
||||
afterEach(() => {
|
||||
Reflect.deleteProperty(globalThis as object, TEST_UNDICI_RUNTIME_DEPS_KEY);
|
||||
vi.unstubAllEnvs();
|
||||
});
|
||||
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import fs from "node:fs/promises";
|
||||
import { createRequire } from "node:module";
|
||||
import type { AddressInfo } from "node:net";
|
||||
import os from "node:os";
|
||||
import path from "node:path";
|
||||
import { afterAll, beforeAll, describe, expect, it, vi } from "vitest";
|
||||
import { afterAll, beforeAll, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
|
||||
const mocks = vi.hoisted(() => ({
|
||||
readFileWithinRoot: vi.fn(),
|
||||
@@ -28,19 +29,32 @@ vi.mock("./store.js", async (importOriginal) => {
|
||||
};
|
||||
});
|
||||
|
||||
const { SafeOpenError } = await import("../infra/fs-safe.js");
|
||||
const { startMediaServer } = await import("./server.js");
|
||||
let SafeOpenError: typeof import("../infra/fs-safe.js").SafeOpenError;
|
||||
let startMediaServer: typeof import("./server.js").startMediaServer;
|
||||
let realFetch: typeof import("undici").fetch;
|
||||
|
||||
describe("media server outside-workspace mapping", () => {
|
||||
let server: Awaited<ReturnType<typeof startMediaServer>>;
|
||||
let port = 0;
|
||||
|
||||
beforeAll(async () => {
|
||||
vi.useRealTimers();
|
||||
vi.doUnmock("undici");
|
||||
vi.resetModules();
|
||||
const require = createRequire(import.meta.url);
|
||||
({ SafeOpenError } = await import("../infra/fs-safe.js"));
|
||||
({ startMediaServer } = await import("./server.js"));
|
||||
({ fetch: realFetch } = require("undici") as typeof import("undici"));
|
||||
mediaDir = await fs.mkdtemp(path.join(os.tmpdir(), "openclaw-media-outside-workspace-"));
|
||||
server = await startMediaServer(0, 1_000);
|
||||
port = (server.address() as AddressInfo).port;
|
||||
});
|
||||
|
||||
beforeEach(() => {
|
||||
mocks.readFileWithinRoot.mockReset();
|
||||
mocks.cleanOldMedia.mockClear();
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
await new Promise((resolve) => server.close(resolve));
|
||||
await fs.rm(mediaDir, { recursive: true, force: true });
|
||||
@@ -52,7 +66,7 @@ describe("media server outside-workspace mapping", () => {
|
||||
new SafeOpenError("outside-workspace", "file is outside workspace root"),
|
||||
);
|
||||
|
||||
const response = await fetch(`http://127.0.0.1:${port}/media/ok-id`);
|
||||
const response = await realFetch(`http://127.0.0.1:${port}/media/ok-id`);
|
||||
expect(response.status).toBe(400);
|
||||
expect(await response.text()).toBe("file is outside workspace root");
|
||||
});
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import fs from "node:fs/promises";
|
||||
import { createRequire } from "node:module";
|
||||
import type { AddressInfo } from "node:net";
|
||||
import os from "node:os";
|
||||
import path from "node:path";
|
||||
@@ -16,8 +17,9 @@ vi.mock("./store.js", async (importOriginal) => {
|
||||
};
|
||||
});
|
||||
|
||||
const { startMediaServer } = await import("./server.js");
|
||||
const { MEDIA_MAX_BYTES } = await import("./store.js");
|
||||
let startMediaServer: typeof import("./server.js").startMediaServer;
|
||||
let MEDIA_MAX_BYTES: typeof import("./store.js").MEDIA_MAX_BYTES;
|
||||
let realFetch: typeof import("undici").fetch;
|
||||
|
||||
async function waitForFileRemoval(filePath: string, maxTicks = 1000) {
|
||||
for (let tick = 0; tick < maxTicks; tick += 1) {
|
||||
@@ -46,6 +48,13 @@ describe("media server", () => {
|
||||
}
|
||||
|
||||
beforeAll(async () => {
|
||||
vi.useRealTimers();
|
||||
vi.doUnmock("undici");
|
||||
vi.resetModules();
|
||||
const require = createRequire(import.meta.url);
|
||||
({ startMediaServer } = await import("./server.js"));
|
||||
({ MEDIA_MAX_BYTES } = await import("./store.js"));
|
||||
({ fetch: realFetch } = require("undici") as typeof import("undici"));
|
||||
MEDIA_DIR = await fs.mkdtemp(path.join(os.tmpdir(), "openclaw-media-test-"));
|
||||
server = await startMediaServer(0, 1_000);
|
||||
port = (server.address() as AddressInfo).port;
|
||||
@@ -59,7 +68,7 @@ describe("media server", () => {
|
||||
|
||||
it("serves media and cleans up after send", async () => {
|
||||
const file = await writeMediaFile("file1", "hello");
|
||||
const res = await fetch(mediaUrl("file1"));
|
||||
const res = await realFetch(mediaUrl("file1"));
|
||||
expect(res.status).toBe(200);
|
||||
expect(res.headers.get("x-content-type-options")).toBe("nosniff");
|
||||
expect(await res.text()).toBe("hello");
|
||||
@@ -70,7 +79,7 @@ describe("media server", () => {
|
||||
const file = await writeMediaFile("old", "stale");
|
||||
const past = Date.now() - 10_000;
|
||||
await fs.utimes(file, past / 1000, past / 1000);
|
||||
const res = await fetch(mediaUrl("old"));
|
||||
const res = await realFetch(mediaUrl("old"));
|
||||
expect(res.status).toBe(410);
|
||||
await expect(fs.stat(file)).rejects.toThrow();
|
||||
});
|
||||
@@ -98,7 +107,7 @@ describe("media server", () => {
|
||||
},
|
||||
] as const)("$testName", async (testCase) => {
|
||||
await testCase.setup?.();
|
||||
const res = await fetch(mediaUrl(testCase.mediaPath));
|
||||
const res = await realFetch(mediaUrl(testCase.mediaPath));
|
||||
expect(res.status).toBe(400);
|
||||
expect(await res.text()).toBe("invalid path");
|
||||
});
|
||||
@@ -106,25 +115,25 @@ describe("media server", () => {
|
||||
it("rejects oversized media files", async () => {
|
||||
const file = await writeMediaFile("big", "");
|
||||
await fs.truncate(file, MEDIA_MAX_BYTES + 1);
|
||||
const res = await fetch(mediaUrl("big"));
|
||||
const res = await realFetch(mediaUrl("big"));
|
||||
expect(res.status).toBe(413);
|
||||
expect(await res.text()).toBe("too large");
|
||||
});
|
||||
|
||||
it("returns not found for missing media IDs", async () => {
|
||||
const res = await fetch(mediaUrl("missing-file"));
|
||||
const res = await realFetch(mediaUrl("missing-file"));
|
||||
expect(res.status).toBe(404);
|
||||
expect(res.headers.get("x-content-type-options")).toBe("nosniff");
|
||||
expect(await res.text()).toBe("not found");
|
||||
});
|
||||
|
||||
it("returns 404 when route param is missing (dot path)", async () => {
|
||||
const res = await fetch(mediaUrl("."));
|
||||
const res = await realFetch(mediaUrl("."));
|
||||
expect(res.status).toBe(404);
|
||||
});
|
||||
|
||||
it("rejects overlong media id", async () => {
|
||||
const res = await fetch(mediaUrl(`${"a".repeat(201)}.txt`));
|
||||
const res = await realFetch(mediaUrl(`${"a".repeat(201)}.txt`));
|
||||
expect(res.status).toBe(400);
|
||||
expect(await res.text()).toBe("invalid path");
|
||||
});
|
||||
|
||||
@@ -1,36 +1,16 @@
|
||||
import { ReadableStream } from "node:stream/web";
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import type { VoyageBatchOutputLine, VoyageBatchRequest } from "./batch-voyage.js";
|
||||
import { setTimeout as nativeSleep } from "node:timers/promises";
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
import {
|
||||
runVoyageEmbeddingBatches,
|
||||
type VoyageBatchOutputLine,
|
||||
type VoyageBatchRequest,
|
||||
} from "./batch-voyage.js";
|
||||
import type { VoyageEmbeddingClient } from "./embeddings-voyage.js";
|
||||
|
||||
// Mock internal.js if needed, but runWithConcurrency is simple enough to keep real.
|
||||
// We DO need to mock retryAsync to avoid actual delays/retries logic complicating tests
|
||||
vi.mock("../infra/retry.js", () => ({
|
||||
retryAsync: async <T>(fn: () => Promise<T>) => fn(),
|
||||
}));
|
||||
|
||||
vi.mock("./remote-http.js", () => ({
|
||||
withRemoteHttpResponse: vi.fn(),
|
||||
}));
|
||||
const realNow = Date.now.bind(Date);
|
||||
|
||||
describe("runVoyageEmbeddingBatches", () => {
|
||||
let runVoyageEmbeddingBatches: typeof import("./batch-voyage.js").runVoyageEmbeddingBatches;
|
||||
let withRemoteHttpResponse: typeof import("./remote-http.js").withRemoteHttpResponse;
|
||||
let remoteHttpMock: ReturnType<typeof vi.mocked<typeof withRemoteHttpResponse>>;
|
||||
|
||||
beforeEach(async () => {
|
||||
vi.resetModules();
|
||||
vi.clearAllMocks();
|
||||
({ runVoyageEmbeddingBatches } = await import("./batch-voyage.js"));
|
||||
({ withRemoteHttpResponse } = await import("./remote-http.js"));
|
||||
remoteHttpMock = vi.mocked(withRemoteHttpResponse);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.resetAllMocks();
|
||||
vi.unstubAllGlobals();
|
||||
});
|
||||
|
||||
const mockClient: VoyageEmbeddingClient = {
|
||||
baseUrl: "https://api.voyageai.com/v1",
|
||||
headers: { Authorization: "Bearer test-key" },
|
||||
@@ -53,6 +33,9 @@ describe("runVoyageEmbeddingBatches", () => {
|
||||
response: { status_code: 200, body: { data: [{ embedding: [0.2, 0.2] }] } },
|
||||
},
|
||||
];
|
||||
const withRemoteHttpResponse = vi.fn();
|
||||
const postJsonWithRetry = vi.fn();
|
||||
const uploadBatchJsonlFile = vi.fn();
|
||||
|
||||
// Create a stream that emits the NDJSON lines
|
||||
const stream = new ReadableStream({
|
||||
@@ -62,41 +45,27 @@ describe("runVoyageEmbeddingBatches", () => {
|
||||
controller.close();
|
||||
},
|
||||
});
|
||||
remoteHttpMock.mockImplementationOnce(async (params) => {
|
||||
expect(params.url).toContain("/files");
|
||||
const uploadBody = params.init?.body;
|
||||
expect(uploadBody).toBeInstanceOf(FormData);
|
||||
expect((uploadBody as FormData).get("purpose")).toBe("batch");
|
||||
return await params.onResponse(
|
||||
new Response(JSON.stringify({ id: "file-123" }), {
|
||||
status: 200,
|
||||
headers: { "Content-Type": "application/json" },
|
||||
}),
|
||||
);
|
||||
uploadBatchJsonlFile.mockImplementationOnce(async (params) => {
|
||||
expect(params.errorPrefix).toBe("voyage batch file upload failed");
|
||||
expect(params.requests).toEqual(mockRequests);
|
||||
return "file-123";
|
||||
});
|
||||
remoteHttpMock.mockImplementationOnce(async (params) => {
|
||||
postJsonWithRetry.mockImplementationOnce(async (params) => {
|
||||
expect(params.url).toContain("/batches");
|
||||
const body = params.init?.body;
|
||||
expect(typeof body).toBe("string");
|
||||
const createBody = JSON.parse(body as string) as {
|
||||
input_file_id: string;
|
||||
completion_window: string;
|
||||
request_params: { model: string; input_type: string };
|
||||
};
|
||||
expect(createBody.input_file_id).toBe("file-123");
|
||||
expect(createBody.completion_window).toBe("12h");
|
||||
expect(createBody.request_params).toEqual({
|
||||
model: "voyage-4-large",
|
||||
input_type: "document",
|
||||
expect(params.body).toMatchObject({
|
||||
input_file_id: "file-123",
|
||||
completion_window: "12h",
|
||||
request_params: {
|
||||
model: "voyage-4-large",
|
||||
input_type: "document",
|
||||
},
|
||||
});
|
||||
return await params.onResponse(
|
||||
new Response(JSON.stringify({ id: "batch-abc", status: "pending" }), {
|
||||
status: 200,
|
||||
headers: { "Content-Type": "application/json" },
|
||||
}),
|
||||
);
|
||||
return {
|
||||
id: "batch-abc",
|
||||
status: "pending",
|
||||
};
|
||||
});
|
||||
remoteHttpMock.mockImplementationOnce(async (params) => {
|
||||
withRemoteHttpResponse.mockImplementationOnce(async (params) => {
|
||||
expect(params.url).toContain("/batches/batch-abc");
|
||||
return await params.onResponse(
|
||||
new Response(
|
||||
@@ -112,7 +81,7 @@ describe("runVoyageEmbeddingBatches", () => {
|
||||
),
|
||||
);
|
||||
});
|
||||
remoteHttpMock.mockImplementationOnce(async (params) => {
|
||||
withRemoteHttpResponse.mockImplementationOnce(async (params) => {
|
||||
expect(params.url).toContain("/files/file-out-999/content");
|
||||
return await params.onResponse(
|
||||
new Response(stream as unknown as BodyInit, {
|
||||
@@ -130,15 +99,29 @@ describe("runVoyageEmbeddingBatches", () => {
|
||||
pollIntervalMs: 1, // fast poll
|
||||
timeoutMs: 1000,
|
||||
concurrency: 1,
|
||||
deps: {
|
||||
now: realNow,
|
||||
sleep: async (ms) => {
|
||||
await nativeSleep(ms);
|
||||
},
|
||||
postJsonWithRetry,
|
||||
uploadBatchJsonlFile,
|
||||
withRemoteHttpResponse,
|
||||
},
|
||||
});
|
||||
|
||||
expect(results.size).toBe(2);
|
||||
expect(results.get("req-1")).toEqual([0.1, 0.1]);
|
||||
expect(results.get("req-2")).toEqual([0.2, 0.2]);
|
||||
expect(remoteHttpMock).toHaveBeenCalledTimes(4);
|
||||
expect(uploadBatchJsonlFile).toHaveBeenCalledTimes(1);
|
||||
expect(postJsonWithRetry).toHaveBeenCalledTimes(1);
|
||||
expect(withRemoteHttpResponse).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it("handles empty lines and stream chunks correctly", async () => {
|
||||
const withRemoteHttpResponse = vi.fn();
|
||||
const postJsonWithRetry = vi.fn();
|
||||
const uploadBatchJsonlFile = vi.fn();
|
||||
const stream = new ReadableStream({
|
||||
start(controller) {
|
||||
const line1 = JSON.stringify({
|
||||
@@ -157,19 +140,13 @@ describe("runVoyageEmbeddingBatches", () => {
|
||||
controller.close();
|
||||
},
|
||||
});
|
||||
remoteHttpMock.mockImplementationOnce(async (params) => {
|
||||
expect(params.url).toContain("/files");
|
||||
return await params.onResponse(new Response(JSON.stringify({ id: "f1" }), { status: 200 }));
|
||||
uploadBatchJsonlFile.mockResolvedValueOnce("f1");
|
||||
postJsonWithRetry.mockResolvedValueOnce({
|
||||
id: "b1",
|
||||
status: "completed",
|
||||
output_file_id: "out1",
|
||||
});
|
||||
remoteHttpMock.mockImplementationOnce(async (params) => {
|
||||
expect(params.url).toContain("/batches");
|
||||
return await params.onResponse(
|
||||
new Response(JSON.stringify({ id: "b1", status: "completed", output_file_id: "out1" }), {
|
||||
status: 200,
|
||||
}),
|
||||
);
|
||||
});
|
||||
remoteHttpMock.mockImplementationOnce(async (params) => {
|
||||
withRemoteHttpResponse.mockImplementationOnce(async (params) => {
|
||||
expect(params.url).toContain("/files/out1/content");
|
||||
return await params.onResponse(new Response(stream as unknown as BodyInit, { status: 200 }));
|
||||
});
|
||||
@@ -182,6 +159,15 @@ describe("runVoyageEmbeddingBatches", () => {
|
||||
pollIntervalMs: 1,
|
||||
timeoutMs: 1000,
|
||||
concurrency: 1,
|
||||
deps: {
|
||||
now: realNow,
|
||||
sleep: async (ms) => {
|
||||
await nativeSleep(ms);
|
||||
},
|
||||
postJsonWithRetry,
|
||||
uploadBatchJsonlFile,
|
||||
withRemoteHttpResponse,
|
||||
},
|
||||
});
|
||||
|
||||
expect(results.get("req-1")).toEqual([1]);
|
||||
|
||||
@@ -40,6 +40,26 @@ export const VOYAGE_BATCH_ENDPOINT = EMBEDDING_BATCH_ENDPOINT;
|
||||
const VOYAGE_BATCH_COMPLETION_WINDOW = "12h";
|
||||
const VOYAGE_BATCH_MAX_REQUESTS = 50000;
|
||||
|
||||
type VoyageBatchDeps = {
|
||||
now: () => number;
|
||||
sleep: (ms: number) => Promise<void>;
|
||||
postJsonWithRetry: typeof postJsonWithRetry;
|
||||
uploadBatchJsonlFile: typeof uploadBatchJsonlFile;
|
||||
withRemoteHttpResponse: typeof withRemoteHttpResponse;
|
||||
};
|
||||
|
||||
function resolveVoyageBatchDeps(overrides: Partial<VoyageBatchDeps> | undefined): VoyageBatchDeps {
|
||||
return {
|
||||
now: overrides?.now ?? Date.now,
|
||||
sleep:
|
||||
overrides?.sleep ??
|
||||
(async (ms: number) => await new Promise((resolve) => setTimeout(resolve, ms))),
|
||||
postJsonWithRetry: overrides?.postJsonWithRetry ?? postJsonWithRetry,
|
||||
uploadBatchJsonlFile: overrides?.uploadBatchJsonlFile ?? uploadBatchJsonlFile,
|
||||
withRemoteHttpResponse: overrides?.withRemoteHttpResponse ?? withRemoteHttpResponse,
|
||||
};
|
||||
}
|
||||
|
||||
async function assertVoyageResponseOk(res: Response, context: string): Promise<void> {
|
||||
if (!res.ok) {
|
||||
const text = await res.text();
|
||||
@@ -67,16 +87,17 @@ async function submitVoyageBatch(params: {
|
||||
client: VoyageEmbeddingClient;
|
||||
requests: VoyageBatchRequest[];
|
||||
agentId: string;
|
||||
deps: VoyageBatchDeps;
|
||||
}): Promise<VoyageBatchStatus> {
|
||||
const baseUrl = normalizeBatchBaseUrl(params.client);
|
||||
const inputFileId = await uploadBatchJsonlFile({
|
||||
const inputFileId = await params.deps.uploadBatchJsonlFile({
|
||||
client: params.client,
|
||||
requests: params.requests,
|
||||
errorPrefix: "voyage batch file upload failed",
|
||||
});
|
||||
|
||||
// 2. Create batch job using Voyage Batches API
|
||||
return await postJsonWithRetry<VoyageBatchStatus>({
|
||||
return await params.deps.postJsonWithRetry<VoyageBatchStatus>({
|
||||
url: `${baseUrl}/batches`,
|
||||
headers: buildBatchHeaders(params.client, { json: true }),
|
||||
ssrfPolicy: params.client.ssrfPolicy,
|
||||
@@ -100,8 +121,9 @@ async function submitVoyageBatch(params: {
|
||||
async function fetchVoyageBatchStatus(params: {
|
||||
client: VoyageEmbeddingClient;
|
||||
batchId: string;
|
||||
deps: VoyageBatchDeps;
|
||||
}): Promise<VoyageBatchStatus> {
|
||||
return await withRemoteHttpResponse(
|
||||
return await params.deps.withRemoteHttpResponse(
|
||||
buildVoyageBatchRequest({
|
||||
client: params.client,
|
||||
path: `batches/${params.batchId}`,
|
||||
@@ -116,9 +138,10 @@ async function fetchVoyageBatchStatus(params: {
|
||||
async function readVoyageBatchError(params: {
|
||||
client: VoyageEmbeddingClient;
|
||||
errorFileId: string;
|
||||
deps: VoyageBatchDeps;
|
||||
}): Promise<string | undefined> {
|
||||
try {
|
||||
return await withRemoteHttpResponse(
|
||||
return await params.deps.withRemoteHttpResponse(
|
||||
buildVoyageBatchRequest({
|
||||
client: params.client,
|
||||
path: `files/${params.errorFileId}/content`,
|
||||
@@ -150,8 +173,9 @@ async function waitForVoyageBatch(params: {
|
||||
timeoutMs: number;
|
||||
debug?: (message: string, data?: Record<string, unknown>) => void;
|
||||
initial?: VoyageBatchStatus;
|
||||
deps: VoyageBatchDeps;
|
||||
}): Promise<BatchCompletionResult> {
|
||||
const start = Date.now();
|
||||
const start = params.deps.now();
|
||||
let current: VoyageBatchStatus | undefined = params.initial;
|
||||
while (true) {
|
||||
const status =
|
||||
@@ -159,6 +183,7 @@ async function waitForVoyageBatch(params: {
|
||||
(await fetchVoyageBatchStatus({
|
||||
client: params.client,
|
||||
batchId: params.batchId,
|
||||
deps: params.deps,
|
||||
}));
|
||||
const state = status.status ?? "unknown";
|
||||
if (state === "completed") {
|
||||
@@ -175,16 +200,17 @@ async function waitForVoyageBatch(params: {
|
||||
await readVoyageBatchError({
|
||||
client: params.client,
|
||||
errorFileId,
|
||||
deps: params.deps,
|
||||
}),
|
||||
});
|
||||
if (!params.wait) {
|
||||
throw new Error(`voyage batch ${params.batchId} still ${state}; wait disabled`);
|
||||
}
|
||||
if (Date.now() - start > params.timeoutMs) {
|
||||
if (params.deps.now() - start > params.timeoutMs) {
|
||||
throw new Error(`voyage batch ${params.batchId} timed out after ${params.timeoutMs}ms`);
|
||||
}
|
||||
params.debug?.(`voyage batch ${params.batchId} ${state}; waiting ${params.pollIntervalMs}ms`);
|
||||
await new Promise((resolve) => setTimeout(resolve, params.pollIntervalMs));
|
||||
await params.deps.sleep(params.pollIntervalMs);
|
||||
current = undefined;
|
||||
}
|
||||
}
|
||||
@@ -194,8 +220,10 @@ export async function runVoyageEmbeddingBatches(
|
||||
client: VoyageEmbeddingClient;
|
||||
agentId: string;
|
||||
requests: VoyageBatchRequest[];
|
||||
deps?: Partial<VoyageBatchDeps>;
|
||||
} & EmbeddingBatchExecutionParams,
|
||||
): Promise<Map<string, number[]>> {
|
||||
const deps = resolveVoyageBatchDeps(params.deps);
|
||||
return await runEmbeddingBatchGroups({
|
||||
...buildEmbeddingBatchGroupOptions(params, {
|
||||
maxRequests: VOYAGE_BATCH_MAX_REQUESTS,
|
||||
@@ -206,6 +234,7 @@ export async function runVoyageEmbeddingBatches(
|
||||
client: params.client,
|
||||
requests: group,
|
||||
agentId: params.agentId,
|
||||
deps,
|
||||
});
|
||||
if (!batchInfo.id) {
|
||||
throw new Error("voyage batch create failed: missing batch id");
|
||||
@@ -233,6 +262,7 @@ export async function runVoyageEmbeddingBatches(
|
||||
timeoutMs: params.timeoutMs,
|
||||
debug: params.debug,
|
||||
initial: batchInfo,
|
||||
deps,
|
||||
}),
|
||||
});
|
||||
|
||||
@@ -240,7 +270,7 @@ export async function runVoyageEmbeddingBatches(
|
||||
const errors: string[] = [];
|
||||
const remaining = new Set(group.map((request) => request.custom_id));
|
||||
|
||||
await withRemoteHttpResponse({
|
||||
await deps.withRemoteHttpResponse({
|
||||
url: `${baseUrl}/files/${completed.outputFileId}/content`,
|
||||
ssrfPolicy: params.client.ssrfPolicy,
|
||||
init: {
|
||||
|
||||
@@ -1,14 +1,5 @@
|
||||
import { afterEach, describe, expect, it, vi } from "vitest";
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import * as authModule from "../agents/model-auth.js";
|
||||
import {
|
||||
buildGeminiEmbeddingRequest,
|
||||
buildGeminiTextEmbeddingRequest,
|
||||
createGeminiEmbeddingProvider,
|
||||
DEFAULT_GEMINI_EMBEDDING_MODEL,
|
||||
GEMINI_EMBEDDING_2_MODELS,
|
||||
isGeminiEmbedding2Model,
|
||||
resolveGeminiOutputDimensionality,
|
||||
} from "./embeddings-gemini.js";
|
||||
import { mockPublicPinnedHostname } from "./test-helpers/ssrf.js";
|
||||
|
||||
vi.mock("../agents/model-auth.js", async () => {
|
||||
@@ -46,7 +37,31 @@ function magnitude(values: number[]) {
|
||||
return Math.sqrt(values.reduce((sum, value) => sum + value * value, 0));
|
||||
}
|
||||
|
||||
let buildGeminiEmbeddingRequest: typeof import("./embeddings-gemini.js").buildGeminiEmbeddingRequest;
|
||||
let buildGeminiTextEmbeddingRequest: typeof import("./embeddings-gemini.js").buildGeminiTextEmbeddingRequest;
|
||||
let createGeminiEmbeddingProvider: typeof import("./embeddings-gemini.js").createGeminiEmbeddingProvider;
|
||||
let DEFAULT_GEMINI_EMBEDDING_MODEL: typeof import("./embeddings-gemini.js").DEFAULT_GEMINI_EMBEDDING_MODEL;
|
||||
let GEMINI_EMBEDDING_2_MODELS: typeof import("./embeddings-gemini.js").GEMINI_EMBEDDING_2_MODELS;
|
||||
let isGeminiEmbedding2Model: typeof import("./embeddings-gemini.js").isGeminiEmbedding2Model;
|
||||
let resolveGeminiOutputDimensionality: typeof import("./embeddings-gemini.js").resolveGeminiOutputDimensionality;
|
||||
|
||||
beforeEach(async () => {
|
||||
vi.useRealTimers();
|
||||
vi.doUnmock("undici");
|
||||
vi.resetModules();
|
||||
({
|
||||
buildGeminiEmbeddingRequest,
|
||||
buildGeminiTextEmbeddingRequest,
|
||||
createGeminiEmbeddingProvider,
|
||||
DEFAULT_GEMINI_EMBEDDING_MODEL,
|
||||
GEMINI_EMBEDDING_2_MODELS,
|
||||
isGeminiEmbedding2Model,
|
||||
resolveGeminiOutputDimensionality,
|
||||
} = await import("./embeddings-gemini.js"));
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.doUnmock("undici");
|
||||
vi.resetAllMocks();
|
||||
vi.unstubAllGlobals();
|
||||
});
|
||||
|
||||
@@ -1,6 +1,21 @@
|
||||
import { describe, it, expect, vi } from "vitest";
|
||||
import { afterEach, beforeEach, describe, it, expect, vi } from "vitest";
|
||||
import type { OpenClawConfig } from "../config/config.js";
|
||||
import { createOllamaEmbeddingProvider } from "./embeddings-ollama.js";
|
||||
|
||||
let createOllamaEmbeddingProvider: typeof import("./embeddings-ollama.js").createOllamaEmbeddingProvider;
|
||||
|
||||
beforeEach(async () => {
|
||||
vi.useRealTimers();
|
||||
vi.doUnmock("undici");
|
||||
vi.resetModules();
|
||||
({ createOllamaEmbeddingProvider } = await import("./embeddings-ollama.js"));
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.doUnmock("undici");
|
||||
vi.unstubAllGlobals();
|
||||
vi.unstubAllEnvs();
|
||||
vi.resetAllMocks();
|
||||
});
|
||||
|
||||
describe("embeddings-ollama", () => {
|
||||
it("calls /api/embeddings and returns normalized vectors", async () => {
|
||||
|
||||
@@ -23,6 +23,9 @@ let createVoyageEmbeddingProvider: typeof import("./embeddings-voyage.js").creat
|
||||
let normalizeVoyageModel: typeof import("./embeddings-voyage.js").normalizeVoyageModel;
|
||||
|
||||
beforeEach(async () => {
|
||||
vi.useRealTimers();
|
||||
vi.doUnmock("undici");
|
||||
vi.resetModules();
|
||||
authModule = await import("../agents/model-auth.js");
|
||||
({ createVoyageEmbeddingProvider, normalizeVoyageModel } =
|
||||
await import("./embeddings-voyage.js"));
|
||||
@@ -53,6 +56,7 @@ async function createDefaultVoyageProvider(
|
||||
|
||||
describe("voyage embedding provider", () => {
|
||||
afterEach(() => {
|
||||
vi.doUnmock("undici");
|
||||
vi.resetAllMocks();
|
||||
vi.unstubAllGlobals();
|
||||
});
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { setTimeout as sleep } from "node:timers/promises";
|
||||
import { afterEach, beforeAll, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { DEFAULT_GEMINI_EMBEDDING_MODEL } from "./embeddings-gemini.js";
|
||||
import { mockPublicPinnedHostname } from "./test-helpers/ssrf.js";
|
||||
@@ -579,13 +580,13 @@ describe("local embedding ensureContext concurrency", () => {
|
||||
throw new Error("transient init failure");
|
||||
}
|
||||
if (params?.initializationDelayMs) {
|
||||
await new Promise((r) => setTimeout(r, params.initializationDelayMs));
|
||||
await sleep(params.initializationDelayMs);
|
||||
}
|
||||
return {
|
||||
loadModel: async (...modelArgs: unknown[]) => {
|
||||
loadModelSpy(...modelArgs);
|
||||
if (params?.initializationDelayMs) {
|
||||
await new Promise((r) => setTimeout(r, params.initializationDelayMs));
|
||||
await sleep(params.initializationDelayMs);
|
||||
}
|
||||
return {
|
||||
createEmbeddingContext: async () => {
|
||||
|
||||
Reference in New Issue
Block a user