mirror of
https://github.com/openclaw/openclaw.git
synced 2026-03-29 19:01:44 +00:00
refactor: move memory host into sdk package
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import type { EmbeddingInput } from "../../packages/memory-host-sdk/src/host/embedding-inputs.js";
|
||||
import type { OpenClawConfig } from "../config/config.js";
|
||||
import type { SecretInput } from "../config/types.secrets.js";
|
||||
import type { EmbeddingInput } from "./memory-host/embedding-inputs.js";
|
||||
|
||||
export type MemoryEmbeddingBatchChunk = {
|
||||
text: string;
|
||||
|
||||
@@ -1,146 +0,0 @@
|
||||
import path from "node:path";
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { resolveAgentWorkspaceDir } from "../../agents/agent-scope.js";
|
||||
import type { OpenClawConfig } from "../../config/config.js";
|
||||
import { resolveMemoryBackendConfig } from "./backend-config.js";
|
||||
|
||||
describe("resolveMemoryBackendConfig", () => {
|
||||
it("defaults to builtin backend when config missing", () => {
|
||||
const cfg = { agents: { defaults: { workspace: "/tmp/memory-test" } } } as OpenClawConfig;
|
||||
const resolved = resolveMemoryBackendConfig({ cfg, agentId: "main" });
|
||||
expect(resolved.backend).toBe("builtin");
|
||||
expect(resolved.citations).toBe("auto");
|
||||
expect(resolved.qmd).toBeUndefined();
|
||||
});
|
||||
|
||||
it("resolves qmd backend with default collections", () => {
|
||||
const cfg = {
|
||||
agents: { defaults: { workspace: "/tmp/memory-test" } },
|
||||
memory: {
|
||||
backend: "qmd",
|
||||
qmd: {},
|
||||
},
|
||||
} as OpenClawConfig;
|
||||
const resolved = resolveMemoryBackendConfig({ cfg, agentId: "main" });
|
||||
expect(resolved.backend).toBe("qmd");
|
||||
expect(resolved.qmd?.collections.length).toBeGreaterThanOrEqual(3);
|
||||
expect(resolved.qmd?.command).toBe("qmd");
|
||||
expect(resolved.qmd?.searchMode).toBe("search");
|
||||
expect(resolved.qmd?.update.intervalMs).toBeGreaterThan(0);
|
||||
expect(resolved.qmd?.update.waitForBootSync).toBe(false);
|
||||
expect(resolved.qmd?.update.commandTimeoutMs).toBe(30_000);
|
||||
expect(resolved.qmd?.update.updateTimeoutMs).toBe(120_000);
|
||||
expect(resolved.qmd?.update.embedTimeoutMs).toBe(120_000);
|
||||
const names = new Set((resolved.qmd?.collections ?? []).map((collection) => collection.name));
|
||||
expect(names.has("memory-root-main")).toBe(true);
|
||||
expect(names.has("memory-alt-main")).toBe(true);
|
||||
expect(names.has("memory-dir-main")).toBe(true);
|
||||
});
|
||||
|
||||
it("parses quoted qmd command paths", () => {
|
||||
const cfg = {
|
||||
agents: { defaults: { workspace: "/tmp/memory-test" } },
|
||||
memory: {
|
||||
backend: "qmd",
|
||||
qmd: {
|
||||
command: '"/Applications/QMD Tools/qmd" --flag',
|
||||
},
|
||||
},
|
||||
} as OpenClawConfig;
|
||||
const resolved = resolveMemoryBackendConfig({ cfg, agentId: "main" });
|
||||
expect(resolved.qmd?.command).toBe("/Applications/QMD Tools/qmd");
|
||||
});
|
||||
|
||||
it("resolves custom paths relative to workspace", () => {
|
||||
const cfg = {
|
||||
agents: {
|
||||
defaults: { workspace: "/workspace/root" },
|
||||
list: [{ id: "main", workspace: "/workspace/root" }],
|
||||
},
|
||||
memory: {
|
||||
backend: "qmd",
|
||||
qmd: {
|
||||
paths: [
|
||||
{
|
||||
path: "notes",
|
||||
name: "custom-notes",
|
||||
pattern: "**/*.md",
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
} as OpenClawConfig;
|
||||
const resolved = resolveMemoryBackendConfig({ cfg, agentId: "main" });
|
||||
const custom = resolved.qmd?.collections.find((c) => c.name.startsWith("custom-notes"));
|
||||
expect(custom).toBeDefined();
|
||||
const workspaceRoot = resolveAgentWorkspaceDir(cfg, "main");
|
||||
expect(custom?.path).toBe(path.resolve(workspaceRoot, "notes"));
|
||||
});
|
||||
|
||||
it("scopes qmd collection names per agent", () => {
|
||||
const cfg = {
|
||||
agents: {
|
||||
defaults: { workspace: "/workspace/root" },
|
||||
list: [
|
||||
{ id: "main", default: true, workspace: "/workspace/root" },
|
||||
{ id: "dev", workspace: "/workspace/dev" },
|
||||
],
|
||||
},
|
||||
memory: {
|
||||
backend: "qmd",
|
||||
qmd: {
|
||||
includeDefaultMemory: true,
|
||||
paths: [{ path: "notes", name: "workspace", pattern: "**/*.md" }],
|
||||
},
|
||||
},
|
||||
} as OpenClawConfig;
|
||||
const mainResolved = resolveMemoryBackendConfig({ cfg, agentId: "main" });
|
||||
const devResolved = resolveMemoryBackendConfig({ cfg, agentId: "dev" });
|
||||
const mainNames = new Set(
|
||||
(mainResolved.qmd?.collections ?? []).map((collection) => collection.name),
|
||||
);
|
||||
const devNames = new Set(
|
||||
(devResolved.qmd?.collections ?? []).map((collection) => collection.name),
|
||||
);
|
||||
expect(mainNames.has("memory-dir-main")).toBe(true);
|
||||
expect(devNames.has("memory-dir-dev")).toBe(true);
|
||||
expect(mainNames.has("workspace-main")).toBe(true);
|
||||
expect(devNames.has("workspace-dev")).toBe(true);
|
||||
});
|
||||
|
||||
it("resolves qmd update timeout overrides", () => {
|
||||
const cfg = {
|
||||
agents: { defaults: { workspace: "/tmp/memory-test" } },
|
||||
memory: {
|
||||
backend: "qmd",
|
||||
qmd: {
|
||||
update: {
|
||||
waitForBootSync: true,
|
||||
commandTimeoutMs: 12_000,
|
||||
updateTimeoutMs: 480_000,
|
||||
embedTimeoutMs: 360_000,
|
||||
},
|
||||
},
|
||||
},
|
||||
} as OpenClawConfig;
|
||||
const resolved = resolveMemoryBackendConfig({ cfg, agentId: "main" });
|
||||
expect(resolved.qmd?.update.waitForBootSync).toBe(true);
|
||||
expect(resolved.qmd?.update.commandTimeoutMs).toBe(12_000);
|
||||
expect(resolved.qmd?.update.updateTimeoutMs).toBe(480_000);
|
||||
expect(resolved.qmd?.update.embedTimeoutMs).toBe(360_000);
|
||||
});
|
||||
|
||||
it("resolves qmd search mode override", () => {
|
||||
const cfg = {
|
||||
agents: { defaults: { workspace: "/tmp/memory-test" } },
|
||||
memory: {
|
||||
backend: "qmd",
|
||||
qmd: {
|
||||
searchMode: "vsearch",
|
||||
},
|
||||
},
|
||||
} as OpenClawConfig;
|
||||
const resolved = resolveMemoryBackendConfig({ cfg, agentId: "main" });
|
||||
expect(resolved.qmd?.searchMode).toBe("vsearch");
|
||||
});
|
||||
});
|
||||
@@ -1,354 +0,0 @@
|
||||
import path from "node:path";
|
||||
import { resolveAgentWorkspaceDir } from "../../agents/agent-scope.js";
|
||||
import { parseDurationMs } from "../../cli/parse-duration.js";
|
||||
import type { OpenClawConfig } from "../../config/config.js";
|
||||
import type { SessionSendPolicyConfig } from "../../config/types.base.js";
|
||||
import type {
|
||||
MemoryBackend,
|
||||
MemoryCitationsMode,
|
||||
MemoryQmdConfig,
|
||||
MemoryQmdIndexPath,
|
||||
MemoryQmdMcporterConfig,
|
||||
MemoryQmdSearchMode,
|
||||
} from "../../config/types.memory.js";
|
||||
import { resolveUserPath } from "../../utils.js";
|
||||
import { splitShellArgs } from "../../utils/shell-argv.js";
|
||||
|
||||
export type ResolvedMemoryBackendConfig = {
|
||||
backend: MemoryBackend;
|
||||
citations: MemoryCitationsMode;
|
||||
qmd?: ResolvedQmdConfig;
|
||||
};
|
||||
|
||||
export type ResolvedQmdCollection = {
|
||||
name: string;
|
||||
path: string;
|
||||
pattern: string;
|
||||
kind: "memory" | "custom" | "sessions";
|
||||
};
|
||||
|
||||
export type ResolvedQmdUpdateConfig = {
|
||||
intervalMs: number;
|
||||
debounceMs: number;
|
||||
onBoot: boolean;
|
||||
waitForBootSync: boolean;
|
||||
embedIntervalMs: number;
|
||||
commandTimeoutMs: number;
|
||||
updateTimeoutMs: number;
|
||||
embedTimeoutMs: number;
|
||||
};
|
||||
|
||||
export type ResolvedQmdLimitsConfig = {
|
||||
maxResults: number;
|
||||
maxSnippetChars: number;
|
||||
maxInjectedChars: number;
|
||||
timeoutMs: number;
|
||||
};
|
||||
|
||||
export type ResolvedQmdSessionConfig = {
|
||||
enabled: boolean;
|
||||
exportDir?: string;
|
||||
retentionDays?: number;
|
||||
};
|
||||
|
||||
export type ResolvedQmdMcporterConfig = {
|
||||
enabled: boolean;
|
||||
serverName: string;
|
||||
startDaemon: boolean;
|
||||
};
|
||||
|
||||
export type ResolvedQmdConfig = {
|
||||
command: string;
|
||||
mcporter: ResolvedQmdMcporterConfig;
|
||||
searchMode: MemoryQmdSearchMode;
|
||||
collections: ResolvedQmdCollection[];
|
||||
sessions: ResolvedQmdSessionConfig;
|
||||
update: ResolvedQmdUpdateConfig;
|
||||
limits: ResolvedQmdLimitsConfig;
|
||||
includeDefaultMemory: boolean;
|
||||
scope?: SessionSendPolicyConfig;
|
||||
};
|
||||
|
||||
const DEFAULT_BACKEND: MemoryBackend = "builtin";
|
||||
const DEFAULT_CITATIONS: MemoryCitationsMode = "auto";
|
||||
const DEFAULT_QMD_INTERVAL = "5m";
|
||||
const DEFAULT_QMD_DEBOUNCE_MS = 15_000;
|
||||
const DEFAULT_QMD_TIMEOUT_MS = 4_000;
|
||||
// Defaulting to `query` can be extremely slow on CPU-only systems (query expansion + rerank).
|
||||
// Prefer a faster mode for interactive use; users can opt into `query` for best recall.
|
||||
const DEFAULT_QMD_SEARCH_MODE: MemoryQmdSearchMode = "search";
|
||||
const DEFAULT_QMD_EMBED_INTERVAL = "60m";
|
||||
const DEFAULT_QMD_COMMAND_TIMEOUT_MS = 30_000;
|
||||
const DEFAULT_QMD_UPDATE_TIMEOUT_MS = 120_000;
|
||||
const DEFAULT_QMD_EMBED_TIMEOUT_MS = 120_000;
|
||||
const DEFAULT_QMD_LIMITS: ResolvedQmdLimitsConfig = {
|
||||
maxResults: 6,
|
||||
maxSnippetChars: 700,
|
||||
maxInjectedChars: 4_000,
|
||||
timeoutMs: DEFAULT_QMD_TIMEOUT_MS,
|
||||
};
|
||||
const DEFAULT_QMD_MCPORTER: ResolvedQmdMcporterConfig = {
|
||||
enabled: false,
|
||||
serverName: "qmd",
|
||||
startDaemon: true,
|
||||
};
|
||||
|
||||
const DEFAULT_QMD_SCOPE: SessionSendPolicyConfig = {
|
||||
default: "deny",
|
||||
rules: [
|
||||
{
|
||||
action: "allow",
|
||||
match: { chatType: "direct" },
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
function sanitizeName(input: string): string {
|
||||
const lower = input.toLowerCase().replace(/[^a-z0-9-]+/g, "-");
|
||||
const trimmed = lower.replace(/^-+|-+$/g, "");
|
||||
return trimmed || "collection";
|
||||
}
|
||||
|
||||
function scopeCollectionBase(base: string, agentId: string): string {
|
||||
return `${base}-${sanitizeName(agentId)}`;
|
||||
}
|
||||
|
||||
function ensureUniqueName(base: string, existing: Set<string>): string {
|
||||
let name = sanitizeName(base);
|
||||
if (!existing.has(name)) {
|
||||
existing.add(name);
|
||||
return name;
|
||||
}
|
||||
let suffix = 2;
|
||||
while (existing.has(`${name}-${suffix}`)) {
|
||||
suffix += 1;
|
||||
}
|
||||
const unique = `${name}-${suffix}`;
|
||||
existing.add(unique);
|
||||
return unique;
|
||||
}
|
||||
|
||||
function resolvePath(raw: string, workspaceDir: string): string {
|
||||
const trimmed = raw.trim();
|
||||
if (!trimmed) {
|
||||
throw new Error("path required");
|
||||
}
|
||||
if (trimmed.startsWith("~") || path.isAbsolute(trimmed)) {
|
||||
return path.normalize(resolveUserPath(trimmed));
|
||||
}
|
||||
return path.normalize(path.resolve(workspaceDir, trimmed));
|
||||
}
|
||||
|
||||
function resolveIntervalMs(raw: string | undefined): number {
|
||||
const value = raw?.trim();
|
||||
if (!value) {
|
||||
return parseDurationMs(DEFAULT_QMD_INTERVAL, { defaultUnit: "m" });
|
||||
}
|
||||
try {
|
||||
return parseDurationMs(value, { defaultUnit: "m" });
|
||||
} catch {
|
||||
return parseDurationMs(DEFAULT_QMD_INTERVAL, { defaultUnit: "m" });
|
||||
}
|
||||
}
|
||||
|
||||
function resolveEmbedIntervalMs(raw: string | undefined): number {
|
||||
const value = raw?.trim();
|
||||
if (!value) {
|
||||
return parseDurationMs(DEFAULT_QMD_EMBED_INTERVAL, { defaultUnit: "m" });
|
||||
}
|
||||
try {
|
||||
return parseDurationMs(value, { defaultUnit: "m" });
|
||||
} catch {
|
||||
return parseDurationMs(DEFAULT_QMD_EMBED_INTERVAL, { defaultUnit: "m" });
|
||||
}
|
||||
}
|
||||
|
||||
function resolveDebounceMs(raw: number | undefined): number {
|
||||
if (typeof raw === "number" && Number.isFinite(raw) && raw >= 0) {
|
||||
return Math.floor(raw);
|
||||
}
|
||||
return DEFAULT_QMD_DEBOUNCE_MS;
|
||||
}
|
||||
|
||||
function resolveTimeoutMs(raw: number | undefined, fallback: number): number {
|
||||
if (typeof raw === "number" && Number.isFinite(raw) && raw > 0) {
|
||||
return Math.floor(raw);
|
||||
}
|
||||
return fallback;
|
||||
}
|
||||
|
||||
function resolveLimits(raw?: MemoryQmdConfig["limits"]): ResolvedQmdLimitsConfig {
|
||||
const parsed: ResolvedQmdLimitsConfig = { ...DEFAULT_QMD_LIMITS };
|
||||
if (raw?.maxResults && raw.maxResults > 0) {
|
||||
parsed.maxResults = Math.floor(raw.maxResults);
|
||||
}
|
||||
if (raw?.maxSnippetChars && raw.maxSnippetChars > 0) {
|
||||
parsed.maxSnippetChars = Math.floor(raw.maxSnippetChars);
|
||||
}
|
||||
if (raw?.maxInjectedChars && raw.maxInjectedChars > 0) {
|
||||
parsed.maxInjectedChars = Math.floor(raw.maxInjectedChars);
|
||||
}
|
||||
if (raw?.timeoutMs && raw.timeoutMs > 0) {
|
||||
parsed.timeoutMs = Math.floor(raw.timeoutMs);
|
||||
}
|
||||
return parsed;
|
||||
}
|
||||
|
||||
function resolveSearchMode(raw?: MemoryQmdConfig["searchMode"]): MemoryQmdSearchMode {
|
||||
if (raw === "search" || raw === "vsearch" || raw === "query") {
|
||||
return raw;
|
||||
}
|
||||
return DEFAULT_QMD_SEARCH_MODE;
|
||||
}
|
||||
|
||||
function resolveSessionConfig(
|
||||
cfg: MemoryQmdConfig["sessions"],
|
||||
workspaceDir: string,
|
||||
): ResolvedQmdSessionConfig {
|
||||
const enabled = Boolean(cfg?.enabled);
|
||||
const exportDirRaw = cfg?.exportDir?.trim();
|
||||
const exportDir = exportDirRaw ? resolvePath(exportDirRaw, workspaceDir) : undefined;
|
||||
const retentionDays =
|
||||
cfg?.retentionDays && cfg.retentionDays > 0 ? Math.floor(cfg.retentionDays) : undefined;
|
||||
return {
|
||||
enabled,
|
||||
exportDir,
|
||||
retentionDays,
|
||||
};
|
||||
}
|
||||
|
||||
function resolveCustomPaths(
|
||||
rawPaths: MemoryQmdIndexPath[] | undefined,
|
||||
workspaceDir: string,
|
||||
existing: Set<string>,
|
||||
agentId: string,
|
||||
): ResolvedQmdCollection[] {
|
||||
if (!rawPaths?.length) {
|
||||
return [];
|
||||
}
|
||||
const collections: ResolvedQmdCollection[] = [];
|
||||
rawPaths.forEach((entry, index) => {
|
||||
const trimmedPath = entry?.path?.trim();
|
||||
if (!trimmedPath) {
|
||||
return;
|
||||
}
|
||||
let resolved: string;
|
||||
try {
|
||||
resolved = resolvePath(trimmedPath, workspaceDir);
|
||||
} catch {
|
||||
return;
|
||||
}
|
||||
const pattern = entry.pattern?.trim() || "**/*.md";
|
||||
const baseName = scopeCollectionBase(entry.name?.trim() || `custom-${index + 1}`, agentId);
|
||||
const name = ensureUniqueName(baseName, existing);
|
||||
collections.push({
|
||||
name,
|
||||
path: resolved,
|
||||
pattern,
|
||||
kind: "custom",
|
||||
});
|
||||
});
|
||||
return collections;
|
||||
}
|
||||
|
||||
function resolveMcporterConfig(raw?: MemoryQmdMcporterConfig): ResolvedQmdMcporterConfig {
|
||||
const parsed: ResolvedQmdMcporterConfig = { ...DEFAULT_QMD_MCPORTER };
|
||||
if (!raw) {
|
||||
return parsed;
|
||||
}
|
||||
if (raw.enabled !== undefined) {
|
||||
parsed.enabled = raw.enabled;
|
||||
}
|
||||
if (typeof raw.serverName === "string" && raw.serverName.trim()) {
|
||||
parsed.serverName = raw.serverName.trim();
|
||||
}
|
||||
if (raw.startDaemon !== undefined) {
|
||||
parsed.startDaemon = raw.startDaemon;
|
||||
}
|
||||
// When enabled, default startDaemon to true.
|
||||
if (parsed.enabled && raw.startDaemon === undefined) {
|
||||
parsed.startDaemon = true;
|
||||
}
|
||||
return parsed;
|
||||
}
|
||||
|
||||
function resolveDefaultCollections(
|
||||
include: boolean,
|
||||
workspaceDir: string,
|
||||
existing: Set<string>,
|
||||
agentId: string,
|
||||
): ResolvedQmdCollection[] {
|
||||
if (!include) {
|
||||
return [];
|
||||
}
|
||||
const entries: Array<{ path: string; pattern: string; base: string }> = [
|
||||
{ path: workspaceDir, pattern: "MEMORY.md", base: "memory-root" },
|
||||
{ path: workspaceDir, pattern: "memory.md", base: "memory-alt" },
|
||||
{ path: path.join(workspaceDir, "memory"), pattern: "**/*.md", base: "memory-dir" },
|
||||
];
|
||||
return entries.map((entry) => ({
|
||||
name: ensureUniqueName(scopeCollectionBase(entry.base, agentId), existing),
|
||||
path: entry.path,
|
||||
pattern: entry.pattern,
|
||||
kind: "memory",
|
||||
}));
|
||||
}
|
||||
|
||||
export function resolveMemoryBackendConfig(params: {
|
||||
cfg: OpenClawConfig;
|
||||
agentId: string;
|
||||
}): ResolvedMemoryBackendConfig {
|
||||
const backend = params.cfg.memory?.backend ?? DEFAULT_BACKEND;
|
||||
const citations = params.cfg.memory?.citations ?? DEFAULT_CITATIONS;
|
||||
if (backend !== "qmd") {
|
||||
return { backend: "builtin", citations };
|
||||
}
|
||||
|
||||
const workspaceDir = resolveAgentWorkspaceDir(params.cfg, params.agentId);
|
||||
const qmdCfg = params.cfg.memory?.qmd;
|
||||
const includeDefaultMemory = qmdCfg?.includeDefaultMemory !== false;
|
||||
const nameSet = new Set<string>();
|
||||
const collections = [
|
||||
...resolveDefaultCollections(includeDefaultMemory, workspaceDir, nameSet, params.agentId),
|
||||
...resolveCustomPaths(qmdCfg?.paths, workspaceDir, nameSet, params.agentId),
|
||||
];
|
||||
|
||||
const rawCommand = qmdCfg?.command?.trim() || "qmd";
|
||||
const parsedCommand = splitShellArgs(rawCommand);
|
||||
const command = parsedCommand?.[0] || rawCommand.split(/\s+/)[0] || "qmd";
|
||||
const resolved: ResolvedQmdConfig = {
|
||||
command,
|
||||
mcporter: resolveMcporterConfig(qmdCfg?.mcporter),
|
||||
searchMode: resolveSearchMode(qmdCfg?.searchMode),
|
||||
collections,
|
||||
includeDefaultMemory,
|
||||
sessions: resolveSessionConfig(qmdCfg?.sessions, workspaceDir),
|
||||
update: {
|
||||
intervalMs: resolveIntervalMs(qmdCfg?.update?.interval),
|
||||
debounceMs: resolveDebounceMs(qmdCfg?.update?.debounceMs),
|
||||
onBoot: qmdCfg?.update?.onBoot !== false,
|
||||
waitForBootSync: qmdCfg?.update?.waitForBootSync === true,
|
||||
embedIntervalMs: resolveEmbedIntervalMs(qmdCfg?.update?.embedInterval),
|
||||
commandTimeoutMs: resolveTimeoutMs(
|
||||
qmdCfg?.update?.commandTimeoutMs,
|
||||
DEFAULT_QMD_COMMAND_TIMEOUT_MS,
|
||||
),
|
||||
updateTimeoutMs: resolveTimeoutMs(
|
||||
qmdCfg?.update?.updateTimeoutMs,
|
||||
DEFAULT_QMD_UPDATE_TIMEOUT_MS,
|
||||
),
|
||||
embedTimeoutMs: resolveTimeoutMs(
|
||||
qmdCfg?.update?.embedTimeoutMs,
|
||||
DEFAULT_QMD_EMBED_TIMEOUT_MS,
|
||||
),
|
||||
},
|
||||
limits: resolveLimits(qmdCfg?.limits),
|
||||
scope: qmdCfg?.scope ?? DEFAULT_QMD_SCOPE,
|
||||
};
|
||||
|
||||
return {
|
||||
backend: "qmd",
|
||||
citations,
|
||||
qmd: resolved,
|
||||
};
|
||||
}
|
||||
@@ -1,22 +0,0 @@
|
||||
export { extractBatchErrorMessage, formatUnavailableBatchError } from "./batch-error-utils.js";
|
||||
export { postJsonWithRetry } from "./batch-http.js";
|
||||
export { applyEmbeddingBatchOutputLine } from "./batch-output.js";
|
||||
export {
|
||||
resolveBatchCompletionFromStatus,
|
||||
resolveCompletedBatchResult,
|
||||
throwIfBatchTerminalFailure,
|
||||
type BatchCompletionResult,
|
||||
} from "./batch-status.js";
|
||||
export {
|
||||
EMBEDDING_BATCH_ENDPOINT,
|
||||
type EmbeddingBatchStatus,
|
||||
type ProviderBatchOutputLine,
|
||||
} from "./batch-provider-common.js";
|
||||
export {
|
||||
buildEmbeddingBatchGroupOptions,
|
||||
runEmbeddingBatchGroups,
|
||||
type EmbeddingBatchExecutionParams,
|
||||
} from "./batch-runner.js";
|
||||
export { uploadBatchJsonlFile } from "./batch-upload.js";
|
||||
export { buildBatchHeaders, normalizeBatchBaseUrl } from "./batch-utils.js";
|
||||
export { withRemoteHttpResponse } from "./remote-http.js";
|
||||
@@ -1,32 +0,0 @@
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { extractBatchErrorMessage, formatUnavailableBatchError } from "./batch-error-utils.js";
|
||||
|
||||
describe("extractBatchErrorMessage", () => {
|
||||
it("returns the first top-level error message", () => {
|
||||
expect(
|
||||
extractBatchErrorMessage([
|
||||
{ response: { body: { error: { message: "nested" } } } },
|
||||
{ error: { message: "top-level" } },
|
||||
]),
|
||||
).toBe("nested");
|
||||
});
|
||||
|
||||
it("falls back to nested response error message", () => {
|
||||
expect(
|
||||
extractBatchErrorMessage([{ response: { body: { error: { message: "nested-only" } } } }, {}]),
|
||||
).toBe("nested-only");
|
||||
});
|
||||
|
||||
it("accepts plain string response bodies", () => {
|
||||
expect(extractBatchErrorMessage([{ response: { body: "provider plain-text error" } }])).toBe(
|
||||
"provider plain-text error",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("formatUnavailableBatchError", () => {
|
||||
it("formats errors and non-error values", () => {
|
||||
expect(formatUnavailableBatchError(new Error("boom"))).toBe("error file unavailable: boom");
|
||||
expect(formatUnavailableBatchError("unreachable")).toBe("error file unavailable: unreachable");
|
||||
});
|
||||
});
|
||||
@@ -1,31 +0,0 @@
|
||||
type BatchOutputErrorLike = {
|
||||
error?: { message?: string };
|
||||
response?: {
|
||||
body?:
|
||||
| string
|
||||
| {
|
||||
error?: { message?: string };
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
function getResponseErrorMessage(line: BatchOutputErrorLike | undefined): string | undefined {
|
||||
const body = line?.response?.body;
|
||||
if (typeof body === "string") {
|
||||
return body || undefined;
|
||||
}
|
||||
if (!body || typeof body !== "object") {
|
||||
return undefined;
|
||||
}
|
||||
return typeof body.error?.message === "string" ? body.error.message : undefined;
|
||||
}
|
||||
|
||||
export function extractBatchErrorMessage(lines: BatchOutputErrorLike[]): string | undefined {
|
||||
const first = lines.find((line) => line.error?.message || getResponseErrorMessage(line));
|
||||
return first?.error?.message ?? getResponseErrorMessage(first);
|
||||
}
|
||||
|
||||
export function formatUnavailableBatchError(err: unknown): string | undefined {
|
||||
const message = err instanceof Error ? err.message : String(err);
|
||||
return message ? `error file unavailable: ${message}` : undefined;
|
||||
}
|
||||
@@ -1,114 +0,0 @@
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import type { GeminiEmbeddingClient } from "./embeddings-gemini.js";
|
||||
|
||||
vi.mock("./remote-http.js", () => ({
|
||||
withRemoteHttpResponse: vi.fn(),
|
||||
}));
|
||||
|
||||
function magnitude(values: number[]) {
|
||||
return Math.sqrt(values.reduce((sum, value) => sum + value * value, 0));
|
||||
}
|
||||
|
||||
describe("runGeminiEmbeddingBatches", () => {
|
||||
let runGeminiEmbeddingBatches: typeof import("./batch-gemini.js").runGeminiEmbeddingBatches;
|
||||
let withRemoteHttpResponse: typeof import("./remote-http.js").withRemoteHttpResponse;
|
||||
let remoteHttpMock: ReturnType<typeof vi.mocked<typeof withRemoteHttpResponse>>;
|
||||
|
||||
beforeEach(async () => {
|
||||
vi.resetModules();
|
||||
vi.clearAllMocks();
|
||||
({ runGeminiEmbeddingBatches } = await import("./batch-gemini.js"));
|
||||
({ withRemoteHttpResponse } = await import("./remote-http.js"));
|
||||
remoteHttpMock = vi.mocked(withRemoteHttpResponse);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.resetAllMocks();
|
||||
vi.unstubAllGlobals();
|
||||
});
|
||||
|
||||
const mockClient: GeminiEmbeddingClient = {
|
||||
baseUrl: "https://generativelanguage.googleapis.com/v1beta",
|
||||
headers: {},
|
||||
model: "gemini-embedding-2-preview",
|
||||
modelPath: "models/gemini-embedding-2-preview",
|
||||
apiKeys: ["test-key"],
|
||||
outputDimensionality: 1536,
|
||||
};
|
||||
|
||||
it("includes outputDimensionality in batch upload requests", async () => {
|
||||
remoteHttpMock.mockImplementationOnce(async (params) => {
|
||||
expect(params.url).toContain("/upload/v1beta/files?uploadType=multipart");
|
||||
const body = params.init?.body;
|
||||
if (!(body instanceof Blob)) {
|
||||
throw new Error("expected multipart blob body");
|
||||
}
|
||||
const text = await body.text();
|
||||
expect(text).toContain('"taskType":"RETRIEVAL_DOCUMENT"');
|
||||
expect(text).toContain('"outputDimensionality":1536');
|
||||
return await params.onResponse(
|
||||
new Response(JSON.stringify({ name: "files/file-123" }), {
|
||||
status: 200,
|
||||
headers: { "Content-Type": "application/json" },
|
||||
}),
|
||||
);
|
||||
});
|
||||
remoteHttpMock.mockImplementationOnce(async (params) => {
|
||||
expect(params.url).toMatch(/:asyncBatchEmbedContent$/u);
|
||||
return await params.onResponse(
|
||||
new Response(
|
||||
JSON.stringify({
|
||||
name: "batches/batch-1",
|
||||
state: "COMPLETED",
|
||||
outputConfig: { file: "files/output-1" },
|
||||
}),
|
||||
{
|
||||
status: 200,
|
||||
headers: { "Content-Type": "application/json" },
|
||||
},
|
||||
),
|
||||
);
|
||||
});
|
||||
remoteHttpMock.mockImplementationOnce(async (params) => {
|
||||
expect(params.url).toMatch(/\/files\/output-1:download$/u);
|
||||
return await params.onResponse(
|
||||
new Response(
|
||||
JSON.stringify({
|
||||
key: "req-1",
|
||||
response: { embedding: { values: [3, 4] } },
|
||||
}),
|
||||
{
|
||||
status: 200,
|
||||
headers: { "Content-Type": "application/jsonl" },
|
||||
},
|
||||
),
|
||||
);
|
||||
});
|
||||
|
||||
const results = await runGeminiEmbeddingBatches({
|
||||
gemini: mockClient,
|
||||
agentId: "main",
|
||||
requests: [
|
||||
{
|
||||
custom_id: "req-1",
|
||||
request: {
|
||||
content: { parts: [{ text: "hello world" }] },
|
||||
taskType: "RETRIEVAL_DOCUMENT",
|
||||
outputDimensionality: 1536,
|
||||
},
|
||||
},
|
||||
],
|
||||
wait: true,
|
||||
pollIntervalMs: 1,
|
||||
timeoutMs: 1000,
|
||||
concurrency: 1,
|
||||
});
|
||||
|
||||
const embedding = results.get("req-1");
|
||||
expect(embedding).toBeDefined();
|
||||
expect(embedding?.[0]).toBeCloseTo(0.6, 5);
|
||||
expect(embedding?.[1]).toBeCloseTo(0.8, 5);
|
||||
expect(magnitude(embedding ?? [])).toBeCloseTo(1, 5);
|
||||
expect(remoteHttpMock).toHaveBeenCalledTimes(3);
|
||||
});
|
||||
});
|
||||
@@ -1,368 +0,0 @@
|
||||
import {
|
||||
buildEmbeddingBatchGroupOptions,
|
||||
runEmbeddingBatchGroups,
|
||||
type EmbeddingBatchExecutionParams,
|
||||
} from "./batch-runner.js";
|
||||
import { buildBatchHeaders, normalizeBatchBaseUrl } from "./batch-utils.js";
|
||||
import { sanitizeAndNormalizeEmbedding } from "./embedding-vectors.js";
|
||||
import { debugEmbeddingsLog } from "./embeddings-debug.js";
|
||||
import type { GeminiEmbeddingClient, GeminiTextEmbeddingRequest } from "./embeddings-gemini.js";
|
||||
import { hashText } from "./internal.js";
|
||||
import { withRemoteHttpResponse } from "./remote-http.js";
|
||||
|
||||
export type GeminiBatchRequest = {
|
||||
custom_id: string;
|
||||
request: GeminiTextEmbeddingRequest;
|
||||
};
|
||||
|
||||
export type GeminiBatchStatus = {
|
||||
name?: string;
|
||||
state?: string;
|
||||
outputConfig?: { file?: string; fileId?: string };
|
||||
metadata?: {
|
||||
output?: {
|
||||
responsesFile?: string;
|
||||
};
|
||||
};
|
||||
error?: { message?: string };
|
||||
};
|
||||
|
||||
export type GeminiBatchOutputLine = {
|
||||
key?: string;
|
||||
custom_id?: string;
|
||||
request_id?: string;
|
||||
embedding?: { values?: number[] };
|
||||
response?: {
|
||||
embedding?: { values?: number[] };
|
||||
error?: { message?: string };
|
||||
};
|
||||
error?: { message?: string };
|
||||
};
|
||||
|
||||
const GEMINI_BATCH_MAX_REQUESTS = 50000;
|
||||
function getGeminiUploadUrl(baseUrl: string): string {
|
||||
if (baseUrl.includes("/v1beta")) {
|
||||
return baseUrl.replace(/\/v1beta\/?$/, "/upload/v1beta");
|
||||
}
|
||||
return `${baseUrl.replace(/\/$/, "")}/upload`;
|
||||
}
|
||||
|
||||
function buildGeminiUploadBody(params: { jsonl: string; displayName: string }): {
|
||||
body: Blob;
|
||||
contentType: string;
|
||||
} {
|
||||
const boundary = `openclaw-${hashText(params.displayName)}`;
|
||||
const jsonPart = JSON.stringify({
|
||||
file: {
|
||||
displayName: params.displayName,
|
||||
mimeType: "application/jsonl",
|
||||
},
|
||||
});
|
||||
const delimiter = `--${boundary}\r\n`;
|
||||
const closeDelimiter = `--${boundary}--\r\n`;
|
||||
const parts = [
|
||||
`${delimiter}Content-Type: application/json; charset=UTF-8\r\n\r\n${jsonPart}\r\n`,
|
||||
`${delimiter}Content-Type: application/jsonl; charset=UTF-8\r\n\r\n${params.jsonl}\r\n`,
|
||||
closeDelimiter,
|
||||
];
|
||||
const body = new Blob([parts.join("")], { type: "multipart/related" });
|
||||
return {
|
||||
body,
|
||||
contentType: `multipart/related; boundary=${boundary}`,
|
||||
};
|
||||
}
|
||||
|
||||
async function submitGeminiBatch(params: {
|
||||
gemini: GeminiEmbeddingClient;
|
||||
requests: GeminiBatchRequest[];
|
||||
agentId: string;
|
||||
}): Promise<GeminiBatchStatus> {
|
||||
const baseUrl = normalizeBatchBaseUrl(params.gemini);
|
||||
const jsonl = params.requests
|
||||
.map((request) =>
|
||||
JSON.stringify({
|
||||
key: request.custom_id,
|
||||
request: request.request,
|
||||
}),
|
||||
)
|
||||
.join("\n");
|
||||
const displayName = `memory-embeddings-${hashText(String(Date.now()))}`;
|
||||
const uploadPayload = buildGeminiUploadBody({ jsonl, displayName });
|
||||
|
||||
const uploadUrl = `${getGeminiUploadUrl(baseUrl)}/files?uploadType=multipart`;
|
||||
debugEmbeddingsLog("memory embeddings: gemini batch upload", {
|
||||
uploadUrl,
|
||||
baseUrl,
|
||||
requests: params.requests.length,
|
||||
});
|
||||
const filePayload = await withRemoteHttpResponse({
|
||||
url: uploadUrl,
|
||||
ssrfPolicy: params.gemini.ssrfPolicy,
|
||||
init: {
|
||||
method: "POST",
|
||||
headers: {
|
||||
...buildBatchHeaders(params.gemini, { json: false }),
|
||||
"Content-Type": uploadPayload.contentType,
|
||||
},
|
||||
body: uploadPayload.body,
|
||||
},
|
||||
onResponse: async (fileRes) => {
|
||||
if (!fileRes.ok) {
|
||||
const text = await fileRes.text();
|
||||
throw new Error(`gemini batch file upload failed: ${fileRes.status} ${text}`);
|
||||
}
|
||||
return (await fileRes.json()) as { name?: string; file?: { name?: string } };
|
||||
},
|
||||
});
|
||||
const fileId = filePayload.name ?? filePayload.file?.name;
|
||||
if (!fileId) {
|
||||
throw new Error("gemini batch file upload failed: missing file id");
|
||||
}
|
||||
|
||||
const batchBody = {
|
||||
batch: {
|
||||
displayName: `memory-embeddings-${params.agentId}`,
|
||||
inputConfig: {
|
||||
file_name: fileId,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const batchEndpoint = `${baseUrl}/${params.gemini.modelPath}:asyncBatchEmbedContent`;
|
||||
debugEmbeddingsLog("memory embeddings: gemini batch create", {
|
||||
batchEndpoint,
|
||||
fileId,
|
||||
});
|
||||
return await withRemoteHttpResponse({
|
||||
url: batchEndpoint,
|
||||
ssrfPolicy: params.gemini.ssrfPolicy,
|
||||
init: {
|
||||
method: "POST",
|
||||
headers: buildBatchHeaders(params.gemini, { json: true }),
|
||||
body: JSON.stringify(batchBody),
|
||||
},
|
||||
onResponse: async (batchRes) => {
|
||||
if (batchRes.ok) {
|
||||
return (await batchRes.json()) as GeminiBatchStatus;
|
||||
}
|
||||
const text = await batchRes.text();
|
||||
if (batchRes.status === 404) {
|
||||
throw new Error(
|
||||
"gemini batch create failed: 404 (asyncBatchEmbedContent not available for this model/baseUrl). Disable remote.batch.enabled or switch providers.",
|
||||
);
|
||||
}
|
||||
throw new Error(`gemini batch create failed: ${batchRes.status} ${text}`);
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
async function fetchGeminiBatchStatus(params: {
|
||||
gemini: GeminiEmbeddingClient;
|
||||
batchName: string;
|
||||
}): Promise<GeminiBatchStatus> {
|
||||
const baseUrl = normalizeBatchBaseUrl(params.gemini);
|
||||
const name = params.batchName.startsWith("batches/")
|
||||
? params.batchName
|
||||
: `batches/${params.batchName}`;
|
||||
const statusUrl = `${baseUrl}/${name}`;
|
||||
debugEmbeddingsLog("memory embeddings: gemini batch status", { statusUrl });
|
||||
return await withRemoteHttpResponse({
|
||||
url: statusUrl,
|
||||
ssrfPolicy: params.gemini.ssrfPolicy,
|
||||
init: {
|
||||
headers: buildBatchHeaders(params.gemini, { json: true }),
|
||||
},
|
||||
onResponse: async (res) => {
|
||||
if (!res.ok) {
|
||||
const text = await res.text();
|
||||
throw new Error(`gemini batch status failed: ${res.status} ${text}`);
|
||||
}
|
||||
return (await res.json()) as GeminiBatchStatus;
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
async function fetchGeminiFileContent(params: {
|
||||
gemini: GeminiEmbeddingClient;
|
||||
fileId: string;
|
||||
}): Promise<string> {
|
||||
const baseUrl = normalizeBatchBaseUrl(params.gemini);
|
||||
const file = params.fileId.startsWith("files/") ? params.fileId : `files/${params.fileId}`;
|
||||
const downloadUrl = `${baseUrl}/${file}:download`;
|
||||
debugEmbeddingsLog("memory embeddings: gemini batch download", { downloadUrl });
|
||||
return await withRemoteHttpResponse({
|
||||
url: downloadUrl,
|
||||
ssrfPolicy: params.gemini.ssrfPolicy,
|
||||
init: {
|
||||
headers: buildBatchHeaders(params.gemini, { json: true }),
|
||||
},
|
||||
onResponse: async (res) => {
|
||||
if (!res.ok) {
|
||||
const text = await res.text();
|
||||
throw new Error(`gemini batch file content failed: ${res.status} ${text}`);
|
||||
}
|
||||
return await res.text();
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
function parseGeminiBatchOutput(text: string): GeminiBatchOutputLine[] {
|
||||
if (!text.trim()) {
|
||||
return [];
|
||||
}
|
||||
return text
|
||||
.split("\n")
|
||||
.map((line) => line.trim())
|
||||
.filter(Boolean)
|
||||
.map((line) => JSON.parse(line) as GeminiBatchOutputLine);
|
||||
}
|
||||
|
||||
async function waitForGeminiBatch(params: {
|
||||
gemini: GeminiEmbeddingClient;
|
||||
batchName: string;
|
||||
wait: boolean;
|
||||
pollIntervalMs: number;
|
||||
timeoutMs: number;
|
||||
debug?: (message: string, data?: Record<string, unknown>) => void;
|
||||
initial?: GeminiBatchStatus;
|
||||
}): Promise<{ outputFileId: string }> {
|
||||
const start = Date.now();
|
||||
let current: GeminiBatchStatus | undefined = params.initial;
|
||||
while (true) {
|
||||
const status =
|
||||
current ??
|
||||
(await fetchGeminiBatchStatus({
|
||||
gemini: params.gemini,
|
||||
batchName: params.batchName,
|
||||
}));
|
||||
const state = status.state ?? "UNKNOWN";
|
||||
if (["SUCCEEDED", "COMPLETED", "DONE"].includes(state)) {
|
||||
const outputFileId =
|
||||
status.outputConfig?.file ??
|
||||
status.outputConfig?.fileId ??
|
||||
status.metadata?.output?.responsesFile;
|
||||
if (!outputFileId) {
|
||||
throw new Error(`gemini batch ${params.batchName} completed without output file`);
|
||||
}
|
||||
return { outputFileId };
|
||||
}
|
||||
if (["FAILED", "CANCELLED", "CANCELED", "EXPIRED"].includes(state)) {
|
||||
const message = status.error?.message ?? "unknown error";
|
||||
throw new Error(`gemini batch ${params.batchName} ${state}: ${message}`);
|
||||
}
|
||||
if (!params.wait) {
|
||||
throw new Error(`gemini batch ${params.batchName} still ${state}; wait disabled`);
|
||||
}
|
||||
if (Date.now() - start > params.timeoutMs) {
|
||||
throw new Error(`gemini batch ${params.batchName} timed out after ${params.timeoutMs}ms`);
|
||||
}
|
||||
params.debug?.(`gemini batch ${params.batchName} ${state}; waiting ${params.pollIntervalMs}ms`);
|
||||
await new Promise((resolve) => setTimeout(resolve, params.pollIntervalMs));
|
||||
current = undefined;
|
||||
}
|
||||
}
|
||||
|
||||
export async function runGeminiEmbeddingBatches(
|
||||
params: {
|
||||
gemini: GeminiEmbeddingClient;
|
||||
agentId: string;
|
||||
requests: GeminiBatchRequest[];
|
||||
} & EmbeddingBatchExecutionParams,
|
||||
): Promise<Map<string, number[]>> {
|
||||
return await runEmbeddingBatchGroups({
|
||||
...buildEmbeddingBatchGroupOptions(params, {
|
||||
maxRequests: GEMINI_BATCH_MAX_REQUESTS,
|
||||
debugLabel: "memory embeddings: gemini batch submit",
|
||||
}),
|
||||
runGroup: async ({ group, groupIndex, groups, byCustomId }) => {
|
||||
const batchInfo = await submitGeminiBatch({
|
||||
gemini: params.gemini,
|
||||
requests: group,
|
||||
agentId: params.agentId,
|
||||
});
|
||||
const batchName = batchInfo.name ?? "";
|
||||
if (!batchName) {
|
||||
throw new Error("gemini batch create failed: missing batch name");
|
||||
}
|
||||
|
||||
params.debug?.("memory embeddings: gemini batch created", {
|
||||
batchName,
|
||||
state: batchInfo.state,
|
||||
group: groupIndex + 1,
|
||||
groups,
|
||||
requests: group.length,
|
||||
});
|
||||
|
||||
if (
|
||||
!params.wait &&
|
||||
batchInfo.state &&
|
||||
!["SUCCEEDED", "COMPLETED", "DONE"].includes(batchInfo.state)
|
||||
) {
|
||||
throw new Error(
|
||||
`gemini batch ${batchName} submitted; enable remote.batch.wait to await completion`,
|
||||
);
|
||||
}
|
||||
|
||||
const completed =
|
||||
batchInfo.state && ["SUCCEEDED", "COMPLETED", "DONE"].includes(batchInfo.state)
|
||||
? {
|
||||
outputFileId:
|
||||
batchInfo.outputConfig?.file ??
|
||||
batchInfo.outputConfig?.fileId ??
|
||||
batchInfo.metadata?.output?.responsesFile ??
|
||||
"",
|
||||
}
|
||||
: await waitForGeminiBatch({
|
||||
gemini: params.gemini,
|
||||
batchName,
|
||||
wait: params.wait,
|
||||
pollIntervalMs: params.pollIntervalMs,
|
||||
timeoutMs: params.timeoutMs,
|
||||
debug: params.debug,
|
||||
initial: batchInfo,
|
||||
});
|
||||
if (!completed.outputFileId) {
|
||||
throw new Error(`gemini batch ${batchName} completed without output file`);
|
||||
}
|
||||
|
||||
const content = await fetchGeminiFileContent({
|
||||
gemini: params.gemini,
|
||||
fileId: completed.outputFileId,
|
||||
});
|
||||
const outputLines = parseGeminiBatchOutput(content);
|
||||
const errors: string[] = [];
|
||||
const remaining = new Set(group.map((request) => request.custom_id));
|
||||
|
||||
for (const line of outputLines) {
|
||||
const customId = line.key ?? line.custom_id ?? line.request_id;
|
||||
if (!customId) {
|
||||
continue;
|
||||
}
|
||||
remaining.delete(customId);
|
||||
if (line.error?.message) {
|
||||
errors.push(`${customId}: ${line.error.message}`);
|
||||
continue;
|
||||
}
|
||||
if (line.response?.error?.message) {
|
||||
errors.push(`${customId}: ${line.response.error.message}`);
|
||||
continue;
|
||||
}
|
||||
const embedding = sanitizeAndNormalizeEmbedding(
|
||||
line.embedding?.values ?? line.response?.embedding?.values ?? [],
|
||||
);
|
||||
if (embedding.length === 0) {
|
||||
errors.push(`${customId}: empty embedding`);
|
||||
continue;
|
||||
}
|
||||
byCustomId.set(customId, embedding);
|
||||
}
|
||||
|
||||
if (errors.length > 0) {
|
||||
throw new Error(`gemini batch ${batchName} failed: ${errors.join("; ")}`);
|
||||
}
|
||||
if (remaining.size > 0) {
|
||||
throw new Error(`gemini batch ${batchName} missing ${remaining.size} embedding responses`);
|
||||
}
|
||||
},
|
||||
});
|
||||
}
|
||||
@@ -1,85 +0,0 @@
|
||||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
|
||||
vi.mock("../../infra/retry.js", () => ({
|
||||
retryAsync: vi.fn(async (run: () => Promise<unknown>) => await run()),
|
||||
}));
|
||||
|
||||
vi.mock("./post-json.js", () => ({
|
||||
postJson: vi.fn(),
|
||||
}));
|
||||
|
||||
describe("postJsonWithRetry", () => {
|
||||
let retryAsyncMock: ReturnType<
|
||||
typeof vi.mocked<typeof import("../../infra/retry.js").retryAsync>
|
||||
>;
|
||||
let postJsonMock: ReturnType<typeof vi.mocked<typeof import("./post-json.js").postJson>>;
|
||||
let postJsonWithRetry: typeof import("./batch-http.js").postJsonWithRetry;
|
||||
|
||||
beforeEach(async () => {
|
||||
vi.resetModules();
|
||||
vi.clearAllMocks();
|
||||
vi.resetModules();
|
||||
({ postJsonWithRetry } = await import("./batch-http.js"));
|
||||
const retryModule = await import("../../infra/retry.js");
|
||||
const postJsonModule = await import("./post-json.js");
|
||||
retryAsyncMock = vi.mocked(retryModule.retryAsync);
|
||||
postJsonMock = vi.mocked(postJsonModule.postJson);
|
||||
});
|
||||
|
||||
it("posts JSON and returns parsed response payload", async () => {
|
||||
postJsonMock.mockImplementationOnce(async (params) => {
|
||||
return await params.parse({ ok: true, ids: [1, 2] });
|
||||
});
|
||||
|
||||
const result = await postJsonWithRetry<{ ok: boolean; ids: number[] }>({
|
||||
url: "https://memory.example/v1/batch",
|
||||
headers: { Authorization: "Bearer test" },
|
||||
body: { chunks: ["a", "b"] },
|
||||
errorPrefix: "memory batch failed",
|
||||
});
|
||||
|
||||
expect(result).toEqual({ ok: true, ids: [1, 2] });
|
||||
expect(postJsonMock).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
url: "https://memory.example/v1/batch",
|
||||
headers: { Authorization: "Bearer test" },
|
||||
body: { chunks: ["a", "b"] },
|
||||
errorPrefix: "memory batch failed",
|
||||
attachStatus: true,
|
||||
}),
|
||||
);
|
||||
|
||||
const retryOptions = retryAsyncMock.mock.calls[0]?.[1] as
|
||||
| {
|
||||
attempts: number;
|
||||
minDelayMs: number;
|
||||
maxDelayMs: number;
|
||||
shouldRetry: (err: unknown) => boolean;
|
||||
}
|
||||
| undefined;
|
||||
expect(retryOptions?.attempts).toBe(3);
|
||||
expect(retryOptions?.minDelayMs).toBe(300);
|
||||
expect(retryOptions?.maxDelayMs).toBe(2000);
|
||||
expect(retryOptions?.shouldRetry({ status: 429 })).toBe(true);
|
||||
expect(retryOptions?.shouldRetry({ status: 503 })).toBe(true);
|
||||
expect(retryOptions?.shouldRetry({ status: 400 })).toBe(false);
|
||||
});
|
||||
|
||||
it("attaches status to non-ok errors", async () => {
|
||||
postJsonMock.mockRejectedValueOnce(
|
||||
Object.assign(new Error("memory batch failed: 503 backend down"), { status: 503 }),
|
||||
);
|
||||
|
||||
await expect(
|
||||
postJsonWithRetry({
|
||||
url: "https://memory.example/v1/batch",
|
||||
headers: {},
|
||||
body: { chunks: [] },
|
||||
errorPrefix: "memory batch failed",
|
||||
}),
|
||||
).rejects.toMatchObject({
|
||||
message: expect.stringContaining("memory batch failed: 503 backend down"),
|
||||
status: 503,
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,35 +0,0 @@
|
||||
import type { SsrFPolicy } from "../../infra/net/ssrf.js";
|
||||
import { retryAsync } from "../../infra/retry.js";
|
||||
import { postJson } from "./post-json.js";
|
||||
|
||||
export async function postJsonWithRetry<T>(params: {
|
||||
url: string;
|
||||
headers: Record<string, string>;
|
||||
ssrfPolicy?: SsrFPolicy;
|
||||
body: unknown;
|
||||
errorPrefix: string;
|
||||
}): Promise<T> {
|
||||
return await retryAsync(
|
||||
async () => {
|
||||
return await postJson<T>({
|
||||
url: params.url,
|
||||
headers: params.headers,
|
||||
ssrfPolicy: params.ssrfPolicy,
|
||||
body: params.body,
|
||||
errorPrefix: params.errorPrefix,
|
||||
attachStatus: true,
|
||||
parse: async (payload) => payload as T,
|
||||
});
|
||||
},
|
||||
{
|
||||
attempts: 3,
|
||||
minDelayMs: 300,
|
||||
maxDelayMs: 2000,
|
||||
jitter: 0.2,
|
||||
shouldRetry: (err) => {
|
||||
const status = (err as { status?: number }).status;
|
||||
return status === 429 || (typeof status === "number" && status >= 500);
|
||||
},
|
||||
},
|
||||
);
|
||||
}
|
||||
@@ -1,259 +0,0 @@
|
||||
import {
|
||||
applyEmbeddingBatchOutputLine,
|
||||
buildBatchHeaders,
|
||||
buildEmbeddingBatchGroupOptions,
|
||||
EMBEDDING_BATCH_ENDPOINT,
|
||||
extractBatchErrorMessage,
|
||||
formatUnavailableBatchError,
|
||||
normalizeBatchBaseUrl,
|
||||
postJsonWithRetry,
|
||||
resolveBatchCompletionFromStatus,
|
||||
resolveCompletedBatchResult,
|
||||
runEmbeddingBatchGroups,
|
||||
throwIfBatchTerminalFailure,
|
||||
type EmbeddingBatchExecutionParams,
|
||||
type EmbeddingBatchStatus,
|
||||
type BatchCompletionResult,
|
||||
type ProviderBatchOutputLine,
|
||||
uploadBatchJsonlFile,
|
||||
withRemoteHttpResponse,
|
||||
} from "./batch-embedding-common.js";
|
||||
import type { OpenAiEmbeddingClient } from "./embeddings-openai.js";
|
||||
|
||||
export type OpenAiBatchRequest = {
|
||||
custom_id: string;
|
||||
method: "POST";
|
||||
url: "/v1/embeddings";
|
||||
body: {
|
||||
model: string;
|
||||
input: string;
|
||||
};
|
||||
};
|
||||
|
||||
export type OpenAiBatchStatus = EmbeddingBatchStatus;
|
||||
export type OpenAiBatchOutputLine = ProviderBatchOutputLine;
|
||||
|
||||
export const OPENAI_BATCH_ENDPOINT = EMBEDDING_BATCH_ENDPOINT;
|
||||
const OPENAI_BATCH_COMPLETION_WINDOW = "24h";
|
||||
const OPENAI_BATCH_MAX_REQUESTS = 50000;
|
||||
|
||||
async function submitOpenAiBatch(params: {
|
||||
openAi: OpenAiEmbeddingClient;
|
||||
requests: OpenAiBatchRequest[];
|
||||
agentId: string;
|
||||
}): Promise<OpenAiBatchStatus> {
|
||||
const baseUrl = normalizeBatchBaseUrl(params.openAi);
|
||||
const inputFileId = await uploadBatchJsonlFile({
|
||||
client: params.openAi,
|
||||
requests: params.requests,
|
||||
errorPrefix: "openai batch file upload failed",
|
||||
});
|
||||
|
||||
return await postJsonWithRetry<OpenAiBatchStatus>({
|
||||
url: `${baseUrl}/batches`,
|
||||
headers: buildBatchHeaders(params.openAi, { json: true }),
|
||||
ssrfPolicy: params.openAi.ssrfPolicy,
|
||||
body: {
|
||||
input_file_id: inputFileId,
|
||||
endpoint: OPENAI_BATCH_ENDPOINT,
|
||||
completion_window: OPENAI_BATCH_COMPLETION_WINDOW,
|
||||
metadata: {
|
||||
source: "openclaw-memory",
|
||||
agent: params.agentId,
|
||||
},
|
||||
},
|
||||
errorPrefix: "openai batch create failed",
|
||||
});
|
||||
}
|
||||
|
||||
async function fetchOpenAiBatchStatus(params: {
|
||||
openAi: OpenAiEmbeddingClient;
|
||||
batchId: string;
|
||||
}): Promise<OpenAiBatchStatus> {
|
||||
return await fetchOpenAiBatchResource({
|
||||
openAi: params.openAi,
|
||||
path: `/batches/${params.batchId}`,
|
||||
errorPrefix: "openai batch status",
|
||||
parse: async (res) => (await res.json()) as OpenAiBatchStatus,
|
||||
});
|
||||
}
|
||||
|
||||
async function fetchOpenAiFileContent(params: {
|
||||
openAi: OpenAiEmbeddingClient;
|
||||
fileId: string;
|
||||
}): Promise<string> {
|
||||
return await fetchOpenAiBatchResource({
|
||||
openAi: params.openAi,
|
||||
path: `/files/${params.fileId}/content`,
|
||||
errorPrefix: "openai batch file content",
|
||||
parse: async (res) => await res.text(),
|
||||
});
|
||||
}
|
||||
|
||||
async function fetchOpenAiBatchResource<T>(params: {
|
||||
openAi: OpenAiEmbeddingClient;
|
||||
path: string;
|
||||
errorPrefix: string;
|
||||
parse: (res: Response) => Promise<T>;
|
||||
}): Promise<T> {
|
||||
const baseUrl = normalizeBatchBaseUrl(params.openAi);
|
||||
return await withRemoteHttpResponse({
|
||||
url: `${baseUrl}${params.path}`,
|
||||
ssrfPolicy: params.openAi.ssrfPolicy,
|
||||
init: {
|
||||
headers: buildBatchHeaders(params.openAi, { json: true }),
|
||||
},
|
||||
onResponse: async (res) => {
|
||||
if (!res.ok) {
|
||||
const text = await res.text();
|
||||
throw new Error(`${params.errorPrefix} failed: ${res.status} ${text}`);
|
||||
}
|
||||
return await params.parse(res);
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
function parseOpenAiBatchOutput(text: string): OpenAiBatchOutputLine[] {
|
||||
if (!text.trim()) {
|
||||
return [];
|
||||
}
|
||||
return text
|
||||
.split("\n")
|
||||
.map((line) => line.trim())
|
||||
.filter(Boolean)
|
||||
.map((line) => JSON.parse(line) as OpenAiBatchOutputLine);
|
||||
}
|
||||
|
||||
async function readOpenAiBatchError(params: {
|
||||
openAi: OpenAiEmbeddingClient;
|
||||
errorFileId: string;
|
||||
}): Promise<string | undefined> {
|
||||
try {
|
||||
const content = await fetchOpenAiFileContent({
|
||||
openAi: params.openAi,
|
||||
fileId: params.errorFileId,
|
||||
});
|
||||
const lines = parseOpenAiBatchOutput(content);
|
||||
return extractBatchErrorMessage(lines);
|
||||
} catch (err) {
|
||||
return formatUnavailableBatchError(err);
|
||||
}
|
||||
}
|
||||
|
||||
async function waitForOpenAiBatch(params: {
|
||||
openAi: OpenAiEmbeddingClient;
|
||||
batchId: string;
|
||||
wait: boolean;
|
||||
pollIntervalMs: number;
|
||||
timeoutMs: number;
|
||||
debug?: (message: string, data?: Record<string, unknown>) => void;
|
||||
initial?: OpenAiBatchStatus;
|
||||
}): Promise<BatchCompletionResult> {
|
||||
const start = Date.now();
|
||||
let current: OpenAiBatchStatus | undefined = params.initial;
|
||||
while (true) {
|
||||
const status =
|
||||
current ??
|
||||
(await fetchOpenAiBatchStatus({
|
||||
openAi: params.openAi,
|
||||
batchId: params.batchId,
|
||||
}));
|
||||
const state = status.status ?? "unknown";
|
||||
if (state === "completed") {
|
||||
return resolveBatchCompletionFromStatus({
|
||||
provider: "openai",
|
||||
batchId: params.batchId,
|
||||
status,
|
||||
});
|
||||
}
|
||||
await throwIfBatchTerminalFailure({
|
||||
provider: "openai",
|
||||
status: { ...status, id: params.batchId },
|
||||
readError: async (errorFileId) =>
|
||||
await readOpenAiBatchError({
|
||||
openAi: params.openAi,
|
||||
errorFileId,
|
||||
}),
|
||||
});
|
||||
if (!params.wait) {
|
||||
throw new Error(`openai batch ${params.batchId} still ${state}; wait disabled`);
|
||||
}
|
||||
if (Date.now() - start > params.timeoutMs) {
|
||||
throw new Error(`openai batch ${params.batchId} timed out after ${params.timeoutMs}ms`);
|
||||
}
|
||||
params.debug?.(`openai batch ${params.batchId} ${state}; waiting ${params.pollIntervalMs}ms`);
|
||||
await new Promise((resolve) => setTimeout(resolve, params.pollIntervalMs));
|
||||
current = undefined;
|
||||
}
|
||||
}
|
||||
|
||||
export async function runOpenAiEmbeddingBatches(
|
||||
params: {
|
||||
openAi: OpenAiEmbeddingClient;
|
||||
agentId: string;
|
||||
requests: OpenAiBatchRequest[];
|
||||
} & EmbeddingBatchExecutionParams,
|
||||
): Promise<Map<string, number[]>> {
|
||||
return await runEmbeddingBatchGroups({
|
||||
...buildEmbeddingBatchGroupOptions(params, {
|
||||
maxRequests: OPENAI_BATCH_MAX_REQUESTS,
|
||||
debugLabel: "memory embeddings: openai batch submit",
|
||||
}),
|
||||
runGroup: async ({ group, groupIndex, groups, byCustomId }) => {
|
||||
const batchInfo = await submitOpenAiBatch({
|
||||
openAi: params.openAi,
|
||||
requests: group,
|
||||
agentId: params.agentId,
|
||||
});
|
||||
if (!batchInfo.id) {
|
||||
throw new Error("openai batch create failed: missing batch id");
|
||||
}
|
||||
const batchId = batchInfo.id;
|
||||
|
||||
params.debug?.("memory embeddings: openai batch created", {
|
||||
batchId: batchInfo.id,
|
||||
status: batchInfo.status,
|
||||
group: groupIndex + 1,
|
||||
groups,
|
||||
requests: group.length,
|
||||
});
|
||||
|
||||
const completed = await resolveCompletedBatchResult({
|
||||
provider: "openai",
|
||||
status: batchInfo,
|
||||
wait: params.wait,
|
||||
waitForBatch: async () =>
|
||||
await waitForOpenAiBatch({
|
||||
openAi: params.openAi,
|
||||
batchId,
|
||||
wait: params.wait,
|
||||
pollIntervalMs: params.pollIntervalMs,
|
||||
timeoutMs: params.timeoutMs,
|
||||
debug: params.debug,
|
||||
initial: batchInfo,
|
||||
}),
|
||||
});
|
||||
|
||||
const content = await fetchOpenAiFileContent({
|
||||
openAi: params.openAi,
|
||||
fileId: completed.outputFileId,
|
||||
});
|
||||
const outputLines = parseOpenAiBatchOutput(content);
|
||||
const errors: string[] = [];
|
||||
const remaining = new Set(group.map((request) => request.custom_id));
|
||||
|
||||
for (const line of outputLines) {
|
||||
applyEmbeddingBatchOutputLine({ line, remaining, errors, byCustomId });
|
||||
}
|
||||
|
||||
if (errors.length > 0) {
|
||||
throw new Error(`openai batch ${batchInfo.id} failed: ${errors.join("; ")}`);
|
||||
}
|
||||
if (remaining.size > 0) {
|
||||
throw new Error(
|
||||
`openai batch ${batchInfo.id} missing ${remaining.size} embedding responses`,
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
}
|
||||
@@ -1,82 +0,0 @@
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { applyEmbeddingBatchOutputLine } from "./batch-output.js";
|
||||
|
||||
describe("applyEmbeddingBatchOutputLine", () => {
|
||||
it("stores embedding for successful response", () => {
|
||||
const remaining = new Set(["req-1"]);
|
||||
const errors: string[] = [];
|
||||
const byCustomId = new Map<string, number[]>();
|
||||
|
||||
applyEmbeddingBatchOutputLine({
|
||||
line: {
|
||||
custom_id: "req-1",
|
||||
response: {
|
||||
status_code: 200,
|
||||
body: { data: [{ embedding: [0.1, 0.2] }] },
|
||||
},
|
||||
},
|
||||
remaining,
|
||||
errors,
|
||||
byCustomId,
|
||||
});
|
||||
|
||||
expect(remaining.has("req-1")).toBe(false);
|
||||
expect(errors).toEqual([]);
|
||||
expect(byCustomId.get("req-1")).toEqual([0.1, 0.2]);
|
||||
});
|
||||
|
||||
it("records provider error from line.error", () => {
|
||||
const remaining = new Set(["req-2"]);
|
||||
const errors: string[] = [];
|
||||
const byCustomId = new Map<string, number[]>();
|
||||
|
||||
applyEmbeddingBatchOutputLine({
|
||||
line: {
|
||||
custom_id: "req-2",
|
||||
error: { message: "provider failed" },
|
||||
},
|
||||
remaining,
|
||||
errors,
|
||||
byCustomId,
|
||||
});
|
||||
|
||||
expect(remaining.has("req-2")).toBe(false);
|
||||
expect(errors).toEqual(["req-2: provider failed"]);
|
||||
expect(byCustomId.size).toBe(0);
|
||||
});
|
||||
|
||||
it("records non-2xx response errors and empty embedding errors", () => {
|
||||
const remaining = new Set(["req-3", "req-4"]);
|
||||
const errors: string[] = [];
|
||||
const byCustomId = new Map<string, number[]>();
|
||||
|
||||
applyEmbeddingBatchOutputLine({
|
||||
line: {
|
||||
custom_id: "req-3",
|
||||
response: {
|
||||
status_code: 500,
|
||||
body: { error: { message: "internal" } },
|
||||
},
|
||||
},
|
||||
remaining,
|
||||
errors,
|
||||
byCustomId,
|
||||
});
|
||||
|
||||
applyEmbeddingBatchOutputLine({
|
||||
line: {
|
||||
custom_id: "req-4",
|
||||
response: {
|
||||
status_code: 200,
|
||||
body: { data: [] },
|
||||
},
|
||||
},
|
||||
remaining,
|
||||
errors,
|
||||
byCustomId,
|
||||
});
|
||||
|
||||
expect(errors).toEqual(["req-3: internal", "req-4: empty embedding"]);
|
||||
expect(byCustomId.size).toBe(0);
|
||||
});
|
||||
});
|
||||
@@ -1,55 +0,0 @@
|
||||
export type EmbeddingBatchOutputLine = {
|
||||
custom_id?: string;
|
||||
error?: { message?: string };
|
||||
response?: {
|
||||
status_code?: number;
|
||||
body?:
|
||||
| {
|
||||
data?: Array<{
|
||||
embedding?: number[];
|
||||
}>;
|
||||
error?: { message?: string };
|
||||
}
|
||||
| string;
|
||||
};
|
||||
};
|
||||
|
||||
export function applyEmbeddingBatchOutputLine(params: {
|
||||
line: EmbeddingBatchOutputLine;
|
||||
remaining: Set<string>;
|
||||
errors: string[];
|
||||
byCustomId: Map<string, number[]>;
|
||||
}) {
|
||||
const customId = params.line.custom_id;
|
||||
if (!customId) {
|
||||
return;
|
||||
}
|
||||
params.remaining.delete(customId);
|
||||
|
||||
const errorMessage = params.line.error?.message;
|
||||
if (errorMessage) {
|
||||
params.errors.push(`${customId}: ${errorMessage}`);
|
||||
return;
|
||||
}
|
||||
|
||||
const response = params.line.response;
|
||||
const statusCode = response?.status_code ?? 0;
|
||||
if (statusCode >= 400) {
|
||||
const messageFromObject =
|
||||
response?.body && typeof response.body === "object"
|
||||
? (response.body as { error?: { message?: string } }).error?.message
|
||||
: undefined;
|
||||
const messageFromString = typeof response?.body === "string" ? response.body : undefined;
|
||||
params.errors.push(`${customId}: ${messageFromObject ?? messageFromString ?? "unknown error"}`);
|
||||
return;
|
||||
}
|
||||
|
||||
const data =
|
||||
response?.body && typeof response.body === "object" ? (response.body.data ?? []) : [];
|
||||
const embedding = data[0]?.embedding ?? [];
|
||||
if (embedding.length === 0) {
|
||||
params.errors.push(`${customId}: empty embedding`);
|
||||
return;
|
||||
}
|
||||
params.byCustomId.set(customId, embedding);
|
||||
}
|
||||
@@ -1,12 +0,0 @@
|
||||
import type { EmbeddingBatchOutputLine } from "./batch-output.js";
|
||||
|
||||
export type EmbeddingBatchStatus = {
|
||||
id?: string;
|
||||
status?: string;
|
||||
output_file_id?: string | null;
|
||||
error_file_id?: string | null;
|
||||
};
|
||||
|
||||
export type ProviderBatchOutputLine = EmbeddingBatchOutputLine;
|
||||
|
||||
export const EMBEDDING_BATCH_ENDPOINT = "/v1/embeddings";
|
||||
@@ -1,64 +0,0 @@
|
||||
import { splitBatchRequests } from "./batch-utils.js";
|
||||
import { runWithConcurrency } from "./internal.js";
|
||||
|
||||
export type EmbeddingBatchExecutionParams = {
|
||||
wait: boolean;
|
||||
pollIntervalMs: number;
|
||||
timeoutMs: number;
|
||||
concurrency: number;
|
||||
debug?: (message: string, data?: Record<string, unknown>) => void;
|
||||
};
|
||||
|
||||
export async function runEmbeddingBatchGroups<TRequest>(params: {
|
||||
requests: TRequest[];
|
||||
maxRequests: number;
|
||||
wait: EmbeddingBatchExecutionParams["wait"];
|
||||
pollIntervalMs: EmbeddingBatchExecutionParams["pollIntervalMs"];
|
||||
timeoutMs: EmbeddingBatchExecutionParams["timeoutMs"];
|
||||
concurrency: EmbeddingBatchExecutionParams["concurrency"];
|
||||
debugLabel: string;
|
||||
debug?: EmbeddingBatchExecutionParams["debug"];
|
||||
runGroup: (args: {
|
||||
group: TRequest[];
|
||||
groupIndex: number;
|
||||
groups: number;
|
||||
byCustomId: Map<string, number[]>;
|
||||
}) => Promise<void>;
|
||||
}): Promise<Map<string, number[]>> {
|
||||
if (params.requests.length === 0) {
|
||||
return new Map();
|
||||
}
|
||||
const groups = splitBatchRequests(params.requests, params.maxRequests);
|
||||
const byCustomId = new Map<string, number[]>();
|
||||
const tasks = groups.map((group, groupIndex) => async () => {
|
||||
await params.runGroup({ group, groupIndex, groups: groups.length, byCustomId });
|
||||
});
|
||||
|
||||
params.debug?.(params.debugLabel, {
|
||||
requests: params.requests.length,
|
||||
groups: groups.length,
|
||||
wait: params.wait,
|
||||
concurrency: params.concurrency,
|
||||
pollIntervalMs: params.pollIntervalMs,
|
||||
timeoutMs: params.timeoutMs,
|
||||
});
|
||||
|
||||
await runWithConcurrency(tasks, params.concurrency);
|
||||
return byCustomId;
|
||||
}
|
||||
|
||||
export function buildEmbeddingBatchGroupOptions<TRequest>(
|
||||
params: { requests: TRequest[] } & EmbeddingBatchExecutionParams,
|
||||
options: { maxRequests: number; debugLabel: string },
|
||||
) {
|
||||
return {
|
||||
requests: params.requests,
|
||||
maxRequests: options.maxRequests,
|
||||
wait: params.wait,
|
||||
pollIntervalMs: params.pollIntervalMs,
|
||||
timeoutMs: params.timeoutMs,
|
||||
concurrency: params.concurrency,
|
||||
debug: params.debug,
|
||||
debugLabel: options.debugLabel,
|
||||
};
|
||||
}
|
||||
@@ -1,60 +0,0 @@
|
||||
import { describe, expect, it } from "vitest";
|
||||
import {
|
||||
resolveBatchCompletionFromStatus,
|
||||
resolveCompletedBatchResult,
|
||||
throwIfBatchTerminalFailure,
|
||||
} from "./batch-status.js";
|
||||
|
||||
describe("batch-status helpers", () => {
|
||||
it("resolves completion payload from completed status", () => {
|
||||
expect(
|
||||
resolveBatchCompletionFromStatus({
|
||||
provider: "openai",
|
||||
batchId: "b1",
|
||||
status: {
|
||||
output_file_id: "out-1",
|
||||
error_file_id: "err-1",
|
||||
},
|
||||
}),
|
||||
).toEqual({
|
||||
outputFileId: "out-1",
|
||||
errorFileId: "err-1",
|
||||
});
|
||||
});
|
||||
|
||||
it("throws for terminal failure states", async () => {
|
||||
await expect(
|
||||
throwIfBatchTerminalFailure({
|
||||
provider: "voyage",
|
||||
status: { id: "b2", status: "failed", error_file_id: "err-file" },
|
||||
readError: async () => "bad input",
|
||||
}),
|
||||
).rejects.toThrow("voyage batch b2 failed: bad input");
|
||||
});
|
||||
|
||||
it("returns completed result directly without waiting", async () => {
|
||||
const waitForBatch = async () => ({ outputFileId: "out-2" });
|
||||
const result = await resolveCompletedBatchResult({
|
||||
provider: "openai",
|
||||
status: {
|
||||
id: "b3",
|
||||
status: "completed",
|
||||
output_file_id: "out-3",
|
||||
},
|
||||
wait: false,
|
||||
waitForBatch,
|
||||
});
|
||||
expect(result).toEqual({ outputFileId: "out-3", errorFileId: undefined });
|
||||
});
|
||||
|
||||
it("throws when wait disabled and batch is not complete", async () => {
|
||||
await expect(
|
||||
resolveCompletedBatchResult({
|
||||
provider: "openai",
|
||||
status: { id: "b4", status: "pending" },
|
||||
wait: false,
|
||||
waitForBatch: async () => ({ outputFileId: "out" }),
|
||||
}),
|
||||
).rejects.toThrow("openai batch b4 submitted; enable remote.batch.wait to await completion");
|
||||
});
|
||||
});
|
||||
@@ -1,69 +0,0 @@
|
||||
const TERMINAL_FAILURE_STATES = new Set(["failed", "expired", "cancelled", "canceled"]);
|
||||
|
||||
type BatchStatusLike = {
|
||||
id?: string;
|
||||
status?: string;
|
||||
output_file_id?: string | null;
|
||||
error_file_id?: string | null;
|
||||
};
|
||||
|
||||
export type BatchCompletionResult = {
|
||||
outputFileId: string;
|
||||
errorFileId?: string;
|
||||
};
|
||||
|
||||
export function resolveBatchCompletionFromStatus(params: {
|
||||
provider: string;
|
||||
batchId: string;
|
||||
status: BatchStatusLike;
|
||||
}): BatchCompletionResult {
|
||||
if (!params.status.output_file_id) {
|
||||
throw new Error(`${params.provider} batch ${params.batchId} completed without output file`);
|
||||
}
|
||||
return {
|
||||
outputFileId: params.status.output_file_id,
|
||||
errorFileId: params.status.error_file_id ?? undefined,
|
||||
};
|
||||
}
|
||||
|
||||
export async function throwIfBatchTerminalFailure(params: {
|
||||
provider: string;
|
||||
status: BatchStatusLike;
|
||||
readError: (errorFileId: string) => Promise<string | undefined>;
|
||||
}): Promise<void> {
|
||||
const state = params.status.status ?? "unknown";
|
||||
if (!TERMINAL_FAILURE_STATES.has(state)) {
|
||||
return;
|
||||
}
|
||||
const detail = params.status.error_file_id
|
||||
? await params.readError(params.status.error_file_id)
|
||||
: undefined;
|
||||
const suffix = detail ? `: ${detail}` : "";
|
||||
throw new Error(`${params.provider} batch ${params.status.id ?? "<unknown>"} ${state}${suffix}`);
|
||||
}
|
||||
|
||||
export async function resolveCompletedBatchResult(params: {
|
||||
provider: string;
|
||||
status: BatchStatusLike;
|
||||
wait: boolean;
|
||||
waitForBatch: () => Promise<BatchCompletionResult>;
|
||||
}): Promise<BatchCompletionResult> {
|
||||
const batchId = params.status.id ?? "<unknown>";
|
||||
if (!params.wait && params.status.status !== "completed") {
|
||||
throw new Error(
|
||||
`${params.provider} batch ${batchId} submitted; enable remote.batch.wait to await completion`,
|
||||
);
|
||||
}
|
||||
const completed =
|
||||
params.status.status === "completed"
|
||||
? resolveBatchCompletionFromStatus({
|
||||
provider: params.provider,
|
||||
batchId,
|
||||
status: params.status,
|
||||
})
|
||||
: await params.waitForBatch();
|
||||
if (!completed.outputFileId) {
|
||||
throw new Error(`${params.provider} batch ${batchId} completed without output file`);
|
||||
}
|
||||
return completed;
|
||||
}
|
||||
@@ -1,44 +0,0 @@
|
||||
import {
|
||||
buildBatchHeaders,
|
||||
normalizeBatchBaseUrl,
|
||||
type BatchHttpClientConfig,
|
||||
} from "./batch-utils.js";
|
||||
import { hashText } from "./internal.js";
|
||||
import { withRemoteHttpResponse } from "./remote-http.js";
|
||||
|
||||
export async function uploadBatchJsonlFile(params: {
|
||||
client: BatchHttpClientConfig;
|
||||
requests: unknown[];
|
||||
errorPrefix: string;
|
||||
}): Promise<string> {
|
||||
const baseUrl = normalizeBatchBaseUrl(params.client);
|
||||
const jsonl = params.requests.map((request) => JSON.stringify(request)).join("\n");
|
||||
const form = new FormData();
|
||||
form.append("purpose", "batch");
|
||||
form.append(
|
||||
"file",
|
||||
new Blob([jsonl], { type: "application/jsonl" }),
|
||||
`memory-embeddings.${hashText(String(Date.now()))}.jsonl`,
|
||||
);
|
||||
|
||||
const filePayload = await withRemoteHttpResponse({
|
||||
url: `${baseUrl}/files`,
|
||||
ssrfPolicy: params.client.ssrfPolicy,
|
||||
init: {
|
||||
method: "POST",
|
||||
headers: buildBatchHeaders(params.client, { json: false }),
|
||||
body: form,
|
||||
},
|
||||
onResponse: async (fileRes) => {
|
||||
if (!fileRes.ok) {
|
||||
const text = await fileRes.text();
|
||||
throw new Error(`${params.errorPrefix}: ${fileRes.status} ${text}`);
|
||||
}
|
||||
return (await fileRes.json()) as { id?: string };
|
||||
},
|
||||
});
|
||||
if (!filePayload.id) {
|
||||
throw new Error(`${params.errorPrefix}: missing file id`);
|
||||
}
|
||||
return filePayload.id;
|
||||
}
|
||||
@@ -1,38 +0,0 @@
|
||||
import type { SsrFPolicy } from "../../infra/net/ssrf.js";
|
||||
|
||||
export type BatchHttpClientConfig = {
|
||||
baseUrl?: string;
|
||||
headers?: Record<string, string>;
|
||||
ssrfPolicy?: SsrFPolicy;
|
||||
};
|
||||
|
||||
export function normalizeBatchBaseUrl(client: BatchHttpClientConfig): string {
|
||||
return client.baseUrl?.replace(/\/$/, "") ?? "";
|
||||
}
|
||||
|
||||
export function buildBatchHeaders(
|
||||
client: Pick<BatchHttpClientConfig, "headers">,
|
||||
params: { json: boolean },
|
||||
): Record<string, string> {
|
||||
const headers = client.headers ? { ...client.headers } : {};
|
||||
if (params.json) {
|
||||
if (!headers["Content-Type"] && !headers["content-type"]) {
|
||||
headers["Content-Type"] = "application/json";
|
||||
}
|
||||
} else {
|
||||
delete headers["Content-Type"];
|
||||
delete headers["content-type"];
|
||||
}
|
||||
return headers;
|
||||
}
|
||||
|
||||
export function splitBatchRequests<T>(requests: T[], maxRequests: number): T[][] {
|
||||
if (requests.length <= maxRequests) {
|
||||
return [requests];
|
||||
}
|
||||
const groups: T[][] = [];
|
||||
for (let i = 0; i < requests.length; i += maxRequests) {
|
||||
groups.push(requests.slice(i, i + maxRequests));
|
||||
}
|
||||
return groups;
|
||||
}
|
||||
@@ -1,176 +0,0 @@
|
||||
import { ReadableStream } from "node:stream/web";
|
||||
import { setTimeout as nativeSleep } from "node:timers/promises";
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
import {
|
||||
runVoyageEmbeddingBatches,
|
||||
type VoyageBatchOutputLine,
|
||||
type VoyageBatchRequest,
|
||||
} from "./batch-voyage.js";
|
||||
import type { VoyageEmbeddingClient } from "./embeddings-voyage.js";
|
||||
|
||||
const realNow = Date.now.bind(Date);
|
||||
|
||||
describe("runVoyageEmbeddingBatches", () => {
|
||||
const mockClient: VoyageEmbeddingClient = {
|
||||
baseUrl: "https://api.voyageai.com/v1",
|
||||
headers: { Authorization: "Bearer test-key" },
|
||||
model: "voyage-4-large",
|
||||
};
|
||||
|
||||
const mockRequests: VoyageBatchRequest[] = [
|
||||
{ custom_id: "req-1", body: { input: "text1" } },
|
||||
{ custom_id: "req-2", body: { input: "text2" } },
|
||||
];
|
||||
|
||||
it("successfully submits batch, waits, and streams results", async () => {
|
||||
const outputLines: VoyageBatchOutputLine[] = [
|
||||
{
|
||||
custom_id: "req-1",
|
||||
response: { status_code: 200, body: { data: [{ embedding: [0.1, 0.1] }] } },
|
||||
},
|
||||
{
|
||||
custom_id: "req-2",
|
||||
response: { status_code: 200, body: { data: [{ embedding: [0.2, 0.2] }] } },
|
||||
},
|
||||
];
|
||||
const withRemoteHttpResponse = vi.fn();
|
||||
const postJsonWithRetry = vi.fn();
|
||||
const uploadBatchJsonlFile = vi.fn();
|
||||
|
||||
// Create a stream that emits the NDJSON lines
|
||||
const stream = new ReadableStream({
|
||||
start(controller) {
|
||||
const text = outputLines.map((l) => JSON.stringify(l)).join("\n");
|
||||
controller.enqueue(new TextEncoder().encode(text));
|
||||
controller.close();
|
||||
},
|
||||
});
|
||||
uploadBatchJsonlFile.mockImplementationOnce(async (params) => {
|
||||
expect(params.errorPrefix).toBe("voyage batch file upload failed");
|
||||
expect(params.requests).toEqual(mockRequests);
|
||||
return "file-123";
|
||||
});
|
||||
postJsonWithRetry.mockImplementationOnce(async (params) => {
|
||||
expect(params.url).toContain("/batches");
|
||||
expect(params.body).toMatchObject({
|
||||
input_file_id: "file-123",
|
||||
completion_window: "12h",
|
||||
request_params: {
|
||||
model: "voyage-4-large",
|
||||
input_type: "document",
|
||||
},
|
||||
});
|
||||
return {
|
||||
id: "batch-abc",
|
||||
status: "pending",
|
||||
};
|
||||
});
|
||||
withRemoteHttpResponse.mockImplementationOnce(async (params) => {
|
||||
expect(params.url).toContain("/batches/batch-abc");
|
||||
return await params.onResponse(
|
||||
new Response(
|
||||
JSON.stringify({
|
||||
id: "batch-abc",
|
||||
status: "completed",
|
||||
output_file_id: "file-out-999",
|
||||
}),
|
||||
{
|
||||
status: 200,
|
||||
headers: { "Content-Type": "application/json" },
|
||||
},
|
||||
),
|
||||
);
|
||||
});
|
||||
withRemoteHttpResponse.mockImplementationOnce(async (params) => {
|
||||
expect(params.url).toContain("/files/file-out-999/content");
|
||||
return await params.onResponse(
|
||||
new Response(stream as unknown as BodyInit, {
|
||||
status: 200,
|
||||
headers: { "Content-Type": "application/x-ndjson" },
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
const results = await runVoyageEmbeddingBatches({
|
||||
client: mockClient,
|
||||
agentId: "agent-1",
|
||||
requests: mockRequests,
|
||||
wait: true,
|
||||
pollIntervalMs: 1, // fast poll
|
||||
timeoutMs: 1000,
|
||||
concurrency: 1,
|
||||
deps: {
|
||||
now: realNow,
|
||||
sleep: async (ms) => {
|
||||
await nativeSleep(ms);
|
||||
},
|
||||
postJsonWithRetry,
|
||||
uploadBatchJsonlFile,
|
||||
withRemoteHttpResponse,
|
||||
},
|
||||
});
|
||||
|
||||
expect(results.size).toBe(2);
|
||||
expect(results.get("req-1")).toEqual([0.1, 0.1]);
|
||||
expect(results.get("req-2")).toEqual([0.2, 0.2]);
|
||||
expect(uploadBatchJsonlFile).toHaveBeenCalledTimes(1);
|
||||
expect(postJsonWithRetry).toHaveBeenCalledTimes(1);
|
||||
expect(withRemoteHttpResponse).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it("handles empty lines and stream chunks correctly", async () => {
|
||||
const withRemoteHttpResponse = vi.fn();
|
||||
const postJsonWithRetry = vi.fn();
|
||||
const uploadBatchJsonlFile = vi.fn();
|
||||
const stream = new ReadableStream({
|
||||
start(controller) {
|
||||
const line1 = JSON.stringify({
|
||||
custom_id: "req-1",
|
||||
response: { body: { data: [{ embedding: [1] }] } },
|
||||
});
|
||||
const line2 = JSON.stringify({
|
||||
custom_id: "req-2",
|
||||
response: { body: { data: [{ embedding: [2] }] } },
|
||||
});
|
||||
|
||||
// Split across chunks
|
||||
controller.enqueue(new TextEncoder().encode(line1 + "\n"));
|
||||
controller.enqueue(new TextEncoder().encode("\n")); // empty line
|
||||
controller.enqueue(new TextEncoder().encode(line2)); // no newline at EOF
|
||||
controller.close();
|
||||
},
|
||||
});
|
||||
uploadBatchJsonlFile.mockResolvedValueOnce("f1");
|
||||
postJsonWithRetry.mockResolvedValueOnce({
|
||||
id: "b1",
|
||||
status: "completed",
|
||||
output_file_id: "out1",
|
||||
});
|
||||
withRemoteHttpResponse.mockImplementationOnce(async (params) => {
|
||||
expect(params.url).toContain("/files/out1/content");
|
||||
return await params.onResponse(new Response(stream as unknown as BodyInit, { status: 200 }));
|
||||
});
|
||||
|
||||
const results = await runVoyageEmbeddingBatches({
|
||||
client: mockClient,
|
||||
agentId: "a1",
|
||||
requests: mockRequests,
|
||||
wait: true,
|
||||
pollIntervalMs: 1,
|
||||
timeoutMs: 1000,
|
||||
concurrency: 1,
|
||||
deps: {
|
||||
now: realNow,
|
||||
sleep: async (ms) => {
|
||||
await nativeSleep(ms);
|
||||
},
|
||||
postJsonWithRetry,
|
||||
uploadBatchJsonlFile,
|
||||
withRemoteHttpResponse,
|
||||
},
|
||||
});
|
||||
|
||||
expect(results.get("req-1")).toEqual([1]);
|
||||
expect(results.get("req-2")).toEqual([2]);
|
||||
});
|
||||
});
|
||||
@@ -1,315 +0,0 @@
|
||||
import { createInterface } from "node:readline";
|
||||
import { Readable } from "node:stream";
|
||||
import {
|
||||
applyEmbeddingBatchOutputLine,
|
||||
buildBatchHeaders,
|
||||
buildEmbeddingBatchGroupOptions,
|
||||
EMBEDDING_BATCH_ENDPOINT,
|
||||
extractBatchErrorMessage,
|
||||
formatUnavailableBatchError,
|
||||
normalizeBatchBaseUrl,
|
||||
postJsonWithRetry,
|
||||
resolveBatchCompletionFromStatus,
|
||||
resolveCompletedBatchResult,
|
||||
runEmbeddingBatchGroups,
|
||||
throwIfBatchTerminalFailure,
|
||||
type EmbeddingBatchExecutionParams,
|
||||
type EmbeddingBatchStatus,
|
||||
type BatchCompletionResult,
|
||||
type ProviderBatchOutputLine,
|
||||
uploadBatchJsonlFile,
|
||||
withRemoteHttpResponse,
|
||||
} from "./batch-embedding-common.js";
|
||||
import type { VoyageEmbeddingClient } from "./embeddings-voyage.js";
|
||||
|
||||
/**
|
||||
* Voyage Batch API Input Line format.
|
||||
* See: https://docs.voyageai.com/docs/batch-inference
|
||||
*/
|
||||
export type VoyageBatchRequest = {
|
||||
custom_id: string;
|
||||
body: {
|
||||
input: string | string[];
|
||||
};
|
||||
};
|
||||
|
||||
export type VoyageBatchStatus = EmbeddingBatchStatus;
|
||||
export type VoyageBatchOutputLine = ProviderBatchOutputLine;
|
||||
|
||||
export const VOYAGE_BATCH_ENDPOINT = EMBEDDING_BATCH_ENDPOINT;
|
||||
const VOYAGE_BATCH_COMPLETION_WINDOW = "12h";
|
||||
const VOYAGE_BATCH_MAX_REQUESTS = 50000;
|
||||
|
||||
type VoyageBatchDeps = {
|
||||
now: () => number;
|
||||
sleep: (ms: number) => Promise<void>;
|
||||
postJsonWithRetry: typeof postJsonWithRetry;
|
||||
uploadBatchJsonlFile: typeof uploadBatchJsonlFile;
|
||||
withRemoteHttpResponse: typeof withRemoteHttpResponse;
|
||||
};
|
||||
|
||||
function resolveVoyageBatchDeps(overrides: Partial<VoyageBatchDeps> | undefined): VoyageBatchDeps {
|
||||
return {
|
||||
now: overrides?.now ?? Date.now,
|
||||
sleep:
|
||||
overrides?.sleep ??
|
||||
(async (ms: number) => await new Promise((resolve) => setTimeout(resolve, ms))),
|
||||
postJsonWithRetry: overrides?.postJsonWithRetry ?? postJsonWithRetry,
|
||||
uploadBatchJsonlFile: overrides?.uploadBatchJsonlFile ?? uploadBatchJsonlFile,
|
||||
withRemoteHttpResponse: overrides?.withRemoteHttpResponse ?? withRemoteHttpResponse,
|
||||
};
|
||||
}
|
||||
|
||||
async function assertVoyageResponseOk(res: Response, context: string): Promise<void> {
|
||||
if (!res.ok) {
|
||||
const text = await res.text();
|
||||
throw new Error(`${context}: ${res.status} ${text}`);
|
||||
}
|
||||
}
|
||||
|
||||
function buildVoyageBatchRequest<T>(params: {
|
||||
client: VoyageEmbeddingClient;
|
||||
path: string;
|
||||
onResponse: (res: Response) => Promise<T>;
|
||||
}) {
|
||||
const baseUrl = normalizeBatchBaseUrl(params.client);
|
||||
return {
|
||||
url: `${baseUrl}/${params.path}`,
|
||||
ssrfPolicy: params.client.ssrfPolicy,
|
||||
init: {
|
||||
headers: buildBatchHeaders(params.client, { json: true }),
|
||||
},
|
||||
onResponse: params.onResponse,
|
||||
};
|
||||
}
|
||||
|
||||
async function submitVoyageBatch(params: {
|
||||
client: VoyageEmbeddingClient;
|
||||
requests: VoyageBatchRequest[];
|
||||
agentId: string;
|
||||
deps: VoyageBatchDeps;
|
||||
}): Promise<VoyageBatchStatus> {
|
||||
const baseUrl = normalizeBatchBaseUrl(params.client);
|
||||
const inputFileId = await params.deps.uploadBatchJsonlFile({
|
||||
client: params.client,
|
||||
requests: params.requests,
|
||||
errorPrefix: "voyage batch file upload failed",
|
||||
});
|
||||
|
||||
// 2. Create batch job using Voyage Batches API
|
||||
return await params.deps.postJsonWithRetry<VoyageBatchStatus>({
|
||||
url: `${baseUrl}/batches`,
|
||||
headers: buildBatchHeaders(params.client, { json: true }),
|
||||
ssrfPolicy: params.client.ssrfPolicy,
|
||||
body: {
|
||||
input_file_id: inputFileId,
|
||||
endpoint: VOYAGE_BATCH_ENDPOINT,
|
||||
completion_window: VOYAGE_BATCH_COMPLETION_WINDOW,
|
||||
request_params: {
|
||||
model: params.client.model,
|
||||
input_type: "document",
|
||||
},
|
||||
metadata: {
|
||||
source: "clawdbot-memory",
|
||||
agent: params.agentId,
|
||||
},
|
||||
},
|
||||
errorPrefix: "voyage batch create failed",
|
||||
});
|
||||
}
|
||||
|
||||
async function fetchVoyageBatchStatus(params: {
|
||||
client: VoyageEmbeddingClient;
|
||||
batchId: string;
|
||||
deps: VoyageBatchDeps;
|
||||
}): Promise<VoyageBatchStatus> {
|
||||
return await params.deps.withRemoteHttpResponse(
|
||||
buildVoyageBatchRequest({
|
||||
client: params.client,
|
||||
path: `batches/${params.batchId}`,
|
||||
onResponse: async (res) => {
|
||||
await assertVoyageResponseOk(res, "voyage batch status failed");
|
||||
return (await res.json()) as VoyageBatchStatus;
|
||||
},
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
async function readVoyageBatchError(params: {
|
||||
client: VoyageEmbeddingClient;
|
||||
errorFileId: string;
|
||||
deps: VoyageBatchDeps;
|
||||
}): Promise<string | undefined> {
|
||||
try {
|
||||
return await params.deps.withRemoteHttpResponse(
|
||||
buildVoyageBatchRequest({
|
||||
client: params.client,
|
||||
path: `files/${params.errorFileId}/content`,
|
||||
onResponse: async (res) => {
|
||||
await assertVoyageResponseOk(res, "voyage batch error file content failed");
|
||||
const text = await res.text();
|
||||
if (!text.trim()) {
|
||||
return undefined;
|
||||
}
|
||||
const lines = text
|
||||
.split("\n")
|
||||
.map((line) => line.trim())
|
||||
.filter(Boolean)
|
||||
.map((line) => JSON.parse(line) as VoyageBatchOutputLine);
|
||||
return extractBatchErrorMessage(lines);
|
||||
},
|
||||
}),
|
||||
);
|
||||
} catch (err) {
|
||||
return formatUnavailableBatchError(err);
|
||||
}
|
||||
}
|
||||
|
||||
async function waitForVoyageBatch(params: {
|
||||
client: VoyageEmbeddingClient;
|
||||
batchId: string;
|
||||
wait: boolean;
|
||||
pollIntervalMs: number;
|
||||
timeoutMs: number;
|
||||
debug?: (message: string, data?: Record<string, unknown>) => void;
|
||||
initial?: VoyageBatchStatus;
|
||||
deps: VoyageBatchDeps;
|
||||
}): Promise<BatchCompletionResult> {
|
||||
const start = params.deps.now();
|
||||
let current: VoyageBatchStatus | undefined = params.initial;
|
||||
while (true) {
|
||||
const status =
|
||||
current ??
|
||||
(await fetchVoyageBatchStatus({
|
||||
client: params.client,
|
||||
batchId: params.batchId,
|
||||
deps: params.deps,
|
||||
}));
|
||||
const state = status.status ?? "unknown";
|
||||
if (state === "completed") {
|
||||
return resolveBatchCompletionFromStatus({
|
||||
provider: "voyage",
|
||||
batchId: params.batchId,
|
||||
status,
|
||||
});
|
||||
}
|
||||
await throwIfBatchTerminalFailure({
|
||||
provider: "voyage",
|
||||
status: { ...status, id: params.batchId },
|
||||
readError: async (errorFileId) =>
|
||||
await readVoyageBatchError({
|
||||
client: params.client,
|
||||
errorFileId,
|
||||
deps: params.deps,
|
||||
}),
|
||||
});
|
||||
if (!params.wait) {
|
||||
throw new Error(`voyage batch ${params.batchId} still ${state}; wait disabled`);
|
||||
}
|
||||
if (params.deps.now() - start > params.timeoutMs) {
|
||||
throw new Error(`voyage batch ${params.batchId} timed out after ${params.timeoutMs}ms`);
|
||||
}
|
||||
params.debug?.(`voyage batch ${params.batchId} ${state}; waiting ${params.pollIntervalMs}ms`);
|
||||
await params.deps.sleep(params.pollIntervalMs);
|
||||
current = undefined;
|
||||
}
|
||||
}
|
||||
|
||||
export async function runVoyageEmbeddingBatches(
|
||||
params: {
|
||||
client: VoyageEmbeddingClient;
|
||||
agentId: string;
|
||||
requests: VoyageBatchRequest[];
|
||||
deps?: Partial<VoyageBatchDeps>;
|
||||
} & EmbeddingBatchExecutionParams,
|
||||
): Promise<Map<string, number[]>> {
|
||||
const deps = resolveVoyageBatchDeps(params.deps);
|
||||
return await runEmbeddingBatchGroups({
|
||||
...buildEmbeddingBatchGroupOptions(params, {
|
||||
maxRequests: VOYAGE_BATCH_MAX_REQUESTS,
|
||||
debugLabel: "memory embeddings: voyage batch submit",
|
||||
}),
|
||||
runGroup: async ({ group, groupIndex, groups, byCustomId }) => {
|
||||
const batchInfo = await submitVoyageBatch({
|
||||
client: params.client,
|
||||
requests: group,
|
||||
agentId: params.agentId,
|
||||
deps,
|
||||
});
|
||||
if (!batchInfo.id) {
|
||||
throw new Error("voyage batch create failed: missing batch id");
|
||||
}
|
||||
const batchId = batchInfo.id;
|
||||
|
||||
params.debug?.("memory embeddings: voyage batch created", {
|
||||
batchId: batchInfo.id,
|
||||
status: batchInfo.status,
|
||||
group: groupIndex + 1,
|
||||
groups,
|
||||
requests: group.length,
|
||||
});
|
||||
|
||||
const completed = await resolveCompletedBatchResult({
|
||||
provider: "voyage",
|
||||
status: batchInfo,
|
||||
wait: params.wait,
|
||||
waitForBatch: async () =>
|
||||
await waitForVoyageBatch({
|
||||
client: params.client,
|
||||
batchId,
|
||||
wait: params.wait,
|
||||
pollIntervalMs: params.pollIntervalMs,
|
||||
timeoutMs: params.timeoutMs,
|
||||
debug: params.debug,
|
||||
initial: batchInfo,
|
||||
deps,
|
||||
}),
|
||||
});
|
||||
|
||||
const baseUrl = normalizeBatchBaseUrl(params.client);
|
||||
const errors: string[] = [];
|
||||
const remaining = new Set(group.map((request) => request.custom_id));
|
||||
|
||||
await deps.withRemoteHttpResponse({
|
||||
url: `${baseUrl}/files/${completed.outputFileId}/content`,
|
||||
ssrfPolicy: params.client.ssrfPolicy,
|
||||
init: {
|
||||
headers: buildBatchHeaders(params.client, { json: true }),
|
||||
},
|
||||
onResponse: async (contentRes) => {
|
||||
if (!contentRes.ok) {
|
||||
const text = await contentRes.text();
|
||||
throw new Error(`voyage batch file content failed: ${contentRes.status} ${text}`);
|
||||
}
|
||||
|
||||
if (!contentRes.body) {
|
||||
return;
|
||||
}
|
||||
const reader = createInterface({
|
||||
input: Readable.fromWeb(
|
||||
contentRes.body as unknown as import("stream/web").ReadableStream,
|
||||
),
|
||||
terminal: false,
|
||||
});
|
||||
|
||||
for await (const rawLine of reader) {
|
||||
if (!rawLine.trim()) {
|
||||
continue;
|
||||
}
|
||||
const line = JSON.parse(rawLine) as VoyageBatchOutputLine;
|
||||
applyEmbeddingBatchOutputLine({ line, remaining, errors, byCustomId });
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
if (errors.length > 0) {
|
||||
throw new Error(`voyage batch ${batchInfo.id} failed: ${errors.join("; ")}`);
|
||||
}
|
||||
if (remaining.size > 0) {
|
||||
throw new Error(
|
||||
`voyage batch ${batchInfo.id} missing ${remaining.size} embedding responses`,
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
}
|
||||
@@ -1,102 +0,0 @@
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { enforceEmbeddingMaxInputTokens } from "./embedding-chunk-limits.js";
|
||||
import { estimateUtf8Bytes } from "./embedding-input-limits.js";
|
||||
import type { EmbeddingProvider } from "./embeddings.js";
|
||||
|
||||
function createProvider(maxInputTokens: number): EmbeddingProvider {
|
||||
return {
|
||||
id: "mock",
|
||||
model: "mock-embed",
|
||||
maxInputTokens,
|
||||
embedQuery: async () => [0],
|
||||
embedBatch: async () => [[0]],
|
||||
};
|
||||
}
|
||||
|
||||
function createProviderWithoutMaxInputTokens(params: {
|
||||
id: string;
|
||||
model: string;
|
||||
}): EmbeddingProvider {
|
||||
return {
|
||||
id: params.id,
|
||||
model: params.model,
|
||||
embedQuery: async () => [0],
|
||||
embedBatch: async () => [[0]],
|
||||
};
|
||||
}
|
||||
|
||||
describe("embedding chunk limits", () => {
|
||||
it("splits oversized chunks so each embedding input stays <= maxInputTokens bytes", () => {
|
||||
const provider = createProvider(8192);
|
||||
const input = {
|
||||
startLine: 1,
|
||||
endLine: 1,
|
||||
text: "x".repeat(9000),
|
||||
hash: "ignored",
|
||||
};
|
||||
|
||||
const out = enforceEmbeddingMaxInputTokens(provider, [input]);
|
||||
expect(out.length).toBeGreaterThan(1);
|
||||
expect(out.map((chunk) => chunk.text).join("")).toBe(input.text);
|
||||
expect(out.every((chunk) => estimateUtf8Bytes(chunk.text) <= 8192)).toBe(true);
|
||||
expect(out.every((chunk) => chunk.startLine === 1 && chunk.endLine === 1)).toBe(true);
|
||||
expect(out.every((chunk) => typeof chunk.hash === "string" && chunk.hash.length > 0)).toBe(
|
||||
true,
|
||||
);
|
||||
});
|
||||
|
||||
it("does not split inside surrogate pairs (emoji)", () => {
|
||||
const provider = createProvider(8192);
|
||||
const emoji = "😀";
|
||||
const inputText = `${emoji.repeat(2100)}\n${emoji.repeat(2100)}`;
|
||||
|
||||
const out = enforceEmbeddingMaxInputTokens(provider, [
|
||||
{ startLine: 1, endLine: 2, text: inputText, hash: "ignored" },
|
||||
]);
|
||||
|
||||
expect(out.length).toBeGreaterThan(1);
|
||||
expect(out.map((chunk) => chunk.text).join("")).toBe(inputText);
|
||||
expect(out.every((chunk) => estimateUtf8Bytes(chunk.text) <= 8192)).toBe(true);
|
||||
|
||||
// If we split inside surrogate pairs we'd likely end up with replacement chars.
|
||||
expect(out.map((chunk) => chunk.text).join("")).not.toContain("\uFFFD");
|
||||
});
|
||||
|
||||
it("uses conservative fallback limits for local providers without declared maxInputTokens", () => {
|
||||
const provider = createProviderWithoutMaxInputTokens({
|
||||
id: "local",
|
||||
model: "unknown-local-embedding",
|
||||
});
|
||||
|
||||
const out = enforceEmbeddingMaxInputTokens(provider, [
|
||||
{
|
||||
startLine: 1,
|
||||
endLine: 1,
|
||||
text: "x".repeat(3000),
|
||||
hash: "ignored",
|
||||
},
|
||||
]);
|
||||
|
||||
expect(out.length).toBeGreaterThan(1);
|
||||
expect(out.every((chunk) => estimateUtf8Bytes(chunk.text) <= 2048)).toBe(true);
|
||||
});
|
||||
|
||||
it("honors hard safety caps lower than provider maxInputTokens", () => {
|
||||
const provider = createProvider(8192);
|
||||
const out = enforceEmbeddingMaxInputTokens(
|
||||
provider,
|
||||
[
|
||||
{
|
||||
startLine: 1,
|
||||
endLine: 1,
|
||||
text: "x".repeat(8100),
|
||||
hash: "ignored",
|
||||
},
|
||||
],
|
||||
8000,
|
||||
);
|
||||
|
||||
expect(out.length).toBeGreaterThan(1);
|
||||
expect(out.every((chunk) => estimateUtf8Bytes(chunk.text) <= 8000)).toBe(true);
|
||||
});
|
||||
});
|
||||
@@ -1,41 +0,0 @@
|
||||
import { estimateUtf8Bytes, splitTextToUtf8ByteLimit } from "./embedding-input-limits.js";
|
||||
import { hasNonTextEmbeddingParts } from "./embedding-inputs.js";
|
||||
import { resolveEmbeddingMaxInputTokens } from "./embedding-model-limits.js";
|
||||
import type { EmbeddingProvider } from "./embeddings.js";
|
||||
import { hashText, type MemoryChunk } from "./internal.js";
|
||||
|
||||
export function enforceEmbeddingMaxInputTokens(
|
||||
provider: EmbeddingProvider,
|
||||
chunks: MemoryChunk[],
|
||||
hardMaxInputTokens?: number,
|
||||
): MemoryChunk[] {
|
||||
const providerMaxInputTokens = resolveEmbeddingMaxInputTokens(provider);
|
||||
const maxInputTokens =
|
||||
typeof hardMaxInputTokens === "number" && hardMaxInputTokens > 0
|
||||
? Math.min(providerMaxInputTokens, hardMaxInputTokens)
|
||||
: providerMaxInputTokens;
|
||||
const out: MemoryChunk[] = [];
|
||||
|
||||
for (const chunk of chunks) {
|
||||
if (hasNonTextEmbeddingParts(chunk.embeddingInput)) {
|
||||
out.push(chunk);
|
||||
continue;
|
||||
}
|
||||
if (estimateUtf8Bytes(chunk.text) <= maxInputTokens) {
|
||||
out.push(chunk);
|
||||
continue;
|
||||
}
|
||||
|
||||
for (const text of splitTextToUtf8ByteLimit(chunk.text, maxInputTokens)) {
|
||||
out.push({
|
||||
startLine: chunk.startLine,
|
||||
endLine: chunk.endLine,
|
||||
text,
|
||||
hash: hashText(text),
|
||||
embeddingInput: { text },
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
@@ -1,85 +0,0 @@
|
||||
import type { EmbeddingInput } from "./embedding-inputs.js";
|
||||
|
||||
// Helpers for enforcing embedding model input size limits.
|
||||
//
|
||||
// We use UTF-8 byte length as a conservative upper bound for tokenizer output.
|
||||
// Tokenizers operate over bytes; a token must contain at least one byte, so
|
||||
// token_count <= utf8_byte_length.
|
||||
|
||||
export function estimateUtf8Bytes(text: string): number {
|
||||
if (!text) {
|
||||
return 0;
|
||||
}
|
||||
return Buffer.byteLength(text, "utf8");
|
||||
}
|
||||
|
||||
export function estimateStructuredEmbeddingInputBytes(input: EmbeddingInput): number {
|
||||
if (!input.parts?.length) {
|
||||
return estimateUtf8Bytes(input.text);
|
||||
}
|
||||
let total = 0;
|
||||
for (const part of input.parts) {
|
||||
if (part.type === "text") {
|
||||
total += estimateUtf8Bytes(part.text);
|
||||
continue;
|
||||
}
|
||||
total += estimateUtf8Bytes(part.mimeType);
|
||||
total += estimateUtf8Bytes(part.data);
|
||||
}
|
||||
return total;
|
||||
}
|
||||
|
||||
export function splitTextToUtf8ByteLimit(text: string, maxUtf8Bytes: number): string[] {
|
||||
if (maxUtf8Bytes <= 0) {
|
||||
return [text];
|
||||
}
|
||||
if (estimateUtf8Bytes(text) <= maxUtf8Bytes) {
|
||||
return [text];
|
||||
}
|
||||
|
||||
const parts: string[] = [];
|
||||
let cursor = 0;
|
||||
while (cursor < text.length) {
|
||||
// The number of UTF-16 code units is always <= the number of UTF-8 bytes.
|
||||
// This makes `cursor + maxUtf8Bytes` a safe upper bound on the next split point.
|
||||
let low = cursor + 1;
|
||||
let high = Math.min(text.length, cursor + maxUtf8Bytes);
|
||||
let best = cursor;
|
||||
|
||||
while (low <= high) {
|
||||
const mid = Math.floor((low + high) / 2);
|
||||
const bytes = estimateUtf8Bytes(text.slice(cursor, mid));
|
||||
if (bytes <= maxUtf8Bytes) {
|
||||
best = mid;
|
||||
low = mid + 1;
|
||||
} else {
|
||||
high = mid - 1;
|
||||
}
|
||||
}
|
||||
|
||||
if (best <= cursor) {
|
||||
best = Math.min(text.length, cursor + 1);
|
||||
}
|
||||
|
||||
// Avoid splitting inside a surrogate pair.
|
||||
if (
|
||||
best < text.length &&
|
||||
best > cursor &&
|
||||
text.charCodeAt(best - 1) >= 0xd800 &&
|
||||
text.charCodeAt(best - 1) <= 0xdbff &&
|
||||
text.charCodeAt(best) >= 0xdc00 &&
|
||||
text.charCodeAt(best) <= 0xdfff
|
||||
) {
|
||||
best -= 1;
|
||||
}
|
||||
|
||||
const part = text.slice(cursor, best);
|
||||
if (!part) {
|
||||
break;
|
||||
}
|
||||
parts.push(part);
|
||||
cursor = best;
|
||||
}
|
||||
|
||||
return parts;
|
||||
}
|
||||
@@ -1,34 +0,0 @@
|
||||
export type EmbeddingInputTextPart = {
|
||||
type: "text";
|
||||
text: string;
|
||||
};
|
||||
|
||||
export type EmbeddingInputInlineDataPart = {
|
||||
type: "inline-data";
|
||||
mimeType: string;
|
||||
data: string;
|
||||
};
|
||||
|
||||
export type EmbeddingInputPart = EmbeddingInputTextPart | EmbeddingInputInlineDataPart;
|
||||
|
||||
export type EmbeddingInput = {
|
||||
text: string;
|
||||
parts?: EmbeddingInputPart[];
|
||||
};
|
||||
|
||||
export function buildTextEmbeddingInput(text: string): EmbeddingInput {
|
||||
return { text };
|
||||
}
|
||||
|
||||
export function isInlineDataEmbeddingInputPart(
|
||||
part: EmbeddingInputPart,
|
||||
): part is EmbeddingInputInlineDataPart {
|
||||
return part.type === "inline-data";
|
||||
}
|
||||
|
||||
export function hasNonTextEmbeddingParts(input: EmbeddingInput | undefined): boolean {
|
||||
if (!input?.parts?.length) {
|
||||
return false;
|
||||
}
|
||||
return input.parts.some((part) => isInlineDataEmbeddingInputPart(part));
|
||||
}
|
||||
@@ -1,41 +0,0 @@
|
||||
import type { EmbeddingProvider } from "./embeddings.js";
|
||||
|
||||
const DEFAULT_EMBEDDING_MAX_INPUT_TOKENS = 8192;
|
||||
const DEFAULT_LOCAL_EMBEDDING_MAX_INPUT_TOKENS = 2048;
|
||||
|
||||
const KNOWN_EMBEDDING_MAX_INPUT_TOKENS: Record<string, number> = {
|
||||
"openai:text-embedding-3-small": 8192,
|
||||
"openai:text-embedding-3-large": 8192,
|
||||
"openai:text-embedding-ada-002": 8191,
|
||||
"gemini:text-embedding-004": 2048,
|
||||
"gemini:gemini-embedding-001": 2048,
|
||||
"gemini:gemini-embedding-2-preview": 8192,
|
||||
"voyage:voyage-3": 32000,
|
||||
"voyage:voyage-3-lite": 16000,
|
||||
"voyage:voyage-code-3": 32000,
|
||||
};
|
||||
|
||||
export function resolveEmbeddingMaxInputTokens(provider: EmbeddingProvider): number {
|
||||
if (typeof provider.maxInputTokens === "number") {
|
||||
return provider.maxInputTokens;
|
||||
}
|
||||
|
||||
// Provider/model mapping is best-effort; different providers use different
|
||||
// limits and we prefer to be conservative when we don't know.
|
||||
const key = `${provider.id}:${provider.model}`.toLowerCase();
|
||||
const known = KNOWN_EMBEDDING_MAX_INPUT_TOKENS[key];
|
||||
if (typeof known === "number") {
|
||||
return known;
|
||||
}
|
||||
|
||||
// Provider-specific conservative fallbacks. This prevents us from accidentally
|
||||
// using the OpenAI default for providers with much smaller limits.
|
||||
if (provider.id.toLowerCase() === "gemini") {
|
||||
return 2048;
|
||||
}
|
||||
if (provider.id.toLowerCase() === "local") {
|
||||
return DEFAULT_LOCAL_EMBEDDING_MAX_INPUT_TOKENS;
|
||||
}
|
||||
|
||||
return DEFAULT_EMBEDDING_MAX_INPUT_TOKENS;
|
||||
}
|
||||
@@ -1,8 +0,0 @@
|
||||
export function sanitizeAndNormalizeEmbedding(vec: number[]): number[] {
|
||||
const sanitized = vec.map((value) => (Number.isFinite(value) ? value : 0));
|
||||
const magnitude = Math.sqrt(sanitized.reduce((sum, value) => sum + value * value, 0));
|
||||
if (magnitude < 1e-10) {
|
||||
return sanitized;
|
||||
}
|
||||
return sanitized.map((value) => value / magnitude);
|
||||
}
|
||||
@@ -1,13 +0,0 @@
|
||||
import { isTruthyEnvValue } from "../../infra/env.js";
|
||||
import { createSubsystemLogger } from "../../logging/subsystem.js";
|
||||
|
||||
const debugEmbeddings = isTruthyEnvValue(process.env.OPENCLAW_DEBUG_MEMORY_EMBEDDINGS);
|
||||
const log = createSubsystemLogger("memory/embeddings");
|
||||
|
||||
export function debugEmbeddingsLog(message: string, meta?: Record<string, unknown>): void {
|
||||
if (!debugEmbeddings) {
|
||||
return;
|
||||
}
|
||||
const suffix = meta ? ` ${JSON.stringify(meta)}` : "";
|
||||
log.raw(`${message}${suffix}`);
|
||||
}
|
||||
@@ -1,570 +0,0 @@
|
||||
import { afterEach, beforeAll, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import * as authModule from "../../agents/model-auth.js";
|
||||
import { mockPublicPinnedHostname } from "./test-helpers/ssrf.js";
|
||||
|
||||
vi.mock("../../agents/model-auth.js", async () => {
|
||||
const { createModelAuthMockModule } = await import("../../test-utils/model-auth-mock.js");
|
||||
return createModelAuthMockModule();
|
||||
});
|
||||
|
||||
const createGeminiFetchMock = (embeddingValues = [1, 2, 3]) =>
|
||||
vi.fn(async (_input?: unknown, _init?: unknown) => ({
|
||||
ok: true,
|
||||
status: 200,
|
||||
json: async () => ({ embedding: { values: embeddingValues } }),
|
||||
}));
|
||||
|
||||
const createGeminiBatchFetchMock = (count: number, embeddingValues = [1, 2, 3]) =>
|
||||
vi.fn(async (_input?: unknown, _init?: unknown) => ({
|
||||
ok: true,
|
||||
status: 200,
|
||||
json: async () => ({
|
||||
embeddings: Array.from({ length: count }, () => ({ values: embeddingValues })),
|
||||
}),
|
||||
}));
|
||||
|
||||
function readFirstFetchRequest(fetchMock: { mock: { calls: unknown[][] } }) {
|
||||
const [url, init] = fetchMock.mock.calls[0] ?? [];
|
||||
return { url, init: init as RequestInit | undefined };
|
||||
}
|
||||
|
||||
function parseFetchBody(fetchMock: { mock: { calls: unknown[][] } }, callIndex = 0) {
|
||||
const init = fetchMock.mock.calls[callIndex]?.[1] as RequestInit | undefined;
|
||||
return JSON.parse((init?.body as string) ?? "{}") as Record<string, unknown>;
|
||||
}
|
||||
|
||||
function magnitude(values: number[]) {
|
||||
return Math.sqrt(values.reduce((sum, value) => sum + value * value, 0));
|
||||
}
|
||||
|
||||
let buildGeminiEmbeddingRequest: typeof import("./embeddings-gemini.js").buildGeminiEmbeddingRequest;
|
||||
let buildGeminiTextEmbeddingRequest: typeof import("./embeddings-gemini.js").buildGeminiTextEmbeddingRequest;
|
||||
let createGeminiEmbeddingProvider: typeof import("./embeddings-gemini.js").createGeminiEmbeddingProvider;
|
||||
let DEFAULT_GEMINI_EMBEDDING_MODEL: typeof import("./embeddings-gemini.js").DEFAULT_GEMINI_EMBEDDING_MODEL;
|
||||
let GEMINI_EMBEDDING_2_MODELS: typeof import("./embeddings-gemini.js").GEMINI_EMBEDDING_2_MODELS;
|
||||
let isGeminiEmbedding2Model: typeof import("./embeddings-gemini.js").isGeminiEmbedding2Model;
|
||||
let resolveGeminiOutputDimensionality: typeof import("./embeddings-gemini.js").resolveGeminiOutputDimensionality;
|
||||
|
||||
beforeAll(async () => {
|
||||
vi.doUnmock("undici");
|
||||
vi.resetModules();
|
||||
({
|
||||
buildGeminiEmbeddingRequest,
|
||||
buildGeminiTextEmbeddingRequest,
|
||||
createGeminiEmbeddingProvider,
|
||||
DEFAULT_GEMINI_EMBEDDING_MODEL,
|
||||
GEMINI_EMBEDDING_2_MODELS,
|
||||
isGeminiEmbedding2Model,
|
||||
resolveGeminiOutputDimensionality,
|
||||
} = await import("./embeddings-gemini.js"));
|
||||
});
|
||||
|
||||
beforeEach(() => {
|
||||
vi.useRealTimers();
|
||||
vi.doUnmock("undici");
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.doUnmock("undici");
|
||||
vi.resetAllMocks();
|
||||
vi.unstubAllGlobals();
|
||||
});
|
||||
|
||||
function mockResolvedProviderKey(apiKey = "test-key") {
|
||||
vi.mocked(authModule.resolveApiKeyForProvider).mockResolvedValue({
|
||||
apiKey,
|
||||
mode: "api-key",
|
||||
source: "test",
|
||||
});
|
||||
}
|
||||
|
||||
type GeminiFetchMock =
|
||||
| ReturnType<typeof createGeminiFetchMock>
|
||||
| ReturnType<typeof createGeminiBatchFetchMock>;
|
||||
|
||||
async function createProviderWithFetch(
|
||||
fetchMock: GeminiFetchMock,
|
||||
options: Partial<Parameters<typeof createGeminiEmbeddingProvider>[0]> & { model: string },
|
||||
) {
|
||||
vi.stubGlobal("fetch", fetchMock);
|
||||
mockPublicPinnedHostname();
|
||||
mockResolvedProviderKey();
|
||||
const { provider } = await createGeminiEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "gemini",
|
||||
fallback: "none",
|
||||
...options,
|
||||
});
|
||||
return provider;
|
||||
}
|
||||
|
||||
function expectNormalizedThreeFourVector(embedding: number[]) {
|
||||
expect(embedding[0]).toBeCloseTo(0.6, 5);
|
||||
expect(embedding[1]).toBeCloseTo(0.8, 5);
|
||||
expect(magnitude(embedding)).toBeCloseTo(1, 5);
|
||||
}
|
||||
|
||||
describe("buildGeminiTextEmbeddingRequest", () => {
|
||||
it("builds a text embedding request with optional model and dimensions", () => {
|
||||
expect(
|
||||
buildGeminiTextEmbeddingRequest({
|
||||
text: "hello",
|
||||
taskType: "RETRIEVAL_DOCUMENT",
|
||||
modelPath: "models/gemini-embedding-2-preview",
|
||||
outputDimensionality: 1536,
|
||||
}),
|
||||
).toEqual({
|
||||
model: "models/gemini-embedding-2-preview",
|
||||
content: { parts: [{ text: "hello" }] },
|
||||
taskType: "RETRIEVAL_DOCUMENT",
|
||||
outputDimensionality: 1536,
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("buildGeminiEmbeddingRequest", () => {
|
||||
it("builds a multimodal request from structured input parts", () => {
|
||||
expect(
|
||||
buildGeminiEmbeddingRequest({
|
||||
input: {
|
||||
text: "Image file: diagram.png",
|
||||
parts: [
|
||||
{ type: "text", text: "Image file: diagram.png" },
|
||||
{ type: "inline-data", mimeType: "image/png", data: "abc123" },
|
||||
],
|
||||
},
|
||||
taskType: "RETRIEVAL_DOCUMENT",
|
||||
modelPath: "models/gemini-embedding-2-preview",
|
||||
outputDimensionality: 1536,
|
||||
}),
|
||||
).toEqual({
|
||||
model: "models/gemini-embedding-2-preview",
|
||||
content: {
|
||||
parts: [
|
||||
{ text: "Image file: diagram.png" },
|
||||
{ inlineData: { mimeType: "image/png", data: "abc123" } },
|
||||
],
|
||||
},
|
||||
taskType: "RETRIEVAL_DOCUMENT",
|
||||
outputDimensionality: 1536,
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// ---------- Model detection ----------
|
||||
|
||||
describe("isGeminiEmbedding2Model", () => {
|
||||
it("returns true for gemini-embedding-2-preview", () => {
|
||||
expect(isGeminiEmbedding2Model("gemini-embedding-2-preview")).toBe(true);
|
||||
});
|
||||
|
||||
it("returns false for gemini-embedding-001", () => {
|
||||
expect(isGeminiEmbedding2Model("gemini-embedding-001")).toBe(false);
|
||||
});
|
||||
|
||||
it("returns false for text-embedding-004", () => {
|
||||
expect(isGeminiEmbedding2Model("text-embedding-004")).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe("GEMINI_EMBEDDING_2_MODELS", () => {
|
||||
it("contains gemini-embedding-2-preview", () => {
|
||||
expect(GEMINI_EMBEDDING_2_MODELS.has("gemini-embedding-2-preview")).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
// ---------- Dimension resolution ----------
|
||||
|
||||
describe("resolveGeminiOutputDimensionality", () => {
|
||||
it("returns undefined for non-v2 models", () => {
|
||||
expect(resolveGeminiOutputDimensionality("gemini-embedding-001")).toBeUndefined();
|
||||
expect(resolveGeminiOutputDimensionality("text-embedding-004")).toBeUndefined();
|
||||
});
|
||||
|
||||
it("returns 3072 by default for v2 models", () => {
|
||||
expect(resolveGeminiOutputDimensionality("gemini-embedding-2-preview")).toBe(3072);
|
||||
});
|
||||
|
||||
it("accepts valid dimension values", () => {
|
||||
expect(resolveGeminiOutputDimensionality("gemini-embedding-2-preview", 768)).toBe(768);
|
||||
expect(resolveGeminiOutputDimensionality("gemini-embedding-2-preview", 1536)).toBe(1536);
|
||||
expect(resolveGeminiOutputDimensionality("gemini-embedding-2-preview", 3072)).toBe(3072);
|
||||
});
|
||||
|
||||
it("throws for invalid dimension values", () => {
|
||||
expect(() => resolveGeminiOutputDimensionality("gemini-embedding-2-preview", 512)).toThrow(
|
||||
/Invalid outputDimensionality 512/,
|
||||
);
|
||||
expect(() => resolveGeminiOutputDimensionality("gemini-embedding-2-preview", 1024)).toThrow(
|
||||
/Valid values: 768, 1536, 3072/,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
// ---------- Provider: gemini-embedding-001 (backward compat) ----------
|
||||
|
||||
describe("gemini-embedding-001 provider (backward compat)", () => {
|
||||
it("does NOT include outputDimensionality in embedQuery", async () => {
|
||||
const fetchMock = createGeminiFetchMock();
|
||||
const provider = await createProviderWithFetch(fetchMock, {
|
||||
model: "gemini-embedding-001",
|
||||
});
|
||||
|
||||
await provider.embedQuery("test query");
|
||||
|
||||
const body = parseFetchBody(fetchMock);
|
||||
expect(body).not.toHaveProperty("outputDimensionality");
|
||||
expect(body.taskType).toBe("RETRIEVAL_QUERY");
|
||||
expect(body.content).toEqual({ parts: [{ text: "test query" }] });
|
||||
});
|
||||
|
||||
it("does NOT include outputDimensionality in embedBatch", async () => {
|
||||
const fetchMock = createGeminiBatchFetchMock(2);
|
||||
const provider = await createProviderWithFetch(fetchMock, {
|
||||
model: "gemini-embedding-001",
|
||||
});
|
||||
|
||||
await provider.embedBatch(["text1", "text2"]);
|
||||
|
||||
const body = parseFetchBody(fetchMock);
|
||||
expect(body).not.toHaveProperty("outputDimensionality");
|
||||
});
|
||||
});
|
||||
|
||||
// ---------- Provider: gemini-embedding-2-preview ----------
|
||||
|
||||
describe("gemini-embedding-2-preview provider", () => {
|
||||
it("includes outputDimensionality in embedQuery request", async () => {
|
||||
const fetchMock = createGeminiFetchMock();
|
||||
const provider = await createProviderWithFetch(fetchMock, {
|
||||
model: "gemini-embedding-2-preview",
|
||||
});
|
||||
|
||||
await provider.embedQuery("test query");
|
||||
|
||||
const body = parseFetchBody(fetchMock);
|
||||
expect(body.outputDimensionality).toBe(3072);
|
||||
expect(body.taskType).toBe("RETRIEVAL_QUERY");
|
||||
expect(body.content).toEqual({ parts: [{ text: "test query" }] });
|
||||
});
|
||||
|
||||
it("normalizes embedQuery response vectors", async () => {
|
||||
const fetchMock = createGeminiFetchMock([3, 4]);
|
||||
const provider = await createProviderWithFetch(fetchMock, {
|
||||
model: "gemini-embedding-2-preview",
|
||||
});
|
||||
|
||||
const embedding = await provider.embedQuery("test query");
|
||||
|
||||
expectNormalizedThreeFourVector(embedding);
|
||||
});
|
||||
|
||||
it("includes outputDimensionality in embedBatch request", async () => {
|
||||
const fetchMock = createGeminiBatchFetchMock(2);
|
||||
const provider = await createProviderWithFetch(fetchMock, {
|
||||
model: "gemini-embedding-2-preview",
|
||||
});
|
||||
|
||||
await provider.embedBatch(["text1", "text2"]);
|
||||
|
||||
const body = parseFetchBody(fetchMock);
|
||||
expect(body.requests).toEqual([
|
||||
{
|
||||
model: "models/gemini-embedding-2-preview",
|
||||
content: { parts: [{ text: "text1" }] },
|
||||
taskType: "RETRIEVAL_DOCUMENT",
|
||||
outputDimensionality: 3072,
|
||||
},
|
||||
{
|
||||
model: "models/gemini-embedding-2-preview",
|
||||
content: { parts: [{ text: "text2" }] },
|
||||
taskType: "RETRIEVAL_DOCUMENT",
|
||||
outputDimensionality: 3072,
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it("normalizes embedBatch response vectors", async () => {
|
||||
const fetchMock = createGeminiBatchFetchMock(2, [3, 4]);
|
||||
const provider = await createProviderWithFetch(fetchMock, {
|
||||
model: "gemini-embedding-2-preview",
|
||||
});
|
||||
|
||||
const embeddings = await provider.embedBatch(["text1", "text2"]);
|
||||
|
||||
expect(embeddings).toHaveLength(2);
|
||||
for (const embedding of embeddings) {
|
||||
expectNormalizedThreeFourVector(embedding);
|
||||
}
|
||||
});
|
||||
|
||||
it("respects custom outputDimensionality", async () => {
|
||||
const fetchMock = createGeminiFetchMock();
|
||||
const provider = await createProviderWithFetch(fetchMock, {
|
||||
model: "gemini-embedding-2-preview",
|
||||
outputDimensionality: 768,
|
||||
});
|
||||
|
||||
await provider.embedQuery("test");
|
||||
|
||||
const body = parseFetchBody(fetchMock);
|
||||
expect(body.outputDimensionality).toBe(768);
|
||||
});
|
||||
|
||||
it("sanitizes and normalizes embedQuery responses", async () => {
|
||||
const fetchMock = createGeminiFetchMock([3, 4, Number.NaN]);
|
||||
const provider = await createProviderWithFetch(fetchMock, {
|
||||
model: "gemini-embedding-2-preview",
|
||||
});
|
||||
|
||||
await expect(provider.embedQuery("test")).resolves.toEqual([0.6, 0.8, 0]);
|
||||
});
|
||||
|
||||
it("uses custom outputDimensionality for each embedBatch request", async () => {
|
||||
const fetchMock = createGeminiBatchFetchMock(2);
|
||||
const provider = await createProviderWithFetch(fetchMock, {
|
||||
model: "gemini-embedding-2-preview",
|
||||
outputDimensionality: 768,
|
||||
});
|
||||
|
||||
await provider.embedBatch(["text1", "text2"]);
|
||||
|
||||
const body = parseFetchBody(fetchMock);
|
||||
expect(body.requests).toEqual([
|
||||
expect.objectContaining({ outputDimensionality: 768 }),
|
||||
expect.objectContaining({ outputDimensionality: 768 }),
|
||||
]);
|
||||
});
|
||||
|
||||
it("sanitizes and normalizes structured batch responses", async () => {
|
||||
const fetchMock = createGeminiBatchFetchMock(1, [0, Number.POSITIVE_INFINITY, 5]);
|
||||
const provider = await createProviderWithFetch(fetchMock, {
|
||||
model: "gemini-embedding-2-preview",
|
||||
});
|
||||
|
||||
await expect(
|
||||
provider.embedBatchInputs?.([
|
||||
{
|
||||
text: "Image file: diagram.png",
|
||||
parts: [
|
||||
{ type: "text", text: "Image file: diagram.png" },
|
||||
{ type: "inline-data", mimeType: "image/png", data: "img" },
|
||||
],
|
||||
},
|
||||
]),
|
||||
).resolves.toEqual([[0, 0, 1]]);
|
||||
});
|
||||
|
||||
it("supports multimodal embedBatchInputs requests", async () => {
|
||||
const fetchMock = createGeminiBatchFetchMock(2);
|
||||
const provider = await createProviderWithFetch(fetchMock, {
|
||||
model: "gemini-embedding-2-preview",
|
||||
});
|
||||
|
||||
expect(provider.embedBatchInputs).toBeDefined();
|
||||
await provider.embedBatchInputs?.([
|
||||
{
|
||||
text: "Image file: diagram.png",
|
||||
parts: [
|
||||
{ type: "text", text: "Image file: diagram.png" },
|
||||
{ type: "inline-data", mimeType: "image/png", data: "img" },
|
||||
],
|
||||
},
|
||||
{
|
||||
text: "Audio file: note.wav",
|
||||
parts: [
|
||||
{ type: "text", text: "Audio file: note.wav" },
|
||||
{ type: "inline-data", mimeType: "audio/wav", data: "aud" },
|
||||
],
|
||||
},
|
||||
]);
|
||||
|
||||
const body = parseFetchBody(fetchMock);
|
||||
expect(body.requests).toEqual([
|
||||
{
|
||||
model: "models/gemini-embedding-2-preview",
|
||||
content: {
|
||||
parts: [
|
||||
{ text: "Image file: diagram.png" },
|
||||
{ inlineData: { mimeType: "image/png", data: "img" } },
|
||||
],
|
||||
},
|
||||
taskType: "RETRIEVAL_DOCUMENT",
|
||||
outputDimensionality: 3072,
|
||||
},
|
||||
{
|
||||
model: "models/gemini-embedding-2-preview",
|
||||
content: {
|
||||
parts: [
|
||||
{ text: "Audio file: note.wav" },
|
||||
{ inlineData: { mimeType: "audio/wav", data: "aud" } },
|
||||
],
|
||||
},
|
||||
taskType: "RETRIEVAL_DOCUMENT",
|
||||
outputDimensionality: 3072,
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it("throws for invalid outputDimensionality", async () => {
|
||||
mockResolvedProviderKey();
|
||||
|
||||
await expect(
|
||||
createGeminiEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "gemini",
|
||||
model: "gemini-embedding-2-preview",
|
||||
fallback: "none",
|
||||
outputDimensionality: 512,
|
||||
}),
|
||||
).rejects.toThrow(/Invalid outputDimensionality 512/);
|
||||
});
|
||||
|
||||
it("sanitizes non-finite values before normalization", async () => {
|
||||
const fetchMock = createGeminiFetchMock([
|
||||
1,
|
||||
Number.NaN,
|
||||
Number.POSITIVE_INFINITY,
|
||||
Number.NEGATIVE_INFINITY,
|
||||
]);
|
||||
const provider = await createProviderWithFetch(fetchMock, {
|
||||
model: "gemini-embedding-2-preview",
|
||||
});
|
||||
|
||||
const embedding = await provider.embedQuery("test");
|
||||
|
||||
expect(embedding).toEqual([1, 0, 0, 0]);
|
||||
});
|
||||
|
||||
it("uses correct endpoint URL", async () => {
|
||||
const fetchMock = createGeminiFetchMock();
|
||||
const provider = await createProviderWithFetch(fetchMock, {
|
||||
model: "gemini-embedding-2-preview",
|
||||
});
|
||||
|
||||
await provider.embedQuery("test");
|
||||
|
||||
const { url } = readFirstFetchRequest(fetchMock);
|
||||
expect(url).toBe(
|
||||
"https://generativelanguage.googleapis.com/v1beta/models/gemini-embedding-2-preview:embedContent",
|
||||
);
|
||||
});
|
||||
|
||||
it("allows taskType override via options", async () => {
|
||||
const fetchMock = createGeminiFetchMock();
|
||||
const provider = await createProviderWithFetch(fetchMock, {
|
||||
model: "gemini-embedding-2-preview",
|
||||
taskType: "SEMANTIC_SIMILARITY",
|
||||
});
|
||||
|
||||
await provider.embedQuery("test");
|
||||
|
||||
const body = parseFetchBody(fetchMock);
|
||||
expect(body.taskType).toBe("SEMANTIC_SIMILARITY");
|
||||
});
|
||||
});
|
||||
|
||||
// ---------- Model normalization ----------
|
||||
|
||||
describe("gemini model normalization", () => {
|
||||
it("handles models/ prefix for v2 model", async () => {
|
||||
const fetchMock = createGeminiFetchMock();
|
||||
vi.stubGlobal("fetch", fetchMock);
|
||||
mockPublicPinnedHostname();
|
||||
mockResolvedProviderKey();
|
||||
|
||||
const { provider } = await createGeminiEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "gemini",
|
||||
model: "models/gemini-embedding-2-preview",
|
||||
fallback: "none",
|
||||
});
|
||||
|
||||
await provider.embedQuery("test");
|
||||
|
||||
const body = parseFetchBody(fetchMock);
|
||||
expect(body.outputDimensionality).toBe(3072);
|
||||
});
|
||||
|
||||
it("handles gemini/ prefix for v2 model", async () => {
|
||||
const fetchMock = createGeminiFetchMock();
|
||||
vi.stubGlobal("fetch", fetchMock);
|
||||
mockPublicPinnedHostname();
|
||||
mockResolvedProviderKey();
|
||||
|
||||
const { provider } = await createGeminiEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "gemini",
|
||||
model: "gemini/gemini-embedding-2-preview",
|
||||
fallback: "none",
|
||||
});
|
||||
|
||||
await provider.embedQuery("test");
|
||||
|
||||
const body = parseFetchBody(fetchMock);
|
||||
expect(body.outputDimensionality).toBe(3072);
|
||||
});
|
||||
|
||||
it("handles google/ prefix for v2 model", async () => {
|
||||
const fetchMock = createGeminiFetchMock();
|
||||
vi.stubGlobal("fetch", fetchMock);
|
||||
mockPublicPinnedHostname();
|
||||
mockResolvedProviderKey();
|
||||
|
||||
const { provider } = await createGeminiEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "gemini",
|
||||
model: "google/gemini-embedding-2-preview",
|
||||
fallback: "none",
|
||||
});
|
||||
|
||||
await provider.embedQuery("test");
|
||||
|
||||
const body = parseFetchBody(fetchMock);
|
||||
expect(body.outputDimensionality).toBe(3072);
|
||||
});
|
||||
|
||||
it("defaults to gemini-embedding-001 when model is empty", async () => {
|
||||
const fetchMock = createGeminiFetchMock();
|
||||
vi.stubGlobal("fetch", fetchMock);
|
||||
mockResolvedProviderKey();
|
||||
|
||||
const { provider, client } = await createGeminiEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "gemini",
|
||||
model: "",
|
||||
fallback: "none",
|
||||
});
|
||||
|
||||
expect(client.model).toBe(DEFAULT_GEMINI_EMBEDDING_MODEL);
|
||||
expect(provider.model).toBe(DEFAULT_GEMINI_EMBEDDING_MODEL);
|
||||
});
|
||||
|
||||
it("returns empty array for blank query text", async () => {
|
||||
mockResolvedProviderKey();
|
||||
|
||||
const { provider } = await createGeminiEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "gemini",
|
||||
model: "gemini-embedding-2-preview",
|
||||
fallback: "none",
|
||||
});
|
||||
|
||||
const result = await provider.embedQuery(" ");
|
||||
expect(result).toEqual([]);
|
||||
});
|
||||
|
||||
it("returns empty array for empty batch", async () => {
|
||||
mockResolvedProviderKey();
|
||||
|
||||
const { provider } = await createGeminiEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "gemini",
|
||||
model: "gemini-embedding-2-preview",
|
||||
fallback: "none",
|
||||
});
|
||||
|
||||
const result = await provider.embedBatch([]);
|
||||
expect(result).toEqual([]);
|
||||
});
|
||||
});
|
||||
@@ -1,336 +0,0 @@
|
||||
import {
|
||||
collectProviderApiKeysForExecution,
|
||||
executeWithApiKeyRotation,
|
||||
} from "../../agents/api-key-rotation.js";
|
||||
import { requireApiKey, resolveApiKeyForProvider } from "../../agents/model-auth.js";
|
||||
import { parseGeminiAuth } from "../../infra/gemini-auth.js";
|
||||
import {
|
||||
DEFAULT_GOOGLE_API_BASE_URL,
|
||||
normalizeGoogleApiBaseUrl,
|
||||
} from "../../infra/google-api-base-url.js";
|
||||
import type { SsrFPolicy } from "../../infra/net/ssrf.js";
|
||||
import type { EmbeddingInput } from "./embedding-inputs.js";
|
||||
import { sanitizeAndNormalizeEmbedding } from "./embedding-vectors.js";
|
||||
import { debugEmbeddingsLog } from "./embeddings-debug.js";
|
||||
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js";
|
||||
import { buildRemoteBaseUrlPolicy, withRemoteHttpResponse } from "./remote-http.js";
|
||||
import { resolveMemorySecretInputString } from "./secret-input.js";
|
||||
|
||||
export type GeminiEmbeddingClient = {
|
||||
baseUrl: string;
|
||||
headers: Record<string, string>;
|
||||
ssrfPolicy?: SsrFPolicy;
|
||||
model: string;
|
||||
modelPath: string;
|
||||
apiKeys: string[];
|
||||
outputDimensionality?: number;
|
||||
};
|
||||
|
||||
export const DEFAULT_GEMINI_EMBEDDING_MODEL = "gemini-embedding-001";
|
||||
const GEMINI_MAX_INPUT_TOKENS: Record<string, number> = {
|
||||
"text-embedding-004": 2048,
|
||||
};
|
||||
|
||||
// --- gemini-embedding-2-preview support ---
|
||||
|
||||
export const GEMINI_EMBEDDING_2_MODELS = new Set([
|
||||
"gemini-embedding-2-preview",
|
||||
// Add the GA model name here once released.
|
||||
]);
|
||||
|
||||
const GEMINI_EMBEDDING_2_DEFAULT_DIMENSIONS = 3072;
|
||||
const GEMINI_EMBEDDING_2_VALID_DIMENSIONS = [768, 1536, 3072] as const;
|
||||
|
||||
export type GeminiTaskType =
|
||||
| "RETRIEVAL_QUERY"
|
||||
| "RETRIEVAL_DOCUMENT"
|
||||
| "SEMANTIC_SIMILARITY"
|
||||
| "CLASSIFICATION"
|
||||
| "CLUSTERING"
|
||||
| "QUESTION_ANSWERING"
|
||||
| "FACT_VERIFICATION";
|
||||
|
||||
export type GeminiTextPart = { text: string };
|
||||
export type GeminiInlinePart = {
|
||||
inlineData: { mimeType: string; data: string };
|
||||
};
|
||||
export type GeminiPart = GeminiTextPart | GeminiInlinePart;
|
||||
export type GeminiEmbeddingRequest = {
|
||||
content: { parts: GeminiPart[] };
|
||||
taskType: GeminiTaskType;
|
||||
outputDimensionality?: number;
|
||||
model?: string;
|
||||
};
|
||||
export type GeminiTextEmbeddingRequest = GeminiEmbeddingRequest;
|
||||
|
||||
/** Builds the text-only Gemini embedding request shape used across direct and batch APIs. */
|
||||
export function buildGeminiTextEmbeddingRequest(params: {
|
||||
text: string;
|
||||
taskType: GeminiTaskType;
|
||||
outputDimensionality?: number;
|
||||
modelPath?: string;
|
||||
}): GeminiTextEmbeddingRequest {
|
||||
return buildGeminiEmbeddingRequest({
|
||||
input: { text: params.text },
|
||||
taskType: params.taskType,
|
||||
outputDimensionality: params.outputDimensionality,
|
||||
modelPath: params.modelPath,
|
||||
});
|
||||
}
|
||||
|
||||
export function buildGeminiEmbeddingRequest(params: {
|
||||
input: EmbeddingInput;
|
||||
taskType: GeminiTaskType;
|
||||
outputDimensionality?: number;
|
||||
modelPath?: string;
|
||||
}): GeminiEmbeddingRequest {
|
||||
const request: GeminiEmbeddingRequest = {
|
||||
content: {
|
||||
parts: params.input.parts?.map((part) =>
|
||||
part.type === "text"
|
||||
? ({ text: part.text } satisfies GeminiTextPart)
|
||||
: ({
|
||||
inlineData: { mimeType: part.mimeType, data: part.data },
|
||||
} satisfies GeminiInlinePart),
|
||||
) ?? [{ text: params.input.text }],
|
||||
},
|
||||
taskType: params.taskType,
|
||||
};
|
||||
if (params.modelPath) {
|
||||
request.model = params.modelPath;
|
||||
}
|
||||
if (params.outputDimensionality != null) {
|
||||
request.outputDimensionality = params.outputDimensionality;
|
||||
}
|
||||
return request;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true if the given model name is a gemini-embedding-2 variant that
|
||||
* supports `outputDimensionality` and extended task types.
|
||||
*/
|
||||
export function isGeminiEmbedding2Model(model: string): boolean {
|
||||
return GEMINI_EMBEDDING_2_MODELS.has(model);
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate and return the `outputDimensionality` for gemini-embedding-2 models.
|
||||
* Returns `undefined` for older models (they don't support the param).
|
||||
*/
|
||||
export function resolveGeminiOutputDimensionality(
|
||||
model: string,
|
||||
requested?: number,
|
||||
): number | undefined {
|
||||
if (!isGeminiEmbedding2Model(model)) {
|
||||
return undefined;
|
||||
}
|
||||
if (requested == null) {
|
||||
return GEMINI_EMBEDDING_2_DEFAULT_DIMENSIONS;
|
||||
}
|
||||
const valid: readonly number[] = GEMINI_EMBEDDING_2_VALID_DIMENSIONS;
|
||||
if (!valid.includes(requested)) {
|
||||
throw new Error(
|
||||
`Invalid outputDimensionality ${requested} for ${model}. Valid values: ${valid.join(", ")}`,
|
||||
);
|
||||
}
|
||||
return requested;
|
||||
}
|
||||
function resolveRemoteApiKey(remoteApiKey: unknown): string | undefined {
|
||||
const trimmed = resolveMemorySecretInputString({
|
||||
value: remoteApiKey,
|
||||
path: "agents.*.memorySearch.remote.apiKey",
|
||||
});
|
||||
if (!trimmed) {
|
||||
return undefined;
|
||||
}
|
||||
if (trimmed === "GOOGLE_API_KEY" || trimmed === "GEMINI_API_KEY") {
|
||||
return process.env[trimmed]?.trim();
|
||||
}
|
||||
return trimmed;
|
||||
}
|
||||
|
||||
export function normalizeGeminiModel(model: string): string {
|
||||
const trimmed = model.trim();
|
||||
if (!trimmed) {
|
||||
return DEFAULT_GEMINI_EMBEDDING_MODEL;
|
||||
}
|
||||
const withoutPrefix = trimmed.replace(/^models\//, "");
|
||||
if (withoutPrefix.startsWith("gemini/")) {
|
||||
return withoutPrefix.slice("gemini/".length);
|
||||
}
|
||||
if (withoutPrefix.startsWith("google/")) {
|
||||
return withoutPrefix.slice("google/".length);
|
||||
}
|
||||
return withoutPrefix;
|
||||
}
|
||||
|
||||
async function fetchGeminiEmbeddingPayload(params: {
|
||||
client: GeminiEmbeddingClient;
|
||||
endpoint: string;
|
||||
body: unknown;
|
||||
}): Promise<{
|
||||
embedding?: { values?: number[] };
|
||||
embeddings?: Array<{ values?: number[] }>;
|
||||
}> {
|
||||
return await executeWithApiKeyRotation({
|
||||
provider: "google",
|
||||
apiKeys: params.client.apiKeys,
|
||||
execute: async (apiKey) => {
|
||||
const authHeaders = parseGeminiAuth(apiKey);
|
||||
const headers = {
|
||||
...authHeaders.headers,
|
||||
...params.client.headers,
|
||||
};
|
||||
return await withRemoteHttpResponse({
|
||||
url: params.endpoint,
|
||||
ssrfPolicy: params.client.ssrfPolicy,
|
||||
init: {
|
||||
method: "POST",
|
||||
headers,
|
||||
body: JSON.stringify(params.body),
|
||||
},
|
||||
onResponse: async (res) => {
|
||||
if (!res.ok) {
|
||||
const text = await res.text();
|
||||
throw new Error(`gemini embeddings failed: ${res.status} ${text}`);
|
||||
}
|
||||
return (await res.json()) as {
|
||||
embedding?: { values?: number[] };
|
||||
embeddings?: Array<{ values?: number[] }>;
|
||||
};
|
||||
},
|
||||
});
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
function normalizeGeminiBaseUrl(raw: string): string {
|
||||
const trimmed = raw.replace(/\/+$/, "");
|
||||
const openAiIndex = trimmed.indexOf("/openai");
|
||||
if (openAiIndex > -1) {
|
||||
return normalizeGoogleApiBaseUrl(trimmed.slice(0, openAiIndex));
|
||||
}
|
||||
return normalizeGoogleApiBaseUrl(trimmed);
|
||||
}
|
||||
|
||||
function buildGeminiModelPath(model: string): string {
|
||||
return model.startsWith("models/") ? model : `models/${model}`;
|
||||
}
|
||||
|
||||
export async function createGeminiEmbeddingProvider(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<{ provider: EmbeddingProvider; client: GeminiEmbeddingClient }> {
|
||||
const client = await resolveGeminiEmbeddingClient(options);
|
||||
const baseUrl = client.baseUrl.replace(/\/$/, "");
|
||||
const embedUrl = `${baseUrl}/${client.modelPath}:embedContent`;
|
||||
const batchUrl = `${baseUrl}/${client.modelPath}:batchEmbedContents`;
|
||||
const isV2 = isGeminiEmbedding2Model(client.model);
|
||||
const outputDimensionality = client.outputDimensionality;
|
||||
|
||||
const embedQuery = async (text: string): Promise<number[]> => {
|
||||
if (!text.trim()) {
|
||||
return [];
|
||||
}
|
||||
const payload = await fetchGeminiEmbeddingPayload({
|
||||
client,
|
||||
endpoint: embedUrl,
|
||||
body: buildGeminiTextEmbeddingRequest({
|
||||
text,
|
||||
taskType: options.taskType ?? "RETRIEVAL_QUERY",
|
||||
outputDimensionality: isV2 ? outputDimensionality : undefined,
|
||||
}),
|
||||
});
|
||||
return sanitizeAndNormalizeEmbedding(payload.embedding?.values ?? []);
|
||||
};
|
||||
|
||||
const embedBatchInputs = async (inputs: EmbeddingInput[]): Promise<number[][]> => {
|
||||
if (inputs.length === 0) {
|
||||
return [];
|
||||
}
|
||||
const payload = await fetchGeminiEmbeddingPayload({
|
||||
client,
|
||||
endpoint: batchUrl,
|
||||
body: {
|
||||
requests: inputs.map((input) =>
|
||||
buildGeminiEmbeddingRequest({
|
||||
input,
|
||||
modelPath: client.modelPath,
|
||||
taskType: options.taskType ?? "RETRIEVAL_DOCUMENT",
|
||||
outputDimensionality: isV2 ? outputDimensionality : undefined,
|
||||
}),
|
||||
),
|
||||
},
|
||||
});
|
||||
const embeddings = Array.isArray(payload.embeddings) ? payload.embeddings : [];
|
||||
return inputs.map((_, index) => sanitizeAndNormalizeEmbedding(embeddings[index]?.values ?? []));
|
||||
};
|
||||
|
||||
const embedBatch = async (texts: string[]): Promise<number[][]> => {
|
||||
return await embedBatchInputs(
|
||||
texts.map((text) => ({
|
||||
text,
|
||||
})),
|
||||
);
|
||||
};
|
||||
|
||||
return {
|
||||
provider: {
|
||||
id: "gemini",
|
||||
model: client.model,
|
||||
maxInputTokens: GEMINI_MAX_INPUT_TOKENS[client.model],
|
||||
embedQuery,
|
||||
embedBatch,
|
||||
embedBatchInputs,
|
||||
},
|
||||
client,
|
||||
};
|
||||
}
|
||||
|
||||
export async function resolveGeminiEmbeddingClient(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<GeminiEmbeddingClient> {
|
||||
const remote = options.remote;
|
||||
const remoteApiKey = resolveRemoteApiKey(remote?.apiKey);
|
||||
const remoteBaseUrl = remote?.baseUrl?.trim();
|
||||
|
||||
const apiKey = remoteApiKey
|
||||
? remoteApiKey
|
||||
: requireApiKey(
|
||||
await resolveApiKeyForProvider({
|
||||
provider: "google",
|
||||
cfg: options.config,
|
||||
agentDir: options.agentDir,
|
||||
}),
|
||||
"google",
|
||||
);
|
||||
|
||||
const providerConfig = options.config.models?.providers?.google;
|
||||
const rawBaseUrl =
|
||||
remoteBaseUrl || providerConfig?.baseUrl?.trim() || DEFAULT_GOOGLE_API_BASE_URL;
|
||||
const baseUrl = normalizeGeminiBaseUrl(rawBaseUrl);
|
||||
const ssrfPolicy = buildRemoteBaseUrlPolicy(baseUrl);
|
||||
const headerOverrides = Object.assign({}, providerConfig?.headers, remote?.headers);
|
||||
const headers: Record<string, string> = {
|
||||
...headerOverrides,
|
||||
};
|
||||
const apiKeys = collectProviderApiKeysForExecution({
|
||||
provider: "google",
|
||||
primaryApiKey: apiKey,
|
||||
});
|
||||
const model = normalizeGeminiModel(options.model);
|
||||
const modelPath = buildGeminiModelPath(model);
|
||||
const outputDimensionality = resolveGeminiOutputDimensionality(
|
||||
model,
|
||||
options.outputDimensionality,
|
||||
);
|
||||
debugEmbeddingsLog("memory embeddings: gemini client", {
|
||||
rawBaseUrl,
|
||||
baseUrl,
|
||||
model,
|
||||
modelPath,
|
||||
outputDimensionality,
|
||||
embedEndpoint: `${baseUrl}/${modelPath}:embedContent`,
|
||||
batchEndpoint: `${baseUrl}/${modelPath}:batchEmbedContents`,
|
||||
});
|
||||
return { baseUrl, headers, ssrfPolicy, model, modelPath, apiKeys, outputDimensionality };
|
||||
}
|
||||
@@ -1,19 +0,0 @@
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { DEFAULT_MISTRAL_EMBEDDING_MODEL, normalizeMistralModel } from "./embeddings-mistral.js";
|
||||
|
||||
describe("normalizeMistralModel", () => {
|
||||
it("returns the default model for empty values", () => {
|
||||
expect(normalizeMistralModel("")).toBe(DEFAULT_MISTRAL_EMBEDDING_MODEL);
|
||||
expect(normalizeMistralModel(" ")).toBe(DEFAULT_MISTRAL_EMBEDDING_MODEL);
|
||||
});
|
||||
|
||||
it("strips the mistral/ prefix", () => {
|
||||
expect(normalizeMistralModel("mistral/mistral-embed")).toBe("mistral-embed");
|
||||
expect(normalizeMistralModel(" mistral/custom-embed ")).toBe("custom-embed");
|
||||
});
|
||||
|
||||
it("keeps explicit non-prefixed models", () => {
|
||||
expect(normalizeMistralModel("mistral-embed")).toBe("mistral-embed");
|
||||
expect(normalizeMistralModel("custom-embed-v2")).toBe("custom-embed-v2");
|
||||
});
|
||||
});
|
||||
@@ -1,51 +0,0 @@
|
||||
import type { SsrFPolicy } from "../../infra/net/ssrf.js";
|
||||
import { normalizeEmbeddingModelWithPrefixes } from "./embeddings-model-normalize.js";
|
||||
import {
|
||||
createRemoteEmbeddingProvider,
|
||||
resolveRemoteEmbeddingClient,
|
||||
} from "./embeddings-remote-provider.js";
|
||||
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js";
|
||||
|
||||
export type MistralEmbeddingClient = {
|
||||
baseUrl: string;
|
||||
headers: Record<string, string>;
|
||||
ssrfPolicy?: SsrFPolicy;
|
||||
model: string;
|
||||
};
|
||||
|
||||
export const DEFAULT_MISTRAL_EMBEDDING_MODEL = "mistral-embed";
|
||||
const DEFAULT_MISTRAL_BASE_URL = "https://api.mistral.ai/v1";
|
||||
|
||||
export function normalizeMistralModel(model: string): string {
|
||||
return normalizeEmbeddingModelWithPrefixes({
|
||||
model,
|
||||
defaultModel: DEFAULT_MISTRAL_EMBEDDING_MODEL,
|
||||
prefixes: ["mistral/"],
|
||||
});
|
||||
}
|
||||
|
||||
export async function createMistralEmbeddingProvider(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<{ provider: EmbeddingProvider; client: MistralEmbeddingClient }> {
|
||||
const client = await resolveMistralEmbeddingClient(options);
|
||||
|
||||
return {
|
||||
provider: createRemoteEmbeddingProvider({
|
||||
id: "mistral",
|
||||
client,
|
||||
errorPrefix: "mistral embeddings failed",
|
||||
}),
|
||||
client,
|
||||
};
|
||||
}
|
||||
|
||||
export async function resolveMistralEmbeddingClient(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<MistralEmbeddingClient> {
|
||||
return await resolveRemoteEmbeddingClient({
|
||||
provider: "mistral",
|
||||
options,
|
||||
defaultBaseUrl: DEFAULT_MISTRAL_BASE_URL,
|
||||
normalizeModel: normalizeMistralModel,
|
||||
});
|
||||
}
|
||||
@@ -1,34 +0,0 @@
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { normalizeEmbeddingModelWithPrefixes } from "./embeddings-model-normalize.js";
|
||||
|
||||
describe("normalizeEmbeddingModelWithPrefixes", () => {
|
||||
it("returns default model when input is blank", () => {
|
||||
expect(
|
||||
normalizeEmbeddingModelWithPrefixes({
|
||||
model: " ",
|
||||
defaultModel: "fallback-model",
|
||||
prefixes: ["openai/"],
|
||||
}),
|
||||
).toBe("fallback-model");
|
||||
});
|
||||
|
||||
it("strips the first matching prefix", () => {
|
||||
expect(
|
||||
normalizeEmbeddingModelWithPrefixes({
|
||||
model: "openai/text-embedding-3-small",
|
||||
defaultModel: "fallback-model",
|
||||
prefixes: ["openai/"],
|
||||
}),
|
||||
).toBe("text-embedding-3-small");
|
||||
});
|
||||
|
||||
it("keeps explicit model names when no prefix matches", () => {
|
||||
expect(
|
||||
normalizeEmbeddingModelWithPrefixes({
|
||||
model: "voyage-4-large",
|
||||
defaultModel: "fallback-model",
|
||||
prefixes: ["voyage/"],
|
||||
}),
|
||||
).toBe("voyage-4-large");
|
||||
});
|
||||
});
|
||||
@@ -1,16 +0,0 @@
|
||||
export function normalizeEmbeddingModelWithPrefixes(params: {
|
||||
model: string;
|
||||
defaultModel: string;
|
||||
prefixes: string[];
|
||||
}): string {
|
||||
const trimmed = params.model.trim();
|
||||
if (!trimmed) {
|
||||
return params.defaultModel;
|
||||
}
|
||||
for (const prefix of params.prefixes) {
|
||||
if (trimmed.startsWith(prefix)) {
|
||||
return trimmed.slice(prefix.length);
|
||||
}
|
||||
}
|
||||
return trimmed;
|
||||
}
|
||||
@@ -1,146 +0,0 @@
|
||||
import { afterEach, beforeAll, beforeEach, describe, it, expect, vi } from "vitest";
|
||||
import type { OpenClawConfig } from "../../config/config.js";
|
||||
|
||||
let createOllamaEmbeddingProvider: typeof import("./embeddings-ollama.js").createOllamaEmbeddingProvider;
|
||||
|
||||
beforeAll(async () => {
|
||||
({ createOllamaEmbeddingProvider } = await import("./embeddings-ollama.js"));
|
||||
});
|
||||
|
||||
beforeEach(() => {
|
||||
vi.useRealTimers();
|
||||
vi.doUnmock("undici");
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.doUnmock("undici");
|
||||
vi.unstubAllGlobals();
|
||||
vi.unstubAllEnvs();
|
||||
vi.resetAllMocks();
|
||||
});
|
||||
|
||||
describe("embeddings-ollama", () => {
|
||||
it("calls /api/embeddings and returns normalized vectors", async () => {
|
||||
const fetchMock = vi.fn(
|
||||
async () =>
|
||||
new Response(JSON.stringify({ embedding: [3, 4] }), {
|
||||
status: 200,
|
||||
headers: { "content-type": "application/json" },
|
||||
}),
|
||||
);
|
||||
globalThis.fetch = fetchMock as unknown as typeof fetch;
|
||||
|
||||
const { provider } = await createOllamaEmbeddingProvider({
|
||||
config: {} as OpenClawConfig,
|
||||
provider: "ollama",
|
||||
model: "nomic-embed-text",
|
||||
fallback: "none",
|
||||
remote: { baseUrl: "http://127.0.0.1:11434" },
|
||||
});
|
||||
|
||||
const v = await provider.embedQuery("hi");
|
||||
expect(fetchMock).toHaveBeenCalledTimes(1);
|
||||
// normalized [3,4] => [0.6,0.8]
|
||||
expect(v[0]).toBeCloseTo(0.6, 5);
|
||||
expect(v[1]).toBeCloseTo(0.8, 5);
|
||||
});
|
||||
|
||||
it("resolves baseUrl/apiKey/headers from models.providers.ollama and strips /v1", async () => {
|
||||
const fetchMock = vi.fn(
|
||||
async () =>
|
||||
new Response(JSON.stringify({ embedding: [1, 0] }), {
|
||||
status: 200,
|
||||
headers: { "content-type": "application/json" },
|
||||
}),
|
||||
);
|
||||
globalThis.fetch = fetchMock as unknown as typeof fetch;
|
||||
|
||||
const { provider } = await createOllamaEmbeddingProvider({
|
||||
config: {
|
||||
models: {
|
||||
providers: {
|
||||
ollama: {
|
||||
baseUrl: "http://127.0.0.1:11434/v1",
|
||||
apiKey: "ollama-\nlocal\r\n", // pragma: allowlist secret
|
||||
headers: {
|
||||
"X-Provider-Header": "provider",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
} as unknown as OpenClawConfig,
|
||||
provider: "ollama",
|
||||
model: "",
|
||||
fallback: "none",
|
||||
});
|
||||
|
||||
await provider.embedQuery("hello");
|
||||
|
||||
expect(fetchMock).toHaveBeenCalledWith(
|
||||
"http://127.0.0.1:11434/api/embeddings",
|
||||
expect.objectContaining({
|
||||
method: "POST",
|
||||
headers: expect.objectContaining({
|
||||
"Content-Type": "application/json",
|
||||
Authorization: "Bearer ollama-local",
|
||||
"X-Provider-Header": "provider",
|
||||
}),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it("fails fast when memory-search remote apiKey is an unresolved SecretRef", async () => {
|
||||
await expect(
|
||||
createOllamaEmbeddingProvider({
|
||||
config: {} as OpenClawConfig,
|
||||
provider: "ollama",
|
||||
model: "nomic-embed-text",
|
||||
fallback: "none",
|
||||
remote: {
|
||||
baseUrl: "http://127.0.0.1:11434",
|
||||
apiKey: { source: "env", provider: "default", id: "OLLAMA_API_KEY" },
|
||||
},
|
||||
}),
|
||||
).rejects.toThrow(/agents\.\*\.memorySearch\.remote\.apiKey: unresolved SecretRef/i);
|
||||
});
|
||||
|
||||
it("falls back to env key when models.providers.ollama.apiKey is an unresolved SecretRef", async () => {
|
||||
const fetchMock = vi.fn(
|
||||
async () =>
|
||||
new Response(JSON.stringify({ embedding: [1, 0] }), {
|
||||
status: 200,
|
||||
headers: { "content-type": "application/json" },
|
||||
}),
|
||||
);
|
||||
globalThis.fetch = fetchMock as unknown as typeof fetch;
|
||||
vi.stubEnv("OLLAMA_API_KEY", "ollama-env");
|
||||
|
||||
const { provider } = await createOllamaEmbeddingProvider({
|
||||
config: {
|
||||
models: {
|
||||
providers: {
|
||||
ollama: {
|
||||
baseUrl: "http://127.0.0.1:11434/v1",
|
||||
apiKey: { source: "env", provider: "default", id: "OLLAMA_API_KEY" },
|
||||
models: [],
|
||||
},
|
||||
},
|
||||
},
|
||||
} as unknown as OpenClawConfig,
|
||||
provider: "ollama",
|
||||
model: "nomic-embed-text",
|
||||
fallback: "none",
|
||||
});
|
||||
|
||||
await provider.embedQuery("hello");
|
||||
|
||||
expect(fetchMock).toHaveBeenCalledWith(
|
||||
"http://127.0.0.1:11434/api/embeddings",
|
||||
expect.objectContaining({
|
||||
headers: expect.objectContaining({
|
||||
Authorization: "Bearer ollama-env",
|
||||
}),
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
@@ -1,123 +0,0 @@
|
||||
import { resolveEnvApiKey } from "../../agents/model-auth.js";
|
||||
import { resolveOllamaApiBase } from "../../agents/ollama-models.js";
|
||||
import { formatErrorMessage } from "../../infra/errors.js";
|
||||
import type { SsrFPolicy } from "../../infra/net/ssrf.js";
|
||||
import { normalizeOptionalSecretInput } from "../../utils/normalize-secret-input.js";
|
||||
import { sanitizeAndNormalizeEmbedding } from "./embedding-vectors.js";
|
||||
import { normalizeEmbeddingModelWithPrefixes } from "./embeddings-model-normalize.js";
|
||||
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js";
|
||||
import { buildRemoteBaseUrlPolicy, withRemoteHttpResponse } from "./remote-http.js";
|
||||
import { resolveMemorySecretInputString } from "./secret-input.js";
|
||||
|
||||
export type OllamaEmbeddingClient = {
|
||||
baseUrl: string;
|
||||
headers: Record<string, string>;
|
||||
ssrfPolicy?: SsrFPolicy;
|
||||
model: string;
|
||||
embedBatch: (texts: string[]) => Promise<number[][]>;
|
||||
};
|
||||
type OllamaEmbeddingClientConfig = Omit<OllamaEmbeddingClient, "embedBatch">;
|
||||
|
||||
export const DEFAULT_OLLAMA_EMBEDDING_MODEL = "nomic-embed-text";
|
||||
|
||||
function normalizeOllamaModel(model: string): string {
|
||||
return normalizeEmbeddingModelWithPrefixes({
|
||||
model,
|
||||
defaultModel: DEFAULT_OLLAMA_EMBEDDING_MODEL,
|
||||
prefixes: ["ollama/"],
|
||||
});
|
||||
}
|
||||
|
||||
function resolveOllamaApiKey(options: EmbeddingProviderOptions): string | undefined {
|
||||
const remoteApiKey = resolveMemorySecretInputString({
|
||||
value: options.remote?.apiKey,
|
||||
path: "agents.*.memorySearch.remote.apiKey",
|
||||
});
|
||||
if (remoteApiKey) {
|
||||
return remoteApiKey;
|
||||
}
|
||||
const providerApiKey = normalizeOptionalSecretInput(
|
||||
options.config.models?.providers?.ollama?.apiKey,
|
||||
);
|
||||
if (providerApiKey) {
|
||||
return providerApiKey;
|
||||
}
|
||||
return resolveEnvApiKey("ollama")?.apiKey;
|
||||
}
|
||||
|
||||
function resolveOllamaEmbeddingClient(
|
||||
options: EmbeddingProviderOptions,
|
||||
): OllamaEmbeddingClientConfig {
|
||||
const providerConfig = options.config.models?.providers?.ollama;
|
||||
const rawBaseUrl = options.remote?.baseUrl?.trim() || providerConfig?.baseUrl?.trim();
|
||||
const baseUrl = resolveOllamaApiBase(rawBaseUrl);
|
||||
const model = normalizeOllamaModel(options.model);
|
||||
const headerOverrides = Object.assign({}, providerConfig?.headers, options.remote?.headers);
|
||||
const headers: Record<string, string> = {
|
||||
"Content-Type": "application/json",
|
||||
...headerOverrides,
|
||||
};
|
||||
const apiKey = resolveOllamaApiKey(options);
|
||||
if (apiKey) {
|
||||
headers.Authorization = `Bearer ${apiKey}`;
|
||||
}
|
||||
return {
|
||||
baseUrl,
|
||||
headers,
|
||||
ssrfPolicy: buildRemoteBaseUrlPolicy(baseUrl),
|
||||
model,
|
||||
};
|
||||
}
|
||||
|
||||
export async function createOllamaEmbeddingProvider(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<{ provider: EmbeddingProvider; client: OllamaEmbeddingClient }> {
|
||||
const client = resolveOllamaEmbeddingClient(options);
|
||||
const embedUrl = `${client.baseUrl.replace(/\/$/, "")}/api/embeddings`;
|
||||
|
||||
const embedOne = async (text: string): Promise<number[]> => {
|
||||
const json = await withRemoteHttpResponse({
|
||||
url: embedUrl,
|
||||
ssrfPolicy: client.ssrfPolicy,
|
||||
init: {
|
||||
method: "POST",
|
||||
headers: client.headers,
|
||||
body: JSON.stringify({ model: client.model, prompt: text }),
|
||||
},
|
||||
onResponse: async (res) => {
|
||||
if (!res.ok) {
|
||||
throw new Error(`Ollama embeddings HTTP ${res.status}: ${await res.text()}`);
|
||||
}
|
||||
return (await res.json()) as { embedding?: number[] };
|
||||
},
|
||||
});
|
||||
if (!Array.isArray(json.embedding)) {
|
||||
throw new Error(`Ollama embeddings response missing embedding[]`);
|
||||
}
|
||||
return sanitizeAndNormalizeEmbedding(json.embedding);
|
||||
};
|
||||
|
||||
const provider: EmbeddingProvider = {
|
||||
id: "ollama",
|
||||
model: client.model,
|
||||
embedQuery: embedOne,
|
||||
embedBatch: async (texts: string[]) => {
|
||||
// Ollama /api/embeddings accepts one prompt per request.
|
||||
return await Promise.all(texts.map(embedOne));
|
||||
},
|
||||
};
|
||||
|
||||
return {
|
||||
provider,
|
||||
client: {
|
||||
...client,
|
||||
embedBatch: async (texts) => {
|
||||
try {
|
||||
return await provider.embedBatch(texts);
|
||||
} catch (err) {
|
||||
throw new Error(formatErrorMessage(err), { cause: err });
|
||||
}
|
||||
},
|
||||
},
|
||||
};
|
||||
}
|
||||
@@ -1,58 +0,0 @@
|
||||
import type { SsrFPolicy } from "../../infra/net/ssrf.js";
|
||||
import { OPENAI_DEFAULT_EMBEDDING_MODEL } from "../../plugins/provider-model-defaults.js";
|
||||
import { normalizeEmbeddingModelWithPrefixes } from "./embeddings-model-normalize.js";
|
||||
import {
|
||||
createRemoteEmbeddingProvider,
|
||||
resolveRemoteEmbeddingClient,
|
||||
} from "./embeddings-remote-provider.js";
|
||||
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js";
|
||||
|
||||
export type OpenAiEmbeddingClient = {
|
||||
baseUrl: string;
|
||||
headers: Record<string, string>;
|
||||
ssrfPolicy?: SsrFPolicy;
|
||||
model: string;
|
||||
};
|
||||
|
||||
const DEFAULT_OPENAI_BASE_URL = "https://api.openai.com/v1";
|
||||
export const DEFAULT_OPENAI_EMBEDDING_MODEL = OPENAI_DEFAULT_EMBEDDING_MODEL;
|
||||
const OPENAI_MAX_INPUT_TOKENS: Record<string, number> = {
|
||||
"text-embedding-3-small": 8192,
|
||||
"text-embedding-3-large": 8192,
|
||||
"text-embedding-ada-002": 8191,
|
||||
};
|
||||
|
||||
export function normalizeOpenAiModel(model: string): string {
|
||||
return normalizeEmbeddingModelWithPrefixes({
|
||||
model,
|
||||
defaultModel: DEFAULT_OPENAI_EMBEDDING_MODEL,
|
||||
prefixes: ["openai/"],
|
||||
});
|
||||
}
|
||||
|
||||
export async function createOpenAiEmbeddingProvider(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<{ provider: EmbeddingProvider; client: OpenAiEmbeddingClient }> {
|
||||
const client = await resolveOpenAiEmbeddingClient(options);
|
||||
|
||||
return {
|
||||
provider: createRemoteEmbeddingProvider({
|
||||
id: "openai",
|
||||
client,
|
||||
errorPrefix: "openai embeddings failed",
|
||||
maxInputTokens: OPENAI_MAX_INPUT_TOKENS[client.model],
|
||||
}),
|
||||
client,
|
||||
};
|
||||
}
|
||||
|
||||
export async function resolveOpenAiEmbeddingClient(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<OpenAiEmbeddingClient> {
|
||||
return await resolveRemoteEmbeddingClient({
|
||||
provider: "openai",
|
||||
options,
|
||||
defaultBaseUrl: DEFAULT_OPENAI_BASE_URL,
|
||||
normalizeModel: normalizeOpenAiModel,
|
||||
});
|
||||
}
|
||||
@@ -1,39 +0,0 @@
|
||||
import { requireApiKey, resolveApiKeyForProvider } from "../../agents/model-auth.js";
|
||||
import type { SsrFPolicy } from "../../infra/net/ssrf.js";
|
||||
import type { EmbeddingProviderOptions } from "./embeddings.js";
|
||||
import { buildRemoteBaseUrlPolicy } from "./remote-http.js";
|
||||
import { resolveMemorySecretInputString } from "./secret-input.js";
|
||||
|
||||
export type RemoteEmbeddingProviderId = "openai" | "voyage" | "mistral";
|
||||
|
||||
export async function resolveRemoteEmbeddingBearerClient(params: {
|
||||
provider: RemoteEmbeddingProviderId;
|
||||
options: EmbeddingProviderOptions;
|
||||
defaultBaseUrl: string;
|
||||
}): Promise<{ baseUrl: string; headers: Record<string, string>; ssrfPolicy?: SsrFPolicy }> {
|
||||
const remote = params.options.remote;
|
||||
const remoteApiKey = resolveMemorySecretInputString({
|
||||
value: remote?.apiKey,
|
||||
path: "agents.*.memorySearch.remote.apiKey",
|
||||
});
|
||||
const remoteBaseUrl = remote?.baseUrl?.trim();
|
||||
const providerConfig = params.options.config.models?.providers?.[params.provider];
|
||||
const apiKey = remoteApiKey
|
||||
? remoteApiKey
|
||||
: requireApiKey(
|
||||
await resolveApiKeyForProvider({
|
||||
provider: params.provider,
|
||||
cfg: params.options.config,
|
||||
agentDir: params.options.agentDir,
|
||||
}),
|
||||
params.provider,
|
||||
);
|
||||
const baseUrl = remoteBaseUrl || providerConfig?.baseUrl?.trim() || params.defaultBaseUrl;
|
||||
const headerOverrides = Object.assign({}, providerConfig?.headers, remote?.headers);
|
||||
const headers: Record<string, string> = {
|
||||
"Content-Type": "application/json",
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
...headerOverrides,
|
||||
};
|
||||
return { baseUrl, headers, ssrfPolicy: buildRemoteBaseUrlPolicy(baseUrl) };
|
||||
}
|
||||
@@ -1,56 +0,0 @@
|
||||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
|
||||
const postJsonMock = vi.hoisted(() => vi.fn());
|
||||
|
||||
type EmbeddingsRemoteFetchModule = typeof import("./embeddings-remote-fetch.js");
|
||||
|
||||
let fetchRemoteEmbeddingVectors: EmbeddingsRemoteFetchModule["fetchRemoteEmbeddingVectors"];
|
||||
|
||||
describe("fetchRemoteEmbeddingVectors", () => {
|
||||
beforeEach(async () => {
|
||||
vi.resetModules();
|
||||
vi.doMock("./post-json.js", () => ({
|
||||
postJson: postJsonMock,
|
||||
}));
|
||||
({ fetchRemoteEmbeddingVectors } = await import("./embeddings-remote-fetch.js"));
|
||||
postJsonMock.mockReset();
|
||||
});
|
||||
|
||||
it("maps remote embedding response data to vectors", async () => {
|
||||
postJsonMock.mockImplementationOnce(async (params) => {
|
||||
return await params.parse({
|
||||
data: [{ embedding: [0.1, 0.2] }, {}, { embedding: [0.3] }],
|
||||
});
|
||||
});
|
||||
|
||||
const vectors = await fetchRemoteEmbeddingVectors({
|
||||
url: "https://memory.example/v1/embeddings",
|
||||
headers: { Authorization: "Bearer test" },
|
||||
body: { input: ["one", "two", "three"] },
|
||||
errorPrefix: "embedding fetch failed",
|
||||
});
|
||||
|
||||
expect(vectors).toEqual([[0.1, 0.2], [], [0.3]]);
|
||||
expect(postJsonMock).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
url: "https://memory.example/v1/embeddings",
|
||||
headers: { Authorization: "Bearer test" },
|
||||
body: { input: ["one", "two", "three"] },
|
||||
errorPrefix: "embedding fetch failed",
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it("throws a status-rich error on non-ok responses", async () => {
|
||||
postJsonMock.mockRejectedValueOnce(new Error("embedding fetch failed: 403 forbidden"));
|
||||
|
||||
await expect(
|
||||
fetchRemoteEmbeddingVectors({
|
||||
url: "https://memory.example/v1/embeddings",
|
||||
headers: {},
|
||||
body: { input: ["one"] },
|
||||
errorPrefix: "embedding fetch failed",
|
||||
}),
|
||||
).rejects.toThrow("embedding fetch failed: 403 forbidden");
|
||||
});
|
||||
});
|
||||
@@ -1,25 +0,0 @@
|
||||
import type { SsrFPolicy } from "../../infra/net/ssrf.js";
|
||||
import { postJson } from "./post-json.js";
|
||||
|
||||
export async function fetchRemoteEmbeddingVectors(params: {
|
||||
url: string;
|
||||
headers: Record<string, string>;
|
||||
ssrfPolicy?: SsrFPolicy;
|
||||
body: unknown;
|
||||
errorPrefix: string;
|
||||
}): Promise<number[][]> {
|
||||
return await postJson({
|
||||
url: params.url,
|
||||
headers: params.headers,
|
||||
ssrfPolicy: params.ssrfPolicy,
|
||||
body: params.body,
|
||||
errorPrefix: params.errorPrefix,
|
||||
parse: (payload) => {
|
||||
const typedPayload = payload as {
|
||||
data?: Array<{ embedding?: number[] }>;
|
||||
};
|
||||
const data = typedPayload.data ?? [];
|
||||
return data.map((entry) => entry.embedding ?? []);
|
||||
},
|
||||
});
|
||||
}
|
||||
@@ -1,63 +0,0 @@
|
||||
import type { SsrFPolicy } from "../../infra/net/ssrf.js";
|
||||
import {
|
||||
resolveRemoteEmbeddingBearerClient,
|
||||
type RemoteEmbeddingProviderId,
|
||||
} from "./embeddings-remote-client.js";
|
||||
import { fetchRemoteEmbeddingVectors } from "./embeddings-remote-fetch.js";
|
||||
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js";
|
||||
|
||||
export type RemoteEmbeddingClient = {
|
||||
baseUrl: string;
|
||||
headers: Record<string, string>;
|
||||
ssrfPolicy?: SsrFPolicy;
|
||||
model: string;
|
||||
};
|
||||
|
||||
export function createRemoteEmbeddingProvider(params: {
|
||||
id: string;
|
||||
client: RemoteEmbeddingClient;
|
||||
errorPrefix: string;
|
||||
maxInputTokens?: number;
|
||||
}): EmbeddingProvider {
|
||||
const { client } = params;
|
||||
const url = `${client.baseUrl.replace(/\/$/, "")}/embeddings`;
|
||||
|
||||
const embed = async (input: string[]): Promise<number[][]> => {
|
||||
if (input.length === 0) {
|
||||
return [];
|
||||
}
|
||||
return await fetchRemoteEmbeddingVectors({
|
||||
url,
|
||||
headers: client.headers,
|
||||
ssrfPolicy: client.ssrfPolicy,
|
||||
body: { model: client.model, input },
|
||||
errorPrefix: params.errorPrefix,
|
||||
});
|
||||
};
|
||||
|
||||
return {
|
||||
id: params.id,
|
||||
model: client.model,
|
||||
...(typeof params.maxInputTokens === "number" ? { maxInputTokens: params.maxInputTokens } : {}),
|
||||
embedQuery: async (text) => {
|
||||
const [vec] = await embed([text]);
|
||||
return vec ?? [];
|
||||
},
|
||||
embedBatch: embed,
|
||||
};
|
||||
}
|
||||
|
||||
export async function resolveRemoteEmbeddingClient(params: {
|
||||
provider: RemoteEmbeddingProviderId;
|
||||
options: EmbeddingProviderOptions;
|
||||
defaultBaseUrl: string;
|
||||
normalizeModel: (model: string) => string;
|
||||
}): Promise<RemoteEmbeddingClient> {
|
||||
const { baseUrl, headers, ssrfPolicy } = await resolveRemoteEmbeddingBearerClient({
|
||||
provider: params.provider,
|
||||
options: params.options,
|
||||
defaultBaseUrl: params.defaultBaseUrl,
|
||||
});
|
||||
const model = params.normalizeModel(params.options.model);
|
||||
return { baseUrl, headers, ssrfPolicy, model };
|
||||
}
|
||||
@@ -1,153 +0,0 @@
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { type FetchMock, withFetchPreconnect } from "../../test-utils/fetch-mock.js";
|
||||
import { mockPublicPinnedHostname } from "./test-helpers/ssrf.js";
|
||||
|
||||
vi.mock("../../agents/model-auth.js", async () => {
|
||||
const { createModelAuthMockModule } = await import("../../test-utils/model-auth-mock.js");
|
||||
return createModelAuthMockModule();
|
||||
});
|
||||
|
||||
const createFetchMock = () => {
|
||||
const fetchMock = vi.fn<FetchMock>(
|
||||
async (_input: RequestInfo | URL, _init?: RequestInit) =>
|
||||
new Response(JSON.stringify({ data: [{ embedding: [0.1, 0.2, 0.3] }] }), {
|
||||
status: 200,
|
||||
headers: { "Content-Type": "application/json" },
|
||||
}),
|
||||
);
|
||||
return withFetchPreconnect(fetchMock);
|
||||
};
|
||||
|
||||
let authModule: typeof import("../../agents/model-auth.js");
|
||||
let createVoyageEmbeddingProvider: typeof import("./embeddings-voyage.js").createVoyageEmbeddingProvider;
|
||||
let normalizeVoyageModel: typeof import("./embeddings-voyage.js").normalizeVoyageModel;
|
||||
|
||||
beforeEach(async () => {
|
||||
vi.useRealTimers();
|
||||
vi.doUnmock("undici");
|
||||
vi.resetModules();
|
||||
authModule = await import("../../agents/model-auth.js");
|
||||
({ createVoyageEmbeddingProvider, normalizeVoyageModel } =
|
||||
await import("./embeddings-voyage.js"));
|
||||
});
|
||||
|
||||
function mockVoyageApiKey() {
|
||||
vi.mocked(authModule.resolveApiKeyForProvider).mockResolvedValue({
|
||||
apiKey: "voyage-key-123",
|
||||
mode: "api-key",
|
||||
source: "test",
|
||||
});
|
||||
}
|
||||
|
||||
async function createDefaultVoyageProvider(
|
||||
model: string,
|
||||
fetchMock: ReturnType<typeof createFetchMock>,
|
||||
) {
|
||||
vi.stubGlobal("fetch", fetchMock);
|
||||
mockPublicPinnedHostname();
|
||||
mockVoyageApiKey();
|
||||
return createVoyageEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "voyage",
|
||||
model,
|
||||
fallback: "none",
|
||||
});
|
||||
}
|
||||
|
||||
describe("voyage embedding provider", () => {
|
||||
afterEach(() => {
|
||||
vi.doUnmock("undici");
|
||||
vi.resetAllMocks();
|
||||
vi.unstubAllGlobals();
|
||||
});
|
||||
|
||||
it("configures client with correct defaults and headers", async () => {
|
||||
const fetchMock = createFetchMock();
|
||||
const result = await createDefaultVoyageProvider("voyage-4-large", fetchMock);
|
||||
|
||||
await result.provider.embedQuery("test query");
|
||||
|
||||
expect(authModule.resolveApiKeyForProvider).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ provider: "voyage" }),
|
||||
);
|
||||
|
||||
const call = fetchMock.mock.calls[0];
|
||||
expect(call).toBeDefined();
|
||||
const [url, init] = call as [RequestInfo | URL, RequestInit | undefined];
|
||||
expect(url).toBe("https://api.voyageai.com/v1/embeddings");
|
||||
|
||||
const headers = (init?.headers ?? {}) as Record<string, string>;
|
||||
expect(headers.Authorization).toBe("Bearer voyage-key-123");
|
||||
expect(headers["Content-Type"]).toBe("application/json");
|
||||
|
||||
const body = JSON.parse(init?.body as string);
|
||||
expect(body).toEqual({
|
||||
model: "voyage-4-large",
|
||||
input: ["test query"],
|
||||
input_type: "query",
|
||||
});
|
||||
});
|
||||
|
||||
it("respects remote overrides for baseUrl and apiKey", async () => {
|
||||
const fetchMock = createFetchMock();
|
||||
vi.stubGlobal("fetch", fetchMock);
|
||||
mockPublicPinnedHostname();
|
||||
|
||||
const result = await createVoyageEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "voyage",
|
||||
model: "voyage-4-lite",
|
||||
fallback: "none",
|
||||
remote: {
|
||||
baseUrl: "https://example.com",
|
||||
apiKey: "remote-override-key",
|
||||
headers: { "X-Custom": "123" },
|
||||
},
|
||||
});
|
||||
|
||||
await result.provider.embedQuery("test");
|
||||
|
||||
const call = fetchMock.mock.calls[0];
|
||||
expect(call).toBeDefined();
|
||||
const [url, init] = call as [RequestInfo | URL, RequestInit | undefined];
|
||||
expect(url).toBe("https://example.com/embeddings");
|
||||
|
||||
const headers = (init?.headers ?? {}) as Record<string, string>;
|
||||
expect(headers.Authorization).toBe("Bearer remote-override-key");
|
||||
expect(headers["X-Custom"]).toBe("123");
|
||||
});
|
||||
|
||||
it("passes input_type=document for embedBatch", async () => {
|
||||
const fetchMock = withFetchPreconnect(
|
||||
vi.fn<FetchMock>(
|
||||
async (_input: RequestInfo | URL, _init?: RequestInit) =>
|
||||
new Response(
|
||||
JSON.stringify({
|
||||
data: [{ embedding: [0.1, 0.2] }, { embedding: [0.3, 0.4] }],
|
||||
}),
|
||||
{ status: 200, headers: { "Content-Type": "application/json" } },
|
||||
),
|
||||
),
|
||||
);
|
||||
const result = await createDefaultVoyageProvider("voyage-4-large", fetchMock);
|
||||
|
||||
await result.provider.embedBatch(["doc1", "doc2"]);
|
||||
|
||||
const call = fetchMock.mock.calls[0];
|
||||
expect(call).toBeDefined();
|
||||
const [, init] = call as [RequestInfo | URL, RequestInit | undefined];
|
||||
const body = JSON.parse(init?.body as string);
|
||||
expect(body).toEqual({
|
||||
model: "voyage-4-large",
|
||||
input: ["doc1", "doc2"],
|
||||
input_type: "document",
|
||||
});
|
||||
});
|
||||
|
||||
it("normalizes model names", async () => {
|
||||
expect(normalizeVoyageModel("voyage/voyage-large-2")).toBe("voyage-large-2");
|
||||
expect(normalizeVoyageModel("voyage-4-large")).toBe("voyage-4-large");
|
||||
expect(normalizeVoyageModel(" voyage-lite ")).toBe("voyage-lite");
|
||||
expect(normalizeVoyageModel("")).toBe("voyage-4-large"); // Default
|
||||
});
|
||||
});
|
||||
@@ -1,82 +0,0 @@
|
||||
import type { SsrFPolicy } from "../../infra/net/ssrf.js";
|
||||
import { normalizeEmbeddingModelWithPrefixes } from "./embeddings-model-normalize.js";
|
||||
import { resolveRemoteEmbeddingBearerClient } from "./embeddings-remote-client.js";
|
||||
import { fetchRemoteEmbeddingVectors } from "./embeddings-remote-fetch.js";
|
||||
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js";
|
||||
|
||||
export type VoyageEmbeddingClient = {
|
||||
baseUrl: string;
|
||||
headers: Record<string, string>;
|
||||
ssrfPolicy?: SsrFPolicy;
|
||||
model: string;
|
||||
};
|
||||
|
||||
export const DEFAULT_VOYAGE_EMBEDDING_MODEL = "voyage-4-large";
|
||||
const DEFAULT_VOYAGE_BASE_URL = "https://api.voyageai.com/v1";
|
||||
const VOYAGE_MAX_INPUT_TOKENS: Record<string, number> = {
|
||||
"voyage-3": 32000,
|
||||
"voyage-3-lite": 16000,
|
||||
"voyage-code-3": 32000,
|
||||
};
|
||||
|
||||
export function normalizeVoyageModel(model: string): string {
|
||||
return normalizeEmbeddingModelWithPrefixes({
|
||||
model,
|
||||
defaultModel: DEFAULT_VOYAGE_EMBEDDING_MODEL,
|
||||
prefixes: ["voyage/"],
|
||||
});
|
||||
}
|
||||
|
||||
export async function createVoyageEmbeddingProvider(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<{ provider: EmbeddingProvider; client: VoyageEmbeddingClient }> {
|
||||
const client = await resolveVoyageEmbeddingClient(options);
|
||||
const url = `${client.baseUrl.replace(/\/$/, "")}/embeddings`;
|
||||
|
||||
const embed = async (input: string[], input_type?: "query" | "document"): Promise<number[][]> => {
|
||||
if (input.length === 0) {
|
||||
return [];
|
||||
}
|
||||
const body: { model: string; input: string[]; input_type?: "query" | "document" } = {
|
||||
model: client.model,
|
||||
input,
|
||||
};
|
||||
if (input_type) {
|
||||
body.input_type = input_type;
|
||||
}
|
||||
|
||||
return await fetchRemoteEmbeddingVectors({
|
||||
url,
|
||||
headers: client.headers,
|
||||
ssrfPolicy: client.ssrfPolicy,
|
||||
body,
|
||||
errorPrefix: "voyage embeddings failed",
|
||||
});
|
||||
};
|
||||
|
||||
return {
|
||||
provider: {
|
||||
id: "voyage",
|
||||
model: client.model,
|
||||
maxInputTokens: VOYAGE_MAX_INPUT_TOKENS[client.model],
|
||||
embedQuery: async (text) => {
|
||||
const [vec] = await embed([text], "query");
|
||||
return vec ?? [];
|
||||
},
|
||||
embedBatch: async (texts) => embed(texts, "document"),
|
||||
},
|
||||
client,
|
||||
};
|
||||
}
|
||||
|
||||
export async function resolveVoyageEmbeddingClient(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<VoyageEmbeddingClient> {
|
||||
const { baseUrl, headers, ssrfPolicy } = await resolveRemoteEmbeddingBearerClient({
|
||||
provider: "voyage",
|
||||
options,
|
||||
defaultBaseUrl: DEFAULT_VOYAGE_BASE_URL,
|
||||
});
|
||||
const model = normalizeVoyageModel(options.model);
|
||||
return { baseUrl, headers, ssrfPolicy, model };
|
||||
}
|
||||
@@ -1,740 +0,0 @@
|
||||
import { setTimeout as sleep } from "node:timers/promises";
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { DEFAULT_GEMINI_EMBEDDING_MODEL } from "./embeddings-gemini.js";
|
||||
import { mockPublicPinnedHostname } from "./test-helpers/ssrf.js";
|
||||
|
||||
const createFetchMock = () =>
|
||||
vi.fn(async (_input?: unknown, _init?: unknown) => ({
|
||||
ok: true,
|
||||
status: 200,
|
||||
json: async () => ({ data: [{ embedding: [1, 2, 3] }] }),
|
||||
}));
|
||||
|
||||
const createGeminiFetchMock = () =>
|
||||
vi.fn(async (_input?: unknown, _init?: unknown) => ({
|
||||
ok: true,
|
||||
status: 200,
|
||||
json: async () => ({ embedding: { values: [1, 2, 3] } }),
|
||||
}));
|
||||
|
||||
function readFirstFetchRequest(fetchMock: { mock: { calls: unknown[][] } }) {
|
||||
const [url, init] = fetchMock.mock.calls[0] ?? [];
|
||||
return { url, init: init as RequestInit | undefined };
|
||||
}
|
||||
|
||||
type EmbeddingsModule = typeof import("./embeddings.js");
|
||||
type AuthModule = typeof import("../../agents/model-auth.js");
|
||||
type ResolvedProviderAuth = Awaited<ReturnType<AuthModule["resolveApiKeyForProvider"]>>;
|
||||
|
||||
let authModule: AuthModule;
|
||||
let nodeLlamaModule: typeof import("./node-llama.js");
|
||||
let createEmbeddingProvider: EmbeddingsModule["createEmbeddingProvider"];
|
||||
let DEFAULT_LOCAL_MODEL: EmbeddingsModule["DEFAULT_LOCAL_MODEL"];
|
||||
|
||||
beforeEach(async () => {
|
||||
vi.resetModules();
|
||||
authModule = await import("../../agents/model-auth.js");
|
||||
nodeLlamaModule = await import("./node-llama.js");
|
||||
vi.spyOn(authModule, "resolveApiKeyForProvider");
|
||||
vi.spyOn(nodeLlamaModule, "importNodeLlamaCpp");
|
||||
({ createEmbeddingProvider, DEFAULT_LOCAL_MODEL } = await import("./embeddings.js"));
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.resetAllMocks();
|
||||
vi.unstubAllGlobals();
|
||||
});
|
||||
|
||||
function requireProvider(result: Awaited<ReturnType<EmbeddingsModule["createEmbeddingProvider"]>>) {
|
||||
if (!result.provider) {
|
||||
throw new Error("Expected embedding provider");
|
||||
}
|
||||
return result.provider;
|
||||
}
|
||||
|
||||
function mockResolvedProviderKey(apiKey = "provider-key") {
|
||||
vi.mocked(authModule.resolveApiKeyForProvider).mockResolvedValue({
|
||||
apiKey,
|
||||
mode: "api-key",
|
||||
source: "test",
|
||||
});
|
||||
}
|
||||
|
||||
function mockMissingLocalEmbeddingDependency() {
|
||||
vi.mocked(nodeLlamaModule.importNodeLlamaCpp).mockRejectedValue(
|
||||
Object.assign(new Error("Cannot find package 'node-llama-cpp'"), {
|
||||
code: "ERR_MODULE_NOT_FOUND",
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
function createLocalProvider(options?: { fallback?: "none" | "openai" }) {
|
||||
return createEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "local",
|
||||
model: "text-embedding-3-small",
|
||||
fallback: options?.fallback ?? "none",
|
||||
});
|
||||
}
|
||||
|
||||
function expectAutoSelectedProvider(
|
||||
result: Awaited<ReturnType<EmbeddingsModule["createEmbeddingProvider"]>>,
|
||||
expectedId: "openai" | "gemini" | "mistral",
|
||||
) {
|
||||
expect(result.requestedProvider).toBe("auto");
|
||||
const provider = requireProvider(result);
|
||||
expect(provider.id).toBe(expectedId);
|
||||
return provider;
|
||||
}
|
||||
|
||||
function createAutoProvider(model = "") {
|
||||
return createEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "auto",
|
||||
model,
|
||||
fallback: "none",
|
||||
});
|
||||
}
|
||||
|
||||
describe("embedding provider remote overrides", () => {
|
||||
it("uses remote baseUrl/apiKey and merges headers", async () => {
|
||||
const fetchMock = createFetchMock();
|
||||
vi.stubGlobal("fetch", fetchMock);
|
||||
mockPublicPinnedHostname();
|
||||
mockResolvedProviderKey("provider-key");
|
||||
|
||||
const cfg = {
|
||||
models: {
|
||||
providers: {
|
||||
openai: {
|
||||
baseUrl: "https://api.openai.com/v1",
|
||||
headers: {
|
||||
"X-Provider": "p",
|
||||
"X-Shared": "provider",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const result = await createEmbeddingProvider({
|
||||
config: cfg as never,
|
||||
provider: "openai",
|
||||
remote: {
|
||||
baseUrl: "https://example.com/v1",
|
||||
apiKey: " remote-key ",
|
||||
headers: {
|
||||
"X-Shared": "remote",
|
||||
"X-Remote": "r",
|
||||
},
|
||||
},
|
||||
model: "text-embedding-3-small",
|
||||
fallback: "openai",
|
||||
});
|
||||
|
||||
const provider = requireProvider(result);
|
||||
await provider.embedQuery("hello");
|
||||
|
||||
expect(authModule.resolveApiKeyForProvider).not.toHaveBeenCalled();
|
||||
const url = fetchMock.mock.calls[0]?.[0];
|
||||
const init = fetchMock.mock.calls[0]?.[1] as RequestInit | undefined;
|
||||
expect(url).toBe("https://example.com/v1/embeddings");
|
||||
const headers = (init?.headers ?? {}) as Record<string, string>;
|
||||
expect(headers.Authorization).toBe("Bearer remote-key");
|
||||
expect(headers["Content-Type"]).toBe("application/json");
|
||||
expect(headers["X-Provider"]).toBe("p");
|
||||
expect(headers["X-Shared"]).toBe("remote");
|
||||
expect(headers["X-Remote"]).toBe("r");
|
||||
});
|
||||
|
||||
it("falls back to resolved api key when remote apiKey is blank", async () => {
|
||||
const fetchMock = createFetchMock();
|
||||
vi.stubGlobal("fetch", fetchMock);
|
||||
mockPublicPinnedHostname();
|
||||
mockResolvedProviderKey("provider-key");
|
||||
|
||||
const cfg = {
|
||||
models: {
|
||||
providers: {
|
||||
openai: {
|
||||
baseUrl: "https://api.openai.com/v1",
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const result = await createEmbeddingProvider({
|
||||
config: cfg as never,
|
||||
provider: "openai",
|
||||
remote: {
|
||||
baseUrl: "https://example.com/v1",
|
||||
apiKey: " ",
|
||||
},
|
||||
model: "text-embedding-3-small",
|
||||
fallback: "openai",
|
||||
});
|
||||
|
||||
const provider = requireProvider(result);
|
||||
await provider.embedQuery("hello");
|
||||
|
||||
expect(authModule.resolveApiKeyForProvider).toHaveBeenCalledTimes(1);
|
||||
const init = fetchMock.mock.calls[0]?.[1] as RequestInit | undefined;
|
||||
const headers = (init?.headers as Record<string, string>) ?? {};
|
||||
expect(headers.Authorization).toBe("Bearer provider-key");
|
||||
});
|
||||
|
||||
it("builds Gemini embeddings requests with api key header", async () => {
|
||||
const fetchMock = createGeminiFetchMock();
|
||||
vi.stubGlobal("fetch", fetchMock);
|
||||
mockPublicPinnedHostname();
|
||||
mockResolvedProviderKey("provider-key");
|
||||
|
||||
const cfg = {
|
||||
models: {
|
||||
providers: {
|
||||
google: {
|
||||
baseUrl: "https://generativelanguage.googleapis.com/v1beta",
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const result = await createEmbeddingProvider({
|
||||
config: cfg as never,
|
||||
provider: "gemini",
|
||||
remote: {
|
||||
apiKey: "gemini-key",
|
||||
},
|
||||
model: "text-embedding-004",
|
||||
fallback: "openai",
|
||||
});
|
||||
|
||||
const provider = requireProvider(result);
|
||||
await provider.embedQuery("hello");
|
||||
|
||||
const { url, init } = readFirstFetchRequest(fetchMock);
|
||||
expect(url).toBe(
|
||||
"https://generativelanguage.googleapis.com/v1beta/models/text-embedding-004:embedContent",
|
||||
);
|
||||
const headers = (init?.headers ?? {}) as Record<string, string>;
|
||||
expect(headers["x-goog-api-key"]).toBe("gemini-key");
|
||||
expect(headers["Content-Type"]).toBe("application/json");
|
||||
});
|
||||
|
||||
it("fails fast when Gemini remote apiKey is an unresolved SecretRef", async () => {
|
||||
await expect(
|
||||
createEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "gemini",
|
||||
remote: {
|
||||
apiKey: { source: "env", provider: "default", id: "GEMINI_API_KEY" },
|
||||
},
|
||||
model: "text-embedding-004",
|
||||
fallback: "openai",
|
||||
}),
|
||||
).rejects.toThrow(/agents\.\*\.memorySearch\.remote\.apiKey:/i);
|
||||
});
|
||||
|
||||
it("uses GEMINI_API_KEY env indirection for Gemini remote apiKey", async () => {
|
||||
const fetchMock = createGeminiFetchMock();
|
||||
vi.stubGlobal("fetch", fetchMock);
|
||||
mockPublicPinnedHostname();
|
||||
vi.stubEnv("GEMINI_API_KEY", "env-gemini-key");
|
||||
|
||||
const result = await createEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "gemini",
|
||||
remote: {
|
||||
apiKey: "GEMINI_API_KEY", // pragma: allowlist secret
|
||||
},
|
||||
model: "text-embedding-004",
|
||||
fallback: "openai",
|
||||
});
|
||||
|
||||
const provider = requireProvider(result);
|
||||
await provider.embedQuery("hello");
|
||||
|
||||
const { init } = readFirstFetchRequest(fetchMock);
|
||||
const headers = (init?.headers ?? {}) as Record<string, string>;
|
||||
expect(headers["x-goog-api-key"]).toBe("env-gemini-key");
|
||||
});
|
||||
|
||||
it("builds Mistral embeddings requests with bearer auth", async () => {
|
||||
const fetchMock = createFetchMock();
|
||||
vi.stubGlobal("fetch", fetchMock);
|
||||
mockPublicPinnedHostname();
|
||||
mockResolvedProviderKey("provider-key");
|
||||
|
||||
const cfg = {
|
||||
models: {
|
||||
providers: {
|
||||
mistral: {
|
||||
baseUrl: "https://api.mistral.ai/v1",
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const result = await createEmbeddingProvider({
|
||||
config: cfg as never,
|
||||
provider: "mistral",
|
||||
remote: {
|
||||
apiKey: "mistral-key", // pragma: allowlist secret
|
||||
},
|
||||
model: "mistral/mistral-embed",
|
||||
fallback: "none",
|
||||
});
|
||||
|
||||
const provider = requireProvider(result);
|
||||
await provider.embedQuery("hello");
|
||||
|
||||
const { url, init } = readFirstFetchRequest(fetchMock);
|
||||
expect(url).toBe("https://api.mistral.ai/v1/embeddings");
|
||||
const headers = (init?.headers ?? {}) as Record<string, string>;
|
||||
expect(headers.Authorization).toBe("Bearer mistral-key");
|
||||
const payload = JSON.parse((init?.body as string | undefined) ?? "{}") as { model?: string };
|
||||
expect(payload.model).toBe("mistral-embed");
|
||||
});
|
||||
});
|
||||
|
||||
describe("embedding provider auto selection", () => {
|
||||
it("keeps explicit model when openai is selected", async () => {
|
||||
const fetchMock = vi.fn(async (_input?: unknown, _init?: unknown) => ({
|
||||
ok: true,
|
||||
status: 200,
|
||||
json: async () => ({ data: [{ embedding: [1, 2, 3] }] }),
|
||||
}));
|
||||
vi.stubGlobal("fetch", fetchMock);
|
||||
mockPublicPinnedHostname();
|
||||
vi.mocked(authModule.resolveApiKeyForProvider).mockImplementation(async ({ provider }) => {
|
||||
if (provider === "openai") {
|
||||
return { apiKey: "openai-key", source: "env: OPENAI_API_KEY", mode: "api-key" };
|
||||
}
|
||||
throw new Error(`Unexpected provider ${provider}`);
|
||||
});
|
||||
|
||||
const result = await createEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "auto",
|
||||
model: "text-embedding-3-small",
|
||||
fallback: "none",
|
||||
});
|
||||
|
||||
expect(result.requestedProvider).toBe("auto");
|
||||
const provider = requireProvider(result);
|
||||
expect(provider.id).toBe("openai");
|
||||
await provider.embedQuery("hello");
|
||||
const url = fetchMock.mock.calls[0]?.[0];
|
||||
const init = fetchMock.mock.calls[0]?.[1] as RequestInit | undefined;
|
||||
expect(url).toBe("https://api.openai.com/v1/embeddings");
|
||||
const payload = JSON.parse(init?.body as string) as { model?: string };
|
||||
expect(payload.model).toBe("text-embedding-3-small");
|
||||
});
|
||||
|
||||
it("selects the first available remote provider in auto mode", async () => {
|
||||
const cases: Array<{
|
||||
name: string;
|
||||
expectedProvider: "openai" | "gemini" | "mistral";
|
||||
fetchMockFactory: typeof createFetchMock | typeof createGeminiFetchMock;
|
||||
resolveApiKey: (provider: string) => ResolvedProviderAuth;
|
||||
expectedUrl: string;
|
||||
}> = [
|
||||
{
|
||||
name: "openai first",
|
||||
expectedProvider: "openai" as const,
|
||||
fetchMockFactory: createFetchMock,
|
||||
resolveApiKey(provider: string): ResolvedProviderAuth {
|
||||
if (provider === "openai") {
|
||||
return { apiKey: "openai-key", source: "env: OPENAI_API_KEY", mode: "api-key" };
|
||||
}
|
||||
throw new Error(`No API key found for provider "${provider}".`);
|
||||
},
|
||||
expectedUrl: "https://api.openai.com/v1/embeddings",
|
||||
},
|
||||
{
|
||||
name: "gemini fallback",
|
||||
expectedProvider: "gemini" as const,
|
||||
fetchMockFactory: createGeminiFetchMock,
|
||||
resolveApiKey(provider: string): ResolvedProviderAuth {
|
||||
if (provider === "openai") {
|
||||
throw new Error('No API key found for provider "openai".');
|
||||
}
|
||||
if (provider === "google") {
|
||||
return {
|
||||
apiKey: "gemini-key",
|
||||
source: "env: GEMINI_API_KEY",
|
||||
mode: "api-key" as const,
|
||||
};
|
||||
}
|
||||
throw new Error(`Unexpected provider ${provider}`);
|
||||
},
|
||||
expectedUrl: `https://generativelanguage.googleapis.com/v1beta/models/${DEFAULT_GEMINI_EMBEDDING_MODEL}:embedContent`,
|
||||
},
|
||||
{
|
||||
name: "mistral after earlier misses",
|
||||
expectedProvider: "mistral" as const,
|
||||
fetchMockFactory: createFetchMock,
|
||||
resolveApiKey(provider: string): ResolvedProviderAuth {
|
||||
if (provider === "mistral") {
|
||||
return {
|
||||
apiKey: "mistral-key",
|
||||
source: "env: MISTRAL_API_KEY",
|
||||
mode: "api-key" as const,
|
||||
};
|
||||
}
|
||||
throw new Error(`No API key found for provider "${provider}".`);
|
||||
},
|
||||
expectedUrl: "https://api.mistral.ai/v1/embeddings",
|
||||
},
|
||||
];
|
||||
|
||||
for (const testCase of cases) {
|
||||
vi.resetAllMocks();
|
||||
vi.unstubAllGlobals();
|
||||
const fetchMock = testCase.fetchMockFactory();
|
||||
vi.stubGlobal("fetch", fetchMock);
|
||||
mockPublicPinnedHostname();
|
||||
vi.mocked(authModule.resolveApiKeyForProvider).mockImplementation(async ({ provider }) =>
|
||||
testCase.resolveApiKey(provider),
|
||||
);
|
||||
|
||||
const result = await createAutoProvider();
|
||||
const provider = expectAutoSelectedProvider(result, testCase.expectedProvider);
|
||||
await provider.embedQuery("hello");
|
||||
const [url] = fetchMock.mock.calls[0] ?? [];
|
||||
expect(url, testCase.name).toBe(testCase.expectedUrl);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
describe("embedding provider local fallback", () => {
|
||||
it("falls back to openai when node-llama-cpp is missing", async () => {
|
||||
mockMissingLocalEmbeddingDependency();
|
||||
|
||||
const fetchMock = createFetchMock();
|
||||
vi.stubGlobal("fetch", fetchMock);
|
||||
|
||||
mockResolvedProviderKey("provider-key");
|
||||
|
||||
const result = await createLocalProvider({ fallback: "openai" });
|
||||
|
||||
const provider = requireProvider(result);
|
||||
expect(provider.id).toBe("openai");
|
||||
expect(result.fallbackFrom).toBe("local");
|
||||
expect(result.fallbackReason).toContain("node-llama-cpp");
|
||||
});
|
||||
|
||||
it("throws a helpful error when local is requested and fallback is none", async () => {
|
||||
mockMissingLocalEmbeddingDependency();
|
||||
await expect(createLocalProvider()).rejects.toThrow(/optional dependency node-llama-cpp/i);
|
||||
});
|
||||
|
||||
it("mentions every remote provider in local setup guidance", async () => {
|
||||
mockMissingLocalEmbeddingDependency();
|
||||
await expect(createLocalProvider()).rejects.toThrow(/provider = "gemini"/i);
|
||||
await expect(createLocalProvider()).rejects.toThrow(/provider = "mistral"/i);
|
||||
});
|
||||
});
|
||||
|
||||
describe("local embedding normalization", () => {
|
||||
async function createLocalProviderForTest() {
|
||||
return createEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "local",
|
||||
model: "",
|
||||
fallback: "none",
|
||||
});
|
||||
}
|
||||
|
||||
function mockSingleLocalEmbeddingVector(
|
||||
vector: number[],
|
||||
resolveModelFile: (modelPath: string, modelDirectory?: string) => Promise<string> = async () =>
|
||||
"/fake/model.gguf",
|
||||
): void {
|
||||
vi.mocked(nodeLlamaModule.importNodeLlamaCpp).mockResolvedValue({
|
||||
getLlama: async () => ({
|
||||
loadModel: vi.fn().mockResolvedValue({
|
||||
createEmbeddingContext: vi.fn().mockResolvedValue({
|
||||
getEmbeddingFor: vi.fn().mockResolvedValue({
|
||||
vector: new Float32Array(vector),
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
resolveModelFile,
|
||||
LlamaLogLevel: { error: 0 },
|
||||
} as never);
|
||||
}
|
||||
|
||||
it("normalizes local embeddings to magnitude ~1.0", async () => {
|
||||
const unnormalizedVector = [2.35, 3.45, 0.63, 4.3, 1.2, 5.1, 2.8, 3.9];
|
||||
const resolveModelFileMock = vi.fn(async () => "/fake/model.gguf");
|
||||
|
||||
mockSingleLocalEmbeddingVector(unnormalizedVector, resolveModelFileMock);
|
||||
|
||||
const result = await createLocalProviderForTest();
|
||||
|
||||
const provider = requireProvider(result);
|
||||
const embedding = await provider.embedQuery("test query");
|
||||
|
||||
const magnitude = Math.sqrt(embedding.reduce((sum, x) => sum + x * x, 0));
|
||||
|
||||
expect(magnitude).toBeCloseTo(1.0, 5);
|
||||
expect(resolveModelFileMock).toHaveBeenCalledWith(DEFAULT_LOCAL_MODEL, undefined);
|
||||
});
|
||||
|
||||
it("handles zero vector without division by zero", async () => {
|
||||
const zeroVector = [0, 0, 0, 0];
|
||||
|
||||
mockSingleLocalEmbeddingVector(zeroVector);
|
||||
|
||||
const result = await createLocalProviderForTest();
|
||||
|
||||
const provider = requireProvider(result);
|
||||
const embedding = await provider.embedQuery("test");
|
||||
|
||||
expect(embedding).toEqual([0, 0, 0, 0]);
|
||||
expect(embedding.every((value) => Number.isFinite(value))).toBe(true);
|
||||
});
|
||||
|
||||
it("sanitizes non-finite values before normalization", async () => {
|
||||
const nonFiniteVector = [1, Number.NaN, Number.POSITIVE_INFINITY, Number.NEGATIVE_INFINITY];
|
||||
|
||||
mockSingleLocalEmbeddingVector(nonFiniteVector);
|
||||
|
||||
const result = await createLocalProviderForTest();
|
||||
|
||||
const provider = requireProvider(result);
|
||||
const embedding = await provider.embedQuery("test");
|
||||
|
||||
expect(embedding).toEqual([1, 0, 0, 0]);
|
||||
expect(embedding.every((value) => Number.isFinite(value))).toBe(true);
|
||||
});
|
||||
|
||||
it("normalizes batch embeddings to magnitude ~1.0", async () => {
|
||||
const unnormalizedVectors = [
|
||||
[2.35, 3.45, 0.63, 4.3],
|
||||
[10.0, 0.0, 0.0, 0.0],
|
||||
[1.0, 1.0, 1.0, 1.0],
|
||||
];
|
||||
|
||||
vi.mocked(nodeLlamaModule.importNodeLlamaCpp).mockResolvedValue({
|
||||
getLlama: async () => ({
|
||||
loadModel: vi.fn().mockResolvedValue({
|
||||
createEmbeddingContext: vi.fn().mockResolvedValue({
|
||||
getEmbeddingFor: vi
|
||||
.fn()
|
||||
.mockResolvedValueOnce({ vector: new Float32Array(unnormalizedVectors[0]) })
|
||||
.mockResolvedValueOnce({ vector: new Float32Array(unnormalizedVectors[1]) })
|
||||
.mockResolvedValueOnce({ vector: new Float32Array(unnormalizedVectors[2]) }),
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
resolveModelFile: async () => "/fake/model.gguf",
|
||||
LlamaLogLevel: { error: 0 },
|
||||
} as never);
|
||||
|
||||
const result = await createLocalProviderForTest();
|
||||
|
||||
const provider = requireProvider(result);
|
||||
const embeddings = await provider.embedBatch(["text1", "text2", "text3"]);
|
||||
|
||||
for (const embedding of embeddings) {
|
||||
const magnitude = Math.sqrt(embedding.reduce((sum, x) => sum + x * x, 0));
|
||||
expect(magnitude).toBeCloseTo(1.0, 5);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
describe("local embedding ensureContext concurrency", () => {
|
||||
beforeEach(() => {
|
||||
vi.resetModules();
|
||||
vi.doUnmock("./node-llama.js");
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.resetModules();
|
||||
vi.doUnmock("./node-llama.js");
|
||||
});
|
||||
|
||||
async function setupLocalProviderWithMockedInit(params?: {
|
||||
initializationDelayMs?: number;
|
||||
failFirstGetLlama?: boolean;
|
||||
}) {
|
||||
const getLlamaSpy = vi.fn();
|
||||
const loadModelSpy = vi.fn();
|
||||
const createContextSpy = vi.fn();
|
||||
let shouldFail = params?.failFirstGetLlama ?? false;
|
||||
|
||||
const nodeLlamaModule = await import("./node-llama.js");
|
||||
vi.spyOn(nodeLlamaModule, "importNodeLlamaCpp").mockResolvedValue({
|
||||
getLlama: async (...args: unknown[]) => {
|
||||
getLlamaSpy(...args);
|
||||
if (shouldFail) {
|
||||
shouldFail = false;
|
||||
throw new Error("transient init failure");
|
||||
}
|
||||
if (params?.initializationDelayMs) {
|
||||
await sleep(params.initializationDelayMs);
|
||||
}
|
||||
return {
|
||||
loadModel: async (...modelArgs: unknown[]) => {
|
||||
loadModelSpy(...modelArgs);
|
||||
if (params?.initializationDelayMs) {
|
||||
await sleep(params.initializationDelayMs);
|
||||
}
|
||||
return {
|
||||
createEmbeddingContext: async () => {
|
||||
createContextSpy();
|
||||
return {
|
||||
getEmbeddingFor: vi.fn().mockResolvedValue({
|
||||
vector: new Float32Array([1, 0, 0, 0]),
|
||||
}),
|
||||
};
|
||||
},
|
||||
};
|
||||
},
|
||||
};
|
||||
},
|
||||
resolveModelFile: async () => "/fake/model.gguf",
|
||||
LlamaLogLevel: { error: 0 },
|
||||
} as never);
|
||||
|
||||
const { createEmbeddingProvider } = await import("./embeddings.js");
|
||||
const result = await createEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "local",
|
||||
model: "",
|
||||
fallback: "none",
|
||||
});
|
||||
|
||||
return {
|
||||
provider: requireProvider(result),
|
||||
getLlamaSpy,
|
||||
loadModelSpy,
|
||||
createContextSpy,
|
||||
};
|
||||
}
|
||||
|
||||
it("loads the model only once when embedBatch is called concurrently", async () => {
|
||||
const { provider, getLlamaSpy, loadModelSpy, createContextSpy } =
|
||||
await setupLocalProviderWithMockedInit({
|
||||
initializationDelayMs: 50,
|
||||
});
|
||||
|
||||
const results = await Promise.all([
|
||||
provider.embedBatch(["text1"]),
|
||||
provider.embedBatch(["text2"]),
|
||||
provider.embedBatch(["text3"]),
|
||||
provider.embedBatch(["text4"]),
|
||||
]);
|
||||
|
||||
expect(results).toHaveLength(4);
|
||||
for (const embeddings of results) {
|
||||
expect(embeddings).toHaveLength(1);
|
||||
expect(embeddings[0]).toHaveLength(4);
|
||||
}
|
||||
|
||||
expect(getLlamaSpy).toHaveBeenCalledTimes(1);
|
||||
expect(loadModelSpy).toHaveBeenCalledTimes(1);
|
||||
expect(createContextSpy).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it("retries initialization after a transient ensureContext failure", async () => {
|
||||
const { provider, getLlamaSpy, loadModelSpy, createContextSpy } =
|
||||
await setupLocalProviderWithMockedInit({
|
||||
failFirstGetLlama: true,
|
||||
});
|
||||
|
||||
await expect(provider.embedBatch(["first"])).rejects.toThrow("transient init failure");
|
||||
|
||||
const recovered = await provider.embedBatch(["second"]);
|
||||
expect(recovered).toHaveLength(1);
|
||||
expect(recovered[0]).toHaveLength(4);
|
||||
|
||||
expect(getLlamaSpy).toHaveBeenCalledTimes(2);
|
||||
expect(loadModelSpy).toHaveBeenCalledTimes(1);
|
||||
expect(createContextSpy).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it("shares initialization when embedQuery and embedBatch start concurrently", async () => {
|
||||
const { provider, getLlamaSpy, loadModelSpy, createContextSpy } =
|
||||
await setupLocalProviderWithMockedInit({
|
||||
initializationDelayMs: 50,
|
||||
});
|
||||
|
||||
const [queryA, batch, queryB] = await Promise.all([
|
||||
provider.embedQuery("query-a"),
|
||||
provider.embedBatch(["batch-a", "batch-b"]),
|
||||
provider.embedQuery("query-b"),
|
||||
]);
|
||||
|
||||
expect(queryA).toHaveLength(4);
|
||||
expect(batch).toHaveLength(2);
|
||||
expect(queryB).toHaveLength(4);
|
||||
expect(batch[0]).toHaveLength(4);
|
||||
expect(batch[1]).toHaveLength(4);
|
||||
|
||||
expect(getLlamaSpy).toHaveBeenCalledTimes(1);
|
||||
expect(loadModelSpy).toHaveBeenCalledTimes(1);
|
||||
expect(createContextSpy).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe("FTS-only fallback when no provider available", () => {
|
||||
beforeEach(async () => {
|
||||
authModule = await import("../../agents/model-auth.js");
|
||||
({ createEmbeddingProvider, DEFAULT_LOCAL_MODEL } = await import("./embeddings.js"));
|
||||
});
|
||||
|
||||
it("returns null provider when all requested auth paths fail", async () => {
|
||||
vi.mocked(authModule.resolveApiKeyForProvider).mockRejectedValue(
|
||||
new Error("No API key found for provider"),
|
||||
);
|
||||
|
||||
for (const testCase of [
|
||||
{
|
||||
name: "auto mode",
|
||||
options: {
|
||||
config: {} as never,
|
||||
provider: "auto" as const,
|
||||
model: "",
|
||||
fallback: "none" as const,
|
||||
},
|
||||
requestedProvider: "auto",
|
||||
fallbackFrom: undefined,
|
||||
reasonIncludes: "No API key",
|
||||
},
|
||||
{
|
||||
name: "explicit provider only",
|
||||
options: {
|
||||
config: {} as never,
|
||||
provider: "openai" as const,
|
||||
model: "text-embedding-3-small",
|
||||
fallback: "none" as const,
|
||||
},
|
||||
requestedProvider: "openai",
|
||||
fallbackFrom: undefined,
|
||||
reasonIncludes: "No API key",
|
||||
},
|
||||
{
|
||||
name: "primary and fallback",
|
||||
options: {
|
||||
config: {} as never,
|
||||
provider: "openai" as const,
|
||||
model: "text-embedding-3-small",
|
||||
fallback: "gemini" as const,
|
||||
},
|
||||
requestedProvider: "openai",
|
||||
fallbackFrom: "openai",
|
||||
reasonIncludes: "Fallback to gemini failed",
|
||||
},
|
||||
]) {
|
||||
const result = await createEmbeddingProvider(testCase.options);
|
||||
expect(result.provider, testCase.name).toBeNull();
|
||||
expect(result.requestedProvider, testCase.name).toBe(testCase.requestedProvider);
|
||||
expect(result.fallbackFrom, testCase.name).toBe(testCase.fallbackFrom);
|
||||
expect(result.providerUnavailableReason, testCase.name).toContain(testCase.reasonIncludes);
|
||||
}
|
||||
});
|
||||
});
|
||||
@@ -1,324 +0,0 @@
|
||||
import fsSync from "node:fs";
|
||||
import type { Llama, LlamaEmbeddingContext, LlamaModel } from "node-llama-cpp";
|
||||
import type { OpenClawConfig } from "../../config/config.js";
|
||||
import type { SecretInput } from "../../config/types.secrets.js";
|
||||
import { formatErrorMessage } from "../../infra/errors.js";
|
||||
import { resolveUserPath } from "../../utils.js";
|
||||
import type { EmbeddingInput } from "./embedding-inputs.js";
|
||||
import { sanitizeAndNormalizeEmbedding } from "./embedding-vectors.js";
|
||||
import {
|
||||
createGeminiEmbeddingProvider,
|
||||
type GeminiEmbeddingClient,
|
||||
type GeminiTaskType,
|
||||
} from "./embeddings-gemini.js";
|
||||
import {
|
||||
createMistralEmbeddingProvider,
|
||||
type MistralEmbeddingClient,
|
||||
} from "./embeddings-mistral.js";
|
||||
import { createOllamaEmbeddingProvider, type OllamaEmbeddingClient } from "./embeddings-ollama.js";
|
||||
import { createOpenAiEmbeddingProvider, type OpenAiEmbeddingClient } from "./embeddings-openai.js";
|
||||
import { createVoyageEmbeddingProvider, type VoyageEmbeddingClient } from "./embeddings-voyage.js";
|
||||
import { importNodeLlamaCpp } from "./node-llama.js";
|
||||
|
||||
export type { GeminiEmbeddingClient } from "./embeddings-gemini.js";
|
||||
export type { MistralEmbeddingClient } from "./embeddings-mistral.js";
|
||||
export type { OpenAiEmbeddingClient } from "./embeddings-openai.js";
|
||||
export type { VoyageEmbeddingClient } from "./embeddings-voyage.js";
|
||||
export type { OllamaEmbeddingClient } from "./embeddings-ollama.js";
|
||||
|
||||
export type EmbeddingProvider = {
|
||||
id: string;
|
||||
model: string;
|
||||
maxInputTokens?: number;
|
||||
embedQuery: (text: string) => Promise<number[]>;
|
||||
embedBatch: (texts: string[]) => Promise<number[][]>;
|
||||
embedBatchInputs?: (inputs: EmbeddingInput[]) => Promise<number[][]>;
|
||||
};
|
||||
|
||||
export type EmbeddingProviderId = "openai" | "local" | "gemini" | "voyage" | "mistral" | "ollama";
|
||||
export type EmbeddingProviderRequest = EmbeddingProviderId | "auto";
|
||||
export type EmbeddingProviderFallback = EmbeddingProviderId | "none";
|
||||
|
||||
// Remote providers considered for auto-selection when provider === "auto".
|
||||
// Ollama is intentionally excluded here so that "auto" mode does not
|
||||
// implicitly assume a local Ollama instance is available.
|
||||
const REMOTE_EMBEDDING_PROVIDER_IDS = ["openai", "gemini", "voyage", "mistral"] as const;
|
||||
|
||||
export type EmbeddingProviderResult = {
|
||||
provider: EmbeddingProvider | null;
|
||||
requestedProvider: EmbeddingProviderRequest;
|
||||
fallbackFrom?: EmbeddingProviderId;
|
||||
fallbackReason?: string;
|
||||
providerUnavailableReason?: string;
|
||||
openAi?: OpenAiEmbeddingClient;
|
||||
gemini?: GeminiEmbeddingClient;
|
||||
voyage?: VoyageEmbeddingClient;
|
||||
mistral?: MistralEmbeddingClient;
|
||||
ollama?: OllamaEmbeddingClient;
|
||||
};
|
||||
|
||||
export type EmbeddingProviderOptions = {
|
||||
config: OpenClawConfig;
|
||||
agentDir?: string;
|
||||
provider: EmbeddingProviderRequest;
|
||||
remote?: {
|
||||
baseUrl?: string;
|
||||
apiKey?: SecretInput;
|
||||
headers?: Record<string, string>;
|
||||
};
|
||||
model: string;
|
||||
fallback: EmbeddingProviderFallback;
|
||||
local?: {
|
||||
modelPath?: string;
|
||||
modelCacheDir?: string;
|
||||
};
|
||||
/** Gemini embedding-2: output vector dimensions (768, 1536, or 3072). */
|
||||
outputDimensionality?: number;
|
||||
/** Gemini: override the default task type sent with embedding requests. */
|
||||
taskType?: GeminiTaskType;
|
||||
};
|
||||
|
||||
export const DEFAULT_LOCAL_MODEL =
|
||||
"hf:ggml-org/embeddinggemma-300m-qat-q8_0-GGUF/embeddinggemma-300m-qat-Q8_0.gguf";
|
||||
|
||||
function canAutoSelectLocal(options: EmbeddingProviderOptions): boolean {
|
||||
const modelPath = options.local?.modelPath?.trim();
|
||||
if (!modelPath) {
|
||||
return false;
|
||||
}
|
||||
if (/^(hf:|https?:)/i.test(modelPath)) {
|
||||
return false;
|
||||
}
|
||||
const resolved = resolveUserPath(modelPath);
|
||||
try {
|
||||
return fsSync.statSync(resolved).isFile();
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
function isMissingApiKeyError(err: unknown): boolean {
|
||||
const message = formatErrorMessage(err);
|
||||
return message.includes("No API key found for provider");
|
||||
}
|
||||
|
||||
export async function createLocalEmbeddingProvider(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<EmbeddingProvider> {
|
||||
const modelPath = options.local?.modelPath?.trim() || DEFAULT_LOCAL_MODEL;
|
||||
const modelCacheDir = options.local?.modelCacheDir?.trim();
|
||||
|
||||
// Lazy-load node-llama-cpp to keep startup light unless local is enabled.
|
||||
const { getLlama, resolveModelFile, LlamaLogLevel } = await importNodeLlamaCpp();
|
||||
|
||||
let llama: Llama | null = null;
|
||||
let embeddingModel: LlamaModel | null = null;
|
||||
let embeddingContext: LlamaEmbeddingContext | null = null;
|
||||
let initPromise: Promise<LlamaEmbeddingContext> | null = null;
|
||||
|
||||
const ensureContext = async (): Promise<LlamaEmbeddingContext> => {
|
||||
if (embeddingContext) {
|
||||
return embeddingContext;
|
||||
}
|
||||
if (initPromise) {
|
||||
return initPromise;
|
||||
}
|
||||
initPromise = (async () => {
|
||||
try {
|
||||
if (!llama) {
|
||||
llama = await getLlama({ logLevel: LlamaLogLevel.error });
|
||||
}
|
||||
if (!embeddingModel) {
|
||||
const resolved = await resolveModelFile(modelPath, modelCacheDir || undefined);
|
||||
embeddingModel = await llama.loadModel({ modelPath: resolved });
|
||||
}
|
||||
if (!embeddingContext) {
|
||||
embeddingContext = await embeddingModel.createEmbeddingContext();
|
||||
}
|
||||
return embeddingContext;
|
||||
} catch (err) {
|
||||
initPromise = null;
|
||||
throw err;
|
||||
}
|
||||
})();
|
||||
return initPromise;
|
||||
};
|
||||
|
||||
return {
|
||||
id: "local",
|
||||
model: modelPath,
|
||||
embedQuery: async (text) => {
|
||||
const ctx = await ensureContext();
|
||||
const embedding = await ctx.getEmbeddingFor(text);
|
||||
return sanitizeAndNormalizeEmbedding(Array.from(embedding.vector));
|
||||
},
|
||||
embedBatch: async (texts) => {
|
||||
const ctx = await ensureContext();
|
||||
const embeddings = await Promise.all(
|
||||
texts.map(async (text) => {
|
||||
const embedding = await ctx.getEmbeddingFor(text);
|
||||
return sanitizeAndNormalizeEmbedding(Array.from(embedding.vector));
|
||||
}),
|
||||
);
|
||||
return embeddings;
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
export async function createEmbeddingProvider(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<EmbeddingProviderResult> {
|
||||
const requestedProvider = options.provider;
|
||||
const fallback = options.fallback;
|
||||
|
||||
const createProvider = async (id: EmbeddingProviderId) => {
|
||||
if (id === "local") {
|
||||
const provider = await createLocalEmbeddingProvider(options);
|
||||
return { provider };
|
||||
}
|
||||
if (id === "ollama") {
|
||||
const { provider, client } = await createOllamaEmbeddingProvider(options);
|
||||
return { provider, ollama: client };
|
||||
}
|
||||
if (id === "gemini") {
|
||||
const { provider, client } = await createGeminiEmbeddingProvider(options);
|
||||
return { provider, gemini: client };
|
||||
}
|
||||
if (id === "voyage") {
|
||||
const { provider, client } = await createVoyageEmbeddingProvider(options);
|
||||
return { provider, voyage: client };
|
||||
}
|
||||
if (id === "mistral") {
|
||||
const { provider, client } = await createMistralEmbeddingProvider(options);
|
||||
return { provider, mistral: client };
|
||||
}
|
||||
const { provider, client } = await createOpenAiEmbeddingProvider(options);
|
||||
return { provider, openAi: client };
|
||||
};
|
||||
|
||||
const formatPrimaryError = (err: unknown, provider: EmbeddingProviderId) =>
|
||||
provider === "local" ? formatLocalSetupError(err) : formatErrorMessage(err);
|
||||
|
||||
if (requestedProvider === "auto") {
|
||||
const missingKeyErrors: string[] = [];
|
||||
let localError: string | null = null;
|
||||
|
||||
if (canAutoSelectLocal(options)) {
|
||||
try {
|
||||
const local = await createProvider("local");
|
||||
return { ...local, requestedProvider };
|
||||
} catch (err) {
|
||||
localError = formatLocalSetupError(err);
|
||||
}
|
||||
}
|
||||
|
||||
for (const provider of REMOTE_EMBEDDING_PROVIDER_IDS) {
|
||||
try {
|
||||
const result = await createProvider(provider);
|
||||
return { ...result, requestedProvider };
|
||||
} catch (err) {
|
||||
const message = formatPrimaryError(err, provider);
|
||||
if (isMissingApiKeyError(err)) {
|
||||
missingKeyErrors.push(message);
|
||||
continue;
|
||||
}
|
||||
// Non-auth errors (e.g., network) are still fatal
|
||||
const wrapped = new Error(message) as Error & { cause?: unknown };
|
||||
wrapped.cause = err;
|
||||
throw wrapped;
|
||||
}
|
||||
}
|
||||
|
||||
// All providers failed due to missing API keys - return null provider for FTS-only mode
|
||||
const details = [...missingKeyErrors, localError].filter(Boolean) as string[];
|
||||
const reason = details.length > 0 ? details.join("\n\n") : "No embeddings provider available.";
|
||||
return {
|
||||
provider: null,
|
||||
requestedProvider,
|
||||
providerUnavailableReason: reason,
|
||||
};
|
||||
}
|
||||
|
||||
try {
|
||||
const primary = await createProvider(requestedProvider);
|
||||
return { ...primary, requestedProvider };
|
||||
} catch (primaryErr) {
|
||||
const reason = formatPrimaryError(primaryErr, requestedProvider);
|
||||
if (fallback && fallback !== "none" && fallback !== requestedProvider) {
|
||||
try {
|
||||
const fallbackResult = await createProvider(fallback);
|
||||
return {
|
||||
...fallbackResult,
|
||||
requestedProvider,
|
||||
fallbackFrom: requestedProvider,
|
||||
fallbackReason: reason,
|
||||
};
|
||||
} catch (fallbackErr) {
|
||||
// Both primary and fallback failed - check if it's auth-related
|
||||
const fallbackReason = formatErrorMessage(fallbackErr);
|
||||
const combinedReason = `${reason}\n\nFallback to ${fallback} failed: ${fallbackReason}`;
|
||||
if (isMissingApiKeyError(primaryErr) && isMissingApiKeyError(fallbackErr)) {
|
||||
// Both failed due to missing API keys - return null for FTS-only mode
|
||||
return {
|
||||
provider: null,
|
||||
requestedProvider,
|
||||
fallbackFrom: requestedProvider,
|
||||
fallbackReason: reason,
|
||||
providerUnavailableReason: combinedReason,
|
||||
};
|
||||
}
|
||||
// Non-auth errors are still fatal
|
||||
const wrapped = new Error(combinedReason) as Error & { cause?: unknown };
|
||||
wrapped.cause = fallbackErr;
|
||||
throw wrapped;
|
||||
}
|
||||
}
|
||||
// No fallback configured - check if we should degrade to FTS-only
|
||||
if (isMissingApiKeyError(primaryErr)) {
|
||||
return {
|
||||
provider: null,
|
||||
requestedProvider,
|
||||
providerUnavailableReason: reason,
|
||||
};
|
||||
}
|
||||
const wrapped = new Error(reason) as Error & { cause?: unknown };
|
||||
wrapped.cause = primaryErr;
|
||||
throw wrapped;
|
||||
}
|
||||
}
|
||||
|
||||
function isNodeLlamaCppMissing(err: unknown): boolean {
|
||||
if (!(err instanceof Error)) {
|
||||
return false;
|
||||
}
|
||||
const code = (err as Error & { code?: unknown }).code;
|
||||
if (code === "ERR_MODULE_NOT_FOUND") {
|
||||
return err.message.includes("node-llama-cpp");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
function formatLocalSetupError(err: unknown): string {
|
||||
const detail = formatErrorMessage(err);
|
||||
const missing = isNodeLlamaCppMissing(err);
|
||||
return [
|
||||
"Local embeddings unavailable.",
|
||||
missing
|
||||
? "Reason: optional dependency node-llama-cpp is missing (or failed to install)."
|
||||
: detail
|
||||
? `Reason: ${detail}`
|
||||
: undefined,
|
||||
missing && detail ? `Detail: ${detail}` : null,
|
||||
"To enable local embeddings:",
|
||||
"1) Use Node 24 (recommended for installs/updates; Node 22 LTS, currently 22.14+, remains supported)",
|
||||
missing
|
||||
? "2) Reinstall OpenClaw (this should install node-llama-cpp): npm i -g openclaw@latest"
|
||||
: null,
|
||||
"3) If you use pnpm: pnpm approve-builds (select node-llama-cpp), then pnpm rebuild node-llama-cpp",
|
||||
...REMOTE_EMBEDDING_PROVIDER_IDS.map(
|
||||
(provider) => `Or set agents.defaults.memorySearch.provider = "${provider}" (remote).`,
|
||||
),
|
||||
]
|
||||
.filter(Boolean)
|
||||
.join("\n");
|
||||
}
|
||||
@@ -1,31 +0,0 @@
|
||||
import type { Stats } from "node:fs";
|
||||
import fs from "node:fs/promises";
|
||||
|
||||
export type RegularFileStatResult = { missing: true } | { missing: false; stat: Stats };
|
||||
|
||||
export function isFileMissingError(
|
||||
err: unknown,
|
||||
): err is NodeJS.ErrnoException & { code: "ENOENT" } {
|
||||
return Boolean(
|
||||
err &&
|
||||
typeof err === "object" &&
|
||||
"code" in err &&
|
||||
(err as Partial<NodeJS.ErrnoException>).code === "ENOENT",
|
||||
);
|
||||
}
|
||||
|
||||
export async function statRegularFile(absPath: string): Promise<RegularFileStatResult> {
|
||||
let stat: Stats;
|
||||
try {
|
||||
stat = await fs.lstat(absPath);
|
||||
} catch (err) {
|
||||
if (isFileMissingError(err)) {
|
||||
return { missing: true };
|
||||
}
|
||||
throw err;
|
||||
}
|
||||
if (stat.isSymbolicLink() || !stat.isFile()) {
|
||||
throw new Error("path required");
|
||||
}
|
||||
return { missing: false, stat };
|
||||
}
|
||||
@@ -1,314 +0,0 @@
|
||||
import fs from "node:fs/promises";
|
||||
import os from "node:os";
|
||||
import path from "node:path";
|
||||
import { afterEach, beforeEach, describe, expect, it } from "vitest";
|
||||
import {
|
||||
buildMultimodalChunkForIndexing,
|
||||
buildFileEntry,
|
||||
chunkMarkdown,
|
||||
listMemoryFiles,
|
||||
normalizeExtraMemoryPaths,
|
||||
remapChunkLines,
|
||||
} from "./internal.js";
|
||||
import {
|
||||
DEFAULT_MEMORY_MULTIMODAL_MAX_FILE_BYTES,
|
||||
type MemoryMultimodalSettings,
|
||||
} from "./multimodal.js";
|
||||
|
||||
function setupTempDirLifecycle(prefix: string): () => string {
|
||||
let tmpDir = "";
|
||||
beforeEach(async () => {
|
||||
tmpDir = await fs.mkdtemp(path.join(os.tmpdir(), prefix));
|
||||
});
|
||||
afterEach(async () => {
|
||||
await fs.rm(tmpDir, { recursive: true, force: true });
|
||||
});
|
||||
return () => tmpDir;
|
||||
}
|
||||
|
||||
describe("normalizeExtraMemoryPaths", () => {
|
||||
it("trims, resolves, and dedupes paths", () => {
|
||||
const workspaceDir = path.join(os.tmpdir(), "memory-test-workspace");
|
||||
const absPath = path.resolve(path.sep, "shared-notes");
|
||||
const result = normalizeExtraMemoryPaths(workspaceDir, [
|
||||
" notes ",
|
||||
"./notes",
|
||||
absPath,
|
||||
absPath,
|
||||
"",
|
||||
]);
|
||||
expect(result).toEqual([path.resolve(workspaceDir, "notes"), absPath]);
|
||||
});
|
||||
});
|
||||
|
||||
describe("listMemoryFiles", () => {
|
||||
const getTmpDir = setupTempDirLifecycle("memory-test-");
|
||||
const multimodal: MemoryMultimodalSettings = {
|
||||
enabled: true,
|
||||
modalities: ["image", "audio"],
|
||||
maxFileBytes: DEFAULT_MEMORY_MULTIMODAL_MAX_FILE_BYTES,
|
||||
};
|
||||
|
||||
it("includes files from additional paths (directory)", async () => {
|
||||
const tmpDir = getTmpDir();
|
||||
await fs.writeFile(path.join(tmpDir, "MEMORY.md"), "# Default memory");
|
||||
const extraDir = path.join(tmpDir, "extra-notes");
|
||||
await fs.mkdir(extraDir, { recursive: true });
|
||||
await fs.writeFile(path.join(extraDir, "note1.md"), "# Note 1");
|
||||
await fs.writeFile(path.join(extraDir, "note2.md"), "# Note 2");
|
||||
await fs.writeFile(path.join(extraDir, "ignore.txt"), "Not a markdown file");
|
||||
|
||||
const files = await listMemoryFiles(tmpDir, [extraDir]);
|
||||
expect(files).toHaveLength(3);
|
||||
expect(files.some((file) => file.endsWith("MEMORY.md"))).toBe(true);
|
||||
expect(files.some((file) => file.endsWith("note1.md"))).toBe(true);
|
||||
expect(files.some((file) => file.endsWith("note2.md"))).toBe(true);
|
||||
expect(files.some((file) => file.endsWith("ignore.txt"))).toBe(false);
|
||||
});
|
||||
|
||||
it("includes files from additional paths (single file)", async () => {
|
||||
const tmpDir = getTmpDir();
|
||||
await fs.writeFile(path.join(tmpDir, "MEMORY.md"), "# Default memory");
|
||||
const singleFile = path.join(tmpDir, "standalone.md");
|
||||
await fs.writeFile(singleFile, "# Standalone");
|
||||
|
||||
const files = await listMemoryFiles(tmpDir, [singleFile]);
|
||||
expect(files).toHaveLength(2);
|
||||
expect(files.some((file) => file.endsWith("standalone.md"))).toBe(true);
|
||||
});
|
||||
|
||||
it("handles relative paths in additional paths", async () => {
|
||||
const tmpDir = getTmpDir();
|
||||
await fs.writeFile(path.join(tmpDir, "MEMORY.md"), "# Default memory");
|
||||
const extraDir = path.join(tmpDir, "subdir");
|
||||
await fs.mkdir(extraDir, { recursive: true });
|
||||
await fs.writeFile(path.join(extraDir, "nested.md"), "# Nested");
|
||||
|
||||
const files = await listMemoryFiles(tmpDir, ["subdir"]);
|
||||
expect(files).toHaveLength(2);
|
||||
expect(files.some((file) => file.endsWith("nested.md"))).toBe(true);
|
||||
});
|
||||
|
||||
it("ignores non-existent additional paths", async () => {
|
||||
const tmpDir = getTmpDir();
|
||||
await fs.writeFile(path.join(tmpDir, "MEMORY.md"), "# Default memory");
|
||||
|
||||
const files = await listMemoryFiles(tmpDir, ["/does/not/exist"]);
|
||||
expect(files).toHaveLength(1);
|
||||
});
|
||||
|
||||
it("ignores symlinked files and directories", async () => {
|
||||
const tmpDir = getTmpDir();
|
||||
await fs.writeFile(path.join(tmpDir, "MEMORY.md"), "# Default memory");
|
||||
const extraDir = path.join(tmpDir, "extra");
|
||||
await fs.mkdir(extraDir, { recursive: true });
|
||||
await fs.writeFile(path.join(extraDir, "note.md"), "# Note");
|
||||
|
||||
const targetFile = path.join(tmpDir, "target.md");
|
||||
await fs.writeFile(targetFile, "# Target");
|
||||
const linkFile = path.join(extraDir, "linked.md");
|
||||
|
||||
const targetDir = path.join(tmpDir, "target-dir");
|
||||
await fs.mkdir(targetDir, { recursive: true });
|
||||
await fs.writeFile(path.join(targetDir, "nested.md"), "# Nested");
|
||||
const linkDir = path.join(tmpDir, "linked-dir");
|
||||
|
||||
let symlinksOk = true;
|
||||
try {
|
||||
await fs.symlink(targetFile, linkFile, "file");
|
||||
await fs.symlink(targetDir, linkDir, "dir");
|
||||
} catch (err) {
|
||||
const code = (err as NodeJS.ErrnoException).code;
|
||||
if (code === "EPERM" || code === "EACCES") {
|
||||
symlinksOk = false;
|
||||
} else {
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
const files = await listMemoryFiles(tmpDir, [extraDir, linkDir]);
|
||||
expect(files.some((file) => file.endsWith("note.md"))).toBe(true);
|
||||
if (symlinksOk) {
|
||||
expect(files.some((file) => file.endsWith("linked.md"))).toBe(false);
|
||||
expect(files.some((file) => file.endsWith("nested.md"))).toBe(false);
|
||||
}
|
||||
});
|
||||
|
||||
it("dedupes overlapping extra paths that resolve to the same file", async () => {
|
||||
const tmpDir = getTmpDir();
|
||||
await fs.writeFile(path.join(tmpDir, "MEMORY.md"), "# Default memory");
|
||||
const files = await listMemoryFiles(tmpDir, [tmpDir, ".", path.join(tmpDir, "MEMORY.md")]);
|
||||
const memoryMatches = files.filter((file) => file.endsWith("MEMORY.md"));
|
||||
expect(memoryMatches).toHaveLength(1);
|
||||
});
|
||||
|
||||
it("includes image and audio files from extra paths when multimodal is enabled", async () => {
|
||||
const tmpDir = getTmpDir();
|
||||
const extraDir = path.join(tmpDir, "media");
|
||||
await fs.mkdir(extraDir, { recursive: true });
|
||||
await fs.writeFile(path.join(extraDir, "diagram.png"), Buffer.from("png"));
|
||||
await fs.writeFile(path.join(extraDir, "note.wav"), Buffer.from("wav"));
|
||||
await fs.writeFile(path.join(extraDir, "ignore.bin"), Buffer.from("bin"));
|
||||
|
||||
const files = await listMemoryFiles(tmpDir, [extraDir], multimodal);
|
||||
expect(files.some((file) => file.endsWith("diagram.png"))).toBe(true);
|
||||
expect(files.some((file) => file.endsWith("note.wav"))).toBe(true);
|
||||
expect(files.some((file) => file.endsWith("ignore.bin"))).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe("buildFileEntry", () => {
|
||||
const getTmpDir = setupTempDirLifecycle("memory-build-entry-");
|
||||
const multimodal: MemoryMultimodalSettings = {
|
||||
enabled: true,
|
||||
modalities: ["image", "audio"],
|
||||
maxFileBytes: DEFAULT_MEMORY_MULTIMODAL_MAX_FILE_BYTES,
|
||||
};
|
||||
|
||||
it("returns null when the file disappears before reading", async () => {
|
||||
const tmpDir = getTmpDir();
|
||||
const target = path.join(tmpDir, "ghost.md");
|
||||
await fs.writeFile(target, "ghost", "utf-8");
|
||||
await fs.rm(target);
|
||||
const entry = await buildFileEntry(target, tmpDir);
|
||||
expect(entry).toBeNull();
|
||||
});
|
||||
|
||||
it("returns metadata when the file exists", async () => {
|
||||
const tmpDir = getTmpDir();
|
||||
const target = path.join(tmpDir, "note.md");
|
||||
await fs.writeFile(target, "hello", "utf-8");
|
||||
const entry = await buildFileEntry(target, tmpDir);
|
||||
expect(entry).not.toBeNull();
|
||||
expect(entry?.path).toBe("note.md");
|
||||
expect(entry?.size).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
it("returns multimodal metadata for eligible image files", async () => {
|
||||
const tmpDir = getTmpDir();
|
||||
const target = path.join(tmpDir, "diagram.png");
|
||||
await fs.writeFile(target, Buffer.from("png"));
|
||||
|
||||
const entry = await buildFileEntry(target, tmpDir, multimodal);
|
||||
|
||||
expect(entry).toMatchObject({
|
||||
path: "diagram.png",
|
||||
kind: "multimodal",
|
||||
modality: "image",
|
||||
mimeType: "image/png",
|
||||
contentText: "Image file: diagram.png",
|
||||
});
|
||||
});
|
||||
|
||||
it("builds a multimodal chunk lazily for indexing", async () => {
|
||||
const tmpDir = getTmpDir();
|
||||
const target = path.join(tmpDir, "diagram.png");
|
||||
await fs.writeFile(target, Buffer.from("png"));
|
||||
|
||||
const entry = await buildFileEntry(target, tmpDir, multimodal);
|
||||
const built = await buildMultimodalChunkForIndexing(entry!);
|
||||
|
||||
expect(built?.chunk.embeddingInput?.parts).toEqual([
|
||||
{ type: "text", text: "Image file: diagram.png" },
|
||||
expect.objectContaining({ type: "inline-data", mimeType: "image/png" }),
|
||||
]);
|
||||
expect(built?.structuredInputBytes).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
it("skips lazy multimodal indexing when the file grows after discovery", async () => {
|
||||
const tmpDir = getTmpDir();
|
||||
const target = path.join(tmpDir, "diagram.png");
|
||||
await fs.writeFile(target, Buffer.from("png"));
|
||||
|
||||
const entry = await buildFileEntry(target, tmpDir, multimodal);
|
||||
await fs.writeFile(target, Buffer.alloc(entry!.size + 32, 1));
|
||||
|
||||
await expect(buildMultimodalChunkForIndexing(entry!)).resolves.toBeNull();
|
||||
});
|
||||
|
||||
it("skips lazy multimodal indexing when file bytes change after discovery", async () => {
|
||||
const tmpDir = getTmpDir();
|
||||
const target = path.join(tmpDir, "diagram.png");
|
||||
await fs.writeFile(target, Buffer.from("png"));
|
||||
|
||||
const entry = await buildFileEntry(target, tmpDir, multimodal);
|
||||
await fs.writeFile(target, Buffer.from("gif"));
|
||||
|
||||
await expect(buildMultimodalChunkForIndexing(entry!)).resolves.toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe("chunkMarkdown", () => {
|
||||
it("splits overly long lines into max-sized chunks", () => {
|
||||
const chunkTokens = 400;
|
||||
const maxChars = chunkTokens * 4;
|
||||
const content = "a".repeat(maxChars * 3 + 25);
|
||||
const chunks = chunkMarkdown(content, { tokens: chunkTokens, overlap: 0 });
|
||||
expect(chunks.length).toBeGreaterThan(1);
|
||||
for (const chunk of chunks) {
|
||||
expect(chunk.text.length).toBeLessThanOrEqual(maxChars);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
describe("remapChunkLines", () => {
|
||||
it("remaps chunk line numbers using a lineMap", () => {
|
||||
// Simulate 5 content lines that came from JSONL lines [4, 6, 7, 10, 13] (1-indexed)
|
||||
const lineMap = [4, 6, 7, 10, 13];
|
||||
|
||||
// Create chunks from content that has 5 lines
|
||||
const content = "User: Hello\nAssistant: Hi\nUser: Question\nAssistant: Answer\nUser: Thanks";
|
||||
const chunks = chunkMarkdown(content, { tokens: 400, overlap: 0 });
|
||||
expect(chunks.length).toBeGreaterThan(0);
|
||||
|
||||
// Before remapping, startLine/endLine reference content line numbers (1-indexed)
|
||||
expect(chunks[0].startLine).toBe(1);
|
||||
|
||||
// Remap
|
||||
remapChunkLines(chunks, lineMap);
|
||||
|
||||
// After remapping, line numbers should reference original JSONL lines
|
||||
// Content line 1 → JSONL line 4, content line 5 → JSONL line 13
|
||||
expect(chunks[0].startLine).toBe(4);
|
||||
const lastChunk = chunks[chunks.length - 1];
|
||||
expect(lastChunk.endLine).toBe(13);
|
||||
});
|
||||
|
||||
it("preserves original line numbers when lineMap is undefined", () => {
|
||||
const content = "Line one\nLine two\nLine three";
|
||||
const chunks = chunkMarkdown(content, { tokens: 400, overlap: 0 });
|
||||
const originalStart = chunks[0].startLine;
|
||||
const originalEnd = chunks[chunks.length - 1].endLine;
|
||||
|
||||
remapChunkLines(chunks, undefined);
|
||||
|
||||
expect(chunks[0].startLine).toBe(originalStart);
|
||||
expect(chunks[chunks.length - 1].endLine).toBe(originalEnd);
|
||||
});
|
||||
|
||||
it("handles multi-chunk content with correct remapping", () => {
|
||||
// Use small chunk size to force multiple chunks
|
||||
// lineMap: 10 content lines from JSONL lines [2, 5, 8, 11, 14, 17, 20, 23, 26, 29]
|
||||
const lineMap = [2, 5, 8, 11, 14, 17, 20, 23, 26, 29];
|
||||
const contentLines = lineMap.map((_, i) =>
|
||||
i % 2 === 0 ? `User: Message ${i}` : `Assistant: Reply ${i}`,
|
||||
);
|
||||
const content = contentLines.join("\n");
|
||||
|
||||
// Use very small chunk size to force splitting
|
||||
const chunks = chunkMarkdown(content, { tokens: 10, overlap: 0 });
|
||||
expect(chunks.length).toBeGreaterThan(1);
|
||||
|
||||
remapChunkLines(chunks, lineMap);
|
||||
|
||||
// First chunk should start at JSONL line 2
|
||||
expect(chunks[0].startLine).toBe(2);
|
||||
// Last chunk should end at JSONL line 29
|
||||
expect(chunks[chunks.length - 1].endLine).toBe(29);
|
||||
|
||||
// Each chunk's startLine should be ≤ its endLine
|
||||
for (const chunk of chunks) {
|
||||
expect(chunk.startLine).toBeLessThanOrEqual(chunk.endLine);
|
||||
}
|
||||
});
|
||||
});
|
||||
@@ -1,482 +0,0 @@
|
||||
import crypto from "node:crypto";
|
||||
import fsSync from "node:fs";
|
||||
import fs from "node:fs/promises";
|
||||
import path from "node:path";
|
||||
import { detectMime } from "../../media/mime.js";
|
||||
import { runTasksWithConcurrency } from "../../utils/run-with-concurrency.js";
|
||||
import { estimateStructuredEmbeddingInputBytes } from "./embedding-input-limits.js";
|
||||
import { buildTextEmbeddingInput, type EmbeddingInput } from "./embedding-inputs.js";
|
||||
import { isFileMissingError } from "./fs-utils.js";
|
||||
import {
|
||||
buildMemoryMultimodalLabel,
|
||||
classifyMemoryMultimodalPath,
|
||||
type MemoryMultimodalModality,
|
||||
type MemoryMultimodalSettings,
|
||||
} from "./multimodal.js";
|
||||
|
||||
export type MemoryFileEntry = {
|
||||
path: string;
|
||||
absPath: string;
|
||||
mtimeMs: number;
|
||||
size: number;
|
||||
hash: string;
|
||||
dataHash?: string;
|
||||
kind?: "markdown" | "multimodal";
|
||||
contentText?: string;
|
||||
modality?: MemoryMultimodalModality;
|
||||
mimeType?: string;
|
||||
};
|
||||
|
||||
export type MemoryChunk = {
|
||||
startLine: number;
|
||||
endLine: number;
|
||||
text: string;
|
||||
hash: string;
|
||||
embeddingInput?: EmbeddingInput;
|
||||
};
|
||||
|
||||
export type MultimodalMemoryChunk = {
|
||||
chunk: MemoryChunk;
|
||||
structuredInputBytes: number;
|
||||
};
|
||||
|
||||
const DISABLED_MULTIMODAL_SETTINGS: MemoryMultimodalSettings = {
|
||||
enabled: false,
|
||||
modalities: [],
|
||||
maxFileBytes: 0,
|
||||
};
|
||||
|
||||
export function ensureDir(dir: string): string {
|
||||
try {
|
||||
fsSync.mkdirSync(dir, { recursive: true });
|
||||
} catch {}
|
||||
return dir;
|
||||
}
|
||||
|
||||
export function normalizeRelPath(value: string): string {
|
||||
const trimmed = value.trim().replace(/^[./]+/, "");
|
||||
return trimmed.replace(/\\/g, "/");
|
||||
}
|
||||
|
||||
export function normalizeExtraMemoryPaths(workspaceDir: string, extraPaths?: string[]): string[] {
|
||||
if (!extraPaths?.length) {
|
||||
return [];
|
||||
}
|
||||
const resolved = extraPaths
|
||||
.map((value) => value.trim())
|
||||
.filter(Boolean)
|
||||
.map((value) =>
|
||||
path.isAbsolute(value) ? path.resolve(value) : path.resolve(workspaceDir, value),
|
||||
);
|
||||
return Array.from(new Set(resolved));
|
||||
}
|
||||
|
||||
export function isMemoryPath(relPath: string): boolean {
|
||||
const normalized = normalizeRelPath(relPath);
|
||||
if (!normalized) {
|
||||
return false;
|
||||
}
|
||||
if (normalized === "MEMORY.md" || normalized === "memory.md") {
|
||||
return true;
|
||||
}
|
||||
return normalized.startsWith("memory/");
|
||||
}
|
||||
|
||||
function isAllowedMemoryFilePath(filePath: string, multimodal?: MemoryMultimodalSettings): boolean {
|
||||
if (filePath.endsWith(".md")) {
|
||||
return true;
|
||||
}
|
||||
return (
|
||||
classifyMemoryMultimodalPath(filePath, multimodal ?? DISABLED_MULTIMODAL_SETTINGS) !== null
|
||||
);
|
||||
}
|
||||
|
||||
async function walkDir(dir: string, files: string[], multimodal?: MemoryMultimodalSettings) {
|
||||
const entries = await fs.readdir(dir, { withFileTypes: true });
|
||||
for (const entry of entries) {
|
||||
const full = path.join(dir, entry.name);
|
||||
if (entry.isSymbolicLink()) {
|
||||
continue;
|
||||
}
|
||||
if (entry.isDirectory()) {
|
||||
await walkDir(full, files, multimodal);
|
||||
continue;
|
||||
}
|
||||
if (!entry.isFile()) {
|
||||
continue;
|
||||
}
|
||||
if (!isAllowedMemoryFilePath(full, multimodal)) {
|
||||
continue;
|
||||
}
|
||||
files.push(full);
|
||||
}
|
||||
}
|
||||
|
||||
export async function listMemoryFiles(
|
||||
workspaceDir: string,
|
||||
extraPaths?: string[],
|
||||
multimodal?: MemoryMultimodalSettings,
|
||||
): Promise<string[]> {
|
||||
const result: string[] = [];
|
||||
const memoryFile = path.join(workspaceDir, "MEMORY.md");
|
||||
const altMemoryFile = path.join(workspaceDir, "memory.md");
|
||||
const memoryDir = path.join(workspaceDir, "memory");
|
||||
|
||||
const addMarkdownFile = async (absPath: string) => {
|
||||
try {
|
||||
const stat = await fs.lstat(absPath);
|
||||
if (stat.isSymbolicLink() || !stat.isFile()) {
|
||||
return;
|
||||
}
|
||||
if (!absPath.endsWith(".md")) {
|
||||
return;
|
||||
}
|
||||
result.push(absPath);
|
||||
} catch {}
|
||||
};
|
||||
|
||||
await addMarkdownFile(memoryFile);
|
||||
await addMarkdownFile(altMemoryFile);
|
||||
try {
|
||||
const dirStat = await fs.lstat(memoryDir);
|
||||
if (!dirStat.isSymbolicLink() && dirStat.isDirectory()) {
|
||||
await walkDir(memoryDir, result);
|
||||
}
|
||||
} catch {}
|
||||
|
||||
const normalizedExtraPaths = normalizeExtraMemoryPaths(workspaceDir, extraPaths);
|
||||
if (normalizedExtraPaths.length > 0) {
|
||||
for (const inputPath of normalizedExtraPaths) {
|
||||
try {
|
||||
const stat = await fs.lstat(inputPath);
|
||||
if (stat.isSymbolicLink()) {
|
||||
continue;
|
||||
}
|
||||
if (stat.isDirectory()) {
|
||||
await walkDir(inputPath, result, multimodal);
|
||||
continue;
|
||||
}
|
||||
if (stat.isFile() && isAllowedMemoryFilePath(inputPath, multimodal)) {
|
||||
result.push(inputPath);
|
||||
}
|
||||
} catch {}
|
||||
}
|
||||
}
|
||||
if (result.length <= 1) {
|
||||
return result;
|
||||
}
|
||||
const seen = new Set<string>();
|
||||
const deduped: string[] = [];
|
||||
for (const entry of result) {
|
||||
let key = entry;
|
||||
try {
|
||||
key = await fs.realpath(entry);
|
||||
} catch {}
|
||||
if (seen.has(key)) {
|
||||
continue;
|
||||
}
|
||||
seen.add(key);
|
||||
deduped.push(entry);
|
||||
}
|
||||
return deduped;
|
||||
}
|
||||
|
||||
export function hashText(value: string): string {
|
||||
return crypto.createHash("sha256").update(value).digest("hex");
|
||||
}
|
||||
|
||||
export async function buildFileEntry(
|
||||
absPath: string,
|
||||
workspaceDir: string,
|
||||
multimodal?: MemoryMultimodalSettings,
|
||||
): Promise<MemoryFileEntry | null> {
|
||||
let stat;
|
||||
try {
|
||||
stat = await fs.stat(absPath);
|
||||
} catch (err) {
|
||||
if (isFileMissingError(err)) {
|
||||
return null;
|
||||
}
|
||||
throw err;
|
||||
}
|
||||
const normalizedPath = path.relative(workspaceDir, absPath).replace(/\\/g, "/");
|
||||
const multimodalSettings = multimodal ?? DISABLED_MULTIMODAL_SETTINGS;
|
||||
const modality = classifyMemoryMultimodalPath(absPath, multimodalSettings);
|
||||
if (modality) {
|
||||
if (stat.size > multimodalSettings.maxFileBytes) {
|
||||
return null;
|
||||
}
|
||||
let buffer: Buffer;
|
||||
try {
|
||||
buffer = await fs.readFile(absPath);
|
||||
} catch (err) {
|
||||
if (isFileMissingError(err)) {
|
||||
return null;
|
||||
}
|
||||
throw err;
|
||||
}
|
||||
const mimeType = await detectMime({ buffer: buffer.subarray(0, 512), filePath: absPath });
|
||||
if (!mimeType || !mimeType.startsWith(`${modality}/`)) {
|
||||
return null;
|
||||
}
|
||||
const contentText = buildMemoryMultimodalLabel(modality, normalizedPath);
|
||||
const dataHash = crypto.createHash("sha256").update(buffer).digest("hex");
|
||||
const chunkHash = hashText(
|
||||
JSON.stringify({
|
||||
path: normalizedPath,
|
||||
contentText,
|
||||
mimeType,
|
||||
dataHash,
|
||||
}),
|
||||
);
|
||||
return {
|
||||
path: normalizedPath,
|
||||
absPath,
|
||||
mtimeMs: stat.mtimeMs,
|
||||
size: stat.size,
|
||||
hash: chunkHash,
|
||||
dataHash,
|
||||
kind: "multimodal",
|
||||
contentText,
|
||||
modality,
|
||||
mimeType,
|
||||
};
|
||||
}
|
||||
let content: string;
|
||||
try {
|
||||
content = await fs.readFile(absPath, "utf-8");
|
||||
} catch (err) {
|
||||
if (isFileMissingError(err)) {
|
||||
return null;
|
||||
}
|
||||
throw err;
|
||||
}
|
||||
const hash = hashText(content);
|
||||
return {
|
||||
path: normalizedPath,
|
||||
absPath,
|
||||
mtimeMs: stat.mtimeMs,
|
||||
size: stat.size,
|
||||
hash,
|
||||
kind: "markdown",
|
||||
};
|
||||
}
|
||||
|
||||
async function loadMultimodalEmbeddingInput(
|
||||
entry: Pick<
|
||||
MemoryFileEntry,
|
||||
"absPath" | "contentText" | "mimeType" | "kind" | "size" | "dataHash"
|
||||
>,
|
||||
): Promise<EmbeddingInput | null> {
|
||||
if (entry.kind !== "multimodal" || !entry.contentText || !entry.mimeType) {
|
||||
return null;
|
||||
}
|
||||
let stat;
|
||||
try {
|
||||
stat = await fs.stat(entry.absPath);
|
||||
} catch (err) {
|
||||
if (isFileMissingError(err)) {
|
||||
return null;
|
||||
}
|
||||
throw err;
|
||||
}
|
||||
if (stat.size !== entry.size) {
|
||||
return null;
|
||||
}
|
||||
let buffer: Buffer;
|
||||
try {
|
||||
buffer = await fs.readFile(entry.absPath);
|
||||
} catch (err) {
|
||||
if (isFileMissingError(err)) {
|
||||
return null;
|
||||
}
|
||||
throw err;
|
||||
}
|
||||
const dataHash = crypto.createHash("sha256").update(buffer).digest("hex");
|
||||
if (entry.dataHash && entry.dataHash !== dataHash) {
|
||||
return null;
|
||||
}
|
||||
return {
|
||||
text: entry.contentText,
|
||||
parts: [
|
||||
{ type: "text", text: entry.contentText },
|
||||
{
|
||||
type: "inline-data",
|
||||
mimeType: entry.mimeType,
|
||||
data: buffer.toString("base64"),
|
||||
},
|
||||
],
|
||||
};
|
||||
}
|
||||
|
||||
export async function buildMultimodalChunkForIndexing(
|
||||
entry: Pick<
|
||||
MemoryFileEntry,
|
||||
"absPath" | "contentText" | "mimeType" | "kind" | "hash" | "size" | "dataHash"
|
||||
>,
|
||||
): Promise<MultimodalMemoryChunk | null> {
|
||||
const embeddingInput = await loadMultimodalEmbeddingInput(entry);
|
||||
if (!embeddingInput) {
|
||||
return null;
|
||||
}
|
||||
return {
|
||||
chunk: {
|
||||
startLine: 1,
|
||||
endLine: 1,
|
||||
text: entry.contentText ?? embeddingInput.text,
|
||||
hash: entry.hash,
|
||||
embeddingInput,
|
||||
},
|
||||
structuredInputBytes: estimateStructuredEmbeddingInputBytes(embeddingInput),
|
||||
};
|
||||
}
|
||||
|
||||
export function chunkMarkdown(
|
||||
content: string,
|
||||
chunking: { tokens: number; overlap: number },
|
||||
): MemoryChunk[] {
|
||||
const lines = content.split("\n");
|
||||
if (lines.length === 0) {
|
||||
return [];
|
||||
}
|
||||
const maxChars = Math.max(32, chunking.tokens * 4);
|
||||
const overlapChars = Math.max(0, chunking.overlap * 4);
|
||||
const chunks: MemoryChunk[] = [];
|
||||
|
||||
let current: Array<{ line: string; lineNo: number }> = [];
|
||||
let currentChars = 0;
|
||||
|
||||
const flush = () => {
|
||||
if (current.length === 0) {
|
||||
return;
|
||||
}
|
||||
const firstEntry = current[0];
|
||||
const lastEntry = current[current.length - 1];
|
||||
if (!firstEntry || !lastEntry) {
|
||||
return;
|
||||
}
|
||||
const text = current.map((entry) => entry.line).join("\n");
|
||||
const startLine = firstEntry.lineNo;
|
||||
const endLine = lastEntry.lineNo;
|
||||
chunks.push({
|
||||
startLine,
|
||||
endLine,
|
||||
text,
|
||||
hash: hashText(text),
|
||||
embeddingInput: buildTextEmbeddingInput(text),
|
||||
});
|
||||
};
|
||||
|
||||
const carryOverlap = () => {
|
||||
if (overlapChars <= 0 || current.length === 0) {
|
||||
current = [];
|
||||
currentChars = 0;
|
||||
return;
|
||||
}
|
||||
let acc = 0;
|
||||
const kept: Array<{ line: string; lineNo: number }> = [];
|
||||
for (let i = current.length - 1; i >= 0; i -= 1) {
|
||||
const entry = current[i];
|
||||
if (!entry) {
|
||||
continue;
|
||||
}
|
||||
acc += entry.line.length + 1;
|
||||
kept.unshift(entry);
|
||||
if (acc >= overlapChars) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
current = kept;
|
||||
currentChars = kept.reduce((sum, entry) => sum + entry.line.length + 1, 0);
|
||||
};
|
||||
|
||||
for (let i = 0; i < lines.length; i += 1) {
|
||||
const line = lines[i] ?? "";
|
||||
const lineNo = i + 1;
|
||||
const segments: string[] = [];
|
||||
if (line.length === 0) {
|
||||
segments.push("");
|
||||
} else {
|
||||
for (let start = 0; start < line.length; start += maxChars) {
|
||||
segments.push(line.slice(start, start + maxChars));
|
||||
}
|
||||
}
|
||||
for (const segment of segments) {
|
||||
const lineSize = segment.length + 1;
|
||||
if (currentChars + lineSize > maxChars && current.length > 0) {
|
||||
flush();
|
||||
carryOverlap();
|
||||
}
|
||||
current.push({ line: segment, lineNo });
|
||||
currentChars += lineSize;
|
||||
}
|
||||
}
|
||||
flush();
|
||||
return chunks;
|
||||
}
|
||||
|
||||
/**
|
||||
* Remap chunk startLine/endLine from content-relative positions to original
|
||||
* source file positions using a lineMap. Each entry in lineMap gives the
|
||||
* 1-indexed source line for the corresponding 0-indexed content line.
|
||||
*
|
||||
* This is used for session JSONL files where buildSessionEntry() flattens
|
||||
* messages into a plain-text string before chunking. Without remapping the
|
||||
* stored line numbers would reference positions in the flattened text rather
|
||||
* than the original JSONL file.
|
||||
*/
|
||||
export function remapChunkLines(chunks: MemoryChunk[], lineMap: number[] | undefined): void {
|
||||
if (!lineMap || lineMap.length === 0) {
|
||||
return;
|
||||
}
|
||||
for (const chunk of chunks) {
|
||||
// startLine/endLine are 1-indexed; lineMap is 0-indexed by content line
|
||||
chunk.startLine = lineMap[chunk.startLine - 1] ?? chunk.startLine;
|
||||
chunk.endLine = lineMap[chunk.endLine - 1] ?? chunk.endLine;
|
||||
}
|
||||
}
|
||||
|
||||
export function parseEmbedding(raw: string): number[] {
|
||||
try {
|
||||
const parsed = JSON.parse(raw) as number[];
|
||||
return Array.isArray(parsed) ? parsed : [];
|
||||
} catch {
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
export function cosineSimilarity(a: number[], b: number[]): number {
|
||||
if (a.length === 0 || b.length === 0) {
|
||||
return 0;
|
||||
}
|
||||
const len = Math.min(a.length, b.length);
|
||||
let dot = 0;
|
||||
let normA = 0;
|
||||
let normB = 0;
|
||||
for (let i = 0; i < len; i += 1) {
|
||||
const av = a[i] ?? 0;
|
||||
const bv = b[i] ?? 0;
|
||||
dot += av * bv;
|
||||
normA += av * av;
|
||||
normB += bv * bv;
|
||||
}
|
||||
if (normA === 0 || normB === 0) {
|
||||
return 0;
|
||||
}
|
||||
return dot / (Math.sqrt(normA) * Math.sqrt(normB));
|
||||
}
|
||||
|
||||
export async function runWithConcurrency<T>(
|
||||
tasks: Array<() => Promise<T>>,
|
||||
limit: number,
|
||||
): Promise<T[]> {
|
||||
const { results, firstError, hasError } = await runTasksWithConcurrency({
|
||||
tasks,
|
||||
limit,
|
||||
errorMode: "stop",
|
||||
});
|
||||
if (hasError) {
|
||||
throw firstError;
|
||||
}
|
||||
return results;
|
||||
}
|
||||
@@ -1,99 +0,0 @@
|
||||
import type { DatabaseSync } from "node:sqlite";
|
||||
|
||||
export function ensureMemoryIndexSchema(params: {
|
||||
db: DatabaseSync;
|
||||
embeddingCacheTable: string;
|
||||
cacheEnabled: boolean;
|
||||
ftsTable: string;
|
||||
ftsEnabled: boolean;
|
||||
}): { ftsAvailable: boolean; ftsError?: string } {
|
||||
params.db.exec(`
|
||||
CREATE TABLE IF NOT EXISTS meta (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT NOT NULL
|
||||
);
|
||||
`);
|
||||
params.db.exec(`
|
||||
CREATE TABLE IF NOT EXISTS files (
|
||||
path TEXT PRIMARY KEY,
|
||||
source TEXT NOT NULL DEFAULT 'memory',
|
||||
hash TEXT NOT NULL,
|
||||
mtime INTEGER NOT NULL,
|
||||
size INTEGER NOT NULL
|
||||
);
|
||||
`);
|
||||
params.db.exec(`
|
||||
CREATE TABLE IF NOT EXISTS chunks (
|
||||
id TEXT PRIMARY KEY,
|
||||
path TEXT NOT NULL,
|
||||
source TEXT NOT NULL DEFAULT 'memory',
|
||||
start_line INTEGER NOT NULL,
|
||||
end_line INTEGER NOT NULL,
|
||||
hash TEXT NOT NULL,
|
||||
model TEXT NOT NULL,
|
||||
text TEXT NOT NULL,
|
||||
embedding TEXT NOT NULL,
|
||||
updated_at INTEGER NOT NULL
|
||||
);
|
||||
`);
|
||||
if (params.cacheEnabled) {
|
||||
params.db.exec(`
|
||||
CREATE TABLE IF NOT EXISTS ${params.embeddingCacheTable} (
|
||||
provider TEXT NOT NULL,
|
||||
model TEXT NOT NULL,
|
||||
provider_key TEXT NOT NULL,
|
||||
hash TEXT NOT NULL,
|
||||
embedding TEXT NOT NULL,
|
||||
dims INTEGER,
|
||||
updated_at INTEGER NOT NULL,
|
||||
PRIMARY KEY (provider, model, provider_key, hash)
|
||||
);
|
||||
`);
|
||||
params.db.exec(
|
||||
`CREATE INDEX IF NOT EXISTS idx_embedding_cache_updated_at ON ${params.embeddingCacheTable}(updated_at);`,
|
||||
);
|
||||
}
|
||||
|
||||
let ftsAvailable = false;
|
||||
let ftsError: string | undefined;
|
||||
if (params.ftsEnabled) {
|
||||
try {
|
||||
params.db.exec(
|
||||
`CREATE VIRTUAL TABLE IF NOT EXISTS ${params.ftsTable} USING fts5(\n` +
|
||||
` text,\n` +
|
||||
` id UNINDEXED,\n` +
|
||||
` path UNINDEXED,\n` +
|
||||
` source UNINDEXED,\n` +
|
||||
` model UNINDEXED,\n` +
|
||||
` start_line UNINDEXED,\n` +
|
||||
` end_line UNINDEXED\n` +
|
||||
`);`,
|
||||
);
|
||||
ftsAvailable = true;
|
||||
} catch (err) {
|
||||
const message = err instanceof Error ? err.message : String(err);
|
||||
ftsAvailable = false;
|
||||
ftsError = message;
|
||||
}
|
||||
}
|
||||
|
||||
ensureColumn(params.db, "files", "source", "TEXT NOT NULL DEFAULT 'memory'");
|
||||
ensureColumn(params.db, "chunks", "source", "TEXT NOT NULL DEFAULT 'memory'");
|
||||
params.db.exec(`CREATE INDEX IF NOT EXISTS idx_chunks_path ON chunks(path);`);
|
||||
params.db.exec(`CREATE INDEX IF NOT EXISTS idx_chunks_source ON chunks(source);`);
|
||||
|
||||
return { ftsAvailable, ...(ftsError ? { ftsError } : {}) };
|
||||
}
|
||||
|
||||
function ensureColumn(
|
||||
db: DatabaseSync,
|
||||
table: "files" | "chunks",
|
||||
column: string,
|
||||
definition: string,
|
||||
): void {
|
||||
const rows = db.prepare(`PRAGMA table_info(${table})`).all() as Array<{ name: string }>;
|
||||
if (rows.some((row) => row.name === column)) {
|
||||
return;
|
||||
}
|
||||
db.exec(`ALTER TABLE ${table} ADD COLUMN ${column} ${definition}`);
|
||||
}
|
||||
@@ -1,118 +0,0 @@
|
||||
const MEMORY_MULTIMODAL_SPECS = {
|
||||
image: {
|
||||
labelPrefix: "Image file",
|
||||
extensions: [".jpg", ".jpeg", ".png", ".webp", ".gif", ".heic", ".heif"],
|
||||
},
|
||||
audio: {
|
||||
labelPrefix: "Audio file",
|
||||
extensions: [".mp3", ".wav", ".ogg", ".opus", ".m4a", ".aac", ".flac"],
|
||||
},
|
||||
} as const;
|
||||
|
||||
export type MemoryMultimodalModality = keyof typeof MEMORY_MULTIMODAL_SPECS;
|
||||
export const MEMORY_MULTIMODAL_MODALITIES = Object.keys(
|
||||
MEMORY_MULTIMODAL_SPECS,
|
||||
) as MemoryMultimodalModality[];
|
||||
export type MemoryMultimodalSelection = MemoryMultimodalModality | "all";
|
||||
|
||||
export type MemoryMultimodalSettings = {
|
||||
enabled: boolean;
|
||||
modalities: MemoryMultimodalModality[];
|
||||
maxFileBytes: number;
|
||||
};
|
||||
|
||||
export const DEFAULT_MEMORY_MULTIMODAL_MAX_FILE_BYTES = 10 * 1024 * 1024;
|
||||
|
||||
export function normalizeMemoryMultimodalModalities(
|
||||
raw: MemoryMultimodalSelection[] | undefined,
|
||||
): MemoryMultimodalModality[] {
|
||||
if (raw === undefined || raw.includes("all")) {
|
||||
return [...MEMORY_MULTIMODAL_MODALITIES];
|
||||
}
|
||||
const normalized = new Set<MemoryMultimodalModality>();
|
||||
for (const value of raw) {
|
||||
if (value === "image" || value === "audio") {
|
||||
normalized.add(value);
|
||||
}
|
||||
}
|
||||
return Array.from(normalized);
|
||||
}
|
||||
|
||||
export function normalizeMemoryMultimodalSettings(raw: {
|
||||
enabled?: boolean;
|
||||
modalities?: MemoryMultimodalSelection[];
|
||||
maxFileBytes?: number;
|
||||
}): MemoryMultimodalSettings {
|
||||
const enabled = raw.enabled === true;
|
||||
const maxFileBytes =
|
||||
typeof raw.maxFileBytes === "number" && Number.isFinite(raw.maxFileBytes)
|
||||
? Math.max(1, Math.floor(raw.maxFileBytes))
|
||||
: DEFAULT_MEMORY_MULTIMODAL_MAX_FILE_BYTES;
|
||||
return {
|
||||
enabled,
|
||||
modalities: enabled ? normalizeMemoryMultimodalModalities(raw.modalities) : [],
|
||||
maxFileBytes,
|
||||
};
|
||||
}
|
||||
|
||||
export function isMemoryMultimodalEnabled(settings: MemoryMultimodalSettings): boolean {
|
||||
return settings.enabled && settings.modalities.length > 0;
|
||||
}
|
||||
|
||||
export function getMemoryMultimodalExtensions(
|
||||
modality: MemoryMultimodalModality,
|
||||
): readonly string[] {
|
||||
return MEMORY_MULTIMODAL_SPECS[modality].extensions;
|
||||
}
|
||||
|
||||
export function buildMemoryMultimodalLabel(
|
||||
modality: MemoryMultimodalModality,
|
||||
normalizedPath: string,
|
||||
): string {
|
||||
return `${MEMORY_MULTIMODAL_SPECS[modality].labelPrefix}: ${normalizedPath}`;
|
||||
}
|
||||
|
||||
export function buildCaseInsensitiveExtensionGlob(extension: string): string {
|
||||
const normalized = extension.trim().replace(/^\./, "").toLowerCase();
|
||||
if (!normalized) {
|
||||
return "*";
|
||||
}
|
||||
const parts = Array.from(normalized, (char) => `[${char.toLowerCase()}${char.toUpperCase()}]`);
|
||||
return `*.${parts.join("")}`;
|
||||
}
|
||||
|
||||
export function classifyMemoryMultimodalPath(
|
||||
filePath: string,
|
||||
settings: MemoryMultimodalSettings,
|
||||
): MemoryMultimodalModality | null {
|
||||
if (!isMemoryMultimodalEnabled(settings)) {
|
||||
return null;
|
||||
}
|
||||
const lower = filePath.trim().toLowerCase();
|
||||
for (const modality of settings.modalities) {
|
||||
for (const extension of getMemoryMultimodalExtensions(modality)) {
|
||||
if (lower.endsWith(extension)) {
|
||||
return modality;
|
||||
}
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
export function normalizeGeminiEmbeddingModelForMemory(model: string): string {
|
||||
const trimmed = model.trim();
|
||||
if (!trimmed) {
|
||||
return "";
|
||||
}
|
||||
return trimmed.replace(/^models\//, "").replace(/^(gemini|google)\//, "");
|
||||
}
|
||||
|
||||
export function supportsMemoryMultimodalEmbeddings(params: {
|
||||
provider: string;
|
||||
model: string;
|
||||
}): boolean {
|
||||
if (params.provider !== "gemini") {
|
||||
return false;
|
||||
}
|
||||
return normalizeGeminiEmbeddingModelForMemory(params.model) === "gemini-embedding-2-preview";
|
||||
}
|
||||
@@ -1,3 +0,0 @@
|
||||
export async function importNodeLlamaCpp() {
|
||||
return import("node-llama-cpp");
|
||||
}
|
||||
@@ -1,59 +0,0 @@
|
||||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
|
||||
vi.mock("./remote-http.js", () => ({
|
||||
withRemoteHttpResponse: vi.fn(),
|
||||
}));
|
||||
|
||||
let postJson: typeof import("./post-json.js").postJson;
|
||||
let withRemoteHttpResponse: typeof import("./remote-http.js").withRemoteHttpResponse;
|
||||
|
||||
describe("postJson", () => {
|
||||
let remoteHttpMock: ReturnType<typeof vi.mocked<typeof withRemoteHttpResponse>>;
|
||||
|
||||
beforeEach(async () => {
|
||||
vi.resetModules();
|
||||
vi.clearAllMocks();
|
||||
vi.resetModules();
|
||||
({ postJson } = await import("./post-json.js"));
|
||||
({ withRemoteHttpResponse } = await import("./remote-http.js"));
|
||||
remoteHttpMock = vi.mocked(withRemoteHttpResponse);
|
||||
});
|
||||
|
||||
it("parses JSON payload on successful response", async () => {
|
||||
remoteHttpMock.mockImplementationOnce(async (params) => {
|
||||
return await params.onResponse(
|
||||
new Response(JSON.stringify({ data: [{ embedding: [1, 2] }] }), { status: 200 }),
|
||||
);
|
||||
});
|
||||
|
||||
const result = await postJson({
|
||||
url: "https://memory.example/v1/post",
|
||||
headers: { Authorization: "Bearer test" },
|
||||
body: { input: ["x"] },
|
||||
errorPrefix: "post failed",
|
||||
parse: (payload) => payload,
|
||||
});
|
||||
|
||||
expect(result).toEqual({ data: [{ embedding: [1, 2] }] });
|
||||
});
|
||||
|
||||
it("attaches status to thrown error when requested", async () => {
|
||||
remoteHttpMock.mockImplementationOnce(async (params) => {
|
||||
return await params.onResponse(new Response("bad gateway", { status: 502 }));
|
||||
});
|
||||
|
||||
await expect(
|
||||
postJson({
|
||||
url: "https://memory.example/v1/post",
|
||||
headers: {},
|
||||
body: {},
|
||||
errorPrefix: "post failed",
|
||||
attachStatus: true,
|
||||
parse: () => ({}),
|
||||
}),
|
||||
).rejects.toMatchObject({
|
||||
message: expect.stringContaining("post failed: 502 bad gateway"),
|
||||
status: 502,
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,35 +0,0 @@
|
||||
import type { SsrFPolicy } from "../../infra/net/ssrf.js";
|
||||
import { withRemoteHttpResponse } from "./remote-http.js";
|
||||
|
||||
export async function postJson<T>(params: {
|
||||
url: string;
|
||||
headers: Record<string, string>;
|
||||
ssrfPolicy?: SsrFPolicy;
|
||||
body: unknown;
|
||||
errorPrefix: string;
|
||||
attachStatus?: boolean;
|
||||
parse: (payload: unknown) => T | Promise<T>;
|
||||
}): Promise<T> {
|
||||
return await withRemoteHttpResponse({
|
||||
url: params.url,
|
||||
ssrfPolicy: params.ssrfPolicy,
|
||||
init: {
|
||||
method: "POST",
|
||||
headers: params.headers,
|
||||
body: JSON.stringify(params.body),
|
||||
},
|
||||
onResponse: async (res) => {
|
||||
if (!res.ok) {
|
||||
const text = await res.text();
|
||||
const err = new Error(`${params.errorPrefix}: ${res.status} ${text}`) as Error & {
|
||||
status?: number;
|
||||
};
|
||||
if (params.attachStatus) {
|
||||
err.status = res.status;
|
||||
}
|
||||
throw err;
|
||||
}
|
||||
return await params.parse(await res.json());
|
||||
},
|
||||
});
|
||||
}
|
||||
@@ -1,91 +0,0 @@
|
||||
import fs from "node:fs/promises";
|
||||
import os from "node:os";
|
||||
import path from "node:path";
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { resolveCliSpawnInvocation } from "./qmd-process.js";
|
||||
|
||||
describe("resolveCliSpawnInvocation", () => {
|
||||
let tempDir = "";
|
||||
let platformSpy: { mockRestore(): void } | null = null;
|
||||
const originalPath = process.env.PATH;
|
||||
const originalPathExt = process.env.PATHEXT;
|
||||
|
||||
beforeEach(async () => {
|
||||
tempDir = await fs.mkdtemp(path.join(os.tmpdir(), "openclaw-qmd-win-spawn-"));
|
||||
platformSpy = vi.spyOn(process, "platform", "get").mockReturnValue("win32");
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
platformSpy?.mockRestore();
|
||||
process.env.PATH = originalPath;
|
||||
process.env.PATHEXT = originalPathExt;
|
||||
if (tempDir) {
|
||||
await fs.rm(tempDir, { recursive: true, force: true });
|
||||
tempDir = "";
|
||||
}
|
||||
});
|
||||
|
||||
it("unwraps npm cmd shims to a direct node entrypoint", async () => {
|
||||
const binDir = path.join(tempDir, "node_modules", ".bin");
|
||||
const packageDir = path.join(tempDir, "node_modules", "qmd");
|
||||
const scriptPath = path.join(packageDir, "dist", "cli.js");
|
||||
await fs.mkdir(path.dirname(scriptPath), { recursive: true });
|
||||
await fs.mkdir(binDir, { recursive: true });
|
||||
await fs.writeFile(path.join(binDir, "qmd.cmd"), "@echo off\r\n", "utf8");
|
||||
await fs.writeFile(
|
||||
path.join(packageDir, "package.json"),
|
||||
JSON.stringify({ name: "qmd", version: "0.0.0", bin: { qmd: "dist/cli.js" } }),
|
||||
"utf8",
|
||||
);
|
||||
await fs.writeFile(scriptPath, "module.exports = {};\n", "utf8");
|
||||
|
||||
process.env.PATH = `${binDir};${originalPath ?? ""}`;
|
||||
process.env.PATHEXT = ".CMD;.EXE";
|
||||
|
||||
const invocation = resolveCliSpawnInvocation({
|
||||
command: "qmd",
|
||||
args: ["query", "hello"],
|
||||
env: process.env,
|
||||
packageName: "qmd",
|
||||
});
|
||||
|
||||
expect(invocation.command).toBe(process.execPath);
|
||||
expect(invocation.argv).toEqual([scriptPath, "query", "hello"]);
|
||||
expect(invocation.shell).not.toBe(true);
|
||||
expect(invocation.windowsHide).toBe(true);
|
||||
});
|
||||
|
||||
it("fails closed when a Windows cmd shim cannot be resolved without shell execution", async () => {
|
||||
const binDir = path.join(tempDir, "bad-bin");
|
||||
await fs.mkdir(binDir, { recursive: true });
|
||||
await fs.writeFile(path.join(binDir, "qmd.cmd"), "@echo off\r\nREM no entrypoint\r\n", "utf8");
|
||||
|
||||
process.env.PATH = `${binDir};${originalPath ?? ""}`;
|
||||
process.env.PATHEXT = ".CMD;.EXE";
|
||||
|
||||
expect(() =>
|
||||
resolveCliSpawnInvocation({
|
||||
command: "qmd",
|
||||
args: ["query", "hello"],
|
||||
env: process.env,
|
||||
packageName: "qmd",
|
||||
}),
|
||||
).toThrow(/without shell execution/);
|
||||
});
|
||||
|
||||
it("keeps bare commands bare when no Windows wrapper exists on PATH", () => {
|
||||
process.env.PATH = originalPath ?? "";
|
||||
process.env.PATHEXT = ".CMD;.EXE";
|
||||
|
||||
const invocation = resolveCliSpawnInvocation({
|
||||
command: "qmd",
|
||||
args: ["query", "hello"],
|
||||
env: process.env,
|
||||
packageName: "qmd",
|
||||
});
|
||||
|
||||
expect(invocation.command).toBe("qmd");
|
||||
expect(invocation.argv).toEqual(["query", "hello"]);
|
||||
expect(invocation.shell).not.toBe(true);
|
||||
});
|
||||
});
|
||||
@@ -1,108 +0,0 @@
|
||||
import { spawn } from "node:child_process";
|
||||
import {
|
||||
materializeWindowsSpawnProgram,
|
||||
resolveWindowsSpawnProgram,
|
||||
} from "../../plugin-sdk/windows-spawn.js";
|
||||
|
||||
export type CliSpawnInvocation = {
|
||||
command: string;
|
||||
argv: string[];
|
||||
shell?: boolean;
|
||||
windowsHide?: boolean;
|
||||
};
|
||||
|
||||
export function resolveCliSpawnInvocation(params: {
|
||||
command: string;
|
||||
args: string[];
|
||||
env: NodeJS.ProcessEnv;
|
||||
packageName: string;
|
||||
}): CliSpawnInvocation {
|
||||
const program = resolveWindowsSpawnProgram({
|
||||
command: params.command,
|
||||
platform: process.platform,
|
||||
env: params.env,
|
||||
execPath: process.execPath,
|
||||
packageName: params.packageName,
|
||||
allowShellFallback: false,
|
||||
});
|
||||
return materializeWindowsSpawnProgram(program, params.args);
|
||||
}
|
||||
|
||||
export async function runCliCommand(params: {
|
||||
commandSummary: string;
|
||||
spawnInvocation: CliSpawnInvocation;
|
||||
env: NodeJS.ProcessEnv;
|
||||
cwd: string;
|
||||
timeoutMs?: number;
|
||||
maxOutputChars: number;
|
||||
discardStdout?: boolean;
|
||||
}): Promise<{ stdout: string; stderr: string }> {
|
||||
return await new Promise((resolve, reject) => {
|
||||
const child = spawn(params.spawnInvocation.command, params.spawnInvocation.argv, {
|
||||
env: params.env,
|
||||
cwd: params.cwd,
|
||||
shell: params.spawnInvocation.shell,
|
||||
windowsHide: params.spawnInvocation.windowsHide,
|
||||
});
|
||||
let stdout = "";
|
||||
let stderr = "";
|
||||
let stdoutTruncated = false;
|
||||
let stderrTruncated = false;
|
||||
const discardStdout = params.discardStdout === true;
|
||||
const timer = params.timeoutMs
|
||||
? setTimeout(() => {
|
||||
child.kill("SIGKILL");
|
||||
reject(new Error(`${params.commandSummary} timed out after ${params.timeoutMs}ms`));
|
||||
}, params.timeoutMs)
|
||||
: null;
|
||||
child.stdout.on("data", (data) => {
|
||||
if (discardStdout) {
|
||||
return;
|
||||
}
|
||||
const next = appendOutputWithCap(stdout, data.toString("utf8"), params.maxOutputChars);
|
||||
stdout = next.text;
|
||||
stdoutTruncated = stdoutTruncated || next.truncated;
|
||||
});
|
||||
child.stderr.on("data", (data) => {
|
||||
const next = appendOutputWithCap(stderr, data.toString("utf8"), params.maxOutputChars);
|
||||
stderr = next.text;
|
||||
stderrTruncated = stderrTruncated || next.truncated;
|
||||
});
|
||||
child.on("error", (err) => {
|
||||
if (timer) {
|
||||
clearTimeout(timer);
|
||||
}
|
||||
reject(err);
|
||||
});
|
||||
child.on("close", (code) => {
|
||||
if (timer) {
|
||||
clearTimeout(timer);
|
||||
}
|
||||
if (!discardStdout && (stdoutTruncated || stderrTruncated)) {
|
||||
reject(
|
||||
new Error(
|
||||
`${params.commandSummary} produced too much output (limit ${params.maxOutputChars} chars)`,
|
||||
),
|
||||
);
|
||||
return;
|
||||
}
|
||||
if (code === 0) {
|
||||
resolve({ stdout, stderr });
|
||||
} else {
|
||||
reject(new Error(`${params.commandSummary} failed (code ${code}): ${stderr || stdout}`));
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
function appendOutputWithCap(
|
||||
current: string,
|
||||
chunk: string,
|
||||
maxChars: number,
|
||||
): { text: string; truncated: boolean } {
|
||||
const appended = current + chunk;
|
||||
if (appended.length <= maxChars) {
|
||||
return { text: appended, truncated: false };
|
||||
}
|
||||
return { text: appended.slice(-maxChars), truncated: true };
|
||||
}
|
||||
@@ -1,48 +0,0 @@
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { parseQmdQueryJson } from "./qmd-query-parser.js";
|
||||
|
||||
describe("parseQmdQueryJson", () => {
|
||||
it("parses clean qmd JSON output", () => {
|
||||
const results = parseQmdQueryJson('[{"docid":"abc","score":1,"snippet":"@@ -1,1\\none"}]', "");
|
||||
expect(results).toEqual([
|
||||
{
|
||||
docid: "abc",
|
||||
score: 1,
|
||||
snippet: "@@ -1,1\none",
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it("extracts embedded result arrays from noisy stdout", () => {
|
||||
const results = parseQmdQueryJson(
|
||||
`initializing
|
||||
{"payload":"ok"}
|
||||
[{"docid":"abc","score":0.5}]
|
||||
complete`,
|
||||
"",
|
||||
);
|
||||
expect(results).toEqual([{ docid: "abc", score: 0.5 }]);
|
||||
});
|
||||
|
||||
it("treats plain-text no-results from stderr as an empty result set", () => {
|
||||
const results = parseQmdQueryJson("", "No results found\n");
|
||||
expect(results).toEqual([]);
|
||||
});
|
||||
|
||||
it("treats prefixed no-results marker output as an empty result set", () => {
|
||||
expect(parseQmdQueryJson("warning: no results found", "")).toEqual([]);
|
||||
expect(parseQmdQueryJson("", "[qmd] warning: no results found\n")).toEqual([]);
|
||||
});
|
||||
|
||||
it("does not treat arbitrary non-marker text as no-results output", () => {
|
||||
expect(() =>
|
||||
parseQmdQueryJson("warning: search completed; no results found for this query", ""),
|
||||
).toThrow(/qmd query returned invalid JSON/i);
|
||||
});
|
||||
|
||||
it("throws when stdout cannot be interpreted as qmd JSON", () => {
|
||||
expect(() => parseQmdQueryJson("this is not json", "")).toThrow(
|
||||
/qmd query returned invalid JSON/i,
|
||||
);
|
||||
});
|
||||
});
|
||||
@@ -1,121 +0,0 @@
|
||||
import { createSubsystemLogger } from "../../logging/subsystem.js";
|
||||
|
||||
const log = createSubsystemLogger("memory");
|
||||
|
||||
export type QmdQueryResult = {
|
||||
docid?: string;
|
||||
score?: number;
|
||||
collection?: string;
|
||||
file?: string;
|
||||
snippet?: string;
|
||||
body?: string;
|
||||
};
|
||||
|
||||
export function parseQmdQueryJson(stdout: string, stderr: string): QmdQueryResult[] {
|
||||
const trimmedStdout = stdout.trim();
|
||||
const trimmedStderr = stderr.trim();
|
||||
const stdoutIsMarker = trimmedStdout.length > 0 && isQmdNoResultsOutput(trimmedStdout);
|
||||
const stderrIsMarker = trimmedStderr.length > 0 && isQmdNoResultsOutput(trimmedStderr);
|
||||
if (stdoutIsMarker || (!trimmedStdout && stderrIsMarker)) {
|
||||
return [];
|
||||
}
|
||||
if (!trimmedStdout) {
|
||||
const context = trimmedStderr ? ` (stderr: ${summarizeQmdStderr(trimmedStderr)})` : "";
|
||||
const message = `stdout empty${context}`;
|
||||
log.warn(`qmd query returned invalid JSON: ${message}`);
|
||||
throw new Error(`qmd query returned invalid JSON: ${message}`);
|
||||
}
|
||||
try {
|
||||
const parsed = parseQmdQueryResultArray(trimmedStdout);
|
||||
if (parsed !== null) {
|
||||
return parsed;
|
||||
}
|
||||
const noisyPayload = extractFirstJsonArray(trimmedStdout);
|
||||
if (!noisyPayload) {
|
||||
throw new Error("qmd query JSON response was not an array");
|
||||
}
|
||||
const fallback = parseQmdQueryResultArray(noisyPayload);
|
||||
if (fallback !== null) {
|
||||
return fallback;
|
||||
}
|
||||
throw new Error("qmd query JSON response was not an array");
|
||||
} catch (err) {
|
||||
const message = err instanceof Error ? err.message : String(err);
|
||||
log.warn(`qmd query returned invalid JSON: ${message}`);
|
||||
throw new Error(`qmd query returned invalid JSON: ${message}`, { cause: err });
|
||||
}
|
||||
}
|
||||
|
||||
function isQmdNoResultsOutput(raw: string): boolean {
|
||||
const lines = raw
|
||||
.split(/\r?\n/)
|
||||
.map((line) => line.trim().toLowerCase().replace(/\s+/g, " "))
|
||||
.filter((line) => line.length > 0);
|
||||
return lines.some((line) => isQmdNoResultsLine(line));
|
||||
}
|
||||
|
||||
function isQmdNoResultsLine(line: string): boolean {
|
||||
if (line === "no results found" || line === "no results found.") {
|
||||
return true;
|
||||
}
|
||||
return /^(?:\[[^\]]+\]\s*)?(?:(?:warn(?:ing)?|info|error|qmd)\s*:\s*)+no results found\.?$/.test(
|
||||
line,
|
||||
);
|
||||
}
|
||||
|
||||
function summarizeQmdStderr(raw: string): string {
|
||||
return raw.length <= 120 ? raw : `${raw.slice(0, 117)}...`;
|
||||
}
|
||||
|
||||
function parseQmdQueryResultArray(raw: string): QmdQueryResult[] | null {
|
||||
try {
|
||||
const parsed = JSON.parse(raw) as unknown;
|
||||
if (!Array.isArray(parsed)) {
|
||||
return null;
|
||||
}
|
||||
return parsed as QmdQueryResult[];
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
function extractFirstJsonArray(raw: string): string | null {
|
||||
const start = raw.indexOf("[");
|
||||
if (start < 0) {
|
||||
return null;
|
||||
}
|
||||
let depth = 0;
|
||||
let inString = false;
|
||||
let escaped = false;
|
||||
for (let i = start; i < raw.length; i += 1) {
|
||||
const char = raw[i];
|
||||
if (char === undefined) {
|
||||
break;
|
||||
}
|
||||
if (inString) {
|
||||
if (escaped) {
|
||||
escaped = false;
|
||||
continue;
|
||||
}
|
||||
if (char === "\\") {
|
||||
escaped = true;
|
||||
} else if (char === '"') {
|
||||
inString = false;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (char === '"') {
|
||||
inString = true;
|
||||
continue;
|
||||
}
|
||||
if (char === "[") {
|
||||
depth += 1;
|
||||
} else if (char === "]") {
|
||||
depth -= 1;
|
||||
if (depth === 0) {
|
||||
return raw.slice(start, i + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
@@ -1,54 +0,0 @@
|
||||
import { describe, expect, it } from "vitest";
|
||||
import type { ResolvedQmdConfig } from "./backend-config.js";
|
||||
import { deriveQmdScopeChannel, deriveQmdScopeChatType, isQmdScopeAllowed } from "./qmd-scope.js";
|
||||
|
||||
describe("qmd scope", () => {
|
||||
const allowDirect: ResolvedQmdConfig["scope"] = {
|
||||
default: "deny",
|
||||
rules: [{ action: "allow", match: { chatType: "direct" } }],
|
||||
};
|
||||
|
||||
it("derives channel and chat type from canonical keys once", () => {
|
||||
expect(deriveQmdScopeChannel("Workspace:group:123")).toBe("workspace");
|
||||
expect(deriveQmdScopeChatType("Workspace:group:123")).toBe("group");
|
||||
});
|
||||
|
||||
it("derives channel and chat type from stored key suffixes", () => {
|
||||
expect(deriveQmdScopeChannel("agent:agent-1:workspace:channel:chan-123")).toBe("workspace");
|
||||
expect(deriveQmdScopeChatType("agent:agent-1:workspace:channel:chan-123")).toBe("channel");
|
||||
});
|
||||
|
||||
it("treats parsed keys with no chat prefix as direct", () => {
|
||||
expect(deriveQmdScopeChannel("agent:agent-1:peer-direct")).toBeUndefined();
|
||||
expect(deriveQmdScopeChatType("agent:agent-1:peer-direct")).toBe("direct");
|
||||
expect(isQmdScopeAllowed(allowDirect, "agent:agent-1:peer-direct")).toBe(true);
|
||||
expect(isQmdScopeAllowed(allowDirect, "agent:agent-1:peer:group:abc")).toBe(false);
|
||||
});
|
||||
|
||||
it("applies scoped key-prefix checks against normalized key", () => {
|
||||
const scope: ResolvedQmdConfig["scope"] = {
|
||||
default: "deny",
|
||||
rules: [{ action: "allow", match: { keyPrefix: "workspace:" } }],
|
||||
};
|
||||
expect(isQmdScopeAllowed(scope, "agent:agent-1:workspace:group:123")).toBe(true);
|
||||
expect(isQmdScopeAllowed(scope, "agent:agent-1:other:group:123")).toBe(false);
|
||||
});
|
||||
|
||||
it("supports rawKeyPrefix matches for agent-prefixed keys", () => {
|
||||
const scope: ResolvedQmdConfig["scope"] = {
|
||||
default: "allow",
|
||||
rules: [{ action: "deny", match: { rawKeyPrefix: "agent:main:discord:" } }],
|
||||
};
|
||||
expect(isQmdScopeAllowed(scope, "agent:main:discord:channel:c123")).toBe(false);
|
||||
expect(isQmdScopeAllowed(scope, "agent:main:slack:channel:c123")).toBe(true);
|
||||
});
|
||||
|
||||
it("keeps legacy agent-prefixed keyPrefix rules working", () => {
|
||||
const scope: ResolvedQmdConfig["scope"] = {
|
||||
default: "allow",
|
||||
rules: [{ action: "deny", match: { keyPrefix: "agent:main:discord:" } }],
|
||||
};
|
||||
expect(isQmdScopeAllowed(scope, "agent:main:discord:channel:c123")).toBe(false);
|
||||
expect(isQmdScopeAllowed(scope, "agent:main:slack:channel:c123")).toBe(true);
|
||||
});
|
||||
});
|
||||
@@ -1,106 +0,0 @@
|
||||
import { parseAgentSessionKey } from "../../sessions/session-key-utils.js";
|
||||
import type { ResolvedQmdConfig } from "./backend-config.js";
|
||||
|
||||
type ParsedQmdSessionScope = {
|
||||
channel?: string;
|
||||
chatType?: "channel" | "group" | "direct";
|
||||
normalizedKey?: string;
|
||||
};
|
||||
|
||||
export function isQmdScopeAllowed(scope: ResolvedQmdConfig["scope"], sessionKey?: string): boolean {
|
||||
if (!scope) {
|
||||
return true;
|
||||
}
|
||||
const parsed = parseQmdSessionScope(sessionKey);
|
||||
const channel = parsed.channel;
|
||||
const chatType = parsed.chatType;
|
||||
const normalizedKey = parsed.normalizedKey ?? "";
|
||||
const rawKey = sessionKey?.trim().toLowerCase() ?? "";
|
||||
for (const rule of scope.rules ?? []) {
|
||||
if (!rule) {
|
||||
continue;
|
||||
}
|
||||
const match = rule.match ?? {};
|
||||
if (match.channel && match.channel !== channel) {
|
||||
continue;
|
||||
}
|
||||
if (match.chatType && match.chatType !== chatType) {
|
||||
continue;
|
||||
}
|
||||
const normalizedPrefix = match.keyPrefix?.trim().toLowerCase() || undefined;
|
||||
const rawPrefix = match.rawKeyPrefix?.trim().toLowerCase() || undefined;
|
||||
|
||||
if (rawPrefix && !rawKey.startsWith(rawPrefix)) {
|
||||
continue;
|
||||
}
|
||||
if (normalizedPrefix) {
|
||||
// Backward compat: older configs used `keyPrefix: "agent:<id>:..."` to match raw keys.
|
||||
const isLegacyRaw = normalizedPrefix.startsWith("agent:");
|
||||
if (isLegacyRaw) {
|
||||
if (!rawKey.startsWith(normalizedPrefix)) {
|
||||
continue;
|
||||
}
|
||||
} else if (!normalizedKey.startsWith(normalizedPrefix)) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
return rule.action === "allow";
|
||||
}
|
||||
const fallback = scope.default ?? "allow";
|
||||
return fallback === "allow";
|
||||
}
|
||||
|
||||
export function deriveQmdScopeChannel(key?: string): string | undefined {
|
||||
return parseQmdSessionScope(key).channel;
|
||||
}
|
||||
|
||||
export function deriveQmdScopeChatType(key?: string): "channel" | "group" | "direct" | undefined {
|
||||
return parseQmdSessionScope(key).chatType;
|
||||
}
|
||||
|
||||
function parseQmdSessionScope(key?: string): ParsedQmdSessionScope {
|
||||
const normalized = normalizeQmdSessionKey(key);
|
||||
if (!normalized) {
|
||||
return {};
|
||||
}
|
||||
const parts = normalized.split(":").filter(Boolean);
|
||||
let chatType: ParsedQmdSessionScope["chatType"];
|
||||
if (
|
||||
parts.length >= 2 &&
|
||||
(parts[1] === "group" || parts[1] === "channel" || parts[1] === "direct" || parts[1] === "dm")
|
||||
) {
|
||||
if (parts.includes("group")) {
|
||||
chatType = "group";
|
||||
} else if (parts.includes("channel")) {
|
||||
chatType = "channel";
|
||||
}
|
||||
return {
|
||||
normalizedKey: normalized,
|
||||
channel: parts[0]?.toLowerCase(),
|
||||
chatType: chatType ?? "direct",
|
||||
};
|
||||
}
|
||||
if (normalized.includes(":group:")) {
|
||||
return { normalizedKey: normalized, chatType: "group" };
|
||||
}
|
||||
if (normalized.includes(":channel:")) {
|
||||
return { normalizedKey: normalized, chatType: "channel" };
|
||||
}
|
||||
return { normalizedKey: normalized, chatType: "direct" };
|
||||
}
|
||||
|
||||
function normalizeQmdSessionKey(key?: string): string | undefined {
|
||||
if (!key) {
|
||||
return undefined;
|
||||
}
|
||||
const trimmed = key.trim();
|
||||
if (!trimmed) {
|
||||
return undefined;
|
||||
}
|
||||
const parsed = parseAgentSessionKey(trimmed);
|
||||
const normalized = (parsed?.rest ?? trimmed).toLowerCase();
|
||||
if (normalized.startsWith("subagent:")) {
|
||||
return undefined;
|
||||
}
|
||||
return normalized;
|
||||
}
|
||||
@@ -1,199 +0,0 @@
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { expandQueryForFts, extractKeywords } from "./query-expansion.js";
|
||||
|
||||
describe("extractKeywords", () => {
|
||||
it("extracts keywords from English conversational query", () => {
|
||||
const keywords = extractKeywords("that thing we discussed about the API");
|
||||
expect(keywords).toContain("discussed");
|
||||
expect(keywords).toContain("api");
|
||||
// Should not include stop words
|
||||
expect(keywords).not.toContain("that");
|
||||
expect(keywords).not.toContain("thing");
|
||||
expect(keywords).not.toContain("we");
|
||||
expect(keywords).not.toContain("about");
|
||||
expect(keywords).not.toContain("the");
|
||||
});
|
||||
|
||||
it("extracts keywords from Chinese conversational query", () => {
|
||||
const keywords = extractKeywords("之前讨论的那个方案");
|
||||
expect(keywords).toContain("讨论");
|
||||
expect(keywords).toContain("方案");
|
||||
// Should not include stop words
|
||||
expect(keywords).not.toContain("之前");
|
||||
expect(keywords).not.toContain("的");
|
||||
expect(keywords).not.toContain("那个");
|
||||
});
|
||||
|
||||
it("extracts keywords from mixed language query", () => {
|
||||
const keywords = extractKeywords("昨天讨论的 API design");
|
||||
expect(keywords).toContain("讨论");
|
||||
expect(keywords).toContain("api");
|
||||
expect(keywords).toContain("design");
|
||||
});
|
||||
|
||||
it("returns specific technical terms", () => {
|
||||
const keywords = extractKeywords("what was the solution for the CFR bug");
|
||||
expect(keywords).toContain("solution");
|
||||
expect(keywords).toContain("cfr");
|
||||
expect(keywords).toContain("bug");
|
||||
});
|
||||
|
||||
it("extracts keywords from Korean conversational query", () => {
|
||||
const keywords = extractKeywords("어제 논의한 배포 전략");
|
||||
expect(keywords).toContain("논의한");
|
||||
expect(keywords).toContain("배포");
|
||||
expect(keywords).toContain("전략");
|
||||
// Should not include stop words
|
||||
expect(keywords).not.toContain("어제");
|
||||
});
|
||||
|
||||
it("strips Korean particles to extract stems", () => {
|
||||
const keywords = extractKeywords("서버에서 발생한 에러를 확인");
|
||||
expect(keywords).toContain("서버");
|
||||
expect(keywords).toContain("에러");
|
||||
expect(keywords).toContain("확인");
|
||||
});
|
||||
|
||||
it("filters Korean stop words including inflected forms", () => {
|
||||
const keywords = extractKeywords("나는 그리고 그래서");
|
||||
expect(keywords).not.toContain("나");
|
||||
expect(keywords).not.toContain("나는");
|
||||
expect(keywords).not.toContain("그리고");
|
||||
expect(keywords).not.toContain("그래서");
|
||||
});
|
||||
|
||||
it("filters inflected Korean stop words not explicitly listed", () => {
|
||||
const keywords = extractKeywords("그녀는 우리는");
|
||||
expect(keywords).not.toContain("그녀는");
|
||||
expect(keywords).not.toContain("우리는");
|
||||
expect(keywords).not.toContain("그녀");
|
||||
expect(keywords).not.toContain("우리");
|
||||
});
|
||||
|
||||
it("does not produce bogus single-char stems from particle stripping", () => {
|
||||
const keywords = extractKeywords("논의");
|
||||
expect(keywords).toContain("논의");
|
||||
expect(keywords).not.toContain("논");
|
||||
});
|
||||
|
||||
it("strips longest Korean trailing particles first", () => {
|
||||
const keywords = extractKeywords("기능으로 설명");
|
||||
expect(keywords).toContain("기능");
|
||||
expect(keywords).not.toContain("기능으");
|
||||
});
|
||||
|
||||
it("keeps stripped ASCII stems for mixed Korean tokens", () => {
|
||||
const keywords = extractKeywords("API를 배포했다");
|
||||
expect(keywords).toContain("api");
|
||||
expect(keywords).toContain("배포했다");
|
||||
});
|
||||
|
||||
it("handles mixed Korean and English query", () => {
|
||||
const keywords = extractKeywords("API 배포에 대한 논의");
|
||||
expect(keywords).toContain("api");
|
||||
expect(keywords).toContain("배포");
|
||||
expect(keywords).toContain("논의");
|
||||
});
|
||||
|
||||
it("extracts keywords from Japanese conversational query", () => {
|
||||
const keywords = extractKeywords("昨日話したデプロイ戦略");
|
||||
expect(keywords).toContain("デプロイ");
|
||||
expect(keywords).toContain("戦略");
|
||||
expect(keywords).not.toContain("昨日");
|
||||
});
|
||||
|
||||
it("handles mixed Japanese and English query", () => {
|
||||
const keywords = extractKeywords("昨日話したAPIのバグ");
|
||||
expect(keywords).toContain("api");
|
||||
expect(keywords).toContain("バグ");
|
||||
expect(keywords).not.toContain("した");
|
||||
});
|
||||
|
||||
it("filters Japanese stop words", () => {
|
||||
const keywords = extractKeywords("これ それ そして どう");
|
||||
expect(keywords).not.toContain("これ");
|
||||
expect(keywords).not.toContain("それ");
|
||||
expect(keywords).not.toContain("そして");
|
||||
expect(keywords).not.toContain("どう");
|
||||
});
|
||||
|
||||
it("extracts keywords from Spanish conversational query", () => {
|
||||
const keywords = extractKeywords("ayer hablamos sobre la estrategia de despliegue");
|
||||
expect(keywords).toContain("estrategia");
|
||||
expect(keywords).toContain("despliegue");
|
||||
expect(keywords).not.toContain("ayer");
|
||||
expect(keywords).not.toContain("sobre");
|
||||
});
|
||||
|
||||
it("extracts keywords from Portuguese conversational query", () => {
|
||||
const keywords = extractKeywords("ontem falamos sobre a estratégia de implantação");
|
||||
expect(keywords).toContain("estratégia");
|
||||
expect(keywords).toContain("implantação");
|
||||
expect(keywords).not.toContain("ontem");
|
||||
expect(keywords).not.toContain("sobre");
|
||||
});
|
||||
|
||||
it("filters Spanish and Portuguese question stop words", () => {
|
||||
const keywords = extractKeywords("cómo cuando donde porquê quando onde");
|
||||
expect(keywords).not.toContain("cómo");
|
||||
expect(keywords).not.toContain("cuando");
|
||||
expect(keywords).not.toContain("donde");
|
||||
expect(keywords).not.toContain("porquê");
|
||||
expect(keywords).not.toContain("quando");
|
||||
expect(keywords).not.toContain("onde");
|
||||
});
|
||||
|
||||
it("extracts keywords from Arabic conversational query", () => {
|
||||
const keywords = extractKeywords("بالأمس ناقشنا استراتيجية النشر");
|
||||
expect(keywords).toContain("ناقشنا");
|
||||
expect(keywords).toContain("استراتيجية");
|
||||
expect(keywords).toContain("النشر");
|
||||
expect(keywords).not.toContain("بالأمس");
|
||||
});
|
||||
|
||||
it("filters Arabic question stop words", () => {
|
||||
const keywords = extractKeywords("كيف متى أين ماذا");
|
||||
expect(keywords).not.toContain("كيف");
|
||||
expect(keywords).not.toContain("متى");
|
||||
expect(keywords).not.toContain("أين");
|
||||
expect(keywords).not.toContain("ماذا");
|
||||
});
|
||||
|
||||
it("handles empty query", () => {
|
||||
expect(extractKeywords("")).toEqual([]);
|
||||
expect(extractKeywords(" ")).toEqual([]);
|
||||
});
|
||||
|
||||
it("handles query with only stop words", () => {
|
||||
const keywords = extractKeywords("the a an is are");
|
||||
expect(keywords.length).toBe(0);
|
||||
});
|
||||
|
||||
it("removes duplicate keywords", () => {
|
||||
const keywords = extractKeywords("test test testing");
|
||||
const testCount = keywords.filter((k) => k === "test").length;
|
||||
expect(testCount).toBe(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe("expandQueryForFts", () => {
|
||||
it("returns original query and extracted keywords", () => {
|
||||
const result = expandQueryForFts("that API we discussed");
|
||||
expect(result.original).toBe("that API we discussed");
|
||||
expect(result.keywords).toContain("api");
|
||||
expect(result.keywords).toContain("discussed");
|
||||
});
|
||||
|
||||
it("builds expanded OR query for FTS", () => {
|
||||
const result = expandQueryForFts("the solution for bugs");
|
||||
expect(result.expanded).toContain("OR");
|
||||
expect(result.expanded).toContain("solution");
|
||||
expect(result.expanded).toContain("bugs");
|
||||
});
|
||||
|
||||
it("returns original query when no keywords extracted", () => {
|
||||
const result = expandQueryForFts("the");
|
||||
expect(result.keywords.length).toBe(0);
|
||||
expect(result.expanded).toBe("the");
|
||||
});
|
||||
});
|
||||
@@ -1,810 +0,0 @@
|
||||
/**
|
||||
* Query expansion for FTS-only search mode.
|
||||
*
|
||||
* When no embedding provider is available, we fall back to FTS (full-text search).
|
||||
* FTS works best with specific keywords, but users often ask conversational queries
|
||||
* like "that thing we discussed yesterday" or "之前讨论的那个方案".
|
||||
*
|
||||
* This module extracts meaningful keywords from such queries to improve FTS results.
|
||||
*/
|
||||
|
||||
// Common stop words that don't add search value
|
||||
const STOP_WORDS_EN = new Set([
|
||||
// Articles and determiners
|
||||
"a",
|
||||
"an",
|
||||
"the",
|
||||
"this",
|
||||
"that",
|
||||
"these",
|
||||
"those",
|
||||
// Pronouns
|
||||
"i",
|
||||
"me",
|
||||
"my",
|
||||
"we",
|
||||
"our",
|
||||
"you",
|
||||
"your",
|
||||
"he",
|
||||
"she",
|
||||
"it",
|
||||
"they",
|
||||
"them",
|
||||
// Common verbs
|
||||
"is",
|
||||
"are",
|
||||
"was",
|
||||
"were",
|
||||
"be",
|
||||
"been",
|
||||
"being",
|
||||
"have",
|
||||
"has",
|
||||
"had",
|
||||
"do",
|
||||
"does",
|
||||
"did",
|
||||
"will",
|
||||
"would",
|
||||
"could",
|
||||
"should",
|
||||
"can",
|
||||
"may",
|
||||
"might",
|
||||
// Prepositions
|
||||
"in",
|
||||
"on",
|
||||
"at",
|
||||
"to",
|
||||
"for",
|
||||
"of",
|
||||
"with",
|
||||
"by",
|
||||
"from",
|
||||
"about",
|
||||
"into",
|
||||
"through",
|
||||
"during",
|
||||
"before",
|
||||
"after",
|
||||
"above",
|
||||
"below",
|
||||
"between",
|
||||
"under",
|
||||
"over",
|
||||
// Conjunctions
|
||||
"and",
|
||||
"or",
|
||||
"but",
|
||||
"if",
|
||||
"then",
|
||||
"because",
|
||||
"as",
|
||||
"while",
|
||||
"when",
|
||||
"where",
|
||||
"what",
|
||||
"which",
|
||||
"who",
|
||||
"how",
|
||||
"why",
|
||||
// Time references (vague, not useful for FTS)
|
||||
"yesterday",
|
||||
"today",
|
||||
"tomorrow",
|
||||
"earlier",
|
||||
"later",
|
||||
"recently",
|
||||
"before",
|
||||
"ago",
|
||||
"just",
|
||||
"now",
|
||||
// Vague references
|
||||
"thing",
|
||||
"things",
|
||||
"stuff",
|
||||
"something",
|
||||
"anything",
|
||||
"everything",
|
||||
"nothing",
|
||||
// Question words
|
||||
"please",
|
||||
"help",
|
||||
"find",
|
||||
"show",
|
||||
"get",
|
||||
"tell",
|
||||
"give",
|
||||
]);
|
||||
|
||||
const STOP_WORDS_ES = new Set([
|
||||
// Articles and determiners
|
||||
"el",
|
||||
"la",
|
||||
"los",
|
||||
"las",
|
||||
"un",
|
||||
"una",
|
||||
"unos",
|
||||
"unas",
|
||||
"este",
|
||||
"esta",
|
||||
"ese",
|
||||
"esa",
|
||||
// Pronouns
|
||||
"yo",
|
||||
"me",
|
||||
"mi",
|
||||
"nosotros",
|
||||
"nosotras",
|
||||
"tu",
|
||||
"tus",
|
||||
"usted",
|
||||
"ustedes",
|
||||
"ellos",
|
||||
"ellas",
|
||||
// Prepositions and conjunctions
|
||||
"de",
|
||||
"del",
|
||||
"a",
|
||||
"en",
|
||||
"con",
|
||||
"por",
|
||||
"para",
|
||||
"sobre",
|
||||
"entre",
|
||||
"y",
|
||||
"o",
|
||||
"pero",
|
||||
"si",
|
||||
"porque",
|
||||
"como",
|
||||
// Common verbs / auxiliaries
|
||||
"es",
|
||||
"son",
|
||||
"fue",
|
||||
"fueron",
|
||||
"ser",
|
||||
"estar",
|
||||
"haber",
|
||||
"tener",
|
||||
"hacer",
|
||||
// Time references (vague)
|
||||
"ayer",
|
||||
"hoy",
|
||||
"mañana",
|
||||
"antes",
|
||||
"despues",
|
||||
"después",
|
||||
"ahora",
|
||||
"recientemente",
|
||||
// Question/request words
|
||||
"que",
|
||||
"qué",
|
||||
"cómo",
|
||||
"cuando",
|
||||
"cuándo",
|
||||
"donde",
|
||||
"dónde",
|
||||
"porqué",
|
||||
"favor",
|
||||
"ayuda",
|
||||
]);
|
||||
|
||||
const STOP_WORDS_PT = new Set([
|
||||
// Articles and determiners
|
||||
"o",
|
||||
"a",
|
||||
"os",
|
||||
"as",
|
||||
"um",
|
||||
"uma",
|
||||
"uns",
|
||||
"umas",
|
||||
"este",
|
||||
"esta",
|
||||
"esse",
|
||||
"essa",
|
||||
// Pronouns
|
||||
"eu",
|
||||
"me",
|
||||
"meu",
|
||||
"minha",
|
||||
"nos",
|
||||
"nós",
|
||||
"você",
|
||||
"vocês",
|
||||
"ele",
|
||||
"ela",
|
||||
"eles",
|
||||
"elas",
|
||||
// Prepositions and conjunctions
|
||||
"de",
|
||||
"do",
|
||||
"da",
|
||||
"em",
|
||||
"com",
|
||||
"por",
|
||||
"para",
|
||||
"sobre",
|
||||
"entre",
|
||||
"e",
|
||||
"ou",
|
||||
"mas",
|
||||
"se",
|
||||
"porque",
|
||||
"como",
|
||||
// Common verbs / auxiliaries
|
||||
"é",
|
||||
"são",
|
||||
"foi",
|
||||
"foram",
|
||||
"ser",
|
||||
"estar",
|
||||
"ter",
|
||||
"fazer",
|
||||
// Time references (vague)
|
||||
"ontem",
|
||||
"hoje",
|
||||
"amanhã",
|
||||
"antes",
|
||||
"depois",
|
||||
"agora",
|
||||
"recentemente",
|
||||
// Question/request words
|
||||
"que",
|
||||
"quê",
|
||||
"quando",
|
||||
"onde",
|
||||
"porquê",
|
||||
"favor",
|
||||
"ajuda",
|
||||
]);
|
||||
|
||||
const STOP_WORDS_AR = new Set([
|
||||
// Articles and connectors
|
||||
"ال",
|
||||
"و",
|
||||
"أو",
|
||||
"لكن",
|
||||
"ثم",
|
||||
"بل",
|
||||
// Pronouns / references
|
||||
"أنا",
|
||||
"نحن",
|
||||
"هو",
|
||||
"هي",
|
||||
"هم",
|
||||
"هذا",
|
||||
"هذه",
|
||||
"ذلك",
|
||||
"تلك",
|
||||
"هنا",
|
||||
"هناك",
|
||||
// Common prepositions
|
||||
"من",
|
||||
"إلى",
|
||||
"الى",
|
||||
"في",
|
||||
"على",
|
||||
"عن",
|
||||
"مع",
|
||||
"بين",
|
||||
"ل",
|
||||
"ب",
|
||||
"ك",
|
||||
// Common auxiliaries / vague verbs
|
||||
"كان",
|
||||
"كانت",
|
||||
"يكون",
|
||||
"تكون",
|
||||
"صار",
|
||||
"أصبح",
|
||||
"يمكن",
|
||||
"ممكن",
|
||||
// Time references (vague)
|
||||
"بالأمس",
|
||||
"امس",
|
||||
"اليوم",
|
||||
"غدا",
|
||||
"الآن",
|
||||
"قبل",
|
||||
"بعد",
|
||||
"مؤخرا",
|
||||
// Question/request words
|
||||
"لماذا",
|
||||
"كيف",
|
||||
"ماذا",
|
||||
"متى",
|
||||
"أين",
|
||||
"هل",
|
||||
"من فضلك",
|
||||
"فضلا",
|
||||
"ساعد",
|
||||
]);
|
||||
|
||||
const STOP_WORDS_KO = new Set([
|
||||
// Particles (조사)
|
||||
"은",
|
||||
"는",
|
||||
"이",
|
||||
"가",
|
||||
"을",
|
||||
"를",
|
||||
"의",
|
||||
"에",
|
||||
"에서",
|
||||
"로",
|
||||
"으로",
|
||||
"와",
|
||||
"과",
|
||||
"도",
|
||||
"만",
|
||||
"까지",
|
||||
"부터",
|
||||
"한테",
|
||||
"에게",
|
||||
"께",
|
||||
"처럼",
|
||||
"같이",
|
||||
"보다",
|
||||
"마다",
|
||||
"밖에",
|
||||
"대로",
|
||||
// Pronouns (대명사)
|
||||
"나",
|
||||
"나는",
|
||||
"내가",
|
||||
"나를",
|
||||
"너",
|
||||
"우리",
|
||||
"저",
|
||||
"저희",
|
||||
"그",
|
||||
"그녀",
|
||||
"그들",
|
||||
"이것",
|
||||
"저것",
|
||||
"그것",
|
||||
"여기",
|
||||
"저기",
|
||||
"거기",
|
||||
// Common verbs / auxiliaries (일반 동사/보조 동사)
|
||||
"있다",
|
||||
"없다",
|
||||
"하다",
|
||||
"되다",
|
||||
"이다",
|
||||
"아니다",
|
||||
"보다",
|
||||
"주다",
|
||||
"오다",
|
||||
"가다",
|
||||
// Nouns (의존 명사 / vague)
|
||||
"것",
|
||||
"거",
|
||||
"등",
|
||||
"수",
|
||||
"때",
|
||||
"곳",
|
||||
"중",
|
||||
"분",
|
||||
// Adverbs
|
||||
"잘",
|
||||
"더",
|
||||
"또",
|
||||
"매우",
|
||||
"정말",
|
||||
"아주",
|
||||
"많이",
|
||||
"너무",
|
||||
"좀",
|
||||
// Conjunctions
|
||||
"그리고",
|
||||
"하지만",
|
||||
"그래서",
|
||||
"그런데",
|
||||
"그러나",
|
||||
"또는",
|
||||
"그러면",
|
||||
// Question words
|
||||
"왜",
|
||||
"어떻게",
|
||||
"뭐",
|
||||
"언제",
|
||||
"어디",
|
||||
"누구",
|
||||
"무엇",
|
||||
"어떤",
|
||||
// Time (vague)
|
||||
"어제",
|
||||
"오늘",
|
||||
"내일",
|
||||
"최근",
|
||||
"지금",
|
||||
"아까",
|
||||
"나중",
|
||||
"전에",
|
||||
// Request words
|
||||
"제발",
|
||||
"부탁",
|
||||
]);
|
||||
|
||||
// Common Korean trailing particles to strip from words for tokenization
|
||||
// Sorted by descending length so longest-match-first is guaranteed.
|
||||
const KO_TRAILING_PARTICLES = [
|
||||
"에서",
|
||||
"으로",
|
||||
"에게",
|
||||
"한테",
|
||||
"처럼",
|
||||
"같이",
|
||||
"보다",
|
||||
"까지",
|
||||
"부터",
|
||||
"마다",
|
||||
"밖에",
|
||||
"대로",
|
||||
"은",
|
||||
"는",
|
||||
"이",
|
||||
"가",
|
||||
"을",
|
||||
"를",
|
||||
"의",
|
||||
"에",
|
||||
"로",
|
||||
"와",
|
||||
"과",
|
||||
"도",
|
||||
"만",
|
||||
].toSorted((a, b) => b.length - a.length);
|
||||
|
||||
function stripKoreanTrailingParticle(token: string): string | null {
|
||||
for (const particle of KO_TRAILING_PARTICLES) {
|
||||
if (token.length > particle.length && token.endsWith(particle)) {
|
||||
return token.slice(0, -particle.length);
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
function isUsefulKoreanStem(stem: string): boolean {
|
||||
// Prevent bogus one-syllable stems from words like "논의" -> "논".
|
||||
if (/[\uac00-\ud7af]/.test(stem)) {
|
||||
return stem.length >= 2;
|
||||
}
|
||||
// Keep stripped ASCII stems for mixed tokens like "API를" -> "api".
|
||||
return /^[a-z0-9_]+$/i.test(stem);
|
||||
}
|
||||
|
||||
const STOP_WORDS_JA = new Set([
|
||||
// Pronouns and references
|
||||
"これ",
|
||||
"それ",
|
||||
"あれ",
|
||||
"この",
|
||||
"その",
|
||||
"あの",
|
||||
"ここ",
|
||||
"そこ",
|
||||
"あそこ",
|
||||
// Common auxiliaries / vague verbs
|
||||
"する",
|
||||
"した",
|
||||
"して",
|
||||
"です",
|
||||
"ます",
|
||||
"いる",
|
||||
"ある",
|
||||
"なる",
|
||||
"できる",
|
||||
// Particles / connectors
|
||||
"の",
|
||||
"こと",
|
||||
"もの",
|
||||
"ため",
|
||||
"そして",
|
||||
"しかし",
|
||||
"また",
|
||||
"でも",
|
||||
"から",
|
||||
"まで",
|
||||
"より",
|
||||
"だけ",
|
||||
// Question words
|
||||
"なぜ",
|
||||
"どう",
|
||||
"何",
|
||||
"いつ",
|
||||
"どこ",
|
||||
"誰",
|
||||
"どれ",
|
||||
// Time (vague)
|
||||
"昨日",
|
||||
"今日",
|
||||
"明日",
|
||||
"最近",
|
||||
"今",
|
||||
"さっき",
|
||||
"前",
|
||||
"後",
|
||||
]);
|
||||
|
||||
const STOP_WORDS_ZH = new Set([
|
||||
// Pronouns
|
||||
"我",
|
||||
"我们",
|
||||
"你",
|
||||
"你们",
|
||||
"他",
|
||||
"她",
|
||||
"它",
|
||||
"他们",
|
||||
"这",
|
||||
"那",
|
||||
"这个",
|
||||
"那个",
|
||||
"这些",
|
||||
"那些",
|
||||
// Auxiliary words
|
||||
"的",
|
||||
"了",
|
||||
"着",
|
||||
"过",
|
||||
"得",
|
||||
"地",
|
||||
"吗",
|
||||
"呢",
|
||||
"吧",
|
||||
"啊",
|
||||
"呀",
|
||||
"嘛",
|
||||
"啦",
|
||||
// Verbs (common, vague)
|
||||
"是",
|
||||
"有",
|
||||
"在",
|
||||
"被",
|
||||
"把",
|
||||
"给",
|
||||
"让",
|
||||
"用",
|
||||
"到",
|
||||
"去",
|
||||
"来",
|
||||
"做",
|
||||
"说",
|
||||
"看",
|
||||
"找",
|
||||
"想",
|
||||
"要",
|
||||
"能",
|
||||
"会",
|
||||
"可以",
|
||||
// Prepositions and conjunctions
|
||||
"和",
|
||||
"与",
|
||||
"或",
|
||||
"但",
|
||||
"但是",
|
||||
"因为",
|
||||
"所以",
|
||||
"如果",
|
||||
"虽然",
|
||||
"而",
|
||||
"也",
|
||||
"都",
|
||||
"就",
|
||||
"还",
|
||||
"又",
|
||||
"再",
|
||||
"才",
|
||||
"只",
|
||||
// Time (vague)
|
||||
"之前",
|
||||
"以前",
|
||||
"之后",
|
||||
"以后",
|
||||
"刚才",
|
||||
"现在",
|
||||
"昨天",
|
||||
"今天",
|
||||
"明天",
|
||||
"最近",
|
||||
// Vague references
|
||||
"东西",
|
||||
"事情",
|
||||
"事",
|
||||
"什么",
|
||||
"哪个",
|
||||
"哪些",
|
||||
"怎么",
|
||||
"为什么",
|
||||
"多少",
|
||||
// Question/request words
|
||||
"请",
|
||||
"帮",
|
||||
"帮忙",
|
||||
"告诉",
|
||||
]);
|
||||
|
||||
export function isQueryStopWordToken(token: string): boolean {
|
||||
return (
|
||||
STOP_WORDS_EN.has(token) ||
|
||||
STOP_WORDS_ES.has(token) ||
|
||||
STOP_WORDS_PT.has(token) ||
|
||||
STOP_WORDS_AR.has(token) ||
|
||||
STOP_WORDS_ZH.has(token) ||
|
||||
STOP_WORDS_KO.has(token) ||
|
||||
STOP_WORDS_JA.has(token)
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a token looks like a meaningful keyword.
|
||||
* Returns false for short tokens, numbers-only, etc.
|
||||
*/
|
||||
function isValidKeyword(token: string): boolean {
|
||||
if (!token || token.length === 0) {
|
||||
return false;
|
||||
}
|
||||
// Skip very short English words (likely stop words or fragments)
|
||||
if (/^[a-zA-Z]+$/.test(token) && token.length < 3) {
|
||||
return false;
|
||||
}
|
||||
// Skip pure numbers (not useful for semantic search)
|
||||
if (/^\d+$/.test(token)) {
|
||||
return false;
|
||||
}
|
||||
// Skip tokens that are all punctuation
|
||||
if (/^[\p{P}\p{S}]+$/u.test(token)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Simple tokenizer that handles English, Chinese, Korean, and Japanese text.
|
||||
* For Chinese, we do character-based splitting since we don't have a proper segmenter.
|
||||
* For English, we split on whitespace and punctuation.
|
||||
*/
|
||||
function tokenize(text: string): string[] {
|
||||
const tokens: string[] = [];
|
||||
const normalized = text.toLowerCase().trim();
|
||||
|
||||
// Split into segments (English words, Chinese character sequences, etc.)
|
||||
const segments = normalized.split(/[\s\p{P}]+/u).filter(Boolean);
|
||||
|
||||
for (const segment of segments) {
|
||||
// Japanese text often mixes scripts (kanji/kana/ASCII) without spaces.
|
||||
// Extract script-specific chunks so technical terms like "API" / "バグ" are retained.
|
||||
if (/[\u3040-\u30ff]/.test(segment)) {
|
||||
const jpParts =
|
||||
segment.match(/[a-z0-9_]+|[\u30a0-\u30ffー]+|[\u4e00-\u9fff]+|[\u3040-\u309f]{2,}/g) ?? [];
|
||||
for (const part of jpParts) {
|
||||
if (/^[\u4e00-\u9fff]+$/.test(part)) {
|
||||
tokens.push(part);
|
||||
for (let i = 0; i < part.length - 1; i++) {
|
||||
tokens.push(part[i] + part[i + 1]);
|
||||
}
|
||||
} else {
|
||||
tokens.push(part);
|
||||
}
|
||||
}
|
||||
} else if (/[\u4e00-\u9fff]/.test(segment)) {
|
||||
// Check if segment contains CJK characters (Chinese)
|
||||
// For Chinese, extract character n-grams (unigrams and bigrams)
|
||||
const chars = Array.from(segment).filter((c) => /[\u4e00-\u9fff]/.test(c));
|
||||
// Add individual characters
|
||||
tokens.push(...chars);
|
||||
// Add bigrams for better phrase matching
|
||||
for (let i = 0; i < chars.length - 1; i++) {
|
||||
tokens.push(chars[i] + chars[i + 1]);
|
||||
}
|
||||
} else if (/[\uac00-\ud7af\u3131-\u3163]/.test(segment)) {
|
||||
// For Korean (Hangul syllables and jamo), keep the word as-is unless it is
|
||||
// effectively a stop word once trailing particles are removed.
|
||||
const stem = stripKoreanTrailingParticle(segment);
|
||||
const stemIsStopWord = stem !== null && STOP_WORDS_KO.has(stem);
|
||||
if (!STOP_WORDS_KO.has(segment) && !stemIsStopWord) {
|
||||
tokens.push(segment);
|
||||
}
|
||||
// Also emit particle-stripped stems when they are useful keywords.
|
||||
if (stem && !STOP_WORDS_KO.has(stem) && isUsefulKoreanStem(stem)) {
|
||||
tokens.push(stem);
|
||||
}
|
||||
} else {
|
||||
// For non-CJK, keep as single token
|
||||
tokens.push(segment);
|
||||
}
|
||||
}
|
||||
|
||||
return tokens;
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract keywords from a conversational query for FTS search.
|
||||
*
|
||||
* Examples:
|
||||
* - "that thing we discussed about the API" → ["discussed", "API"]
|
||||
* - "之前讨论的那个方案" → ["讨论", "方案"]
|
||||
* - "what was the solution for the bug" → ["solution", "bug"]
|
||||
*/
|
||||
export function extractKeywords(query: string): string[] {
|
||||
const tokens = tokenize(query);
|
||||
const keywords: string[] = [];
|
||||
const seen = new Set<string>();
|
||||
|
||||
for (const token of tokens) {
|
||||
// Skip stop words
|
||||
if (isQueryStopWordToken(token)) {
|
||||
continue;
|
||||
}
|
||||
// Skip invalid keywords
|
||||
if (!isValidKeyword(token)) {
|
||||
continue;
|
||||
}
|
||||
// Skip duplicates
|
||||
if (seen.has(token)) {
|
||||
continue;
|
||||
}
|
||||
seen.add(token);
|
||||
keywords.push(token);
|
||||
}
|
||||
|
||||
return keywords;
|
||||
}
|
||||
|
||||
/**
|
||||
* Expand a query for FTS search.
|
||||
* Returns both the original query and extracted keywords for OR-matching.
|
||||
*
|
||||
* @param query - User's original query
|
||||
* @returns Object with original query and extracted keywords
|
||||
*/
|
||||
export function expandQueryForFts(query: string): {
|
||||
original: string;
|
||||
keywords: string[];
|
||||
expanded: string;
|
||||
} {
|
||||
const original = query.trim();
|
||||
const keywords = extractKeywords(original);
|
||||
|
||||
// Build expanded query: original terms OR extracted keywords
|
||||
// This ensures both exact matches and keyword matches are found
|
||||
const expanded = keywords.length > 0 ? `${original} OR ${keywords.join(" OR ")}` : original;
|
||||
|
||||
return { original, keywords, expanded };
|
||||
}
|
||||
|
||||
/**
|
||||
* Type for an optional LLM-based query expander.
|
||||
* Can be provided to enhance keyword extraction with semantic understanding.
|
||||
*/
|
||||
export type LlmQueryExpander = (query: string) => Promise<string[]>;
|
||||
|
||||
/**
|
||||
* Expand query with optional LLM assistance.
|
||||
* Falls back to local extraction if LLM is unavailable or fails.
|
||||
*/
|
||||
export async function expandQueryWithLlm(
|
||||
query: string,
|
||||
llmExpander?: LlmQueryExpander,
|
||||
): Promise<string[]> {
|
||||
// If LLM expander is provided, try it first
|
||||
if (llmExpander) {
|
||||
try {
|
||||
const llmKeywords = await llmExpander(query);
|
||||
if (llmKeywords.length > 0) {
|
||||
return llmKeywords;
|
||||
}
|
||||
} catch {
|
||||
// LLM failed, fall back to local extraction
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to local keyword extraction
|
||||
return extractKeywords(query);
|
||||
}
|
||||
@@ -1,96 +0,0 @@
|
||||
import fs from "node:fs/promises";
|
||||
import path from "node:path";
|
||||
import { resolveAgentWorkspaceDir } from "../../agents/agent-scope.js";
|
||||
import { resolveMemorySearchConfig } from "../../agents/memory-search.js";
|
||||
import type { OpenClawConfig } from "../../config/config.js";
|
||||
import { isFileMissingError, statRegularFile } from "./fs-utils.js";
|
||||
import { isMemoryPath, normalizeExtraMemoryPaths } from "./internal.js";
|
||||
|
||||
export async function readMemoryFile(params: {
|
||||
workspaceDir: string;
|
||||
extraPaths?: string[];
|
||||
relPath: string;
|
||||
from?: number;
|
||||
lines?: number;
|
||||
}): Promise<{ text: string; path: string }> {
|
||||
const rawPath = params.relPath.trim();
|
||||
if (!rawPath) {
|
||||
throw new Error("path required");
|
||||
}
|
||||
const absPath = path.isAbsolute(rawPath)
|
||||
? path.resolve(rawPath)
|
||||
: path.resolve(params.workspaceDir, rawPath);
|
||||
const relPath = path.relative(params.workspaceDir, absPath).replace(/\\/g, "/");
|
||||
const inWorkspace = relPath.length > 0 && !relPath.startsWith("..") && !path.isAbsolute(relPath);
|
||||
const allowedWorkspace = inWorkspace && isMemoryPath(relPath);
|
||||
let allowedAdditional = false;
|
||||
if (!allowedWorkspace && (params.extraPaths?.length ?? 0) > 0) {
|
||||
const additionalPaths = normalizeExtraMemoryPaths(params.workspaceDir, params.extraPaths);
|
||||
for (const additionalPath of additionalPaths) {
|
||||
try {
|
||||
const stat = await fs.lstat(additionalPath);
|
||||
if (stat.isSymbolicLink()) {
|
||||
continue;
|
||||
}
|
||||
if (stat.isDirectory()) {
|
||||
if (absPath === additionalPath || absPath.startsWith(`${additionalPath}${path.sep}`)) {
|
||||
allowedAdditional = true;
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (stat.isFile() && absPath === additionalPath && absPath.endsWith(".md")) {
|
||||
allowedAdditional = true;
|
||||
break;
|
||||
}
|
||||
} catch {}
|
||||
}
|
||||
}
|
||||
if (!allowedWorkspace && !allowedAdditional) {
|
||||
throw new Error("path required");
|
||||
}
|
||||
if (!absPath.endsWith(".md")) {
|
||||
throw new Error("path required");
|
||||
}
|
||||
const statResult = await statRegularFile(absPath);
|
||||
if (statResult.missing) {
|
||||
return { text: "", path: relPath };
|
||||
}
|
||||
let content: string;
|
||||
try {
|
||||
content = await fs.readFile(absPath, "utf-8");
|
||||
} catch (err) {
|
||||
if (isFileMissingError(err)) {
|
||||
return { text: "", path: relPath };
|
||||
}
|
||||
throw err;
|
||||
}
|
||||
if (!params.from && !params.lines) {
|
||||
return { text: content, path: relPath };
|
||||
}
|
||||
const fileLines = content.split("\n");
|
||||
const start = Math.max(1, params.from ?? 1);
|
||||
const count = Math.max(1, params.lines ?? fileLines.length);
|
||||
const slice = fileLines.slice(start - 1, start - 1 + count);
|
||||
return { text: slice.join("\n"), path: relPath };
|
||||
}
|
||||
|
||||
export async function readAgentMemoryFile(params: {
|
||||
cfg: OpenClawConfig;
|
||||
agentId: string;
|
||||
relPath: string;
|
||||
from?: number;
|
||||
lines?: number;
|
||||
}): Promise<{ text: string; path: string }> {
|
||||
const settings = resolveMemorySearchConfig(params.cfg, params.agentId);
|
||||
if (!settings) {
|
||||
throw new Error("memory search disabled");
|
||||
}
|
||||
return await readMemoryFile({
|
||||
workspaceDir: resolveAgentWorkspaceDir(params.cfg, params.agentId),
|
||||
extraPaths: settings.extraPaths,
|
||||
relPath: params.relPath,
|
||||
from: params.from,
|
||||
lines: params.lines,
|
||||
});
|
||||
}
|
||||
@@ -1,40 +0,0 @@
|
||||
import { fetchWithSsrFGuard } from "../../infra/net/fetch-guard.js";
|
||||
import type { SsrFPolicy } from "../../infra/net/ssrf.js";
|
||||
|
||||
export function buildRemoteBaseUrlPolicy(baseUrl: string): SsrFPolicy | undefined {
|
||||
const trimmed = baseUrl.trim();
|
||||
if (!trimmed) {
|
||||
return undefined;
|
||||
}
|
||||
try {
|
||||
const parsed = new URL(trimmed);
|
||||
if (parsed.protocol !== "http:" && parsed.protocol !== "https:") {
|
||||
return undefined;
|
||||
}
|
||||
// Keep policy tied to the configured host so private operator endpoints
|
||||
// continue to work, while cross-host redirects stay blocked.
|
||||
return { allowedHostnames: [parsed.hostname] };
|
||||
} catch {
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
export async function withRemoteHttpResponse<T>(params: {
|
||||
url: string;
|
||||
init?: RequestInit;
|
||||
ssrfPolicy?: SsrFPolicy;
|
||||
auditContext?: string;
|
||||
onResponse: (response: Response) => Promise<T>;
|
||||
}): Promise<T> {
|
||||
const { response, release } = await fetchWithSsrFGuard({
|
||||
url: params.url,
|
||||
init: params.init,
|
||||
policy: params.ssrfPolicy,
|
||||
auditContext: params.auditContext ?? "memory-remote",
|
||||
});
|
||||
try {
|
||||
return await params.onResponse(response);
|
||||
} finally {
|
||||
await release();
|
||||
}
|
||||
}
|
||||
@@ -1,18 +0,0 @@
|
||||
import {
|
||||
hasConfiguredSecretInput,
|
||||
normalizeResolvedSecretInputString,
|
||||
} from "../../config/types.secrets.js";
|
||||
|
||||
export function hasConfiguredMemorySecretInput(value: unknown): boolean {
|
||||
return hasConfiguredSecretInput(value);
|
||||
}
|
||||
|
||||
export function resolveMemorySecretInputString(params: {
|
||||
value: unknown;
|
||||
path: string;
|
||||
}): string | undefined {
|
||||
return normalizeResolvedSecretInputString({
|
||||
value: params.value,
|
||||
path: params.path,
|
||||
});
|
||||
}
|
||||
@@ -1,87 +0,0 @@
|
||||
import fs from "node:fs/promises";
|
||||
import os from "node:os";
|
||||
import path from "node:path";
|
||||
import { afterEach, beforeEach, describe, expect, it } from "vitest";
|
||||
import { buildSessionEntry } from "./session-files.js";
|
||||
|
||||
describe("buildSessionEntry", () => {
|
||||
let tmpDir: string;
|
||||
|
||||
beforeEach(async () => {
|
||||
tmpDir = await fs.mkdtemp(path.join(os.tmpdir(), "session-entry-test-"));
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
await fs.rm(tmpDir, { recursive: true, force: true });
|
||||
});
|
||||
|
||||
it("returns lineMap tracking original JSONL line numbers", async () => {
|
||||
// Simulate a real session JSONL file with metadata records interspersed
|
||||
// Lines 1-3: non-message metadata records
|
||||
// Line 4: user message
|
||||
// Line 5: metadata
|
||||
// Line 6: assistant message
|
||||
// Line 7: user message
|
||||
const jsonlLines = [
|
||||
JSON.stringify({ type: "custom", customType: "model-snapshot", data: {} }),
|
||||
JSON.stringify({ type: "custom", customType: "openclaw.cache-ttl", data: {} }),
|
||||
JSON.stringify({ type: "session-meta", agentId: "test" }),
|
||||
JSON.stringify({ type: "message", message: { role: "user", content: "Hello world" } }),
|
||||
JSON.stringify({ type: "custom", customType: "tool-result", data: {} }),
|
||||
JSON.stringify({
|
||||
type: "message",
|
||||
message: { role: "assistant", content: "Hi there, how can I help?" },
|
||||
}),
|
||||
JSON.stringify({ type: "message", message: { role: "user", content: "Tell me a joke" } }),
|
||||
];
|
||||
const filePath = path.join(tmpDir, "session.jsonl");
|
||||
await fs.writeFile(filePath, jsonlLines.join("\n"));
|
||||
|
||||
const entry = await buildSessionEntry(filePath);
|
||||
expect(entry).not.toBeNull();
|
||||
|
||||
// The content should have 3 lines (3 message records)
|
||||
const contentLines = entry!.content.split("\n");
|
||||
expect(contentLines).toHaveLength(3);
|
||||
expect(contentLines[0]).toContain("User: Hello world");
|
||||
expect(contentLines[1]).toContain("Assistant: Hi there");
|
||||
expect(contentLines[2]).toContain("User: Tell me a joke");
|
||||
|
||||
// lineMap should map each content line to its original JSONL line (1-indexed)
|
||||
// Content line 0 → JSONL line 4 (the first user message)
|
||||
// Content line 1 → JSONL line 6 (the assistant message)
|
||||
// Content line 2 → JSONL line 7 (the second user message)
|
||||
expect(entry!.lineMap).toBeDefined();
|
||||
expect(entry!.lineMap).toEqual([4, 6, 7]);
|
||||
});
|
||||
|
||||
it("returns empty lineMap when no messages are found", async () => {
|
||||
const jsonlLines = [
|
||||
JSON.stringify({ type: "custom", customType: "model-snapshot", data: {} }),
|
||||
JSON.stringify({ type: "session-meta", agentId: "test" }),
|
||||
];
|
||||
const filePath = path.join(tmpDir, "empty-session.jsonl");
|
||||
await fs.writeFile(filePath, jsonlLines.join("\n"));
|
||||
|
||||
const entry = await buildSessionEntry(filePath);
|
||||
expect(entry).not.toBeNull();
|
||||
expect(entry!.content).toBe("");
|
||||
expect(entry!.lineMap).toEqual([]);
|
||||
});
|
||||
|
||||
it("skips blank lines and invalid JSON without breaking lineMap", async () => {
|
||||
const jsonlLines = [
|
||||
"",
|
||||
"not valid json",
|
||||
JSON.stringify({ type: "message", message: { role: "user", content: "First" } }),
|
||||
"",
|
||||
JSON.stringify({ type: "message", message: { role: "assistant", content: "Second" } }),
|
||||
];
|
||||
const filePath = path.join(tmpDir, "gaps.jsonl");
|
||||
await fs.writeFile(filePath, jsonlLines.join("\n"));
|
||||
|
||||
const entry = await buildSessionEntry(filePath);
|
||||
expect(entry).not.toBeNull();
|
||||
expect(entry!.lineMap).toEqual([3, 5]);
|
||||
});
|
||||
});
|
||||
@@ -1,131 +0,0 @@
|
||||
import fs from "node:fs/promises";
|
||||
import path from "node:path";
|
||||
import { resolveSessionTranscriptsDirForAgent } from "../../config/sessions/paths.js";
|
||||
import { redactSensitiveText } from "../../logging/redact.js";
|
||||
import { createSubsystemLogger } from "../../logging/subsystem.js";
|
||||
import { hashText } from "./internal.js";
|
||||
|
||||
const log = createSubsystemLogger("memory");
|
||||
|
||||
export type SessionFileEntry = {
|
||||
path: string;
|
||||
absPath: string;
|
||||
mtimeMs: number;
|
||||
size: number;
|
||||
hash: string;
|
||||
content: string;
|
||||
/** Maps each content line (0-indexed) to its 1-indexed JSONL source line. */
|
||||
lineMap: number[];
|
||||
};
|
||||
|
||||
export async function listSessionFilesForAgent(agentId: string): Promise<string[]> {
|
||||
const dir = resolveSessionTranscriptsDirForAgent(agentId);
|
||||
try {
|
||||
const entries = await fs.readdir(dir, { withFileTypes: true });
|
||||
return entries
|
||||
.filter((entry) => entry.isFile())
|
||||
.map((entry) => entry.name)
|
||||
.filter((name) => name.endsWith(".jsonl"))
|
||||
.map((name) => path.join(dir, name));
|
||||
} catch {
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
export function sessionPathForFile(absPath: string): string {
|
||||
return path.join("sessions", path.basename(absPath)).replace(/\\/g, "/");
|
||||
}
|
||||
|
||||
function normalizeSessionText(value: string): string {
|
||||
return value
|
||||
.replace(/\s*\n+\s*/g, " ")
|
||||
.replace(/\s+/g, " ")
|
||||
.trim();
|
||||
}
|
||||
|
||||
export function extractSessionText(content: unknown): string | null {
|
||||
if (typeof content === "string") {
|
||||
const normalized = normalizeSessionText(content);
|
||||
return normalized ? normalized : null;
|
||||
}
|
||||
if (!Array.isArray(content)) {
|
||||
return null;
|
||||
}
|
||||
const parts: string[] = [];
|
||||
for (const block of content) {
|
||||
if (!block || typeof block !== "object") {
|
||||
continue;
|
||||
}
|
||||
const record = block as { type?: unknown; text?: unknown };
|
||||
if (record.type !== "text" || typeof record.text !== "string") {
|
||||
continue;
|
||||
}
|
||||
const normalized = normalizeSessionText(record.text);
|
||||
if (normalized) {
|
||||
parts.push(normalized);
|
||||
}
|
||||
}
|
||||
if (parts.length === 0) {
|
||||
return null;
|
||||
}
|
||||
return parts.join(" ");
|
||||
}
|
||||
|
||||
export async function buildSessionEntry(absPath: string): Promise<SessionFileEntry | null> {
|
||||
try {
|
||||
const stat = await fs.stat(absPath);
|
||||
const raw = await fs.readFile(absPath, "utf-8");
|
||||
const lines = raw.split("\n");
|
||||
const collected: string[] = [];
|
||||
const lineMap: number[] = [];
|
||||
for (let jsonlIdx = 0; jsonlIdx < lines.length; jsonlIdx++) {
|
||||
const line = lines[jsonlIdx];
|
||||
if (!line.trim()) {
|
||||
continue;
|
||||
}
|
||||
let record: unknown;
|
||||
try {
|
||||
record = JSON.parse(line);
|
||||
} catch {
|
||||
continue;
|
||||
}
|
||||
if (
|
||||
!record ||
|
||||
typeof record !== "object" ||
|
||||
(record as { type?: unknown }).type !== "message"
|
||||
) {
|
||||
continue;
|
||||
}
|
||||
const message = (record as { message?: unknown }).message as
|
||||
| { role?: unknown; content?: unknown }
|
||||
| undefined;
|
||||
if (!message || typeof message.role !== "string") {
|
||||
continue;
|
||||
}
|
||||
if (message.role !== "user" && message.role !== "assistant") {
|
||||
continue;
|
||||
}
|
||||
const text = extractSessionText(message.content);
|
||||
if (!text) {
|
||||
continue;
|
||||
}
|
||||
const safe = redactSensitiveText(text, { mode: "tools" });
|
||||
const label = message.role === "user" ? "User" : "Assistant";
|
||||
collected.push(`${label}: ${safe}`);
|
||||
lineMap.push(jsonlIdx + 1);
|
||||
}
|
||||
const content = collected.join("\n");
|
||||
return {
|
||||
path: sessionPathForFile(absPath),
|
||||
absPath,
|
||||
mtimeMs: stat.mtimeMs,
|
||||
size: stat.size,
|
||||
hash: hashText(content + "\n" + lineMap.join(",")),
|
||||
content,
|
||||
lineMap,
|
||||
};
|
||||
} catch (err) {
|
||||
log.debug(`Failed reading session file ${absPath}: ${String(err)}`);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
@@ -1,24 +0,0 @@
|
||||
import type { DatabaseSync } from "node:sqlite";
|
||||
|
||||
export async function loadSqliteVecExtension(params: {
|
||||
db: DatabaseSync;
|
||||
extensionPath?: string;
|
||||
}): Promise<{ ok: boolean; extensionPath?: string; error?: string }> {
|
||||
try {
|
||||
const sqliteVec = await import("sqlite-vec");
|
||||
const resolvedPath = params.extensionPath?.trim() ? params.extensionPath.trim() : undefined;
|
||||
const extensionPath = resolvedPath ?? sqliteVec.getLoadablePath();
|
||||
|
||||
params.db.enableLoadExtension(true);
|
||||
if (resolvedPath) {
|
||||
params.db.loadExtension(extensionPath);
|
||||
} else {
|
||||
sqliteVec.load(params.db);
|
||||
}
|
||||
|
||||
return { ok: true, extensionPath };
|
||||
} catch (err) {
|
||||
const message = err instanceof Error ? err.message : String(err);
|
||||
return { ok: false, error: message };
|
||||
}
|
||||
}
|
||||
@@ -1,19 +0,0 @@
|
||||
import { createRequire } from "node:module";
|
||||
import { installProcessWarningFilter } from "../../infra/warning-filter.js";
|
||||
|
||||
const require = createRequire(import.meta.url);
|
||||
|
||||
export function requireNodeSqlite(): typeof import("node:sqlite") {
|
||||
installProcessWarningFilter();
|
||||
try {
|
||||
return require("node:sqlite") as typeof import("node:sqlite");
|
||||
} catch (err) {
|
||||
const message = err instanceof Error ? err.message : String(err);
|
||||
// Node distributions can ship without the experimental builtin SQLite module.
|
||||
// Surface an actionable error instead of the generic "unknown builtin module".
|
||||
throw new Error(
|
||||
`SQLite support is unavailable in this Node runtime (missing node:sqlite). ${message}`,
|
||||
{ cause: err },
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,45 +0,0 @@
|
||||
export type Tone = "ok" | "warn" | "muted";
|
||||
|
||||
export function resolveMemoryVectorState(vector: { enabled: boolean; available?: boolean }): {
|
||||
tone: Tone;
|
||||
state: "ready" | "unavailable" | "disabled" | "unknown";
|
||||
} {
|
||||
if (!vector.enabled) {
|
||||
return { tone: "muted", state: "disabled" };
|
||||
}
|
||||
if (vector.available === true) {
|
||||
return { tone: "ok", state: "ready" };
|
||||
}
|
||||
if (vector.available === false) {
|
||||
return { tone: "warn", state: "unavailable" };
|
||||
}
|
||||
return { tone: "muted", state: "unknown" };
|
||||
}
|
||||
|
||||
export function resolveMemoryFtsState(fts: { enabled: boolean; available: boolean }): {
|
||||
tone: Tone;
|
||||
state: "ready" | "unavailable" | "disabled";
|
||||
} {
|
||||
if (!fts.enabled) {
|
||||
return { tone: "muted", state: "disabled" };
|
||||
}
|
||||
return fts.available ? { tone: "ok", state: "ready" } : { tone: "warn", state: "unavailable" };
|
||||
}
|
||||
|
||||
export function resolveMemoryCacheSummary(cache: { enabled: boolean; entries?: number }): {
|
||||
tone: Tone;
|
||||
text: string;
|
||||
} {
|
||||
if (!cache.enabled) {
|
||||
return { tone: "muted", text: "cache off" };
|
||||
}
|
||||
const suffix = typeof cache.entries === "number" ? ` (${cache.entries})` : "";
|
||||
return { tone: "ok", text: `cache on${suffix}` };
|
||||
}
|
||||
|
||||
export function resolveMemoryCacheState(cache: { enabled: boolean }): {
|
||||
tone: Tone;
|
||||
state: "enabled" | "disabled";
|
||||
} {
|
||||
return cache.enabled ? { tone: "ok", state: "enabled" } : { tone: "muted", state: "disabled" };
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
import { vi } from "vitest";
|
||||
import * as ssrf from "../../../infra/net/ssrf.js";
|
||||
|
||||
export function mockPublicPinnedHostname() {
|
||||
return vi.spyOn(ssrf, "resolvePinnedHostnameWithPolicy").mockImplementation(async (hostname) => {
|
||||
const normalized = hostname.trim().toLowerCase().replace(/\.$/, "");
|
||||
const addresses = ["93.184.216.34"];
|
||||
return {
|
||||
hostname: normalized,
|
||||
addresses,
|
||||
lookup: ssrf.createPinnedLookup({ hostname: normalized, addresses }),
|
||||
};
|
||||
});
|
||||
}
|
||||
@@ -1,81 +0,0 @@
|
||||
export type MemorySource = "memory" | "sessions";
|
||||
|
||||
export type MemorySearchResult = {
|
||||
path: string;
|
||||
startLine: number;
|
||||
endLine: number;
|
||||
score: number;
|
||||
snippet: string;
|
||||
source: MemorySource;
|
||||
citation?: string;
|
||||
};
|
||||
|
||||
export type MemoryEmbeddingProbeResult = {
|
||||
ok: boolean;
|
||||
error?: string;
|
||||
};
|
||||
|
||||
export type MemorySyncProgressUpdate = {
|
||||
completed: number;
|
||||
total: number;
|
||||
label?: string;
|
||||
};
|
||||
|
||||
export type MemoryProviderStatus = {
|
||||
backend: "builtin" | "qmd";
|
||||
provider: string;
|
||||
model?: string;
|
||||
requestedProvider?: string;
|
||||
files?: number;
|
||||
chunks?: number;
|
||||
dirty?: boolean;
|
||||
workspaceDir?: string;
|
||||
dbPath?: string;
|
||||
extraPaths?: string[];
|
||||
sources?: MemorySource[];
|
||||
sourceCounts?: Array<{ source: MemorySource; files: number; chunks: number }>;
|
||||
cache?: { enabled: boolean; entries?: number; maxEntries?: number };
|
||||
fts?: { enabled: boolean; available: boolean; error?: string };
|
||||
fallback?: { from: string; reason?: string };
|
||||
vector?: {
|
||||
enabled: boolean;
|
||||
available?: boolean;
|
||||
extensionPath?: string;
|
||||
loadError?: string;
|
||||
dims?: number;
|
||||
};
|
||||
batch?: {
|
||||
enabled: boolean;
|
||||
failures: number;
|
||||
limit: number;
|
||||
wait: boolean;
|
||||
concurrency: number;
|
||||
pollIntervalMs: number;
|
||||
timeoutMs: number;
|
||||
lastError?: string;
|
||||
lastProvider?: string;
|
||||
};
|
||||
custom?: Record<string, unknown>;
|
||||
};
|
||||
|
||||
export interface MemorySearchManager {
|
||||
search(
|
||||
query: string,
|
||||
opts?: { maxResults?: number; minScore?: number; sessionKey?: string },
|
||||
): Promise<MemorySearchResult[]>;
|
||||
readFile(params: {
|
||||
relPath: string;
|
||||
from?: number;
|
||||
lines?: number;
|
||||
}): Promise<{ text: string; path: string }>;
|
||||
status(): MemoryProviderStatus;
|
||||
sync?(params?: {
|
||||
reason?: string;
|
||||
force?: boolean;
|
||||
sessionFiles?: string[];
|
||||
progress?: (update: MemorySyncProgressUpdate) => void;
|
||||
}): Promise<void>;
|
||||
probeEmbeddingAvailability(): Promise<MemoryEmbeddingProbeResult>;
|
||||
probeVectorAvailability(): Promise<boolean>;
|
||||
close?(): Promise<void>;
|
||||
}
|
||||
@@ -4,7 +4,7 @@ import type {
|
||||
MemoryEmbeddingProbeResult,
|
||||
MemoryProviderStatus,
|
||||
MemorySyncProgressUpdate,
|
||||
} from "../plugins/memory-host/types.js";
|
||||
} from "../plugin-sdk/memory-core-host-engine-storage.js";
|
||||
|
||||
export type MemoryPromptSectionBuilder = (params: {
|
||||
availableTools: Set<string>;
|
||||
|
||||
Reference in New Issue
Block a user