refactor: move memory embeddings into provider plugins

This commit is contained in:
Peter Steinberger
2026-04-17 01:31:39 +01:00
parent 7e9ff0f86e
commit 77e6e4cf87
94 changed files with 1039 additions and 7125 deletions

View File

@@ -318,7 +318,7 @@ Current bundled provider examples:
| `plugin-sdk/memory-core` | Bundled memory-core helpers | Memory manager/config/file/CLI helper surface |
| `plugin-sdk/memory-core-engine-runtime` | Memory engine runtime facade | Memory index/search runtime facade |
| `plugin-sdk/memory-core-host-engine-foundation` | Memory host foundation engine | Memory host foundation engine exports |
| `plugin-sdk/memory-core-host-engine-embeddings` | Memory host embedding engine | Memory host embedding engine exports |
| `plugin-sdk/memory-core-host-engine-embeddings` | Memory host embedding engine | Memory embedding contracts, registry access, local provider, and generic batch/remote helpers; concrete remote providers live in their owning plugins |
| `plugin-sdk/memory-core-host-engine-qmd` | Memory host QMD engine | Memory host QMD engine exports |
| `plugin-sdk/memory-core-host-engine-storage` | Memory host storage engine | Memory host storage engine exports |
| `plugin-sdk/memory-core-host-multimodal` | Memory host multimodal helpers | Memory host multimodal helpers |

View File

@@ -264,7 +264,7 @@ explicitly promotes one as public.
| `plugin-sdk/memory-core` | Bundled memory-core helper surface for manager/config/file/CLI helpers |
| `plugin-sdk/memory-core-engine-runtime` | Memory index/search runtime facade |
| `plugin-sdk/memory-core-host-engine-foundation` | Memory host foundation engine exports |
| `plugin-sdk/memory-core-host-engine-embeddings` | Memory host embedding engine exports |
| `plugin-sdk/memory-core-host-engine-embeddings` | Memory host embedding contracts, registry access, local provider, and generic batch/remote helpers |
| `plugin-sdk/memory-core-host-engine-qmd` | Memory host QMD engine exports |
| `plugin-sdk/memory-core-host-engine-storage` | Memory host storage engine exports |
| `plugin-sdk/memory-core-host-multimodal` | Memory host multimodal helpers |

View File

@@ -1,7 +1,10 @@
import { normalizeLowercaseStringOrEmpty } from "../../shared/string-coerce.js";
import { sanitizeAndNormalizeEmbedding } from "./embedding-vectors.js";
import { debugEmbeddingsLog } from "./embeddings-debug.js";
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.types.js";
import {
debugEmbeddingsLog,
sanitizeAndNormalizeEmbedding,
type MemoryEmbeddingProvider,
type MemoryEmbeddingProviderCreateOptions,
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
import { normalizeLowercaseStringOrEmpty } from "openclaw/plugin-sdk/text-runtime";
// ---------------------------------------------------------------------------
// Types & constants
@@ -254,8 +257,8 @@ function parseCohereBatch(family: Family, raw: string): number[][] {
// ---------------------------------------------------------------------------
export async function createBedrockEmbeddingProvider(
options: EmbeddingProviderOptions,
): Promise<{ provider: EmbeddingProvider; client: BedrockEmbeddingClient }> {
options: MemoryEmbeddingProviderCreateOptions,
): Promise<{ provider: MemoryEmbeddingProvider; client: BedrockEmbeddingClient }> {
const client = resolveBedrockEmbeddingClient(options);
const { BedrockRuntimeClient, InvokeModelCommand } = await loadSdk();
const sdk = new BedrockRuntimeClient({ region: client.region });
@@ -333,7 +336,7 @@ export async function createBedrockEmbeddingProvider(
// ---------------------------------------------------------------------------
export function resolveBedrockEmbeddingClient(
options: EmbeddingProviderOptions,
options: MemoryEmbeddingProviderCreateOptions,
): BedrockEmbeddingClient {
const model = normalizeBedrockEmbeddingModel(options.model);
const spec = resolveSpec(model);

View File

@@ -0,0 +1,37 @@
import {
isMissingEmbeddingApiKeyError,
type MemoryEmbeddingProviderAdapter,
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
import {
createBedrockEmbeddingProvider,
DEFAULT_BEDROCK_EMBEDDING_MODEL,
} from "./embedding-provider.js";
export const bedrockMemoryEmbeddingProviderAdapter: MemoryEmbeddingProviderAdapter = {
id: "bedrock",
defaultModel: DEFAULT_BEDROCK_EMBEDDING_MODEL,
transport: "remote",
authProviderId: "amazon-bedrock",
autoSelectPriority: 60,
allowExplicitWhenConfiguredAuto: true,
shouldContinueAutoSelection: isMissingEmbeddingApiKeyError,
create: async (options) => {
const { provider, client } = await createBedrockEmbeddingProvider({
...options,
provider: "bedrock",
fallback: "none",
});
return {
provider,
runtime: {
id: "bedrock",
cacheKeyData: {
provider: "bedrock",
region: client.region,
model: client.model,
dimensions: client.dimensions,
},
},
};
},
};

View File

@@ -2,6 +2,9 @@
"id": "amazon-bedrock",
"enabledByDefault": true,
"providers": ["amazon-bedrock"],
"contracts": {
"memoryEmbeddingProviders": ["bedrock"]
},
"configSchema": {
"type": "object",
"additionalProperties": false,

View File

@@ -5,7 +5,9 @@
"description": "OpenClaw Amazon Bedrock provider plugin",
"type": "module",
"dependencies": {
"@aws-sdk/client-bedrock": "3.1028.0"
"@aws-sdk/client-bedrock": "3.1028.0",
"@aws-sdk/client-bedrock-runtime": "3.1028.0",
"@aws-sdk/credential-provider-node": "3.972.30"
},
"devDependencies": {
"@openclaw/plugin-sdk": "workspace:*"

View File

@@ -14,6 +14,7 @@ import {
resolveBedrockConfigApiKey,
resolveImplicitBedrockProvider,
} from "./api.js";
import { bedrockMemoryEmbeddingProviderAdapter } from "./memory-embedding-adapter.js";
type GuardrailConfig = {
guardrailIdentifier: string;
@@ -78,6 +79,8 @@ export function registerAmazonBedrockPlugin(api: OpenClawPluginApi): void {
const pluginConfig = (api.pluginConfig ?? {}) as AmazonBedrockPluginConfig;
const guardrail = pluginConfig.guardrail;
api.registerMemoryEmbeddingProvider(bedrockMemoryEmbeddingProviderAdapter);
const baseWrapStreamFn = ({ modelId, streamFn }: { modelId: string; streamFn?: StreamFn }) =>
isAnthropicBedrockModel(modelId) ? streamFn : createBedrockNoCacheWrapper(streamFn);

View File

@@ -4,7 +4,6 @@ const resolveFirstGithubTokenMock = vi.hoisted(() => vi.fn());
const resolveCopilotApiTokenMock = vi.hoisted(() => vi.fn());
const resolveConfiguredSecretInputStringMock = vi.hoisted(() => vi.fn());
const fetchWithSsrFGuardMock = vi.hoisted(() => vi.fn());
const createGitHubCopilotEmbeddingProviderMock = vi.hoisted(() => vi.fn());
vi.mock("./auth.js", () => ({
resolveFirstGithubToken: resolveFirstGithubTokenMock,
@@ -19,10 +18,6 @@ vi.mock("openclaw/plugin-sdk/github-copilot-token", () => ({
resolveCopilotApiToken: resolveCopilotApiTokenMock,
}));
vi.mock("openclaw/plugin-sdk/memory-core-host-engine-embeddings", () => ({
createGitHubCopilotEmbeddingProvider: createGitHubCopilotEmbeddingProviderMock,
}));
vi.mock("openclaw/plugin-sdk/ssrf-runtime", () => ({
fetchWithSsrFGuard: fetchWithSsrFGuardMock,
}));
@@ -73,15 +68,6 @@ describe("githubCopilotMemoryEmbeddingProviderAdapter", () => {
source: "test",
baseUrl: TEST_BASE_URL,
});
createGitHubCopilotEmbeddingProviderMock.mockImplementation(async (client) => ({
provider: {
id: "github-copilot",
model: client.model,
embedQuery: async () => [0.1, 0.2, 0.3],
embedBatch: async (texts: string[]) => texts.map(() => [0.1, 0.2, 0.3]),
},
client,
}));
});
afterEach(() => {
@@ -89,7 +75,6 @@ describe("githubCopilotMemoryEmbeddingProviderAdapter", () => {
resolveConfiguredSecretInputStringMock.mockReset();
resolveFirstGithubTokenMock.mockReset();
resolveCopilotApiTokenMock.mockReset();
createGitHubCopilotEmbeddingProviderMock.mockReset();
fetchWithSsrFGuardMock.mockReset();
});
@@ -113,12 +98,8 @@ describe("githubCopilotMemoryEmbeddingProviderAdapter", () => {
const result = await githubCopilotMemoryEmbeddingProviderAdapter.create(defaultCreateOptions());
expect(result.provider?.model).toBe("text-embedding-3-small");
expect(createGitHubCopilotEmbeddingProviderMock).toHaveBeenCalledWith(
expect.objectContaining({
baseUrl: TEST_BASE_URL,
githubToken: "gh_test_token_123",
model: "text-embedding-3-small",
}),
expect(resolveCopilotApiTokenMock).toHaveBeenCalledWith(
expect.objectContaining({ githubToken: "gh_test_token_123" }),
);
});
@@ -217,14 +198,12 @@ describe("githubCopilotMemoryEmbeddingProviderAdapter", () => {
} as never);
expect(resolveFirstGithubTokenMock).toHaveBeenCalled();
expect(createGitHubCopilotEmbeddingProviderMock).toHaveBeenCalledWith({
baseUrl: "https://proxy.example/v1",
env: process.env,
fetchImpl: fetch,
githubToken: "gh_remote_token",
headers: { "X-Proxy-Token": "proxy" },
model: "text-embedding-3-small",
});
expect(resolveCopilotApiTokenMock).toHaveBeenCalledWith(
expect.objectContaining({
env: process.env,
githubToken: "gh_remote_token",
}),
);
const discoveryCall = fetchWithSsrFGuardMock.mock.calls[0]?.[0] as {
init: { headers: Record<string, string> };

View File

@@ -4,7 +4,10 @@ import {
resolveCopilotApiToken,
} from "openclaw/plugin-sdk/github-copilot-token";
import {
createGitHubCopilotEmbeddingProvider,
buildRemoteBaseUrlPolicy,
sanitizeAndNormalizeEmbedding,
withRemoteHttpResponse,
type MemoryEmbeddingProvider,
type MemoryEmbeddingProviderAdapter,
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
import { fetchWithSsrFGuard, type SsrFPolicy } from "openclaw/plugin-sdk/ssrf-runtime";
@@ -44,6 +47,15 @@ type CopilotModelEntry = {
supported_endpoints?: unknown;
};
type GitHubCopilotEmbeddingClient = {
githubToken: string;
model: string;
baseUrl?: string;
headers?: Record<string, string>;
env?: NodeJS.ProcessEnv;
fetchImpl?: typeof fetch;
};
function isCopilotSetupError(err: unknown): boolean {
if (!(err instanceof Error)) {
return false;
@@ -147,9 +159,126 @@ function pickBestModel(available: string[], userModel?: string): string {
throw new Error("No embedding models available from GitHub Copilot");
}
function parseGitHubCopilotEmbeddingPayload(payload: unknown, expectedCount: number): number[][] {
if (!payload || typeof payload !== "object") {
throw new Error("GitHub Copilot embeddings response missing data[]");
}
const data = (payload as { data?: unknown }).data;
if (!Array.isArray(data)) {
throw new Error("GitHub Copilot embeddings response missing data[]");
}
const vectors = Array.from<number[] | undefined>({ length: expectedCount });
for (const entry of data) {
if (!entry || typeof entry !== "object") {
throw new Error("GitHub Copilot embeddings response contains an invalid entry");
}
const indexValue = (entry as { index?: unknown }).index;
const embedding = (entry as { embedding?: unknown }).embedding;
const index = typeof indexValue === "number" ? indexValue : Number.NaN;
if (!Number.isInteger(index)) {
throw new Error("GitHub Copilot embeddings response contains an invalid index");
}
if (index < 0 || index >= expectedCount) {
throw new Error("GitHub Copilot embeddings response contains an out-of-range index");
}
if (vectors[index] !== undefined) {
throw new Error("GitHub Copilot embeddings response contains duplicate indexes");
}
if (!Array.isArray(embedding) || !embedding.every((value) => typeof value === "number")) {
throw new Error("GitHub Copilot embeddings response contains an invalid embedding");
}
vectors[index] = sanitizeAndNormalizeEmbedding(embedding);
}
for (let index = 0; index < expectedCount; index += 1) {
if (vectors[index] === undefined) {
throw new Error("GitHub Copilot embeddings response missing vectors for some inputs");
}
}
return vectors as number[][];
}
async function resolveGitHubCopilotEmbeddingSession(client: GitHubCopilotEmbeddingClient): Promise<{
baseUrl: string;
headers: Record<string, string>;
}> {
const token = await resolveCopilotApiToken({
githubToken: client.githubToken,
env: client.env,
fetchImpl: client.fetchImpl,
});
const baseUrl = client.baseUrl?.trim() || token.baseUrl || DEFAULT_COPILOT_API_BASE_URL;
return {
baseUrl,
headers: {
...COPILOT_HEADERS_STATIC,
...client.headers,
Authorization: `Bearer ${token.token}`,
},
};
}
async function createGitHubCopilotEmbeddingProvider(
client: GitHubCopilotEmbeddingClient,
): Promise<{ provider: MemoryEmbeddingProvider; client: GitHubCopilotEmbeddingClient }> {
const initialSession = await resolveGitHubCopilotEmbeddingSession(client);
const embed = async (input: string[]): Promise<number[][]> => {
if (input.length === 0) {
return [];
}
const session = await resolveGitHubCopilotEmbeddingSession(client);
const url = `${session.baseUrl.replace(/\/$/, "")}/embeddings`;
return await withRemoteHttpResponse({
url,
fetchImpl: client.fetchImpl,
ssrfPolicy: buildRemoteBaseUrlPolicy(session.baseUrl),
init: {
method: "POST",
headers: session.headers,
body: JSON.stringify({ model: client.model, input }),
},
onResponse: async (response) => {
if (!response.ok) {
throw new Error(
`GitHub Copilot embeddings HTTP ${response.status}: ${await response.text()}`,
);
}
let payload: unknown;
try {
payload = await response.json();
} catch {
throw new Error("GitHub Copilot embeddings returned invalid JSON");
}
return parseGitHubCopilotEmbeddingPayload(payload, input.length);
},
});
};
return {
provider: {
id: COPILOT_EMBEDDING_PROVIDER_ID,
model: client.model,
embedQuery: async (text) => {
const [vector] = await embed([text]);
return vector ?? [];
},
embedBatch: embed,
},
client: {
...client,
baseUrl: initialSession.baseUrl,
},
};
}
export const githubCopilotMemoryEmbeddingProviderAdapter: MemoryEmbeddingProviderAdapter = {
id: COPILOT_EMBEDDING_PROVIDER_ID,
transport: "remote",
authProviderId: COPILOT_EMBEDDING_PROVIDER_ID,
autoSelectPriority: 15,
allowExplicitWhenConfiguredAuto: true,
shouldContinueAutoSelection: (err: unknown) => isCopilotSetupError(err),

View File

@@ -1,14 +1,15 @@
import crypto from "node:crypto";
import {
buildEmbeddingBatchGroupOptions,
runEmbeddingBatchGroups,
type EmbeddingBatchExecutionParams,
} from "./batch-runner.js";
import { buildBatchHeaders, normalizeBatchBaseUrl } from "./batch-utils.js";
import { sanitizeAndNormalizeEmbedding } from "./embedding-vectors.js";
import { debugEmbeddingsLog } from "./embeddings-debug.js";
import type { GeminiEmbeddingClient, GeminiTextEmbeddingRequest } from "./embeddings-gemini.js";
import { hashText } from "./internal.js";
import { withRemoteHttpResponse } from "./remote-http.js";
buildBatchHeaders,
debugEmbeddingsLog,
normalizeBatchBaseUrl,
sanitizeAndNormalizeEmbedding,
withRemoteHttpResponse,
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
import type { GeminiEmbeddingClient, GeminiTextEmbeddingRequest } from "./embedding-provider.js";
export type GeminiBatchRequest = {
custom_id: string;
@@ -40,6 +41,10 @@ export type GeminiBatchOutputLine = {
};
const GEMINI_BATCH_MAX_REQUESTS = 50000;
function hashText(text: string): string {
return crypto.createHash("sha256").update(text).digest("hex");
}
function getGeminiUploadUrl(baseUrl: string): string {
if (baseUrl.includes("/v1beta")) {
return baseUrl.replace(/\/v1beta\/?$/, "/upload/v1beta");

View File

@@ -1,84 +1,40 @@
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
import * as authModule from "../../agents/model-auth.js";
import { afterEach, describe, expect, it, vi } from "vitest";
import {
buildGeminiEmbeddingRequest,
buildGeminiTextEmbeddingRequest,
createGeminiEmbeddingProvider,
DEFAULT_GEMINI_EMBEDDING_MODEL,
GEMINI_EMBEDDING_2_MODELS,
isGeminiEmbedding2Model,
normalizeGeminiModel,
resolveGeminiOutputDimensionality,
} from "./embeddings-gemini-request.js";
import {
createGeminiBatchFetchMock,
createJsonResponseFetchMock,
installFetchMock,
mockResolvedProviderKey,
parseFetchBody,
readFirstFetchRequest,
type JsonFetchMock,
} from "./embeddings-provider.test-support.js";
const { resolveApiKeyForProviderMock } = vi.hoisted(() => ({
resolveApiKeyForProviderMock: vi.fn(),
}));
vi.mock("../../agents/model-auth.js", () => {
return {
resolveApiKeyForProvider: resolveApiKeyForProviderMock,
requireApiKey: (auth: { apiKey?: string; mode?: string }, provider: string) => {
if (auth.apiKey) {
return auth.apiKey;
}
throw new Error(`No API key resolved for provider "${provider}" (auth mode: ${auth.mode}).`);
},
};
});
vi.mock("../../agents/api-key-rotation.js", () => ({
collectProviderApiKeysForExecution: (params: { primaryApiKey?: string }) =>
params.primaryApiKey ? [params.primaryApiKey] : [],
executeWithApiKeyRotation: async <T>(params: {
apiKeys: string[];
execute: (apiKey: string) => Promise<T>;
}) => {
const apiKey = params.apiKeys[0];
if (!apiKey) {
throw new Error('No API keys configured for provider "google".');
}
return await params.execute(apiKey);
},
}));
beforeEach(() => {
vi.useRealTimers();
vi.doUnmock("undici");
});
} from "./embedding-provider.js";
afterEach(() => {
vi.doUnmock("undici");
vi.resetAllMocks();
vi.restoreAllMocks();
vi.unstubAllGlobals();
});
type GeminiProviderOptions = Parameters<
typeof import("./embeddings-gemini.js").createGeminiEmbeddingProvider
>[0];
async function createProviderWithFetch(
fetchMock: JsonFetchMock,
options: Partial<GeminiProviderOptions> & { model: string },
) {
installFetchMock(fetchMock as unknown as typeof globalThis.fetch);
mockResolvedProviderKey(authModule.resolveApiKeyForProvider);
const { createGeminiEmbeddingProvider } = await import("./embeddings-gemini.js");
const { provider } = await createGeminiEmbeddingProvider({
config: {} as never,
provider: "gemini",
fallback: "none",
...options,
function installFetchMock(
handler: (input: RequestInfo | URL, init?: RequestInit) => unknown,
): ReturnType<typeof vi.fn> {
const fetchMock = vi.fn(async (input: RequestInfo | URL, init?: RequestInit) => {
return new Response(JSON.stringify(handler(input, init)), {
status: 200,
headers: { "Content-Type": "application/json" },
});
});
return provider;
vi.stubGlobal("fetch", fetchMock);
return fetchMock;
}
function fetchJsonBody(fetchMock: ReturnType<typeof vi.fn>, index: number): unknown {
const init = fetchMock.mock.calls[index]?.[1] as RequestInit | undefined;
const body = init?.body;
if (typeof body !== "string") {
throw new Error("Expected JSON string request body.");
}
return JSON.parse(body) as unknown;
}
describe("Gemini embedding request helpers", () => {
@@ -149,24 +105,9 @@ describe("Gemini embedding request helpers", () => {
});
});
describe("gemini embedding provider", () => {
describe("Gemini embedding provider", () => {
it("handles legacy and v2 request/response behavior", async () => {
const legacyFetch = createGeminiBatchFetchMock(2);
const legacyProvider = await createProviderWithFetch(legacyFetch, {
model: "gemini-embedding-001",
});
await legacyProvider.embedQuery("test query");
await legacyProvider.embedBatch(["text1", "text2"]);
expect(parseFetchBody(legacyFetch, 0)).toMatchObject({
taskType: "RETRIEVAL_QUERY",
content: { parts: [{ text: "test query" }] },
});
expect(parseFetchBody(legacyFetch, 0)).not.toHaveProperty("outputDimensionality");
expect(parseFetchBody(legacyFetch, 1)).not.toHaveProperty("outputDimensionality");
const v2Fetch = createJsonResponseFetchMock((input) => {
const fetchMock = installFetchMock((input) => {
const url = input instanceof URL ? input.href : typeof input === "string" ? input : input.url;
return url.endsWith(":batchEmbedContents")
? {
@@ -176,16 +117,22 @@ describe("gemini embedding provider", () => {
}
: { embedding: { values: [3, 4, Number.NaN] } };
});
const v2Provider = await createProviderWithFetch(v2Fetch, {
const { provider } = await createGeminiEmbeddingProvider({
config: {} as never,
provider: "gemini",
remote: { apiKey: "test-key" },
model: "gemini-embedding-2-preview",
outputDimensionality: 768,
taskType: "SEMANTIC_SIMILARITY",
fallback: "none",
});
await expect(v2Provider.embedQuery(" ")).resolves.toEqual([]);
await expect(v2Provider.embedBatch([])).resolves.toEqual([]);
await expect(v2Provider.embedQuery("test query")).resolves.toEqual([0.6, 0.8, 0]);
const structuredBatch = await v2Provider.embedBatchInputs?.([
await expect(provider.embedQuery(" ")).resolves.toEqual([]);
await expect(provider.embedBatch([])).resolves.toEqual([]);
await expect(provider.embedQuery("test query")).resolves.toEqual([0.6, 0.8, 0]);
const structuredBatch = await provider.embedBatchInputs?.([
{
text: "Image file: diagram.png",
parts: [
@@ -206,38 +153,39 @@ describe("gemini embedding provider", () => {
[0, 0, 1],
]);
const { url } = readFirstFetchRequest(v2Fetch);
expect(url).toBe(
expect(fetchMock.mock.calls[0]?.[0]).toBe(
"https://generativelanguage.googleapis.com/v1beta/models/gemini-embedding-2-preview:embedContent",
);
expect(parseFetchBody(v2Fetch, 0)).toMatchObject({
expect(fetchJsonBody(fetchMock, 0)).toMatchObject({
outputDimensionality: 768,
taskType: "SEMANTIC_SIMILARITY",
content: { parts: [{ text: "test query" }] },
});
expect(parseFetchBody(v2Fetch, 1).requests).toEqual([
{
model: "models/gemini-embedding-2-preview",
content: {
parts: [
{ text: "Image file: diagram.png" },
{ inlineData: { mimeType: "image/png", data: "img" } },
],
expect(fetchJsonBody(fetchMock, 1)).toMatchObject({
requests: [
{
model: "models/gemini-embedding-2-preview",
content: {
parts: [
{ text: "Image file: diagram.png" },
{ inlineData: { mimeType: "image/png", data: "img" } },
],
},
taskType: "SEMANTIC_SIMILARITY",
outputDimensionality: 768,
},
taskType: "SEMANTIC_SIMILARITY",
outputDimensionality: 768,
},
{
model: "models/gemini-embedding-2-preview",
content: {
parts: [
{ text: "Audio file: note.wav" },
{ inlineData: { mimeType: "audio/wav", data: "aud" } },
],
{
model: "models/gemini-embedding-2-preview",
content: {
parts: [
{ text: "Audio file: note.wav" },
{ inlineData: { mimeType: "audio/wav", data: "aud" } },
],
},
taskType: "SEMANTIC_SIMILARITY",
outputDimensionality: 768,
},
taskType: "SEMANTIC_SIMILARITY",
outputDimensionality: 768,
},
]);
],
});
});
});

View File

@@ -1,44 +1,22 @@
import { parseGeminiAuth } from "openclaw/plugin-sdk/image-generation-core";
import {
buildRemoteBaseUrlPolicy,
debugEmbeddingsLog,
sanitizeAndNormalizeEmbedding,
withRemoteHttpResponse,
type EmbeddingInput,
type MemoryEmbeddingProvider,
type MemoryEmbeddingProviderCreateOptions,
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
import { resolveMemorySecretInputString } from "openclaw/plugin-sdk/memory-core-host-secret";
import {
collectProviderApiKeysForExecution,
executeWithApiKeyRotation,
} from "../../agents/api-key-rotation.js";
import { requireApiKey, resolveApiKeyForProvider } from "../../agents/model-auth.js";
import { parseGeminiAuth } from "../../infra/gemini-auth.js";
import {
DEFAULT_GOOGLE_API_BASE_URL,
normalizeGoogleApiBaseUrl,
} from "../../infra/google-api-base-url.js";
import type { SsrFPolicy } from "../../infra/net/ssrf.js";
import { normalizeOptionalString } from "../../shared/string-coerce.js";
import type { EmbeddingInput } from "./embedding-inputs.js";
import { sanitizeAndNormalizeEmbedding } from "./embedding-vectors.js";
import { debugEmbeddingsLog } from "./embeddings-debug.js";
import {
buildGeminiEmbeddingRequest,
buildGeminiTextEmbeddingRequest,
isGeminiEmbedding2Model,
normalizeGeminiModel,
resolveGeminiOutputDimensionality,
} from "./embeddings-gemini-request.js";
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.types.js";
import { buildRemoteBaseUrlPolicy, withRemoteHttpResponse } from "./remote-http.js";
import { resolveMemorySecretInputString } from "./secret-input.js";
export {
buildGeminiEmbeddingRequest,
buildGeminiTextEmbeddingRequest,
DEFAULT_GEMINI_EMBEDDING_MODEL,
GEMINI_EMBEDDING_2_MODELS,
isGeminiEmbedding2Model,
normalizeGeminiModel,
resolveGeminiOutputDimensionality,
type GeminiEmbeddingRequest,
type GeminiInlinePart,
type GeminiPart,
type GeminiTaskType,
type GeminiTextEmbeddingRequest,
type GeminiTextPart,
} from "./embeddings-gemini-request.js";
requireApiKey,
resolveApiKeyForProvider,
} from "openclaw/plugin-sdk/provider-auth-runtime";
import type { SsrFPolicy } from "openclaw/plugin-sdk/ssrf-runtime";
import { normalizeOptionalString } from "openclaw/plugin-sdk/text-runtime";
export type GeminiEmbeddingClient = {
baseUrl: string;
@@ -50,9 +28,111 @@ export type GeminiEmbeddingClient = {
outputDimensionality?: number;
};
export const DEFAULT_GEMINI_EMBEDDING_MODEL = "gemini-embedding-001";
const DEFAULT_GOOGLE_API_BASE_URL = "https://generativelanguage.googleapis.com/v1beta";
const GEMINI_MAX_INPUT_TOKENS: Record<string, number> = {
"text-embedding-004": 2048,
"gemini-embedding-001": 2048,
"gemini-embedding-2-preview": 8192,
};
export type GeminiTaskType = NonNullable<MemoryEmbeddingProviderCreateOptions["taskType"]>;
// --- gemini-embedding-2-preview support ---
export const GEMINI_EMBEDDING_2_MODELS = new Set([
"gemini-embedding-2-preview",
// Add the GA model name here once released.
]);
const GEMINI_EMBEDDING_2_DEFAULT_DIMENSIONS = 3072;
const GEMINI_EMBEDDING_2_VALID_DIMENSIONS = [768, 1536, 3072] as const;
export type GeminiTextPart = { text: string };
export type GeminiInlinePart = {
inlineData: { mimeType: string; data: string };
};
export type GeminiPart = GeminiTextPart | GeminiInlinePart;
export type GeminiEmbeddingRequest = {
content: { parts: GeminiPart[] };
taskType: GeminiTaskType;
outputDimensionality?: number;
model?: string;
};
export type GeminiTextEmbeddingRequest = GeminiEmbeddingRequest;
/** Builds the text-only Gemini embedding request shape used across direct and batch APIs. */
export function buildGeminiTextEmbeddingRequest(params: {
text: string;
taskType: GeminiTaskType;
outputDimensionality?: number;
modelPath?: string;
}): GeminiTextEmbeddingRequest {
return buildGeminiEmbeddingRequest({
input: { text: params.text },
taskType: params.taskType,
outputDimensionality: params.outputDimensionality,
modelPath: params.modelPath,
});
}
export function buildGeminiEmbeddingRequest(params: {
input: EmbeddingInput;
taskType: GeminiTaskType;
outputDimensionality?: number;
modelPath?: string;
}): GeminiEmbeddingRequest {
const request: GeminiEmbeddingRequest = {
content: {
parts: params.input.parts?.map((part) =>
part.type === "text"
? ({ text: part.text } satisfies GeminiTextPart)
: ({
inlineData: { mimeType: part.mimeType, data: part.data },
} satisfies GeminiInlinePart),
) ?? [{ text: params.input.text }],
},
taskType: params.taskType,
};
if (params.modelPath) {
request.model = params.modelPath;
}
if (params.outputDimensionality != null) {
request.outputDimensionality = params.outputDimensionality;
}
return request;
}
/**
* Returns true if the given model name is a gemini-embedding-2 variant that
* supports `outputDimensionality` and extended task types.
*/
export function isGeminiEmbedding2Model(model: string): boolean {
return GEMINI_EMBEDDING_2_MODELS.has(model);
}
/**
* Validate and return the `outputDimensionality` for gemini-embedding-2 models.
* Returns `undefined` for older models (they don't support the param).
*/
export function resolveGeminiOutputDimensionality(
model: string,
requested?: number,
): number | undefined {
if (!isGeminiEmbedding2Model(model)) {
return undefined;
}
if (requested == null) {
return GEMINI_EMBEDDING_2_DEFAULT_DIMENSIONS;
}
const valid: readonly number[] = GEMINI_EMBEDDING_2_VALID_DIMENSIONS;
if (!valid.includes(requested)) {
throw new Error(
`Invalid outputDimensionality ${requested} for ${model}. Valid values: ${valid.join(", ")}`,
);
}
return requested;
}
function resolveRemoteApiKey(remoteApiKey: unknown): string | undefined {
const trimmed = resolveMemorySecretInputString({
value: remoteApiKey,
@@ -67,6 +147,21 @@ function resolveRemoteApiKey(remoteApiKey: unknown): string | undefined {
return trimmed;
}
export function normalizeGeminiModel(model: string): string {
const trimmed = model.trim();
if (!trimmed) {
return DEFAULT_GEMINI_EMBEDDING_MODEL;
}
const withoutPrefix = trimmed.replace(/^models\//, "");
if (withoutPrefix.startsWith("gemini/")) {
return withoutPrefix.slice("gemini/".length);
}
if (withoutPrefix.startsWith("google/")) {
return withoutPrefix.slice("google/".length);
}
return withoutPrefix;
}
async function fetchGeminiEmbeddingPayload(params: {
client: GeminiEmbeddingClient;
endpoint: string;
@@ -120,9 +215,30 @@ function buildGeminiModelPath(model: string): string {
return model.startsWith("models/") ? model : `models/${model}`;
}
function normalizeGoogleApiBaseUrl(baseUrl: string): string {
const trimmed = baseUrl.trim().replace(/\/+$/, "");
if (!trimmed) {
return DEFAULT_GOOGLE_API_BASE_URL;
}
try {
const url = new URL(trimmed);
url.hash = "";
url.search = "";
if (
url.origin.toLowerCase() === "https://generativelanguage.googleapis.com" &&
url.pathname.replace(/\/+$/, "") === ""
) {
url.pathname = "/v1beta";
}
return url.toString().replace(/\/+$/, "");
} catch {
return trimmed;
}
}
export async function createGeminiEmbeddingProvider(
options: EmbeddingProviderOptions,
): Promise<{ provider: EmbeddingProvider; client: GeminiEmbeddingClient }> {
options: MemoryEmbeddingProviderCreateOptions,
): Promise<{ provider: MemoryEmbeddingProvider; client: GeminiEmbeddingClient }> {
const client = await resolveGeminiEmbeddingClient(options);
const baseUrl = client.baseUrl.replace(/\/$/, "");
const embedUrl = `${baseUrl}/${client.modelPath}:embedContent`;
@@ -190,7 +306,7 @@ export async function createGeminiEmbeddingProvider(
}
export async function resolveGeminiEmbeddingClient(
options: EmbeddingProviderOptions,
options: MemoryEmbeddingProviderCreateOptions,
): Promise<GeminiEmbeddingClient> {
const remote = options.remote;
const remoteApiKey = resolveRemoteApiKey(remote?.apiKey);

View File

@@ -3,6 +3,7 @@ import type { MediaUnderstandingProvider } from "openclaw/plugin-sdk/media-under
import { definePluginEntry } from "openclaw/plugin-sdk/plugin-entry";
import { buildGoogleGeminiCliBackend } from "./cli-backend.js";
import { registerGoogleGeminiCliProvider } from "./gemini-cli-provider.js";
import { geminiMemoryEmbeddingProviderAdapter } from "./memory-embedding-adapter.js";
import { buildGoogleMusicGenerationProvider } from "./music-generation-provider.js";
import { registerGoogleProvider } from "./provider-registration.js";
import { buildGoogleSpeechProvider } from "./speech-provider.js";
@@ -111,6 +112,7 @@ export default definePluginEntry({
api.registerCliBackend(buildGoogleGeminiCliBackend());
registerGoogleGeminiCliProvider(api);
registerGoogleProvider(api);
api.registerMemoryEmbeddingProvider(geminiMemoryEmbeddingProviderAdapter);
api.registerImageGenerationProvider(createLazyGoogleImageGenerationProvider());
api.registerMediaUnderstandingProvider(createLazyGoogleMediaUnderstandingProvider());
api.registerMusicGenerationProvider(buildGoogleMusicGenerationProvider());

View File

@@ -0,0 +1,79 @@
import {
hasNonTextEmbeddingParts,
isMissingEmbeddingApiKeyError,
mapBatchEmbeddingsByIndex,
sanitizeEmbeddingCacheHeaders,
type MemoryEmbeddingProviderAdapter,
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
import { runGeminiEmbeddingBatches } from "./embedding-batch.js";
import {
buildGeminiEmbeddingRequest,
createGeminiEmbeddingProvider,
DEFAULT_GEMINI_EMBEDDING_MODEL,
} from "./embedding-provider.js";
function supportsGeminiMultimodalEmbeddings(model: string): boolean {
const normalized = model
.trim()
.replace(/^models\//, "")
.replace(/^(gemini|google)\//, "");
return normalized === "gemini-embedding-2-preview";
}
export const geminiMemoryEmbeddingProviderAdapter: MemoryEmbeddingProviderAdapter = {
id: "gemini",
defaultModel: DEFAULT_GEMINI_EMBEDDING_MODEL,
transport: "remote",
authProviderId: "google",
autoSelectPriority: 30,
allowExplicitWhenConfiguredAuto: true,
supportsMultimodalEmbeddings: ({ model }) => supportsGeminiMultimodalEmbeddings(model),
shouldContinueAutoSelection: isMissingEmbeddingApiKeyError,
create: async (options) => {
const { provider, client } = await createGeminiEmbeddingProvider({
...options,
provider: "gemini",
fallback: "none",
});
return {
provider,
runtime: {
id: "gemini",
cacheKeyData: {
provider: "gemini",
baseUrl: client.baseUrl,
model: client.model,
outputDimensionality: client.outputDimensionality,
headers: sanitizeEmbeddingCacheHeaders(client.headers, [
"authorization",
"x-goog-api-key",
]),
},
batchEmbed: async (batch) => {
if (batch.chunks.some((chunk) => hasNonTextEmbeddingParts(chunk.embeddingInput))) {
return null;
}
const byCustomId = await runGeminiEmbeddingBatches({
gemini: client,
agentId: batch.agentId,
requests: batch.chunks.map((chunk, index) => ({
custom_id: String(index),
request: buildGeminiEmbeddingRequest({
input: chunk.embeddingInput ?? { text: chunk.text },
taskType: "RETRIEVAL_DOCUMENT",
modelPath: client.modelPath,
outputDimensionality: client.outputDimensionality,
}),
})),
wait: batch.wait,
concurrency: batch.concurrency,
pollIntervalMs: batch.pollIntervalMs,
timeoutMs: batch.timeoutMs,
debug: batch.debug,
});
return mapBatchEmbeddingsByIndex(byCustomId, batch.chunks.length);
},
},
};
},
};

View File

@@ -46,6 +46,7 @@
},
"contracts": {
"mediaUnderstandingProviders": ["google"],
"memoryEmbeddingProviders": ["gemini"],
"imageGenerationProviders": ["google"],
"musicGenerationProviders": ["google"],
"speechProviders": ["google"],

View File

@@ -8,6 +8,7 @@ import {
type ProviderRuntimeModel,
} from "openclaw/plugin-sdk/plugin-entry";
import { CUSTOM_LOCAL_AUTH_MARKER } from "openclaw/plugin-sdk/provider-auth";
import { lmstudioMemoryEmbeddingProviderAdapter } from "./memory-embedding-adapter.js";
import {
LMSTUDIO_DEFAULT_API_KEY_ENV_VAR,
LMSTUDIO_LOCAL_API_KEY_PLACEHOLDER,
@@ -52,6 +53,7 @@ export default definePluginEntry({
name: "LM Studio Provider",
description: "Bundled LM Studio provider plugin",
register(api: OpenClawPluginApi) {
api.registerMemoryEmbeddingProvider(lmstudioMemoryEmbeddingProviderAdapter);
api.registerProvider({
id: PROVIDER_ID,
label: "LM Studio",

View File

@@ -0,0 +1,35 @@
import {
sanitizeEmbeddingCacheHeaders,
type MemoryEmbeddingProviderAdapter,
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
import {
createLmstudioEmbeddingProvider,
DEFAULT_LMSTUDIO_EMBEDDING_MODEL,
} from "./src/embedding-provider.js";
export const lmstudioMemoryEmbeddingProviderAdapter: MemoryEmbeddingProviderAdapter = {
id: "lmstudio",
defaultModel: DEFAULT_LMSTUDIO_EMBEDDING_MODEL,
transport: "remote",
authProviderId: "lmstudio",
allowExplicitWhenConfiguredAuto: true,
create: async (options) => {
const { provider, client } = await createLmstudioEmbeddingProvider({
...options,
provider: "lmstudio",
fallback: "none",
});
return {
provider,
runtime: {
id: "lmstudio",
cacheKeyData: {
provider: "lmstudio",
baseUrl: client.baseUrl,
model: client.model,
headers: sanitizeEmbeddingCacheHeaders(client.headers, ["authorization"]),
},
},
};
},
};

View File

@@ -21,6 +21,9 @@
"groupHint": "Self-hosted open-weight models"
}
],
"contracts": {
"memoryEmbeddingProviders": ["lmstudio"]
},
"configSchema": {
"type": "object",
"additionalProperties": false,

View File

@@ -1,20 +1,21 @@
import { formatErrorMessage } from "../../infra/errors.js";
import type { SsrFPolicy } from "../../infra/net/ssrf.js";
import { createSubsystemLogger } from "../../logging/subsystem.js";
import { createSubsystemLogger } from "openclaw/plugin-sdk/logging-core";
import {
buildRemoteBaseUrlPolicy,
createRemoteEmbeddingProvider,
normalizeEmbeddingModelWithPrefixes,
type MemoryEmbeddingProvider,
type MemoryEmbeddingProviderCreateOptions,
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
import { resolveMemorySecretInputString } from "openclaw/plugin-sdk/memory-core-host-secret";
import { formatErrorMessage, type SsrFPolicy } from "openclaw/plugin-sdk/ssrf-runtime";
import { LMSTUDIO_DEFAULT_EMBEDDING_MODEL, LMSTUDIO_PROVIDER_ID } from "./defaults.js";
import { ensureLmstudioModelLoaded } from "./models.fetch.js";
import { resolveLmstudioInferenceBase } from "./models.js";
import {
buildLmstudioAuthHeaders,
ensureLmstudioModelLoaded,
LMSTUDIO_DEFAULT_EMBEDDING_MODEL,
LMSTUDIO_PROVIDER_ID,
resolveLmstudioInferenceBase,
resolveLmstudioProviderHeaders,
resolveLmstudioRuntimeApiKey,
} from "../../plugin-sdk/lmstudio-runtime.js";
import { normalizeEmbeddingModelWithPrefixes } from "./embeddings-model-normalize.js";
import { createRemoteEmbeddingProvider } from "./embeddings-remote-provider.js";
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.types.js";
import { buildRemoteBaseUrlPolicy } from "./remote-http.js";
import { resolveMemorySecretInputString } from "./secret-input.js";
} from "./runtime.js";
const log = createSubsystemLogger("memory/embeddings");
@@ -47,7 +48,7 @@ function hasAuthorizationHeader(headers: Record<string, string> | undefined): bo
/** Resolves API key (real or synthetic placeholder) from runtime/provider auth config. */
async function resolveLmstudioApiKey(
options: EmbeddingProviderOptions,
options: MemoryEmbeddingProviderCreateOptions,
): Promise<string | undefined> {
try {
return await resolveLmstudioRuntimeApiKey({
@@ -65,8 +66,8 @@ async function resolveLmstudioApiKey(
/** Creates the LM Studio embedding provider client and preloads the target model before return. */
export async function createLmstudioEmbeddingProvider(
options: EmbeddingProviderOptions,
): Promise<{ provider: EmbeddingProvider; client: LmstudioEmbeddingClient }> {
options: MemoryEmbeddingProviderCreateOptions,
): Promise<{ provider: MemoryEmbeddingProvider; client: LmstudioEmbeddingClient }> {
const providerConfig = options.config.models?.providers?.lmstudio;
const providerBaseUrl = providerConfig?.baseUrl?.trim();
const isFallbackActivation = options.fallback === "lmstudio" && options.provider !== "lmstudio";

View File

@@ -1,10 +1,5 @@
import {
DEFAULT_GEMINI_EMBEDDING_MODEL,
DEFAULT_LOCAL_MODEL,
DEFAULT_MISTRAL_EMBEDDING_MODEL,
DEFAULT_OLLAMA_EMBEDDING_MODEL,
DEFAULT_OPENAI_EMBEDDING_MODEL,
DEFAULT_VOYAGE_EMBEDDING_MODEL,
getMemoryEmbeddingProvider,
listMemoryEmbeddingProviders,
type MemoryEmbeddingProvider,
@@ -15,15 +10,7 @@ import {
import { formatErrorMessage } from "../dreaming-shared.js";
import { canAutoSelectLocal } from "./provider-adapters.js";
export {
DEFAULT_GEMINI_EMBEDDING_MODEL,
DEFAULT_LMSTUDIO_EMBEDDING_MODEL,
DEFAULT_LOCAL_MODEL,
DEFAULT_MISTRAL_EMBEDDING_MODEL,
DEFAULT_OLLAMA_EMBEDDING_MODEL,
DEFAULT_OPENAI_EMBEDDING_MODEL,
DEFAULT_VOYAGE_EMBEDDING_MODEL,
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
export { DEFAULT_LOCAL_MODEL } from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
export type EmbeddingProvider = MemoryEmbeddingProvider;
export type EmbeddingProviderId = string;

View File

@@ -11,9 +11,9 @@ import {
} from "../../../../src/plugins/memory-embedding-providers.js";
import "./test-runtime-mocks.js";
import type { MemoryIndexManager } from "./index.js";
import { getMemorySearchManager, closeAllMemorySearchManagers } from "./index.js";
import { closeAllMemorySearchManagers, getMemorySearchManager } from "./index.js";
import {
DEFAULT_OLLAMA_EMBEDDING_MODEL,
DEFAULT_LOCAL_MODEL,
registerBuiltInMemoryEmbeddingProviders,
} from "./provider-adapters.js";
@@ -112,14 +112,14 @@ vi.mock("./embeddings.js", () => {
});
describe("memory index", () => {
it("registers the builtin ollama embedding provider", () => {
const adapter = listRegisteredAdapters().find((entry) => entry.id === "ollama");
it("registers the builtin local embedding provider", () => {
const adapter = listRegisteredAdapters().find((entry) => entry.id === "local");
expect(adapter).toBeDefined();
expect(adapter).toEqual(
expect.objectContaining({
id: "ollama",
defaultModel: DEFAULT_OLLAMA_EMBEDDING_MODEL,
id: "local",
defaultModel: DEFAULT_LOCAL_MODEL,
}),
);
});

View File

@@ -1,31 +1,13 @@
import fsSync from "node:fs";
import {
DEFAULT_GEMINI_EMBEDDING_MODEL,
DEFAULT_LMSTUDIO_EMBEDDING_MODEL,
DEFAULT_LOCAL_MODEL,
DEFAULT_MISTRAL_EMBEDDING_MODEL,
DEFAULT_OLLAMA_EMBEDDING_MODEL,
DEFAULT_OPENAI_EMBEDDING_MODEL,
DEFAULT_VOYAGE_EMBEDDING_MODEL,
OPENAI_BATCH_ENDPOINT,
buildGeminiEmbeddingRequest,
createGeminiEmbeddingProvider,
createLmstudioEmbeddingProvider,
createLocalEmbeddingProvider,
createMistralEmbeddingProvider,
createOllamaEmbeddingProvider,
createOpenAiEmbeddingProvider,
createVoyageEmbeddingProvider,
hasNonTextEmbeddingParts,
DEFAULT_LOCAL_MODEL,
listMemoryEmbeddingProviders,
listRegisteredMemoryEmbeddingProviderAdapters,
runGeminiEmbeddingBatches,
runOpenAiEmbeddingBatches,
runVoyageEmbeddingBatches,
type MemoryEmbeddingProviderAdapter,
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
import { resolveUserPath } from "openclaw/plugin-sdk/memory-core-host-engine-foundation";
import { getProviderEnvVars } from "openclaw/plugin-sdk/provider-env-vars";
import { normalizeLowercaseStringOrEmpty } from "openclaw/plugin-sdk/text-runtime";
import { formatErrorMessage } from "../dreaming-shared.js";
import { filterUnregisteredMemoryEmbeddingProviderAdapters } from "./provider-adapter-registration.js";
@@ -37,31 +19,6 @@ export type BuiltinMemoryEmbeddingProviderDoctorMetadata = {
autoSelectPriority?: number;
};
function isMissingApiKeyError(err: unknown): boolean {
return formatErrorMessage(err).includes("No API key found for provider");
}
function sanitizeHeaders(
headers: Record<string, string>,
excludedHeaderNames: string[],
): Array<[string, string]> {
const excluded = new Set(
excludedHeaderNames.map((name) => normalizeLowercaseStringOrEmpty(name)),
);
return Object.entries(headers)
.filter(([key]) => !excluded.has(normalizeLowercaseStringOrEmpty(key)))
.toSorted(([a], [b]) => a.localeCompare(b))
.map(([key, value]) => [key, value]);
}
function mapBatchEmbeddingsByIndex(byCustomId: Map<string, number[]>, count: number): number[][] {
const embeddings: number[][] = [];
for (let index = 0; index < count; index += 1) {
embeddings.push(byCustomId.get(String(index)) ?? []);
}
return embeddings;
}
function isNodeLlamaCppMissing(err: unknown): boolean {
if (!(err instanceof Error)) {
return false;
@@ -70,6 +27,20 @@ function isNodeLlamaCppMissing(err: unknown): boolean {
return code === "ERR_MODULE_NOT_FOUND" && err.message.includes("node-llama-cpp");
}
function listRemoteEmbeddingSetupHints(): string[] {
try {
return listMemoryEmbeddingProviders()
.filter(
(adapter) =>
adapter.transport === "remote" && typeof adapter.autoSelectPriority === "number",
)
.toSorted((a, b) => (a.autoSelectPriority ?? 0) - (b.autoSelectPriority ?? 0))
.map((adapter) => `Or set agents.defaults.memorySearch.provider = "${adapter.id}" (remote).`);
} catch {
return [];
}
}
function formatLocalSetupError(err: unknown): string {
const detail = formatErrorMessage(err);
const missing = isNodeLlamaCppMissing(err);
@@ -87,9 +58,7 @@ function formatLocalSetupError(err: unknown): string {
? "2) Reinstall OpenClaw (this should install node-llama-cpp): npm i -g openclaw@latest"
: null,
"3) If you use pnpm: pnpm approve-builds (select node-llama-cpp), then pnpm rebuild node-llama-cpp",
...["openai", "gemini", "voyage", "mistral"].map(
(provider) => `Or set agents.defaults.memorySearch.provider = "${provider}" (remote).`,
),
...listRemoteEmbeddingSetupHints(),
]
.filter(Boolean)
.join("\n");
@@ -111,237 +80,6 @@ function canAutoSelectLocal(modelPath?: string): boolean {
}
}
function supportsGeminiMultimodalEmbeddings(model: string): boolean {
const normalized = model
.trim()
.replace(/^models\//, "")
.replace(/^(gemini|google)\//, "");
return normalized === "gemini-embedding-2-preview";
}
function resolveMemoryEmbeddingAuthProviderId(providerId: string): string {
return providerId === "gemini" ? "google" : providerId;
}
const openAiAdapter: MemoryEmbeddingProviderAdapter = {
id: "openai",
defaultModel: DEFAULT_OPENAI_EMBEDDING_MODEL,
transport: "remote",
autoSelectPriority: 20,
allowExplicitWhenConfiguredAuto: true,
shouldContinueAutoSelection: isMissingApiKeyError,
create: async (options) => {
const { provider, client } = await createOpenAiEmbeddingProvider({
...options,
provider: "openai",
fallback: "none",
});
return {
provider,
runtime: {
id: "openai",
cacheKeyData: {
provider: "openai",
baseUrl: client.baseUrl,
model: client.model,
headers: sanitizeHeaders(client.headers, ["authorization"]),
},
batchEmbed: async (batch) => {
const byCustomId = await runOpenAiEmbeddingBatches({
openAi: client,
agentId: batch.agentId,
requests: batch.chunks.map((chunk, index) => ({
custom_id: String(index),
method: "POST",
url: OPENAI_BATCH_ENDPOINT,
body: {
model: client.model,
input: chunk.text,
},
})),
wait: batch.wait,
concurrency: batch.concurrency,
pollIntervalMs: batch.pollIntervalMs,
timeoutMs: batch.timeoutMs,
debug: batch.debug,
});
return mapBatchEmbeddingsByIndex(byCustomId, batch.chunks.length);
},
},
};
},
};
const geminiAdapter: MemoryEmbeddingProviderAdapter = {
id: "gemini",
defaultModel: DEFAULT_GEMINI_EMBEDDING_MODEL,
transport: "remote",
autoSelectPriority: 30,
allowExplicitWhenConfiguredAuto: true,
supportsMultimodalEmbeddings: ({ model }) => supportsGeminiMultimodalEmbeddings(model),
shouldContinueAutoSelection: isMissingApiKeyError,
create: async (options) => {
const { provider, client } = await createGeminiEmbeddingProvider({
...options,
provider: "gemini",
fallback: "none",
});
return {
provider,
runtime: {
id: "gemini",
cacheKeyData: {
provider: "gemini",
baseUrl: client.baseUrl,
model: client.model,
outputDimensionality: client.outputDimensionality,
headers: sanitizeHeaders(client.headers, ["authorization", "x-goog-api-key"]),
},
batchEmbed: async (batch) => {
if (batch.chunks.some((chunk) => hasNonTextEmbeddingParts(chunk.embeddingInput))) {
return null;
}
const byCustomId = await runGeminiEmbeddingBatches({
gemini: client,
agentId: batch.agentId,
requests: batch.chunks.map((chunk, index) => ({
custom_id: String(index),
request: buildGeminiEmbeddingRequest({
input: chunk.embeddingInput ?? { text: chunk.text },
taskType: "RETRIEVAL_DOCUMENT",
modelPath: client.modelPath,
outputDimensionality: client.outputDimensionality,
}),
})),
wait: batch.wait,
concurrency: batch.concurrency,
pollIntervalMs: batch.pollIntervalMs,
timeoutMs: batch.timeoutMs,
debug: batch.debug,
});
return mapBatchEmbeddingsByIndex(byCustomId, batch.chunks.length);
},
},
};
},
};
const voyageAdapter: MemoryEmbeddingProviderAdapter = {
id: "voyage",
defaultModel: DEFAULT_VOYAGE_EMBEDDING_MODEL,
transport: "remote",
autoSelectPriority: 40,
allowExplicitWhenConfiguredAuto: true,
shouldContinueAutoSelection: isMissingApiKeyError,
create: async (options) => {
const { provider, client } = await createVoyageEmbeddingProvider({
...options,
provider: "voyage",
fallback: "none",
});
return {
provider,
runtime: {
id: "voyage",
batchEmbed: async (batch) => {
const byCustomId = await runVoyageEmbeddingBatches({
client,
agentId: batch.agentId,
requests: batch.chunks.map((chunk, index) => ({
custom_id: String(index),
body: {
input: chunk.text,
},
})),
wait: batch.wait,
concurrency: batch.concurrency,
pollIntervalMs: batch.pollIntervalMs,
timeoutMs: batch.timeoutMs,
debug: batch.debug,
});
return mapBatchEmbeddingsByIndex(byCustomId, batch.chunks.length);
},
},
};
},
};
const mistralAdapter: MemoryEmbeddingProviderAdapter = {
id: "mistral",
defaultModel: DEFAULT_MISTRAL_EMBEDDING_MODEL,
transport: "remote",
autoSelectPriority: 50,
allowExplicitWhenConfiguredAuto: true,
shouldContinueAutoSelection: isMissingApiKeyError,
create: async (options) => {
const { provider, client } = await createMistralEmbeddingProvider({
...options,
provider: "mistral",
fallback: "none",
});
return {
provider,
runtime: {
id: "mistral",
cacheKeyData: {
provider: "mistral",
model: client.model,
},
},
};
},
};
const ollamaAdapter: MemoryEmbeddingProviderAdapter = {
id: "ollama",
defaultModel: DEFAULT_OLLAMA_EMBEDDING_MODEL,
transport: "remote",
create: async (options) => {
const { provider, client } = await createOllamaEmbeddingProvider({
...options,
provider: "ollama",
fallback: "none",
});
return {
provider,
runtime: {
id: "ollama",
cacheKeyData: {
provider: "ollama",
baseUrl: client.baseUrl,
model: client.model,
headers: sanitizeHeaders(client.headers, ["authorization"]),
},
},
};
},
};
const lmstudioAdapter: MemoryEmbeddingProviderAdapter = {
id: "lmstudio",
defaultModel: DEFAULT_LMSTUDIO_EMBEDDING_MODEL,
transport: "remote",
create: async (options) => {
const { provider, client } = await createLmstudioEmbeddingProvider({
...options,
provider: "lmstudio",
fallback: "none",
});
return {
provider,
runtime: {
id: "lmstudio",
cacheKeyData: {
provider: "lmstudio",
baseUrl: client.baseUrl,
model: client.model,
headers: sanitizeHeaders(client.headers, ["authorization"]),
},
},
};
},
};
const localAdapter: MemoryEmbeddingProviderAdapter = {
id: "local",
defaultModel: DEFAULT_LOCAL_MODEL,
@@ -368,24 +106,14 @@ const localAdapter: MemoryEmbeddingProviderAdapter = {
},
};
export const builtinMemoryEmbeddingProviderAdapters = [
localAdapter,
openAiAdapter,
geminiAdapter,
voyageAdapter,
mistralAdapter,
ollamaAdapter,
lmstudioAdapter,
] as const;
export const builtinMemoryEmbeddingProviderAdapters = [localAdapter] as const;
const builtinMemoryEmbeddingProviderAdapterById = new Map(
builtinMemoryEmbeddingProviderAdapters.map((adapter) => [adapter.id, adapter]),
);
export { DEFAULT_LOCAL_MODEL };
export function getBuiltinMemoryEmbeddingProviderAdapter(
id: string,
): MemoryEmbeddingProviderAdapter | undefined {
return builtinMemoryEmbeddingProviderAdapterById.get(id);
return listMemoryEmbeddingProviders().find((adapter) => adapter.id === id);
}
export function registerBuiltInMemoryEmbeddingProviders(register: {
@@ -409,7 +137,7 @@ export function getBuiltinMemoryEmbeddingProviderDoctorMetadata(
if (!adapter) {
return null;
}
const authProviderId = resolveMemoryEmbeddingAuthProviderId(adapter.id);
const authProviderId = adapter.authProviderId ?? adapter.id;
return {
providerId: adapter.id,
authProviderId,
@@ -420,27 +148,19 @@ export function getBuiltinMemoryEmbeddingProviderDoctorMetadata(
}
export function listBuiltinAutoSelectMemoryEmbeddingProviderDoctorMetadata(): Array<BuiltinMemoryEmbeddingProviderDoctorMetadata> {
return builtinMemoryEmbeddingProviderAdapters
return listMemoryEmbeddingProviders()
.filter((adapter) => typeof adapter.autoSelectPriority === "number")
.toSorted((a, b) => (a.autoSelectPriority ?? 0) - (b.autoSelectPriority ?? 0))
.map((adapter) => ({
providerId: adapter.id,
authProviderId: resolveMemoryEmbeddingAuthProviderId(adapter.id),
envVars: getProviderEnvVars(resolveMemoryEmbeddingAuthProviderId(adapter.id)),
transport: adapter.transport === "local" ? "local" : "remote",
autoSelectPriority: adapter.autoSelectPriority,
}));
.map((adapter) => {
const authProviderId = adapter.authProviderId ?? adapter.id;
return {
providerId: adapter.id,
authProviderId,
envVars: getProviderEnvVars(authProviderId),
transport: adapter.transport === "local" ? "local" : "remote",
autoSelectPriority: adapter.autoSelectPriority,
};
});
}
export {
DEFAULT_GEMINI_EMBEDDING_MODEL,
DEFAULT_LMSTUDIO_EMBEDDING_MODEL,
DEFAULT_LOCAL_MODEL,
DEFAULT_MISTRAL_EMBEDDING_MODEL,
DEFAULT_OLLAMA_EMBEDDING_MODEL,
DEFAULT_OPENAI_EMBEDDING_MODEL,
DEFAULT_VOYAGE_EMBEDDING_MODEL,
canAutoSelectLocal,
formatLocalSetupError,
isMissingApiKeyError,
};
export { canAutoSelectLocal, formatLocalSetupError };

View File

@@ -16,6 +16,7 @@ import {
writeFileWithinRoot,
type OpenClawConfig,
} from "openclaw/plugin-sdk/memory-core-host-engine-foundation";
import { resolveAgentContextLimits } from "openclaw/plugin-sdk/memory-core-host-engine-foundation";
import {
buildSessionEntry,
deriveQmdScopeChannel,
@@ -47,7 +48,6 @@ import {
type ResolvedQmdConfig,
type ResolvedQmdMcporterConfig,
} from "openclaw/plugin-sdk/memory-core-host-engine-storage";
import { resolveAgentContextLimits } from "openclaw/plugin-sdk/memory-core-host-engine-foundation";
import {
localeLowercasePreservingWhitespace,
normalizeLowercaseStringOrEmpty,
@@ -1945,8 +1945,7 @@ export class QmdMemoryManager implements MemorySearchManager {
from?: number,
lines?: number,
): Promise<
| { missing: true }
| { missing: false; selectedLines: string[]; moreSourceLinesRemain: boolean }
{ missing: true } | { missing: false; selectedLines: string[]; moreSourceLinesRemain: boolean }
> {
const start = Math.max(1, from ?? 1);
const count = Math.max(1, lines ?? Number.POSITIVE_INFINITY);

View File

@@ -1,10 +1,11 @@
import type { SsrFPolicy } from "../../infra/net/ssrf.js";
import { normalizeEmbeddingModelWithPrefixes } from "./embeddings-model-normalize.js";
import {
createRemoteEmbeddingProvider,
normalizeEmbeddingModelWithPrefixes,
resolveRemoteEmbeddingClient,
} from "./embeddings-remote-provider.js";
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.types.js";
type MemoryEmbeddingProvider,
type MemoryEmbeddingProviderCreateOptions,
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
import type { SsrFPolicy } from "openclaw/plugin-sdk/ssrf-runtime";
export type MistralEmbeddingClient = {
baseUrl: string;
@@ -25,8 +26,8 @@ export function normalizeMistralModel(model: string): string {
}
export async function createMistralEmbeddingProvider(
options: EmbeddingProviderOptions,
): Promise<{ provider: EmbeddingProvider; client: MistralEmbeddingClient }> {
options: MemoryEmbeddingProviderCreateOptions,
): Promise<{ provider: MemoryEmbeddingProvider; client: MistralEmbeddingClient }> {
const client = await resolveMistralEmbeddingClient(options);
return {
@@ -40,7 +41,7 @@ export async function createMistralEmbeddingProvider(
}
export async function resolveMistralEmbeddingClient(
options: EmbeddingProviderOptions,
options: MemoryEmbeddingProviderCreateOptions,
): Promise<MistralEmbeddingClient> {
return await resolveRemoteEmbeddingClient({
provider: "mistral",

View File

@@ -1,6 +1,7 @@
import { defineSingleProviderPluginEntry } from "openclaw/plugin-sdk/provider-entry";
import { applyMistralModelCompat } from "./api.js";
import { mistralMediaUnderstandingProvider } from "./media-understanding-provider.js";
import { mistralMemoryEmbeddingProviderAdapter } from "./memory-embedding-adapter.js";
import { applyMistralConfig, MISTRAL_DEFAULT_MODEL_REF } from "./onboard.js";
import { buildMistralProvider } from "./provider-catalog.js";
import { contributeMistralResolvedModelCompat } from "./provider-compat.js";
@@ -48,6 +49,7 @@ export default defineSingleProviderPluginEntry({
buildReplayPolicy: () => buildMistralReplayPolicy(),
},
register(api) {
api.registerMemoryEmbeddingProvider(mistralMemoryEmbeddingProviderAdapter);
api.registerMediaUnderstandingProvider(mistralMediaUnderstandingProvider);
},
});

View File

@@ -0,0 +1,35 @@
import {
isMissingEmbeddingApiKeyError,
type MemoryEmbeddingProviderAdapter,
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
import {
createMistralEmbeddingProvider,
DEFAULT_MISTRAL_EMBEDDING_MODEL,
} from "./embedding-provider.js";
export const mistralMemoryEmbeddingProviderAdapter: MemoryEmbeddingProviderAdapter = {
id: "mistral",
defaultModel: DEFAULT_MISTRAL_EMBEDDING_MODEL,
transport: "remote",
authProviderId: "mistral",
autoSelectPriority: 50,
allowExplicitWhenConfiguredAuto: true,
shouldContinueAutoSelection: isMissingEmbeddingApiKeyError,
create: async (options) => {
const { provider, client } = await createMistralEmbeddingProvider({
...options,
provider: "mistral",
fallback: "none",
});
return {
provider,
runtime: {
id: "mistral",
cacheKeyData: {
provider: "mistral",
model: client.model,
},
},
};
},
};

View File

@@ -21,6 +21,7 @@
}
],
"contracts": {
"memoryEmbeddingProviders": ["mistral"],
"mediaUnderstandingProviders": ["mistral"]
},
"configSchema": {

View File

@@ -8,6 +8,7 @@ export const ollamaMemoryEmbeddingProviderAdapter: MemoryEmbeddingProviderAdapte
id: "ollama",
defaultModel: DEFAULT_OLLAMA_EMBEDDING_MODEL,
transport: "remote",
authProviderId: "ollama",
create: async (options) => {
const { provider, client } = await createOllamaEmbeddingProvider({
...options,

View File

@@ -17,8 +17,8 @@ import {
type ProviderBatchOutputLine,
uploadBatchJsonlFile,
withRemoteHttpResponse,
} from "./batch-embedding-common.js";
import type { OpenAiEmbeddingClient } from "./embeddings-openai.js";
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
import type { OpenAiEmbeddingClient } from "./embedding-provider.js";
export type OpenAiBatchRequest = {
custom_id: string;

View File

@@ -1,11 +1,11 @@
import { parseStaticModelRef } from "../../agents/model-ref-shared.js";
import type { SsrFPolicy } from "../../infra/net/ssrf.js";
import { OPENAI_DEFAULT_EMBEDDING_MODEL } from "../../plugins/provider-model-defaults.js";
import {
createRemoteEmbeddingProvider,
resolveRemoteEmbeddingClient,
} from "./embeddings-remote-provider.js";
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.types.js";
type MemoryEmbeddingProvider,
type MemoryEmbeddingProviderCreateOptions,
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
import type { SsrFPolicy } from "openclaw/plugin-sdk/ssrf-runtime";
import { OPENAI_DEFAULT_EMBEDDING_MODEL } from "./default-models.js";
export type OpenAiEmbeddingClient = {
baseUrl: string;
@@ -28,13 +28,12 @@ export function normalizeOpenAiModel(model: string): string {
if (!trimmed) {
return DEFAULT_OPENAI_EMBEDDING_MODEL;
}
const parsed = parseStaticModelRef(trimmed, "openai");
return parsed && parsed.provider === "openai" ? parsed.model : trimmed;
return trimmed.startsWith("openai/") ? trimmed.slice("openai/".length) : trimmed;
}
export async function createOpenAiEmbeddingProvider(
options: EmbeddingProviderOptions,
): Promise<{ provider: EmbeddingProvider; client: OpenAiEmbeddingClient }> {
options: MemoryEmbeddingProviderCreateOptions,
): Promise<{ provider: MemoryEmbeddingProvider; client: OpenAiEmbeddingClient }> {
const client = await resolveOpenAiEmbeddingClient(options);
return {
@@ -49,7 +48,7 @@ export async function createOpenAiEmbeddingProvider(
}
export async function resolveOpenAiEmbeddingClient(
options: EmbeddingProviderOptions,
options: MemoryEmbeddingProviderCreateOptions,
): Promise<OpenAiEmbeddingClient> {
return await resolveRemoteEmbeddingClient({
provider: "openai",

View File

@@ -6,6 +6,7 @@ import {
openaiCodexMediaUnderstandingProvider,
openaiMediaUnderstandingProvider,
} from "./media-understanding-provider.js";
import { openAiMemoryEmbeddingProviderAdapter } from "./memory-embedding-adapter.js";
import { buildOpenAICodexProviderPlugin } from "./openai-codex-provider.js";
import { buildOpenAIProvider } from "./openai-provider.js";
import {
@@ -39,6 +40,7 @@ export default definePluginEntry({
api.registerCliBackend(buildOpenAICodexCliBackend());
api.registerProvider(buildProviderWithPromptContribution(buildOpenAIProvider()));
api.registerProvider(buildProviderWithPromptContribution(buildOpenAICodexProviderPlugin()));
api.registerMemoryEmbeddingProvider(openAiMemoryEmbeddingProviderAdapter);
api.registerImageGenerationProvider(buildOpenAIImageGenerationProvider());
api.registerRealtimeTranscriptionProvider(buildOpenAIRealtimeTranscriptionProvider());
api.registerRealtimeVoiceProvider(buildOpenAIRealtimeVoiceProvider());

View File

@@ -0,0 +1,61 @@
import {
isMissingEmbeddingApiKeyError,
mapBatchEmbeddingsByIndex,
sanitizeEmbeddingCacheHeaders,
type MemoryEmbeddingProviderAdapter,
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
import { OPENAI_BATCH_ENDPOINT, runOpenAiEmbeddingBatches } from "./embedding-batch.js";
import {
createOpenAiEmbeddingProvider,
DEFAULT_OPENAI_EMBEDDING_MODEL,
} from "./embedding-provider.js";
export const openAiMemoryEmbeddingProviderAdapter: MemoryEmbeddingProviderAdapter = {
id: "openai",
defaultModel: DEFAULT_OPENAI_EMBEDDING_MODEL,
transport: "remote",
authProviderId: "openai",
autoSelectPriority: 20,
allowExplicitWhenConfiguredAuto: true,
shouldContinueAutoSelection: isMissingEmbeddingApiKeyError,
create: async (options) => {
const { provider, client } = await createOpenAiEmbeddingProvider({
...options,
provider: "openai",
fallback: "none",
});
return {
provider,
runtime: {
id: "openai",
cacheKeyData: {
provider: "openai",
baseUrl: client.baseUrl,
model: client.model,
headers: sanitizeEmbeddingCacheHeaders(client.headers, ["authorization"]),
},
batchEmbed: async (batch) => {
const byCustomId = await runOpenAiEmbeddingBatches({
openAi: client,
agentId: batch.agentId,
requests: batch.chunks.map((chunk, index) => ({
custom_id: String(index),
method: "POST",
url: OPENAI_BATCH_ENDPOINT,
body: {
model: client.model,
input: chunk.text,
},
})),
wait: batch.wait,
concurrency: batch.concurrency,
pollIntervalMs: batch.pollIntervalMs,
timeoutMs: batch.timeoutMs,
debug: batch.debug,
});
return mapBatchEmbeddingsByIndex(byCustomId, batch.chunks.length);
},
},
};
},
};

View File

@@ -39,6 +39,7 @@
"speechProviders": ["openai"],
"realtimeTranscriptionProviders": ["openai"],
"realtimeVoiceProviders": ["openai"],
"memoryEmbeddingProviders": ["openai"],
"mediaUnderstandingProviders": ["openai", "openai-codex"],
"imageGenerationProviders": ["openai"],
"videoGenerationProviders": ["openai"]

View File

@@ -19,8 +19,8 @@ import {
type ProviderBatchOutputLine,
uploadBatchJsonlFile,
withRemoteHttpResponse,
} from "./batch-embedding-common.js";
import type { VoyageEmbeddingClient } from "./embeddings-voyage.js";
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
import type { VoyageEmbeddingClient } from "./embedding-provider.js";
/**
* Voyage Batch API Input Line format.

View File

@@ -1,8 +1,11 @@
import type { SsrFPolicy } from "../../infra/net/ssrf.js";
import { normalizeEmbeddingModelWithPrefixes } from "./embeddings-model-normalize.js";
import { resolveRemoteEmbeddingBearerClient } from "./embeddings-remote-client.js";
import { fetchRemoteEmbeddingVectors } from "./embeddings-remote-fetch.js";
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.types.js";
import {
fetchRemoteEmbeddingVectors,
normalizeEmbeddingModelWithPrefixes,
resolveRemoteEmbeddingBearerClient,
type MemoryEmbeddingProvider,
type MemoryEmbeddingProviderCreateOptions,
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
import type { SsrFPolicy } from "openclaw/plugin-sdk/ssrf-runtime";
export type VoyageEmbeddingClient = {
baseUrl: string;
@@ -28,8 +31,8 @@ export function normalizeVoyageModel(model: string): string {
}
export async function createVoyageEmbeddingProvider(
options: EmbeddingProviderOptions,
): Promise<{ provider: EmbeddingProvider; client: VoyageEmbeddingClient }> {
options: MemoryEmbeddingProviderCreateOptions,
): Promise<{ provider: MemoryEmbeddingProvider; client: VoyageEmbeddingClient }> {
const client = await resolveVoyageEmbeddingClient(options);
const url = `${client.baseUrl.replace(/\/$/, "")}/embeddings`;
@@ -70,7 +73,7 @@ export async function createVoyageEmbeddingProvider(
}
export async function resolveVoyageEmbeddingClient(
options: EmbeddingProviderOptions,
options: MemoryEmbeddingProviderCreateOptions,
): Promise<VoyageEmbeddingClient> {
const { baseUrl, headers, ssrfPolicy } = await resolveRemoteEmbeddingBearerClient({
provider: "voyage",

View File

@@ -0,0 +1,11 @@
import { definePluginEntry } from "openclaw/plugin-sdk/plugin-entry";
import { voyageMemoryEmbeddingProviderAdapter } from "./memory-embedding-adapter.js";
export default definePluginEntry({
id: "voyage",
name: "Voyage Embeddings",
description: "Bundled Voyage memory embedding provider plugin",
register(api) {
api.registerMemoryEmbeddingProvider(voyageMemoryEmbeddingProviderAdapter);
},
});

View File

@@ -0,0 +1,56 @@
import {
isMissingEmbeddingApiKeyError,
mapBatchEmbeddingsByIndex,
sanitizeEmbeddingCacheHeaders,
type MemoryEmbeddingProviderAdapter,
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
import { runVoyageEmbeddingBatches } from "./embedding-batch.js";
import {
createVoyageEmbeddingProvider,
DEFAULT_VOYAGE_EMBEDDING_MODEL,
} from "./embedding-provider.js";
export const voyageMemoryEmbeddingProviderAdapter: MemoryEmbeddingProviderAdapter = {
id: "voyage",
defaultModel: DEFAULT_VOYAGE_EMBEDDING_MODEL,
transport: "remote",
authProviderId: "voyage",
autoSelectPriority: 40,
allowExplicitWhenConfiguredAuto: true,
shouldContinueAutoSelection: isMissingEmbeddingApiKeyError,
create: async (options) => {
const { provider, client } = await createVoyageEmbeddingProvider({
...options,
provider: "voyage",
fallback: "none",
});
return {
provider,
runtime: {
id: "voyage",
cacheKeyData: {
provider: "voyage",
baseUrl: client.baseUrl,
model: client.model,
headers: sanitizeEmbeddingCacheHeaders(client.headers, ["authorization"]),
},
batchEmbed: async (batch) => {
const byCustomId = await runVoyageEmbeddingBatches({
client,
agentId: batch.agentId,
requests: batch.chunks.map((chunk, index) => ({
custom_id: String(index),
body: { input: chunk.text },
})),
wait: batch.wait,
concurrency: batch.concurrency,
pollIntervalMs: batch.pollIntervalMs,
timeoutMs: batch.timeoutMs,
debug: batch.debug,
});
return mapBatchEmbeddingsByIndex(byCustomId, batch.chunks.length);
},
},
};
},
};

View File

@@ -0,0 +1,15 @@
{
"id": "voyage",
"enabledByDefault": true,
"contracts": {
"memoryEmbeddingProviders": ["voyage"]
},
"providerAuthEnvVars": {
"voyage": ["VOYAGE_API_KEY"]
},
"configSchema": {
"type": "object",
"additionalProperties": false,
"properties": {}
}
}

View File

@@ -0,0 +1,15 @@
{
"name": "@openclaw/voyage-provider",
"version": "2026.4.15-beta.1",
"private": true,
"description": "OpenClaw Voyage embedding provider plugin",
"type": "module",
"devDependencies": {
"@openclaw/plugin-sdk": "workspace:*"
},
"openclaw": {
"extensions": [
"./index.ts"
]
}
}

View File

@@ -1,116 +0,0 @@
import { afterEach, beforeAll, beforeEach, describe, expect, it, vi } from "vitest";
import type { GeminiEmbeddingClient } from "./embeddings-gemini.js";
vi.mock("./remote-http.js", () => ({
withRemoteHttpResponse: vi.fn(),
}));
function magnitude(values: number[]) {
return Math.sqrt(values.reduce((sum, value) => sum + value * value, 0));
}
describe("runGeminiEmbeddingBatches", () => {
let runGeminiEmbeddingBatches: typeof import("./batch-gemini.js").runGeminiEmbeddingBatches;
let withRemoteHttpResponse: typeof import("./remote-http.js").withRemoteHttpResponse;
let remoteHttpMock: ReturnType<typeof vi.mocked<typeof withRemoteHttpResponse>>;
beforeAll(async () => {
({ runGeminiEmbeddingBatches } = await import("./batch-gemini.js"));
({ withRemoteHttpResponse } = await import("./remote-http.js"));
remoteHttpMock = vi.mocked(withRemoteHttpResponse);
});
beforeEach(() => {
vi.clearAllMocks();
});
afterEach(() => {
vi.resetAllMocks();
vi.unstubAllGlobals();
});
const mockClient: GeminiEmbeddingClient = {
baseUrl: "https://generativelanguage.googleapis.com/v1beta",
headers: {},
model: "gemini-embedding-2-preview",
modelPath: "models/gemini-embedding-2-preview",
apiKeys: ["test-key"],
outputDimensionality: 1536,
};
it("includes outputDimensionality in batch upload requests", async () => {
remoteHttpMock.mockImplementationOnce(async (params) => {
expect(params.url).toContain("/upload/v1beta/files?uploadType=multipart");
const body = params.init?.body;
if (!(body instanceof Blob)) {
throw new Error("expected multipart blob body");
}
const text = await body.text();
expect(text).toContain('"taskType":"RETRIEVAL_DOCUMENT"');
expect(text).toContain('"outputDimensionality":1536');
return await params.onResponse(
new Response(JSON.stringify({ name: "files/file-123" }), {
status: 200,
headers: { "Content-Type": "application/json" },
}),
);
});
remoteHttpMock.mockImplementationOnce(async (params) => {
expect(params.url).toMatch(/:asyncBatchEmbedContent$/u);
return await params.onResponse(
new Response(
JSON.stringify({
name: "batches/batch-1",
state: "COMPLETED",
outputConfig: { file: "files/output-1" },
}),
{
status: 200,
headers: { "Content-Type": "application/json" },
},
),
);
});
remoteHttpMock.mockImplementationOnce(async (params) => {
expect(params.url).toMatch(/\/files\/output-1:download$/u);
return await params.onResponse(
new Response(
JSON.stringify({
key: "req-1",
response: { embedding: { values: [3, 4] } },
}),
{
status: 200,
headers: { "Content-Type": "application/jsonl" },
},
),
);
});
const results = await runGeminiEmbeddingBatches({
gemini: mockClient,
agentId: "main",
requests: [
{
custom_id: "req-1",
request: {
content: { parts: [{ text: "hello world" }] },
taskType: "RETRIEVAL_DOCUMENT",
outputDimensionality: 1536,
},
},
],
wait: true,
pollIntervalMs: 1,
timeoutMs: 1000,
concurrency: 1,
});
const embedding = results.get("req-1");
expect(embedding).toBeDefined();
expect(embedding?.[0]).toBeCloseTo(0.6, 5);
expect(embedding?.[1]).toBeCloseTo(0.8, 5);
expect(magnitude(embedding ?? [])).toBeCloseTo(1, 5);
expect(remoteHttpMock).toHaveBeenCalledTimes(3);
});
});

View File

@@ -1,259 +0,0 @@
import {
applyEmbeddingBatchOutputLine,
buildBatchHeaders,
buildEmbeddingBatchGroupOptions,
EMBEDDING_BATCH_ENDPOINT,
extractBatchErrorMessage,
formatUnavailableBatchError,
normalizeBatchBaseUrl,
postJsonWithRetry,
resolveBatchCompletionFromStatus,
resolveCompletedBatchResult,
runEmbeddingBatchGroups,
throwIfBatchTerminalFailure,
type EmbeddingBatchExecutionParams,
type EmbeddingBatchStatus,
type BatchCompletionResult,
type ProviderBatchOutputLine,
uploadBatchJsonlFile,
withRemoteHttpResponse,
} from "./batch-embedding-common.js";
import type { OpenAiEmbeddingClient } from "./embeddings-openai.js";
export type OpenAiBatchRequest = {
custom_id: string;
method: "POST";
url: "/v1/embeddings";
body: {
model: string;
input: string;
};
};
export type OpenAiBatchStatus = EmbeddingBatchStatus;
export type OpenAiBatchOutputLine = ProviderBatchOutputLine;
export const OPENAI_BATCH_ENDPOINT = EMBEDDING_BATCH_ENDPOINT;
const OPENAI_BATCH_COMPLETION_WINDOW = "24h";
const OPENAI_BATCH_MAX_REQUESTS = 50000;
async function submitOpenAiBatch(params: {
openAi: OpenAiEmbeddingClient;
requests: OpenAiBatchRequest[];
agentId: string;
}): Promise<OpenAiBatchStatus> {
const baseUrl = normalizeBatchBaseUrl(params.openAi);
const inputFileId = await uploadBatchJsonlFile({
client: params.openAi,
requests: params.requests,
errorPrefix: "openai batch file upload failed",
});
return await postJsonWithRetry<OpenAiBatchStatus>({
url: `${baseUrl}/batches`,
headers: buildBatchHeaders(params.openAi, { json: true }),
ssrfPolicy: params.openAi.ssrfPolicy,
body: {
input_file_id: inputFileId,
endpoint: OPENAI_BATCH_ENDPOINT,
completion_window: OPENAI_BATCH_COMPLETION_WINDOW,
metadata: {
source: "openclaw-memory",
agent: params.agentId,
},
},
errorPrefix: "openai batch create failed",
});
}
async function fetchOpenAiBatchStatus(params: {
openAi: OpenAiEmbeddingClient;
batchId: string;
}): Promise<OpenAiBatchStatus> {
return await fetchOpenAiBatchResource({
openAi: params.openAi,
path: `/batches/${params.batchId}`,
errorPrefix: "openai batch status",
parse: async (res) => (await res.json()) as OpenAiBatchStatus,
});
}
async function fetchOpenAiFileContent(params: {
openAi: OpenAiEmbeddingClient;
fileId: string;
}): Promise<string> {
return await fetchOpenAiBatchResource({
openAi: params.openAi,
path: `/files/${params.fileId}/content`,
errorPrefix: "openai batch file content",
parse: async (res) => await res.text(),
});
}
async function fetchOpenAiBatchResource<T>(params: {
openAi: OpenAiEmbeddingClient;
path: string;
errorPrefix: string;
parse: (res: Response) => Promise<T>;
}): Promise<T> {
const baseUrl = normalizeBatchBaseUrl(params.openAi);
return await withRemoteHttpResponse({
url: `${baseUrl}${params.path}`,
ssrfPolicy: params.openAi.ssrfPolicy,
init: {
headers: buildBatchHeaders(params.openAi, { json: true }),
},
onResponse: async (res) => {
if (!res.ok) {
const text = await res.text();
throw new Error(`${params.errorPrefix} failed: ${res.status} ${text}`);
}
return await params.parse(res);
},
});
}
function parseOpenAiBatchOutput(text: string): OpenAiBatchOutputLine[] {
if (!text.trim()) {
return [];
}
return text
.split("\n")
.map((line) => line.trim())
.filter(Boolean)
.map((line) => JSON.parse(line) as OpenAiBatchOutputLine);
}
async function readOpenAiBatchError(params: {
openAi: OpenAiEmbeddingClient;
errorFileId: string;
}): Promise<string | undefined> {
try {
const content = await fetchOpenAiFileContent({
openAi: params.openAi,
fileId: params.errorFileId,
});
const lines = parseOpenAiBatchOutput(content);
return extractBatchErrorMessage(lines);
} catch (err) {
return formatUnavailableBatchError(err);
}
}
async function waitForOpenAiBatch(params: {
openAi: OpenAiEmbeddingClient;
batchId: string;
wait: boolean;
pollIntervalMs: number;
timeoutMs: number;
debug?: (message: string, data?: Record<string, unknown>) => void;
initial?: OpenAiBatchStatus;
}): Promise<BatchCompletionResult> {
const start = Date.now();
let current: OpenAiBatchStatus | undefined = params.initial;
while (true) {
const status =
current ??
(await fetchOpenAiBatchStatus({
openAi: params.openAi,
batchId: params.batchId,
}));
const state = status.status ?? "unknown";
if (state === "completed") {
return resolveBatchCompletionFromStatus({
provider: "openai",
batchId: params.batchId,
status,
});
}
await throwIfBatchTerminalFailure({
provider: "openai",
status: { ...status, id: params.batchId },
readError: async (errorFileId) =>
await readOpenAiBatchError({
openAi: params.openAi,
errorFileId,
}),
});
if (!params.wait) {
throw new Error(`openai batch ${params.batchId} still ${state}; wait disabled`);
}
if (Date.now() - start > params.timeoutMs) {
throw new Error(`openai batch ${params.batchId} timed out after ${params.timeoutMs}ms`);
}
params.debug?.(`openai batch ${params.batchId} ${state}; waiting ${params.pollIntervalMs}ms`);
await new Promise((resolve) => setTimeout(resolve, params.pollIntervalMs));
current = undefined;
}
}
export async function runOpenAiEmbeddingBatches(
params: {
openAi: OpenAiEmbeddingClient;
agentId: string;
requests: OpenAiBatchRequest[];
} & EmbeddingBatchExecutionParams,
): Promise<Map<string, number[]>> {
return await runEmbeddingBatchGroups({
...buildEmbeddingBatchGroupOptions(params, {
maxRequests: OPENAI_BATCH_MAX_REQUESTS,
debugLabel: "memory embeddings: openai batch submit",
}),
runGroup: async ({ group, groupIndex, groups, byCustomId }) => {
const batchInfo = await submitOpenAiBatch({
openAi: params.openAi,
requests: group,
agentId: params.agentId,
});
if (!batchInfo.id) {
throw new Error("openai batch create failed: missing batch id");
}
const batchId = batchInfo.id;
params.debug?.("memory embeddings: openai batch created", {
batchId: batchInfo.id,
status: batchInfo.status,
group: groupIndex + 1,
groups,
requests: group.length,
});
const completed = await resolveCompletedBatchResult({
provider: "openai",
status: batchInfo,
wait: params.wait,
waitForBatch: async () =>
await waitForOpenAiBatch({
openAi: params.openAi,
batchId,
wait: params.wait,
pollIntervalMs: params.pollIntervalMs,
timeoutMs: params.timeoutMs,
debug: params.debug,
initial: batchInfo,
}),
});
const content = await fetchOpenAiFileContent({
openAi: params.openAi,
fileId: completed.outputFileId,
});
const outputLines = parseOpenAiBatchOutput(content);
const errors: string[] = [];
const remaining = new Set(group.map((request) => request.custom_id));
for (const line of outputLines) {
applyEmbeddingBatchOutputLine({ line, remaining, errors, byCustomId });
}
if (errors.length > 0) {
throw new Error(`openai batch ${batchInfo.id} failed: ${errors.join("; ")}`);
}
if (remaining.size > 0) {
throw new Error(
`openai batch ${batchInfo.id} missing ${remaining.size} embedding responses`,
);
}
},
});
}

View File

@@ -1,176 +0,0 @@
import { ReadableStream } from "node:stream/web";
import { setTimeout as nativeSleep } from "node:timers/promises";
import { describe, expect, it, vi } from "vitest";
import {
runVoyageEmbeddingBatches,
type VoyageBatchOutputLine,
type VoyageBatchRequest,
} from "./batch-voyage.js";
import type { VoyageEmbeddingClient } from "./embeddings-voyage.js";
const realNow = Date.now.bind(Date);
describe("runVoyageEmbeddingBatches", () => {
const mockClient: VoyageEmbeddingClient = {
baseUrl: "https://api.voyageai.com/v1",
headers: { Authorization: "Bearer test-key" },
model: "voyage-4-large",
};
const mockRequests: VoyageBatchRequest[] = [
{ custom_id: "req-1", body: { input: "text1" } },
{ custom_id: "req-2", body: { input: "text2" } },
];
it("successfully submits batch, waits, and streams results", async () => {
const outputLines: VoyageBatchOutputLine[] = [
{
custom_id: "req-1",
response: { status_code: 200, body: { data: [{ embedding: [0.1, 0.1] }] } },
},
{
custom_id: "req-2",
response: { status_code: 200, body: { data: [{ embedding: [0.2, 0.2] }] } },
},
];
const withRemoteHttpResponse = vi.fn();
const postJsonWithRetry = vi.fn();
const uploadBatchJsonlFile = vi.fn();
// Create a stream that emits the NDJSON lines
const stream = new ReadableStream({
start(controller) {
const text = outputLines.map((l) => JSON.stringify(l)).join("\n");
controller.enqueue(new TextEncoder().encode(text));
controller.close();
},
});
uploadBatchJsonlFile.mockImplementationOnce(async (params) => {
expect(params.errorPrefix).toBe("voyage batch file upload failed");
expect(params.requests).toEqual(mockRequests);
return "file-123";
});
postJsonWithRetry.mockImplementationOnce(async (params) => {
expect(params.url).toContain("/batches");
expect(params.body).toMatchObject({
input_file_id: "file-123",
completion_window: "12h",
request_params: {
model: "voyage-4-large",
input_type: "document",
},
});
return {
id: "batch-abc",
status: "pending",
};
});
withRemoteHttpResponse.mockImplementationOnce(async (params) => {
expect(params.url).toContain("/batches/batch-abc");
return await params.onResponse(
new Response(
JSON.stringify({
id: "batch-abc",
status: "completed",
output_file_id: "file-out-999",
}),
{
status: 200,
headers: { "Content-Type": "application/json" },
},
),
);
});
withRemoteHttpResponse.mockImplementationOnce(async (params) => {
expect(params.url).toContain("/files/file-out-999/content");
return await params.onResponse(
new Response(stream as unknown as BodyInit, {
status: 200,
headers: { "Content-Type": "application/x-ndjson" },
}),
);
});
const results = await runVoyageEmbeddingBatches({
client: mockClient,
agentId: "agent-1",
requests: mockRequests,
wait: true,
pollIntervalMs: 1, // fast poll
timeoutMs: 1000,
concurrency: 1,
deps: {
now: realNow,
sleep: async (ms) => {
await nativeSleep(ms);
},
postJsonWithRetry,
uploadBatchJsonlFile,
withRemoteHttpResponse,
},
});
expect(results.size).toBe(2);
expect(results.get("req-1")).toEqual([0.1, 0.1]);
expect(results.get("req-2")).toEqual([0.2, 0.2]);
expect(uploadBatchJsonlFile).toHaveBeenCalledTimes(1);
expect(postJsonWithRetry).toHaveBeenCalledTimes(1);
expect(withRemoteHttpResponse).toHaveBeenCalledTimes(2);
});
it("handles empty lines and stream chunks correctly", async () => {
const withRemoteHttpResponse = vi.fn();
const postJsonWithRetry = vi.fn();
const uploadBatchJsonlFile = vi.fn();
const stream = new ReadableStream({
start(controller) {
const line1 = JSON.stringify({
custom_id: "req-1",
response: { body: { data: [{ embedding: [1] }] } },
});
const line2 = JSON.stringify({
custom_id: "req-2",
response: { body: { data: [{ embedding: [2] }] } },
});
// Split across chunks
controller.enqueue(new TextEncoder().encode(line1 + "\n"));
controller.enqueue(new TextEncoder().encode("\n")); // empty line
controller.enqueue(new TextEncoder().encode(line2)); // no newline at EOF
controller.close();
},
});
uploadBatchJsonlFile.mockResolvedValueOnce("f1");
postJsonWithRetry.mockResolvedValueOnce({
id: "b1",
status: "completed",
output_file_id: "out1",
});
withRemoteHttpResponse.mockImplementationOnce(async (params) => {
expect(params.url).toContain("/files/out1/content");
return await params.onResponse(new Response(stream as unknown as BodyInit, { status: 200 }));
});
const results = await runVoyageEmbeddingBatches({
client: mockClient,
agentId: "a1",
requests: mockRequests,
wait: true,
pollIntervalMs: 1,
timeoutMs: 1000,
concurrency: 1,
deps: {
now: realNow,
sleep: async (ms) => {
await nativeSleep(ms);
},
postJsonWithRetry,
uploadBatchJsonlFile,
withRemoteHttpResponse,
},
});
expect(results.get("req-1")).toEqual([1]);
expect(results.get("req-2")).toEqual([2]);
});
});

View File

@@ -1,40 +1,14 @@
import { normalizeLowercaseStringOrEmpty } from "../../../../src/shared/string-coerce.js";
import type { EmbeddingProvider } from "./embeddings.js";
const DEFAULT_EMBEDDING_MAX_INPUT_TOKENS = 8192;
const DEFAULT_LOCAL_EMBEDDING_MAX_INPUT_TOKENS = 2048;
const KNOWN_EMBEDDING_MAX_INPUT_TOKENS: Record<string, number> = {
"openai:text-embedding-3-small": 8192,
"openai:text-embedding-3-large": 8192,
"openai:text-embedding-ada-002": 8191,
"gemini:text-embedding-004": 2048,
"gemini:gemini-embedding-001": 2048,
"gemini:gemini-embedding-2-preview": 8192,
"voyage:voyage-3": 32000,
"voyage:voyage-3-lite": 16000,
"voyage:voyage-code-3": 32000,
};
export function resolveEmbeddingMaxInputTokens(provider: EmbeddingProvider): number {
if (typeof provider.maxInputTokens === "number") {
return provider.maxInputTokens;
}
// Provider/model mapping is best-effort; different providers use different
// limits and we prefer to be conservative when we don't know.
const key = normalizeLowercaseStringOrEmpty(`${provider.id}:${provider.model}`);
const known = KNOWN_EMBEDDING_MAX_INPUT_TOKENS[key];
if (typeof known === "number") {
return known;
}
// Provider-specific conservative fallbacks. This prevents us from accidentally
// using the OpenAI default for providers with much smaller limits.
if (normalizeLowercaseStringOrEmpty(provider.id) === "gemini") {
return 2048;
}
if (normalizeLowercaseStringOrEmpty(provider.id) === "local") {
if (provider.id === "local") {
return DEFAULT_LOCAL_EMBEDDING_MAX_INPUT_TOKENS;
}

View File

@@ -1,377 +0,0 @@
import { afterEach, beforeAll, beforeEach, describe, expect, it, vi } from "vitest";
const { defaultProviderMock, resolveCredentialsMock, sendMock } = vi.hoisted(() => ({
defaultProviderMock: vi.fn(),
resolveCredentialsMock: vi.fn(),
sendMock: vi.fn(),
}));
vi.mock("@aws-sdk/client-bedrock-runtime", () => {
class MockClient {
region: string;
constructor(config: { region: string }) {
this.region = config.region;
}
send = sendMock;
}
class MockCommand {
input: unknown;
constructor(input: unknown) {
this.input = input;
}
}
return { BedrockRuntimeClient: MockClient, InvokeModelCommand: MockCommand };
});
vi.mock("@aws-sdk/credential-provider-node", () => ({
defaultProvider: defaultProviderMock.mockImplementation(() => resolveCredentialsMock),
}));
let createBedrockEmbeddingProvider: typeof import("./embeddings-bedrock.js").createBedrockEmbeddingProvider;
let resolveBedrockEmbeddingClient: typeof import("./embeddings-bedrock.js").resolveBedrockEmbeddingClient;
let normalizeBedrockEmbeddingModel: typeof import("./embeddings-bedrock.js").normalizeBedrockEmbeddingModel;
let hasAwsCredentials: typeof import("./embeddings-bedrock.js").hasAwsCredentials;
beforeAll(async () => {
({
createBedrockEmbeddingProvider,
resolveBedrockEmbeddingClient,
normalizeBedrockEmbeddingModel,
hasAwsCredentials,
} = await import("./embeddings-bedrock.js"));
});
beforeEach(() => {
defaultProviderMock.mockImplementation(() => resolveCredentialsMock);
});
const enc = (body: unknown) => ({ body: new TextEncoder().encode(JSON.stringify(body)) });
const reqBody = (i = 0): Record<string, unknown> =>
JSON.parse(sendMock.mock.calls[i][0].input.body);
describe("bedrock embedding provider", () => {
const originalEnv = process.env;
afterEach(() => {
process.env = originalEnv;
vi.restoreAllMocks();
defaultProviderMock.mockClear();
resolveCredentialsMock.mockReset();
sendMock.mockReset();
});
// --- Normalization ---
it("normalizes model names with prefixes", () => {
expect(normalizeBedrockEmbeddingModel("bedrock/amazon.titan-embed-text-v2:0")).toBe(
"amazon.titan-embed-text-v2:0",
);
expect(normalizeBedrockEmbeddingModel("amazon-bedrock/cohere.embed-english-v3")).toBe(
"cohere.embed-english-v3",
);
expect(normalizeBedrockEmbeddingModel("")).toBe("amazon.titan-embed-text-v2:0");
});
// --- Client resolution ---
it("resolves region from env", () => {
process.env = { ...originalEnv, AWS_REGION: "eu-west-1" };
const c = resolveBedrockEmbeddingClient({
config: {} as never,
provider: "bedrock",
model: "amazon.titan-embed-text-v2:0",
fallback: "none",
});
expect(c.region).toBe("eu-west-1");
expect(c.dimensions).toBe(1024);
});
it("defaults to us-east-1", () => {
process.env = { ...originalEnv };
delete process.env.AWS_REGION;
delete process.env.AWS_DEFAULT_REGION;
expect(
resolveBedrockEmbeddingClient({
config: {} as never,
provider: "bedrock",
model: "amazon.titan-embed-text-v2:0",
fallback: "none",
}).region,
).toBe("us-east-1");
});
it("extracts region from baseUrl", () => {
process.env = { ...originalEnv };
delete process.env.AWS_REGION;
const c = resolveBedrockEmbeddingClient({
config: {
models: {
providers: {
"amazon-bedrock": { baseUrl: "https://bedrock-runtime.ap-southeast-2.amazonaws.com" },
},
},
} as never,
provider: "bedrock",
model: "amazon.titan-embed-text-v2:0",
fallback: "none",
});
expect(c.region).toBe("ap-southeast-2");
});
it("validates dimensions", () => {
expect(() =>
resolveBedrockEmbeddingClient({
config: {} as never,
provider: "bedrock",
model: "amazon.titan-embed-text-v2:0",
fallback: "none",
outputDimensionality: 768,
}),
).toThrow("Invalid dimensions 768");
});
it("accepts valid dimensions", () => {
expect(
resolveBedrockEmbeddingClient({
config: {} as never,
provider: "bedrock",
model: "amazon.titan-embed-text-v2:0",
fallback: "none",
outputDimensionality: 256,
}).dimensions,
).toBe(256);
});
it("resolves throughput-suffixed variants", () => {
expect(
resolveBedrockEmbeddingClient({
config: {} as never,
provider: "bedrock",
model: "amazon.titan-embed-text-v1:2:8k",
fallback: "none",
}).dimensions,
).toBe(1536);
});
// --- Credential detection ---
it("detects access keys", async () => {
await expect(
hasAwsCredentials({
AWS_ACCESS_KEY_ID: "A",
AWS_SECRET_ACCESS_KEY: "s",
} as NodeJS.ProcessEnv),
).resolves.toBe(true);
});
it("detects profile", async () => {
await expect(hasAwsCredentials({ AWS_PROFILE: "default" } as NodeJS.ProcessEnv)).resolves.toBe(
true,
);
});
it("detects ECS task role", async () => {
await expect(
hasAwsCredentials({ AWS_CONTAINER_CREDENTIALS_RELATIVE_URI: "/v2" } as NodeJS.ProcessEnv),
).resolves.toBe(true);
});
it("detects EKS IRSA", async () => {
await expect(
hasAwsCredentials({
AWS_WEB_IDENTITY_TOKEN_FILE: "/var/run/secrets/token",
AWS_ROLE_ARN: "arn:aws:iam::123:role/x",
} as NodeJS.ProcessEnv),
).resolves.toBe(true);
});
it("detects credentials via the AWS SDK default provider chain", async () => {
resolveCredentialsMock.mockResolvedValue({ accessKeyId: "AKIAEXAMPLE" });
await expect(hasAwsCredentials({} as NodeJS.ProcessEnv)).resolves.toBe(true);
expect(defaultProviderMock).toHaveBeenCalledWith({ timeout: 1000, maxRetries: 0 });
});
it("returns false with no creds", async () => {
resolveCredentialsMock.mockRejectedValue(new Error("no aws credentials"));
await expect(hasAwsCredentials({} as NodeJS.ProcessEnv)).resolves.toBe(false);
});
// --- Titan V2 ---
it("embeds with Titan V2", async () => {
sendMock.mockResolvedValue(enc({ embedding: [0.1, 0.2, 0.3] }));
const { provider } = await createBedrockEmbeddingProvider({
config: {} as never,
provider: "bedrock",
model: "amazon.titan-embed-text-v2:0",
fallback: "none",
});
expect(await provider.embedQuery("test")).toHaveLength(3);
expect(reqBody()).toMatchObject({ inputText: "test", normalize: true, dimensions: 1024 });
});
it("returns empty for blank text", async () => {
const { provider } = await createBedrockEmbeddingProvider({
config: {} as never,
provider: "bedrock",
model: "amazon.titan-embed-text-v2:0",
fallback: "none",
});
expect(await provider.embedQuery(" ")).toEqual([]);
expect(sendMock).not.toHaveBeenCalled();
});
it("batches Titan V2 concurrently", async () => {
sendMock
.mockResolvedValueOnce(enc({ embedding: [0.1] }))
.mockResolvedValueOnce(enc({ embedding: [0.2] }));
const { provider } = await createBedrockEmbeddingProvider({
config: {} as never,
provider: "bedrock",
model: "amazon.titan-embed-text-v2:0",
fallback: "none",
});
expect(await provider.embedBatch(["a", "b"])).toHaveLength(2);
expect(sendMock).toHaveBeenCalledTimes(2);
});
// --- Titan V1 ---
it("sends only inputText for Titan V1", async () => {
sendMock.mockResolvedValue(enc({ embedding: [0.5] }));
const { provider } = await createBedrockEmbeddingProvider({
config: {} as never,
provider: "bedrock",
model: "amazon.titan-embed-text-v1",
fallback: "none",
});
await provider.embedQuery("hi");
expect(reqBody()).toEqual({ inputText: "hi" });
});
it("handles Titan G1 text variant", async () => {
sendMock.mockResolvedValue(enc({ embedding: [0.1] }));
const { provider } = await createBedrockEmbeddingProvider({
config: {} as never,
provider: "bedrock",
model: "amazon.titan-embed-g1-text-02",
fallback: "none",
});
await provider.embedQuery("hi");
expect(reqBody()).toEqual({ inputText: "hi" });
});
// --- Cohere V3 ---
it("embeds Cohere V3 batch in single call", async () => {
sendMock.mockResolvedValue(enc({ embeddings: [[0.1], [0.2]] }));
const { provider } = await createBedrockEmbeddingProvider({
config: {} as never,
provider: "bedrock",
model: "cohere.embed-english-v3",
fallback: "none",
});
expect(await provider.embedBatch(["a", "b"])).toHaveLength(2);
expect(sendMock).toHaveBeenCalledTimes(1);
expect(reqBody()).toMatchObject({ texts: ["a", "b"], input_type: "search_document" });
});
it("uses search_query for Cohere embedQuery", async () => {
sendMock.mockResolvedValue(enc({ embeddings: [[0.1]] }));
const { provider } = await createBedrockEmbeddingProvider({
config: {} as never,
provider: "bedrock",
model: "cohere.embed-english-v3",
fallback: "none",
});
await provider.embedQuery("q");
expect(reqBody().input_type).toBe("search_query");
});
// --- Cohere V4 ---
it("embeds Cohere V4 with embedding_types + output_dimension", async () => {
sendMock.mockResolvedValue(enc({ embeddings: { float: [[0.1], [0.2]] } }));
const { provider } = await createBedrockEmbeddingProvider({
config: {} as never,
provider: "bedrock",
model: "cohere.embed-v4:0",
fallback: "none",
});
expect(await provider.embedBatch(["a", "b"])).toHaveLength(2);
expect(reqBody()).toMatchObject({ embedding_types: ["float"], output_dimension: 1536 });
});
it("validates Cohere V4 dimensions", () => {
expect(() =>
resolveBedrockEmbeddingClient({
config: {} as never,
provider: "bedrock",
model: "cohere.embed-v4:0",
fallback: "none",
outputDimensionality: 2048,
}),
).toThrow("Invalid dimensions 2048");
});
// --- Nova ---
it("embeds Nova with SINGLE_EMBEDDING format", async () => {
sendMock.mockResolvedValue(
enc({ embeddings: [{ embeddingType: "TEXT", embedding: [0.1, 0.2] }] }),
);
const { provider } = await createBedrockEmbeddingProvider({
config: {} as never,
provider: "bedrock",
model: "amazon.nova-2-multimodal-embeddings-v1:0",
fallback: "none",
});
expect(await provider.embedQuery("hi")).toHaveLength(2);
expect(reqBody().taskType).toBe("SINGLE_EMBEDDING");
});
it("validates Nova dimensions", () => {
expect(() =>
resolveBedrockEmbeddingClient({
config: {} as never,
provider: "bedrock",
model: "amazon.nova-2-multimodal-embeddings-v1:0",
fallback: "none",
outputDimensionality: 512,
}),
).toThrow("Invalid dimensions 512");
});
it("batches Nova concurrently", async () => {
sendMock
.mockResolvedValueOnce(enc({ embeddings: [{ embeddingType: "TEXT", embedding: [0.1] }] }))
.mockResolvedValueOnce(enc({ embeddings: [{ embeddingType: "TEXT", embedding: [0.2] }] }));
const { provider } = await createBedrockEmbeddingProvider({
config: {} as never,
provider: "bedrock",
model: "amazon.nova-2-multimodal-embeddings-v1:0",
fallback: "none",
});
expect(await provider.embedBatch(["a", "b"])).toHaveLength(2);
expect(sendMock).toHaveBeenCalledTimes(2);
});
// --- TwelveLabs ---
it("embeds TwelveLabs Marengo", async () => {
sendMock.mockResolvedValue(enc({ data: [{ embedding: [0.1, 0.2] }] }));
const { provider } = await createBedrockEmbeddingProvider({
config: {} as never,
provider: "bedrock",
model: "twelvelabs.marengo-embed-3-0-v1:0",
fallback: "none",
});
expect(await provider.embedQuery("hi")).toHaveLength(2);
expect(reqBody()).toEqual({ inputType: "text", text: { inputText: "hi" } });
});
it("embeds TwelveLabs object-style responses", async () => {
sendMock.mockResolvedValue(enc({ data: { embedding: [0.3, 0.4] } }));
const { provider } = await createBedrockEmbeddingProvider({
config: {} as never,
provider: "bedrock",
model: "twelvelabs.marengo-embed-2-7-v1:0",
fallback: "none",
});
expect(await provider.embedQuery("hi")).toEqual([0.6, 0.8]);
});
});

View File

@@ -1,398 +0,0 @@
import { normalizeLowercaseStringOrEmpty } from "../../../../src/shared/string-coerce.js";
import { sanitizeAndNormalizeEmbedding } from "./embedding-vectors.js";
import { debugEmbeddingsLog } from "./embeddings-debug.js";
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js";
// ---------------------------------------------------------------------------
// Types & constants
// ---------------------------------------------------------------------------
export type BedrockEmbeddingClient = {
region: string;
model: string;
dimensions?: number;
};
export const DEFAULT_BEDROCK_EMBEDDING_MODEL = "amazon.titan-embed-text-v2:0";
/** Request/response format family — each has a different API shape. */
type Family = "titan-v1" | "titan-v2" | "cohere-v3" | "cohere-v4" | "nova" | "twelvelabs";
interface ModelSpec {
maxTokens: number;
dims: number;
validDims?: number[];
family: Family;
}
// ---------------------------------------------------------------------------
// Model catalog
// ---------------------------------------------------------------------------
const MODELS: Record<string, ModelSpec> = {
"amazon.titan-embed-text-v2:0": {
maxTokens: 8192,
dims: 1024,
validDims: [256, 512, 1024],
family: "titan-v2",
},
"amazon.titan-embed-text-v1": { maxTokens: 8000, dims: 1536, family: "titan-v1" },
"amazon.titan-embed-g1-text-02": { maxTokens: 8000, dims: 1536, family: "titan-v1" },
"amazon.titan-embed-image-v1": { maxTokens: 128, dims: 1024, family: "titan-v1" },
"cohere.embed-english-v3": { maxTokens: 512, dims: 1024, family: "cohere-v3" },
"cohere.embed-multilingual-v3": { maxTokens: 512, dims: 1024, family: "cohere-v3" },
"cohere.embed-v4:0": {
maxTokens: 128000,
dims: 1536,
validDims: [256, 384, 512, 768, 1024, 1536],
family: "cohere-v4",
},
"amazon.nova-2-multimodal-embeddings-v1:0": {
maxTokens: 8192,
dims: 1024,
validDims: [256, 384, 1024, 3072],
family: "nova",
},
"twelvelabs.marengo-embed-2-7-v1:0": { maxTokens: 512, dims: 1024, family: "twelvelabs" },
"twelvelabs.marengo-embed-3-0-v1:0": { maxTokens: 512, dims: 512, family: "twelvelabs" },
};
/** Resolve spec, stripping throughput suffixes like `:2:8k` or `:0:512`. */
function resolveSpec(modelId: string): ModelSpec | undefined {
if (MODELS[modelId]) {
return MODELS[modelId];
}
const parts = modelId.split(":");
for (let i = parts.length - 1; i >= 1; i--) {
const spec = MODELS[parts.slice(0, i).join(":")];
if (spec) {
return spec;
}
}
return undefined;
}
/** Infer family from model ID prefix when not in catalog. */
function inferFamily(modelId: string): Family {
const id = normalizeLowercaseStringOrEmpty(modelId);
if (id.startsWith("amazon.titan-embed-text-v2")) {
return "titan-v2";
}
if (id.startsWith("amazon.titan-embed")) {
return "titan-v1";
}
if (id.startsWith("amazon.nova")) {
return "nova";
}
if (id.startsWith("cohere.embed-v4")) {
return "cohere-v4";
}
if (id.startsWith("cohere.embed")) {
return "cohere-v3";
}
if (id.startsWith("twelvelabs.")) {
return "twelvelabs";
}
return "titan-v1"; // safest default — simplest request format
}
// ---------------------------------------------------------------------------
// AWS SDK lazy loader
// ---------------------------------------------------------------------------
type SdkClient = import("@aws-sdk/client-bedrock-runtime").BedrockRuntimeClient;
type SdkCommand = import("@aws-sdk/client-bedrock-runtime").InvokeModelCommand;
interface AwsSdk {
BedrockRuntimeClient: new (config: { region: string }) => SdkClient;
InvokeModelCommand: new (input: {
modelId: string;
body: string;
contentType: string;
accept: string;
}) => SdkCommand;
}
interface AwsCredentialProviderSdk {
defaultProvider: (init?: { timeout?: number; maxRetries?: number }) => () => Promise<{
accessKeyId?: string;
}>;
}
let sdkCache: AwsSdk | null = null;
let credentialProviderSdkCache: AwsCredentialProviderSdk | null | undefined;
async function loadSdk(): Promise<AwsSdk> {
if (sdkCache) {
return sdkCache;
}
try {
sdkCache = (await import("@aws-sdk/client-bedrock-runtime")) as unknown as AwsSdk;
return sdkCache;
} catch {
throw new Error(
"No API key found for provider bedrock: @aws-sdk/client-bedrock-runtime is not installed. " +
"Install it with: npm install @aws-sdk/client-bedrock-runtime",
);
}
}
async function loadCredentialProviderSdk(): Promise<AwsCredentialProviderSdk | null> {
if (credentialProviderSdkCache !== undefined) {
return credentialProviderSdkCache;
}
try {
credentialProviderSdkCache =
(await import("@aws-sdk/credential-provider-node")) as unknown as AwsCredentialProviderSdk;
} catch {
credentialProviderSdkCache = null;
}
return credentialProviderSdkCache;
}
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
const MODEL_PREFIX_RE = /^(?:bedrock|amazon-bedrock|aws)\//;
const REGION_RE = /bedrock-runtime\.([a-z0-9-]+)\./;
export function normalizeBedrockEmbeddingModel(model: string): string {
const trimmed = model.trim();
return trimmed ? trimmed.replace(MODEL_PREFIX_RE, "") : DEFAULT_BEDROCK_EMBEDDING_MODEL;
}
function regionFromUrl(url: string | undefined): string | undefined {
return url?.trim() ? REGION_RE.exec(url)?.[1] : undefined;
}
// ---------------------------------------------------------------------------
// Request builders
// ---------------------------------------------------------------------------
function buildBody(family: Family, text: string, dims?: number): string {
switch (family) {
case "titan-v2": {
const b: Record<string, unknown> = { inputText: text };
if (dims != null) {
b.dimensions = dims;
b.normalize = true;
}
return JSON.stringify(b);
}
case "titan-v1":
return JSON.stringify({ inputText: text });
case "nova":
return JSON.stringify({
taskType: "SINGLE_EMBEDDING",
singleEmbeddingParams: {
embeddingPurpose: "GENERIC_INDEX",
embeddingDimension: dims ?? 1024,
text: { truncationMode: "END", value: text },
},
});
case "twelvelabs":
return JSON.stringify({ inputType: "text", text: { inputText: text } });
default:
return JSON.stringify({ inputText: text });
}
}
function buildCohereBody(
family: Family,
texts: string[],
inputType: "search_query" | "search_document",
dims?: number,
): string {
const body: Record<string, unknown> = { texts, input_type: inputType, truncate: "END" };
if (family === "cohere-v4") {
body.embedding_types = ["float"];
if (dims != null) {
body.output_dimension = dims;
}
}
return JSON.stringify(body);
}
// ---------------------------------------------------------------------------
// Response parsers
// ---------------------------------------------------------------------------
function parseSingle(family: Family, raw: string): number[] {
const data = JSON.parse(raw);
switch (family) {
case "nova":
return data.embeddings?.[0]?.embedding ?? [];
case "twelvelabs": {
if (Array.isArray(data.data)) {
return data.data[0]?.embedding ?? [];
}
if (Array.isArray(data.data?.embedding)) {
return data.data.embedding;
}
return data.embedding ?? [];
}
default:
return data.embedding ?? [];
}
}
function parseCohereBatch(family: Family, raw: string): number[][] {
const data = JSON.parse(raw);
const embeddings = data.embeddings;
if (!embeddings) {
return [];
}
if (family === "cohere-v4" && !Array.isArray(embeddings)) {
return embeddings.float ?? [];
}
return embeddings;
}
// ---------------------------------------------------------------------------
// Provider
// ---------------------------------------------------------------------------
export async function createBedrockEmbeddingProvider(
options: EmbeddingProviderOptions,
): Promise<{ provider: EmbeddingProvider; client: BedrockEmbeddingClient }> {
const client = resolveBedrockEmbeddingClient(options);
const { BedrockRuntimeClient, InvokeModelCommand } = await loadSdk();
const sdk = new BedrockRuntimeClient({ region: client.region });
const spec = resolveSpec(client.model);
const family = spec?.family ?? inferFamily(client.model);
debugEmbeddingsLog("memory embeddings: bedrock client", {
region: client.region,
model: client.model,
dimensions: client.dimensions,
family,
});
const invoke = async (body: string): Promise<string> => {
const res = await sdk.send(
new InvokeModelCommand({
modelId: client.model,
body,
contentType: "application/json",
accept: "application/json",
}),
);
return new TextDecoder().decode(res.body);
};
const isCohere = family === "cohere-v3" || family === "cohere-v4";
const embedSingle = async (text: string): Promise<number[]> => {
const raw = await invoke(buildBody(family, text, client.dimensions));
return sanitizeAndNormalizeEmbedding(parseSingle(family, raw));
};
const embedCohere = async (
texts: string[],
inputType: "search_query" | "search_document",
): Promise<number[][]> => {
const raw = await invoke(buildCohereBody(family, texts, inputType, client.dimensions));
return parseCohereBatch(family, raw).map((e) => sanitizeAndNormalizeEmbedding(e));
};
const embedQuery = async (text: string): Promise<number[]> => {
if (!text.trim()) {
return [];
}
if (isCohere) {
return (await embedCohere([text], "search_query"))[0] ?? [];
}
return embedSingle(text);
};
const embedBatch = async (texts: string[]): Promise<number[][]> => {
if (texts.length === 0) {
return [];
}
if (isCohere) {
return embedCohere(texts, "search_document");
}
return Promise.all(texts.map((t) => (t.trim() ? embedSingle(t) : Promise.resolve([]))));
};
return {
provider: {
id: "bedrock",
model: client.model,
maxInputTokens: spec?.maxTokens,
embedQuery,
embedBatch,
},
client,
};
}
// ---------------------------------------------------------------------------
// Client resolution
// ---------------------------------------------------------------------------
export function resolveBedrockEmbeddingClient(
options: EmbeddingProviderOptions,
): BedrockEmbeddingClient {
const model = normalizeBedrockEmbeddingModel(options.model);
const spec = resolveSpec(model);
const providerConfig = options.config.models?.providers?.["amazon-bedrock"];
const region =
regionFromUrl(options.remote?.baseUrl) ??
regionFromUrl(providerConfig?.baseUrl) ??
process.env.AWS_REGION ??
process.env.AWS_DEFAULT_REGION ??
"us-east-1";
let dimensions: number | undefined;
if (options.outputDimensionality != null) {
if (spec?.validDims && !spec.validDims.includes(options.outputDimensionality)) {
throw new Error(
`Invalid dimensions ${options.outputDimensionality} for ${model}. Valid values: ${spec.validDims.join(", ")}`,
);
}
dimensions = options.outputDimensionality;
} else {
dimensions = spec?.dims;
}
return { region, model, dimensions };
}
// ---------------------------------------------------------------------------
// Credential detection
// ---------------------------------------------------------------------------
const CREDENTIAL_ENV_VARS = [
"AWS_PROFILE",
"AWS_BEARER_TOKEN_BEDROCK",
"AWS_CONTAINER_CREDENTIALS_RELATIVE_URI",
"AWS_CONTAINER_CREDENTIALS_FULL_URI",
"AWS_EC2_METADATA_SERVICE_ENDPOINT",
"AWS_WEB_IDENTITY_TOKEN_FILE",
"AWS_ROLE_ARN",
] as const;
export async function hasAwsCredentials(env: NodeJS.ProcessEnv = process.env): Promise<boolean> {
if (env.AWS_ACCESS_KEY_ID?.trim() && env.AWS_SECRET_ACCESS_KEY?.trim()) {
return true;
}
if (CREDENTIAL_ENV_VARS.some((k) => env[k]?.trim())) {
return true;
}
const credentialProviderSdk = await loadCredentialProviderSdk();
if (!credentialProviderSdk) {
return false;
}
try {
const credentials = await credentialProviderSdk.defaultProvider({
timeout: 1000,
maxRetries: 0,
})();
return typeof credentials.accessKeyId === "string" && credentials.accessKeyId.trim().length > 0;
} catch {
return false;
}
}

View File

@@ -1,121 +0,0 @@
import type { EmbeddingInput } from "./embedding-inputs.js";
export const DEFAULT_GEMINI_EMBEDDING_MODEL = "gemini-embedding-001";
export const GEMINI_EMBEDDING_2_MODELS = new Set([
"gemini-embedding-2-preview",
// Add the GA model name here once released.
]);
const GEMINI_EMBEDDING_2_DEFAULT_DIMENSIONS = 3072;
const GEMINI_EMBEDDING_2_VALID_DIMENSIONS = [768, 1536, 3072] as const;
export type GeminiTaskType =
| "RETRIEVAL_QUERY"
| "RETRIEVAL_DOCUMENT"
| "SEMANTIC_SIMILARITY"
| "CLASSIFICATION"
| "CLUSTERING"
| "QUESTION_ANSWERING"
| "FACT_VERIFICATION";
export type GeminiTextPart = { text: string };
export type GeminiInlinePart = {
inlineData: { mimeType: string; data: string };
};
export type GeminiPart = GeminiTextPart | GeminiInlinePart;
export type GeminiEmbeddingRequest = {
content: { parts: GeminiPart[] };
taskType: GeminiTaskType;
outputDimensionality?: number;
model?: string;
};
export type GeminiTextEmbeddingRequest = GeminiEmbeddingRequest;
/** Builds the text-only Gemini embedding request shape used across direct and batch APIs. */
export function buildGeminiTextEmbeddingRequest(params: {
text: string;
taskType: GeminiTaskType;
outputDimensionality?: number;
modelPath?: string;
}): GeminiTextEmbeddingRequest {
return buildGeminiEmbeddingRequest({
input: { text: params.text },
taskType: params.taskType,
outputDimensionality: params.outputDimensionality,
modelPath: params.modelPath,
});
}
export function buildGeminiEmbeddingRequest(params: {
input: EmbeddingInput;
taskType: GeminiTaskType;
outputDimensionality?: number;
modelPath?: string;
}): GeminiEmbeddingRequest {
const request: GeminiEmbeddingRequest = {
content: {
parts: params.input.parts?.map((part) =>
part.type === "text"
? ({ text: part.text } satisfies GeminiTextPart)
: ({
inlineData: { mimeType: part.mimeType, data: part.data },
} satisfies GeminiInlinePart),
) ?? [{ text: params.input.text }],
},
taskType: params.taskType,
};
if (params.modelPath) {
request.model = params.modelPath;
}
if (params.outputDimensionality != null) {
request.outputDimensionality = params.outputDimensionality;
}
return request;
}
/**
* Returns true if the given model name is a gemini-embedding-2 variant that
* supports `outputDimensionality` and extended task types.
*/
export function isGeminiEmbedding2Model(model: string): boolean {
return GEMINI_EMBEDDING_2_MODELS.has(model);
}
/**
* Validate and return the `outputDimensionality` for gemini-embedding-2 models.
* Returns `undefined` for older models (they don't support the param).
*/
export function resolveGeminiOutputDimensionality(
model: string,
requested?: number,
): number | undefined {
if (!isGeminiEmbedding2Model(model)) {
return undefined;
}
if (requested == null) {
return GEMINI_EMBEDDING_2_DEFAULT_DIMENSIONS;
}
const valid: readonly number[] = GEMINI_EMBEDDING_2_VALID_DIMENSIONS;
if (!valid.includes(requested)) {
throw new Error(
`Invalid outputDimensionality ${requested} for ${model}. Valid values: ${valid.join(", ")}`,
);
}
return requested;
}
export function normalizeGeminiModel(model: string): string {
const trimmed = model.trim();
if (!trimmed) {
return DEFAULT_GEMINI_EMBEDDING_MODEL;
}
const withoutPrefix = trimmed.replace(/^models\//, "");
if (withoutPrefix.startsWith("gemini/")) {
return withoutPrefix.slice("gemini/".length);
}
if (withoutPrefix.startsWith("google/")) {
return withoutPrefix.slice("google/".length);
}
return withoutPrefix;
}

View File

@@ -1,52 +0,0 @@
import { describe, expect, it } from "vitest";
import {
buildGeminiEmbeddingRequest,
DEFAULT_GEMINI_EMBEDDING_MODEL,
normalizeGeminiModel,
resolveGeminiOutputDimensionality,
} from "./embeddings-gemini-request.js";
describe("package Gemini embedding request helpers", () => {
it("builds multimodal v2 requests and resolves model settings", () => {
expect(
buildGeminiEmbeddingRequest({
input: {
text: "Image file: diagram.png",
parts: [
{ type: "text", text: "Image file: diagram.png" },
{ type: "inline-data", mimeType: "image/png", data: "abc123" },
],
},
taskType: "RETRIEVAL_DOCUMENT",
modelPath: "models/gemini-embedding-2-preview",
outputDimensionality: 1536,
}),
).toEqual({
model: "models/gemini-embedding-2-preview",
content: {
parts: [
{ text: "Image file: diagram.png" },
{ inlineData: { mimeType: "image/png", data: "abc123" } },
],
},
taskType: "RETRIEVAL_DOCUMENT",
outputDimensionality: 1536,
});
expect(resolveGeminiOutputDimensionality("gemini-embedding-001")).toBeUndefined();
expect(resolveGeminiOutputDimensionality("gemini-embedding-2-preview")).toBe(3072);
expect(resolveGeminiOutputDimensionality("gemini-embedding-2-preview", 768)).toBe(768);
expect(() => resolveGeminiOutputDimensionality("gemini-embedding-2-preview", 512)).toThrow(
/Invalid outputDimensionality 512/,
);
expect(normalizeGeminiModel("models/gemini-embedding-2-preview")).toBe(
"gemini-embedding-2-preview",
);
expect(normalizeGeminiModel("gemini/gemini-embedding-2-preview")).toBe(
"gemini-embedding-2-preview",
);
expect(normalizeGeminiModel("google/gemini-embedding-2-preview")).toBe(
"gemini-embedding-2-preview",
);
expect(normalizeGeminiModel("")).toBe(DEFAULT_GEMINI_EMBEDDING_MODEL);
});
});

View File

@@ -1,238 +0,0 @@
import {
collectProviderApiKeysForExecution,
executeWithApiKeyRotation,
} from "../../../../src/agents/api-key-rotation.js";
import { requireApiKey, resolveApiKeyForProvider } from "../../../../src/agents/model-auth.js";
import { parseGeminiAuth } from "../../../../src/infra/gemini-auth.js";
import {
DEFAULT_GOOGLE_API_BASE_URL,
normalizeGoogleApiBaseUrl,
} from "../../../../src/infra/google-api-base-url.js";
import type { SsrFPolicy } from "../../../../src/infra/net/ssrf.js";
import type { EmbeddingInput } from "./embedding-inputs.js";
import { sanitizeAndNormalizeEmbedding } from "./embedding-vectors.js";
import { debugEmbeddingsLog } from "./embeddings-debug.js";
import {
buildGeminiEmbeddingRequest,
buildGeminiTextEmbeddingRequest,
isGeminiEmbedding2Model,
normalizeGeminiModel,
resolveGeminiOutputDimensionality,
} from "./embeddings-gemini-request.js";
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js";
import { buildRemoteBaseUrlPolicy, withRemoteHttpResponse } from "./remote-http.js";
import { resolveMemorySecretInputString } from "./secret-input.js";
export {
buildGeminiEmbeddingRequest,
buildGeminiTextEmbeddingRequest,
DEFAULT_GEMINI_EMBEDDING_MODEL,
GEMINI_EMBEDDING_2_MODELS,
isGeminiEmbedding2Model,
normalizeGeminiModel,
resolveGeminiOutputDimensionality,
type GeminiEmbeddingRequest,
type GeminiInlinePart,
type GeminiPart,
type GeminiTaskType,
type GeminiTextEmbeddingRequest,
type GeminiTextPart,
} from "./embeddings-gemini-request.js";
export type GeminiEmbeddingClient = {
baseUrl: string;
headers: Record<string, string>;
ssrfPolicy?: SsrFPolicy;
model: string;
modelPath: string;
apiKeys: string[];
outputDimensionality?: number;
};
const GEMINI_MAX_INPUT_TOKENS: Record<string, number> = {
"text-embedding-004": 2048,
};
function resolveRemoteApiKey(remoteApiKey: unknown): string | undefined {
const trimmed = resolveMemorySecretInputString({
value: remoteApiKey,
path: "agents.*.memorySearch.remote.apiKey",
});
if (!trimmed) {
return undefined;
}
if (trimmed === "GOOGLE_API_KEY" || trimmed === "GEMINI_API_KEY") {
return process.env[trimmed]?.trim();
}
return trimmed;
}
async function fetchGeminiEmbeddingPayload(params: {
client: GeminiEmbeddingClient;
endpoint: string;
body: unknown;
}): Promise<{
embedding?: { values?: number[] };
embeddings?: Array<{ values?: number[] }>;
}> {
return await executeWithApiKeyRotation({
provider: "google",
apiKeys: params.client.apiKeys,
execute: async (apiKey) => {
const authHeaders = parseGeminiAuth(apiKey);
const headers = {
...authHeaders.headers,
...params.client.headers,
};
return await withRemoteHttpResponse({
url: params.endpoint,
ssrfPolicy: params.client.ssrfPolicy,
init: {
method: "POST",
headers,
body: JSON.stringify(params.body),
},
onResponse: async (res) => {
if (!res.ok) {
const text = await res.text();
throw new Error(`gemini embeddings failed: ${res.status} ${text}`);
}
return (await res.json()) as {
embedding?: { values?: number[] };
embeddings?: Array<{ values?: number[] }>;
};
},
});
},
});
}
function normalizeGeminiBaseUrl(raw: string): string {
const trimmed = raw.replace(/\/+$/, "");
const openAiIndex = trimmed.indexOf("/openai");
if (openAiIndex > -1) {
return normalizeGoogleApiBaseUrl(trimmed.slice(0, openAiIndex));
}
return normalizeGoogleApiBaseUrl(trimmed);
}
function buildGeminiModelPath(model: string): string {
return model.startsWith("models/") ? model : `models/${model}`;
}
export async function createGeminiEmbeddingProvider(
options: EmbeddingProviderOptions,
): Promise<{ provider: EmbeddingProvider; client: GeminiEmbeddingClient }> {
const client = await resolveGeminiEmbeddingClient(options);
const baseUrl = client.baseUrl.replace(/\/$/, "");
const embedUrl = `${baseUrl}/${client.modelPath}:embedContent`;
const batchUrl = `${baseUrl}/${client.modelPath}:batchEmbedContents`;
const isV2 = isGeminiEmbedding2Model(client.model);
const outputDimensionality = client.outputDimensionality;
const embedQuery = async (text: string): Promise<number[]> => {
if (!text.trim()) {
return [];
}
const payload = await fetchGeminiEmbeddingPayload({
client,
endpoint: embedUrl,
body: buildGeminiTextEmbeddingRequest({
text,
taskType: options.taskType ?? "RETRIEVAL_QUERY",
outputDimensionality: isV2 ? outputDimensionality : undefined,
}),
});
return sanitizeAndNormalizeEmbedding(payload.embedding?.values ?? []);
};
const embedBatchInputs = async (inputs: EmbeddingInput[]): Promise<number[][]> => {
if (inputs.length === 0) {
return [];
}
const payload = await fetchGeminiEmbeddingPayload({
client,
endpoint: batchUrl,
body: {
requests: inputs.map((input) =>
buildGeminiEmbeddingRequest({
input,
modelPath: client.modelPath,
taskType: options.taskType ?? "RETRIEVAL_DOCUMENT",
outputDimensionality: isV2 ? outputDimensionality : undefined,
}),
),
},
});
const embeddings = Array.isArray(payload.embeddings) ? payload.embeddings : [];
return inputs.map((_, index) => sanitizeAndNormalizeEmbedding(embeddings[index]?.values ?? []));
};
const embedBatch = async (texts: string[]): Promise<number[][]> => {
return await embedBatchInputs(
texts.map((text) => ({
text,
})),
);
};
return {
provider: {
id: "gemini",
model: client.model,
maxInputTokens: GEMINI_MAX_INPUT_TOKENS[client.model],
embedQuery,
embedBatch,
embedBatchInputs,
},
client,
};
}
export async function resolveGeminiEmbeddingClient(
options: EmbeddingProviderOptions,
): Promise<GeminiEmbeddingClient> {
const remote = options.remote;
const remoteApiKey = resolveRemoteApiKey(remote?.apiKey);
const remoteBaseUrl = remote?.baseUrl?.trim();
const apiKey = remoteApiKey
? remoteApiKey
: requireApiKey(
await resolveApiKeyForProvider({
provider: "google",
cfg: options.config,
agentDir: options.agentDir,
}),
"google",
);
const providerConfig = options.config.models?.providers?.google;
const rawBaseUrl =
remoteBaseUrl || providerConfig?.baseUrl?.trim() || DEFAULT_GOOGLE_API_BASE_URL;
const baseUrl = normalizeGeminiBaseUrl(rawBaseUrl);
const ssrfPolicy = buildRemoteBaseUrlPolicy(baseUrl);
const headerOverrides = Object.assign({}, providerConfig?.headers, remote?.headers);
const headers: Record<string, string> = {
...headerOverrides,
};
const apiKeys = collectProviderApiKeysForExecution({
provider: "google",
primaryApiKey: apiKey,
});
const model = normalizeGeminiModel(options.model);
const modelPath = buildGeminiModelPath(model);
const outputDimensionality = resolveGeminiOutputDimensionality(
model,
options.outputDimensionality,
);
debugEmbeddingsLog("memory embeddings: gemini client", {
rawBaseUrl,
baseUrl,
model,
modelPath,
outputDimensionality,
embedEndpoint: `${baseUrl}/${modelPath}:embedContent`,
batchEndpoint: `${baseUrl}/${modelPath}:batchEmbedContents`,
});
return { baseUrl, headers, ssrfPolicy, model, modelPath, apiKeys, outputDimensionality };
}

View File

@@ -1 +0,0 @@
export * from "../../../../src/memory-host-sdk/host/embeddings-lmstudio.js";

View File

@@ -1,19 +0,0 @@
import { describe, expect, it } from "vitest";
import { DEFAULT_MISTRAL_EMBEDDING_MODEL, normalizeMistralModel } from "./embeddings-mistral.js";
describe("normalizeMistralModel", () => {
it("returns the default model for empty values", () => {
expect(normalizeMistralModel("")).toBe(DEFAULT_MISTRAL_EMBEDDING_MODEL);
expect(normalizeMistralModel(" ")).toBe(DEFAULT_MISTRAL_EMBEDDING_MODEL);
});
it("strips the mistral/ prefix", () => {
expect(normalizeMistralModel("mistral/mistral-embed")).toBe("mistral-embed");
expect(normalizeMistralModel(" mistral/custom-embed ")).toBe("custom-embed");
});
it("keeps explicit non-prefixed models", () => {
expect(normalizeMistralModel("mistral-embed")).toBe("mistral-embed");
expect(normalizeMistralModel("custom-embed-v2")).toBe("custom-embed-v2");
});
});

View File

@@ -1,51 +0,0 @@
import type { SsrFPolicy } from "../../../../src/infra/net/ssrf.js";
import { normalizeEmbeddingModelWithPrefixes } from "./embeddings-model-normalize.js";
import {
createRemoteEmbeddingProvider,
resolveRemoteEmbeddingClient,
} from "./embeddings-remote-provider.js";
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js";
export type MistralEmbeddingClient = {
baseUrl: string;
headers: Record<string, string>;
ssrfPolicy?: SsrFPolicy;
model: string;
};
export const DEFAULT_MISTRAL_EMBEDDING_MODEL = "mistral-embed";
const DEFAULT_MISTRAL_BASE_URL = "https://api.mistral.ai/v1";
export function normalizeMistralModel(model: string): string {
return normalizeEmbeddingModelWithPrefixes({
model,
defaultModel: DEFAULT_MISTRAL_EMBEDDING_MODEL,
prefixes: ["mistral/"],
});
}
export async function createMistralEmbeddingProvider(
options: EmbeddingProviderOptions,
): Promise<{ provider: EmbeddingProvider; client: MistralEmbeddingClient }> {
const client = await resolveMistralEmbeddingClient(options);
return {
provider: createRemoteEmbeddingProvider({
id: "mistral",
client,
errorPrefix: "mistral embeddings failed",
}),
client,
};
}
export async function resolveMistralEmbeddingClient(
options: EmbeddingProviderOptions,
): Promise<MistralEmbeddingClient> {
return await resolveRemoteEmbeddingClient({
provider: "mistral",
options,
defaultBaseUrl: DEFAULT_MISTRAL_BASE_URL,
normalizeModel: normalizeMistralModel,
});
}

View File

@@ -1,43 +0,0 @@
import { beforeEach, describe, expect, it, vi } from "vitest";
const { createOllamaEmbeddingProviderMock } = vi.hoisted(() => ({
createOllamaEmbeddingProviderMock: vi.fn(async (options: unknown) => ({
provider: { source: "mock-provider", options },
client: { source: "mock-client" },
})),
}));
vi.mock("../../../../src/plugin-sdk/ollama-runtime.js", () => ({
DEFAULT_OLLAMA_EMBEDDING_MODEL: "nomic-embed-text",
createOllamaEmbeddingProvider: createOllamaEmbeddingProviderMock,
}));
describe("memory-host-sdk Ollama embedding facade", () => {
beforeEach(() => {
createOllamaEmbeddingProviderMock.mockClear();
});
it("re-exports the default Ollama embedding model", async () => {
const mod = await import("./embeddings-ollama.js");
expect(mod.DEFAULT_OLLAMA_EMBEDDING_MODEL).toBe("nomic-embed-text");
});
it("delegates provider creation to the plugin-sdk runtime facade", async () => {
const mod = await import("./embeddings-ollama.js");
const options = {
provider: "ollama",
model: "nomic-embed-text",
fallback: "none",
config: {},
};
const result = await mod.createOllamaEmbeddingProvider(options as never);
expect(createOllamaEmbeddingProviderMock).toHaveBeenCalledTimes(1);
expect(createOllamaEmbeddingProviderMock).toHaveBeenCalledWith(options);
expect(result).toEqual({
provider: { source: "mock-provider", options },
client: { source: "mock-client" },
});
});
});

View File

@@ -1,5 +0,0 @@
export type { OllamaEmbeddingClient } from "../../../../src/plugin-sdk/ollama-runtime.js";
export {
createOllamaEmbeddingProvider,
DEFAULT_OLLAMA_EMBEDDING_MODEL,
} from "../../../../src/plugin-sdk/ollama-runtime.js";

View File

@@ -1,58 +0,0 @@
import type { SsrFPolicy } from "../../../../src/infra/net/ssrf.js";
import { OPENAI_DEFAULT_EMBEDDING_MODEL } from "../../../../src/plugins/provider-model-defaults.js";
import { normalizeEmbeddingModelWithPrefixes } from "./embeddings-model-normalize.js";
import {
createRemoteEmbeddingProvider,
resolveRemoteEmbeddingClient,
} from "./embeddings-remote-provider.js";
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js";
export type OpenAiEmbeddingClient = {
baseUrl: string;
headers: Record<string, string>;
ssrfPolicy?: SsrFPolicy;
model: string;
};
const DEFAULT_OPENAI_BASE_URL = "https://api.openai.com/v1";
export const DEFAULT_OPENAI_EMBEDDING_MODEL = OPENAI_DEFAULT_EMBEDDING_MODEL;
const OPENAI_MAX_INPUT_TOKENS: Record<string, number> = {
"text-embedding-3-small": 8192,
"text-embedding-3-large": 8192,
"text-embedding-ada-002": 8191,
};
export function normalizeOpenAiModel(model: string): string {
return normalizeEmbeddingModelWithPrefixes({
model,
defaultModel: DEFAULT_OPENAI_EMBEDDING_MODEL,
prefixes: ["openai/"],
});
}
export async function createOpenAiEmbeddingProvider(
options: EmbeddingProviderOptions,
): Promise<{ provider: EmbeddingProvider; client: OpenAiEmbeddingClient }> {
const client = await resolveOpenAiEmbeddingClient(options);
return {
provider: createRemoteEmbeddingProvider({
id: "openai",
client,
errorPrefix: "openai embeddings failed",
maxInputTokens: OPENAI_MAX_INPUT_TOKENS[client.model],
}),
client,
};
}
export async function resolveOpenAiEmbeddingClient(
options: EmbeddingProviderOptions,
): Promise<OpenAiEmbeddingClient> {
return await resolveRemoteEmbeddingClient({
provider: "openai",
options,
defaultBaseUrl: DEFAULT_OPENAI_BASE_URL,
normalizeModel: normalizeOpenAiModel,
});
}

View File

@@ -4,7 +4,7 @@ import type { EmbeddingProviderOptions } from "./embeddings.js";
import { buildRemoteBaseUrlPolicy } from "./remote-http.js";
import { resolveMemorySecretInputString } from "./secret-input.js";
export type RemoteEmbeddingProviderId = "openai" | "voyage" | "mistral";
export type RemoteEmbeddingProviderId = string;
export async function resolveRemoteEmbeddingBearerClient(params: {
provider: RemoteEmbeddingProviderId;

View File

@@ -1,188 +0,0 @@
import { afterEach, beforeAll, beforeEach, describe, expect, it, vi } from "vitest";
import * as authModule from "../../../../src/agents/model-auth.js";
import { type FetchMock, withFetchPreconnect } from "../../../../src/test-utils/fetch-mock.js";
import { mockPublicPinnedHostname } from "./test-helpers/ssrf.js";
vi.mock("../../../../src/infra/net/fetch-guard.js", () => ({
fetchWithSsrFGuard: async (params: {
url: string;
init?: RequestInit;
fetchImpl?: typeof fetch;
}) => {
const fetchImpl = params.fetchImpl ?? globalThis.fetch;
if (!fetchImpl) {
throw new Error("fetch is not available");
}
const response = await fetchImpl(params.url, params.init);
return {
response,
finalUrl: params.url,
release: async () => {},
};
},
}));
const { resolveApiKeyForProviderMock } = vi.hoisted(() => ({
resolveApiKeyForProviderMock: vi.fn(),
}));
vi.mock("../../../../src/agents/model-auth.js", () => {
return {
resolveApiKeyForProvider: resolveApiKeyForProviderMock,
requireApiKey: (auth: { apiKey?: string; mode?: string }, provider: string) => {
if (auth.apiKey) {
return auth.apiKey;
}
throw new Error(`No API key resolved for provider "${provider}" (auth mode: ${auth.mode}).`);
},
};
});
const createFetchMock = () => {
const fetchMock = vi.fn<FetchMock>(
async (_input: RequestInfo | URL, _init?: RequestInit) =>
new Response(JSON.stringify({ data: [{ embedding: [0.1, 0.2, 0.3] }] }), {
status: 200,
headers: { "Content-Type": "application/json" },
}),
);
return withFetchPreconnect(fetchMock);
};
function installFetchMock(fetchMock: typeof globalThis.fetch) {
vi.stubGlobal("fetch", fetchMock);
}
let createVoyageEmbeddingProvider: typeof import("./embeddings-voyage.js").createVoyageEmbeddingProvider;
let normalizeVoyageModel: typeof import("./embeddings-voyage.js").normalizeVoyageModel;
beforeAll(async () => {
({ createVoyageEmbeddingProvider, normalizeVoyageModel } =
await import("./embeddings-voyage.js"));
});
beforeEach(() => {
vi.useRealTimers();
vi.doUnmock("undici");
});
function mockVoyageApiKey() {
vi.mocked(authModule.resolveApiKeyForProvider).mockResolvedValue({
apiKey: "voyage-key-123",
mode: "api-key",
source: "test",
});
}
async function createDefaultVoyageProvider(
model: string,
fetchMock: ReturnType<typeof createFetchMock>,
) {
installFetchMock(fetchMock as unknown as typeof globalThis.fetch);
mockPublicPinnedHostname();
mockVoyageApiKey();
return createVoyageEmbeddingProvider({
config: {} as never,
provider: "voyage",
model,
fallback: "none",
});
}
describe("voyage embedding provider", () => {
afterEach(() => {
vi.doUnmock("undici");
vi.resetAllMocks();
vi.unstubAllGlobals();
});
it("configures client with correct defaults and headers", async () => {
const fetchMock = createFetchMock();
const result = await createDefaultVoyageProvider("voyage-4-large", fetchMock);
await result.provider.embedQuery("test query");
expect(authModule.resolveApiKeyForProvider).toHaveBeenCalledWith(
expect.objectContaining({ provider: "voyage" }),
);
const call = fetchMock.mock.calls[0];
expect(call).toBeDefined();
const [url, init] = call as [RequestInfo | URL, RequestInit | undefined];
expect(url).toBe("https://api.voyageai.com/v1/embeddings");
const headers = (init?.headers ?? {}) as Record<string, string>;
expect(headers.Authorization).toBe("Bearer voyage-key-123");
expect(headers["Content-Type"]).toBe("application/json");
const body = JSON.parse(init?.body as string);
expect(body).toEqual({
model: "voyage-4-large",
input: ["test query"],
input_type: "query",
});
});
it("respects remote overrides for baseUrl and apiKey", async () => {
const fetchMock = createFetchMock();
installFetchMock(fetchMock as unknown as typeof globalThis.fetch);
mockPublicPinnedHostname();
const result = await createVoyageEmbeddingProvider({
config: {} as never,
provider: "voyage",
model: "voyage-4-lite",
fallback: "none",
remote: {
baseUrl: "https://example.com",
apiKey: "remote-override-key",
headers: { "X-Custom": "123" },
},
});
await result.provider.embedQuery("test");
const call = fetchMock.mock.calls[0];
expect(call).toBeDefined();
const [url, init] = call as [RequestInfo | URL, RequestInit | undefined];
expect(url).toBe("https://example.com/embeddings");
const headers = (init?.headers ?? {}) as Record<string, string>;
expect(headers.Authorization).toBe("Bearer remote-override-key");
expect(headers["X-Custom"]).toBe("123");
});
it("passes input_type=document for embedBatch", async () => {
const fetchMock = withFetchPreconnect(
vi.fn<FetchMock>(
async (_input: RequestInfo | URL, _init?: RequestInit) =>
new Response(
JSON.stringify({
data: [{ embedding: [0.1, 0.2] }, { embedding: [0.3, 0.4] }],
}),
{ status: 200, headers: { "Content-Type": "application/json" } },
),
),
);
const result = await createDefaultVoyageProvider("voyage-4-large", fetchMock);
await result.provider.embedBatch(["doc1", "doc2"]);
const call = fetchMock.mock.calls[0];
expect(call).toBeDefined();
const [, init] = call as [RequestInfo | URL, RequestInit | undefined];
const body = JSON.parse(init?.body as string);
expect(body).toEqual({
model: "voyage-4-large",
input: ["doc1", "doc2"],
input_type: "document",
});
});
it("normalizes model names", async () => {
expect(normalizeVoyageModel("voyage/voyage-large-2")).toBe("voyage-large-2");
expect(normalizeVoyageModel("voyage-4-large")).toBe("voyage-4-large");
expect(normalizeVoyageModel(" voyage-lite ")).toBe("voyage-lite");
expect(normalizeVoyageModel("")).toBe("voyage-4-large"); // Default
});
});

View File

@@ -1,82 +0,0 @@
import type { SsrFPolicy } from "../../../../src/infra/net/ssrf.js";
import { normalizeEmbeddingModelWithPrefixes } from "./embeddings-model-normalize.js";
import { resolveRemoteEmbeddingBearerClient } from "./embeddings-remote-client.js";
import { fetchRemoteEmbeddingVectors } from "./embeddings-remote-fetch.js";
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js";
export type VoyageEmbeddingClient = {
baseUrl: string;
headers: Record<string, string>;
ssrfPolicy?: SsrFPolicy;
model: string;
};
export const DEFAULT_VOYAGE_EMBEDDING_MODEL = "voyage-4-large";
const DEFAULT_VOYAGE_BASE_URL = "https://api.voyageai.com/v1";
const VOYAGE_MAX_INPUT_TOKENS: Record<string, number> = {
"voyage-3": 32000,
"voyage-3-lite": 16000,
"voyage-code-3": 32000,
};
export function normalizeVoyageModel(model: string): string {
return normalizeEmbeddingModelWithPrefixes({
model,
defaultModel: DEFAULT_VOYAGE_EMBEDDING_MODEL,
prefixes: ["voyage/"],
});
}
export async function createVoyageEmbeddingProvider(
options: EmbeddingProviderOptions,
): Promise<{ provider: EmbeddingProvider; client: VoyageEmbeddingClient }> {
const client = await resolveVoyageEmbeddingClient(options);
const url = `${client.baseUrl.replace(/\/$/, "")}/embeddings`;
const embed = async (input: string[], input_type?: "query" | "document"): Promise<number[][]> => {
if (input.length === 0) {
return [];
}
const body: { model: string; input: string[]; input_type?: "query" | "document" } = {
model: client.model,
input,
};
if (input_type) {
body.input_type = input_type;
}
return await fetchRemoteEmbeddingVectors({
url,
headers: client.headers,
ssrfPolicy: client.ssrfPolicy,
body,
errorPrefix: "voyage embeddings failed",
});
};
return {
provider: {
id: "voyage",
model: client.model,
maxInputTokens: VOYAGE_MAX_INPUT_TOKENS[client.model],
embedQuery: async (text) => {
const [vec] = await embed([text], "query");
return vec ?? [];
},
embedBatch: async (texts) => embed(texts, "document"),
},
client,
};
}
export async function resolveVoyageEmbeddingClient(
options: EmbeddingProviderOptions,
): Promise<VoyageEmbeddingClient> {
const { baseUrl, headers, ssrfPolicy } = await resolveRemoteEmbeddingBearerClient({
provider: "voyage",
options,
defaultBaseUrl: DEFAULT_VOYAGE_BASE_URL,
});
const model = normalizeVoyageModel(options.model);
return { baseUrl, headers, ssrfPolicy, model };
}

View File

@@ -1,199 +1,8 @@
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
import * as authModule from "../../../../src/agents/model-auth.js";
import { createEmbeddingProvider, DEFAULT_LOCAL_MODEL } from "./embeddings.js";
import * as nodeLlamaModule from "./node-llama.js";
import { mockPublicPinnedHostname } from "./test-helpers/ssrf.js";
import { describe, expect, it } from "vitest";
import { DEFAULT_LOCAL_MODEL } from "./embeddings.js";
const { resolveApiKeyForProviderMock } = vi.hoisted(() => ({
resolveApiKeyForProviderMock: vi.fn(),
}));
vi.mock("../../../../src/agents/model-auth.js", () => {
return {
resolveApiKeyForProvider: resolveApiKeyForProviderMock,
requireApiKey: (auth: { apiKey?: string; mode?: string }, provider: string) => {
if (auth.apiKey) {
return auth.apiKey;
}
throw new Error(`No API key resolved for provider "${provider}" (auth mode: ${auth.mode}).`);
},
};
});
vi.mock("../../../../src/infra/net/fetch-guard.js", () => ({
fetchWithSsrFGuard: async (params: {
url: string;
init?: RequestInit;
fetchImpl?: typeof fetch;
}) => {
const fetchImpl = params.fetchImpl ?? globalThis.fetch;
if (!fetchImpl) {
throw new Error("fetch is not available");
}
const response = await fetchImpl(params.url, params.init);
return {
response,
finalUrl: params.url,
release: async () => {},
};
},
}));
const createEmbeddingDataFetchMock = () =>
vi.fn(async (_input?: unknown, _init?: unknown) => ({
ok: true,
status: 200,
json: async () => ({ data: [{ embedding: [1, 2, 3] }] }),
}));
const createGeminiFetchMock = () =>
vi.fn(async (_input?: unknown, _init?: unknown) => ({
ok: true,
status: 200,
json: async () => ({ embedding: { values: [1, 2, 3] } }),
}));
beforeEach(() => {
vi.spyOn(authModule, "resolveApiKeyForProvider");
vi.spyOn(nodeLlamaModule, "importNodeLlamaCpp");
});
afterEach(() => {
vi.resetAllMocks();
vi.unstubAllGlobals();
});
function installFetchMock(fetchMock: typeof globalThis.fetch) {
vi.stubGlobal("fetch", fetchMock);
}
function readFirstFetchRequest(fetchMock: { mock: { calls: unknown[][] } }) {
const [url, init] = fetchMock.mock.calls[0] ?? [];
return { url, init: init as RequestInit | undefined };
}
function requireProvider(result: Awaited<ReturnType<typeof createEmbeddingProvider>>) {
if (!result.provider) {
throw new Error("Expected embedding provider");
}
return result.provider;
}
function mockResolvedProviderKey(apiKey = "provider-key") {
vi.mocked(authModule.resolveApiKeyForProvider).mockResolvedValue({
apiKey,
mode: "api-key",
source: "test",
});
}
describe("package embedding provider smoke", () => {
it("uses remote OpenAI baseUrl/apiKey and merges headers", async () => {
const fetchMock = createEmbeddingDataFetchMock();
installFetchMock(fetchMock as unknown as typeof globalThis.fetch);
mockPublicPinnedHostname();
mockResolvedProviderKey("provider-key");
const result = await createEmbeddingProvider({
config: {
models: {
providers: {
openai: {
baseUrl: "https://api.openai.com/v1",
headers: { "X-Provider": "p", "X-Shared": "provider" },
},
},
},
} as never,
provider: "openai",
remote: {
baseUrl: "https://example.com/v1",
apiKey: " remote-key ",
headers: { "X-Shared": "remote", "X-Remote": "r" },
},
model: "text-embedding-3-small",
fallback: "openai",
});
await requireProvider(result).embedQuery("hello");
expect(authModule.resolveApiKeyForProvider).not.toHaveBeenCalled();
const { url, init } = readFirstFetchRequest(fetchMock);
expect(url).toBe("https://example.com/v1/embeddings");
const headers = (init?.headers ?? {}) as Record<string, string>;
expect(headers.Authorization).toBe("Bearer remote-key");
expect(headers["X-Provider"]).toBe("p");
expect(headers["X-Shared"]).toBe("remote");
expect(headers["X-Remote"]).toBe("r");
});
it("uses GEMINI_API_KEY env indirection for Gemini remote apiKey", async () => {
const fetchMock = createGeminiFetchMock();
installFetchMock(fetchMock as unknown as typeof globalThis.fetch);
mockPublicPinnedHostname();
vi.stubEnv("GEMINI_API_KEY", "env-gemini-key");
const result = await createEmbeddingProvider({
config: {} as never,
provider: "gemini",
remote: {
apiKey: "GEMINI_API_KEY", // pragma: allowlist secret
},
model: "text-embedding-004",
fallback: "openai",
});
await requireProvider(result).embedQuery("hello");
const { init } = readFirstFetchRequest(fetchMock);
const headers = (init?.headers ?? {}) as Record<string, string>;
expect(headers["x-goog-api-key"]).toBe("env-gemini-key");
});
it("normalizes local embeddings and resolves the default local model", async () => {
const resolveModelFileMock = vi.fn(async () => "/fake/model.gguf");
vi.mocked(nodeLlamaModule.importNodeLlamaCpp).mockResolvedValue({
getLlama: async () => ({
loadModel: vi.fn().mockResolvedValue({
createEmbeddingContext: vi.fn().mockResolvedValue({
getEmbeddingFor: vi.fn().mockResolvedValue({
vector: new Float32Array([2.35, 3.45, 0.63, 4.3]),
}),
}),
}),
}),
resolveModelFile: resolveModelFileMock,
LlamaLogLevel: { error: 0 },
} as never);
const result = await createEmbeddingProvider({
config: {} as never,
provider: "local",
model: "",
fallback: "none",
});
const embedding = await requireProvider(result).embedQuery("test query");
const magnitude = Math.sqrt(embedding.reduce((sum, value) => sum + value * value, 0));
expect(magnitude).toBeCloseTo(1, 5);
expect(resolveModelFileMock).toHaveBeenCalledWith(DEFAULT_LOCAL_MODEL, undefined);
});
it("returns null provider when explicit primary and fallback auth paths fail", async () => {
vi.mocked(authModule.resolveApiKeyForProvider).mockRejectedValue(
new Error("No API key found for provider"),
);
const result = await createEmbeddingProvider({
config: {} as never,
provider: "openai",
model: "text-embedding-3-small",
fallback: "gemini",
});
expect(result.provider).toBeNull();
expect(result.requestedProvider).toBe("openai");
expect(result.fallbackFrom).toBe("openai");
expect(result.providerUnavailableReason).toContain("Fallback to gemini failed");
describe("package embeddings barrel", () => {
it("re-exports the source local embedding contract", () => {
expect(DEFAULT_LOCAL_MODEL).toContain("embeddinggemma");
});
});

View File

@@ -1,373 +1 @@
import fsSync from "node:fs";
import type { Llama, LlamaEmbeddingContext, LlamaModel } from "node-llama-cpp";
import type { OpenClawConfig } from "../../../../src/config/config.js";
import type { SecretInput } from "../../../../src/config/types.secrets.js";
import { formatErrorMessage } from "../../../../src/infra/errors.js";
import { resolveUserPath } from "../../../../src/utils.js";
import type { EmbeddingInput } from "./embedding-inputs.js";
import { sanitizeAndNormalizeEmbedding } from "./embedding-vectors.js";
import {
createBedrockEmbeddingProvider,
hasAwsCredentials,
type BedrockEmbeddingClient,
} from "./embeddings-bedrock.js";
import {
createGeminiEmbeddingProvider,
type GeminiEmbeddingClient,
type GeminiTaskType,
} from "./embeddings-gemini.js";
import {
createLmstudioEmbeddingProvider,
type LmstudioEmbeddingClient,
} from "./embeddings-lmstudio.js";
import {
createMistralEmbeddingProvider,
type MistralEmbeddingClient,
} from "./embeddings-mistral.js";
import { createOllamaEmbeddingProvider, type OllamaEmbeddingClient } from "./embeddings-ollama.js";
import { createOpenAiEmbeddingProvider, type OpenAiEmbeddingClient } from "./embeddings-openai.js";
import { createVoyageEmbeddingProvider, type VoyageEmbeddingClient } from "./embeddings-voyage.js";
import { importNodeLlamaCpp } from "./node-llama.js";
export type { GeminiEmbeddingClient } from "./embeddings-gemini.js";
export type { LmstudioEmbeddingClient } from "./embeddings-lmstudio.js";
export type { MistralEmbeddingClient } from "./embeddings-mistral.js";
export type { OpenAiEmbeddingClient } from "./embeddings-openai.js";
export type { VoyageEmbeddingClient } from "./embeddings-voyage.js";
export type { OllamaEmbeddingClient } from "./embeddings-ollama.js";
export type { BedrockEmbeddingClient } from "./embeddings-bedrock.js";
export type EmbeddingProvider = {
id: string;
model: string;
maxInputTokens?: number;
embedQuery: (text: string) => Promise<number[]>;
embedBatch: (texts: string[]) => Promise<number[][]>;
embedBatchInputs?: (inputs: EmbeddingInput[]) => Promise<number[][]>;
};
export type EmbeddingProviderId =
| "openai"
| "local"
| "gemini"
| "voyage"
| "mistral"
| "bedrock"
| "lmstudio"
| "ollama";
export type EmbeddingProviderRequest = EmbeddingProviderId | "auto";
export type EmbeddingProviderFallback = EmbeddingProviderId | "none";
// Remote providers considered for auto-selection when provider === "auto".
// LM Studio and Ollama are intentionally excluded here so that "auto" mode does not
// implicitly assume either instance is available.
// Bedrock is handled separately when AWS credentials are detected.
const REMOTE_EMBEDDING_PROVIDER_IDS = ["openai", "gemini", "voyage", "mistral"] as const;
export type EmbeddingProviderResult = {
provider: EmbeddingProvider | null;
requestedProvider: EmbeddingProviderRequest;
fallbackFrom?: EmbeddingProviderId;
fallbackReason?: string;
providerUnavailableReason?: string;
openAi?: OpenAiEmbeddingClient;
gemini?: GeminiEmbeddingClient;
voyage?: VoyageEmbeddingClient;
mistral?: MistralEmbeddingClient;
bedrock?: BedrockEmbeddingClient;
lmstudio?: LmstudioEmbeddingClient;
ollama?: OllamaEmbeddingClient;
};
export type EmbeddingProviderOptions = {
config: OpenClawConfig;
agentDir?: string;
provider: EmbeddingProviderRequest;
remote?: {
baseUrl?: string;
apiKey?: SecretInput;
headers?: Record<string, string>;
};
model: string;
fallback: EmbeddingProviderFallback;
local?: {
modelPath?: string;
modelCacheDir?: string;
};
/** Provider-specific output vector dimensions for supported embedding families. */
outputDimensionality?: number;
/** Gemini: override the default task type sent with embedding requests. */
taskType?: GeminiTaskType;
};
export const DEFAULT_LOCAL_MODEL =
"hf:ggml-org/embeddinggemma-300m-qat-q8_0-GGUF/embeddinggemma-300m-qat-Q8_0.gguf";
function canAutoSelectLocal(options: EmbeddingProviderOptions): boolean {
const modelPath = options.local?.modelPath?.trim();
if (!modelPath) {
return false;
}
if (/^(hf:|https?:)/i.test(modelPath)) {
return false;
}
const resolved = resolveUserPath(modelPath);
try {
return fsSync.statSync(resolved).isFile();
} catch {
return false;
}
}
function isMissingApiKeyError(err: unknown): boolean {
const message = formatErrorMessage(err);
return message.includes("No API key found for provider");
}
export async function createLocalEmbeddingProvider(
options: EmbeddingProviderOptions,
): Promise<EmbeddingProvider> {
const modelPath = options.local?.modelPath?.trim() || DEFAULT_LOCAL_MODEL;
const modelCacheDir = options.local?.modelCacheDir?.trim();
// Lazy-load node-llama-cpp to keep startup light unless local is enabled.
const { getLlama, resolveModelFile, LlamaLogLevel } = await importNodeLlamaCpp();
let llama: Llama | null = null;
let embeddingModel: LlamaModel | null = null;
let embeddingContext: LlamaEmbeddingContext | null = null;
let initPromise: Promise<LlamaEmbeddingContext> | null = null;
const ensureContext = async (): Promise<LlamaEmbeddingContext> => {
if (embeddingContext) {
return embeddingContext;
}
if (initPromise) {
return initPromise;
}
initPromise = (async () => {
try {
if (!llama) {
llama = await getLlama({ logLevel: LlamaLogLevel.error });
}
if (!embeddingModel) {
const resolved = await resolveModelFile(modelPath, modelCacheDir || undefined);
embeddingModel = await llama.loadModel({ modelPath: resolved });
}
if (!embeddingContext) {
embeddingContext = await embeddingModel.createEmbeddingContext();
}
return embeddingContext;
} catch (err) {
initPromise = null;
throw err;
}
})();
return initPromise;
};
return {
id: "local",
model: modelPath,
embedQuery: async (text) => {
const ctx = await ensureContext();
const embedding = await ctx.getEmbeddingFor(text);
return sanitizeAndNormalizeEmbedding(Array.from(embedding.vector));
},
embedBatch: async (texts) => {
const ctx = await ensureContext();
const embeddings = await Promise.all(
texts.map(async (text) => {
const embedding = await ctx.getEmbeddingFor(text);
return sanitizeAndNormalizeEmbedding(Array.from(embedding.vector));
}),
);
return embeddings;
},
};
}
export async function createEmbeddingProvider(
options: EmbeddingProviderOptions,
): Promise<EmbeddingProviderResult> {
const requestedProvider = options.provider;
const fallback = options.fallback;
const createProvider = async (id: EmbeddingProviderId) => {
if (id === "local") {
const provider = await createLocalEmbeddingProvider(options);
return { provider };
}
if (id === "lmstudio") {
const { provider, client } = await createLmstudioEmbeddingProvider(options);
return { provider, lmstudio: client };
}
if (id === "ollama") {
const { provider, client } = await createOllamaEmbeddingProvider(options);
return { provider, ollama: client };
}
if (id === "gemini") {
const { provider, client } = await createGeminiEmbeddingProvider(options);
return { provider, gemini: client };
}
if (id === "voyage") {
const { provider, client } = await createVoyageEmbeddingProvider(options);
return { provider, voyage: client };
}
if (id === "mistral") {
const { provider, client } = await createMistralEmbeddingProvider(options);
return { provider, mistral: client };
}
if (id === "bedrock") {
const { provider, client } = await createBedrockEmbeddingProvider(options);
return { provider, bedrock: client };
}
const { provider, client } = await createOpenAiEmbeddingProvider(options);
return { provider, openAi: client };
};
const formatPrimaryError = (err: unknown, provider: EmbeddingProviderId) =>
provider === "local" ? formatLocalSetupError(err) : formatErrorMessage(err);
if (requestedProvider === "auto") {
const missingKeyErrors: string[] = [];
let localError: string | null = null;
if (canAutoSelectLocal(options)) {
try {
const local = await createProvider("local");
return { ...local, requestedProvider };
} catch (err) {
localError = formatLocalSetupError(err);
}
}
for (const provider of REMOTE_EMBEDDING_PROVIDER_IDS) {
try {
const result = await createProvider(provider);
return { ...result, requestedProvider };
} catch (err) {
const message = formatPrimaryError(err, provider);
if (isMissingApiKeyError(err)) {
missingKeyErrors.push(message);
continue;
}
// Non-auth errors (e.g., network) are still fatal
const wrapped = new Error(message) as Error & { cause?: unknown };
wrapped.cause = err;
throw wrapped;
}
}
// Try bedrock if AWS credentials are available
if (await hasAwsCredentials()) {
try {
const result = await createProvider("bedrock");
return { ...result, requestedProvider };
} catch (err) {
const message = formatPrimaryError(err, "bedrock");
if (isMissingApiKeyError(err)) {
missingKeyErrors.push(message);
} else {
const wrapped = new Error(message) as Error & { cause?: unknown };
wrapped.cause = err;
throw wrapped;
}
}
}
// All providers failed due to missing API keys - return null provider for FTS-only mode
const details = [...missingKeyErrors, localError].filter(Boolean) as string[];
const reason = details.length > 0 ? details.join("\n\n") : "No embeddings provider available.";
return {
provider: null,
requestedProvider,
providerUnavailableReason: reason,
};
}
try {
const primary = await createProvider(requestedProvider);
return { ...primary, requestedProvider };
} catch (primaryErr) {
const reason = formatPrimaryError(primaryErr, requestedProvider);
if (fallback && fallback !== "none" && fallback !== requestedProvider) {
try {
const fallbackResult = await createProvider(fallback);
return {
...fallbackResult,
requestedProvider,
fallbackFrom: requestedProvider,
fallbackReason: reason,
};
} catch (fallbackErr) {
// Both primary and fallback failed - check if it's auth-related
const fallbackReason = formatErrorMessage(fallbackErr);
const combinedReason = `${reason}\n\nFallback to ${fallback} failed: ${fallbackReason}`;
if (isMissingApiKeyError(primaryErr) && isMissingApiKeyError(fallbackErr)) {
// Both failed due to missing API keys - return null for FTS-only mode
return {
provider: null,
requestedProvider,
fallbackFrom: requestedProvider,
fallbackReason: reason,
providerUnavailableReason: combinedReason,
};
}
// Non-auth errors are still fatal
const wrapped = new Error(combinedReason) as Error & {
cause?: unknown;
};
wrapped.cause = fallbackErr;
throw wrapped;
}
}
// No fallback configured - check if we should degrade to FTS-only
if (isMissingApiKeyError(primaryErr)) {
return {
provider: null,
requestedProvider,
providerUnavailableReason: reason,
};
}
const wrapped = new Error(reason) as Error & { cause?: unknown };
wrapped.cause = primaryErr;
throw wrapped;
}
}
function isNodeLlamaCppMissing(err: unknown): boolean {
if (!(err instanceof Error)) {
return false;
}
const code = (err as Error & { code?: unknown }).code;
if (code === "ERR_MODULE_NOT_FOUND") {
return err.message.includes("node-llama-cpp");
}
return false;
}
function formatLocalSetupError(err: unknown): string {
const detail = formatErrorMessage(err);
const missing = isNodeLlamaCppMissing(err);
return [
"Local embeddings unavailable.",
missing
? "Reason: optional dependency node-llama-cpp is missing (or failed to install)."
: detail
? `Reason: ${detail}`
: undefined,
missing && detail ? `Detail: ${detail}` : null,
"To enable local embeddings:",
"1) Use Node 24 (recommended for installs/updates; Node 22 LTS, currently 22.14+, remains supported)",
missing
? "2) Reinstall OpenClaw (this should install node-llama-cpp): npm i -g openclaw@latest"
: null,
"3) If you use pnpm: pnpm approve-builds (select node-llama-cpp), then pnpm rebuild node-llama-cpp",
...REMOTE_EMBEDDING_PROVIDER_IDS.map(
(provider) => `Or set agents.defaults.memorySearch.provider = "${provider}" (remote).`,
),
]
.filter(Boolean)
.join("\n");
}
export * from "../../../../src/memory-host-sdk/host/embeddings.js";

View File

@@ -100,21 +100,3 @@ export function classifyMemoryMultimodalPath(
}
return null;
}
export function normalizeGeminiEmbeddingModelForMemory(model: string): string {
const trimmed = model.trim();
if (!trimmed) {
return "";
}
return trimmed.replace(/^models\//, "").replace(/^(gemini|google)\//, "");
}
export function supportsMemoryMultimodalEmbeddings(params: {
provider: string;
model: string;
}): boolean {
if (params.provider !== "gemini") {
return false;
}
return normalizeGeminiEmbeddingModelForMemory(params.model) === "gemini-embedding-2-preview";
}

12
pnpm-lock.yaml generated
View File

@@ -362,6 +362,12 @@ importers:
'@aws-sdk/client-bedrock':
specifier: 3.1028.0
version: 3.1028.0
'@aws-sdk/client-bedrock-runtime':
specifier: 3.1028.0
version: 3.1028.0
'@aws-sdk/credential-provider-node':
specifier: 3.972.30
version: 3.972.30
devDependencies:
'@openclaw/plugin-sdk':
specifier: workspace:*
@@ -1225,6 +1231,12 @@ importers:
specifier: workspace:*
version: link:../../packages/plugin-sdk
extensions/voyage:
devDependencies:
'@openclaw/plugin-sdk':
specifier: workspace:*
version: link:../../packages/plugin-sdk
extensions/vydra:
devDependencies:
'@openclaw/plugin-sdk':

View File

@@ -6,7 +6,6 @@ import type { SecretInput } from "../config/types.secrets.js";
import {
isMemoryMultimodalEnabled,
normalizeMemoryMultimodalSettings,
supportsMemoryMultimodalEmbeddings,
type MemoryMultimodalSettings,
} from "../memory-host-sdk/multimodal.js";
import { getMemoryEmbeddingProvider } from "../plugins/memory-embedding-provider-runtime.js";
@@ -389,24 +388,9 @@ export function resolveMemorySearchConfig(
const multimodalActive = isMemoryMultimodalEnabled(resolved.multimodal);
const multimodalProvider =
resolved.provider === "auto" ? undefined : getMemoryEmbeddingProvider(resolved.provider);
const builtinMultimodalSupport =
resolved.provider === "auto"
? false
: supportsMemoryMultimodalEmbeddings({
provider: resolved.provider,
model: resolved.model,
});
if (
multimodalActive &&
!(
// Fall back to the built-in helper when the provider is not registered yet
// or when a registered adapter does not implement multimodal capability checks.
(
multimodalProvider?.supportsMultimodalEmbeddings?.({
model: resolved.model,
}) ?? builtinMultimodalSupport
)
)
!(multimodalProvider?.supportsMultimodalEmbeddings?.({ model: resolved.model }) ?? false)
) {
throw new Error(
"agents.*.memorySearch.multimodal requires a provider adapter that supports multimodal embeddings for the configured model.",

View File

@@ -16,50 +16,56 @@ export type {
MemoryEmbeddingProviderRuntime,
} from "../plugins/memory-embedding-providers.js";
export { createLocalEmbeddingProvider, DEFAULT_LOCAL_MODEL } from "./host/embeddings.js";
export { extractBatchErrorMessage, formatUnavailableBatchError } from "./host/batch-error-utils.js";
export { postJsonWithRetry } from "./host/batch-http.js";
export { applyEmbeddingBatchOutputLine } from "./host/batch-output.js";
export {
createGeminiEmbeddingProvider,
DEFAULT_GEMINI_EMBEDDING_MODEL,
buildGeminiEmbeddingRequest,
} from "./host/embeddings-gemini.js";
EMBEDDING_BATCH_ENDPOINT,
type EmbeddingBatchStatus,
type ProviderBatchOutputLine,
} from "./host/batch-provider-common.js";
export {
createLmstudioEmbeddingProvider,
DEFAULT_LMSTUDIO_EMBEDDING_MODEL,
} from "./host/embeddings-lmstudio.js";
export type { LmstudioEmbeddingClient } from "./host/embeddings-lmstudio.js";
buildEmbeddingBatchGroupOptions,
runEmbeddingBatchGroups,
type EmbeddingBatchExecutionParams,
} from "./host/batch-runner.js";
export {
createMistralEmbeddingProvider,
DEFAULT_MISTRAL_EMBEDDING_MODEL,
} from "./host/embeddings-mistral.js";
resolveBatchCompletionFromStatus,
resolveCompletedBatchResult,
throwIfBatchTerminalFailure,
type BatchCompletionResult,
} from "./host/batch-status.js";
export { uploadBatchJsonlFile } from "./host/batch-upload.js";
export {
createGitHubCopilotEmbeddingProvider,
type GitHubCopilotEmbeddingClient,
} from "./host/embeddings-github-copilot.js";
export {
createOllamaEmbeddingProvider,
DEFAULT_OLLAMA_EMBEDDING_MODEL,
} from "./host/embeddings-ollama.js";
export type { OllamaEmbeddingClient } from "./host/embeddings-ollama.js";
export {
createOpenAiEmbeddingProvider,
DEFAULT_OPENAI_EMBEDDING_MODEL,
} from "./host/embeddings-openai.js";
export {
createVoyageEmbeddingProvider,
DEFAULT_VOYAGE_EMBEDDING_MODEL,
} from "./host/embeddings-voyage.js";
export { runGeminiEmbeddingBatches, type GeminiBatchRequest } from "./host/batch-gemini.js";
export {
OPENAI_BATCH_ENDPOINT,
runOpenAiEmbeddingBatches,
type OpenAiBatchRequest,
} from "./host/batch-openai.js";
export { runVoyageEmbeddingBatches, type VoyageBatchRequest } from "./host/batch-voyage.js";
buildBatchHeaders,
normalizeBatchBaseUrl,
type BatchHttpClientConfig,
} from "./host/batch-utils.js";
export { enforceEmbeddingMaxInputTokens } from "./host/embedding-chunk-limits.js";
export {
isMissingEmbeddingApiKeyError,
mapBatchEmbeddingsByIndex,
sanitizeEmbeddingCacheHeaders,
} from "./host/embedding-provider-adapter-utils.js";
export { sanitizeAndNormalizeEmbedding } from "./host/embedding-vectors.js";
export { debugEmbeddingsLog } from "./host/embeddings-debug.js";
export { normalizeEmbeddingModelWithPrefixes } from "./host/embeddings-model-normalize.js";
export {
resolveRemoteEmbeddingBearerClient,
type RemoteEmbeddingProviderId,
} from "./host/embeddings-remote-client.js";
export {
createRemoteEmbeddingProvider,
resolveRemoteEmbeddingClient,
type RemoteEmbeddingClient,
} from "./host/embeddings-remote-provider.js";
export { fetchRemoteEmbeddingVectors } from "./host/embeddings-remote-fetch.js";
export {
estimateStructuredEmbeddingInputBytes,
estimateUtf8Bytes,
} from "./host/embedding-input-limits.js";
export { hasNonTextEmbeddingParts, type EmbeddingInput } from "./host/embedding-inputs.js";
export { buildRemoteBaseUrlPolicy, withRemoteHttpResponse } from "./host/remote-http.js";
export {
buildCaseInsensitiveExtensionGlob,
classifyMemoryMultimodalPath,

View File

@@ -1,116 +0,0 @@
import { afterEach, beforeAll, beforeEach, describe, expect, it, vi } from "vitest";
import type { GeminiEmbeddingClient } from "./embeddings-gemini.js";
vi.mock("./remote-http.js", () => ({
withRemoteHttpResponse: vi.fn(),
}));
function magnitude(values: number[]) {
return Math.sqrt(values.reduce((sum, value) => sum + value * value, 0));
}
describe("runGeminiEmbeddingBatches", () => {
let runGeminiEmbeddingBatches: typeof import("./batch-gemini.js").runGeminiEmbeddingBatches;
let withRemoteHttpResponse: typeof import("./remote-http.js").withRemoteHttpResponse;
let remoteHttpMock: ReturnType<typeof vi.mocked<typeof withRemoteHttpResponse>>;
beforeAll(async () => {
({ runGeminiEmbeddingBatches } = await import("./batch-gemini.js"));
({ withRemoteHttpResponse } = await import("./remote-http.js"));
remoteHttpMock = vi.mocked(withRemoteHttpResponse);
});
beforeEach(() => {
vi.clearAllMocks();
});
afterEach(() => {
vi.resetAllMocks();
vi.unstubAllGlobals();
});
const mockClient: GeminiEmbeddingClient = {
baseUrl: "https://generativelanguage.googleapis.com/v1beta",
headers: {},
model: "gemini-embedding-2-preview",
modelPath: "models/gemini-embedding-2-preview",
apiKeys: ["test-key"],
outputDimensionality: 1536,
};
it("includes outputDimensionality in batch upload requests", async () => {
remoteHttpMock.mockImplementationOnce(async (params) => {
expect(params.url).toContain("/upload/v1beta/files?uploadType=multipart");
const body = params.init?.body;
if (!(body instanceof Blob)) {
throw new Error("expected multipart blob body");
}
const text = await body.text();
expect(text).toContain('"taskType":"RETRIEVAL_DOCUMENT"');
expect(text).toContain('"outputDimensionality":1536');
return await params.onResponse(
new Response(JSON.stringify({ name: "files/file-123" }), {
status: 200,
headers: { "Content-Type": "application/json" },
}),
);
});
remoteHttpMock.mockImplementationOnce(async (params) => {
expect(params.url).toMatch(/:asyncBatchEmbedContent$/u);
return await params.onResponse(
new Response(
JSON.stringify({
name: "batches/batch-1",
state: "COMPLETED",
outputConfig: { file: "files/output-1" },
}),
{
status: 200,
headers: { "Content-Type": "application/json" },
},
),
);
});
remoteHttpMock.mockImplementationOnce(async (params) => {
expect(params.url).toMatch(/\/files\/output-1:download$/u);
return await params.onResponse(
new Response(
JSON.stringify({
key: "req-1",
response: { embedding: { values: [3, 4] } },
}),
{
status: 200,
headers: { "Content-Type": "application/jsonl" },
},
),
);
});
const results = await runGeminiEmbeddingBatches({
gemini: mockClient,
agentId: "main",
requests: [
{
custom_id: "req-1",
request: {
content: { parts: [{ text: "hello world" }] },
taskType: "RETRIEVAL_DOCUMENT",
outputDimensionality: 1536,
},
},
],
wait: true,
pollIntervalMs: 1,
timeoutMs: 1000,
concurrency: 1,
});
const embedding = results.get("req-1");
expect(embedding).toBeDefined();
expect(embedding?.[0]).toBeCloseTo(0.6, 5);
expect(embedding?.[1]).toBeCloseTo(0.8, 5);
expect(magnitude(embedding ?? [])).toBeCloseTo(1, 5);
expect(remoteHttpMock).toHaveBeenCalledTimes(3);
});
});

View File

@@ -1,368 +0,0 @@
import {
buildEmbeddingBatchGroupOptions,
runEmbeddingBatchGroups,
type EmbeddingBatchExecutionParams,
} from "./batch-runner.js";
import { buildBatchHeaders, normalizeBatchBaseUrl } from "./batch-utils.js";
import { sanitizeAndNormalizeEmbedding } from "./embedding-vectors.js";
import { debugEmbeddingsLog } from "./embeddings-debug.js";
import type { GeminiEmbeddingClient, GeminiTextEmbeddingRequest } from "./embeddings-gemini.js";
import { hashText } from "./internal.js";
import { withRemoteHttpResponse } from "./remote-http.js";
export type GeminiBatchRequest = {
custom_id: string;
request: GeminiTextEmbeddingRequest;
};
export type GeminiBatchStatus = {
name?: string;
state?: string;
outputConfig?: { file?: string; fileId?: string };
metadata?: {
output?: {
responsesFile?: string;
};
};
error?: { message?: string };
};
export type GeminiBatchOutputLine = {
key?: string;
custom_id?: string;
request_id?: string;
embedding?: { values?: number[] };
response?: {
embedding?: { values?: number[] };
error?: { message?: string };
};
error?: { message?: string };
};
const GEMINI_BATCH_MAX_REQUESTS = 50000;
function getGeminiUploadUrl(baseUrl: string): string {
if (baseUrl.includes("/v1beta")) {
return baseUrl.replace(/\/v1beta\/?$/, "/upload/v1beta");
}
return `${baseUrl.replace(/\/$/, "")}/upload`;
}
function buildGeminiUploadBody(params: { jsonl: string; displayName: string }): {
body: Blob;
contentType: string;
} {
const boundary = `openclaw-${hashText(params.displayName)}`;
const jsonPart = JSON.stringify({
file: {
displayName: params.displayName,
mimeType: "application/jsonl",
},
});
const delimiter = `--${boundary}\r\n`;
const closeDelimiter = `--${boundary}--\r\n`;
const parts = [
`${delimiter}Content-Type: application/json; charset=UTF-8\r\n\r\n${jsonPart}\r\n`,
`${delimiter}Content-Type: application/jsonl; charset=UTF-8\r\n\r\n${params.jsonl}\r\n`,
closeDelimiter,
];
const body = new Blob([parts.join("")], { type: "multipart/related" });
return {
body,
contentType: `multipart/related; boundary=${boundary}`,
};
}
async function submitGeminiBatch(params: {
gemini: GeminiEmbeddingClient;
requests: GeminiBatchRequest[];
agentId: string;
}): Promise<GeminiBatchStatus> {
const baseUrl = normalizeBatchBaseUrl(params.gemini);
const jsonl = params.requests
.map((request) =>
JSON.stringify({
key: request.custom_id,
request: request.request,
}),
)
.join("\n");
const displayName = `memory-embeddings-${hashText(String(Date.now()))}`;
const uploadPayload = buildGeminiUploadBody({ jsonl, displayName });
const uploadUrl = `${getGeminiUploadUrl(baseUrl)}/files?uploadType=multipart`;
debugEmbeddingsLog("memory embeddings: gemini batch upload", {
uploadUrl,
baseUrl,
requests: params.requests.length,
});
const filePayload = await withRemoteHttpResponse({
url: uploadUrl,
ssrfPolicy: params.gemini.ssrfPolicy,
init: {
method: "POST",
headers: {
...buildBatchHeaders(params.gemini, { json: false }),
"Content-Type": uploadPayload.contentType,
},
body: uploadPayload.body,
},
onResponse: async (fileRes) => {
if (!fileRes.ok) {
const text = await fileRes.text();
throw new Error(`gemini batch file upload failed: ${fileRes.status} ${text}`);
}
return (await fileRes.json()) as { name?: string; file?: { name?: string } };
},
});
const fileId = filePayload.name ?? filePayload.file?.name;
if (!fileId) {
throw new Error("gemini batch file upload failed: missing file id");
}
const batchBody = {
batch: {
displayName: `memory-embeddings-${params.agentId}`,
inputConfig: {
file_name: fileId,
},
},
};
const batchEndpoint = `${baseUrl}/${params.gemini.modelPath}:asyncBatchEmbedContent`;
debugEmbeddingsLog("memory embeddings: gemini batch create", {
batchEndpoint,
fileId,
});
return await withRemoteHttpResponse({
url: batchEndpoint,
ssrfPolicy: params.gemini.ssrfPolicy,
init: {
method: "POST",
headers: buildBatchHeaders(params.gemini, { json: true }),
body: JSON.stringify(batchBody),
},
onResponse: async (batchRes) => {
if (batchRes.ok) {
return (await batchRes.json()) as GeminiBatchStatus;
}
const text = await batchRes.text();
if (batchRes.status === 404) {
throw new Error(
"gemini batch create failed: 404 (asyncBatchEmbedContent not available for this model/baseUrl). Disable remote.batch.enabled or switch providers.",
);
}
throw new Error(`gemini batch create failed: ${batchRes.status} ${text}`);
},
});
}
async function fetchGeminiBatchStatus(params: {
gemini: GeminiEmbeddingClient;
batchName: string;
}): Promise<GeminiBatchStatus> {
const baseUrl = normalizeBatchBaseUrl(params.gemini);
const name = params.batchName.startsWith("batches/")
? params.batchName
: `batches/${params.batchName}`;
const statusUrl = `${baseUrl}/${name}`;
debugEmbeddingsLog("memory embeddings: gemini batch status", { statusUrl });
return await withRemoteHttpResponse({
url: statusUrl,
ssrfPolicy: params.gemini.ssrfPolicy,
init: {
headers: buildBatchHeaders(params.gemini, { json: true }),
},
onResponse: async (res) => {
if (!res.ok) {
const text = await res.text();
throw new Error(`gemini batch status failed: ${res.status} ${text}`);
}
return (await res.json()) as GeminiBatchStatus;
},
});
}
async function fetchGeminiFileContent(params: {
gemini: GeminiEmbeddingClient;
fileId: string;
}): Promise<string> {
const baseUrl = normalizeBatchBaseUrl(params.gemini);
const file = params.fileId.startsWith("files/") ? params.fileId : `files/${params.fileId}`;
const downloadUrl = `${baseUrl}/${file}:download`;
debugEmbeddingsLog("memory embeddings: gemini batch download", { downloadUrl });
return await withRemoteHttpResponse({
url: downloadUrl,
ssrfPolicy: params.gemini.ssrfPolicy,
init: {
headers: buildBatchHeaders(params.gemini, { json: true }),
},
onResponse: async (res) => {
if (!res.ok) {
const text = await res.text();
throw new Error(`gemini batch file content failed: ${res.status} ${text}`);
}
return await res.text();
},
});
}
function parseGeminiBatchOutput(text: string): GeminiBatchOutputLine[] {
if (!text.trim()) {
return [];
}
return text
.split("\n")
.map((line) => line.trim())
.filter(Boolean)
.map((line) => JSON.parse(line) as GeminiBatchOutputLine);
}
async function waitForGeminiBatch(params: {
gemini: GeminiEmbeddingClient;
batchName: string;
wait: boolean;
pollIntervalMs: number;
timeoutMs: number;
debug?: (message: string, data?: Record<string, unknown>) => void;
initial?: GeminiBatchStatus;
}): Promise<{ outputFileId: string }> {
const start = Date.now();
let current: GeminiBatchStatus | undefined = params.initial;
while (true) {
const status =
current ??
(await fetchGeminiBatchStatus({
gemini: params.gemini,
batchName: params.batchName,
}));
const state = status.state ?? "UNKNOWN";
if (["SUCCEEDED", "COMPLETED", "DONE"].includes(state)) {
const outputFileId =
status.outputConfig?.file ??
status.outputConfig?.fileId ??
status.metadata?.output?.responsesFile;
if (!outputFileId) {
throw new Error(`gemini batch ${params.batchName} completed without output file`);
}
return { outputFileId };
}
if (["FAILED", "CANCELLED", "CANCELED", "EXPIRED"].includes(state)) {
const message = status.error?.message ?? "unknown error";
throw new Error(`gemini batch ${params.batchName} ${state}: ${message}`);
}
if (!params.wait) {
throw new Error(`gemini batch ${params.batchName} still ${state}; wait disabled`);
}
if (Date.now() - start > params.timeoutMs) {
throw new Error(`gemini batch ${params.batchName} timed out after ${params.timeoutMs}ms`);
}
params.debug?.(`gemini batch ${params.batchName} ${state}; waiting ${params.pollIntervalMs}ms`);
await new Promise((resolve) => setTimeout(resolve, params.pollIntervalMs));
current = undefined;
}
}
export async function runGeminiEmbeddingBatches(
params: {
gemini: GeminiEmbeddingClient;
agentId: string;
requests: GeminiBatchRequest[];
} & EmbeddingBatchExecutionParams,
): Promise<Map<string, number[]>> {
return await runEmbeddingBatchGroups({
...buildEmbeddingBatchGroupOptions(params, {
maxRequests: GEMINI_BATCH_MAX_REQUESTS,
debugLabel: "memory embeddings: gemini batch submit",
}),
runGroup: async ({ group, groupIndex, groups, byCustomId }) => {
const batchInfo = await submitGeminiBatch({
gemini: params.gemini,
requests: group,
agentId: params.agentId,
});
const batchName = batchInfo.name ?? "";
if (!batchName) {
throw new Error("gemini batch create failed: missing batch name");
}
params.debug?.("memory embeddings: gemini batch created", {
batchName,
state: batchInfo.state,
group: groupIndex + 1,
groups,
requests: group.length,
});
if (
!params.wait &&
batchInfo.state &&
!["SUCCEEDED", "COMPLETED", "DONE"].includes(batchInfo.state)
) {
throw new Error(
`gemini batch ${batchName} submitted; enable remote.batch.wait to await completion`,
);
}
const completed =
batchInfo.state && ["SUCCEEDED", "COMPLETED", "DONE"].includes(batchInfo.state)
? {
outputFileId:
batchInfo.outputConfig?.file ??
batchInfo.outputConfig?.fileId ??
batchInfo.metadata?.output?.responsesFile ??
"",
}
: await waitForGeminiBatch({
gemini: params.gemini,
batchName,
wait: params.wait,
pollIntervalMs: params.pollIntervalMs,
timeoutMs: params.timeoutMs,
debug: params.debug,
initial: batchInfo,
});
if (!completed.outputFileId) {
throw new Error(`gemini batch ${batchName} completed without output file`);
}
const content = await fetchGeminiFileContent({
gemini: params.gemini,
fileId: completed.outputFileId,
});
const outputLines = parseGeminiBatchOutput(content);
const errors: string[] = [];
const remaining = new Set(group.map((request) => request.custom_id));
for (const line of outputLines) {
const customId = line.key ?? line.custom_id ?? line.request_id;
if (!customId) {
continue;
}
remaining.delete(customId);
if (line.error?.message) {
errors.push(`${customId}: ${line.error.message}`);
continue;
}
if (line.response?.error?.message) {
errors.push(`${customId}: ${line.response.error.message}`);
continue;
}
const embedding = sanitizeAndNormalizeEmbedding(
line.embedding?.values ?? line.response?.embedding?.values ?? [],
);
if (embedding.length === 0) {
errors.push(`${customId}: empty embedding`);
continue;
}
byCustomId.set(customId, embedding);
}
if (errors.length > 0) {
throw new Error(`gemini batch ${batchName} failed: ${errors.join("; ")}`);
}
if (remaining.size > 0) {
throw new Error(`gemini batch ${batchName} missing ${remaining.size} embedding responses`);
}
},
});
}

View File

@@ -1,108 +0,0 @@
import { beforeEach, describe, expect, it, vi } from "vitest";
const mocks = vi.hoisted(() => ({
uploadBatchJsonlFile: vi.fn(async () => "file_in"),
postJsonWithRetry: vi.fn(async () => ({ id: "batch_1", status: "in_progress" })),
resolveCompletedBatchResult: vi.fn(async () => ({ outputFileId: "file_out" })),
withRemoteHttpResponse: vi.fn(
async (params: { url: string; onResponse: (res: Response) => Promise<unknown> }) => {
if (params.url.endsWith("/files/file_out/content")) {
const content = [
JSON.stringify({
custom_id: "0",
response: {
status_code: 200,
body: { data: [{ embedding: [1, 0, 0], index: 0 }] },
},
}),
JSON.stringify({
custom_id: "1",
response: {
status_code: 200,
body: { data: [{ embedding: [2, 0, 0], index: 0 }] },
},
}),
].join("\n");
return await params.onResponse({
ok: true,
status: 200,
text: async () => content,
} as Response);
}
return await params.onResponse({
ok: true,
status: 200,
json: async () => ({ id: "batch_1", status: "completed", output_file_id: "file_out" }),
} as Response);
},
),
}));
vi.mock("./batch-upload.js", () => ({
uploadBatchJsonlFile: mocks.uploadBatchJsonlFile,
}));
vi.mock("./batch-http.js", () => ({
postJsonWithRetry: mocks.postJsonWithRetry,
}));
vi.mock("./batch-status.js", () => ({
resolveBatchCompletionFromStatus: vi.fn(),
resolveCompletedBatchResult: mocks.resolveCompletedBatchResult,
throwIfBatchTerminalFailure: vi.fn(),
}));
vi.mock("./remote-http.js", () => ({
withRemoteHttpResponse: mocks.withRemoteHttpResponse,
}));
describe("runOpenAiEmbeddingBatches", () => {
beforeEach(() => {
vi.clearAllMocks();
});
it("maps uploaded batch output rows back to embeddings", async () => {
const { runOpenAiEmbeddingBatches, OPENAI_BATCH_ENDPOINT } = await import("./batch-openai.js");
const result = await runOpenAiEmbeddingBatches({
openAi: {
baseUrl: "https://api.openai.com/v1",
headers: { Authorization: "Bearer test" },
fetchImpl: fetch,
model: "text-embedding-3-small",
},
agentId: "main",
requests: [
{
custom_id: "0",
method: "POST",
url: OPENAI_BATCH_ENDPOINT,
body: { model: "text-embedding-3-small", input: "hello" },
},
{
custom_id: "1",
method: "POST",
url: OPENAI_BATCH_ENDPOINT,
body: { model: "text-embedding-3-small", input: "world" },
},
],
wait: true,
pollIntervalMs: 1,
timeoutMs: 1000,
concurrency: 3,
});
expect(mocks.uploadBatchJsonlFile).toHaveBeenCalled();
expect(mocks.postJsonWithRetry).toHaveBeenCalledWith(
expect.objectContaining({
errorPrefix: "openai batch create failed",
body: expect.objectContaining({
endpoint: OPENAI_BATCH_ENDPOINT,
metadata: { source: "openclaw-memory", agent: "main" },
}),
}),
);
expect(result.get("0")).toEqual([1, 0, 0]);
expect(result.get("1")).toEqual([2, 0, 0]);
});
});

View File

@@ -1,176 +0,0 @@
import { ReadableStream } from "node:stream/web";
import { setTimeout as nativeSleep } from "node:timers/promises";
import { describe, expect, it, vi } from "vitest";
import {
runVoyageEmbeddingBatches,
type VoyageBatchOutputLine,
type VoyageBatchRequest,
} from "./batch-voyage.js";
import type { VoyageEmbeddingClient } from "./embeddings-voyage.js";
const realNow = Date.now.bind(Date);
describe("runVoyageEmbeddingBatches", () => {
const mockClient: VoyageEmbeddingClient = {
baseUrl: "https://api.voyageai.com/v1",
headers: { Authorization: "Bearer test-key" },
model: "voyage-4-large",
};
const mockRequests: VoyageBatchRequest[] = [
{ custom_id: "req-1", body: { input: "text1" } },
{ custom_id: "req-2", body: { input: "text2" } },
];
it("successfully submits batch, waits, and streams results", async () => {
const outputLines: VoyageBatchOutputLine[] = [
{
custom_id: "req-1",
response: { status_code: 200, body: { data: [{ embedding: [0.1, 0.1] }] } },
},
{
custom_id: "req-2",
response: { status_code: 200, body: { data: [{ embedding: [0.2, 0.2] }] } },
},
];
const withRemoteHttpResponse = vi.fn();
const postJsonWithRetry = vi.fn();
const uploadBatchJsonlFile = vi.fn();
// Create a stream that emits the NDJSON lines
const stream = new ReadableStream({
start(controller) {
const text = outputLines.map((l) => JSON.stringify(l)).join("\n");
controller.enqueue(new TextEncoder().encode(text));
controller.close();
},
});
uploadBatchJsonlFile.mockImplementationOnce(async (params) => {
expect(params.errorPrefix).toBe("voyage batch file upload failed");
expect(params.requests).toEqual(mockRequests);
return "file-123";
});
postJsonWithRetry.mockImplementationOnce(async (params) => {
expect(params.url).toContain("/batches");
expect(params.body).toMatchObject({
input_file_id: "file-123",
completion_window: "12h",
request_params: {
model: "voyage-4-large",
input_type: "document",
},
});
return {
id: "batch-abc",
status: "pending",
};
});
withRemoteHttpResponse.mockImplementationOnce(async (params) => {
expect(params.url).toContain("/batches/batch-abc");
return await params.onResponse(
new Response(
JSON.stringify({
id: "batch-abc",
status: "completed",
output_file_id: "file-out-999",
}),
{
status: 200,
headers: { "Content-Type": "application/json" },
},
),
);
});
withRemoteHttpResponse.mockImplementationOnce(async (params) => {
expect(params.url).toContain("/files/file-out-999/content");
return await params.onResponse(
new Response(stream as unknown as BodyInit, {
status: 200,
headers: { "Content-Type": "application/x-ndjson" },
}),
);
});
const results = await runVoyageEmbeddingBatches({
client: mockClient,
agentId: "agent-1",
requests: mockRequests,
wait: true,
pollIntervalMs: 1, // fast poll
timeoutMs: 1000,
concurrency: 1,
deps: {
now: realNow,
sleep: async (ms) => {
await nativeSleep(ms);
},
postJsonWithRetry,
uploadBatchJsonlFile,
withRemoteHttpResponse,
},
});
expect(results.size).toBe(2);
expect(results.get("req-1")).toEqual([0.1, 0.1]);
expect(results.get("req-2")).toEqual([0.2, 0.2]);
expect(uploadBatchJsonlFile).toHaveBeenCalledTimes(1);
expect(postJsonWithRetry).toHaveBeenCalledTimes(1);
expect(withRemoteHttpResponse).toHaveBeenCalledTimes(2);
});
it("handles empty lines and stream chunks correctly", async () => {
const withRemoteHttpResponse = vi.fn();
const postJsonWithRetry = vi.fn();
const uploadBatchJsonlFile = vi.fn();
const stream = new ReadableStream({
start(controller) {
const line1 = JSON.stringify({
custom_id: "req-1",
response: { body: { data: [{ embedding: [1] }] } },
});
const line2 = JSON.stringify({
custom_id: "req-2",
response: { body: { data: [{ embedding: [2] }] } },
});
// Split across chunks
controller.enqueue(new TextEncoder().encode(line1 + "\n"));
controller.enqueue(new TextEncoder().encode("\n")); // empty line
controller.enqueue(new TextEncoder().encode(line2)); // no newline at EOF
controller.close();
},
});
uploadBatchJsonlFile.mockResolvedValueOnce("f1");
postJsonWithRetry.mockResolvedValueOnce({
id: "b1",
status: "completed",
output_file_id: "out1",
});
withRemoteHttpResponse.mockImplementationOnce(async (params) => {
expect(params.url).toContain("/files/out1/content");
return await params.onResponse(new Response(stream as unknown as BodyInit, { status: 200 }));
});
const results = await runVoyageEmbeddingBatches({
client: mockClient,
agentId: "a1",
requests: mockRequests,
wait: true,
pollIntervalMs: 1,
timeoutMs: 1000,
concurrency: 1,
deps: {
now: realNow,
sleep: async (ms) => {
await nativeSleep(ms);
},
postJsonWithRetry,
uploadBatchJsonlFile,
withRemoteHttpResponse,
},
});
expect(results.get("req-1")).toEqual([1]);
expect(results.get("req-2")).toEqual([2]);
});
});

View File

@@ -1,315 +0,0 @@
import { createInterface } from "node:readline";
import { Readable } from "node:stream";
import {
applyEmbeddingBatchOutputLine,
buildBatchHeaders,
buildEmbeddingBatchGroupOptions,
EMBEDDING_BATCH_ENDPOINT,
extractBatchErrorMessage,
formatUnavailableBatchError,
normalizeBatchBaseUrl,
postJsonWithRetry,
resolveBatchCompletionFromStatus,
resolveCompletedBatchResult,
runEmbeddingBatchGroups,
throwIfBatchTerminalFailure,
type EmbeddingBatchExecutionParams,
type EmbeddingBatchStatus,
type BatchCompletionResult,
type ProviderBatchOutputLine,
uploadBatchJsonlFile,
withRemoteHttpResponse,
} from "./batch-embedding-common.js";
import type { VoyageEmbeddingClient } from "./embeddings-voyage.js";
/**
* Voyage Batch API Input Line format.
* See: https://docs.voyageai.com/docs/batch-inference
*/
export type VoyageBatchRequest = {
custom_id: string;
body: {
input: string | string[];
};
};
export type VoyageBatchStatus = EmbeddingBatchStatus;
export type VoyageBatchOutputLine = ProviderBatchOutputLine;
export const VOYAGE_BATCH_ENDPOINT = EMBEDDING_BATCH_ENDPOINT;
const VOYAGE_BATCH_COMPLETION_WINDOW = "12h";
const VOYAGE_BATCH_MAX_REQUESTS = 50000;
type VoyageBatchDeps = {
now: () => number;
sleep: (ms: number) => Promise<void>;
postJsonWithRetry: typeof postJsonWithRetry;
uploadBatchJsonlFile: typeof uploadBatchJsonlFile;
withRemoteHttpResponse: typeof withRemoteHttpResponse;
};
function resolveVoyageBatchDeps(overrides: Partial<VoyageBatchDeps> | undefined): VoyageBatchDeps {
return {
now: overrides?.now ?? Date.now,
sleep:
overrides?.sleep ??
(async (ms: number) => await new Promise((resolve) => setTimeout(resolve, ms))),
postJsonWithRetry: overrides?.postJsonWithRetry ?? postJsonWithRetry,
uploadBatchJsonlFile: overrides?.uploadBatchJsonlFile ?? uploadBatchJsonlFile,
withRemoteHttpResponse: overrides?.withRemoteHttpResponse ?? withRemoteHttpResponse,
};
}
async function assertVoyageResponseOk(res: Response, context: string): Promise<void> {
if (!res.ok) {
const text = await res.text();
throw new Error(`${context}: ${res.status} ${text}`);
}
}
function buildVoyageBatchRequest<T>(params: {
client: VoyageEmbeddingClient;
path: string;
onResponse: (res: Response) => Promise<T>;
}) {
const baseUrl = normalizeBatchBaseUrl(params.client);
return {
url: `${baseUrl}/${params.path}`,
ssrfPolicy: params.client.ssrfPolicy,
init: {
headers: buildBatchHeaders(params.client, { json: true }),
},
onResponse: params.onResponse,
};
}
async function submitVoyageBatch(params: {
client: VoyageEmbeddingClient;
requests: VoyageBatchRequest[];
agentId: string;
deps: VoyageBatchDeps;
}): Promise<VoyageBatchStatus> {
const baseUrl = normalizeBatchBaseUrl(params.client);
const inputFileId = await params.deps.uploadBatchJsonlFile({
client: params.client,
requests: params.requests,
errorPrefix: "voyage batch file upload failed",
});
// 2. Create batch job using Voyage Batches API
return await params.deps.postJsonWithRetry<VoyageBatchStatus>({
url: `${baseUrl}/batches`,
headers: buildBatchHeaders(params.client, { json: true }),
ssrfPolicy: params.client.ssrfPolicy,
body: {
input_file_id: inputFileId,
endpoint: VOYAGE_BATCH_ENDPOINT,
completion_window: VOYAGE_BATCH_COMPLETION_WINDOW,
request_params: {
model: params.client.model,
input_type: "document",
},
metadata: {
source: "clawdbot-memory",
agent: params.agentId,
},
},
errorPrefix: "voyage batch create failed",
});
}
async function fetchVoyageBatchStatus(params: {
client: VoyageEmbeddingClient;
batchId: string;
deps: VoyageBatchDeps;
}): Promise<VoyageBatchStatus> {
return await params.deps.withRemoteHttpResponse(
buildVoyageBatchRequest({
client: params.client,
path: `batches/${params.batchId}`,
onResponse: async (res) => {
await assertVoyageResponseOk(res, "voyage batch status failed");
return (await res.json()) as VoyageBatchStatus;
},
}),
);
}
async function readVoyageBatchError(params: {
client: VoyageEmbeddingClient;
errorFileId: string;
deps: VoyageBatchDeps;
}): Promise<string | undefined> {
try {
return await params.deps.withRemoteHttpResponse(
buildVoyageBatchRequest({
client: params.client,
path: `files/${params.errorFileId}/content`,
onResponse: async (res) => {
await assertVoyageResponseOk(res, "voyage batch error file content failed");
const text = await res.text();
if (!text.trim()) {
return undefined;
}
const lines = text
.split("\n")
.map((line) => line.trim())
.filter(Boolean)
.map((line) => JSON.parse(line) as VoyageBatchOutputLine);
return extractBatchErrorMessage(lines);
},
}),
);
} catch (err) {
return formatUnavailableBatchError(err);
}
}
async function waitForVoyageBatch(params: {
client: VoyageEmbeddingClient;
batchId: string;
wait: boolean;
pollIntervalMs: number;
timeoutMs: number;
debug?: (message: string, data?: Record<string, unknown>) => void;
initial?: VoyageBatchStatus;
deps: VoyageBatchDeps;
}): Promise<BatchCompletionResult> {
const start = params.deps.now();
let current: VoyageBatchStatus | undefined = params.initial;
while (true) {
const status =
current ??
(await fetchVoyageBatchStatus({
client: params.client,
batchId: params.batchId,
deps: params.deps,
}));
const state = status.status ?? "unknown";
if (state === "completed") {
return resolveBatchCompletionFromStatus({
provider: "voyage",
batchId: params.batchId,
status,
});
}
await throwIfBatchTerminalFailure({
provider: "voyage",
status: { ...status, id: params.batchId },
readError: async (errorFileId) =>
await readVoyageBatchError({
client: params.client,
errorFileId,
deps: params.deps,
}),
});
if (!params.wait) {
throw new Error(`voyage batch ${params.batchId} still ${state}; wait disabled`);
}
if (params.deps.now() - start > params.timeoutMs) {
throw new Error(`voyage batch ${params.batchId} timed out after ${params.timeoutMs}ms`);
}
params.debug?.(`voyage batch ${params.batchId} ${state}; waiting ${params.pollIntervalMs}ms`);
await params.deps.sleep(params.pollIntervalMs);
current = undefined;
}
}
export async function runVoyageEmbeddingBatches(
params: {
client: VoyageEmbeddingClient;
agentId: string;
requests: VoyageBatchRequest[];
deps?: Partial<VoyageBatchDeps>;
} & EmbeddingBatchExecutionParams,
): Promise<Map<string, number[]>> {
const deps = resolveVoyageBatchDeps(params.deps);
return await runEmbeddingBatchGroups({
...buildEmbeddingBatchGroupOptions(params, {
maxRequests: VOYAGE_BATCH_MAX_REQUESTS,
debugLabel: "memory embeddings: voyage batch submit",
}),
runGroup: async ({ group, groupIndex, groups, byCustomId }) => {
const batchInfo = await submitVoyageBatch({
client: params.client,
requests: group,
agentId: params.agentId,
deps,
});
if (!batchInfo.id) {
throw new Error("voyage batch create failed: missing batch id");
}
const batchId = batchInfo.id;
params.debug?.("memory embeddings: voyage batch created", {
batchId: batchInfo.id,
status: batchInfo.status,
group: groupIndex + 1,
groups,
requests: group.length,
});
const completed = await resolveCompletedBatchResult({
provider: "voyage",
status: batchInfo,
wait: params.wait,
waitForBatch: async () =>
await waitForVoyageBatch({
client: params.client,
batchId,
wait: params.wait,
pollIntervalMs: params.pollIntervalMs,
timeoutMs: params.timeoutMs,
debug: params.debug,
initial: batchInfo,
deps,
}),
});
const baseUrl = normalizeBatchBaseUrl(params.client);
const errors: string[] = [];
const remaining = new Set(group.map((request) => request.custom_id));
await deps.withRemoteHttpResponse({
url: `${baseUrl}/files/${completed.outputFileId}/content`,
ssrfPolicy: params.client.ssrfPolicy,
init: {
headers: buildBatchHeaders(params.client, { json: true }),
},
onResponse: async (contentRes) => {
if (!contentRes.ok) {
const text = await contentRes.text();
throw new Error(`voyage batch file content failed: ${contentRes.status} ${text}`);
}
if (!contentRes.body) {
return;
}
const reader = createInterface({
input: Readable.fromWeb(
contentRes.body as unknown as import("stream/web").ReadableStream,
),
terminal: false,
});
for await (const rawLine of reader) {
if (!rawLine.trim()) {
continue;
}
const line = JSON.parse(rawLine) as VoyageBatchOutputLine;
applyEmbeddingBatchOutputLine({ line, remaining, errors, byCustomId });
}
},
});
if (errors.length > 0) {
throw new Error(`voyage batch ${batchInfo.id} failed: ${errors.join("; ")}`);
}
if (remaining.size > 0) {
throw new Error(
`voyage batch ${batchInfo.id} missing ${remaining.size} embedding responses`,
);
}
},
});
}

View File

@@ -1,40 +1,14 @@
import { normalizeLowercaseStringOrEmpty } from "../../shared/string-coerce.js";
import type { EmbeddingProvider } from "./embeddings.js";
const DEFAULT_EMBEDDING_MAX_INPUT_TOKENS = 8192;
const DEFAULT_LOCAL_EMBEDDING_MAX_INPUT_TOKENS = 2048;
const KNOWN_EMBEDDING_MAX_INPUT_TOKENS: Record<string, number> = {
"openai:text-embedding-3-small": 8192,
"openai:text-embedding-3-large": 8192,
"openai:text-embedding-ada-002": 8191,
"gemini:text-embedding-004": 2048,
"gemini:gemini-embedding-001": 2048,
"gemini:gemini-embedding-2-preview": 8192,
"voyage:voyage-3": 32000,
"voyage:voyage-3-lite": 16000,
"voyage:voyage-code-3": 32000,
};
export function resolveEmbeddingMaxInputTokens(provider: EmbeddingProvider): number {
if (typeof provider.maxInputTokens === "number") {
return provider.maxInputTokens;
}
// Provider/model mapping is best-effort; different providers use different
// limits and we prefer to be conservative when we don't know.
const key = normalizeLowercaseStringOrEmpty(`${provider.id}:${provider.model}`);
const known = KNOWN_EMBEDDING_MAX_INPUT_TOKENS[key];
if (typeof known === "number") {
return known;
}
// Provider-specific conservative fallbacks. This prevents us from accidentally
// using the OpenAI default for providers with much smaller limits.
if (normalizeLowercaseStringOrEmpty(provider.id) === "gemini") {
return 2048;
}
if (normalizeLowercaseStringOrEmpty(provider.id) === "local") {
if (provider.id === "local") {
return DEFAULT_LOCAL_EMBEDDING_MAX_INPUT_TOKENS;
}

View File

@@ -0,0 +1,29 @@
import { normalizeLowercaseStringOrEmpty } from "../../shared/string-coerce.js";
export function isMissingEmbeddingApiKeyError(err: unknown): boolean {
return err instanceof Error && err.message.includes("No API key found for provider");
}
export function sanitizeEmbeddingCacheHeaders(
headers: Record<string, string>,
excludedHeaderNames: string[],
): Array<[string, string]> {
const excluded = new Set(
excludedHeaderNames.map((name) => normalizeLowercaseStringOrEmpty(name)),
);
return Object.entries(headers)
.filter(([key]) => !excluded.has(normalizeLowercaseStringOrEmpty(key)))
.toSorted(([a], [b]) => a.localeCompare(b))
.map(([key, value]) => [key, value]);
}
export function mapBatchEmbeddingsByIndex(
byCustomId: Map<string, number[]>,
count: number,
): number[][] {
const embeddings: number[][] = [];
for (let index = 0; index < count; index += 1) {
embeddings.push(byCustomId.get(String(index)) ?? []);
}
return embeddings;
}

View File

@@ -1,377 +0,0 @@
import { afterEach, beforeAll, beforeEach, describe, expect, it, vi } from "vitest";
const { defaultProviderMock, resolveCredentialsMock, sendMock } = vi.hoisted(() => ({
defaultProviderMock: vi.fn(),
resolveCredentialsMock: vi.fn(),
sendMock: vi.fn(),
}));
vi.mock("@aws-sdk/client-bedrock-runtime", () => {
class MockClient {
region: string;
constructor(config: { region: string }) {
this.region = config.region;
}
send = sendMock;
}
class MockCommand {
input: unknown;
constructor(input: unknown) {
this.input = input;
}
}
return { BedrockRuntimeClient: MockClient, InvokeModelCommand: MockCommand };
});
vi.mock("@aws-sdk/credential-provider-node", () => ({
defaultProvider: defaultProviderMock.mockImplementation(() => resolveCredentialsMock),
}));
let createBedrockEmbeddingProvider: typeof import("./embeddings-bedrock.js").createBedrockEmbeddingProvider;
let resolveBedrockEmbeddingClient: typeof import("./embeddings-bedrock.js").resolveBedrockEmbeddingClient;
let normalizeBedrockEmbeddingModel: typeof import("./embeddings-bedrock.js").normalizeBedrockEmbeddingModel;
let hasAwsCredentials: typeof import("./embeddings-bedrock.js").hasAwsCredentials;
beforeAll(async () => {
({
createBedrockEmbeddingProvider,
resolveBedrockEmbeddingClient,
normalizeBedrockEmbeddingModel,
hasAwsCredentials,
} = await import("./embeddings-bedrock.js"));
});
beforeEach(() => {
defaultProviderMock.mockImplementation(() => resolveCredentialsMock);
});
const enc = (body: unknown) => ({ body: new TextEncoder().encode(JSON.stringify(body)) });
const reqBody = (i = 0): Record<string, unknown> =>
JSON.parse(sendMock.mock.calls[i][0].input.body);
describe("bedrock embedding provider", () => {
const originalEnv = process.env;
afterEach(() => {
process.env = originalEnv;
vi.restoreAllMocks();
defaultProviderMock.mockClear();
resolveCredentialsMock.mockReset();
sendMock.mockReset();
});
// --- Normalization ---
it("normalizes model names with prefixes", () => {
expect(normalizeBedrockEmbeddingModel("bedrock/amazon.titan-embed-text-v2:0")).toBe(
"amazon.titan-embed-text-v2:0",
);
expect(normalizeBedrockEmbeddingModel("amazon-bedrock/cohere.embed-english-v3")).toBe(
"cohere.embed-english-v3",
);
expect(normalizeBedrockEmbeddingModel("")).toBe("amazon.titan-embed-text-v2:0");
});
// --- Client resolution ---
it("resolves region from env", () => {
process.env = { ...originalEnv, AWS_REGION: "eu-west-1" };
const c = resolveBedrockEmbeddingClient({
config: {} as never,
provider: "bedrock",
model: "amazon.titan-embed-text-v2:0",
fallback: "none",
});
expect(c.region).toBe("eu-west-1");
expect(c.dimensions).toBe(1024);
});
it("defaults to us-east-1", () => {
process.env = { ...originalEnv };
delete process.env.AWS_REGION;
delete process.env.AWS_DEFAULT_REGION;
expect(
resolveBedrockEmbeddingClient({
config: {} as never,
provider: "bedrock",
model: "amazon.titan-embed-text-v2:0",
fallback: "none",
}).region,
).toBe("us-east-1");
});
it("extracts region from baseUrl", () => {
process.env = { ...originalEnv };
delete process.env.AWS_REGION;
const c = resolveBedrockEmbeddingClient({
config: {
models: {
providers: {
"amazon-bedrock": { baseUrl: "https://bedrock-runtime.ap-southeast-2.amazonaws.com" },
},
},
} as never,
provider: "bedrock",
model: "amazon.titan-embed-text-v2:0",
fallback: "none",
});
expect(c.region).toBe("ap-southeast-2");
});
it("validates dimensions", () => {
expect(() =>
resolveBedrockEmbeddingClient({
config: {} as never,
provider: "bedrock",
model: "amazon.titan-embed-text-v2:0",
fallback: "none",
outputDimensionality: 768,
}),
).toThrow("Invalid dimensions 768");
});
it("accepts valid dimensions", () => {
expect(
resolveBedrockEmbeddingClient({
config: {} as never,
provider: "bedrock",
model: "amazon.titan-embed-text-v2:0",
fallback: "none",
outputDimensionality: 256,
}).dimensions,
).toBe(256);
});
it("resolves throughput-suffixed variants", () => {
expect(
resolveBedrockEmbeddingClient({
config: {} as never,
provider: "bedrock",
model: "amazon.titan-embed-text-v1:2:8k",
fallback: "none",
}).dimensions,
).toBe(1536);
});
// --- Credential detection ---
it("detects access keys", async () => {
await expect(
hasAwsCredentials({
AWS_ACCESS_KEY_ID: "A",
AWS_SECRET_ACCESS_KEY: "s",
} as NodeJS.ProcessEnv),
).resolves.toBe(true);
});
it("detects profile", async () => {
await expect(hasAwsCredentials({ AWS_PROFILE: "default" } as NodeJS.ProcessEnv)).resolves.toBe(
true,
);
});
it("detects ECS task role", async () => {
await expect(
hasAwsCredentials({ AWS_CONTAINER_CREDENTIALS_RELATIVE_URI: "/v2" } as NodeJS.ProcessEnv),
).resolves.toBe(true);
});
it("detects EKS IRSA", async () => {
await expect(
hasAwsCredentials({
AWS_WEB_IDENTITY_TOKEN_FILE: "/var/run/secrets/token",
AWS_ROLE_ARN: "arn:aws:iam::123:role/x",
} as NodeJS.ProcessEnv),
).resolves.toBe(true);
});
it("detects credentials via the AWS SDK default provider chain", async () => {
resolveCredentialsMock.mockResolvedValue({ accessKeyId: "AKIAEXAMPLE" });
await expect(hasAwsCredentials({} as NodeJS.ProcessEnv)).resolves.toBe(true);
expect(defaultProviderMock).toHaveBeenCalledWith({ timeout: 1000, maxRetries: 0 });
});
it("returns false with no creds", async () => {
resolveCredentialsMock.mockRejectedValue(new Error("no aws credentials"));
await expect(hasAwsCredentials({} as NodeJS.ProcessEnv)).resolves.toBe(false);
});
// --- Titan V2 ---
it("embeds with Titan V2", async () => {
sendMock.mockResolvedValue(enc({ embedding: [0.1, 0.2, 0.3] }));
const { provider } = await createBedrockEmbeddingProvider({
config: {} as never,
provider: "bedrock",
model: "amazon.titan-embed-text-v2:0",
fallback: "none",
});
expect(await provider.embedQuery("test")).toHaveLength(3);
expect(reqBody()).toMatchObject({ inputText: "test", normalize: true, dimensions: 1024 });
});
it("returns empty for blank text", async () => {
const { provider } = await createBedrockEmbeddingProvider({
config: {} as never,
provider: "bedrock",
model: "amazon.titan-embed-text-v2:0",
fallback: "none",
});
expect(await provider.embedQuery(" ")).toEqual([]);
expect(sendMock).not.toHaveBeenCalled();
});
it("batches Titan V2 concurrently", async () => {
sendMock
.mockResolvedValueOnce(enc({ embedding: [0.1] }))
.mockResolvedValueOnce(enc({ embedding: [0.2] }));
const { provider } = await createBedrockEmbeddingProvider({
config: {} as never,
provider: "bedrock",
model: "amazon.titan-embed-text-v2:0",
fallback: "none",
});
expect(await provider.embedBatch(["a", "b"])).toHaveLength(2);
expect(sendMock).toHaveBeenCalledTimes(2);
});
// --- Titan V1 ---
it("sends only inputText for Titan V1", async () => {
sendMock.mockResolvedValue(enc({ embedding: [0.5] }));
const { provider } = await createBedrockEmbeddingProvider({
config: {} as never,
provider: "bedrock",
model: "amazon.titan-embed-text-v1",
fallback: "none",
});
await provider.embedQuery("hi");
expect(reqBody()).toEqual({ inputText: "hi" });
});
it("handles Titan G1 text variant", async () => {
sendMock.mockResolvedValue(enc({ embedding: [0.1] }));
const { provider } = await createBedrockEmbeddingProvider({
config: {} as never,
provider: "bedrock",
model: "amazon.titan-embed-g1-text-02",
fallback: "none",
});
await provider.embedQuery("hi");
expect(reqBody()).toEqual({ inputText: "hi" });
});
// --- Cohere V3 ---
it("embeds Cohere V3 batch in single call", async () => {
sendMock.mockResolvedValue(enc({ embeddings: [[0.1], [0.2]] }));
const { provider } = await createBedrockEmbeddingProvider({
config: {} as never,
provider: "bedrock",
model: "cohere.embed-english-v3",
fallback: "none",
});
expect(await provider.embedBatch(["a", "b"])).toHaveLength(2);
expect(sendMock).toHaveBeenCalledTimes(1);
expect(reqBody()).toMatchObject({ texts: ["a", "b"], input_type: "search_document" });
});
it("uses search_query for Cohere embedQuery", async () => {
sendMock.mockResolvedValue(enc({ embeddings: [[0.1]] }));
const { provider } = await createBedrockEmbeddingProvider({
config: {} as never,
provider: "bedrock",
model: "cohere.embed-english-v3",
fallback: "none",
});
await provider.embedQuery("q");
expect(reqBody().input_type).toBe("search_query");
});
// --- Cohere V4 ---
it("embeds Cohere V4 with embedding_types + output_dimension", async () => {
sendMock.mockResolvedValue(enc({ embeddings: { float: [[0.1], [0.2]] } }));
const { provider } = await createBedrockEmbeddingProvider({
config: {} as never,
provider: "bedrock",
model: "cohere.embed-v4:0",
fallback: "none",
});
expect(await provider.embedBatch(["a", "b"])).toHaveLength(2);
expect(reqBody()).toMatchObject({ embedding_types: ["float"], output_dimension: 1536 });
});
it("validates Cohere V4 dimensions", () => {
expect(() =>
resolveBedrockEmbeddingClient({
config: {} as never,
provider: "bedrock",
model: "cohere.embed-v4:0",
fallback: "none",
outputDimensionality: 2048,
}),
).toThrow("Invalid dimensions 2048");
});
// --- Nova ---
it("embeds Nova with SINGLE_EMBEDDING format", async () => {
sendMock.mockResolvedValue(
enc({ embeddings: [{ embeddingType: "TEXT", embedding: [0.1, 0.2] }] }),
);
const { provider } = await createBedrockEmbeddingProvider({
config: {} as never,
provider: "bedrock",
model: "amazon.nova-2-multimodal-embeddings-v1:0",
fallback: "none",
});
expect(await provider.embedQuery("hi")).toHaveLength(2);
expect(reqBody().taskType).toBe("SINGLE_EMBEDDING");
});
it("validates Nova dimensions", () => {
expect(() =>
resolveBedrockEmbeddingClient({
config: {} as never,
provider: "bedrock",
model: "amazon.nova-2-multimodal-embeddings-v1:0",
fallback: "none",
outputDimensionality: 512,
}),
).toThrow("Invalid dimensions 512");
});
it("batches Nova concurrently", async () => {
sendMock
.mockResolvedValueOnce(enc({ embeddings: [{ embeddingType: "TEXT", embedding: [0.1] }] }))
.mockResolvedValueOnce(enc({ embeddings: [{ embeddingType: "TEXT", embedding: [0.2] }] }));
const { provider } = await createBedrockEmbeddingProvider({
config: {} as never,
provider: "bedrock",
model: "amazon.nova-2-multimodal-embeddings-v1:0",
fallback: "none",
});
expect(await provider.embedBatch(["a", "b"])).toHaveLength(2);
expect(sendMock).toHaveBeenCalledTimes(2);
});
// --- TwelveLabs ---
it("embeds TwelveLabs Marengo", async () => {
sendMock.mockResolvedValue(enc({ data: [{ embedding: [0.1, 0.2] }] }));
const { provider } = await createBedrockEmbeddingProvider({
config: {} as never,
provider: "bedrock",
model: "twelvelabs.marengo-embed-3-0-v1:0",
fallback: "none",
});
expect(await provider.embedQuery("hi")).toHaveLength(2);
expect(reqBody()).toEqual({ inputType: "text", text: { inputText: "hi" } });
});
it("embeds TwelveLabs object-style responses", async () => {
sendMock.mockResolvedValue(enc({ data: { embedding: [0.3, 0.4] } }));
const { provider } = await createBedrockEmbeddingProvider({
config: {} as never,
provider: "bedrock",
model: "twelvelabs.marengo-embed-2-7-v1:0",
fallback: "none",
});
expect(await provider.embedQuery("hi")).toEqual([0.6, 0.8]);
});
});

View File

@@ -1,115 +0,0 @@
import type { EmbeddingInput } from "./embedding-inputs.js";
import type { GeminiTaskType } from "./embeddings.types.js";
export const DEFAULT_GEMINI_EMBEDDING_MODEL = "gemini-embedding-001";
export const GEMINI_EMBEDDING_2_MODELS = new Set([
"gemini-embedding-2-preview",
// Add the GA model name here once released.
]);
const GEMINI_EMBEDDING_2_DEFAULT_DIMENSIONS = 3072;
const GEMINI_EMBEDDING_2_VALID_DIMENSIONS = [768, 1536, 3072] as const;
export type { GeminiTaskType } from "./embeddings.types.js";
export type GeminiTextPart = { text: string };
export type GeminiInlinePart = {
inlineData: { mimeType: string; data: string };
};
export type GeminiPart = GeminiTextPart | GeminiInlinePart;
export type GeminiEmbeddingRequest = {
content: { parts: GeminiPart[] };
taskType: GeminiTaskType;
outputDimensionality?: number;
model?: string;
};
export type GeminiTextEmbeddingRequest = GeminiEmbeddingRequest;
/** Builds the text-only Gemini embedding request shape used across direct and batch APIs. */
export function buildGeminiTextEmbeddingRequest(params: {
text: string;
taskType: GeminiTaskType;
outputDimensionality?: number;
modelPath?: string;
}): GeminiTextEmbeddingRequest {
return buildGeminiEmbeddingRequest({
input: { text: params.text },
taskType: params.taskType,
outputDimensionality: params.outputDimensionality,
modelPath: params.modelPath,
});
}
export function buildGeminiEmbeddingRequest(params: {
input: EmbeddingInput;
taskType: GeminiTaskType;
outputDimensionality?: number;
modelPath?: string;
}): GeminiEmbeddingRequest {
const request: GeminiEmbeddingRequest = {
content: {
parts: params.input.parts?.map((part) =>
part.type === "text"
? ({ text: part.text } satisfies GeminiTextPart)
: ({
inlineData: { mimeType: part.mimeType, data: part.data },
} satisfies GeminiInlinePart),
) ?? [{ text: params.input.text }],
},
taskType: params.taskType,
};
if (params.modelPath) {
request.model = params.modelPath;
}
if (params.outputDimensionality != null) {
request.outputDimensionality = params.outputDimensionality;
}
return request;
}
/**
* Returns true if the given model name is a gemini-embedding-2 variant that
* supports `outputDimensionality` and extended task types.
*/
export function isGeminiEmbedding2Model(model: string): boolean {
return GEMINI_EMBEDDING_2_MODELS.has(model);
}
/**
* Validate and return the `outputDimensionality` for gemini-embedding-2 models.
* Returns `undefined` for older models (they don't support the param).
*/
export function resolveGeminiOutputDimensionality(
model: string,
requested?: number,
): number | undefined {
if (!isGeminiEmbedding2Model(model)) {
return undefined;
}
if (requested == null) {
return GEMINI_EMBEDDING_2_DEFAULT_DIMENSIONS;
}
const valid: readonly number[] = GEMINI_EMBEDDING_2_VALID_DIMENSIONS;
if (!valid.includes(requested)) {
throw new Error(
`Invalid outputDimensionality ${requested} for ${model}. Valid values: ${valid.join(", ")}`,
);
}
return requested;
}
export function normalizeGeminiModel(model: string): string {
const trimmed = model.trim();
if (!trimmed) {
return DEFAULT_GEMINI_EMBEDDING_MODEL;
}
const withoutPrefix = trimmed.replace(/^models\//, "");
if (withoutPrefix.startsWith("gemini/")) {
return withoutPrefix.slice("gemini/".length);
}
if (withoutPrefix.startsWith("google/")) {
return withoutPrefix.slice("google/".length);
}
return withoutPrefix;
}

View File

@@ -1,178 +0,0 @@
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
const resolveCopilotApiTokenMock = vi.hoisted(() => vi.fn());
const fetchWithSsrFGuardMock = vi.hoisted(() => vi.fn());
vi.mock("../../agents/github-copilot-token.js", () => ({
DEFAULT_COPILOT_API_BASE_URL: "https://api.githubcopilot.test",
resolveCopilotApiToken: resolveCopilotApiTokenMock,
}));
vi.mock("../../infra/net/fetch-guard.js", () => ({
fetchWithSsrFGuard: fetchWithSsrFGuardMock,
}));
import { createGitHubCopilotEmbeddingProvider } from "./embeddings-github-copilot.js";
function mockFetchResponse(spec: { ok: boolean; status?: number; json?: unknown; text?: string }) {
fetchWithSsrFGuardMock.mockImplementationOnce(async () => ({
response: {
ok: spec.ok,
status: spec.status ?? (spec.ok ? 200 : 500),
json: async () => spec.json,
text: async () => spec.text ?? "",
},
release: vi.fn(async () => {}),
}));
}
describe("createGitHubCopilotEmbeddingProvider", () => {
beforeEach(() => {
resolveCopilotApiTokenMock.mockResolvedValue({
token: "copilot-token-a",
expiresAt: Date.now() + 3_600_000,
source: "test",
baseUrl: "https://api.githubcopilot.test",
});
});
afterEach(() => {
vi.restoreAllMocks();
resolveCopilotApiTokenMock.mockReset();
fetchWithSsrFGuardMock.mockReset();
});
it("normalizes embeddings returned for queries", async () => {
mockFetchResponse({
ok: true,
json: {
data: [{ index: 0, embedding: [3, 4] }],
},
});
const { provider } = await createGitHubCopilotEmbeddingProvider({
githubToken: "gh_test",
model: "text-embedding-3-small",
});
await expect(provider.embedQuery("hello")).resolves.toEqual([0.6, 0.8]);
expect(fetchWithSsrFGuardMock).toHaveBeenCalledWith(
expect.objectContaining({
url: "https://api.githubcopilot.test/embeddings",
}),
);
});
it("preserves input order by explicit response index", async () => {
mockFetchResponse({
ok: true,
json: {
data: [
{ index: 1, embedding: [0, 2] },
{ index: 0, embedding: [1, 0] },
],
},
});
const { provider } = await createGitHubCopilotEmbeddingProvider({
githubToken: "gh_test",
model: "text-embedding-3-small",
});
await expect(provider.embedBatch(["first", "second"])).resolves.toEqual([
[1, 0],
[0, 1],
]);
});
it("uses a fresh Copilot token for later requests", async () => {
resolveCopilotApiTokenMock
.mockResolvedValueOnce({
token: "copilot-token-create",
expiresAt: Date.now() + 3_600_000,
source: "test",
baseUrl: "https://api.githubcopilot.test",
})
.mockResolvedValueOnce({
token: "copilot-token-first",
expiresAt: Date.now() + 3_600_000,
source: "test",
baseUrl: "https://api.githubcopilot.test",
})
.mockResolvedValueOnce({
token: "copilot-token-second",
expiresAt: Date.now() + 3_600_000,
source: "test",
baseUrl: "https://api.githubcopilot.test",
});
mockFetchResponse({
ok: true,
json: { data: [{ index: 0, embedding: [1, 0] }] },
});
mockFetchResponse({
ok: true,
json: { data: [{ index: 0, embedding: [0, 1] }] },
});
const { provider } = await createGitHubCopilotEmbeddingProvider({
githubToken: "gh_test",
model: "text-embedding-3-small",
});
await provider.embedQuery("first");
await provider.embedQuery("second");
const firstHeaders = fetchWithSsrFGuardMock.mock.calls[0]?.[0]?.init?.headers as Record<
string,
string
>;
const secondHeaders = fetchWithSsrFGuardMock.mock.calls[1]?.[0]?.init?.headers as Record<
string,
string
>;
expect(firstHeaders.Authorization).toBe("Bearer copilot-token-first");
expect(secondHeaders.Authorization).toBe("Bearer copilot-token-second");
});
it("honors custom baseUrl and header overrides", async () => {
mockFetchResponse({
ok: true,
json: { data: [{ index: 0, embedding: [1, 0] }] },
});
const { provider } = await createGitHubCopilotEmbeddingProvider({
githubToken: "gh_test",
model: "text-embedding-3-small",
baseUrl: "https://proxy.example/v1",
headers: { "X-Proxy-Token": "proxy" },
});
await provider.embedQuery("hello");
const call = fetchWithSsrFGuardMock.mock.calls[0]?.[0] as {
init: { headers: Record<string, string> };
url: string;
};
expect(call.url).toBe("https://proxy.example/v1/embeddings");
expect(call.init.headers["X-Proxy-Token"]).toBe("proxy");
expect(call.init.headers.Authorization).toBe("Bearer copilot-token-a");
});
it("fails fast on sparse or malformed embedding payloads", async () => {
mockFetchResponse({
ok: true,
json: {
data: [{ index: 1, embedding: [1, 0] }],
},
});
const { provider } = await createGitHubCopilotEmbeddingProvider({
githubToken: "gh_test",
model: "text-embedding-3-small",
});
await expect(provider.embedBatch(["first", "second"])).rejects.toThrow(
"GitHub Copilot embeddings response missing vectors for some inputs",
);
});
});

View File

@@ -1,151 +0,0 @@
import {
DEFAULT_COPILOT_API_BASE_URL,
resolveCopilotApiToken,
} from "../../agents/github-copilot-token.js";
import { sanitizeAndNormalizeEmbedding } from "./embedding-vectors.js";
import type { EmbeddingProvider } from "./embeddings.types.js";
import { buildRemoteBaseUrlPolicy, withRemoteHttpResponse } from "./remote-http.js";
export type GitHubCopilotEmbeddingClient = {
githubToken: string;
model: string;
baseUrl?: string;
headers?: Record<string, string>;
env?: NodeJS.ProcessEnv;
fetchImpl?: typeof fetch;
};
const COPILOT_EMBEDDING_PROVIDER_ID = "github-copilot";
const COPILOT_HEADERS_STATIC: Record<string, string> = {
"Content-Type": "application/json",
"Editor-Version": "vscode/1.96.2",
"User-Agent": "GitHubCopilotChat/0.26.7",
};
function resolveConfiguredBaseUrl(
configuredBaseUrl: string | undefined,
tokenBaseUrl: string | undefined,
): string {
const trimmed = configuredBaseUrl?.trim();
if (trimmed) {
return trimmed;
}
return tokenBaseUrl || DEFAULT_COPILOT_API_BASE_URL;
}
async function resolveGitHubCopilotEmbeddingSession(client: GitHubCopilotEmbeddingClient): Promise<{
baseUrl: string;
headers: Record<string, string>;
}> {
const token = await resolveCopilotApiToken({
githubToken: client.githubToken,
env: client.env,
fetchImpl: client.fetchImpl,
});
const baseUrl = resolveConfiguredBaseUrl(client.baseUrl, token.baseUrl);
return {
baseUrl,
headers: {
...COPILOT_HEADERS_STATIC,
...client.headers,
Authorization: `Bearer ${token.token}`,
},
};
}
function parseGitHubCopilotEmbeddingPayload(payload: unknown, expectedCount: number): number[][] {
if (!payload || typeof payload !== "object") {
throw new Error("GitHub Copilot embeddings response missing data[]");
}
const data = (payload as { data?: unknown }).data;
if (!Array.isArray(data)) {
throw new Error("GitHub Copilot embeddings response missing data[]");
}
const vectors = Array.from<number[] | undefined>({ length: expectedCount });
for (const entry of data) {
if (!entry || typeof entry !== "object") {
throw new Error("GitHub Copilot embeddings response contains an invalid entry");
}
const indexValue = (entry as { index?: unknown }).index;
const embedding = (entry as { embedding?: unknown }).embedding;
const index = typeof indexValue === "number" ? indexValue : Number.NaN;
if (!Number.isInteger(index)) {
throw new Error("GitHub Copilot embeddings response contains an invalid index");
}
if (index < 0 || index >= expectedCount) {
throw new Error("GitHub Copilot embeddings response contains an out-of-range index");
}
if (vectors[index] !== undefined) {
throw new Error("GitHub Copilot embeddings response contains duplicate indexes");
}
if (!Array.isArray(embedding) || !embedding.every((value) => typeof value === "number")) {
throw new Error("GitHub Copilot embeddings response contains an invalid embedding");
}
vectors[index] = sanitizeAndNormalizeEmbedding(embedding);
}
for (let index = 0; index < expectedCount; index += 1) {
if (vectors[index] === undefined) {
throw new Error("GitHub Copilot embeddings response missing vectors for some inputs");
}
}
return vectors as number[][];
}
export async function createGitHubCopilotEmbeddingProvider(
client: GitHubCopilotEmbeddingClient,
): Promise<{ provider: EmbeddingProvider; client: GitHubCopilotEmbeddingClient }> {
const initialSession = await resolveGitHubCopilotEmbeddingSession(client);
const embed = async (input: string[]): Promise<number[][]> => {
if (input.length === 0) {
return [];
}
const session = await resolveGitHubCopilotEmbeddingSession(client);
const url = `${session.baseUrl.replace(/\/$/, "")}/embeddings`;
return await withRemoteHttpResponse({
url,
fetchImpl: client.fetchImpl,
ssrfPolicy: buildRemoteBaseUrlPolicy(session.baseUrl),
init: {
method: "POST",
headers: session.headers,
body: JSON.stringify({ model: client.model, input }),
},
onResponse: async (response) => {
if (!response.ok) {
throw new Error(
`GitHub Copilot embeddings HTTP ${response.status}: ${await response.text()}`,
);
}
let payload: unknown;
try {
payload = await response.json();
} catch {
throw new Error("GitHub Copilot embeddings returned invalid JSON");
}
return parseGitHubCopilotEmbeddingPayload(payload, input.length);
},
});
};
return {
provider: {
id: COPILOT_EMBEDDING_PROVIDER_ID,
model: client.model,
embedQuery: async (text) => {
const [vector] = await embed([text]);
return vector ?? [];
},
embedBatch: embed,
},
client: {
...client,
baseUrl: initialSession.baseUrl,
},
};
}

View File

@@ -1,387 +0,0 @@
import { afterEach, beforeAll, beforeEach, describe, expect, it, vi } from "vitest";
import type { OpenClawConfig } from "../../config/config.js";
const ensureLmstudioModelLoadedMock = vi.hoisted(() => vi.fn());
const resolveLmstudioRuntimeApiKeyMock = vi.hoisted(() => vi.fn());
vi.mock("../../plugin-sdk/lmstudio-runtime.js", () => ({
buildLmstudioAuthHeaders: ({
apiKey,
json,
headers,
}: {
apiKey?: string;
json?: boolean;
headers?: Record<string, string>;
}) => ({
...(json ? { "Content-Type": "application/json" } : {}),
...(apiKey ? { Authorization: `Bearer ${apiKey}` } : {}),
...headers,
}),
ensureLmstudioModelLoaded: (...args: unknown[]) => ensureLmstudioModelLoadedMock(...args),
LMSTUDIO_DEFAULT_EMBEDDING_MODEL: "text-embedding-nomic-embed-text-v1.5",
LMSTUDIO_PROVIDER_ID: "lmstudio",
resolveLmstudioInferenceBase: (baseUrl?: string) => {
const normalized = (baseUrl || "http://localhost:1234").replace(/\/+$/u, "");
if (normalized.endsWith("/api/v1")) {
return normalized.slice(0, -"/api/v1".length) + "/v1";
}
if (normalized.endsWith("/v1")) {
return normalized;
}
return `${normalized}/v1`;
},
resolveLmstudioProviderHeaders: ({ headers }: { headers?: Record<string, string> }) =>
headers ?? {},
resolveLmstudioRuntimeApiKey: (...args: unknown[]) => resolveLmstudioRuntimeApiKeyMock(...args),
}));
let createLmstudioEmbeddingProvider: typeof import("./embeddings-lmstudio.js").createLmstudioEmbeddingProvider;
describe("embeddings-lmstudio", () => {
const originalFetch = globalThis.fetch;
const jsonResponse = (embedding: number[]) =>
new Response(
JSON.stringify({
data: [{ embedding }],
}),
{
status: 200,
headers: { "content-type": "application/json" },
},
);
function mockEmbeddingFetch(embedding: number[]) {
const fetchMock = vi.fn<typeof fetch>();
fetchMock.mockResolvedValue(jsonResponse(embedding));
globalThis.fetch = fetchMock as unknown as typeof fetch;
return fetchMock;
}
beforeAll(async () => {
({ createLmstudioEmbeddingProvider } = await import("./embeddings-lmstudio.js"));
});
beforeEach(() => {
ensureLmstudioModelLoadedMock.mockReset();
resolveLmstudioRuntimeApiKeyMock.mockReset();
});
afterEach(() => {
globalThis.fetch = originalFetch;
});
it("embeds against inference base and warms model with resolved key", async () => {
ensureLmstudioModelLoadedMock.mockResolvedValue(undefined);
resolveLmstudioRuntimeApiKeyMock.mockResolvedValue("profile-lmstudio-key");
const fetchMock = mockEmbeddingFetch([0.1, 0.2]);
const { provider } = await createLmstudioEmbeddingProvider({
config: {
models: {
providers: {
lmstudio: {
baseUrl: "http://localhost:1234/api/v1/",
headers: { "X-Provider": "provider" },
models: [],
},
},
},
} as OpenClawConfig,
provider: "lmstudio",
model: "lmstudio/text-embedding-nomic-embed-text-v1.5",
fallback: "none",
});
await provider.embedQuery("hello");
expect(fetchMock).toHaveBeenCalledWith(
"http://localhost:1234/v1/embeddings",
expect.objectContaining({
method: "POST",
headers: expect.objectContaining({
"Content-Type": "application/json",
Authorization: "Bearer profile-lmstudio-key",
"X-Provider": "provider",
}),
}),
);
expect(ensureLmstudioModelLoadedMock).toHaveBeenCalledWith({
baseUrl: "http://localhost:1234/v1",
apiKey: "profile-lmstudio-key",
headers: {
"X-Provider": "provider",
},
ssrfPolicy: { allowedHostnames: ["localhost"] },
modelKey: "text-embedding-nomic-embed-text-v1.5",
timeoutMs: 120_000,
});
});
it("uses memorySearch remote overrides for primary lmstudio", async () => {
ensureLmstudioModelLoadedMock.mockResolvedValue(undefined);
resolveLmstudioRuntimeApiKeyMock.mockResolvedValue("profile-key");
const fetchMock = mockEmbeddingFetch([1, 2, 3]);
const { provider } = await createLmstudioEmbeddingProvider({
config: {
models: {
providers: {
lmstudio: {
baseUrl: "http://localhost:1234",
headers: {
"X-Provider": "provider",
"X-Config-Only": "from-provider",
},
models: [],
},
},
},
} as OpenClawConfig,
provider: "lmstudio",
model: "",
fallback: "none",
remote: {
baseUrl: "http://localhost:9999",
apiKey: "remote-lmstudio-key",
headers: {
"X-Provider": "remote",
"X-Remote-Only": "from-remote",
},
},
});
await provider.embedBatch(["one", "two"]);
expect(fetchMock).toHaveBeenCalledWith(
"http://localhost:9999/v1/embeddings",
expect.objectContaining({
headers: expect.objectContaining({
Authorization: "Bearer remote-lmstudio-key",
"X-Provider": "remote",
"X-Config-Only": "from-provider",
"X-Remote-Only": "from-remote",
}),
}),
);
expect(resolveLmstudioRuntimeApiKeyMock).not.toHaveBeenCalled();
expect(ensureLmstudioModelLoadedMock).toHaveBeenCalledWith({
baseUrl: "http://localhost:9999/v1",
apiKey: "remote-lmstudio-key",
headers: {
"X-Provider": "remote",
"X-Config-Only": "from-provider",
"X-Remote-Only": "from-remote",
},
ssrfPolicy: { allowedHostnames: ["localhost"] },
modelKey: "text-embedding-nomic-embed-text-v1.5",
timeoutMs: 120_000,
});
});
it("preserves remote Authorization header auth for primary lmstudio", async () => {
ensureLmstudioModelLoadedMock.mockResolvedValue(undefined);
resolveLmstudioRuntimeApiKeyMock.mockResolvedValue("stale-profile-key");
const fetchMock = mockEmbeddingFetch([1, 2, 3]);
const { provider } = await createLmstudioEmbeddingProvider({
config: {
models: {
providers: {
lmstudio: {
baseUrl: "http://localhost:1234",
headers: {
"X-Provider": "provider",
},
models: [],
},
},
},
} as OpenClawConfig,
provider: "lmstudio",
model: "",
fallback: "none",
remote: {
baseUrl: "http://localhost:9999",
headers: {
Authorization: "Bearer remote-proxy-token",
"X-Remote-Only": "from-remote",
},
},
});
await provider.embedBatch(["one", "two"]);
expect(fetchMock).toHaveBeenCalledWith(
"http://localhost:9999/v1/embeddings",
expect.objectContaining({
headers: expect.objectContaining({
Authorization: "Bearer remote-proxy-token",
"X-Provider": "provider",
"X-Remote-Only": "from-remote",
}),
}),
);
expect(resolveLmstudioRuntimeApiKeyMock).not.toHaveBeenCalled();
expect(ensureLmstudioModelLoadedMock).toHaveBeenCalledWith({
baseUrl: "http://localhost:9999/v1",
apiKey: undefined,
headers: {
"X-Provider": "provider",
Authorization: "Bearer remote-proxy-token",
"X-Remote-Only": "from-remote",
},
ssrfPolicy: { allowedHostnames: ["localhost"] },
modelKey: "text-embedding-nomic-embed-text-v1.5",
timeoutMs: 120_000,
});
});
it("ignores memorySearch remote overrides for fallback lmstudio activation", async () => {
ensureLmstudioModelLoadedMock.mockResolvedValue(undefined);
resolveLmstudioRuntimeApiKeyMock.mockResolvedValue("profile-key");
const fetchMock = mockEmbeddingFetch([1, 2, 3]);
const { provider } = await createLmstudioEmbeddingProvider({
config: {
models: {
providers: {
lmstudio: {
baseUrl: "http://localhost:1234",
headers: {
"X-Provider": "provider",
"X-Config-Only": "from-provider",
},
models: [],
},
},
},
} as OpenClawConfig,
provider: "openai",
model: "",
fallback: "lmstudio",
remote: {
baseUrl: "http://localhost:9999",
apiKey: "remote-lmstudio-key",
headers: {
"X-Provider": "remote",
"X-Remote-Only": "from-remote",
},
},
});
await provider.embedBatch(["one", "two"]);
expect(fetchMock).toHaveBeenCalledWith(
"http://localhost:1234/v1/embeddings",
expect.objectContaining({
headers: expect.objectContaining({
Authorization: "Bearer profile-key",
"X-Provider": "provider",
"X-Config-Only": "from-provider",
}),
}),
);
const callHeaders = fetchMock.mock.calls[0]?.[1]?.headers as Record<string, string>;
expect(callHeaders["X-Remote-Only"]).toBeUndefined();
expect(resolveLmstudioRuntimeApiKeyMock).toHaveBeenCalled();
expect(ensureLmstudioModelLoadedMock).toHaveBeenCalledWith({
baseUrl: "http://localhost:1234/v1",
apiKey: "profile-key",
headers: {
"X-Provider": "provider",
"X-Config-Only": "from-provider",
},
ssrfPolicy: { allowedHostnames: ["localhost"] },
modelKey: "text-embedding-nomic-embed-text-v1.5",
timeoutMs: 120_000,
});
});
it("skips remote SecretRef resolution for fallback lmstudio activation", async () => {
ensureLmstudioModelLoadedMock.mockResolvedValue(undefined);
resolveLmstudioRuntimeApiKeyMock.mockResolvedValue("profile-key");
const fetchMock = mockEmbeddingFetch([1, 2, 3]);
const { provider } = await createLmstudioEmbeddingProvider({
config: {
models: {
providers: {
lmstudio: {
baseUrl: "http://localhost:1234",
headers: {
"X-Provider": "provider",
},
models: [],
},
},
},
} as OpenClawConfig,
provider: "openai",
model: "",
fallback: "lmstudio",
remote: {
baseUrl: "http://localhost:9999",
apiKey: { source: "env", provider: "default", id: "OPENAI_API_KEY" },
headers: {
"X-Remote-Only": "from-remote",
},
},
});
await provider.embedQuery("hello");
expect(fetchMock).toHaveBeenCalledWith(
"http://localhost:1234/v1/embeddings",
expect.objectContaining({
headers: expect.objectContaining({
Authorization: "Bearer profile-key",
"X-Provider": "provider",
}),
}),
);
const callHeaders = fetchMock.mock.calls[0]?.[1]?.headers as Record<string, string>;
expect(callHeaders["X-Remote-Only"]).toBeUndefined();
expect(resolveLmstudioRuntimeApiKeyMock).toHaveBeenCalled();
});
it("uses env-template-backed provider api keys in embedding requests", async () => {
ensureLmstudioModelLoadedMock.mockResolvedValue(undefined);
resolveLmstudioRuntimeApiKeyMock.mockResolvedValue("template-lmstudio-key");
const fetchMock = mockEmbeddingFetch([0.3, 0.4]);
const { provider } = await createLmstudioEmbeddingProvider({
config: {
models: {
providers: {
lmstudio: {
baseUrl: "http://localhost:1234/v1",
apiKey: "${LM_API_TOKEN}",
models: [],
},
},
},
} as OpenClawConfig,
provider: "lmstudio",
model: "text-embedding-nomic-embed-text-v1.5",
fallback: "none",
});
await provider.embedQuery("hello");
expect(fetchMock).toHaveBeenCalledWith(
"http://localhost:1234/v1/embeddings",
expect.objectContaining({
headers: expect.objectContaining({
Authorization: "Bearer template-lmstudio-key",
}),
}),
);
});
});

View File

@@ -1,19 +0,0 @@
import { describe, expect, it } from "vitest";
import { DEFAULT_MISTRAL_EMBEDDING_MODEL, normalizeMistralModel } from "./embeddings-mistral.js";
describe("normalizeMistralModel", () => {
it("returns the default model for empty values", () => {
expect(normalizeMistralModel("")).toBe(DEFAULT_MISTRAL_EMBEDDING_MODEL);
expect(normalizeMistralModel(" ")).toBe(DEFAULT_MISTRAL_EMBEDDING_MODEL);
});
it("strips the mistral/ prefix", () => {
expect(normalizeMistralModel("mistral/mistral-embed")).toBe("mistral-embed");
expect(normalizeMistralModel(" mistral/custom-embed ")).toBe("custom-embed");
});
it("keeps explicit non-prefixed models", () => {
expect(normalizeMistralModel("mistral-embed")).toBe("mistral-embed");
expect(normalizeMistralModel("custom-embed-v2")).toBe("custom-embed-v2");
});
});

View File

@@ -1,43 +0,0 @@
import { beforeEach, describe, expect, it, vi } from "vitest";
const { createOllamaEmbeddingProviderMock } = vi.hoisted(() => ({
createOllamaEmbeddingProviderMock: vi.fn(async (options: unknown) => ({
provider: { source: "mock-provider", options },
client: { source: "mock-client" },
})),
}));
vi.mock("../../plugin-sdk/ollama-runtime.js", () => ({
DEFAULT_OLLAMA_EMBEDDING_MODEL: "nomic-embed-text",
createOllamaEmbeddingProvider: createOllamaEmbeddingProviderMock,
}));
describe("embeddings-ollama facade", () => {
beforeEach(() => {
createOllamaEmbeddingProviderMock.mockClear();
});
it("re-exports the default Ollama embedding model", async () => {
const mod = await import("./embeddings-ollama.js");
expect(mod.DEFAULT_OLLAMA_EMBEDDING_MODEL).toBe("nomic-embed-text");
});
it("delegates provider creation to the plugin-sdk runtime facade", async () => {
const mod = await import("./embeddings-ollama.js");
const options = {
provider: "ollama",
model: "nomic-embed-text",
fallback: "none",
config: {},
};
const result = await mod.createOllamaEmbeddingProvider(options as never);
expect(createOllamaEmbeddingProviderMock).toHaveBeenCalledTimes(1);
expect(createOllamaEmbeddingProviderMock).toHaveBeenCalledWith(options);
expect(result).toEqual({
provider: { source: "mock-provider", options },
client: { source: "mock-client" },
});
});
});

View File

@@ -1,5 +0,0 @@
export type { OllamaEmbeddingClient } from "../../plugin-sdk/ollama-runtime.js";
export {
createOllamaEmbeddingProvider,
DEFAULT_OLLAMA_EMBEDDING_MODEL,
} from "../../plugin-sdk/ollama-runtime.js";

View File

@@ -1,25 +0,0 @@
import { describe, expect, it } from "vitest";
import { DEFAULT_OPENAI_EMBEDDING_MODEL, normalizeOpenAiModel } from "./embeddings-openai.js";
describe("normalizeOpenAiModel", () => {
it("returns the default model when input is blank", () => {
expect(normalizeOpenAiModel("")).toBe(DEFAULT_OPENAI_EMBEDDING_MODEL);
expect(normalizeOpenAiModel(" ")).toBe(DEFAULT_OPENAI_EMBEDDING_MODEL);
});
it("strips the openai/ prefix", () => {
expect(normalizeOpenAiModel("openai/text-embedding-3-small")).toBe("text-embedding-3-small");
expect(normalizeOpenAiModel("openai/text-embedding-ada-002")).toBe("text-embedding-ada-002");
});
it("preserves explicit third-party provider prefixes", () => {
expect(normalizeOpenAiModel("spark/text-embedding-3-small")).toBe(
"spark/text-embedding-3-small",
);
expect(normalizeOpenAiModel("litellm/azure/ada-002")).toBe("litellm/azure/ada-002");
});
it("preserves unprefixed model ids", () => {
expect(normalizeOpenAiModel("text-embedding-3-large")).toBe("text-embedding-3-large");
});
});

View File

@@ -1,88 +0,0 @@
import { vi } from "vitest";
import { type FetchMock, withFetchPreconnect } from "../../test-utils/fetch-mock.js";
vi.mock("../../infra/net/fetch-guard.js", () => ({
fetchWithSsrFGuard: async (params: {
url: string;
init?: RequestInit;
fetchImpl?: typeof fetch;
}) => {
const fetchImpl = params.fetchImpl ?? globalThis.fetch;
if (!fetchImpl) {
throw new Error("fetch is not available");
}
const response = await fetchImpl(params.url, params.init);
return {
response,
finalUrl: params.url,
release: async () => {},
};
},
}));
type FetchPayloadFactory = (input: RequestInfo | URL, init?: RequestInit) => unknown;
type JsonResponseFetchMock = ReturnType<typeof vi.fn<FetchMock>> & {
preconnect: (
url: string | URL,
options?: { dns?: boolean; tcp?: boolean; http?: boolean; https?: boolean },
) => void;
__openclawAcceptsDispatcher: true;
};
export type JsonFetchMock = ReturnType<typeof createJsonResponseFetchMock>;
export function createJsonResponseFetchMock(payload: FetchPayloadFactory): JsonResponseFetchMock;
export function createJsonResponseFetchMock(payload: unknown): JsonResponseFetchMock;
export function createJsonResponseFetchMock(payload: unknown) {
const fetchMock = vi.fn<FetchMock>(async (input: RequestInfo | URL, init?: RequestInit) => {
const body =
typeof payload === "function" ? (payload as FetchPayloadFactory)(input, init) : payload;
const serialized = JSON.stringify(body);
return {
ok: true,
status: 200,
json: async () => body,
text: async () => serialized,
} as Response;
});
return withFetchPreconnect(fetchMock) as JsonResponseFetchMock;
}
export function createEmbeddingDataFetchMock(embeddingValues = [0.1, 0.2, 0.3]) {
return createJsonResponseFetchMock({ data: [{ embedding: embeddingValues }] });
}
export function createGeminiFetchMock(embeddingValues = [1, 2, 3]) {
return createJsonResponseFetchMock({ embedding: { values: embeddingValues } });
}
export function createGeminiBatchFetchMock(count: number, embeddingValues = [1, 2, 3]) {
return createJsonResponseFetchMock({
embeddings: Array.from({ length: count }, () => ({ values: embeddingValues })),
});
}
export function installFetchMock(fetchMock: typeof globalThis.fetch) {
vi.stubGlobal("fetch", fetchMock);
}
export function readFirstFetchRequest(fetchMock: { mock: { calls: unknown[][] } }) {
const [url, init] = fetchMock.mock.calls[0] ?? [];
return { url, init: init as RequestInit | undefined };
}
export function parseFetchBody(fetchMock: { mock: { calls: unknown[][] } }, callIndex = 0) {
const init = fetchMock.mock.calls[callIndex]?.[1] as RequestInit | undefined;
return JSON.parse((init?.body as string) ?? "{}") as Record<string, unknown>;
}
export function mockResolvedProviderKey(
resolveApiKeyForProvider: typeof import("../../agents/model-auth.js").resolveApiKeyForProvider,
apiKey = "test-key",
) {
vi.mocked(resolveApiKeyForProvider).mockResolvedValue({
apiKey,
mode: "api-key",
source: "test",
});
}

View File

@@ -5,7 +5,7 @@ import type { EmbeddingProviderOptions } from "./embeddings.types.js";
import { buildRemoteBaseUrlPolicy } from "./remote-http.js";
import { resolveMemorySecretInputString } from "./secret-input.js";
export type RemoteEmbeddingProviderId = "openai" | "voyage" | "mistral";
export type RemoteEmbeddingProviderId = string;
export async function resolveRemoteEmbeddingBearerClient(params: {
provider: RemoteEmbeddingProviderId;

View File

@@ -1,141 +0,0 @@
import { afterEach, beforeAll, beforeEach, describe, expect, it, vi } from "vitest";
import * as authModule from "../../agents/model-auth.js";
import {
createEmbeddingDataFetchMock,
createJsonResponseFetchMock,
installFetchMock,
mockResolvedProviderKey,
type JsonFetchMock,
} from "./embeddings-provider.test-support.js";
import { mockPublicPinnedHostname } from "./test-helpers/ssrf.js";
const { resolveApiKeyForProviderMock } = vi.hoisted(() => ({
resolveApiKeyForProviderMock: vi.fn(),
}));
vi.mock("../../agents/model-auth.js", () => {
return {
resolveApiKeyForProvider: resolveApiKeyForProviderMock,
requireApiKey: (auth: { apiKey?: string; mode?: string }, provider: string) => {
if (auth.apiKey) {
return auth.apiKey;
}
throw new Error(`No API key resolved for provider "${provider}" (auth mode: ${auth.mode}).`);
},
};
});
let createVoyageEmbeddingProvider: typeof import("./embeddings-voyage.js").createVoyageEmbeddingProvider;
let normalizeVoyageModel: typeof import("./embeddings-voyage.js").normalizeVoyageModel;
beforeAll(async () => {
({ createVoyageEmbeddingProvider, normalizeVoyageModel } =
await import("./embeddings-voyage.js"));
});
beforeEach(() => {
vi.useRealTimers();
vi.doUnmock("undici");
});
async function createDefaultVoyageProvider(model: string, fetchMock: JsonFetchMock) {
installFetchMock(fetchMock as unknown as typeof globalThis.fetch);
mockPublicPinnedHostname();
mockResolvedProviderKey(authModule.resolveApiKeyForProvider, "voyage-key-123");
return createVoyageEmbeddingProvider({
config: {} as never,
provider: "voyage",
model,
fallback: "none",
});
}
describe("voyage embedding provider", () => {
afterEach(() => {
vi.doUnmock("undici");
vi.resetAllMocks();
vi.unstubAllGlobals();
});
it("configures client with correct defaults and headers", async () => {
const fetchMock = createEmbeddingDataFetchMock();
const result = await createDefaultVoyageProvider("voyage-4-large", fetchMock);
await result.provider.embedQuery("test query");
expect(authModule.resolveApiKeyForProvider).toHaveBeenCalledWith(
expect.objectContaining({ provider: "voyage" }),
);
const call = fetchMock.mock.calls[0];
expect(call).toBeDefined();
const [url, init] = call as [RequestInfo | URL, RequestInit | undefined];
expect(url).toBe("https://api.voyageai.com/v1/embeddings");
const headers = (init?.headers ?? {}) as Record<string, string>;
expect(headers.Authorization).toBe("Bearer voyage-key-123");
expect(headers["Content-Type"]).toBe("application/json");
const body = JSON.parse(init?.body as string);
expect(body).toEqual({
model: "voyage-4-large",
input: ["test query"],
input_type: "query",
});
});
it("respects remote overrides for baseUrl and apiKey", async () => {
const fetchMock = createEmbeddingDataFetchMock();
installFetchMock(fetchMock as unknown as typeof globalThis.fetch);
mockPublicPinnedHostname();
const result = await createVoyageEmbeddingProvider({
config: {} as never,
provider: "voyage",
model: "voyage-4-lite",
fallback: "none",
remote: {
baseUrl: "https://example.com",
apiKey: "remote-override-key",
headers: { "X-Custom": "123" },
},
});
await result.provider.embedQuery("test");
const call = fetchMock.mock.calls[0];
expect(call).toBeDefined();
const [url, init] = call as [RequestInfo | URL, RequestInit | undefined];
expect(url).toBe("https://example.com/embeddings");
const headers = (init?.headers ?? {}) as Record<string, string>;
expect(headers.Authorization).toBe("Bearer remote-override-key");
expect(headers["X-Custom"]).toBe("123");
});
it("passes input_type=document for embedBatch", async () => {
const fetchMock = createJsonResponseFetchMock({
data: [{ embedding: [0.1, 0.2] }, { embedding: [0.3, 0.4] }],
});
const result = await createDefaultVoyageProvider("voyage-4-large", fetchMock);
await result.provider.embedBatch(["doc1", "doc2"]);
const call = fetchMock.mock.calls[0];
expect(call).toBeDefined();
const [, init] = call as [RequestInfo | URL, RequestInit | undefined];
const body = JSON.parse(init?.body as string);
expect(body).toEqual({
model: "voyage-4-large",
input: ["doc1", "doc2"],
input_type: "document",
});
});
it("normalizes model names", async () => {
expect(normalizeVoyageModel("voyage/voyage-large-2")).toBe("voyage-large-2");
expect(normalizeVoyageModel("voyage-4-large")).toBe("voyage-4-large");
expect(normalizeVoyageModel(" voyage-lite ")).toBe("voyage-lite");
expect(normalizeVoyageModel("")).toBe("voyage-4-large"); // Default
});
});

View File

@@ -1,768 +1,67 @@
import { setTimeout as sleep } from "node:timers/promises";
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
import * as authModule from "../../agents/model-auth.js";
import {
createEmbeddingDataFetchMock,
createGeminiFetchMock,
installFetchMock,
mockResolvedProviderKey as mockResolvedProviderKeyBase,
readFirstFetchRequest,
} from "./embeddings-provider.test-support.js";
import { createEmbeddingProvider, DEFAULT_LOCAL_MODEL } from "./embeddings.js";
import { createLocalEmbeddingProvider, DEFAULT_LOCAL_MODEL } from "./embeddings.js";
import * as nodeLlamaModule from "./node-llama.js";
import { mockPublicPinnedHostname } from "./test-helpers/ssrf.js";
const {
bedrockSendMock,
createOllamaEmbeddingProviderMock,
defaultProviderMock,
resolveApiKeyForProviderMock,
resolveCredentialsMock,
} = vi.hoisted(() => ({
bedrockSendMock: vi.fn(),
createOllamaEmbeddingProviderMock: vi.fn(async () => {
throw new Error("Unexpected ollama provider in embeddings.test.ts");
}),
defaultProviderMock: vi.fn(),
resolveCredentialsMock: vi.fn(),
resolveApiKeyForProviderMock: vi.fn(),
}));
vi.mock("../../agents/model-auth.js", () => {
return {
resolveApiKeyForProvider: resolveApiKeyForProviderMock,
requireApiKey: (auth: { apiKey?: string; mode?: string }, provider: string) => {
if (auth.apiKey) {
return auth.apiKey;
}
throw new Error(`No API key resolved for provider "${provider}" (auth mode: ${auth.mode}).`);
},
};
});
vi.mock("./embeddings-ollama.js", () => ({
createOllamaEmbeddingProvider: createOllamaEmbeddingProviderMock,
}));
vi.mock("@aws-sdk/client-bedrock-runtime", () => {
class MockClient {
send = bedrockSendMock;
}
class MockCommand {
input: unknown;
constructor(input: unknown) {
this.input = input;
}
}
return { BedrockRuntimeClient: MockClient, InvokeModelCommand: MockCommand };
});
vi.mock("@aws-sdk/credential-provider-node", () => ({
defaultProvider: defaultProviderMock.mockImplementation(() => resolveCredentialsMock),
}));
const createFetchMock = () => createEmbeddingDataFetchMock([1, 2, 3]);
type ResolvedProviderAuth = Awaited<ReturnType<typeof authModule.resolveApiKeyForProvider>>;
beforeEach(() => {
vi.spyOn(authModule, "resolveApiKeyForProvider");
vi.spyOn(nodeLlamaModule, "importNodeLlamaCpp");
defaultProviderMock.mockImplementation(() => resolveCredentialsMock);
});
beforeEach(() => {
vi.useRealTimers();
});
afterEach(() => {
vi.resetAllMocks();
vi.unstubAllGlobals();
});
function requireProvider(result: Awaited<ReturnType<typeof createEmbeddingProvider>>) {
if (!result.provider) {
throw new Error("Expected embedding provider");
}
return result.provider;
function mockLocalEmbeddingRuntime(vector = new Float32Array([2.35, 3.45, 0.63, 4.3])) {
const getEmbeddingFor = vi.fn().mockResolvedValue({ vector });
const createEmbeddingContext = vi.fn().mockResolvedValue({ getEmbeddingFor });
const loadModel = vi.fn().mockResolvedValue({ createEmbeddingContext });
const resolveModelFile = vi.fn(async (modelPath: string) => `/resolved/${modelPath}`);
vi.mocked(nodeLlamaModule.importNodeLlamaCpp).mockResolvedValue({
getLlama: async () => ({ loadModel }),
resolveModelFile,
LlamaLogLevel: { error: 0 },
} as never);
return { createEmbeddingContext, getEmbeddingFor, loadModel, resolveModelFile };
}
function mockResolvedProviderKey(apiKey = "provider-key") {
mockResolvedProviderKeyBase(authModule.resolveApiKeyForProvider, apiKey);
}
describe("local embedding provider", () => {
it("normalizes local embeddings and resolves the default local model", async () => {
const runtime = mockLocalEmbeddingRuntime();
function mockMissingLocalEmbeddingDependency() {
vi.mocked(nodeLlamaModule.importNodeLlamaCpp).mockRejectedValue(
Object.assign(new Error("Cannot find package 'node-llama-cpp'"), {
code: "ERR_MODULE_NOT_FOUND",
}),
);
}
function createLocalProvider(options?: { fallback?: "none" | "openai" }) {
return createEmbeddingProvider({
config: {} as never,
provider: "local",
model: "text-embedding-3-small",
fallback: options?.fallback ?? "none",
});
}
function expectAutoSelectedProvider(
result: Awaited<ReturnType<typeof createEmbeddingProvider>>,
expectedId: "openai" | "gemini" | "mistral" | "bedrock",
) {
expect(result.requestedProvider).toBe("auto");
const provider = requireProvider(result);
expect(provider.id).toBe(expectedId);
return provider;
}
function createAutoProvider(model = "") {
return createEmbeddingProvider({
config: {} as never,
provider: "auto",
model,
fallback: "none",
});
}
describe("embedding provider remote overrides", () => {
it("uses remote baseUrl/apiKey and merges headers", async () => {
const fetchMock = createFetchMock();
installFetchMock(fetchMock as unknown as typeof globalThis.fetch);
mockPublicPinnedHostname();
mockResolvedProviderKey("provider-key");
const cfg = {
models: {
providers: {
openai: {
baseUrl: "https://api.openai.com/v1",
headers: {
"X-Provider": "p",
"X-Shared": "provider",
},
},
},
},
};
const result = await createEmbeddingProvider({
config: cfg as never,
provider: "openai",
remote: {
baseUrl: "https://example.com/v1",
apiKey: " remote-key ",
headers: {
"X-Shared": "remote",
"X-Remote": "r",
},
},
model: "text-embedding-3-small",
fallback: "openai",
});
const provider = requireProvider(result);
await provider.embedQuery("hello");
expect(authModule.resolveApiKeyForProvider).not.toHaveBeenCalled();
const url = fetchMock.mock.calls[0]?.[0];
const init = fetchMock.mock.calls[0]?.[1];
expect(url).toBe("https://example.com/v1/embeddings");
const headers = (init?.headers ?? {}) as Record<string, string>;
expect(headers.Authorization).toBe("Bearer remote-key");
expect(headers["Content-Type"]).toBe("application/json");
expect(headers["X-Provider"]).toBe("p");
expect(headers["X-Shared"]).toBe("remote");
expect(headers["X-Remote"]).toBe("r");
});
it("falls back to resolved api key when remote apiKey is blank", async () => {
const fetchMock = createFetchMock();
installFetchMock(fetchMock as unknown as typeof globalThis.fetch);
mockPublicPinnedHostname();
mockResolvedProviderKey("provider-key");
const cfg = {
models: {
providers: {
openai: {
baseUrl: "https://api.openai.com/v1",
},
},
},
};
const result = await createEmbeddingProvider({
config: cfg as never,
provider: "openai",
remote: {
baseUrl: "https://example.com/v1",
apiKey: " ",
},
model: "text-embedding-3-small",
fallback: "openai",
});
const provider = requireProvider(result);
await provider.embedQuery("hello");
expect(authModule.resolveApiKeyForProvider).toHaveBeenCalledTimes(1);
const init = fetchMock.mock.calls[0]?.[1];
const headers = (init?.headers as Record<string, string>) ?? {};
expect(headers.Authorization).toBe("Bearer provider-key");
});
it("builds Gemini embeddings requests with api key header", async () => {
const fetchMock = createGeminiFetchMock();
installFetchMock(fetchMock as unknown as typeof globalThis.fetch);
mockPublicPinnedHostname();
mockResolvedProviderKey("provider-key");
const cfg = {
models: {
providers: {
google: {
baseUrl: "https://generativelanguage.googleapis.com/v1beta",
},
},
},
};
const result = await createEmbeddingProvider({
config: cfg as never,
provider: "gemini",
remote: {
apiKey: "gemini-key",
},
model: "text-embedding-004",
fallback: "openai",
});
const provider = requireProvider(result);
await provider.embedQuery("hello");
const { url, init } = readFirstFetchRequest(fetchMock);
expect(url).toBe(
"https://generativelanguage.googleapis.com/v1beta/models/text-embedding-004:embedContent",
);
const headers = (init?.headers ?? {}) as Record<string, string>;
expect(headers["x-goog-api-key"]).toBe("gemini-key");
expect(headers["Content-Type"]).toBe("application/json");
});
it("fails fast when Gemini remote apiKey is an unresolved SecretRef", async () => {
vi.stubEnv("GEMINI_API_KEY", "");
await expect(
createEmbeddingProvider({
config: {} as never,
provider: "gemini",
remote: {
apiKey: { source: "env", provider: "default", id: "GEMINI_API_KEY" },
},
model: "text-embedding-004",
fallback: "openai",
}),
).rejects.toThrow(/agents\.\*\.memorySearch\.remote\.apiKey:/i);
});
it("uses GEMINI_API_KEY env indirection for Gemini remote apiKey", async () => {
vi.stubEnv("GEMINI_API_KEY", "env-gemini-key");
const result = await createEmbeddingProvider({
config: {} as never,
provider: "gemini",
remote: {
apiKey: "GEMINI_API_KEY", // pragma: allowlist secret
},
model: "text-embedding-004",
fallback: "openai",
});
const provider = requireProvider(result);
expect(provider.id).toBe("gemini");
expect(result.gemini?.apiKeys).toEqual(["env-gemini-key"]);
});
it("builds Mistral embeddings requests with bearer auth", async () => {
const fetchMock = createFetchMock();
installFetchMock(fetchMock as unknown as typeof globalThis.fetch);
mockPublicPinnedHostname();
mockResolvedProviderKey("provider-key");
const cfg = {
models: {
providers: {
mistral: {
baseUrl: "https://api.mistral.ai/v1",
},
},
},
};
const result = await createEmbeddingProvider({
config: cfg as never,
provider: "mistral",
remote: {
apiKey: "mistral-key", // pragma: allowlist secret
},
model: "mistral/mistral-embed",
fallback: "none",
});
const provider = requireProvider(result);
await provider.embedQuery("hello");
const { url, init } = readFirstFetchRequest(fetchMock);
expect(url).toBe("https://api.mistral.ai/v1/embeddings");
const headers = (init?.headers ?? {}) as Record<string, string>;
expect(headers.Authorization).toBe("Bearer mistral-key");
const payload = JSON.parse((init?.body as string | undefined) ?? "{}") as { model?: string };
expect(payload.model).toBe("mistral-embed");
});
});
describe("embedding provider auto selection", () => {
it("keeps explicit model when openai is selected", async () => {
const fetchMock = vi.fn(async (_input?: unknown, _init?: unknown) => ({
ok: true,
status: 200,
json: async () => ({ data: [{ embedding: [1, 2, 3] }] }),
}));
installFetchMock(fetchMock as unknown as typeof globalThis.fetch);
mockPublicPinnedHostname();
vi.mocked(authModule.resolveApiKeyForProvider).mockImplementation(async ({ provider }) => {
if (provider === "openai") {
return { apiKey: "openai-key", source: "env: OPENAI_API_KEY", mode: "api-key" };
}
throw new Error(`Unexpected provider ${provider}`);
});
const result = await createEmbeddingProvider({
config: {} as never,
provider: "auto",
model: "text-embedding-3-small",
fallback: "none",
});
expect(result.requestedProvider).toBe("auto");
const provider = requireProvider(result);
expect(provider.id).toBe("openai");
await provider.embedQuery("hello");
const url = fetchMock.mock.calls[0]?.[0];
const init = fetchMock.mock.calls[0]?.[1] as RequestInit | undefined;
expect(url).toBe("https://api.openai.com/v1/embeddings");
const payload = JSON.parse(init?.body as string) as { model?: string };
expect(payload.model).toBe("text-embedding-3-small");
});
it("selects the first available remote provider in auto mode", async () => {
const cases: Array<{
name: string;
expectedProvider: "openai" | "gemini" | "mistral";
resolveApiKey: (provider: string) => ResolvedProviderAuth;
}> = [
{
name: "openai first",
expectedProvider: "openai" as const,
resolveApiKey(provider: string): ResolvedProviderAuth {
if (provider === "openai") {
return { apiKey: "openai-key", source: "env: OPENAI_API_KEY", mode: "api-key" };
}
throw new Error(`No API key found for provider "${provider}".`);
},
},
{
name: "gemini fallback",
expectedProvider: "gemini" as const,
resolveApiKey(provider: string): ResolvedProviderAuth {
if (provider === "openai") {
throw new Error('No API key found for provider "openai".');
}
if (provider === "google") {
return {
apiKey: "gemini-key",
source: "env: GEMINI_API_KEY",
mode: "api-key" as const,
};
}
throw new Error(`Unexpected provider ${provider}`);
},
},
{
name: "mistral after earlier misses",
expectedProvider: "mistral" as const,
resolveApiKey(provider: string): ResolvedProviderAuth {
if (provider === "mistral") {
return {
apiKey: "mistral-key",
source: "env: MISTRAL_API_KEY",
mode: "api-key" as const,
};
}
throw new Error(`No API key found for provider "${provider}".`);
},
},
];
for (const testCase of cases) {
vi.mocked(authModule.resolveApiKeyForProvider).mockReset();
vi.mocked(authModule.resolveApiKeyForProvider).mockImplementation(async ({ provider }) =>
testCase.resolveApiKey(provider),
);
const result = await createAutoProvider();
expectAutoSelectedProvider(result, testCase.expectedProvider);
}
});
it("selects Bedrock in auto mode when the AWS credential chain resolves", async () => {
bedrockSendMock.mockResolvedValue({
body: new TextEncoder().encode(JSON.stringify({ embedding: [1, 2, 3] })),
});
resolveCredentialsMock.mockResolvedValue({ accessKeyId: "AKIAEXAMPLE" });
vi.mocked(authModule.resolveApiKeyForProvider).mockImplementation(async ({ provider }) => {
throw new Error(`No API key found for provider "${provider}".`);
});
const result = await createAutoProvider();
const provider = expectAutoSelectedProvider(result, "bedrock");
await provider.embedQuery("hello");
expect(bedrockSendMock).toHaveBeenCalledTimes(1);
});
it("rethrows non-auth Bedrock setup errors in auto mode", async () => {
resolveCredentialsMock.mockResolvedValue({ accessKeyId: "AKIAEXAMPLE" });
vi.mocked(authModule.resolveApiKeyForProvider).mockImplementation(async ({ provider }) => {
throw new Error(`No API key found for provider "${provider}".`);
});
await expect(
createEmbeddingProvider({
config: {} as never,
provider: "auto",
model: "",
fallback: "none",
outputDimensionality: 768,
}),
).rejects.toThrow("Invalid dimensions 768");
});
});
describe("embedding provider local fallback", () => {
it("falls back to openai when node-llama-cpp is missing", async () => {
mockMissingLocalEmbeddingDependency();
const fetchMock = createFetchMock();
installFetchMock(fetchMock as unknown as typeof globalThis.fetch);
mockResolvedProviderKey("provider-key");
const result = await createLocalProvider({ fallback: "openai" });
const provider = requireProvider(result);
expect(provider.id).toBe("openai");
expect(result.fallbackFrom).toBe("local");
expect(result.fallbackReason).toContain("node-llama-cpp");
});
it("throws a helpful error when local is requested and fallback is none", async () => {
mockMissingLocalEmbeddingDependency();
await expect(createLocalProvider()).rejects.toThrow(/optional dependency node-llama-cpp/i);
});
it("mentions every remote provider in local setup guidance", async () => {
mockMissingLocalEmbeddingDependency();
await expect(createLocalProvider()).rejects.toThrow(/provider = "gemini"/i);
await expect(createLocalProvider()).rejects.toThrow(/provider = "mistral"/i);
});
});
describe("local embedding normalization", () => {
async function createLocalProviderForTest() {
return createEmbeddingProvider({
const provider = await createLocalEmbeddingProvider({
config: {} as never,
provider: "local",
model: "",
fallback: "none",
});
}
function mockSingleLocalEmbeddingVector(
vector: number[],
resolveModelFile: (modelPath: string, modelDirectory?: string) => Promise<string> = async () =>
"/fake/model.gguf",
): void {
vi.mocked(nodeLlamaModule.importNodeLlamaCpp).mockResolvedValue({
getLlama: async () => ({
loadModel: vi.fn().mockResolvedValue({
createEmbeddingContext: vi.fn().mockResolvedValue({
getEmbeddingFor: vi.fn().mockResolvedValue({
vector: new Float32Array(vector),
}),
}),
}),
}),
resolveModelFile,
LlamaLogLevel: { error: 0 },
} as never);
}
it("normalizes local embeddings to magnitude ~1.0", async () => {
const unnormalizedVector = [2.35, 3.45, 0.63, 4.3, 1.2, 5.1, 2.8, 3.9];
const resolveModelFileMock = vi.fn(async () => "/fake/model.gguf");
mockSingleLocalEmbeddingVector(unnormalizedVector, resolveModelFileMock);
const result = await createLocalProviderForTest();
const provider = requireProvider(result);
const embedding = await provider.embedQuery("test query");
const magnitude = Math.sqrt(embedding.reduce((sum, value) => sum + value * value, 0));
const magnitude = Math.sqrt(embedding.reduce((sum, x) => sum + x * x, 0));
expect(magnitude).toBeCloseTo(1.0, 5);
expect(resolveModelFileMock).toHaveBeenCalledWith(DEFAULT_LOCAL_MODEL, undefined);
expect(magnitude).toBeCloseTo(1, 5);
expect(runtime.resolveModelFile).toHaveBeenCalledWith(DEFAULT_LOCAL_MODEL, undefined);
expect(runtime.getEmbeddingFor).toHaveBeenCalledWith("test query");
});
it("handles zero vector without division by zero", async () => {
const zeroVector = [0, 0, 0, 0];
it("trims explicit local model paths and cache directories", async () => {
const runtime = mockLocalEmbeddingRuntime(new Float32Array([1, 0]));
mockSingleLocalEmbeddingVector(zeroVector);
const result = await createLocalProviderForTest();
const provider = requireProvider(result);
const embedding = await provider.embedQuery("test");
expect(embedding).toEqual([0, 0, 0, 0]);
expect(embedding.every((value) => Number.isFinite(value))).toBe(true);
});
it("sanitizes non-finite values before normalization", async () => {
const nonFiniteVector = [1, Number.NaN, Number.POSITIVE_INFINITY, Number.NEGATIVE_INFINITY];
mockSingleLocalEmbeddingVector(nonFiniteVector);
const result = await createLocalProviderForTest();
const provider = requireProvider(result);
const embedding = await provider.embedQuery("test");
expect(embedding).toEqual([1, 0, 0, 0]);
expect(embedding.every((value) => Number.isFinite(value))).toBe(true);
});
it("normalizes batch embeddings to magnitude ~1.0", async () => {
const unnormalizedVectors = [
[2.35, 3.45, 0.63, 4.3],
[10.0, 0.0, 0.0, 0.0],
[1.0, 1.0, 1.0, 1.0],
];
vi.mocked(nodeLlamaModule.importNodeLlamaCpp).mockResolvedValue({
getLlama: async () => ({
loadModel: vi.fn().mockResolvedValue({
createEmbeddingContext: vi.fn().mockResolvedValue({
getEmbeddingFor: vi
.fn()
.mockResolvedValueOnce({ vector: new Float32Array(unnormalizedVectors[0]) })
.mockResolvedValueOnce({ vector: new Float32Array(unnormalizedVectors[1]) })
.mockResolvedValueOnce({ vector: new Float32Array(unnormalizedVectors[2]) }),
}),
}),
}),
resolveModelFile: async () => "/fake/model.gguf",
LlamaLogLevel: { error: 0 },
} as never);
const result = await createLocalProviderForTest();
const provider = requireProvider(result);
const embeddings = await provider.embedBatch(["text1", "text2", "text3"]);
for (const embedding of embeddings) {
const magnitude = Math.sqrt(embedding.reduce((sum, x) => sum + x * x, 0));
expect(magnitude).toBeCloseTo(1.0, 5);
}
});
});
describe("local embedding ensureContext concurrency", () => {
async function setupLocalProviderWithMockedInit(params?: {
initializationDelayMs?: number;
failFirstGetLlama?: boolean;
}) {
const getLlamaSpy = vi.fn();
const loadModelSpy = vi.fn();
const createContextSpy = vi.fn();
let shouldFail = params?.failFirstGetLlama ?? false;
vi.spyOn(nodeLlamaModule, "importNodeLlamaCpp").mockResolvedValue({
getLlama: async (...args: unknown[]) => {
getLlamaSpy(...args);
if (shouldFail) {
shouldFail = false;
throw new Error("transient init failure");
}
if (params?.initializationDelayMs) {
await sleep(params.initializationDelayMs);
}
return {
loadModel: async (...modelArgs: unknown[]) => {
loadModelSpy(...modelArgs);
if (params?.initializationDelayMs) {
await sleep(params.initializationDelayMs);
}
return {
createEmbeddingContext: async () => {
createContextSpy();
return {
getEmbeddingFor: vi.fn().mockResolvedValue({
vector: new Float32Array([1, 0, 0, 0]),
}),
};
},
};
},
};
},
resolveModelFile: async () => "/fake/model.gguf",
LlamaLogLevel: { error: 0 },
} as never);
const result = await createEmbeddingProvider({
const provider = await createLocalEmbeddingProvider({
config: {} as never,
provider: "local",
model: "",
fallback: "none",
local: {
modelPath: " /models/embed.gguf ",
modelCacheDir: " /cache/models ",
},
});
return {
provider: requireProvider(result),
getLlamaSpy,
loadModelSpy,
createContextSpy,
};
}
await provider.embedBatch(["a", "b"]);
it("loads the model only once when embedBatch is called concurrently", async () => {
const { provider, getLlamaSpy, loadModelSpy, createContextSpy } =
await setupLocalProviderWithMockedInit({
initializationDelayMs: 5,
});
const results = await Promise.all([
provider.embedBatch(["text1"]),
provider.embedBatch(["text2"]),
provider.embedBatch(["text3"]),
provider.embedBatch(["text4"]),
]);
expect(results).toHaveLength(4);
for (const embeddings of results) {
expect(embeddings).toHaveLength(1);
expect(embeddings[0]).toHaveLength(4);
}
expect(getLlamaSpy).toHaveBeenCalledTimes(1);
expect(loadModelSpy).toHaveBeenCalledTimes(1);
expect(createContextSpy).toHaveBeenCalledTimes(1);
});
it("retries initialization after a transient ensureContext failure", async () => {
const { provider, getLlamaSpy, loadModelSpy, createContextSpy } =
await setupLocalProviderWithMockedInit({
failFirstGetLlama: true,
});
await expect(provider.embedBatch(["first"])).rejects.toThrow("transient init failure");
const recovered = await provider.embedBatch(["second"]);
expect(recovered).toHaveLength(1);
expect(recovered[0]).toHaveLength(4);
expect(getLlamaSpy).toHaveBeenCalledTimes(2);
expect(loadModelSpy).toHaveBeenCalledTimes(1);
expect(createContextSpy).toHaveBeenCalledTimes(1);
});
it("shares initialization when embedQuery and embedBatch start concurrently", async () => {
const { provider, getLlamaSpy, loadModelSpy, createContextSpy } =
await setupLocalProviderWithMockedInit({
initializationDelayMs: 5,
});
const [queryA, batch, queryB] = await Promise.all([
provider.embedQuery("query-a"),
provider.embedBatch(["batch-a", "batch-b"]),
provider.embedQuery("query-b"),
]);
expect(queryA).toHaveLength(4);
expect(batch).toHaveLength(2);
expect(queryB).toHaveLength(4);
expect(batch[0]).toHaveLength(4);
expect(batch[1]).toHaveLength(4);
expect(getLlamaSpy).toHaveBeenCalledTimes(1);
expect(loadModelSpy).toHaveBeenCalledTimes(1);
expect(createContextSpy).toHaveBeenCalledTimes(1);
});
});
describe("FTS-only fallback when no provider available", () => {
it("returns null provider when all requested auth paths fail", async () => {
vi.mocked(authModule.resolveApiKeyForProvider).mockRejectedValue(
new Error("No API key found for provider"),
);
for (const testCase of [
{
name: "auto mode",
options: {
config: {} as never,
provider: "auto" as const,
model: "",
fallback: "none" as const,
},
requestedProvider: "auto",
fallbackFrom: undefined,
reasonIncludes: "No API key",
},
{
name: "explicit provider only",
options: {
config: {} as never,
provider: "openai" as const,
model: "text-embedding-3-small",
fallback: "none" as const,
},
requestedProvider: "openai",
fallbackFrom: undefined,
reasonIncludes: "No API key",
},
{
name: "primary and fallback",
options: {
config: {} as never,
provider: "openai" as const,
model: "text-embedding-3-small",
fallback: "gemini" as const,
},
requestedProvider: "openai",
fallbackFrom: "openai",
reasonIncludes: "Fallback to gemini failed",
},
]) {
const result = await createEmbeddingProvider(testCase.options);
expect(result.provider, testCase.name).toBeNull();
expect(result.requestedProvider, testCase.name).toBe(testCase.requestedProvider);
expect(result.fallbackFrom, testCase.name).toBe(testCase.fallbackFrom);
expect(result.providerUnavailableReason, testCase.name).toContain(testCase.reasonIncludes);
}
expect(provider.model).toBe("/models/embed.gguf");
expect(runtime.resolveModelFile).toHaveBeenCalledWith("/models/embed.gguf", "/cache/models");
expect(runtime.getEmbeddingFor).toHaveBeenCalledTimes(2);
});
});

View File

@@ -1,43 +1,9 @@
import fsSync from "node:fs";
import type { Llama, LlamaEmbeddingContext, LlamaModel } from "node-llama-cpp";
import { formatErrorMessage } from "../../infra/errors.js";
import { normalizeOptionalString } from "../../shared/string-coerce.js";
import { resolveUserPath } from "../../utils.js";
import { sanitizeAndNormalizeEmbedding } from "./embedding-vectors.js";
import {
createBedrockEmbeddingProvider,
hasAwsCredentials,
type BedrockEmbeddingClient,
} from "./embeddings-bedrock.js";
import { createGeminiEmbeddingProvider, type GeminiEmbeddingClient } from "./embeddings-gemini.js";
import {
createLmstudioEmbeddingProvider,
type LmstudioEmbeddingClient,
} from "./embeddings-lmstudio.js";
import {
createMistralEmbeddingProvider,
type MistralEmbeddingClient,
} from "./embeddings-mistral.js";
import { createOllamaEmbeddingProvider, type OllamaEmbeddingClient } from "./embeddings-ollama.js";
import { createOpenAiEmbeddingProvider, type OpenAiEmbeddingClient } from "./embeddings-openai.js";
import { createVoyageEmbeddingProvider, type VoyageEmbeddingClient } from "./embeddings-voyage.js";
import type {
EmbeddingProvider,
EmbeddingProviderFallback,
EmbeddingProviderId,
EmbeddingProviderOptions,
EmbeddingProviderRequest,
GeminiTaskType,
} from "./embeddings.types.js";
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.types.js";
import { importNodeLlamaCpp } from "./node-llama.js";
export type { GeminiEmbeddingClient } from "./embeddings-gemini.js";
export type { LmstudioEmbeddingClient } from "./embeddings-lmstudio.js";
export type { MistralEmbeddingClient } from "./embeddings-mistral.js";
export type { OpenAiEmbeddingClient } from "./embeddings-openai.js";
export type { VoyageEmbeddingClient } from "./embeddings-voyage.js";
export type { OllamaEmbeddingClient } from "./embeddings-ollama.js";
export type { BedrockEmbeddingClient } from "./embeddings-bedrock.js";
export type {
EmbeddingProvider,
EmbeddingProviderFallback,
@@ -47,51 +13,9 @@ export type {
GeminiTaskType,
} from "./embeddings.types.js";
// Remote providers considered for auto-selection when provider === "auto".
// LM Studio and Ollama are intentionally excluded here so that "auto" mode does not
// implicitly assume either instance is available.
// Bedrock is handled separately when AWS credentials are detected.
const REMOTE_EMBEDDING_PROVIDER_IDS = ["openai", "gemini", "voyage", "mistral"] as const;
export type EmbeddingProviderResult = {
provider: EmbeddingProvider | null;
requestedProvider: EmbeddingProviderRequest;
fallbackFrom?: EmbeddingProviderId;
fallbackReason?: string;
providerUnavailableReason?: string;
openAi?: OpenAiEmbeddingClient;
gemini?: GeminiEmbeddingClient;
voyage?: VoyageEmbeddingClient;
mistral?: MistralEmbeddingClient;
ollama?: OllamaEmbeddingClient;
bedrock?: BedrockEmbeddingClient;
lmstudio?: LmstudioEmbeddingClient;
};
export const DEFAULT_LOCAL_MODEL =
"hf:ggml-org/embeddinggemma-300m-qat-q8_0-GGUF/embeddinggemma-300m-qat-Q8_0.gguf";
function canAutoSelectLocal(options: EmbeddingProviderOptions): boolean {
const modelPath = options.local?.modelPath?.trim();
if (!modelPath) {
return false;
}
if (/^(hf:|https?:)/i.test(modelPath)) {
return false;
}
const resolved = resolveUserPath(modelPath);
try {
return fsSync.statSync(resolved).isFile();
} catch {
return false;
}
}
function isMissingApiKeyError(err: unknown): boolean {
const message = formatErrorMessage(err);
return message.includes("No API key found for provider");
}
export async function createLocalEmbeddingProvider(
options: EmbeddingProviderOptions,
): Promise<EmbeddingProvider> {
@@ -154,186 +78,3 @@ export async function createLocalEmbeddingProvider(
},
};
}
export async function createEmbeddingProvider(
options: EmbeddingProviderOptions,
): Promise<EmbeddingProviderResult> {
const requestedProvider = options.provider;
const fallback = options.fallback;
const createProvider = async (id: EmbeddingProviderId) => {
if (id === "local") {
const provider = await createLocalEmbeddingProvider(options);
return { provider };
}
if (id === "lmstudio") {
const { provider, client } = await createLmstudioEmbeddingProvider(options);
return { provider, lmstudio: client };
}
if (id === "ollama") {
const { provider, client } = await createOllamaEmbeddingProvider(options);
return { provider, ollama: client };
}
if (id === "gemini") {
const { provider, client } = await createGeminiEmbeddingProvider(options);
return { provider, gemini: client };
}
if (id === "voyage") {
const { provider, client } = await createVoyageEmbeddingProvider(options);
return { provider, voyage: client };
}
if (id === "mistral") {
const { provider, client } = await createMistralEmbeddingProvider(options);
return { provider, mistral: client };
}
if (id === "bedrock") {
const { provider, client } = await createBedrockEmbeddingProvider(options);
return { provider, bedrock: client };
}
const { provider, client } = await createOpenAiEmbeddingProvider(options);
return { provider, openAi: client };
};
const formatPrimaryError = (err: unknown, provider: EmbeddingProviderId) =>
provider === "local" ? formatLocalSetupError(err) : formatErrorMessage(err);
if (requestedProvider === "auto") {
const missingKeyErrors: string[] = [];
let localError: string | null = null;
if (canAutoSelectLocal(options)) {
try {
const local = await createProvider("local");
return { ...local, requestedProvider };
} catch (err) {
localError = formatLocalSetupError(err);
}
}
for (const provider of REMOTE_EMBEDDING_PROVIDER_IDS) {
try {
const result = await createProvider(provider);
return { ...result, requestedProvider };
} catch (err) {
const message = formatPrimaryError(err, provider);
if (isMissingApiKeyError(err)) {
missingKeyErrors.push(message);
continue;
}
// Non-auth errors (e.g., network) are still fatal
const wrapped = new Error(message) as Error & { cause?: unknown };
wrapped.cause = err;
throw wrapped;
}
}
// Try bedrock if AWS credentials are available
if (await hasAwsCredentials()) {
try {
const result = await createProvider("bedrock");
return { ...result, requestedProvider };
} catch (err) {
const message = formatPrimaryError(err, "bedrock");
if (isMissingApiKeyError(err)) {
missingKeyErrors.push(message);
} else {
const wrapped = new Error(message) as Error & { cause?: unknown };
wrapped.cause = err;
throw wrapped;
}
}
}
// All providers failed due to missing API keys - return null provider for FTS-only mode
const details = [...missingKeyErrors, localError].filter(Boolean) as string[];
const reason = details.length > 0 ? details.join("\n\n") : "No embeddings provider available.";
return {
provider: null,
requestedProvider,
providerUnavailableReason: reason,
};
}
try {
const primary = await createProvider(requestedProvider);
return { ...primary, requestedProvider };
} catch (primaryErr) {
const reason = formatPrimaryError(primaryErr, requestedProvider);
if (fallback && fallback !== "none" && fallback !== requestedProvider) {
try {
const fallbackResult = await createProvider(fallback);
return {
...fallbackResult,
requestedProvider,
fallbackFrom: requestedProvider,
fallbackReason: reason,
};
} catch (fallbackErr) {
// Both primary and fallback failed - check if it's auth-related
const fallbackReason = formatErrorMessage(fallbackErr);
const combinedReason = `${reason}\n\nFallback to ${fallback} failed: ${fallbackReason}`;
if (isMissingApiKeyError(primaryErr) && isMissingApiKeyError(fallbackErr)) {
// Both failed due to missing API keys - return null for FTS-only mode
return {
provider: null,
requestedProvider,
fallbackFrom: requestedProvider,
fallbackReason: reason,
providerUnavailableReason: combinedReason,
};
}
// Non-auth errors are still fatal
const wrapped = new Error(combinedReason) as Error & { cause?: unknown };
wrapped.cause = fallbackErr;
throw wrapped;
}
}
// No fallback configured - check if we should degrade to FTS-only
if (isMissingApiKeyError(primaryErr)) {
return {
provider: null,
requestedProvider,
providerUnavailableReason: reason,
};
}
const wrapped = new Error(reason) as Error & { cause?: unknown };
wrapped.cause = primaryErr;
throw wrapped;
}
}
function isNodeLlamaCppMissing(err: unknown): boolean {
if (!(err instanceof Error)) {
return false;
}
const code = (err as Error & { code?: unknown }).code;
if (code === "ERR_MODULE_NOT_FOUND") {
return err.message.includes("node-llama-cpp");
}
return false;
}
function formatLocalSetupError(err: unknown): string {
const detail = formatErrorMessage(err);
const missing = isNodeLlamaCppMissing(err);
return [
"Local embeddings unavailable.",
missing
? "Reason: optional dependency node-llama-cpp is missing (or failed to install)."
: detail
? `Reason: ${detail}`
: undefined,
missing && detail ? `Detail: ${detail}` : null,
"To enable local embeddings:",
"1) Use Node 24 (recommended for installs/updates; Node 22 LTS, currently 22.14+, remains supported)",
missing
? "2) Reinstall OpenClaw (this should install node-llama-cpp): npm i -g openclaw@latest"
: null,
"3) If you use pnpm: pnpm approve-builds (select node-llama-cpp), then pnpm rebuild node-llama-cpp",
...REMOTE_EMBEDDING_PROVIDER_IDS.map(
(provider) => `Or set agents.defaults.memorySearch.provider = "${provider}" (remote).`,
),
]
.filter(Boolean)
.join("\n");
}

View File

@@ -11,18 +11,9 @@ export type EmbeddingProvider = {
embedBatchInputs?: (inputs: EmbeddingInput[]) => Promise<number[][]>;
};
export type EmbeddingProviderId =
| "openai"
| "local"
| "gemini"
| "voyage"
| "mistral"
| "lmstudio"
| "ollama"
| "bedrock";
export type EmbeddingProviderRequest = EmbeddingProviderId | "auto";
export type EmbeddingProviderFallback = EmbeddingProviderId | "none";
export type EmbeddingProviderId = string;
export type EmbeddingProviderRequest = string;
export type EmbeddingProviderFallback = string;
export type GeminiTaskType =
| "RETRIEVAL_QUERY"
@@ -36,14 +27,14 @@ export type GeminiTaskType =
export type EmbeddingProviderOptions = {
config: OpenClawConfig;
agentDir?: string;
provider: EmbeddingProviderRequest;
provider?: EmbeddingProviderRequest;
remote?: {
baseUrl?: string;
apiKey?: SecretInput;
headers?: Record<string, string>;
};
model: string;
fallback: EmbeddingProviderFallback;
fallback?: EmbeddingProviderFallback;
local?: {
modelPath?: string;
modelCacheDir?: string;

View File

@@ -106,21 +106,3 @@ export function classifyMemoryMultimodalPath(
}
return null;
}
export function normalizeGeminiEmbeddingModelForMemory(model: string): string {
const trimmed = model.trim();
if (!trimmed) {
return "";
}
return trimmed.replace(/^models\//, "").replace(/^(gemini|google)\//, "");
}
export function supportsMemoryMultimodalEmbeddings(params: {
provider: string;
model: string;
}): boolean {
if (params.provider !== "gemini") {
return false;
}
return normalizeGeminiEmbeddingModelForMemory(params.model) === "gemini-embedding-2-preview";
}

View File

@@ -85,11 +85,7 @@ export interface MemorySearchManager {
onDebug?: (debug: MemorySearchRuntimeDebug) => void;
},
): Promise<MemorySearchResult[]>;
readFile(params: {
relPath: string;
from?: number;
lines?: number;
}): Promise<MemoryReadResult>;
readFile(params: { relPath: string; from?: number; lines?: number }): Promise<MemoryReadResult>;
status(): MemoryProviderStatus;
sync?(params?: {
reason?: string;

View File

@@ -1,6 +1,5 @@
export {
isMemoryMultimodalEnabled,
normalizeMemoryMultimodalSettings,
supportsMemoryMultimodalEmbeddings,
type MemoryMultimodalSettings,
} from "./host/multimodal.js";

View File

@@ -5,6 +5,10 @@ import path from "node:path";
import { fileURLToPath, pathToFileURL } from "node:url";
export { resolveEnvApiKey } from "../agents/model-auth-env.js";
export {
collectProviderApiKeysForExecution,
executeWithApiKeyRotation,
} from "../agents/api-key-rotation.js";
export { NON_ENV_SECRETREF_MARKER } from "../agents/model-auth-markers.js";
export {
requireApiKey,

View File

@@ -84,13 +84,29 @@ export function resolvePluginCapabilityProviders<K extends CapabilityProviderReg
}): CapabilityProviderForKey<K>[] {
const activeRegistry = resolveRuntimePluginRegistry();
const activeProviders = activeRegistry?.[params.key] ?? [];
if (activeProviders.length > 0) {
if (activeProviders.length > 0 && params.key !== "memoryEmbeddingProviders") {
return activeProviders.map((entry) => entry.provider) as CapabilityProviderForKey<K>[];
}
const compatConfig = resolveCapabilityProviderConfig({ key: params.key, cfg: params.cfg });
const loadOptions = compatConfig === undefined ? undefined : { config: compatConfig };
const registry = resolveRuntimePluginRegistry(loadOptions);
return (registry?.[params.key] ?? []).map(
(entry) => entry.provider,
) as CapabilityProviderForKey<K>[];
if (params.key !== "memoryEmbeddingProviders") {
return (registry?.[params.key] ?? []).map(
(entry) => entry.provider,
) as CapabilityProviderForKey<K>[];
}
const merged = new Map<string, CapabilityProviderForKey<K>>();
for (const entry of activeProviders) {
const provider = entry.provider as CapabilityProviderForKey<K> & { id?: string };
if (provider.id) {
merged.set(provider.id, provider);
}
}
for (const entry of registry?.[params.key] ?? []) {
const provider = entry.provider as CapabilityProviderForKey<K> & { id?: string };
if (provider.id && !merged.has(provider.id)) {
merged.set(provider.id, provider);
}
}
return [...merged.values()];
}

View File

@@ -28,7 +28,14 @@ const packageManifestContractTests: PackageManifestContractParams[] = [
},
{ pluginId: "irc", minHostVersionBaseline: "2026.3.22" },
{ pluginId: "line", minHostVersionBaseline: "2026.3.22" },
{ pluginId: "amazon-bedrock", mirroredRootRuntimeDeps: ["@aws-sdk/client-bedrock"] },
{
pluginId: "amazon-bedrock",
mirroredRootRuntimeDeps: [
"@aws-sdk/client-bedrock",
"@aws-sdk/client-bedrock-runtime",
"@aws-sdk/credential-provider-node",
],
},
{
pluginId: "amazon-bedrock-mantle",
mirroredRootRuntimeDeps: ["@aws/bedrock-token-generator"],

View File

@@ -36,7 +36,7 @@ afterEach(() => {
});
describe("memory embedding provider runtime resolution", () => {
it("prefers registered adapters over capability fallback adapters", () => {
it("merges registered and declared capability fallback adapters", () => {
registerMemoryEmbeddingProvider({
id: "registered",
create: async () => ({ provider: null }),
@@ -45,9 +45,10 @@ describe("memory embedding provider runtime resolution", () => {
expect(runtimeModule.listMemoryEmbeddingProviders().map((adapter) => adapter.id)).toEqual([
"registered",
"capability",
]);
expect(runtimeModule.getMemoryEmbeddingProvider("registered")?.id).toBe("registered");
expect(mocks.resolvePluginCapabilityProviders).not.toHaveBeenCalled();
expect(mocks.resolvePluginCapabilityProviders).toHaveBeenCalledTimes(1);
});
it("falls back to declared capability adapters when the registry is cold", () => {
@@ -60,14 +61,22 @@ describe("memory embedding provider runtime resolution", () => {
expect(mocks.resolvePluginCapabilityProviders).toHaveBeenCalledTimes(2);
});
it("does not consult capability fallback once runtime adapters are registered", () => {
registerMemoryEmbeddingProvider({
it("prefers registered adapters over declared capability fallback adapters with the same id", () => {
const registered = {
id: "openai",
create: async () => ({ provider: null }),
} satisfies MemoryEmbeddingProviderAdapter;
registerMemoryEmbeddingProvider({
...registered,
});
mocks.resolvePluginCapabilityProviders.mockReturnValue([createCapabilityAdapter("ollama")]);
mocks.resolvePluginCapabilityProviders.mockReturnValue([createCapabilityAdapter("openai")]);
expect(runtimeModule.getMemoryEmbeddingProvider("ollama")).toBeUndefined();
expect(mocks.resolvePluginCapabilityProviders).not.toHaveBeenCalled();
expect(runtimeModule.getMemoryEmbeddingProvider("openai")).toEqual(
expect.objectContaining({ id: "openai" }),
);
expect(runtimeModule.listMemoryEmbeddingProviders().map((adapter) => adapter.id)).toEqual([
"openai",
]);
expect(mocks.resolvePluginCapabilityProviders).toHaveBeenCalledTimes(1);
});
});

View File

@@ -15,13 +15,16 @@ export function listMemoryEmbeddingProviders(
cfg?: OpenClawConfig,
): MemoryEmbeddingProviderAdapter[] {
const registered = listRegisteredMemoryEmbeddingProviderAdapters();
if (registered.length > 0) {
return registered;
}
return resolvePluginCapabilityProviders({
const merged = new Map(registered.map((adapter) => [adapter.id, adapter]));
for (const adapter of resolvePluginCapabilityProviders({
key: "memoryEmbeddingProviders",
cfg,
});
})) {
if (!merged.has(adapter.id)) {
merged.set(adapter.id, adapter);
}
}
return [...merged.values()];
}
export function getMemoryEmbeddingProvider(
@@ -32,8 +35,5 @@ export function getMemoryEmbeddingProvider(
if (registered) {
return registered.adapter;
}
if (listRegisteredMemoryEmbeddingProviders().length > 0) {
return undefined;
}
return listMemoryEmbeddingProviders(cfg).find((adapter) => adapter.id === id);
}

View File

@@ -35,6 +35,8 @@ export type MemoryEmbeddingProvider = {
export type MemoryEmbeddingProviderCreateOptions = {
config: OpenClawConfig;
agentDir?: string;
provider?: string;
fallback?: string;
remote?: {
baseUrl?: string;
apiKey?: SecretInput;
@@ -46,6 +48,14 @@ export type MemoryEmbeddingProviderCreateOptions = {
modelCacheDir?: string;
};
outputDimensionality?: number;
taskType?:
| "RETRIEVAL_QUERY"
| "RETRIEVAL_DOCUMENT"
| "SEMANTIC_SIMILARITY"
| "CLASSIFICATION"
| "CLUSTERING"
| "QUESTION_ANSWERING"
| "FACT_VERIFICATION";
};
export type MemoryEmbeddingProviderCreateResult = {
@@ -57,6 +67,7 @@ export type MemoryEmbeddingProviderAdapter = {
id: string;
defaultModel?: string;
transport?: "local" | "remote";
authProviderId?: string;
autoSelectPriority?: number;
allowExplicitWhenConfiguredAuto?: boolean;
supportsMultimodalEmbeddings?: (params: { model: string }) => boolean;