diff --git a/docs/plugins/sdk-migration.md b/docs/plugins/sdk-migration.md index 9dc963b47fe..7bc95797116 100644 --- a/docs/plugins/sdk-migration.md +++ b/docs/plugins/sdk-migration.md @@ -318,7 +318,7 @@ Current bundled provider examples: | `plugin-sdk/memory-core` | Bundled memory-core helpers | Memory manager/config/file/CLI helper surface | | `plugin-sdk/memory-core-engine-runtime` | Memory engine runtime facade | Memory index/search runtime facade | | `plugin-sdk/memory-core-host-engine-foundation` | Memory host foundation engine | Memory host foundation engine exports | - | `plugin-sdk/memory-core-host-engine-embeddings` | Memory host embedding engine | Memory host embedding engine exports | + | `plugin-sdk/memory-core-host-engine-embeddings` | Memory host embedding engine | Memory embedding contracts, registry access, local provider, and generic batch/remote helpers; concrete remote providers live in their owning plugins | | `plugin-sdk/memory-core-host-engine-qmd` | Memory host QMD engine | Memory host QMD engine exports | | `plugin-sdk/memory-core-host-engine-storage` | Memory host storage engine | Memory host storage engine exports | | `plugin-sdk/memory-core-host-multimodal` | Memory host multimodal helpers | Memory host multimodal helpers | diff --git a/docs/plugins/sdk-overview.md b/docs/plugins/sdk-overview.md index 512d1cbc5e4..c1c39f073c1 100644 --- a/docs/plugins/sdk-overview.md +++ b/docs/plugins/sdk-overview.md @@ -264,7 +264,7 @@ explicitly promotes one as public. | `plugin-sdk/memory-core` | Bundled memory-core helper surface for manager/config/file/CLI helpers | | `plugin-sdk/memory-core-engine-runtime` | Memory index/search runtime facade | | `plugin-sdk/memory-core-host-engine-foundation` | Memory host foundation engine exports | - | `plugin-sdk/memory-core-host-engine-embeddings` | Memory host embedding engine exports | + | `plugin-sdk/memory-core-host-engine-embeddings` | Memory host embedding contracts, registry access, local provider, and generic batch/remote helpers | | `plugin-sdk/memory-core-host-engine-qmd` | Memory host QMD engine exports | | `plugin-sdk/memory-core-host-engine-storage` | Memory host storage engine exports | | `plugin-sdk/memory-core-host-multimodal` | Memory host multimodal helpers | diff --git a/src/memory-host-sdk/host/embeddings-bedrock.ts b/extensions/amazon-bedrock/embedding-provider.ts similarity index 96% rename from src/memory-host-sdk/host/embeddings-bedrock.ts rename to extensions/amazon-bedrock/embedding-provider.ts index 6bdfab2c511..5e8ccb1c5bb 100644 --- a/src/memory-host-sdk/host/embeddings-bedrock.ts +++ b/extensions/amazon-bedrock/embedding-provider.ts @@ -1,7 +1,10 @@ -import { normalizeLowercaseStringOrEmpty } from "../../shared/string-coerce.js"; -import { sanitizeAndNormalizeEmbedding } from "./embedding-vectors.js"; -import { debugEmbeddingsLog } from "./embeddings-debug.js"; -import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.types.js"; +import { + debugEmbeddingsLog, + sanitizeAndNormalizeEmbedding, + type MemoryEmbeddingProvider, + type MemoryEmbeddingProviderCreateOptions, +} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings"; +import { normalizeLowercaseStringOrEmpty } from "openclaw/plugin-sdk/text-runtime"; // --------------------------------------------------------------------------- // Types & constants @@ -254,8 +257,8 @@ function parseCohereBatch(family: Family, raw: string): number[][] { // --------------------------------------------------------------------------- export async function createBedrockEmbeddingProvider( - options: EmbeddingProviderOptions, -): Promise<{ provider: EmbeddingProvider; client: BedrockEmbeddingClient }> { + options: MemoryEmbeddingProviderCreateOptions, +): Promise<{ provider: MemoryEmbeddingProvider; client: BedrockEmbeddingClient }> { const client = resolveBedrockEmbeddingClient(options); const { BedrockRuntimeClient, InvokeModelCommand } = await loadSdk(); const sdk = new BedrockRuntimeClient({ region: client.region }); @@ -333,7 +336,7 @@ export async function createBedrockEmbeddingProvider( // --------------------------------------------------------------------------- export function resolveBedrockEmbeddingClient( - options: EmbeddingProviderOptions, + options: MemoryEmbeddingProviderCreateOptions, ): BedrockEmbeddingClient { const model = normalizeBedrockEmbeddingModel(options.model); const spec = resolveSpec(model); diff --git a/extensions/amazon-bedrock/memory-embedding-adapter.ts b/extensions/amazon-bedrock/memory-embedding-adapter.ts new file mode 100644 index 00000000000..5b003f72116 --- /dev/null +++ b/extensions/amazon-bedrock/memory-embedding-adapter.ts @@ -0,0 +1,37 @@ +import { + isMissingEmbeddingApiKeyError, + type MemoryEmbeddingProviderAdapter, +} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings"; +import { + createBedrockEmbeddingProvider, + DEFAULT_BEDROCK_EMBEDDING_MODEL, +} from "./embedding-provider.js"; + +export const bedrockMemoryEmbeddingProviderAdapter: MemoryEmbeddingProviderAdapter = { + id: "bedrock", + defaultModel: DEFAULT_BEDROCK_EMBEDDING_MODEL, + transport: "remote", + authProviderId: "amazon-bedrock", + autoSelectPriority: 60, + allowExplicitWhenConfiguredAuto: true, + shouldContinueAutoSelection: isMissingEmbeddingApiKeyError, + create: async (options) => { + const { provider, client } = await createBedrockEmbeddingProvider({ + ...options, + provider: "bedrock", + fallback: "none", + }); + return { + provider, + runtime: { + id: "bedrock", + cacheKeyData: { + provider: "bedrock", + region: client.region, + model: client.model, + dimensions: client.dimensions, + }, + }, + }; + }, +}; diff --git a/extensions/amazon-bedrock/openclaw.plugin.json b/extensions/amazon-bedrock/openclaw.plugin.json index fbb443606ed..62fe8a3eb7d 100644 --- a/extensions/amazon-bedrock/openclaw.plugin.json +++ b/extensions/amazon-bedrock/openclaw.plugin.json @@ -2,6 +2,9 @@ "id": "amazon-bedrock", "enabledByDefault": true, "providers": ["amazon-bedrock"], + "contracts": { + "memoryEmbeddingProviders": ["bedrock"] + }, "configSchema": { "type": "object", "additionalProperties": false, diff --git a/extensions/amazon-bedrock/package.json b/extensions/amazon-bedrock/package.json index 6b09510a680..977da7e1fe0 100644 --- a/extensions/amazon-bedrock/package.json +++ b/extensions/amazon-bedrock/package.json @@ -5,7 +5,9 @@ "description": "OpenClaw Amazon Bedrock provider plugin", "type": "module", "dependencies": { - "@aws-sdk/client-bedrock": "3.1028.0" + "@aws-sdk/client-bedrock": "3.1028.0", + "@aws-sdk/client-bedrock-runtime": "3.1028.0", + "@aws-sdk/credential-provider-node": "3.972.30" }, "devDependencies": { "@openclaw/plugin-sdk": "workspace:*" diff --git a/extensions/amazon-bedrock/register.sync.runtime.ts b/extensions/amazon-bedrock/register.sync.runtime.ts index 60f282d9a86..a141ffd3444 100644 --- a/extensions/amazon-bedrock/register.sync.runtime.ts +++ b/extensions/amazon-bedrock/register.sync.runtime.ts @@ -14,6 +14,7 @@ import { resolveBedrockConfigApiKey, resolveImplicitBedrockProvider, } from "./api.js"; +import { bedrockMemoryEmbeddingProviderAdapter } from "./memory-embedding-adapter.js"; type GuardrailConfig = { guardrailIdentifier: string; @@ -78,6 +79,8 @@ export function registerAmazonBedrockPlugin(api: OpenClawPluginApi): void { const pluginConfig = (api.pluginConfig ?? {}) as AmazonBedrockPluginConfig; const guardrail = pluginConfig.guardrail; + api.registerMemoryEmbeddingProvider(bedrockMemoryEmbeddingProviderAdapter); + const baseWrapStreamFn = ({ modelId, streamFn }: { modelId: string; streamFn?: StreamFn }) => isAnthropicBedrockModel(modelId) ? streamFn : createBedrockNoCacheWrapper(streamFn); diff --git a/extensions/github-copilot/embeddings.test.ts b/extensions/github-copilot/embeddings.test.ts index 73a656d564c..229b5673680 100644 --- a/extensions/github-copilot/embeddings.test.ts +++ b/extensions/github-copilot/embeddings.test.ts @@ -4,7 +4,6 @@ const resolveFirstGithubTokenMock = vi.hoisted(() => vi.fn()); const resolveCopilotApiTokenMock = vi.hoisted(() => vi.fn()); const resolveConfiguredSecretInputStringMock = vi.hoisted(() => vi.fn()); const fetchWithSsrFGuardMock = vi.hoisted(() => vi.fn()); -const createGitHubCopilotEmbeddingProviderMock = vi.hoisted(() => vi.fn()); vi.mock("./auth.js", () => ({ resolveFirstGithubToken: resolveFirstGithubTokenMock, @@ -19,10 +18,6 @@ vi.mock("openclaw/plugin-sdk/github-copilot-token", () => ({ resolveCopilotApiToken: resolveCopilotApiTokenMock, })); -vi.mock("openclaw/plugin-sdk/memory-core-host-engine-embeddings", () => ({ - createGitHubCopilotEmbeddingProvider: createGitHubCopilotEmbeddingProviderMock, -})); - vi.mock("openclaw/plugin-sdk/ssrf-runtime", () => ({ fetchWithSsrFGuard: fetchWithSsrFGuardMock, })); @@ -73,15 +68,6 @@ describe("githubCopilotMemoryEmbeddingProviderAdapter", () => { source: "test", baseUrl: TEST_BASE_URL, }); - createGitHubCopilotEmbeddingProviderMock.mockImplementation(async (client) => ({ - provider: { - id: "github-copilot", - model: client.model, - embedQuery: async () => [0.1, 0.2, 0.3], - embedBatch: async (texts: string[]) => texts.map(() => [0.1, 0.2, 0.3]), - }, - client, - })); }); afterEach(() => { @@ -89,7 +75,6 @@ describe("githubCopilotMemoryEmbeddingProviderAdapter", () => { resolveConfiguredSecretInputStringMock.mockReset(); resolveFirstGithubTokenMock.mockReset(); resolveCopilotApiTokenMock.mockReset(); - createGitHubCopilotEmbeddingProviderMock.mockReset(); fetchWithSsrFGuardMock.mockReset(); }); @@ -113,12 +98,8 @@ describe("githubCopilotMemoryEmbeddingProviderAdapter", () => { const result = await githubCopilotMemoryEmbeddingProviderAdapter.create(defaultCreateOptions()); expect(result.provider?.model).toBe("text-embedding-3-small"); - expect(createGitHubCopilotEmbeddingProviderMock).toHaveBeenCalledWith( - expect.objectContaining({ - baseUrl: TEST_BASE_URL, - githubToken: "gh_test_token_123", - model: "text-embedding-3-small", - }), + expect(resolveCopilotApiTokenMock).toHaveBeenCalledWith( + expect.objectContaining({ githubToken: "gh_test_token_123" }), ); }); @@ -217,14 +198,12 @@ describe("githubCopilotMemoryEmbeddingProviderAdapter", () => { } as never); expect(resolveFirstGithubTokenMock).toHaveBeenCalled(); - expect(createGitHubCopilotEmbeddingProviderMock).toHaveBeenCalledWith({ - baseUrl: "https://proxy.example/v1", - env: process.env, - fetchImpl: fetch, - githubToken: "gh_remote_token", - headers: { "X-Proxy-Token": "proxy" }, - model: "text-embedding-3-small", - }); + expect(resolveCopilotApiTokenMock).toHaveBeenCalledWith( + expect.objectContaining({ + env: process.env, + githubToken: "gh_remote_token", + }), + ); const discoveryCall = fetchWithSsrFGuardMock.mock.calls[0]?.[0] as { init: { headers: Record }; diff --git a/extensions/github-copilot/embeddings.ts b/extensions/github-copilot/embeddings.ts index d06c8a06942..28a4210ec2d 100644 --- a/extensions/github-copilot/embeddings.ts +++ b/extensions/github-copilot/embeddings.ts @@ -4,7 +4,10 @@ import { resolveCopilotApiToken, } from "openclaw/plugin-sdk/github-copilot-token"; import { - createGitHubCopilotEmbeddingProvider, + buildRemoteBaseUrlPolicy, + sanitizeAndNormalizeEmbedding, + withRemoteHttpResponse, + type MemoryEmbeddingProvider, type MemoryEmbeddingProviderAdapter, } from "openclaw/plugin-sdk/memory-core-host-engine-embeddings"; import { fetchWithSsrFGuard, type SsrFPolicy } from "openclaw/plugin-sdk/ssrf-runtime"; @@ -44,6 +47,15 @@ type CopilotModelEntry = { supported_endpoints?: unknown; }; +type GitHubCopilotEmbeddingClient = { + githubToken: string; + model: string; + baseUrl?: string; + headers?: Record; + env?: NodeJS.ProcessEnv; + fetchImpl?: typeof fetch; +}; + function isCopilotSetupError(err: unknown): boolean { if (!(err instanceof Error)) { return false; @@ -147,9 +159,126 @@ function pickBestModel(available: string[], userModel?: string): string { throw new Error("No embedding models available from GitHub Copilot"); } +function parseGitHubCopilotEmbeddingPayload(payload: unknown, expectedCount: number): number[][] { + if (!payload || typeof payload !== "object") { + throw new Error("GitHub Copilot embeddings response missing data[]"); + } + const data = (payload as { data?: unknown }).data; + if (!Array.isArray(data)) { + throw new Error("GitHub Copilot embeddings response missing data[]"); + } + + const vectors = Array.from({ length: expectedCount }); + for (const entry of data) { + if (!entry || typeof entry !== "object") { + throw new Error("GitHub Copilot embeddings response contains an invalid entry"); + } + const indexValue = (entry as { index?: unknown }).index; + const embedding = (entry as { embedding?: unknown }).embedding; + const index = typeof indexValue === "number" ? indexValue : Number.NaN; + if (!Number.isInteger(index)) { + throw new Error("GitHub Copilot embeddings response contains an invalid index"); + } + if (index < 0 || index >= expectedCount) { + throw new Error("GitHub Copilot embeddings response contains an out-of-range index"); + } + if (vectors[index] !== undefined) { + throw new Error("GitHub Copilot embeddings response contains duplicate indexes"); + } + if (!Array.isArray(embedding) || !embedding.every((value) => typeof value === "number")) { + throw new Error("GitHub Copilot embeddings response contains an invalid embedding"); + } + vectors[index] = sanitizeAndNormalizeEmbedding(embedding); + } + + for (let index = 0; index < expectedCount; index += 1) { + if (vectors[index] === undefined) { + throw new Error("GitHub Copilot embeddings response missing vectors for some inputs"); + } + } + return vectors as number[][]; +} + +async function resolveGitHubCopilotEmbeddingSession(client: GitHubCopilotEmbeddingClient): Promise<{ + baseUrl: string; + headers: Record; +}> { + const token = await resolveCopilotApiToken({ + githubToken: client.githubToken, + env: client.env, + fetchImpl: client.fetchImpl, + }); + const baseUrl = client.baseUrl?.trim() || token.baseUrl || DEFAULT_COPILOT_API_BASE_URL; + return { + baseUrl, + headers: { + ...COPILOT_HEADERS_STATIC, + ...client.headers, + Authorization: `Bearer ${token.token}`, + }, + }; +} + +async function createGitHubCopilotEmbeddingProvider( + client: GitHubCopilotEmbeddingClient, +): Promise<{ provider: MemoryEmbeddingProvider; client: GitHubCopilotEmbeddingClient }> { + const initialSession = await resolveGitHubCopilotEmbeddingSession(client); + + const embed = async (input: string[]): Promise => { + if (input.length === 0) { + return []; + } + + const session = await resolveGitHubCopilotEmbeddingSession(client); + const url = `${session.baseUrl.replace(/\/$/, "")}/embeddings`; + return await withRemoteHttpResponse({ + url, + fetchImpl: client.fetchImpl, + ssrfPolicy: buildRemoteBaseUrlPolicy(session.baseUrl), + init: { + method: "POST", + headers: session.headers, + body: JSON.stringify({ model: client.model, input }), + }, + onResponse: async (response) => { + if (!response.ok) { + throw new Error( + `GitHub Copilot embeddings HTTP ${response.status}: ${await response.text()}`, + ); + } + + let payload: unknown; + try { + payload = await response.json(); + } catch { + throw new Error("GitHub Copilot embeddings returned invalid JSON"); + } + return parseGitHubCopilotEmbeddingPayload(payload, input.length); + }, + }); + }; + + return { + provider: { + id: COPILOT_EMBEDDING_PROVIDER_ID, + model: client.model, + embedQuery: async (text) => { + const [vector] = await embed([text]); + return vector ?? []; + }, + embedBatch: embed, + }, + client: { + ...client, + baseUrl: initialSession.baseUrl, + }, + }; +} + export const githubCopilotMemoryEmbeddingProviderAdapter: MemoryEmbeddingProviderAdapter = { id: COPILOT_EMBEDDING_PROVIDER_ID, transport: "remote", + authProviderId: COPILOT_EMBEDDING_PROVIDER_ID, autoSelectPriority: 15, allowExplicitWhenConfiguredAuto: true, shouldContinueAutoSelection: (err: unknown) => isCopilotSetupError(err), diff --git a/packages/memory-host-sdk/src/host/batch-gemini.ts b/extensions/google/embedding-batch.ts similarity index 96% rename from packages/memory-host-sdk/src/host/batch-gemini.ts rename to extensions/google/embedding-batch.ts index 4bdc9fa055e..d7fe2cdf36c 100644 --- a/packages/memory-host-sdk/src/host/batch-gemini.ts +++ b/extensions/google/embedding-batch.ts @@ -1,14 +1,15 @@ +import crypto from "node:crypto"; import { buildEmbeddingBatchGroupOptions, runEmbeddingBatchGroups, type EmbeddingBatchExecutionParams, -} from "./batch-runner.js"; -import { buildBatchHeaders, normalizeBatchBaseUrl } from "./batch-utils.js"; -import { sanitizeAndNormalizeEmbedding } from "./embedding-vectors.js"; -import { debugEmbeddingsLog } from "./embeddings-debug.js"; -import type { GeminiEmbeddingClient, GeminiTextEmbeddingRequest } from "./embeddings-gemini.js"; -import { hashText } from "./internal.js"; -import { withRemoteHttpResponse } from "./remote-http.js"; + buildBatchHeaders, + debugEmbeddingsLog, + normalizeBatchBaseUrl, + sanitizeAndNormalizeEmbedding, + withRemoteHttpResponse, +} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings"; +import type { GeminiEmbeddingClient, GeminiTextEmbeddingRequest } from "./embedding-provider.js"; export type GeminiBatchRequest = { custom_id: string; @@ -40,6 +41,10 @@ export type GeminiBatchOutputLine = { }; const GEMINI_BATCH_MAX_REQUESTS = 50000; +function hashText(text: string): string { + return crypto.createHash("sha256").update(text).digest("hex"); +} + function getGeminiUploadUrl(baseUrl: string): string { if (baseUrl.includes("/v1beta")) { return baseUrl.replace(/\/v1beta\/?$/, "/upload/v1beta"); diff --git a/src/memory-host-sdk/host/embeddings-gemini.test.ts b/extensions/google/embedding-provider.test.ts similarity index 52% rename from src/memory-host-sdk/host/embeddings-gemini.test.ts rename to extensions/google/embedding-provider.test.ts index a9c4826a1c9..c2b068b4b2d 100644 --- a/src/memory-host-sdk/host/embeddings-gemini.test.ts +++ b/extensions/google/embedding-provider.test.ts @@ -1,84 +1,40 @@ -import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; -import * as authModule from "../../agents/model-auth.js"; +import { afterEach, describe, expect, it, vi } from "vitest"; import { buildGeminiEmbeddingRequest, buildGeminiTextEmbeddingRequest, + createGeminiEmbeddingProvider, DEFAULT_GEMINI_EMBEDDING_MODEL, GEMINI_EMBEDDING_2_MODELS, isGeminiEmbedding2Model, normalizeGeminiModel, resolveGeminiOutputDimensionality, -} from "./embeddings-gemini-request.js"; -import { - createGeminiBatchFetchMock, - createJsonResponseFetchMock, - installFetchMock, - mockResolvedProviderKey, - parseFetchBody, - readFirstFetchRequest, - type JsonFetchMock, -} from "./embeddings-provider.test-support.js"; - -const { resolveApiKeyForProviderMock } = vi.hoisted(() => ({ - resolveApiKeyForProviderMock: vi.fn(), -})); - -vi.mock("../../agents/model-auth.js", () => { - return { - resolveApiKeyForProvider: resolveApiKeyForProviderMock, - requireApiKey: (auth: { apiKey?: string; mode?: string }, provider: string) => { - if (auth.apiKey) { - return auth.apiKey; - } - throw new Error(`No API key resolved for provider "${provider}" (auth mode: ${auth.mode}).`); - }, - }; -}); - -vi.mock("../../agents/api-key-rotation.js", () => ({ - collectProviderApiKeysForExecution: (params: { primaryApiKey?: string }) => - params.primaryApiKey ? [params.primaryApiKey] : [], - executeWithApiKeyRotation: async (params: { - apiKeys: string[]; - execute: (apiKey: string) => Promise; - }) => { - const apiKey = params.apiKeys[0]; - if (!apiKey) { - throw new Error('No API keys configured for provider "google".'); - } - return await params.execute(apiKey); - }, -})); - -beforeEach(() => { - vi.useRealTimers(); - vi.doUnmock("undici"); -}); +} from "./embedding-provider.js"; afterEach(() => { - vi.doUnmock("undici"); - vi.resetAllMocks(); + vi.restoreAllMocks(); vi.unstubAllGlobals(); }); -type GeminiProviderOptions = Parameters< - typeof import("./embeddings-gemini.js").createGeminiEmbeddingProvider ->[0]; - -async function createProviderWithFetch( - fetchMock: JsonFetchMock, - options: Partial & { model: string }, -) { - installFetchMock(fetchMock as unknown as typeof globalThis.fetch); - mockResolvedProviderKey(authModule.resolveApiKeyForProvider); - const { createGeminiEmbeddingProvider } = await import("./embeddings-gemini.js"); - const { provider } = await createGeminiEmbeddingProvider({ - config: {} as never, - provider: "gemini", - fallback: "none", - ...options, +function installFetchMock( + handler: (input: RequestInfo | URL, init?: RequestInit) => unknown, +): ReturnType { + const fetchMock = vi.fn(async (input: RequestInfo | URL, init?: RequestInit) => { + return new Response(JSON.stringify(handler(input, init)), { + status: 200, + headers: { "Content-Type": "application/json" }, + }); }); - return provider; + vi.stubGlobal("fetch", fetchMock); + return fetchMock; +} + +function fetchJsonBody(fetchMock: ReturnType, index: number): unknown { + const init = fetchMock.mock.calls[index]?.[1] as RequestInit | undefined; + const body = init?.body; + if (typeof body !== "string") { + throw new Error("Expected JSON string request body."); + } + return JSON.parse(body) as unknown; } describe("Gemini embedding request helpers", () => { @@ -149,24 +105,9 @@ describe("Gemini embedding request helpers", () => { }); }); -describe("gemini embedding provider", () => { +describe("Gemini embedding provider", () => { it("handles legacy and v2 request/response behavior", async () => { - const legacyFetch = createGeminiBatchFetchMock(2); - const legacyProvider = await createProviderWithFetch(legacyFetch, { - model: "gemini-embedding-001", - }); - - await legacyProvider.embedQuery("test query"); - await legacyProvider.embedBatch(["text1", "text2"]); - - expect(parseFetchBody(legacyFetch, 0)).toMatchObject({ - taskType: "RETRIEVAL_QUERY", - content: { parts: [{ text: "test query" }] }, - }); - expect(parseFetchBody(legacyFetch, 0)).not.toHaveProperty("outputDimensionality"); - expect(parseFetchBody(legacyFetch, 1)).not.toHaveProperty("outputDimensionality"); - - const v2Fetch = createJsonResponseFetchMock((input) => { + const fetchMock = installFetchMock((input) => { const url = input instanceof URL ? input.href : typeof input === "string" ? input : input.url; return url.endsWith(":batchEmbedContents") ? { @@ -176,16 +117,22 @@ describe("gemini embedding provider", () => { } : { embedding: { values: [3, 4, Number.NaN] } }; }); - const v2Provider = await createProviderWithFetch(v2Fetch, { + + const { provider } = await createGeminiEmbeddingProvider({ + config: {} as never, + provider: "gemini", + remote: { apiKey: "test-key" }, model: "gemini-embedding-2-preview", outputDimensionality: 768, taskType: "SEMANTIC_SIMILARITY", + fallback: "none", }); - await expect(v2Provider.embedQuery(" ")).resolves.toEqual([]); - await expect(v2Provider.embedBatch([])).resolves.toEqual([]); - await expect(v2Provider.embedQuery("test query")).resolves.toEqual([0.6, 0.8, 0]); - const structuredBatch = await v2Provider.embedBatchInputs?.([ + await expect(provider.embedQuery(" ")).resolves.toEqual([]); + await expect(provider.embedBatch([])).resolves.toEqual([]); + await expect(provider.embedQuery("test query")).resolves.toEqual([0.6, 0.8, 0]); + + const structuredBatch = await provider.embedBatchInputs?.([ { text: "Image file: diagram.png", parts: [ @@ -206,38 +153,39 @@ describe("gemini embedding provider", () => { [0, 0, 1], ]); - const { url } = readFirstFetchRequest(v2Fetch); - expect(url).toBe( + expect(fetchMock.mock.calls[0]?.[0]).toBe( "https://generativelanguage.googleapis.com/v1beta/models/gemini-embedding-2-preview:embedContent", ); - expect(parseFetchBody(v2Fetch, 0)).toMatchObject({ + expect(fetchJsonBody(fetchMock, 0)).toMatchObject({ outputDimensionality: 768, taskType: "SEMANTIC_SIMILARITY", content: { parts: [{ text: "test query" }] }, }); - expect(parseFetchBody(v2Fetch, 1).requests).toEqual([ - { - model: "models/gemini-embedding-2-preview", - content: { - parts: [ - { text: "Image file: diagram.png" }, - { inlineData: { mimeType: "image/png", data: "img" } }, - ], + expect(fetchJsonBody(fetchMock, 1)).toMatchObject({ + requests: [ + { + model: "models/gemini-embedding-2-preview", + content: { + parts: [ + { text: "Image file: diagram.png" }, + { inlineData: { mimeType: "image/png", data: "img" } }, + ], + }, + taskType: "SEMANTIC_SIMILARITY", + outputDimensionality: 768, }, - taskType: "SEMANTIC_SIMILARITY", - outputDimensionality: 768, - }, - { - model: "models/gemini-embedding-2-preview", - content: { - parts: [ - { text: "Audio file: note.wav" }, - { inlineData: { mimeType: "audio/wav", data: "aud" } }, - ], + { + model: "models/gemini-embedding-2-preview", + content: { + parts: [ + { text: "Audio file: note.wav" }, + { inlineData: { mimeType: "audio/wav", data: "aud" } }, + ], + }, + taskType: "SEMANTIC_SIMILARITY", + outputDimensionality: 768, }, - taskType: "SEMANTIC_SIMILARITY", - outputDimensionality: 768, - }, - ]); + ], + }); }); }); diff --git a/src/memory-host-sdk/host/embeddings-gemini.ts b/extensions/google/embedding-provider.ts similarity index 54% rename from src/memory-host-sdk/host/embeddings-gemini.ts rename to extensions/google/embedding-provider.ts index d4b989d5fcc..26911cf685f 100644 --- a/src/memory-host-sdk/host/embeddings-gemini.ts +++ b/extensions/google/embedding-provider.ts @@ -1,44 +1,22 @@ +import { parseGeminiAuth } from "openclaw/plugin-sdk/image-generation-core"; +import { + buildRemoteBaseUrlPolicy, + debugEmbeddingsLog, + sanitizeAndNormalizeEmbedding, + withRemoteHttpResponse, + type EmbeddingInput, + type MemoryEmbeddingProvider, + type MemoryEmbeddingProviderCreateOptions, +} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings"; +import { resolveMemorySecretInputString } from "openclaw/plugin-sdk/memory-core-host-secret"; import { collectProviderApiKeysForExecution, executeWithApiKeyRotation, -} from "../../agents/api-key-rotation.js"; -import { requireApiKey, resolveApiKeyForProvider } from "../../agents/model-auth.js"; -import { parseGeminiAuth } from "../../infra/gemini-auth.js"; -import { - DEFAULT_GOOGLE_API_BASE_URL, - normalizeGoogleApiBaseUrl, -} from "../../infra/google-api-base-url.js"; -import type { SsrFPolicy } from "../../infra/net/ssrf.js"; -import { normalizeOptionalString } from "../../shared/string-coerce.js"; -import type { EmbeddingInput } from "./embedding-inputs.js"; -import { sanitizeAndNormalizeEmbedding } from "./embedding-vectors.js"; -import { debugEmbeddingsLog } from "./embeddings-debug.js"; -import { - buildGeminiEmbeddingRequest, - buildGeminiTextEmbeddingRequest, - isGeminiEmbedding2Model, - normalizeGeminiModel, - resolveGeminiOutputDimensionality, -} from "./embeddings-gemini-request.js"; -import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.types.js"; -import { buildRemoteBaseUrlPolicy, withRemoteHttpResponse } from "./remote-http.js"; -import { resolveMemorySecretInputString } from "./secret-input.js"; - -export { - buildGeminiEmbeddingRequest, - buildGeminiTextEmbeddingRequest, - DEFAULT_GEMINI_EMBEDDING_MODEL, - GEMINI_EMBEDDING_2_MODELS, - isGeminiEmbedding2Model, - normalizeGeminiModel, - resolveGeminiOutputDimensionality, - type GeminiEmbeddingRequest, - type GeminiInlinePart, - type GeminiPart, - type GeminiTaskType, - type GeminiTextEmbeddingRequest, - type GeminiTextPart, -} from "./embeddings-gemini-request.js"; + requireApiKey, + resolveApiKeyForProvider, +} from "openclaw/plugin-sdk/provider-auth-runtime"; +import type { SsrFPolicy } from "openclaw/plugin-sdk/ssrf-runtime"; +import { normalizeOptionalString } from "openclaw/plugin-sdk/text-runtime"; export type GeminiEmbeddingClient = { baseUrl: string; @@ -50,9 +28,111 @@ export type GeminiEmbeddingClient = { outputDimensionality?: number; }; +export const DEFAULT_GEMINI_EMBEDDING_MODEL = "gemini-embedding-001"; +const DEFAULT_GOOGLE_API_BASE_URL = "https://generativelanguage.googleapis.com/v1beta"; const GEMINI_MAX_INPUT_TOKENS: Record = { "text-embedding-004": 2048, + "gemini-embedding-001": 2048, + "gemini-embedding-2-preview": 8192, }; + +export type GeminiTaskType = NonNullable; + +// --- gemini-embedding-2-preview support --- + +export const GEMINI_EMBEDDING_2_MODELS = new Set([ + "gemini-embedding-2-preview", + // Add the GA model name here once released. +]); + +const GEMINI_EMBEDDING_2_DEFAULT_DIMENSIONS = 3072; +const GEMINI_EMBEDDING_2_VALID_DIMENSIONS = [768, 1536, 3072] as const; + +export type GeminiTextPart = { text: string }; +export type GeminiInlinePart = { + inlineData: { mimeType: string; data: string }; +}; +export type GeminiPart = GeminiTextPart | GeminiInlinePart; +export type GeminiEmbeddingRequest = { + content: { parts: GeminiPart[] }; + taskType: GeminiTaskType; + outputDimensionality?: number; + model?: string; +}; +export type GeminiTextEmbeddingRequest = GeminiEmbeddingRequest; + +/** Builds the text-only Gemini embedding request shape used across direct and batch APIs. */ +export function buildGeminiTextEmbeddingRequest(params: { + text: string; + taskType: GeminiTaskType; + outputDimensionality?: number; + modelPath?: string; +}): GeminiTextEmbeddingRequest { + return buildGeminiEmbeddingRequest({ + input: { text: params.text }, + taskType: params.taskType, + outputDimensionality: params.outputDimensionality, + modelPath: params.modelPath, + }); +} + +export function buildGeminiEmbeddingRequest(params: { + input: EmbeddingInput; + taskType: GeminiTaskType; + outputDimensionality?: number; + modelPath?: string; +}): GeminiEmbeddingRequest { + const request: GeminiEmbeddingRequest = { + content: { + parts: params.input.parts?.map((part) => + part.type === "text" + ? ({ text: part.text } satisfies GeminiTextPart) + : ({ + inlineData: { mimeType: part.mimeType, data: part.data }, + } satisfies GeminiInlinePart), + ) ?? [{ text: params.input.text }], + }, + taskType: params.taskType, + }; + if (params.modelPath) { + request.model = params.modelPath; + } + if (params.outputDimensionality != null) { + request.outputDimensionality = params.outputDimensionality; + } + return request; +} + +/** + * Returns true if the given model name is a gemini-embedding-2 variant that + * supports `outputDimensionality` and extended task types. + */ +export function isGeminiEmbedding2Model(model: string): boolean { + return GEMINI_EMBEDDING_2_MODELS.has(model); +} + +/** + * Validate and return the `outputDimensionality` for gemini-embedding-2 models. + * Returns `undefined` for older models (they don't support the param). + */ +export function resolveGeminiOutputDimensionality( + model: string, + requested?: number, +): number | undefined { + if (!isGeminiEmbedding2Model(model)) { + return undefined; + } + if (requested == null) { + return GEMINI_EMBEDDING_2_DEFAULT_DIMENSIONS; + } + const valid: readonly number[] = GEMINI_EMBEDDING_2_VALID_DIMENSIONS; + if (!valid.includes(requested)) { + throw new Error( + `Invalid outputDimensionality ${requested} for ${model}. Valid values: ${valid.join(", ")}`, + ); + } + return requested; +} function resolveRemoteApiKey(remoteApiKey: unknown): string | undefined { const trimmed = resolveMemorySecretInputString({ value: remoteApiKey, @@ -67,6 +147,21 @@ function resolveRemoteApiKey(remoteApiKey: unknown): string | undefined { return trimmed; } +export function normalizeGeminiModel(model: string): string { + const trimmed = model.trim(); + if (!trimmed) { + return DEFAULT_GEMINI_EMBEDDING_MODEL; + } + const withoutPrefix = trimmed.replace(/^models\//, ""); + if (withoutPrefix.startsWith("gemini/")) { + return withoutPrefix.slice("gemini/".length); + } + if (withoutPrefix.startsWith("google/")) { + return withoutPrefix.slice("google/".length); + } + return withoutPrefix; +} + async function fetchGeminiEmbeddingPayload(params: { client: GeminiEmbeddingClient; endpoint: string; @@ -120,9 +215,30 @@ function buildGeminiModelPath(model: string): string { return model.startsWith("models/") ? model : `models/${model}`; } +function normalizeGoogleApiBaseUrl(baseUrl: string): string { + const trimmed = baseUrl.trim().replace(/\/+$/, ""); + if (!trimmed) { + return DEFAULT_GOOGLE_API_BASE_URL; + } + try { + const url = new URL(trimmed); + url.hash = ""; + url.search = ""; + if ( + url.origin.toLowerCase() === "https://generativelanguage.googleapis.com" && + url.pathname.replace(/\/+$/, "") === "" + ) { + url.pathname = "/v1beta"; + } + return url.toString().replace(/\/+$/, ""); + } catch { + return trimmed; + } +} + export async function createGeminiEmbeddingProvider( - options: EmbeddingProviderOptions, -): Promise<{ provider: EmbeddingProvider; client: GeminiEmbeddingClient }> { + options: MemoryEmbeddingProviderCreateOptions, +): Promise<{ provider: MemoryEmbeddingProvider; client: GeminiEmbeddingClient }> { const client = await resolveGeminiEmbeddingClient(options); const baseUrl = client.baseUrl.replace(/\/$/, ""); const embedUrl = `${baseUrl}/${client.modelPath}:embedContent`; @@ -190,7 +306,7 @@ export async function createGeminiEmbeddingProvider( } export async function resolveGeminiEmbeddingClient( - options: EmbeddingProviderOptions, + options: MemoryEmbeddingProviderCreateOptions, ): Promise { const remote = options.remote; const remoteApiKey = resolveRemoteApiKey(remote?.apiKey); diff --git a/extensions/google/index.ts b/extensions/google/index.ts index 3f8b36eea9b..195e79602a5 100644 --- a/extensions/google/index.ts +++ b/extensions/google/index.ts @@ -3,6 +3,7 @@ import type { MediaUnderstandingProvider } from "openclaw/plugin-sdk/media-under import { definePluginEntry } from "openclaw/plugin-sdk/plugin-entry"; import { buildGoogleGeminiCliBackend } from "./cli-backend.js"; import { registerGoogleGeminiCliProvider } from "./gemini-cli-provider.js"; +import { geminiMemoryEmbeddingProviderAdapter } from "./memory-embedding-adapter.js"; import { buildGoogleMusicGenerationProvider } from "./music-generation-provider.js"; import { registerGoogleProvider } from "./provider-registration.js"; import { buildGoogleSpeechProvider } from "./speech-provider.js"; @@ -111,6 +112,7 @@ export default definePluginEntry({ api.registerCliBackend(buildGoogleGeminiCliBackend()); registerGoogleGeminiCliProvider(api); registerGoogleProvider(api); + api.registerMemoryEmbeddingProvider(geminiMemoryEmbeddingProviderAdapter); api.registerImageGenerationProvider(createLazyGoogleImageGenerationProvider()); api.registerMediaUnderstandingProvider(createLazyGoogleMediaUnderstandingProvider()); api.registerMusicGenerationProvider(buildGoogleMusicGenerationProvider()); diff --git a/extensions/google/memory-embedding-adapter.ts b/extensions/google/memory-embedding-adapter.ts new file mode 100644 index 00000000000..3d544e625ef --- /dev/null +++ b/extensions/google/memory-embedding-adapter.ts @@ -0,0 +1,79 @@ +import { + hasNonTextEmbeddingParts, + isMissingEmbeddingApiKeyError, + mapBatchEmbeddingsByIndex, + sanitizeEmbeddingCacheHeaders, + type MemoryEmbeddingProviderAdapter, +} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings"; +import { runGeminiEmbeddingBatches } from "./embedding-batch.js"; +import { + buildGeminiEmbeddingRequest, + createGeminiEmbeddingProvider, + DEFAULT_GEMINI_EMBEDDING_MODEL, +} from "./embedding-provider.js"; + +function supportsGeminiMultimodalEmbeddings(model: string): boolean { + const normalized = model + .trim() + .replace(/^models\//, "") + .replace(/^(gemini|google)\//, ""); + return normalized === "gemini-embedding-2-preview"; +} + +export const geminiMemoryEmbeddingProviderAdapter: MemoryEmbeddingProviderAdapter = { + id: "gemini", + defaultModel: DEFAULT_GEMINI_EMBEDDING_MODEL, + transport: "remote", + authProviderId: "google", + autoSelectPriority: 30, + allowExplicitWhenConfiguredAuto: true, + supportsMultimodalEmbeddings: ({ model }) => supportsGeminiMultimodalEmbeddings(model), + shouldContinueAutoSelection: isMissingEmbeddingApiKeyError, + create: async (options) => { + const { provider, client } = await createGeminiEmbeddingProvider({ + ...options, + provider: "gemini", + fallback: "none", + }); + return { + provider, + runtime: { + id: "gemini", + cacheKeyData: { + provider: "gemini", + baseUrl: client.baseUrl, + model: client.model, + outputDimensionality: client.outputDimensionality, + headers: sanitizeEmbeddingCacheHeaders(client.headers, [ + "authorization", + "x-goog-api-key", + ]), + }, + batchEmbed: async (batch) => { + if (batch.chunks.some((chunk) => hasNonTextEmbeddingParts(chunk.embeddingInput))) { + return null; + } + const byCustomId = await runGeminiEmbeddingBatches({ + gemini: client, + agentId: batch.agentId, + requests: batch.chunks.map((chunk, index) => ({ + custom_id: String(index), + request: buildGeminiEmbeddingRequest({ + input: chunk.embeddingInput ?? { text: chunk.text }, + taskType: "RETRIEVAL_DOCUMENT", + modelPath: client.modelPath, + outputDimensionality: client.outputDimensionality, + }), + })), + wait: batch.wait, + concurrency: batch.concurrency, + pollIntervalMs: batch.pollIntervalMs, + timeoutMs: batch.timeoutMs, + debug: batch.debug, + }); + return mapBatchEmbeddingsByIndex(byCustomId, batch.chunks.length); + }, + }, + }; + }, +}; diff --git a/extensions/google/openclaw.plugin.json b/extensions/google/openclaw.plugin.json index 40f0ad25e4d..5ce20ace302 100644 --- a/extensions/google/openclaw.plugin.json +++ b/extensions/google/openclaw.plugin.json @@ -46,6 +46,7 @@ }, "contracts": { "mediaUnderstandingProviders": ["google"], + "memoryEmbeddingProviders": ["gemini"], "imageGenerationProviders": ["google"], "musicGenerationProviders": ["google"], "speechProviders": ["google"], diff --git a/extensions/lmstudio/index.ts b/extensions/lmstudio/index.ts index e47d9bdef9b..7e18c015502 100644 --- a/extensions/lmstudio/index.ts +++ b/extensions/lmstudio/index.ts @@ -8,6 +8,7 @@ import { type ProviderRuntimeModel, } from "openclaw/plugin-sdk/plugin-entry"; import { CUSTOM_LOCAL_AUTH_MARKER } from "openclaw/plugin-sdk/provider-auth"; +import { lmstudioMemoryEmbeddingProviderAdapter } from "./memory-embedding-adapter.js"; import { LMSTUDIO_DEFAULT_API_KEY_ENV_VAR, LMSTUDIO_LOCAL_API_KEY_PLACEHOLDER, @@ -52,6 +53,7 @@ export default definePluginEntry({ name: "LM Studio Provider", description: "Bundled LM Studio provider plugin", register(api: OpenClawPluginApi) { + api.registerMemoryEmbeddingProvider(lmstudioMemoryEmbeddingProviderAdapter); api.registerProvider({ id: PROVIDER_ID, label: "LM Studio", diff --git a/extensions/lmstudio/memory-embedding-adapter.ts b/extensions/lmstudio/memory-embedding-adapter.ts new file mode 100644 index 00000000000..2d56b811be9 --- /dev/null +++ b/extensions/lmstudio/memory-embedding-adapter.ts @@ -0,0 +1,35 @@ +import { + sanitizeEmbeddingCacheHeaders, + type MemoryEmbeddingProviderAdapter, +} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings"; +import { + createLmstudioEmbeddingProvider, + DEFAULT_LMSTUDIO_EMBEDDING_MODEL, +} from "./src/embedding-provider.js"; + +export const lmstudioMemoryEmbeddingProviderAdapter: MemoryEmbeddingProviderAdapter = { + id: "lmstudio", + defaultModel: DEFAULT_LMSTUDIO_EMBEDDING_MODEL, + transport: "remote", + authProviderId: "lmstudio", + allowExplicitWhenConfiguredAuto: true, + create: async (options) => { + const { provider, client } = await createLmstudioEmbeddingProvider({ + ...options, + provider: "lmstudio", + fallback: "none", + }); + return { + provider, + runtime: { + id: "lmstudio", + cacheKeyData: { + provider: "lmstudio", + baseUrl: client.baseUrl, + model: client.model, + headers: sanitizeEmbeddingCacheHeaders(client.headers, ["authorization"]), + }, + }, + }; + }, +}; diff --git a/extensions/lmstudio/openclaw.plugin.json b/extensions/lmstudio/openclaw.plugin.json index 0fc035ce76b..0559047b9a9 100644 --- a/extensions/lmstudio/openclaw.plugin.json +++ b/extensions/lmstudio/openclaw.plugin.json @@ -21,6 +21,9 @@ "groupHint": "Self-hosted open-weight models" } ], + "contracts": { + "memoryEmbeddingProviders": ["lmstudio"] + }, "configSchema": { "type": "object", "additionalProperties": false, diff --git a/src/memory-host-sdk/host/embeddings-lmstudio.ts b/extensions/lmstudio/src/embedding-provider.ts similarity index 81% rename from src/memory-host-sdk/host/embeddings-lmstudio.ts rename to extensions/lmstudio/src/embedding-provider.ts index d80c4255f0b..7e2a410485b 100644 --- a/src/memory-host-sdk/host/embeddings-lmstudio.ts +++ b/extensions/lmstudio/src/embedding-provider.ts @@ -1,20 +1,21 @@ -import { formatErrorMessage } from "../../infra/errors.js"; -import type { SsrFPolicy } from "../../infra/net/ssrf.js"; -import { createSubsystemLogger } from "../../logging/subsystem.js"; +import { createSubsystemLogger } from "openclaw/plugin-sdk/logging-core"; +import { + buildRemoteBaseUrlPolicy, + createRemoteEmbeddingProvider, + normalizeEmbeddingModelWithPrefixes, + type MemoryEmbeddingProvider, + type MemoryEmbeddingProviderCreateOptions, +} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings"; +import { resolveMemorySecretInputString } from "openclaw/plugin-sdk/memory-core-host-secret"; +import { formatErrorMessage, type SsrFPolicy } from "openclaw/plugin-sdk/ssrf-runtime"; +import { LMSTUDIO_DEFAULT_EMBEDDING_MODEL, LMSTUDIO_PROVIDER_ID } from "./defaults.js"; +import { ensureLmstudioModelLoaded } from "./models.fetch.js"; +import { resolveLmstudioInferenceBase } from "./models.js"; import { buildLmstudioAuthHeaders, - ensureLmstudioModelLoaded, - LMSTUDIO_DEFAULT_EMBEDDING_MODEL, - LMSTUDIO_PROVIDER_ID, - resolveLmstudioInferenceBase, resolveLmstudioProviderHeaders, resolveLmstudioRuntimeApiKey, -} from "../../plugin-sdk/lmstudio-runtime.js"; -import { normalizeEmbeddingModelWithPrefixes } from "./embeddings-model-normalize.js"; -import { createRemoteEmbeddingProvider } from "./embeddings-remote-provider.js"; -import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.types.js"; -import { buildRemoteBaseUrlPolicy } from "./remote-http.js"; -import { resolveMemorySecretInputString } from "./secret-input.js"; +} from "./runtime.js"; const log = createSubsystemLogger("memory/embeddings"); @@ -47,7 +48,7 @@ function hasAuthorizationHeader(headers: Record | undefined): bo /** Resolves API key (real or synthetic placeholder) from runtime/provider auth config. */ async function resolveLmstudioApiKey( - options: EmbeddingProviderOptions, + options: MemoryEmbeddingProviderCreateOptions, ): Promise { try { return await resolveLmstudioRuntimeApiKey({ @@ -65,8 +66,8 @@ async function resolveLmstudioApiKey( /** Creates the LM Studio embedding provider client and preloads the target model before return. */ export async function createLmstudioEmbeddingProvider( - options: EmbeddingProviderOptions, -): Promise<{ provider: EmbeddingProvider; client: LmstudioEmbeddingClient }> { + options: MemoryEmbeddingProviderCreateOptions, +): Promise<{ provider: MemoryEmbeddingProvider; client: LmstudioEmbeddingClient }> { const providerConfig = options.config.models?.providers?.lmstudio; const providerBaseUrl = providerConfig?.baseUrl?.trim(); const isFallbackActivation = options.fallback === "lmstudio" && options.provider !== "lmstudio"; diff --git a/extensions/memory-core/src/memory/embeddings.ts b/extensions/memory-core/src/memory/embeddings.ts index 609d91c1f46..e06aba20c5f 100644 --- a/extensions/memory-core/src/memory/embeddings.ts +++ b/extensions/memory-core/src/memory/embeddings.ts @@ -1,10 +1,5 @@ import { - DEFAULT_GEMINI_EMBEDDING_MODEL, DEFAULT_LOCAL_MODEL, - DEFAULT_MISTRAL_EMBEDDING_MODEL, - DEFAULT_OLLAMA_EMBEDDING_MODEL, - DEFAULT_OPENAI_EMBEDDING_MODEL, - DEFAULT_VOYAGE_EMBEDDING_MODEL, getMemoryEmbeddingProvider, listMemoryEmbeddingProviders, type MemoryEmbeddingProvider, @@ -15,15 +10,7 @@ import { import { formatErrorMessage } from "../dreaming-shared.js"; import { canAutoSelectLocal } from "./provider-adapters.js"; -export { - DEFAULT_GEMINI_EMBEDDING_MODEL, - DEFAULT_LMSTUDIO_EMBEDDING_MODEL, - DEFAULT_LOCAL_MODEL, - DEFAULT_MISTRAL_EMBEDDING_MODEL, - DEFAULT_OLLAMA_EMBEDDING_MODEL, - DEFAULT_OPENAI_EMBEDDING_MODEL, - DEFAULT_VOYAGE_EMBEDDING_MODEL, -} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings"; +export { DEFAULT_LOCAL_MODEL } from "openclaw/plugin-sdk/memory-core-host-engine-embeddings"; export type EmbeddingProvider = MemoryEmbeddingProvider; export type EmbeddingProviderId = string; diff --git a/extensions/memory-core/src/memory/index.test.ts b/extensions/memory-core/src/memory/index.test.ts index b63904ee10b..750eb257921 100644 --- a/extensions/memory-core/src/memory/index.test.ts +++ b/extensions/memory-core/src/memory/index.test.ts @@ -11,9 +11,9 @@ import { } from "../../../../src/plugins/memory-embedding-providers.js"; import "./test-runtime-mocks.js"; import type { MemoryIndexManager } from "./index.js"; -import { getMemorySearchManager, closeAllMemorySearchManagers } from "./index.js"; +import { closeAllMemorySearchManagers, getMemorySearchManager } from "./index.js"; import { - DEFAULT_OLLAMA_EMBEDDING_MODEL, + DEFAULT_LOCAL_MODEL, registerBuiltInMemoryEmbeddingProviders, } from "./provider-adapters.js"; @@ -112,14 +112,14 @@ vi.mock("./embeddings.js", () => { }); describe("memory index", () => { - it("registers the builtin ollama embedding provider", () => { - const adapter = listRegisteredAdapters().find((entry) => entry.id === "ollama"); + it("registers the builtin local embedding provider", () => { + const adapter = listRegisteredAdapters().find((entry) => entry.id === "local"); expect(adapter).toBeDefined(); expect(adapter).toEqual( expect.objectContaining({ - id: "ollama", - defaultModel: DEFAULT_OLLAMA_EMBEDDING_MODEL, + id: "local", + defaultModel: DEFAULT_LOCAL_MODEL, }), ); }); diff --git a/extensions/memory-core/src/memory/provider-adapters.ts b/extensions/memory-core/src/memory/provider-adapters.ts index e9ceed7b0bd..3cdc631676d 100644 --- a/extensions/memory-core/src/memory/provider-adapters.ts +++ b/extensions/memory-core/src/memory/provider-adapters.ts @@ -1,31 +1,13 @@ import fsSync from "node:fs"; import { - DEFAULT_GEMINI_EMBEDDING_MODEL, - DEFAULT_LMSTUDIO_EMBEDDING_MODEL, - DEFAULT_LOCAL_MODEL, - DEFAULT_MISTRAL_EMBEDDING_MODEL, - DEFAULT_OLLAMA_EMBEDDING_MODEL, - DEFAULT_OPENAI_EMBEDDING_MODEL, - DEFAULT_VOYAGE_EMBEDDING_MODEL, - OPENAI_BATCH_ENDPOINT, - buildGeminiEmbeddingRequest, - createGeminiEmbeddingProvider, - createLmstudioEmbeddingProvider, createLocalEmbeddingProvider, - createMistralEmbeddingProvider, - createOllamaEmbeddingProvider, - createOpenAiEmbeddingProvider, - createVoyageEmbeddingProvider, - hasNonTextEmbeddingParts, + DEFAULT_LOCAL_MODEL, + listMemoryEmbeddingProviders, listRegisteredMemoryEmbeddingProviderAdapters, - runGeminiEmbeddingBatches, - runOpenAiEmbeddingBatches, - runVoyageEmbeddingBatches, type MemoryEmbeddingProviderAdapter, } from "openclaw/plugin-sdk/memory-core-host-engine-embeddings"; import { resolveUserPath } from "openclaw/plugin-sdk/memory-core-host-engine-foundation"; import { getProviderEnvVars } from "openclaw/plugin-sdk/provider-env-vars"; -import { normalizeLowercaseStringOrEmpty } from "openclaw/plugin-sdk/text-runtime"; import { formatErrorMessage } from "../dreaming-shared.js"; import { filterUnregisteredMemoryEmbeddingProviderAdapters } from "./provider-adapter-registration.js"; @@ -37,31 +19,6 @@ export type BuiltinMemoryEmbeddingProviderDoctorMetadata = { autoSelectPriority?: number; }; -function isMissingApiKeyError(err: unknown): boolean { - return formatErrorMessage(err).includes("No API key found for provider"); -} - -function sanitizeHeaders( - headers: Record, - excludedHeaderNames: string[], -): Array<[string, string]> { - const excluded = new Set( - excludedHeaderNames.map((name) => normalizeLowercaseStringOrEmpty(name)), - ); - return Object.entries(headers) - .filter(([key]) => !excluded.has(normalizeLowercaseStringOrEmpty(key))) - .toSorted(([a], [b]) => a.localeCompare(b)) - .map(([key, value]) => [key, value]); -} - -function mapBatchEmbeddingsByIndex(byCustomId: Map, count: number): number[][] { - const embeddings: number[][] = []; - for (let index = 0; index < count; index += 1) { - embeddings.push(byCustomId.get(String(index)) ?? []); - } - return embeddings; -} - function isNodeLlamaCppMissing(err: unknown): boolean { if (!(err instanceof Error)) { return false; @@ -70,6 +27,20 @@ function isNodeLlamaCppMissing(err: unknown): boolean { return code === "ERR_MODULE_NOT_FOUND" && err.message.includes("node-llama-cpp"); } +function listRemoteEmbeddingSetupHints(): string[] { + try { + return listMemoryEmbeddingProviders() + .filter( + (adapter) => + adapter.transport === "remote" && typeof adapter.autoSelectPriority === "number", + ) + .toSorted((a, b) => (a.autoSelectPriority ?? 0) - (b.autoSelectPriority ?? 0)) + .map((adapter) => `Or set agents.defaults.memorySearch.provider = "${adapter.id}" (remote).`); + } catch { + return []; + } +} + function formatLocalSetupError(err: unknown): string { const detail = formatErrorMessage(err); const missing = isNodeLlamaCppMissing(err); @@ -87,9 +58,7 @@ function formatLocalSetupError(err: unknown): string { ? "2) Reinstall OpenClaw (this should install node-llama-cpp): npm i -g openclaw@latest" : null, "3) If you use pnpm: pnpm approve-builds (select node-llama-cpp), then pnpm rebuild node-llama-cpp", - ...["openai", "gemini", "voyage", "mistral"].map( - (provider) => `Or set agents.defaults.memorySearch.provider = "${provider}" (remote).`, - ), + ...listRemoteEmbeddingSetupHints(), ] .filter(Boolean) .join("\n"); @@ -111,237 +80,6 @@ function canAutoSelectLocal(modelPath?: string): boolean { } } -function supportsGeminiMultimodalEmbeddings(model: string): boolean { - const normalized = model - .trim() - .replace(/^models\//, "") - .replace(/^(gemini|google)\//, ""); - return normalized === "gemini-embedding-2-preview"; -} - -function resolveMemoryEmbeddingAuthProviderId(providerId: string): string { - return providerId === "gemini" ? "google" : providerId; -} - -const openAiAdapter: MemoryEmbeddingProviderAdapter = { - id: "openai", - defaultModel: DEFAULT_OPENAI_EMBEDDING_MODEL, - transport: "remote", - autoSelectPriority: 20, - allowExplicitWhenConfiguredAuto: true, - shouldContinueAutoSelection: isMissingApiKeyError, - create: async (options) => { - const { provider, client } = await createOpenAiEmbeddingProvider({ - ...options, - provider: "openai", - fallback: "none", - }); - return { - provider, - runtime: { - id: "openai", - cacheKeyData: { - provider: "openai", - baseUrl: client.baseUrl, - model: client.model, - headers: sanitizeHeaders(client.headers, ["authorization"]), - }, - batchEmbed: async (batch) => { - const byCustomId = await runOpenAiEmbeddingBatches({ - openAi: client, - agentId: batch.agentId, - requests: batch.chunks.map((chunk, index) => ({ - custom_id: String(index), - method: "POST", - url: OPENAI_BATCH_ENDPOINT, - body: { - model: client.model, - input: chunk.text, - }, - })), - wait: batch.wait, - concurrency: batch.concurrency, - pollIntervalMs: batch.pollIntervalMs, - timeoutMs: batch.timeoutMs, - debug: batch.debug, - }); - return mapBatchEmbeddingsByIndex(byCustomId, batch.chunks.length); - }, - }, - }; - }, -}; - -const geminiAdapter: MemoryEmbeddingProviderAdapter = { - id: "gemini", - defaultModel: DEFAULT_GEMINI_EMBEDDING_MODEL, - transport: "remote", - autoSelectPriority: 30, - allowExplicitWhenConfiguredAuto: true, - supportsMultimodalEmbeddings: ({ model }) => supportsGeminiMultimodalEmbeddings(model), - shouldContinueAutoSelection: isMissingApiKeyError, - create: async (options) => { - const { provider, client } = await createGeminiEmbeddingProvider({ - ...options, - provider: "gemini", - fallback: "none", - }); - return { - provider, - runtime: { - id: "gemini", - cacheKeyData: { - provider: "gemini", - baseUrl: client.baseUrl, - model: client.model, - outputDimensionality: client.outputDimensionality, - headers: sanitizeHeaders(client.headers, ["authorization", "x-goog-api-key"]), - }, - batchEmbed: async (batch) => { - if (batch.chunks.some((chunk) => hasNonTextEmbeddingParts(chunk.embeddingInput))) { - return null; - } - const byCustomId = await runGeminiEmbeddingBatches({ - gemini: client, - agentId: batch.agentId, - requests: batch.chunks.map((chunk, index) => ({ - custom_id: String(index), - request: buildGeminiEmbeddingRequest({ - input: chunk.embeddingInput ?? { text: chunk.text }, - taskType: "RETRIEVAL_DOCUMENT", - modelPath: client.modelPath, - outputDimensionality: client.outputDimensionality, - }), - })), - wait: batch.wait, - concurrency: batch.concurrency, - pollIntervalMs: batch.pollIntervalMs, - timeoutMs: batch.timeoutMs, - debug: batch.debug, - }); - return mapBatchEmbeddingsByIndex(byCustomId, batch.chunks.length); - }, - }, - }; - }, -}; - -const voyageAdapter: MemoryEmbeddingProviderAdapter = { - id: "voyage", - defaultModel: DEFAULT_VOYAGE_EMBEDDING_MODEL, - transport: "remote", - autoSelectPriority: 40, - allowExplicitWhenConfiguredAuto: true, - shouldContinueAutoSelection: isMissingApiKeyError, - create: async (options) => { - const { provider, client } = await createVoyageEmbeddingProvider({ - ...options, - provider: "voyage", - fallback: "none", - }); - return { - provider, - runtime: { - id: "voyage", - batchEmbed: async (batch) => { - const byCustomId = await runVoyageEmbeddingBatches({ - client, - agentId: batch.agentId, - requests: batch.chunks.map((chunk, index) => ({ - custom_id: String(index), - body: { - input: chunk.text, - }, - })), - wait: batch.wait, - concurrency: batch.concurrency, - pollIntervalMs: batch.pollIntervalMs, - timeoutMs: batch.timeoutMs, - debug: batch.debug, - }); - return mapBatchEmbeddingsByIndex(byCustomId, batch.chunks.length); - }, - }, - }; - }, -}; - -const mistralAdapter: MemoryEmbeddingProviderAdapter = { - id: "mistral", - defaultModel: DEFAULT_MISTRAL_EMBEDDING_MODEL, - transport: "remote", - autoSelectPriority: 50, - allowExplicitWhenConfiguredAuto: true, - shouldContinueAutoSelection: isMissingApiKeyError, - create: async (options) => { - const { provider, client } = await createMistralEmbeddingProvider({ - ...options, - provider: "mistral", - fallback: "none", - }); - return { - provider, - runtime: { - id: "mistral", - cacheKeyData: { - provider: "mistral", - model: client.model, - }, - }, - }; - }, -}; - -const ollamaAdapter: MemoryEmbeddingProviderAdapter = { - id: "ollama", - defaultModel: DEFAULT_OLLAMA_EMBEDDING_MODEL, - transport: "remote", - create: async (options) => { - const { provider, client } = await createOllamaEmbeddingProvider({ - ...options, - provider: "ollama", - fallback: "none", - }); - return { - provider, - runtime: { - id: "ollama", - cacheKeyData: { - provider: "ollama", - baseUrl: client.baseUrl, - model: client.model, - headers: sanitizeHeaders(client.headers, ["authorization"]), - }, - }, - }; - }, -}; - -const lmstudioAdapter: MemoryEmbeddingProviderAdapter = { - id: "lmstudio", - defaultModel: DEFAULT_LMSTUDIO_EMBEDDING_MODEL, - transport: "remote", - create: async (options) => { - const { provider, client } = await createLmstudioEmbeddingProvider({ - ...options, - provider: "lmstudio", - fallback: "none", - }); - return { - provider, - runtime: { - id: "lmstudio", - cacheKeyData: { - provider: "lmstudio", - baseUrl: client.baseUrl, - model: client.model, - headers: sanitizeHeaders(client.headers, ["authorization"]), - }, - }, - }; - }, -}; - const localAdapter: MemoryEmbeddingProviderAdapter = { id: "local", defaultModel: DEFAULT_LOCAL_MODEL, @@ -368,24 +106,14 @@ const localAdapter: MemoryEmbeddingProviderAdapter = { }, }; -export const builtinMemoryEmbeddingProviderAdapters = [ - localAdapter, - openAiAdapter, - geminiAdapter, - voyageAdapter, - mistralAdapter, - ollamaAdapter, - lmstudioAdapter, -] as const; +export const builtinMemoryEmbeddingProviderAdapters = [localAdapter] as const; -const builtinMemoryEmbeddingProviderAdapterById = new Map( - builtinMemoryEmbeddingProviderAdapters.map((adapter) => [adapter.id, adapter]), -); +export { DEFAULT_LOCAL_MODEL }; export function getBuiltinMemoryEmbeddingProviderAdapter( id: string, ): MemoryEmbeddingProviderAdapter | undefined { - return builtinMemoryEmbeddingProviderAdapterById.get(id); + return listMemoryEmbeddingProviders().find((adapter) => adapter.id === id); } export function registerBuiltInMemoryEmbeddingProviders(register: { @@ -409,7 +137,7 @@ export function getBuiltinMemoryEmbeddingProviderDoctorMetadata( if (!adapter) { return null; } - const authProviderId = resolveMemoryEmbeddingAuthProviderId(adapter.id); + const authProviderId = adapter.authProviderId ?? adapter.id; return { providerId: adapter.id, authProviderId, @@ -420,27 +148,19 @@ export function getBuiltinMemoryEmbeddingProviderDoctorMetadata( } export function listBuiltinAutoSelectMemoryEmbeddingProviderDoctorMetadata(): Array { - return builtinMemoryEmbeddingProviderAdapters + return listMemoryEmbeddingProviders() .filter((adapter) => typeof adapter.autoSelectPriority === "number") .toSorted((a, b) => (a.autoSelectPriority ?? 0) - (b.autoSelectPriority ?? 0)) - .map((adapter) => ({ - providerId: adapter.id, - authProviderId: resolveMemoryEmbeddingAuthProviderId(adapter.id), - envVars: getProviderEnvVars(resolveMemoryEmbeddingAuthProviderId(adapter.id)), - transport: adapter.transport === "local" ? "local" : "remote", - autoSelectPriority: adapter.autoSelectPriority, - })); + .map((adapter) => { + const authProviderId = adapter.authProviderId ?? adapter.id; + return { + providerId: adapter.id, + authProviderId, + envVars: getProviderEnvVars(authProviderId), + transport: adapter.transport === "local" ? "local" : "remote", + autoSelectPriority: adapter.autoSelectPriority, + }; + }); } -export { - DEFAULT_GEMINI_EMBEDDING_MODEL, - DEFAULT_LMSTUDIO_EMBEDDING_MODEL, - DEFAULT_LOCAL_MODEL, - DEFAULT_MISTRAL_EMBEDDING_MODEL, - DEFAULT_OLLAMA_EMBEDDING_MODEL, - DEFAULT_OPENAI_EMBEDDING_MODEL, - DEFAULT_VOYAGE_EMBEDDING_MODEL, - canAutoSelectLocal, - formatLocalSetupError, - isMissingApiKeyError, -}; +export { canAutoSelectLocal, formatLocalSetupError }; diff --git a/extensions/memory-core/src/memory/qmd-manager.ts b/extensions/memory-core/src/memory/qmd-manager.ts index 427931d89f5..ebed89f00ec 100644 --- a/extensions/memory-core/src/memory/qmd-manager.ts +++ b/extensions/memory-core/src/memory/qmd-manager.ts @@ -16,6 +16,7 @@ import { writeFileWithinRoot, type OpenClawConfig, } from "openclaw/plugin-sdk/memory-core-host-engine-foundation"; +import { resolveAgentContextLimits } from "openclaw/plugin-sdk/memory-core-host-engine-foundation"; import { buildSessionEntry, deriveQmdScopeChannel, @@ -47,7 +48,6 @@ import { type ResolvedQmdConfig, type ResolvedQmdMcporterConfig, } from "openclaw/plugin-sdk/memory-core-host-engine-storage"; -import { resolveAgentContextLimits } from "openclaw/plugin-sdk/memory-core-host-engine-foundation"; import { localeLowercasePreservingWhitespace, normalizeLowercaseStringOrEmpty, @@ -1945,8 +1945,7 @@ export class QmdMemoryManager implements MemorySearchManager { from?: number, lines?: number, ): Promise< - | { missing: true } - | { missing: false; selectedLines: string[]; moreSourceLinesRemain: boolean } + { missing: true } | { missing: false; selectedLines: string[]; moreSourceLinesRemain: boolean } > { const start = Math.max(1, from ?? 1); const count = Math.max(1, lines ?? Number.POSITIVE_INFINITY); diff --git a/src/memory-host-sdk/host/embeddings-mistral.ts b/extensions/mistral/embedding-provider.ts similarity index 72% rename from src/memory-host-sdk/host/embeddings-mistral.ts rename to extensions/mistral/embedding-provider.ts index e72efef8028..4ef3c25d5a2 100644 --- a/src/memory-host-sdk/host/embeddings-mistral.ts +++ b/extensions/mistral/embedding-provider.ts @@ -1,10 +1,11 @@ -import type { SsrFPolicy } from "../../infra/net/ssrf.js"; -import { normalizeEmbeddingModelWithPrefixes } from "./embeddings-model-normalize.js"; import { createRemoteEmbeddingProvider, + normalizeEmbeddingModelWithPrefixes, resolveRemoteEmbeddingClient, -} from "./embeddings-remote-provider.js"; -import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.types.js"; + type MemoryEmbeddingProvider, + type MemoryEmbeddingProviderCreateOptions, +} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings"; +import type { SsrFPolicy } from "openclaw/plugin-sdk/ssrf-runtime"; export type MistralEmbeddingClient = { baseUrl: string; @@ -25,8 +26,8 @@ export function normalizeMistralModel(model: string): string { } export async function createMistralEmbeddingProvider( - options: EmbeddingProviderOptions, -): Promise<{ provider: EmbeddingProvider; client: MistralEmbeddingClient }> { + options: MemoryEmbeddingProviderCreateOptions, +): Promise<{ provider: MemoryEmbeddingProvider; client: MistralEmbeddingClient }> { const client = await resolveMistralEmbeddingClient(options); return { @@ -40,7 +41,7 @@ export async function createMistralEmbeddingProvider( } export async function resolveMistralEmbeddingClient( - options: EmbeddingProviderOptions, + options: MemoryEmbeddingProviderCreateOptions, ): Promise { return await resolveRemoteEmbeddingClient({ provider: "mistral", diff --git a/extensions/mistral/index.ts b/extensions/mistral/index.ts index f3a87a13ebe..c80f267c102 100644 --- a/extensions/mistral/index.ts +++ b/extensions/mistral/index.ts @@ -1,6 +1,7 @@ import { defineSingleProviderPluginEntry } from "openclaw/plugin-sdk/provider-entry"; import { applyMistralModelCompat } from "./api.js"; import { mistralMediaUnderstandingProvider } from "./media-understanding-provider.js"; +import { mistralMemoryEmbeddingProviderAdapter } from "./memory-embedding-adapter.js"; import { applyMistralConfig, MISTRAL_DEFAULT_MODEL_REF } from "./onboard.js"; import { buildMistralProvider } from "./provider-catalog.js"; import { contributeMistralResolvedModelCompat } from "./provider-compat.js"; @@ -48,6 +49,7 @@ export default defineSingleProviderPluginEntry({ buildReplayPolicy: () => buildMistralReplayPolicy(), }, register(api) { + api.registerMemoryEmbeddingProvider(mistralMemoryEmbeddingProviderAdapter); api.registerMediaUnderstandingProvider(mistralMediaUnderstandingProvider); }, }); diff --git a/extensions/mistral/memory-embedding-adapter.ts b/extensions/mistral/memory-embedding-adapter.ts new file mode 100644 index 00000000000..c52c5f96ad3 --- /dev/null +++ b/extensions/mistral/memory-embedding-adapter.ts @@ -0,0 +1,35 @@ +import { + isMissingEmbeddingApiKeyError, + type MemoryEmbeddingProviderAdapter, +} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings"; +import { + createMistralEmbeddingProvider, + DEFAULT_MISTRAL_EMBEDDING_MODEL, +} from "./embedding-provider.js"; + +export const mistralMemoryEmbeddingProviderAdapter: MemoryEmbeddingProviderAdapter = { + id: "mistral", + defaultModel: DEFAULT_MISTRAL_EMBEDDING_MODEL, + transport: "remote", + authProviderId: "mistral", + autoSelectPriority: 50, + allowExplicitWhenConfiguredAuto: true, + shouldContinueAutoSelection: isMissingEmbeddingApiKeyError, + create: async (options) => { + const { provider, client } = await createMistralEmbeddingProvider({ + ...options, + provider: "mistral", + fallback: "none", + }); + return { + provider, + runtime: { + id: "mistral", + cacheKeyData: { + provider: "mistral", + model: client.model, + }, + }, + }; + }, +}; diff --git a/extensions/mistral/openclaw.plugin.json b/extensions/mistral/openclaw.plugin.json index 6cf38a73a27..53bf7e40e98 100644 --- a/extensions/mistral/openclaw.plugin.json +++ b/extensions/mistral/openclaw.plugin.json @@ -21,6 +21,7 @@ } ], "contracts": { + "memoryEmbeddingProviders": ["mistral"], "mediaUnderstandingProviders": ["mistral"] }, "configSchema": { diff --git a/extensions/ollama/src/memory-embedding-adapter.ts b/extensions/ollama/src/memory-embedding-adapter.ts index eccc9e03773..b3527fe8b4f 100644 --- a/extensions/ollama/src/memory-embedding-adapter.ts +++ b/extensions/ollama/src/memory-embedding-adapter.ts @@ -8,6 +8,7 @@ export const ollamaMemoryEmbeddingProviderAdapter: MemoryEmbeddingProviderAdapte id: "ollama", defaultModel: DEFAULT_OLLAMA_EMBEDDING_MODEL, transport: "remote", + authProviderId: "ollama", create: async (options) => { const { provider, client } = await createOllamaEmbeddingProvider({ ...options, diff --git a/src/memory-host-sdk/host/batch-openai.ts b/extensions/openai/embedding-batch.ts similarity index 98% rename from src/memory-host-sdk/host/batch-openai.ts rename to extensions/openai/embedding-batch.ts index 380a50fc6a3..ccbc0e3e462 100644 --- a/src/memory-host-sdk/host/batch-openai.ts +++ b/extensions/openai/embedding-batch.ts @@ -17,8 +17,8 @@ import { type ProviderBatchOutputLine, uploadBatchJsonlFile, withRemoteHttpResponse, -} from "./batch-embedding-common.js"; -import type { OpenAiEmbeddingClient } from "./embeddings-openai.js"; +} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings"; +import type { OpenAiEmbeddingClient } from "./embedding-provider.js"; export type OpenAiBatchRequest = { custom_id: string; diff --git a/src/memory-host-sdk/host/embeddings-openai.ts b/extensions/openai/embedding-provider.ts similarity index 67% rename from src/memory-host-sdk/host/embeddings-openai.ts rename to extensions/openai/embedding-provider.ts index c8121dd2426..a536a93b0fe 100644 --- a/src/memory-host-sdk/host/embeddings-openai.ts +++ b/extensions/openai/embedding-provider.ts @@ -1,11 +1,11 @@ -import { parseStaticModelRef } from "../../agents/model-ref-shared.js"; -import type { SsrFPolicy } from "../../infra/net/ssrf.js"; -import { OPENAI_DEFAULT_EMBEDDING_MODEL } from "../../plugins/provider-model-defaults.js"; import { createRemoteEmbeddingProvider, resolveRemoteEmbeddingClient, -} from "./embeddings-remote-provider.js"; -import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.types.js"; + type MemoryEmbeddingProvider, + type MemoryEmbeddingProviderCreateOptions, +} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings"; +import type { SsrFPolicy } from "openclaw/plugin-sdk/ssrf-runtime"; +import { OPENAI_DEFAULT_EMBEDDING_MODEL } from "./default-models.js"; export type OpenAiEmbeddingClient = { baseUrl: string; @@ -28,13 +28,12 @@ export function normalizeOpenAiModel(model: string): string { if (!trimmed) { return DEFAULT_OPENAI_EMBEDDING_MODEL; } - const parsed = parseStaticModelRef(trimmed, "openai"); - return parsed && parsed.provider === "openai" ? parsed.model : trimmed; + return trimmed.startsWith("openai/") ? trimmed.slice("openai/".length) : trimmed; } export async function createOpenAiEmbeddingProvider( - options: EmbeddingProviderOptions, -): Promise<{ provider: EmbeddingProvider; client: OpenAiEmbeddingClient }> { + options: MemoryEmbeddingProviderCreateOptions, +): Promise<{ provider: MemoryEmbeddingProvider; client: OpenAiEmbeddingClient }> { const client = await resolveOpenAiEmbeddingClient(options); return { @@ -49,7 +48,7 @@ export async function createOpenAiEmbeddingProvider( } export async function resolveOpenAiEmbeddingClient( - options: EmbeddingProviderOptions, + options: MemoryEmbeddingProviderCreateOptions, ): Promise { return await resolveRemoteEmbeddingClient({ provider: "openai", diff --git a/extensions/openai/index.ts b/extensions/openai/index.ts index 822bb0bbedb..d5bf2d12fd9 100644 --- a/extensions/openai/index.ts +++ b/extensions/openai/index.ts @@ -6,6 +6,7 @@ import { openaiCodexMediaUnderstandingProvider, openaiMediaUnderstandingProvider, } from "./media-understanding-provider.js"; +import { openAiMemoryEmbeddingProviderAdapter } from "./memory-embedding-adapter.js"; import { buildOpenAICodexProviderPlugin } from "./openai-codex-provider.js"; import { buildOpenAIProvider } from "./openai-provider.js"; import { @@ -39,6 +40,7 @@ export default definePluginEntry({ api.registerCliBackend(buildOpenAICodexCliBackend()); api.registerProvider(buildProviderWithPromptContribution(buildOpenAIProvider())); api.registerProvider(buildProviderWithPromptContribution(buildOpenAICodexProviderPlugin())); + api.registerMemoryEmbeddingProvider(openAiMemoryEmbeddingProviderAdapter); api.registerImageGenerationProvider(buildOpenAIImageGenerationProvider()); api.registerRealtimeTranscriptionProvider(buildOpenAIRealtimeTranscriptionProvider()); api.registerRealtimeVoiceProvider(buildOpenAIRealtimeVoiceProvider()); diff --git a/extensions/openai/memory-embedding-adapter.ts b/extensions/openai/memory-embedding-adapter.ts new file mode 100644 index 00000000000..16a255c4af9 --- /dev/null +++ b/extensions/openai/memory-embedding-adapter.ts @@ -0,0 +1,61 @@ +import { + isMissingEmbeddingApiKeyError, + mapBatchEmbeddingsByIndex, + sanitizeEmbeddingCacheHeaders, + type MemoryEmbeddingProviderAdapter, +} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings"; +import { OPENAI_BATCH_ENDPOINT, runOpenAiEmbeddingBatches } from "./embedding-batch.js"; +import { + createOpenAiEmbeddingProvider, + DEFAULT_OPENAI_EMBEDDING_MODEL, +} from "./embedding-provider.js"; + +export const openAiMemoryEmbeddingProviderAdapter: MemoryEmbeddingProviderAdapter = { + id: "openai", + defaultModel: DEFAULT_OPENAI_EMBEDDING_MODEL, + transport: "remote", + authProviderId: "openai", + autoSelectPriority: 20, + allowExplicitWhenConfiguredAuto: true, + shouldContinueAutoSelection: isMissingEmbeddingApiKeyError, + create: async (options) => { + const { provider, client } = await createOpenAiEmbeddingProvider({ + ...options, + provider: "openai", + fallback: "none", + }); + return { + provider, + runtime: { + id: "openai", + cacheKeyData: { + provider: "openai", + baseUrl: client.baseUrl, + model: client.model, + headers: sanitizeEmbeddingCacheHeaders(client.headers, ["authorization"]), + }, + batchEmbed: async (batch) => { + const byCustomId = await runOpenAiEmbeddingBatches({ + openAi: client, + agentId: batch.agentId, + requests: batch.chunks.map((chunk, index) => ({ + custom_id: String(index), + method: "POST", + url: OPENAI_BATCH_ENDPOINT, + body: { + model: client.model, + input: chunk.text, + }, + })), + wait: batch.wait, + concurrency: batch.concurrency, + pollIntervalMs: batch.pollIntervalMs, + timeoutMs: batch.timeoutMs, + debug: batch.debug, + }); + return mapBatchEmbeddingsByIndex(byCustomId, batch.chunks.length); + }, + }, + }; + }, +}; diff --git a/extensions/openai/openclaw.plugin.json b/extensions/openai/openclaw.plugin.json index df962ac8b47..b5e3aed4211 100644 --- a/extensions/openai/openclaw.plugin.json +++ b/extensions/openai/openclaw.plugin.json @@ -39,6 +39,7 @@ "speechProviders": ["openai"], "realtimeTranscriptionProviders": ["openai"], "realtimeVoiceProviders": ["openai"], + "memoryEmbeddingProviders": ["openai"], "mediaUnderstandingProviders": ["openai", "openai-codex"], "imageGenerationProviders": ["openai"], "videoGenerationProviders": ["openai"] diff --git a/packages/memory-host-sdk/src/host/batch-voyage.ts b/extensions/voyage/embedding-batch.ts similarity index 98% rename from packages/memory-host-sdk/src/host/batch-voyage.ts rename to extensions/voyage/embedding-batch.ts index fcb257a4d7d..0bef2d2aa25 100644 --- a/packages/memory-host-sdk/src/host/batch-voyage.ts +++ b/extensions/voyage/embedding-batch.ts @@ -19,8 +19,8 @@ import { type ProviderBatchOutputLine, uploadBatchJsonlFile, withRemoteHttpResponse, -} from "./batch-embedding-common.js"; -import type { VoyageEmbeddingClient } from "./embeddings-voyage.js"; +} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings"; +import type { VoyageEmbeddingClient } from "./embedding-provider.js"; /** * Voyage Batch API Input Line format. diff --git a/src/memory-host-sdk/host/embeddings-voyage.ts b/extensions/voyage/embedding-provider.ts similarity index 79% rename from src/memory-host-sdk/host/embeddings-voyage.ts rename to extensions/voyage/embedding-provider.ts index 54eb0431524..f4d218c80cc 100644 --- a/src/memory-host-sdk/host/embeddings-voyage.ts +++ b/extensions/voyage/embedding-provider.ts @@ -1,8 +1,11 @@ -import type { SsrFPolicy } from "../../infra/net/ssrf.js"; -import { normalizeEmbeddingModelWithPrefixes } from "./embeddings-model-normalize.js"; -import { resolveRemoteEmbeddingBearerClient } from "./embeddings-remote-client.js"; -import { fetchRemoteEmbeddingVectors } from "./embeddings-remote-fetch.js"; -import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.types.js"; +import { + fetchRemoteEmbeddingVectors, + normalizeEmbeddingModelWithPrefixes, + resolveRemoteEmbeddingBearerClient, + type MemoryEmbeddingProvider, + type MemoryEmbeddingProviderCreateOptions, +} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings"; +import type { SsrFPolicy } from "openclaw/plugin-sdk/ssrf-runtime"; export type VoyageEmbeddingClient = { baseUrl: string; @@ -28,8 +31,8 @@ export function normalizeVoyageModel(model: string): string { } export async function createVoyageEmbeddingProvider( - options: EmbeddingProviderOptions, -): Promise<{ provider: EmbeddingProvider; client: VoyageEmbeddingClient }> { + options: MemoryEmbeddingProviderCreateOptions, +): Promise<{ provider: MemoryEmbeddingProvider; client: VoyageEmbeddingClient }> { const client = await resolveVoyageEmbeddingClient(options); const url = `${client.baseUrl.replace(/\/$/, "")}/embeddings`; @@ -70,7 +73,7 @@ export async function createVoyageEmbeddingProvider( } export async function resolveVoyageEmbeddingClient( - options: EmbeddingProviderOptions, + options: MemoryEmbeddingProviderCreateOptions, ): Promise { const { baseUrl, headers, ssrfPolicy } = await resolveRemoteEmbeddingBearerClient({ provider: "voyage", diff --git a/extensions/voyage/index.ts b/extensions/voyage/index.ts new file mode 100644 index 00000000000..d3020f110bc --- /dev/null +++ b/extensions/voyage/index.ts @@ -0,0 +1,11 @@ +import { definePluginEntry } from "openclaw/plugin-sdk/plugin-entry"; +import { voyageMemoryEmbeddingProviderAdapter } from "./memory-embedding-adapter.js"; + +export default definePluginEntry({ + id: "voyage", + name: "Voyage Embeddings", + description: "Bundled Voyage memory embedding provider plugin", + register(api) { + api.registerMemoryEmbeddingProvider(voyageMemoryEmbeddingProviderAdapter); + }, +}); diff --git a/extensions/voyage/memory-embedding-adapter.ts b/extensions/voyage/memory-embedding-adapter.ts new file mode 100644 index 00000000000..5d0e0a9841d --- /dev/null +++ b/extensions/voyage/memory-embedding-adapter.ts @@ -0,0 +1,56 @@ +import { + isMissingEmbeddingApiKeyError, + mapBatchEmbeddingsByIndex, + sanitizeEmbeddingCacheHeaders, + type MemoryEmbeddingProviderAdapter, +} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings"; +import { runVoyageEmbeddingBatches } from "./embedding-batch.js"; +import { + createVoyageEmbeddingProvider, + DEFAULT_VOYAGE_EMBEDDING_MODEL, +} from "./embedding-provider.js"; + +export const voyageMemoryEmbeddingProviderAdapter: MemoryEmbeddingProviderAdapter = { + id: "voyage", + defaultModel: DEFAULT_VOYAGE_EMBEDDING_MODEL, + transport: "remote", + authProviderId: "voyage", + autoSelectPriority: 40, + allowExplicitWhenConfiguredAuto: true, + shouldContinueAutoSelection: isMissingEmbeddingApiKeyError, + create: async (options) => { + const { provider, client } = await createVoyageEmbeddingProvider({ + ...options, + provider: "voyage", + fallback: "none", + }); + return { + provider, + runtime: { + id: "voyage", + cacheKeyData: { + provider: "voyage", + baseUrl: client.baseUrl, + model: client.model, + headers: sanitizeEmbeddingCacheHeaders(client.headers, ["authorization"]), + }, + batchEmbed: async (batch) => { + const byCustomId = await runVoyageEmbeddingBatches({ + client, + agentId: batch.agentId, + requests: batch.chunks.map((chunk, index) => ({ + custom_id: String(index), + body: { input: chunk.text }, + })), + wait: batch.wait, + concurrency: batch.concurrency, + pollIntervalMs: batch.pollIntervalMs, + timeoutMs: batch.timeoutMs, + debug: batch.debug, + }); + return mapBatchEmbeddingsByIndex(byCustomId, batch.chunks.length); + }, + }, + }; + }, +}; diff --git a/extensions/voyage/openclaw.plugin.json b/extensions/voyage/openclaw.plugin.json new file mode 100644 index 00000000000..4615d1d3600 --- /dev/null +++ b/extensions/voyage/openclaw.plugin.json @@ -0,0 +1,15 @@ +{ + "id": "voyage", + "enabledByDefault": true, + "contracts": { + "memoryEmbeddingProviders": ["voyage"] + }, + "providerAuthEnvVars": { + "voyage": ["VOYAGE_API_KEY"] + }, + "configSchema": { + "type": "object", + "additionalProperties": false, + "properties": {} + } +} diff --git a/extensions/voyage/package.json b/extensions/voyage/package.json new file mode 100644 index 00000000000..a60b913b084 --- /dev/null +++ b/extensions/voyage/package.json @@ -0,0 +1,15 @@ +{ + "name": "@openclaw/voyage-provider", + "version": "2026.4.15-beta.1", + "private": true, + "description": "OpenClaw Voyage embedding provider plugin", + "type": "module", + "devDependencies": { + "@openclaw/plugin-sdk": "workspace:*" + }, + "openclaw": { + "extensions": [ + "./index.ts" + ] + } +} diff --git a/packages/memory-host-sdk/src/host/batch-gemini.test.ts b/packages/memory-host-sdk/src/host/batch-gemini.test.ts deleted file mode 100644 index 095ebe008b9..00000000000 --- a/packages/memory-host-sdk/src/host/batch-gemini.test.ts +++ /dev/null @@ -1,116 +0,0 @@ -import { afterEach, beforeAll, beforeEach, describe, expect, it, vi } from "vitest"; -import type { GeminiEmbeddingClient } from "./embeddings-gemini.js"; - -vi.mock("./remote-http.js", () => ({ - withRemoteHttpResponse: vi.fn(), -})); - -function magnitude(values: number[]) { - return Math.sqrt(values.reduce((sum, value) => sum + value * value, 0)); -} - -describe("runGeminiEmbeddingBatches", () => { - let runGeminiEmbeddingBatches: typeof import("./batch-gemini.js").runGeminiEmbeddingBatches; - let withRemoteHttpResponse: typeof import("./remote-http.js").withRemoteHttpResponse; - let remoteHttpMock: ReturnType>; - - beforeAll(async () => { - ({ runGeminiEmbeddingBatches } = await import("./batch-gemini.js")); - ({ withRemoteHttpResponse } = await import("./remote-http.js")); - remoteHttpMock = vi.mocked(withRemoteHttpResponse); - }); - - beforeEach(() => { - vi.clearAllMocks(); - }); - - afterEach(() => { - vi.resetAllMocks(); - vi.unstubAllGlobals(); - }); - - const mockClient: GeminiEmbeddingClient = { - baseUrl: "https://generativelanguage.googleapis.com/v1beta", - headers: {}, - model: "gemini-embedding-2-preview", - modelPath: "models/gemini-embedding-2-preview", - apiKeys: ["test-key"], - outputDimensionality: 1536, - }; - - it("includes outputDimensionality in batch upload requests", async () => { - remoteHttpMock.mockImplementationOnce(async (params) => { - expect(params.url).toContain("/upload/v1beta/files?uploadType=multipart"); - const body = params.init?.body; - if (!(body instanceof Blob)) { - throw new Error("expected multipart blob body"); - } - const text = await body.text(); - expect(text).toContain('"taskType":"RETRIEVAL_DOCUMENT"'); - expect(text).toContain('"outputDimensionality":1536'); - return await params.onResponse( - new Response(JSON.stringify({ name: "files/file-123" }), { - status: 200, - headers: { "Content-Type": "application/json" }, - }), - ); - }); - remoteHttpMock.mockImplementationOnce(async (params) => { - expect(params.url).toMatch(/:asyncBatchEmbedContent$/u); - return await params.onResponse( - new Response( - JSON.stringify({ - name: "batches/batch-1", - state: "COMPLETED", - outputConfig: { file: "files/output-1" }, - }), - { - status: 200, - headers: { "Content-Type": "application/json" }, - }, - ), - ); - }); - remoteHttpMock.mockImplementationOnce(async (params) => { - expect(params.url).toMatch(/\/files\/output-1:download$/u); - return await params.onResponse( - new Response( - JSON.stringify({ - key: "req-1", - response: { embedding: { values: [3, 4] } }, - }), - { - status: 200, - headers: { "Content-Type": "application/jsonl" }, - }, - ), - ); - }); - - const results = await runGeminiEmbeddingBatches({ - gemini: mockClient, - agentId: "main", - requests: [ - { - custom_id: "req-1", - request: { - content: { parts: [{ text: "hello world" }] }, - taskType: "RETRIEVAL_DOCUMENT", - outputDimensionality: 1536, - }, - }, - ], - wait: true, - pollIntervalMs: 1, - timeoutMs: 1000, - concurrency: 1, - }); - - const embedding = results.get("req-1"); - expect(embedding).toBeDefined(); - expect(embedding?.[0]).toBeCloseTo(0.6, 5); - expect(embedding?.[1]).toBeCloseTo(0.8, 5); - expect(magnitude(embedding ?? [])).toBeCloseTo(1, 5); - expect(remoteHttpMock).toHaveBeenCalledTimes(3); - }); -}); diff --git a/packages/memory-host-sdk/src/host/batch-openai.ts b/packages/memory-host-sdk/src/host/batch-openai.ts deleted file mode 100644 index e17a420812c..00000000000 --- a/packages/memory-host-sdk/src/host/batch-openai.ts +++ /dev/null @@ -1,259 +0,0 @@ -import { - applyEmbeddingBatchOutputLine, - buildBatchHeaders, - buildEmbeddingBatchGroupOptions, - EMBEDDING_BATCH_ENDPOINT, - extractBatchErrorMessage, - formatUnavailableBatchError, - normalizeBatchBaseUrl, - postJsonWithRetry, - resolveBatchCompletionFromStatus, - resolveCompletedBatchResult, - runEmbeddingBatchGroups, - throwIfBatchTerminalFailure, - type EmbeddingBatchExecutionParams, - type EmbeddingBatchStatus, - type BatchCompletionResult, - type ProviderBatchOutputLine, - uploadBatchJsonlFile, - withRemoteHttpResponse, -} from "./batch-embedding-common.js"; -import type { OpenAiEmbeddingClient } from "./embeddings-openai.js"; - -export type OpenAiBatchRequest = { - custom_id: string; - method: "POST"; - url: "/v1/embeddings"; - body: { - model: string; - input: string; - }; -}; - -export type OpenAiBatchStatus = EmbeddingBatchStatus; -export type OpenAiBatchOutputLine = ProviderBatchOutputLine; - -export const OPENAI_BATCH_ENDPOINT = EMBEDDING_BATCH_ENDPOINT; -const OPENAI_BATCH_COMPLETION_WINDOW = "24h"; -const OPENAI_BATCH_MAX_REQUESTS = 50000; - -async function submitOpenAiBatch(params: { - openAi: OpenAiEmbeddingClient; - requests: OpenAiBatchRequest[]; - agentId: string; -}): Promise { - const baseUrl = normalizeBatchBaseUrl(params.openAi); - const inputFileId = await uploadBatchJsonlFile({ - client: params.openAi, - requests: params.requests, - errorPrefix: "openai batch file upload failed", - }); - - return await postJsonWithRetry({ - url: `${baseUrl}/batches`, - headers: buildBatchHeaders(params.openAi, { json: true }), - ssrfPolicy: params.openAi.ssrfPolicy, - body: { - input_file_id: inputFileId, - endpoint: OPENAI_BATCH_ENDPOINT, - completion_window: OPENAI_BATCH_COMPLETION_WINDOW, - metadata: { - source: "openclaw-memory", - agent: params.agentId, - }, - }, - errorPrefix: "openai batch create failed", - }); -} - -async function fetchOpenAiBatchStatus(params: { - openAi: OpenAiEmbeddingClient; - batchId: string; -}): Promise { - return await fetchOpenAiBatchResource({ - openAi: params.openAi, - path: `/batches/${params.batchId}`, - errorPrefix: "openai batch status", - parse: async (res) => (await res.json()) as OpenAiBatchStatus, - }); -} - -async function fetchOpenAiFileContent(params: { - openAi: OpenAiEmbeddingClient; - fileId: string; -}): Promise { - return await fetchOpenAiBatchResource({ - openAi: params.openAi, - path: `/files/${params.fileId}/content`, - errorPrefix: "openai batch file content", - parse: async (res) => await res.text(), - }); -} - -async function fetchOpenAiBatchResource(params: { - openAi: OpenAiEmbeddingClient; - path: string; - errorPrefix: string; - parse: (res: Response) => Promise; -}): Promise { - const baseUrl = normalizeBatchBaseUrl(params.openAi); - return await withRemoteHttpResponse({ - url: `${baseUrl}${params.path}`, - ssrfPolicy: params.openAi.ssrfPolicy, - init: { - headers: buildBatchHeaders(params.openAi, { json: true }), - }, - onResponse: async (res) => { - if (!res.ok) { - const text = await res.text(); - throw new Error(`${params.errorPrefix} failed: ${res.status} ${text}`); - } - return await params.parse(res); - }, - }); -} - -function parseOpenAiBatchOutput(text: string): OpenAiBatchOutputLine[] { - if (!text.trim()) { - return []; - } - return text - .split("\n") - .map((line) => line.trim()) - .filter(Boolean) - .map((line) => JSON.parse(line) as OpenAiBatchOutputLine); -} - -async function readOpenAiBatchError(params: { - openAi: OpenAiEmbeddingClient; - errorFileId: string; -}): Promise { - try { - const content = await fetchOpenAiFileContent({ - openAi: params.openAi, - fileId: params.errorFileId, - }); - const lines = parseOpenAiBatchOutput(content); - return extractBatchErrorMessage(lines); - } catch (err) { - return formatUnavailableBatchError(err); - } -} - -async function waitForOpenAiBatch(params: { - openAi: OpenAiEmbeddingClient; - batchId: string; - wait: boolean; - pollIntervalMs: number; - timeoutMs: number; - debug?: (message: string, data?: Record) => void; - initial?: OpenAiBatchStatus; -}): Promise { - const start = Date.now(); - let current: OpenAiBatchStatus | undefined = params.initial; - while (true) { - const status = - current ?? - (await fetchOpenAiBatchStatus({ - openAi: params.openAi, - batchId: params.batchId, - })); - const state = status.status ?? "unknown"; - if (state === "completed") { - return resolveBatchCompletionFromStatus({ - provider: "openai", - batchId: params.batchId, - status, - }); - } - await throwIfBatchTerminalFailure({ - provider: "openai", - status: { ...status, id: params.batchId }, - readError: async (errorFileId) => - await readOpenAiBatchError({ - openAi: params.openAi, - errorFileId, - }), - }); - if (!params.wait) { - throw new Error(`openai batch ${params.batchId} still ${state}; wait disabled`); - } - if (Date.now() - start > params.timeoutMs) { - throw new Error(`openai batch ${params.batchId} timed out after ${params.timeoutMs}ms`); - } - params.debug?.(`openai batch ${params.batchId} ${state}; waiting ${params.pollIntervalMs}ms`); - await new Promise((resolve) => setTimeout(resolve, params.pollIntervalMs)); - current = undefined; - } -} - -export async function runOpenAiEmbeddingBatches( - params: { - openAi: OpenAiEmbeddingClient; - agentId: string; - requests: OpenAiBatchRequest[]; - } & EmbeddingBatchExecutionParams, -): Promise> { - return await runEmbeddingBatchGroups({ - ...buildEmbeddingBatchGroupOptions(params, { - maxRequests: OPENAI_BATCH_MAX_REQUESTS, - debugLabel: "memory embeddings: openai batch submit", - }), - runGroup: async ({ group, groupIndex, groups, byCustomId }) => { - const batchInfo = await submitOpenAiBatch({ - openAi: params.openAi, - requests: group, - agentId: params.agentId, - }); - if (!batchInfo.id) { - throw new Error("openai batch create failed: missing batch id"); - } - const batchId = batchInfo.id; - - params.debug?.("memory embeddings: openai batch created", { - batchId: batchInfo.id, - status: batchInfo.status, - group: groupIndex + 1, - groups, - requests: group.length, - }); - - const completed = await resolveCompletedBatchResult({ - provider: "openai", - status: batchInfo, - wait: params.wait, - waitForBatch: async () => - await waitForOpenAiBatch({ - openAi: params.openAi, - batchId, - wait: params.wait, - pollIntervalMs: params.pollIntervalMs, - timeoutMs: params.timeoutMs, - debug: params.debug, - initial: batchInfo, - }), - }); - - const content = await fetchOpenAiFileContent({ - openAi: params.openAi, - fileId: completed.outputFileId, - }); - const outputLines = parseOpenAiBatchOutput(content); - const errors: string[] = []; - const remaining = new Set(group.map((request) => request.custom_id)); - - for (const line of outputLines) { - applyEmbeddingBatchOutputLine({ line, remaining, errors, byCustomId }); - } - - if (errors.length > 0) { - throw new Error(`openai batch ${batchInfo.id} failed: ${errors.join("; ")}`); - } - if (remaining.size > 0) { - throw new Error( - `openai batch ${batchInfo.id} missing ${remaining.size} embedding responses`, - ); - } - }, - }); -} diff --git a/packages/memory-host-sdk/src/host/batch-voyage.test.ts b/packages/memory-host-sdk/src/host/batch-voyage.test.ts deleted file mode 100644 index 2fcdb9ec7c0..00000000000 --- a/packages/memory-host-sdk/src/host/batch-voyage.test.ts +++ /dev/null @@ -1,176 +0,0 @@ -import { ReadableStream } from "node:stream/web"; -import { setTimeout as nativeSleep } from "node:timers/promises"; -import { describe, expect, it, vi } from "vitest"; -import { - runVoyageEmbeddingBatches, - type VoyageBatchOutputLine, - type VoyageBatchRequest, -} from "./batch-voyage.js"; -import type { VoyageEmbeddingClient } from "./embeddings-voyage.js"; - -const realNow = Date.now.bind(Date); - -describe("runVoyageEmbeddingBatches", () => { - const mockClient: VoyageEmbeddingClient = { - baseUrl: "https://api.voyageai.com/v1", - headers: { Authorization: "Bearer test-key" }, - model: "voyage-4-large", - }; - - const mockRequests: VoyageBatchRequest[] = [ - { custom_id: "req-1", body: { input: "text1" } }, - { custom_id: "req-2", body: { input: "text2" } }, - ]; - - it("successfully submits batch, waits, and streams results", async () => { - const outputLines: VoyageBatchOutputLine[] = [ - { - custom_id: "req-1", - response: { status_code: 200, body: { data: [{ embedding: [0.1, 0.1] }] } }, - }, - { - custom_id: "req-2", - response: { status_code: 200, body: { data: [{ embedding: [0.2, 0.2] }] } }, - }, - ]; - const withRemoteHttpResponse = vi.fn(); - const postJsonWithRetry = vi.fn(); - const uploadBatchJsonlFile = vi.fn(); - - // Create a stream that emits the NDJSON lines - const stream = new ReadableStream({ - start(controller) { - const text = outputLines.map((l) => JSON.stringify(l)).join("\n"); - controller.enqueue(new TextEncoder().encode(text)); - controller.close(); - }, - }); - uploadBatchJsonlFile.mockImplementationOnce(async (params) => { - expect(params.errorPrefix).toBe("voyage batch file upload failed"); - expect(params.requests).toEqual(mockRequests); - return "file-123"; - }); - postJsonWithRetry.mockImplementationOnce(async (params) => { - expect(params.url).toContain("/batches"); - expect(params.body).toMatchObject({ - input_file_id: "file-123", - completion_window: "12h", - request_params: { - model: "voyage-4-large", - input_type: "document", - }, - }); - return { - id: "batch-abc", - status: "pending", - }; - }); - withRemoteHttpResponse.mockImplementationOnce(async (params) => { - expect(params.url).toContain("/batches/batch-abc"); - return await params.onResponse( - new Response( - JSON.stringify({ - id: "batch-abc", - status: "completed", - output_file_id: "file-out-999", - }), - { - status: 200, - headers: { "Content-Type": "application/json" }, - }, - ), - ); - }); - withRemoteHttpResponse.mockImplementationOnce(async (params) => { - expect(params.url).toContain("/files/file-out-999/content"); - return await params.onResponse( - new Response(stream as unknown as BodyInit, { - status: 200, - headers: { "Content-Type": "application/x-ndjson" }, - }), - ); - }); - - const results = await runVoyageEmbeddingBatches({ - client: mockClient, - agentId: "agent-1", - requests: mockRequests, - wait: true, - pollIntervalMs: 1, // fast poll - timeoutMs: 1000, - concurrency: 1, - deps: { - now: realNow, - sleep: async (ms) => { - await nativeSleep(ms); - }, - postJsonWithRetry, - uploadBatchJsonlFile, - withRemoteHttpResponse, - }, - }); - - expect(results.size).toBe(2); - expect(results.get("req-1")).toEqual([0.1, 0.1]); - expect(results.get("req-2")).toEqual([0.2, 0.2]); - expect(uploadBatchJsonlFile).toHaveBeenCalledTimes(1); - expect(postJsonWithRetry).toHaveBeenCalledTimes(1); - expect(withRemoteHttpResponse).toHaveBeenCalledTimes(2); - }); - - it("handles empty lines and stream chunks correctly", async () => { - const withRemoteHttpResponse = vi.fn(); - const postJsonWithRetry = vi.fn(); - const uploadBatchJsonlFile = vi.fn(); - const stream = new ReadableStream({ - start(controller) { - const line1 = JSON.stringify({ - custom_id: "req-1", - response: { body: { data: [{ embedding: [1] }] } }, - }); - const line2 = JSON.stringify({ - custom_id: "req-2", - response: { body: { data: [{ embedding: [2] }] } }, - }); - - // Split across chunks - controller.enqueue(new TextEncoder().encode(line1 + "\n")); - controller.enqueue(new TextEncoder().encode("\n")); // empty line - controller.enqueue(new TextEncoder().encode(line2)); // no newline at EOF - controller.close(); - }, - }); - uploadBatchJsonlFile.mockResolvedValueOnce("f1"); - postJsonWithRetry.mockResolvedValueOnce({ - id: "b1", - status: "completed", - output_file_id: "out1", - }); - withRemoteHttpResponse.mockImplementationOnce(async (params) => { - expect(params.url).toContain("/files/out1/content"); - return await params.onResponse(new Response(stream as unknown as BodyInit, { status: 200 })); - }); - - const results = await runVoyageEmbeddingBatches({ - client: mockClient, - agentId: "a1", - requests: mockRequests, - wait: true, - pollIntervalMs: 1, - timeoutMs: 1000, - concurrency: 1, - deps: { - now: realNow, - sleep: async (ms) => { - await nativeSleep(ms); - }, - postJsonWithRetry, - uploadBatchJsonlFile, - withRemoteHttpResponse, - }, - }); - - expect(results.get("req-1")).toEqual([1]); - expect(results.get("req-2")).toEqual([2]); - }); -}); diff --git a/packages/memory-host-sdk/src/host/embedding-model-limits.ts b/packages/memory-host-sdk/src/host/embedding-model-limits.ts index 201e6450416..714114f670e 100644 --- a/packages/memory-host-sdk/src/host/embedding-model-limits.ts +++ b/packages/memory-host-sdk/src/host/embedding-model-limits.ts @@ -1,40 +1,14 @@ -import { normalizeLowercaseStringOrEmpty } from "../../../../src/shared/string-coerce.js"; import type { EmbeddingProvider } from "./embeddings.js"; const DEFAULT_EMBEDDING_MAX_INPUT_TOKENS = 8192; const DEFAULT_LOCAL_EMBEDDING_MAX_INPUT_TOKENS = 2048; -const KNOWN_EMBEDDING_MAX_INPUT_TOKENS: Record = { - "openai:text-embedding-3-small": 8192, - "openai:text-embedding-3-large": 8192, - "openai:text-embedding-ada-002": 8191, - "gemini:text-embedding-004": 2048, - "gemini:gemini-embedding-001": 2048, - "gemini:gemini-embedding-2-preview": 8192, - "voyage:voyage-3": 32000, - "voyage:voyage-3-lite": 16000, - "voyage:voyage-code-3": 32000, -}; - export function resolveEmbeddingMaxInputTokens(provider: EmbeddingProvider): number { if (typeof provider.maxInputTokens === "number") { return provider.maxInputTokens; } - // Provider/model mapping is best-effort; different providers use different - // limits and we prefer to be conservative when we don't know. - const key = normalizeLowercaseStringOrEmpty(`${provider.id}:${provider.model}`); - const known = KNOWN_EMBEDDING_MAX_INPUT_TOKENS[key]; - if (typeof known === "number") { - return known; - } - - // Provider-specific conservative fallbacks. This prevents us from accidentally - // using the OpenAI default for providers with much smaller limits. - if (normalizeLowercaseStringOrEmpty(provider.id) === "gemini") { - return 2048; - } - if (normalizeLowercaseStringOrEmpty(provider.id) === "local") { + if (provider.id === "local") { return DEFAULT_LOCAL_EMBEDDING_MAX_INPUT_TOKENS; } diff --git a/packages/memory-host-sdk/src/host/embeddings-bedrock.test.ts b/packages/memory-host-sdk/src/host/embeddings-bedrock.test.ts deleted file mode 100644 index 71228daad5f..00000000000 --- a/packages/memory-host-sdk/src/host/embeddings-bedrock.test.ts +++ /dev/null @@ -1,377 +0,0 @@ -import { afterEach, beforeAll, beforeEach, describe, expect, it, vi } from "vitest"; - -const { defaultProviderMock, resolveCredentialsMock, sendMock } = vi.hoisted(() => ({ - defaultProviderMock: vi.fn(), - resolveCredentialsMock: vi.fn(), - sendMock: vi.fn(), -})); - -vi.mock("@aws-sdk/client-bedrock-runtime", () => { - class MockClient { - region: string; - constructor(config: { region: string }) { - this.region = config.region; - } - send = sendMock; - } - class MockCommand { - input: unknown; - constructor(input: unknown) { - this.input = input; - } - } - return { BedrockRuntimeClient: MockClient, InvokeModelCommand: MockCommand }; -}); - -vi.mock("@aws-sdk/credential-provider-node", () => ({ - defaultProvider: defaultProviderMock.mockImplementation(() => resolveCredentialsMock), -})); - -let createBedrockEmbeddingProvider: typeof import("./embeddings-bedrock.js").createBedrockEmbeddingProvider; -let resolveBedrockEmbeddingClient: typeof import("./embeddings-bedrock.js").resolveBedrockEmbeddingClient; -let normalizeBedrockEmbeddingModel: typeof import("./embeddings-bedrock.js").normalizeBedrockEmbeddingModel; -let hasAwsCredentials: typeof import("./embeddings-bedrock.js").hasAwsCredentials; - -beforeAll(async () => { - ({ - createBedrockEmbeddingProvider, - resolveBedrockEmbeddingClient, - normalizeBedrockEmbeddingModel, - hasAwsCredentials, - } = await import("./embeddings-bedrock.js")); -}); - -beforeEach(() => { - defaultProviderMock.mockImplementation(() => resolveCredentialsMock); -}); - -const enc = (body: unknown) => ({ body: new TextEncoder().encode(JSON.stringify(body)) }); -const reqBody = (i = 0): Record => - JSON.parse(sendMock.mock.calls[i][0].input.body); - -describe("bedrock embedding provider", () => { - const originalEnv = process.env; - afterEach(() => { - process.env = originalEnv; - vi.restoreAllMocks(); - defaultProviderMock.mockClear(); - resolveCredentialsMock.mockReset(); - sendMock.mockReset(); - }); - - // --- Normalization --- - - it("normalizes model names with prefixes", () => { - expect(normalizeBedrockEmbeddingModel("bedrock/amazon.titan-embed-text-v2:0")).toBe( - "amazon.titan-embed-text-v2:0", - ); - expect(normalizeBedrockEmbeddingModel("amazon-bedrock/cohere.embed-english-v3")).toBe( - "cohere.embed-english-v3", - ); - expect(normalizeBedrockEmbeddingModel("")).toBe("amazon.titan-embed-text-v2:0"); - }); - - // --- Client resolution --- - - it("resolves region from env", () => { - process.env = { ...originalEnv, AWS_REGION: "eu-west-1" }; - const c = resolveBedrockEmbeddingClient({ - config: {} as never, - provider: "bedrock", - model: "amazon.titan-embed-text-v2:0", - fallback: "none", - }); - expect(c.region).toBe("eu-west-1"); - expect(c.dimensions).toBe(1024); - }); - - it("defaults to us-east-1", () => { - process.env = { ...originalEnv }; - delete process.env.AWS_REGION; - delete process.env.AWS_DEFAULT_REGION; - expect( - resolveBedrockEmbeddingClient({ - config: {} as never, - provider: "bedrock", - model: "amazon.titan-embed-text-v2:0", - fallback: "none", - }).region, - ).toBe("us-east-1"); - }); - - it("extracts region from baseUrl", () => { - process.env = { ...originalEnv }; - delete process.env.AWS_REGION; - const c = resolveBedrockEmbeddingClient({ - config: { - models: { - providers: { - "amazon-bedrock": { baseUrl: "https://bedrock-runtime.ap-southeast-2.amazonaws.com" }, - }, - }, - } as never, - provider: "bedrock", - model: "amazon.titan-embed-text-v2:0", - fallback: "none", - }); - expect(c.region).toBe("ap-southeast-2"); - }); - - it("validates dimensions", () => { - expect(() => - resolveBedrockEmbeddingClient({ - config: {} as never, - provider: "bedrock", - model: "amazon.titan-embed-text-v2:0", - fallback: "none", - outputDimensionality: 768, - }), - ).toThrow("Invalid dimensions 768"); - }); - - it("accepts valid dimensions", () => { - expect( - resolveBedrockEmbeddingClient({ - config: {} as never, - provider: "bedrock", - model: "amazon.titan-embed-text-v2:0", - fallback: "none", - outputDimensionality: 256, - }).dimensions, - ).toBe(256); - }); - - it("resolves throughput-suffixed variants", () => { - expect( - resolveBedrockEmbeddingClient({ - config: {} as never, - provider: "bedrock", - model: "amazon.titan-embed-text-v1:2:8k", - fallback: "none", - }).dimensions, - ).toBe(1536); - }); - - // --- Credential detection --- - - it("detects access keys", async () => { - await expect( - hasAwsCredentials({ - AWS_ACCESS_KEY_ID: "A", - AWS_SECRET_ACCESS_KEY: "s", - } as NodeJS.ProcessEnv), - ).resolves.toBe(true); - }); - it("detects profile", async () => { - await expect(hasAwsCredentials({ AWS_PROFILE: "default" } as NodeJS.ProcessEnv)).resolves.toBe( - true, - ); - }); - it("detects ECS task role", async () => { - await expect( - hasAwsCredentials({ AWS_CONTAINER_CREDENTIALS_RELATIVE_URI: "/v2" } as NodeJS.ProcessEnv), - ).resolves.toBe(true); - }); - it("detects EKS IRSA", async () => { - await expect( - hasAwsCredentials({ - AWS_WEB_IDENTITY_TOKEN_FILE: "/var/run/secrets/token", - AWS_ROLE_ARN: "arn:aws:iam::123:role/x", - } as NodeJS.ProcessEnv), - ).resolves.toBe(true); - }); - it("detects credentials via the AWS SDK default provider chain", async () => { - resolveCredentialsMock.mockResolvedValue({ accessKeyId: "AKIAEXAMPLE" }); - await expect(hasAwsCredentials({} as NodeJS.ProcessEnv)).resolves.toBe(true); - expect(defaultProviderMock).toHaveBeenCalledWith({ timeout: 1000, maxRetries: 0 }); - }); - it("returns false with no creds", async () => { - resolveCredentialsMock.mockRejectedValue(new Error("no aws credentials")); - await expect(hasAwsCredentials({} as NodeJS.ProcessEnv)).resolves.toBe(false); - }); - - // --- Titan V2 --- - - it("embeds with Titan V2", async () => { - sendMock.mockResolvedValue(enc({ embedding: [0.1, 0.2, 0.3] })); - const { provider } = await createBedrockEmbeddingProvider({ - config: {} as never, - provider: "bedrock", - model: "amazon.titan-embed-text-v2:0", - fallback: "none", - }); - expect(await provider.embedQuery("test")).toHaveLength(3); - expect(reqBody()).toMatchObject({ inputText: "test", normalize: true, dimensions: 1024 }); - }); - - it("returns empty for blank text", async () => { - const { provider } = await createBedrockEmbeddingProvider({ - config: {} as never, - provider: "bedrock", - model: "amazon.titan-embed-text-v2:0", - fallback: "none", - }); - expect(await provider.embedQuery(" ")).toEqual([]); - expect(sendMock).not.toHaveBeenCalled(); - }); - - it("batches Titan V2 concurrently", async () => { - sendMock - .mockResolvedValueOnce(enc({ embedding: [0.1] })) - .mockResolvedValueOnce(enc({ embedding: [0.2] })); - const { provider } = await createBedrockEmbeddingProvider({ - config: {} as never, - provider: "bedrock", - model: "amazon.titan-embed-text-v2:0", - fallback: "none", - }); - expect(await provider.embedBatch(["a", "b"])).toHaveLength(2); - expect(sendMock).toHaveBeenCalledTimes(2); - }); - - // --- Titan V1 --- - - it("sends only inputText for Titan V1", async () => { - sendMock.mockResolvedValue(enc({ embedding: [0.5] })); - const { provider } = await createBedrockEmbeddingProvider({ - config: {} as never, - provider: "bedrock", - model: "amazon.titan-embed-text-v1", - fallback: "none", - }); - await provider.embedQuery("hi"); - expect(reqBody()).toEqual({ inputText: "hi" }); - }); - - it("handles Titan G1 text variant", async () => { - sendMock.mockResolvedValue(enc({ embedding: [0.1] })); - const { provider } = await createBedrockEmbeddingProvider({ - config: {} as never, - provider: "bedrock", - model: "amazon.titan-embed-g1-text-02", - fallback: "none", - }); - await provider.embedQuery("hi"); - expect(reqBody()).toEqual({ inputText: "hi" }); - }); - - // --- Cohere V3 --- - - it("embeds Cohere V3 batch in single call", async () => { - sendMock.mockResolvedValue(enc({ embeddings: [[0.1], [0.2]] })); - const { provider } = await createBedrockEmbeddingProvider({ - config: {} as never, - provider: "bedrock", - model: "cohere.embed-english-v3", - fallback: "none", - }); - expect(await provider.embedBatch(["a", "b"])).toHaveLength(2); - expect(sendMock).toHaveBeenCalledTimes(1); - expect(reqBody()).toMatchObject({ texts: ["a", "b"], input_type: "search_document" }); - }); - - it("uses search_query for Cohere embedQuery", async () => { - sendMock.mockResolvedValue(enc({ embeddings: [[0.1]] })); - const { provider } = await createBedrockEmbeddingProvider({ - config: {} as never, - provider: "bedrock", - model: "cohere.embed-english-v3", - fallback: "none", - }); - await provider.embedQuery("q"); - expect(reqBody().input_type).toBe("search_query"); - }); - - // --- Cohere V4 --- - - it("embeds Cohere V4 with embedding_types + output_dimension", async () => { - sendMock.mockResolvedValue(enc({ embeddings: { float: [[0.1], [0.2]] } })); - const { provider } = await createBedrockEmbeddingProvider({ - config: {} as never, - provider: "bedrock", - model: "cohere.embed-v4:0", - fallback: "none", - }); - expect(await provider.embedBatch(["a", "b"])).toHaveLength(2); - expect(reqBody()).toMatchObject({ embedding_types: ["float"], output_dimension: 1536 }); - }); - - it("validates Cohere V4 dimensions", () => { - expect(() => - resolveBedrockEmbeddingClient({ - config: {} as never, - provider: "bedrock", - model: "cohere.embed-v4:0", - fallback: "none", - outputDimensionality: 2048, - }), - ).toThrow("Invalid dimensions 2048"); - }); - - // --- Nova --- - - it("embeds Nova with SINGLE_EMBEDDING format", async () => { - sendMock.mockResolvedValue( - enc({ embeddings: [{ embeddingType: "TEXT", embedding: [0.1, 0.2] }] }), - ); - const { provider } = await createBedrockEmbeddingProvider({ - config: {} as never, - provider: "bedrock", - model: "amazon.nova-2-multimodal-embeddings-v1:0", - fallback: "none", - }); - expect(await provider.embedQuery("hi")).toHaveLength(2); - expect(reqBody().taskType).toBe("SINGLE_EMBEDDING"); - }); - - it("validates Nova dimensions", () => { - expect(() => - resolveBedrockEmbeddingClient({ - config: {} as never, - provider: "bedrock", - model: "amazon.nova-2-multimodal-embeddings-v1:0", - fallback: "none", - outputDimensionality: 512, - }), - ).toThrow("Invalid dimensions 512"); - }); - - it("batches Nova concurrently", async () => { - sendMock - .mockResolvedValueOnce(enc({ embeddings: [{ embeddingType: "TEXT", embedding: [0.1] }] })) - .mockResolvedValueOnce(enc({ embeddings: [{ embeddingType: "TEXT", embedding: [0.2] }] })); - const { provider } = await createBedrockEmbeddingProvider({ - config: {} as never, - provider: "bedrock", - model: "amazon.nova-2-multimodal-embeddings-v1:0", - fallback: "none", - }); - expect(await provider.embedBatch(["a", "b"])).toHaveLength(2); - expect(sendMock).toHaveBeenCalledTimes(2); - }); - - // --- TwelveLabs --- - - it("embeds TwelveLabs Marengo", async () => { - sendMock.mockResolvedValue(enc({ data: [{ embedding: [0.1, 0.2] }] })); - const { provider } = await createBedrockEmbeddingProvider({ - config: {} as never, - provider: "bedrock", - model: "twelvelabs.marengo-embed-3-0-v1:0", - fallback: "none", - }); - expect(await provider.embedQuery("hi")).toHaveLength(2); - expect(reqBody()).toEqual({ inputType: "text", text: { inputText: "hi" } }); - }); - - it("embeds TwelveLabs object-style responses", async () => { - sendMock.mockResolvedValue(enc({ data: { embedding: [0.3, 0.4] } })); - const { provider } = await createBedrockEmbeddingProvider({ - config: {} as never, - provider: "bedrock", - model: "twelvelabs.marengo-embed-2-7-v1:0", - fallback: "none", - }); - expect(await provider.embedQuery("hi")).toEqual([0.6, 0.8]); - }); -}); diff --git a/packages/memory-host-sdk/src/host/embeddings-bedrock.ts b/packages/memory-host-sdk/src/host/embeddings-bedrock.ts deleted file mode 100644 index ba05010f4d6..00000000000 --- a/packages/memory-host-sdk/src/host/embeddings-bedrock.ts +++ /dev/null @@ -1,398 +0,0 @@ -import { normalizeLowercaseStringOrEmpty } from "../../../../src/shared/string-coerce.js"; -import { sanitizeAndNormalizeEmbedding } from "./embedding-vectors.js"; -import { debugEmbeddingsLog } from "./embeddings-debug.js"; -import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js"; - -// --------------------------------------------------------------------------- -// Types & constants -// --------------------------------------------------------------------------- - -export type BedrockEmbeddingClient = { - region: string; - model: string; - dimensions?: number; -}; - -export const DEFAULT_BEDROCK_EMBEDDING_MODEL = "amazon.titan-embed-text-v2:0"; - -/** Request/response format family — each has a different API shape. */ -type Family = "titan-v1" | "titan-v2" | "cohere-v3" | "cohere-v4" | "nova" | "twelvelabs"; - -interface ModelSpec { - maxTokens: number; - dims: number; - validDims?: number[]; - family: Family; -} - -// --------------------------------------------------------------------------- -// Model catalog -// --------------------------------------------------------------------------- - -const MODELS: Record = { - "amazon.titan-embed-text-v2:0": { - maxTokens: 8192, - dims: 1024, - validDims: [256, 512, 1024], - family: "titan-v2", - }, - "amazon.titan-embed-text-v1": { maxTokens: 8000, dims: 1536, family: "titan-v1" }, - "amazon.titan-embed-g1-text-02": { maxTokens: 8000, dims: 1536, family: "titan-v1" }, - "amazon.titan-embed-image-v1": { maxTokens: 128, dims: 1024, family: "titan-v1" }, - "cohere.embed-english-v3": { maxTokens: 512, dims: 1024, family: "cohere-v3" }, - "cohere.embed-multilingual-v3": { maxTokens: 512, dims: 1024, family: "cohere-v3" }, - "cohere.embed-v4:0": { - maxTokens: 128000, - dims: 1536, - validDims: [256, 384, 512, 768, 1024, 1536], - family: "cohere-v4", - }, - "amazon.nova-2-multimodal-embeddings-v1:0": { - maxTokens: 8192, - dims: 1024, - validDims: [256, 384, 1024, 3072], - family: "nova", - }, - "twelvelabs.marengo-embed-2-7-v1:0": { maxTokens: 512, dims: 1024, family: "twelvelabs" }, - "twelvelabs.marengo-embed-3-0-v1:0": { maxTokens: 512, dims: 512, family: "twelvelabs" }, -}; - -/** Resolve spec, stripping throughput suffixes like `:2:8k` or `:0:512`. */ -function resolveSpec(modelId: string): ModelSpec | undefined { - if (MODELS[modelId]) { - return MODELS[modelId]; - } - const parts = modelId.split(":"); - for (let i = parts.length - 1; i >= 1; i--) { - const spec = MODELS[parts.slice(0, i).join(":")]; - if (spec) { - return spec; - } - } - return undefined; -} - -/** Infer family from model ID prefix when not in catalog. */ -function inferFamily(modelId: string): Family { - const id = normalizeLowercaseStringOrEmpty(modelId); - if (id.startsWith("amazon.titan-embed-text-v2")) { - return "titan-v2"; - } - if (id.startsWith("amazon.titan-embed")) { - return "titan-v1"; - } - if (id.startsWith("amazon.nova")) { - return "nova"; - } - if (id.startsWith("cohere.embed-v4")) { - return "cohere-v4"; - } - if (id.startsWith("cohere.embed")) { - return "cohere-v3"; - } - if (id.startsWith("twelvelabs.")) { - return "twelvelabs"; - } - return "titan-v1"; // safest default — simplest request format -} - -// --------------------------------------------------------------------------- -// AWS SDK lazy loader -// --------------------------------------------------------------------------- - -type SdkClient = import("@aws-sdk/client-bedrock-runtime").BedrockRuntimeClient; -type SdkCommand = import("@aws-sdk/client-bedrock-runtime").InvokeModelCommand; - -interface AwsSdk { - BedrockRuntimeClient: new (config: { region: string }) => SdkClient; - InvokeModelCommand: new (input: { - modelId: string; - body: string; - contentType: string; - accept: string; - }) => SdkCommand; -} - -interface AwsCredentialProviderSdk { - defaultProvider: (init?: { timeout?: number; maxRetries?: number }) => () => Promise<{ - accessKeyId?: string; - }>; -} - -let sdkCache: AwsSdk | null = null; -let credentialProviderSdkCache: AwsCredentialProviderSdk | null | undefined; - -async function loadSdk(): Promise { - if (sdkCache) { - return sdkCache; - } - try { - sdkCache = (await import("@aws-sdk/client-bedrock-runtime")) as unknown as AwsSdk; - return sdkCache; - } catch { - throw new Error( - "No API key found for provider bedrock: @aws-sdk/client-bedrock-runtime is not installed. " + - "Install it with: npm install @aws-sdk/client-bedrock-runtime", - ); - } -} - -async function loadCredentialProviderSdk(): Promise { - if (credentialProviderSdkCache !== undefined) { - return credentialProviderSdkCache; - } - try { - credentialProviderSdkCache = - (await import("@aws-sdk/credential-provider-node")) as unknown as AwsCredentialProviderSdk; - } catch { - credentialProviderSdkCache = null; - } - return credentialProviderSdkCache; -} - -// --------------------------------------------------------------------------- -// Helpers -// --------------------------------------------------------------------------- - -const MODEL_PREFIX_RE = /^(?:bedrock|amazon-bedrock|aws)\//; -const REGION_RE = /bedrock-runtime\.([a-z0-9-]+)\./; - -export function normalizeBedrockEmbeddingModel(model: string): string { - const trimmed = model.trim(); - return trimmed ? trimmed.replace(MODEL_PREFIX_RE, "") : DEFAULT_BEDROCK_EMBEDDING_MODEL; -} - -function regionFromUrl(url: string | undefined): string | undefined { - return url?.trim() ? REGION_RE.exec(url)?.[1] : undefined; -} - -// --------------------------------------------------------------------------- -// Request builders -// --------------------------------------------------------------------------- - -function buildBody(family: Family, text: string, dims?: number): string { - switch (family) { - case "titan-v2": { - const b: Record = { inputText: text }; - if (dims != null) { - b.dimensions = dims; - b.normalize = true; - } - return JSON.stringify(b); - } - case "titan-v1": - return JSON.stringify({ inputText: text }); - case "nova": - return JSON.stringify({ - taskType: "SINGLE_EMBEDDING", - singleEmbeddingParams: { - embeddingPurpose: "GENERIC_INDEX", - embeddingDimension: dims ?? 1024, - text: { truncationMode: "END", value: text }, - }, - }); - case "twelvelabs": - return JSON.stringify({ inputType: "text", text: { inputText: text } }); - default: - return JSON.stringify({ inputText: text }); - } -} - -function buildCohereBody( - family: Family, - texts: string[], - inputType: "search_query" | "search_document", - dims?: number, -): string { - const body: Record = { texts, input_type: inputType, truncate: "END" }; - if (family === "cohere-v4") { - body.embedding_types = ["float"]; - if (dims != null) { - body.output_dimension = dims; - } - } - return JSON.stringify(body); -} - -// --------------------------------------------------------------------------- -// Response parsers -// --------------------------------------------------------------------------- - -function parseSingle(family: Family, raw: string): number[] { - const data = JSON.parse(raw); - switch (family) { - case "nova": - return data.embeddings?.[0]?.embedding ?? []; - case "twelvelabs": { - if (Array.isArray(data.data)) { - return data.data[0]?.embedding ?? []; - } - if (Array.isArray(data.data?.embedding)) { - return data.data.embedding; - } - return data.embedding ?? []; - } - default: - return data.embedding ?? []; - } -} - -function parseCohereBatch(family: Family, raw: string): number[][] { - const data = JSON.parse(raw); - const embeddings = data.embeddings; - if (!embeddings) { - return []; - } - if (family === "cohere-v4" && !Array.isArray(embeddings)) { - return embeddings.float ?? []; - } - return embeddings; -} - -// --------------------------------------------------------------------------- -// Provider -// --------------------------------------------------------------------------- - -export async function createBedrockEmbeddingProvider( - options: EmbeddingProviderOptions, -): Promise<{ provider: EmbeddingProvider; client: BedrockEmbeddingClient }> { - const client = resolveBedrockEmbeddingClient(options); - const { BedrockRuntimeClient, InvokeModelCommand } = await loadSdk(); - const sdk = new BedrockRuntimeClient({ region: client.region }); - const spec = resolveSpec(client.model); - const family = spec?.family ?? inferFamily(client.model); - - debugEmbeddingsLog("memory embeddings: bedrock client", { - region: client.region, - model: client.model, - dimensions: client.dimensions, - family, - }); - - const invoke = async (body: string): Promise => { - const res = await sdk.send( - new InvokeModelCommand({ - modelId: client.model, - body, - contentType: "application/json", - accept: "application/json", - }), - ); - return new TextDecoder().decode(res.body); - }; - - const isCohere = family === "cohere-v3" || family === "cohere-v4"; - - const embedSingle = async (text: string): Promise => { - const raw = await invoke(buildBody(family, text, client.dimensions)); - return sanitizeAndNormalizeEmbedding(parseSingle(family, raw)); - }; - - const embedCohere = async ( - texts: string[], - inputType: "search_query" | "search_document", - ): Promise => { - const raw = await invoke(buildCohereBody(family, texts, inputType, client.dimensions)); - return parseCohereBatch(family, raw).map((e) => sanitizeAndNormalizeEmbedding(e)); - }; - - const embedQuery = async (text: string): Promise => { - if (!text.trim()) { - return []; - } - if (isCohere) { - return (await embedCohere([text], "search_query"))[0] ?? []; - } - return embedSingle(text); - }; - - const embedBatch = async (texts: string[]): Promise => { - if (texts.length === 0) { - return []; - } - if (isCohere) { - return embedCohere(texts, "search_document"); - } - return Promise.all(texts.map((t) => (t.trim() ? embedSingle(t) : Promise.resolve([])))); - }; - - return { - provider: { - id: "bedrock", - model: client.model, - maxInputTokens: spec?.maxTokens, - embedQuery, - embedBatch, - }, - client, - }; -} - -// --------------------------------------------------------------------------- -// Client resolution -// --------------------------------------------------------------------------- - -export function resolveBedrockEmbeddingClient( - options: EmbeddingProviderOptions, -): BedrockEmbeddingClient { - const model = normalizeBedrockEmbeddingModel(options.model); - const spec = resolveSpec(model); - const providerConfig = options.config.models?.providers?.["amazon-bedrock"]; - - const region = - regionFromUrl(options.remote?.baseUrl) ?? - regionFromUrl(providerConfig?.baseUrl) ?? - process.env.AWS_REGION ?? - process.env.AWS_DEFAULT_REGION ?? - "us-east-1"; - - let dimensions: number | undefined; - if (options.outputDimensionality != null) { - if (spec?.validDims && !spec.validDims.includes(options.outputDimensionality)) { - throw new Error( - `Invalid dimensions ${options.outputDimensionality} for ${model}. Valid values: ${spec.validDims.join(", ")}`, - ); - } - dimensions = options.outputDimensionality; - } else { - dimensions = spec?.dims; - } - - return { region, model, dimensions }; -} - -// --------------------------------------------------------------------------- -// Credential detection -// --------------------------------------------------------------------------- - -const CREDENTIAL_ENV_VARS = [ - "AWS_PROFILE", - "AWS_BEARER_TOKEN_BEDROCK", - "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", - "AWS_CONTAINER_CREDENTIALS_FULL_URI", - "AWS_EC2_METADATA_SERVICE_ENDPOINT", - "AWS_WEB_IDENTITY_TOKEN_FILE", - "AWS_ROLE_ARN", -] as const; - -export async function hasAwsCredentials(env: NodeJS.ProcessEnv = process.env): Promise { - if (env.AWS_ACCESS_KEY_ID?.trim() && env.AWS_SECRET_ACCESS_KEY?.trim()) { - return true; - } - if (CREDENTIAL_ENV_VARS.some((k) => env[k]?.trim())) { - return true; - } - const credentialProviderSdk = await loadCredentialProviderSdk(); - if (!credentialProviderSdk) { - return false; - } - try { - const credentials = await credentialProviderSdk.defaultProvider({ - timeout: 1000, - maxRetries: 0, - })(); - return typeof credentials.accessKeyId === "string" && credentials.accessKeyId.trim().length > 0; - } catch { - return false; - } -} diff --git a/packages/memory-host-sdk/src/host/embeddings-gemini-request.ts b/packages/memory-host-sdk/src/host/embeddings-gemini-request.ts deleted file mode 100644 index 887376bbad3..00000000000 --- a/packages/memory-host-sdk/src/host/embeddings-gemini-request.ts +++ /dev/null @@ -1,121 +0,0 @@ -import type { EmbeddingInput } from "./embedding-inputs.js"; - -export const DEFAULT_GEMINI_EMBEDDING_MODEL = "gemini-embedding-001"; - -export const GEMINI_EMBEDDING_2_MODELS = new Set([ - "gemini-embedding-2-preview", - // Add the GA model name here once released. -]); - -const GEMINI_EMBEDDING_2_DEFAULT_DIMENSIONS = 3072; -const GEMINI_EMBEDDING_2_VALID_DIMENSIONS = [768, 1536, 3072] as const; - -export type GeminiTaskType = - | "RETRIEVAL_QUERY" - | "RETRIEVAL_DOCUMENT" - | "SEMANTIC_SIMILARITY" - | "CLASSIFICATION" - | "CLUSTERING" - | "QUESTION_ANSWERING" - | "FACT_VERIFICATION"; - -export type GeminiTextPart = { text: string }; -export type GeminiInlinePart = { - inlineData: { mimeType: string; data: string }; -}; -export type GeminiPart = GeminiTextPart | GeminiInlinePart; -export type GeminiEmbeddingRequest = { - content: { parts: GeminiPart[] }; - taskType: GeminiTaskType; - outputDimensionality?: number; - model?: string; -}; -export type GeminiTextEmbeddingRequest = GeminiEmbeddingRequest; - -/** Builds the text-only Gemini embedding request shape used across direct and batch APIs. */ -export function buildGeminiTextEmbeddingRequest(params: { - text: string; - taskType: GeminiTaskType; - outputDimensionality?: number; - modelPath?: string; -}): GeminiTextEmbeddingRequest { - return buildGeminiEmbeddingRequest({ - input: { text: params.text }, - taskType: params.taskType, - outputDimensionality: params.outputDimensionality, - modelPath: params.modelPath, - }); -} - -export function buildGeminiEmbeddingRequest(params: { - input: EmbeddingInput; - taskType: GeminiTaskType; - outputDimensionality?: number; - modelPath?: string; -}): GeminiEmbeddingRequest { - const request: GeminiEmbeddingRequest = { - content: { - parts: params.input.parts?.map((part) => - part.type === "text" - ? ({ text: part.text } satisfies GeminiTextPart) - : ({ - inlineData: { mimeType: part.mimeType, data: part.data }, - } satisfies GeminiInlinePart), - ) ?? [{ text: params.input.text }], - }, - taskType: params.taskType, - }; - if (params.modelPath) { - request.model = params.modelPath; - } - if (params.outputDimensionality != null) { - request.outputDimensionality = params.outputDimensionality; - } - return request; -} - -/** - * Returns true if the given model name is a gemini-embedding-2 variant that - * supports `outputDimensionality` and extended task types. - */ -export function isGeminiEmbedding2Model(model: string): boolean { - return GEMINI_EMBEDDING_2_MODELS.has(model); -} - -/** - * Validate and return the `outputDimensionality` for gemini-embedding-2 models. - * Returns `undefined` for older models (they don't support the param). - */ -export function resolveGeminiOutputDimensionality( - model: string, - requested?: number, -): number | undefined { - if (!isGeminiEmbedding2Model(model)) { - return undefined; - } - if (requested == null) { - return GEMINI_EMBEDDING_2_DEFAULT_DIMENSIONS; - } - const valid: readonly number[] = GEMINI_EMBEDDING_2_VALID_DIMENSIONS; - if (!valid.includes(requested)) { - throw new Error( - `Invalid outputDimensionality ${requested} for ${model}. Valid values: ${valid.join(", ")}`, - ); - } - return requested; -} - -export function normalizeGeminiModel(model: string): string { - const trimmed = model.trim(); - if (!trimmed) { - return DEFAULT_GEMINI_EMBEDDING_MODEL; - } - const withoutPrefix = trimmed.replace(/^models\//, ""); - if (withoutPrefix.startsWith("gemini/")) { - return withoutPrefix.slice("gemini/".length); - } - if (withoutPrefix.startsWith("google/")) { - return withoutPrefix.slice("google/".length); - } - return withoutPrefix; -} diff --git a/packages/memory-host-sdk/src/host/embeddings-gemini.test.ts b/packages/memory-host-sdk/src/host/embeddings-gemini.test.ts deleted file mode 100644 index 06c804c0bf1..00000000000 --- a/packages/memory-host-sdk/src/host/embeddings-gemini.test.ts +++ /dev/null @@ -1,52 +0,0 @@ -import { describe, expect, it } from "vitest"; -import { - buildGeminiEmbeddingRequest, - DEFAULT_GEMINI_EMBEDDING_MODEL, - normalizeGeminiModel, - resolveGeminiOutputDimensionality, -} from "./embeddings-gemini-request.js"; - -describe("package Gemini embedding request helpers", () => { - it("builds multimodal v2 requests and resolves model settings", () => { - expect( - buildGeminiEmbeddingRequest({ - input: { - text: "Image file: diagram.png", - parts: [ - { type: "text", text: "Image file: diagram.png" }, - { type: "inline-data", mimeType: "image/png", data: "abc123" }, - ], - }, - taskType: "RETRIEVAL_DOCUMENT", - modelPath: "models/gemini-embedding-2-preview", - outputDimensionality: 1536, - }), - ).toEqual({ - model: "models/gemini-embedding-2-preview", - content: { - parts: [ - { text: "Image file: diagram.png" }, - { inlineData: { mimeType: "image/png", data: "abc123" } }, - ], - }, - taskType: "RETRIEVAL_DOCUMENT", - outputDimensionality: 1536, - }); - expect(resolveGeminiOutputDimensionality("gemini-embedding-001")).toBeUndefined(); - expect(resolveGeminiOutputDimensionality("gemini-embedding-2-preview")).toBe(3072); - expect(resolveGeminiOutputDimensionality("gemini-embedding-2-preview", 768)).toBe(768); - expect(() => resolveGeminiOutputDimensionality("gemini-embedding-2-preview", 512)).toThrow( - /Invalid outputDimensionality 512/, - ); - expect(normalizeGeminiModel("models/gemini-embedding-2-preview")).toBe( - "gemini-embedding-2-preview", - ); - expect(normalizeGeminiModel("gemini/gemini-embedding-2-preview")).toBe( - "gemini-embedding-2-preview", - ); - expect(normalizeGeminiModel("google/gemini-embedding-2-preview")).toBe( - "gemini-embedding-2-preview", - ); - expect(normalizeGeminiModel("")).toBe(DEFAULT_GEMINI_EMBEDDING_MODEL); - }); -}); diff --git a/packages/memory-host-sdk/src/host/embeddings-gemini.ts b/packages/memory-host-sdk/src/host/embeddings-gemini.ts deleted file mode 100644 index ed6bdaca899..00000000000 --- a/packages/memory-host-sdk/src/host/embeddings-gemini.ts +++ /dev/null @@ -1,238 +0,0 @@ -import { - collectProviderApiKeysForExecution, - executeWithApiKeyRotation, -} from "../../../../src/agents/api-key-rotation.js"; -import { requireApiKey, resolveApiKeyForProvider } from "../../../../src/agents/model-auth.js"; -import { parseGeminiAuth } from "../../../../src/infra/gemini-auth.js"; -import { - DEFAULT_GOOGLE_API_BASE_URL, - normalizeGoogleApiBaseUrl, -} from "../../../../src/infra/google-api-base-url.js"; -import type { SsrFPolicy } from "../../../../src/infra/net/ssrf.js"; -import type { EmbeddingInput } from "./embedding-inputs.js"; -import { sanitizeAndNormalizeEmbedding } from "./embedding-vectors.js"; -import { debugEmbeddingsLog } from "./embeddings-debug.js"; -import { - buildGeminiEmbeddingRequest, - buildGeminiTextEmbeddingRequest, - isGeminiEmbedding2Model, - normalizeGeminiModel, - resolveGeminiOutputDimensionality, -} from "./embeddings-gemini-request.js"; -import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js"; -import { buildRemoteBaseUrlPolicy, withRemoteHttpResponse } from "./remote-http.js"; -import { resolveMemorySecretInputString } from "./secret-input.js"; - -export { - buildGeminiEmbeddingRequest, - buildGeminiTextEmbeddingRequest, - DEFAULT_GEMINI_EMBEDDING_MODEL, - GEMINI_EMBEDDING_2_MODELS, - isGeminiEmbedding2Model, - normalizeGeminiModel, - resolveGeminiOutputDimensionality, - type GeminiEmbeddingRequest, - type GeminiInlinePart, - type GeminiPart, - type GeminiTaskType, - type GeminiTextEmbeddingRequest, - type GeminiTextPart, -} from "./embeddings-gemini-request.js"; - -export type GeminiEmbeddingClient = { - baseUrl: string; - headers: Record; - ssrfPolicy?: SsrFPolicy; - model: string; - modelPath: string; - apiKeys: string[]; - outputDimensionality?: number; -}; - -const GEMINI_MAX_INPUT_TOKENS: Record = { - "text-embedding-004": 2048, -}; -function resolveRemoteApiKey(remoteApiKey: unknown): string | undefined { - const trimmed = resolveMemorySecretInputString({ - value: remoteApiKey, - path: "agents.*.memorySearch.remote.apiKey", - }); - if (!trimmed) { - return undefined; - } - if (trimmed === "GOOGLE_API_KEY" || trimmed === "GEMINI_API_KEY") { - return process.env[trimmed]?.trim(); - } - return trimmed; -} - -async function fetchGeminiEmbeddingPayload(params: { - client: GeminiEmbeddingClient; - endpoint: string; - body: unknown; -}): Promise<{ - embedding?: { values?: number[] }; - embeddings?: Array<{ values?: number[] }>; -}> { - return await executeWithApiKeyRotation({ - provider: "google", - apiKeys: params.client.apiKeys, - execute: async (apiKey) => { - const authHeaders = parseGeminiAuth(apiKey); - const headers = { - ...authHeaders.headers, - ...params.client.headers, - }; - return await withRemoteHttpResponse({ - url: params.endpoint, - ssrfPolicy: params.client.ssrfPolicy, - init: { - method: "POST", - headers, - body: JSON.stringify(params.body), - }, - onResponse: async (res) => { - if (!res.ok) { - const text = await res.text(); - throw new Error(`gemini embeddings failed: ${res.status} ${text}`); - } - return (await res.json()) as { - embedding?: { values?: number[] }; - embeddings?: Array<{ values?: number[] }>; - }; - }, - }); - }, - }); -} - -function normalizeGeminiBaseUrl(raw: string): string { - const trimmed = raw.replace(/\/+$/, ""); - const openAiIndex = trimmed.indexOf("/openai"); - if (openAiIndex > -1) { - return normalizeGoogleApiBaseUrl(trimmed.slice(0, openAiIndex)); - } - return normalizeGoogleApiBaseUrl(trimmed); -} - -function buildGeminiModelPath(model: string): string { - return model.startsWith("models/") ? model : `models/${model}`; -} - -export async function createGeminiEmbeddingProvider( - options: EmbeddingProviderOptions, -): Promise<{ provider: EmbeddingProvider; client: GeminiEmbeddingClient }> { - const client = await resolveGeminiEmbeddingClient(options); - const baseUrl = client.baseUrl.replace(/\/$/, ""); - const embedUrl = `${baseUrl}/${client.modelPath}:embedContent`; - const batchUrl = `${baseUrl}/${client.modelPath}:batchEmbedContents`; - const isV2 = isGeminiEmbedding2Model(client.model); - const outputDimensionality = client.outputDimensionality; - - const embedQuery = async (text: string): Promise => { - if (!text.trim()) { - return []; - } - const payload = await fetchGeminiEmbeddingPayload({ - client, - endpoint: embedUrl, - body: buildGeminiTextEmbeddingRequest({ - text, - taskType: options.taskType ?? "RETRIEVAL_QUERY", - outputDimensionality: isV2 ? outputDimensionality : undefined, - }), - }); - return sanitizeAndNormalizeEmbedding(payload.embedding?.values ?? []); - }; - - const embedBatchInputs = async (inputs: EmbeddingInput[]): Promise => { - if (inputs.length === 0) { - return []; - } - const payload = await fetchGeminiEmbeddingPayload({ - client, - endpoint: batchUrl, - body: { - requests: inputs.map((input) => - buildGeminiEmbeddingRequest({ - input, - modelPath: client.modelPath, - taskType: options.taskType ?? "RETRIEVAL_DOCUMENT", - outputDimensionality: isV2 ? outputDimensionality : undefined, - }), - ), - }, - }); - const embeddings = Array.isArray(payload.embeddings) ? payload.embeddings : []; - return inputs.map((_, index) => sanitizeAndNormalizeEmbedding(embeddings[index]?.values ?? [])); - }; - - const embedBatch = async (texts: string[]): Promise => { - return await embedBatchInputs( - texts.map((text) => ({ - text, - })), - ); - }; - - return { - provider: { - id: "gemini", - model: client.model, - maxInputTokens: GEMINI_MAX_INPUT_TOKENS[client.model], - embedQuery, - embedBatch, - embedBatchInputs, - }, - client, - }; -} - -export async function resolveGeminiEmbeddingClient( - options: EmbeddingProviderOptions, -): Promise { - const remote = options.remote; - const remoteApiKey = resolveRemoteApiKey(remote?.apiKey); - const remoteBaseUrl = remote?.baseUrl?.trim(); - - const apiKey = remoteApiKey - ? remoteApiKey - : requireApiKey( - await resolveApiKeyForProvider({ - provider: "google", - cfg: options.config, - agentDir: options.agentDir, - }), - "google", - ); - - const providerConfig = options.config.models?.providers?.google; - const rawBaseUrl = - remoteBaseUrl || providerConfig?.baseUrl?.trim() || DEFAULT_GOOGLE_API_BASE_URL; - const baseUrl = normalizeGeminiBaseUrl(rawBaseUrl); - const ssrfPolicy = buildRemoteBaseUrlPolicy(baseUrl); - const headerOverrides = Object.assign({}, providerConfig?.headers, remote?.headers); - const headers: Record = { - ...headerOverrides, - }; - const apiKeys = collectProviderApiKeysForExecution({ - provider: "google", - primaryApiKey: apiKey, - }); - const model = normalizeGeminiModel(options.model); - const modelPath = buildGeminiModelPath(model); - const outputDimensionality = resolveGeminiOutputDimensionality( - model, - options.outputDimensionality, - ); - debugEmbeddingsLog("memory embeddings: gemini client", { - rawBaseUrl, - baseUrl, - model, - modelPath, - outputDimensionality, - embedEndpoint: `${baseUrl}/${modelPath}:embedContent`, - batchEndpoint: `${baseUrl}/${modelPath}:batchEmbedContents`, - }); - return { baseUrl, headers, ssrfPolicy, model, modelPath, apiKeys, outputDimensionality }; -} diff --git a/packages/memory-host-sdk/src/host/embeddings-lmstudio.ts b/packages/memory-host-sdk/src/host/embeddings-lmstudio.ts deleted file mode 100644 index 99cc6475868..00000000000 --- a/packages/memory-host-sdk/src/host/embeddings-lmstudio.ts +++ /dev/null @@ -1 +0,0 @@ -export * from "../../../../src/memory-host-sdk/host/embeddings-lmstudio.js"; diff --git a/packages/memory-host-sdk/src/host/embeddings-mistral.test.ts b/packages/memory-host-sdk/src/host/embeddings-mistral.test.ts deleted file mode 100644 index 7826cd35467..00000000000 --- a/packages/memory-host-sdk/src/host/embeddings-mistral.test.ts +++ /dev/null @@ -1,19 +0,0 @@ -import { describe, expect, it } from "vitest"; -import { DEFAULT_MISTRAL_EMBEDDING_MODEL, normalizeMistralModel } from "./embeddings-mistral.js"; - -describe("normalizeMistralModel", () => { - it("returns the default model for empty values", () => { - expect(normalizeMistralModel("")).toBe(DEFAULT_MISTRAL_EMBEDDING_MODEL); - expect(normalizeMistralModel(" ")).toBe(DEFAULT_MISTRAL_EMBEDDING_MODEL); - }); - - it("strips the mistral/ prefix", () => { - expect(normalizeMistralModel("mistral/mistral-embed")).toBe("mistral-embed"); - expect(normalizeMistralModel(" mistral/custom-embed ")).toBe("custom-embed"); - }); - - it("keeps explicit non-prefixed models", () => { - expect(normalizeMistralModel("mistral-embed")).toBe("mistral-embed"); - expect(normalizeMistralModel("custom-embed-v2")).toBe("custom-embed-v2"); - }); -}); diff --git a/packages/memory-host-sdk/src/host/embeddings-mistral.ts b/packages/memory-host-sdk/src/host/embeddings-mistral.ts deleted file mode 100644 index cde20d92556..00000000000 --- a/packages/memory-host-sdk/src/host/embeddings-mistral.ts +++ /dev/null @@ -1,51 +0,0 @@ -import type { SsrFPolicy } from "../../../../src/infra/net/ssrf.js"; -import { normalizeEmbeddingModelWithPrefixes } from "./embeddings-model-normalize.js"; -import { - createRemoteEmbeddingProvider, - resolveRemoteEmbeddingClient, -} from "./embeddings-remote-provider.js"; -import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js"; - -export type MistralEmbeddingClient = { - baseUrl: string; - headers: Record; - ssrfPolicy?: SsrFPolicy; - model: string; -}; - -export const DEFAULT_MISTRAL_EMBEDDING_MODEL = "mistral-embed"; -const DEFAULT_MISTRAL_BASE_URL = "https://api.mistral.ai/v1"; - -export function normalizeMistralModel(model: string): string { - return normalizeEmbeddingModelWithPrefixes({ - model, - defaultModel: DEFAULT_MISTRAL_EMBEDDING_MODEL, - prefixes: ["mistral/"], - }); -} - -export async function createMistralEmbeddingProvider( - options: EmbeddingProviderOptions, -): Promise<{ provider: EmbeddingProvider; client: MistralEmbeddingClient }> { - const client = await resolveMistralEmbeddingClient(options); - - return { - provider: createRemoteEmbeddingProvider({ - id: "mistral", - client, - errorPrefix: "mistral embeddings failed", - }), - client, - }; -} - -export async function resolveMistralEmbeddingClient( - options: EmbeddingProviderOptions, -): Promise { - return await resolveRemoteEmbeddingClient({ - provider: "mistral", - options, - defaultBaseUrl: DEFAULT_MISTRAL_BASE_URL, - normalizeModel: normalizeMistralModel, - }); -} diff --git a/packages/memory-host-sdk/src/host/embeddings-ollama.test.ts b/packages/memory-host-sdk/src/host/embeddings-ollama.test.ts deleted file mode 100644 index 77212567018..00000000000 --- a/packages/memory-host-sdk/src/host/embeddings-ollama.test.ts +++ /dev/null @@ -1,43 +0,0 @@ -import { beforeEach, describe, expect, it, vi } from "vitest"; - -const { createOllamaEmbeddingProviderMock } = vi.hoisted(() => ({ - createOllamaEmbeddingProviderMock: vi.fn(async (options: unknown) => ({ - provider: { source: "mock-provider", options }, - client: { source: "mock-client" }, - })), -})); - -vi.mock("../../../../src/plugin-sdk/ollama-runtime.js", () => ({ - DEFAULT_OLLAMA_EMBEDDING_MODEL: "nomic-embed-text", - createOllamaEmbeddingProvider: createOllamaEmbeddingProviderMock, -})); - -describe("memory-host-sdk Ollama embedding facade", () => { - beforeEach(() => { - createOllamaEmbeddingProviderMock.mockClear(); - }); - - it("re-exports the default Ollama embedding model", async () => { - const mod = await import("./embeddings-ollama.js"); - expect(mod.DEFAULT_OLLAMA_EMBEDDING_MODEL).toBe("nomic-embed-text"); - }); - - it("delegates provider creation to the plugin-sdk runtime facade", async () => { - const mod = await import("./embeddings-ollama.js"); - const options = { - provider: "ollama", - model: "nomic-embed-text", - fallback: "none", - config: {}, - }; - - const result = await mod.createOllamaEmbeddingProvider(options as never); - - expect(createOllamaEmbeddingProviderMock).toHaveBeenCalledTimes(1); - expect(createOllamaEmbeddingProviderMock).toHaveBeenCalledWith(options); - expect(result).toEqual({ - provider: { source: "mock-provider", options }, - client: { source: "mock-client" }, - }); - }); -}); diff --git a/packages/memory-host-sdk/src/host/embeddings-ollama.ts b/packages/memory-host-sdk/src/host/embeddings-ollama.ts deleted file mode 100644 index 23453bd7c53..00000000000 --- a/packages/memory-host-sdk/src/host/embeddings-ollama.ts +++ /dev/null @@ -1,5 +0,0 @@ -export type { OllamaEmbeddingClient } from "../../../../src/plugin-sdk/ollama-runtime.js"; -export { - createOllamaEmbeddingProvider, - DEFAULT_OLLAMA_EMBEDDING_MODEL, -} from "../../../../src/plugin-sdk/ollama-runtime.js"; diff --git a/packages/memory-host-sdk/src/host/embeddings-openai.ts b/packages/memory-host-sdk/src/host/embeddings-openai.ts deleted file mode 100644 index 4d045c863fd..00000000000 --- a/packages/memory-host-sdk/src/host/embeddings-openai.ts +++ /dev/null @@ -1,58 +0,0 @@ -import type { SsrFPolicy } from "../../../../src/infra/net/ssrf.js"; -import { OPENAI_DEFAULT_EMBEDDING_MODEL } from "../../../../src/plugins/provider-model-defaults.js"; -import { normalizeEmbeddingModelWithPrefixes } from "./embeddings-model-normalize.js"; -import { - createRemoteEmbeddingProvider, - resolveRemoteEmbeddingClient, -} from "./embeddings-remote-provider.js"; -import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js"; - -export type OpenAiEmbeddingClient = { - baseUrl: string; - headers: Record; - ssrfPolicy?: SsrFPolicy; - model: string; -}; - -const DEFAULT_OPENAI_BASE_URL = "https://api.openai.com/v1"; -export const DEFAULT_OPENAI_EMBEDDING_MODEL = OPENAI_DEFAULT_EMBEDDING_MODEL; -const OPENAI_MAX_INPUT_TOKENS: Record = { - "text-embedding-3-small": 8192, - "text-embedding-3-large": 8192, - "text-embedding-ada-002": 8191, -}; - -export function normalizeOpenAiModel(model: string): string { - return normalizeEmbeddingModelWithPrefixes({ - model, - defaultModel: DEFAULT_OPENAI_EMBEDDING_MODEL, - prefixes: ["openai/"], - }); -} - -export async function createOpenAiEmbeddingProvider( - options: EmbeddingProviderOptions, -): Promise<{ provider: EmbeddingProvider; client: OpenAiEmbeddingClient }> { - const client = await resolveOpenAiEmbeddingClient(options); - - return { - provider: createRemoteEmbeddingProvider({ - id: "openai", - client, - errorPrefix: "openai embeddings failed", - maxInputTokens: OPENAI_MAX_INPUT_TOKENS[client.model], - }), - client, - }; -} - -export async function resolveOpenAiEmbeddingClient( - options: EmbeddingProviderOptions, -): Promise { - return await resolveRemoteEmbeddingClient({ - provider: "openai", - options, - defaultBaseUrl: DEFAULT_OPENAI_BASE_URL, - normalizeModel: normalizeOpenAiModel, - }); -} diff --git a/packages/memory-host-sdk/src/host/embeddings-remote-client.ts b/packages/memory-host-sdk/src/host/embeddings-remote-client.ts index 01316cbd946..f5fd0b79920 100644 --- a/packages/memory-host-sdk/src/host/embeddings-remote-client.ts +++ b/packages/memory-host-sdk/src/host/embeddings-remote-client.ts @@ -4,7 +4,7 @@ import type { EmbeddingProviderOptions } from "./embeddings.js"; import { buildRemoteBaseUrlPolicy } from "./remote-http.js"; import { resolveMemorySecretInputString } from "./secret-input.js"; -export type RemoteEmbeddingProviderId = "openai" | "voyage" | "mistral"; +export type RemoteEmbeddingProviderId = string; export async function resolveRemoteEmbeddingBearerClient(params: { provider: RemoteEmbeddingProviderId; diff --git a/packages/memory-host-sdk/src/host/embeddings-voyage.test.ts b/packages/memory-host-sdk/src/host/embeddings-voyage.test.ts deleted file mode 100644 index 2615b8cdb5e..00000000000 --- a/packages/memory-host-sdk/src/host/embeddings-voyage.test.ts +++ /dev/null @@ -1,188 +0,0 @@ -import { afterEach, beforeAll, beforeEach, describe, expect, it, vi } from "vitest"; -import * as authModule from "../../../../src/agents/model-auth.js"; -import { type FetchMock, withFetchPreconnect } from "../../../../src/test-utils/fetch-mock.js"; -import { mockPublicPinnedHostname } from "./test-helpers/ssrf.js"; - -vi.mock("../../../../src/infra/net/fetch-guard.js", () => ({ - fetchWithSsrFGuard: async (params: { - url: string; - init?: RequestInit; - fetchImpl?: typeof fetch; - }) => { - const fetchImpl = params.fetchImpl ?? globalThis.fetch; - if (!fetchImpl) { - throw new Error("fetch is not available"); - } - const response = await fetchImpl(params.url, params.init); - return { - response, - finalUrl: params.url, - release: async () => {}, - }; - }, -})); - -const { resolveApiKeyForProviderMock } = vi.hoisted(() => ({ - resolveApiKeyForProviderMock: vi.fn(), -})); - -vi.mock("../../../../src/agents/model-auth.js", () => { - return { - resolveApiKeyForProvider: resolveApiKeyForProviderMock, - requireApiKey: (auth: { apiKey?: string; mode?: string }, provider: string) => { - if (auth.apiKey) { - return auth.apiKey; - } - throw new Error(`No API key resolved for provider "${provider}" (auth mode: ${auth.mode}).`); - }, - }; -}); - -const createFetchMock = () => { - const fetchMock = vi.fn( - async (_input: RequestInfo | URL, _init?: RequestInit) => - new Response(JSON.stringify({ data: [{ embedding: [0.1, 0.2, 0.3] }] }), { - status: 200, - headers: { "Content-Type": "application/json" }, - }), - ); - return withFetchPreconnect(fetchMock); -}; - -function installFetchMock(fetchMock: typeof globalThis.fetch) { - vi.stubGlobal("fetch", fetchMock); -} - -let createVoyageEmbeddingProvider: typeof import("./embeddings-voyage.js").createVoyageEmbeddingProvider; -let normalizeVoyageModel: typeof import("./embeddings-voyage.js").normalizeVoyageModel; - -beforeAll(async () => { - ({ createVoyageEmbeddingProvider, normalizeVoyageModel } = - await import("./embeddings-voyage.js")); -}); - -beforeEach(() => { - vi.useRealTimers(); - vi.doUnmock("undici"); -}); - -function mockVoyageApiKey() { - vi.mocked(authModule.resolveApiKeyForProvider).mockResolvedValue({ - apiKey: "voyage-key-123", - mode: "api-key", - source: "test", - }); -} - -async function createDefaultVoyageProvider( - model: string, - fetchMock: ReturnType, -) { - installFetchMock(fetchMock as unknown as typeof globalThis.fetch); - mockPublicPinnedHostname(); - mockVoyageApiKey(); - return createVoyageEmbeddingProvider({ - config: {} as never, - provider: "voyage", - model, - fallback: "none", - }); -} - -describe("voyage embedding provider", () => { - afterEach(() => { - vi.doUnmock("undici"); - vi.resetAllMocks(); - vi.unstubAllGlobals(); - }); - - it("configures client with correct defaults and headers", async () => { - const fetchMock = createFetchMock(); - const result = await createDefaultVoyageProvider("voyage-4-large", fetchMock); - - await result.provider.embedQuery("test query"); - - expect(authModule.resolveApiKeyForProvider).toHaveBeenCalledWith( - expect.objectContaining({ provider: "voyage" }), - ); - - const call = fetchMock.mock.calls[0]; - expect(call).toBeDefined(); - const [url, init] = call as [RequestInfo | URL, RequestInit | undefined]; - expect(url).toBe("https://api.voyageai.com/v1/embeddings"); - - const headers = (init?.headers ?? {}) as Record; - expect(headers.Authorization).toBe("Bearer voyage-key-123"); - expect(headers["Content-Type"]).toBe("application/json"); - - const body = JSON.parse(init?.body as string); - expect(body).toEqual({ - model: "voyage-4-large", - input: ["test query"], - input_type: "query", - }); - }); - - it("respects remote overrides for baseUrl and apiKey", async () => { - const fetchMock = createFetchMock(); - installFetchMock(fetchMock as unknown as typeof globalThis.fetch); - mockPublicPinnedHostname(); - - const result = await createVoyageEmbeddingProvider({ - config: {} as never, - provider: "voyage", - model: "voyage-4-lite", - fallback: "none", - remote: { - baseUrl: "https://example.com", - apiKey: "remote-override-key", - headers: { "X-Custom": "123" }, - }, - }); - - await result.provider.embedQuery("test"); - - const call = fetchMock.mock.calls[0]; - expect(call).toBeDefined(); - const [url, init] = call as [RequestInfo | URL, RequestInit | undefined]; - expect(url).toBe("https://example.com/embeddings"); - - const headers = (init?.headers ?? {}) as Record; - expect(headers.Authorization).toBe("Bearer remote-override-key"); - expect(headers["X-Custom"]).toBe("123"); - }); - - it("passes input_type=document for embedBatch", async () => { - const fetchMock = withFetchPreconnect( - vi.fn( - async (_input: RequestInfo | URL, _init?: RequestInit) => - new Response( - JSON.stringify({ - data: [{ embedding: [0.1, 0.2] }, { embedding: [0.3, 0.4] }], - }), - { status: 200, headers: { "Content-Type": "application/json" } }, - ), - ), - ); - const result = await createDefaultVoyageProvider("voyage-4-large", fetchMock); - - await result.provider.embedBatch(["doc1", "doc2"]); - - const call = fetchMock.mock.calls[0]; - expect(call).toBeDefined(); - const [, init] = call as [RequestInfo | URL, RequestInit | undefined]; - const body = JSON.parse(init?.body as string); - expect(body).toEqual({ - model: "voyage-4-large", - input: ["doc1", "doc2"], - input_type: "document", - }); - }); - - it("normalizes model names", async () => { - expect(normalizeVoyageModel("voyage/voyage-large-2")).toBe("voyage-large-2"); - expect(normalizeVoyageModel("voyage-4-large")).toBe("voyage-4-large"); - expect(normalizeVoyageModel(" voyage-lite ")).toBe("voyage-lite"); - expect(normalizeVoyageModel("")).toBe("voyage-4-large"); // Default - }); -}); diff --git a/packages/memory-host-sdk/src/host/embeddings-voyage.ts b/packages/memory-host-sdk/src/host/embeddings-voyage.ts deleted file mode 100644 index f46614e5566..00000000000 --- a/packages/memory-host-sdk/src/host/embeddings-voyage.ts +++ /dev/null @@ -1,82 +0,0 @@ -import type { SsrFPolicy } from "../../../../src/infra/net/ssrf.js"; -import { normalizeEmbeddingModelWithPrefixes } from "./embeddings-model-normalize.js"; -import { resolveRemoteEmbeddingBearerClient } from "./embeddings-remote-client.js"; -import { fetchRemoteEmbeddingVectors } from "./embeddings-remote-fetch.js"; -import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js"; - -export type VoyageEmbeddingClient = { - baseUrl: string; - headers: Record; - ssrfPolicy?: SsrFPolicy; - model: string; -}; - -export const DEFAULT_VOYAGE_EMBEDDING_MODEL = "voyage-4-large"; -const DEFAULT_VOYAGE_BASE_URL = "https://api.voyageai.com/v1"; -const VOYAGE_MAX_INPUT_TOKENS: Record = { - "voyage-3": 32000, - "voyage-3-lite": 16000, - "voyage-code-3": 32000, -}; - -export function normalizeVoyageModel(model: string): string { - return normalizeEmbeddingModelWithPrefixes({ - model, - defaultModel: DEFAULT_VOYAGE_EMBEDDING_MODEL, - prefixes: ["voyage/"], - }); -} - -export async function createVoyageEmbeddingProvider( - options: EmbeddingProviderOptions, -): Promise<{ provider: EmbeddingProvider; client: VoyageEmbeddingClient }> { - const client = await resolveVoyageEmbeddingClient(options); - const url = `${client.baseUrl.replace(/\/$/, "")}/embeddings`; - - const embed = async (input: string[], input_type?: "query" | "document"): Promise => { - if (input.length === 0) { - return []; - } - const body: { model: string; input: string[]; input_type?: "query" | "document" } = { - model: client.model, - input, - }; - if (input_type) { - body.input_type = input_type; - } - - return await fetchRemoteEmbeddingVectors({ - url, - headers: client.headers, - ssrfPolicy: client.ssrfPolicy, - body, - errorPrefix: "voyage embeddings failed", - }); - }; - - return { - provider: { - id: "voyage", - model: client.model, - maxInputTokens: VOYAGE_MAX_INPUT_TOKENS[client.model], - embedQuery: async (text) => { - const [vec] = await embed([text], "query"); - return vec ?? []; - }, - embedBatch: async (texts) => embed(texts, "document"), - }, - client, - }; -} - -export async function resolveVoyageEmbeddingClient( - options: EmbeddingProviderOptions, -): Promise { - const { baseUrl, headers, ssrfPolicy } = await resolveRemoteEmbeddingBearerClient({ - provider: "voyage", - options, - defaultBaseUrl: DEFAULT_VOYAGE_BASE_URL, - }); - const model = normalizeVoyageModel(options.model); - return { baseUrl, headers, ssrfPolicy, model }; -} diff --git a/packages/memory-host-sdk/src/host/embeddings.test.ts b/packages/memory-host-sdk/src/host/embeddings.test.ts index fe1d5abdac3..f1e33e6acad 100644 --- a/packages/memory-host-sdk/src/host/embeddings.test.ts +++ b/packages/memory-host-sdk/src/host/embeddings.test.ts @@ -1,199 +1,8 @@ -import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; -import * as authModule from "../../../../src/agents/model-auth.js"; -import { createEmbeddingProvider, DEFAULT_LOCAL_MODEL } from "./embeddings.js"; -import * as nodeLlamaModule from "./node-llama.js"; -import { mockPublicPinnedHostname } from "./test-helpers/ssrf.js"; +import { describe, expect, it } from "vitest"; +import { DEFAULT_LOCAL_MODEL } from "./embeddings.js"; -const { resolveApiKeyForProviderMock } = vi.hoisted(() => ({ - resolveApiKeyForProviderMock: vi.fn(), -})); - -vi.mock("../../../../src/agents/model-auth.js", () => { - return { - resolveApiKeyForProvider: resolveApiKeyForProviderMock, - requireApiKey: (auth: { apiKey?: string; mode?: string }, provider: string) => { - if (auth.apiKey) { - return auth.apiKey; - } - throw new Error(`No API key resolved for provider "${provider}" (auth mode: ${auth.mode}).`); - }, - }; -}); - -vi.mock("../../../../src/infra/net/fetch-guard.js", () => ({ - fetchWithSsrFGuard: async (params: { - url: string; - init?: RequestInit; - fetchImpl?: typeof fetch; - }) => { - const fetchImpl = params.fetchImpl ?? globalThis.fetch; - if (!fetchImpl) { - throw new Error("fetch is not available"); - } - const response = await fetchImpl(params.url, params.init); - return { - response, - finalUrl: params.url, - release: async () => {}, - }; - }, -})); - -const createEmbeddingDataFetchMock = () => - vi.fn(async (_input?: unknown, _init?: unknown) => ({ - ok: true, - status: 200, - json: async () => ({ data: [{ embedding: [1, 2, 3] }] }), - })); - -const createGeminiFetchMock = () => - vi.fn(async (_input?: unknown, _init?: unknown) => ({ - ok: true, - status: 200, - json: async () => ({ embedding: { values: [1, 2, 3] } }), - })); - -beforeEach(() => { - vi.spyOn(authModule, "resolveApiKeyForProvider"); - vi.spyOn(nodeLlamaModule, "importNodeLlamaCpp"); -}); - -afterEach(() => { - vi.resetAllMocks(); - vi.unstubAllGlobals(); -}); - -function installFetchMock(fetchMock: typeof globalThis.fetch) { - vi.stubGlobal("fetch", fetchMock); -} - -function readFirstFetchRequest(fetchMock: { mock: { calls: unknown[][] } }) { - const [url, init] = fetchMock.mock.calls[0] ?? []; - return { url, init: init as RequestInit | undefined }; -} - -function requireProvider(result: Awaited>) { - if (!result.provider) { - throw new Error("Expected embedding provider"); - } - return result.provider; -} - -function mockResolvedProviderKey(apiKey = "provider-key") { - vi.mocked(authModule.resolveApiKeyForProvider).mockResolvedValue({ - apiKey, - mode: "api-key", - source: "test", - }); -} - -describe("package embedding provider smoke", () => { - it("uses remote OpenAI baseUrl/apiKey and merges headers", async () => { - const fetchMock = createEmbeddingDataFetchMock(); - installFetchMock(fetchMock as unknown as typeof globalThis.fetch); - mockPublicPinnedHostname(); - mockResolvedProviderKey("provider-key"); - - const result = await createEmbeddingProvider({ - config: { - models: { - providers: { - openai: { - baseUrl: "https://api.openai.com/v1", - headers: { "X-Provider": "p", "X-Shared": "provider" }, - }, - }, - }, - } as never, - provider: "openai", - remote: { - baseUrl: "https://example.com/v1", - apiKey: " remote-key ", - headers: { "X-Shared": "remote", "X-Remote": "r" }, - }, - model: "text-embedding-3-small", - fallback: "openai", - }); - - await requireProvider(result).embedQuery("hello"); - - expect(authModule.resolveApiKeyForProvider).not.toHaveBeenCalled(); - const { url, init } = readFirstFetchRequest(fetchMock); - expect(url).toBe("https://example.com/v1/embeddings"); - const headers = (init?.headers ?? {}) as Record; - expect(headers.Authorization).toBe("Bearer remote-key"); - expect(headers["X-Provider"]).toBe("p"); - expect(headers["X-Shared"]).toBe("remote"); - expect(headers["X-Remote"]).toBe("r"); - }); - - it("uses GEMINI_API_KEY env indirection for Gemini remote apiKey", async () => { - const fetchMock = createGeminiFetchMock(); - installFetchMock(fetchMock as unknown as typeof globalThis.fetch); - mockPublicPinnedHostname(); - vi.stubEnv("GEMINI_API_KEY", "env-gemini-key"); - - const result = await createEmbeddingProvider({ - config: {} as never, - provider: "gemini", - remote: { - apiKey: "GEMINI_API_KEY", // pragma: allowlist secret - }, - model: "text-embedding-004", - fallback: "openai", - }); - - await requireProvider(result).embedQuery("hello"); - - const { init } = readFirstFetchRequest(fetchMock); - const headers = (init?.headers ?? {}) as Record; - expect(headers["x-goog-api-key"]).toBe("env-gemini-key"); - }); - - it("normalizes local embeddings and resolves the default local model", async () => { - const resolveModelFileMock = vi.fn(async () => "/fake/model.gguf"); - vi.mocked(nodeLlamaModule.importNodeLlamaCpp).mockResolvedValue({ - getLlama: async () => ({ - loadModel: vi.fn().mockResolvedValue({ - createEmbeddingContext: vi.fn().mockResolvedValue({ - getEmbeddingFor: vi.fn().mockResolvedValue({ - vector: new Float32Array([2.35, 3.45, 0.63, 4.3]), - }), - }), - }), - }), - resolveModelFile: resolveModelFileMock, - LlamaLogLevel: { error: 0 }, - } as never); - - const result = await createEmbeddingProvider({ - config: {} as never, - provider: "local", - model: "", - fallback: "none", - }); - - const embedding = await requireProvider(result).embedQuery("test query"); - const magnitude = Math.sqrt(embedding.reduce((sum, value) => sum + value * value, 0)); - expect(magnitude).toBeCloseTo(1, 5); - expect(resolveModelFileMock).toHaveBeenCalledWith(DEFAULT_LOCAL_MODEL, undefined); - }); - - it("returns null provider when explicit primary and fallback auth paths fail", async () => { - vi.mocked(authModule.resolveApiKeyForProvider).mockRejectedValue( - new Error("No API key found for provider"), - ); - - const result = await createEmbeddingProvider({ - config: {} as never, - provider: "openai", - model: "text-embedding-3-small", - fallback: "gemini", - }); - - expect(result.provider).toBeNull(); - expect(result.requestedProvider).toBe("openai"); - expect(result.fallbackFrom).toBe("openai"); - expect(result.providerUnavailableReason).toContain("Fallback to gemini failed"); +describe("package embeddings barrel", () => { + it("re-exports the source local embedding contract", () => { + expect(DEFAULT_LOCAL_MODEL).toContain("embeddinggemma"); }); }); diff --git a/packages/memory-host-sdk/src/host/embeddings.ts b/packages/memory-host-sdk/src/host/embeddings.ts index 5deab0a591d..89aaf665439 100644 --- a/packages/memory-host-sdk/src/host/embeddings.ts +++ b/packages/memory-host-sdk/src/host/embeddings.ts @@ -1,373 +1 @@ -import fsSync from "node:fs"; -import type { Llama, LlamaEmbeddingContext, LlamaModel } from "node-llama-cpp"; -import type { OpenClawConfig } from "../../../../src/config/config.js"; -import type { SecretInput } from "../../../../src/config/types.secrets.js"; -import { formatErrorMessage } from "../../../../src/infra/errors.js"; -import { resolveUserPath } from "../../../../src/utils.js"; -import type { EmbeddingInput } from "./embedding-inputs.js"; -import { sanitizeAndNormalizeEmbedding } from "./embedding-vectors.js"; -import { - createBedrockEmbeddingProvider, - hasAwsCredentials, - type BedrockEmbeddingClient, -} from "./embeddings-bedrock.js"; -import { - createGeminiEmbeddingProvider, - type GeminiEmbeddingClient, - type GeminiTaskType, -} from "./embeddings-gemini.js"; -import { - createLmstudioEmbeddingProvider, - type LmstudioEmbeddingClient, -} from "./embeddings-lmstudio.js"; -import { - createMistralEmbeddingProvider, - type MistralEmbeddingClient, -} from "./embeddings-mistral.js"; -import { createOllamaEmbeddingProvider, type OllamaEmbeddingClient } from "./embeddings-ollama.js"; -import { createOpenAiEmbeddingProvider, type OpenAiEmbeddingClient } from "./embeddings-openai.js"; -import { createVoyageEmbeddingProvider, type VoyageEmbeddingClient } from "./embeddings-voyage.js"; -import { importNodeLlamaCpp } from "./node-llama.js"; - -export type { GeminiEmbeddingClient } from "./embeddings-gemini.js"; -export type { LmstudioEmbeddingClient } from "./embeddings-lmstudio.js"; -export type { MistralEmbeddingClient } from "./embeddings-mistral.js"; -export type { OpenAiEmbeddingClient } from "./embeddings-openai.js"; -export type { VoyageEmbeddingClient } from "./embeddings-voyage.js"; -export type { OllamaEmbeddingClient } from "./embeddings-ollama.js"; -export type { BedrockEmbeddingClient } from "./embeddings-bedrock.js"; - -export type EmbeddingProvider = { - id: string; - model: string; - maxInputTokens?: number; - embedQuery: (text: string) => Promise; - embedBatch: (texts: string[]) => Promise; - embedBatchInputs?: (inputs: EmbeddingInput[]) => Promise; -}; - -export type EmbeddingProviderId = - | "openai" - | "local" - | "gemini" - | "voyage" - | "mistral" - | "bedrock" - | "lmstudio" - | "ollama"; -export type EmbeddingProviderRequest = EmbeddingProviderId | "auto"; -export type EmbeddingProviderFallback = EmbeddingProviderId | "none"; - -// Remote providers considered for auto-selection when provider === "auto". -// LM Studio and Ollama are intentionally excluded here so that "auto" mode does not -// implicitly assume either instance is available. -// Bedrock is handled separately when AWS credentials are detected. -const REMOTE_EMBEDDING_PROVIDER_IDS = ["openai", "gemini", "voyage", "mistral"] as const; - -export type EmbeddingProviderResult = { - provider: EmbeddingProvider | null; - requestedProvider: EmbeddingProviderRequest; - fallbackFrom?: EmbeddingProviderId; - fallbackReason?: string; - providerUnavailableReason?: string; - openAi?: OpenAiEmbeddingClient; - gemini?: GeminiEmbeddingClient; - voyage?: VoyageEmbeddingClient; - mistral?: MistralEmbeddingClient; - bedrock?: BedrockEmbeddingClient; - lmstudio?: LmstudioEmbeddingClient; - ollama?: OllamaEmbeddingClient; -}; - -export type EmbeddingProviderOptions = { - config: OpenClawConfig; - agentDir?: string; - provider: EmbeddingProviderRequest; - remote?: { - baseUrl?: string; - apiKey?: SecretInput; - headers?: Record; - }; - model: string; - fallback: EmbeddingProviderFallback; - local?: { - modelPath?: string; - modelCacheDir?: string; - }; - /** Provider-specific output vector dimensions for supported embedding families. */ - outputDimensionality?: number; - /** Gemini: override the default task type sent with embedding requests. */ - taskType?: GeminiTaskType; -}; - -export const DEFAULT_LOCAL_MODEL = - "hf:ggml-org/embeddinggemma-300m-qat-q8_0-GGUF/embeddinggemma-300m-qat-Q8_0.gguf"; - -function canAutoSelectLocal(options: EmbeddingProviderOptions): boolean { - const modelPath = options.local?.modelPath?.trim(); - if (!modelPath) { - return false; - } - if (/^(hf:|https?:)/i.test(modelPath)) { - return false; - } - const resolved = resolveUserPath(modelPath); - try { - return fsSync.statSync(resolved).isFile(); - } catch { - return false; - } -} - -function isMissingApiKeyError(err: unknown): boolean { - const message = formatErrorMessage(err); - return message.includes("No API key found for provider"); -} - -export async function createLocalEmbeddingProvider( - options: EmbeddingProviderOptions, -): Promise { - const modelPath = options.local?.modelPath?.trim() || DEFAULT_LOCAL_MODEL; - const modelCacheDir = options.local?.modelCacheDir?.trim(); - - // Lazy-load node-llama-cpp to keep startup light unless local is enabled. - const { getLlama, resolveModelFile, LlamaLogLevel } = await importNodeLlamaCpp(); - - let llama: Llama | null = null; - let embeddingModel: LlamaModel | null = null; - let embeddingContext: LlamaEmbeddingContext | null = null; - let initPromise: Promise | null = null; - - const ensureContext = async (): Promise => { - if (embeddingContext) { - return embeddingContext; - } - if (initPromise) { - return initPromise; - } - initPromise = (async () => { - try { - if (!llama) { - llama = await getLlama({ logLevel: LlamaLogLevel.error }); - } - if (!embeddingModel) { - const resolved = await resolveModelFile(modelPath, modelCacheDir || undefined); - embeddingModel = await llama.loadModel({ modelPath: resolved }); - } - if (!embeddingContext) { - embeddingContext = await embeddingModel.createEmbeddingContext(); - } - return embeddingContext; - } catch (err) { - initPromise = null; - throw err; - } - })(); - return initPromise; - }; - - return { - id: "local", - model: modelPath, - embedQuery: async (text) => { - const ctx = await ensureContext(); - const embedding = await ctx.getEmbeddingFor(text); - return sanitizeAndNormalizeEmbedding(Array.from(embedding.vector)); - }, - embedBatch: async (texts) => { - const ctx = await ensureContext(); - const embeddings = await Promise.all( - texts.map(async (text) => { - const embedding = await ctx.getEmbeddingFor(text); - return sanitizeAndNormalizeEmbedding(Array.from(embedding.vector)); - }), - ); - return embeddings; - }, - }; -} - -export async function createEmbeddingProvider( - options: EmbeddingProviderOptions, -): Promise { - const requestedProvider = options.provider; - const fallback = options.fallback; - - const createProvider = async (id: EmbeddingProviderId) => { - if (id === "local") { - const provider = await createLocalEmbeddingProvider(options); - return { provider }; - } - if (id === "lmstudio") { - const { provider, client } = await createLmstudioEmbeddingProvider(options); - return { provider, lmstudio: client }; - } - if (id === "ollama") { - const { provider, client } = await createOllamaEmbeddingProvider(options); - return { provider, ollama: client }; - } - if (id === "gemini") { - const { provider, client } = await createGeminiEmbeddingProvider(options); - return { provider, gemini: client }; - } - if (id === "voyage") { - const { provider, client } = await createVoyageEmbeddingProvider(options); - return { provider, voyage: client }; - } - if (id === "mistral") { - const { provider, client } = await createMistralEmbeddingProvider(options); - return { provider, mistral: client }; - } - if (id === "bedrock") { - const { provider, client } = await createBedrockEmbeddingProvider(options); - return { provider, bedrock: client }; - } - const { provider, client } = await createOpenAiEmbeddingProvider(options); - return { provider, openAi: client }; - }; - - const formatPrimaryError = (err: unknown, provider: EmbeddingProviderId) => - provider === "local" ? formatLocalSetupError(err) : formatErrorMessage(err); - - if (requestedProvider === "auto") { - const missingKeyErrors: string[] = []; - let localError: string | null = null; - - if (canAutoSelectLocal(options)) { - try { - const local = await createProvider("local"); - return { ...local, requestedProvider }; - } catch (err) { - localError = formatLocalSetupError(err); - } - } - - for (const provider of REMOTE_EMBEDDING_PROVIDER_IDS) { - try { - const result = await createProvider(provider); - return { ...result, requestedProvider }; - } catch (err) { - const message = formatPrimaryError(err, provider); - if (isMissingApiKeyError(err)) { - missingKeyErrors.push(message); - continue; - } - // Non-auth errors (e.g., network) are still fatal - const wrapped = new Error(message) as Error & { cause?: unknown }; - wrapped.cause = err; - throw wrapped; - } - } - - // Try bedrock if AWS credentials are available - if (await hasAwsCredentials()) { - try { - const result = await createProvider("bedrock"); - return { ...result, requestedProvider }; - } catch (err) { - const message = formatPrimaryError(err, "bedrock"); - if (isMissingApiKeyError(err)) { - missingKeyErrors.push(message); - } else { - const wrapped = new Error(message) as Error & { cause?: unknown }; - wrapped.cause = err; - throw wrapped; - } - } - } - - // All providers failed due to missing API keys - return null provider for FTS-only mode - const details = [...missingKeyErrors, localError].filter(Boolean) as string[]; - const reason = details.length > 0 ? details.join("\n\n") : "No embeddings provider available."; - return { - provider: null, - requestedProvider, - providerUnavailableReason: reason, - }; - } - - try { - const primary = await createProvider(requestedProvider); - return { ...primary, requestedProvider }; - } catch (primaryErr) { - const reason = formatPrimaryError(primaryErr, requestedProvider); - if (fallback && fallback !== "none" && fallback !== requestedProvider) { - try { - const fallbackResult = await createProvider(fallback); - return { - ...fallbackResult, - requestedProvider, - fallbackFrom: requestedProvider, - fallbackReason: reason, - }; - } catch (fallbackErr) { - // Both primary and fallback failed - check if it's auth-related - const fallbackReason = formatErrorMessage(fallbackErr); - const combinedReason = `${reason}\n\nFallback to ${fallback} failed: ${fallbackReason}`; - if (isMissingApiKeyError(primaryErr) && isMissingApiKeyError(fallbackErr)) { - // Both failed due to missing API keys - return null for FTS-only mode - return { - provider: null, - requestedProvider, - fallbackFrom: requestedProvider, - fallbackReason: reason, - providerUnavailableReason: combinedReason, - }; - } - // Non-auth errors are still fatal - const wrapped = new Error(combinedReason) as Error & { - cause?: unknown; - }; - wrapped.cause = fallbackErr; - throw wrapped; - } - } - // No fallback configured - check if we should degrade to FTS-only - if (isMissingApiKeyError(primaryErr)) { - return { - provider: null, - requestedProvider, - providerUnavailableReason: reason, - }; - } - const wrapped = new Error(reason) as Error & { cause?: unknown }; - wrapped.cause = primaryErr; - throw wrapped; - } -} - -function isNodeLlamaCppMissing(err: unknown): boolean { - if (!(err instanceof Error)) { - return false; - } - const code = (err as Error & { code?: unknown }).code; - if (code === "ERR_MODULE_NOT_FOUND") { - return err.message.includes("node-llama-cpp"); - } - return false; -} - -function formatLocalSetupError(err: unknown): string { - const detail = formatErrorMessage(err); - const missing = isNodeLlamaCppMissing(err); - return [ - "Local embeddings unavailable.", - missing - ? "Reason: optional dependency node-llama-cpp is missing (or failed to install)." - : detail - ? `Reason: ${detail}` - : undefined, - missing && detail ? `Detail: ${detail}` : null, - "To enable local embeddings:", - "1) Use Node 24 (recommended for installs/updates; Node 22 LTS, currently 22.14+, remains supported)", - missing - ? "2) Reinstall OpenClaw (this should install node-llama-cpp): npm i -g openclaw@latest" - : null, - "3) If you use pnpm: pnpm approve-builds (select node-llama-cpp), then pnpm rebuild node-llama-cpp", - ...REMOTE_EMBEDDING_PROVIDER_IDS.map( - (provider) => `Or set agents.defaults.memorySearch.provider = "${provider}" (remote).`, - ), - ] - .filter(Boolean) - .join("\n"); -} +export * from "../../../../src/memory-host-sdk/host/embeddings.js"; diff --git a/packages/memory-host-sdk/src/host/multimodal.ts b/packages/memory-host-sdk/src/host/multimodal.ts index c5e1e07004f..32167b48f81 100644 --- a/packages/memory-host-sdk/src/host/multimodal.ts +++ b/packages/memory-host-sdk/src/host/multimodal.ts @@ -100,21 +100,3 @@ export function classifyMemoryMultimodalPath( } return null; } - -export function normalizeGeminiEmbeddingModelForMemory(model: string): string { - const trimmed = model.trim(); - if (!trimmed) { - return ""; - } - return trimmed.replace(/^models\//, "").replace(/^(gemini|google)\//, ""); -} - -export function supportsMemoryMultimodalEmbeddings(params: { - provider: string; - model: string; -}): boolean { - if (params.provider !== "gemini") { - return false; - } - return normalizeGeminiEmbeddingModelForMemory(params.model) === "gemini-embedding-2-preview"; -} diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 58d75cb7ca2..595327767f9 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -362,6 +362,12 @@ importers: '@aws-sdk/client-bedrock': specifier: 3.1028.0 version: 3.1028.0 + '@aws-sdk/client-bedrock-runtime': + specifier: 3.1028.0 + version: 3.1028.0 + '@aws-sdk/credential-provider-node': + specifier: 3.972.30 + version: 3.972.30 devDependencies: '@openclaw/plugin-sdk': specifier: workspace:* @@ -1225,6 +1231,12 @@ importers: specifier: workspace:* version: link:../../packages/plugin-sdk + extensions/voyage: + devDependencies: + '@openclaw/plugin-sdk': + specifier: workspace:* + version: link:../../packages/plugin-sdk + extensions/vydra: devDependencies: '@openclaw/plugin-sdk': diff --git a/src/agents/memory-search.ts b/src/agents/memory-search.ts index b65b59b8b5f..420bc74a44f 100644 --- a/src/agents/memory-search.ts +++ b/src/agents/memory-search.ts @@ -6,7 +6,6 @@ import type { SecretInput } from "../config/types.secrets.js"; import { isMemoryMultimodalEnabled, normalizeMemoryMultimodalSettings, - supportsMemoryMultimodalEmbeddings, type MemoryMultimodalSettings, } from "../memory-host-sdk/multimodal.js"; import { getMemoryEmbeddingProvider } from "../plugins/memory-embedding-provider-runtime.js"; @@ -389,24 +388,9 @@ export function resolveMemorySearchConfig( const multimodalActive = isMemoryMultimodalEnabled(resolved.multimodal); const multimodalProvider = resolved.provider === "auto" ? undefined : getMemoryEmbeddingProvider(resolved.provider); - const builtinMultimodalSupport = - resolved.provider === "auto" - ? false - : supportsMemoryMultimodalEmbeddings({ - provider: resolved.provider, - model: resolved.model, - }); if ( multimodalActive && - !( - // Fall back to the built-in helper when the provider is not registered yet - // or when a registered adapter does not implement multimodal capability checks. - ( - multimodalProvider?.supportsMultimodalEmbeddings?.({ - model: resolved.model, - }) ?? builtinMultimodalSupport - ) - ) + !(multimodalProvider?.supportsMultimodalEmbeddings?.({ model: resolved.model }) ?? false) ) { throw new Error( "agents.*.memorySearch.multimodal requires a provider adapter that supports multimodal embeddings for the configured model.", diff --git a/src/memory-host-sdk/engine-embeddings.ts b/src/memory-host-sdk/engine-embeddings.ts index ad794fe4f1d..39f140e419b 100644 --- a/src/memory-host-sdk/engine-embeddings.ts +++ b/src/memory-host-sdk/engine-embeddings.ts @@ -16,50 +16,56 @@ export type { MemoryEmbeddingProviderRuntime, } from "../plugins/memory-embedding-providers.js"; export { createLocalEmbeddingProvider, DEFAULT_LOCAL_MODEL } from "./host/embeddings.js"; +export { extractBatchErrorMessage, formatUnavailableBatchError } from "./host/batch-error-utils.js"; +export { postJsonWithRetry } from "./host/batch-http.js"; +export { applyEmbeddingBatchOutputLine } from "./host/batch-output.js"; export { - createGeminiEmbeddingProvider, - DEFAULT_GEMINI_EMBEDDING_MODEL, - buildGeminiEmbeddingRequest, -} from "./host/embeddings-gemini.js"; + EMBEDDING_BATCH_ENDPOINT, + type EmbeddingBatchStatus, + type ProviderBatchOutputLine, +} from "./host/batch-provider-common.js"; export { - createLmstudioEmbeddingProvider, - DEFAULT_LMSTUDIO_EMBEDDING_MODEL, -} from "./host/embeddings-lmstudio.js"; -export type { LmstudioEmbeddingClient } from "./host/embeddings-lmstudio.js"; + buildEmbeddingBatchGroupOptions, + runEmbeddingBatchGroups, + type EmbeddingBatchExecutionParams, +} from "./host/batch-runner.js"; export { - createMistralEmbeddingProvider, - DEFAULT_MISTRAL_EMBEDDING_MODEL, -} from "./host/embeddings-mistral.js"; + resolveBatchCompletionFromStatus, + resolveCompletedBatchResult, + throwIfBatchTerminalFailure, + type BatchCompletionResult, +} from "./host/batch-status.js"; +export { uploadBatchJsonlFile } from "./host/batch-upload.js"; export { - createGitHubCopilotEmbeddingProvider, - type GitHubCopilotEmbeddingClient, -} from "./host/embeddings-github-copilot.js"; -export { - createOllamaEmbeddingProvider, - DEFAULT_OLLAMA_EMBEDDING_MODEL, -} from "./host/embeddings-ollama.js"; -export type { OllamaEmbeddingClient } from "./host/embeddings-ollama.js"; -export { - createOpenAiEmbeddingProvider, - DEFAULT_OPENAI_EMBEDDING_MODEL, -} from "./host/embeddings-openai.js"; -export { - createVoyageEmbeddingProvider, - DEFAULT_VOYAGE_EMBEDDING_MODEL, -} from "./host/embeddings-voyage.js"; -export { runGeminiEmbeddingBatches, type GeminiBatchRequest } from "./host/batch-gemini.js"; -export { - OPENAI_BATCH_ENDPOINT, - runOpenAiEmbeddingBatches, - type OpenAiBatchRequest, -} from "./host/batch-openai.js"; -export { runVoyageEmbeddingBatches, type VoyageBatchRequest } from "./host/batch-voyage.js"; + buildBatchHeaders, + normalizeBatchBaseUrl, + type BatchHttpClientConfig, +} from "./host/batch-utils.js"; export { enforceEmbeddingMaxInputTokens } from "./host/embedding-chunk-limits.js"; +export { + isMissingEmbeddingApiKeyError, + mapBatchEmbeddingsByIndex, + sanitizeEmbeddingCacheHeaders, +} from "./host/embedding-provider-adapter-utils.js"; +export { sanitizeAndNormalizeEmbedding } from "./host/embedding-vectors.js"; +export { debugEmbeddingsLog } from "./host/embeddings-debug.js"; +export { normalizeEmbeddingModelWithPrefixes } from "./host/embeddings-model-normalize.js"; +export { + resolveRemoteEmbeddingBearerClient, + type RemoteEmbeddingProviderId, +} from "./host/embeddings-remote-client.js"; +export { + createRemoteEmbeddingProvider, + resolveRemoteEmbeddingClient, + type RemoteEmbeddingClient, +} from "./host/embeddings-remote-provider.js"; +export { fetchRemoteEmbeddingVectors } from "./host/embeddings-remote-fetch.js"; export { estimateStructuredEmbeddingInputBytes, estimateUtf8Bytes, } from "./host/embedding-input-limits.js"; export { hasNonTextEmbeddingParts, type EmbeddingInput } from "./host/embedding-inputs.js"; +export { buildRemoteBaseUrlPolicy, withRemoteHttpResponse } from "./host/remote-http.js"; export { buildCaseInsensitiveExtensionGlob, classifyMemoryMultimodalPath, diff --git a/src/memory-host-sdk/host/batch-gemini.test.ts b/src/memory-host-sdk/host/batch-gemini.test.ts deleted file mode 100644 index 095ebe008b9..00000000000 --- a/src/memory-host-sdk/host/batch-gemini.test.ts +++ /dev/null @@ -1,116 +0,0 @@ -import { afterEach, beforeAll, beforeEach, describe, expect, it, vi } from "vitest"; -import type { GeminiEmbeddingClient } from "./embeddings-gemini.js"; - -vi.mock("./remote-http.js", () => ({ - withRemoteHttpResponse: vi.fn(), -})); - -function magnitude(values: number[]) { - return Math.sqrt(values.reduce((sum, value) => sum + value * value, 0)); -} - -describe("runGeminiEmbeddingBatches", () => { - let runGeminiEmbeddingBatches: typeof import("./batch-gemini.js").runGeminiEmbeddingBatches; - let withRemoteHttpResponse: typeof import("./remote-http.js").withRemoteHttpResponse; - let remoteHttpMock: ReturnType>; - - beforeAll(async () => { - ({ runGeminiEmbeddingBatches } = await import("./batch-gemini.js")); - ({ withRemoteHttpResponse } = await import("./remote-http.js")); - remoteHttpMock = vi.mocked(withRemoteHttpResponse); - }); - - beforeEach(() => { - vi.clearAllMocks(); - }); - - afterEach(() => { - vi.resetAllMocks(); - vi.unstubAllGlobals(); - }); - - const mockClient: GeminiEmbeddingClient = { - baseUrl: "https://generativelanguage.googleapis.com/v1beta", - headers: {}, - model: "gemini-embedding-2-preview", - modelPath: "models/gemini-embedding-2-preview", - apiKeys: ["test-key"], - outputDimensionality: 1536, - }; - - it("includes outputDimensionality in batch upload requests", async () => { - remoteHttpMock.mockImplementationOnce(async (params) => { - expect(params.url).toContain("/upload/v1beta/files?uploadType=multipart"); - const body = params.init?.body; - if (!(body instanceof Blob)) { - throw new Error("expected multipart blob body"); - } - const text = await body.text(); - expect(text).toContain('"taskType":"RETRIEVAL_DOCUMENT"'); - expect(text).toContain('"outputDimensionality":1536'); - return await params.onResponse( - new Response(JSON.stringify({ name: "files/file-123" }), { - status: 200, - headers: { "Content-Type": "application/json" }, - }), - ); - }); - remoteHttpMock.mockImplementationOnce(async (params) => { - expect(params.url).toMatch(/:asyncBatchEmbedContent$/u); - return await params.onResponse( - new Response( - JSON.stringify({ - name: "batches/batch-1", - state: "COMPLETED", - outputConfig: { file: "files/output-1" }, - }), - { - status: 200, - headers: { "Content-Type": "application/json" }, - }, - ), - ); - }); - remoteHttpMock.mockImplementationOnce(async (params) => { - expect(params.url).toMatch(/\/files\/output-1:download$/u); - return await params.onResponse( - new Response( - JSON.stringify({ - key: "req-1", - response: { embedding: { values: [3, 4] } }, - }), - { - status: 200, - headers: { "Content-Type": "application/jsonl" }, - }, - ), - ); - }); - - const results = await runGeminiEmbeddingBatches({ - gemini: mockClient, - agentId: "main", - requests: [ - { - custom_id: "req-1", - request: { - content: { parts: [{ text: "hello world" }] }, - taskType: "RETRIEVAL_DOCUMENT", - outputDimensionality: 1536, - }, - }, - ], - wait: true, - pollIntervalMs: 1, - timeoutMs: 1000, - concurrency: 1, - }); - - const embedding = results.get("req-1"); - expect(embedding).toBeDefined(); - expect(embedding?.[0]).toBeCloseTo(0.6, 5); - expect(embedding?.[1]).toBeCloseTo(0.8, 5); - expect(magnitude(embedding ?? [])).toBeCloseTo(1, 5); - expect(remoteHttpMock).toHaveBeenCalledTimes(3); - }); -}); diff --git a/src/memory-host-sdk/host/batch-gemini.ts b/src/memory-host-sdk/host/batch-gemini.ts deleted file mode 100644 index 4bdc9fa055e..00000000000 --- a/src/memory-host-sdk/host/batch-gemini.ts +++ /dev/null @@ -1,368 +0,0 @@ -import { - buildEmbeddingBatchGroupOptions, - runEmbeddingBatchGroups, - type EmbeddingBatchExecutionParams, -} from "./batch-runner.js"; -import { buildBatchHeaders, normalizeBatchBaseUrl } from "./batch-utils.js"; -import { sanitizeAndNormalizeEmbedding } from "./embedding-vectors.js"; -import { debugEmbeddingsLog } from "./embeddings-debug.js"; -import type { GeminiEmbeddingClient, GeminiTextEmbeddingRequest } from "./embeddings-gemini.js"; -import { hashText } from "./internal.js"; -import { withRemoteHttpResponse } from "./remote-http.js"; - -export type GeminiBatchRequest = { - custom_id: string; - request: GeminiTextEmbeddingRequest; -}; - -export type GeminiBatchStatus = { - name?: string; - state?: string; - outputConfig?: { file?: string; fileId?: string }; - metadata?: { - output?: { - responsesFile?: string; - }; - }; - error?: { message?: string }; -}; - -export type GeminiBatchOutputLine = { - key?: string; - custom_id?: string; - request_id?: string; - embedding?: { values?: number[] }; - response?: { - embedding?: { values?: number[] }; - error?: { message?: string }; - }; - error?: { message?: string }; -}; - -const GEMINI_BATCH_MAX_REQUESTS = 50000; -function getGeminiUploadUrl(baseUrl: string): string { - if (baseUrl.includes("/v1beta")) { - return baseUrl.replace(/\/v1beta\/?$/, "/upload/v1beta"); - } - return `${baseUrl.replace(/\/$/, "")}/upload`; -} - -function buildGeminiUploadBody(params: { jsonl: string; displayName: string }): { - body: Blob; - contentType: string; -} { - const boundary = `openclaw-${hashText(params.displayName)}`; - const jsonPart = JSON.stringify({ - file: { - displayName: params.displayName, - mimeType: "application/jsonl", - }, - }); - const delimiter = `--${boundary}\r\n`; - const closeDelimiter = `--${boundary}--\r\n`; - const parts = [ - `${delimiter}Content-Type: application/json; charset=UTF-8\r\n\r\n${jsonPart}\r\n`, - `${delimiter}Content-Type: application/jsonl; charset=UTF-8\r\n\r\n${params.jsonl}\r\n`, - closeDelimiter, - ]; - const body = new Blob([parts.join("")], { type: "multipart/related" }); - return { - body, - contentType: `multipart/related; boundary=${boundary}`, - }; -} - -async function submitGeminiBatch(params: { - gemini: GeminiEmbeddingClient; - requests: GeminiBatchRequest[]; - agentId: string; -}): Promise { - const baseUrl = normalizeBatchBaseUrl(params.gemini); - const jsonl = params.requests - .map((request) => - JSON.stringify({ - key: request.custom_id, - request: request.request, - }), - ) - .join("\n"); - const displayName = `memory-embeddings-${hashText(String(Date.now()))}`; - const uploadPayload = buildGeminiUploadBody({ jsonl, displayName }); - - const uploadUrl = `${getGeminiUploadUrl(baseUrl)}/files?uploadType=multipart`; - debugEmbeddingsLog("memory embeddings: gemini batch upload", { - uploadUrl, - baseUrl, - requests: params.requests.length, - }); - const filePayload = await withRemoteHttpResponse({ - url: uploadUrl, - ssrfPolicy: params.gemini.ssrfPolicy, - init: { - method: "POST", - headers: { - ...buildBatchHeaders(params.gemini, { json: false }), - "Content-Type": uploadPayload.contentType, - }, - body: uploadPayload.body, - }, - onResponse: async (fileRes) => { - if (!fileRes.ok) { - const text = await fileRes.text(); - throw new Error(`gemini batch file upload failed: ${fileRes.status} ${text}`); - } - return (await fileRes.json()) as { name?: string; file?: { name?: string } }; - }, - }); - const fileId = filePayload.name ?? filePayload.file?.name; - if (!fileId) { - throw new Error("gemini batch file upload failed: missing file id"); - } - - const batchBody = { - batch: { - displayName: `memory-embeddings-${params.agentId}`, - inputConfig: { - file_name: fileId, - }, - }, - }; - - const batchEndpoint = `${baseUrl}/${params.gemini.modelPath}:asyncBatchEmbedContent`; - debugEmbeddingsLog("memory embeddings: gemini batch create", { - batchEndpoint, - fileId, - }); - return await withRemoteHttpResponse({ - url: batchEndpoint, - ssrfPolicy: params.gemini.ssrfPolicy, - init: { - method: "POST", - headers: buildBatchHeaders(params.gemini, { json: true }), - body: JSON.stringify(batchBody), - }, - onResponse: async (batchRes) => { - if (batchRes.ok) { - return (await batchRes.json()) as GeminiBatchStatus; - } - const text = await batchRes.text(); - if (batchRes.status === 404) { - throw new Error( - "gemini batch create failed: 404 (asyncBatchEmbedContent not available for this model/baseUrl). Disable remote.batch.enabled or switch providers.", - ); - } - throw new Error(`gemini batch create failed: ${batchRes.status} ${text}`); - }, - }); -} - -async function fetchGeminiBatchStatus(params: { - gemini: GeminiEmbeddingClient; - batchName: string; -}): Promise { - const baseUrl = normalizeBatchBaseUrl(params.gemini); - const name = params.batchName.startsWith("batches/") - ? params.batchName - : `batches/${params.batchName}`; - const statusUrl = `${baseUrl}/${name}`; - debugEmbeddingsLog("memory embeddings: gemini batch status", { statusUrl }); - return await withRemoteHttpResponse({ - url: statusUrl, - ssrfPolicy: params.gemini.ssrfPolicy, - init: { - headers: buildBatchHeaders(params.gemini, { json: true }), - }, - onResponse: async (res) => { - if (!res.ok) { - const text = await res.text(); - throw new Error(`gemini batch status failed: ${res.status} ${text}`); - } - return (await res.json()) as GeminiBatchStatus; - }, - }); -} - -async function fetchGeminiFileContent(params: { - gemini: GeminiEmbeddingClient; - fileId: string; -}): Promise { - const baseUrl = normalizeBatchBaseUrl(params.gemini); - const file = params.fileId.startsWith("files/") ? params.fileId : `files/${params.fileId}`; - const downloadUrl = `${baseUrl}/${file}:download`; - debugEmbeddingsLog("memory embeddings: gemini batch download", { downloadUrl }); - return await withRemoteHttpResponse({ - url: downloadUrl, - ssrfPolicy: params.gemini.ssrfPolicy, - init: { - headers: buildBatchHeaders(params.gemini, { json: true }), - }, - onResponse: async (res) => { - if (!res.ok) { - const text = await res.text(); - throw new Error(`gemini batch file content failed: ${res.status} ${text}`); - } - return await res.text(); - }, - }); -} - -function parseGeminiBatchOutput(text: string): GeminiBatchOutputLine[] { - if (!text.trim()) { - return []; - } - return text - .split("\n") - .map((line) => line.trim()) - .filter(Boolean) - .map((line) => JSON.parse(line) as GeminiBatchOutputLine); -} - -async function waitForGeminiBatch(params: { - gemini: GeminiEmbeddingClient; - batchName: string; - wait: boolean; - pollIntervalMs: number; - timeoutMs: number; - debug?: (message: string, data?: Record) => void; - initial?: GeminiBatchStatus; -}): Promise<{ outputFileId: string }> { - const start = Date.now(); - let current: GeminiBatchStatus | undefined = params.initial; - while (true) { - const status = - current ?? - (await fetchGeminiBatchStatus({ - gemini: params.gemini, - batchName: params.batchName, - })); - const state = status.state ?? "UNKNOWN"; - if (["SUCCEEDED", "COMPLETED", "DONE"].includes(state)) { - const outputFileId = - status.outputConfig?.file ?? - status.outputConfig?.fileId ?? - status.metadata?.output?.responsesFile; - if (!outputFileId) { - throw new Error(`gemini batch ${params.batchName} completed without output file`); - } - return { outputFileId }; - } - if (["FAILED", "CANCELLED", "CANCELED", "EXPIRED"].includes(state)) { - const message = status.error?.message ?? "unknown error"; - throw new Error(`gemini batch ${params.batchName} ${state}: ${message}`); - } - if (!params.wait) { - throw new Error(`gemini batch ${params.batchName} still ${state}; wait disabled`); - } - if (Date.now() - start > params.timeoutMs) { - throw new Error(`gemini batch ${params.batchName} timed out after ${params.timeoutMs}ms`); - } - params.debug?.(`gemini batch ${params.batchName} ${state}; waiting ${params.pollIntervalMs}ms`); - await new Promise((resolve) => setTimeout(resolve, params.pollIntervalMs)); - current = undefined; - } -} - -export async function runGeminiEmbeddingBatches( - params: { - gemini: GeminiEmbeddingClient; - agentId: string; - requests: GeminiBatchRequest[]; - } & EmbeddingBatchExecutionParams, -): Promise> { - return await runEmbeddingBatchGroups({ - ...buildEmbeddingBatchGroupOptions(params, { - maxRequests: GEMINI_BATCH_MAX_REQUESTS, - debugLabel: "memory embeddings: gemini batch submit", - }), - runGroup: async ({ group, groupIndex, groups, byCustomId }) => { - const batchInfo = await submitGeminiBatch({ - gemini: params.gemini, - requests: group, - agentId: params.agentId, - }); - const batchName = batchInfo.name ?? ""; - if (!batchName) { - throw new Error("gemini batch create failed: missing batch name"); - } - - params.debug?.("memory embeddings: gemini batch created", { - batchName, - state: batchInfo.state, - group: groupIndex + 1, - groups, - requests: group.length, - }); - - if ( - !params.wait && - batchInfo.state && - !["SUCCEEDED", "COMPLETED", "DONE"].includes(batchInfo.state) - ) { - throw new Error( - `gemini batch ${batchName} submitted; enable remote.batch.wait to await completion`, - ); - } - - const completed = - batchInfo.state && ["SUCCEEDED", "COMPLETED", "DONE"].includes(batchInfo.state) - ? { - outputFileId: - batchInfo.outputConfig?.file ?? - batchInfo.outputConfig?.fileId ?? - batchInfo.metadata?.output?.responsesFile ?? - "", - } - : await waitForGeminiBatch({ - gemini: params.gemini, - batchName, - wait: params.wait, - pollIntervalMs: params.pollIntervalMs, - timeoutMs: params.timeoutMs, - debug: params.debug, - initial: batchInfo, - }); - if (!completed.outputFileId) { - throw new Error(`gemini batch ${batchName} completed without output file`); - } - - const content = await fetchGeminiFileContent({ - gemini: params.gemini, - fileId: completed.outputFileId, - }); - const outputLines = parseGeminiBatchOutput(content); - const errors: string[] = []; - const remaining = new Set(group.map((request) => request.custom_id)); - - for (const line of outputLines) { - const customId = line.key ?? line.custom_id ?? line.request_id; - if (!customId) { - continue; - } - remaining.delete(customId); - if (line.error?.message) { - errors.push(`${customId}: ${line.error.message}`); - continue; - } - if (line.response?.error?.message) { - errors.push(`${customId}: ${line.response.error.message}`); - continue; - } - const embedding = sanitizeAndNormalizeEmbedding( - line.embedding?.values ?? line.response?.embedding?.values ?? [], - ); - if (embedding.length === 0) { - errors.push(`${customId}: empty embedding`); - continue; - } - byCustomId.set(customId, embedding); - } - - if (errors.length > 0) { - throw new Error(`gemini batch ${batchName} failed: ${errors.join("; ")}`); - } - if (remaining.size > 0) { - throw new Error(`gemini batch ${batchName} missing ${remaining.size} embedding responses`); - } - }, - }); -} diff --git a/src/memory-host-sdk/host/batch-openai.test.ts b/src/memory-host-sdk/host/batch-openai.test.ts deleted file mode 100644 index ae48b8d7d93..00000000000 --- a/src/memory-host-sdk/host/batch-openai.test.ts +++ /dev/null @@ -1,108 +0,0 @@ -import { beforeEach, describe, expect, it, vi } from "vitest"; - -const mocks = vi.hoisted(() => ({ - uploadBatchJsonlFile: vi.fn(async () => "file_in"), - postJsonWithRetry: vi.fn(async () => ({ id: "batch_1", status: "in_progress" })), - resolveCompletedBatchResult: vi.fn(async () => ({ outputFileId: "file_out" })), - withRemoteHttpResponse: vi.fn( - async (params: { url: string; onResponse: (res: Response) => Promise }) => { - if (params.url.endsWith("/files/file_out/content")) { - const content = [ - JSON.stringify({ - custom_id: "0", - response: { - status_code: 200, - body: { data: [{ embedding: [1, 0, 0], index: 0 }] }, - }, - }), - JSON.stringify({ - custom_id: "1", - response: { - status_code: 200, - body: { data: [{ embedding: [2, 0, 0], index: 0 }] }, - }, - }), - ].join("\n"); - return await params.onResponse({ - ok: true, - status: 200, - text: async () => content, - } as Response); - } - return await params.onResponse({ - ok: true, - status: 200, - json: async () => ({ id: "batch_1", status: "completed", output_file_id: "file_out" }), - } as Response); - }, - ), -})); - -vi.mock("./batch-upload.js", () => ({ - uploadBatchJsonlFile: mocks.uploadBatchJsonlFile, -})); - -vi.mock("./batch-http.js", () => ({ - postJsonWithRetry: mocks.postJsonWithRetry, -})); - -vi.mock("./batch-status.js", () => ({ - resolveBatchCompletionFromStatus: vi.fn(), - resolveCompletedBatchResult: mocks.resolveCompletedBatchResult, - throwIfBatchTerminalFailure: vi.fn(), -})); - -vi.mock("./remote-http.js", () => ({ - withRemoteHttpResponse: mocks.withRemoteHttpResponse, -})); - -describe("runOpenAiEmbeddingBatches", () => { - beforeEach(() => { - vi.clearAllMocks(); - }); - - it("maps uploaded batch output rows back to embeddings", async () => { - const { runOpenAiEmbeddingBatches, OPENAI_BATCH_ENDPOINT } = await import("./batch-openai.js"); - - const result = await runOpenAiEmbeddingBatches({ - openAi: { - baseUrl: "https://api.openai.com/v1", - headers: { Authorization: "Bearer test" }, - fetchImpl: fetch, - model: "text-embedding-3-small", - }, - agentId: "main", - requests: [ - { - custom_id: "0", - method: "POST", - url: OPENAI_BATCH_ENDPOINT, - body: { model: "text-embedding-3-small", input: "hello" }, - }, - { - custom_id: "1", - method: "POST", - url: OPENAI_BATCH_ENDPOINT, - body: { model: "text-embedding-3-small", input: "world" }, - }, - ], - wait: true, - pollIntervalMs: 1, - timeoutMs: 1000, - concurrency: 3, - }); - - expect(mocks.uploadBatchJsonlFile).toHaveBeenCalled(); - expect(mocks.postJsonWithRetry).toHaveBeenCalledWith( - expect.objectContaining({ - errorPrefix: "openai batch create failed", - body: expect.objectContaining({ - endpoint: OPENAI_BATCH_ENDPOINT, - metadata: { source: "openclaw-memory", agent: "main" }, - }), - }), - ); - expect(result.get("0")).toEqual([1, 0, 0]); - expect(result.get("1")).toEqual([2, 0, 0]); - }); -}); diff --git a/src/memory-host-sdk/host/batch-voyage.test.ts b/src/memory-host-sdk/host/batch-voyage.test.ts deleted file mode 100644 index 2fcdb9ec7c0..00000000000 --- a/src/memory-host-sdk/host/batch-voyage.test.ts +++ /dev/null @@ -1,176 +0,0 @@ -import { ReadableStream } from "node:stream/web"; -import { setTimeout as nativeSleep } from "node:timers/promises"; -import { describe, expect, it, vi } from "vitest"; -import { - runVoyageEmbeddingBatches, - type VoyageBatchOutputLine, - type VoyageBatchRequest, -} from "./batch-voyage.js"; -import type { VoyageEmbeddingClient } from "./embeddings-voyage.js"; - -const realNow = Date.now.bind(Date); - -describe("runVoyageEmbeddingBatches", () => { - const mockClient: VoyageEmbeddingClient = { - baseUrl: "https://api.voyageai.com/v1", - headers: { Authorization: "Bearer test-key" }, - model: "voyage-4-large", - }; - - const mockRequests: VoyageBatchRequest[] = [ - { custom_id: "req-1", body: { input: "text1" } }, - { custom_id: "req-2", body: { input: "text2" } }, - ]; - - it("successfully submits batch, waits, and streams results", async () => { - const outputLines: VoyageBatchOutputLine[] = [ - { - custom_id: "req-1", - response: { status_code: 200, body: { data: [{ embedding: [0.1, 0.1] }] } }, - }, - { - custom_id: "req-2", - response: { status_code: 200, body: { data: [{ embedding: [0.2, 0.2] }] } }, - }, - ]; - const withRemoteHttpResponse = vi.fn(); - const postJsonWithRetry = vi.fn(); - const uploadBatchJsonlFile = vi.fn(); - - // Create a stream that emits the NDJSON lines - const stream = new ReadableStream({ - start(controller) { - const text = outputLines.map((l) => JSON.stringify(l)).join("\n"); - controller.enqueue(new TextEncoder().encode(text)); - controller.close(); - }, - }); - uploadBatchJsonlFile.mockImplementationOnce(async (params) => { - expect(params.errorPrefix).toBe("voyage batch file upload failed"); - expect(params.requests).toEqual(mockRequests); - return "file-123"; - }); - postJsonWithRetry.mockImplementationOnce(async (params) => { - expect(params.url).toContain("/batches"); - expect(params.body).toMatchObject({ - input_file_id: "file-123", - completion_window: "12h", - request_params: { - model: "voyage-4-large", - input_type: "document", - }, - }); - return { - id: "batch-abc", - status: "pending", - }; - }); - withRemoteHttpResponse.mockImplementationOnce(async (params) => { - expect(params.url).toContain("/batches/batch-abc"); - return await params.onResponse( - new Response( - JSON.stringify({ - id: "batch-abc", - status: "completed", - output_file_id: "file-out-999", - }), - { - status: 200, - headers: { "Content-Type": "application/json" }, - }, - ), - ); - }); - withRemoteHttpResponse.mockImplementationOnce(async (params) => { - expect(params.url).toContain("/files/file-out-999/content"); - return await params.onResponse( - new Response(stream as unknown as BodyInit, { - status: 200, - headers: { "Content-Type": "application/x-ndjson" }, - }), - ); - }); - - const results = await runVoyageEmbeddingBatches({ - client: mockClient, - agentId: "agent-1", - requests: mockRequests, - wait: true, - pollIntervalMs: 1, // fast poll - timeoutMs: 1000, - concurrency: 1, - deps: { - now: realNow, - sleep: async (ms) => { - await nativeSleep(ms); - }, - postJsonWithRetry, - uploadBatchJsonlFile, - withRemoteHttpResponse, - }, - }); - - expect(results.size).toBe(2); - expect(results.get("req-1")).toEqual([0.1, 0.1]); - expect(results.get("req-2")).toEqual([0.2, 0.2]); - expect(uploadBatchJsonlFile).toHaveBeenCalledTimes(1); - expect(postJsonWithRetry).toHaveBeenCalledTimes(1); - expect(withRemoteHttpResponse).toHaveBeenCalledTimes(2); - }); - - it("handles empty lines and stream chunks correctly", async () => { - const withRemoteHttpResponse = vi.fn(); - const postJsonWithRetry = vi.fn(); - const uploadBatchJsonlFile = vi.fn(); - const stream = new ReadableStream({ - start(controller) { - const line1 = JSON.stringify({ - custom_id: "req-1", - response: { body: { data: [{ embedding: [1] }] } }, - }); - const line2 = JSON.stringify({ - custom_id: "req-2", - response: { body: { data: [{ embedding: [2] }] } }, - }); - - // Split across chunks - controller.enqueue(new TextEncoder().encode(line1 + "\n")); - controller.enqueue(new TextEncoder().encode("\n")); // empty line - controller.enqueue(new TextEncoder().encode(line2)); // no newline at EOF - controller.close(); - }, - }); - uploadBatchJsonlFile.mockResolvedValueOnce("f1"); - postJsonWithRetry.mockResolvedValueOnce({ - id: "b1", - status: "completed", - output_file_id: "out1", - }); - withRemoteHttpResponse.mockImplementationOnce(async (params) => { - expect(params.url).toContain("/files/out1/content"); - return await params.onResponse(new Response(stream as unknown as BodyInit, { status: 200 })); - }); - - const results = await runVoyageEmbeddingBatches({ - client: mockClient, - agentId: "a1", - requests: mockRequests, - wait: true, - pollIntervalMs: 1, - timeoutMs: 1000, - concurrency: 1, - deps: { - now: realNow, - sleep: async (ms) => { - await nativeSleep(ms); - }, - postJsonWithRetry, - uploadBatchJsonlFile, - withRemoteHttpResponse, - }, - }); - - expect(results.get("req-1")).toEqual([1]); - expect(results.get("req-2")).toEqual([2]); - }); -}); diff --git a/src/memory-host-sdk/host/batch-voyage.ts b/src/memory-host-sdk/host/batch-voyage.ts deleted file mode 100644 index fcb257a4d7d..00000000000 --- a/src/memory-host-sdk/host/batch-voyage.ts +++ /dev/null @@ -1,315 +0,0 @@ -import { createInterface } from "node:readline"; -import { Readable } from "node:stream"; -import { - applyEmbeddingBatchOutputLine, - buildBatchHeaders, - buildEmbeddingBatchGroupOptions, - EMBEDDING_BATCH_ENDPOINT, - extractBatchErrorMessage, - formatUnavailableBatchError, - normalizeBatchBaseUrl, - postJsonWithRetry, - resolveBatchCompletionFromStatus, - resolveCompletedBatchResult, - runEmbeddingBatchGroups, - throwIfBatchTerminalFailure, - type EmbeddingBatchExecutionParams, - type EmbeddingBatchStatus, - type BatchCompletionResult, - type ProviderBatchOutputLine, - uploadBatchJsonlFile, - withRemoteHttpResponse, -} from "./batch-embedding-common.js"; -import type { VoyageEmbeddingClient } from "./embeddings-voyage.js"; - -/** - * Voyage Batch API Input Line format. - * See: https://docs.voyageai.com/docs/batch-inference - */ -export type VoyageBatchRequest = { - custom_id: string; - body: { - input: string | string[]; - }; -}; - -export type VoyageBatchStatus = EmbeddingBatchStatus; -export type VoyageBatchOutputLine = ProviderBatchOutputLine; - -export const VOYAGE_BATCH_ENDPOINT = EMBEDDING_BATCH_ENDPOINT; -const VOYAGE_BATCH_COMPLETION_WINDOW = "12h"; -const VOYAGE_BATCH_MAX_REQUESTS = 50000; - -type VoyageBatchDeps = { - now: () => number; - sleep: (ms: number) => Promise; - postJsonWithRetry: typeof postJsonWithRetry; - uploadBatchJsonlFile: typeof uploadBatchJsonlFile; - withRemoteHttpResponse: typeof withRemoteHttpResponse; -}; - -function resolveVoyageBatchDeps(overrides: Partial | undefined): VoyageBatchDeps { - return { - now: overrides?.now ?? Date.now, - sleep: - overrides?.sleep ?? - (async (ms: number) => await new Promise((resolve) => setTimeout(resolve, ms))), - postJsonWithRetry: overrides?.postJsonWithRetry ?? postJsonWithRetry, - uploadBatchJsonlFile: overrides?.uploadBatchJsonlFile ?? uploadBatchJsonlFile, - withRemoteHttpResponse: overrides?.withRemoteHttpResponse ?? withRemoteHttpResponse, - }; -} - -async function assertVoyageResponseOk(res: Response, context: string): Promise { - if (!res.ok) { - const text = await res.text(); - throw new Error(`${context}: ${res.status} ${text}`); - } -} - -function buildVoyageBatchRequest(params: { - client: VoyageEmbeddingClient; - path: string; - onResponse: (res: Response) => Promise; -}) { - const baseUrl = normalizeBatchBaseUrl(params.client); - return { - url: `${baseUrl}/${params.path}`, - ssrfPolicy: params.client.ssrfPolicy, - init: { - headers: buildBatchHeaders(params.client, { json: true }), - }, - onResponse: params.onResponse, - }; -} - -async function submitVoyageBatch(params: { - client: VoyageEmbeddingClient; - requests: VoyageBatchRequest[]; - agentId: string; - deps: VoyageBatchDeps; -}): Promise { - const baseUrl = normalizeBatchBaseUrl(params.client); - const inputFileId = await params.deps.uploadBatchJsonlFile({ - client: params.client, - requests: params.requests, - errorPrefix: "voyage batch file upload failed", - }); - - // 2. Create batch job using Voyage Batches API - return await params.deps.postJsonWithRetry({ - url: `${baseUrl}/batches`, - headers: buildBatchHeaders(params.client, { json: true }), - ssrfPolicy: params.client.ssrfPolicy, - body: { - input_file_id: inputFileId, - endpoint: VOYAGE_BATCH_ENDPOINT, - completion_window: VOYAGE_BATCH_COMPLETION_WINDOW, - request_params: { - model: params.client.model, - input_type: "document", - }, - metadata: { - source: "clawdbot-memory", - agent: params.agentId, - }, - }, - errorPrefix: "voyage batch create failed", - }); -} - -async function fetchVoyageBatchStatus(params: { - client: VoyageEmbeddingClient; - batchId: string; - deps: VoyageBatchDeps; -}): Promise { - return await params.deps.withRemoteHttpResponse( - buildVoyageBatchRequest({ - client: params.client, - path: `batches/${params.batchId}`, - onResponse: async (res) => { - await assertVoyageResponseOk(res, "voyage batch status failed"); - return (await res.json()) as VoyageBatchStatus; - }, - }), - ); -} - -async function readVoyageBatchError(params: { - client: VoyageEmbeddingClient; - errorFileId: string; - deps: VoyageBatchDeps; -}): Promise { - try { - return await params.deps.withRemoteHttpResponse( - buildVoyageBatchRequest({ - client: params.client, - path: `files/${params.errorFileId}/content`, - onResponse: async (res) => { - await assertVoyageResponseOk(res, "voyage batch error file content failed"); - const text = await res.text(); - if (!text.trim()) { - return undefined; - } - const lines = text - .split("\n") - .map((line) => line.trim()) - .filter(Boolean) - .map((line) => JSON.parse(line) as VoyageBatchOutputLine); - return extractBatchErrorMessage(lines); - }, - }), - ); - } catch (err) { - return formatUnavailableBatchError(err); - } -} - -async function waitForVoyageBatch(params: { - client: VoyageEmbeddingClient; - batchId: string; - wait: boolean; - pollIntervalMs: number; - timeoutMs: number; - debug?: (message: string, data?: Record) => void; - initial?: VoyageBatchStatus; - deps: VoyageBatchDeps; -}): Promise { - const start = params.deps.now(); - let current: VoyageBatchStatus | undefined = params.initial; - while (true) { - const status = - current ?? - (await fetchVoyageBatchStatus({ - client: params.client, - batchId: params.batchId, - deps: params.deps, - })); - const state = status.status ?? "unknown"; - if (state === "completed") { - return resolveBatchCompletionFromStatus({ - provider: "voyage", - batchId: params.batchId, - status, - }); - } - await throwIfBatchTerminalFailure({ - provider: "voyage", - status: { ...status, id: params.batchId }, - readError: async (errorFileId) => - await readVoyageBatchError({ - client: params.client, - errorFileId, - deps: params.deps, - }), - }); - if (!params.wait) { - throw new Error(`voyage batch ${params.batchId} still ${state}; wait disabled`); - } - if (params.deps.now() - start > params.timeoutMs) { - throw new Error(`voyage batch ${params.batchId} timed out after ${params.timeoutMs}ms`); - } - params.debug?.(`voyage batch ${params.batchId} ${state}; waiting ${params.pollIntervalMs}ms`); - await params.deps.sleep(params.pollIntervalMs); - current = undefined; - } -} - -export async function runVoyageEmbeddingBatches( - params: { - client: VoyageEmbeddingClient; - agentId: string; - requests: VoyageBatchRequest[]; - deps?: Partial; - } & EmbeddingBatchExecutionParams, -): Promise> { - const deps = resolveVoyageBatchDeps(params.deps); - return await runEmbeddingBatchGroups({ - ...buildEmbeddingBatchGroupOptions(params, { - maxRequests: VOYAGE_BATCH_MAX_REQUESTS, - debugLabel: "memory embeddings: voyage batch submit", - }), - runGroup: async ({ group, groupIndex, groups, byCustomId }) => { - const batchInfo = await submitVoyageBatch({ - client: params.client, - requests: group, - agentId: params.agentId, - deps, - }); - if (!batchInfo.id) { - throw new Error("voyage batch create failed: missing batch id"); - } - const batchId = batchInfo.id; - - params.debug?.("memory embeddings: voyage batch created", { - batchId: batchInfo.id, - status: batchInfo.status, - group: groupIndex + 1, - groups, - requests: group.length, - }); - - const completed = await resolveCompletedBatchResult({ - provider: "voyage", - status: batchInfo, - wait: params.wait, - waitForBatch: async () => - await waitForVoyageBatch({ - client: params.client, - batchId, - wait: params.wait, - pollIntervalMs: params.pollIntervalMs, - timeoutMs: params.timeoutMs, - debug: params.debug, - initial: batchInfo, - deps, - }), - }); - - const baseUrl = normalizeBatchBaseUrl(params.client); - const errors: string[] = []; - const remaining = new Set(group.map((request) => request.custom_id)); - - await deps.withRemoteHttpResponse({ - url: `${baseUrl}/files/${completed.outputFileId}/content`, - ssrfPolicy: params.client.ssrfPolicy, - init: { - headers: buildBatchHeaders(params.client, { json: true }), - }, - onResponse: async (contentRes) => { - if (!contentRes.ok) { - const text = await contentRes.text(); - throw new Error(`voyage batch file content failed: ${contentRes.status} ${text}`); - } - - if (!contentRes.body) { - return; - } - const reader = createInterface({ - input: Readable.fromWeb( - contentRes.body as unknown as import("stream/web").ReadableStream, - ), - terminal: false, - }); - - for await (const rawLine of reader) { - if (!rawLine.trim()) { - continue; - } - const line = JSON.parse(rawLine) as VoyageBatchOutputLine; - applyEmbeddingBatchOutputLine({ line, remaining, errors, byCustomId }); - } - }, - }); - - if (errors.length > 0) { - throw new Error(`voyage batch ${batchInfo.id} failed: ${errors.join("; ")}`); - } - if (remaining.size > 0) { - throw new Error( - `voyage batch ${batchInfo.id} missing ${remaining.size} embedding responses`, - ); - } - }, - }); -} diff --git a/src/memory-host-sdk/host/embedding-model-limits.ts b/src/memory-host-sdk/host/embedding-model-limits.ts index 00766748848..714114f670e 100644 --- a/src/memory-host-sdk/host/embedding-model-limits.ts +++ b/src/memory-host-sdk/host/embedding-model-limits.ts @@ -1,40 +1,14 @@ -import { normalizeLowercaseStringOrEmpty } from "../../shared/string-coerce.js"; import type { EmbeddingProvider } from "./embeddings.js"; const DEFAULT_EMBEDDING_MAX_INPUT_TOKENS = 8192; const DEFAULT_LOCAL_EMBEDDING_MAX_INPUT_TOKENS = 2048; -const KNOWN_EMBEDDING_MAX_INPUT_TOKENS: Record = { - "openai:text-embedding-3-small": 8192, - "openai:text-embedding-3-large": 8192, - "openai:text-embedding-ada-002": 8191, - "gemini:text-embedding-004": 2048, - "gemini:gemini-embedding-001": 2048, - "gemini:gemini-embedding-2-preview": 8192, - "voyage:voyage-3": 32000, - "voyage:voyage-3-lite": 16000, - "voyage:voyage-code-3": 32000, -}; - export function resolveEmbeddingMaxInputTokens(provider: EmbeddingProvider): number { if (typeof provider.maxInputTokens === "number") { return provider.maxInputTokens; } - // Provider/model mapping is best-effort; different providers use different - // limits and we prefer to be conservative when we don't know. - const key = normalizeLowercaseStringOrEmpty(`${provider.id}:${provider.model}`); - const known = KNOWN_EMBEDDING_MAX_INPUT_TOKENS[key]; - if (typeof known === "number") { - return known; - } - - // Provider-specific conservative fallbacks. This prevents us from accidentally - // using the OpenAI default for providers with much smaller limits. - if (normalizeLowercaseStringOrEmpty(provider.id) === "gemini") { - return 2048; - } - if (normalizeLowercaseStringOrEmpty(provider.id) === "local") { + if (provider.id === "local") { return DEFAULT_LOCAL_EMBEDDING_MAX_INPUT_TOKENS; } diff --git a/src/memory-host-sdk/host/embedding-provider-adapter-utils.ts b/src/memory-host-sdk/host/embedding-provider-adapter-utils.ts new file mode 100644 index 00000000000..401173b2826 --- /dev/null +++ b/src/memory-host-sdk/host/embedding-provider-adapter-utils.ts @@ -0,0 +1,29 @@ +import { normalizeLowercaseStringOrEmpty } from "../../shared/string-coerce.js"; + +export function isMissingEmbeddingApiKeyError(err: unknown): boolean { + return err instanceof Error && err.message.includes("No API key found for provider"); +} + +export function sanitizeEmbeddingCacheHeaders( + headers: Record, + excludedHeaderNames: string[], +): Array<[string, string]> { + const excluded = new Set( + excludedHeaderNames.map((name) => normalizeLowercaseStringOrEmpty(name)), + ); + return Object.entries(headers) + .filter(([key]) => !excluded.has(normalizeLowercaseStringOrEmpty(key))) + .toSorted(([a], [b]) => a.localeCompare(b)) + .map(([key, value]) => [key, value]); +} + +export function mapBatchEmbeddingsByIndex( + byCustomId: Map, + count: number, +): number[][] { + const embeddings: number[][] = []; + for (let index = 0; index < count; index += 1) { + embeddings.push(byCustomId.get(String(index)) ?? []); + } + return embeddings; +} diff --git a/src/memory-host-sdk/host/embeddings-bedrock.test.ts b/src/memory-host-sdk/host/embeddings-bedrock.test.ts deleted file mode 100644 index 71228daad5f..00000000000 --- a/src/memory-host-sdk/host/embeddings-bedrock.test.ts +++ /dev/null @@ -1,377 +0,0 @@ -import { afterEach, beforeAll, beforeEach, describe, expect, it, vi } from "vitest"; - -const { defaultProviderMock, resolveCredentialsMock, sendMock } = vi.hoisted(() => ({ - defaultProviderMock: vi.fn(), - resolveCredentialsMock: vi.fn(), - sendMock: vi.fn(), -})); - -vi.mock("@aws-sdk/client-bedrock-runtime", () => { - class MockClient { - region: string; - constructor(config: { region: string }) { - this.region = config.region; - } - send = sendMock; - } - class MockCommand { - input: unknown; - constructor(input: unknown) { - this.input = input; - } - } - return { BedrockRuntimeClient: MockClient, InvokeModelCommand: MockCommand }; -}); - -vi.mock("@aws-sdk/credential-provider-node", () => ({ - defaultProvider: defaultProviderMock.mockImplementation(() => resolveCredentialsMock), -})); - -let createBedrockEmbeddingProvider: typeof import("./embeddings-bedrock.js").createBedrockEmbeddingProvider; -let resolveBedrockEmbeddingClient: typeof import("./embeddings-bedrock.js").resolveBedrockEmbeddingClient; -let normalizeBedrockEmbeddingModel: typeof import("./embeddings-bedrock.js").normalizeBedrockEmbeddingModel; -let hasAwsCredentials: typeof import("./embeddings-bedrock.js").hasAwsCredentials; - -beforeAll(async () => { - ({ - createBedrockEmbeddingProvider, - resolveBedrockEmbeddingClient, - normalizeBedrockEmbeddingModel, - hasAwsCredentials, - } = await import("./embeddings-bedrock.js")); -}); - -beforeEach(() => { - defaultProviderMock.mockImplementation(() => resolveCredentialsMock); -}); - -const enc = (body: unknown) => ({ body: new TextEncoder().encode(JSON.stringify(body)) }); -const reqBody = (i = 0): Record => - JSON.parse(sendMock.mock.calls[i][0].input.body); - -describe("bedrock embedding provider", () => { - const originalEnv = process.env; - afterEach(() => { - process.env = originalEnv; - vi.restoreAllMocks(); - defaultProviderMock.mockClear(); - resolveCredentialsMock.mockReset(); - sendMock.mockReset(); - }); - - // --- Normalization --- - - it("normalizes model names with prefixes", () => { - expect(normalizeBedrockEmbeddingModel("bedrock/amazon.titan-embed-text-v2:0")).toBe( - "amazon.titan-embed-text-v2:0", - ); - expect(normalizeBedrockEmbeddingModel("amazon-bedrock/cohere.embed-english-v3")).toBe( - "cohere.embed-english-v3", - ); - expect(normalizeBedrockEmbeddingModel("")).toBe("amazon.titan-embed-text-v2:0"); - }); - - // --- Client resolution --- - - it("resolves region from env", () => { - process.env = { ...originalEnv, AWS_REGION: "eu-west-1" }; - const c = resolveBedrockEmbeddingClient({ - config: {} as never, - provider: "bedrock", - model: "amazon.titan-embed-text-v2:0", - fallback: "none", - }); - expect(c.region).toBe("eu-west-1"); - expect(c.dimensions).toBe(1024); - }); - - it("defaults to us-east-1", () => { - process.env = { ...originalEnv }; - delete process.env.AWS_REGION; - delete process.env.AWS_DEFAULT_REGION; - expect( - resolveBedrockEmbeddingClient({ - config: {} as never, - provider: "bedrock", - model: "amazon.titan-embed-text-v2:0", - fallback: "none", - }).region, - ).toBe("us-east-1"); - }); - - it("extracts region from baseUrl", () => { - process.env = { ...originalEnv }; - delete process.env.AWS_REGION; - const c = resolveBedrockEmbeddingClient({ - config: { - models: { - providers: { - "amazon-bedrock": { baseUrl: "https://bedrock-runtime.ap-southeast-2.amazonaws.com" }, - }, - }, - } as never, - provider: "bedrock", - model: "amazon.titan-embed-text-v2:0", - fallback: "none", - }); - expect(c.region).toBe("ap-southeast-2"); - }); - - it("validates dimensions", () => { - expect(() => - resolveBedrockEmbeddingClient({ - config: {} as never, - provider: "bedrock", - model: "amazon.titan-embed-text-v2:0", - fallback: "none", - outputDimensionality: 768, - }), - ).toThrow("Invalid dimensions 768"); - }); - - it("accepts valid dimensions", () => { - expect( - resolveBedrockEmbeddingClient({ - config: {} as never, - provider: "bedrock", - model: "amazon.titan-embed-text-v2:0", - fallback: "none", - outputDimensionality: 256, - }).dimensions, - ).toBe(256); - }); - - it("resolves throughput-suffixed variants", () => { - expect( - resolveBedrockEmbeddingClient({ - config: {} as never, - provider: "bedrock", - model: "amazon.titan-embed-text-v1:2:8k", - fallback: "none", - }).dimensions, - ).toBe(1536); - }); - - // --- Credential detection --- - - it("detects access keys", async () => { - await expect( - hasAwsCredentials({ - AWS_ACCESS_KEY_ID: "A", - AWS_SECRET_ACCESS_KEY: "s", - } as NodeJS.ProcessEnv), - ).resolves.toBe(true); - }); - it("detects profile", async () => { - await expect(hasAwsCredentials({ AWS_PROFILE: "default" } as NodeJS.ProcessEnv)).resolves.toBe( - true, - ); - }); - it("detects ECS task role", async () => { - await expect( - hasAwsCredentials({ AWS_CONTAINER_CREDENTIALS_RELATIVE_URI: "/v2" } as NodeJS.ProcessEnv), - ).resolves.toBe(true); - }); - it("detects EKS IRSA", async () => { - await expect( - hasAwsCredentials({ - AWS_WEB_IDENTITY_TOKEN_FILE: "/var/run/secrets/token", - AWS_ROLE_ARN: "arn:aws:iam::123:role/x", - } as NodeJS.ProcessEnv), - ).resolves.toBe(true); - }); - it("detects credentials via the AWS SDK default provider chain", async () => { - resolveCredentialsMock.mockResolvedValue({ accessKeyId: "AKIAEXAMPLE" }); - await expect(hasAwsCredentials({} as NodeJS.ProcessEnv)).resolves.toBe(true); - expect(defaultProviderMock).toHaveBeenCalledWith({ timeout: 1000, maxRetries: 0 }); - }); - it("returns false with no creds", async () => { - resolveCredentialsMock.mockRejectedValue(new Error("no aws credentials")); - await expect(hasAwsCredentials({} as NodeJS.ProcessEnv)).resolves.toBe(false); - }); - - // --- Titan V2 --- - - it("embeds with Titan V2", async () => { - sendMock.mockResolvedValue(enc({ embedding: [0.1, 0.2, 0.3] })); - const { provider } = await createBedrockEmbeddingProvider({ - config: {} as never, - provider: "bedrock", - model: "amazon.titan-embed-text-v2:0", - fallback: "none", - }); - expect(await provider.embedQuery("test")).toHaveLength(3); - expect(reqBody()).toMatchObject({ inputText: "test", normalize: true, dimensions: 1024 }); - }); - - it("returns empty for blank text", async () => { - const { provider } = await createBedrockEmbeddingProvider({ - config: {} as never, - provider: "bedrock", - model: "amazon.titan-embed-text-v2:0", - fallback: "none", - }); - expect(await provider.embedQuery(" ")).toEqual([]); - expect(sendMock).not.toHaveBeenCalled(); - }); - - it("batches Titan V2 concurrently", async () => { - sendMock - .mockResolvedValueOnce(enc({ embedding: [0.1] })) - .mockResolvedValueOnce(enc({ embedding: [0.2] })); - const { provider } = await createBedrockEmbeddingProvider({ - config: {} as never, - provider: "bedrock", - model: "amazon.titan-embed-text-v2:0", - fallback: "none", - }); - expect(await provider.embedBatch(["a", "b"])).toHaveLength(2); - expect(sendMock).toHaveBeenCalledTimes(2); - }); - - // --- Titan V1 --- - - it("sends only inputText for Titan V1", async () => { - sendMock.mockResolvedValue(enc({ embedding: [0.5] })); - const { provider } = await createBedrockEmbeddingProvider({ - config: {} as never, - provider: "bedrock", - model: "amazon.titan-embed-text-v1", - fallback: "none", - }); - await provider.embedQuery("hi"); - expect(reqBody()).toEqual({ inputText: "hi" }); - }); - - it("handles Titan G1 text variant", async () => { - sendMock.mockResolvedValue(enc({ embedding: [0.1] })); - const { provider } = await createBedrockEmbeddingProvider({ - config: {} as never, - provider: "bedrock", - model: "amazon.titan-embed-g1-text-02", - fallback: "none", - }); - await provider.embedQuery("hi"); - expect(reqBody()).toEqual({ inputText: "hi" }); - }); - - // --- Cohere V3 --- - - it("embeds Cohere V3 batch in single call", async () => { - sendMock.mockResolvedValue(enc({ embeddings: [[0.1], [0.2]] })); - const { provider } = await createBedrockEmbeddingProvider({ - config: {} as never, - provider: "bedrock", - model: "cohere.embed-english-v3", - fallback: "none", - }); - expect(await provider.embedBatch(["a", "b"])).toHaveLength(2); - expect(sendMock).toHaveBeenCalledTimes(1); - expect(reqBody()).toMatchObject({ texts: ["a", "b"], input_type: "search_document" }); - }); - - it("uses search_query for Cohere embedQuery", async () => { - sendMock.mockResolvedValue(enc({ embeddings: [[0.1]] })); - const { provider } = await createBedrockEmbeddingProvider({ - config: {} as never, - provider: "bedrock", - model: "cohere.embed-english-v3", - fallback: "none", - }); - await provider.embedQuery("q"); - expect(reqBody().input_type).toBe("search_query"); - }); - - // --- Cohere V4 --- - - it("embeds Cohere V4 with embedding_types + output_dimension", async () => { - sendMock.mockResolvedValue(enc({ embeddings: { float: [[0.1], [0.2]] } })); - const { provider } = await createBedrockEmbeddingProvider({ - config: {} as never, - provider: "bedrock", - model: "cohere.embed-v4:0", - fallback: "none", - }); - expect(await provider.embedBatch(["a", "b"])).toHaveLength(2); - expect(reqBody()).toMatchObject({ embedding_types: ["float"], output_dimension: 1536 }); - }); - - it("validates Cohere V4 dimensions", () => { - expect(() => - resolveBedrockEmbeddingClient({ - config: {} as never, - provider: "bedrock", - model: "cohere.embed-v4:0", - fallback: "none", - outputDimensionality: 2048, - }), - ).toThrow("Invalid dimensions 2048"); - }); - - // --- Nova --- - - it("embeds Nova with SINGLE_EMBEDDING format", async () => { - sendMock.mockResolvedValue( - enc({ embeddings: [{ embeddingType: "TEXT", embedding: [0.1, 0.2] }] }), - ); - const { provider } = await createBedrockEmbeddingProvider({ - config: {} as never, - provider: "bedrock", - model: "amazon.nova-2-multimodal-embeddings-v1:0", - fallback: "none", - }); - expect(await provider.embedQuery("hi")).toHaveLength(2); - expect(reqBody().taskType).toBe("SINGLE_EMBEDDING"); - }); - - it("validates Nova dimensions", () => { - expect(() => - resolveBedrockEmbeddingClient({ - config: {} as never, - provider: "bedrock", - model: "amazon.nova-2-multimodal-embeddings-v1:0", - fallback: "none", - outputDimensionality: 512, - }), - ).toThrow("Invalid dimensions 512"); - }); - - it("batches Nova concurrently", async () => { - sendMock - .mockResolvedValueOnce(enc({ embeddings: [{ embeddingType: "TEXT", embedding: [0.1] }] })) - .mockResolvedValueOnce(enc({ embeddings: [{ embeddingType: "TEXT", embedding: [0.2] }] })); - const { provider } = await createBedrockEmbeddingProvider({ - config: {} as never, - provider: "bedrock", - model: "amazon.nova-2-multimodal-embeddings-v1:0", - fallback: "none", - }); - expect(await provider.embedBatch(["a", "b"])).toHaveLength(2); - expect(sendMock).toHaveBeenCalledTimes(2); - }); - - // --- TwelveLabs --- - - it("embeds TwelveLabs Marengo", async () => { - sendMock.mockResolvedValue(enc({ data: [{ embedding: [0.1, 0.2] }] })); - const { provider } = await createBedrockEmbeddingProvider({ - config: {} as never, - provider: "bedrock", - model: "twelvelabs.marengo-embed-3-0-v1:0", - fallback: "none", - }); - expect(await provider.embedQuery("hi")).toHaveLength(2); - expect(reqBody()).toEqual({ inputType: "text", text: { inputText: "hi" } }); - }); - - it("embeds TwelveLabs object-style responses", async () => { - sendMock.mockResolvedValue(enc({ data: { embedding: [0.3, 0.4] } })); - const { provider } = await createBedrockEmbeddingProvider({ - config: {} as never, - provider: "bedrock", - model: "twelvelabs.marengo-embed-2-7-v1:0", - fallback: "none", - }); - expect(await provider.embedQuery("hi")).toEqual([0.6, 0.8]); - }); -}); diff --git a/src/memory-host-sdk/host/embeddings-gemini-request.ts b/src/memory-host-sdk/host/embeddings-gemini-request.ts deleted file mode 100644 index 038843fe207..00000000000 --- a/src/memory-host-sdk/host/embeddings-gemini-request.ts +++ /dev/null @@ -1,115 +0,0 @@ -import type { EmbeddingInput } from "./embedding-inputs.js"; -import type { GeminiTaskType } from "./embeddings.types.js"; - -export const DEFAULT_GEMINI_EMBEDDING_MODEL = "gemini-embedding-001"; - -export const GEMINI_EMBEDDING_2_MODELS = new Set([ - "gemini-embedding-2-preview", - // Add the GA model name here once released. -]); - -const GEMINI_EMBEDDING_2_DEFAULT_DIMENSIONS = 3072; -const GEMINI_EMBEDDING_2_VALID_DIMENSIONS = [768, 1536, 3072] as const; - -export type { GeminiTaskType } from "./embeddings.types.js"; - -export type GeminiTextPart = { text: string }; -export type GeminiInlinePart = { - inlineData: { mimeType: string; data: string }; -}; -export type GeminiPart = GeminiTextPart | GeminiInlinePart; -export type GeminiEmbeddingRequest = { - content: { parts: GeminiPart[] }; - taskType: GeminiTaskType; - outputDimensionality?: number; - model?: string; -}; -export type GeminiTextEmbeddingRequest = GeminiEmbeddingRequest; - -/** Builds the text-only Gemini embedding request shape used across direct and batch APIs. */ -export function buildGeminiTextEmbeddingRequest(params: { - text: string; - taskType: GeminiTaskType; - outputDimensionality?: number; - modelPath?: string; -}): GeminiTextEmbeddingRequest { - return buildGeminiEmbeddingRequest({ - input: { text: params.text }, - taskType: params.taskType, - outputDimensionality: params.outputDimensionality, - modelPath: params.modelPath, - }); -} - -export function buildGeminiEmbeddingRequest(params: { - input: EmbeddingInput; - taskType: GeminiTaskType; - outputDimensionality?: number; - modelPath?: string; -}): GeminiEmbeddingRequest { - const request: GeminiEmbeddingRequest = { - content: { - parts: params.input.parts?.map((part) => - part.type === "text" - ? ({ text: part.text } satisfies GeminiTextPart) - : ({ - inlineData: { mimeType: part.mimeType, data: part.data }, - } satisfies GeminiInlinePart), - ) ?? [{ text: params.input.text }], - }, - taskType: params.taskType, - }; - if (params.modelPath) { - request.model = params.modelPath; - } - if (params.outputDimensionality != null) { - request.outputDimensionality = params.outputDimensionality; - } - return request; -} - -/** - * Returns true if the given model name is a gemini-embedding-2 variant that - * supports `outputDimensionality` and extended task types. - */ -export function isGeminiEmbedding2Model(model: string): boolean { - return GEMINI_EMBEDDING_2_MODELS.has(model); -} - -/** - * Validate and return the `outputDimensionality` for gemini-embedding-2 models. - * Returns `undefined` for older models (they don't support the param). - */ -export function resolveGeminiOutputDimensionality( - model: string, - requested?: number, -): number | undefined { - if (!isGeminiEmbedding2Model(model)) { - return undefined; - } - if (requested == null) { - return GEMINI_EMBEDDING_2_DEFAULT_DIMENSIONS; - } - const valid: readonly number[] = GEMINI_EMBEDDING_2_VALID_DIMENSIONS; - if (!valid.includes(requested)) { - throw new Error( - `Invalid outputDimensionality ${requested} for ${model}. Valid values: ${valid.join(", ")}`, - ); - } - return requested; -} - -export function normalizeGeminiModel(model: string): string { - const trimmed = model.trim(); - if (!trimmed) { - return DEFAULT_GEMINI_EMBEDDING_MODEL; - } - const withoutPrefix = trimmed.replace(/^models\//, ""); - if (withoutPrefix.startsWith("gemini/")) { - return withoutPrefix.slice("gemini/".length); - } - if (withoutPrefix.startsWith("google/")) { - return withoutPrefix.slice("google/".length); - } - return withoutPrefix; -} diff --git a/src/memory-host-sdk/host/embeddings-github-copilot.test.ts b/src/memory-host-sdk/host/embeddings-github-copilot.test.ts deleted file mode 100644 index 0f3904b1555..00000000000 --- a/src/memory-host-sdk/host/embeddings-github-copilot.test.ts +++ /dev/null @@ -1,178 +0,0 @@ -import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; - -const resolveCopilotApiTokenMock = vi.hoisted(() => vi.fn()); -const fetchWithSsrFGuardMock = vi.hoisted(() => vi.fn()); - -vi.mock("../../agents/github-copilot-token.js", () => ({ - DEFAULT_COPILOT_API_BASE_URL: "https://api.githubcopilot.test", - resolveCopilotApiToken: resolveCopilotApiTokenMock, -})); - -vi.mock("../../infra/net/fetch-guard.js", () => ({ - fetchWithSsrFGuard: fetchWithSsrFGuardMock, -})); - -import { createGitHubCopilotEmbeddingProvider } from "./embeddings-github-copilot.js"; - -function mockFetchResponse(spec: { ok: boolean; status?: number; json?: unknown; text?: string }) { - fetchWithSsrFGuardMock.mockImplementationOnce(async () => ({ - response: { - ok: spec.ok, - status: spec.status ?? (spec.ok ? 200 : 500), - json: async () => spec.json, - text: async () => spec.text ?? "", - }, - release: vi.fn(async () => {}), - })); -} - -describe("createGitHubCopilotEmbeddingProvider", () => { - beforeEach(() => { - resolveCopilotApiTokenMock.mockResolvedValue({ - token: "copilot-token-a", - expiresAt: Date.now() + 3_600_000, - source: "test", - baseUrl: "https://api.githubcopilot.test", - }); - }); - - afterEach(() => { - vi.restoreAllMocks(); - resolveCopilotApiTokenMock.mockReset(); - fetchWithSsrFGuardMock.mockReset(); - }); - - it("normalizes embeddings returned for queries", async () => { - mockFetchResponse({ - ok: true, - json: { - data: [{ index: 0, embedding: [3, 4] }], - }, - }); - - const { provider } = await createGitHubCopilotEmbeddingProvider({ - githubToken: "gh_test", - model: "text-embedding-3-small", - }); - - await expect(provider.embedQuery("hello")).resolves.toEqual([0.6, 0.8]); - expect(fetchWithSsrFGuardMock).toHaveBeenCalledWith( - expect.objectContaining({ - url: "https://api.githubcopilot.test/embeddings", - }), - ); - }); - - it("preserves input order by explicit response index", async () => { - mockFetchResponse({ - ok: true, - json: { - data: [ - { index: 1, embedding: [0, 2] }, - { index: 0, embedding: [1, 0] }, - ], - }, - }); - - const { provider } = await createGitHubCopilotEmbeddingProvider({ - githubToken: "gh_test", - model: "text-embedding-3-small", - }); - - await expect(provider.embedBatch(["first", "second"])).resolves.toEqual([ - [1, 0], - [0, 1], - ]); - }); - - it("uses a fresh Copilot token for later requests", async () => { - resolveCopilotApiTokenMock - .mockResolvedValueOnce({ - token: "copilot-token-create", - expiresAt: Date.now() + 3_600_000, - source: "test", - baseUrl: "https://api.githubcopilot.test", - }) - .mockResolvedValueOnce({ - token: "copilot-token-first", - expiresAt: Date.now() + 3_600_000, - source: "test", - baseUrl: "https://api.githubcopilot.test", - }) - .mockResolvedValueOnce({ - token: "copilot-token-second", - expiresAt: Date.now() + 3_600_000, - source: "test", - baseUrl: "https://api.githubcopilot.test", - }); - mockFetchResponse({ - ok: true, - json: { data: [{ index: 0, embedding: [1, 0] }] }, - }); - mockFetchResponse({ - ok: true, - json: { data: [{ index: 0, embedding: [0, 1] }] }, - }); - - const { provider } = await createGitHubCopilotEmbeddingProvider({ - githubToken: "gh_test", - model: "text-embedding-3-small", - }); - - await provider.embedQuery("first"); - await provider.embedQuery("second"); - - const firstHeaders = fetchWithSsrFGuardMock.mock.calls[0]?.[0]?.init?.headers as Record< - string, - string - >; - const secondHeaders = fetchWithSsrFGuardMock.mock.calls[1]?.[0]?.init?.headers as Record< - string, - string - >; - expect(firstHeaders.Authorization).toBe("Bearer copilot-token-first"); - expect(secondHeaders.Authorization).toBe("Bearer copilot-token-second"); - }); - - it("honors custom baseUrl and header overrides", async () => { - mockFetchResponse({ - ok: true, - json: { data: [{ index: 0, embedding: [1, 0] }] }, - }); - - const { provider } = await createGitHubCopilotEmbeddingProvider({ - githubToken: "gh_test", - model: "text-embedding-3-small", - baseUrl: "https://proxy.example/v1", - headers: { "X-Proxy-Token": "proxy" }, - }); - - await provider.embedQuery("hello"); - - const call = fetchWithSsrFGuardMock.mock.calls[0]?.[0] as { - init: { headers: Record }; - url: string; - }; - expect(call.url).toBe("https://proxy.example/v1/embeddings"); - expect(call.init.headers["X-Proxy-Token"]).toBe("proxy"); - expect(call.init.headers.Authorization).toBe("Bearer copilot-token-a"); - }); - - it("fails fast on sparse or malformed embedding payloads", async () => { - mockFetchResponse({ - ok: true, - json: { - data: [{ index: 1, embedding: [1, 0] }], - }, - }); - - const { provider } = await createGitHubCopilotEmbeddingProvider({ - githubToken: "gh_test", - model: "text-embedding-3-small", - }); - - await expect(provider.embedBatch(["first", "second"])).rejects.toThrow( - "GitHub Copilot embeddings response missing vectors for some inputs", - ); - }); -}); diff --git a/src/memory-host-sdk/host/embeddings-github-copilot.ts b/src/memory-host-sdk/host/embeddings-github-copilot.ts deleted file mode 100644 index 246b764abb9..00000000000 --- a/src/memory-host-sdk/host/embeddings-github-copilot.ts +++ /dev/null @@ -1,151 +0,0 @@ -import { - DEFAULT_COPILOT_API_BASE_URL, - resolveCopilotApiToken, -} from "../../agents/github-copilot-token.js"; -import { sanitizeAndNormalizeEmbedding } from "./embedding-vectors.js"; -import type { EmbeddingProvider } from "./embeddings.types.js"; -import { buildRemoteBaseUrlPolicy, withRemoteHttpResponse } from "./remote-http.js"; - -export type GitHubCopilotEmbeddingClient = { - githubToken: string; - model: string; - baseUrl?: string; - headers?: Record; - env?: NodeJS.ProcessEnv; - fetchImpl?: typeof fetch; -}; - -const COPILOT_EMBEDDING_PROVIDER_ID = "github-copilot"; - -const COPILOT_HEADERS_STATIC: Record = { - "Content-Type": "application/json", - "Editor-Version": "vscode/1.96.2", - "User-Agent": "GitHubCopilotChat/0.26.7", -}; - -function resolveConfiguredBaseUrl( - configuredBaseUrl: string | undefined, - tokenBaseUrl: string | undefined, -): string { - const trimmed = configuredBaseUrl?.trim(); - if (trimmed) { - return trimmed; - } - return tokenBaseUrl || DEFAULT_COPILOT_API_BASE_URL; -} - -async function resolveGitHubCopilotEmbeddingSession(client: GitHubCopilotEmbeddingClient): Promise<{ - baseUrl: string; - headers: Record; -}> { - const token = await resolveCopilotApiToken({ - githubToken: client.githubToken, - env: client.env, - fetchImpl: client.fetchImpl, - }); - const baseUrl = resolveConfiguredBaseUrl(client.baseUrl, token.baseUrl); - return { - baseUrl, - headers: { - ...COPILOT_HEADERS_STATIC, - ...client.headers, - Authorization: `Bearer ${token.token}`, - }, - }; -} - -function parseGitHubCopilotEmbeddingPayload(payload: unknown, expectedCount: number): number[][] { - if (!payload || typeof payload !== "object") { - throw new Error("GitHub Copilot embeddings response missing data[]"); - } - const data = (payload as { data?: unknown }).data; - if (!Array.isArray(data)) { - throw new Error("GitHub Copilot embeddings response missing data[]"); - } - - const vectors = Array.from({ length: expectedCount }); - for (const entry of data) { - if (!entry || typeof entry !== "object") { - throw new Error("GitHub Copilot embeddings response contains an invalid entry"); - } - const indexValue = (entry as { index?: unknown }).index; - const embedding = (entry as { embedding?: unknown }).embedding; - const index = typeof indexValue === "number" ? indexValue : Number.NaN; - if (!Number.isInteger(index)) { - throw new Error("GitHub Copilot embeddings response contains an invalid index"); - } - if (index < 0 || index >= expectedCount) { - throw new Error("GitHub Copilot embeddings response contains an out-of-range index"); - } - if (vectors[index] !== undefined) { - throw new Error("GitHub Copilot embeddings response contains duplicate indexes"); - } - if (!Array.isArray(embedding) || !embedding.every((value) => typeof value === "number")) { - throw new Error("GitHub Copilot embeddings response contains an invalid embedding"); - } - vectors[index] = sanitizeAndNormalizeEmbedding(embedding); - } - - for (let index = 0; index < expectedCount; index += 1) { - if (vectors[index] === undefined) { - throw new Error("GitHub Copilot embeddings response missing vectors for some inputs"); - } - } - return vectors as number[][]; -} - -export async function createGitHubCopilotEmbeddingProvider( - client: GitHubCopilotEmbeddingClient, -): Promise<{ provider: EmbeddingProvider; client: GitHubCopilotEmbeddingClient }> { - const initialSession = await resolveGitHubCopilotEmbeddingSession(client); - - const embed = async (input: string[]): Promise => { - if (input.length === 0) { - return []; - } - - const session = await resolveGitHubCopilotEmbeddingSession(client); - const url = `${session.baseUrl.replace(/\/$/, "")}/embeddings`; - return await withRemoteHttpResponse({ - url, - fetchImpl: client.fetchImpl, - ssrfPolicy: buildRemoteBaseUrlPolicy(session.baseUrl), - init: { - method: "POST", - headers: session.headers, - body: JSON.stringify({ model: client.model, input }), - }, - onResponse: async (response) => { - if (!response.ok) { - throw new Error( - `GitHub Copilot embeddings HTTP ${response.status}: ${await response.text()}`, - ); - } - - let payload: unknown; - try { - payload = await response.json(); - } catch { - throw new Error("GitHub Copilot embeddings returned invalid JSON"); - } - return parseGitHubCopilotEmbeddingPayload(payload, input.length); - }, - }); - }; - - return { - provider: { - id: COPILOT_EMBEDDING_PROVIDER_ID, - model: client.model, - embedQuery: async (text) => { - const [vector] = await embed([text]); - return vector ?? []; - }, - embedBatch: embed, - }, - client: { - ...client, - baseUrl: initialSession.baseUrl, - }, - }; -} diff --git a/src/memory-host-sdk/host/embeddings-lmstudio.test.ts b/src/memory-host-sdk/host/embeddings-lmstudio.test.ts deleted file mode 100644 index f953daba36f..00000000000 --- a/src/memory-host-sdk/host/embeddings-lmstudio.test.ts +++ /dev/null @@ -1,387 +0,0 @@ -import { afterEach, beforeAll, beforeEach, describe, expect, it, vi } from "vitest"; -import type { OpenClawConfig } from "../../config/config.js"; - -const ensureLmstudioModelLoadedMock = vi.hoisted(() => vi.fn()); -const resolveLmstudioRuntimeApiKeyMock = vi.hoisted(() => vi.fn()); - -vi.mock("../../plugin-sdk/lmstudio-runtime.js", () => ({ - buildLmstudioAuthHeaders: ({ - apiKey, - json, - headers, - }: { - apiKey?: string; - json?: boolean; - headers?: Record; - }) => ({ - ...(json ? { "Content-Type": "application/json" } : {}), - ...(apiKey ? { Authorization: `Bearer ${apiKey}` } : {}), - ...headers, - }), - ensureLmstudioModelLoaded: (...args: unknown[]) => ensureLmstudioModelLoadedMock(...args), - LMSTUDIO_DEFAULT_EMBEDDING_MODEL: "text-embedding-nomic-embed-text-v1.5", - LMSTUDIO_PROVIDER_ID: "lmstudio", - resolveLmstudioInferenceBase: (baseUrl?: string) => { - const normalized = (baseUrl || "http://localhost:1234").replace(/\/+$/u, ""); - if (normalized.endsWith("/api/v1")) { - return normalized.slice(0, -"/api/v1".length) + "/v1"; - } - if (normalized.endsWith("/v1")) { - return normalized; - } - return `${normalized}/v1`; - }, - resolveLmstudioProviderHeaders: ({ headers }: { headers?: Record }) => - headers ?? {}, - resolveLmstudioRuntimeApiKey: (...args: unknown[]) => resolveLmstudioRuntimeApiKeyMock(...args), -})); - -let createLmstudioEmbeddingProvider: typeof import("./embeddings-lmstudio.js").createLmstudioEmbeddingProvider; - -describe("embeddings-lmstudio", () => { - const originalFetch = globalThis.fetch; - const jsonResponse = (embedding: number[]) => - new Response( - JSON.stringify({ - data: [{ embedding }], - }), - { - status: 200, - headers: { "content-type": "application/json" }, - }, - ); - - function mockEmbeddingFetch(embedding: number[]) { - const fetchMock = vi.fn(); - fetchMock.mockResolvedValue(jsonResponse(embedding)); - globalThis.fetch = fetchMock as unknown as typeof fetch; - return fetchMock; - } - - beforeAll(async () => { - ({ createLmstudioEmbeddingProvider } = await import("./embeddings-lmstudio.js")); - }); - - beforeEach(() => { - ensureLmstudioModelLoadedMock.mockReset(); - resolveLmstudioRuntimeApiKeyMock.mockReset(); - }); - - afterEach(() => { - globalThis.fetch = originalFetch; - }); - - it("embeds against inference base and warms model with resolved key", async () => { - ensureLmstudioModelLoadedMock.mockResolvedValue(undefined); - resolveLmstudioRuntimeApiKeyMock.mockResolvedValue("profile-lmstudio-key"); - - const fetchMock = mockEmbeddingFetch([0.1, 0.2]); - - const { provider } = await createLmstudioEmbeddingProvider({ - config: { - models: { - providers: { - lmstudio: { - baseUrl: "http://localhost:1234/api/v1/", - headers: { "X-Provider": "provider" }, - models: [], - }, - }, - }, - } as OpenClawConfig, - provider: "lmstudio", - model: "lmstudio/text-embedding-nomic-embed-text-v1.5", - fallback: "none", - }); - - await provider.embedQuery("hello"); - - expect(fetchMock).toHaveBeenCalledWith( - "http://localhost:1234/v1/embeddings", - expect.objectContaining({ - method: "POST", - headers: expect.objectContaining({ - "Content-Type": "application/json", - Authorization: "Bearer profile-lmstudio-key", - "X-Provider": "provider", - }), - }), - ); - expect(ensureLmstudioModelLoadedMock).toHaveBeenCalledWith({ - baseUrl: "http://localhost:1234/v1", - apiKey: "profile-lmstudio-key", - headers: { - "X-Provider": "provider", - }, - ssrfPolicy: { allowedHostnames: ["localhost"] }, - modelKey: "text-embedding-nomic-embed-text-v1.5", - timeoutMs: 120_000, - }); - }); - - it("uses memorySearch remote overrides for primary lmstudio", async () => { - ensureLmstudioModelLoadedMock.mockResolvedValue(undefined); - resolveLmstudioRuntimeApiKeyMock.mockResolvedValue("profile-key"); - - const fetchMock = mockEmbeddingFetch([1, 2, 3]); - - const { provider } = await createLmstudioEmbeddingProvider({ - config: { - models: { - providers: { - lmstudio: { - baseUrl: "http://localhost:1234", - headers: { - "X-Provider": "provider", - "X-Config-Only": "from-provider", - }, - models: [], - }, - }, - }, - } as OpenClawConfig, - provider: "lmstudio", - model: "", - fallback: "none", - remote: { - baseUrl: "http://localhost:9999", - apiKey: "remote-lmstudio-key", - headers: { - "X-Provider": "remote", - "X-Remote-Only": "from-remote", - }, - }, - }); - - await provider.embedBatch(["one", "two"]); - - expect(fetchMock).toHaveBeenCalledWith( - "http://localhost:9999/v1/embeddings", - expect.objectContaining({ - headers: expect.objectContaining({ - Authorization: "Bearer remote-lmstudio-key", - "X-Provider": "remote", - "X-Config-Only": "from-provider", - "X-Remote-Only": "from-remote", - }), - }), - ); - expect(resolveLmstudioRuntimeApiKeyMock).not.toHaveBeenCalled(); - expect(ensureLmstudioModelLoadedMock).toHaveBeenCalledWith({ - baseUrl: "http://localhost:9999/v1", - apiKey: "remote-lmstudio-key", - headers: { - "X-Provider": "remote", - "X-Config-Only": "from-provider", - "X-Remote-Only": "from-remote", - }, - ssrfPolicy: { allowedHostnames: ["localhost"] }, - modelKey: "text-embedding-nomic-embed-text-v1.5", - timeoutMs: 120_000, - }); - }); - - it("preserves remote Authorization header auth for primary lmstudio", async () => { - ensureLmstudioModelLoadedMock.mockResolvedValue(undefined); - resolveLmstudioRuntimeApiKeyMock.mockResolvedValue("stale-profile-key"); - - const fetchMock = mockEmbeddingFetch([1, 2, 3]); - - const { provider } = await createLmstudioEmbeddingProvider({ - config: { - models: { - providers: { - lmstudio: { - baseUrl: "http://localhost:1234", - headers: { - "X-Provider": "provider", - }, - models: [], - }, - }, - }, - } as OpenClawConfig, - provider: "lmstudio", - model: "", - fallback: "none", - remote: { - baseUrl: "http://localhost:9999", - headers: { - Authorization: "Bearer remote-proxy-token", - "X-Remote-Only": "from-remote", - }, - }, - }); - - await provider.embedBatch(["one", "two"]); - - expect(fetchMock).toHaveBeenCalledWith( - "http://localhost:9999/v1/embeddings", - expect.objectContaining({ - headers: expect.objectContaining({ - Authorization: "Bearer remote-proxy-token", - "X-Provider": "provider", - "X-Remote-Only": "from-remote", - }), - }), - ); - expect(resolveLmstudioRuntimeApiKeyMock).not.toHaveBeenCalled(); - expect(ensureLmstudioModelLoadedMock).toHaveBeenCalledWith({ - baseUrl: "http://localhost:9999/v1", - apiKey: undefined, - headers: { - "X-Provider": "provider", - Authorization: "Bearer remote-proxy-token", - "X-Remote-Only": "from-remote", - }, - ssrfPolicy: { allowedHostnames: ["localhost"] }, - modelKey: "text-embedding-nomic-embed-text-v1.5", - timeoutMs: 120_000, - }); - }); - - it("ignores memorySearch remote overrides for fallback lmstudio activation", async () => { - ensureLmstudioModelLoadedMock.mockResolvedValue(undefined); - resolveLmstudioRuntimeApiKeyMock.mockResolvedValue("profile-key"); - - const fetchMock = mockEmbeddingFetch([1, 2, 3]); - - const { provider } = await createLmstudioEmbeddingProvider({ - config: { - models: { - providers: { - lmstudio: { - baseUrl: "http://localhost:1234", - headers: { - "X-Provider": "provider", - "X-Config-Only": "from-provider", - }, - models: [], - }, - }, - }, - } as OpenClawConfig, - provider: "openai", - model: "", - fallback: "lmstudio", - remote: { - baseUrl: "http://localhost:9999", - apiKey: "remote-lmstudio-key", - headers: { - "X-Provider": "remote", - "X-Remote-Only": "from-remote", - }, - }, - }); - - await provider.embedBatch(["one", "two"]); - - expect(fetchMock).toHaveBeenCalledWith( - "http://localhost:1234/v1/embeddings", - expect.objectContaining({ - headers: expect.objectContaining({ - Authorization: "Bearer profile-key", - "X-Provider": "provider", - "X-Config-Only": "from-provider", - }), - }), - ); - const callHeaders = fetchMock.mock.calls[0]?.[1]?.headers as Record; - expect(callHeaders["X-Remote-Only"]).toBeUndefined(); - expect(resolveLmstudioRuntimeApiKeyMock).toHaveBeenCalled(); - expect(ensureLmstudioModelLoadedMock).toHaveBeenCalledWith({ - baseUrl: "http://localhost:1234/v1", - apiKey: "profile-key", - headers: { - "X-Provider": "provider", - "X-Config-Only": "from-provider", - }, - ssrfPolicy: { allowedHostnames: ["localhost"] }, - modelKey: "text-embedding-nomic-embed-text-v1.5", - timeoutMs: 120_000, - }); - }); - - it("skips remote SecretRef resolution for fallback lmstudio activation", async () => { - ensureLmstudioModelLoadedMock.mockResolvedValue(undefined); - resolveLmstudioRuntimeApiKeyMock.mockResolvedValue("profile-key"); - - const fetchMock = mockEmbeddingFetch([1, 2, 3]); - - const { provider } = await createLmstudioEmbeddingProvider({ - config: { - models: { - providers: { - lmstudio: { - baseUrl: "http://localhost:1234", - headers: { - "X-Provider": "provider", - }, - models: [], - }, - }, - }, - } as OpenClawConfig, - provider: "openai", - model: "", - fallback: "lmstudio", - remote: { - baseUrl: "http://localhost:9999", - apiKey: { source: "env", provider: "default", id: "OPENAI_API_KEY" }, - headers: { - "X-Remote-Only": "from-remote", - }, - }, - }); - - await provider.embedQuery("hello"); - - expect(fetchMock).toHaveBeenCalledWith( - "http://localhost:1234/v1/embeddings", - expect.objectContaining({ - headers: expect.objectContaining({ - Authorization: "Bearer profile-key", - "X-Provider": "provider", - }), - }), - ); - const callHeaders = fetchMock.mock.calls[0]?.[1]?.headers as Record; - expect(callHeaders["X-Remote-Only"]).toBeUndefined(); - expect(resolveLmstudioRuntimeApiKeyMock).toHaveBeenCalled(); - }); - - it("uses env-template-backed provider api keys in embedding requests", async () => { - ensureLmstudioModelLoadedMock.mockResolvedValue(undefined); - resolveLmstudioRuntimeApiKeyMock.mockResolvedValue("template-lmstudio-key"); - - const fetchMock = mockEmbeddingFetch([0.3, 0.4]); - - const { provider } = await createLmstudioEmbeddingProvider({ - config: { - models: { - providers: { - lmstudio: { - baseUrl: "http://localhost:1234/v1", - apiKey: "${LM_API_TOKEN}", - models: [], - }, - }, - }, - } as OpenClawConfig, - provider: "lmstudio", - model: "text-embedding-nomic-embed-text-v1.5", - fallback: "none", - }); - - await provider.embedQuery("hello"); - - expect(fetchMock).toHaveBeenCalledWith( - "http://localhost:1234/v1/embeddings", - expect.objectContaining({ - headers: expect.objectContaining({ - Authorization: "Bearer template-lmstudio-key", - }), - }), - ); - }); -}); diff --git a/src/memory-host-sdk/host/embeddings-mistral.test.ts b/src/memory-host-sdk/host/embeddings-mistral.test.ts deleted file mode 100644 index 7826cd35467..00000000000 --- a/src/memory-host-sdk/host/embeddings-mistral.test.ts +++ /dev/null @@ -1,19 +0,0 @@ -import { describe, expect, it } from "vitest"; -import { DEFAULT_MISTRAL_EMBEDDING_MODEL, normalizeMistralModel } from "./embeddings-mistral.js"; - -describe("normalizeMistralModel", () => { - it("returns the default model for empty values", () => { - expect(normalizeMistralModel("")).toBe(DEFAULT_MISTRAL_EMBEDDING_MODEL); - expect(normalizeMistralModel(" ")).toBe(DEFAULT_MISTRAL_EMBEDDING_MODEL); - }); - - it("strips the mistral/ prefix", () => { - expect(normalizeMistralModel("mistral/mistral-embed")).toBe("mistral-embed"); - expect(normalizeMistralModel(" mistral/custom-embed ")).toBe("custom-embed"); - }); - - it("keeps explicit non-prefixed models", () => { - expect(normalizeMistralModel("mistral-embed")).toBe("mistral-embed"); - expect(normalizeMistralModel("custom-embed-v2")).toBe("custom-embed-v2"); - }); -}); diff --git a/src/memory-host-sdk/host/embeddings-ollama.test.ts b/src/memory-host-sdk/host/embeddings-ollama.test.ts deleted file mode 100644 index 3b601fba4f0..00000000000 --- a/src/memory-host-sdk/host/embeddings-ollama.test.ts +++ /dev/null @@ -1,43 +0,0 @@ -import { beforeEach, describe, expect, it, vi } from "vitest"; - -const { createOllamaEmbeddingProviderMock } = vi.hoisted(() => ({ - createOllamaEmbeddingProviderMock: vi.fn(async (options: unknown) => ({ - provider: { source: "mock-provider", options }, - client: { source: "mock-client" }, - })), -})); - -vi.mock("../../plugin-sdk/ollama-runtime.js", () => ({ - DEFAULT_OLLAMA_EMBEDDING_MODEL: "nomic-embed-text", - createOllamaEmbeddingProvider: createOllamaEmbeddingProviderMock, -})); - -describe("embeddings-ollama facade", () => { - beforeEach(() => { - createOllamaEmbeddingProviderMock.mockClear(); - }); - - it("re-exports the default Ollama embedding model", async () => { - const mod = await import("./embeddings-ollama.js"); - expect(mod.DEFAULT_OLLAMA_EMBEDDING_MODEL).toBe("nomic-embed-text"); - }); - - it("delegates provider creation to the plugin-sdk runtime facade", async () => { - const mod = await import("./embeddings-ollama.js"); - const options = { - provider: "ollama", - model: "nomic-embed-text", - fallback: "none", - config: {}, - }; - - const result = await mod.createOllamaEmbeddingProvider(options as never); - - expect(createOllamaEmbeddingProviderMock).toHaveBeenCalledTimes(1); - expect(createOllamaEmbeddingProviderMock).toHaveBeenCalledWith(options); - expect(result).toEqual({ - provider: { source: "mock-provider", options }, - client: { source: "mock-client" }, - }); - }); -}); diff --git a/src/memory-host-sdk/host/embeddings-ollama.ts b/src/memory-host-sdk/host/embeddings-ollama.ts deleted file mode 100644 index 61af79c7330..00000000000 --- a/src/memory-host-sdk/host/embeddings-ollama.ts +++ /dev/null @@ -1,5 +0,0 @@ -export type { OllamaEmbeddingClient } from "../../plugin-sdk/ollama-runtime.js"; -export { - createOllamaEmbeddingProvider, - DEFAULT_OLLAMA_EMBEDDING_MODEL, -} from "../../plugin-sdk/ollama-runtime.js"; diff --git a/src/memory-host-sdk/host/embeddings-openai.test.ts b/src/memory-host-sdk/host/embeddings-openai.test.ts deleted file mode 100644 index 7749afb6271..00000000000 --- a/src/memory-host-sdk/host/embeddings-openai.test.ts +++ /dev/null @@ -1,25 +0,0 @@ -import { describe, expect, it } from "vitest"; -import { DEFAULT_OPENAI_EMBEDDING_MODEL, normalizeOpenAiModel } from "./embeddings-openai.js"; - -describe("normalizeOpenAiModel", () => { - it("returns the default model when input is blank", () => { - expect(normalizeOpenAiModel("")).toBe(DEFAULT_OPENAI_EMBEDDING_MODEL); - expect(normalizeOpenAiModel(" ")).toBe(DEFAULT_OPENAI_EMBEDDING_MODEL); - }); - - it("strips the openai/ prefix", () => { - expect(normalizeOpenAiModel("openai/text-embedding-3-small")).toBe("text-embedding-3-small"); - expect(normalizeOpenAiModel("openai/text-embedding-ada-002")).toBe("text-embedding-ada-002"); - }); - - it("preserves explicit third-party provider prefixes", () => { - expect(normalizeOpenAiModel("spark/text-embedding-3-small")).toBe( - "spark/text-embedding-3-small", - ); - expect(normalizeOpenAiModel("litellm/azure/ada-002")).toBe("litellm/azure/ada-002"); - }); - - it("preserves unprefixed model ids", () => { - expect(normalizeOpenAiModel("text-embedding-3-large")).toBe("text-embedding-3-large"); - }); -}); diff --git a/src/memory-host-sdk/host/embeddings-provider.test-support.ts b/src/memory-host-sdk/host/embeddings-provider.test-support.ts deleted file mode 100644 index bb2c09a20f5..00000000000 --- a/src/memory-host-sdk/host/embeddings-provider.test-support.ts +++ /dev/null @@ -1,88 +0,0 @@ -import { vi } from "vitest"; -import { type FetchMock, withFetchPreconnect } from "../../test-utils/fetch-mock.js"; - -vi.mock("../../infra/net/fetch-guard.js", () => ({ - fetchWithSsrFGuard: async (params: { - url: string; - init?: RequestInit; - fetchImpl?: typeof fetch; - }) => { - const fetchImpl = params.fetchImpl ?? globalThis.fetch; - if (!fetchImpl) { - throw new Error("fetch is not available"); - } - const response = await fetchImpl(params.url, params.init); - return { - response, - finalUrl: params.url, - release: async () => {}, - }; - }, -})); - -type FetchPayloadFactory = (input: RequestInfo | URL, init?: RequestInit) => unknown; -type JsonResponseFetchMock = ReturnType> & { - preconnect: ( - url: string | URL, - options?: { dns?: boolean; tcp?: boolean; http?: boolean; https?: boolean }, - ) => void; - __openclawAcceptsDispatcher: true; -}; - -export type JsonFetchMock = ReturnType; - -export function createJsonResponseFetchMock(payload: FetchPayloadFactory): JsonResponseFetchMock; -export function createJsonResponseFetchMock(payload: unknown): JsonResponseFetchMock; -export function createJsonResponseFetchMock(payload: unknown) { - const fetchMock = vi.fn(async (input: RequestInfo | URL, init?: RequestInit) => { - const body = - typeof payload === "function" ? (payload as FetchPayloadFactory)(input, init) : payload; - const serialized = JSON.stringify(body); - return { - ok: true, - status: 200, - json: async () => body, - text: async () => serialized, - } as Response; - }); - return withFetchPreconnect(fetchMock) as JsonResponseFetchMock; -} - -export function createEmbeddingDataFetchMock(embeddingValues = [0.1, 0.2, 0.3]) { - return createJsonResponseFetchMock({ data: [{ embedding: embeddingValues }] }); -} - -export function createGeminiFetchMock(embeddingValues = [1, 2, 3]) { - return createJsonResponseFetchMock({ embedding: { values: embeddingValues } }); -} - -export function createGeminiBatchFetchMock(count: number, embeddingValues = [1, 2, 3]) { - return createJsonResponseFetchMock({ - embeddings: Array.from({ length: count }, () => ({ values: embeddingValues })), - }); -} - -export function installFetchMock(fetchMock: typeof globalThis.fetch) { - vi.stubGlobal("fetch", fetchMock); -} - -export function readFirstFetchRequest(fetchMock: { mock: { calls: unknown[][] } }) { - const [url, init] = fetchMock.mock.calls[0] ?? []; - return { url, init: init as RequestInit | undefined }; -} - -export function parseFetchBody(fetchMock: { mock: { calls: unknown[][] } }, callIndex = 0) { - const init = fetchMock.mock.calls[callIndex]?.[1] as RequestInit | undefined; - return JSON.parse((init?.body as string) ?? "{}") as Record; -} - -export function mockResolvedProviderKey( - resolveApiKeyForProvider: typeof import("../../agents/model-auth.js").resolveApiKeyForProvider, - apiKey = "test-key", -) { - vi.mocked(resolveApiKeyForProvider).mockResolvedValue({ - apiKey, - mode: "api-key", - source: "test", - }); -} diff --git a/src/memory-host-sdk/host/embeddings-remote-client.ts b/src/memory-host-sdk/host/embeddings-remote-client.ts index a8e8fd1d495..77f11479f65 100644 --- a/src/memory-host-sdk/host/embeddings-remote-client.ts +++ b/src/memory-host-sdk/host/embeddings-remote-client.ts @@ -5,7 +5,7 @@ import type { EmbeddingProviderOptions } from "./embeddings.types.js"; import { buildRemoteBaseUrlPolicy } from "./remote-http.js"; import { resolveMemorySecretInputString } from "./secret-input.js"; -export type RemoteEmbeddingProviderId = "openai" | "voyage" | "mistral"; +export type RemoteEmbeddingProviderId = string; export async function resolveRemoteEmbeddingBearerClient(params: { provider: RemoteEmbeddingProviderId; diff --git a/src/memory-host-sdk/host/embeddings-voyage.test.ts b/src/memory-host-sdk/host/embeddings-voyage.test.ts deleted file mode 100644 index 9122861513e..00000000000 --- a/src/memory-host-sdk/host/embeddings-voyage.test.ts +++ /dev/null @@ -1,141 +0,0 @@ -import { afterEach, beforeAll, beforeEach, describe, expect, it, vi } from "vitest"; -import * as authModule from "../../agents/model-auth.js"; -import { - createEmbeddingDataFetchMock, - createJsonResponseFetchMock, - installFetchMock, - mockResolvedProviderKey, - type JsonFetchMock, -} from "./embeddings-provider.test-support.js"; -import { mockPublicPinnedHostname } from "./test-helpers/ssrf.js"; - -const { resolveApiKeyForProviderMock } = vi.hoisted(() => ({ - resolveApiKeyForProviderMock: vi.fn(), -})); - -vi.mock("../../agents/model-auth.js", () => { - return { - resolveApiKeyForProvider: resolveApiKeyForProviderMock, - requireApiKey: (auth: { apiKey?: string; mode?: string }, provider: string) => { - if (auth.apiKey) { - return auth.apiKey; - } - throw new Error(`No API key resolved for provider "${provider}" (auth mode: ${auth.mode}).`); - }, - }; -}); - -let createVoyageEmbeddingProvider: typeof import("./embeddings-voyage.js").createVoyageEmbeddingProvider; -let normalizeVoyageModel: typeof import("./embeddings-voyage.js").normalizeVoyageModel; - -beforeAll(async () => { - ({ createVoyageEmbeddingProvider, normalizeVoyageModel } = - await import("./embeddings-voyage.js")); -}); - -beforeEach(() => { - vi.useRealTimers(); - vi.doUnmock("undici"); -}); - -async function createDefaultVoyageProvider(model: string, fetchMock: JsonFetchMock) { - installFetchMock(fetchMock as unknown as typeof globalThis.fetch); - mockPublicPinnedHostname(); - mockResolvedProviderKey(authModule.resolveApiKeyForProvider, "voyage-key-123"); - return createVoyageEmbeddingProvider({ - config: {} as never, - provider: "voyage", - model, - fallback: "none", - }); -} - -describe("voyage embedding provider", () => { - afterEach(() => { - vi.doUnmock("undici"); - vi.resetAllMocks(); - vi.unstubAllGlobals(); - }); - - it("configures client with correct defaults and headers", async () => { - const fetchMock = createEmbeddingDataFetchMock(); - const result = await createDefaultVoyageProvider("voyage-4-large", fetchMock); - - await result.provider.embedQuery("test query"); - - expect(authModule.resolveApiKeyForProvider).toHaveBeenCalledWith( - expect.objectContaining({ provider: "voyage" }), - ); - - const call = fetchMock.mock.calls[0]; - expect(call).toBeDefined(); - const [url, init] = call as [RequestInfo | URL, RequestInit | undefined]; - expect(url).toBe("https://api.voyageai.com/v1/embeddings"); - - const headers = (init?.headers ?? {}) as Record; - expect(headers.Authorization).toBe("Bearer voyage-key-123"); - expect(headers["Content-Type"]).toBe("application/json"); - - const body = JSON.parse(init?.body as string); - expect(body).toEqual({ - model: "voyage-4-large", - input: ["test query"], - input_type: "query", - }); - }); - - it("respects remote overrides for baseUrl and apiKey", async () => { - const fetchMock = createEmbeddingDataFetchMock(); - installFetchMock(fetchMock as unknown as typeof globalThis.fetch); - mockPublicPinnedHostname(); - - const result = await createVoyageEmbeddingProvider({ - config: {} as never, - provider: "voyage", - model: "voyage-4-lite", - fallback: "none", - remote: { - baseUrl: "https://example.com", - apiKey: "remote-override-key", - headers: { "X-Custom": "123" }, - }, - }); - - await result.provider.embedQuery("test"); - - const call = fetchMock.mock.calls[0]; - expect(call).toBeDefined(); - const [url, init] = call as [RequestInfo | URL, RequestInit | undefined]; - expect(url).toBe("https://example.com/embeddings"); - - const headers = (init?.headers ?? {}) as Record; - expect(headers.Authorization).toBe("Bearer remote-override-key"); - expect(headers["X-Custom"]).toBe("123"); - }); - - it("passes input_type=document for embedBatch", async () => { - const fetchMock = createJsonResponseFetchMock({ - data: [{ embedding: [0.1, 0.2] }, { embedding: [0.3, 0.4] }], - }); - const result = await createDefaultVoyageProvider("voyage-4-large", fetchMock); - - await result.provider.embedBatch(["doc1", "doc2"]); - - const call = fetchMock.mock.calls[0]; - expect(call).toBeDefined(); - const [, init] = call as [RequestInfo | URL, RequestInit | undefined]; - const body = JSON.parse(init?.body as string); - expect(body).toEqual({ - model: "voyage-4-large", - input: ["doc1", "doc2"], - input_type: "document", - }); - }); - - it("normalizes model names", async () => { - expect(normalizeVoyageModel("voyage/voyage-large-2")).toBe("voyage-large-2"); - expect(normalizeVoyageModel("voyage-4-large")).toBe("voyage-4-large"); - expect(normalizeVoyageModel(" voyage-lite ")).toBe("voyage-lite"); - expect(normalizeVoyageModel("")).toBe("voyage-4-large"); // Default - }); -}); diff --git a/src/memory-host-sdk/host/embeddings.test.ts b/src/memory-host-sdk/host/embeddings.test.ts index affcfef00c4..a6073c413e7 100644 --- a/src/memory-host-sdk/host/embeddings.test.ts +++ b/src/memory-host-sdk/host/embeddings.test.ts @@ -1,768 +1,67 @@ -import { setTimeout as sleep } from "node:timers/promises"; import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; -import * as authModule from "../../agents/model-auth.js"; -import { - createEmbeddingDataFetchMock, - createGeminiFetchMock, - installFetchMock, - mockResolvedProviderKey as mockResolvedProviderKeyBase, - readFirstFetchRequest, -} from "./embeddings-provider.test-support.js"; -import { createEmbeddingProvider, DEFAULT_LOCAL_MODEL } from "./embeddings.js"; +import { createLocalEmbeddingProvider, DEFAULT_LOCAL_MODEL } from "./embeddings.js"; import * as nodeLlamaModule from "./node-llama.js"; -import { mockPublicPinnedHostname } from "./test-helpers/ssrf.js"; - -const { - bedrockSendMock, - createOllamaEmbeddingProviderMock, - defaultProviderMock, - resolveApiKeyForProviderMock, - resolveCredentialsMock, -} = vi.hoisted(() => ({ - bedrockSendMock: vi.fn(), - createOllamaEmbeddingProviderMock: vi.fn(async () => { - throw new Error("Unexpected ollama provider in embeddings.test.ts"); - }), - defaultProviderMock: vi.fn(), - resolveCredentialsMock: vi.fn(), - resolveApiKeyForProviderMock: vi.fn(), -})); - -vi.mock("../../agents/model-auth.js", () => { - return { - resolveApiKeyForProvider: resolveApiKeyForProviderMock, - requireApiKey: (auth: { apiKey?: string; mode?: string }, provider: string) => { - if (auth.apiKey) { - return auth.apiKey; - } - throw new Error(`No API key resolved for provider "${provider}" (auth mode: ${auth.mode}).`); - }, - }; -}); - -vi.mock("./embeddings-ollama.js", () => ({ - createOllamaEmbeddingProvider: createOllamaEmbeddingProviderMock, -})); - -vi.mock("@aws-sdk/client-bedrock-runtime", () => { - class MockClient { - send = bedrockSendMock; - } - class MockCommand { - input: unknown; - constructor(input: unknown) { - this.input = input; - } - } - return { BedrockRuntimeClient: MockClient, InvokeModelCommand: MockCommand }; -}); - -vi.mock("@aws-sdk/credential-provider-node", () => ({ - defaultProvider: defaultProviderMock.mockImplementation(() => resolveCredentialsMock), -})); - -const createFetchMock = () => createEmbeddingDataFetchMock([1, 2, 3]); - -type ResolvedProviderAuth = Awaited>; beforeEach(() => { - vi.spyOn(authModule, "resolveApiKeyForProvider"); vi.spyOn(nodeLlamaModule, "importNodeLlamaCpp"); - defaultProviderMock.mockImplementation(() => resolveCredentialsMock); -}); - -beforeEach(() => { - vi.useRealTimers(); }); afterEach(() => { vi.resetAllMocks(); - vi.unstubAllGlobals(); }); -function requireProvider(result: Awaited>) { - if (!result.provider) { - throw new Error("Expected embedding provider"); - } - return result.provider; +function mockLocalEmbeddingRuntime(vector = new Float32Array([2.35, 3.45, 0.63, 4.3])) { + const getEmbeddingFor = vi.fn().mockResolvedValue({ vector }); + const createEmbeddingContext = vi.fn().mockResolvedValue({ getEmbeddingFor }); + const loadModel = vi.fn().mockResolvedValue({ createEmbeddingContext }); + const resolveModelFile = vi.fn(async (modelPath: string) => `/resolved/${modelPath}`); + + vi.mocked(nodeLlamaModule.importNodeLlamaCpp).mockResolvedValue({ + getLlama: async () => ({ loadModel }), + resolveModelFile, + LlamaLogLevel: { error: 0 }, + } as never); + + return { createEmbeddingContext, getEmbeddingFor, loadModel, resolveModelFile }; } -function mockResolvedProviderKey(apiKey = "provider-key") { - mockResolvedProviderKeyBase(authModule.resolveApiKeyForProvider, apiKey); -} +describe("local embedding provider", () => { + it("normalizes local embeddings and resolves the default local model", async () => { + const runtime = mockLocalEmbeddingRuntime(); -function mockMissingLocalEmbeddingDependency() { - vi.mocked(nodeLlamaModule.importNodeLlamaCpp).mockRejectedValue( - Object.assign(new Error("Cannot find package 'node-llama-cpp'"), { - code: "ERR_MODULE_NOT_FOUND", - }), - ); -} - -function createLocalProvider(options?: { fallback?: "none" | "openai" }) { - return createEmbeddingProvider({ - config: {} as never, - provider: "local", - model: "text-embedding-3-small", - fallback: options?.fallback ?? "none", - }); -} - -function expectAutoSelectedProvider( - result: Awaited>, - expectedId: "openai" | "gemini" | "mistral" | "bedrock", -) { - expect(result.requestedProvider).toBe("auto"); - const provider = requireProvider(result); - expect(provider.id).toBe(expectedId); - return provider; -} - -function createAutoProvider(model = "") { - return createEmbeddingProvider({ - config: {} as never, - provider: "auto", - model, - fallback: "none", - }); -} - -describe("embedding provider remote overrides", () => { - it("uses remote baseUrl/apiKey and merges headers", async () => { - const fetchMock = createFetchMock(); - installFetchMock(fetchMock as unknown as typeof globalThis.fetch); - mockPublicPinnedHostname(); - mockResolvedProviderKey("provider-key"); - - const cfg = { - models: { - providers: { - openai: { - baseUrl: "https://api.openai.com/v1", - headers: { - "X-Provider": "p", - "X-Shared": "provider", - }, - }, - }, - }, - }; - - const result = await createEmbeddingProvider({ - config: cfg as never, - provider: "openai", - remote: { - baseUrl: "https://example.com/v1", - apiKey: " remote-key ", - headers: { - "X-Shared": "remote", - "X-Remote": "r", - }, - }, - model: "text-embedding-3-small", - fallback: "openai", - }); - - const provider = requireProvider(result); - await provider.embedQuery("hello"); - - expect(authModule.resolveApiKeyForProvider).not.toHaveBeenCalled(); - const url = fetchMock.mock.calls[0]?.[0]; - const init = fetchMock.mock.calls[0]?.[1]; - expect(url).toBe("https://example.com/v1/embeddings"); - const headers = (init?.headers ?? {}) as Record; - expect(headers.Authorization).toBe("Bearer remote-key"); - expect(headers["Content-Type"]).toBe("application/json"); - expect(headers["X-Provider"]).toBe("p"); - expect(headers["X-Shared"]).toBe("remote"); - expect(headers["X-Remote"]).toBe("r"); - }); - - it("falls back to resolved api key when remote apiKey is blank", async () => { - const fetchMock = createFetchMock(); - installFetchMock(fetchMock as unknown as typeof globalThis.fetch); - mockPublicPinnedHostname(); - mockResolvedProviderKey("provider-key"); - - const cfg = { - models: { - providers: { - openai: { - baseUrl: "https://api.openai.com/v1", - }, - }, - }, - }; - - const result = await createEmbeddingProvider({ - config: cfg as never, - provider: "openai", - remote: { - baseUrl: "https://example.com/v1", - apiKey: " ", - }, - model: "text-embedding-3-small", - fallback: "openai", - }); - - const provider = requireProvider(result); - await provider.embedQuery("hello"); - - expect(authModule.resolveApiKeyForProvider).toHaveBeenCalledTimes(1); - const init = fetchMock.mock.calls[0]?.[1]; - const headers = (init?.headers as Record) ?? {}; - expect(headers.Authorization).toBe("Bearer provider-key"); - }); - - it("builds Gemini embeddings requests with api key header", async () => { - const fetchMock = createGeminiFetchMock(); - installFetchMock(fetchMock as unknown as typeof globalThis.fetch); - mockPublicPinnedHostname(); - mockResolvedProviderKey("provider-key"); - - const cfg = { - models: { - providers: { - google: { - baseUrl: "https://generativelanguage.googleapis.com/v1beta", - }, - }, - }, - }; - - const result = await createEmbeddingProvider({ - config: cfg as never, - provider: "gemini", - remote: { - apiKey: "gemini-key", - }, - model: "text-embedding-004", - fallback: "openai", - }); - - const provider = requireProvider(result); - await provider.embedQuery("hello"); - - const { url, init } = readFirstFetchRequest(fetchMock); - expect(url).toBe( - "https://generativelanguage.googleapis.com/v1beta/models/text-embedding-004:embedContent", - ); - const headers = (init?.headers ?? {}) as Record; - expect(headers["x-goog-api-key"]).toBe("gemini-key"); - expect(headers["Content-Type"]).toBe("application/json"); - }); - - it("fails fast when Gemini remote apiKey is an unresolved SecretRef", async () => { - vi.stubEnv("GEMINI_API_KEY", ""); - - await expect( - createEmbeddingProvider({ - config: {} as never, - provider: "gemini", - remote: { - apiKey: { source: "env", provider: "default", id: "GEMINI_API_KEY" }, - }, - model: "text-embedding-004", - fallback: "openai", - }), - ).rejects.toThrow(/agents\.\*\.memorySearch\.remote\.apiKey:/i); - }); - - it("uses GEMINI_API_KEY env indirection for Gemini remote apiKey", async () => { - vi.stubEnv("GEMINI_API_KEY", "env-gemini-key"); - - const result = await createEmbeddingProvider({ - config: {} as never, - provider: "gemini", - remote: { - apiKey: "GEMINI_API_KEY", // pragma: allowlist secret - }, - model: "text-embedding-004", - fallback: "openai", - }); - - const provider = requireProvider(result); - expect(provider.id).toBe("gemini"); - expect(result.gemini?.apiKeys).toEqual(["env-gemini-key"]); - }); - - it("builds Mistral embeddings requests with bearer auth", async () => { - const fetchMock = createFetchMock(); - installFetchMock(fetchMock as unknown as typeof globalThis.fetch); - mockPublicPinnedHostname(); - mockResolvedProviderKey("provider-key"); - - const cfg = { - models: { - providers: { - mistral: { - baseUrl: "https://api.mistral.ai/v1", - }, - }, - }, - }; - - const result = await createEmbeddingProvider({ - config: cfg as never, - provider: "mistral", - remote: { - apiKey: "mistral-key", // pragma: allowlist secret - }, - model: "mistral/mistral-embed", - fallback: "none", - }); - - const provider = requireProvider(result); - await provider.embedQuery("hello"); - - const { url, init } = readFirstFetchRequest(fetchMock); - expect(url).toBe("https://api.mistral.ai/v1/embeddings"); - const headers = (init?.headers ?? {}) as Record; - expect(headers.Authorization).toBe("Bearer mistral-key"); - const payload = JSON.parse((init?.body as string | undefined) ?? "{}") as { model?: string }; - expect(payload.model).toBe("mistral-embed"); - }); -}); - -describe("embedding provider auto selection", () => { - it("keeps explicit model when openai is selected", async () => { - const fetchMock = vi.fn(async (_input?: unknown, _init?: unknown) => ({ - ok: true, - status: 200, - json: async () => ({ data: [{ embedding: [1, 2, 3] }] }), - })); - installFetchMock(fetchMock as unknown as typeof globalThis.fetch); - mockPublicPinnedHostname(); - vi.mocked(authModule.resolveApiKeyForProvider).mockImplementation(async ({ provider }) => { - if (provider === "openai") { - return { apiKey: "openai-key", source: "env: OPENAI_API_KEY", mode: "api-key" }; - } - throw new Error(`Unexpected provider ${provider}`); - }); - - const result = await createEmbeddingProvider({ - config: {} as never, - provider: "auto", - model: "text-embedding-3-small", - fallback: "none", - }); - - expect(result.requestedProvider).toBe("auto"); - const provider = requireProvider(result); - expect(provider.id).toBe("openai"); - await provider.embedQuery("hello"); - const url = fetchMock.mock.calls[0]?.[0]; - const init = fetchMock.mock.calls[0]?.[1] as RequestInit | undefined; - expect(url).toBe("https://api.openai.com/v1/embeddings"); - const payload = JSON.parse(init?.body as string) as { model?: string }; - expect(payload.model).toBe("text-embedding-3-small"); - }); - - it("selects the first available remote provider in auto mode", async () => { - const cases: Array<{ - name: string; - expectedProvider: "openai" | "gemini" | "mistral"; - resolveApiKey: (provider: string) => ResolvedProviderAuth; - }> = [ - { - name: "openai first", - expectedProvider: "openai" as const, - resolveApiKey(provider: string): ResolvedProviderAuth { - if (provider === "openai") { - return { apiKey: "openai-key", source: "env: OPENAI_API_KEY", mode: "api-key" }; - } - throw new Error(`No API key found for provider "${provider}".`); - }, - }, - { - name: "gemini fallback", - expectedProvider: "gemini" as const, - resolveApiKey(provider: string): ResolvedProviderAuth { - if (provider === "openai") { - throw new Error('No API key found for provider "openai".'); - } - if (provider === "google") { - return { - apiKey: "gemini-key", - source: "env: GEMINI_API_KEY", - mode: "api-key" as const, - }; - } - throw new Error(`Unexpected provider ${provider}`); - }, - }, - { - name: "mistral after earlier misses", - expectedProvider: "mistral" as const, - resolveApiKey(provider: string): ResolvedProviderAuth { - if (provider === "mistral") { - return { - apiKey: "mistral-key", - source: "env: MISTRAL_API_KEY", - mode: "api-key" as const, - }; - } - throw new Error(`No API key found for provider "${provider}".`); - }, - }, - ]; - - for (const testCase of cases) { - vi.mocked(authModule.resolveApiKeyForProvider).mockReset(); - vi.mocked(authModule.resolveApiKeyForProvider).mockImplementation(async ({ provider }) => - testCase.resolveApiKey(provider), - ); - - const result = await createAutoProvider(); - expectAutoSelectedProvider(result, testCase.expectedProvider); - } - }); - - it("selects Bedrock in auto mode when the AWS credential chain resolves", async () => { - bedrockSendMock.mockResolvedValue({ - body: new TextEncoder().encode(JSON.stringify({ embedding: [1, 2, 3] })), - }); - resolveCredentialsMock.mockResolvedValue({ accessKeyId: "AKIAEXAMPLE" }); - vi.mocked(authModule.resolveApiKeyForProvider).mockImplementation(async ({ provider }) => { - throw new Error(`No API key found for provider "${provider}".`); - }); - - const result = await createAutoProvider(); - const provider = expectAutoSelectedProvider(result, "bedrock"); - await provider.embedQuery("hello"); - - expect(bedrockSendMock).toHaveBeenCalledTimes(1); - }); - - it("rethrows non-auth Bedrock setup errors in auto mode", async () => { - resolveCredentialsMock.mockResolvedValue({ accessKeyId: "AKIAEXAMPLE" }); - vi.mocked(authModule.resolveApiKeyForProvider).mockImplementation(async ({ provider }) => { - throw new Error(`No API key found for provider "${provider}".`); - }); - - await expect( - createEmbeddingProvider({ - config: {} as never, - provider: "auto", - model: "", - fallback: "none", - outputDimensionality: 768, - }), - ).rejects.toThrow("Invalid dimensions 768"); - }); -}); - -describe("embedding provider local fallback", () => { - it("falls back to openai when node-llama-cpp is missing", async () => { - mockMissingLocalEmbeddingDependency(); - - const fetchMock = createFetchMock(); - installFetchMock(fetchMock as unknown as typeof globalThis.fetch); - - mockResolvedProviderKey("provider-key"); - - const result = await createLocalProvider({ fallback: "openai" }); - - const provider = requireProvider(result); - expect(provider.id).toBe("openai"); - expect(result.fallbackFrom).toBe("local"); - expect(result.fallbackReason).toContain("node-llama-cpp"); - }); - - it("throws a helpful error when local is requested and fallback is none", async () => { - mockMissingLocalEmbeddingDependency(); - await expect(createLocalProvider()).rejects.toThrow(/optional dependency node-llama-cpp/i); - }); - - it("mentions every remote provider in local setup guidance", async () => { - mockMissingLocalEmbeddingDependency(); - await expect(createLocalProvider()).rejects.toThrow(/provider = "gemini"/i); - await expect(createLocalProvider()).rejects.toThrow(/provider = "mistral"/i); - }); -}); - -describe("local embedding normalization", () => { - async function createLocalProviderForTest() { - return createEmbeddingProvider({ + const provider = await createLocalEmbeddingProvider({ config: {} as never, provider: "local", model: "", fallback: "none", }); - } - function mockSingleLocalEmbeddingVector( - vector: number[], - resolveModelFile: (modelPath: string, modelDirectory?: string) => Promise = async () => - "/fake/model.gguf", - ): void { - vi.mocked(nodeLlamaModule.importNodeLlamaCpp).mockResolvedValue({ - getLlama: async () => ({ - loadModel: vi.fn().mockResolvedValue({ - createEmbeddingContext: vi.fn().mockResolvedValue({ - getEmbeddingFor: vi.fn().mockResolvedValue({ - vector: new Float32Array(vector), - }), - }), - }), - }), - resolveModelFile, - LlamaLogLevel: { error: 0 }, - } as never); - } - - it("normalizes local embeddings to magnitude ~1.0", async () => { - const unnormalizedVector = [2.35, 3.45, 0.63, 4.3, 1.2, 5.1, 2.8, 3.9]; - const resolveModelFileMock = vi.fn(async () => "/fake/model.gguf"); - - mockSingleLocalEmbeddingVector(unnormalizedVector, resolveModelFileMock); - - const result = await createLocalProviderForTest(); - - const provider = requireProvider(result); const embedding = await provider.embedQuery("test query"); + const magnitude = Math.sqrt(embedding.reduce((sum, value) => sum + value * value, 0)); - const magnitude = Math.sqrt(embedding.reduce((sum, x) => sum + x * x, 0)); - - expect(magnitude).toBeCloseTo(1.0, 5); - expect(resolveModelFileMock).toHaveBeenCalledWith(DEFAULT_LOCAL_MODEL, undefined); + expect(magnitude).toBeCloseTo(1, 5); + expect(runtime.resolveModelFile).toHaveBeenCalledWith(DEFAULT_LOCAL_MODEL, undefined); + expect(runtime.getEmbeddingFor).toHaveBeenCalledWith("test query"); }); - it("handles zero vector without division by zero", async () => { - const zeroVector = [0, 0, 0, 0]; + it("trims explicit local model paths and cache directories", async () => { + const runtime = mockLocalEmbeddingRuntime(new Float32Array([1, 0])); - mockSingleLocalEmbeddingVector(zeroVector); - - const result = await createLocalProviderForTest(); - - const provider = requireProvider(result); - const embedding = await provider.embedQuery("test"); - - expect(embedding).toEqual([0, 0, 0, 0]); - expect(embedding.every((value) => Number.isFinite(value))).toBe(true); - }); - - it("sanitizes non-finite values before normalization", async () => { - const nonFiniteVector = [1, Number.NaN, Number.POSITIVE_INFINITY, Number.NEGATIVE_INFINITY]; - - mockSingleLocalEmbeddingVector(nonFiniteVector); - - const result = await createLocalProviderForTest(); - - const provider = requireProvider(result); - const embedding = await provider.embedQuery("test"); - - expect(embedding).toEqual([1, 0, 0, 0]); - expect(embedding.every((value) => Number.isFinite(value))).toBe(true); - }); - - it("normalizes batch embeddings to magnitude ~1.0", async () => { - const unnormalizedVectors = [ - [2.35, 3.45, 0.63, 4.3], - [10.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 1.0, 1.0], - ]; - - vi.mocked(nodeLlamaModule.importNodeLlamaCpp).mockResolvedValue({ - getLlama: async () => ({ - loadModel: vi.fn().mockResolvedValue({ - createEmbeddingContext: vi.fn().mockResolvedValue({ - getEmbeddingFor: vi - .fn() - .mockResolvedValueOnce({ vector: new Float32Array(unnormalizedVectors[0]) }) - .mockResolvedValueOnce({ vector: new Float32Array(unnormalizedVectors[1]) }) - .mockResolvedValueOnce({ vector: new Float32Array(unnormalizedVectors[2]) }), - }), - }), - }), - resolveModelFile: async () => "/fake/model.gguf", - LlamaLogLevel: { error: 0 }, - } as never); - - const result = await createLocalProviderForTest(); - - const provider = requireProvider(result); - const embeddings = await provider.embedBatch(["text1", "text2", "text3"]); - - for (const embedding of embeddings) { - const magnitude = Math.sqrt(embedding.reduce((sum, x) => sum + x * x, 0)); - expect(magnitude).toBeCloseTo(1.0, 5); - } - }); -}); - -describe("local embedding ensureContext concurrency", () => { - async function setupLocalProviderWithMockedInit(params?: { - initializationDelayMs?: number; - failFirstGetLlama?: boolean; - }) { - const getLlamaSpy = vi.fn(); - const loadModelSpy = vi.fn(); - const createContextSpy = vi.fn(); - let shouldFail = params?.failFirstGetLlama ?? false; - - vi.spyOn(nodeLlamaModule, "importNodeLlamaCpp").mockResolvedValue({ - getLlama: async (...args: unknown[]) => { - getLlamaSpy(...args); - if (shouldFail) { - shouldFail = false; - throw new Error("transient init failure"); - } - if (params?.initializationDelayMs) { - await sleep(params.initializationDelayMs); - } - return { - loadModel: async (...modelArgs: unknown[]) => { - loadModelSpy(...modelArgs); - if (params?.initializationDelayMs) { - await sleep(params.initializationDelayMs); - } - return { - createEmbeddingContext: async () => { - createContextSpy(); - return { - getEmbeddingFor: vi.fn().mockResolvedValue({ - vector: new Float32Array([1, 0, 0, 0]), - }), - }; - }, - }; - }, - }; - }, - resolveModelFile: async () => "/fake/model.gguf", - LlamaLogLevel: { error: 0 }, - } as never); - - const result = await createEmbeddingProvider({ + const provider = await createLocalEmbeddingProvider({ config: {} as never, provider: "local", model: "", fallback: "none", + local: { + modelPath: " /models/embed.gguf ", + modelCacheDir: " /cache/models ", + }, }); - return { - provider: requireProvider(result), - getLlamaSpy, - loadModelSpy, - createContextSpy, - }; - } + await provider.embedBatch(["a", "b"]); - it("loads the model only once when embedBatch is called concurrently", async () => { - const { provider, getLlamaSpy, loadModelSpy, createContextSpy } = - await setupLocalProviderWithMockedInit({ - initializationDelayMs: 5, - }); - - const results = await Promise.all([ - provider.embedBatch(["text1"]), - provider.embedBatch(["text2"]), - provider.embedBatch(["text3"]), - provider.embedBatch(["text4"]), - ]); - - expect(results).toHaveLength(4); - for (const embeddings of results) { - expect(embeddings).toHaveLength(1); - expect(embeddings[0]).toHaveLength(4); - } - - expect(getLlamaSpy).toHaveBeenCalledTimes(1); - expect(loadModelSpy).toHaveBeenCalledTimes(1); - expect(createContextSpy).toHaveBeenCalledTimes(1); - }); - - it("retries initialization after a transient ensureContext failure", async () => { - const { provider, getLlamaSpy, loadModelSpy, createContextSpy } = - await setupLocalProviderWithMockedInit({ - failFirstGetLlama: true, - }); - - await expect(provider.embedBatch(["first"])).rejects.toThrow("transient init failure"); - - const recovered = await provider.embedBatch(["second"]); - expect(recovered).toHaveLength(1); - expect(recovered[0]).toHaveLength(4); - - expect(getLlamaSpy).toHaveBeenCalledTimes(2); - expect(loadModelSpy).toHaveBeenCalledTimes(1); - expect(createContextSpy).toHaveBeenCalledTimes(1); - }); - - it("shares initialization when embedQuery and embedBatch start concurrently", async () => { - const { provider, getLlamaSpy, loadModelSpy, createContextSpy } = - await setupLocalProviderWithMockedInit({ - initializationDelayMs: 5, - }); - - const [queryA, batch, queryB] = await Promise.all([ - provider.embedQuery("query-a"), - provider.embedBatch(["batch-a", "batch-b"]), - provider.embedQuery("query-b"), - ]); - - expect(queryA).toHaveLength(4); - expect(batch).toHaveLength(2); - expect(queryB).toHaveLength(4); - expect(batch[0]).toHaveLength(4); - expect(batch[1]).toHaveLength(4); - - expect(getLlamaSpy).toHaveBeenCalledTimes(1); - expect(loadModelSpy).toHaveBeenCalledTimes(1); - expect(createContextSpy).toHaveBeenCalledTimes(1); - }); -}); - -describe("FTS-only fallback when no provider available", () => { - it("returns null provider when all requested auth paths fail", async () => { - vi.mocked(authModule.resolveApiKeyForProvider).mockRejectedValue( - new Error("No API key found for provider"), - ); - - for (const testCase of [ - { - name: "auto mode", - options: { - config: {} as never, - provider: "auto" as const, - model: "", - fallback: "none" as const, - }, - requestedProvider: "auto", - fallbackFrom: undefined, - reasonIncludes: "No API key", - }, - { - name: "explicit provider only", - options: { - config: {} as never, - provider: "openai" as const, - model: "text-embedding-3-small", - fallback: "none" as const, - }, - requestedProvider: "openai", - fallbackFrom: undefined, - reasonIncludes: "No API key", - }, - { - name: "primary and fallback", - options: { - config: {} as never, - provider: "openai" as const, - model: "text-embedding-3-small", - fallback: "gemini" as const, - }, - requestedProvider: "openai", - fallbackFrom: "openai", - reasonIncludes: "Fallback to gemini failed", - }, - ]) { - const result = await createEmbeddingProvider(testCase.options); - expect(result.provider, testCase.name).toBeNull(); - expect(result.requestedProvider, testCase.name).toBe(testCase.requestedProvider); - expect(result.fallbackFrom, testCase.name).toBe(testCase.fallbackFrom); - expect(result.providerUnavailableReason, testCase.name).toContain(testCase.reasonIncludes); - } + expect(provider.model).toBe("/models/embed.gguf"); + expect(runtime.resolveModelFile).toHaveBeenCalledWith("/models/embed.gguf", "/cache/models"); + expect(runtime.getEmbeddingFor).toHaveBeenCalledTimes(2); }); }); diff --git a/src/memory-host-sdk/host/embeddings.ts b/src/memory-host-sdk/host/embeddings.ts index a8ca24ed7a7..ec672e6b5a9 100644 --- a/src/memory-host-sdk/host/embeddings.ts +++ b/src/memory-host-sdk/host/embeddings.ts @@ -1,43 +1,9 @@ -import fsSync from "node:fs"; import type { Llama, LlamaEmbeddingContext, LlamaModel } from "node-llama-cpp"; -import { formatErrorMessage } from "../../infra/errors.js"; import { normalizeOptionalString } from "../../shared/string-coerce.js"; -import { resolveUserPath } from "../../utils.js"; import { sanitizeAndNormalizeEmbedding } from "./embedding-vectors.js"; -import { - createBedrockEmbeddingProvider, - hasAwsCredentials, - type BedrockEmbeddingClient, -} from "./embeddings-bedrock.js"; -import { createGeminiEmbeddingProvider, type GeminiEmbeddingClient } from "./embeddings-gemini.js"; -import { - createLmstudioEmbeddingProvider, - type LmstudioEmbeddingClient, -} from "./embeddings-lmstudio.js"; -import { - createMistralEmbeddingProvider, - type MistralEmbeddingClient, -} from "./embeddings-mistral.js"; -import { createOllamaEmbeddingProvider, type OllamaEmbeddingClient } from "./embeddings-ollama.js"; -import { createOpenAiEmbeddingProvider, type OpenAiEmbeddingClient } from "./embeddings-openai.js"; -import { createVoyageEmbeddingProvider, type VoyageEmbeddingClient } from "./embeddings-voyage.js"; -import type { - EmbeddingProvider, - EmbeddingProviderFallback, - EmbeddingProviderId, - EmbeddingProviderOptions, - EmbeddingProviderRequest, - GeminiTaskType, -} from "./embeddings.types.js"; +import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.types.js"; import { importNodeLlamaCpp } from "./node-llama.js"; -export type { GeminiEmbeddingClient } from "./embeddings-gemini.js"; -export type { LmstudioEmbeddingClient } from "./embeddings-lmstudio.js"; -export type { MistralEmbeddingClient } from "./embeddings-mistral.js"; -export type { OpenAiEmbeddingClient } from "./embeddings-openai.js"; -export type { VoyageEmbeddingClient } from "./embeddings-voyage.js"; -export type { OllamaEmbeddingClient } from "./embeddings-ollama.js"; -export type { BedrockEmbeddingClient } from "./embeddings-bedrock.js"; export type { EmbeddingProvider, EmbeddingProviderFallback, @@ -47,51 +13,9 @@ export type { GeminiTaskType, } from "./embeddings.types.js"; -// Remote providers considered for auto-selection when provider === "auto". -// LM Studio and Ollama are intentionally excluded here so that "auto" mode does not -// implicitly assume either instance is available. -// Bedrock is handled separately when AWS credentials are detected. -const REMOTE_EMBEDDING_PROVIDER_IDS = ["openai", "gemini", "voyage", "mistral"] as const; - -export type EmbeddingProviderResult = { - provider: EmbeddingProvider | null; - requestedProvider: EmbeddingProviderRequest; - fallbackFrom?: EmbeddingProviderId; - fallbackReason?: string; - providerUnavailableReason?: string; - openAi?: OpenAiEmbeddingClient; - gemini?: GeminiEmbeddingClient; - voyage?: VoyageEmbeddingClient; - mistral?: MistralEmbeddingClient; - ollama?: OllamaEmbeddingClient; - bedrock?: BedrockEmbeddingClient; - lmstudio?: LmstudioEmbeddingClient; -}; - export const DEFAULT_LOCAL_MODEL = "hf:ggml-org/embeddinggemma-300m-qat-q8_0-GGUF/embeddinggemma-300m-qat-Q8_0.gguf"; -function canAutoSelectLocal(options: EmbeddingProviderOptions): boolean { - const modelPath = options.local?.modelPath?.trim(); - if (!modelPath) { - return false; - } - if (/^(hf:|https?:)/i.test(modelPath)) { - return false; - } - const resolved = resolveUserPath(modelPath); - try { - return fsSync.statSync(resolved).isFile(); - } catch { - return false; - } -} - -function isMissingApiKeyError(err: unknown): boolean { - const message = formatErrorMessage(err); - return message.includes("No API key found for provider"); -} - export async function createLocalEmbeddingProvider( options: EmbeddingProviderOptions, ): Promise { @@ -154,186 +78,3 @@ export async function createLocalEmbeddingProvider( }, }; } - -export async function createEmbeddingProvider( - options: EmbeddingProviderOptions, -): Promise { - const requestedProvider = options.provider; - const fallback = options.fallback; - - const createProvider = async (id: EmbeddingProviderId) => { - if (id === "local") { - const provider = await createLocalEmbeddingProvider(options); - return { provider }; - } - if (id === "lmstudio") { - const { provider, client } = await createLmstudioEmbeddingProvider(options); - return { provider, lmstudio: client }; - } - if (id === "ollama") { - const { provider, client } = await createOllamaEmbeddingProvider(options); - return { provider, ollama: client }; - } - if (id === "gemini") { - const { provider, client } = await createGeminiEmbeddingProvider(options); - return { provider, gemini: client }; - } - if (id === "voyage") { - const { provider, client } = await createVoyageEmbeddingProvider(options); - return { provider, voyage: client }; - } - if (id === "mistral") { - const { provider, client } = await createMistralEmbeddingProvider(options); - return { provider, mistral: client }; - } - if (id === "bedrock") { - const { provider, client } = await createBedrockEmbeddingProvider(options); - return { provider, bedrock: client }; - } - const { provider, client } = await createOpenAiEmbeddingProvider(options); - return { provider, openAi: client }; - }; - - const formatPrimaryError = (err: unknown, provider: EmbeddingProviderId) => - provider === "local" ? formatLocalSetupError(err) : formatErrorMessage(err); - - if (requestedProvider === "auto") { - const missingKeyErrors: string[] = []; - let localError: string | null = null; - - if (canAutoSelectLocal(options)) { - try { - const local = await createProvider("local"); - return { ...local, requestedProvider }; - } catch (err) { - localError = formatLocalSetupError(err); - } - } - - for (const provider of REMOTE_EMBEDDING_PROVIDER_IDS) { - try { - const result = await createProvider(provider); - return { ...result, requestedProvider }; - } catch (err) { - const message = formatPrimaryError(err, provider); - if (isMissingApiKeyError(err)) { - missingKeyErrors.push(message); - continue; - } - // Non-auth errors (e.g., network) are still fatal - const wrapped = new Error(message) as Error & { cause?: unknown }; - wrapped.cause = err; - throw wrapped; - } - } - - // Try bedrock if AWS credentials are available - if (await hasAwsCredentials()) { - try { - const result = await createProvider("bedrock"); - return { ...result, requestedProvider }; - } catch (err) { - const message = formatPrimaryError(err, "bedrock"); - if (isMissingApiKeyError(err)) { - missingKeyErrors.push(message); - } else { - const wrapped = new Error(message) as Error & { cause?: unknown }; - wrapped.cause = err; - throw wrapped; - } - } - } - - // All providers failed due to missing API keys - return null provider for FTS-only mode - const details = [...missingKeyErrors, localError].filter(Boolean) as string[]; - const reason = details.length > 0 ? details.join("\n\n") : "No embeddings provider available."; - return { - provider: null, - requestedProvider, - providerUnavailableReason: reason, - }; - } - - try { - const primary = await createProvider(requestedProvider); - return { ...primary, requestedProvider }; - } catch (primaryErr) { - const reason = formatPrimaryError(primaryErr, requestedProvider); - if (fallback && fallback !== "none" && fallback !== requestedProvider) { - try { - const fallbackResult = await createProvider(fallback); - return { - ...fallbackResult, - requestedProvider, - fallbackFrom: requestedProvider, - fallbackReason: reason, - }; - } catch (fallbackErr) { - // Both primary and fallback failed - check if it's auth-related - const fallbackReason = formatErrorMessage(fallbackErr); - const combinedReason = `${reason}\n\nFallback to ${fallback} failed: ${fallbackReason}`; - if (isMissingApiKeyError(primaryErr) && isMissingApiKeyError(fallbackErr)) { - // Both failed due to missing API keys - return null for FTS-only mode - return { - provider: null, - requestedProvider, - fallbackFrom: requestedProvider, - fallbackReason: reason, - providerUnavailableReason: combinedReason, - }; - } - // Non-auth errors are still fatal - const wrapped = new Error(combinedReason) as Error & { cause?: unknown }; - wrapped.cause = fallbackErr; - throw wrapped; - } - } - // No fallback configured - check if we should degrade to FTS-only - if (isMissingApiKeyError(primaryErr)) { - return { - provider: null, - requestedProvider, - providerUnavailableReason: reason, - }; - } - const wrapped = new Error(reason) as Error & { cause?: unknown }; - wrapped.cause = primaryErr; - throw wrapped; - } -} - -function isNodeLlamaCppMissing(err: unknown): boolean { - if (!(err instanceof Error)) { - return false; - } - const code = (err as Error & { code?: unknown }).code; - if (code === "ERR_MODULE_NOT_FOUND") { - return err.message.includes("node-llama-cpp"); - } - return false; -} - -function formatLocalSetupError(err: unknown): string { - const detail = formatErrorMessage(err); - const missing = isNodeLlamaCppMissing(err); - return [ - "Local embeddings unavailable.", - missing - ? "Reason: optional dependency node-llama-cpp is missing (or failed to install)." - : detail - ? `Reason: ${detail}` - : undefined, - missing && detail ? `Detail: ${detail}` : null, - "To enable local embeddings:", - "1) Use Node 24 (recommended for installs/updates; Node 22 LTS, currently 22.14+, remains supported)", - missing - ? "2) Reinstall OpenClaw (this should install node-llama-cpp): npm i -g openclaw@latest" - : null, - "3) If you use pnpm: pnpm approve-builds (select node-llama-cpp), then pnpm rebuild node-llama-cpp", - ...REMOTE_EMBEDDING_PROVIDER_IDS.map( - (provider) => `Or set agents.defaults.memorySearch.provider = "${provider}" (remote).`, - ), - ] - .filter(Boolean) - .join("\n"); -} diff --git a/src/memory-host-sdk/host/embeddings.types.ts b/src/memory-host-sdk/host/embeddings.types.ts index 297db0ecdc3..d83c5305b0a 100644 --- a/src/memory-host-sdk/host/embeddings.types.ts +++ b/src/memory-host-sdk/host/embeddings.types.ts @@ -11,18 +11,9 @@ export type EmbeddingProvider = { embedBatchInputs?: (inputs: EmbeddingInput[]) => Promise; }; -export type EmbeddingProviderId = - | "openai" - | "local" - | "gemini" - | "voyage" - | "mistral" - | "lmstudio" - | "ollama" - | "bedrock"; - -export type EmbeddingProviderRequest = EmbeddingProviderId | "auto"; -export type EmbeddingProviderFallback = EmbeddingProviderId | "none"; +export type EmbeddingProviderId = string; +export type EmbeddingProviderRequest = string; +export type EmbeddingProviderFallback = string; export type GeminiTaskType = | "RETRIEVAL_QUERY" @@ -36,14 +27,14 @@ export type GeminiTaskType = export type EmbeddingProviderOptions = { config: OpenClawConfig; agentDir?: string; - provider: EmbeddingProviderRequest; + provider?: EmbeddingProviderRequest; remote?: { baseUrl?: string; apiKey?: SecretInput; headers?: Record; }; model: string; - fallback: EmbeddingProviderFallback; + fallback?: EmbeddingProviderFallback; local?: { modelPath?: string; modelCacheDir?: string; diff --git a/src/memory-host-sdk/host/multimodal.ts b/src/memory-host-sdk/host/multimodal.ts index baf97c39666..3ad16b9d414 100644 --- a/src/memory-host-sdk/host/multimodal.ts +++ b/src/memory-host-sdk/host/multimodal.ts @@ -106,21 +106,3 @@ export function classifyMemoryMultimodalPath( } return null; } - -export function normalizeGeminiEmbeddingModelForMemory(model: string): string { - const trimmed = model.trim(); - if (!trimmed) { - return ""; - } - return trimmed.replace(/^models\//, "").replace(/^(gemini|google)\//, ""); -} - -export function supportsMemoryMultimodalEmbeddings(params: { - provider: string; - model: string; -}): boolean { - if (params.provider !== "gemini") { - return false; - } - return normalizeGeminiEmbeddingModelForMemory(params.model) === "gemini-embedding-2-preview"; -} diff --git a/src/memory-host-sdk/host/types.ts b/src/memory-host-sdk/host/types.ts index 992cda6579c..92ec371b240 100644 --- a/src/memory-host-sdk/host/types.ts +++ b/src/memory-host-sdk/host/types.ts @@ -85,11 +85,7 @@ export interface MemorySearchManager { onDebug?: (debug: MemorySearchRuntimeDebug) => void; }, ): Promise; - readFile(params: { - relPath: string; - from?: number; - lines?: number; - }): Promise; + readFile(params: { relPath: string; from?: number; lines?: number }): Promise; status(): MemoryProviderStatus; sync?(params?: { reason?: string; diff --git a/src/memory-host-sdk/multimodal.ts b/src/memory-host-sdk/multimodal.ts index 5c62de35490..eb11867ac3a 100644 --- a/src/memory-host-sdk/multimodal.ts +++ b/src/memory-host-sdk/multimodal.ts @@ -1,6 +1,5 @@ export { isMemoryMultimodalEnabled, normalizeMemoryMultimodalSettings, - supportsMemoryMultimodalEmbeddings, type MemoryMultimodalSettings, } from "./host/multimodal.js"; diff --git a/src/plugin-sdk/provider-auth-runtime.ts b/src/plugin-sdk/provider-auth-runtime.ts index 40360e23c15..605116d512d 100644 --- a/src/plugin-sdk/provider-auth-runtime.ts +++ b/src/plugin-sdk/provider-auth-runtime.ts @@ -5,6 +5,10 @@ import path from "node:path"; import { fileURLToPath, pathToFileURL } from "node:url"; export { resolveEnvApiKey } from "../agents/model-auth-env.js"; +export { + collectProviderApiKeysForExecution, + executeWithApiKeyRotation, +} from "../agents/api-key-rotation.js"; export { NON_ENV_SECRETREF_MARKER } from "../agents/model-auth-markers.js"; export { requireApiKey, diff --git a/src/plugins/capability-provider-runtime.ts b/src/plugins/capability-provider-runtime.ts index 51ab6d646fd..c3a350b145d 100644 --- a/src/plugins/capability-provider-runtime.ts +++ b/src/plugins/capability-provider-runtime.ts @@ -84,13 +84,29 @@ export function resolvePluginCapabilityProviders[] { const activeRegistry = resolveRuntimePluginRegistry(); const activeProviders = activeRegistry?.[params.key] ?? []; - if (activeProviders.length > 0) { + if (activeProviders.length > 0 && params.key !== "memoryEmbeddingProviders") { return activeProviders.map((entry) => entry.provider) as CapabilityProviderForKey[]; } const compatConfig = resolveCapabilityProviderConfig({ key: params.key, cfg: params.cfg }); const loadOptions = compatConfig === undefined ? undefined : { config: compatConfig }; const registry = resolveRuntimePluginRegistry(loadOptions); - return (registry?.[params.key] ?? []).map( - (entry) => entry.provider, - ) as CapabilityProviderForKey[]; + if (params.key !== "memoryEmbeddingProviders") { + return (registry?.[params.key] ?? []).map( + (entry) => entry.provider, + ) as CapabilityProviderForKey[]; + } + const merged = new Map>(); + for (const entry of activeProviders) { + const provider = entry.provider as CapabilityProviderForKey & { id?: string }; + if (provider.id) { + merged.set(provider.id, provider); + } + } + for (const entry of registry?.[params.key] ?? []) { + const provider = entry.provider as CapabilityProviderForKey & { id?: string }; + if (provider.id && !merged.has(provider.id)) { + merged.set(provider.id, provider); + } + } + return [...merged.values()]; } diff --git a/src/plugins/contracts/package-manifest.contract.test.ts b/src/plugins/contracts/package-manifest.contract.test.ts index 50ee6916996..3f1d4de2bfb 100644 --- a/src/plugins/contracts/package-manifest.contract.test.ts +++ b/src/plugins/contracts/package-manifest.contract.test.ts @@ -28,7 +28,14 @@ const packageManifestContractTests: PackageManifestContractParams[] = [ }, { pluginId: "irc", minHostVersionBaseline: "2026.3.22" }, { pluginId: "line", minHostVersionBaseline: "2026.3.22" }, - { pluginId: "amazon-bedrock", mirroredRootRuntimeDeps: ["@aws-sdk/client-bedrock"] }, + { + pluginId: "amazon-bedrock", + mirroredRootRuntimeDeps: [ + "@aws-sdk/client-bedrock", + "@aws-sdk/client-bedrock-runtime", + "@aws-sdk/credential-provider-node", + ], + }, { pluginId: "amazon-bedrock-mantle", mirroredRootRuntimeDeps: ["@aws/bedrock-token-generator"], diff --git a/src/plugins/memory-embedding-provider-runtime.test.ts b/src/plugins/memory-embedding-provider-runtime.test.ts index d62bc004edf..ba30496f31b 100644 --- a/src/plugins/memory-embedding-provider-runtime.test.ts +++ b/src/plugins/memory-embedding-provider-runtime.test.ts @@ -36,7 +36,7 @@ afterEach(() => { }); describe("memory embedding provider runtime resolution", () => { - it("prefers registered adapters over capability fallback adapters", () => { + it("merges registered and declared capability fallback adapters", () => { registerMemoryEmbeddingProvider({ id: "registered", create: async () => ({ provider: null }), @@ -45,9 +45,10 @@ describe("memory embedding provider runtime resolution", () => { expect(runtimeModule.listMemoryEmbeddingProviders().map((adapter) => adapter.id)).toEqual([ "registered", + "capability", ]); expect(runtimeModule.getMemoryEmbeddingProvider("registered")?.id).toBe("registered"); - expect(mocks.resolvePluginCapabilityProviders).not.toHaveBeenCalled(); + expect(mocks.resolvePluginCapabilityProviders).toHaveBeenCalledTimes(1); }); it("falls back to declared capability adapters when the registry is cold", () => { @@ -60,14 +61,22 @@ describe("memory embedding provider runtime resolution", () => { expect(mocks.resolvePluginCapabilityProviders).toHaveBeenCalledTimes(2); }); - it("does not consult capability fallback once runtime adapters are registered", () => { - registerMemoryEmbeddingProvider({ + it("prefers registered adapters over declared capability fallback adapters with the same id", () => { + const registered = { id: "openai", create: async () => ({ provider: null }), + } satisfies MemoryEmbeddingProviderAdapter; + registerMemoryEmbeddingProvider({ + ...registered, }); - mocks.resolvePluginCapabilityProviders.mockReturnValue([createCapabilityAdapter("ollama")]); + mocks.resolvePluginCapabilityProviders.mockReturnValue([createCapabilityAdapter("openai")]); - expect(runtimeModule.getMemoryEmbeddingProvider("ollama")).toBeUndefined(); - expect(mocks.resolvePluginCapabilityProviders).not.toHaveBeenCalled(); + expect(runtimeModule.getMemoryEmbeddingProvider("openai")).toEqual( + expect.objectContaining({ id: "openai" }), + ); + expect(runtimeModule.listMemoryEmbeddingProviders().map((adapter) => adapter.id)).toEqual([ + "openai", + ]); + expect(mocks.resolvePluginCapabilityProviders).toHaveBeenCalledTimes(1); }); }); diff --git a/src/plugins/memory-embedding-provider-runtime.ts b/src/plugins/memory-embedding-provider-runtime.ts index 4e4d6316872..7a5e3ff64f0 100644 --- a/src/plugins/memory-embedding-provider-runtime.ts +++ b/src/plugins/memory-embedding-provider-runtime.ts @@ -15,13 +15,16 @@ export function listMemoryEmbeddingProviders( cfg?: OpenClawConfig, ): MemoryEmbeddingProviderAdapter[] { const registered = listRegisteredMemoryEmbeddingProviderAdapters(); - if (registered.length > 0) { - return registered; - } - return resolvePluginCapabilityProviders({ + const merged = new Map(registered.map((adapter) => [adapter.id, adapter])); + for (const adapter of resolvePluginCapabilityProviders({ key: "memoryEmbeddingProviders", cfg, - }); + })) { + if (!merged.has(adapter.id)) { + merged.set(adapter.id, adapter); + } + } + return [...merged.values()]; } export function getMemoryEmbeddingProvider( @@ -32,8 +35,5 @@ export function getMemoryEmbeddingProvider( if (registered) { return registered.adapter; } - if (listRegisteredMemoryEmbeddingProviders().length > 0) { - return undefined; - } return listMemoryEmbeddingProviders(cfg).find((adapter) => adapter.id === id); } diff --git a/src/plugins/memory-embedding-providers.ts b/src/plugins/memory-embedding-providers.ts index 717d44b508f..019b4e18804 100644 --- a/src/plugins/memory-embedding-providers.ts +++ b/src/plugins/memory-embedding-providers.ts @@ -35,6 +35,8 @@ export type MemoryEmbeddingProvider = { export type MemoryEmbeddingProviderCreateOptions = { config: OpenClawConfig; agentDir?: string; + provider?: string; + fallback?: string; remote?: { baseUrl?: string; apiKey?: SecretInput; @@ -46,6 +48,14 @@ export type MemoryEmbeddingProviderCreateOptions = { modelCacheDir?: string; }; outputDimensionality?: number; + taskType?: + | "RETRIEVAL_QUERY" + | "RETRIEVAL_DOCUMENT" + | "SEMANTIC_SIMILARITY" + | "CLASSIFICATION" + | "CLUSTERING" + | "QUESTION_ANSWERING" + | "FACT_VERIFICATION"; }; export type MemoryEmbeddingProviderCreateResult = { @@ -57,6 +67,7 @@ export type MemoryEmbeddingProviderAdapter = { id: string; defaultModel?: string; transport?: "local" | "remote"; + authProviderId?: string; autoSelectPriority?: number; allowExplicitWhenConfiguredAuto?: boolean; supportsMultimodalEmbeddings?: (params: { model: string }) => boolean;