refactor: move memory host into sdk package

This commit is contained in:
Peter Steinberger
2026-03-27 04:11:55 +00:00
parent 490b2f881c
commit eebce9e9c7
107 changed files with 166 additions and 178 deletions

View File

@@ -1,6 +1,6 @@
import type { EmbeddingInput } from "../../packages/memory-host-sdk/src/host/embedding-inputs.js";
import type { OpenClawConfig } from "../config/config.js";
import type { SecretInput } from "../config/types.secrets.js";
import type { EmbeddingInput } from "./memory-host/embedding-inputs.js";
export type MemoryEmbeddingBatchChunk = {
text: string;

View File

@@ -1,146 +0,0 @@
import path from "node:path";
import { describe, expect, it } from "vitest";
import { resolveAgentWorkspaceDir } from "../../agents/agent-scope.js";
import type { OpenClawConfig } from "../../config/config.js";
import { resolveMemoryBackendConfig } from "./backend-config.js";
describe("resolveMemoryBackendConfig", () => {
it("defaults to builtin backend when config missing", () => {
const cfg = { agents: { defaults: { workspace: "/tmp/memory-test" } } } as OpenClawConfig;
const resolved = resolveMemoryBackendConfig({ cfg, agentId: "main" });
expect(resolved.backend).toBe("builtin");
expect(resolved.citations).toBe("auto");
expect(resolved.qmd).toBeUndefined();
});
it("resolves qmd backend with default collections", () => {
const cfg = {
agents: { defaults: { workspace: "/tmp/memory-test" } },
memory: {
backend: "qmd",
qmd: {},
},
} as OpenClawConfig;
const resolved = resolveMemoryBackendConfig({ cfg, agentId: "main" });
expect(resolved.backend).toBe("qmd");
expect(resolved.qmd?.collections.length).toBeGreaterThanOrEqual(3);
expect(resolved.qmd?.command).toBe("qmd");
expect(resolved.qmd?.searchMode).toBe("search");
expect(resolved.qmd?.update.intervalMs).toBeGreaterThan(0);
expect(resolved.qmd?.update.waitForBootSync).toBe(false);
expect(resolved.qmd?.update.commandTimeoutMs).toBe(30_000);
expect(resolved.qmd?.update.updateTimeoutMs).toBe(120_000);
expect(resolved.qmd?.update.embedTimeoutMs).toBe(120_000);
const names = new Set((resolved.qmd?.collections ?? []).map((collection) => collection.name));
expect(names.has("memory-root-main")).toBe(true);
expect(names.has("memory-alt-main")).toBe(true);
expect(names.has("memory-dir-main")).toBe(true);
});
it("parses quoted qmd command paths", () => {
const cfg = {
agents: { defaults: { workspace: "/tmp/memory-test" } },
memory: {
backend: "qmd",
qmd: {
command: '"/Applications/QMD Tools/qmd" --flag',
},
},
} as OpenClawConfig;
const resolved = resolveMemoryBackendConfig({ cfg, agentId: "main" });
expect(resolved.qmd?.command).toBe("/Applications/QMD Tools/qmd");
});
it("resolves custom paths relative to workspace", () => {
const cfg = {
agents: {
defaults: { workspace: "/workspace/root" },
list: [{ id: "main", workspace: "/workspace/root" }],
},
memory: {
backend: "qmd",
qmd: {
paths: [
{
path: "notes",
name: "custom-notes",
pattern: "**/*.md",
},
],
},
},
} as OpenClawConfig;
const resolved = resolveMemoryBackendConfig({ cfg, agentId: "main" });
const custom = resolved.qmd?.collections.find((c) => c.name.startsWith("custom-notes"));
expect(custom).toBeDefined();
const workspaceRoot = resolveAgentWorkspaceDir(cfg, "main");
expect(custom?.path).toBe(path.resolve(workspaceRoot, "notes"));
});
it("scopes qmd collection names per agent", () => {
const cfg = {
agents: {
defaults: { workspace: "/workspace/root" },
list: [
{ id: "main", default: true, workspace: "/workspace/root" },
{ id: "dev", workspace: "/workspace/dev" },
],
},
memory: {
backend: "qmd",
qmd: {
includeDefaultMemory: true,
paths: [{ path: "notes", name: "workspace", pattern: "**/*.md" }],
},
},
} as OpenClawConfig;
const mainResolved = resolveMemoryBackendConfig({ cfg, agentId: "main" });
const devResolved = resolveMemoryBackendConfig({ cfg, agentId: "dev" });
const mainNames = new Set(
(mainResolved.qmd?.collections ?? []).map((collection) => collection.name),
);
const devNames = new Set(
(devResolved.qmd?.collections ?? []).map((collection) => collection.name),
);
expect(mainNames.has("memory-dir-main")).toBe(true);
expect(devNames.has("memory-dir-dev")).toBe(true);
expect(mainNames.has("workspace-main")).toBe(true);
expect(devNames.has("workspace-dev")).toBe(true);
});
it("resolves qmd update timeout overrides", () => {
const cfg = {
agents: { defaults: { workspace: "/tmp/memory-test" } },
memory: {
backend: "qmd",
qmd: {
update: {
waitForBootSync: true,
commandTimeoutMs: 12_000,
updateTimeoutMs: 480_000,
embedTimeoutMs: 360_000,
},
},
},
} as OpenClawConfig;
const resolved = resolveMemoryBackendConfig({ cfg, agentId: "main" });
expect(resolved.qmd?.update.waitForBootSync).toBe(true);
expect(resolved.qmd?.update.commandTimeoutMs).toBe(12_000);
expect(resolved.qmd?.update.updateTimeoutMs).toBe(480_000);
expect(resolved.qmd?.update.embedTimeoutMs).toBe(360_000);
});
it("resolves qmd search mode override", () => {
const cfg = {
agents: { defaults: { workspace: "/tmp/memory-test" } },
memory: {
backend: "qmd",
qmd: {
searchMode: "vsearch",
},
},
} as OpenClawConfig;
const resolved = resolveMemoryBackendConfig({ cfg, agentId: "main" });
expect(resolved.qmd?.searchMode).toBe("vsearch");
});
});

View File

@@ -1,354 +0,0 @@
import path from "node:path";
import { resolveAgentWorkspaceDir } from "../../agents/agent-scope.js";
import { parseDurationMs } from "../../cli/parse-duration.js";
import type { OpenClawConfig } from "../../config/config.js";
import type { SessionSendPolicyConfig } from "../../config/types.base.js";
import type {
MemoryBackend,
MemoryCitationsMode,
MemoryQmdConfig,
MemoryQmdIndexPath,
MemoryQmdMcporterConfig,
MemoryQmdSearchMode,
} from "../../config/types.memory.js";
import { resolveUserPath } from "../../utils.js";
import { splitShellArgs } from "../../utils/shell-argv.js";
export type ResolvedMemoryBackendConfig = {
backend: MemoryBackend;
citations: MemoryCitationsMode;
qmd?: ResolvedQmdConfig;
};
export type ResolvedQmdCollection = {
name: string;
path: string;
pattern: string;
kind: "memory" | "custom" | "sessions";
};
export type ResolvedQmdUpdateConfig = {
intervalMs: number;
debounceMs: number;
onBoot: boolean;
waitForBootSync: boolean;
embedIntervalMs: number;
commandTimeoutMs: number;
updateTimeoutMs: number;
embedTimeoutMs: number;
};
export type ResolvedQmdLimitsConfig = {
maxResults: number;
maxSnippetChars: number;
maxInjectedChars: number;
timeoutMs: number;
};
export type ResolvedQmdSessionConfig = {
enabled: boolean;
exportDir?: string;
retentionDays?: number;
};
export type ResolvedQmdMcporterConfig = {
enabled: boolean;
serverName: string;
startDaemon: boolean;
};
export type ResolvedQmdConfig = {
command: string;
mcporter: ResolvedQmdMcporterConfig;
searchMode: MemoryQmdSearchMode;
collections: ResolvedQmdCollection[];
sessions: ResolvedQmdSessionConfig;
update: ResolvedQmdUpdateConfig;
limits: ResolvedQmdLimitsConfig;
includeDefaultMemory: boolean;
scope?: SessionSendPolicyConfig;
};
const DEFAULT_BACKEND: MemoryBackend = "builtin";
const DEFAULT_CITATIONS: MemoryCitationsMode = "auto";
const DEFAULT_QMD_INTERVAL = "5m";
const DEFAULT_QMD_DEBOUNCE_MS = 15_000;
const DEFAULT_QMD_TIMEOUT_MS = 4_000;
// Defaulting to `query` can be extremely slow on CPU-only systems (query expansion + rerank).
// Prefer a faster mode for interactive use; users can opt into `query` for best recall.
const DEFAULT_QMD_SEARCH_MODE: MemoryQmdSearchMode = "search";
const DEFAULT_QMD_EMBED_INTERVAL = "60m";
const DEFAULT_QMD_COMMAND_TIMEOUT_MS = 30_000;
const DEFAULT_QMD_UPDATE_TIMEOUT_MS = 120_000;
const DEFAULT_QMD_EMBED_TIMEOUT_MS = 120_000;
const DEFAULT_QMD_LIMITS: ResolvedQmdLimitsConfig = {
maxResults: 6,
maxSnippetChars: 700,
maxInjectedChars: 4_000,
timeoutMs: DEFAULT_QMD_TIMEOUT_MS,
};
const DEFAULT_QMD_MCPORTER: ResolvedQmdMcporterConfig = {
enabled: false,
serverName: "qmd",
startDaemon: true,
};
const DEFAULT_QMD_SCOPE: SessionSendPolicyConfig = {
default: "deny",
rules: [
{
action: "allow",
match: { chatType: "direct" },
},
],
};
function sanitizeName(input: string): string {
const lower = input.toLowerCase().replace(/[^a-z0-9-]+/g, "-");
const trimmed = lower.replace(/^-+|-+$/g, "");
return trimmed || "collection";
}
function scopeCollectionBase(base: string, agentId: string): string {
return `${base}-${sanitizeName(agentId)}`;
}
function ensureUniqueName(base: string, existing: Set<string>): string {
let name = sanitizeName(base);
if (!existing.has(name)) {
existing.add(name);
return name;
}
let suffix = 2;
while (existing.has(`${name}-${suffix}`)) {
suffix += 1;
}
const unique = `${name}-${suffix}`;
existing.add(unique);
return unique;
}
function resolvePath(raw: string, workspaceDir: string): string {
const trimmed = raw.trim();
if (!trimmed) {
throw new Error("path required");
}
if (trimmed.startsWith("~") || path.isAbsolute(trimmed)) {
return path.normalize(resolveUserPath(trimmed));
}
return path.normalize(path.resolve(workspaceDir, trimmed));
}
function resolveIntervalMs(raw: string | undefined): number {
const value = raw?.trim();
if (!value) {
return parseDurationMs(DEFAULT_QMD_INTERVAL, { defaultUnit: "m" });
}
try {
return parseDurationMs(value, { defaultUnit: "m" });
} catch {
return parseDurationMs(DEFAULT_QMD_INTERVAL, { defaultUnit: "m" });
}
}
function resolveEmbedIntervalMs(raw: string | undefined): number {
const value = raw?.trim();
if (!value) {
return parseDurationMs(DEFAULT_QMD_EMBED_INTERVAL, { defaultUnit: "m" });
}
try {
return parseDurationMs(value, { defaultUnit: "m" });
} catch {
return parseDurationMs(DEFAULT_QMD_EMBED_INTERVAL, { defaultUnit: "m" });
}
}
function resolveDebounceMs(raw: number | undefined): number {
if (typeof raw === "number" && Number.isFinite(raw) && raw >= 0) {
return Math.floor(raw);
}
return DEFAULT_QMD_DEBOUNCE_MS;
}
function resolveTimeoutMs(raw: number | undefined, fallback: number): number {
if (typeof raw === "number" && Number.isFinite(raw) && raw > 0) {
return Math.floor(raw);
}
return fallback;
}
function resolveLimits(raw?: MemoryQmdConfig["limits"]): ResolvedQmdLimitsConfig {
const parsed: ResolvedQmdLimitsConfig = { ...DEFAULT_QMD_LIMITS };
if (raw?.maxResults && raw.maxResults > 0) {
parsed.maxResults = Math.floor(raw.maxResults);
}
if (raw?.maxSnippetChars && raw.maxSnippetChars > 0) {
parsed.maxSnippetChars = Math.floor(raw.maxSnippetChars);
}
if (raw?.maxInjectedChars && raw.maxInjectedChars > 0) {
parsed.maxInjectedChars = Math.floor(raw.maxInjectedChars);
}
if (raw?.timeoutMs && raw.timeoutMs > 0) {
parsed.timeoutMs = Math.floor(raw.timeoutMs);
}
return parsed;
}
function resolveSearchMode(raw?: MemoryQmdConfig["searchMode"]): MemoryQmdSearchMode {
if (raw === "search" || raw === "vsearch" || raw === "query") {
return raw;
}
return DEFAULT_QMD_SEARCH_MODE;
}
function resolveSessionConfig(
cfg: MemoryQmdConfig["sessions"],
workspaceDir: string,
): ResolvedQmdSessionConfig {
const enabled = Boolean(cfg?.enabled);
const exportDirRaw = cfg?.exportDir?.trim();
const exportDir = exportDirRaw ? resolvePath(exportDirRaw, workspaceDir) : undefined;
const retentionDays =
cfg?.retentionDays && cfg.retentionDays > 0 ? Math.floor(cfg.retentionDays) : undefined;
return {
enabled,
exportDir,
retentionDays,
};
}
function resolveCustomPaths(
rawPaths: MemoryQmdIndexPath[] | undefined,
workspaceDir: string,
existing: Set<string>,
agentId: string,
): ResolvedQmdCollection[] {
if (!rawPaths?.length) {
return [];
}
const collections: ResolvedQmdCollection[] = [];
rawPaths.forEach((entry, index) => {
const trimmedPath = entry?.path?.trim();
if (!trimmedPath) {
return;
}
let resolved: string;
try {
resolved = resolvePath(trimmedPath, workspaceDir);
} catch {
return;
}
const pattern = entry.pattern?.trim() || "**/*.md";
const baseName = scopeCollectionBase(entry.name?.trim() || `custom-${index + 1}`, agentId);
const name = ensureUniqueName(baseName, existing);
collections.push({
name,
path: resolved,
pattern,
kind: "custom",
});
});
return collections;
}
function resolveMcporterConfig(raw?: MemoryQmdMcporterConfig): ResolvedQmdMcporterConfig {
const parsed: ResolvedQmdMcporterConfig = { ...DEFAULT_QMD_MCPORTER };
if (!raw) {
return parsed;
}
if (raw.enabled !== undefined) {
parsed.enabled = raw.enabled;
}
if (typeof raw.serverName === "string" && raw.serverName.trim()) {
parsed.serverName = raw.serverName.trim();
}
if (raw.startDaemon !== undefined) {
parsed.startDaemon = raw.startDaemon;
}
// When enabled, default startDaemon to true.
if (parsed.enabled && raw.startDaemon === undefined) {
parsed.startDaemon = true;
}
return parsed;
}
function resolveDefaultCollections(
include: boolean,
workspaceDir: string,
existing: Set<string>,
agentId: string,
): ResolvedQmdCollection[] {
if (!include) {
return [];
}
const entries: Array<{ path: string; pattern: string; base: string }> = [
{ path: workspaceDir, pattern: "MEMORY.md", base: "memory-root" },
{ path: workspaceDir, pattern: "memory.md", base: "memory-alt" },
{ path: path.join(workspaceDir, "memory"), pattern: "**/*.md", base: "memory-dir" },
];
return entries.map((entry) => ({
name: ensureUniqueName(scopeCollectionBase(entry.base, agentId), existing),
path: entry.path,
pattern: entry.pattern,
kind: "memory",
}));
}
export function resolveMemoryBackendConfig(params: {
cfg: OpenClawConfig;
agentId: string;
}): ResolvedMemoryBackendConfig {
const backend = params.cfg.memory?.backend ?? DEFAULT_BACKEND;
const citations = params.cfg.memory?.citations ?? DEFAULT_CITATIONS;
if (backend !== "qmd") {
return { backend: "builtin", citations };
}
const workspaceDir = resolveAgentWorkspaceDir(params.cfg, params.agentId);
const qmdCfg = params.cfg.memory?.qmd;
const includeDefaultMemory = qmdCfg?.includeDefaultMemory !== false;
const nameSet = new Set<string>();
const collections = [
...resolveDefaultCollections(includeDefaultMemory, workspaceDir, nameSet, params.agentId),
...resolveCustomPaths(qmdCfg?.paths, workspaceDir, nameSet, params.agentId),
];
const rawCommand = qmdCfg?.command?.trim() || "qmd";
const parsedCommand = splitShellArgs(rawCommand);
const command = parsedCommand?.[0] || rawCommand.split(/\s+/)[0] || "qmd";
const resolved: ResolvedQmdConfig = {
command,
mcporter: resolveMcporterConfig(qmdCfg?.mcporter),
searchMode: resolveSearchMode(qmdCfg?.searchMode),
collections,
includeDefaultMemory,
sessions: resolveSessionConfig(qmdCfg?.sessions, workspaceDir),
update: {
intervalMs: resolveIntervalMs(qmdCfg?.update?.interval),
debounceMs: resolveDebounceMs(qmdCfg?.update?.debounceMs),
onBoot: qmdCfg?.update?.onBoot !== false,
waitForBootSync: qmdCfg?.update?.waitForBootSync === true,
embedIntervalMs: resolveEmbedIntervalMs(qmdCfg?.update?.embedInterval),
commandTimeoutMs: resolveTimeoutMs(
qmdCfg?.update?.commandTimeoutMs,
DEFAULT_QMD_COMMAND_TIMEOUT_MS,
),
updateTimeoutMs: resolveTimeoutMs(
qmdCfg?.update?.updateTimeoutMs,
DEFAULT_QMD_UPDATE_TIMEOUT_MS,
),
embedTimeoutMs: resolveTimeoutMs(
qmdCfg?.update?.embedTimeoutMs,
DEFAULT_QMD_EMBED_TIMEOUT_MS,
),
},
limits: resolveLimits(qmdCfg?.limits),
scope: qmdCfg?.scope ?? DEFAULT_QMD_SCOPE,
};
return {
backend: "qmd",
citations,
qmd: resolved,
};
}

View File

@@ -1,22 +0,0 @@
export { extractBatchErrorMessage, formatUnavailableBatchError } from "./batch-error-utils.js";
export { postJsonWithRetry } from "./batch-http.js";
export { applyEmbeddingBatchOutputLine } from "./batch-output.js";
export {
resolveBatchCompletionFromStatus,
resolveCompletedBatchResult,
throwIfBatchTerminalFailure,
type BatchCompletionResult,
} from "./batch-status.js";
export {
EMBEDDING_BATCH_ENDPOINT,
type EmbeddingBatchStatus,
type ProviderBatchOutputLine,
} from "./batch-provider-common.js";
export {
buildEmbeddingBatchGroupOptions,
runEmbeddingBatchGroups,
type EmbeddingBatchExecutionParams,
} from "./batch-runner.js";
export { uploadBatchJsonlFile } from "./batch-upload.js";
export { buildBatchHeaders, normalizeBatchBaseUrl } from "./batch-utils.js";
export { withRemoteHttpResponse } from "./remote-http.js";

View File

@@ -1,32 +0,0 @@
import { describe, expect, it } from "vitest";
import { extractBatchErrorMessage, formatUnavailableBatchError } from "./batch-error-utils.js";
describe("extractBatchErrorMessage", () => {
it("returns the first top-level error message", () => {
expect(
extractBatchErrorMessage([
{ response: { body: { error: { message: "nested" } } } },
{ error: { message: "top-level" } },
]),
).toBe("nested");
});
it("falls back to nested response error message", () => {
expect(
extractBatchErrorMessage([{ response: { body: { error: { message: "nested-only" } } } }, {}]),
).toBe("nested-only");
});
it("accepts plain string response bodies", () => {
expect(extractBatchErrorMessage([{ response: { body: "provider plain-text error" } }])).toBe(
"provider plain-text error",
);
});
});
describe("formatUnavailableBatchError", () => {
it("formats errors and non-error values", () => {
expect(formatUnavailableBatchError(new Error("boom"))).toBe("error file unavailable: boom");
expect(formatUnavailableBatchError("unreachable")).toBe("error file unavailable: unreachable");
});
});

View File

@@ -1,31 +0,0 @@
type BatchOutputErrorLike = {
error?: { message?: string };
response?: {
body?:
| string
| {
error?: { message?: string };
};
};
};
function getResponseErrorMessage(line: BatchOutputErrorLike | undefined): string | undefined {
const body = line?.response?.body;
if (typeof body === "string") {
return body || undefined;
}
if (!body || typeof body !== "object") {
return undefined;
}
return typeof body.error?.message === "string" ? body.error.message : undefined;
}
export function extractBatchErrorMessage(lines: BatchOutputErrorLike[]): string | undefined {
const first = lines.find((line) => line.error?.message || getResponseErrorMessage(line));
return first?.error?.message ?? getResponseErrorMessage(first);
}
export function formatUnavailableBatchError(err: unknown): string | undefined {
const message = err instanceof Error ? err.message : String(err);
return message ? `error file unavailable: ${message}` : undefined;
}

View File

@@ -1,114 +0,0 @@
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
import type { GeminiEmbeddingClient } from "./embeddings-gemini.js";
vi.mock("./remote-http.js", () => ({
withRemoteHttpResponse: vi.fn(),
}));
function magnitude(values: number[]) {
return Math.sqrt(values.reduce((sum, value) => sum + value * value, 0));
}
describe("runGeminiEmbeddingBatches", () => {
let runGeminiEmbeddingBatches: typeof import("./batch-gemini.js").runGeminiEmbeddingBatches;
let withRemoteHttpResponse: typeof import("./remote-http.js").withRemoteHttpResponse;
let remoteHttpMock: ReturnType<typeof vi.mocked<typeof withRemoteHttpResponse>>;
beforeEach(async () => {
vi.resetModules();
vi.clearAllMocks();
({ runGeminiEmbeddingBatches } = await import("./batch-gemini.js"));
({ withRemoteHttpResponse } = await import("./remote-http.js"));
remoteHttpMock = vi.mocked(withRemoteHttpResponse);
});
afterEach(() => {
vi.resetAllMocks();
vi.unstubAllGlobals();
});
const mockClient: GeminiEmbeddingClient = {
baseUrl: "https://generativelanguage.googleapis.com/v1beta",
headers: {},
model: "gemini-embedding-2-preview",
modelPath: "models/gemini-embedding-2-preview",
apiKeys: ["test-key"],
outputDimensionality: 1536,
};
it("includes outputDimensionality in batch upload requests", async () => {
remoteHttpMock.mockImplementationOnce(async (params) => {
expect(params.url).toContain("/upload/v1beta/files?uploadType=multipart");
const body = params.init?.body;
if (!(body instanceof Blob)) {
throw new Error("expected multipart blob body");
}
const text = await body.text();
expect(text).toContain('"taskType":"RETRIEVAL_DOCUMENT"');
expect(text).toContain('"outputDimensionality":1536');
return await params.onResponse(
new Response(JSON.stringify({ name: "files/file-123" }), {
status: 200,
headers: { "Content-Type": "application/json" },
}),
);
});
remoteHttpMock.mockImplementationOnce(async (params) => {
expect(params.url).toMatch(/:asyncBatchEmbedContent$/u);
return await params.onResponse(
new Response(
JSON.stringify({
name: "batches/batch-1",
state: "COMPLETED",
outputConfig: { file: "files/output-1" },
}),
{
status: 200,
headers: { "Content-Type": "application/json" },
},
),
);
});
remoteHttpMock.mockImplementationOnce(async (params) => {
expect(params.url).toMatch(/\/files\/output-1:download$/u);
return await params.onResponse(
new Response(
JSON.stringify({
key: "req-1",
response: { embedding: { values: [3, 4] } },
}),
{
status: 200,
headers: { "Content-Type": "application/jsonl" },
},
),
);
});
const results = await runGeminiEmbeddingBatches({
gemini: mockClient,
agentId: "main",
requests: [
{
custom_id: "req-1",
request: {
content: { parts: [{ text: "hello world" }] },
taskType: "RETRIEVAL_DOCUMENT",
outputDimensionality: 1536,
},
},
],
wait: true,
pollIntervalMs: 1,
timeoutMs: 1000,
concurrency: 1,
});
const embedding = results.get("req-1");
expect(embedding).toBeDefined();
expect(embedding?.[0]).toBeCloseTo(0.6, 5);
expect(embedding?.[1]).toBeCloseTo(0.8, 5);
expect(magnitude(embedding ?? [])).toBeCloseTo(1, 5);
expect(remoteHttpMock).toHaveBeenCalledTimes(3);
});
});

View File

@@ -1,368 +0,0 @@
import {
buildEmbeddingBatchGroupOptions,
runEmbeddingBatchGroups,
type EmbeddingBatchExecutionParams,
} from "./batch-runner.js";
import { buildBatchHeaders, normalizeBatchBaseUrl } from "./batch-utils.js";
import { sanitizeAndNormalizeEmbedding } from "./embedding-vectors.js";
import { debugEmbeddingsLog } from "./embeddings-debug.js";
import type { GeminiEmbeddingClient, GeminiTextEmbeddingRequest } from "./embeddings-gemini.js";
import { hashText } from "./internal.js";
import { withRemoteHttpResponse } from "./remote-http.js";
export type GeminiBatchRequest = {
custom_id: string;
request: GeminiTextEmbeddingRequest;
};
export type GeminiBatchStatus = {
name?: string;
state?: string;
outputConfig?: { file?: string; fileId?: string };
metadata?: {
output?: {
responsesFile?: string;
};
};
error?: { message?: string };
};
export type GeminiBatchOutputLine = {
key?: string;
custom_id?: string;
request_id?: string;
embedding?: { values?: number[] };
response?: {
embedding?: { values?: number[] };
error?: { message?: string };
};
error?: { message?: string };
};
const GEMINI_BATCH_MAX_REQUESTS = 50000;
function getGeminiUploadUrl(baseUrl: string): string {
if (baseUrl.includes("/v1beta")) {
return baseUrl.replace(/\/v1beta\/?$/, "/upload/v1beta");
}
return `${baseUrl.replace(/\/$/, "")}/upload`;
}
function buildGeminiUploadBody(params: { jsonl: string; displayName: string }): {
body: Blob;
contentType: string;
} {
const boundary = `openclaw-${hashText(params.displayName)}`;
const jsonPart = JSON.stringify({
file: {
displayName: params.displayName,
mimeType: "application/jsonl",
},
});
const delimiter = `--${boundary}\r\n`;
const closeDelimiter = `--${boundary}--\r\n`;
const parts = [
`${delimiter}Content-Type: application/json; charset=UTF-8\r\n\r\n${jsonPart}\r\n`,
`${delimiter}Content-Type: application/jsonl; charset=UTF-8\r\n\r\n${params.jsonl}\r\n`,
closeDelimiter,
];
const body = new Blob([parts.join("")], { type: "multipart/related" });
return {
body,
contentType: `multipart/related; boundary=${boundary}`,
};
}
async function submitGeminiBatch(params: {
gemini: GeminiEmbeddingClient;
requests: GeminiBatchRequest[];
agentId: string;
}): Promise<GeminiBatchStatus> {
const baseUrl = normalizeBatchBaseUrl(params.gemini);
const jsonl = params.requests
.map((request) =>
JSON.stringify({
key: request.custom_id,
request: request.request,
}),
)
.join("\n");
const displayName = `memory-embeddings-${hashText(String(Date.now()))}`;
const uploadPayload = buildGeminiUploadBody({ jsonl, displayName });
const uploadUrl = `${getGeminiUploadUrl(baseUrl)}/files?uploadType=multipart`;
debugEmbeddingsLog("memory embeddings: gemini batch upload", {
uploadUrl,
baseUrl,
requests: params.requests.length,
});
const filePayload = await withRemoteHttpResponse({
url: uploadUrl,
ssrfPolicy: params.gemini.ssrfPolicy,
init: {
method: "POST",
headers: {
...buildBatchHeaders(params.gemini, { json: false }),
"Content-Type": uploadPayload.contentType,
},
body: uploadPayload.body,
},
onResponse: async (fileRes) => {
if (!fileRes.ok) {
const text = await fileRes.text();
throw new Error(`gemini batch file upload failed: ${fileRes.status} ${text}`);
}
return (await fileRes.json()) as { name?: string; file?: { name?: string } };
},
});
const fileId = filePayload.name ?? filePayload.file?.name;
if (!fileId) {
throw new Error("gemini batch file upload failed: missing file id");
}
const batchBody = {
batch: {
displayName: `memory-embeddings-${params.agentId}`,
inputConfig: {
file_name: fileId,
},
},
};
const batchEndpoint = `${baseUrl}/${params.gemini.modelPath}:asyncBatchEmbedContent`;
debugEmbeddingsLog("memory embeddings: gemini batch create", {
batchEndpoint,
fileId,
});
return await withRemoteHttpResponse({
url: batchEndpoint,
ssrfPolicy: params.gemini.ssrfPolicy,
init: {
method: "POST",
headers: buildBatchHeaders(params.gemini, { json: true }),
body: JSON.stringify(batchBody),
},
onResponse: async (batchRes) => {
if (batchRes.ok) {
return (await batchRes.json()) as GeminiBatchStatus;
}
const text = await batchRes.text();
if (batchRes.status === 404) {
throw new Error(
"gemini batch create failed: 404 (asyncBatchEmbedContent not available for this model/baseUrl). Disable remote.batch.enabled or switch providers.",
);
}
throw new Error(`gemini batch create failed: ${batchRes.status} ${text}`);
},
});
}
async function fetchGeminiBatchStatus(params: {
gemini: GeminiEmbeddingClient;
batchName: string;
}): Promise<GeminiBatchStatus> {
const baseUrl = normalizeBatchBaseUrl(params.gemini);
const name = params.batchName.startsWith("batches/")
? params.batchName
: `batches/${params.batchName}`;
const statusUrl = `${baseUrl}/${name}`;
debugEmbeddingsLog("memory embeddings: gemini batch status", { statusUrl });
return await withRemoteHttpResponse({
url: statusUrl,
ssrfPolicy: params.gemini.ssrfPolicy,
init: {
headers: buildBatchHeaders(params.gemini, { json: true }),
},
onResponse: async (res) => {
if (!res.ok) {
const text = await res.text();
throw new Error(`gemini batch status failed: ${res.status} ${text}`);
}
return (await res.json()) as GeminiBatchStatus;
},
});
}
async function fetchGeminiFileContent(params: {
gemini: GeminiEmbeddingClient;
fileId: string;
}): Promise<string> {
const baseUrl = normalizeBatchBaseUrl(params.gemini);
const file = params.fileId.startsWith("files/") ? params.fileId : `files/${params.fileId}`;
const downloadUrl = `${baseUrl}/${file}:download`;
debugEmbeddingsLog("memory embeddings: gemini batch download", { downloadUrl });
return await withRemoteHttpResponse({
url: downloadUrl,
ssrfPolicy: params.gemini.ssrfPolicy,
init: {
headers: buildBatchHeaders(params.gemini, { json: true }),
},
onResponse: async (res) => {
if (!res.ok) {
const text = await res.text();
throw new Error(`gemini batch file content failed: ${res.status} ${text}`);
}
return await res.text();
},
});
}
function parseGeminiBatchOutput(text: string): GeminiBatchOutputLine[] {
if (!text.trim()) {
return [];
}
return text
.split("\n")
.map((line) => line.trim())
.filter(Boolean)
.map((line) => JSON.parse(line) as GeminiBatchOutputLine);
}
async function waitForGeminiBatch(params: {
gemini: GeminiEmbeddingClient;
batchName: string;
wait: boolean;
pollIntervalMs: number;
timeoutMs: number;
debug?: (message: string, data?: Record<string, unknown>) => void;
initial?: GeminiBatchStatus;
}): Promise<{ outputFileId: string }> {
const start = Date.now();
let current: GeminiBatchStatus | undefined = params.initial;
while (true) {
const status =
current ??
(await fetchGeminiBatchStatus({
gemini: params.gemini,
batchName: params.batchName,
}));
const state = status.state ?? "UNKNOWN";
if (["SUCCEEDED", "COMPLETED", "DONE"].includes(state)) {
const outputFileId =
status.outputConfig?.file ??
status.outputConfig?.fileId ??
status.metadata?.output?.responsesFile;
if (!outputFileId) {
throw new Error(`gemini batch ${params.batchName} completed without output file`);
}
return { outputFileId };
}
if (["FAILED", "CANCELLED", "CANCELED", "EXPIRED"].includes(state)) {
const message = status.error?.message ?? "unknown error";
throw new Error(`gemini batch ${params.batchName} ${state}: ${message}`);
}
if (!params.wait) {
throw new Error(`gemini batch ${params.batchName} still ${state}; wait disabled`);
}
if (Date.now() - start > params.timeoutMs) {
throw new Error(`gemini batch ${params.batchName} timed out after ${params.timeoutMs}ms`);
}
params.debug?.(`gemini batch ${params.batchName} ${state}; waiting ${params.pollIntervalMs}ms`);
await new Promise((resolve) => setTimeout(resolve, params.pollIntervalMs));
current = undefined;
}
}
export async function runGeminiEmbeddingBatches(
params: {
gemini: GeminiEmbeddingClient;
agentId: string;
requests: GeminiBatchRequest[];
} & EmbeddingBatchExecutionParams,
): Promise<Map<string, number[]>> {
return await runEmbeddingBatchGroups({
...buildEmbeddingBatchGroupOptions(params, {
maxRequests: GEMINI_BATCH_MAX_REQUESTS,
debugLabel: "memory embeddings: gemini batch submit",
}),
runGroup: async ({ group, groupIndex, groups, byCustomId }) => {
const batchInfo = await submitGeminiBatch({
gemini: params.gemini,
requests: group,
agentId: params.agentId,
});
const batchName = batchInfo.name ?? "";
if (!batchName) {
throw new Error("gemini batch create failed: missing batch name");
}
params.debug?.("memory embeddings: gemini batch created", {
batchName,
state: batchInfo.state,
group: groupIndex + 1,
groups,
requests: group.length,
});
if (
!params.wait &&
batchInfo.state &&
!["SUCCEEDED", "COMPLETED", "DONE"].includes(batchInfo.state)
) {
throw new Error(
`gemini batch ${batchName} submitted; enable remote.batch.wait to await completion`,
);
}
const completed =
batchInfo.state && ["SUCCEEDED", "COMPLETED", "DONE"].includes(batchInfo.state)
? {
outputFileId:
batchInfo.outputConfig?.file ??
batchInfo.outputConfig?.fileId ??
batchInfo.metadata?.output?.responsesFile ??
"",
}
: await waitForGeminiBatch({
gemini: params.gemini,
batchName,
wait: params.wait,
pollIntervalMs: params.pollIntervalMs,
timeoutMs: params.timeoutMs,
debug: params.debug,
initial: batchInfo,
});
if (!completed.outputFileId) {
throw new Error(`gemini batch ${batchName} completed without output file`);
}
const content = await fetchGeminiFileContent({
gemini: params.gemini,
fileId: completed.outputFileId,
});
const outputLines = parseGeminiBatchOutput(content);
const errors: string[] = [];
const remaining = new Set(group.map((request) => request.custom_id));
for (const line of outputLines) {
const customId = line.key ?? line.custom_id ?? line.request_id;
if (!customId) {
continue;
}
remaining.delete(customId);
if (line.error?.message) {
errors.push(`${customId}: ${line.error.message}`);
continue;
}
if (line.response?.error?.message) {
errors.push(`${customId}: ${line.response.error.message}`);
continue;
}
const embedding = sanitizeAndNormalizeEmbedding(
line.embedding?.values ?? line.response?.embedding?.values ?? [],
);
if (embedding.length === 0) {
errors.push(`${customId}: empty embedding`);
continue;
}
byCustomId.set(customId, embedding);
}
if (errors.length > 0) {
throw new Error(`gemini batch ${batchName} failed: ${errors.join("; ")}`);
}
if (remaining.size > 0) {
throw new Error(`gemini batch ${batchName} missing ${remaining.size} embedding responses`);
}
},
});
}

View File

@@ -1,85 +0,0 @@
import { beforeEach, describe, expect, it, vi } from "vitest";
vi.mock("../../infra/retry.js", () => ({
retryAsync: vi.fn(async (run: () => Promise<unknown>) => await run()),
}));
vi.mock("./post-json.js", () => ({
postJson: vi.fn(),
}));
describe("postJsonWithRetry", () => {
let retryAsyncMock: ReturnType<
typeof vi.mocked<typeof import("../../infra/retry.js").retryAsync>
>;
let postJsonMock: ReturnType<typeof vi.mocked<typeof import("./post-json.js").postJson>>;
let postJsonWithRetry: typeof import("./batch-http.js").postJsonWithRetry;
beforeEach(async () => {
vi.resetModules();
vi.clearAllMocks();
vi.resetModules();
({ postJsonWithRetry } = await import("./batch-http.js"));
const retryModule = await import("../../infra/retry.js");
const postJsonModule = await import("./post-json.js");
retryAsyncMock = vi.mocked(retryModule.retryAsync);
postJsonMock = vi.mocked(postJsonModule.postJson);
});
it("posts JSON and returns parsed response payload", async () => {
postJsonMock.mockImplementationOnce(async (params) => {
return await params.parse({ ok: true, ids: [1, 2] });
});
const result = await postJsonWithRetry<{ ok: boolean; ids: number[] }>({
url: "https://memory.example/v1/batch",
headers: { Authorization: "Bearer test" },
body: { chunks: ["a", "b"] },
errorPrefix: "memory batch failed",
});
expect(result).toEqual({ ok: true, ids: [1, 2] });
expect(postJsonMock).toHaveBeenCalledWith(
expect.objectContaining({
url: "https://memory.example/v1/batch",
headers: { Authorization: "Bearer test" },
body: { chunks: ["a", "b"] },
errorPrefix: "memory batch failed",
attachStatus: true,
}),
);
const retryOptions = retryAsyncMock.mock.calls[0]?.[1] as
| {
attempts: number;
minDelayMs: number;
maxDelayMs: number;
shouldRetry: (err: unknown) => boolean;
}
| undefined;
expect(retryOptions?.attempts).toBe(3);
expect(retryOptions?.minDelayMs).toBe(300);
expect(retryOptions?.maxDelayMs).toBe(2000);
expect(retryOptions?.shouldRetry({ status: 429 })).toBe(true);
expect(retryOptions?.shouldRetry({ status: 503 })).toBe(true);
expect(retryOptions?.shouldRetry({ status: 400 })).toBe(false);
});
it("attaches status to non-ok errors", async () => {
postJsonMock.mockRejectedValueOnce(
Object.assign(new Error("memory batch failed: 503 backend down"), { status: 503 }),
);
await expect(
postJsonWithRetry({
url: "https://memory.example/v1/batch",
headers: {},
body: { chunks: [] },
errorPrefix: "memory batch failed",
}),
).rejects.toMatchObject({
message: expect.stringContaining("memory batch failed: 503 backend down"),
status: 503,
});
});
});

View File

@@ -1,35 +0,0 @@
import type { SsrFPolicy } from "../../infra/net/ssrf.js";
import { retryAsync } from "../../infra/retry.js";
import { postJson } from "./post-json.js";
export async function postJsonWithRetry<T>(params: {
url: string;
headers: Record<string, string>;
ssrfPolicy?: SsrFPolicy;
body: unknown;
errorPrefix: string;
}): Promise<T> {
return await retryAsync(
async () => {
return await postJson<T>({
url: params.url,
headers: params.headers,
ssrfPolicy: params.ssrfPolicy,
body: params.body,
errorPrefix: params.errorPrefix,
attachStatus: true,
parse: async (payload) => payload as T,
});
},
{
attempts: 3,
minDelayMs: 300,
maxDelayMs: 2000,
jitter: 0.2,
shouldRetry: (err) => {
const status = (err as { status?: number }).status;
return status === 429 || (typeof status === "number" && status >= 500);
},
},
);
}

View File

@@ -1,259 +0,0 @@
import {
applyEmbeddingBatchOutputLine,
buildBatchHeaders,
buildEmbeddingBatchGroupOptions,
EMBEDDING_BATCH_ENDPOINT,
extractBatchErrorMessage,
formatUnavailableBatchError,
normalizeBatchBaseUrl,
postJsonWithRetry,
resolveBatchCompletionFromStatus,
resolveCompletedBatchResult,
runEmbeddingBatchGroups,
throwIfBatchTerminalFailure,
type EmbeddingBatchExecutionParams,
type EmbeddingBatchStatus,
type BatchCompletionResult,
type ProviderBatchOutputLine,
uploadBatchJsonlFile,
withRemoteHttpResponse,
} from "./batch-embedding-common.js";
import type { OpenAiEmbeddingClient } from "./embeddings-openai.js";
export type OpenAiBatchRequest = {
custom_id: string;
method: "POST";
url: "/v1/embeddings";
body: {
model: string;
input: string;
};
};
export type OpenAiBatchStatus = EmbeddingBatchStatus;
export type OpenAiBatchOutputLine = ProviderBatchOutputLine;
export const OPENAI_BATCH_ENDPOINT = EMBEDDING_BATCH_ENDPOINT;
const OPENAI_BATCH_COMPLETION_WINDOW = "24h";
const OPENAI_BATCH_MAX_REQUESTS = 50000;
async function submitOpenAiBatch(params: {
openAi: OpenAiEmbeddingClient;
requests: OpenAiBatchRequest[];
agentId: string;
}): Promise<OpenAiBatchStatus> {
const baseUrl = normalizeBatchBaseUrl(params.openAi);
const inputFileId = await uploadBatchJsonlFile({
client: params.openAi,
requests: params.requests,
errorPrefix: "openai batch file upload failed",
});
return await postJsonWithRetry<OpenAiBatchStatus>({
url: `${baseUrl}/batches`,
headers: buildBatchHeaders(params.openAi, { json: true }),
ssrfPolicy: params.openAi.ssrfPolicy,
body: {
input_file_id: inputFileId,
endpoint: OPENAI_BATCH_ENDPOINT,
completion_window: OPENAI_BATCH_COMPLETION_WINDOW,
metadata: {
source: "openclaw-memory",
agent: params.agentId,
},
},
errorPrefix: "openai batch create failed",
});
}
async function fetchOpenAiBatchStatus(params: {
openAi: OpenAiEmbeddingClient;
batchId: string;
}): Promise<OpenAiBatchStatus> {
return await fetchOpenAiBatchResource({
openAi: params.openAi,
path: `/batches/${params.batchId}`,
errorPrefix: "openai batch status",
parse: async (res) => (await res.json()) as OpenAiBatchStatus,
});
}
async function fetchOpenAiFileContent(params: {
openAi: OpenAiEmbeddingClient;
fileId: string;
}): Promise<string> {
return await fetchOpenAiBatchResource({
openAi: params.openAi,
path: `/files/${params.fileId}/content`,
errorPrefix: "openai batch file content",
parse: async (res) => await res.text(),
});
}
async function fetchOpenAiBatchResource<T>(params: {
openAi: OpenAiEmbeddingClient;
path: string;
errorPrefix: string;
parse: (res: Response) => Promise<T>;
}): Promise<T> {
const baseUrl = normalizeBatchBaseUrl(params.openAi);
return await withRemoteHttpResponse({
url: `${baseUrl}${params.path}`,
ssrfPolicy: params.openAi.ssrfPolicy,
init: {
headers: buildBatchHeaders(params.openAi, { json: true }),
},
onResponse: async (res) => {
if (!res.ok) {
const text = await res.text();
throw new Error(`${params.errorPrefix} failed: ${res.status} ${text}`);
}
return await params.parse(res);
},
});
}
function parseOpenAiBatchOutput(text: string): OpenAiBatchOutputLine[] {
if (!text.trim()) {
return [];
}
return text
.split("\n")
.map((line) => line.trim())
.filter(Boolean)
.map((line) => JSON.parse(line) as OpenAiBatchOutputLine);
}
async function readOpenAiBatchError(params: {
openAi: OpenAiEmbeddingClient;
errorFileId: string;
}): Promise<string | undefined> {
try {
const content = await fetchOpenAiFileContent({
openAi: params.openAi,
fileId: params.errorFileId,
});
const lines = parseOpenAiBatchOutput(content);
return extractBatchErrorMessage(lines);
} catch (err) {
return formatUnavailableBatchError(err);
}
}
async function waitForOpenAiBatch(params: {
openAi: OpenAiEmbeddingClient;
batchId: string;
wait: boolean;
pollIntervalMs: number;
timeoutMs: number;
debug?: (message: string, data?: Record<string, unknown>) => void;
initial?: OpenAiBatchStatus;
}): Promise<BatchCompletionResult> {
const start = Date.now();
let current: OpenAiBatchStatus | undefined = params.initial;
while (true) {
const status =
current ??
(await fetchOpenAiBatchStatus({
openAi: params.openAi,
batchId: params.batchId,
}));
const state = status.status ?? "unknown";
if (state === "completed") {
return resolveBatchCompletionFromStatus({
provider: "openai",
batchId: params.batchId,
status,
});
}
await throwIfBatchTerminalFailure({
provider: "openai",
status: { ...status, id: params.batchId },
readError: async (errorFileId) =>
await readOpenAiBatchError({
openAi: params.openAi,
errorFileId,
}),
});
if (!params.wait) {
throw new Error(`openai batch ${params.batchId} still ${state}; wait disabled`);
}
if (Date.now() - start > params.timeoutMs) {
throw new Error(`openai batch ${params.batchId} timed out after ${params.timeoutMs}ms`);
}
params.debug?.(`openai batch ${params.batchId} ${state}; waiting ${params.pollIntervalMs}ms`);
await new Promise((resolve) => setTimeout(resolve, params.pollIntervalMs));
current = undefined;
}
}
export async function runOpenAiEmbeddingBatches(
params: {
openAi: OpenAiEmbeddingClient;
agentId: string;
requests: OpenAiBatchRequest[];
} & EmbeddingBatchExecutionParams,
): Promise<Map<string, number[]>> {
return await runEmbeddingBatchGroups({
...buildEmbeddingBatchGroupOptions(params, {
maxRequests: OPENAI_BATCH_MAX_REQUESTS,
debugLabel: "memory embeddings: openai batch submit",
}),
runGroup: async ({ group, groupIndex, groups, byCustomId }) => {
const batchInfo = await submitOpenAiBatch({
openAi: params.openAi,
requests: group,
agentId: params.agentId,
});
if (!batchInfo.id) {
throw new Error("openai batch create failed: missing batch id");
}
const batchId = batchInfo.id;
params.debug?.("memory embeddings: openai batch created", {
batchId: batchInfo.id,
status: batchInfo.status,
group: groupIndex + 1,
groups,
requests: group.length,
});
const completed = await resolveCompletedBatchResult({
provider: "openai",
status: batchInfo,
wait: params.wait,
waitForBatch: async () =>
await waitForOpenAiBatch({
openAi: params.openAi,
batchId,
wait: params.wait,
pollIntervalMs: params.pollIntervalMs,
timeoutMs: params.timeoutMs,
debug: params.debug,
initial: batchInfo,
}),
});
const content = await fetchOpenAiFileContent({
openAi: params.openAi,
fileId: completed.outputFileId,
});
const outputLines = parseOpenAiBatchOutput(content);
const errors: string[] = [];
const remaining = new Set(group.map((request) => request.custom_id));
for (const line of outputLines) {
applyEmbeddingBatchOutputLine({ line, remaining, errors, byCustomId });
}
if (errors.length > 0) {
throw new Error(`openai batch ${batchInfo.id} failed: ${errors.join("; ")}`);
}
if (remaining.size > 0) {
throw new Error(
`openai batch ${batchInfo.id} missing ${remaining.size} embedding responses`,
);
}
},
});
}

View File

@@ -1,82 +0,0 @@
import { describe, expect, it } from "vitest";
import { applyEmbeddingBatchOutputLine } from "./batch-output.js";
describe("applyEmbeddingBatchOutputLine", () => {
it("stores embedding for successful response", () => {
const remaining = new Set(["req-1"]);
const errors: string[] = [];
const byCustomId = new Map<string, number[]>();
applyEmbeddingBatchOutputLine({
line: {
custom_id: "req-1",
response: {
status_code: 200,
body: { data: [{ embedding: [0.1, 0.2] }] },
},
},
remaining,
errors,
byCustomId,
});
expect(remaining.has("req-1")).toBe(false);
expect(errors).toEqual([]);
expect(byCustomId.get("req-1")).toEqual([0.1, 0.2]);
});
it("records provider error from line.error", () => {
const remaining = new Set(["req-2"]);
const errors: string[] = [];
const byCustomId = new Map<string, number[]>();
applyEmbeddingBatchOutputLine({
line: {
custom_id: "req-2",
error: { message: "provider failed" },
},
remaining,
errors,
byCustomId,
});
expect(remaining.has("req-2")).toBe(false);
expect(errors).toEqual(["req-2: provider failed"]);
expect(byCustomId.size).toBe(0);
});
it("records non-2xx response errors and empty embedding errors", () => {
const remaining = new Set(["req-3", "req-4"]);
const errors: string[] = [];
const byCustomId = new Map<string, number[]>();
applyEmbeddingBatchOutputLine({
line: {
custom_id: "req-3",
response: {
status_code: 500,
body: { error: { message: "internal" } },
},
},
remaining,
errors,
byCustomId,
});
applyEmbeddingBatchOutputLine({
line: {
custom_id: "req-4",
response: {
status_code: 200,
body: { data: [] },
},
},
remaining,
errors,
byCustomId,
});
expect(errors).toEqual(["req-3: internal", "req-4: empty embedding"]);
expect(byCustomId.size).toBe(0);
});
});

View File

@@ -1,55 +0,0 @@
export type EmbeddingBatchOutputLine = {
custom_id?: string;
error?: { message?: string };
response?: {
status_code?: number;
body?:
| {
data?: Array<{
embedding?: number[];
}>;
error?: { message?: string };
}
| string;
};
};
export function applyEmbeddingBatchOutputLine(params: {
line: EmbeddingBatchOutputLine;
remaining: Set<string>;
errors: string[];
byCustomId: Map<string, number[]>;
}) {
const customId = params.line.custom_id;
if (!customId) {
return;
}
params.remaining.delete(customId);
const errorMessage = params.line.error?.message;
if (errorMessage) {
params.errors.push(`${customId}: ${errorMessage}`);
return;
}
const response = params.line.response;
const statusCode = response?.status_code ?? 0;
if (statusCode >= 400) {
const messageFromObject =
response?.body && typeof response.body === "object"
? (response.body as { error?: { message?: string } }).error?.message
: undefined;
const messageFromString = typeof response?.body === "string" ? response.body : undefined;
params.errors.push(`${customId}: ${messageFromObject ?? messageFromString ?? "unknown error"}`);
return;
}
const data =
response?.body && typeof response.body === "object" ? (response.body.data ?? []) : [];
const embedding = data[0]?.embedding ?? [];
if (embedding.length === 0) {
params.errors.push(`${customId}: empty embedding`);
return;
}
params.byCustomId.set(customId, embedding);
}

View File

@@ -1,12 +0,0 @@
import type { EmbeddingBatchOutputLine } from "./batch-output.js";
export type EmbeddingBatchStatus = {
id?: string;
status?: string;
output_file_id?: string | null;
error_file_id?: string | null;
};
export type ProviderBatchOutputLine = EmbeddingBatchOutputLine;
export const EMBEDDING_BATCH_ENDPOINT = "/v1/embeddings";

View File

@@ -1,64 +0,0 @@
import { splitBatchRequests } from "./batch-utils.js";
import { runWithConcurrency } from "./internal.js";
export type EmbeddingBatchExecutionParams = {
wait: boolean;
pollIntervalMs: number;
timeoutMs: number;
concurrency: number;
debug?: (message: string, data?: Record<string, unknown>) => void;
};
export async function runEmbeddingBatchGroups<TRequest>(params: {
requests: TRequest[];
maxRequests: number;
wait: EmbeddingBatchExecutionParams["wait"];
pollIntervalMs: EmbeddingBatchExecutionParams["pollIntervalMs"];
timeoutMs: EmbeddingBatchExecutionParams["timeoutMs"];
concurrency: EmbeddingBatchExecutionParams["concurrency"];
debugLabel: string;
debug?: EmbeddingBatchExecutionParams["debug"];
runGroup: (args: {
group: TRequest[];
groupIndex: number;
groups: number;
byCustomId: Map<string, number[]>;
}) => Promise<void>;
}): Promise<Map<string, number[]>> {
if (params.requests.length === 0) {
return new Map();
}
const groups = splitBatchRequests(params.requests, params.maxRequests);
const byCustomId = new Map<string, number[]>();
const tasks = groups.map((group, groupIndex) => async () => {
await params.runGroup({ group, groupIndex, groups: groups.length, byCustomId });
});
params.debug?.(params.debugLabel, {
requests: params.requests.length,
groups: groups.length,
wait: params.wait,
concurrency: params.concurrency,
pollIntervalMs: params.pollIntervalMs,
timeoutMs: params.timeoutMs,
});
await runWithConcurrency(tasks, params.concurrency);
return byCustomId;
}
export function buildEmbeddingBatchGroupOptions<TRequest>(
params: { requests: TRequest[] } & EmbeddingBatchExecutionParams,
options: { maxRequests: number; debugLabel: string },
) {
return {
requests: params.requests,
maxRequests: options.maxRequests,
wait: params.wait,
pollIntervalMs: params.pollIntervalMs,
timeoutMs: params.timeoutMs,
concurrency: params.concurrency,
debug: params.debug,
debugLabel: options.debugLabel,
};
}

View File

@@ -1,60 +0,0 @@
import { describe, expect, it } from "vitest";
import {
resolveBatchCompletionFromStatus,
resolveCompletedBatchResult,
throwIfBatchTerminalFailure,
} from "./batch-status.js";
describe("batch-status helpers", () => {
it("resolves completion payload from completed status", () => {
expect(
resolveBatchCompletionFromStatus({
provider: "openai",
batchId: "b1",
status: {
output_file_id: "out-1",
error_file_id: "err-1",
},
}),
).toEqual({
outputFileId: "out-1",
errorFileId: "err-1",
});
});
it("throws for terminal failure states", async () => {
await expect(
throwIfBatchTerminalFailure({
provider: "voyage",
status: { id: "b2", status: "failed", error_file_id: "err-file" },
readError: async () => "bad input",
}),
).rejects.toThrow("voyage batch b2 failed: bad input");
});
it("returns completed result directly without waiting", async () => {
const waitForBatch = async () => ({ outputFileId: "out-2" });
const result = await resolveCompletedBatchResult({
provider: "openai",
status: {
id: "b3",
status: "completed",
output_file_id: "out-3",
},
wait: false,
waitForBatch,
});
expect(result).toEqual({ outputFileId: "out-3", errorFileId: undefined });
});
it("throws when wait disabled and batch is not complete", async () => {
await expect(
resolveCompletedBatchResult({
provider: "openai",
status: { id: "b4", status: "pending" },
wait: false,
waitForBatch: async () => ({ outputFileId: "out" }),
}),
).rejects.toThrow("openai batch b4 submitted; enable remote.batch.wait to await completion");
});
});

View File

@@ -1,69 +0,0 @@
const TERMINAL_FAILURE_STATES = new Set(["failed", "expired", "cancelled", "canceled"]);
type BatchStatusLike = {
id?: string;
status?: string;
output_file_id?: string | null;
error_file_id?: string | null;
};
export type BatchCompletionResult = {
outputFileId: string;
errorFileId?: string;
};
export function resolveBatchCompletionFromStatus(params: {
provider: string;
batchId: string;
status: BatchStatusLike;
}): BatchCompletionResult {
if (!params.status.output_file_id) {
throw new Error(`${params.provider} batch ${params.batchId} completed without output file`);
}
return {
outputFileId: params.status.output_file_id,
errorFileId: params.status.error_file_id ?? undefined,
};
}
export async function throwIfBatchTerminalFailure(params: {
provider: string;
status: BatchStatusLike;
readError: (errorFileId: string) => Promise<string | undefined>;
}): Promise<void> {
const state = params.status.status ?? "unknown";
if (!TERMINAL_FAILURE_STATES.has(state)) {
return;
}
const detail = params.status.error_file_id
? await params.readError(params.status.error_file_id)
: undefined;
const suffix = detail ? `: ${detail}` : "";
throw new Error(`${params.provider} batch ${params.status.id ?? "<unknown>"} ${state}${suffix}`);
}
export async function resolveCompletedBatchResult(params: {
provider: string;
status: BatchStatusLike;
wait: boolean;
waitForBatch: () => Promise<BatchCompletionResult>;
}): Promise<BatchCompletionResult> {
const batchId = params.status.id ?? "<unknown>";
if (!params.wait && params.status.status !== "completed") {
throw new Error(
`${params.provider} batch ${batchId} submitted; enable remote.batch.wait to await completion`,
);
}
const completed =
params.status.status === "completed"
? resolveBatchCompletionFromStatus({
provider: params.provider,
batchId,
status: params.status,
})
: await params.waitForBatch();
if (!completed.outputFileId) {
throw new Error(`${params.provider} batch ${batchId} completed without output file`);
}
return completed;
}

View File

@@ -1,44 +0,0 @@
import {
buildBatchHeaders,
normalizeBatchBaseUrl,
type BatchHttpClientConfig,
} from "./batch-utils.js";
import { hashText } from "./internal.js";
import { withRemoteHttpResponse } from "./remote-http.js";
export async function uploadBatchJsonlFile(params: {
client: BatchHttpClientConfig;
requests: unknown[];
errorPrefix: string;
}): Promise<string> {
const baseUrl = normalizeBatchBaseUrl(params.client);
const jsonl = params.requests.map((request) => JSON.stringify(request)).join("\n");
const form = new FormData();
form.append("purpose", "batch");
form.append(
"file",
new Blob([jsonl], { type: "application/jsonl" }),
`memory-embeddings.${hashText(String(Date.now()))}.jsonl`,
);
const filePayload = await withRemoteHttpResponse({
url: `${baseUrl}/files`,
ssrfPolicy: params.client.ssrfPolicy,
init: {
method: "POST",
headers: buildBatchHeaders(params.client, { json: false }),
body: form,
},
onResponse: async (fileRes) => {
if (!fileRes.ok) {
const text = await fileRes.text();
throw new Error(`${params.errorPrefix}: ${fileRes.status} ${text}`);
}
return (await fileRes.json()) as { id?: string };
},
});
if (!filePayload.id) {
throw new Error(`${params.errorPrefix}: missing file id`);
}
return filePayload.id;
}

View File

@@ -1,38 +0,0 @@
import type { SsrFPolicy } from "../../infra/net/ssrf.js";
export type BatchHttpClientConfig = {
baseUrl?: string;
headers?: Record<string, string>;
ssrfPolicy?: SsrFPolicy;
};
export function normalizeBatchBaseUrl(client: BatchHttpClientConfig): string {
return client.baseUrl?.replace(/\/$/, "") ?? "";
}
export function buildBatchHeaders(
client: Pick<BatchHttpClientConfig, "headers">,
params: { json: boolean },
): Record<string, string> {
const headers = client.headers ? { ...client.headers } : {};
if (params.json) {
if (!headers["Content-Type"] && !headers["content-type"]) {
headers["Content-Type"] = "application/json";
}
} else {
delete headers["Content-Type"];
delete headers["content-type"];
}
return headers;
}
export function splitBatchRequests<T>(requests: T[], maxRequests: number): T[][] {
if (requests.length <= maxRequests) {
return [requests];
}
const groups: T[][] = [];
for (let i = 0; i < requests.length; i += maxRequests) {
groups.push(requests.slice(i, i + maxRequests));
}
return groups;
}

View File

@@ -1,176 +0,0 @@
import { ReadableStream } from "node:stream/web";
import { setTimeout as nativeSleep } from "node:timers/promises";
import { describe, expect, it, vi } from "vitest";
import {
runVoyageEmbeddingBatches,
type VoyageBatchOutputLine,
type VoyageBatchRequest,
} from "./batch-voyage.js";
import type { VoyageEmbeddingClient } from "./embeddings-voyage.js";
const realNow = Date.now.bind(Date);
describe("runVoyageEmbeddingBatches", () => {
const mockClient: VoyageEmbeddingClient = {
baseUrl: "https://api.voyageai.com/v1",
headers: { Authorization: "Bearer test-key" },
model: "voyage-4-large",
};
const mockRequests: VoyageBatchRequest[] = [
{ custom_id: "req-1", body: { input: "text1" } },
{ custom_id: "req-2", body: { input: "text2" } },
];
it("successfully submits batch, waits, and streams results", async () => {
const outputLines: VoyageBatchOutputLine[] = [
{
custom_id: "req-1",
response: { status_code: 200, body: { data: [{ embedding: [0.1, 0.1] }] } },
},
{
custom_id: "req-2",
response: { status_code: 200, body: { data: [{ embedding: [0.2, 0.2] }] } },
},
];
const withRemoteHttpResponse = vi.fn();
const postJsonWithRetry = vi.fn();
const uploadBatchJsonlFile = vi.fn();
// Create a stream that emits the NDJSON lines
const stream = new ReadableStream({
start(controller) {
const text = outputLines.map((l) => JSON.stringify(l)).join("\n");
controller.enqueue(new TextEncoder().encode(text));
controller.close();
},
});
uploadBatchJsonlFile.mockImplementationOnce(async (params) => {
expect(params.errorPrefix).toBe("voyage batch file upload failed");
expect(params.requests).toEqual(mockRequests);
return "file-123";
});
postJsonWithRetry.mockImplementationOnce(async (params) => {
expect(params.url).toContain("/batches");
expect(params.body).toMatchObject({
input_file_id: "file-123",
completion_window: "12h",
request_params: {
model: "voyage-4-large",
input_type: "document",
},
});
return {
id: "batch-abc",
status: "pending",
};
});
withRemoteHttpResponse.mockImplementationOnce(async (params) => {
expect(params.url).toContain("/batches/batch-abc");
return await params.onResponse(
new Response(
JSON.stringify({
id: "batch-abc",
status: "completed",
output_file_id: "file-out-999",
}),
{
status: 200,
headers: { "Content-Type": "application/json" },
},
),
);
});
withRemoteHttpResponse.mockImplementationOnce(async (params) => {
expect(params.url).toContain("/files/file-out-999/content");
return await params.onResponse(
new Response(stream as unknown as BodyInit, {
status: 200,
headers: { "Content-Type": "application/x-ndjson" },
}),
);
});
const results = await runVoyageEmbeddingBatches({
client: mockClient,
agentId: "agent-1",
requests: mockRequests,
wait: true,
pollIntervalMs: 1, // fast poll
timeoutMs: 1000,
concurrency: 1,
deps: {
now: realNow,
sleep: async (ms) => {
await nativeSleep(ms);
},
postJsonWithRetry,
uploadBatchJsonlFile,
withRemoteHttpResponse,
},
});
expect(results.size).toBe(2);
expect(results.get("req-1")).toEqual([0.1, 0.1]);
expect(results.get("req-2")).toEqual([0.2, 0.2]);
expect(uploadBatchJsonlFile).toHaveBeenCalledTimes(1);
expect(postJsonWithRetry).toHaveBeenCalledTimes(1);
expect(withRemoteHttpResponse).toHaveBeenCalledTimes(2);
});
it("handles empty lines and stream chunks correctly", async () => {
const withRemoteHttpResponse = vi.fn();
const postJsonWithRetry = vi.fn();
const uploadBatchJsonlFile = vi.fn();
const stream = new ReadableStream({
start(controller) {
const line1 = JSON.stringify({
custom_id: "req-1",
response: { body: { data: [{ embedding: [1] }] } },
});
const line2 = JSON.stringify({
custom_id: "req-2",
response: { body: { data: [{ embedding: [2] }] } },
});
// Split across chunks
controller.enqueue(new TextEncoder().encode(line1 + "\n"));
controller.enqueue(new TextEncoder().encode("\n")); // empty line
controller.enqueue(new TextEncoder().encode(line2)); // no newline at EOF
controller.close();
},
});
uploadBatchJsonlFile.mockResolvedValueOnce("f1");
postJsonWithRetry.mockResolvedValueOnce({
id: "b1",
status: "completed",
output_file_id: "out1",
});
withRemoteHttpResponse.mockImplementationOnce(async (params) => {
expect(params.url).toContain("/files/out1/content");
return await params.onResponse(new Response(stream as unknown as BodyInit, { status: 200 }));
});
const results = await runVoyageEmbeddingBatches({
client: mockClient,
agentId: "a1",
requests: mockRequests,
wait: true,
pollIntervalMs: 1,
timeoutMs: 1000,
concurrency: 1,
deps: {
now: realNow,
sleep: async (ms) => {
await nativeSleep(ms);
},
postJsonWithRetry,
uploadBatchJsonlFile,
withRemoteHttpResponse,
},
});
expect(results.get("req-1")).toEqual([1]);
expect(results.get("req-2")).toEqual([2]);
});
});

View File

@@ -1,315 +0,0 @@
import { createInterface } from "node:readline";
import { Readable } from "node:stream";
import {
applyEmbeddingBatchOutputLine,
buildBatchHeaders,
buildEmbeddingBatchGroupOptions,
EMBEDDING_BATCH_ENDPOINT,
extractBatchErrorMessage,
formatUnavailableBatchError,
normalizeBatchBaseUrl,
postJsonWithRetry,
resolveBatchCompletionFromStatus,
resolveCompletedBatchResult,
runEmbeddingBatchGroups,
throwIfBatchTerminalFailure,
type EmbeddingBatchExecutionParams,
type EmbeddingBatchStatus,
type BatchCompletionResult,
type ProviderBatchOutputLine,
uploadBatchJsonlFile,
withRemoteHttpResponse,
} from "./batch-embedding-common.js";
import type { VoyageEmbeddingClient } from "./embeddings-voyage.js";
/**
* Voyage Batch API Input Line format.
* See: https://docs.voyageai.com/docs/batch-inference
*/
export type VoyageBatchRequest = {
custom_id: string;
body: {
input: string | string[];
};
};
export type VoyageBatchStatus = EmbeddingBatchStatus;
export type VoyageBatchOutputLine = ProviderBatchOutputLine;
export const VOYAGE_BATCH_ENDPOINT = EMBEDDING_BATCH_ENDPOINT;
const VOYAGE_BATCH_COMPLETION_WINDOW = "12h";
const VOYAGE_BATCH_MAX_REQUESTS = 50000;
type VoyageBatchDeps = {
now: () => number;
sleep: (ms: number) => Promise<void>;
postJsonWithRetry: typeof postJsonWithRetry;
uploadBatchJsonlFile: typeof uploadBatchJsonlFile;
withRemoteHttpResponse: typeof withRemoteHttpResponse;
};
function resolveVoyageBatchDeps(overrides: Partial<VoyageBatchDeps> | undefined): VoyageBatchDeps {
return {
now: overrides?.now ?? Date.now,
sleep:
overrides?.sleep ??
(async (ms: number) => await new Promise((resolve) => setTimeout(resolve, ms))),
postJsonWithRetry: overrides?.postJsonWithRetry ?? postJsonWithRetry,
uploadBatchJsonlFile: overrides?.uploadBatchJsonlFile ?? uploadBatchJsonlFile,
withRemoteHttpResponse: overrides?.withRemoteHttpResponse ?? withRemoteHttpResponse,
};
}
async function assertVoyageResponseOk(res: Response, context: string): Promise<void> {
if (!res.ok) {
const text = await res.text();
throw new Error(`${context}: ${res.status} ${text}`);
}
}
function buildVoyageBatchRequest<T>(params: {
client: VoyageEmbeddingClient;
path: string;
onResponse: (res: Response) => Promise<T>;
}) {
const baseUrl = normalizeBatchBaseUrl(params.client);
return {
url: `${baseUrl}/${params.path}`,
ssrfPolicy: params.client.ssrfPolicy,
init: {
headers: buildBatchHeaders(params.client, { json: true }),
},
onResponse: params.onResponse,
};
}
async function submitVoyageBatch(params: {
client: VoyageEmbeddingClient;
requests: VoyageBatchRequest[];
agentId: string;
deps: VoyageBatchDeps;
}): Promise<VoyageBatchStatus> {
const baseUrl = normalizeBatchBaseUrl(params.client);
const inputFileId = await params.deps.uploadBatchJsonlFile({
client: params.client,
requests: params.requests,
errorPrefix: "voyage batch file upload failed",
});
// 2. Create batch job using Voyage Batches API
return await params.deps.postJsonWithRetry<VoyageBatchStatus>({
url: `${baseUrl}/batches`,
headers: buildBatchHeaders(params.client, { json: true }),
ssrfPolicy: params.client.ssrfPolicy,
body: {
input_file_id: inputFileId,
endpoint: VOYAGE_BATCH_ENDPOINT,
completion_window: VOYAGE_BATCH_COMPLETION_WINDOW,
request_params: {
model: params.client.model,
input_type: "document",
},
metadata: {
source: "clawdbot-memory",
agent: params.agentId,
},
},
errorPrefix: "voyage batch create failed",
});
}
async function fetchVoyageBatchStatus(params: {
client: VoyageEmbeddingClient;
batchId: string;
deps: VoyageBatchDeps;
}): Promise<VoyageBatchStatus> {
return await params.deps.withRemoteHttpResponse(
buildVoyageBatchRequest({
client: params.client,
path: `batches/${params.batchId}`,
onResponse: async (res) => {
await assertVoyageResponseOk(res, "voyage batch status failed");
return (await res.json()) as VoyageBatchStatus;
},
}),
);
}
async function readVoyageBatchError(params: {
client: VoyageEmbeddingClient;
errorFileId: string;
deps: VoyageBatchDeps;
}): Promise<string | undefined> {
try {
return await params.deps.withRemoteHttpResponse(
buildVoyageBatchRequest({
client: params.client,
path: `files/${params.errorFileId}/content`,
onResponse: async (res) => {
await assertVoyageResponseOk(res, "voyage batch error file content failed");
const text = await res.text();
if (!text.trim()) {
return undefined;
}
const lines = text
.split("\n")
.map((line) => line.trim())
.filter(Boolean)
.map((line) => JSON.parse(line) as VoyageBatchOutputLine);
return extractBatchErrorMessage(lines);
},
}),
);
} catch (err) {
return formatUnavailableBatchError(err);
}
}
async function waitForVoyageBatch(params: {
client: VoyageEmbeddingClient;
batchId: string;
wait: boolean;
pollIntervalMs: number;
timeoutMs: number;
debug?: (message: string, data?: Record<string, unknown>) => void;
initial?: VoyageBatchStatus;
deps: VoyageBatchDeps;
}): Promise<BatchCompletionResult> {
const start = params.deps.now();
let current: VoyageBatchStatus | undefined = params.initial;
while (true) {
const status =
current ??
(await fetchVoyageBatchStatus({
client: params.client,
batchId: params.batchId,
deps: params.deps,
}));
const state = status.status ?? "unknown";
if (state === "completed") {
return resolveBatchCompletionFromStatus({
provider: "voyage",
batchId: params.batchId,
status,
});
}
await throwIfBatchTerminalFailure({
provider: "voyage",
status: { ...status, id: params.batchId },
readError: async (errorFileId) =>
await readVoyageBatchError({
client: params.client,
errorFileId,
deps: params.deps,
}),
});
if (!params.wait) {
throw new Error(`voyage batch ${params.batchId} still ${state}; wait disabled`);
}
if (params.deps.now() - start > params.timeoutMs) {
throw new Error(`voyage batch ${params.batchId} timed out after ${params.timeoutMs}ms`);
}
params.debug?.(`voyage batch ${params.batchId} ${state}; waiting ${params.pollIntervalMs}ms`);
await params.deps.sleep(params.pollIntervalMs);
current = undefined;
}
}
export async function runVoyageEmbeddingBatches(
params: {
client: VoyageEmbeddingClient;
agentId: string;
requests: VoyageBatchRequest[];
deps?: Partial<VoyageBatchDeps>;
} & EmbeddingBatchExecutionParams,
): Promise<Map<string, number[]>> {
const deps = resolveVoyageBatchDeps(params.deps);
return await runEmbeddingBatchGroups({
...buildEmbeddingBatchGroupOptions(params, {
maxRequests: VOYAGE_BATCH_MAX_REQUESTS,
debugLabel: "memory embeddings: voyage batch submit",
}),
runGroup: async ({ group, groupIndex, groups, byCustomId }) => {
const batchInfo = await submitVoyageBatch({
client: params.client,
requests: group,
agentId: params.agentId,
deps,
});
if (!batchInfo.id) {
throw new Error("voyage batch create failed: missing batch id");
}
const batchId = batchInfo.id;
params.debug?.("memory embeddings: voyage batch created", {
batchId: batchInfo.id,
status: batchInfo.status,
group: groupIndex + 1,
groups,
requests: group.length,
});
const completed = await resolveCompletedBatchResult({
provider: "voyage",
status: batchInfo,
wait: params.wait,
waitForBatch: async () =>
await waitForVoyageBatch({
client: params.client,
batchId,
wait: params.wait,
pollIntervalMs: params.pollIntervalMs,
timeoutMs: params.timeoutMs,
debug: params.debug,
initial: batchInfo,
deps,
}),
});
const baseUrl = normalizeBatchBaseUrl(params.client);
const errors: string[] = [];
const remaining = new Set(group.map((request) => request.custom_id));
await deps.withRemoteHttpResponse({
url: `${baseUrl}/files/${completed.outputFileId}/content`,
ssrfPolicy: params.client.ssrfPolicy,
init: {
headers: buildBatchHeaders(params.client, { json: true }),
},
onResponse: async (contentRes) => {
if (!contentRes.ok) {
const text = await contentRes.text();
throw new Error(`voyage batch file content failed: ${contentRes.status} ${text}`);
}
if (!contentRes.body) {
return;
}
const reader = createInterface({
input: Readable.fromWeb(
contentRes.body as unknown as import("stream/web").ReadableStream,
),
terminal: false,
});
for await (const rawLine of reader) {
if (!rawLine.trim()) {
continue;
}
const line = JSON.parse(rawLine) as VoyageBatchOutputLine;
applyEmbeddingBatchOutputLine({ line, remaining, errors, byCustomId });
}
},
});
if (errors.length > 0) {
throw new Error(`voyage batch ${batchInfo.id} failed: ${errors.join("; ")}`);
}
if (remaining.size > 0) {
throw new Error(
`voyage batch ${batchInfo.id} missing ${remaining.size} embedding responses`,
);
}
},
});
}

View File

@@ -1,102 +0,0 @@
import { describe, expect, it } from "vitest";
import { enforceEmbeddingMaxInputTokens } from "./embedding-chunk-limits.js";
import { estimateUtf8Bytes } from "./embedding-input-limits.js";
import type { EmbeddingProvider } from "./embeddings.js";
function createProvider(maxInputTokens: number): EmbeddingProvider {
return {
id: "mock",
model: "mock-embed",
maxInputTokens,
embedQuery: async () => [0],
embedBatch: async () => [[0]],
};
}
function createProviderWithoutMaxInputTokens(params: {
id: string;
model: string;
}): EmbeddingProvider {
return {
id: params.id,
model: params.model,
embedQuery: async () => [0],
embedBatch: async () => [[0]],
};
}
describe("embedding chunk limits", () => {
it("splits oversized chunks so each embedding input stays <= maxInputTokens bytes", () => {
const provider = createProvider(8192);
const input = {
startLine: 1,
endLine: 1,
text: "x".repeat(9000),
hash: "ignored",
};
const out = enforceEmbeddingMaxInputTokens(provider, [input]);
expect(out.length).toBeGreaterThan(1);
expect(out.map((chunk) => chunk.text).join("")).toBe(input.text);
expect(out.every((chunk) => estimateUtf8Bytes(chunk.text) <= 8192)).toBe(true);
expect(out.every((chunk) => chunk.startLine === 1 && chunk.endLine === 1)).toBe(true);
expect(out.every((chunk) => typeof chunk.hash === "string" && chunk.hash.length > 0)).toBe(
true,
);
});
it("does not split inside surrogate pairs (emoji)", () => {
const provider = createProvider(8192);
const emoji = "😀";
const inputText = `${emoji.repeat(2100)}\n${emoji.repeat(2100)}`;
const out = enforceEmbeddingMaxInputTokens(provider, [
{ startLine: 1, endLine: 2, text: inputText, hash: "ignored" },
]);
expect(out.length).toBeGreaterThan(1);
expect(out.map((chunk) => chunk.text).join("")).toBe(inputText);
expect(out.every((chunk) => estimateUtf8Bytes(chunk.text) <= 8192)).toBe(true);
// If we split inside surrogate pairs we'd likely end up with replacement chars.
expect(out.map((chunk) => chunk.text).join("")).not.toContain("\uFFFD");
});
it("uses conservative fallback limits for local providers without declared maxInputTokens", () => {
const provider = createProviderWithoutMaxInputTokens({
id: "local",
model: "unknown-local-embedding",
});
const out = enforceEmbeddingMaxInputTokens(provider, [
{
startLine: 1,
endLine: 1,
text: "x".repeat(3000),
hash: "ignored",
},
]);
expect(out.length).toBeGreaterThan(1);
expect(out.every((chunk) => estimateUtf8Bytes(chunk.text) <= 2048)).toBe(true);
});
it("honors hard safety caps lower than provider maxInputTokens", () => {
const provider = createProvider(8192);
const out = enforceEmbeddingMaxInputTokens(
provider,
[
{
startLine: 1,
endLine: 1,
text: "x".repeat(8100),
hash: "ignored",
},
],
8000,
);
expect(out.length).toBeGreaterThan(1);
expect(out.every((chunk) => estimateUtf8Bytes(chunk.text) <= 8000)).toBe(true);
});
});

View File

@@ -1,41 +0,0 @@
import { estimateUtf8Bytes, splitTextToUtf8ByteLimit } from "./embedding-input-limits.js";
import { hasNonTextEmbeddingParts } from "./embedding-inputs.js";
import { resolveEmbeddingMaxInputTokens } from "./embedding-model-limits.js";
import type { EmbeddingProvider } from "./embeddings.js";
import { hashText, type MemoryChunk } from "./internal.js";
export function enforceEmbeddingMaxInputTokens(
provider: EmbeddingProvider,
chunks: MemoryChunk[],
hardMaxInputTokens?: number,
): MemoryChunk[] {
const providerMaxInputTokens = resolveEmbeddingMaxInputTokens(provider);
const maxInputTokens =
typeof hardMaxInputTokens === "number" && hardMaxInputTokens > 0
? Math.min(providerMaxInputTokens, hardMaxInputTokens)
: providerMaxInputTokens;
const out: MemoryChunk[] = [];
for (const chunk of chunks) {
if (hasNonTextEmbeddingParts(chunk.embeddingInput)) {
out.push(chunk);
continue;
}
if (estimateUtf8Bytes(chunk.text) <= maxInputTokens) {
out.push(chunk);
continue;
}
for (const text of splitTextToUtf8ByteLimit(chunk.text, maxInputTokens)) {
out.push({
startLine: chunk.startLine,
endLine: chunk.endLine,
text,
hash: hashText(text),
embeddingInput: { text },
});
}
}
return out;
}

View File

@@ -1,85 +0,0 @@
import type { EmbeddingInput } from "./embedding-inputs.js";
// Helpers for enforcing embedding model input size limits.
//
// We use UTF-8 byte length as a conservative upper bound for tokenizer output.
// Tokenizers operate over bytes; a token must contain at least one byte, so
// token_count <= utf8_byte_length.
export function estimateUtf8Bytes(text: string): number {
if (!text) {
return 0;
}
return Buffer.byteLength(text, "utf8");
}
export function estimateStructuredEmbeddingInputBytes(input: EmbeddingInput): number {
if (!input.parts?.length) {
return estimateUtf8Bytes(input.text);
}
let total = 0;
for (const part of input.parts) {
if (part.type === "text") {
total += estimateUtf8Bytes(part.text);
continue;
}
total += estimateUtf8Bytes(part.mimeType);
total += estimateUtf8Bytes(part.data);
}
return total;
}
export function splitTextToUtf8ByteLimit(text: string, maxUtf8Bytes: number): string[] {
if (maxUtf8Bytes <= 0) {
return [text];
}
if (estimateUtf8Bytes(text) <= maxUtf8Bytes) {
return [text];
}
const parts: string[] = [];
let cursor = 0;
while (cursor < text.length) {
// The number of UTF-16 code units is always <= the number of UTF-8 bytes.
// This makes `cursor + maxUtf8Bytes` a safe upper bound on the next split point.
let low = cursor + 1;
let high = Math.min(text.length, cursor + maxUtf8Bytes);
let best = cursor;
while (low <= high) {
const mid = Math.floor((low + high) / 2);
const bytes = estimateUtf8Bytes(text.slice(cursor, mid));
if (bytes <= maxUtf8Bytes) {
best = mid;
low = mid + 1;
} else {
high = mid - 1;
}
}
if (best <= cursor) {
best = Math.min(text.length, cursor + 1);
}
// Avoid splitting inside a surrogate pair.
if (
best < text.length &&
best > cursor &&
text.charCodeAt(best - 1) >= 0xd800 &&
text.charCodeAt(best - 1) <= 0xdbff &&
text.charCodeAt(best) >= 0xdc00 &&
text.charCodeAt(best) <= 0xdfff
) {
best -= 1;
}
const part = text.slice(cursor, best);
if (!part) {
break;
}
parts.push(part);
cursor = best;
}
return parts;
}

View File

@@ -1,34 +0,0 @@
export type EmbeddingInputTextPart = {
type: "text";
text: string;
};
export type EmbeddingInputInlineDataPart = {
type: "inline-data";
mimeType: string;
data: string;
};
export type EmbeddingInputPart = EmbeddingInputTextPart | EmbeddingInputInlineDataPart;
export type EmbeddingInput = {
text: string;
parts?: EmbeddingInputPart[];
};
export function buildTextEmbeddingInput(text: string): EmbeddingInput {
return { text };
}
export function isInlineDataEmbeddingInputPart(
part: EmbeddingInputPart,
): part is EmbeddingInputInlineDataPart {
return part.type === "inline-data";
}
export function hasNonTextEmbeddingParts(input: EmbeddingInput | undefined): boolean {
if (!input?.parts?.length) {
return false;
}
return input.parts.some((part) => isInlineDataEmbeddingInputPart(part));
}

View File

@@ -1,41 +0,0 @@
import type { EmbeddingProvider } from "./embeddings.js";
const DEFAULT_EMBEDDING_MAX_INPUT_TOKENS = 8192;
const DEFAULT_LOCAL_EMBEDDING_MAX_INPUT_TOKENS = 2048;
const KNOWN_EMBEDDING_MAX_INPUT_TOKENS: Record<string, number> = {
"openai:text-embedding-3-small": 8192,
"openai:text-embedding-3-large": 8192,
"openai:text-embedding-ada-002": 8191,
"gemini:text-embedding-004": 2048,
"gemini:gemini-embedding-001": 2048,
"gemini:gemini-embedding-2-preview": 8192,
"voyage:voyage-3": 32000,
"voyage:voyage-3-lite": 16000,
"voyage:voyage-code-3": 32000,
};
export function resolveEmbeddingMaxInputTokens(provider: EmbeddingProvider): number {
if (typeof provider.maxInputTokens === "number") {
return provider.maxInputTokens;
}
// Provider/model mapping is best-effort; different providers use different
// limits and we prefer to be conservative when we don't know.
const key = `${provider.id}:${provider.model}`.toLowerCase();
const known = KNOWN_EMBEDDING_MAX_INPUT_TOKENS[key];
if (typeof known === "number") {
return known;
}
// Provider-specific conservative fallbacks. This prevents us from accidentally
// using the OpenAI default for providers with much smaller limits.
if (provider.id.toLowerCase() === "gemini") {
return 2048;
}
if (provider.id.toLowerCase() === "local") {
return DEFAULT_LOCAL_EMBEDDING_MAX_INPUT_TOKENS;
}
return DEFAULT_EMBEDDING_MAX_INPUT_TOKENS;
}

View File

@@ -1,8 +0,0 @@
export function sanitizeAndNormalizeEmbedding(vec: number[]): number[] {
const sanitized = vec.map((value) => (Number.isFinite(value) ? value : 0));
const magnitude = Math.sqrt(sanitized.reduce((sum, value) => sum + value * value, 0));
if (magnitude < 1e-10) {
return sanitized;
}
return sanitized.map((value) => value / magnitude);
}

View File

@@ -1,13 +0,0 @@
import { isTruthyEnvValue } from "../../infra/env.js";
import { createSubsystemLogger } from "../../logging/subsystem.js";
const debugEmbeddings = isTruthyEnvValue(process.env.OPENCLAW_DEBUG_MEMORY_EMBEDDINGS);
const log = createSubsystemLogger("memory/embeddings");
export function debugEmbeddingsLog(message: string, meta?: Record<string, unknown>): void {
if (!debugEmbeddings) {
return;
}
const suffix = meta ? ` ${JSON.stringify(meta)}` : "";
log.raw(`${message}${suffix}`);
}

View File

@@ -1,570 +0,0 @@
import { afterEach, beforeAll, beforeEach, describe, expect, it, vi } from "vitest";
import * as authModule from "../../agents/model-auth.js";
import { mockPublicPinnedHostname } from "./test-helpers/ssrf.js";
vi.mock("../../agents/model-auth.js", async () => {
const { createModelAuthMockModule } = await import("../../test-utils/model-auth-mock.js");
return createModelAuthMockModule();
});
const createGeminiFetchMock = (embeddingValues = [1, 2, 3]) =>
vi.fn(async (_input?: unknown, _init?: unknown) => ({
ok: true,
status: 200,
json: async () => ({ embedding: { values: embeddingValues } }),
}));
const createGeminiBatchFetchMock = (count: number, embeddingValues = [1, 2, 3]) =>
vi.fn(async (_input?: unknown, _init?: unknown) => ({
ok: true,
status: 200,
json: async () => ({
embeddings: Array.from({ length: count }, () => ({ values: embeddingValues })),
}),
}));
function readFirstFetchRequest(fetchMock: { mock: { calls: unknown[][] } }) {
const [url, init] = fetchMock.mock.calls[0] ?? [];
return { url, init: init as RequestInit | undefined };
}
function parseFetchBody(fetchMock: { mock: { calls: unknown[][] } }, callIndex = 0) {
const init = fetchMock.mock.calls[callIndex]?.[1] as RequestInit | undefined;
return JSON.parse((init?.body as string) ?? "{}") as Record<string, unknown>;
}
function magnitude(values: number[]) {
return Math.sqrt(values.reduce((sum, value) => sum + value * value, 0));
}
let buildGeminiEmbeddingRequest: typeof import("./embeddings-gemini.js").buildGeminiEmbeddingRequest;
let buildGeminiTextEmbeddingRequest: typeof import("./embeddings-gemini.js").buildGeminiTextEmbeddingRequest;
let createGeminiEmbeddingProvider: typeof import("./embeddings-gemini.js").createGeminiEmbeddingProvider;
let DEFAULT_GEMINI_EMBEDDING_MODEL: typeof import("./embeddings-gemini.js").DEFAULT_GEMINI_EMBEDDING_MODEL;
let GEMINI_EMBEDDING_2_MODELS: typeof import("./embeddings-gemini.js").GEMINI_EMBEDDING_2_MODELS;
let isGeminiEmbedding2Model: typeof import("./embeddings-gemini.js").isGeminiEmbedding2Model;
let resolveGeminiOutputDimensionality: typeof import("./embeddings-gemini.js").resolveGeminiOutputDimensionality;
beforeAll(async () => {
vi.doUnmock("undici");
vi.resetModules();
({
buildGeminiEmbeddingRequest,
buildGeminiTextEmbeddingRequest,
createGeminiEmbeddingProvider,
DEFAULT_GEMINI_EMBEDDING_MODEL,
GEMINI_EMBEDDING_2_MODELS,
isGeminiEmbedding2Model,
resolveGeminiOutputDimensionality,
} = await import("./embeddings-gemini.js"));
});
beforeEach(() => {
vi.useRealTimers();
vi.doUnmock("undici");
});
afterEach(() => {
vi.doUnmock("undici");
vi.resetAllMocks();
vi.unstubAllGlobals();
});
function mockResolvedProviderKey(apiKey = "test-key") {
vi.mocked(authModule.resolveApiKeyForProvider).mockResolvedValue({
apiKey,
mode: "api-key",
source: "test",
});
}
type GeminiFetchMock =
| ReturnType<typeof createGeminiFetchMock>
| ReturnType<typeof createGeminiBatchFetchMock>;
async function createProviderWithFetch(
fetchMock: GeminiFetchMock,
options: Partial<Parameters<typeof createGeminiEmbeddingProvider>[0]> & { model: string },
) {
vi.stubGlobal("fetch", fetchMock);
mockPublicPinnedHostname();
mockResolvedProviderKey();
const { provider } = await createGeminiEmbeddingProvider({
config: {} as never,
provider: "gemini",
fallback: "none",
...options,
});
return provider;
}
function expectNormalizedThreeFourVector(embedding: number[]) {
expect(embedding[0]).toBeCloseTo(0.6, 5);
expect(embedding[1]).toBeCloseTo(0.8, 5);
expect(magnitude(embedding)).toBeCloseTo(1, 5);
}
describe("buildGeminiTextEmbeddingRequest", () => {
it("builds a text embedding request with optional model and dimensions", () => {
expect(
buildGeminiTextEmbeddingRequest({
text: "hello",
taskType: "RETRIEVAL_DOCUMENT",
modelPath: "models/gemini-embedding-2-preview",
outputDimensionality: 1536,
}),
).toEqual({
model: "models/gemini-embedding-2-preview",
content: { parts: [{ text: "hello" }] },
taskType: "RETRIEVAL_DOCUMENT",
outputDimensionality: 1536,
});
});
});
describe("buildGeminiEmbeddingRequest", () => {
it("builds a multimodal request from structured input parts", () => {
expect(
buildGeminiEmbeddingRequest({
input: {
text: "Image file: diagram.png",
parts: [
{ type: "text", text: "Image file: diagram.png" },
{ type: "inline-data", mimeType: "image/png", data: "abc123" },
],
},
taskType: "RETRIEVAL_DOCUMENT",
modelPath: "models/gemini-embedding-2-preview",
outputDimensionality: 1536,
}),
).toEqual({
model: "models/gemini-embedding-2-preview",
content: {
parts: [
{ text: "Image file: diagram.png" },
{ inlineData: { mimeType: "image/png", data: "abc123" } },
],
},
taskType: "RETRIEVAL_DOCUMENT",
outputDimensionality: 1536,
});
});
});
// ---------- Model detection ----------
describe("isGeminiEmbedding2Model", () => {
it("returns true for gemini-embedding-2-preview", () => {
expect(isGeminiEmbedding2Model("gemini-embedding-2-preview")).toBe(true);
});
it("returns false for gemini-embedding-001", () => {
expect(isGeminiEmbedding2Model("gemini-embedding-001")).toBe(false);
});
it("returns false for text-embedding-004", () => {
expect(isGeminiEmbedding2Model("text-embedding-004")).toBe(false);
});
});
describe("GEMINI_EMBEDDING_2_MODELS", () => {
it("contains gemini-embedding-2-preview", () => {
expect(GEMINI_EMBEDDING_2_MODELS.has("gemini-embedding-2-preview")).toBe(true);
});
});
// ---------- Dimension resolution ----------
describe("resolveGeminiOutputDimensionality", () => {
it("returns undefined for non-v2 models", () => {
expect(resolveGeminiOutputDimensionality("gemini-embedding-001")).toBeUndefined();
expect(resolveGeminiOutputDimensionality("text-embedding-004")).toBeUndefined();
});
it("returns 3072 by default for v2 models", () => {
expect(resolveGeminiOutputDimensionality("gemini-embedding-2-preview")).toBe(3072);
});
it("accepts valid dimension values", () => {
expect(resolveGeminiOutputDimensionality("gemini-embedding-2-preview", 768)).toBe(768);
expect(resolveGeminiOutputDimensionality("gemini-embedding-2-preview", 1536)).toBe(1536);
expect(resolveGeminiOutputDimensionality("gemini-embedding-2-preview", 3072)).toBe(3072);
});
it("throws for invalid dimension values", () => {
expect(() => resolveGeminiOutputDimensionality("gemini-embedding-2-preview", 512)).toThrow(
/Invalid outputDimensionality 512/,
);
expect(() => resolveGeminiOutputDimensionality("gemini-embedding-2-preview", 1024)).toThrow(
/Valid values: 768, 1536, 3072/,
);
});
});
// ---------- Provider: gemini-embedding-001 (backward compat) ----------
describe("gemini-embedding-001 provider (backward compat)", () => {
it("does NOT include outputDimensionality in embedQuery", async () => {
const fetchMock = createGeminiFetchMock();
const provider = await createProviderWithFetch(fetchMock, {
model: "gemini-embedding-001",
});
await provider.embedQuery("test query");
const body = parseFetchBody(fetchMock);
expect(body).not.toHaveProperty("outputDimensionality");
expect(body.taskType).toBe("RETRIEVAL_QUERY");
expect(body.content).toEqual({ parts: [{ text: "test query" }] });
});
it("does NOT include outputDimensionality in embedBatch", async () => {
const fetchMock = createGeminiBatchFetchMock(2);
const provider = await createProviderWithFetch(fetchMock, {
model: "gemini-embedding-001",
});
await provider.embedBatch(["text1", "text2"]);
const body = parseFetchBody(fetchMock);
expect(body).not.toHaveProperty("outputDimensionality");
});
});
// ---------- Provider: gemini-embedding-2-preview ----------
describe("gemini-embedding-2-preview provider", () => {
it("includes outputDimensionality in embedQuery request", async () => {
const fetchMock = createGeminiFetchMock();
const provider = await createProviderWithFetch(fetchMock, {
model: "gemini-embedding-2-preview",
});
await provider.embedQuery("test query");
const body = parseFetchBody(fetchMock);
expect(body.outputDimensionality).toBe(3072);
expect(body.taskType).toBe("RETRIEVAL_QUERY");
expect(body.content).toEqual({ parts: [{ text: "test query" }] });
});
it("normalizes embedQuery response vectors", async () => {
const fetchMock = createGeminiFetchMock([3, 4]);
const provider = await createProviderWithFetch(fetchMock, {
model: "gemini-embedding-2-preview",
});
const embedding = await provider.embedQuery("test query");
expectNormalizedThreeFourVector(embedding);
});
it("includes outputDimensionality in embedBatch request", async () => {
const fetchMock = createGeminiBatchFetchMock(2);
const provider = await createProviderWithFetch(fetchMock, {
model: "gemini-embedding-2-preview",
});
await provider.embedBatch(["text1", "text2"]);
const body = parseFetchBody(fetchMock);
expect(body.requests).toEqual([
{
model: "models/gemini-embedding-2-preview",
content: { parts: [{ text: "text1" }] },
taskType: "RETRIEVAL_DOCUMENT",
outputDimensionality: 3072,
},
{
model: "models/gemini-embedding-2-preview",
content: { parts: [{ text: "text2" }] },
taskType: "RETRIEVAL_DOCUMENT",
outputDimensionality: 3072,
},
]);
});
it("normalizes embedBatch response vectors", async () => {
const fetchMock = createGeminiBatchFetchMock(2, [3, 4]);
const provider = await createProviderWithFetch(fetchMock, {
model: "gemini-embedding-2-preview",
});
const embeddings = await provider.embedBatch(["text1", "text2"]);
expect(embeddings).toHaveLength(2);
for (const embedding of embeddings) {
expectNormalizedThreeFourVector(embedding);
}
});
it("respects custom outputDimensionality", async () => {
const fetchMock = createGeminiFetchMock();
const provider = await createProviderWithFetch(fetchMock, {
model: "gemini-embedding-2-preview",
outputDimensionality: 768,
});
await provider.embedQuery("test");
const body = parseFetchBody(fetchMock);
expect(body.outputDimensionality).toBe(768);
});
it("sanitizes and normalizes embedQuery responses", async () => {
const fetchMock = createGeminiFetchMock([3, 4, Number.NaN]);
const provider = await createProviderWithFetch(fetchMock, {
model: "gemini-embedding-2-preview",
});
await expect(provider.embedQuery("test")).resolves.toEqual([0.6, 0.8, 0]);
});
it("uses custom outputDimensionality for each embedBatch request", async () => {
const fetchMock = createGeminiBatchFetchMock(2);
const provider = await createProviderWithFetch(fetchMock, {
model: "gemini-embedding-2-preview",
outputDimensionality: 768,
});
await provider.embedBatch(["text1", "text2"]);
const body = parseFetchBody(fetchMock);
expect(body.requests).toEqual([
expect.objectContaining({ outputDimensionality: 768 }),
expect.objectContaining({ outputDimensionality: 768 }),
]);
});
it("sanitizes and normalizes structured batch responses", async () => {
const fetchMock = createGeminiBatchFetchMock(1, [0, Number.POSITIVE_INFINITY, 5]);
const provider = await createProviderWithFetch(fetchMock, {
model: "gemini-embedding-2-preview",
});
await expect(
provider.embedBatchInputs?.([
{
text: "Image file: diagram.png",
parts: [
{ type: "text", text: "Image file: diagram.png" },
{ type: "inline-data", mimeType: "image/png", data: "img" },
],
},
]),
).resolves.toEqual([[0, 0, 1]]);
});
it("supports multimodal embedBatchInputs requests", async () => {
const fetchMock = createGeminiBatchFetchMock(2);
const provider = await createProviderWithFetch(fetchMock, {
model: "gemini-embedding-2-preview",
});
expect(provider.embedBatchInputs).toBeDefined();
await provider.embedBatchInputs?.([
{
text: "Image file: diagram.png",
parts: [
{ type: "text", text: "Image file: diagram.png" },
{ type: "inline-data", mimeType: "image/png", data: "img" },
],
},
{
text: "Audio file: note.wav",
parts: [
{ type: "text", text: "Audio file: note.wav" },
{ type: "inline-data", mimeType: "audio/wav", data: "aud" },
],
},
]);
const body = parseFetchBody(fetchMock);
expect(body.requests).toEqual([
{
model: "models/gemini-embedding-2-preview",
content: {
parts: [
{ text: "Image file: diagram.png" },
{ inlineData: { mimeType: "image/png", data: "img" } },
],
},
taskType: "RETRIEVAL_DOCUMENT",
outputDimensionality: 3072,
},
{
model: "models/gemini-embedding-2-preview",
content: {
parts: [
{ text: "Audio file: note.wav" },
{ inlineData: { mimeType: "audio/wav", data: "aud" } },
],
},
taskType: "RETRIEVAL_DOCUMENT",
outputDimensionality: 3072,
},
]);
});
it("throws for invalid outputDimensionality", async () => {
mockResolvedProviderKey();
await expect(
createGeminiEmbeddingProvider({
config: {} as never,
provider: "gemini",
model: "gemini-embedding-2-preview",
fallback: "none",
outputDimensionality: 512,
}),
).rejects.toThrow(/Invalid outputDimensionality 512/);
});
it("sanitizes non-finite values before normalization", async () => {
const fetchMock = createGeminiFetchMock([
1,
Number.NaN,
Number.POSITIVE_INFINITY,
Number.NEGATIVE_INFINITY,
]);
const provider = await createProviderWithFetch(fetchMock, {
model: "gemini-embedding-2-preview",
});
const embedding = await provider.embedQuery("test");
expect(embedding).toEqual([1, 0, 0, 0]);
});
it("uses correct endpoint URL", async () => {
const fetchMock = createGeminiFetchMock();
const provider = await createProviderWithFetch(fetchMock, {
model: "gemini-embedding-2-preview",
});
await provider.embedQuery("test");
const { url } = readFirstFetchRequest(fetchMock);
expect(url).toBe(
"https://generativelanguage.googleapis.com/v1beta/models/gemini-embedding-2-preview:embedContent",
);
});
it("allows taskType override via options", async () => {
const fetchMock = createGeminiFetchMock();
const provider = await createProviderWithFetch(fetchMock, {
model: "gemini-embedding-2-preview",
taskType: "SEMANTIC_SIMILARITY",
});
await provider.embedQuery("test");
const body = parseFetchBody(fetchMock);
expect(body.taskType).toBe("SEMANTIC_SIMILARITY");
});
});
// ---------- Model normalization ----------
describe("gemini model normalization", () => {
it("handles models/ prefix for v2 model", async () => {
const fetchMock = createGeminiFetchMock();
vi.stubGlobal("fetch", fetchMock);
mockPublicPinnedHostname();
mockResolvedProviderKey();
const { provider } = await createGeminiEmbeddingProvider({
config: {} as never,
provider: "gemini",
model: "models/gemini-embedding-2-preview",
fallback: "none",
});
await provider.embedQuery("test");
const body = parseFetchBody(fetchMock);
expect(body.outputDimensionality).toBe(3072);
});
it("handles gemini/ prefix for v2 model", async () => {
const fetchMock = createGeminiFetchMock();
vi.stubGlobal("fetch", fetchMock);
mockPublicPinnedHostname();
mockResolvedProviderKey();
const { provider } = await createGeminiEmbeddingProvider({
config: {} as never,
provider: "gemini",
model: "gemini/gemini-embedding-2-preview",
fallback: "none",
});
await provider.embedQuery("test");
const body = parseFetchBody(fetchMock);
expect(body.outputDimensionality).toBe(3072);
});
it("handles google/ prefix for v2 model", async () => {
const fetchMock = createGeminiFetchMock();
vi.stubGlobal("fetch", fetchMock);
mockPublicPinnedHostname();
mockResolvedProviderKey();
const { provider } = await createGeminiEmbeddingProvider({
config: {} as never,
provider: "gemini",
model: "google/gemini-embedding-2-preview",
fallback: "none",
});
await provider.embedQuery("test");
const body = parseFetchBody(fetchMock);
expect(body.outputDimensionality).toBe(3072);
});
it("defaults to gemini-embedding-001 when model is empty", async () => {
const fetchMock = createGeminiFetchMock();
vi.stubGlobal("fetch", fetchMock);
mockResolvedProviderKey();
const { provider, client } = await createGeminiEmbeddingProvider({
config: {} as never,
provider: "gemini",
model: "",
fallback: "none",
});
expect(client.model).toBe(DEFAULT_GEMINI_EMBEDDING_MODEL);
expect(provider.model).toBe(DEFAULT_GEMINI_EMBEDDING_MODEL);
});
it("returns empty array for blank query text", async () => {
mockResolvedProviderKey();
const { provider } = await createGeminiEmbeddingProvider({
config: {} as never,
provider: "gemini",
model: "gemini-embedding-2-preview",
fallback: "none",
});
const result = await provider.embedQuery(" ");
expect(result).toEqual([]);
});
it("returns empty array for empty batch", async () => {
mockResolvedProviderKey();
const { provider } = await createGeminiEmbeddingProvider({
config: {} as never,
provider: "gemini",
model: "gemini-embedding-2-preview",
fallback: "none",
});
const result = await provider.embedBatch([]);
expect(result).toEqual([]);
});
});

View File

@@ -1,336 +0,0 @@
import {
collectProviderApiKeysForExecution,
executeWithApiKeyRotation,
} from "../../agents/api-key-rotation.js";
import { requireApiKey, resolveApiKeyForProvider } from "../../agents/model-auth.js";
import { parseGeminiAuth } from "../../infra/gemini-auth.js";
import {
DEFAULT_GOOGLE_API_BASE_URL,
normalizeGoogleApiBaseUrl,
} from "../../infra/google-api-base-url.js";
import type { SsrFPolicy } from "../../infra/net/ssrf.js";
import type { EmbeddingInput } from "./embedding-inputs.js";
import { sanitizeAndNormalizeEmbedding } from "./embedding-vectors.js";
import { debugEmbeddingsLog } from "./embeddings-debug.js";
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js";
import { buildRemoteBaseUrlPolicy, withRemoteHttpResponse } from "./remote-http.js";
import { resolveMemorySecretInputString } from "./secret-input.js";
export type GeminiEmbeddingClient = {
baseUrl: string;
headers: Record<string, string>;
ssrfPolicy?: SsrFPolicy;
model: string;
modelPath: string;
apiKeys: string[];
outputDimensionality?: number;
};
export const DEFAULT_GEMINI_EMBEDDING_MODEL = "gemini-embedding-001";
const GEMINI_MAX_INPUT_TOKENS: Record<string, number> = {
"text-embedding-004": 2048,
};
// --- gemini-embedding-2-preview support ---
export const GEMINI_EMBEDDING_2_MODELS = new Set([
"gemini-embedding-2-preview",
// Add the GA model name here once released.
]);
const GEMINI_EMBEDDING_2_DEFAULT_DIMENSIONS = 3072;
const GEMINI_EMBEDDING_2_VALID_DIMENSIONS = [768, 1536, 3072] as const;
export type GeminiTaskType =
| "RETRIEVAL_QUERY"
| "RETRIEVAL_DOCUMENT"
| "SEMANTIC_SIMILARITY"
| "CLASSIFICATION"
| "CLUSTERING"
| "QUESTION_ANSWERING"
| "FACT_VERIFICATION";
export type GeminiTextPart = { text: string };
export type GeminiInlinePart = {
inlineData: { mimeType: string; data: string };
};
export type GeminiPart = GeminiTextPart | GeminiInlinePart;
export type GeminiEmbeddingRequest = {
content: { parts: GeminiPart[] };
taskType: GeminiTaskType;
outputDimensionality?: number;
model?: string;
};
export type GeminiTextEmbeddingRequest = GeminiEmbeddingRequest;
/** Builds the text-only Gemini embedding request shape used across direct and batch APIs. */
export function buildGeminiTextEmbeddingRequest(params: {
text: string;
taskType: GeminiTaskType;
outputDimensionality?: number;
modelPath?: string;
}): GeminiTextEmbeddingRequest {
return buildGeminiEmbeddingRequest({
input: { text: params.text },
taskType: params.taskType,
outputDimensionality: params.outputDimensionality,
modelPath: params.modelPath,
});
}
export function buildGeminiEmbeddingRequest(params: {
input: EmbeddingInput;
taskType: GeminiTaskType;
outputDimensionality?: number;
modelPath?: string;
}): GeminiEmbeddingRequest {
const request: GeminiEmbeddingRequest = {
content: {
parts: params.input.parts?.map((part) =>
part.type === "text"
? ({ text: part.text } satisfies GeminiTextPart)
: ({
inlineData: { mimeType: part.mimeType, data: part.data },
} satisfies GeminiInlinePart),
) ?? [{ text: params.input.text }],
},
taskType: params.taskType,
};
if (params.modelPath) {
request.model = params.modelPath;
}
if (params.outputDimensionality != null) {
request.outputDimensionality = params.outputDimensionality;
}
return request;
}
/**
* Returns true if the given model name is a gemini-embedding-2 variant that
* supports `outputDimensionality` and extended task types.
*/
export function isGeminiEmbedding2Model(model: string): boolean {
return GEMINI_EMBEDDING_2_MODELS.has(model);
}
/**
* Validate and return the `outputDimensionality` for gemini-embedding-2 models.
* Returns `undefined` for older models (they don't support the param).
*/
export function resolveGeminiOutputDimensionality(
model: string,
requested?: number,
): number | undefined {
if (!isGeminiEmbedding2Model(model)) {
return undefined;
}
if (requested == null) {
return GEMINI_EMBEDDING_2_DEFAULT_DIMENSIONS;
}
const valid: readonly number[] = GEMINI_EMBEDDING_2_VALID_DIMENSIONS;
if (!valid.includes(requested)) {
throw new Error(
`Invalid outputDimensionality ${requested} for ${model}. Valid values: ${valid.join(", ")}`,
);
}
return requested;
}
function resolveRemoteApiKey(remoteApiKey: unknown): string | undefined {
const trimmed = resolveMemorySecretInputString({
value: remoteApiKey,
path: "agents.*.memorySearch.remote.apiKey",
});
if (!trimmed) {
return undefined;
}
if (trimmed === "GOOGLE_API_KEY" || trimmed === "GEMINI_API_KEY") {
return process.env[trimmed]?.trim();
}
return trimmed;
}
export function normalizeGeminiModel(model: string): string {
const trimmed = model.trim();
if (!trimmed) {
return DEFAULT_GEMINI_EMBEDDING_MODEL;
}
const withoutPrefix = trimmed.replace(/^models\//, "");
if (withoutPrefix.startsWith("gemini/")) {
return withoutPrefix.slice("gemini/".length);
}
if (withoutPrefix.startsWith("google/")) {
return withoutPrefix.slice("google/".length);
}
return withoutPrefix;
}
async function fetchGeminiEmbeddingPayload(params: {
client: GeminiEmbeddingClient;
endpoint: string;
body: unknown;
}): Promise<{
embedding?: { values?: number[] };
embeddings?: Array<{ values?: number[] }>;
}> {
return await executeWithApiKeyRotation({
provider: "google",
apiKeys: params.client.apiKeys,
execute: async (apiKey) => {
const authHeaders = parseGeminiAuth(apiKey);
const headers = {
...authHeaders.headers,
...params.client.headers,
};
return await withRemoteHttpResponse({
url: params.endpoint,
ssrfPolicy: params.client.ssrfPolicy,
init: {
method: "POST",
headers,
body: JSON.stringify(params.body),
},
onResponse: async (res) => {
if (!res.ok) {
const text = await res.text();
throw new Error(`gemini embeddings failed: ${res.status} ${text}`);
}
return (await res.json()) as {
embedding?: { values?: number[] };
embeddings?: Array<{ values?: number[] }>;
};
},
});
},
});
}
function normalizeGeminiBaseUrl(raw: string): string {
const trimmed = raw.replace(/\/+$/, "");
const openAiIndex = trimmed.indexOf("/openai");
if (openAiIndex > -1) {
return normalizeGoogleApiBaseUrl(trimmed.slice(0, openAiIndex));
}
return normalizeGoogleApiBaseUrl(trimmed);
}
function buildGeminiModelPath(model: string): string {
return model.startsWith("models/") ? model : `models/${model}`;
}
export async function createGeminiEmbeddingProvider(
options: EmbeddingProviderOptions,
): Promise<{ provider: EmbeddingProvider; client: GeminiEmbeddingClient }> {
const client = await resolveGeminiEmbeddingClient(options);
const baseUrl = client.baseUrl.replace(/\/$/, "");
const embedUrl = `${baseUrl}/${client.modelPath}:embedContent`;
const batchUrl = `${baseUrl}/${client.modelPath}:batchEmbedContents`;
const isV2 = isGeminiEmbedding2Model(client.model);
const outputDimensionality = client.outputDimensionality;
const embedQuery = async (text: string): Promise<number[]> => {
if (!text.trim()) {
return [];
}
const payload = await fetchGeminiEmbeddingPayload({
client,
endpoint: embedUrl,
body: buildGeminiTextEmbeddingRequest({
text,
taskType: options.taskType ?? "RETRIEVAL_QUERY",
outputDimensionality: isV2 ? outputDimensionality : undefined,
}),
});
return sanitizeAndNormalizeEmbedding(payload.embedding?.values ?? []);
};
const embedBatchInputs = async (inputs: EmbeddingInput[]): Promise<number[][]> => {
if (inputs.length === 0) {
return [];
}
const payload = await fetchGeminiEmbeddingPayload({
client,
endpoint: batchUrl,
body: {
requests: inputs.map((input) =>
buildGeminiEmbeddingRequest({
input,
modelPath: client.modelPath,
taskType: options.taskType ?? "RETRIEVAL_DOCUMENT",
outputDimensionality: isV2 ? outputDimensionality : undefined,
}),
),
},
});
const embeddings = Array.isArray(payload.embeddings) ? payload.embeddings : [];
return inputs.map((_, index) => sanitizeAndNormalizeEmbedding(embeddings[index]?.values ?? []));
};
const embedBatch = async (texts: string[]): Promise<number[][]> => {
return await embedBatchInputs(
texts.map((text) => ({
text,
})),
);
};
return {
provider: {
id: "gemini",
model: client.model,
maxInputTokens: GEMINI_MAX_INPUT_TOKENS[client.model],
embedQuery,
embedBatch,
embedBatchInputs,
},
client,
};
}
export async function resolveGeminiEmbeddingClient(
options: EmbeddingProviderOptions,
): Promise<GeminiEmbeddingClient> {
const remote = options.remote;
const remoteApiKey = resolveRemoteApiKey(remote?.apiKey);
const remoteBaseUrl = remote?.baseUrl?.trim();
const apiKey = remoteApiKey
? remoteApiKey
: requireApiKey(
await resolveApiKeyForProvider({
provider: "google",
cfg: options.config,
agentDir: options.agentDir,
}),
"google",
);
const providerConfig = options.config.models?.providers?.google;
const rawBaseUrl =
remoteBaseUrl || providerConfig?.baseUrl?.trim() || DEFAULT_GOOGLE_API_BASE_URL;
const baseUrl = normalizeGeminiBaseUrl(rawBaseUrl);
const ssrfPolicy = buildRemoteBaseUrlPolicy(baseUrl);
const headerOverrides = Object.assign({}, providerConfig?.headers, remote?.headers);
const headers: Record<string, string> = {
...headerOverrides,
};
const apiKeys = collectProviderApiKeysForExecution({
provider: "google",
primaryApiKey: apiKey,
});
const model = normalizeGeminiModel(options.model);
const modelPath = buildGeminiModelPath(model);
const outputDimensionality = resolveGeminiOutputDimensionality(
model,
options.outputDimensionality,
);
debugEmbeddingsLog("memory embeddings: gemini client", {
rawBaseUrl,
baseUrl,
model,
modelPath,
outputDimensionality,
embedEndpoint: `${baseUrl}/${modelPath}:embedContent`,
batchEndpoint: `${baseUrl}/${modelPath}:batchEmbedContents`,
});
return { baseUrl, headers, ssrfPolicy, model, modelPath, apiKeys, outputDimensionality };
}

View File

@@ -1,19 +0,0 @@
import { describe, expect, it } from "vitest";
import { DEFAULT_MISTRAL_EMBEDDING_MODEL, normalizeMistralModel } from "./embeddings-mistral.js";
describe("normalizeMistralModel", () => {
it("returns the default model for empty values", () => {
expect(normalizeMistralModel("")).toBe(DEFAULT_MISTRAL_EMBEDDING_MODEL);
expect(normalizeMistralModel(" ")).toBe(DEFAULT_MISTRAL_EMBEDDING_MODEL);
});
it("strips the mistral/ prefix", () => {
expect(normalizeMistralModel("mistral/mistral-embed")).toBe("mistral-embed");
expect(normalizeMistralModel(" mistral/custom-embed ")).toBe("custom-embed");
});
it("keeps explicit non-prefixed models", () => {
expect(normalizeMistralModel("mistral-embed")).toBe("mistral-embed");
expect(normalizeMistralModel("custom-embed-v2")).toBe("custom-embed-v2");
});
});

View File

@@ -1,51 +0,0 @@
import type { SsrFPolicy } from "../../infra/net/ssrf.js";
import { normalizeEmbeddingModelWithPrefixes } from "./embeddings-model-normalize.js";
import {
createRemoteEmbeddingProvider,
resolveRemoteEmbeddingClient,
} from "./embeddings-remote-provider.js";
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js";
export type MistralEmbeddingClient = {
baseUrl: string;
headers: Record<string, string>;
ssrfPolicy?: SsrFPolicy;
model: string;
};
export const DEFAULT_MISTRAL_EMBEDDING_MODEL = "mistral-embed";
const DEFAULT_MISTRAL_BASE_URL = "https://api.mistral.ai/v1";
export function normalizeMistralModel(model: string): string {
return normalizeEmbeddingModelWithPrefixes({
model,
defaultModel: DEFAULT_MISTRAL_EMBEDDING_MODEL,
prefixes: ["mistral/"],
});
}
export async function createMistralEmbeddingProvider(
options: EmbeddingProviderOptions,
): Promise<{ provider: EmbeddingProvider; client: MistralEmbeddingClient }> {
const client = await resolveMistralEmbeddingClient(options);
return {
provider: createRemoteEmbeddingProvider({
id: "mistral",
client,
errorPrefix: "mistral embeddings failed",
}),
client,
};
}
export async function resolveMistralEmbeddingClient(
options: EmbeddingProviderOptions,
): Promise<MistralEmbeddingClient> {
return await resolveRemoteEmbeddingClient({
provider: "mistral",
options,
defaultBaseUrl: DEFAULT_MISTRAL_BASE_URL,
normalizeModel: normalizeMistralModel,
});
}

View File

@@ -1,34 +0,0 @@
import { describe, expect, it } from "vitest";
import { normalizeEmbeddingModelWithPrefixes } from "./embeddings-model-normalize.js";
describe("normalizeEmbeddingModelWithPrefixes", () => {
it("returns default model when input is blank", () => {
expect(
normalizeEmbeddingModelWithPrefixes({
model: " ",
defaultModel: "fallback-model",
prefixes: ["openai/"],
}),
).toBe("fallback-model");
});
it("strips the first matching prefix", () => {
expect(
normalizeEmbeddingModelWithPrefixes({
model: "openai/text-embedding-3-small",
defaultModel: "fallback-model",
prefixes: ["openai/"],
}),
).toBe("text-embedding-3-small");
});
it("keeps explicit model names when no prefix matches", () => {
expect(
normalizeEmbeddingModelWithPrefixes({
model: "voyage-4-large",
defaultModel: "fallback-model",
prefixes: ["voyage/"],
}),
).toBe("voyage-4-large");
});
});

View File

@@ -1,16 +0,0 @@
export function normalizeEmbeddingModelWithPrefixes(params: {
model: string;
defaultModel: string;
prefixes: string[];
}): string {
const trimmed = params.model.trim();
if (!trimmed) {
return params.defaultModel;
}
for (const prefix of params.prefixes) {
if (trimmed.startsWith(prefix)) {
return trimmed.slice(prefix.length);
}
}
return trimmed;
}

View File

@@ -1,146 +0,0 @@
import { afterEach, beforeAll, beforeEach, describe, it, expect, vi } from "vitest";
import type { OpenClawConfig } from "../../config/config.js";
let createOllamaEmbeddingProvider: typeof import("./embeddings-ollama.js").createOllamaEmbeddingProvider;
beforeAll(async () => {
({ createOllamaEmbeddingProvider } = await import("./embeddings-ollama.js"));
});
beforeEach(() => {
vi.useRealTimers();
vi.doUnmock("undici");
});
afterEach(() => {
vi.doUnmock("undici");
vi.unstubAllGlobals();
vi.unstubAllEnvs();
vi.resetAllMocks();
});
describe("embeddings-ollama", () => {
it("calls /api/embeddings and returns normalized vectors", async () => {
const fetchMock = vi.fn(
async () =>
new Response(JSON.stringify({ embedding: [3, 4] }), {
status: 200,
headers: { "content-type": "application/json" },
}),
);
globalThis.fetch = fetchMock as unknown as typeof fetch;
const { provider } = await createOllamaEmbeddingProvider({
config: {} as OpenClawConfig,
provider: "ollama",
model: "nomic-embed-text",
fallback: "none",
remote: { baseUrl: "http://127.0.0.1:11434" },
});
const v = await provider.embedQuery("hi");
expect(fetchMock).toHaveBeenCalledTimes(1);
// normalized [3,4] => [0.6,0.8]
expect(v[0]).toBeCloseTo(0.6, 5);
expect(v[1]).toBeCloseTo(0.8, 5);
});
it("resolves baseUrl/apiKey/headers from models.providers.ollama and strips /v1", async () => {
const fetchMock = vi.fn(
async () =>
new Response(JSON.stringify({ embedding: [1, 0] }), {
status: 200,
headers: { "content-type": "application/json" },
}),
);
globalThis.fetch = fetchMock as unknown as typeof fetch;
const { provider } = await createOllamaEmbeddingProvider({
config: {
models: {
providers: {
ollama: {
baseUrl: "http://127.0.0.1:11434/v1",
apiKey: "ollama-\nlocal\r\n", // pragma: allowlist secret
headers: {
"X-Provider-Header": "provider",
},
},
},
},
} as unknown as OpenClawConfig,
provider: "ollama",
model: "",
fallback: "none",
});
await provider.embedQuery("hello");
expect(fetchMock).toHaveBeenCalledWith(
"http://127.0.0.1:11434/api/embeddings",
expect.objectContaining({
method: "POST",
headers: expect.objectContaining({
"Content-Type": "application/json",
Authorization: "Bearer ollama-local",
"X-Provider-Header": "provider",
}),
}),
);
});
it("fails fast when memory-search remote apiKey is an unresolved SecretRef", async () => {
await expect(
createOllamaEmbeddingProvider({
config: {} as OpenClawConfig,
provider: "ollama",
model: "nomic-embed-text",
fallback: "none",
remote: {
baseUrl: "http://127.0.0.1:11434",
apiKey: { source: "env", provider: "default", id: "OLLAMA_API_KEY" },
},
}),
).rejects.toThrow(/agents\.\*\.memorySearch\.remote\.apiKey: unresolved SecretRef/i);
});
it("falls back to env key when models.providers.ollama.apiKey is an unresolved SecretRef", async () => {
const fetchMock = vi.fn(
async () =>
new Response(JSON.stringify({ embedding: [1, 0] }), {
status: 200,
headers: { "content-type": "application/json" },
}),
);
globalThis.fetch = fetchMock as unknown as typeof fetch;
vi.stubEnv("OLLAMA_API_KEY", "ollama-env");
const { provider } = await createOllamaEmbeddingProvider({
config: {
models: {
providers: {
ollama: {
baseUrl: "http://127.0.0.1:11434/v1",
apiKey: { source: "env", provider: "default", id: "OLLAMA_API_KEY" },
models: [],
},
},
},
} as unknown as OpenClawConfig,
provider: "ollama",
model: "nomic-embed-text",
fallback: "none",
});
await provider.embedQuery("hello");
expect(fetchMock).toHaveBeenCalledWith(
"http://127.0.0.1:11434/api/embeddings",
expect.objectContaining({
headers: expect.objectContaining({
Authorization: "Bearer ollama-env",
}),
}),
);
});
});

View File

@@ -1,123 +0,0 @@
import { resolveEnvApiKey } from "../../agents/model-auth.js";
import { resolveOllamaApiBase } from "../../agents/ollama-models.js";
import { formatErrorMessage } from "../../infra/errors.js";
import type { SsrFPolicy } from "../../infra/net/ssrf.js";
import { normalizeOptionalSecretInput } from "../../utils/normalize-secret-input.js";
import { sanitizeAndNormalizeEmbedding } from "./embedding-vectors.js";
import { normalizeEmbeddingModelWithPrefixes } from "./embeddings-model-normalize.js";
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js";
import { buildRemoteBaseUrlPolicy, withRemoteHttpResponse } from "./remote-http.js";
import { resolveMemorySecretInputString } from "./secret-input.js";
export type OllamaEmbeddingClient = {
baseUrl: string;
headers: Record<string, string>;
ssrfPolicy?: SsrFPolicy;
model: string;
embedBatch: (texts: string[]) => Promise<number[][]>;
};
type OllamaEmbeddingClientConfig = Omit<OllamaEmbeddingClient, "embedBatch">;
export const DEFAULT_OLLAMA_EMBEDDING_MODEL = "nomic-embed-text";
function normalizeOllamaModel(model: string): string {
return normalizeEmbeddingModelWithPrefixes({
model,
defaultModel: DEFAULT_OLLAMA_EMBEDDING_MODEL,
prefixes: ["ollama/"],
});
}
function resolveOllamaApiKey(options: EmbeddingProviderOptions): string | undefined {
const remoteApiKey = resolveMemorySecretInputString({
value: options.remote?.apiKey,
path: "agents.*.memorySearch.remote.apiKey",
});
if (remoteApiKey) {
return remoteApiKey;
}
const providerApiKey = normalizeOptionalSecretInput(
options.config.models?.providers?.ollama?.apiKey,
);
if (providerApiKey) {
return providerApiKey;
}
return resolveEnvApiKey("ollama")?.apiKey;
}
function resolveOllamaEmbeddingClient(
options: EmbeddingProviderOptions,
): OllamaEmbeddingClientConfig {
const providerConfig = options.config.models?.providers?.ollama;
const rawBaseUrl = options.remote?.baseUrl?.trim() || providerConfig?.baseUrl?.trim();
const baseUrl = resolveOllamaApiBase(rawBaseUrl);
const model = normalizeOllamaModel(options.model);
const headerOverrides = Object.assign({}, providerConfig?.headers, options.remote?.headers);
const headers: Record<string, string> = {
"Content-Type": "application/json",
...headerOverrides,
};
const apiKey = resolveOllamaApiKey(options);
if (apiKey) {
headers.Authorization = `Bearer ${apiKey}`;
}
return {
baseUrl,
headers,
ssrfPolicy: buildRemoteBaseUrlPolicy(baseUrl),
model,
};
}
export async function createOllamaEmbeddingProvider(
options: EmbeddingProviderOptions,
): Promise<{ provider: EmbeddingProvider; client: OllamaEmbeddingClient }> {
const client = resolveOllamaEmbeddingClient(options);
const embedUrl = `${client.baseUrl.replace(/\/$/, "")}/api/embeddings`;
const embedOne = async (text: string): Promise<number[]> => {
const json = await withRemoteHttpResponse({
url: embedUrl,
ssrfPolicy: client.ssrfPolicy,
init: {
method: "POST",
headers: client.headers,
body: JSON.stringify({ model: client.model, prompt: text }),
},
onResponse: async (res) => {
if (!res.ok) {
throw new Error(`Ollama embeddings HTTP ${res.status}: ${await res.text()}`);
}
return (await res.json()) as { embedding?: number[] };
},
});
if (!Array.isArray(json.embedding)) {
throw new Error(`Ollama embeddings response missing embedding[]`);
}
return sanitizeAndNormalizeEmbedding(json.embedding);
};
const provider: EmbeddingProvider = {
id: "ollama",
model: client.model,
embedQuery: embedOne,
embedBatch: async (texts: string[]) => {
// Ollama /api/embeddings accepts one prompt per request.
return await Promise.all(texts.map(embedOne));
},
};
return {
provider,
client: {
...client,
embedBatch: async (texts) => {
try {
return await provider.embedBatch(texts);
} catch (err) {
throw new Error(formatErrorMessage(err), { cause: err });
}
},
},
};
}

View File

@@ -1,58 +0,0 @@
import type { SsrFPolicy } from "../../infra/net/ssrf.js";
import { OPENAI_DEFAULT_EMBEDDING_MODEL } from "../../plugins/provider-model-defaults.js";
import { normalizeEmbeddingModelWithPrefixes } from "./embeddings-model-normalize.js";
import {
createRemoteEmbeddingProvider,
resolveRemoteEmbeddingClient,
} from "./embeddings-remote-provider.js";
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js";
export type OpenAiEmbeddingClient = {
baseUrl: string;
headers: Record<string, string>;
ssrfPolicy?: SsrFPolicy;
model: string;
};
const DEFAULT_OPENAI_BASE_URL = "https://api.openai.com/v1";
export const DEFAULT_OPENAI_EMBEDDING_MODEL = OPENAI_DEFAULT_EMBEDDING_MODEL;
const OPENAI_MAX_INPUT_TOKENS: Record<string, number> = {
"text-embedding-3-small": 8192,
"text-embedding-3-large": 8192,
"text-embedding-ada-002": 8191,
};
export function normalizeOpenAiModel(model: string): string {
return normalizeEmbeddingModelWithPrefixes({
model,
defaultModel: DEFAULT_OPENAI_EMBEDDING_MODEL,
prefixes: ["openai/"],
});
}
export async function createOpenAiEmbeddingProvider(
options: EmbeddingProviderOptions,
): Promise<{ provider: EmbeddingProvider; client: OpenAiEmbeddingClient }> {
const client = await resolveOpenAiEmbeddingClient(options);
return {
provider: createRemoteEmbeddingProvider({
id: "openai",
client,
errorPrefix: "openai embeddings failed",
maxInputTokens: OPENAI_MAX_INPUT_TOKENS[client.model],
}),
client,
};
}
export async function resolveOpenAiEmbeddingClient(
options: EmbeddingProviderOptions,
): Promise<OpenAiEmbeddingClient> {
return await resolveRemoteEmbeddingClient({
provider: "openai",
options,
defaultBaseUrl: DEFAULT_OPENAI_BASE_URL,
normalizeModel: normalizeOpenAiModel,
});
}

View File

@@ -1,39 +0,0 @@
import { requireApiKey, resolveApiKeyForProvider } from "../../agents/model-auth.js";
import type { SsrFPolicy } from "../../infra/net/ssrf.js";
import type { EmbeddingProviderOptions } from "./embeddings.js";
import { buildRemoteBaseUrlPolicy } from "./remote-http.js";
import { resolveMemorySecretInputString } from "./secret-input.js";
export type RemoteEmbeddingProviderId = "openai" | "voyage" | "mistral";
export async function resolveRemoteEmbeddingBearerClient(params: {
provider: RemoteEmbeddingProviderId;
options: EmbeddingProviderOptions;
defaultBaseUrl: string;
}): Promise<{ baseUrl: string; headers: Record<string, string>; ssrfPolicy?: SsrFPolicy }> {
const remote = params.options.remote;
const remoteApiKey = resolveMemorySecretInputString({
value: remote?.apiKey,
path: "agents.*.memorySearch.remote.apiKey",
});
const remoteBaseUrl = remote?.baseUrl?.trim();
const providerConfig = params.options.config.models?.providers?.[params.provider];
const apiKey = remoteApiKey
? remoteApiKey
: requireApiKey(
await resolveApiKeyForProvider({
provider: params.provider,
cfg: params.options.config,
agentDir: params.options.agentDir,
}),
params.provider,
);
const baseUrl = remoteBaseUrl || providerConfig?.baseUrl?.trim() || params.defaultBaseUrl;
const headerOverrides = Object.assign({}, providerConfig?.headers, remote?.headers);
const headers: Record<string, string> = {
"Content-Type": "application/json",
Authorization: `Bearer ${apiKey}`,
...headerOverrides,
};
return { baseUrl, headers, ssrfPolicy: buildRemoteBaseUrlPolicy(baseUrl) };
}

View File

@@ -1,56 +0,0 @@
import { beforeEach, describe, expect, it, vi } from "vitest";
const postJsonMock = vi.hoisted(() => vi.fn());
type EmbeddingsRemoteFetchModule = typeof import("./embeddings-remote-fetch.js");
let fetchRemoteEmbeddingVectors: EmbeddingsRemoteFetchModule["fetchRemoteEmbeddingVectors"];
describe("fetchRemoteEmbeddingVectors", () => {
beforeEach(async () => {
vi.resetModules();
vi.doMock("./post-json.js", () => ({
postJson: postJsonMock,
}));
({ fetchRemoteEmbeddingVectors } = await import("./embeddings-remote-fetch.js"));
postJsonMock.mockReset();
});
it("maps remote embedding response data to vectors", async () => {
postJsonMock.mockImplementationOnce(async (params) => {
return await params.parse({
data: [{ embedding: [0.1, 0.2] }, {}, { embedding: [0.3] }],
});
});
const vectors = await fetchRemoteEmbeddingVectors({
url: "https://memory.example/v1/embeddings",
headers: { Authorization: "Bearer test" },
body: { input: ["one", "two", "three"] },
errorPrefix: "embedding fetch failed",
});
expect(vectors).toEqual([[0.1, 0.2], [], [0.3]]);
expect(postJsonMock).toHaveBeenCalledWith(
expect.objectContaining({
url: "https://memory.example/v1/embeddings",
headers: { Authorization: "Bearer test" },
body: { input: ["one", "two", "three"] },
errorPrefix: "embedding fetch failed",
}),
);
});
it("throws a status-rich error on non-ok responses", async () => {
postJsonMock.mockRejectedValueOnce(new Error("embedding fetch failed: 403 forbidden"));
await expect(
fetchRemoteEmbeddingVectors({
url: "https://memory.example/v1/embeddings",
headers: {},
body: { input: ["one"] },
errorPrefix: "embedding fetch failed",
}),
).rejects.toThrow("embedding fetch failed: 403 forbidden");
});
});

View File

@@ -1,25 +0,0 @@
import type { SsrFPolicy } from "../../infra/net/ssrf.js";
import { postJson } from "./post-json.js";
export async function fetchRemoteEmbeddingVectors(params: {
url: string;
headers: Record<string, string>;
ssrfPolicy?: SsrFPolicy;
body: unknown;
errorPrefix: string;
}): Promise<number[][]> {
return await postJson({
url: params.url,
headers: params.headers,
ssrfPolicy: params.ssrfPolicy,
body: params.body,
errorPrefix: params.errorPrefix,
parse: (payload) => {
const typedPayload = payload as {
data?: Array<{ embedding?: number[] }>;
};
const data = typedPayload.data ?? [];
return data.map((entry) => entry.embedding ?? []);
},
});
}

View File

@@ -1,63 +0,0 @@
import type { SsrFPolicy } from "../../infra/net/ssrf.js";
import {
resolveRemoteEmbeddingBearerClient,
type RemoteEmbeddingProviderId,
} from "./embeddings-remote-client.js";
import { fetchRemoteEmbeddingVectors } from "./embeddings-remote-fetch.js";
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js";
export type RemoteEmbeddingClient = {
baseUrl: string;
headers: Record<string, string>;
ssrfPolicy?: SsrFPolicy;
model: string;
};
export function createRemoteEmbeddingProvider(params: {
id: string;
client: RemoteEmbeddingClient;
errorPrefix: string;
maxInputTokens?: number;
}): EmbeddingProvider {
const { client } = params;
const url = `${client.baseUrl.replace(/\/$/, "")}/embeddings`;
const embed = async (input: string[]): Promise<number[][]> => {
if (input.length === 0) {
return [];
}
return await fetchRemoteEmbeddingVectors({
url,
headers: client.headers,
ssrfPolicy: client.ssrfPolicy,
body: { model: client.model, input },
errorPrefix: params.errorPrefix,
});
};
return {
id: params.id,
model: client.model,
...(typeof params.maxInputTokens === "number" ? { maxInputTokens: params.maxInputTokens } : {}),
embedQuery: async (text) => {
const [vec] = await embed([text]);
return vec ?? [];
},
embedBatch: embed,
};
}
export async function resolveRemoteEmbeddingClient(params: {
provider: RemoteEmbeddingProviderId;
options: EmbeddingProviderOptions;
defaultBaseUrl: string;
normalizeModel: (model: string) => string;
}): Promise<RemoteEmbeddingClient> {
const { baseUrl, headers, ssrfPolicy } = await resolveRemoteEmbeddingBearerClient({
provider: params.provider,
options: params.options,
defaultBaseUrl: params.defaultBaseUrl,
});
const model = params.normalizeModel(params.options.model);
return { baseUrl, headers, ssrfPolicy, model };
}

View File

@@ -1,153 +0,0 @@
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
import { type FetchMock, withFetchPreconnect } from "../../test-utils/fetch-mock.js";
import { mockPublicPinnedHostname } from "./test-helpers/ssrf.js";
vi.mock("../../agents/model-auth.js", async () => {
const { createModelAuthMockModule } = await import("../../test-utils/model-auth-mock.js");
return createModelAuthMockModule();
});
const createFetchMock = () => {
const fetchMock = vi.fn<FetchMock>(
async (_input: RequestInfo | URL, _init?: RequestInit) =>
new Response(JSON.stringify({ data: [{ embedding: [0.1, 0.2, 0.3] }] }), {
status: 200,
headers: { "Content-Type": "application/json" },
}),
);
return withFetchPreconnect(fetchMock);
};
let authModule: typeof import("../../agents/model-auth.js");
let createVoyageEmbeddingProvider: typeof import("./embeddings-voyage.js").createVoyageEmbeddingProvider;
let normalizeVoyageModel: typeof import("./embeddings-voyage.js").normalizeVoyageModel;
beforeEach(async () => {
vi.useRealTimers();
vi.doUnmock("undici");
vi.resetModules();
authModule = await import("../../agents/model-auth.js");
({ createVoyageEmbeddingProvider, normalizeVoyageModel } =
await import("./embeddings-voyage.js"));
});
function mockVoyageApiKey() {
vi.mocked(authModule.resolveApiKeyForProvider).mockResolvedValue({
apiKey: "voyage-key-123",
mode: "api-key",
source: "test",
});
}
async function createDefaultVoyageProvider(
model: string,
fetchMock: ReturnType<typeof createFetchMock>,
) {
vi.stubGlobal("fetch", fetchMock);
mockPublicPinnedHostname();
mockVoyageApiKey();
return createVoyageEmbeddingProvider({
config: {} as never,
provider: "voyage",
model,
fallback: "none",
});
}
describe("voyage embedding provider", () => {
afterEach(() => {
vi.doUnmock("undici");
vi.resetAllMocks();
vi.unstubAllGlobals();
});
it("configures client with correct defaults and headers", async () => {
const fetchMock = createFetchMock();
const result = await createDefaultVoyageProvider("voyage-4-large", fetchMock);
await result.provider.embedQuery("test query");
expect(authModule.resolveApiKeyForProvider).toHaveBeenCalledWith(
expect.objectContaining({ provider: "voyage" }),
);
const call = fetchMock.mock.calls[0];
expect(call).toBeDefined();
const [url, init] = call as [RequestInfo | URL, RequestInit | undefined];
expect(url).toBe("https://api.voyageai.com/v1/embeddings");
const headers = (init?.headers ?? {}) as Record<string, string>;
expect(headers.Authorization).toBe("Bearer voyage-key-123");
expect(headers["Content-Type"]).toBe("application/json");
const body = JSON.parse(init?.body as string);
expect(body).toEqual({
model: "voyage-4-large",
input: ["test query"],
input_type: "query",
});
});
it("respects remote overrides for baseUrl and apiKey", async () => {
const fetchMock = createFetchMock();
vi.stubGlobal("fetch", fetchMock);
mockPublicPinnedHostname();
const result = await createVoyageEmbeddingProvider({
config: {} as never,
provider: "voyage",
model: "voyage-4-lite",
fallback: "none",
remote: {
baseUrl: "https://example.com",
apiKey: "remote-override-key",
headers: { "X-Custom": "123" },
},
});
await result.provider.embedQuery("test");
const call = fetchMock.mock.calls[0];
expect(call).toBeDefined();
const [url, init] = call as [RequestInfo | URL, RequestInit | undefined];
expect(url).toBe("https://example.com/embeddings");
const headers = (init?.headers ?? {}) as Record<string, string>;
expect(headers.Authorization).toBe("Bearer remote-override-key");
expect(headers["X-Custom"]).toBe("123");
});
it("passes input_type=document for embedBatch", async () => {
const fetchMock = withFetchPreconnect(
vi.fn<FetchMock>(
async (_input: RequestInfo | URL, _init?: RequestInit) =>
new Response(
JSON.stringify({
data: [{ embedding: [0.1, 0.2] }, { embedding: [0.3, 0.4] }],
}),
{ status: 200, headers: { "Content-Type": "application/json" } },
),
),
);
const result = await createDefaultVoyageProvider("voyage-4-large", fetchMock);
await result.provider.embedBatch(["doc1", "doc2"]);
const call = fetchMock.mock.calls[0];
expect(call).toBeDefined();
const [, init] = call as [RequestInfo | URL, RequestInit | undefined];
const body = JSON.parse(init?.body as string);
expect(body).toEqual({
model: "voyage-4-large",
input: ["doc1", "doc2"],
input_type: "document",
});
});
it("normalizes model names", async () => {
expect(normalizeVoyageModel("voyage/voyage-large-2")).toBe("voyage-large-2");
expect(normalizeVoyageModel("voyage-4-large")).toBe("voyage-4-large");
expect(normalizeVoyageModel(" voyage-lite ")).toBe("voyage-lite");
expect(normalizeVoyageModel("")).toBe("voyage-4-large"); // Default
});
});

View File

@@ -1,82 +0,0 @@
import type { SsrFPolicy } from "../../infra/net/ssrf.js";
import { normalizeEmbeddingModelWithPrefixes } from "./embeddings-model-normalize.js";
import { resolveRemoteEmbeddingBearerClient } from "./embeddings-remote-client.js";
import { fetchRemoteEmbeddingVectors } from "./embeddings-remote-fetch.js";
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js";
export type VoyageEmbeddingClient = {
baseUrl: string;
headers: Record<string, string>;
ssrfPolicy?: SsrFPolicy;
model: string;
};
export const DEFAULT_VOYAGE_EMBEDDING_MODEL = "voyage-4-large";
const DEFAULT_VOYAGE_BASE_URL = "https://api.voyageai.com/v1";
const VOYAGE_MAX_INPUT_TOKENS: Record<string, number> = {
"voyage-3": 32000,
"voyage-3-lite": 16000,
"voyage-code-3": 32000,
};
export function normalizeVoyageModel(model: string): string {
return normalizeEmbeddingModelWithPrefixes({
model,
defaultModel: DEFAULT_VOYAGE_EMBEDDING_MODEL,
prefixes: ["voyage/"],
});
}
export async function createVoyageEmbeddingProvider(
options: EmbeddingProviderOptions,
): Promise<{ provider: EmbeddingProvider; client: VoyageEmbeddingClient }> {
const client = await resolveVoyageEmbeddingClient(options);
const url = `${client.baseUrl.replace(/\/$/, "")}/embeddings`;
const embed = async (input: string[], input_type?: "query" | "document"): Promise<number[][]> => {
if (input.length === 0) {
return [];
}
const body: { model: string; input: string[]; input_type?: "query" | "document" } = {
model: client.model,
input,
};
if (input_type) {
body.input_type = input_type;
}
return await fetchRemoteEmbeddingVectors({
url,
headers: client.headers,
ssrfPolicy: client.ssrfPolicy,
body,
errorPrefix: "voyage embeddings failed",
});
};
return {
provider: {
id: "voyage",
model: client.model,
maxInputTokens: VOYAGE_MAX_INPUT_TOKENS[client.model],
embedQuery: async (text) => {
const [vec] = await embed([text], "query");
return vec ?? [];
},
embedBatch: async (texts) => embed(texts, "document"),
},
client,
};
}
export async function resolveVoyageEmbeddingClient(
options: EmbeddingProviderOptions,
): Promise<VoyageEmbeddingClient> {
const { baseUrl, headers, ssrfPolicy } = await resolveRemoteEmbeddingBearerClient({
provider: "voyage",
options,
defaultBaseUrl: DEFAULT_VOYAGE_BASE_URL,
});
const model = normalizeVoyageModel(options.model);
return { baseUrl, headers, ssrfPolicy, model };
}

View File

@@ -1,740 +0,0 @@
import { setTimeout as sleep } from "node:timers/promises";
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
import { DEFAULT_GEMINI_EMBEDDING_MODEL } from "./embeddings-gemini.js";
import { mockPublicPinnedHostname } from "./test-helpers/ssrf.js";
const createFetchMock = () =>
vi.fn(async (_input?: unknown, _init?: unknown) => ({
ok: true,
status: 200,
json: async () => ({ data: [{ embedding: [1, 2, 3] }] }),
}));
const createGeminiFetchMock = () =>
vi.fn(async (_input?: unknown, _init?: unknown) => ({
ok: true,
status: 200,
json: async () => ({ embedding: { values: [1, 2, 3] } }),
}));
function readFirstFetchRequest(fetchMock: { mock: { calls: unknown[][] } }) {
const [url, init] = fetchMock.mock.calls[0] ?? [];
return { url, init: init as RequestInit | undefined };
}
type EmbeddingsModule = typeof import("./embeddings.js");
type AuthModule = typeof import("../../agents/model-auth.js");
type ResolvedProviderAuth = Awaited<ReturnType<AuthModule["resolveApiKeyForProvider"]>>;
let authModule: AuthModule;
let nodeLlamaModule: typeof import("./node-llama.js");
let createEmbeddingProvider: EmbeddingsModule["createEmbeddingProvider"];
let DEFAULT_LOCAL_MODEL: EmbeddingsModule["DEFAULT_LOCAL_MODEL"];
beforeEach(async () => {
vi.resetModules();
authModule = await import("../../agents/model-auth.js");
nodeLlamaModule = await import("./node-llama.js");
vi.spyOn(authModule, "resolveApiKeyForProvider");
vi.spyOn(nodeLlamaModule, "importNodeLlamaCpp");
({ createEmbeddingProvider, DEFAULT_LOCAL_MODEL } = await import("./embeddings.js"));
});
afterEach(() => {
vi.resetAllMocks();
vi.unstubAllGlobals();
});
function requireProvider(result: Awaited<ReturnType<EmbeddingsModule["createEmbeddingProvider"]>>) {
if (!result.provider) {
throw new Error("Expected embedding provider");
}
return result.provider;
}
function mockResolvedProviderKey(apiKey = "provider-key") {
vi.mocked(authModule.resolveApiKeyForProvider).mockResolvedValue({
apiKey,
mode: "api-key",
source: "test",
});
}
function mockMissingLocalEmbeddingDependency() {
vi.mocked(nodeLlamaModule.importNodeLlamaCpp).mockRejectedValue(
Object.assign(new Error("Cannot find package 'node-llama-cpp'"), {
code: "ERR_MODULE_NOT_FOUND",
}),
);
}
function createLocalProvider(options?: { fallback?: "none" | "openai" }) {
return createEmbeddingProvider({
config: {} as never,
provider: "local",
model: "text-embedding-3-small",
fallback: options?.fallback ?? "none",
});
}
function expectAutoSelectedProvider(
result: Awaited<ReturnType<EmbeddingsModule["createEmbeddingProvider"]>>,
expectedId: "openai" | "gemini" | "mistral",
) {
expect(result.requestedProvider).toBe("auto");
const provider = requireProvider(result);
expect(provider.id).toBe(expectedId);
return provider;
}
function createAutoProvider(model = "") {
return createEmbeddingProvider({
config: {} as never,
provider: "auto",
model,
fallback: "none",
});
}
describe("embedding provider remote overrides", () => {
it("uses remote baseUrl/apiKey and merges headers", async () => {
const fetchMock = createFetchMock();
vi.stubGlobal("fetch", fetchMock);
mockPublicPinnedHostname();
mockResolvedProviderKey("provider-key");
const cfg = {
models: {
providers: {
openai: {
baseUrl: "https://api.openai.com/v1",
headers: {
"X-Provider": "p",
"X-Shared": "provider",
},
},
},
},
};
const result = await createEmbeddingProvider({
config: cfg as never,
provider: "openai",
remote: {
baseUrl: "https://example.com/v1",
apiKey: " remote-key ",
headers: {
"X-Shared": "remote",
"X-Remote": "r",
},
},
model: "text-embedding-3-small",
fallback: "openai",
});
const provider = requireProvider(result);
await provider.embedQuery("hello");
expect(authModule.resolveApiKeyForProvider).not.toHaveBeenCalled();
const url = fetchMock.mock.calls[0]?.[0];
const init = fetchMock.mock.calls[0]?.[1] as RequestInit | undefined;
expect(url).toBe("https://example.com/v1/embeddings");
const headers = (init?.headers ?? {}) as Record<string, string>;
expect(headers.Authorization).toBe("Bearer remote-key");
expect(headers["Content-Type"]).toBe("application/json");
expect(headers["X-Provider"]).toBe("p");
expect(headers["X-Shared"]).toBe("remote");
expect(headers["X-Remote"]).toBe("r");
});
it("falls back to resolved api key when remote apiKey is blank", async () => {
const fetchMock = createFetchMock();
vi.stubGlobal("fetch", fetchMock);
mockPublicPinnedHostname();
mockResolvedProviderKey("provider-key");
const cfg = {
models: {
providers: {
openai: {
baseUrl: "https://api.openai.com/v1",
},
},
},
};
const result = await createEmbeddingProvider({
config: cfg as never,
provider: "openai",
remote: {
baseUrl: "https://example.com/v1",
apiKey: " ",
},
model: "text-embedding-3-small",
fallback: "openai",
});
const provider = requireProvider(result);
await provider.embedQuery("hello");
expect(authModule.resolveApiKeyForProvider).toHaveBeenCalledTimes(1);
const init = fetchMock.mock.calls[0]?.[1] as RequestInit | undefined;
const headers = (init?.headers as Record<string, string>) ?? {};
expect(headers.Authorization).toBe("Bearer provider-key");
});
it("builds Gemini embeddings requests with api key header", async () => {
const fetchMock = createGeminiFetchMock();
vi.stubGlobal("fetch", fetchMock);
mockPublicPinnedHostname();
mockResolvedProviderKey("provider-key");
const cfg = {
models: {
providers: {
google: {
baseUrl: "https://generativelanguage.googleapis.com/v1beta",
},
},
},
};
const result = await createEmbeddingProvider({
config: cfg as never,
provider: "gemini",
remote: {
apiKey: "gemini-key",
},
model: "text-embedding-004",
fallback: "openai",
});
const provider = requireProvider(result);
await provider.embedQuery("hello");
const { url, init } = readFirstFetchRequest(fetchMock);
expect(url).toBe(
"https://generativelanguage.googleapis.com/v1beta/models/text-embedding-004:embedContent",
);
const headers = (init?.headers ?? {}) as Record<string, string>;
expect(headers["x-goog-api-key"]).toBe("gemini-key");
expect(headers["Content-Type"]).toBe("application/json");
});
it("fails fast when Gemini remote apiKey is an unresolved SecretRef", async () => {
await expect(
createEmbeddingProvider({
config: {} as never,
provider: "gemini",
remote: {
apiKey: { source: "env", provider: "default", id: "GEMINI_API_KEY" },
},
model: "text-embedding-004",
fallback: "openai",
}),
).rejects.toThrow(/agents\.\*\.memorySearch\.remote\.apiKey:/i);
});
it("uses GEMINI_API_KEY env indirection for Gemini remote apiKey", async () => {
const fetchMock = createGeminiFetchMock();
vi.stubGlobal("fetch", fetchMock);
mockPublicPinnedHostname();
vi.stubEnv("GEMINI_API_KEY", "env-gemini-key");
const result = await createEmbeddingProvider({
config: {} as never,
provider: "gemini",
remote: {
apiKey: "GEMINI_API_KEY", // pragma: allowlist secret
},
model: "text-embedding-004",
fallback: "openai",
});
const provider = requireProvider(result);
await provider.embedQuery("hello");
const { init } = readFirstFetchRequest(fetchMock);
const headers = (init?.headers ?? {}) as Record<string, string>;
expect(headers["x-goog-api-key"]).toBe("env-gemini-key");
});
it("builds Mistral embeddings requests with bearer auth", async () => {
const fetchMock = createFetchMock();
vi.stubGlobal("fetch", fetchMock);
mockPublicPinnedHostname();
mockResolvedProviderKey("provider-key");
const cfg = {
models: {
providers: {
mistral: {
baseUrl: "https://api.mistral.ai/v1",
},
},
},
};
const result = await createEmbeddingProvider({
config: cfg as never,
provider: "mistral",
remote: {
apiKey: "mistral-key", // pragma: allowlist secret
},
model: "mistral/mistral-embed",
fallback: "none",
});
const provider = requireProvider(result);
await provider.embedQuery("hello");
const { url, init } = readFirstFetchRequest(fetchMock);
expect(url).toBe("https://api.mistral.ai/v1/embeddings");
const headers = (init?.headers ?? {}) as Record<string, string>;
expect(headers.Authorization).toBe("Bearer mistral-key");
const payload = JSON.parse((init?.body as string | undefined) ?? "{}") as { model?: string };
expect(payload.model).toBe("mistral-embed");
});
});
describe("embedding provider auto selection", () => {
it("keeps explicit model when openai is selected", async () => {
const fetchMock = vi.fn(async (_input?: unknown, _init?: unknown) => ({
ok: true,
status: 200,
json: async () => ({ data: [{ embedding: [1, 2, 3] }] }),
}));
vi.stubGlobal("fetch", fetchMock);
mockPublicPinnedHostname();
vi.mocked(authModule.resolveApiKeyForProvider).mockImplementation(async ({ provider }) => {
if (provider === "openai") {
return { apiKey: "openai-key", source: "env: OPENAI_API_KEY", mode: "api-key" };
}
throw new Error(`Unexpected provider ${provider}`);
});
const result = await createEmbeddingProvider({
config: {} as never,
provider: "auto",
model: "text-embedding-3-small",
fallback: "none",
});
expect(result.requestedProvider).toBe("auto");
const provider = requireProvider(result);
expect(provider.id).toBe("openai");
await provider.embedQuery("hello");
const url = fetchMock.mock.calls[0]?.[0];
const init = fetchMock.mock.calls[0]?.[1] as RequestInit | undefined;
expect(url).toBe("https://api.openai.com/v1/embeddings");
const payload = JSON.parse(init?.body as string) as { model?: string };
expect(payload.model).toBe("text-embedding-3-small");
});
it("selects the first available remote provider in auto mode", async () => {
const cases: Array<{
name: string;
expectedProvider: "openai" | "gemini" | "mistral";
fetchMockFactory: typeof createFetchMock | typeof createGeminiFetchMock;
resolveApiKey: (provider: string) => ResolvedProviderAuth;
expectedUrl: string;
}> = [
{
name: "openai first",
expectedProvider: "openai" as const,
fetchMockFactory: createFetchMock,
resolveApiKey(provider: string): ResolvedProviderAuth {
if (provider === "openai") {
return { apiKey: "openai-key", source: "env: OPENAI_API_KEY", mode: "api-key" };
}
throw new Error(`No API key found for provider "${provider}".`);
},
expectedUrl: "https://api.openai.com/v1/embeddings",
},
{
name: "gemini fallback",
expectedProvider: "gemini" as const,
fetchMockFactory: createGeminiFetchMock,
resolveApiKey(provider: string): ResolvedProviderAuth {
if (provider === "openai") {
throw new Error('No API key found for provider "openai".');
}
if (provider === "google") {
return {
apiKey: "gemini-key",
source: "env: GEMINI_API_KEY",
mode: "api-key" as const,
};
}
throw new Error(`Unexpected provider ${provider}`);
},
expectedUrl: `https://generativelanguage.googleapis.com/v1beta/models/${DEFAULT_GEMINI_EMBEDDING_MODEL}:embedContent`,
},
{
name: "mistral after earlier misses",
expectedProvider: "mistral" as const,
fetchMockFactory: createFetchMock,
resolveApiKey(provider: string): ResolvedProviderAuth {
if (provider === "mistral") {
return {
apiKey: "mistral-key",
source: "env: MISTRAL_API_KEY",
mode: "api-key" as const,
};
}
throw new Error(`No API key found for provider "${provider}".`);
},
expectedUrl: "https://api.mistral.ai/v1/embeddings",
},
];
for (const testCase of cases) {
vi.resetAllMocks();
vi.unstubAllGlobals();
const fetchMock = testCase.fetchMockFactory();
vi.stubGlobal("fetch", fetchMock);
mockPublicPinnedHostname();
vi.mocked(authModule.resolveApiKeyForProvider).mockImplementation(async ({ provider }) =>
testCase.resolveApiKey(provider),
);
const result = await createAutoProvider();
const provider = expectAutoSelectedProvider(result, testCase.expectedProvider);
await provider.embedQuery("hello");
const [url] = fetchMock.mock.calls[0] ?? [];
expect(url, testCase.name).toBe(testCase.expectedUrl);
}
});
});
describe("embedding provider local fallback", () => {
it("falls back to openai when node-llama-cpp is missing", async () => {
mockMissingLocalEmbeddingDependency();
const fetchMock = createFetchMock();
vi.stubGlobal("fetch", fetchMock);
mockResolvedProviderKey("provider-key");
const result = await createLocalProvider({ fallback: "openai" });
const provider = requireProvider(result);
expect(provider.id).toBe("openai");
expect(result.fallbackFrom).toBe("local");
expect(result.fallbackReason).toContain("node-llama-cpp");
});
it("throws a helpful error when local is requested and fallback is none", async () => {
mockMissingLocalEmbeddingDependency();
await expect(createLocalProvider()).rejects.toThrow(/optional dependency node-llama-cpp/i);
});
it("mentions every remote provider in local setup guidance", async () => {
mockMissingLocalEmbeddingDependency();
await expect(createLocalProvider()).rejects.toThrow(/provider = "gemini"/i);
await expect(createLocalProvider()).rejects.toThrow(/provider = "mistral"/i);
});
});
describe("local embedding normalization", () => {
async function createLocalProviderForTest() {
return createEmbeddingProvider({
config: {} as never,
provider: "local",
model: "",
fallback: "none",
});
}
function mockSingleLocalEmbeddingVector(
vector: number[],
resolveModelFile: (modelPath: string, modelDirectory?: string) => Promise<string> = async () =>
"/fake/model.gguf",
): void {
vi.mocked(nodeLlamaModule.importNodeLlamaCpp).mockResolvedValue({
getLlama: async () => ({
loadModel: vi.fn().mockResolvedValue({
createEmbeddingContext: vi.fn().mockResolvedValue({
getEmbeddingFor: vi.fn().mockResolvedValue({
vector: new Float32Array(vector),
}),
}),
}),
}),
resolveModelFile,
LlamaLogLevel: { error: 0 },
} as never);
}
it("normalizes local embeddings to magnitude ~1.0", async () => {
const unnormalizedVector = [2.35, 3.45, 0.63, 4.3, 1.2, 5.1, 2.8, 3.9];
const resolveModelFileMock = vi.fn(async () => "/fake/model.gguf");
mockSingleLocalEmbeddingVector(unnormalizedVector, resolveModelFileMock);
const result = await createLocalProviderForTest();
const provider = requireProvider(result);
const embedding = await provider.embedQuery("test query");
const magnitude = Math.sqrt(embedding.reduce((sum, x) => sum + x * x, 0));
expect(magnitude).toBeCloseTo(1.0, 5);
expect(resolveModelFileMock).toHaveBeenCalledWith(DEFAULT_LOCAL_MODEL, undefined);
});
it("handles zero vector without division by zero", async () => {
const zeroVector = [0, 0, 0, 0];
mockSingleLocalEmbeddingVector(zeroVector);
const result = await createLocalProviderForTest();
const provider = requireProvider(result);
const embedding = await provider.embedQuery("test");
expect(embedding).toEqual([0, 0, 0, 0]);
expect(embedding.every((value) => Number.isFinite(value))).toBe(true);
});
it("sanitizes non-finite values before normalization", async () => {
const nonFiniteVector = [1, Number.NaN, Number.POSITIVE_INFINITY, Number.NEGATIVE_INFINITY];
mockSingleLocalEmbeddingVector(nonFiniteVector);
const result = await createLocalProviderForTest();
const provider = requireProvider(result);
const embedding = await provider.embedQuery("test");
expect(embedding).toEqual([1, 0, 0, 0]);
expect(embedding.every((value) => Number.isFinite(value))).toBe(true);
});
it("normalizes batch embeddings to magnitude ~1.0", async () => {
const unnormalizedVectors = [
[2.35, 3.45, 0.63, 4.3],
[10.0, 0.0, 0.0, 0.0],
[1.0, 1.0, 1.0, 1.0],
];
vi.mocked(nodeLlamaModule.importNodeLlamaCpp).mockResolvedValue({
getLlama: async () => ({
loadModel: vi.fn().mockResolvedValue({
createEmbeddingContext: vi.fn().mockResolvedValue({
getEmbeddingFor: vi
.fn()
.mockResolvedValueOnce({ vector: new Float32Array(unnormalizedVectors[0]) })
.mockResolvedValueOnce({ vector: new Float32Array(unnormalizedVectors[1]) })
.mockResolvedValueOnce({ vector: new Float32Array(unnormalizedVectors[2]) }),
}),
}),
}),
resolveModelFile: async () => "/fake/model.gguf",
LlamaLogLevel: { error: 0 },
} as never);
const result = await createLocalProviderForTest();
const provider = requireProvider(result);
const embeddings = await provider.embedBatch(["text1", "text2", "text3"]);
for (const embedding of embeddings) {
const magnitude = Math.sqrt(embedding.reduce((sum, x) => sum + x * x, 0));
expect(magnitude).toBeCloseTo(1.0, 5);
}
});
});
describe("local embedding ensureContext concurrency", () => {
beforeEach(() => {
vi.resetModules();
vi.doUnmock("./node-llama.js");
});
afterEach(() => {
vi.resetModules();
vi.doUnmock("./node-llama.js");
});
async function setupLocalProviderWithMockedInit(params?: {
initializationDelayMs?: number;
failFirstGetLlama?: boolean;
}) {
const getLlamaSpy = vi.fn();
const loadModelSpy = vi.fn();
const createContextSpy = vi.fn();
let shouldFail = params?.failFirstGetLlama ?? false;
const nodeLlamaModule = await import("./node-llama.js");
vi.spyOn(nodeLlamaModule, "importNodeLlamaCpp").mockResolvedValue({
getLlama: async (...args: unknown[]) => {
getLlamaSpy(...args);
if (shouldFail) {
shouldFail = false;
throw new Error("transient init failure");
}
if (params?.initializationDelayMs) {
await sleep(params.initializationDelayMs);
}
return {
loadModel: async (...modelArgs: unknown[]) => {
loadModelSpy(...modelArgs);
if (params?.initializationDelayMs) {
await sleep(params.initializationDelayMs);
}
return {
createEmbeddingContext: async () => {
createContextSpy();
return {
getEmbeddingFor: vi.fn().mockResolvedValue({
vector: new Float32Array([1, 0, 0, 0]),
}),
};
},
};
},
};
},
resolveModelFile: async () => "/fake/model.gguf",
LlamaLogLevel: { error: 0 },
} as never);
const { createEmbeddingProvider } = await import("./embeddings.js");
const result = await createEmbeddingProvider({
config: {} as never,
provider: "local",
model: "",
fallback: "none",
});
return {
provider: requireProvider(result),
getLlamaSpy,
loadModelSpy,
createContextSpy,
};
}
it("loads the model only once when embedBatch is called concurrently", async () => {
const { provider, getLlamaSpy, loadModelSpy, createContextSpy } =
await setupLocalProviderWithMockedInit({
initializationDelayMs: 50,
});
const results = await Promise.all([
provider.embedBatch(["text1"]),
provider.embedBatch(["text2"]),
provider.embedBatch(["text3"]),
provider.embedBatch(["text4"]),
]);
expect(results).toHaveLength(4);
for (const embeddings of results) {
expect(embeddings).toHaveLength(1);
expect(embeddings[0]).toHaveLength(4);
}
expect(getLlamaSpy).toHaveBeenCalledTimes(1);
expect(loadModelSpy).toHaveBeenCalledTimes(1);
expect(createContextSpy).toHaveBeenCalledTimes(1);
});
it("retries initialization after a transient ensureContext failure", async () => {
const { provider, getLlamaSpy, loadModelSpy, createContextSpy } =
await setupLocalProviderWithMockedInit({
failFirstGetLlama: true,
});
await expect(provider.embedBatch(["first"])).rejects.toThrow("transient init failure");
const recovered = await provider.embedBatch(["second"]);
expect(recovered).toHaveLength(1);
expect(recovered[0]).toHaveLength(4);
expect(getLlamaSpy).toHaveBeenCalledTimes(2);
expect(loadModelSpy).toHaveBeenCalledTimes(1);
expect(createContextSpy).toHaveBeenCalledTimes(1);
});
it("shares initialization when embedQuery and embedBatch start concurrently", async () => {
const { provider, getLlamaSpy, loadModelSpy, createContextSpy } =
await setupLocalProviderWithMockedInit({
initializationDelayMs: 50,
});
const [queryA, batch, queryB] = await Promise.all([
provider.embedQuery("query-a"),
provider.embedBatch(["batch-a", "batch-b"]),
provider.embedQuery("query-b"),
]);
expect(queryA).toHaveLength(4);
expect(batch).toHaveLength(2);
expect(queryB).toHaveLength(4);
expect(batch[0]).toHaveLength(4);
expect(batch[1]).toHaveLength(4);
expect(getLlamaSpy).toHaveBeenCalledTimes(1);
expect(loadModelSpy).toHaveBeenCalledTimes(1);
expect(createContextSpy).toHaveBeenCalledTimes(1);
});
});
describe("FTS-only fallback when no provider available", () => {
beforeEach(async () => {
authModule = await import("../../agents/model-auth.js");
({ createEmbeddingProvider, DEFAULT_LOCAL_MODEL } = await import("./embeddings.js"));
});
it("returns null provider when all requested auth paths fail", async () => {
vi.mocked(authModule.resolveApiKeyForProvider).mockRejectedValue(
new Error("No API key found for provider"),
);
for (const testCase of [
{
name: "auto mode",
options: {
config: {} as never,
provider: "auto" as const,
model: "",
fallback: "none" as const,
},
requestedProvider: "auto",
fallbackFrom: undefined,
reasonIncludes: "No API key",
},
{
name: "explicit provider only",
options: {
config: {} as never,
provider: "openai" as const,
model: "text-embedding-3-small",
fallback: "none" as const,
},
requestedProvider: "openai",
fallbackFrom: undefined,
reasonIncludes: "No API key",
},
{
name: "primary and fallback",
options: {
config: {} as never,
provider: "openai" as const,
model: "text-embedding-3-small",
fallback: "gemini" as const,
},
requestedProvider: "openai",
fallbackFrom: "openai",
reasonIncludes: "Fallback to gemini failed",
},
]) {
const result = await createEmbeddingProvider(testCase.options);
expect(result.provider, testCase.name).toBeNull();
expect(result.requestedProvider, testCase.name).toBe(testCase.requestedProvider);
expect(result.fallbackFrom, testCase.name).toBe(testCase.fallbackFrom);
expect(result.providerUnavailableReason, testCase.name).toContain(testCase.reasonIncludes);
}
});
});

View File

@@ -1,324 +0,0 @@
import fsSync from "node:fs";
import type { Llama, LlamaEmbeddingContext, LlamaModel } from "node-llama-cpp";
import type { OpenClawConfig } from "../../config/config.js";
import type { SecretInput } from "../../config/types.secrets.js";
import { formatErrorMessage } from "../../infra/errors.js";
import { resolveUserPath } from "../../utils.js";
import type { EmbeddingInput } from "./embedding-inputs.js";
import { sanitizeAndNormalizeEmbedding } from "./embedding-vectors.js";
import {
createGeminiEmbeddingProvider,
type GeminiEmbeddingClient,
type GeminiTaskType,
} from "./embeddings-gemini.js";
import {
createMistralEmbeddingProvider,
type MistralEmbeddingClient,
} from "./embeddings-mistral.js";
import { createOllamaEmbeddingProvider, type OllamaEmbeddingClient } from "./embeddings-ollama.js";
import { createOpenAiEmbeddingProvider, type OpenAiEmbeddingClient } from "./embeddings-openai.js";
import { createVoyageEmbeddingProvider, type VoyageEmbeddingClient } from "./embeddings-voyage.js";
import { importNodeLlamaCpp } from "./node-llama.js";
export type { GeminiEmbeddingClient } from "./embeddings-gemini.js";
export type { MistralEmbeddingClient } from "./embeddings-mistral.js";
export type { OpenAiEmbeddingClient } from "./embeddings-openai.js";
export type { VoyageEmbeddingClient } from "./embeddings-voyage.js";
export type { OllamaEmbeddingClient } from "./embeddings-ollama.js";
export type EmbeddingProvider = {
id: string;
model: string;
maxInputTokens?: number;
embedQuery: (text: string) => Promise<number[]>;
embedBatch: (texts: string[]) => Promise<number[][]>;
embedBatchInputs?: (inputs: EmbeddingInput[]) => Promise<number[][]>;
};
export type EmbeddingProviderId = "openai" | "local" | "gemini" | "voyage" | "mistral" | "ollama";
export type EmbeddingProviderRequest = EmbeddingProviderId | "auto";
export type EmbeddingProviderFallback = EmbeddingProviderId | "none";
// Remote providers considered for auto-selection when provider === "auto".
// Ollama is intentionally excluded here so that "auto" mode does not
// implicitly assume a local Ollama instance is available.
const REMOTE_EMBEDDING_PROVIDER_IDS = ["openai", "gemini", "voyage", "mistral"] as const;
export type EmbeddingProviderResult = {
provider: EmbeddingProvider | null;
requestedProvider: EmbeddingProviderRequest;
fallbackFrom?: EmbeddingProviderId;
fallbackReason?: string;
providerUnavailableReason?: string;
openAi?: OpenAiEmbeddingClient;
gemini?: GeminiEmbeddingClient;
voyage?: VoyageEmbeddingClient;
mistral?: MistralEmbeddingClient;
ollama?: OllamaEmbeddingClient;
};
export type EmbeddingProviderOptions = {
config: OpenClawConfig;
agentDir?: string;
provider: EmbeddingProviderRequest;
remote?: {
baseUrl?: string;
apiKey?: SecretInput;
headers?: Record<string, string>;
};
model: string;
fallback: EmbeddingProviderFallback;
local?: {
modelPath?: string;
modelCacheDir?: string;
};
/** Gemini embedding-2: output vector dimensions (768, 1536, or 3072). */
outputDimensionality?: number;
/** Gemini: override the default task type sent with embedding requests. */
taskType?: GeminiTaskType;
};
export const DEFAULT_LOCAL_MODEL =
"hf:ggml-org/embeddinggemma-300m-qat-q8_0-GGUF/embeddinggemma-300m-qat-Q8_0.gguf";
function canAutoSelectLocal(options: EmbeddingProviderOptions): boolean {
const modelPath = options.local?.modelPath?.trim();
if (!modelPath) {
return false;
}
if (/^(hf:|https?:)/i.test(modelPath)) {
return false;
}
const resolved = resolveUserPath(modelPath);
try {
return fsSync.statSync(resolved).isFile();
} catch {
return false;
}
}
function isMissingApiKeyError(err: unknown): boolean {
const message = formatErrorMessage(err);
return message.includes("No API key found for provider");
}
export async function createLocalEmbeddingProvider(
options: EmbeddingProviderOptions,
): Promise<EmbeddingProvider> {
const modelPath = options.local?.modelPath?.trim() || DEFAULT_LOCAL_MODEL;
const modelCacheDir = options.local?.modelCacheDir?.trim();
// Lazy-load node-llama-cpp to keep startup light unless local is enabled.
const { getLlama, resolveModelFile, LlamaLogLevel } = await importNodeLlamaCpp();
let llama: Llama | null = null;
let embeddingModel: LlamaModel | null = null;
let embeddingContext: LlamaEmbeddingContext | null = null;
let initPromise: Promise<LlamaEmbeddingContext> | null = null;
const ensureContext = async (): Promise<LlamaEmbeddingContext> => {
if (embeddingContext) {
return embeddingContext;
}
if (initPromise) {
return initPromise;
}
initPromise = (async () => {
try {
if (!llama) {
llama = await getLlama({ logLevel: LlamaLogLevel.error });
}
if (!embeddingModel) {
const resolved = await resolveModelFile(modelPath, modelCacheDir || undefined);
embeddingModel = await llama.loadModel({ modelPath: resolved });
}
if (!embeddingContext) {
embeddingContext = await embeddingModel.createEmbeddingContext();
}
return embeddingContext;
} catch (err) {
initPromise = null;
throw err;
}
})();
return initPromise;
};
return {
id: "local",
model: modelPath,
embedQuery: async (text) => {
const ctx = await ensureContext();
const embedding = await ctx.getEmbeddingFor(text);
return sanitizeAndNormalizeEmbedding(Array.from(embedding.vector));
},
embedBatch: async (texts) => {
const ctx = await ensureContext();
const embeddings = await Promise.all(
texts.map(async (text) => {
const embedding = await ctx.getEmbeddingFor(text);
return sanitizeAndNormalizeEmbedding(Array.from(embedding.vector));
}),
);
return embeddings;
},
};
}
export async function createEmbeddingProvider(
options: EmbeddingProviderOptions,
): Promise<EmbeddingProviderResult> {
const requestedProvider = options.provider;
const fallback = options.fallback;
const createProvider = async (id: EmbeddingProviderId) => {
if (id === "local") {
const provider = await createLocalEmbeddingProvider(options);
return { provider };
}
if (id === "ollama") {
const { provider, client } = await createOllamaEmbeddingProvider(options);
return { provider, ollama: client };
}
if (id === "gemini") {
const { provider, client } = await createGeminiEmbeddingProvider(options);
return { provider, gemini: client };
}
if (id === "voyage") {
const { provider, client } = await createVoyageEmbeddingProvider(options);
return { provider, voyage: client };
}
if (id === "mistral") {
const { provider, client } = await createMistralEmbeddingProvider(options);
return { provider, mistral: client };
}
const { provider, client } = await createOpenAiEmbeddingProvider(options);
return { provider, openAi: client };
};
const formatPrimaryError = (err: unknown, provider: EmbeddingProviderId) =>
provider === "local" ? formatLocalSetupError(err) : formatErrorMessage(err);
if (requestedProvider === "auto") {
const missingKeyErrors: string[] = [];
let localError: string | null = null;
if (canAutoSelectLocal(options)) {
try {
const local = await createProvider("local");
return { ...local, requestedProvider };
} catch (err) {
localError = formatLocalSetupError(err);
}
}
for (const provider of REMOTE_EMBEDDING_PROVIDER_IDS) {
try {
const result = await createProvider(provider);
return { ...result, requestedProvider };
} catch (err) {
const message = formatPrimaryError(err, provider);
if (isMissingApiKeyError(err)) {
missingKeyErrors.push(message);
continue;
}
// Non-auth errors (e.g., network) are still fatal
const wrapped = new Error(message) as Error & { cause?: unknown };
wrapped.cause = err;
throw wrapped;
}
}
// All providers failed due to missing API keys - return null provider for FTS-only mode
const details = [...missingKeyErrors, localError].filter(Boolean) as string[];
const reason = details.length > 0 ? details.join("\n\n") : "No embeddings provider available.";
return {
provider: null,
requestedProvider,
providerUnavailableReason: reason,
};
}
try {
const primary = await createProvider(requestedProvider);
return { ...primary, requestedProvider };
} catch (primaryErr) {
const reason = formatPrimaryError(primaryErr, requestedProvider);
if (fallback && fallback !== "none" && fallback !== requestedProvider) {
try {
const fallbackResult = await createProvider(fallback);
return {
...fallbackResult,
requestedProvider,
fallbackFrom: requestedProvider,
fallbackReason: reason,
};
} catch (fallbackErr) {
// Both primary and fallback failed - check if it's auth-related
const fallbackReason = formatErrorMessage(fallbackErr);
const combinedReason = `${reason}\n\nFallback to ${fallback} failed: ${fallbackReason}`;
if (isMissingApiKeyError(primaryErr) && isMissingApiKeyError(fallbackErr)) {
// Both failed due to missing API keys - return null for FTS-only mode
return {
provider: null,
requestedProvider,
fallbackFrom: requestedProvider,
fallbackReason: reason,
providerUnavailableReason: combinedReason,
};
}
// Non-auth errors are still fatal
const wrapped = new Error(combinedReason) as Error & { cause?: unknown };
wrapped.cause = fallbackErr;
throw wrapped;
}
}
// No fallback configured - check if we should degrade to FTS-only
if (isMissingApiKeyError(primaryErr)) {
return {
provider: null,
requestedProvider,
providerUnavailableReason: reason,
};
}
const wrapped = new Error(reason) as Error & { cause?: unknown };
wrapped.cause = primaryErr;
throw wrapped;
}
}
function isNodeLlamaCppMissing(err: unknown): boolean {
if (!(err instanceof Error)) {
return false;
}
const code = (err as Error & { code?: unknown }).code;
if (code === "ERR_MODULE_NOT_FOUND") {
return err.message.includes("node-llama-cpp");
}
return false;
}
function formatLocalSetupError(err: unknown): string {
const detail = formatErrorMessage(err);
const missing = isNodeLlamaCppMissing(err);
return [
"Local embeddings unavailable.",
missing
? "Reason: optional dependency node-llama-cpp is missing (or failed to install)."
: detail
? `Reason: ${detail}`
: undefined,
missing && detail ? `Detail: ${detail}` : null,
"To enable local embeddings:",
"1) Use Node 24 (recommended for installs/updates; Node 22 LTS, currently 22.14+, remains supported)",
missing
? "2) Reinstall OpenClaw (this should install node-llama-cpp): npm i -g openclaw@latest"
: null,
"3) If you use pnpm: pnpm approve-builds (select node-llama-cpp), then pnpm rebuild node-llama-cpp",
...REMOTE_EMBEDDING_PROVIDER_IDS.map(
(provider) => `Or set agents.defaults.memorySearch.provider = "${provider}" (remote).`,
),
]
.filter(Boolean)
.join("\n");
}

View File

@@ -1,31 +0,0 @@
import type { Stats } from "node:fs";
import fs from "node:fs/promises";
export type RegularFileStatResult = { missing: true } | { missing: false; stat: Stats };
export function isFileMissingError(
err: unknown,
): err is NodeJS.ErrnoException & { code: "ENOENT" } {
return Boolean(
err &&
typeof err === "object" &&
"code" in err &&
(err as Partial<NodeJS.ErrnoException>).code === "ENOENT",
);
}
export async function statRegularFile(absPath: string): Promise<RegularFileStatResult> {
let stat: Stats;
try {
stat = await fs.lstat(absPath);
} catch (err) {
if (isFileMissingError(err)) {
return { missing: true };
}
throw err;
}
if (stat.isSymbolicLink() || !stat.isFile()) {
throw new Error("path required");
}
return { missing: false, stat };
}

View File

@@ -1,314 +0,0 @@
import fs from "node:fs/promises";
import os from "node:os";
import path from "node:path";
import { afterEach, beforeEach, describe, expect, it } from "vitest";
import {
buildMultimodalChunkForIndexing,
buildFileEntry,
chunkMarkdown,
listMemoryFiles,
normalizeExtraMemoryPaths,
remapChunkLines,
} from "./internal.js";
import {
DEFAULT_MEMORY_MULTIMODAL_MAX_FILE_BYTES,
type MemoryMultimodalSettings,
} from "./multimodal.js";
function setupTempDirLifecycle(prefix: string): () => string {
let tmpDir = "";
beforeEach(async () => {
tmpDir = await fs.mkdtemp(path.join(os.tmpdir(), prefix));
});
afterEach(async () => {
await fs.rm(tmpDir, { recursive: true, force: true });
});
return () => tmpDir;
}
describe("normalizeExtraMemoryPaths", () => {
it("trims, resolves, and dedupes paths", () => {
const workspaceDir = path.join(os.tmpdir(), "memory-test-workspace");
const absPath = path.resolve(path.sep, "shared-notes");
const result = normalizeExtraMemoryPaths(workspaceDir, [
" notes ",
"./notes",
absPath,
absPath,
"",
]);
expect(result).toEqual([path.resolve(workspaceDir, "notes"), absPath]);
});
});
describe("listMemoryFiles", () => {
const getTmpDir = setupTempDirLifecycle("memory-test-");
const multimodal: MemoryMultimodalSettings = {
enabled: true,
modalities: ["image", "audio"],
maxFileBytes: DEFAULT_MEMORY_MULTIMODAL_MAX_FILE_BYTES,
};
it("includes files from additional paths (directory)", async () => {
const tmpDir = getTmpDir();
await fs.writeFile(path.join(tmpDir, "MEMORY.md"), "# Default memory");
const extraDir = path.join(tmpDir, "extra-notes");
await fs.mkdir(extraDir, { recursive: true });
await fs.writeFile(path.join(extraDir, "note1.md"), "# Note 1");
await fs.writeFile(path.join(extraDir, "note2.md"), "# Note 2");
await fs.writeFile(path.join(extraDir, "ignore.txt"), "Not a markdown file");
const files = await listMemoryFiles(tmpDir, [extraDir]);
expect(files).toHaveLength(3);
expect(files.some((file) => file.endsWith("MEMORY.md"))).toBe(true);
expect(files.some((file) => file.endsWith("note1.md"))).toBe(true);
expect(files.some((file) => file.endsWith("note2.md"))).toBe(true);
expect(files.some((file) => file.endsWith("ignore.txt"))).toBe(false);
});
it("includes files from additional paths (single file)", async () => {
const tmpDir = getTmpDir();
await fs.writeFile(path.join(tmpDir, "MEMORY.md"), "# Default memory");
const singleFile = path.join(tmpDir, "standalone.md");
await fs.writeFile(singleFile, "# Standalone");
const files = await listMemoryFiles(tmpDir, [singleFile]);
expect(files).toHaveLength(2);
expect(files.some((file) => file.endsWith("standalone.md"))).toBe(true);
});
it("handles relative paths in additional paths", async () => {
const tmpDir = getTmpDir();
await fs.writeFile(path.join(tmpDir, "MEMORY.md"), "# Default memory");
const extraDir = path.join(tmpDir, "subdir");
await fs.mkdir(extraDir, { recursive: true });
await fs.writeFile(path.join(extraDir, "nested.md"), "# Nested");
const files = await listMemoryFiles(tmpDir, ["subdir"]);
expect(files).toHaveLength(2);
expect(files.some((file) => file.endsWith("nested.md"))).toBe(true);
});
it("ignores non-existent additional paths", async () => {
const tmpDir = getTmpDir();
await fs.writeFile(path.join(tmpDir, "MEMORY.md"), "# Default memory");
const files = await listMemoryFiles(tmpDir, ["/does/not/exist"]);
expect(files).toHaveLength(1);
});
it("ignores symlinked files and directories", async () => {
const tmpDir = getTmpDir();
await fs.writeFile(path.join(tmpDir, "MEMORY.md"), "# Default memory");
const extraDir = path.join(tmpDir, "extra");
await fs.mkdir(extraDir, { recursive: true });
await fs.writeFile(path.join(extraDir, "note.md"), "# Note");
const targetFile = path.join(tmpDir, "target.md");
await fs.writeFile(targetFile, "# Target");
const linkFile = path.join(extraDir, "linked.md");
const targetDir = path.join(tmpDir, "target-dir");
await fs.mkdir(targetDir, { recursive: true });
await fs.writeFile(path.join(targetDir, "nested.md"), "# Nested");
const linkDir = path.join(tmpDir, "linked-dir");
let symlinksOk = true;
try {
await fs.symlink(targetFile, linkFile, "file");
await fs.symlink(targetDir, linkDir, "dir");
} catch (err) {
const code = (err as NodeJS.ErrnoException).code;
if (code === "EPERM" || code === "EACCES") {
symlinksOk = false;
} else {
throw err;
}
}
const files = await listMemoryFiles(tmpDir, [extraDir, linkDir]);
expect(files.some((file) => file.endsWith("note.md"))).toBe(true);
if (symlinksOk) {
expect(files.some((file) => file.endsWith("linked.md"))).toBe(false);
expect(files.some((file) => file.endsWith("nested.md"))).toBe(false);
}
});
it("dedupes overlapping extra paths that resolve to the same file", async () => {
const tmpDir = getTmpDir();
await fs.writeFile(path.join(tmpDir, "MEMORY.md"), "# Default memory");
const files = await listMemoryFiles(tmpDir, [tmpDir, ".", path.join(tmpDir, "MEMORY.md")]);
const memoryMatches = files.filter((file) => file.endsWith("MEMORY.md"));
expect(memoryMatches).toHaveLength(1);
});
it("includes image and audio files from extra paths when multimodal is enabled", async () => {
const tmpDir = getTmpDir();
const extraDir = path.join(tmpDir, "media");
await fs.mkdir(extraDir, { recursive: true });
await fs.writeFile(path.join(extraDir, "diagram.png"), Buffer.from("png"));
await fs.writeFile(path.join(extraDir, "note.wav"), Buffer.from("wav"));
await fs.writeFile(path.join(extraDir, "ignore.bin"), Buffer.from("bin"));
const files = await listMemoryFiles(tmpDir, [extraDir], multimodal);
expect(files.some((file) => file.endsWith("diagram.png"))).toBe(true);
expect(files.some((file) => file.endsWith("note.wav"))).toBe(true);
expect(files.some((file) => file.endsWith("ignore.bin"))).toBe(false);
});
});
describe("buildFileEntry", () => {
const getTmpDir = setupTempDirLifecycle("memory-build-entry-");
const multimodal: MemoryMultimodalSettings = {
enabled: true,
modalities: ["image", "audio"],
maxFileBytes: DEFAULT_MEMORY_MULTIMODAL_MAX_FILE_BYTES,
};
it("returns null when the file disappears before reading", async () => {
const tmpDir = getTmpDir();
const target = path.join(tmpDir, "ghost.md");
await fs.writeFile(target, "ghost", "utf-8");
await fs.rm(target);
const entry = await buildFileEntry(target, tmpDir);
expect(entry).toBeNull();
});
it("returns metadata when the file exists", async () => {
const tmpDir = getTmpDir();
const target = path.join(tmpDir, "note.md");
await fs.writeFile(target, "hello", "utf-8");
const entry = await buildFileEntry(target, tmpDir);
expect(entry).not.toBeNull();
expect(entry?.path).toBe("note.md");
expect(entry?.size).toBeGreaterThan(0);
});
it("returns multimodal metadata for eligible image files", async () => {
const tmpDir = getTmpDir();
const target = path.join(tmpDir, "diagram.png");
await fs.writeFile(target, Buffer.from("png"));
const entry = await buildFileEntry(target, tmpDir, multimodal);
expect(entry).toMatchObject({
path: "diagram.png",
kind: "multimodal",
modality: "image",
mimeType: "image/png",
contentText: "Image file: diagram.png",
});
});
it("builds a multimodal chunk lazily for indexing", async () => {
const tmpDir = getTmpDir();
const target = path.join(tmpDir, "diagram.png");
await fs.writeFile(target, Buffer.from("png"));
const entry = await buildFileEntry(target, tmpDir, multimodal);
const built = await buildMultimodalChunkForIndexing(entry!);
expect(built?.chunk.embeddingInput?.parts).toEqual([
{ type: "text", text: "Image file: diagram.png" },
expect.objectContaining({ type: "inline-data", mimeType: "image/png" }),
]);
expect(built?.structuredInputBytes).toBeGreaterThan(0);
});
it("skips lazy multimodal indexing when the file grows after discovery", async () => {
const tmpDir = getTmpDir();
const target = path.join(tmpDir, "diagram.png");
await fs.writeFile(target, Buffer.from("png"));
const entry = await buildFileEntry(target, tmpDir, multimodal);
await fs.writeFile(target, Buffer.alloc(entry!.size + 32, 1));
await expect(buildMultimodalChunkForIndexing(entry!)).resolves.toBeNull();
});
it("skips lazy multimodal indexing when file bytes change after discovery", async () => {
const tmpDir = getTmpDir();
const target = path.join(tmpDir, "diagram.png");
await fs.writeFile(target, Buffer.from("png"));
const entry = await buildFileEntry(target, tmpDir, multimodal);
await fs.writeFile(target, Buffer.from("gif"));
await expect(buildMultimodalChunkForIndexing(entry!)).resolves.toBeNull();
});
});
describe("chunkMarkdown", () => {
it("splits overly long lines into max-sized chunks", () => {
const chunkTokens = 400;
const maxChars = chunkTokens * 4;
const content = "a".repeat(maxChars * 3 + 25);
const chunks = chunkMarkdown(content, { tokens: chunkTokens, overlap: 0 });
expect(chunks.length).toBeGreaterThan(1);
for (const chunk of chunks) {
expect(chunk.text.length).toBeLessThanOrEqual(maxChars);
}
});
});
describe("remapChunkLines", () => {
it("remaps chunk line numbers using a lineMap", () => {
// Simulate 5 content lines that came from JSONL lines [4, 6, 7, 10, 13] (1-indexed)
const lineMap = [4, 6, 7, 10, 13];
// Create chunks from content that has 5 lines
const content = "User: Hello\nAssistant: Hi\nUser: Question\nAssistant: Answer\nUser: Thanks";
const chunks = chunkMarkdown(content, { tokens: 400, overlap: 0 });
expect(chunks.length).toBeGreaterThan(0);
// Before remapping, startLine/endLine reference content line numbers (1-indexed)
expect(chunks[0].startLine).toBe(1);
// Remap
remapChunkLines(chunks, lineMap);
// After remapping, line numbers should reference original JSONL lines
// Content line 1 → JSONL line 4, content line 5 → JSONL line 13
expect(chunks[0].startLine).toBe(4);
const lastChunk = chunks[chunks.length - 1];
expect(lastChunk.endLine).toBe(13);
});
it("preserves original line numbers when lineMap is undefined", () => {
const content = "Line one\nLine two\nLine three";
const chunks = chunkMarkdown(content, { tokens: 400, overlap: 0 });
const originalStart = chunks[0].startLine;
const originalEnd = chunks[chunks.length - 1].endLine;
remapChunkLines(chunks, undefined);
expect(chunks[0].startLine).toBe(originalStart);
expect(chunks[chunks.length - 1].endLine).toBe(originalEnd);
});
it("handles multi-chunk content with correct remapping", () => {
// Use small chunk size to force multiple chunks
// lineMap: 10 content lines from JSONL lines [2, 5, 8, 11, 14, 17, 20, 23, 26, 29]
const lineMap = [2, 5, 8, 11, 14, 17, 20, 23, 26, 29];
const contentLines = lineMap.map((_, i) =>
i % 2 === 0 ? `User: Message ${i}` : `Assistant: Reply ${i}`,
);
const content = contentLines.join("\n");
// Use very small chunk size to force splitting
const chunks = chunkMarkdown(content, { tokens: 10, overlap: 0 });
expect(chunks.length).toBeGreaterThan(1);
remapChunkLines(chunks, lineMap);
// First chunk should start at JSONL line 2
expect(chunks[0].startLine).toBe(2);
// Last chunk should end at JSONL line 29
expect(chunks[chunks.length - 1].endLine).toBe(29);
// Each chunk's startLine should be ≤ its endLine
for (const chunk of chunks) {
expect(chunk.startLine).toBeLessThanOrEqual(chunk.endLine);
}
});
});

View File

@@ -1,482 +0,0 @@
import crypto from "node:crypto";
import fsSync from "node:fs";
import fs from "node:fs/promises";
import path from "node:path";
import { detectMime } from "../../media/mime.js";
import { runTasksWithConcurrency } from "../../utils/run-with-concurrency.js";
import { estimateStructuredEmbeddingInputBytes } from "./embedding-input-limits.js";
import { buildTextEmbeddingInput, type EmbeddingInput } from "./embedding-inputs.js";
import { isFileMissingError } from "./fs-utils.js";
import {
buildMemoryMultimodalLabel,
classifyMemoryMultimodalPath,
type MemoryMultimodalModality,
type MemoryMultimodalSettings,
} from "./multimodal.js";
export type MemoryFileEntry = {
path: string;
absPath: string;
mtimeMs: number;
size: number;
hash: string;
dataHash?: string;
kind?: "markdown" | "multimodal";
contentText?: string;
modality?: MemoryMultimodalModality;
mimeType?: string;
};
export type MemoryChunk = {
startLine: number;
endLine: number;
text: string;
hash: string;
embeddingInput?: EmbeddingInput;
};
export type MultimodalMemoryChunk = {
chunk: MemoryChunk;
structuredInputBytes: number;
};
const DISABLED_MULTIMODAL_SETTINGS: MemoryMultimodalSettings = {
enabled: false,
modalities: [],
maxFileBytes: 0,
};
export function ensureDir(dir: string): string {
try {
fsSync.mkdirSync(dir, { recursive: true });
} catch {}
return dir;
}
export function normalizeRelPath(value: string): string {
const trimmed = value.trim().replace(/^[./]+/, "");
return trimmed.replace(/\\/g, "/");
}
export function normalizeExtraMemoryPaths(workspaceDir: string, extraPaths?: string[]): string[] {
if (!extraPaths?.length) {
return [];
}
const resolved = extraPaths
.map((value) => value.trim())
.filter(Boolean)
.map((value) =>
path.isAbsolute(value) ? path.resolve(value) : path.resolve(workspaceDir, value),
);
return Array.from(new Set(resolved));
}
export function isMemoryPath(relPath: string): boolean {
const normalized = normalizeRelPath(relPath);
if (!normalized) {
return false;
}
if (normalized === "MEMORY.md" || normalized === "memory.md") {
return true;
}
return normalized.startsWith("memory/");
}
function isAllowedMemoryFilePath(filePath: string, multimodal?: MemoryMultimodalSettings): boolean {
if (filePath.endsWith(".md")) {
return true;
}
return (
classifyMemoryMultimodalPath(filePath, multimodal ?? DISABLED_MULTIMODAL_SETTINGS) !== null
);
}
async function walkDir(dir: string, files: string[], multimodal?: MemoryMultimodalSettings) {
const entries = await fs.readdir(dir, { withFileTypes: true });
for (const entry of entries) {
const full = path.join(dir, entry.name);
if (entry.isSymbolicLink()) {
continue;
}
if (entry.isDirectory()) {
await walkDir(full, files, multimodal);
continue;
}
if (!entry.isFile()) {
continue;
}
if (!isAllowedMemoryFilePath(full, multimodal)) {
continue;
}
files.push(full);
}
}
export async function listMemoryFiles(
workspaceDir: string,
extraPaths?: string[],
multimodal?: MemoryMultimodalSettings,
): Promise<string[]> {
const result: string[] = [];
const memoryFile = path.join(workspaceDir, "MEMORY.md");
const altMemoryFile = path.join(workspaceDir, "memory.md");
const memoryDir = path.join(workspaceDir, "memory");
const addMarkdownFile = async (absPath: string) => {
try {
const stat = await fs.lstat(absPath);
if (stat.isSymbolicLink() || !stat.isFile()) {
return;
}
if (!absPath.endsWith(".md")) {
return;
}
result.push(absPath);
} catch {}
};
await addMarkdownFile(memoryFile);
await addMarkdownFile(altMemoryFile);
try {
const dirStat = await fs.lstat(memoryDir);
if (!dirStat.isSymbolicLink() && dirStat.isDirectory()) {
await walkDir(memoryDir, result);
}
} catch {}
const normalizedExtraPaths = normalizeExtraMemoryPaths(workspaceDir, extraPaths);
if (normalizedExtraPaths.length > 0) {
for (const inputPath of normalizedExtraPaths) {
try {
const stat = await fs.lstat(inputPath);
if (stat.isSymbolicLink()) {
continue;
}
if (stat.isDirectory()) {
await walkDir(inputPath, result, multimodal);
continue;
}
if (stat.isFile() && isAllowedMemoryFilePath(inputPath, multimodal)) {
result.push(inputPath);
}
} catch {}
}
}
if (result.length <= 1) {
return result;
}
const seen = new Set<string>();
const deduped: string[] = [];
for (const entry of result) {
let key = entry;
try {
key = await fs.realpath(entry);
} catch {}
if (seen.has(key)) {
continue;
}
seen.add(key);
deduped.push(entry);
}
return deduped;
}
export function hashText(value: string): string {
return crypto.createHash("sha256").update(value).digest("hex");
}
export async function buildFileEntry(
absPath: string,
workspaceDir: string,
multimodal?: MemoryMultimodalSettings,
): Promise<MemoryFileEntry | null> {
let stat;
try {
stat = await fs.stat(absPath);
} catch (err) {
if (isFileMissingError(err)) {
return null;
}
throw err;
}
const normalizedPath = path.relative(workspaceDir, absPath).replace(/\\/g, "/");
const multimodalSettings = multimodal ?? DISABLED_MULTIMODAL_SETTINGS;
const modality = classifyMemoryMultimodalPath(absPath, multimodalSettings);
if (modality) {
if (stat.size > multimodalSettings.maxFileBytes) {
return null;
}
let buffer: Buffer;
try {
buffer = await fs.readFile(absPath);
} catch (err) {
if (isFileMissingError(err)) {
return null;
}
throw err;
}
const mimeType = await detectMime({ buffer: buffer.subarray(0, 512), filePath: absPath });
if (!mimeType || !mimeType.startsWith(`${modality}/`)) {
return null;
}
const contentText = buildMemoryMultimodalLabel(modality, normalizedPath);
const dataHash = crypto.createHash("sha256").update(buffer).digest("hex");
const chunkHash = hashText(
JSON.stringify({
path: normalizedPath,
contentText,
mimeType,
dataHash,
}),
);
return {
path: normalizedPath,
absPath,
mtimeMs: stat.mtimeMs,
size: stat.size,
hash: chunkHash,
dataHash,
kind: "multimodal",
contentText,
modality,
mimeType,
};
}
let content: string;
try {
content = await fs.readFile(absPath, "utf-8");
} catch (err) {
if (isFileMissingError(err)) {
return null;
}
throw err;
}
const hash = hashText(content);
return {
path: normalizedPath,
absPath,
mtimeMs: stat.mtimeMs,
size: stat.size,
hash,
kind: "markdown",
};
}
async function loadMultimodalEmbeddingInput(
entry: Pick<
MemoryFileEntry,
"absPath" | "contentText" | "mimeType" | "kind" | "size" | "dataHash"
>,
): Promise<EmbeddingInput | null> {
if (entry.kind !== "multimodal" || !entry.contentText || !entry.mimeType) {
return null;
}
let stat;
try {
stat = await fs.stat(entry.absPath);
} catch (err) {
if (isFileMissingError(err)) {
return null;
}
throw err;
}
if (stat.size !== entry.size) {
return null;
}
let buffer: Buffer;
try {
buffer = await fs.readFile(entry.absPath);
} catch (err) {
if (isFileMissingError(err)) {
return null;
}
throw err;
}
const dataHash = crypto.createHash("sha256").update(buffer).digest("hex");
if (entry.dataHash && entry.dataHash !== dataHash) {
return null;
}
return {
text: entry.contentText,
parts: [
{ type: "text", text: entry.contentText },
{
type: "inline-data",
mimeType: entry.mimeType,
data: buffer.toString("base64"),
},
],
};
}
export async function buildMultimodalChunkForIndexing(
entry: Pick<
MemoryFileEntry,
"absPath" | "contentText" | "mimeType" | "kind" | "hash" | "size" | "dataHash"
>,
): Promise<MultimodalMemoryChunk | null> {
const embeddingInput = await loadMultimodalEmbeddingInput(entry);
if (!embeddingInput) {
return null;
}
return {
chunk: {
startLine: 1,
endLine: 1,
text: entry.contentText ?? embeddingInput.text,
hash: entry.hash,
embeddingInput,
},
structuredInputBytes: estimateStructuredEmbeddingInputBytes(embeddingInput),
};
}
export function chunkMarkdown(
content: string,
chunking: { tokens: number; overlap: number },
): MemoryChunk[] {
const lines = content.split("\n");
if (lines.length === 0) {
return [];
}
const maxChars = Math.max(32, chunking.tokens * 4);
const overlapChars = Math.max(0, chunking.overlap * 4);
const chunks: MemoryChunk[] = [];
let current: Array<{ line: string; lineNo: number }> = [];
let currentChars = 0;
const flush = () => {
if (current.length === 0) {
return;
}
const firstEntry = current[0];
const lastEntry = current[current.length - 1];
if (!firstEntry || !lastEntry) {
return;
}
const text = current.map((entry) => entry.line).join("\n");
const startLine = firstEntry.lineNo;
const endLine = lastEntry.lineNo;
chunks.push({
startLine,
endLine,
text,
hash: hashText(text),
embeddingInput: buildTextEmbeddingInput(text),
});
};
const carryOverlap = () => {
if (overlapChars <= 0 || current.length === 0) {
current = [];
currentChars = 0;
return;
}
let acc = 0;
const kept: Array<{ line: string; lineNo: number }> = [];
for (let i = current.length - 1; i >= 0; i -= 1) {
const entry = current[i];
if (!entry) {
continue;
}
acc += entry.line.length + 1;
kept.unshift(entry);
if (acc >= overlapChars) {
break;
}
}
current = kept;
currentChars = kept.reduce((sum, entry) => sum + entry.line.length + 1, 0);
};
for (let i = 0; i < lines.length; i += 1) {
const line = lines[i] ?? "";
const lineNo = i + 1;
const segments: string[] = [];
if (line.length === 0) {
segments.push("");
} else {
for (let start = 0; start < line.length; start += maxChars) {
segments.push(line.slice(start, start + maxChars));
}
}
for (const segment of segments) {
const lineSize = segment.length + 1;
if (currentChars + lineSize > maxChars && current.length > 0) {
flush();
carryOverlap();
}
current.push({ line: segment, lineNo });
currentChars += lineSize;
}
}
flush();
return chunks;
}
/**
* Remap chunk startLine/endLine from content-relative positions to original
* source file positions using a lineMap. Each entry in lineMap gives the
* 1-indexed source line for the corresponding 0-indexed content line.
*
* This is used for session JSONL files where buildSessionEntry() flattens
* messages into a plain-text string before chunking. Without remapping the
* stored line numbers would reference positions in the flattened text rather
* than the original JSONL file.
*/
export function remapChunkLines(chunks: MemoryChunk[], lineMap: number[] | undefined): void {
if (!lineMap || lineMap.length === 0) {
return;
}
for (const chunk of chunks) {
// startLine/endLine are 1-indexed; lineMap is 0-indexed by content line
chunk.startLine = lineMap[chunk.startLine - 1] ?? chunk.startLine;
chunk.endLine = lineMap[chunk.endLine - 1] ?? chunk.endLine;
}
}
export function parseEmbedding(raw: string): number[] {
try {
const parsed = JSON.parse(raw) as number[];
return Array.isArray(parsed) ? parsed : [];
} catch {
return [];
}
}
export function cosineSimilarity(a: number[], b: number[]): number {
if (a.length === 0 || b.length === 0) {
return 0;
}
const len = Math.min(a.length, b.length);
let dot = 0;
let normA = 0;
let normB = 0;
for (let i = 0; i < len; i += 1) {
const av = a[i] ?? 0;
const bv = b[i] ?? 0;
dot += av * bv;
normA += av * av;
normB += bv * bv;
}
if (normA === 0 || normB === 0) {
return 0;
}
return dot / (Math.sqrt(normA) * Math.sqrt(normB));
}
export async function runWithConcurrency<T>(
tasks: Array<() => Promise<T>>,
limit: number,
): Promise<T[]> {
const { results, firstError, hasError } = await runTasksWithConcurrency({
tasks,
limit,
errorMode: "stop",
});
if (hasError) {
throw firstError;
}
return results;
}

View File

@@ -1,99 +0,0 @@
import type { DatabaseSync } from "node:sqlite";
export function ensureMemoryIndexSchema(params: {
db: DatabaseSync;
embeddingCacheTable: string;
cacheEnabled: boolean;
ftsTable: string;
ftsEnabled: boolean;
}): { ftsAvailable: boolean; ftsError?: string } {
params.db.exec(`
CREATE TABLE IF NOT EXISTS meta (
key TEXT PRIMARY KEY,
value TEXT NOT NULL
);
`);
params.db.exec(`
CREATE TABLE IF NOT EXISTS files (
path TEXT PRIMARY KEY,
source TEXT NOT NULL DEFAULT 'memory',
hash TEXT NOT NULL,
mtime INTEGER NOT NULL,
size INTEGER NOT NULL
);
`);
params.db.exec(`
CREATE TABLE IF NOT EXISTS chunks (
id TEXT PRIMARY KEY,
path TEXT NOT NULL,
source TEXT NOT NULL DEFAULT 'memory',
start_line INTEGER NOT NULL,
end_line INTEGER NOT NULL,
hash TEXT NOT NULL,
model TEXT NOT NULL,
text TEXT NOT NULL,
embedding TEXT NOT NULL,
updated_at INTEGER NOT NULL
);
`);
if (params.cacheEnabled) {
params.db.exec(`
CREATE TABLE IF NOT EXISTS ${params.embeddingCacheTable} (
provider TEXT NOT NULL,
model TEXT NOT NULL,
provider_key TEXT NOT NULL,
hash TEXT NOT NULL,
embedding TEXT NOT NULL,
dims INTEGER,
updated_at INTEGER NOT NULL,
PRIMARY KEY (provider, model, provider_key, hash)
);
`);
params.db.exec(
`CREATE INDEX IF NOT EXISTS idx_embedding_cache_updated_at ON ${params.embeddingCacheTable}(updated_at);`,
);
}
let ftsAvailable = false;
let ftsError: string | undefined;
if (params.ftsEnabled) {
try {
params.db.exec(
`CREATE VIRTUAL TABLE IF NOT EXISTS ${params.ftsTable} USING fts5(\n` +
` text,\n` +
` id UNINDEXED,\n` +
` path UNINDEXED,\n` +
` source UNINDEXED,\n` +
` model UNINDEXED,\n` +
` start_line UNINDEXED,\n` +
` end_line UNINDEXED\n` +
`);`,
);
ftsAvailable = true;
} catch (err) {
const message = err instanceof Error ? err.message : String(err);
ftsAvailable = false;
ftsError = message;
}
}
ensureColumn(params.db, "files", "source", "TEXT NOT NULL DEFAULT 'memory'");
ensureColumn(params.db, "chunks", "source", "TEXT NOT NULL DEFAULT 'memory'");
params.db.exec(`CREATE INDEX IF NOT EXISTS idx_chunks_path ON chunks(path);`);
params.db.exec(`CREATE INDEX IF NOT EXISTS idx_chunks_source ON chunks(source);`);
return { ftsAvailable, ...(ftsError ? { ftsError } : {}) };
}
function ensureColumn(
db: DatabaseSync,
table: "files" | "chunks",
column: string,
definition: string,
): void {
const rows = db.prepare(`PRAGMA table_info(${table})`).all() as Array<{ name: string }>;
if (rows.some((row) => row.name === column)) {
return;
}
db.exec(`ALTER TABLE ${table} ADD COLUMN ${column} ${definition}`);
}

View File

@@ -1,118 +0,0 @@
const MEMORY_MULTIMODAL_SPECS = {
image: {
labelPrefix: "Image file",
extensions: [".jpg", ".jpeg", ".png", ".webp", ".gif", ".heic", ".heif"],
},
audio: {
labelPrefix: "Audio file",
extensions: [".mp3", ".wav", ".ogg", ".opus", ".m4a", ".aac", ".flac"],
},
} as const;
export type MemoryMultimodalModality = keyof typeof MEMORY_MULTIMODAL_SPECS;
export const MEMORY_MULTIMODAL_MODALITIES = Object.keys(
MEMORY_MULTIMODAL_SPECS,
) as MemoryMultimodalModality[];
export type MemoryMultimodalSelection = MemoryMultimodalModality | "all";
export type MemoryMultimodalSettings = {
enabled: boolean;
modalities: MemoryMultimodalModality[];
maxFileBytes: number;
};
export const DEFAULT_MEMORY_MULTIMODAL_MAX_FILE_BYTES = 10 * 1024 * 1024;
export function normalizeMemoryMultimodalModalities(
raw: MemoryMultimodalSelection[] | undefined,
): MemoryMultimodalModality[] {
if (raw === undefined || raw.includes("all")) {
return [...MEMORY_MULTIMODAL_MODALITIES];
}
const normalized = new Set<MemoryMultimodalModality>();
for (const value of raw) {
if (value === "image" || value === "audio") {
normalized.add(value);
}
}
return Array.from(normalized);
}
export function normalizeMemoryMultimodalSettings(raw: {
enabled?: boolean;
modalities?: MemoryMultimodalSelection[];
maxFileBytes?: number;
}): MemoryMultimodalSettings {
const enabled = raw.enabled === true;
const maxFileBytes =
typeof raw.maxFileBytes === "number" && Number.isFinite(raw.maxFileBytes)
? Math.max(1, Math.floor(raw.maxFileBytes))
: DEFAULT_MEMORY_MULTIMODAL_MAX_FILE_BYTES;
return {
enabled,
modalities: enabled ? normalizeMemoryMultimodalModalities(raw.modalities) : [],
maxFileBytes,
};
}
export function isMemoryMultimodalEnabled(settings: MemoryMultimodalSettings): boolean {
return settings.enabled && settings.modalities.length > 0;
}
export function getMemoryMultimodalExtensions(
modality: MemoryMultimodalModality,
): readonly string[] {
return MEMORY_MULTIMODAL_SPECS[modality].extensions;
}
export function buildMemoryMultimodalLabel(
modality: MemoryMultimodalModality,
normalizedPath: string,
): string {
return `${MEMORY_MULTIMODAL_SPECS[modality].labelPrefix}: ${normalizedPath}`;
}
export function buildCaseInsensitiveExtensionGlob(extension: string): string {
const normalized = extension.trim().replace(/^\./, "").toLowerCase();
if (!normalized) {
return "*";
}
const parts = Array.from(normalized, (char) => `[${char.toLowerCase()}${char.toUpperCase()}]`);
return `*.${parts.join("")}`;
}
export function classifyMemoryMultimodalPath(
filePath: string,
settings: MemoryMultimodalSettings,
): MemoryMultimodalModality | null {
if (!isMemoryMultimodalEnabled(settings)) {
return null;
}
const lower = filePath.trim().toLowerCase();
for (const modality of settings.modalities) {
for (const extension of getMemoryMultimodalExtensions(modality)) {
if (lower.endsWith(extension)) {
return modality;
}
}
}
return null;
}
export function normalizeGeminiEmbeddingModelForMemory(model: string): string {
const trimmed = model.trim();
if (!trimmed) {
return "";
}
return trimmed.replace(/^models\//, "").replace(/^(gemini|google)\//, "");
}
export function supportsMemoryMultimodalEmbeddings(params: {
provider: string;
model: string;
}): boolean {
if (params.provider !== "gemini") {
return false;
}
return normalizeGeminiEmbeddingModelForMemory(params.model) === "gemini-embedding-2-preview";
}

View File

@@ -1,3 +0,0 @@
export async function importNodeLlamaCpp() {
return import("node-llama-cpp");
}

View File

@@ -1,59 +0,0 @@
import { beforeEach, describe, expect, it, vi } from "vitest";
vi.mock("./remote-http.js", () => ({
withRemoteHttpResponse: vi.fn(),
}));
let postJson: typeof import("./post-json.js").postJson;
let withRemoteHttpResponse: typeof import("./remote-http.js").withRemoteHttpResponse;
describe("postJson", () => {
let remoteHttpMock: ReturnType<typeof vi.mocked<typeof withRemoteHttpResponse>>;
beforeEach(async () => {
vi.resetModules();
vi.clearAllMocks();
vi.resetModules();
({ postJson } = await import("./post-json.js"));
({ withRemoteHttpResponse } = await import("./remote-http.js"));
remoteHttpMock = vi.mocked(withRemoteHttpResponse);
});
it("parses JSON payload on successful response", async () => {
remoteHttpMock.mockImplementationOnce(async (params) => {
return await params.onResponse(
new Response(JSON.stringify({ data: [{ embedding: [1, 2] }] }), { status: 200 }),
);
});
const result = await postJson({
url: "https://memory.example/v1/post",
headers: { Authorization: "Bearer test" },
body: { input: ["x"] },
errorPrefix: "post failed",
parse: (payload) => payload,
});
expect(result).toEqual({ data: [{ embedding: [1, 2] }] });
});
it("attaches status to thrown error when requested", async () => {
remoteHttpMock.mockImplementationOnce(async (params) => {
return await params.onResponse(new Response("bad gateway", { status: 502 }));
});
await expect(
postJson({
url: "https://memory.example/v1/post",
headers: {},
body: {},
errorPrefix: "post failed",
attachStatus: true,
parse: () => ({}),
}),
).rejects.toMatchObject({
message: expect.stringContaining("post failed: 502 bad gateway"),
status: 502,
});
});
});

View File

@@ -1,35 +0,0 @@
import type { SsrFPolicy } from "../../infra/net/ssrf.js";
import { withRemoteHttpResponse } from "./remote-http.js";
export async function postJson<T>(params: {
url: string;
headers: Record<string, string>;
ssrfPolicy?: SsrFPolicy;
body: unknown;
errorPrefix: string;
attachStatus?: boolean;
parse: (payload: unknown) => T | Promise<T>;
}): Promise<T> {
return await withRemoteHttpResponse({
url: params.url,
ssrfPolicy: params.ssrfPolicy,
init: {
method: "POST",
headers: params.headers,
body: JSON.stringify(params.body),
},
onResponse: async (res) => {
if (!res.ok) {
const text = await res.text();
const err = new Error(`${params.errorPrefix}: ${res.status} ${text}`) as Error & {
status?: number;
};
if (params.attachStatus) {
err.status = res.status;
}
throw err;
}
return await params.parse(await res.json());
},
});
}

View File

@@ -1,91 +0,0 @@
import fs from "node:fs/promises";
import os from "node:os";
import path from "node:path";
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
import { resolveCliSpawnInvocation } from "./qmd-process.js";
describe("resolveCliSpawnInvocation", () => {
let tempDir = "";
let platformSpy: { mockRestore(): void } | null = null;
const originalPath = process.env.PATH;
const originalPathExt = process.env.PATHEXT;
beforeEach(async () => {
tempDir = await fs.mkdtemp(path.join(os.tmpdir(), "openclaw-qmd-win-spawn-"));
platformSpy = vi.spyOn(process, "platform", "get").mockReturnValue("win32");
});
afterEach(async () => {
platformSpy?.mockRestore();
process.env.PATH = originalPath;
process.env.PATHEXT = originalPathExt;
if (tempDir) {
await fs.rm(tempDir, { recursive: true, force: true });
tempDir = "";
}
});
it("unwraps npm cmd shims to a direct node entrypoint", async () => {
const binDir = path.join(tempDir, "node_modules", ".bin");
const packageDir = path.join(tempDir, "node_modules", "qmd");
const scriptPath = path.join(packageDir, "dist", "cli.js");
await fs.mkdir(path.dirname(scriptPath), { recursive: true });
await fs.mkdir(binDir, { recursive: true });
await fs.writeFile(path.join(binDir, "qmd.cmd"), "@echo off\r\n", "utf8");
await fs.writeFile(
path.join(packageDir, "package.json"),
JSON.stringify({ name: "qmd", version: "0.0.0", bin: { qmd: "dist/cli.js" } }),
"utf8",
);
await fs.writeFile(scriptPath, "module.exports = {};\n", "utf8");
process.env.PATH = `${binDir};${originalPath ?? ""}`;
process.env.PATHEXT = ".CMD;.EXE";
const invocation = resolveCliSpawnInvocation({
command: "qmd",
args: ["query", "hello"],
env: process.env,
packageName: "qmd",
});
expect(invocation.command).toBe(process.execPath);
expect(invocation.argv).toEqual([scriptPath, "query", "hello"]);
expect(invocation.shell).not.toBe(true);
expect(invocation.windowsHide).toBe(true);
});
it("fails closed when a Windows cmd shim cannot be resolved without shell execution", async () => {
const binDir = path.join(tempDir, "bad-bin");
await fs.mkdir(binDir, { recursive: true });
await fs.writeFile(path.join(binDir, "qmd.cmd"), "@echo off\r\nREM no entrypoint\r\n", "utf8");
process.env.PATH = `${binDir};${originalPath ?? ""}`;
process.env.PATHEXT = ".CMD;.EXE";
expect(() =>
resolveCliSpawnInvocation({
command: "qmd",
args: ["query", "hello"],
env: process.env,
packageName: "qmd",
}),
).toThrow(/without shell execution/);
});
it("keeps bare commands bare when no Windows wrapper exists on PATH", () => {
process.env.PATH = originalPath ?? "";
process.env.PATHEXT = ".CMD;.EXE";
const invocation = resolveCliSpawnInvocation({
command: "qmd",
args: ["query", "hello"],
env: process.env,
packageName: "qmd",
});
expect(invocation.command).toBe("qmd");
expect(invocation.argv).toEqual(["query", "hello"]);
expect(invocation.shell).not.toBe(true);
});
});

View File

@@ -1,108 +0,0 @@
import { spawn } from "node:child_process";
import {
materializeWindowsSpawnProgram,
resolveWindowsSpawnProgram,
} from "../../plugin-sdk/windows-spawn.js";
export type CliSpawnInvocation = {
command: string;
argv: string[];
shell?: boolean;
windowsHide?: boolean;
};
export function resolveCliSpawnInvocation(params: {
command: string;
args: string[];
env: NodeJS.ProcessEnv;
packageName: string;
}): CliSpawnInvocation {
const program = resolveWindowsSpawnProgram({
command: params.command,
platform: process.platform,
env: params.env,
execPath: process.execPath,
packageName: params.packageName,
allowShellFallback: false,
});
return materializeWindowsSpawnProgram(program, params.args);
}
export async function runCliCommand(params: {
commandSummary: string;
spawnInvocation: CliSpawnInvocation;
env: NodeJS.ProcessEnv;
cwd: string;
timeoutMs?: number;
maxOutputChars: number;
discardStdout?: boolean;
}): Promise<{ stdout: string; stderr: string }> {
return await new Promise((resolve, reject) => {
const child = spawn(params.spawnInvocation.command, params.spawnInvocation.argv, {
env: params.env,
cwd: params.cwd,
shell: params.spawnInvocation.shell,
windowsHide: params.spawnInvocation.windowsHide,
});
let stdout = "";
let stderr = "";
let stdoutTruncated = false;
let stderrTruncated = false;
const discardStdout = params.discardStdout === true;
const timer = params.timeoutMs
? setTimeout(() => {
child.kill("SIGKILL");
reject(new Error(`${params.commandSummary} timed out after ${params.timeoutMs}ms`));
}, params.timeoutMs)
: null;
child.stdout.on("data", (data) => {
if (discardStdout) {
return;
}
const next = appendOutputWithCap(stdout, data.toString("utf8"), params.maxOutputChars);
stdout = next.text;
stdoutTruncated = stdoutTruncated || next.truncated;
});
child.stderr.on("data", (data) => {
const next = appendOutputWithCap(stderr, data.toString("utf8"), params.maxOutputChars);
stderr = next.text;
stderrTruncated = stderrTruncated || next.truncated;
});
child.on("error", (err) => {
if (timer) {
clearTimeout(timer);
}
reject(err);
});
child.on("close", (code) => {
if (timer) {
clearTimeout(timer);
}
if (!discardStdout && (stdoutTruncated || stderrTruncated)) {
reject(
new Error(
`${params.commandSummary} produced too much output (limit ${params.maxOutputChars} chars)`,
),
);
return;
}
if (code === 0) {
resolve({ stdout, stderr });
} else {
reject(new Error(`${params.commandSummary} failed (code ${code}): ${stderr || stdout}`));
}
});
});
}
function appendOutputWithCap(
current: string,
chunk: string,
maxChars: number,
): { text: string; truncated: boolean } {
const appended = current + chunk;
if (appended.length <= maxChars) {
return { text: appended, truncated: false };
}
return { text: appended.slice(-maxChars), truncated: true };
}

View File

@@ -1,48 +0,0 @@
import { describe, expect, it } from "vitest";
import { parseQmdQueryJson } from "./qmd-query-parser.js";
describe("parseQmdQueryJson", () => {
it("parses clean qmd JSON output", () => {
const results = parseQmdQueryJson('[{"docid":"abc","score":1,"snippet":"@@ -1,1\\none"}]', "");
expect(results).toEqual([
{
docid: "abc",
score: 1,
snippet: "@@ -1,1\none",
},
]);
});
it("extracts embedded result arrays from noisy stdout", () => {
const results = parseQmdQueryJson(
`initializing
{"payload":"ok"}
[{"docid":"abc","score":0.5}]
complete`,
"",
);
expect(results).toEqual([{ docid: "abc", score: 0.5 }]);
});
it("treats plain-text no-results from stderr as an empty result set", () => {
const results = parseQmdQueryJson("", "No results found\n");
expect(results).toEqual([]);
});
it("treats prefixed no-results marker output as an empty result set", () => {
expect(parseQmdQueryJson("warning: no results found", "")).toEqual([]);
expect(parseQmdQueryJson("", "[qmd] warning: no results found\n")).toEqual([]);
});
it("does not treat arbitrary non-marker text as no-results output", () => {
expect(() =>
parseQmdQueryJson("warning: search completed; no results found for this query", ""),
).toThrow(/qmd query returned invalid JSON/i);
});
it("throws when stdout cannot be interpreted as qmd JSON", () => {
expect(() => parseQmdQueryJson("this is not json", "")).toThrow(
/qmd query returned invalid JSON/i,
);
});
});

View File

@@ -1,121 +0,0 @@
import { createSubsystemLogger } from "../../logging/subsystem.js";
const log = createSubsystemLogger("memory");
export type QmdQueryResult = {
docid?: string;
score?: number;
collection?: string;
file?: string;
snippet?: string;
body?: string;
};
export function parseQmdQueryJson(stdout: string, stderr: string): QmdQueryResult[] {
const trimmedStdout = stdout.trim();
const trimmedStderr = stderr.trim();
const stdoutIsMarker = trimmedStdout.length > 0 && isQmdNoResultsOutput(trimmedStdout);
const stderrIsMarker = trimmedStderr.length > 0 && isQmdNoResultsOutput(trimmedStderr);
if (stdoutIsMarker || (!trimmedStdout && stderrIsMarker)) {
return [];
}
if (!trimmedStdout) {
const context = trimmedStderr ? ` (stderr: ${summarizeQmdStderr(trimmedStderr)})` : "";
const message = `stdout empty${context}`;
log.warn(`qmd query returned invalid JSON: ${message}`);
throw new Error(`qmd query returned invalid JSON: ${message}`);
}
try {
const parsed = parseQmdQueryResultArray(trimmedStdout);
if (parsed !== null) {
return parsed;
}
const noisyPayload = extractFirstJsonArray(trimmedStdout);
if (!noisyPayload) {
throw new Error("qmd query JSON response was not an array");
}
const fallback = parseQmdQueryResultArray(noisyPayload);
if (fallback !== null) {
return fallback;
}
throw new Error("qmd query JSON response was not an array");
} catch (err) {
const message = err instanceof Error ? err.message : String(err);
log.warn(`qmd query returned invalid JSON: ${message}`);
throw new Error(`qmd query returned invalid JSON: ${message}`, { cause: err });
}
}
function isQmdNoResultsOutput(raw: string): boolean {
const lines = raw
.split(/\r?\n/)
.map((line) => line.trim().toLowerCase().replace(/\s+/g, " "))
.filter((line) => line.length > 0);
return lines.some((line) => isQmdNoResultsLine(line));
}
function isQmdNoResultsLine(line: string): boolean {
if (line === "no results found" || line === "no results found.") {
return true;
}
return /^(?:\[[^\]]+\]\s*)?(?:(?:warn(?:ing)?|info|error|qmd)\s*:\s*)+no results found\.?$/.test(
line,
);
}
function summarizeQmdStderr(raw: string): string {
return raw.length <= 120 ? raw : `${raw.slice(0, 117)}...`;
}
function parseQmdQueryResultArray(raw: string): QmdQueryResult[] | null {
try {
const parsed = JSON.parse(raw) as unknown;
if (!Array.isArray(parsed)) {
return null;
}
return parsed as QmdQueryResult[];
} catch {
return null;
}
}
function extractFirstJsonArray(raw: string): string | null {
const start = raw.indexOf("[");
if (start < 0) {
return null;
}
let depth = 0;
let inString = false;
let escaped = false;
for (let i = start; i < raw.length; i += 1) {
const char = raw[i];
if (char === undefined) {
break;
}
if (inString) {
if (escaped) {
escaped = false;
continue;
}
if (char === "\\") {
escaped = true;
} else if (char === '"') {
inString = false;
}
continue;
}
if (char === '"') {
inString = true;
continue;
}
if (char === "[") {
depth += 1;
} else if (char === "]") {
depth -= 1;
if (depth === 0) {
return raw.slice(start, i + 1);
}
}
}
return null;
}

View File

@@ -1,54 +0,0 @@
import { describe, expect, it } from "vitest";
import type { ResolvedQmdConfig } from "./backend-config.js";
import { deriveQmdScopeChannel, deriveQmdScopeChatType, isQmdScopeAllowed } from "./qmd-scope.js";
describe("qmd scope", () => {
const allowDirect: ResolvedQmdConfig["scope"] = {
default: "deny",
rules: [{ action: "allow", match: { chatType: "direct" } }],
};
it("derives channel and chat type from canonical keys once", () => {
expect(deriveQmdScopeChannel("Workspace:group:123")).toBe("workspace");
expect(deriveQmdScopeChatType("Workspace:group:123")).toBe("group");
});
it("derives channel and chat type from stored key suffixes", () => {
expect(deriveQmdScopeChannel("agent:agent-1:workspace:channel:chan-123")).toBe("workspace");
expect(deriveQmdScopeChatType("agent:agent-1:workspace:channel:chan-123")).toBe("channel");
});
it("treats parsed keys with no chat prefix as direct", () => {
expect(deriveQmdScopeChannel("agent:agent-1:peer-direct")).toBeUndefined();
expect(deriveQmdScopeChatType("agent:agent-1:peer-direct")).toBe("direct");
expect(isQmdScopeAllowed(allowDirect, "agent:agent-1:peer-direct")).toBe(true);
expect(isQmdScopeAllowed(allowDirect, "agent:agent-1:peer:group:abc")).toBe(false);
});
it("applies scoped key-prefix checks against normalized key", () => {
const scope: ResolvedQmdConfig["scope"] = {
default: "deny",
rules: [{ action: "allow", match: { keyPrefix: "workspace:" } }],
};
expect(isQmdScopeAllowed(scope, "agent:agent-1:workspace:group:123")).toBe(true);
expect(isQmdScopeAllowed(scope, "agent:agent-1:other:group:123")).toBe(false);
});
it("supports rawKeyPrefix matches for agent-prefixed keys", () => {
const scope: ResolvedQmdConfig["scope"] = {
default: "allow",
rules: [{ action: "deny", match: { rawKeyPrefix: "agent:main:discord:" } }],
};
expect(isQmdScopeAllowed(scope, "agent:main:discord:channel:c123")).toBe(false);
expect(isQmdScopeAllowed(scope, "agent:main:slack:channel:c123")).toBe(true);
});
it("keeps legacy agent-prefixed keyPrefix rules working", () => {
const scope: ResolvedQmdConfig["scope"] = {
default: "allow",
rules: [{ action: "deny", match: { keyPrefix: "agent:main:discord:" } }],
};
expect(isQmdScopeAllowed(scope, "agent:main:discord:channel:c123")).toBe(false);
expect(isQmdScopeAllowed(scope, "agent:main:slack:channel:c123")).toBe(true);
});
});

View File

@@ -1,106 +0,0 @@
import { parseAgentSessionKey } from "../../sessions/session-key-utils.js";
import type { ResolvedQmdConfig } from "./backend-config.js";
type ParsedQmdSessionScope = {
channel?: string;
chatType?: "channel" | "group" | "direct";
normalizedKey?: string;
};
export function isQmdScopeAllowed(scope: ResolvedQmdConfig["scope"], sessionKey?: string): boolean {
if (!scope) {
return true;
}
const parsed = parseQmdSessionScope(sessionKey);
const channel = parsed.channel;
const chatType = parsed.chatType;
const normalizedKey = parsed.normalizedKey ?? "";
const rawKey = sessionKey?.trim().toLowerCase() ?? "";
for (const rule of scope.rules ?? []) {
if (!rule) {
continue;
}
const match = rule.match ?? {};
if (match.channel && match.channel !== channel) {
continue;
}
if (match.chatType && match.chatType !== chatType) {
continue;
}
const normalizedPrefix = match.keyPrefix?.trim().toLowerCase() || undefined;
const rawPrefix = match.rawKeyPrefix?.trim().toLowerCase() || undefined;
if (rawPrefix && !rawKey.startsWith(rawPrefix)) {
continue;
}
if (normalizedPrefix) {
// Backward compat: older configs used `keyPrefix: "agent:<id>:..."` to match raw keys.
const isLegacyRaw = normalizedPrefix.startsWith("agent:");
if (isLegacyRaw) {
if (!rawKey.startsWith(normalizedPrefix)) {
continue;
}
} else if (!normalizedKey.startsWith(normalizedPrefix)) {
continue;
}
}
return rule.action === "allow";
}
const fallback = scope.default ?? "allow";
return fallback === "allow";
}
export function deriveQmdScopeChannel(key?: string): string | undefined {
return parseQmdSessionScope(key).channel;
}
export function deriveQmdScopeChatType(key?: string): "channel" | "group" | "direct" | undefined {
return parseQmdSessionScope(key).chatType;
}
function parseQmdSessionScope(key?: string): ParsedQmdSessionScope {
const normalized = normalizeQmdSessionKey(key);
if (!normalized) {
return {};
}
const parts = normalized.split(":").filter(Boolean);
let chatType: ParsedQmdSessionScope["chatType"];
if (
parts.length >= 2 &&
(parts[1] === "group" || parts[1] === "channel" || parts[1] === "direct" || parts[1] === "dm")
) {
if (parts.includes("group")) {
chatType = "group";
} else if (parts.includes("channel")) {
chatType = "channel";
}
return {
normalizedKey: normalized,
channel: parts[0]?.toLowerCase(),
chatType: chatType ?? "direct",
};
}
if (normalized.includes(":group:")) {
return { normalizedKey: normalized, chatType: "group" };
}
if (normalized.includes(":channel:")) {
return { normalizedKey: normalized, chatType: "channel" };
}
return { normalizedKey: normalized, chatType: "direct" };
}
function normalizeQmdSessionKey(key?: string): string | undefined {
if (!key) {
return undefined;
}
const trimmed = key.trim();
if (!trimmed) {
return undefined;
}
const parsed = parseAgentSessionKey(trimmed);
const normalized = (parsed?.rest ?? trimmed).toLowerCase();
if (normalized.startsWith("subagent:")) {
return undefined;
}
return normalized;
}

View File

@@ -1,199 +0,0 @@
import { describe, expect, it } from "vitest";
import { expandQueryForFts, extractKeywords } from "./query-expansion.js";
describe("extractKeywords", () => {
it("extracts keywords from English conversational query", () => {
const keywords = extractKeywords("that thing we discussed about the API");
expect(keywords).toContain("discussed");
expect(keywords).toContain("api");
// Should not include stop words
expect(keywords).not.toContain("that");
expect(keywords).not.toContain("thing");
expect(keywords).not.toContain("we");
expect(keywords).not.toContain("about");
expect(keywords).not.toContain("the");
});
it("extracts keywords from Chinese conversational query", () => {
const keywords = extractKeywords("之前讨论的那个方案");
expect(keywords).toContain("讨论");
expect(keywords).toContain("方案");
// Should not include stop words
expect(keywords).not.toContain("之前");
expect(keywords).not.toContain("的");
expect(keywords).not.toContain("那个");
});
it("extracts keywords from mixed language query", () => {
const keywords = extractKeywords("昨天讨论的 API design");
expect(keywords).toContain("讨论");
expect(keywords).toContain("api");
expect(keywords).toContain("design");
});
it("returns specific technical terms", () => {
const keywords = extractKeywords("what was the solution for the CFR bug");
expect(keywords).toContain("solution");
expect(keywords).toContain("cfr");
expect(keywords).toContain("bug");
});
it("extracts keywords from Korean conversational query", () => {
const keywords = extractKeywords("어제 논의한 배포 전략");
expect(keywords).toContain("논의한");
expect(keywords).toContain("배포");
expect(keywords).toContain("전략");
// Should not include stop words
expect(keywords).not.toContain("어제");
});
it("strips Korean particles to extract stems", () => {
const keywords = extractKeywords("서버에서 발생한 에러를 확인");
expect(keywords).toContain("서버");
expect(keywords).toContain("에러");
expect(keywords).toContain("확인");
});
it("filters Korean stop words including inflected forms", () => {
const keywords = extractKeywords("나는 그리고 그래서");
expect(keywords).not.toContain("나");
expect(keywords).not.toContain("나는");
expect(keywords).not.toContain("그리고");
expect(keywords).not.toContain("그래서");
});
it("filters inflected Korean stop words not explicitly listed", () => {
const keywords = extractKeywords("그녀는 우리는");
expect(keywords).not.toContain("그녀는");
expect(keywords).not.toContain("우리는");
expect(keywords).not.toContain("그녀");
expect(keywords).not.toContain("우리");
});
it("does not produce bogus single-char stems from particle stripping", () => {
const keywords = extractKeywords("논의");
expect(keywords).toContain("논의");
expect(keywords).not.toContain("논");
});
it("strips longest Korean trailing particles first", () => {
const keywords = extractKeywords("기능으로 설명");
expect(keywords).toContain("기능");
expect(keywords).not.toContain("기능으");
});
it("keeps stripped ASCII stems for mixed Korean tokens", () => {
const keywords = extractKeywords("API를 배포했다");
expect(keywords).toContain("api");
expect(keywords).toContain("배포했다");
});
it("handles mixed Korean and English query", () => {
const keywords = extractKeywords("API 배포에 대한 논의");
expect(keywords).toContain("api");
expect(keywords).toContain("배포");
expect(keywords).toContain("논의");
});
it("extracts keywords from Japanese conversational query", () => {
const keywords = extractKeywords("昨日話したデプロイ戦略");
expect(keywords).toContain("デプロイ");
expect(keywords).toContain("戦略");
expect(keywords).not.toContain("昨日");
});
it("handles mixed Japanese and English query", () => {
const keywords = extractKeywords("昨日話したAPIのバグ");
expect(keywords).toContain("api");
expect(keywords).toContain("バグ");
expect(keywords).not.toContain("した");
});
it("filters Japanese stop words", () => {
const keywords = extractKeywords("これ それ そして どう");
expect(keywords).not.toContain("これ");
expect(keywords).not.toContain("それ");
expect(keywords).not.toContain("そして");
expect(keywords).not.toContain("どう");
});
it("extracts keywords from Spanish conversational query", () => {
const keywords = extractKeywords("ayer hablamos sobre la estrategia de despliegue");
expect(keywords).toContain("estrategia");
expect(keywords).toContain("despliegue");
expect(keywords).not.toContain("ayer");
expect(keywords).not.toContain("sobre");
});
it("extracts keywords from Portuguese conversational query", () => {
const keywords = extractKeywords("ontem falamos sobre a estratégia de implantação");
expect(keywords).toContain("estratégia");
expect(keywords).toContain("implantação");
expect(keywords).not.toContain("ontem");
expect(keywords).not.toContain("sobre");
});
it("filters Spanish and Portuguese question stop words", () => {
const keywords = extractKeywords("cómo cuando donde porquê quando onde");
expect(keywords).not.toContain("cómo");
expect(keywords).not.toContain("cuando");
expect(keywords).not.toContain("donde");
expect(keywords).not.toContain("porquê");
expect(keywords).not.toContain("quando");
expect(keywords).not.toContain("onde");
});
it("extracts keywords from Arabic conversational query", () => {
const keywords = extractKeywords("بالأمس ناقشنا استراتيجية النشر");
expect(keywords).toContain("ناقشنا");
expect(keywords).toContain("استراتيجية");
expect(keywords).toContain("النشر");
expect(keywords).not.toContain("بالأمس");
});
it("filters Arabic question stop words", () => {
const keywords = extractKeywords("كيف متى أين ماذا");
expect(keywords).not.toContain("كيف");
expect(keywords).not.toContain("متى");
expect(keywords).not.toContain("أين");
expect(keywords).not.toContain("ماذا");
});
it("handles empty query", () => {
expect(extractKeywords("")).toEqual([]);
expect(extractKeywords(" ")).toEqual([]);
});
it("handles query with only stop words", () => {
const keywords = extractKeywords("the a an is are");
expect(keywords.length).toBe(0);
});
it("removes duplicate keywords", () => {
const keywords = extractKeywords("test test testing");
const testCount = keywords.filter((k) => k === "test").length;
expect(testCount).toBe(1);
});
});
describe("expandQueryForFts", () => {
it("returns original query and extracted keywords", () => {
const result = expandQueryForFts("that API we discussed");
expect(result.original).toBe("that API we discussed");
expect(result.keywords).toContain("api");
expect(result.keywords).toContain("discussed");
});
it("builds expanded OR query for FTS", () => {
const result = expandQueryForFts("the solution for bugs");
expect(result.expanded).toContain("OR");
expect(result.expanded).toContain("solution");
expect(result.expanded).toContain("bugs");
});
it("returns original query when no keywords extracted", () => {
const result = expandQueryForFts("the");
expect(result.keywords.length).toBe(0);
expect(result.expanded).toBe("the");
});
});

View File

@@ -1,810 +0,0 @@
/**
* Query expansion for FTS-only search mode.
*
* When no embedding provider is available, we fall back to FTS (full-text search).
* FTS works best with specific keywords, but users often ask conversational queries
* like "that thing we discussed yesterday" or "之前讨论的那个方案".
*
* This module extracts meaningful keywords from such queries to improve FTS results.
*/
// Common stop words that don't add search value
const STOP_WORDS_EN = new Set([
// Articles and determiners
"a",
"an",
"the",
"this",
"that",
"these",
"those",
// Pronouns
"i",
"me",
"my",
"we",
"our",
"you",
"your",
"he",
"she",
"it",
"they",
"them",
// Common verbs
"is",
"are",
"was",
"were",
"be",
"been",
"being",
"have",
"has",
"had",
"do",
"does",
"did",
"will",
"would",
"could",
"should",
"can",
"may",
"might",
// Prepositions
"in",
"on",
"at",
"to",
"for",
"of",
"with",
"by",
"from",
"about",
"into",
"through",
"during",
"before",
"after",
"above",
"below",
"between",
"under",
"over",
// Conjunctions
"and",
"or",
"but",
"if",
"then",
"because",
"as",
"while",
"when",
"where",
"what",
"which",
"who",
"how",
"why",
// Time references (vague, not useful for FTS)
"yesterday",
"today",
"tomorrow",
"earlier",
"later",
"recently",
"before",
"ago",
"just",
"now",
// Vague references
"thing",
"things",
"stuff",
"something",
"anything",
"everything",
"nothing",
// Question words
"please",
"help",
"find",
"show",
"get",
"tell",
"give",
]);
const STOP_WORDS_ES = new Set([
// Articles and determiners
"el",
"la",
"los",
"las",
"un",
"una",
"unos",
"unas",
"este",
"esta",
"ese",
"esa",
// Pronouns
"yo",
"me",
"mi",
"nosotros",
"nosotras",
"tu",
"tus",
"usted",
"ustedes",
"ellos",
"ellas",
// Prepositions and conjunctions
"de",
"del",
"a",
"en",
"con",
"por",
"para",
"sobre",
"entre",
"y",
"o",
"pero",
"si",
"porque",
"como",
// Common verbs / auxiliaries
"es",
"son",
"fue",
"fueron",
"ser",
"estar",
"haber",
"tener",
"hacer",
// Time references (vague)
"ayer",
"hoy",
"mañana",
"antes",
"despues",
"después",
"ahora",
"recientemente",
// Question/request words
"que",
"qué",
"cómo",
"cuando",
"cuándo",
"donde",
"dónde",
"porqué",
"favor",
"ayuda",
]);
const STOP_WORDS_PT = new Set([
// Articles and determiners
"o",
"a",
"os",
"as",
"um",
"uma",
"uns",
"umas",
"este",
"esta",
"esse",
"essa",
// Pronouns
"eu",
"me",
"meu",
"minha",
"nos",
"nós",
"você",
"vocês",
"ele",
"ela",
"eles",
"elas",
// Prepositions and conjunctions
"de",
"do",
"da",
"em",
"com",
"por",
"para",
"sobre",
"entre",
"e",
"ou",
"mas",
"se",
"porque",
"como",
// Common verbs / auxiliaries
"é",
"são",
"foi",
"foram",
"ser",
"estar",
"ter",
"fazer",
// Time references (vague)
"ontem",
"hoje",
"amanhã",
"antes",
"depois",
"agora",
"recentemente",
// Question/request words
"que",
"quê",
"quando",
"onde",
"porquê",
"favor",
"ajuda",
]);
const STOP_WORDS_AR = new Set([
// Articles and connectors
"ال",
"و",
"أو",
"لكن",
"ثم",
"بل",
// Pronouns / references
"أنا",
"نحن",
"هو",
"هي",
"هم",
"هذا",
"هذه",
"ذلك",
"تلك",
"هنا",
"هناك",
// Common prepositions
"من",
"إلى",
"الى",
"في",
"على",
"عن",
"مع",
"بين",
"ل",
"ب",
"ك",
// Common auxiliaries / vague verbs
"كان",
"كانت",
"يكون",
"تكون",
"صار",
"أصبح",
"يمكن",
"ممكن",
// Time references (vague)
"بالأمس",
"امس",
"اليوم",
"غدا",
"الآن",
"قبل",
"بعد",
"مؤخرا",
// Question/request words
"لماذا",
"كيف",
"ماذا",
"متى",
"أين",
"هل",
"من فضلك",
"فضلا",
"ساعد",
]);
const STOP_WORDS_KO = new Set([
// Particles (조사)
"은",
"는",
"이",
"가",
"을",
"를",
"의",
"에",
"에서",
"로",
"으로",
"와",
"과",
"도",
"만",
"까지",
"부터",
"한테",
"에게",
"께",
"처럼",
"같이",
"보다",
"마다",
"밖에",
"대로",
// Pronouns (대명사)
"나",
"나는",
"내가",
"나를",
"너",
"우리",
"저",
"저희",
"그",
"그녀",
"그들",
"이것",
"저것",
"그것",
"여기",
"저기",
"거기",
// Common verbs / auxiliaries (일반 동사/보조 동사)
"있다",
"없다",
"하다",
"되다",
"이다",
"아니다",
"보다",
"주다",
"오다",
"가다",
// Nouns (의존 명사 / vague)
"것",
"거",
"등",
"수",
"때",
"곳",
"중",
"분",
// Adverbs
"잘",
"더",
"또",
"매우",
"정말",
"아주",
"많이",
"너무",
"좀",
// Conjunctions
"그리고",
"하지만",
"그래서",
"그런데",
"그러나",
"또는",
"그러면",
// Question words
"왜",
"어떻게",
"뭐",
"언제",
"어디",
"누구",
"무엇",
"어떤",
// Time (vague)
"어제",
"오늘",
"내일",
"최근",
"지금",
"아까",
"나중",
"전에",
// Request words
"제발",
"부탁",
]);
// Common Korean trailing particles to strip from words for tokenization
// Sorted by descending length so longest-match-first is guaranteed.
const KO_TRAILING_PARTICLES = [
"에서",
"으로",
"에게",
"한테",
"처럼",
"같이",
"보다",
"까지",
"부터",
"마다",
"밖에",
"대로",
"은",
"는",
"이",
"가",
"을",
"를",
"의",
"에",
"로",
"와",
"과",
"도",
"만",
].toSorted((a, b) => b.length - a.length);
function stripKoreanTrailingParticle(token: string): string | null {
for (const particle of KO_TRAILING_PARTICLES) {
if (token.length > particle.length && token.endsWith(particle)) {
return token.slice(0, -particle.length);
}
}
return null;
}
function isUsefulKoreanStem(stem: string): boolean {
// Prevent bogus one-syllable stems from words like "논의" -> "논".
if (/[\uac00-\ud7af]/.test(stem)) {
return stem.length >= 2;
}
// Keep stripped ASCII stems for mixed tokens like "API를" -> "api".
return /^[a-z0-9_]+$/i.test(stem);
}
const STOP_WORDS_JA = new Set([
// Pronouns and references
"これ",
"それ",
"あれ",
"この",
"その",
"あの",
"ここ",
"そこ",
"あそこ",
// Common auxiliaries / vague verbs
"する",
"した",
"して",
"です",
"ます",
"いる",
"ある",
"なる",
"できる",
// Particles / connectors
"の",
"こと",
"もの",
"ため",
"そして",
"しかし",
"また",
"でも",
"から",
"まで",
"より",
"だけ",
// Question words
"なぜ",
"どう",
"何",
"いつ",
"どこ",
"誰",
"どれ",
// Time (vague)
"昨日",
"今日",
"明日",
"最近",
"今",
"さっき",
"前",
"後",
]);
const STOP_WORDS_ZH = new Set([
// Pronouns
"我",
"我们",
"你",
"你们",
"他",
"她",
"它",
"他们",
"这",
"那",
"这个",
"那个",
"这些",
"那些",
// Auxiliary words
"的",
"了",
"着",
"过",
"得",
"地",
"吗",
"呢",
"吧",
"啊",
"呀",
"嘛",
"啦",
// Verbs (common, vague)
"是",
"有",
"在",
"被",
"把",
"给",
"让",
"用",
"到",
"去",
"来",
"做",
"说",
"看",
"找",
"想",
"要",
"能",
"会",
"可以",
// Prepositions and conjunctions
"和",
"与",
"或",
"但",
"但是",
"因为",
"所以",
"如果",
"虽然",
"而",
"也",
"都",
"就",
"还",
"又",
"再",
"才",
"只",
// Time (vague)
"之前",
"以前",
"之后",
"以后",
"刚才",
"现在",
"昨天",
"今天",
"明天",
"最近",
// Vague references
"东西",
"事情",
"事",
"什么",
"哪个",
"哪些",
"怎么",
"为什么",
"多少",
// Question/request words
"请",
"帮",
"帮忙",
"告诉",
]);
export function isQueryStopWordToken(token: string): boolean {
return (
STOP_WORDS_EN.has(token) ||
STOP_WORDS_ES.has(token) ||
STOP_WORDS_PT.has(token) ||
STOP_WORDS_AR.has(token) ||
STOP_WORDS_ZH.has(token) ||
STOP_WORDS_KO.has(token) ||
STOP_WORDS_JA.has(token)
);
}
/**
* Check if a token looks like a meaningful keyword.
* Returns false for short tokens, numbers-only, etc.
*/
function isValidKeyword(token: string): boolean {
if (!token || token.length === 0) {
return false;
}
// Skip very short English words (likely stop words or fragments)
if (/^[a-zA-Z]+$/.test(token) && token.length < 3) {
return false;
}
// Skip pure numbers (not useful for semantic search)
if (/^\d+$/.test(token)) {
return false;
}
// Skip tokens that are all punctuation
if (/^[\p{P}\p{S}]+$/u.test(token)) {
return false;
}
return true;
}
/**
* Simple tokenizer that handles English, Chinese, Korean, and Japanese text.
* For Chinese, we do character-based splitting since we don't have a proper segmenter.
* For English, we split on whitespace and punctuation.
*/
function tokenize(text: string): string[] {
const tokens: string[] = [];
const normalized = text.toLowerCase().trim();
// Split into segments (English words, Chinese character sequences, etc.)
const segments = normalized.split(/[\s\p{P}]+/u).filter(Boolean);
for (const segment of segments) {
// Japanese text often mixes scripts (kanji/kana/ASCII) without spaces.
// Extract script-specific chunks so technical terms like "API" / "バグ" are retained.
if (/[\u3040-\u30ff]/.test(segment)) {
const jpParts =
segment.match(/[a-z0-9_]+|[\u30a0-\u30ffー]+|[\u4e00-\u9fff]+|[\u3040-\u309f]{2,}/g) ?? [];
for (const part of jpParts) {
if (/^[\u4e00-\u9fff]+$/.test(part)) {
tokens.push(part);
for (let i = 0; i < part.length - 1; i++) {
tokens.push(part[i] + part[i + 1]);
}
} else {
tokens.push(part);
}
}
} else if (/[\u4e00-\u9fff]/.test(segment)) {
// Check if segment contains CJK characters (Chinese)
// For Chinese, extract character n-grams (unigrams and bigrams)
const chars = Array.from(segment).filter((c) => /[\u4e00-\u9fff]/.test(c));
// Add individual characters
tokens.push(...chars);
// Add bigrams for better phrase matching
for (let i = 0; i < chars.length - 1; i++) {
tokens.push(chars[i] + chars[i + 1]);
}
} else if (/[\uac00-\ud7af\u3131-\u3163]/.test(segment)) {
// For Korean (Hangul syllables and jamo), keep the word as-is unless it is
// effectively a stop word once trailing particles are removed.
const stem = stripKoreanTrailingParticle(segment);
const stemIsStopWord = stem !== null && STOP_WORDS_KO.has(stem);
if (!STOP_WORDS_KO.has(segment) && !stemIsStopWord) {
tokens.push(segment);
}
// Also emit particle-stripped stems when they are useful keywords.
if (stem && !STOP_WORDS_KO.has(stem) && isUsefulKoreanStem(stem)) {
tokens.push(stem);
}
} else {
// For non-CJK, keep as single token
tokens.push(segment);
}
}
return tokens;
}
/**
* Extract keywords from a conversational query for FTS search.
*
* Examples:
* - "that thing we discussed about the API" → ["discussed", "API"]
* - "之前讨论的那个方案" → ["讨论", "方案"]
* - "what was the solution for the bug" → ["solution", "bug"]
*/
export function extractKeywords(query: string): string[] {
const tokens = tokenize(query);
const keywords: string[] = [];
const seen = new Set<string>();
for (const token of tokens) {
// Skip stop words
if (isQueryStopWordToken(token)) {
continue;
}
// Skip invalid keywords
if (!isValidKeyword(token)) {
continue;
}
// Skip duplicates
if (seen.has(token)) {
continue;
}
seen.add(token);
keywords.push(token);
}
return keywords;
}
/**
* Expand a query for FTS search.
* Returns both the original query and extracted keywords for OR-matching.
*
* @param query - User's original query
* @returns Object with original query and extracted keywords
*/
export function expandQueryForFts(query: string): {
original: string;
keywords: string[];
expanded: string;
} {
const original = query.trim();
const keywords = extractKeywords(original);
// Build expanded query: original terms OR extracted keywords
// This ensures both exact matches and keyword matches are found
const expanded = keywords.length > 0 ? `${original} OR ${keywords.join(" OR ")}` : original;
return { original, keywords, expanded };
}
/**
* Type for an optional LLM-based query expander.
* Can be provided to enhance keyword extraction with semantic understanding.
*/
export type LlmQueryExpander = (query: string) => Promise<string[]>;
/**
* Expand query with optional LLM assistance.
* Falls back to local extraction if LLM is unavailable or fails.
*/
export async function expandQueryWithLlm(
query: string,
llmExpander?: LlmQueryExpander,
): Promise<string[]> {
// If LLM expander is provided, try it first
if (llmExpander) {
try {
const llmKeywords = await llmExpander(query);
if (llmKeywords.length > 0) {
return llmKeywords;
}
} catch {
// LLM failed, fall back to local extraction
}
}
// Fall back to local keyword extraction
return extractKeywords(query);
}

View File

@@ -1,96 +0,0 @@
import fs from "node:fs/promises";
import path from "node:path";
import { resolveAgentWorkspaceDir } from "../../agents/agent-scope.js";
import { resolveMemorySearchConfig } from "../../agents/memory-search.js";
import type { OpenClawConfig } from "../../config/config.js";
import { isFileMissingError, statRegularFile } from "./fs-utils.js";
import { isMemoryPath, normalizeExtraMemoryPaths } from "./internal.js";
export async function readMemoryFile(params: {
workspaceDir: string;
extraPaths?: string[];
relPath: string;
from?: number;
lines?: number;
}): Promise<{ text: string; path: string }> {
const rawPath = params.relPath.trim();
if (!rawPath) {
throw new Error("path required");
}
const absPath = path.isAbsolute(rawPath)
? path.resolve(rawPath)
: path.resolve(params.workspaceDir, rawPath);
const relPath = path.relative(params.workspaceDir, absPath).replace(/\\/g, "/");
const inWorkspace = relPath.length > 0 && !relPath.startsWith("..") && !path.isAbsolute(relPath);
const allowedWorkspace = inWorkspace && isMemoryPath(relPath);
let allowedAdditional = false;
if (!allowedWorkspace && (params.extraPaths?.length ?? 0) > 0) {
const additionalPaths = normalizeExtraMemoryPaths(params.workspaceDir, params.extraPaths);
for (const additionalPath of additionalPaths) {
try {
const stat = await fs.lstat(additionalPath);
if (stat.isSymbolicLink()) {
continue;
}
if (stat.isDirectory()) {
if (absPath === additionalPath || absPath.startsWith(`${additionalPath}${path.sep}`)) {
allowedAdditional = true;
break;
}
continue;
}
if (stat.isFile() && absPath === additionalPath && absPath.endsWith(".md")) {
allowedAdditional = true;
break;
}
} catch {}
}
}
if (!allowedWorkspace && !allowedAdditional) {
throw new Error("path required");
}
if (!absPath.endsWith(".md")) {
throw new Error("path required");
}
const statResult = await statRegularFile(absPath);
if (statResult.missing) {
return { text: "", path: relPath };
}
let content: string;
try {
content = await fs.readFile(absPath, "utf-8");
} catch (err) {
if (isFileMissingError(err)) {
return { text: "", path: relPath };
}
throw err;
}
if (!params.from && !params.lines) {
return { text: content, path: relPath };
}
const fileLines = content.split("\n");
const start = Math.max(1, params.from ?? 1);
const count = Math.max(1, params.lines ?? fileLines.length);
const slice = fileLines.slice(start - 1, start - 1 + count);
return { text: slice.join("\n"), path: relPath };
}
export async function readAgentMemoryFile(params: {
cfg: OpenClawConfig;
agentId: string;
relPath: string;
from?: number;
lines?: number;
}): Promise<{ text: string; path: string }> {
const settings = resolveMemorySearchConfig(params.cfg, params.agentId);
if (!settings) {
throw new Error("memory search disabled");
}
return await readMemoryFile({
workspaceDir: resolveAgentWorkspaceDir(params.cfg, params.agentId),
extraPaths: settings.extraPaths,
relPath: params.relPath,
from: params.from,
lines: params.lines,
});
}

View File

@@ -1,40 +0,0 @@
import { fetchWithSsrFGuard } from "../../infra/net/fetch-guard.js";
import type { SsrFPolicy } from "../../infra/net/ssrf.js";
export function buildRemoteBaseUrlPolicy(baseUrl: string): SsrFPolicy | undefined {
const trimmed = baseUrl.trim();
if (!trimmed) {
return undefined;
}
try {
const parsed = new URL(trimmed);
if (parsed.protocol !== "http:" && parsed.protocol !== "https:") {
return undefined;
}
// Keep policy tied to the configured host so private operator endpoints
// continue to work, while cross-host redirects stay blocked.
return { allowedHostnames: [parsed.hostname] };
} catch {
return undefined;
}
}
export async function withRemoteHttpResponse<T>(params: {
url: string;
init?: RequestInit;
ssrfPolicy?: SsrFPolicy;
auditContext?: string;
onResponse: (response: Response) => Promise<T>;
}): Promise<T> {
const { response, release } = await fetchWithSsrFGuard({
url: params.url,
init: params.init,
policy: params.ssrfPolicy,
auditContext: params.auditContext ?? "memory-remote",
});
try {
return await params.onResponse(response);
} finally {
await release();
}
}

View File

@@ -1,18 +0,0 @@
import {
hasConfiguredSecretInput,
normalizeResolvedSecretInputString,
} from "../../config/types.secrets.js";
export function hasConfiguredMemorySecretInput(value: unknown): boolean {
return hasConfiguredSecretInput(value);
}
export function resolveMemorySecretInputString(params: {
value: unknown;
path: string;
}): string | undefined {
return normalizeResolvedSecretInputString({
value: params.value,
path: params.path,
});
}

View File

@@ -1,87 +0,0 @@
import fs from "node:fs/promises";
import os from "node:os";
import path from "node:path";
import { afterEach, beforeEach, describe, expect, it } from "vitest";
import { buildSessionEntry } from "./session-files.js";
describe("buildSessionEntry", () => {
let tmpDir: string;
beforeEach(async () => {
tmpDir = await fs.mkdtemp(path.join(os.tmpdir(), "session-entry-test-"));
});
afterEach(async () => {
await fs.rm(tmpDir, { recursive: true, force: true });
});
it("returns lineMap tracking original JSONL line numbers", async () => {
// Simulate a real session JSONL file with metadata records interspersed
// Lines 1-3: non-message metadata records
// Line 4: user message
// Line 5: metadata
// Line 6: assistant message
// Line 7: user message
const jsonlLines = [
JSON.stringify({ type: "custom", customType: "model-snapshot", data: {} }),
JSON.stringify({ type: "custom", customType: "openclaw.cache-ttl", data: {} }),
JSON.stringify({ type: "session-meta", agentId: "test" }),
JSON.stringify({ type: "message", message: { role: "user", content: "Hello world" } }),
JSON.stringify({ type: "custom", customType: "tool-result", data: {} }),
JSON.stringify({
type: "message",
message: { role: "assistant", content: "Hi there, how can I help?" },
}),
JSON.stringify({ type: "message", message: { role: "user", content: "Tell me a joke" } }),
];
const filePath = path.join(tmpDir, "session.jsonl");
await fs.writeFile(filePath, jsonlLines.join("\n"));
const entry = await buildSessionEntry(filePath);
expect(entry).not.toBeNull();
// The content should have 3 lines (3 message records)
const contentLines = entry!.content.split("\n");
expect(contentLines).toHaveLength(3);
expect(contentLines[0]).toContain("User: Hello world");
expect(contentLines[1]).toContain("Assistant: Hi there");
expect(contentLines[2]).toContain("User: Tell me a joke");
// lineMap should map each content line to its original JSONL line (1-indexed)
// Content line 0 → JSONL line 4 (the first user message)
// Content line 1 → JSONL line 6 (the assistant message)
// Content line 2 → JSONL line 7 (the second user message)
expect(entry!.lineMap).toBeDefined();
expect(entry!.lineMap).toEqual([4, 6, 7]);
});
it("returns empty lineMap when no messages are found", async () => {
const jsonlLines = [
JSON.stringify({ type: "custom", customType: "model-snapshot", data: {} }),
JSON.stringify({ type: "session-meta", agentId: "test" }),
];
const filePath = path.join(tmpDir, "empty-session.jsonl");
await fs.writeFile(filePath, jsonlLines.join("\n"));
const entry = await buildSessionEntry(filePath);
expect(entry).not.toBeNull();
expect(entry!.content).toBe("");
expect(entry!.lineMap).toEqual([]);
});
it("skips blank lines and invalid JSON without breaking lineMap", async () => {
const jsonlLines = [
"",
"not valid json",
JSON.stringify({ type: "message", message: { role: "user", content: "First" } }),
"",
JSON.stringify({ type: "message", message: { role: "assistant", content: "Second" } }),
];
const filePath = path.join(tmpDir, "gaps.jsonl");
await fs.writeFile(filePath, jsonlLines.join("\n"));
const entry = await buildSessionEntry(filePath);
expect(entry).not.toBeNull();
expect(entry!.lineMap).toEqual([3, 5]);
});
});

View File

@@ -1,131 +0,0 @@
import fs from "node:fs/promises";
import path from "node:path";
import { resolveSessionTranscriptsDirForAgent } from "../../config/sessions/paths.js";
import { redactSensitiveText } from "../../logging/redact.js";
import { createSubsystemLogger } from "../../logging/subsystem.js";
import { hashText } from "./internal.js";
const log = createSubsystemLogger("memory");
export type SessionFileEntry = {
path: string;
absPath: string;
mtimeMs: number;
size: number;
hash: string;
content: string;
/** Maps each content line (0-indexed) to its 1-indexed JSONL source line. */
lineMap: number[];
};
export async function listSessionFilesForAgent(agentId: string): Promise<string[]> {
const dir = resolveSessionTranscriptsDirForAgent(agentId);
try {
const entries = await fs.readdir(dir, { withFileTypes: true });
return entries
.filter((entry) => entry.isFile())
.map((entry) => entry.name)
.filter((name) => name.endsWith(".jsonl"))
.map((name) => path.join(dir, name));
} catch {
return [];
}
}
export function sessionPathForFile(absPath: string): string {
return path.join("sessions", path.basename(absPath)).replace(/\\/g, "/");
}
function normalizeSessionText(value: string): string {
return value
.replace(/\s*\n+\s*/g, " ")
.replace(/\s+/g, " ")
.trim();
}
export function extractSessionText(content: unknown): string | null {
if (typeof content === "string") {
const normalized = normalizeSessionText(content);
return normalized ? normalized : null;
}
if (!Array.isArray(content)) {
return null;
}
const parts: string[] = [];
for (const block of content) {
if (!block || typeof block !== "object") {
continue;
}
const record = block as { type?: unknown; text?: unknown };
if (record.type !== "text" || typeof record.text !== "string") {
continue;
}
const normalized = normalizeSessionText(record.text);
if (normalized) {
parts.push(normalized);
}
}
if (parts.length === 0) {
return null;
}
return parts.join(" ");
}
export async function buildSessionEntry(absPath: string): Promise<SessionFileEntry | null> {
try {
const stat = await fs.stat(absPath);
const raw = await fs.readFile(absPath, "utf-8");
const lines = raw.split("\n");
const collected: string[] = [];
const lineMap: number[] = [];
for (let jsonlIdx = 0; jsonlIdx < lines.length; jsonlIdx++) {
const line = lines[jsonlIdx];
if (!line.trim()) {
continue;
}
let record: unknown;
try {
record = JSON.parse(line);
} catch {
continue;
}
if (
!record ||
typeof record !== "object" ||
(record as { type?: unknown }).type !== "message"
) {
continue;
}
const message = (record as { message?: unknown }).message as
| { role?: unknown; content?: unknown }
| undefined;
if (!message || typeof message.role !== "string") {
continue;
}
if (message.role !== "user" && message.role !== "assistant") {
continue;
}
const text = extractSessionText(message.content);
if (!text) {
continue;
}
const safe = redactSensitiveText(text, { mode: "tools" });
const label = message.role === "user" ? "User" : "Assistant";
collected.push(`${label}: ${safe}`);
lineMap.push(jsonlIdx + 1);
}
const content = collected.join("\n");
return {
path: sessionPathForFile(absPath),
absPath,
mtimeMs: stat.mtimeMs,
size: stat.size,
hash: hashText(content + "\n" + lineMap.join(",")),
content,
lineMap,
};
} catch (err) {
log.debug(`Failed reading session file ${absPath}: ${String(err)}`);
return null;
}
}

View File

@@ -1,24 +0,0 @@
import type { DatabaseSync } from "node:sqlite";
export async function loadSqliteVecExtension(params: {
db: DatabaseSync;
extensionPath?: string;
}): Promise<{ ok: boolean; extensionPath?: string; error?: string }> {
try {
const sqliteVec = await import("sqlite-vec");
const resolvedPath = params.extensionPath?.trim() ? params.extensionPath.trim() : undefined;
const extensionPath = resolvedPath ?? sqliteVec.getLoadablePath();
params.db.enableLoadExtension(true);
if (resolvedPath) {
params.db.loadExtension(extensionPath);
} else {
sqliteVec.load(params.db);
}
return { ok: true, extensionPath };
} catch (err) {
const message = err instanceof Error ? err.message : String(err);
return { ok: false, error: message };
}
}

View File

@@ -1,19 +0,0 @@
import { createRequire } from "node:module";
import { installProcessWarningFilter } from "../../infra/warning-filter.js";
const require = createRequire(import.meta.url);
export function requireNodeSqlite(): typeof import("node:sqlite") {
installProcessWarningFilter();
try {
return require("node:sqlite") as typeof import("node:sqlite");
} catch (err) {
const message = err instanceof Error ? err.message : String(err);
// Node distributions can ship without the experimental builtin SQLite module.
// Surface an actionable error instead of the generic "unknown builtin module".
throw new Error(
`SQLite support is unavailable in this Node runtime (missing node:sqlite). ${message}`,
{ cause: err },
);
}
}

View File

@@ -1,45 +0,0 @@
export type Tone = "ok" | "warn" | "muted";
export function resolveMemoryVectorState(vector: { enabled: boolean; available?: boolean }): {
tone: Tone;
state: "ready" | "unavailable" | "disabled" | "unknown";
} {
if (!vector.enabled) {
return { tone: "muted", state: "disabled" };
}
if (vector.available === true) {
return { tone: "ok", state: "ready" };
}
if (vector.available === false) {
return { tone: "warn", state: "unavailable" };
}
return { tone: "muted", state: "unknown" };
}
export function resolveMemoryFtsState(fts: { enabled: boolean; available: boolean }): {
tone: Tone;
state: "ready" | "unavailable" | "disabled";
} {
if (!fts.enabled) {
return { tone: "muted", state: "disabled" };
}
return fts.available ? { tone: "ok", state: "ready" } : { tone: "warn", state: "unavailable" };
}
export function resolveMemoryCacheSummary(cache: { enabled: boolean; entries?: number }): {
tone: Tone;
text: string;
} {
if (!cache.enabled) {
return { tone: "muted", text: "cache off" };
}
const suffix = typeof cache.entries === "number" ? ` (${cache.entries})` : "";
return { tone: "ok", text: `cache on${suffix}` };
}
export function resolveMemoryCacheState(cache: { enabled: boolean }): {
tone: Tone;
state: "enabled" | "disabled";
} {
return cache.enabled ? { tone: "ok", state: "enabled" } : { tone: "muted", state: "disabled" };
}

View File

@@ -1,14 +0,0 @@
import { vi } from "vitest";
import * as ssrf from "../../../infra/net/ssrf.js";
export function mockPublicPinnedHostname() {
return vi.spyOn(ssrf, "resolvePinnedHostnameWithPolicy").mockImplementation(async (hostname) => {
const normalized = hostname.trim().toLowerCase().replace(/\.$/, "");
const addresses = ["93.184.216.34"];
return {
hostname: normalized,
addresses,
lookup: ssrf.createPinnedLookup({ hostname: normalized, addresses }),
};
});
}

View File

@@ -1,81 +0,0 @@
export type MemorySource = "memory" | "sessions";
export type MemorySearchResult = {
path: string;
startLine: number;
endLine: number;
score: number;
snippet: string;
source: MemorySource;
citation?: string;
};
export type MemoryEmbeddingProbeResult = {
ok: boolean;
error?: string;
};
export type MemorySyncProgressUpdate = {
completed: number;
total: number;
label?: string;
};
export type MemoryProviderStatus = {
backend: "builtin" | "qmd";
provider: string;
model?: string;
requestedProvider?: string;
files?: number;
chunks?: number;
dirty?: boolean;
workspaceDir?: string;
dbPath?: string;
extraPaths?: string[];
sources?: MemorySource[];
sourceCounts?: Array<{ source: MemorySource; files: number; chunks: number }>;
cache?: { enabled: boolean; entries?: number; maxEntries?: number };
fts?: { enabled: boolean; available: boolean; error?: string };
fallback?: { from: string; reason?: string };
vector?: {
enabled: boolean;
available?: boolean;
extensionPath?: string;
loadError?: string;
dims?: number;
};
batch?: {
enabled: boolean;
failures: number;
limit: number;
wait: boolean;
concurrency: number;
pollIntervalMs: number;
timeoutMs: number;
lastError?: string;
lastProvider?: string;
};
custom?: Record<string, unknown>;
};
export interface MemorySearchManager {
search(
query: string,
opts?: { maxResults?: number; minScore?: number; sessionKey?: string },
): Promise<MemorySearchResult[]>;
readFile(params: {
relPath: string;
from?: number;
lines?: number;
}): Promise<{ text: string; path: string }>;
status(): MemoryProviderStatus;
sync?(params?: {
reason?: string;
force?: boolean;
sessionFiles?: string[];
progress?: (update: MemorySyncProgressUpdate) => void;
}): Promise<void>;
probeEmbeddingAvailability(): Promise<MemoryEmbeddingProbeResult>;
probeVectorAvailability(): Promise<boolean>;
close?(): Promise<void>;
}

View File

@@ -4,7 +4,7 @@ import type {
MemoryEmbeddingProbeResult,
MemoryProviderStatus,
MemorySyncProgressUpdate,
} from "../plugins/memory-host/types.js";
} from "../plugin-sdk/memory-core-host-engine-storage.js";
export type MemoryPromptSectionBuilder = (params: {
availableTools: Set<string>;