mirror of
https://github.com/openclaw/openclaw.git
synced 2026-05-06 05:50:43 +00:00
refactor: move memory embeddings into provider plugins
This commit is contained in:
@@ -318,7 +318,7 @@ Current bundled provider examples:
|
||||
| `plugin-sdk/memory-core` | Bundled memory-core helpers | Memory manager/config/file/CLI helper surface |
|
||||
| `plugin-sdk/memory-core-engine-runtime` | Memory engine runtime facade | Memory index/search runtime facade |
|
||||
| `plugin-sdk/memory-core-host-engine-foundation` | Memory host foundation engine | Memory host foundation engine exports |
|
||||
| `plugin-sdk/memory-core-host-engine-embeddings` | Memory host embedding engine | Memory host embedding engine exports |
|
||||
| `plugin-sdk/memory-core-host-engine-embeddings` | Memory host embedding engine | Memory embedding contracts, registry access, local provider, and generic batch/remote helpers; concrete remote providers live in their owning plugins |
|
||||
| `plugin-sdk/memory-core-host-engine-qmd` | Memory host QMD engine | Memory host QMD engine exports |
|
||||
| `plugin-sdk/memory-core-host-engine-storage` | Memory host storage engine | Memory host storage engine exports |
|
||||
| `plugin-sdk/memory-core-host-multimodal` | Memory host multimodal helpers | Memory host multimodal helpers |
|
||||
|
||||
@@ -264,7 +264,7 @@ explicitly promotes one as public.
|
||||
| `plugin-sdk/memory-core` | Bundled memory-core helper surface for manager/config/file/CLI helpers |
|
||||
| `plugin-sdk/memory-core-engine-runtime` | Memory index/search runtime facade |
|
||||
| `plugin-sdk/memory-core-host-engine-foundation` | Memory host foundation engine exports |
|
||||
| `plugin-sdk/memory-core-host-engine-embeddings` | Memory host embedding engine exports |
|
||||
| `plugin-sdk/memory-core-host-engine-embeddings` | Memory host embedding contracts, registry access, local provider, and generic batch/remote helpers |
|
||||
| `plugin-sdk/memory-core-host-engine-qmd` | Memory host QMD engine exports |
|
||||
| `plugin-sdk/memory-core-host-engine-storage` | Memory host storage engine exports |
|
||||
| `plugin-sdk/memory-core-host-multimodal` | Memory host multimodal helpers |
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import { normalizeLowercaseStringOrEmpty } from "../../shared/string-coerce.js";
|
||||
import { sanitizeAndNormalizeEmbedding } from "./embedding-vectors.js";
|
||||
import { debugEmbeddingsLog } from "./embeddings-debug.js";
|
||||
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.types.js";
|
||||
import {
|
||||
debugEmbeddingsLog,
|
||||
sanitizeAndNormalizeEmbedding,
|
||||
type MemoryEmbeddingProvider,
|
||||
type MemoryEmbeddingProviderCreateOptions,
|
||||
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
|
||||
import { normalizeLowercaseStringOrEmpty } from "openclaw/plugin-sdk/text-runtime";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Types & constants
|
||||
@@ -254,8 +257,8 @@ function parseCohereBatch(family: Family, raw: string): number[][] {
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export async function createBedrockEmbeddingProvider(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<{ provider: EmbeddingProvider; client: BedrockEmbeddingClient }> {
|
||||
options: MemoryEmbeddingProviderCreateOptions,
|
||||
): Promise<{ provider: MemoryEmbeddingProvider; client: BedrockEmbeddingClient }> {
|
||||
const client = resolveBedrockEmbeddingClient(options);
|
||||
const { BedrockRuntimeClient, InvokeModelCommand } = await loadSdk();
|
||||
const sdk = new BedrockRuntimeClient({ region: client.region });
|
||||
@@ -333,7 +336,7 @@ export async function createBedrockEmbeddingProvider(
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export function resolveBedrockEmbeddingClient(
|
||||
options: EmbeddingProviderOptions,
|
||||
options: MemoryEmbeddingProviderCreateOptions,
|
||||
): BedrockEmbeddingClient {
|
||||
const model = normalizeBedrockEmbeddingModel(options.model);
|
||||
const spec = resolveSpec(model);
|
||||
37
extensions/amazon-bedrock/memory-embedding-adapter.ts
Normal file
37
extensions/amazon-bedrock/memory-embedding-adapter.ts
Normal file
@@ -0,0 +1,37 @@
|
||||
import {
|
||||
isMissingEmbeddingApiKeyError,
|
||||
type MemoryEmbeddingProviderAdapter,
|
||||
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
|
||||
import {
|
||||
createBedrockEmbeddingProvider,
|
||||
DEFAULT_BEDROCK_EMBEDDING_MODEL,
|
||||
} from "./embedding-provider.js";
|
||||
|
||||
export const bedrockMemoryEmbeddingProviderAdapter: MemoryEmbeddingProviderAdapter = {
|
||||
id: "bedrock",
|
||||
defaultModel: DEFAULT_BEDROCK_EMBEDDING_MODEL,
|
||||
transport: "remote",
|
||||
authProviderId: "amazon-bedrock",
|
||||
autoSelectPriority: 60,
|
||||
allowExplicitWhenConfiguredAuto: true,
|
||||
shouldContinueAutoSelection: isMissingEmbeddingApiKeyError,
|
||||
create: async (options) => {
|
||||
const { provider, client } = await createBedrockEmbeddingProvider({
|
||||
...options,
|
||||
provider: "bedrock",
|
||||
fallback: "none",
|
||||
});
|
||||
return {
|
||||
provider,
|
||||
runtime: {
|
||||
id: "bedrock",
|
||||
cacheKeyData: {
|
||||
provider: "bedrock",
|
||||
region: client.region,
|
||||
model: client.model,
|
||||
dimensions: client.dimensions,
|
||||
},
|
||||
},
|
||||
};
|
||||
},
|
||||
};
|
||||
@@ -2,6 +2,9 @@
|
||||
"id": "amazon-bedrock",
|
||||
"enabledByDefault": true,
|
||||
"providers": ["amazon-bedrock"],
|
||||
"contracts": {
|
||||
"memoryEmbeddingProviders": ["bedrock"]
|
||||
},
|
||||
"configSchema": {
|
||||
"type": "object",
|
||||
"additionalProperties": false,
|
||||
|
||||
@@ -5,7 +5,9 @@
|
||||
"description": "OpenClaw Amazon Bedrock provider plugin",
|
||||
"type": "module",
|
||||
"dependencies": {
|
||||
"@aws-sdk/client-bedrock": "3.1028.0"
|
||||
"@aws-sdk/client-bedrock": "3.1028.0",
|
||||
"@aws-sdk/client-bedrock-runtime": "3.1028.0",
|
||||
"@aws-sdk/credential-provider-node": "3.972.30"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@openclaw/plugin-sdk": "workspace:*"
|
||||
|
||||
@@ -14,6 +14,7 @@ import {
|
||||
resolveBedrockConfigApiKey,
|
||||
resolveImplicitBedrockProvider,
|
||||
} from "./api.js";
|
||||
import { bedrockMemoryEmbeddingProviderAdapter } from "./memory-embedding-adapter.js";
|
||||
|
||||
type GuardrailConfig = {
|
||||
guardrailIdentifier: string;
|
||||
@@ -78,6 +79,8 @@ export function registerAmazonBedrockPlugin(api: OpenClawPluginApi): void {
|
||||
const pluginConfig = (api.pluginConfig ?? {}) as AmazonBedrockPluginConfig;
|
||||
const guardrail = pluginConfig.guardrail;
|
||||
|
||||
api.registerMemoryEmbeddingProvider(bedrockMemoryEmbeddingProviderAdapter);
|
||||
|
||||
const baseWrapStreamFn = ({ modelId, streamFn }: { modelId: string; streamFn?: StreamFn }) =>
|
||||
isAnthropicBedrockModel(modelId) ? streamFn : createBedrockNoCacheWrapper(streamFn);
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ const resolveFirstGithubTokenMock = vi.hoisted(() => vi.fn());
|
||||
const resolveCopilotApiTokenMock = vi.hoisted(() => vi.fn());
|
||||
const resolveConfiguredSecretInputStringMock = vi.hoisted(() => vi.fn());
|
||||
const fetchWithSsrFGuardMock = vi.hoisted(() => vi.fn());
|
||||
const createGitHubCopilotEmbeddingProviderMock = vi.hoisted(() => vi.fn());
|
||||
|
||||
vi.mock("./auth.js", () => ({
|
||||
resolveFirstGithubToken: resolveFirstGithubTokenMock,
|
||||
@@ -19,10 +18,6 @@ vi.mock("openclaw/plugin-sdk/github-copilot-token", () => ({
|
||||
resolveCopilotApiToken: resolveCopilotApiTokenMock,
|
||||
}));
|
||||
|
||||
vi.mock("openclaw/plugin-sdk/memory-core-host-engine-embeddings", () => ({
|
||||
createGitHubCopilotEmbeddingProvider: createGitHubCopilotEmbeddingProviderMock,
|
||||
}));
|
||||
|
||||
vi.mock("openclaw/plugin-sdk/ssrf-runtime", () => ({
|
||||
fetchWithSsrFGuard: fetchWithSsrFGuardMock,
|
||||
}));
|
||||
@@ -73,15 +68,6 @@ describe("githubCopilotMemoryEmbeddingProviderAdapter", () => {
|
||||
source: "test",
|
||||
baseUrl: TEST_BASE_URL,
|
||||
});
|
||||
createGitHubCopilotEmbeddingProviderMock.mockImplementation(async (client) => ({
|
||||
provider: {
|
||||
id: "github-copilot",
|
||||
model: client.model,
|
||||
embedQuery: async () => [0.1, 0.2, 0.3],
|
||||
embedBatch: async (texts: string[]) => texts.map(() => [0.1, 0.2, 0.3]),
|
||||
},
|
||||
client,
|
||||
}));
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
@@ -89,7 +75,6 @@ describe("githubCopilotMemoryEmbeddingProviderAdapter", () => {
|
||||
resolveConfiguredSecretInputStringMock.mockReset();
|
||||
resolveFirstGithubTokenMock.mockReset();
|
||||
resolveCopilotApiTokenMock.mockReset();
|
||||
createGitHubCopilotEmbeddingProviderMock.mockReset();
|
||||
fetchWithSsrFGuardMock.mockReset();
|
||||
});
|
||||
|
||||
@@ -113,12 +98,8 @@ describe("githubCopilotMemoryEmbeddingProviderAdapter", () => {
|
||||
const result = await githubCopilotMemoryEmbeddingProviderAdapter.create(defaultCreateOptions());
|
||||
|
||||
expect(result.provider?.model).toBe("text-embedding-3-small");
|
||||
expect(createGitHubCopilotEmbeddingProviderMock).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
baseUrl: TEST_BASE_URL,
|
||||
githubToken: "gh_test_token_123",
|
||||
model: "text-embedding-3-small",
|
||||
}),
|
||||
expect(resolveCopilotApiTokenMock).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ githubToken: "gh_test_token_123" }),
|
||||
);
|
||||
});
|
||||
|
||||
@@ -217,14 +198,12 @@ describe("githubCopilotMemoryEmbeddingProviderAdapter", () => {
|
||||
} as never);
|
||||
|
||||
expect(resolveFirstGithubTokenMock).toHaveBeenCalled();
|
||||
expect(createGitHubCopilotEmbeddingProviderMock).toHaveBeenCalledWith({
|
||||
baseUrl: "https://proxy.example/v1",
|
||||
env: process.env,
|
||||
fetchImpl: fetch,
|
||||
githubToken: "gh_remote_token",
|
||||
headers: { "X-Proxy-Token": "proxy" },
|
||||
model: "text-embedding-3-small",
|
||||
});
|
||||
expect(resolveCopilotApiTokenMock).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
env: process.env,
|
||||
githubToken: "gh_remote_token",
|
||||
}),
|
||||
);
|
||||
|
||||
const discoveryCall = fetchWithSsrFGuardMock.mock.calls[0]?.[0] as {
|
||||
init: { headers: Record<string, string> };
|
||||
|
||||
@@ -4,7 +4,10 @@ import {
|
||||
resolveCopilotApiToken,
|
||||
} from "openclaw/plugin-sdk/github-copilot-token";
|
||||
import {
|
||||
createGitHubCopilotEmbeddingProvider,
|
||||
buildRemoteBaseUrlPolicy,
|
||||
sanitizeAndNormalizeEmbedding,
|
||||
withRemoteHttpResponse,
|
||||
type MemoryEmbeddingProvider,
|
||||
type MemoryEmbeddingProviderAdapter,
|
||||
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
|
||||
import { fetchWithSsrFGuard, type SsrFPolicy } from "openclaw/plugin-sdk/ssrf-runtime";
|
||||
@@ -44,6 +47,15 @@ type CopilotModelEntry = {
|
||||
supported_endpoints?: unknown;
|
||||
};
|
||||
|
||||
type GitHubCopilotEmbeddingClient = {
|
||||
githubToken: string;
|
||||
model: string;
|
||||
baseUrl?: string;
|
||||
headers?: Record<string, string>;
|
||||
env?: NodeJS.ProcessEnv;
|
||||
fetchImpl?: typeof fetch;
|
||||
};
|
||||
|
||||
function isCopilotSetupError(err: unknown): boolean {
|
||||
if (!(err instanceof Error)) {
|
||||
return false;
|
||||
@@ -147,9 +159,126 @@ function pickBestModel(available: string[], userModel?: string): string {
|
||||
throw new Error("No embedding models available from GitHub Copilot");
|
||||
}
|
||||
|
||||
function parseGitHubCopilotEmbeddingPayload(payload: unknown, expectedCount: number): number[][] {
|
||||
if (!payload || typeof payload !== "object") {
|
||||
throw new Error("GitHub Copilot embeddings response missing data[]");
|
||||
}
|
||||
const data = (payload as { data?: unknown }).data;
|
||||
if (!Array.isArray(data)) {
|
||||
throw new Error("GitHub Copilot embeddings response missing data[]");
|
||||
}
|
||||
|
||||
const vectors = Array.from<number[] | undefined>({ length: expectedCount });
|
||||
for (const entry of data) {
|
||||
if (!entry || typeof entry !== "object") {
|
||||
throw new Error("GitHub Copilot embeddings response contains an invalid entry");
|
||||
}
|
||||
const indexValue = (entry as { index?: unknown }).index;
|
||||
const embedding = (entry as { embedding?: unknown }).embedding;
|
||||
const index = typeof indexValue === "number" ? indexValue : Number.NaN;
|
||||
if (!Number.isInteger(index)) {
|
||||
throw new Error("GitHub Copilot embeddings response contains an invalid index");
|
||||
}
|
||||
if (index < 0 || index >= expectedCount) {
|
||||
throw new Error("GitHub Copilot embeddings response contains an out-of-range index");
|
||||
}
|
||||
if (vectors[index] !== undefined) {
|
||||
throw new Error("GitHub Copilot embeddings response contains duplicate indexes");
|
||||
}
|
||||
if (!Array.isArray(embedding) || !embedding.every((value) => typeof value === "number")) {
|
||||
throw new Error("GitHub Copilot embeddings response contains an invalid embedding");
|
||||
}
|
||||
vectors[index] = sanitizeAndNormalizeEmbedding(embedding);
|
||||
}
|
||||
|
||||
for (let index = 0; index < expectedCount; index += 1) {
|
||||
if (vectors[index] === undefined) {
|
||||
throw new Error("GitHub Copilot embeddings response missing vectors for some inputs");
|
||||
}
|
||||
}
|
||||
return vectors as number[][];
|
||||
}
|
||||
|
||||
async function resolveGitHubCopilotEmbeddingSession(client: GitHubCopilotEmbeddingClient): Promise<{
|
||||
baseUrl: string;
|
||||
headers: Record<string, string>;
|
||||
}> {
|
||||
const token = await resolveCopilotApiToken({
|
||||
githubToken: client.githubToken,
|
||||
env: client.env,
|
||||
fetchImpl: client.fetchImpl,
|
||||
});
|
||||
const baseUrl = client.baseUrl?.trim() || token.baseUrl || DEFAULT_COPILOT_API_BASE_URL;
|
||||
return {
|
||||
baseUrl,
|
||||
headers: {
|
||||
...COPILOT_HEADERS_STATIC,
|
||||
...client.headers,
|
||||
Authorization: `Bearer ${token.token}`,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
async function createGitHubCopilotEmbeddingProvider(
|
||||
client: GitHubCopilotEmbeddingClient,
|
||||
): Promise<{ provider: MemoryEmbeddingProvider; client: GitHubCopilotEmbeddingClient }> {
|
||||
const initialSession = await resolveGitHubCopilotEmbeddingSession(client);
|
||||
|
||||
const embed = async (input: string[]): Promise<number[][]> => {
|
||||
if (input.length === 0) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const session = await resolveGitHubCopilotEmbeddingSession(client);
|
||||
const url = `${session.baseUrl.replace(/\/$/, "")}/embeddings`;
|
||||
return await withRemoteHttpResponse({
|
||||
url,
|
||||
fetchImpl: client.fetchImpl,
|
||||
ssrfPolicy: buildRemoteBaseUrlPolicy(session.baseUrl),
|
||||
init: {
|
||||
method: "POST",
|
||||
headers: session.headers,
|
||||
body: JSON.stringify({ model: client.model, input }),
|
||||
},
|
||||
onResponse: async (response) => {
|
||||
if (!response.ok) {
|
||||
throw new Error(
|
||||
`GitHub Copilot embeddings HTTP ${response.status}: ${await response.text()}`,
|
||||
);
|
||||
}
|
||||
|
||||
let payload: unknown;
|
||||
try {
|
||||
payload = await response.json();
|
||||
} catch {
|
||||
throw new Error("GitHub Copilot embeddings returned invalid JSON");
|
||||
}
|
||||
return parseGitHubCopilotEmbeddingPayload(payload, input.length);
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
return {
|
||||
provider: {
|
||||
id: COPILOT_EMBEDDING_PROVIDER_ID,
|
||||
model: client.model,
|
||||
embedQuery: async (text) => {
|
||||
const [vector] = await embed([text]);
|
||||
return vector ?? [];
|
||||
},
|
||||
embedBatch: embed,
|
||||
},
|
||||
client: {
|
||||
...client,
|
||||
baseUrl: initialSession.baseUrl,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
export const githubCopilotMemoryEmbeddingProviderAdapter: MemoryEmbeddingProviderAdapter = {
|
||||
id: COPILOT_EMBEDDING_PROVIDER_ID,
|
||||
transport: "remote",
|
||||
authProviderId: COPILOT_EMBEDDING_PROVIDER_ID,
|
||||
autoSelectPriority: 15,
|
||||
allowExplicitWhenConfiguredAuto: true,
|
||||
shouldContinueAutoSelection: (err: unknown) => isCopilotSetupError(err),
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
import crypto from "node:crypto";
|
||||
import {
|
||||
buildEmbeddingBatchGroupOptions,
|
||||
runEmbeddingBatchGroups,
|
||||
type EmbeddingBatchExecutionParams,
|
||||
} from "./batch-runner.js";
|
||||
import { buildBatchHeaders, normalizeBatchBaseUrl } from "./batch-utils.js";
|
||||
import { sanitizeAndNormalizeEmbedding } from "./embedding-vectors.js";
|
||||
import { debugEmbeddingsLog } from "./embeddings-debug.js";
|
||||
import type { GeminiEmbeddingClient, GeminiTextEmbeddingRequest } from "./embeddings-gemini.js";
|
||||
import { hashText } from "./internal.js";
|
||||
import { withRemoteHttpResponse } from "./remote-http.js";
|
||||
buildBatchHeaders,
|
||||
debugEmbeddingsLog,
|
||||
normalizeBatchBaseUrl,
|
||||
sanitizeAndNormalizeEmbedding,
|
||||
withRemoteHttpResponse,
|
||||
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
|
||||
import type { GeminiEmbeddingClient, GeminiTextEmbeddingRequest } from "./embedding-provider.js";
|
||||
|
||||
export type GeminiBatchRequest = {
|
||||
custom_id: string;
|
||||
@@ -40,6 +41,10 @@ export type GeminiBatchOutputLine = {
|
||||
};
|
||||
|
||||
const GEMINI_BATCH_MAX_REQUESTS = 50000;
|
||||
function hashText(text: string): string {
|
||||
return crypto.createHash("sha256").update(text).digest("hex");
|
||||
}
|
||||
|
||||
function getGeminiUploadUrl(baseUrl: string): string {
|
||||
if (baseUrl.includes("/v1beta")) {
|
||||
return baseUrl.replace(/\/v1beta\/?$/, "/upload/v1beta");
|
||||
@@ -1,84 +1,40 @@
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import * as authModule from "../../agents/model-auth.js";
|
||||
import { afterEach, describe, expect, it, vi } from "vitest";
|
||||
import {
|
||||
buildGeminiEmbeddingRequest,
|
||||
buildGeminiTextEmbeddingRequest,
|
||||
createGeminiEmbeddingProvider,
|
||||
DEFAULT_GEMINI_EMBEDDING_MODEL,
|
||||
GEMINI_EMBEDDING_2_MODELS,
|
||||
isGeminiEmbedding2Model,
|
||||
normalizeGeminiModel,
|
||||
resolveGeminiOutputDimensionality,
|
||||
} from "./embeddings-gemini-request.js";
|
||||
import {
|
||||
createGeminiBatchFetchMock,
|
||||
createJsonResponseFetchMock,
|
||||
installFetchMock,
|
||||
mockResolvedProviderKey,
|
||||
parseFetchBody,
|
||||
readFirstFetchRequest,
|
||||
type JsonFetchMock,
|
||||
} from "./embeddings-provider.test-support.js";
|
||||
|
||||
const { resolveApiKeyForProviderMock } = vi.hoisted(() => ({
|
||||
resolveApiKeyForProviderMock: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock("../../agents/model-auth.js", () => {
|
||||
return {
|
||||
resolveApiKeyForProvider: resolveApiKeyForProviderMock,
|
||||
requireApiKey: (auth: { apiKey?: string; mode?: string }, provider: string) => {
|
||||
if (auth.apiKey) {
|
||||
return auth.apiKey;
|
||||
}
|
||||
throw new Error(`No API key resolved for provider "${provider}" (auth mode: ${auth.mode}).`);
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
vi.mock("../../agents/api-key-rotation.js", () => ({
|
||||
collectProviderApiKeysForExecution: (params: { primaryApiKey?: string }) =>
|
||||
params.primaryApiKey ? [params.primaryApiKey] : [],
|
||||
executeWithApiKeyRotation: async <T>(params: {
|
||||
apiKeys: string[];
|
||||
execute: (apiKey: string) => Promise<T>;
|
||||
}) => {
|
||||
const apiKey = params.apiKeys[0];
|
||||
if (!apiKey) {
|
||||
throw new Error('No API keys configured for provider "google".');
|
||||
}
|
||||
return await params.execute(apiKey);
|
||||
},
|
||||
}));
|
||||
|
||||
beforeEach(() => {
|
||||
vi.useRealTimers();
|
||||
vi.doUnmock("undici");
|
||||
});
|
||||
} from "./embedding-provider.js";
|
||||
|
||||
afterEach(() => {
|
||||
vi.doUnmock("undici");
|
||||
vi.resetAllMocks();
|
||||
vi.restoreAllMocks();
|
||||
vi.unstubAllGlobals();
|
||||
});
|
||||
|
||||
type GeminiProviderOptions = Parameters<
|
||||
typeof import("./embeddings-gemini.js").createGeminiEmbeddingProvider
|
||||
>[0];
|
||||
|
||||
async function createProviderWithFetch(
|
||||
fetchMock: JsonFetchMock,
|
||||
options: Partial<GeminiProviderOptions> & { model: string },
|
||||
) {
|
||||
installFetchMock(fetchMock as unknown as typeof globalThis.fetch);
|
||||
mockResolvedProviderKey(authModule.resolveApiKeyForProvider);
|
||||
const { createGeminiEmbeddingProvider } = await import("./embeddings-gemini.js");
|
||||
const { provider } = await createGeminiEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "gemini",
|
||||
fallback: "none",
|
||||
...options,
|
||||
function installFetchMock(
|
||||
handler: (input: RequestInfo | URL, init?: RequestInit) => unknown,
|
||||
): ReturnType<typeof vi.fn> {
|
||||
const fetchMock = vi.fn(async (input: RequestInfo | URL, init?: RequestInit) => {
|
||||
return new Response(JSON.stringify(handler(input, init)), {
|
||||
status: 200,
|
||||
headers: { "Content-Type": "application/json" },
|
||||
});
|
||||
});
|
||||
return provider;
|
||||
vi.stubGlobal("fetch", fetchMock);
|
||||
return fetchMock;
|
||||
}
|
||||
|
||||
function fetchJsonBody(fetchMock: ReturnType<typeof vi.fn>, index: number): unknown {
|
||||
const init = fetchMock.mock.calls[index]?.[1] as RequestInit | undefined;
|
||||
const body = init?.body;
|
||||
if (typeof body !== "string") {
|
||||
throw new Error("Expected JSON string request body.");
|
||||
}
|
||||
return JSON.parse(body) as unknown;
|
||||
}
|
||||
|
||||
describe("Gemini embedding request helpers", () => {
|
||||
@@ -149,24 +105,9 @@ describe("Gemini embedding request helpers", () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe("gemini embedding provider", () => {
|
||||
describe("Gemini embedding provider", () => {
|
||||
it("handles legacy and v2 request/response behavior", async () => {
|
||||
const legacyFetch = createGeminiBatchFetchMock(2);
|
||||
const legacyProvider = await createProviderWithFetch(legacyFetch, {
|
||||
model: "gemini-embedding-001",
|
||||
});
|
||||
|
||||
await legacyProvider.embedQuery("test query");
|
||||
await legacyProvider.embedBatch(["text1", "text2"]);
|
||||
|
||||
expect(parseFetchBody(legacyFetch, 0)).toMatchObject({
|
||||
taskType: "RETRIEVAL_QUERY",
|
||||
content: { parts: [{ text: "test query" }] },
|
||||
});
|
||||
expect(parseFetchBody(legacyFetch, 0)).not.toHaveProperty("outputDimensionality");
|
||||
expect(parseFetchBody(legacyFetch, 1)).not.toHaveProperty("outputDimensionality");
|
||||
|
||||
const v2Fetch = createJsonResponseFetchMock((input) => {
|
||||
const fetchMock = installFetchMock((input) => {
|
||||
const url = input instanceof URL ? input.href : typeof input === "string" ? input : input.url;
|
||||
return url.endsWith(":batchEmbedContents")
|
||||
? {
|
||||
@@ -176,16 +117,22 @@ describe("gemini embedding provider", () => {
|
||||
}
|
||||
: { embedding: { values: [3, 4, Number.NaN] } };
|
||||
});
|
||||
const v2Provider = await createProviderWithFetch(v2Fetch, {
|
||||
|
||||
const { provider } = await createGeminiEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "gemini",
|
||||
remote: { apiKey: "test-key" },
|
||||
model: "gemini-embedding-2-preview",
|
||||
outputDimensionality: 768,
|
||||
taskType: "SEMANTIC_SIMILARITY",
|
||||
fallback: "none",
|
||||
});
|
||||
await expect(v2Provider.embedQuery(" ")).resolves.toEqual([]);
|
||||
await expect(v2Provider.embedBatch([])).resolves.toEqual([]);
|
||||
await expect(v2Provider.embedQuery("test query")).resolves.toEqual([0.6, 0.8, 0]);
|
||||
|
||||
const structuredBatch = await v2Provider.embedBatchInputs?.([
|
||||
await expect(provider.embedQuery(" ")).resolves.toEqual([]);
|
||||
await expect(provider.embedBatch([])).resolves.toEqual([]);
|
||||
await expect(provider.embedQuery("test query")).resolves.toEqual([0.6, 0.8, 0]);
|
||||
|
||||
const structuredBatch = await provider.embedBatchInputs?.([
|
||||
{
|
||||
text: "Image file: diagram.png",
|
||||
parts: [
|
||||
@@ -206,38 +153,39 @@ describe("gemini embedding provider", () => {
|
||||
[0, 0, 1],
|
||||
]);
|
||||
|
||||
const { url } = readFirstFetchRequest(v2Fetch);
|
||||
expect(url).toBe(
|
||||
expect(fetchMock.mock.calls[0]?.[0]).toBe(
|
||||
"https://generativelanguage.googleapis.com/v1beta/models/gemini-embedding-2-preview:embedContent",
|
||||
);
|
||||
expect(parseFetchBody(v2Fetch, 0)).toMatchObject({
|
||||
expect(fetchJsonBody(fetchMock, 0)).toMatchObject({
|
||||
outputDimensionality: 768,
|
||||
taskType: "SEMANTIC_SIMILARITY",
|
||||
content: { parts: [{ text: "test query" }] },
|
||||
});
|
||||
expect(parseFetchBody(v2Fetch, 1).requests).toEqual([
|
||||
{
|
||||
model: "models/gemini-embedding-2-preview",
|
||||
content: {
|
||||
parts: [
|
||||
{ text: "Image file: diagram.png" },
|
||||
{ inlineData: { mimeType: "image/png", data: "img" } },
|
||||
],
|
||||
expect(fetchJsonBody(fetchMock, 1)).toMatchObject({
|
||||
requests: [
|
||||
{
|
||||
model: "models/gemini-embedding-2-preview",
|
||||
content: {
|
||||
parts: [
|
||||
{ text: "Image file: diagram.png" },
|
||||
{ inlineData: { mimeType: "image/png", data: "img" } },
|
||||
],
|
||||
},
|
||||
taskType: "SEMANTIC_SIMILARITY",
|
||||
outputDimensionality: 768,
|
||||
},
|
||||
taskType: "SEMANTIC_SIMILARITY",
|
||||
outputDimensionality: 768,
|
||||
},
|
||||
{
|
||||
model: "models/gemini-embedding-2-preview",
|
||||
content: {
|
||||
parts: [
|
||||
{ text: "Audio file: note.wav" },
|
||||
{ inlineData: { mimeType: "audio/wav", data: "aud" } },
|
||||
],
|
||||
{
|
||||
model: "models/gemini-embedding-2-preview",
|
||||
content: {
|
||||
parts: [
|
||||
{ text: "Audio file: note.wav" },
|
||||
{ inlineData: { mimeType: "audio/wav", data: "aud" } },
|
||||
],
|
||||
},
|
||||
taskType: "SEMANTIC_SIMILARITY",
|
||||
outputDimensionality: 768,
|
||||
},
|
||||
taskType: "SEMANTIC_SIMILARITY",
|
||||
outputDimensionality: 768,
|
||||
},
|
||||
]);
|
||||
],
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,44 +1,22 @@
|
||||
import { parseGeminiAuth } from "openclaw/plugin-sdk/image-generation-core";
|
||||
import {
|
||||
buildRemoteBaseUrlPolicy,
|
||||
debugEmbeddingsLog,
|
||||
sanitizeAndNormalizeEmbedding,
|
||||
withRemoteHttpResponse,
|
||||
type EmbeddingInput,
|
||||
type MemoryEmbeddingProvider,
|
||||
type MemoryEmbeddingProviderCreateOptions,
|
||||
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
|
||||
import { resolveMemorySecretInputString } from "openclaw/plugin-sdk/memory-core-host-secret";
|
||||
import {
|
||||
collectProviderApiKeysForExecution,
|
||||
executeWithApiKeyRotation,
|
||||
} from "../../agents/api-key-rotation.js";
|
||||
import { requireApiKey, resolveApiKeyForProvider } from "../../agents/model-auth.js";
|
||||
import { parseGeminiAuth } from "../../infra/gemini-auth.js";
|
||||
import {
|
||||
DEFAULT_GOOGLE_API_BASE_URL,
|
||||
normalizeGoogleApiBaseUrl,
|
||||
} from "../../infra/google-api-base-url.js";
|
||||
import type { SsrFPolicy } from "../../infra/net/ssrf.js";
|
||||
import { normalizeOptionalString } from "../../shared/string-coerce.js";
|
||||
import type { EmbeddingInput } from "./embedding-inputs.js";
|
||||
import { sanitizeAndNormalizeEmbedding } from "./embedding-vectors.js";
|
||||
import { debugEmbeddingsLog } from "./embeddings-debug.js";
|
||||
import {
|
||||
buildGeminiEmbeddingRequest,
|
||||
buildGeminiTextEmbeddingRequest,
|
||||
isGeminiEmbedding2Model,
|
||||
normalizeGeminiModel,
|
||||
resolveGeminiOutputDimensionality,
|
||||
} from "./embeddings-gemini-request.js";
|
||||
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.types.js";
|
||||
import { buildRemoteBaseUrlPolicy, withRemoteHttpResponse } from "./remote-http.js";
|
||||
import { resolveMemorySecretInputString } from "./secret-input.js";
|
||||
|
||||
export {
|
||||
buildGeminiEmbeddingRequest,
|
||||
buildGeminiTextEmbeddingRequest,
|
||||
DEFAULT_GEMINI_EMBEDDING_MODEL,
|
||||
GEMINI_EMBEDDING_2_MODELS,
|
||||
isGeminiEmbedding2Model,
|
||||
normalizeGeminiModel,
|
||||
resolveGeminiOutputDimensionality,
|
||||
type GeminiEmbeddingRequest,
|
||||
type GeminiInlinePart,
|
||||
type GeminiPart,
|
||||
type GeminiTaskType,
|
||||
type GeminiTextEmbeddingRequest,
|
||||
type GeminiTextPart,
|
||||
} from "./embeddings-gemini-request.js";
|
||||
requireApiKey,
|
||||
resolveApiKeyForProvider,
|
||||
} from "openclaw/plugin-sdk/provider-auth-runtime";
|
||||
import type { SsrFPolicy } from "openclaw/plugin-sdk/ssrf-runtime";
|
||||
import { normalizeOptionalString } from "openclaw/plugin-sdk/text-runtime";
|
||||
|
||||
export type GeminiEmbeddingClient = {
|
||||
baseUrl: string;
|
||||
@@ -50,9 +28,111 @@ export type GeminiEmbeddingClient = {
|
||||
outputDimensionality?: number;
|
||||
};
|
||||
|
||||
export const DEFAULT_GEMINI_EMBEDDING_MODEL = "gemini-embedding-001";
|
||||
const DEFAULT_GOOGLE_API_BASE_URL = "https://generativelanguage.googleapis.com/v1beta";
|
||||
const GEMINI_MAX_INPUT_TOKENS: Record<string, number> = {
|
||||
"text-embedding-004": 2048,
|
||||
"gemini-embedding-001": 2048,
|
||||
"gemini-embedding-2-preview": 8192,
|
||||
};
|
||||
|
||||
export type GeminiTaskType = NonNullable<MemoryEmbeddingProviderCreateOptions["taskType"]>;
|
||||
|
||||
// --- gemini-embedding-2-preview support ---
|
||||
|
||||
export const GEMINI_EMBEDDING_2_MODELS = new Set([
|
||||
"gemini-embedding-2-preview",
|
||||
// Add the GA model name here once released.
|
||||
]);
|
||||
|
||||
const GEMINI_EMBEDDING_2_DEFAULT_DIMENSIONS = 3072;
|
||||
const GEMINI_EMBEDDING_2_VALID_DIMENSIONS = [768, 1536, 3072] as const;
|
||||
|
||||
export type GeminiTextPart = { text: string };
|
||||
export type GeminiInlinePart = {
|
||||
inlineData: { mimeType: string; data: string };
|
||||
};
|
||||
export type GeminiPart = GeminiTextPart | GeminiInlinePart;
|
||||
export type GeminiEmbeddingRequest = {
|
||||
content: { parts: GeminiPart[] };
|
||||
taskType: GeminiTaskType;
|
||||
outputDimensionality?: number;
|
||||
model?: string;
|
||||
};
|
||||
export type GeminiTextEmbeddingRequest = GeminiEmbeddingRequest;
|
||||
|
||||
/** Builds the text-only Gemini embedding request shape used across direct and batch APIs. */
|
||||
export function buildGeminiTextEmbeddingRequest(params: {
|
||||
text: string;
|
||||
taskType: GeminiTaskType;
|
||||
outputDimensionality?: number;
|
||||
modelPath?: string;
|
||||
}): GeminiTextEmbeddingRequest {
|
||||
return buildGeminiEmbeddingRequest({
|
||||
input: { text: params.text },
|
||||
taskType: params.taskType,
|
||||
outputDimensionality: params.outputDimensionality,
|
||||
modelPath: params.modelPath,
|
||||
});
|
||||
}
|
||||
|
||||
export function buildGeminiEmbeddingRequest(params: {
|
||||
input: EmbeddingInput;
|
||||
taskType: GeminiTaskType;
|
||||
outputDimensionality?: number;
|
||||
modelPath?: string;
|
||||
}): GeminiEmbeddingRequest {
|
||||
const request: GeminiEmbeddingRequest = {
|
||||
content: {
|
||||
parts: params.input.parts?.map((part) =>
|
||||
part.type === "text"
|
||||
? ({ text: part.text } satisfies GeminiTextPart)
|
||||
: ({
|
||||
inlineData: { mimeType: part.mimeType, data: part.data },
|
||||
} satisfies GeminiInlinePart),
|
||||
) ?? [{ text: params.input.text }],
|
||||
},
|
||||
taskType: params.taskType,
|
||||
};
|
||||
if (params.modelPath) {
|
||||
request.model = params.modelPath;
|
||||
}
|
||||
if (params.outputDimensionality != null) {
|
||||
request.outputDimensionality = params.outputDimensionality;
|
||||
}
|
||||
return request;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true if the given model name is a gemini-embedding-2 variant that
|
||||
* supports `outputDimensionality` and extended task types.
|
||||
*/
|
||||
export function isGeminiEmbedding2Model(model: string): boolean {
|
||||
return GEMINI_EMBEDDING_2_MODELS.has(model);
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate and return the `outputDimensionality` for gemini-embedding-2 models.
|
||||
* Returns `undefined` for older models (they don't support the param).
|
||||
*/
|
||||
export function resolveGeminiOutputDimensionality(
|
||||
model: string,
|
||||
requested?: number,
|
||||
): number | undefined {
|
||||
if (!isGeminiEmbedding2Model(model)) {
|
||||
return undefined;
|
||||
}
|
||||
if (requested == null) {
|
||||
return GEMINI_EMBEDDING_2_DEFAULT_DIMENSIONS;
|
||||
}
|
||||
const valid: readonly number[] = GEMINI_EMBEDDING_2_VALID_DIMENSIONS;
|
||||
if (!valid.includes(requested)) {
|
||||
throw new Error(
|
||||
`Invalid outputDimensionality ${requested} for ${model}. Valid values: ${valid.join(", ")}`,
|
||||
);
|
||||
}
|
||||
return requested;
|
||||
}
|
||||
function resolveRemoteApiKey(remoteApiKey: unknown): string | undefined {
|
||||
const trimmed = resolveMemorySecretInputString({
|
||||
value: remoteApiKey,
|
||||
@@ -67,6 +147,21 @@ function resolveRemoteApiKey(remoteApiKey: unknown): string | undefined {
|
||||
return trimmed;
|
||||
}
|
||||
|
||||
export function normalizeGeminiModel(model: string): string {
|
||||
const trimmed = model.trim();
|
||||
if (!trimmed) {
|
||||
return DEFAULT_GEMINI_EMBEDDING_MODEL;
|
||||
}
|
||||
const withoutPrefix = trimmed.replace(/^models\//, "");
|
||||
if (withoutPrefix.startsWith("gemini/")) {
|
||||
return withoutPrefix.slice("gemini/".length);
|
||||
}
|
||||
if (withoutPrefix.startsWith("google/")) {
|
||||
return withoutPrefix.slice("google/".length);
|
||||
}
|
||||
return withoutPrefix;
|
||||
}
|
||||
|
||||
async function fetchGeminiEmbeddingPayload(params: {
|
||||
client: GeminiEmbeddingClient;
|
||||
endpoint: string;
|
||||
@@ -120,9 +215,30 @@ function buildGeminiModelPath(model: string): string {
|
||||
return model.startsWith("models/") ? model : `models/${model}`;
|
||||
}
|
||||
|
||||
function normalizeGoogleApiBaseUrl(baseUrl: string): string {
|
||||
const trimmed = baseUrl.trim().replace(/\/+$/, "");
|
||||
if (!trimmed) {
|
||||
return DEFAULT_GOOGLE_API_BASE_URL;
|
||||
}
|
||||
try {
|
||||
const url = new URL(trimmed);
|
||||
url.hash = "";
|
||||
url.search = "";
|
||||
if (
|
||||
url.origin.toLowerCase() === "https://generativelanguage.googleapis.com" &&
|
||||
url.pathname.replace(/\/+$/, "") === ""
|
||||
) {
|
||||
url.pathname = "/v1beta";
|
||||
}
|
||||
return url.toString().replace(/\/+$/, "");
|
||||
} catch {
|
||||
return trimmed;
|
||||
}
|
||||
}
|
||||
|
||||
export async function createGeminiEmbeddingProvider(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<{ provider: EmbeddingProvider; client: GeminiEmbeddingClient }> {
|
||||
options: MemoryEmbeddingProviderCreateOptions,
|
||||
): Promise<{ provider: MemoryEmbeddingProvider; client: GeminiEmbeddingClient }> {
|
||||
const client = await resolveGeminiEmbeddingClient(options);
|
||||
const baseUrl = client.baseUrl.replace(/\/$/, "");
|
||||
const embedUrl = `${baseUrl}/${client.modelPath}:embedContent`;
|
||||
@@ -190,7 +306,7 @@ export async function createGeminiEmbeddingProvider(
|
||||
}
|
||||
|
||||
export async function resolveGeminiEmbeddingClient(
|
||||
options: EmbeddingProviderOptions,
|
||||
options: MemoryEmbeddingProviderCreateOptions,
|
||||
): Promise<GeminiEmbeddingClient> {
|
||||
const remote = options.remote;
|
||||
const remoteApiKey = resolveRemoteApiKey(remote?.apiKey);
|
||||
@@ -3,6 +3,7 @@ import type { MediaUnderstandingProvider } from "openclaw/plugin-sdk/media-under
|
||||
import { definePluginEntry } from "openclaw/plugin-sdk/plugin-entry";
|
||||
import { buildGoogleGeminiCliBackend } from "./cli-backend.js";
|
||||
import { registerGoogleGeminiCliProvider } from "./gemini-cli-provider.js";
|
||||
import { geminiMemoryEmbeddingProviderAdapter } from "./memory-embedding-adapter.js";
|
||||
import { buildGoogleMusicGenerationProvider } from "./music-generation-provider.js";
|
||||
import { registerGoogleProvider } from "./provider-registration.js";
|
||||
import { buildGoogleSpeechProvider } from "./speech-provider.js";
|
||||
@@ -111,6 +112,7 @@ export default definePluginEntry({
|
||||
api.registerCliBackend(buildGoogleGeminiCliBackend());
|
||||
registerGoogleGeminiCliProvider(api);
|
||||
registerGoogleProvider(api);
|
||||
api.registerMemoryEmbeddingProvider(geminiMemoryEmbeddingProviderAdapter);
|
||||
api.registerImageGenerationProvider(createLazyGoogleImageGenerationProvider());
|
||||
api.registerMediaUnderstandingProvider(createLazyGoogleMediaUnderstandingProvider());
|
||||
api.registerMusicGenerationProvider(buildGoogleMusicGenerationProvider());
|
||||
|
||||
79
extensions/google/memory-embedding-adapter.ts
Normal file
79
extensions/google/memory-embedding-adapter.ts
Normal file
@@ -0,0 +1,79 @@
|
||||
import {
|
||||
hasNonTextEmbeddingParts,
|
||||
isMissingEmbeddingApiKeyError,
|
||||
mapBatchEmbeddingsByIndex,
|
||||
sanitizeEmbeddingCacheHeaders,
|
||||
type MemoryEmbeddingProviderAdapter,
|
||||
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
|
||||
import { runGeminiEmbeddingBatches } from "./embedding-batch.js";
|
||||
import {
|
||||
buildGeminiEmbeddingRequest,
|
||||
createGeminiEmbeddingProvider,
|
||||
DEFAULT_GEMINI_EMBEDDING_MODEL,
|
||||
} from "./embedding-provider.js";
|
||||
|
||||
function supportsGeminiMultimodalEmbeddings(model: string): boolean {
|
||||
const normalized = model
|
||||
.trim()
|
||||
.replace(/^models\//, "")
|
||||
.replace(/^(gemini|google)\//, "");
|
||||
return normalized === "gemini-embedding-2-preview";
|
||||
}
|
||||
|
||||
export const geminiMemoryEmbeddingProviderAdapter: MemoryEmbeddingProviderAdapter = {
|
||||
id: "gemini",
|
||||
defaultModel: DEFAULT_GEMINI_EMBEDDING_MODEL,
|
||||
transport: "remote",
|
||||
authProviderId: "google",
|
||||
autoSelectPriority: 30,
|
||||
allowExplicitWhenConfiguredAuto: true,
|
||||
supportsMultimodalEmbeddings: ({ model }) => supportsGeminiMultimodalEmbeddings(model),
|
||||
shouldContinueAutoSelection: isMissingEmbeddingApiKeyError,
|
||||
create: async (options) => {
|
||||
const { provider, client } = await createGeminiEmbeddingProvider({
|
||||
...options,
|
||||
provider: "gemini",
|
||||
fallback: "none",
|
||||
});
|
||||
return {
|
||||
provider,
|
||||
runtime: {
|
||||
id: "gemini",
|
||||
cacheKeyData: {
|
||||
provider: "gemini",
|
||||
baseUrl: client.baseUrl,
|
||||
model: client.model,
|
||||
outputDimensionality: client.outputDimensionality,
|
||||
headers: sanitizeEmbeddingCacheHeaders(client.headers, [
|
||||
"authorization",
|
||||
"x-goog-api-key",
|
||||
]),
|
||||
},
|
||||
batchEmbed: async (batch) => {
|
||||
if (batch.chunks.some((chunk) => hasNonTextEmbeddingParts(chunk.embeddingInput))) {
|
||||
return null;
|
||||
}
|
||||
const byCustomId = await runGeminiEmbeddingBatches({
|
||||
gemini: client,
|
||||
agentId: batch.agentId,
|
||||
requests: batch.chunks.map((chunk, index) => ({
|
||||
custom_id: String(index),
|
||||
request: buildGeminiEmbeddingRequest({
|
||||
input: chunk.embeddingInput ?? { text: chunk.text },
|
||||
taskType: "RETRIEVAL_DOCUMENT",
|
||||
modelPath: client.modelPath,
|
||||
outputDimensionality: client.outputDimensionality,
|
||||
}),
|
||||
})),
|
||||
wait: batch.wait,
|
||||
concurrency: batch.concurrency,
|
||||
pollIntervalMs: batch.pollIntervalMs,
|
||||
timeoutMs: batch.timeoutMs,
|
||||
debug: batch.debug,
|
||||
});
|
||||
return mapBatchEmbeddingsByIndex(byCustomId, batch.chunks.length);
|
||||
},
|
||||
},
|
||||
};
|
||||
},
|
||||
};
|
||||
@@ -46,6 +46,7 @@
|
||||
},
|
||||
"contracts": {
|
||||
"mediaUnderstandingProviders": ["google"],
|
||||
"memoryEmbeddingProviders": ["gemini"],
|
||||
"imageGenerationProviders": ["google"],
|
||||
"musicGenerationProviders": ["google"],
|
||||
"speechProviders": ["google"],
|
||||
|
||||
@@ -8,6 +8,7 @@ import {
|
||||
type ProviderRuntimeModel,
|
||||
} from "openclaw/plugin-sdk/plugin-entry";
|
||||
import { CUSTOM_LOCAL_AUTH_MARKER } from "openclaw/plugin-sdk/provider-auth";
|
||||
import { lmstudioMemoryEmbeddingProviderAdapter } from "./memory-embedding-adapter.js";
|
||||
import {
|
||||
LMSTUDIO_DEFAULT_API_KEY_ENV_VAR,
|
||||
LMSTUDIO_LOCAL_API_KEY_PLACEHOLDER,
|
||||
@@ -52,6 +53,7 @@ export default definePluginEntry({
|
||||
name: "LM Studio Provider",
|
||||
description: "Bundled LM Studio provider plugin",
|
||||
register(api: OpenClawPluginApi) {
|
||||
api.registerMemoryEmbeddingProvider(lmstudioMemoryEmbeddingProviderAdapter);
|
||||
api.registerProvider({
|
||||
id: PROVIDER_ID,
|
||||
label: "LM Studio",
|
||||
|
||||
35
extensions/lmstudio/memory-embedding-adapter.ts
Normal file
35
extensions/lmstudio/memory-embedding-adapter.ts
Normal file
@@ -0,0 +1,35 @@
|
||||
import {
|
||||
sanitizeEmbeddingCacheHeaders,
|
||||
type MemoryEmbeddingProviderAdapter,
|
||||
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
|
||||
import {
|
||||
createLmstudioEmbeddingProvider,
|
||||
DEFAULT_LMSTUDIO_EMBEDDING_MODEL,
|
||||
} from "./src/embedding-provider.js";
|
||||
|
||||
export const lmstudioMemoryEmbeddingProviderAdapter: MemoryEmbeddingProviderAdapter = {
|
||||
id: "lmstudio",
|
||||
defaultModel: DEFAULT_LMSTUDIO_EMBEDDING_MODEL,
|
||||
transport: "remote",
|
||||
authProviderId: "lmstudio",
|
||||
allowExplicitWhenConfiguredAuto: true,
|
||||
create: async (options) => {
|
||||
const { provider, client } = await createLmstudioEmbeddingProvider({
|
||||
...options,
|
||||
provider: "lmstudio",
|
||||
fallback: "none",
|
||||
});
|
||||
return {
|
||||
provider,
|
||||
runtime: {
|
||||
id: "lmstudio",
|
||||
cacheKeyData: {
|
||||
provider: "lmstudio",
|
||||
baseUrl: client.baseUrl,
|
||||
model: client.model,
|
||||
headers: sanitizeEmbeddingCacheHeaders(client.headers, ["authorization"]),
|
||||
},
|
||||
},
|
||||
};
|
||||
},
|
||||
};
|
||||
@@ -21,6 +21,9 @@
|
||||
"groupHint": "Self-hosted open-weight models"
|
||||
}
|
||||
],
|
||||
"contracts": {
|
||||
"memoryEmbeddingProviders": ["lmstudio"]
|
||||
},
|
||||
"configSchema": {
|
||||
"type": "object",
|
||||
"additionalProperties": false,
|
||||
|
||||
@@ -1,20 +1,21 @@
|
||||
import { formatErrorMessage } from "../../infra/errors.js";
|
||||
import type { SsrFPolicy } from "../../infra/net/ssrf.js";
|
||||
import { createSubsystemLogger } from "../../logging/subsystem.js";
|
||||
import { createSubsystemLogger } from "openclaw/plugin-sdk/logging-core";
|
||||
import {
|
||||
buildRemoteBaseUrlPolicy,
|
||||
createRemoteEmbeddingProvider,
|
||||
normalizeEmbeddingModelWithPrefixes,
|
||||
type MemoryEmbeddingProvider,
|
||||
type MemoryEmbeddingProviderCreateOptions,
|
||||
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
|
||||
import { resolveMemorySecretInputString } from "openclaw/plugin-sdk/memory-core-host-secret";
|
||||
import { formatErrorMessage, type SsrFPolicy } from "openclaw/plugin-sdk/ssrf-runtime";
|
||||
import { LMSTUDIO_DEFAULT_EMBEDDING_MODEL, LMSTUDIO_PROVIDER_ID } from "./defaults.js";
|
||||
import { ensureLmstudioModelLoaded } from "./models.fetch.js";
|
||||
import { resolveLmstudioInferenceBase } from "./models.js";
|
||||
import {
|
||||
buildLmstudioAuthHeaders,
|
||||
ensureLmstudioModelLoaded,
|
||||
LMSTUDIO_DEFAULT_EMBEDDING_MODEL,
|
||||
LMSTUDIO_PROVIDER_ID,
|
||||
resolveLmstudioInferenceBase,
|
||||
resolveLmstudioProviderHeaders,
|
||||
resolveLmstudioRuntimeApiKey,
|
||||
} from "../../plugin-sdk/lmstudio-runtime.js";
|
||||
import { normalizeEmbeddingModelWithPrefixes } from "./embeddings-model-normalize.js";
|
||||
import { createRemoteEmbeddingProvider } from "./embeddings-remote-provider.js";
|
||||
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.types.js";
|
||||
import { buildRemoteBaseUrlPolicy } from "./remote-http.js";
|
||||
import { resolveMemorySecretInputString } from "./secret-input.js";
|
||||
} from "./runtime.js";
|
||||
|
||||
const log = createSubsystemLogger("memory/embeddings");
|
||||
|
||||
@@ -47,7 +48,7 @@ function hasAuthorizationHeader(headers: Record<string, string> | undefined): bo
|
||||
|
||||
/** Resolves API key (real or synthetic placeholder) from runtime/provider auth config. */
|
||||
async function resolveLmstudioApiKey(
|
||||
options: EmbeddingProviderOptions,
|
||||
options: MemoryEmbeddingProviderCreateOptions,
|
||||
): Promise<string | undefined> {
|
||||
try {
|
||||
return await resolveLmstudioRuntimeApiKey({
|
||||
@@ -65,8 +66,8 @@ async function resolveLmstudioApiKey(
|
||||
|
||||
/** Creates the LM Studio embedding provider client and preloads the target model before return. */
|
||||
export async function createLmstudioEmbeddingProvider(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<{ provider: EmbeddingProvider; client: LmstudioEmbeddingClient }> {
|
||||
options: MemoryEmbeddingProviderCreateOptions,
|
||||
): Promise<{ provider: MemoryEmbeddingProvider; client: LmstudioEmbeddingClient }> {
|
||||
const providerConfig = options.config.models?.providers?.lmstudio;
|
||||
const providerBaseUrl = providerConfig?.baseUrl?.trim();
|
||||
const isFallbackActivation = options.fallback === "lmstudio" && options.provider !== "lmstudio";
|
||||
@@ -1,10 +1,5 @@
|
||||
import {
|
||||
DEFAULT_GEMINI_EMBEDDING_MODEL,
|
||||
DEFAULT_LOCAL_MODEL,
|
||||
DEFAULT_MISTRAL_EMBEDDING_MODEL,
|
||||
DEFAULT_OLLAMA_EMBEDDING_MODEL,
|
||||
DEFAULT_OPENAI_EMBEDDING_MODEL,
|
||||
DEFAULT_VOYAGE_EMBEDDING_MODEL,
|
||||
getMemoryEmbeddingProvider,
|
||||
listMemoryEmbeddingProviders,
|
||||
type MemoryEmbeddingProvider,
|
||||
@@ -15,15 +10,7 @@ import {
|
||||
import { formatErrorMessage } from "../dreaming-shared.js";
|
||||
import { canAutoSelectLocal } from "./provider-adapters.js";
|
||||
|
||||
export {
|
||||
DEFAULT_GEMINI_EMBEDDING_MODEL,
|
||||
DEFAULT_LMSTUDIO_EMBEDDING_MODEL,
|
||||
DEFAULT_LOCAL_MODEL,
|
||||
DEFAULT_MISTRAL_EMBEDDING_MODEL,
|
||||
DEFAULT_OLLAMA_EMBEDDING_MODEL,
|
||||
DEFAULT_OPENAI_EMBEDDING_MODEL,
|
||||
DEFAULT_VOYAGE_EMBEDDING_MODEL,
|
||||
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
|
||||
export { DEFAULT_LOCAL_MODEL } from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
|
||||
|
||||
export type EmbeddingProvider = MemoryEmbeddingProvider;
|
||||
export type EmbeddingProviderId = string;
|
||||
|
||||
@@ -11,9 +11,9 @@ import {
|
||||
} from "../../../../src/plugins/memory-embedding-providers.js";
|
||||
import "./test-runtime-mocks.js";
|
||||
import type { MemoryIndexManager } from "./index.js";
|
||||
import { getMemorySearchManager, closeAllMemorySearchManagers } from "./index.js";
|
||||
import { closeAllMemorySearchManagers, getMemorySearchManager } from "./index.js";
|
||||
import {
|
||||
DEFAULT_OLLAMA_EMBEDDING_MODEL,
|
||||
DEFAULT_LOCAL_MODEL,
|
||||
registerBuiltInMemoryEmbeddingProviders,
|
||||
} from "./provider-adapters.js";
|
||||
|
||||
@@ -112,14 +112,14 @@ vi.mock("./embeddings.js", () => {
|
||||
});
|
||||
|
||||
describe("memory index", () => {
|
||||
it("registers the builtin ollama embedding provider", () => {
|
||||
const adapter = listRegisteredAdapters().find((entry) => entry.id === "ollama");
|
||||
it("registers the builtin local embedding provider", () => {
|
||||
const adapter = listRegisteredAdapters().find((entry) => entry.id === "local");
|
||||
|
||||
expect(adapter).toBeDefined();
|
||||
expect(adapter).toEqual(
|
||||
expect.objectContaining({
|
||||
id: "ollama",
|
||||
defaultModel: DEFAULT_OLLAMA_EMBEDDING_MODEL,
|
||||
id: "local",
|
||||
defaultModel: DEFAULT_LOCAL_MODEL,
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
@@ -1,31 +1,13 @@
|
||||
import fsSync from "node:fs";
|
||||
import {
|
||||
DEFAULT_GEMINI_EMBEDDING_MODEL,
|
||||
DEFAULT_LMSTUDIO_EMBEDDING_MODEL,
|
||||
DEFAULT_LOCAL_MODEL,
|
||||
DEFAULT_MISTRAL_EMBEDDING_MODEL,
|
||||
DEFAULT_OLLAMA_EMBEDDING_MODEL,
|
||||
DEFAULT_OPENAI_EMBEDDING_MODEL,
|
||||
DEFAULT_VOYAGE_EMBEDDING_MODEL,
|
||||
OPENAI_BATCH_ENDPOINT,
|
||||
buildGeminiEmbeddingRequest,
|
||||
createGeminiEmbeddingProvider,
|
||||
createLmstudioEmbeddingProvider,
|
||||
createLocalEmbeddingProvider,
|
||||
createMistralEmbeddingProvider,
|
||||
createOllamaEmbeddingProvider,
|
||||
createOpenAiEmbeddingProvider,
|
||||
createVoyageEmbeddingProvider,
|
||||
hasNonTextEmbeddingParts,
|
||||
DEFAULT_LOCAL_MODEL,
|
||||
listMemoryEmbeddingProviders,
|
||||
listRegisteredMemoryEmbeddingProviderAdapters,
|
||||
runGeminiEmbeddingBatches,
|
||||
runOpenAiEmbeddingBatches,
|
||||
runVoyageEmbeddingBatches,
|
||||
type MemoryEmbeddingProviderAdapter,
|
||||
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
|
||||
import { resolveUserPath } from "openclaw/plugin-sdk/memory-core-host-engine-foundation";
|
||||
import { getProviderEnvVars } from "openclaw/plugin-sdk/provider-env-vars";
|
||||
import { normalizeLowercaseStringOrEmpty } from "openclaw/plugin-sdk/text-runtime";
|
||||
import { formatErrorMessage } from "../dreaming-shared.js";
|
||||
import { filterUnregisteredMemoryEmbeddingProviderAdapters } from "./provider-adapter-registration.js";
|
||||
|
||||
@@ -37,31 +19,6 @@ export type BuiltinMemoryEmbeddingProviderDoctorMetadata = {
|
||||
autoSelectPriority?: number;
|
||||
};
|
||||
|
||||
function isMissingApiKeyError(err: unknown): boolean {
|
||||
return formatErrorMessage(err).includes("No API key found for provider");
|
||||
}
|
||||
|
||||
function sanitizeHeaders(
|
||||
headers: Record<string, string>,
|
||||
excludedHeaderNames: string[],
|
||||
): Array<[string, string]> {
|
||||
const excluded = new Set(
|
||||
excludedHeaderNames.map((name) => normalizeLowercaseStringOrEmpty(name)),
|
||||
);
|
||||
return Object.entries(headers)
|
||||
.filter(([key]) => !excluded.has(normalizeLowercaseStringOrEmpty(key)))
|
||||
.toSorted(([a], [b]) => a.localeCompare(b))
|
||||
.map(([key, value]) => [key, value]);
|
||||
}
|
||||
|
||||
function mapBatchEmbeddingsByIndex(byCustomId: Map<string, number[]>, count: number): number[][] {
|
||||
const embeddings: number[][] = [];
|
||||
for (let index = 0; index < count; index += 1) {
|
||||
embeddings.push(byCustomId.get(String(index)) ?? []);
|
||||
}
|
||||
return embeddings;
|
||||
}
|
||||
|
||||
function isNodeLlamaCppMissing(err: unknown): boolean {
|
||||
if (!(err instanceof Error)) {
|
||||
return false;
|
||||
@@ -70,6 +27,20 @@ function isNodeLlamaCppMissing(err: unknown): boolean {
|
||||
return code === "ERR_MODULE_NOT_FOUND" && err.message.includes("node-llama-cpp");
|
||||
}
|
||||
|
||||
function listRemoteEmbeddingSetupHints(): string[] {
|
||||
try {
|
||||
return listMemoryEmbeddingProviders()
|
||||
.filter(
|
||||
(adapter) =>
|
||||
adapter.transport === "remote" && typeof adapter.autoSelectPriority === "number",
|
||||
)
|
||||
.toSorted((a, b) => (a.autoSelectPriority ?? 0) - (b.autoSelectPriority ?? 0))
|
||||
.map((adapter) => `Or set agents.defaults.memorySearch.provider = "${adapter.id}" (remote).`);
|
||||
} catch {
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
function formatLocalSetupError(err: unknown): string {
|
||||
const detail = formatErrorMessage(err);
|
||||
const missing = isNodeLlamaCppMissing(err);
|
||||
@@ -87,9 +58,7 @@ function formatLocalSetupError(err: unknown): string {
|
||||
? "2) Reinstall OpenClaw (this should install node-llama-cpp): npm i -g openclaw@latest"
|
||||
: null,
|
||||
"3) If you use pnpm: pnpm approve-builds (select node-llama-cpp), then pnpm rebuild node-llama-cpp",
|
||||
...["openai", "gemini", "voyage", "mistral"].map(
|
||||
(provider) => `Or set agents.defaults.memorySearch.provider = "${provider}" (remote).`,
|
||||
),
|
||||
...listRemoteEmbeddingSetupHints(),
|
||||
]
|
||||
.filter(Boolean)
|
||||
.join("\n");
|
||||
@@ -111,237 +80,6 @@ function canAutoSelectLocal(modelPath?: string): boolean {
|
||||
}
|
||||
}
|
||||
|
||||
function supportsGeminiMultimodalEmbeddings(model: string): boolean {
|
||||
const normalized = model
|
||||
.trim()
|
||||
.replace(/^models\//, "")
|
||||
.replace(/^(gemini|google)\//, "");
|
||||
return normalized === "gemini-embedding-2-preview";
|
||||
}
|
||||
|
||||
function resolveMemoryEmbeddingAuthProviderId(providerId: string): string {
|
||||
return providerId === "gemini" ? "google" : providerId;
|
||||
}
|
||||
|
||||
const openAiAdapter: MemoryEmbeddingProviderAdapter = {
|
||||
id: "openai",
|
||||
defaultModel: DEFAULT_OPENAI_EMBEDDING_MODEL,
|
||||
transport: "remote",
|
||||
autoSelectPriority: 20,
|
||||
allowExplicitWhenConfiguredAuto: true,
|
||||
shouldContinueAutoSelection: isMissingApiKeyError,
|
||||
create: async (options) => {
|
||||
const { provider, client } = await createOpenAiEmbeddingProvider({
|
||||
...options,
|
||||
provider: "openai",
|
||||
fallback: "none",
|
||||
});
|
||||
return {
|
||||
provider,
|
||||
runtime: {
|
||||
id: "openai",
|
||||
cacheKeyData: {
|
||||
provider: "openai",
|
||||
baseUrl: client.baseUrl,
|
||||
model: client.model,
|
||||
headers: sanitizeHeaders(client.headers, ["authorization"]),
|
||||
},
|
||||
batchEmbed: async (batch) => {
|
||||
const byCustomId = await runOpenAiEmbeddingBatches({
|
||||
openAi: client,
|
||||
agentId: batch.agentId,
|
||||
requests: batch.chunks.map((chunk, index) => ({
|
||||
custom_id: String(index),
|
||||
method: "POST",
|
||||
url: OPENAI_BATCH_ENDPOINT,
|
||||
body: {
|
||||
model: client.model,
|
||||
input: chunk.text,
|
||||
},
|
||||
})),
|
||||
wait: batch.wait,
|
||||
concurrency: batch.concurrency,
|
||||
pollIntervalMs: batch.pollIntervalMs,
|
||||
timeoutMs: batch.timeoutMs,
|
||||
debug: batch.debug,
|
||||
});
|
||||
return mapBatchEmbeddingsByIndex(byCustomId, batch.chunks.length);
|
||||
},
|
||||
},
|
||||
};
|
||||
},
|
||||
};
|
||||
|
||||
const geminiAdapter: MemoryEmbeddingProviderAdapter = {
|
||||
id: "gemini",
|
||||
defaultModel: DEFAULT_GEMINI_EMBEDDING_MODEL,
|
||||
transport: "remote",
|
||||
autoSelectPriority: 30,
|
||||
allowExplicitWhenConfiguredAuto: true,
|
||||
supportsMultimodalEmbeddings: ({ model }) => supportsGeminiMultimodalEmbeddings(model),
|
||||
shouldContinueAutoSelection: isMissingApiKeyError,
|
||||
create: async (options) => {
|
||||
const { provider, client } = await createGeminiEmbeddingProvider({
|
||||
...options,
|
||||
provider: "gemini",
|
||||
fallback: "none",
|
||||
});
|
||||
return {
|
||||
provider,
|
||||
runtime: {
|
||||
id: "gemini",
|
||||
cacheKeyData: {
|
||||
provider: "gemini",
|
||||
baseUrl: client.baseUrl,
|
||||
model: client.model,
|
||||
outputDimensionality: client.outputDimensionality,
|
||||
headers: sanitizeHeaders(client.headers, ["authorization", "x-goog-api-key"]),
|
||||
},
|
||||
batchEmbed: async (batch) => {
|
||||
if (batch.chunks.some((chunk) => hasNonTextEmbeddingParts(chunk.embeddingInput))) {
|
||||
return null;
|
||||
}
|
||||
const byCustomId = await runGeminiEmbeddingBatches({
|
||||
gemini: client,
|
||||
agentId: batch.agentId,
|
||||
requests: batch.chunks.map((chunk, index) => ({
|
||||
custom_id: String(index),
|
||||
request: buildGeminiEmbeddingRequest({
|
||||
input: chunk.embeddingInput ?? { text: chunk.text },
|
||||
taskType: "RETRIEVAL_DOCUMENT",
|
||||
modelPath: client.modelPath,
|
||||
outputDimensionality: client.outputDimensionality,
|
||||
}),
|
||||
})),
|
||||
wait: batch.wait,
|
||||
concurrency: batch.concurrency,
|
||||
pollIntervalMs: batch.pollIntervalMs,
|
||||
timeoutMs: batch.timeoutMs,
|
||||
debug: batch.debug,
|
||||
});
|
||||
return mapBatchEmbeddingsByIndex(byCustomId, batch.chunks.length);
|
||||
},
|
||||
},
|
||||
};
|
||||
},
|
||||
};
|
||||
|
||||
const voyageAdapter: MemoryEmbeddingProviderAdapter = {
|
||||
id: "voyage",
|
||||
defaultModel: DEFAULT_VOYAGE_EMBEDDING_MODEL,
|
||||
transport: "remote",
|
||||
autoSelectPriority: 40,
|
||||
allowExplicitWhenConfiguredAuto: true,
|
||||
shouldContinueAutoSelection: isMissingApiKeyError,
|
||||
create: async (options) => {
|
||||
const { provider, client } = await createVoyageEmbeddingProvider({
|
||||
...options,
|
||||
provider: "voyage",
|
||||
fallback: "none",
|
||||
});
|
||||
return {
|
||||
provider,
|
||||
runtime: {
|
||||
id: "voyage",
|
||||
batchEmbed: async (batch) => {
|
||||
const byCustomId = await runVoyageEmbeddingBatches({
|
||||
client,
|
||||
agentId: batch.agentId,
|
||||
requests: batch.chunks.map((chunk, index) => ({
|
||||
custom_id: String(index),
|
||||
body: {
|
||||
input: chunk.text,
|
||||
},
|
||||
})),
|
||||
wait: batch.wait,
|
||||
concurrency: batch.concurrency,
|
||||
pollIntervalMs: batch.pollIntervalMs,
|
||||
timeoutMs: batch.timeoutMs,
|
||||
debug: batch.debug,
|
||||
});
|
||||
return mapBatchEmbeddingsByIndex(byCustomId, batch.chunks.length);
|
||||
},
|
||||
},
|
||||
};
|
||||
},
|
||||
};
|
||||
|
||||
const mistralAdapter: MemoryEmbeddingProviderAdapter = {
|
||||
id: "mistral",
|
||||
defaultModel: DEFAULT_MISTRAL_EMBEDDING_MODEL,
|
||||
transport: "remote",
|
||||
autoSelectPriority: 50,
|
||||
allowExplicitWhenConfiguredAuto: true,
|
||||
shouldContinueAutoSelection: isMissingApiKeyError,
|
||||
create: async (options) => {
|
||||
const { provider, client } = await createMistralEmbeddingProvider({
|
||||
...options,
|
||||
provider: "mistral",
|
||||
fallback: "none",
|
||||
});
|
||||
return {
|
||||
provider,
|
||||
runtime: {
|
||||
id: "mistral",
|
||||
cacheKeyData: {
|
||||
provider: "mistral",
|
||||
model: client.model,
|
||||
},
|
||||
},
|
||||
};
|
||||
},
|
||||
};
|
||||
|
||||
const ollamaAdapter: MemoryEmbeddingProviderAdapter = {
|
||||
id: "ollama",
|
||||
defaultModel: DEFAULT_OLLAMA_EMBEDDING_MODEL,
|
||||
transport: "remote",
|
||||
create: async (options) => {
|
||||
const { provider, client } = await createOllamaEmbeddingProvider({
|
||||
...options,
|
||||
provider: "ollama",
|
||||
fallback: "none",
|
||||
});
|
||||
return {
|
||||
provider,
|
||||
runtime: {
|
||||
id: "ollama",
|
||||
cacheKeyData: {
|
||||
provider: "ollama",
|
||||
baseUrl: client.baseUrl,
|
||||
model: client.model,
|
||||
headers: sanitizeHeaders(client.headers, ["authorization"]),
|
||||
},
|
||||
},
|
||||
};
|
||||
},
|
||||
};
|
||||
|
||||
const lmstudioAdapter: MemoryEmbeddingProviderAdapter = {
|
||||
id: "lmstudio",
|
||||
defaultModel: DEFAULT_LMSTUDIO_EMBEDDING_MODEL,
|
||||
transport: "remote",
|
||||
create: async (options) => {
|
||||
const { provider, client } = await createLmstudioEmbeddingProvider({
|
||||
...options,
|
||||
provider: "lmstudio",
|
||||
fallback: "none",
|
||||
});
|
||||
return {
|
||||
provider,
|
||||
runtime: {
|
||||
id: "lmstudio",
|
||||
cacheKeyData: {
|
||||
provider: "lmstudio",
|
||||
baseUrl: client.baseUrl,
|
||||
model: client.model,
|
||||
headers: sanitizeHeaders(client.headers, ["authorization"]),
|
||||
},
|
||||
},
|
||||
};
|
||||
},
|
||||
};
|
||||
|
||||
const localAdapter: MemoryEmbeddingProviderAdapter = {
|
||||
id: "local",
|
||||
defaultModel: DEFAULT_LOCAL_MODEL,
|
||||
@@ -368,24 +106,14 @@ const localAdapter: MemoryEmbeddingProviderAdapter = {
|
||||
},
|
||||
};
|
||||
|
||||
export const builtinMemoryEmbeddingProviderAdapters = [
|
||||
localAdapter,
|
||||
openAiAdapter,
|
||||
geminiAdapter,
|
||||
voyageAdapter,
|
||||
mistralAdapter,
|
||||
ollamaAdapter,
|
||||
lmstudioAdapter,
|
||||
] as const;
|
||||
export const builtinMemoryEmbeddingProviderAdapters = [localAdapter] as const;
|
||||
|
||||
const builtinMemoryEmbeddingProviderAdapterById = new Map(
|
||||
builtinMemoryEmbeddingProviderAdapters.map((adapter) => [adapter.id, adapter]),
|
||||
);
|
||||
export { DEFAULT_LOCAL_MODEL };
|
||||
|
||||
export function getBuiltinMemoryEmbeddingProviderAdapter(
|
||||
id: string,
|
||||
): MemoryEmbeddingProviderAdapter | undefined {
|
||||
return builtinMemoryEmbeddingProviderAdapterById.get(id);
|
||||
return listMemoryEmbeddingProviders().find((adapter) => adapter.id === id);
|
||||
}
|
||||
|
||||
export function registerBuiltInMemoryEmbeddingProviders(register: {
|
||||
@@ -409,7 +137,7 @@ export function getBuiltinMemoryEmbeddingProviderDoctorMetadata(
|
||||
if (!adapter) {
|
||||
return null;
|
||||
}
|
||||
const authProviderId = resolveMemoryEmbeddingAuthProviderId(adapter.id);
|
||||
const authProviderId = adapter.authProviderId ?? adapter.id;
|
||||
return {
|
||||
providerId: adapter.id,
|
||||
authProviderId,
|
||||
@@ -420,27 +148,19 @@ export function getBuiltinMemoryEmbeddingProviderDoctorMetadata(
|
||||
}
|
||||
|
||||
export function listBuiltinAutoSelectMemoryEmbeddingProviderDoctorMetadata(): Array<BuiltinMemoryEmbeddingProviderDoctorMetadata> {
|
||||
return builtinMemoryEmbeddingProviderAdapters
|
||||
return listMemoryEmbeddingProviders()
|
||||
.filter((adapter) => typeof adapter.autoSelectPriority === "number")
|
||||
.toSorted((a, b) => (a.autoSelectPriority ?? 0) - (b.autoSelectPriority ?? 0))
|
||||
.map((adapter) => ({
|
||||
providerId: adapter.id,
|
||||
authProviderId: resolveMemoryEmbeddingAuthProviderId(adapter.id),
|
||||
envVars: getProviderEnvVars(resolveMemoryEmbeddingAuthProviderId(adapter.id)),
|
||||
transport: adapter.transport === "local" ? "local" : "remote",
|
||||
autoSelectPriority: adapter.autoSelectPriority,
|
||||
}));
|
||||
.map((adapter) => {
|
||||
const authProviderId = adapter.authProviderId ?? adapter.id;
|
||||
return {
|
||||
providerId: adapter.id,
|
||||
authProviderId,
|
||||
envVars: getProviderEnvVars(authProviderId),
|
||||
transport: adapter.transport === "local" ? "local" : "remote",
|
||||
autoSelectPriority: adapter.autoSelectPriority,
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
export {
|
||||
DEFAULT_GEMINI_EMBEDDING_MODEL,
|
||||
DEFAULT_LMSTUDIO_EMBEDDING_MODEL,
|
||||
DEFAULT_LOCAL_MODEL,
|
||||
DEFAULT_MISTRAL_EMBEDDING_MODEL,
|
||||
DEFAULT_OLLAMA_EMBEDDING_MODEL,
|
||||
DEFAULT_OPENAI_EMBEDDING_MODEL,
|
||||
DEFAULT_VOYAGE_EMBEDDING_MODEL,
|
||||
canAutoSelectLocal,
|
||||
formatLocalSetupError,
|
||||
isMissingApiKeyError,
|
||||
};
|
||||
export { canAutoSelectLocal, formatLocalSetupError };
|
||||
|
||||
@@ -16,6 +16,7 @@ import {
|
||||
writeFileWithinRoot,
|
||||
type OpenClawConfig,
|
||||
} from "openclaw/plugin-sdk/memory-core-host-engine-foundation";
|
||||
import { resolveAgentContextLimits } from "openclaw/plugin-sdk/memory-core-host-engine-foundation";
|
||||
import {
|
||||
buildSessionEntry,
|
||||
deriveQmdScopeChannel,
|
||||
@@ -47,7 +48,6 @@ import {
|
||||
type ResolvedQmdConfig,
|
||||
type ResolvedQmdMcporterConfig,
|
||||
} from "openclaw/plugin-sdk/memory-core-host-engine-storage";
|
||||
import { resolveAgentContextLimits } from "openclaw/plugin-sdk/memory-core-host-engine-foundation";
|
||||
import {
|
||||
localeLowercasePreservingWhitespace,
|
||||
normalizeLowercaseStringOrEmpty,
|
||||
@@ -1945,8 +1945,7 @@ export class QmdMemoryManager implements MemorySearchManager {
|
||||
from?: number,
|
||||
lines?: number,
|
||||
): Promise<
|
||||
| { missing: true }
|
||||
| { missing: false; selectedLines: string[]; moreSourceLinesRemain: boolean }
|
||||
{ missing: true } | { missing: false; selectedLines: string[]; moreSourceLinesRemain: boolean }
|
||||
> {
|
||||
const start = Math.max(1, from ?? 1);
|
||||
const count = Math.max(1, lines ?? Number.POSITIVE_INFINITY);
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import type { SsrFPolicy } from "../../infra/net/ssrf.js";
|
||||
import { normalizeEmbeddingModelWithPrefixes } from "./embeddings-model-normalize.js";
|
||||
import {
|
||||
createRemoteEmbeddingProvider,
|
||||
normalizeEmbeddingModelWithPrefixes,
|
||||
resolveRemoteEmbeddingClient,
|
||||
} from "./embeddings-remote-provider.js";
|
||||
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.types.js";
|
||||
type MemoryEmbeddingProvider,
|
||||
type MemoryEmbeddingProviderCreateOptions,
|
||||
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
|
||||
import type { SsrFPolicy } from "openclaw/plugin-sdk/ssrf-runtime";
|
||||
|
||||
export type MistralEmbeddingClient = {
|
||||
baseUrl: string;
|
||||
@@ -25,8 +26,8 @@ export function normalizeMistralModel(model: string): string {
|
||||
}
|
||||
|
||||
export async function createMistralEmbeddingProvider(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<{ provider: EmbeddingProvider; client: MistralEmbeddingClient }> {
|
||||
options: MemoryEmbeddingProviderCreateOptions,
|
||||
): Promise<{ provider: MemoryEmbeddingProvider; client: MistralEmbeddingClient }> {
|
||||
const client = await resolveMistralEmbeddingClient(options);
|
||||
|
||||
return {
|
||||
@@ -40,7 +41,7 @@ export async function createMistralEmbeddingProvider(
|
||||
}
|
||||
|
||||
export async function resolveMistralEmbeddingClient(
|
||||
options: EmbeddingProviderOptions,
|
||||
options: MemoryEmbeddingProviderCreateOptions,
|
||||
): Promise<MistralEmbeddingClient> {
|
||||
return await resolveRemoteEmbeddingClient({
|
||||
provider: "mistral",
|
||||
@@ -1,6 +1,7 @@
|
||||
import { defineSingleProviderPluginEntry } from "openclaw/plugin-sdk/provider-entry";
|
||||
import { applyMistralModelCompat } from "./api.js";
|
||||
import { mistralMediaUnderstandingProvider } from "./media-understanding-provider.js";
|
||||
import { mistralMemoryEmbeddingProviderAdapter } from "./memory-embedding-adapter.js";
|
||||
import { applyMistralConfig, MISTRAL_DEFAULT_MODEL_REF } from "./onboard.js";
|
||||
import { buildMistralProvider } from "./provider-catalog.js";
|
||||
import { contributeMistralResolvedModelCompat } from "./provider-compat.js";
|
||||
@@ -48,6 +49,7 @@ export default defineSingleProviderPluginEntry({
|
||||
buildReplayPolicy: () => buildMistralReplayPolicy(),
|
||||
},
|
||||
register(api) {
|
||||
api.registerMemoryEmbeddingProvider(mistralMemoryEmbeddingProviderAdapter);
|
||||
api.registerMediaUnderstandingProvider(mistralMediaUnderstandingProvider);
|
||||
},
|
||||
});
|
||||
|
||||
35
extensions/mistral/memory-embedding-adapter.ts
Normal file
35
extensions/mistral/memory-embedding-adapter.ts
Normal file
@@ -0,0 +1,35 @@
|
||||
import {
|
||||
isMissingEmbeddingApiKeyError,
|
||||
type MemoryEmbeddingProviderAdapter,
|
||||
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
|
||||
import {
|
||||
createMistralEmbeddingProvider,
|
||||
DEFAULT_MISTRAL_EMBEDDING_MODEL,
|
||||
} from "./embedding-provider.js";
|
||||
|
||||
export const mistralMemoryEmbeddingProviderAdapter: MemoryEmbeddingProviderAdapter = {
|
||||
id: "mistral",
|
||||
defaultModel: DEFAULT_MISTRAL_EMBEDDING_MODEL,
|
||||
transport: "remote",
|
||||
authProviderId: "mistral",
|
||||
autoSelectPriority: 50,
|
||||
allowExplicitWhenConfiguredAuto: true,
|
||||
shouldContinueAutoSelection: isMissingEmbeddingApiKeyError,
|
||||
create: async (options) => {
|
||||
const { provider, client } = await createMistralEmbeddingProvider({
|
||||
...options,
|
||||
provider: "mistral",
|
||||
fallback: "none",
|
||||
});
|
||||
return {
|
||||
provider,
|
||||
runtime: {
|
||||
id: "mistral",
|
||||
cacheKeyData: {
|
||||
provider: "mistral",
|
||||
model: client.model,
|
||||
},
|
||||
},
|
||||
};
|
||||
},
|
||||
};
|
||||
@@ -21,6 +21,7 @@
|
||||
}
|
||||
],
|
||||
"contracts": {
|
||||
"memoryEmbeddingProviders": ["mistral"],
|
||||
"mediaUnderstandingProviders": ["mistral"]
|
||||
},
|
||||
"configSchema": {
|
||||
|
||||
@@ -8,6 +8,7 @@ export const ollamaMemoryEmbeddingProviderAdapter: MemoryEmbeddingProviderAdapte
|
||||
id: "ollama",
|
||||
defaultModel: DEFAULT_OLLAMA_EMBEDDING_MODEL,
|
||||
transport: "remote",
|
||||
authProviderId: "ollama",
|
||||
create: async (options) => {
|
||||
const { provider, client } = await createOllamaEmbeddingProvider({
|
||||
...options,
|
||||
|
||||
@@ -17,8 +17,8 @@ import {
|
||||
type ProviderBatchOutputLine,
|
||||
uploadBatchJsonlFile,
|
||||
withRemoteHttpResponse,
|
||||
} from "./batch-embedding-common.js";
|
||||
import type { OpenAiEmbeddingClient } from "./embeddings-openai.js";
|
||||
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
|
||||
import type { OpenAiEmbeddingClient } from "./embedding-provider.js";
|
||||
|
||||
export type OpenAiBatchRequest = {
|
||||
custom_id: string;
|
||||
@@ -1,11 +1,11 @@
|
||||
import { parseStaticModelRef } from "../../agents/model-ref-shared.js";
|
||||
import type { SsrFPolicy } from "../../infra/net/ssrf.js";
|
||||
import { OPENAI_DEFAULT_EMBEDDING_MODEL } from "../../plugins/provider-model-defaults.js";
|
||||
import {
|
||||
createRemoteEmbeddingProvider,
|
||||
resolveRemoteEmbeddingClient,
|
||||
} from "./embeddings-remote-provider.js";
|
||||
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.types.js";
|
||||
type MemoryEmbeddingProvider,
|
||||
type MemoryEmbeddingProviderCreateOptions,
|
||||
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
|
||||
import type { SsrFPolicy } from "openclaw/plugin-sdk/ssrf-runtime";
|
||||
import { OPENAI_DEFAULT_EMBEDDING_MODEL } from "./default-models.js";
|
||||
|
||||
export type OpenAiEmbeddingClient = {
|
||||
baseUrl: string;
|
||||
@@ -28,13 +28,12 @@ export function normalizeOpenAiModel(model: string): string {
|
||||
if (!trimmed) {
|
||||
return DEFAULT_OPENAI_EMBEDDING_MODEL;
|
||||
}
|
||||
const parsed = parseStaticModelRef(trimmed, "openai");
|
||||
return parsed && parsed.provider === "openai" ? parsed.model : trimmed;
|
||||
return trimmed.startsWith("openai/") ? trimmed.slice("openai/".length) : trimmed;
|
||||
}
|
||||
|
||||
export async function createOpenAiEmbeddingProvider(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<{ provider: EmbeddingProvider; client: OpenAiEmbeddingClient }> {
|
||||
options: MemoryEmbeddingProviderCreateOptions,
|
||||
): Promise<{ provider: MemoryEmbeddingProvider; client: OpenAiEmbeddingClient }> {
|
||||
const client = await resolveOpenAiEmbeddingClient(options);
|
||||
|
||||
return {
|
||||
@@ -49,7 +48,7 @@ export async function createOpenAiEmbeddingProvider(
|
||||
}
|
||||
|
||||
export async function resolveOpenAiEmbeddingClient(
|
||||
options: EmbeddingProviderOptions,
|
||||
options: MemoryEmbeddingProviderCreateOptions,
|
||||
): Promise<OpenAiEmbeddingClient> {
|
||||
return await resolveRemoteEmbeddingClient({
|
||||
provider: "openai",
|
||||
@@ -6,6 +6,7 @@ import {
|
||||
openaiCodexMediaUnderstandingProvider,
|
||||
openaiMediaUnderstandingProvider,
|
||||
} from "./media-understanding-provider.js";
|
||||
import { openAiMemoryEmbeddingProviderAdapter } from "./memory-embedding-adapter.js";
|
||||
import { buildOpenAICodexProviderPlugin } from "./openai-codex-provider.js";
|
||||
import { buildOpenAIProvider } from "./openai-provider.js";
|
||||
import {
|
||||
@@ -39,6 +40,7 @@ export default definePluginEntry({
|
||||
api.registerCliBackend(buildOpenAICodexCliBackend());
|
||||
api.registerProvider(buildProviderWithPromptContribution(buildOpenAIProvider()));
|
||||
api.registerProvider(buildProviderWithPromptContribution(buildOpenAICodexProviderPlugin()));
|
||||
api.registerMemoryEmbeddingProvider(openAiMemoryEmbeddingProviderAdapter);
|
||||
api.registerImageGenerationProvider(buildOpenAIImageGenerationProvider());
|
||||
api.registerRealtimeTranscriptionProvider(buildOpenAIRealtimeTranscriptionProvider());
|
||||
api.registerRealtimeVoiceProvider(buildOpenAIRealtimeVoiceProvider());
|
||||
|
||||
61
extensions/openai/memory-embedding-adapter.ts
Normal file
61
extensions/openai/memory-embedding-adapter.ts
Normal file
@@ -0,0 +1,61 @@
|
||||
import {
|
||||
isMissingEmbeddingApiKeyError,
|
||||
mapBatchEmbeddingsByIndex,
|
||||
sanitizeEmbeddingCacheHeaders,
|
||||
type MemoryEmbeddingProviderAdapter,
|
||||
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
|
||||
import { OPENAI_BATCH_ENDPOINT, runOpenAiEmbeddingBatches } from "./embedding-batch.js";
|
||||
import {
|
||||
createOpenAiEmbeddingProvider,
|
||||
DEFAULT_OPENAI_EMBEDDING_MODEL,
|
||||
} from "./embedding-provider.js";
|
||||
|
||||
export const openAiMemoryEmbeddingProviderAdapter: MemoryEmbeddingProviderAdapter = {
|
||||
id: "openai",
|
||||
defaultModel: DEFAULT_OPENAI_EMBEDDING_MODEL,
|
||||
transport: "remote",
|
||||
authProviderId: "openai",
|
||||
autoSelectPriority: 20,
|
||||
allowExplicitWhenConfiguredAuto: true,
|
||||
shouldContinueAutoSelection: isMissingEmbeddingApiKeyError,
|
||||
create: async (options) => {
|
||||
const { provider, client } = await createOpenAiEmbeddingProvider({
|
||||
...options,
|
||||
provider: "openai",
|
||||
fallback: "none",
|
||||
});
|
||||
return {
|
||||
provider,
|
||||
runtime: {
|
||||
id: "openai",
|
||||
cacheKeyData: {
|
||||
provider: "openai",
|
||||
baseUrl: client.baseUrl,
|
||||
model: client.model,
|
||||
headers: sanitizeEmbeddingCacheHeaders(client.headers, ["authorization"]),
|
||||
},
|
||||
batchEmbed: async (batch) => {
|
||||
const byCustomId = await runOpenAiEmbeddingBatches({
|
||||
openAi: client,
|
||||
agentId: batch.agentId,
|
||||
requests: batch.chunks.map((chunk, index) => ({
|
||||
custom_id: String(index),
|
||||
method: "POST",
|
||||
url: OPENAI_BATCH_ENDPOINT,
|
||||
body: {
|
||||
model: client.model,
|
||||
input: chunk.text,
|
||||
},
|
||||
})),
|
||||
wait: batch.wait,
|
||||
concurrency: batch.concurrency,
|
||||
pollIntervalMs: batch.pollIntervalMs,
|
||||
timeoutMs: batch.timeoutMs,
|
||||
debug: batch.debug,
|
||||
});
|
||||
return mapBatchEmbeddingsByIndex(byCustomId, batch.chunks.length);
|
||||
},
|
||||
},
|
||||
};
|
||||
},
|
||||
};
|
||||
@@ -39,6 +39,7 @@
|
||||
"speechProviders": ["openai"],
|
||||
"realtimeTranscriptionProviders": ["openai"],
|
||||
"realtimeVoiceProviders": ["openai"],
|
||||
"memoryEmbeddingProviders": ["openai"],
|
||||
"mediaUnderstandingProviders": ["openai", "openai-codex"],
|
||||
"imageGenerationProviders": ["openai"],
|
||||
"videoGenerationProviders": ["openai"]
|
||||
|
||||
@@ -19,8 +19,8 @@ import {
|
||||
type ProviderBatchOutputLine,
|
||||
uploadBatchJsonlFile,
|
||||
withRemoteHttpResponse,
|
||||
} from "./batch-embedding-common.js";
|
||||
import type { VoyageEmbeddingClient } from "./embeddings-voyage.js";
|
||||
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
|
||||
import type { VoyageEmbeddingClient } from "./embedding-provider.js";
|
||||
|
||||
/**
|
||||
* Voyage Batch API Input Line format.
|
||||
@@ -1,8 +1,11 @@
|
||||
import type { SsrFPolicy } from "../../infra/net/ssrf.js";
|
||||
import { normalizeEmbeddingModelWithPrefixes } from "./embeddings-model-normalize.js";
|
||||
import { resolveRemoteEmbeddingBearerClient } from "./embeddings-remote-client.js";
|
||||
import { fetchRemoteEmbeddingVectors } from "./embeddings-remote-fetch.js";
|
||||
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.types.js";
|
||||
import {
|
||||
fetchRemoteEmbeddingVectors,
|
||||
normalizeEmbeddingModelWithPrefixes,
|
||||
resolveRemoteEmbeddingBearerClient,
|
||||
type MemoryEmbeddingProvider,
|
||||
type MemoryEmbeddingProviderCreateOptions,
|
||||
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
|
||||
import type { SsrFPolicy } from "openclaw/plugin-sdk/ssrf-runtime";
|
||||
|
||||
export type VoyageEmbeddingClient = {
|
||||
baseUrl: string;
|
||||
@@ -28,8 +31,8 @@ export function normalizeVoyageModel(model: string): string {
|
||||
}
|
||||
|
||||
export async function createVoyageEmbeddingProvider(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<{ provider: EmbeddingProvider; client: VoyageEmbeddingClient }> {
|
||||
options: MemoryEmbeddingProviderCreateOptions,
|
||||
): Promise<{ provider: MemoryEmbeddingProvider; client: VoyageEmbeddingClient }> {
|
||||
const client = await resolveVoyageEmbeddingClient(options);
|
||||
const url = `${client.baseUrl.replace(/\/$/, "")}/embeddings`;
|
||||
|
||||
@@ -70,7 +73,7 @@ export async function createVoyageEmbeddingProvider(
|
||||
}
|
||||
|
||||
export async function resolveVoyageEmbeddingClient(
|
||||
options: EmbeddingProviderOptions,
|
||||
options: MemoryEmbeddingProviderCreateOptions,
|
||||
): Promise<VoyageEmbeddingClient> {
|
||||
const { baseUrl, headers, ssrfPolicy } = await resolveRemoteEmbeddingBearerClient({
|
||||
provider: "voyage",
|
||||
11
extensions/voyage/index.ts
Normal file
11
extensions/voyage/index.ts
Normal file
@@ -0,0 +1,11 @@
|
||||
import { definePluginEntry } from "openclaw/plugin-sdk/plugin-entry";
|
||||
import { voyageMemoryEmbeddingProviderAdapter } from "./memory-embedding-adapter.js";
|
||||
|
||||
export default definePluginEntry({
|
||||
id: "voyage",
|
||||
name: "Voyage Embeddings",
|
||||
description: "Bundled Voyage memory embedding provider plugin",
|
||||
register(api) {
|
||||
api.registerMemoryEmbeddingProvider(voyageMemoryEmbeddingProviderAdapter);
|
||||
},
|
||||
});
|
||||
56
extensions/voyage/memory-embedding-adapter.ts
Normal file
56
extensions/voyage/memory-embedding-adapter.ts
Normal file
@@ -0,0 +1,56 @@
|
||||
import {
|
||||
isMissingEmbeddingApiKeyError,
|
||||
mapBatchEmbeddingsByIndex,
|
||||
sanitizeEmbeddingCacheHeaders,
|
||||
type MemoryEmbeddingProviderAdapter,
|
||||
} from "openclaw/plugin-sdk/memory-core-host-engine-embeddings";
|
||||
import { runVoyageEmbeddingBatches } from "./embedding-batch.js";
|
||||
import {
|
||||
createVoyageEmbeddingProvider,
|
||||
DEFAULT_VOYAGE_EMBEDDING_MODEL,
|
||||
} from "./embedding-provider.js";
|
||||
|
||||
export const voyageMemoryEmbeddingProviderAdapter: MemoryEmbeddingProviderAdapter = {
|
||||
id: "voyage",
|
||||
defaultModel: DEFAULT_VOYAGE_EMBEDDING_MODEL,
|
||||
transport: "remote",
|
||||
authProviderId: "voyage",
|
||||
autoSelectPriority: 40,
|
||||
allowExplicitWhenConfiguredAuto: true,
|
||||
shouldContinueAutoSelection: isMissingEmbeddingApiKeyError,
|
||||
create: async (options) => {
|
||||
const { provider, client } = await createVoyageEmbeddingProvider({
|
||||
...options,
|
||||
provider: "voyage",
|
||||
fallback: "none",
|
||||
});
|
||||
return {
|
||||
provider,
|
||||
runtime: {
|
||||
id: "voyage",
|
||||
cacheKeyData: {
|
||||
provider: "voyage",
|
||||
baseUrl: client.baseUrl,
|
||||
model: client.model,
|
||||
headers: sanitizeEmbeddingCacheHeaders(client.headers, ["authorization"]),
|
||||
},
|
||||
batchEmbed: async (batch) => {
|
||||
const byCustomId = await runVoyageEmbeddingBatches({
|
||||
client,
|
||||
agentId: batch.agentId,
|
||||
requests: batch.chunks.map((chunk, index) => ({
|
||||
custom_id: String(index),
|
||||
body: { input: chunk.text },
|
||||
})),
|
||||
wait: batch.wait,
|
||||
concurrency: batch.concurrency,
|
||||
pollIntervalMs: batch.pollIntervalMs,
|
||||
timeoutMs: batch.timeoutMs,
|
||||
debug: batch.debug,
|
||||
});
|
||||
return mapBatchEmbeddingsByIndex(byCustomId, batch.chunks.length);
|
||||
},
|
||||
},
|
||||
};
|
||||
},
|
||||
};
|
||||
15
extensions/voyage/openclaw.plugin.json
Normal file
15
extensions/voyage/openclaw.plugin.json
Normal file
@@ -0,0 +1,15 @@
|
||||
{
|
||||
"id": "voyage",
|
||||
"enabledByDefault": true,
|
||||
"contracts": {
|
||||
"memoryEmbeddingProviders": ["voyage"]
|
||||
},
|
||||
"providerAuthEnvVars": {
|
||||
"voyage": ["VOYAGE_API_KEY"]
|
||||
},
|
||||
"configSchema": {
|
||||
"type": "object",
|
||||
"additionalProperties": false,
|
||||
"properties": {}
|
||||
}
|
||||
}
|
||||
15
extensions/voyage/package.json
Normal file
15
extensions/voyage/package.json
Normal file
@@ -0,0 +1,15 @@
|
||||
{
|
||||
"name": "@openclaw/voyage-provider",
|
||||
"version": "2026.4.15-beta.1",
|
||||
"private": true,
|
||||
"description": "OpenClaw Voyage embedding provider plugin",
|
||||
"type": "module",
|
||||
"devDependencies": {
|
||||
"@openclaw/plugin-sdk": "workspace:*"
|
||||
},
|
||||
"openclaw": {
|
||||
"extensions": [
|
||||
"./index.ts"
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -1,116 +0,0 @@
|
||||
import { afterEach, beforeAll, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import type { GeminiEmbeddingClient } from "./embeddings-gemini.js";
|
||||
|
||||
vi.mock("./remote-http.js", () => ({
|
||||
withRemoteHttpResponse: vi.fn(),
|
||||
}));
|
||||
|
||||
function magnitude(values: number[]) {
|
||||
return Math.sqrt(values.reduce((sum, value) => sum + value * value, 0));
|
||||
}
|
||||
|
||||
describe("runGeminiEmbeddingBatches", () => {
|
||||
let runGeminiEmbeddingBatches: typeof import("./batch-gemini.js").runGeminiEmbeddingBatches;
|
||||
let withRemoteHttpResponse: typeof import("./remote-http.js").withRemoteHttpResponse;
|
||||
let remoteHttpMock: ReturnType<typeof vi.mocked<typeof withRemoteHttpResponse>>;
|
||||
|
||||
beforeAll(async () => {
|
||||
({ runGeminiEmbeddingBatches } = await import("./batch-gemini.js"));
|
||||
({ withRemoteHttpResponse } = await import("./remote-http.js"));
|
||||
remoteHttpMock = vi.mocked(withRemoteHttpResponse);
|
||||
});
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.resetAllMocks();
|
||||
vi.unstubAllGlobals();
|
||||
});
|
||||
|
||||
const mockClient: GeminiEmbeddingClient = {
|
||||
baseUrl: "https://generativelanguage.googleapis.com/v1beta",
|
||||
headers: {},
|
||||
model: "gemini-embedding-2-preview",
|
||||
modelPath: "models/gemini-embedding-2-preview",
|
||||
apiKeys: ["test-key"],
|
||||
outputDimensionality: 1536,
|
||||
};
|
||||
|
||||
it("includes outputDimensionality in batch upload requests", async () => {
|
||||
remoteHttpMock.mockImplementationOnce(async (params) => {
|
||||
expect(params.url).toContain("/upload/v1beta/files?uploadType=multipart");
|
||||
const body = params.init?.body;
|
||||
if (!(body instanceof Blob)) {
|
||||
throw new Error("expected multipart blob body");
|
||||
}
|
||||
const text = await body.text();
|
||||
expect(text).toContain('"taskType":"RETRIEVAL_DOCUMENT"');
|
||||
expect(text).toContain('"outputDimensionality":1536');
|
||||
return await params.onResponse(
|
||||
new Response(JSON.stringify({ name: "files/file-123" }), {
|
||||
status: 200,
|
||||
headers: { "Content-Type": "application/json" },
|
||||
}),
|
||||
);
|
||||
});
|
||||
remoteHttpMock.mockImplementationOnce(async (params) => {
|
||||
expect(params.url).toMatch(/:asyncBatchEmbedContent$/u);
|
||||
return await params.onResponse(
|
||||
new Response(
|
||||
JSON.stringify({
|
||||
name: "batches/batch-1",
|
||||
state: "COMPLETED",
|
||||
outputConfig: { file: "files/output-1" },
|
||||
}),
|
||||
{
|
||||
status: 200,
|
||||
headers: { "Content-Type": "application/json" },
|
||||
},
|
||||
),
|
||||
);
|
||||
});
|
||||
remoteHttpMock.mockImplementationOnce(async (params) => {
|
||||
expect(params.url).toMatch(/\/files\/output-1:download$/u);
|
||||
return await params.onResponse(
|
||||
new Response(
|
||||
JSON.stringify({
|
||||
key: "req-1",
|
||||
response: { embedding: { values: [3, 4] } },
|
||||
}),
|
||||
{
|
||||
status: 200,
|
||||
headers: { "Content-Type": "application/jsonl" },
|
||||
},
|
||||
),
|
||||
);
|
||||
});
|
||||
|
||||
const results = await runGeminiEmbeddingBatches({
|
||||
gemini: mockClient,
|
||||
agentId: "main",
|
||||
requests: [
|
||||
{
|
||||
custom_id: "req-1",
|
||||
request: {
|
||||
content: { parts: [{ text: "hello world" }] },
|
||||
taskType: "RETRIEVAL_DOCUMENT",
|
||||
outputDimensionality: 1536,
|
||||
},
|
||||
},
|
||||
],
|
||||
wait: true,
|
||||
pollIntervalMs: 1,
|
||||
timeoutMs: 1000,
|
||||
concurrency: 1,
|
||||
});
|
||||
|
||||
const embedding = results.get("req-1");
|
||||
expect(embedding).toBeDefined();
|
||||
expect(embedding?.[0]).toBeCloseTo(0.6, 5);
|
||||
expect(embedding?.[1]).toBeCloseTo(0.8, 5);
|
||||
expect(magnitude(embedding ?? [])).toBeCloseTo(1, 5);
|
||||
expect(remoteHttpMock).toHaveBeenCalledTimes(3);
|
||||
});
|
||||
});
|
||||
@@ -1,259 +0,0 @@
|
||||
import {
|
||||
applyEmbeddingBatchOutputLine,
|
||||
buildBatchHeaders,
|
||||
buildEmbeddingBatchGroupOptions,
|
||||
EMBEDDING_BATCH_ENDPOINT,
|
||||
extractBatchErrorMessage,
|
||||
formatUnavailableBatchError,
|
||||
normalizeBatchBaseUrl,
|
||||
postJsonWithRetry,
|
||||
resolveBatchCompletionFromStatus,
|
||||
resolveCompletedBatchResult,
|
||||
runEmbeddingBatchGroups,
|
||||
throwIfBatchTerminalFailure,
|
||||
type EmbeddingBatchExecutionParams,
|
||||
type EmbeddingBatchStatus,
|
||||
type BatchCompletionResult,
|
||||
type ProviderBatchOutputLine,
|
||||
uploadBatchJsonlFile,
|
||||
withRemoteHttpResponse,
|
||||
} from "./batch-embedding-common.js";
|
||||
import type { OpenAiEmbeddingClient } from "./embeddings-openai.js";
|
||||
|
||||
export type OpenAiBatchRequest = {
|
||||
custom_id: string;
|
||||
method: "POST";
|
||||
url: "/v1/embeddings";
|
||||
body: {
|
||||
model: string;
|
||||
input: string;
|
||||
};
|
||||
};
|
||||
|
||||
export type OpenAiBatchStatus = EmbeddingBatchStatus;
|
||||
export type OpenAiBatchOutputLine = ProviderBatchOutputLine;
|
||||
|
||||
export const OPENAI_BATCH_ENDPOINT = EMBEDDING_BATCH_ENDPOINT;
|
||||
const OPENAI_BATCH_COMPLETION_WINDOW = "24h";
|
||||
const OPENAI_BATCH_MAX_REQUESTS = 50000;
|
||||
|
||||
async function submitOpenAiBatch(params: {
|
||||
openAi: OpenAiEmbeddingClient;
|
||||
requests: OpenAiBatchRequest[];
|
||||
agentId: string;
|
||||
}): Promise<OpenAiBatchStatus> {
|
||||
const baseUrl = normalizeBatchBaseUrl(params.openAi);
|
||||
const inputFileId = await uploadBatchJsonlFile({
|
||||
client: params.openAi,
|
||||
requests: params.requests,
|
||||
errorPrefix: "openai batch file upload failed",
|
||||
});
|
||||
|
||||
return await postJsonWithRetry<OpenAiBatchStatus>({
|
||||
url: `${baseUrl}/batches`,
|
||||
headers: buildBatchHeaders(params.openAi, { json: true }),
|
||||
ssrfPolicy: params.openAi.ssrfPolicy,
|
||||
body: {
|
||||
input_file_id: inputFileId,
|
||||
endpoint: OPENAI_BATCH_ENDPOINT,
|
||||
completion_window: OPENAI_BATCH_COMPLETION_WINDOW,
|
||||
metadata: {
|
||||
source: "openclaw-memory",
|
||||
agent: params.agentId,
|
||||
},
|
||||
},
|
||||
errorPrefix: "openai batch create failed",
|
||||
});
|
||||
}
|
||||
|
||||
async function fetchOpenAiBatchStatus(params: {
|
||||
openAi: OpenAiEmbeddingClient;
|
||||
batchId: string;
|
||||
}): Promise<OpenAiBatchStatus> {
|
||||
return await fetchOpenAiBatchResource({
|
||||
openAi: params.openAi,
|
||||
path: `/batches/${params.batchId}`,
|
||||
errorPrefix: "openai batch status",
|
||||
parse: async (res) => (await res.json()) as OpenAiBatchStatus,
|
||||
});
|
||||
}
|
||||
|
||||
async function fetchOpenAiFileContent(params: {
|
||||
openAi: OpenAiEmbeddingClient;
|
||||
fileId: string;
|
||||
}): Promise<string> {
|
||||
return await fetchOpenAiBatchResource({
|
||||
openAi: params.openAi,
|
||||
path: `/files/${params.fileId}/content`,
|
||||
errorPrefix: "openai batch file content",
|
||||
parse: async (res) => await res.text(),
|
||||
});
|
||||
}
|
||||
|
||||
async function fetchOpenAiBatchResource<T>(params: {
|
||||
openAi: OpenAiEmbeddingClient;
|
||||
path: string;
|
||||
errorPrefix: string;
|
||||
parse: (res: Response) => Promise<T>;
|
||||
}): Promise<T> {
|
||||
const baseUrl = normalizeBatchBaseUrl(params.openAi);
|
||||
return await withRemoteHttpResponse({
|
||||
url: `${baseUrl}${params.path}`,
|
||||
ssrfPolicy: params.openAi.ssrfPolicy,
|
||||
init: {
|
||||
headers: buildBatchHeaders(params.openAi, { json: true }),
|
||||
},
|
||||
onResponse: async (res) => {
|
||||
if (!res.ok) {
|
||||
const text = await res.text();
|
||||
throw new Error(`${params.errorPrefix} failed: ${res.status} ${text}`);
|
||||
}
|
||||
return await params.parse(res);
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
function parseOpenAiBatchOutput(text: string): OpenAiBatchOutputLine[] {
|
||||
if (!text.trim()) {
|
||||
return [];
|
||||
}
|
||||
return text
|
||||
.split("\n")
|
||||
.map((line) => line.trim())
|
||||
.filter(Boolean)
|
||||
.map((line) => JSON.parse(line) as OpenAiBatchOutputLine);
|
||||
}
|
||||
|
||||
async function readOpenAiBatchError(params: {
|
||||
openAi: OpenAiEmbeddingClient;
|
||||
errorFileId: string;
|
||||
}): Promise<string | undefined> {
|
||||
try {
|
||||
const content = await fetchOpenAiFileContent({
|
||||
openAi: params.openAi,
|
||||
fileId: params.errorFileId,
|
||||
});
|
||||
const lines = parseOpenAiBatchOutput(content);
|
||||
return extractBatchErrorMessage(lines);
|
||||
} catch (err) {
|
||||
return formatUnavailableBatchError(err);
|
||||
}
|
||||
}
|
||||
|
||||
async function waitForOpenAiBatch(params: {
|
||||
openAi: OpenAiEmbeddingClient;
|
||||
batchId: string;
|
||||
wait: boolean;
|
||||
pollIntervalMs: number;
|
||||
timeoutMs: number;
|
||||
debug?: (message: string, data?: Record<string, unknown>) => void;
|
||||
initial?: OpenAiBatchStatus;
|
||||
}): Promise<BatchCompletionResult> {
|
||||
const start = Date.now();
|
||||
let current: OpenAiBatchStatus | undefined = params.initial;
|
||||
while (true) {
|
||||
const status =
|
||||
current ??
|
||||
(await fetchOpenAiBatchStatus({
|
||||
openAi: params.openAi,
|
||||
batchId: params.batchId,
|
||||
}));
|
||||
const state = status.status ?? "unknown";
|
||||
if (state === "completed") {
|
||||
return resolveBatchCompletionFromStatus({
|
||||
provider: "openai",
|
||||
batchId: params.batchId,
|
||||
status,
|
||||
});
|
||||
}
|
||||
await throwIfBatchTerminalFailure({
|
||||
provider: "openai",
|
||||
status: { ...status, id: params.batchId },
|
||||
readError: async (errorFileId) =>
|
||||
await readOpenAiBatchError({
|
||||
openAi: params.openAi,
|
||||
errorFileId,
|
||||
}),
|
||||
});
|
||||
if (!params.wait) {
|
||||
throw new Error(`openai batch ${params.batchId} still ${state}; wait disabled`);
|
||||
}
|
||||
if (Date.now() - start > params.timeoutMs) {
|
||||
throw new Error(`openai batch ${params.batchId} timed out after ${params.timeoutMs}ms`);
|
||||
}
|
||||
params.debug?.(`openai batch ${params.batchId} ${state}; waiting ${params.pollIntervalMs}ms`);
|
||||
await new Promise((resolve) => setTimeout(resolve, params.pollIntervalMs));
|
||||
current = undefined;
|
||||
}
|
||||
}
|
||||
|
||||
export async function runOpenAiEmbeddingBatches(
|
||||
params: {
|
||||
openAi: OpenAiEmbeddingClient;
|
||||
agentId: string;
|
||||
requests: OpenAiBatchRequest[];
|
||||
} & EmbeddingBatchExecutionParams,
|
||||
): Promise<Map<string, number[]>> {
|
||||
return await runEmbeddingBatchGroups({
|
||||
...buildEmbeddingBatchGroupOptions(params, {
|
||||
maxRequests: OPENAI_BATCH_MAX_REQUESTS,
|
||||
debugLabel: "memory embeddings: openai batch submit",
|
||||
}),
|
||||
runGroup: async ({ group, groupIndex, groups, byCustomId }) => {
|
||||
const batchInfo = await submitOpenAiBatch({
|
||||
openAi: params.openAi,
|
||||
requests: group,
|
||||
agentId: params.agentId,
|
||||
});
|
||||
if (!batchInfo.id) {
|
||||
throw new Error("openai batch create failed: missing batch id");
|
||||
}
|
||||
const batchId = batchInfo.id;
|
||||
|
||||
params.debug?.("memory embeddings: openai batch created", {
|
||||
batchId: batchInfo.id,
|
||||
status: batchInfo.status,
|
||||
group: groupIndex + 1,
|
||||
groups,
|
||||
requests: group.length,
|
||||
});
|
||||
|
||||
const completed = await resolveCompletedBatchResult({
|
||||
provider: "openai",
|
||||
status: batchInfo,
|
||||
wait: params.wait,
|
||||
waitForBatch: async () =>
|
||||
await waitForOpenAiBatch({
|
||||
openAi: params.openAi,
|
||||
batchId,
|
||||
wait: params.wait,
|
||||
pollIntervalMs: params.pollIntervalMs,
|
||||
timeoutMs: params.timeoutMs,
|
||||
debug: params.debug,
|
||||
initial: batchInfo,
|
||||
}),
|
||||
});
|
||||
|
||||
const content = await fetchOpenAiFileContent({
|
||||
openAi: params.openAi,
|
||||
fileId: completed.outputFileId,
|
||||
});
|
||||
const outputLines = parseOpenAiBatchOutput(content);
|
||||
const errors: string[] = [];
|
||||
const remaining = new Set(group.map((request) => request.custom_id));
|
||||
|
||||
for (const line of outputLines) {
|
||||
applyEmbeddingBatchOutputLine({ line, remaining, errors, byCustomId });
|
||||
}
|
||||
|
||||
if (errors.length > 0) {
|
||||
throw new Error(`openai batch ${batchInfo.id} failed: ${errors.join("; ")}`);
|
||||
}
|
||||
if (remaining.size > 0) {
|
||||
throw new Error(
|
||||
`openai batch ${batchInfo.id} missing ${remaining.size} embedding responses`,
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
}
|
||||
@@ -1,176 +0,0 @@
|
||||
import { ReadableStream } from "node:stream/web";
|
||||
import { setTimeout as nativeSleep } from "node:timers/promises";
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
import {
|
||||
runVoyageEmbeddingBatches,
|
||||
type VoyageBatchOutputLine,
|
||||
type VoyageBatchRequest,
|
||||
} from "./batch-voyage.js";
|
||||
import type { VoyageEmbeddingClient } from "./embeddings-voyage.js";
|
||||
|
||||
const realNow = Date.now.bind(Date);
|
||||
|
||||
describe("runVoyageEmbeddingBatches", () => {
|
||||
const mockClient: VoyageEmbeddingClient = {
|
||||
baseUrl: "https://api.voyageai.com/v1",
|
||||
headers: { Authorization: "Bearer test-key" },
|
||||
model: "voyage-4-large",
|
||||
};
|
||||
|
||||
const mockRequests: VoyageBatchRequest[] = [
|
||||
{ custom_id: "req-1", body: { input: "text1" } },
|
||||
{ custom_id: "req-2", body: { input: "text2" } },
|
||||
];
|
||||
|
||||
it("successfully submits batch, waits, and streams results", async () => {
|
||||
const outputLines: VoyageBatchOutputLine[] = [
|
||||
{
|
||||
custom_id: "req-1",
|
||||
response: { status_code: 200, body: { data: [{ embedding: [0.1, 0.1] }] } },
|
||||
},
|
||||
{
|
||||
custom_id: "req-2",
|
||||
response: { status_code: 200, body: { data: [{ embedding: [0.2, 0.2] }] } },
|
||||
},
|
||||
];
|
||||
const withRemoteHttpResponse = vi.fn();
|
||||
const postJsonWithRetry = vi.fn();
|
||||
const uploadBatchJsonlFile = vi.fn();
|
||||
|
||||
// Create a stream that emits the NDJSON lines
|
||||
const stream = new ReadableStream({
|
||||
start(controller) {
|
||||
const text = outputLines.map((l) => JSON.stringify(l)).join("\n");
|
||||
controller.enqueue(new TextEncoder().encode(text));
|
||||
controller.close();
|
||||
},
|
||||
});
|
||||
uploadBatchJsonlFile.mockImplementationOnce(async (params) => {
|
||||
expect(params.errorPrefix).toBe("voyage batch file upload failed");
|
||||
expect(params.requests).toEqual(mockRequests);
|
||||
return "file-123";
|
||||
});
|
||||
postJsonWithRetry.mockImplementationOnce(async (params) => {
|
||||
expect(params.url).toContain("/batches");
|
||||
expect(params.body).toMatchObject({
|
||||
input_file_id: "file-123",
|
||||
completion_window: "12h",
|
||||
request_params: {
|
||||
model: "voyage-4-large",
|
||||
input_type: "document",
|
||||
},
|
||||
});
|
||||
return {
|
||||
id: "batch-abc",
|
||||
status: "pending",
|
||||
};
|
||||
});
|
||||
withRemoteHttpResponse.mockImplementationOnce(async (params) => {
|
||||
expect(params.url).toContain("/batches/batch-abc");
|
||||
return await params.onResponse(
|
||||
new Response(
|
||||
JSON.stringify({
|
||||
id: "batch-abc",
|
||||
status: "completed",
|
||||
output_file_id: "file-out-999",
|
||||
}),
|
||||
{
|
||||
status: 200,
|
||||
headers: { "Content-Type": "application/json" },
|
||||
},
|
||||
),
|
||||
);
|
||||
});
|
||||
withRemoteHttpResponse.mockImplementationOnce(async (params) => {
|
||||
expect(params.url).toContain("/files/file-out-999/content");
|
||||
return await params.onResponse(
|
||||
new Response(stream as unknown as BodyInit, {
|
||||
status: 200,
|
||||
headers: { "Content-Type": "application/x-ndjson" },
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
const results = await runVoyageEmbeddingBatches({
|
||||
client: mockClient,
|
||||
agentId: "agent-1",
|
||||
requests: mockRequests,
|
||||
wait: true,
|
||||
pollIntervalMs: 1, // fast poll
|
||||
timeoutMs: 1000,
|
||||
concurrency: 1,
|
||||
deps: {
|
||||
now: realNow,
|
||||
sleep: async (ms) => {
|
||||
await nativeSleep(ms);
|
||||
},
|
||||
postJsonWithRetry,
|
||||
uploadBatchJsonlFile,
|
||||
withRemoteHttpResponse,
|
||||
},
|
||||
});
|
||||
|
||||
expect(results.size).toBe(2);
|
||||
expect(results.get("req-1")).toEqual([0.1, 0.1]);
|
||||
expect(results.get("req-2")).toEqual([0.2, 0.2]);
|
||||
expect(uploadBatchJsonlFile).toHaveBeenCalledTimes(1);
|
||||
expect(postJsonWithRetry).toHaveBeenCalledTimes(1);
|
||||
expect(withRemoteHttpResponse).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it("handles empty lines and stream chunks correctly", async () => {
|
||||
const withRemoteHttpResponse = vi.fn();
|
||||
const postJsonWithRetry = vi.fn();
|
||||
const uploadBatchJsonlFile = vi.fn();
|
||||
const stream = new ReadableStream({
|
||||
start(controller) {
|
||||
const line1 = JSON.stringify({
|
||||
custom_id: "req-1",
|
||||
response: { body: { data: [{ embedding: [1] }] } },
|
||||
});
|
||||
const line2 = JSON.stringify({
|
||||
custom_id: "req-2",
|
||||
response: { body: { data: [{ embedding: [2] }] } },
|
||||
});
|
||||
|
||||
// Split across chunks
|
||||
controller.enqueue(new TextEncoder().encode(line1 + "\n"));
|
||||
controller.enqueue(new TextEncoder().encode("\n")); // empty line
|
||||
controller.enqueue(new TextEncoder().encode(line2)); // no newline at EOF
|
||||
controller.close();
|
||||
},
|
||||
});
|
||||
uploadBatchJsonlFile.mockResolvedValueOnce("f1");
|
||||
postJsonWithRetry.mockResolvedValueOnce({
|
||||
id: "b1",
|
||||
status: "completed",
|
||||
output_file_id: "out1",
|
||||
});
|
||||
withRemoteHttpResponse.mockImplementationOnce(async (params) => {
|
||||
expect(params.url).toContain("/files/out1/content");
|
||||
return await params.onResponse(new Response(stream as unknown as BodyInit, { status: 200 }));
|
||||
});
|
||||
|
||||
const results = await runVoyageEmbeddingBatches({
|
||||
client: mockClient,
|
||||
agentId: "a1",
|
||||
requests: mockRequests,
|
||||
wait: true,
|
||||
pollIntervalMs: 1,
|
||||
timeoutMs: 1000,
|
||||
concurrency: 1,
|
||||
deps: {
|
||||
now: realNow,
|
||||
sleep: async (ms) => {
|
||||
await nativeSleep(ms);
|
||||
},
|
||||
postJsonWithRetry,
|
||||
uploadBatchJsonlFile,
|
||||
withRemoteHttpResponse,
|
||||
},
|
||||
});
|
||||
|
||||
expect(results.get("req-1")).toEqual([1]);
|
||||
expect(results.get("req-2")).toEqual([2]);
|
||||
});
|
||||
});
|
||||
@@ -1,40 +1,14 @@
|
||||
import { normalizeLowercaseStringOrEmpty } from "../../../../src/shared/string-coerce.js";
|
||||
import type { EmbeddingProvider } from "./embeddings.js";
|
||||
|
||||
const DEFAULT_EMBEDDING_MAX_INPUT_TOKENS = 8192;
|
||||
const DEFAULT_LOCAL_EMBEDDING_MAX_INPUT_TOKENS = 2048;
|
||||
|
||||
const KNOWN_EMBEDDING_MAX_INPUT_TOKENS: Record<string, number> = {
|
||||
"openai:text-embedding-3-small": 8192,
|
||||
"openai:text-embedding-3-large": 8192,
|
||||
"openai:text-embedding-ada-002": 8191,
|
||||
"gemini:text-embedding-004": 2048,
|
||||
"gemini:gemini-embedding-001": 2048,
|
||||
"gemini:gemini-embedding-2-preview": 8192,
|
||||
"voyage:voyage-3": 32000,
|
||||
"voyage:voyage-3-lite": 16000,
|
||||
"voyage:voyage-code-3": 32000,
|
||||
};
|
||||
|
||||
export function resolveEmbeddingMaxInputTokens(provider: EmbeddingProvider): number {
|
||||
if (typeof provider.maxInputTokens === "number") {
|
||||
return provider.maxInputTokens;
|
||||
}
|
||||
|
||||
// Provider/model mapping is best-effort; different providers use different
|
||||
// limits and we prefer to be conservative when we don't know.
|
||||
const key = normalizeLowercaseStringOrEmpty(`${provider.id}:${provider.model}`);
|
||||
const known = KNOWN_EMBEDDING_MAX_INPUT_TOKENS[key];
|
||||
if (typeof known === "number") {
|
||||
return known;
|
||||
}
|
||||
|
||||
// Provider-specific conservative fallbacks. This prevents us from accidentally
|
||||
// using the OpenAI default for providers with much smaller limits.
|
||||
if (normalizeLowercaseStringOrEmpty(provider.id) === "gemini") {
|
||||
return 2048;
|
||||
}
|
||||
if (normalizeLowercaseStringOrEmpty(provider.id) === "local") {
|
||||
if (provider.id === "local") {
|
||||
return DEFAULT_LOCAL_EMBEDDING_MAX_INPUT_TOKENS;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,377 +0,0 @@
|
||||
import { afterEach, beforeAll, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
|
||||
const { defaultProviderMock, resolveCredentialsMock, sendMock } = vi.hoisted(() => ({
|
||||
defaultProviderMock: vi.fn(),
|
||||
resolveCredentialsMock: vi.fn(),
|
||||
sendMock: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock("@aws-sdk/client-bedrock-runtime", () => {
|
||||
class MockClient {
|
||||
region: string;
|
||||
constructor(config: { region: string }) {
|
||||
this.region = config.region;
|
||||
}
|
||||
send = sendMock;
|
||||
}
|
||||
class MockCommand {
|
||||
input: unknown;
|
||||
constructor(input: unknown) {
|
||||
this.input = input;
|
||||
}
|
||||
}
|
||||
return { BedrockRuntimeClient: MockClient, InvokeModelCommand: MockCommand };
|
||||
});
|
||||
|
||||
vi.mock("@aws-sdk/credential-provider-node", () => ({
|
||||
defaultProvider: defaultProviderMock.mockImplementation(() => resolveCredentialsMock),
|
||||
}));
|
||||
|
||||
let createBedrockEmbeddingProvider: typeof import("./embeddings-bedrock.js").createBedrockEmbeddingProvider;
|
||||
let resolveBedrockEmbeddingClient: typeof import("./embeddings-bedrock.js").resolveBedrockEmbeddingClient;
|
||||
let normalizeBedrockEmbeddingModel: typeof import("./embeddings-bedrock.js").normalizeBedrockEmbeddingModel;
|
||||
let hasAwsCredentials: typeof import("./embeddings-bedrock.js").hasAwsCredentials;
|
||||
|
||||
beforeAll(async () => {
|
||||
({
|
||||
createBedrockEmbeddingProvider,
|
||||
resolveBedrockEmbeddingClient,
|
||||
normalizeBedrockEmbeddingModel,
|
||||
hasAwsCredentials,
|
||||
} = await import("./embeddings-bedrock.js"));
|
||||
});
|
||||
|
||||
beforeEach(() => {
|
||||
defaultProviderMock.mockImplementation(() => resolveCredentialsMock);
|
||||
});
|
||||
|
||||
const enc = (body: unknown) => ({ body: new TextEncoder().encode(JSON.stringify(body)) });
|
||||
const reqBody = (i = 0): Record<string, unknown> =>
|
||||
JSON.parse(sendMock.mock.calls[i][0].input.body);
|
||||
|
||||
describe("bedrock embedding provider", () => {
|
||||
const originalEnv = process.env;
|
||||
afterEach(() => {
|
||||
process.env = originalEnv;
|
||||
vi.restoreAllMocks();
|
||||
defaultProviderMock.mockClear();
|
||||
resolveCredentialsMock.mockReset();
|
||||
sendMock.mockReset();
|
||||
});
|
||||
|
||||
// --- Normalization ---
|
||||
|
||||
it("normalizes model names with prefixes", () => {
|
||||
expect(normalizeBedrockEmbeddingModel("bedrock/amazon.titan-embed-text-v2:0")).toBe(
|
||||
"amazon.titan-embed-text-v2:0",
|
||||
);
|
||||
expect(normalizeBedrockEmbeddingModel("amazon-bedrock/cohere.embed-english-v3")).toBe(
|
||||
"cohere.embed-english-v3",
|
||||
);
|
||||
expect(normalizeBedrockEmbeddingModel("")).toBe("amazon.titan-embed-text-v2:0");
|
||||
});
|
||||
|
||||
// --- Client resolution ---
|
||||
|
||||
it("resolves region from env", () => {
|
||||
process.env = { ...originalEnv, AWS_REGION: "eu-west-1" };
|
||||
const c = resolveBedrockEmbeddingClient({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "amazon.titan-embed-text-v2:0",
|
||||
fallback: "none",
|
||||
});
|
||||
expect(c.region).toBe("eu-west-1");
|
||||
expect(c.dimensions).toBe(1024);
|
||||
});
|
||||
|
||||
it("defaults to us-east-1", () => {
|
||||
process.env = { ...originalEnv };
|
||||
delete process.env.AWS_REGION;
|
||||
delete process.env.AWS_DEFAULT_REGION;
|
||||
expect(
|
||||
resolveBedrockEmbeddingClient({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "amazon.titan-embed-text-v2:0",
|
||||
fallback: "none",
|
||||
}).region,
|
||||
).toBe("us-east-1");
|
||||
});
|
||||
|
||||
it("extracts region from baseUrl", () => {
|
||||
process.env = { ...originalEnv };
|
||||
delete process.env.AWS_REGION;
|
||||
const c = resolveBedrockEmbeddingClient({
|
||||
config: {
|
||||
models: {
|
||||
providers: {
|
||||
"amazon-bedrock": { baseUrl: "https://bedrock-runtime.ap-southeast-2.amazonaws.com" },
|
||||
},
|
||||
},
|
||||
} as never,
|
||||
provider: "bedrock",
|
||||
model: "amazon.titan-embed-text-v2:0",
|
||||
fallback: "none",
|
||||
});
|
||||
expect(c.region).toBe("ap-southeast-2");
|
||||
});
|
||||
|
||||
it("validates dimensions", () => {
|
||||
expect(() =>
|
||||
resolveBedrockEmbeddingClient({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "amazon.titan-embed-text-v2:0",
|
||||
fallback: "none",
|
||||
outputDimensionality: 768,
|
||||
}),
|
||||
).toThrow("Invalid dimensions 768");
|
||||
});
|
||||
|
||||
it("accepts valid dimensions", () => {
|
||||
expect(
|
||||
resolveBedrockEmbeddingClient({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "amazon.titan-embed-text-v2:0",
|
||||
fallback: "none",
|
||||
outputDimensionality: 256,
|
||||
}).dimensions,
|
||||
).toBe(256);
|
||||
});
|
||||
|
||||
it("resolves throughput-suffixed variants", () => {
|
||||
expect(
|
||||
resolveBedrockEmbeddingClient({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "amazon.titan-embed-text-v1:2:8k",
|
||||
fallback: "none",
|
||||
}).dimensions,
|
||||
).toBe(1536);
|
||||
});
|
||||
|
||||
// --- Credential detection ---
|
||||
|
||||
it("detects access keys", async () => {
|
||||
await expect(
|
||||
hasAwsCredentials({
|
||||
AWS_ACCESS_KEY_ID: "A",
|
||||
AWS_SECRET_ACCESS_KEY: "s",
|
||||
} as NodeJS.ProcessEnv),
|
||||
).resolves.toBe(true);
|
||||
});
|
||||
it("detects profile", async () => {
|
||||
await expect(hasAwsCredentials({ AWS_PROFILE: "default" } as NodeJS.ProcessEnv)).resolves.toBe(
|
||||
true,
|
||||
);
|
||||
});
|
||||
it("detects ECS task role", async () => {
|
||||
await expect(
|
||||
hasAwsCredentials({ AWS_CONTAINER_CREDENTIALS_RELATIVE_URI: "/v2" } as NodeJS.ProcessEnv),
|
||||
).resolves.toBe(true);
|
||||
});
|
||||
it("detects EKS IRSA", async () => {
|
||||
await expect(
|
||||
hasAwsCredentials({
|
||||
AWS_WEB_IDENTITY_TOKEN_FILE: "/var/run/secrets/token",
|
||||
AWS_ROLE_ARN: "arn:aws:iam::123:role/x",
|
||||
} as NodeJS.ProcessEnv),
|
||||
).resolves.toBe(true);
|
||||
});
|
||||
it("detects credentials via the AWS SDK default provider chain", async () => {
|
||||
resolveCredentialsMock.mockResolvedValue({ accessKeyId: "AKIAEXAMPLE" });
|
||||
await expect(hasAwsCredentials({} as NodeJS.ProcessEnv)).resolves.toBe(true);
|
||||
expect(defaultProviderMock).toHaveBeenCalledWith({ timeout: 1000, maxRetries: 0 });
|
||||
});
|
||||
it("returns false with no creds", async () => {
|
||||
resolveCredentialsMock.mockRejectedValue(new Error("no aws credentials"));
|
||||
await expect(hasAwsCredentials({} as NodeJS.ProcessEnv)).resolves.toBe(false);
|
||||
});
|
||||
|
||||
// --- Titan V2 ---
|
||||
|
||||
it("embeds with Titan V2", async () => {
|
||||
sendMock.mockResolvedValue(enc({ embedding: [0.1, 0.2, 0.3] }));
|
||||
const { provider } = await createBedrockEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "amazon.titan-embed-text-v2:0",
|
||||
fallback: "none",
|
||||
});
|
||||
expect(await provider.embedQuery("test")).toHaveLength(3);
|
||||
expect(reqBody()).toMatchObject({ inputText: "test", normalize: true, dimensions: 1024 });
|
||||
});
|
||||
|
||||
it("returns empty for blank text", async () => {
|
||||
const { provider } = await createBedrockEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "amazon.titan-embed-text-v2:0",
|
||||
fallback: "none",
|
||||
});
|
||||
expect(await provider.embedQuery(" ")).toEqual([]);
|
||||
expect(sendMock).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("batches Titan V2 concurrently", async () => {
|
||||
sendMock
|
||||
.mockResolvedValueOnce(enc({ embedding: [0.1] }))
|
||||
.mockResolvedValueOnce(enc({ embedding: [0.2] }));
|
||||
const { provider } = await createBedrockEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "amazon.titan-embed-text-v2:0",
|
||||
fallback: "none",
|
||||
});
|
||||
expect(await provider.embedBatch(["a", "b"])).toHaveLength(2);
|
||||
expect(sendMock).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
// --- Titan V1 ---
|
||||
|
||||
it("sends only inputText for Titan V1", async () => {
|
||||
sendMock.mockResolvedValue(enc({ embedding: [0.5] }));
|
||||
const { provider } = await createBedrockEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "amazon.titan-embed-text-v1",
|
||||
fallback: "none",
|
||||
});
|
||||
await provider.embedQuery("hi");
|
||||
expect(reqBody()).toEqual({ inputText: "hi" });
|
||||
});
|
||||
|
||||
it("handles Titan G1 text variant", async () => {
|
||||
sendMock.mockResolvedValue(enc({ embedding: [0.1] }));
|
||||
const { provider } = await createBedrockEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "amazon.titan-embed-g1-text-02",
|
||||
fallback: "none",
|
||||
});
|
||||
await provider.embedQuery("hi");
|
||||
expect(reqBody()).toEqual({ inputText: "hi" });
|
||||
});
|
||||
|
||||
// --- Cohere V3 ---
|
||||
|
||||
it("embeds Cohere V3 batch in single call", async () => {
|
||||
sendMock.mockResolvedValue(enc({ embeddings: [[0.1], [0.2]] }));
|
||||
const { provider } = await createBedrockEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "cohere.embed-english-v3",
|
||||
fallback: "none",
|
||||
});
|
||||
expect(await provider.embedBatch(["a", "b"])).toHaveLength(2);
|
||||
expect(sendMock).toHaveBeenCalledTimes(1);
|
||||
expect(reqBody()).toMatchObject({ texts: ["a", "b"], input_type: "search_document" });
|
||||
});
|
||||
|
||||
it("uses search_query for Cohere embedQuery", async () => {
|
||||
sendMock.mockResolvedValue(enc({ embeddings: [[0.1]] }));
|
||||
const { provider } = await createBedrockEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "cohere.embed-english-v3",
|
||||
fallback: "none",
|
||||
});
|
||||
await provider.embedQuery("q");
|
||||
expect(reqBody().input_type).toBe("search_query");
|
||||
});
|
||||
|
||||
// --- Cohere V4 ---
|
||||
|
||||
it("embeds Cohere V4 with embedding_types + output_dimension", async () => {
|
||||
sendMock.mockResolvedValue(enc({ embeddings: { float: [[0.1], [0.2]] } }));
|
||||
const { provider } = await createBedrockEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "cohere.embed-v4:0",
|
||||
fallback: "none",
|
||||
});
|
||||
expect(await provider.embedBatch(["a", "b"])).toHaveLength(2);
|
||||
expect(reqBody()).toMatchObject({ embedding_types: ["float"], output_dimension: 1536 });
|
||||
});
|
||||
|
||||
it("validates Cohere V4 dimensions", () => {
|
||||
expect(() =>
|
||||
resolveBedrockEmbeddingClient({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "cohere.embed-v4:0",
|
||||
fallback: "none",
|
||||
outputDimensionality: 2048,
|
||||
}),
|
||||
).toThrow("Invalid dimensions 2048");
|
||||
});
|
||||
|
||||
// --- Nova ---
|
||||
|
||||
it("embeds Nova with SINGLE_EMBEDDING format", async () => {
|
||||
sendMock.mockResolvedValue(
|
||||
enc({ embeddings: [{ embeddingType: "TEXT", embedding: [0.1, 0.2] }] }),
|
||||
);
|
||||
const { provider } = await createBedrockEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "amazon.nova-2-multimodal-embeddings-v1:0",
|
||||
fallback: "none",
|
||||
});
|
||||
expect(await provider.embedQuery("hi")).toHaveLength(2);
|
||||
expect(reqBody().taskType).toBe("SINGLE_EMBEDDING");
|
||||
});
|
||||
|
||||
it("validates Nova dimensions", () => {
|
||||
expect(() =>
|
||||
resolveBedrockEmbeddingClient({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "amazon.nova-2-multimodal-embeddings-v1:0",
|
||||
fallback: "none",
|
||||
outputDimensionality: 512,
|
||||
}),
|
||||
).toThrow("Invalid dimensions 512");
|
||||
});
|
||||
|
||||
it("batches Nova concurrently", async () => {
|
||||
sendMock
|
||||
.mockResolvedValueOnce(enc({ embeddings: [{ embeddingType: "TEXT", embedding: [0.1] }] }))
|
||||
.mockResolvedValueOnce(enc({ embeddings: [{ embeddingType: "TEXT", embedding: [0.2] }] }));
|
||||
const { provider } = await createBedrockEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "amazon.nova-2-multimodal-embeddings-v1:0",
|
||||
fallback: "none",
|
||||
});
|
||||
expect(await provider.embedBatch(["a", "b"])).toHaveLength(2);
|
||||
expect(sendMock).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
// --- TwelveLabs ---
|
||||
|
||||
it("embeds TwelveLabs Marengo", async () => {
|
||||
sendMock.mockResolvedValue(enc({ data: [{ embedding: [0.1, 0.2] }] }));
|
||||
const { provider } = await createBedrockEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "twelvelabs.marengo-embed-3-0-v1:0",
|
||||
fallback: "none",
|
||||
});
|
||||
expect(await provider.embedQuery("hi")).toHaveLength(2);
|
||||
expect(reqBody()).toEqual({ inputType: "text", text: { inputText: "hi" } });
|
||||
});
|
||||
|
||||
it("embeds TwelveLabs object-style responses", async () => {
|
||||
sendMock.mockResolvedValue(enc({ data: { embedding: [0.3, 0.4] } }));
|
||||
const { provider } = await createBedrockEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "twelvelabs.marengo-embed-2-7-v1:0",
|
||||
fallback: "none",
|
||||
});
|
||||
expect(await provider.embedQuery("hi")).toEqual([0.6, 0.8]);
|
||||
});
|
||||
});
|
||||
@@ -1,398 +0,0 @@
|
||||
import { normalizeLowercaseStringOrEmpty } from "../../../../src/shared/string-coerce.js";
|
||||
import { sanitizeAndNormalizeEmbedding } from "./embedding-vectors.js";
|
||||
import { debugEmbeddingsLog } from "./embeddings-debug.js";
|
||||
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Types & constants
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export type BedrockEmbeddingClient = {
|
||||
region: string;
|
||||
model: string;
|
||||
dimensions?: number;
|
||||
};
|
||||
|
||||
export const DEFAULT_BEDROCK_EMBEDDING_MODEL = "amazon.titan-embed-text-v2:0";
|
||||
|
||||
/** Request/response format family — each has a different API shape. */
|
||||
type Family = "titan-v1" | "titan-v2" | "cohere-v3" | "cohere-v4" | "nova" | "twelvelabs";
|
||||
|
||||
interface ModelSpec {
|
||||
maxTokens: number;
|
||||
dims: number;
|
||||
validDims?: number[];
|
||||
family: Family;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Model catalog
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const MODELS: Record<string, ModelSpec> = {
|
||||
"amazon.titan-embed-text-v2:0": {
|
||||
maxTokens: 8192,
|
||||
dims: 1024,
|
||||
validDims: [256, 512, 1024],
|
||||
family: "titan-v2",
|
||||
},
|
||||
"amazon.titan-embed-text-v1": { maxTokens: 8000, dims: 1536, family: "titan-v1" },
|
||||
"amazon.titan-embed-g1-text-02": { maxTokens: 8000, dims: 1536, family: "titan-v1" },
|
||||
"amazon.titan-embed-image-v1": { maxTokens: 128, dims: 1024, family: "titan-v1" },
|
||||
"cohere.embed-english-v3": { maxTokens: 512, dims: 1024, family: "cohere-v3" },
|
||||
"cohere.embed-multilingual-v3": { maxTokens: 512, dims: 1024, family: "cohere-v3" },
|
||||
"cohere.embed-v4:0": {
|
||||
maxTokens: 128000,
|
||||
dims: 1536,
|
||||
validDims: [256, 384, 512, 768, 1024, 1536],
|
||||
family: "cohere-v4",
|
||||
},
|
||||
"amazon.nova-2-multimodal-embeddings-v1:0": {
|
||||
maxTokens: 8192,
|
||||
dims: 1024,
|
||||
validDims: [256, 384, 1024, 3072],
|
||||
family: "nova",
|
||||
},
|
||||
"twelvelabs.marengo-embed-2-7-v1:0": { maxTokens: 512, dims: 1024, family: "twelvelabs" },
|
||||
"twelvelabs.marengo-embed-3-0-v1:0": { maxTokens: 512, dims: 512, family: "twelvelabs" },
|
||||
};
|
||||
|
||||
/** Resolve spec, stripping throughput suffixes like `:2:8k` or `:0:512`. */
|
||||
function resolveSpec(modelId: string): ModelSpec | undefined {
|
||||
if (MODELS[modelId]) {
|
||||
return MODELS[modelId];
|
||||
}
|
||||
const parts = modelId.split(":");
|
||||
for (let i = parts.length - 1; i >= 1; i--) {
|
||||
const spec = MODELS[parts.slice(0, i).join(":")];
|
||||
if (spec) {
|
||||
return spec;
|
||||
}
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
/** Infer family from model ID prefix when not in catalog. */
|
||||
function inferFamily(modelId: string): Family {
|
||||
const id = normalizeLowercaseStringOrEmpty(modelId);
|
||||
if (id.startsWith("amazon.titan-embed-text-v2")) {
|
||||
return "titan-v2";
|
||||
}
|
||||
if (id.startsWith("amazon.titan-embed")) {
|
||||
return "titan-v1";
|
||||
}
|
||||
if (id.startsWith("amazon.nova")) {
|
||||
return "nova";
|
||||
}
|
||||
if (id.startsWith("cohere.embed-v4")) {
|
||||
return "cohere-v4";
|
||||
}
|
||||
if (id.startsWith("cohere.embed")) {
|
||||
return "cohere-v3";
|
||||
}
|
||||
if (id.startsWith("twelvelabs.")) {
|
||||
return "twelvelabs";
|
||||
}
|
||||
return "titan-v1"; // safest default — simplest request format
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// AWS SDK lazy loader
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type SdkClient = import("@aws-sdk/client-bedrock-runtime").BedrockRuntimeClient;
|
||||
type SdkCommand = import("@aws-sdk/client-bedrock-runtime").InvokeModelCommand;
|
||||
|
||||
interface AwsSdk {
|
||||
BedrockRuntimeClient: new (config: { region: string }) => SdkClient;
|
||||
InvokeModelCommand: new (input: {
|
||||
modelId: string;
|
||||
body: string;
|
||||
contentType: string;
|
||||
accept: string;
|
||||
}) => SdkCommand;
|
||||
}
|
||||
|
||||
interface AwsCredentialProviderSdk {
|
||||
defaultProvider: (init?: { timeout?: number; maxRetries?: number }) => () => Promise<{
|
||||
accessKeyId?: string;
|
||||
}>;
|
||||
}
|
||||
|
||||
let sdkCache: AwsSdk | null = null;
|
||||
let credentialProviderSdkCache: AwsCredentialProviderSdk | null | undefined;
|
||||
|
||||
async function loadSdk(): Promise<AwsSdk> {
|
||||
if (sdkCache) {
|
||||
return sdkCache;
|
||||
}
|
||||
try {
|
||||
sdkCache = (await import("@aws-sdk/client-bedrock-runtime")) as unknown as AwsSdk;
|
||||
return sdkCache;
|
||||
} catch {
|
||||
throw new Error(
|
||||
"No API key found for provider bedrock: @aws-sdk/client-bedrock-runtime is not installed. " +
|
||||
"Install it with: npm install @aws-sdk/client-bedrock-runtime",
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
async function loadCredentialProviderSdk(): Promise<AwsCredentialProviderSdk | null> {
|
||||
if (credentialProviderSdkCache !== undefined) {
|
||||
return credentialProviderSdkCache;
|
||||
}
|
||||
try {
|
||||
credentialProviderSdkCache =
|
||||
(await import("@aws-sdk/credential-provider-node")) as unknown as AwsCredentialProviderSdk;
|
||||
} catch {
|
||||
credentialProviderSdkCache = null;
|
||||
}
|
||||
return credentialProviderSdkCache;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const MODEL_PREFIX_RE = /^(?:bedrock|amazon-bedrock|aws)\//;
|
||||
const REGION_RE = /bedrock-runtime\.([a-z0-9-]+)\./;
|
||||
|
||||
export function normalizeBedrockEmbeddingModel(model: string): string {
|
||||
const trimmed = model.trim();
|
||||
return trimmed ? trimmed.replace(MODEL_PREFIX_RE, "") : DEFAULT_BEDROCK_EMBEDDING_MODEL;
|
||||
}
|
||||
|
||||
function regionFromUrl(url: string | undefined): string | undefined {
|
||||
return url?.trim() ? REGION_RE.exec(url)?.[1] : undefined;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Request builders
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
function buildBody(family: Family, text: string, dims?: number): string {
|
||||
switch (family) {
|
||||
case "titan-v2": {
|
||||
const b: Record<string, unknown> = { inputText: text };
|
||||
if (dims != null) {
|
||||
b.dimensions = dims;
|
||||
b.normalize = true;
|
||||
}
|
||||
return JSON.stringify(b);
|
||||
}
|
||||
case "titan-v1":
|
||||
return JSON.stringify({ inputText: text });
|
||||
case "nova":
|
||||
return JSON.stringify({
|
||||
taskType: "SINGLE_EMBEDDING",
|
||||
singleEmbeddingParams: {
|
||||
embeddingPurpose: "GENERIC_INDEX",
|
||||
embeddingDimension: dims ?? 1024,
|
||||
text: { truncationMode: "END", value: text },
|
||||
},
|
||||
});
|
||||
case "twelvelabs":
|
||||
return JSON.stringify({ inputType: "text", text: { inputText: text } });
|
||||
default:
|
||||
return JSON.stringify({ inputText: text });
|
||||
}
|
||||
}
|
||||
|
||||
function buildCohereBody(
|
||||
family: Family,
|
||||
texts: string[],
|
||||
inputType: "search_query" | "search_document",
|
||||
dims?: number,
|
||||
): string {
|
||||
const body: Record<string, unknown> = { texts, input_type: inputType, truncate: "END" };
|
||||
if (family === "cohere-v4") {
|
||||
body.embedding_types = ["float"];
|
||||
if (dims != null) {
|
||||
body.output_dimension = dims;
|
||||
}
|
||||
}
|
||||
return JSON.stringify(body);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Response parsers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
function parseSingle(family: Family, raw: string): number[] {
|
||||
const data = JSON.parse(raw);
|
||||
switch (family) {
|
||||
case "nova":
|
||||
return data.embeddings?.[0]?.embedding ?? [];
|
||||
case "twelvelabs": {
|
||||
if (Array.isArray(data.data)) {
|
||||
return data.data[0]?.embedding ?? [];
|
||||
}
|
||||
if (Array.isArray(data.data?.embedding)) {
|
||||
return data.data.embedding;
|
||||
}
|
||||
return data.embedding ?? [];
|
||||
}
|
||||
default:
|
||||
return data.embedding ?? [];
|
||||
}
|
||||
}
|
||||
|
||||
function parseCohereBatch(family: Family, raw: string): number[][] {
|
||||
const data = JSON.parse(raw);
|
||||
const embeddings = data.embeddings;
|
||||
if (!embeddings) {
|
||||
return [];
|
||||
}
|
||||
if (family === "cohere-v4" && !Array.isArray(embeddings)) {
|
||||
return embeddings.float ?? [];
|
||||
}
|
||||
return embeddings;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Provider
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export async function createBedrockEmbeddingProvider(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<{ provider: EmbeddingProvider; client: BedrockEmbeddingClient }> {
|
||||
const client = resolveBedrockEmbeddingClient(options);
|
||||
const { BedrockRuntimeClient, InvokeModelCommand } = await loadSdk();
|
||||
const sdk = new BedrockRuntimeClient({ region: client.region });
|
||||
const spec = resolveSpec(client.model);
|
||||
const family = spec?.family ?? inferFamily(client.model);
|
||||
|
||||
debugEmbeddingsLog("memory embeddings: bedrock client", {
|
||||
region: client.region,
|
||||
model: client.model,
|
||||
dimensions: client.dimensions,
|
||||
family,
|
||||
});
|
||||
|
||||
const invoke = async (body: string): Promise<string> => {
|
||||
const res = await sdk.send(
|
||||
new InvokeModelCommand({
|
||||
modelId: client.model,
|
||||
body,
|
||||
contentType: "application/json",
|
||||
accept: "application/json",
|
||||
}),
|
||||
);
|
||||
return new TextDecoder().decode(res.body);
|
||||
};
|
||||
|
||||
const isCohere = family === "cohere-v3" || family === "cohere-v4";
|
||||
|
||||
const embedSingle = async (text: string): Promise<number[]> => {
|
||||
const raw = await invoke(buildBody(family, text, client.dimensions));
|
||||
return sanitizeAndNormalizeEmbedding(parseSingle(family, raw));
|
||||
};
|
||||
|
||||
const embedCohere = async (
|
||||
texts: string[],
|
||||
inputType: "search_query" | "search_document",
|
||||
): Promise<number[][]> => {
|
||||
const raw = await invoke(buildCohereBody(family, texts, inputType, client.dimensions));
|
||||
return parseCohereBatch(family, raw).map((e) => sanitizeAndNormalizeEmbedding(e));
|
||||
};
|
||||
|
||||
const embedQuery = async (text: string): Promise<number[]> => {
|
||||
if (!text.trim()) {
|
||||
return [];
|
||||
}
|
||||
if (isCohere) {
|
||||
return (await embedCohere([text], "search_query"))[0] ?? [];
|
||||
}
|
||||
return embedSingle(text);
|
||||
};
|
||||
|
||||
const embedBatch = async (texts: string[]): Promise<number[][]> => {
|
||||
if (texts.length === 0) {
|
||||
return [];
|
||||
}
|
||||
if (isCohere) {
|
||||
return embedCohere(texts, "search_document");
|
||||
}
|
||||
return Promise.all(texts.map((t) => (t.trim() ? embedSingle(t) : Promise.resolve([]))));
|
||||
};
|
||||
|
||||
return {
|
||||
provider: {
|
||||
id: "bedrock",
|
||||
model: client.model,
|
||||
maxInputTokens: spec?.maxTokens,
|
||||
embedQuery,
|
||||
embedBatch,
|
||||
},
|
||||
client,
|
||||
};
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Client resolution
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export function resolveBedrockEmbeddingClient(
|
||||
options: EmbeddingProviderOptions,
|
||||
): BedrockEmbeddingClient {
|
||||
const model = normalizeBedrockEmbeddingModel(options.model);
|
||||
const spec = resolveSpec(model);
|
||||
const providerConfig = options.config.models?.providers?.["amazon-bedrock"];
|
||||
|
||||
const region =
|
||||
regionFromUrl(options.remote?.baseUrl) ??
|
||||
regionFromUrl(providerConfig?.baseUrl) ??
|
||||
process.env.AWS_REGION ??
|
||||
process.env.AWS_DEFAULT_REGION ??
|
||||
"us-east-1";
|
||||
|
||||
let dimensions: number | undefined;
|
||||
if (options.outputDimensionality != null) {
|
||||
if (spec?.validDims && !spec.validDims.includes(options.outputDimensionality)) {
|
||||
throw new Error(
|
||||
`Invalid dimensions ${options.outputDimensionality} for ${model}. Valid values: ${spec.validDims.join(", ")}`,
|
||||
);
|
||||
}
|
||||
dimensions = options.outputDimensionality;
|
||||
} else {
|
||||
dimensions = spec?.dims;
|
||||
}
|
||||
|
||||
return { region, model, dimensions };
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Credential detection
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const CREDENTIAL_ENV_VARS = [
|
||||
"AWS_PROFILE",
|
||||
"AWS_BEARER_TOKEN_BEDROCK",
|
||||
"AWS_CONTAINER_CREDENTIALS_RELATIVE_URI",
|
||||
"AWS_CONTAINER_CREDENTIALS_FULL_URI",
|
||||
"AWS_EC2_METADATA_SERVICE_ENDPOINT",
|
||||
"AWS_WEB_IDENTITY_TOKEN_FILE",
|
||||
"AWS_ROLE_ARN",
|
||||
] as const;
|
||||
|
||||
export async function hasAwsCredentials(env: NodeJS.ProcessEnv = process.env): Promise<boolean> {
|
||||
if (env.AWS_ACCESS_KEY_ID?.trim() && env.AWS_SECRET_ACCESS_KEY?.trim()) {
|
||||
return true;
|
||||
}
|
||||
if (CREDENTIAL_ENV_VARS.some((k) => env[k]?.trim())) {
|
||||
return true;
|
||||
}
|
||||
const credentialProviderSdk = await loadCredentialProviderSdk();
|
||||
if (!credentialProviderSdk) {
|
||||
return false;
|
||||
}
|
||||
try {
|
||||
const credentials = await credentialProviderSdk.defaultProvider({
|
||||
timeout: 1000,
|
||||
maxRetries: 0,
|
||||
})();
|
||||
return typeof credentials.accessKeyId === "string" && credentials.accessKeyId.trim().length > 0;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -1,121 +0,0 @@
|
||||
import type { EmbeddingInput } from "./embedding-inputs.js";
|
||||
|
||||
export const DEFAULT_GEMINI_EMBEDDING_MODEL = "gemini-embedding-001";
|
||||
|
||||
export const GEMINI_EMBEDDING_2_MODELS = new Set([
|
||||
"gemini-embedding-2-preview",
|
||||
// Add the GA model name here once released.
|
||||
]);
|
||||
|
||||
const GEMINI_EMBEDDING_2_DEFAULT_DIMENSIONS = 3072;
|
||||
const GEMINI_EMBEDDING_2_VALID_DIMENSIONS = [768, 1536, 3072] as const;
|
||||
|
||||
export type GeminiTaskType =
|
||||
| "RETRIEVAL_QUERY"
|
||||
| "RETRIEVAL_DOCUMENT"
|
||||
| "SEMANTIC_SIMILARITY"
|
||||
| "CLASSIFICATION"
|
||||
| "CLUSTERING"
|
||||
| "QUESTION_ANSWERING"
|
||||
| "FACT_VERIFICATION";
|
||||
|
||||
export type GeminiTextPart = { text: string };
|
||||
export type GeminiInlinePart = {
|
||||
inlineData: { mimeType: string; data: string };
|
||||
};
|
||||
export type GeminiPart = GeminiTextPart | GeminiInlinePart;
|
||||
export type GeminiEmbeddingRequest = {
|
||||
content: { parts: GeminiPart[] };
|
||||
taskType: GeminiTaskType;
|
||||
outputDimensionality?: number;
|
||||
model?: string;
|
||||
};
|
||||
export type GeminiTextEmbeddingRequest = GeminiEmbeddingRequest;
|
||||
|
||||
/** Builds the text-only Gemini embedding request shape used across direct and batch APIs. */
|
||||
export function buildGeminiTextEmbeddingRequest(params: {
|
||||
text: string;
|
||||
taskType: GeminiTaskType;
|
||||
outputDimensionality?: number;
|
||||
modelPath?: string;
|
||||
}): GeminiTextEmbeddingRequest {
|
||||
return buildGeminiEmbeddingRequest({
|
||||
input: { text: params.text },
|
||||
taskType: params.taskType,
|
||||
outputDimensionality: params.outputDimensionality,
|
||||
modelPath: params.modelPath,
|
||||
});
|
||||
}
|
||||
|
||||
export function buildGeminiEmbeddingRequest(params: {
|
||||
input: EmbeddingInput;
|
||||
taskType: GeminiTaskType;
|
||||
outputDimensionality?: number;
|
||||
modelPath?: string;
|
||||
}): GeminiEmbeddingRequest {
|
||||
const request: GeminiEmbeddingRequest = {
|
||||
content: {
|
||||
parts: params.input.parts?.map((part) =>
|
||||
part.type === "text"
|
||||
? ({ text: part.text } satisfies GeminiTextPart)
|
||||
: ({
|
||||
inlineData: { mimeType: part.mimeType, data: part.data },
|
||||
} satisfies GeminiInlinePart),
|
||||
) ?? [{ text: params.input.text }],
|
||||
},
|
||||
taskType: params.taskType,
|
||||
};
|
||||
if (params.modelPath) {
|
||||
request.model = params.modelPath;
|
||||
}
|
||||
if (params.outputDimensionality != null) {
|
||||
request.outputDimensionality = params.outputDimensionality;
|
||||
}
|
||||
return request;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true if the given model name is a gemini-embedding-2 variant that
|
||||
* supports `outputDimensionality` and extended task types.
|
||||
*/
|
||||
export function isGeminiEmbedding2Model(model: string): boolean {
|
||||
return GEMINI_EMBEDDING_2_MODELS.has(model);
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate and return the `outputDimensionality` for gemini-embedding-2 models.
|
||||
* Returns `undefined` for older models (they don't support the param).
|
||||
*/
|
||||
export function resolveGeminiOutputDimensionality(
|
||||
model: string,
|
||||
requested?: number,
|
||||
): number | undefined {
|
||||
if (!isGeminiEmbedding2Model(model)) {
|
||||
return undefined;
|
||||
}
|
||||
if (requested == null) {
|
||||
return GEMINI_EMBEDDING_2_DEFAULT_DIMENSIONS;
|
||||
}
|
||||
const valid: readonly number[] = GEMINI_EMBEDDING_2_VALID_DIMENSIONS;
|
||||
if (!valid.includes(requested)) {
|
||||
throw new Error(
|
||||
`Invalid outputDimensionality ${requested} for ${model}. Valid values: ${valid.join(", ")}`,
|
||||
);
|
||||
}
|
||||
return requested;
|
||||
}
|
||||
|
||||
export function normalizeGeminiModel(model: string): string {
|
||||
const trimmed = model.trim();
|
||||
if (!trimmed) {
|
||||
return DEFAULT_GEMINI_EMBEDDING_MODEL;
|
||||
}
|
||||
const withoutPrefix = trimmed.replace(/^models\//, "");
|
||||
if (withoutPrefix.startsWith("gemini/")) {
|
||||
return withoutPrefix.slice("gemini/".length);
|
||||
}
|
||||
if (withoutPrefix.startsWith("google/")) {
|
||||
return withoutPrefix.slice("google/".length);
|
||||
}
|
||||
return withoutPrefix;
|
||||
}
|
||||
@@ -1,52 +0,0 @@
|
||||
import { describe, expect, it } from "vitest";
|
||||
import {
|
||||
buildGeminiEmbeddingRequest,
|
||||
DEFAULT_GEMINI_EMBEDDING_MODEL,
|
||||
normalizeGeminiModel,
|
||||
resolveGeminiOutputDimensionality,
|
||||
} from "./embeddings-gemini-request.js";
|
||||
|
||||
describe("package Gemini embedding request helpers", () => {
|
||||
it("builds multimodal v2 requests and resolves model settings", () => {
|
||||
expect(
|
||||
buildGeminiEmbeddingRequest({
|
||||
input: {
|
||||
text: "Image file: diagram.png",
|
||||
parts: [
|
||||
{ type: "text", text: "Image file: diagram.png" },
|
||||
{ type: "inline-data", mimeType: "image/png", data: "abc123" },
|
||||
],
|
||||
},
|
||||
taskType: "RETRIEVAL_DOCUMENT",
|
||||
modelPath: "models/gemini-embedding-2-preview",
|
||||
outputDimensionality: 1536,
|
||||
}),
|
||||
).toEqual({
|
||||
model: "models/gemini-embedding-2-preview",
|
||||
content: {
|
||||
parts: [
|
||||
{ text: "Image file: diagram.png" },
|
||||
{ inlineData: { mimeType: "image/png", data: "abc123" } },
|
||||
],
|
||||
},
|
||||
taskType: "RETRIEVAL_DOCUMENT",
|
||||
outputDimensionality: 1536,
|
||||
});
|
||||
expect(resolveGeminiOutputDimensionality("gemini-embedding-001")).toBeUndefined();
|
||||
expect(resolveGeminiOutputDimensionality("gemini-embedding-2-preview")).toBe(3072);
|
||||
expect(resolveGeminiOutputDimensionality("gemini-embedding-2-preview", 768)).toBe(768);
|
||||
expect(() => resolveGeminiOutputDimensionality("gemini-embedding-2-preview", 512)).toThrow(
|
||||
/Invalid outputDimensionality 512/,
|
||||
);
|
||||
expect(normalizeGeminiModel("models/gemini-embedding-2-preview")).toBe(
|
||||
"gemini-embedding-2-preview",
|
||||
);
|
||||
expect(normalizeGeminiModel("gemini/gemini-embedding-2-preview")).toBe(
|
||||
"gemini-embedding-2-preview",
|
||||
);
|
||||
expect(normalizeGeminiModel("google/gemini-embedding-2-preview")).toBe(
|
||||
"gemini-embedding-2-preview",
|
||||
);
|
||||
expect(normalizeGeminiModel("")).toBe(DEFAULT_GEMINI_EMBEDDING_MODEL);
|
||||
});
|
||||
});
|
||||
@@ -1,238 +0,0 @@
|
||||
import {
|
||||
collectProviderApiKeysForExecution,
|
||||
executeWithApiKeyRotation,
|
||||
} from "../../../../src/agents/api-key-rotation.js";
|
||||
import { requireApiKey, resolveApiKeyForProvider } from "../../../../src/agents/model-auth.js";
|
||||
import { parseGeminiAuth } from "../../../../src/infra/gemini-auth.js";
|
||||
import {
|
||||
DEFAULT_GOOGLE_API_BASE_URL,
|
||||
normalizeGoogleApiBaseUrl,
|
||||
} from "../../../../src/infra/google-api-base-url.js";
|
||||
import type { SsrFPolicy } from "../../../../src/infra/net/ssrf.js";
|
||||
import type { EmbeddingInput } from "./embedding-inputs.js";
|
||||
import { sanitizeAndNormalizeEmbedding } from "./embedding-vectors.js";
|
||||
import { debugEmbeddingsLog } from "./embeddings-debug.js";
|
||||
import {
|
||||
buildGeminiEmbeddingRequest,
|
||||
buildGeminiTextEmbeddingRequest,
|
||||
isGeminiEmbedding2Model,
|
||||
normalizeGeminiModel,
|
||||
resolveGeminiOutputDimensionality,
|
||||
} from "./embeddings-gemini-request.js";
|
||||
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js";
|
||||
import { buildRemoteBaseUrlPolicy, withRemoteHttpResponse } from "./remote-http.js";
|
||||
import { resolveMemorySecretInputString } from "./secret-input.js";
|
||||
|
||||
export {
|
||||
buildGeminiEmbeddingRequest,
|
||||
buildGeminiTextEmbeddingRequest,
|
||||
DEFAULT_GEMINI_EMBEDDING_MODEL,
|
||||
GEMINI_EMBEDDING_2_MODELS,
|
||||
isGeminiEmbedding2Model,
|
||||
normalizeGeminiModel,
|
||||
resolveGeminiOutputDimensionality,
|
||||
type GeminiEmbeddingRequest,
|
||||
type GeminiInlinePart,
|
||||
type GeminiPart,
|
||||
type GeminiTaskType,
|
||||
type GeminiTextEmbeddingRequest,
|
||||
type GeminiTextPart,
|
||||
} from "./embeddings-gemini-request.js";
|
||||
|
||||
export type GeminiEmbeddingClient = {
|
||||
baseUrl: string;
|
||||
headers: Record<string, string>;
|
||||
ssrfPolicy?: SsrFPolicy;
|
||||
model: string;
|
||||
modelPath: string;
|
||||
apiKeys: string[];
|
||||
outputDimensionality?: number;
|
||||
};
|
||||
|
||||
const GEMINI_MAX_INPUT_TOKENS: Record<string, number> = {
|
||||
"text-embedding-004": 2048,
|
||||
};
|
||||
function resolveRemoteApiKey(remoteApiKey: unknown): string | undefined {
|
||||
const trimmed = resolveMemorySecretInputString({
|
||||
value: remoteApiKey,
|
||||
path: "agents.*.memorySearch.remote.apiKey",
|
||||
});
|
||||
if (!trimmed) {
|
||||
return undefined;
|
||||
}
|
||||
if (trimmed === "GOOGLE_API_KEY" || trimmed === "GEMINI_API_KEY") {
|
||||
return process.env[trimmed]?.trim();
|
||||
}
|
||||
return trimmed;
|
||||
}
|
||||
|
||||
async function fetchGeminiEmbeddingPayload(params: {
|
||||
client: GeminiEmbeddingClient;
|
||||
endpoint: string;
|
||||
body: unknown;
|
||||
}): Promise<{
|
||||
embedding?: { values?: number[] };
|
||||
embeddings?: Array<{ values?: number[] }>;
|
||||
}> {
|
||||
return await executeWithApiKeyRotation({
|
||||
provider: "google",
|
||||
apiKeys: params.client.apiKeys,
|
||||
execute: async (apiKey) => {
|
||||
const authHeaders = parseGeminiAuth(apiKey);
|
||||
const headers = {
|
||||
...authHeaders.headers,
|
||||
...params.client.headers,
|
||||
};
|
||||
return await withRemoteHttpResponse({
|
||||
url: params.endpoint,
|
||||
ssrfPolicy: params.client.ssrfPolicy,
|
||||
init: {
|
||||
method: "POST",
|
||||
headers,
|
||||
body: JSON.stringify(params.body),
|
||||
},
|
||||
onResponse: async (res) => {
|
||||
if (!res.ok) {
|
||||
const text = await res.text();
|
||||
throw new Error(`gemini embeddings failed: ${res.status} ${text}`);
|
||||
}
|
||||
return (await res.json()) as {
|
||||
embedding?: { values?: number[] };
|
||||
embeddings?: Array<{ values?: number[] }>;
|
||||
};
|
||||
},
|
||||
});
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
function normalizeGeminiBaseUrl(raw: string): string {
|
||||
const trimmed = raw.replace(/\/+$/, "");
|
||||
const openAiIndex = trimmed.indexOf("/openai");
|
||||
if (openAiIndex > -1) {
|
||||
return normalizeGoogleApiBaseUrl(trimmed.slice(0, openAiIndex));
|
||||
}
|
||||
return normalizeGoogleApiBaseUrl(trimmed);
|
||||
}
|
||||
|
||||
function buildGeminiModelPath(model: string): string {
|
||||
return model.startsWith("models/") ? model : `models/${model}`;
|
||||
}
|
||||
|
||||
export async function createGeminiEmbeddingProvider(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<{ provider: EmbeddingProvider; client: GeminiEmbeddingClient }> {
|
||||
const client = await resolveGeminiEmbeddingClient(options);
|
||||
const baseUrl = client.baseUrl.replace(/\/$/, "");
|
||||
const embedUrl = `${baseUrl}/${client.modelPath}:embedContent`;
|
||||
const batchUrl = `${baseUrl}/${client.modelPath}:batchEmbedContents`;
|
||||
const isV2 = isGeminiEmbedding2Model(client.model);
|
||||
const outputDimensionality = client.outputDimensionality;
|
||||
|
||||
const embedQuery = async (text: string): Promise<number[]> => {
|
||||
if (!text.trim()) {
|
||||
return [];
|
||||
}
|
||||
const payload = await fetchGeminiEmbeddingPayload({
|
||||
client,
|
||||
endpoint: embedUrl,
|
||||
body: buildGeminiTextEmbeddingRequest({
|
||||
text,
|
||||
taskType: options.taskType ?? "RETRIEVAL_QUERY",
|
||||
outputDimensionality: isV2 ? outputDimensionality : undefined,
|
||||
}),
|
||||
});
|
||||
return sanitizeAndNormalizeEmbedding(payload.embedding?.values ?? []);
|
||||
};
|
||||
|
||||
const embedBatchInputs = async (inputs: EmbeddingInput[]): Promise<number[][]> => {
|
||||
if (inputs.length === 0) {
|
||||
return [];
|
||||
}
|
||||
const payload = await fetchGeminiEmbeddingPayload({
|
||||
client,
|
||||
endpoint: batchUrl,
|
||||
body: {
|
||||
requests: inputs.map((input) =>
|
||||
buildGeminiEmbeddingRequest({
|
||||
input,
|
||||
modelPath: client.modelPath,
|
||||
taskType: options.taskType ?? "RETRIEVAL_DOCUMENT",
|
||||
outputDimensionality: isV2 ? outputDimensionality : undefined,
|
||||
}),
|
||||
),
|
||||
},
|
||||
});
|
||||
const embeddings = Array.isArray(payload.embeddings) ? payload.embeddings : [];
|
||||
return inputs.map((_, index) => sanitizeAndNormalizeEmbedding(embeddings[index]?.values ?? []));
|
||||
};
|
||||
|
||||
const embedBatch = async (texts: string[]): Promise<number[][]> => {
|
||||
return await embedBatchInputs(
|
||||
texts.map((text) => ({
|
||||
text,
|
||||
})),
|
||||
);
|
||||
};
|
||||
|
||||
return {
|
||||
provider: {
|
||||
id: "gemini",
|
||||
model: client.model,
|
||||
maxInputTokens: GEMINI_MAX_INPUT_TOKENS[client.model],
|
||||
embedQuery,
|
||||
embedBatch,
|
||||
embedBatchInputs,
|
||||
},
|
||||
client,
|
||||
};
|
||||
}
|
||||
|
||||
export async function resolveGeminiEmbeddingClient(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<GeminiEmbeddingClient> {
|
||||
const remote = options.remote;
|
||||
const remoteApiKey = resolveRemoteApiKey(remote?.apiKey);
|
||||
const remoteBaseUrl = remote?.baseUrl?.trim();
|
||||
|
||||
const apiKey = remoteApiKey
|
||||
? remoteApiKey
|
||||
: requireApiKey(
|
||||
await resolveApiKeyForProvider({
|
||||
provider: "google",
|
||||
cfg: options.config,
|
||||
agentDir: options.agentDir,
|
||||
}),
|
||||
"google",
|
||||
);
|
||||
|
||||
const providerConfig = options.config.models?.providers?.google;
|
||||
const rawBaseUrl =
|
||||
remoteBaseUrl || providerConfig?.baseUrl?.trim() || DEFAULT_GOOGLE_API_BASE_URL;
|
||||
const baseUrl = normalizeGeminiBaseUrl(rawBaseUrl);
|
||||
const ssrfPolicy = buildRemoteBaseUrlPolicy(baseUrl);
|
||||
const headerOverrides = Object.assign({}, providerConfig?.headers, remote?.headers);
|
||||
const headers: Record<string, string> = {
|
||||
...headerOverrides,
|
||||
};
|
||||
const apiKeys = collectProviderApiKeysForExecution({
|
||||
provider: "google",
|
||||
primaryApiKey: apiKey,
|
||||
});
|
||||
const model = normalizeGeminiModel(options.model);
|
||||
const modelPath = buildGeminiModelPath(model);
|
||||
const outputDimensionality = resolveGeminiOutputDimensionality(
|
||||
model,
|
||||
options.outputDimensionality,
|
||||
);
|
||||
debugEmbeddingsLog("memory embeddings: gemini client", {
|
||||
rawBaseUrl,
|
||||
baseUrl,
|
||||
model,
|
||||
modelPath,
|
||||
outputDimensionality,
|
||||
embedEndpoint: `${baseUrl}/${modelPath}:embedContent`,
|
||||
batchEndpoint: `${baseUrl}/${modelPath}:batchEmbedContents`,
|
||||
});
|
||||
return { baseUrl, headers, ssrfPolicy, model, modelPath, apiKeys, outputDimensionality };
|
||||
}
|
||||
@@ -1 +0,0 @@
|
||||
export * from "../../../../src/memory-host-sdk/host/embeddings-lmstudio.js";
|
||||
@@ -1,19 +0,0 @@
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { DEFAULT_MISTRAL_EMBEDDING_MODEL, normalizeMistralModel } from "./embeddings-mistral.js";
|
||||
|
||||
describe("normalizeMistralModel", () => {
|
||||
it("returns the default model for empty values", () => {
|
||||
expect(normalizeMistralModel("")).toBe(DEFAULT_MISTRAL_EMBEDDING_MODEL);
|
||||
expect(normalizeMistralModel(" ")).toBe(DEFAULT_MISTRAL_EMBEDDING_MODEL);
|
||||
});
|
||||
|
||||
it("strips the mistral/ prefix", () => {
|
||||
expect(normalizeMistralModel("mistral/mistral-embed")).toBe("mistral-embed");
|
||||
expect(normalizeMistralModel(" mistral/custom-embed ")).toBe("custom-embed");
|
||||
});
|
||||
|
||||
it("keeps explicit non-prefixed models", () => {
|
||||
expect(normalizeMistralModel("mistral-embed")).toBe("mistral-embed");
|
||||
expect(normalizeMistralModel("custom-embed-v2")).toBe("custom-embed-v2");
|
||||
});
|
||||
});
|
||||
@@ -1,51 +0,0 @@
|
||||
import type { SsrFPolicy } from "../../../../src/infra/net/ssrf.js";
|
||||
import { normalizeEmbeddingModelWithPrefixes } from "./embeddings-model-normalize.js";
|
||||
import {
|
||||
createRemoteEmbeddingProvider,
|
||||
resolveRemoteEmbeddingClient,
|
||||
} from "./embeddings-remote-provider.js";
|
||||
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js";
|
||||
|
||||
export type MistralEmbeddingClient = {
|
||||
baseUrl: string;
|
||||
headers: Record<string, string>;
|
||||
ssrfPolicy?: SsrFPolicy;
|
||||
model: string;
|
||||
};
|
||||
|
||||
export const DEFAULT_MISTRAL_EMBEDDING_MODEL = "mistral-embed";
|
||||
const DEFAULT_MISTRAL_BASE_URL = "https://api.mistral.ai/v1";
|
||||
|
||||
export function normalizeMistralModel(model: string): string {
|
||||
return normalizeEmbeddingModelWithPrefixes({
|
||||
model,
|
||||
defaultModel: DEFAULT_MISTRAL_EMBEDDING_MODEL,
|
||||
prefixes: ["mistral/"],
|
||||
});
|
||||
}
|
||||
|
||||
export async function createMistralEmbeddingProvider(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<{ provider: EmbeddingProvider; client: MistralEmbeddingClient }> {
|
||||
const client = await resolveMistralEmbeddingClient(options);
|
||||
|
||||
return {
|
||||
provider: createRemoteEmbeddingProvider({
|
||||
id: "mistral",
|
||||
client,
|
||||
errorPrefix: "mistral embeddings failed",
|
||||
}),
|
||||
client,
|
||||
};
|
||||
}
|
||||
|
||||
export async function resolveMistralEmbeddingClient(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<MistralEmbeddingClient> {
|
||||
return await resolveRemoteEmbeddingClient({
|
||||
provider: "mistral",
|
||||
options,
|
||||
defaultBaseUrl: DEFAULT_MISTRAL_BASE_URL,
|
||||
normalizeModel: normalizeMistralModel,
|
||||
});
|
||||
}
|
||||
@@ -1,43 +0,0 @@
|
||||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
|
||||
const { createOllamaEmbeddingProviderMock } = vi.hoisted(() => ({
|
||||
createOllamaEmbeddingProviderMock: vi.fn(async (options: unknown) => ({
|
||||
provider: { source: "mock-provider", options },
|
||||
client: { source: "mock-client" },
|
||||
})),
|
||||
}));
|
||||
|
||||
vi.mock("../../../../src/plugin-sdk/ollama-runtime.js", () => ({
|
||||
DEFAULT_OLLAMA_EMBEDDING_MODEL: "nomic-embed-text",
|
||||
createOllamaEmbeddingProvider: createOllamaEmbeddingProviderMock,
|
||||
}));
|
||||
|
||||
describe("memory-host-sdk Ollama embedding facade", () => {
|
||||
beforeEach(() => {
|
||||
createOllamaEmbeddingProviderMock.mockClear();
|
||||
});
|
||||
|
||||
it("re-exports the default Ollama embedding model", async () => {
|
||||
const mod = await import("./embeddings-ollama.js");
|
||||
expect(mod.DEFAULT_OLLAMA_EMBEDDING_MODEL).toBe("nomic-embed-text");
|
||||
});
|
||||
|
||||
it("delegates provider creation to the plugin-sdk runtime facade", async () => {
|
||||
const mod = await import("./embeddings-ollama.js");
|
||||
const options = {
|
||||
provider: "ollama",
|
||||
model: "nomic-embed-text",
|
||||
fallback: "none",
|
||||
config: {},
|
||||
};
|
||||
|
||||
const result = await mod.createOllamaEmbeddingProvider(options as never);
|
||||
|
||||
expect(createOllamaEmbeddingProviderMock).toHaveBeenCalledTimes(1);
|
||||
expect(createOllamaEmbeddingProviderMock).toHaveBeenCalledWith(options);
|
||||
expect(result).toEqual({
|
||||
provider: { source: "mock-provider", options },
|
||||
client: { source: "mock-client" },
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,5 +0,0 @@
|
||||
export type { OllamaEmbeddingClient } from "../../../../src/plugin-sdk/ollama-runtime.js";
|
||||
export {
|
||||
createOllamaEmbeddingProvider,
|
||||
DEFAULT_OLLAMA_EMBEDDING_MODEL,
|
||||
} from "../../../../src/plugin-sdk/ollama-runtime.js";
|
||||
@@ -1,58 +0,0 @@
|
||||
import type { SsrFPolicy } from "../../../../src/infra/net/ssrf.js";
|
||||
import { OPENAI_DEFAULT_EMBEDDING_MODEL } from "../../../../src/plugins/provider-model-defaults.js";
|
||||
import { normalizeEmbeddingModelWithPrefixes } from "./embeddings-model-normalize.js";
|
||||
import {
|
||||
createRemoteEmbeddingProvider,
|
||||
resolveRemoteEmbeddingClient,
|
||||
} from "./embeddings-remote-provider.js";
|
||||
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js";
|
||||
|
||||
export type OpenAiEmbeddingClient = {
|
||||
baseUrl: string;
|
||||
headers: Record<string, string>;
|
||||
ssrfPolicy?: SsrFPolicy;
|
||||
model: string;
|
||||
};
|
||||
|
||||
const DEFAULT_OPENAI_BASE_URL = "https://api.openai.com/v1";
|
||||
export const DEFAULT_OPENAI_EMBEDDING_MODEL = OPENAI_DEFAULT_EMBEDDING_MODEL;
|
||||
const OPENAI_MAX_INPUT_TOKENS: Record<string, number> = {
|
||||
"text-embedding-3-small": 8192,
|
||||
"text-embedding-3-large": 8192,
|
||||
"text-embedding-ada-002": 8191,
|
||||
};
|
||||
|
||||
export function normalizeOpenAiModel(model: string): string {
|
||||
return normalizeEmbeddingModelWithPrefixes({
|
||||
model,
|
||||
defaultModel: DEFAULT_OPENAI_EMBEDDING_MODEL,
|
||||
prefixes: ["openai/"],
|
||||
});
|
||||
}
|
||||
|
||||
export async function createOpenAiEmbeddingProvider(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<{ provider: EmbeddingProvider; client: OpenAiEmbeddingClient }> {
|
||||
const client = await resolveOpenAiEmbeddingClient(options);
|
||||
|
||||
return {
|
||||
provider: createRemoteEmbeddingProvider({
|
||||
id: "openai",
|
||||
client,
|
||||
errorPrefix: "openai embeddings failed",
|
||||
maxInputTokens: OPENAI_MAX_INPUT_TOKENS[client.model],
|
||||
}),
|
||||
client,
|
||||
};
|
||||
}
|
||||
|
||||
export async function resolveOpenAiEmbeddingClient(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<OpenAiEmbeddingClient> {
|
||||
return await resolveRemoteEmbeddingClient({
|
||||
provider: "openai",
|
||||
options,
|
||||
defaultBaseUrl: DEFAULT_OPENAI_BASE_URL,
|
||||
normalizeModel: normalizeOpenAiModel,
|
||||
});
|
||||
}
|
||||
@@ -4,7 +4,7 @@ import type { EmbeddingProviderOptions } from "./embeddings.js";
|
||||
import { buildRemoteBaseUrlPolicy } from "./remote-http.js";
|
||||
import { resolveMemorySecretInputString } from "./secret-input.js";
|
||||
|
||||
export type RemoteEmbeddingProviderId = "openai" | "voyage" | "mistral";
|
||||
export type RemoteEmbeddingProviderId = string;
|
||||
|
||||
export async function resolveRemoteEmbeddingBearerClient(params: {
|
||||
provider: RemoteEmbeddingProviderId;
|
||||
|
||||
@@ -1,188 +0,0 @@
|
||||
import { afterEach, beforeAll, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import * as authModule from "../../../../src/agents/model-auth.js";
|
||||
import { type FetchMock, withFetchPreconnect } from "../../../../src/test-utils/fetch-mock.js";
|
||||
import { mockPublicPinnedHostname } from "./test-helpers/ssrf.js";
|
||||
|
||||
vi.mock("../../../../src/infra/net/fetch-guard.js", () => ({
|
||||
fetchWithSsrFGuard: async (params: {
|
||||
url: string;
|
||||
init?: RequestInit;
|
||||
fetchImpl?: typeof fetch;
|
||||
}) => {
|
||||
const fetchImpl = params.fetchImpl ?? globalThis.fetch;
|
||||
if (!fetchImpl) {
|
||||
throw new Error("fetch is not available");
|
||||
}
|
||||
const response = await fetchImpl(params.url, params.init);
|
||||
return {
|
||||
response,
|
||||
finalUrl: params.url,
|
||||
release: async () => {},
|
||||
};
|
||||
},
|
||||
}));
|
||||
|
||||
const { resolveApiKeyForProviderMock } = vi.hoisted(() => ({
|
||||
resolveApiKeyForProviderMock: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock("../../../../src/agents/model-auth.js", () => {
|
||||
return {
|
||||
resolveApiKeyForProvider: resolveApiKeyForProviderMock,
|
||||
requireApiKey: (auth: { apiKey?: string; mode?: string }, provider: string) => {
|
||||
if (auth.apiKey) {
|
||||
return auth.apiKey;
|
||||
}
|
||||
throw new Error(`No API key resolved for provider "${provider}" (auth mode: ${auth.mode}).`);
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
const createFetchMock = () => {
|
||||
const fetchMock = vi.fn<FetchMock>(
|
||||
async (_input: RequestInfo | URL, _init?: RequestInit) =>
|
||||
new Response(JSON.stringify({ data: [{ embedding: [0.1, 0.2, 0.3] }] }), {
|
||||
status: 200,
|
||||
headers: { "Content-Type": "application/json" },
|
||||
}),
|
||||
);
|
||||
return withFetchPreconnect(fetchMock);
|
||||
};
|
||||
|
||||
function installFetchMock(fetchMock: typeof globalThis.fetch) {
|
||||
vi.stubGlobal("fetch", fetchMock);
|
||||
}
|
||||
|
||||
let createVoyageEmbeddingProvider: typeof import("./embeddings-voyage.js").createVoyageEmbeddingProvider;
|
||||
let normalizeVoyageModel: typeof import("./embeddings-voyage.js").normalizeVoyageModel;
|
||||
|
||||
beforeAll(async () => {
|
||||
({ createVoyageEmbeddingProvider, normalizeVoyageModel } =
|
||||
await import("./embeddings-voyage.js"));
|
||||
});
|
||||
|
||||
beforeEach(() => {
|
||||
vi.useRealTimers();
|
||||
vi.doUnmock("undici");
|
||||
});
|
||||
|
||||
function mockVoyageApiKey() {
|
||||
vi.mocked(authModule.resolveApiKeyForProvider).mockResolvedValue({
|
||||
apiKey: "voyage-key-123",
|
||||
mode: "api-key",
|
||||
source: "test",
|
||||
});
|
||||
}
|
||||
|
||||
async function createDefaultVoyageProvider(
|
||||
model: string,
|
||||
fetchMock: ReturnType<typeof createFetchMock>,
|
||||
) {
|
||||
installFetchMock(fetchMock as unknown as typeof globalThis.fetch);
|
||||
mockPublicPinnedHostname();
|
||||
mockVoyageApiKey();
|
||||
return createVoyageEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "voyage",
|
||||
model,
|
||||
fallback: "none",
|
||||
});
|
||||
}
|
||||
|
||||
describe("voyage embedding provider", () => {
|
||||
afterEach(() => {
|
||||
vi.doUnmock("undici");
|
||||
vi.resetAllMocks();
|
||||
vi.unstubAllGlobals();
|
||||
});
|
||||
|
||||
it("configures client with correct defaults and headers", async () => {
|
||||
const fetchMock = createFetchMock();
|
||||
const result = await createDefaultVoyageProvider("voyage-4-large", fetchMock);
|
||||
|
||||
await result.provider.embedQuery("test query");
|
||||
|
||||
expect(authModule.resolveApiKeyForProvider).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ provider: "voyage" }),
|
||||
);
|
||||
|
||||
const call = fetchMock.mock.calls[0];
|
||||
expect(call).toBeDefined();
|
||||
const [url, init] = call as [RequestInfo | URL, RequestInit | undefined];
|
||||
expect(url).toBe("https://api.voyageai.com/v1/embeddings");
|
||||
|
||||
const headers = (init?.headers ?? {}) as Record<string, string>;
|
||||
expect(headers.Authorization).toBe("Bearer voyage-key-123");
|
||||
expect(headers["Content-Type"]).toBe("application/json");
|
||||
|
||||
const body = JSON.parse(init?.body as string);
|
||||
expect(body).toEqual({
|
||||
model: "voyage-4-large",
|
||||
input: ["test query"],
|
||||
input_type: "query",
|
||||
});
|
||||
});
|
||||
|
||||
it("respects remote overrides for baseUrl and apiKey", async () => {
|
||||
const fetchMock = createFetchMock();
|
||||
installFetchMock(fetchMock as unknown as typeof globalThis.fetch);
|
||||
mockPublicPinnedHostname();
|
||||
|
||||
const result = await createVoyageEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "voyage",
|
||||
model: "voyage-4-lite",
|
||||
fallback: "none",
|
||||
remote: {
|
||||
baseUrl: "https://example.com",
|
||||
apiKey: "remote-override-key",
|
||||
headers: { "X-Custom": "123" },
|
||||
},
|
||||
});
|
||||
|
||||
await result.provider.embedQuery("test");
|
||||
|
||||
const call = fetchMock.mock.calls[0];
|
||||
expect(call).toBeDefined();
|
||||
const [url, init] = call as [RequestInfo | URL, RequestInit | undefined];
|
||||
expect(url).toBe("https://example.com/embeddings");
|
||||
|
||||
const headers = (init?.headers ?? {}) as Record<string, string>;
|
||||
expect(headers.Authorization).toBe("Bearer remote-override-key");
|
||||
expect(headers["X-Custom"]).toBe("123");
|
||||
});
|
||||
|
||||
it("passes input_type=document for embedBatch", async () => {
|
||||
const fetchMock = withFetchPreconnect(
|
||||
vi.fn<FetchMock>(
|
||||
async (_input: RequestInfo | URL, _init?: RequestInit) =>
|
||||
new Response(
|
||||
JSON.stringify({
|
||||
data: [{ embedding: [0.1, 0.2] }, { embedding: [0.3, 0.4] }],
|
||||
}),
|
||||
{ status: 200, headers: { "Content-Type": "application/json" } },
|
||||
),
|
||||
),
|
||||
);
|
||||
const result = await createDefaultVoyageProvider("voyage-4-large", fetchMock);
|
||||
|
||||
await result.provider.embedBatch(["doc1", "doc2"]);
|
||||
|
||||
const call = fetchMock.mock.calls[0];
|
||||
expect(call).toBeDefined();
|
||||
const [, init] = call as [RequestInfo | URL, RequestInit | undefined];
|
||||
const body = JSON.parse(init?.body as string);
|
||||
expect(body).toEqual({
|
||||
model: "voyage-4-large",
|
||||
input: ["doc1", "doc2"],
|
||||
input_type: "document",
|
||||
});
|
||||
});
|
||||
|
||||
it("normalizes model names", async () => {
|
||||
expect(normalizeVoyageModel("voyage/voyage-large-2")).toBe("voyage-large-2");
|
||||
expect(normalizeVoyageModel("voyage-4-large")).toBe("voyage-4-large");
|
||||
expect(normalizeVoyageModel(" voyage-lite ")).toBe("voyage-lite");
|
||||
expect(normalizeVoyageModel("")).toBe("voyage-4-large"); // Default
|
||||
});
|
||||
});
|
||||
@@ -1,82 +0,0 @@
|
||||
import type { SsrFPolicy } from "../../../../src/infra/net/ssrf.js";
|
||||
import { normalizeEmbeddingModelWithPrefixes } from "./embeddings-model-normalize.js";
|
||||
import { resolveRemoteEmbeddingBearerClient } from "./embeddings-remote-client.js";
|
||||
import { fetchRemoteEmbeddingVectors } from "./embeddings-remote-fetch.js";
|
||||
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js";
|
||||
|
||||
export type VoyageEmbeddingClient = {
|
||||
baseUrl: string;
|
||||
headers: Record<string, string>;
|
||||
ssrfPolicy?: SsrFPolicy;
|
||||
model: string;
|
||||
};
|
||||
|
||||
export const DEFAULT_VOYAGE_EMBEDDING_MODEL = "voyage-4-large";
|
||||
const DEFAULT_VOYAGE_BASE_URL = "https://api.voyageai.com/v1";
|
||||
const VOYAGE_MAX_INPUT_TOKENS: Record<string, number> = {
|
||||
"voyage-3": 32000,
|
||||
"voyage-3-lite": 16000,
|
||||
"voyage-code-3": 32000,
|
||||
};
|
||||
|
||||
export function normalizeVoyageModel(model: string): string {
|
||||
return normalizeEmbeddingModelWithPrefixes({
|
||||
model,
|
||||
defaultModel: DEFAULT_VOYAGE_EMBEDDING_MODEL,
|
||||
prefixes: ["voyage/"],
|
||||
});
|
||||
}
|
||||
|
||||
export async function createVoyageEmbeddingProvider(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<{ provider: EmbeddingProvider; client: VoyageEmbeddingClient }> {
|
||||
const client = await resolveVoyageEmbeddingClient(options);
|
||||
const url = `${client.baseUrl.replace(/\/$/, "")}/embeddings`;
|
||||
|
||||
const embed = async (input: string[], input_type?: "query" | "document"): Promise<number[][]> => {
|
||||
if (input.length === 0) {
|
||||
return [];
|
||||
}
|
||||
const body: { model: string; input: string[]; input_type?: "query" | "document" } = {
|
||||
model: client.model,
|
||||
input,
|
||||
};
|
||||
if (input_type) {
|
||||
body.input_type = input_type;
|
||||
}
|
||||
|
||||
return await fetchRemoteEmbeddingVectors({
|
||||
url,
|
||||
headers: client.headers,
|
||||
ssrfPolicy: client.ssrfPolicy,
|
||||
body,
|
||||
errorPrefix: "voyage embeddings failed",
|
||||
});
|
||||
};
|
||||
|
||||
return {
|
||||
provider: {
|
||||
id: "voyage",
|
||||
model: client.model,
|
||||
maxInputTokens: VOYAGE_MAX_INPUT_TOKENS[client.model],
|
||||
embedQuery: async (text) => {
|
||||
const [vec] = await embed([text], "query");
|
||||
return vec ?? [];
|
||||
},
|
||||
embedBatch: async (texts) => embed(texts, "document"),
|
||||
},
|
||||
client,
|
||||
};
|
||||
}
|
||||
|
||||
export async function resolveVoyageEmbeddingClient(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<VoyageEmbeddingClient> {
|
||||
const { baseUrl, headers, ssrfPolicy } = await resolveRemoteEmbeddingBearerClient({
|
||||
provider: "voyage",
|
||||
options,
|
||||
defaultBaseUrl: DEFAULT_VOYAGE_BASE_URL,
|
||||
});
|
||||
const model = normalizeVoyageModel(options.model);
|
||||
return { baseUrl, headers, ssrfPolicy, model };
|
||||
}
|
||||
@@ -1,199 +1,8 @@
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import * as authModule from "../../../../src/agents/model-auth.js";
|
||||
import { createEmbeddingProvider, DEFAULT_LOCAL_MODEL } from "./embeddings.js";
|
||||
import * as nodeLlamaModule from "./node-llama.js";
|
||||
import { mockPublicPinnedHostname } from "./test-helpers/ssrf.js";
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { DEFAULT_LOCAL_MODEL } from "./embeddings.js";
|
||||
|
||||
const { resolveApiKeyForProviderMock } = vi.hoisted(() => ({
|
||||
resolveApiKeyForProviderMock: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock("../../../../src/agents/model-auth.js", () => {
|
||||
return {
|
||||
resolveApiKeyForProvider: resolveApiKeyForProviderMock,
|
||||
requireApiKey: (auth: { apiKey?: string; mode?: string }, provider: string) => {
|
||||
if (auth.apiKey) {
|
||||
return auth.apiKey;
|
||||
}
|
||||
throw new Error(`No API key resolved for provider "${provider}" (auth mode: ${auth.mode}).`);
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
vi.mock("../../../../src/infra/net/fetch-guard.js", () => ({
|
||||
fetchWithSsrFGuard: async (params: {
|
||||
url: string;
|
||||
init?: RequestInit;
|
||||
fetchImpl?: typeof fetch;
|
||||
}) => {
|
||||
const fetchImpl = params.fetchImpl ?? globalThis.fetch;
|
||||
if (!fetchImpl) {
|
||||
throw new Error("fetch is not available");
|
||||
}
|
||||
const response = await fetchImpl(params.url, params.init);
|
||||
return {
|
||||
response,
|
||||
finalUrl: params.url,
|
||||
release: async () => {},
|
||||
};
|
||||
},
|
||||
}));
|
||||
|
||||
const createEmbeddingDataFetchMock = () =>
|
||||
vi.fn(async (_input?: unknown, _init?: unknown) => ({
|
||||
ok: true,
|
||||
status: 200,
|
||||
json: async () => ({ data: [{ embedding: [1, 2, 3] }] }),
|
||||
}));
|
||||
|
||||
const createGeminiFetchMock = () =>
|
||||
vi.fn(async (_input?: unknown, _init?: unknown) => ({
|
||||
ok: true,
|
||||
status: 200,
|
||||
json: async () => ({ embedding: { values: [1, 2, 3] } }),
|
||||
}));
|
||||
|
||||
beforeEach(() => {
|
||||
vi.spyOn(authModule, "resolveApiKeyForProvider");
|
||||
vi.spyOn(nodeLlamaModule, "importNodeLlamaCpp");
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.resetAllMocks();
|
||||
vi.unstubAllGlobals();
|
||||
});
|
||||
|
||||
function installFetchMock(fetchMock: typeof globalThis.fetch) {
|
||||
vi.stubGlobal("fetch", fetchMock);
|
||||
}
|
||||
|
||||
function readFirstFetchRequest(fetchMock: { mock: { calls: unknown[][] } }) {
|
||||
const [url, init] = fetchMock.mock.calls[0] ?? [];
|
||||
return { url, init: init as RequestInit | undefined };
|
||||
}
|
||||
|
||||
function requireProvider(result: Awaited<ReturnType<typeof createEmbeddingProvider>>) {
|
||||
if (!result.provider) {
|
||||
throw new Error("Expected embedding provider");
|
||||
}
|
||||
return result.provider;
|
||||
}
|
||||
|
||||
function mockResolvedProviderKey(apiKey = "provider-key") {
|
||||
vi.mocked(authModule.resolveApiKeyForProvider).mockResolvedValue({
|
||||
apiKey,
|
||||
mode: "api-key",
|
||||
source: "test",
|
||||
});
|
||||
}
|
||||
|
||||
describe("package embedding provider smoke", () => {
|
||||
it("uses remote OpenAI baseUrl/apiKey and merges headers", async () => {
|
||||
const fetchMock = createEmbeddingDataFetchMock();
|
||||
installFetchMock(fetchMock as unknown as typeof globalThis.fetch);
|
||||
mockPublicPinnedHostname();
|
||||
mockResolvedProviderKey("provider-key");
|
||||
|
||||
const result = await createEmbeddingProvider({
|
||||
config: {
|
||||
models: {
|
||||
providers: {
|
||||
openai: {
|
||||
baseUrl: "https://api.openai.com/v1",
|
||||
headers: { "X-Provider": "p", "X-Shared": "provider" },
|
||||
},
|
||||
},
|
||||
},
|
||||
} as never,
|
||||
provider: "openai",
|
||||
remote: {
|
||||
baseUrl: "https://example.com/v1",
|
||||
apiKey: " remote-key ",
|
||||
headers: { "X-Shared": "remote", "X-Remote": "r" },
|
||||
},
|
||||
model: "text-embedding-3-small",
|
||||
fallback: "openai",
|
||||
});
|
||||
|
||||
await requireProvider(result).embedQuery("hello");
|
||||
|
||||
expect(authModule.resolveApiKeyForProvider).not.toHaveBeenCalled();
|
||||
const { url, init } = readFirstFetchRequest(fetchMock);
|
||||
expect(url).toBe("https://example.com/v1/embeddings");
|
||||
const headers = (init?.headers ?? {}) as Record<string, string>;
|
||||
expect(headers.Authorization).toBe("Bearer remote-key");
|
||||
expect(headers["X-Provider"]).toBe("p");
|
||||
expect(headers["X-Shared"]).toBe("remote");
|
||||
expect(headers["X-Remote"]).toBe("r");
|
||||
});
|
||||
|
||||
it("uses GEMINI_API_KEY env indirection for Gemini remote apiKey", async () => {
|
||||
const fetchMock = createGeminiFetchMock();
|
||||
installFetchMock(fetchMock as unknown as typeof globalThis.fetch);
|
||||
mockPublicPinnedHostname();
|
||||
vi.stubEnv("GEMINI_API_KEY", "env-gemini-key");
|
||||
|
||||
const result = await createEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "gemini",
|
||||
remote: {
|
||||
apiKey: "GEMINI_API_KEY", // pragma: allowlist secret
|
||||
},
|
||||
model: "text-embedding-004",
|
||||
fallback: "openai",
|
||||
});
|
||||
|
||||
await requireProvider(result).embedQuery("hello");
|
||||
|
||||
const { init } = readFirstFetchRequest(fetchMock);
|
||||
const headers = (init?.headers ?? {}) as Record<string, string>;
|
||||
expect(headers["x-goog-api-key"]).toBe("env-gemini-key");
|
||||
});
|
||||
|
||||
it("normalizes local embeddings and resolves the default local model", async () => {
|
||||
const resolveModelFileMock = vi.fn(async () => "/fake/model.gguf");
|
||||
vi.mocked(nodeLlamaModule.importNodeLlamaCpp).mockResolvedValue({
|
||||
getLlama: async () => ({
|
||||
loadModel: vi.fn().mockResolvedValue({
|
||||
createEmbeddingContext: vi.fn().mockResolvedValue({
|
||||
getEmbeddingFor: vi.fn().mockResolvedValue({
|
||||
vector: new Float32Array([2.35, 3.45, 0.63, 4.3]),
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
resolveModelFile: resolveModelFileMock,
|
||||
LlamaLogLevel: { error: 0 },
|
||||
} as never);
|
||||
|
||||
const result = await createEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "local",
|
||||
model: "",
|
||||
fallback: "none",
|
||||
});
|
||||
|
||||
const embedding = await requireProvider(result).embedQuery("test query");
|
||||
const magnitude = Math.sqrt(embedding.reduce((sum, value) => sum + value * value, 0));
|
||||
expect(magnitude).toBeCloseTo(1, 5);
|
||||
expect(resolveModelFileMock).toHaveBeenCalledWith(DEFAULT_LOCAL_MODEL, undefined);
|
||||
});
|
||||
|
||||
it("returns null provider when explicit primary and fallback auth paths fail", async () => {
|
||||
vi.mocked(authModule.resolveApiKeyForProvider).mockRejectedValue(
|
||||
new Error("No API key found for provider"),
|
||||
);
|
||||
|
||||
const result = await createEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "openai",
|
||||
model: "text-embedding-3-small",
|
||||
fallback: "gemini",
|
||||
});
|
||||
|
||||
expect(result.provider).toBeNull();
|
||||
expect(result.requestedProvider).toBe("openai");
|
||||
expect(result.fallbackFrom).toBe("openai");
|
||||
expect(result.providerUnavailableReason).toContain("Fallback to gemini failed");
|
||||
describe("package embeddings barrel", () => {
|
||||
it("re-exports the source local embedding contract", () => {
|
||||
expect(DEFAULT_LOCAL_MODEL).toContain("embeddinggemma");
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,373 +1 @@
|
||||
import fsSync from "node:fs";
|
||||
import type { Llama, LlamaEmbeddingContext, LlamaModel } from "node-llama-cpp";
|
||||
import type { OpenClawConfig } from "../../../../src/config/config.js";
|
||||
import type { SecretInput } from "../../../../src/config/types.secrets.js";
|
||||
import { formatErrorMessage } from "../../../../src/infra/errors.js";
|
||||
import { resolveUserPath } from "../../../../src/utils.js";
|
||||
import type { EmbeddingInput } from "./embedding-inputs.js";
|
||||
import { sanitizeAndNormalizeEmbedding } from "./embedding-vectors.js";
|
||||
import {
|
||||
createBedrockEmbeddingProvider,
|
||||
hasAwsCredentials,
|
||||
type BedrockEmbeddingClient,
|
||||
} from "./embeddings-bedrock.js";
|
||||
import {
|
||||
createGeminiEmbeddingProvider,
|
||||
type GeminiEmbeddingClient,
|
||||
type GeminiTaskType,
|
||||
} from "./embeddings-gemini.js";
|
||||
import {
|
||||
createLmstudioEmbeddingProvider,
|
||||
type LmstudioEmbeddingClient,
|
||||
} from "./embeddings-lmstudio.js";
|
||||
import {
|
||||
createMistralEmbeddingProvider,
|
||||
type MistralEmbeddingClient,
|
||||
} from "./embeddings-mistral.js";
|
||||
import { createOllamaEmbeddingProvider, type OllamaEmbeddingClient } from "./embeddings-ollama.js";
|
||||
import { createOpenAiEmbeddingProvider, type OpenAiEmbeddingClient } from "./embeddings-openai.js";
|
||||
import { createVoyageEmbeddingProvider, type VoyageEmbeddingClient } from "./embeddings-voyage.js";
|
||||
import { importNodeLlamaCpp } from "./node-llama.js";
|
||||
|
||||
export type { GeminiEmbeddingClient } from "./embeddings-gemini.js";
|
||||
export type { LmstudioEmbeddingClient } from "./embeddings-lmstudio.js";
|
||||
export type { MistralEmbeddingClient } from "./embeddings-mistral.js";
|
||||
export type { OpenAiEmbeddingClient } from "./embeddings-openai.js";
|
||||
export type { VoyageEmbeddingClient } from "./embeddings-voyage.js";
|
||||
export type { OllamaEmbeddingClient } from "./embeddings-ollama.js";
|
||||
export type { BedrockEmbeddingClient } from "./embeddings-bedrock.js";
|
||||
|
||||
export type EmbeddingProvider = {
|
||||
id: string;
|
||||
model: string;
|
||||
maxInputTokens?: number;
|
||||
embedQuery: (text: string) => Promise<number[]>;
|
||||
embedBatch: (texts: string[]) => Promise<number[][]>;
|
||||
embedBatchInputs?: (inputs: EmbeddingInput[]) => Promise<number[][]>;
|
||||
};
|
||||
|
||||
export type EmbeddingProviderId =
|
||||
| "openai"
|
||||
| "local"
|
||||
| "gemini"
|
||||
| "voyage"
|
||||
| "mistral"
|
||||
| "bedrock"
|
||||
| "lmstudio"
|
||||
| "ollama";
|
||||
export type EmbeddingProviderRequest = EmbeddingProviderId | "auto";
|
||||
export type EmbeddingProviderFallback = EmbeddingProviderId | "none";
|
||||
|
||||
// Remote providers considered for auto-selection when provider === "auto".
|
||||
// LM Studio and Ollama are intentionally excluded here so that "auto" mode does not
|
||||
// implicitly assume either instance is available.
|
||||
// Bedrock is handled separately when AWS credentials are detected.
|
||||
const REMOTE_EMBEDDING_PROVIDER_IDS = ["openai", "gemini", "voyage", "mistral"] as const;
|
||||
|
||||
export type EmbeddingProviderResult = {
|
||||
provider: EmbeddingProvider | null;
|
||||
requestedProvider: EmbeddingProviderRequest;
|
||||
fallbackFrom?: EmbeddingProviderId;
|
||||
fallbackReason?: string;
|
||||
providerUnavailableReason?: string;
|
||||
openAi?: OpenAiEmbeddingClient;
|
||||
gemini?: GeminiEmbeddingClient;
|
||||
voyage?: VoyageEmbeddingClient;
|
||||
mistral?: MistralEmbeddingClient;
|
||||
bedrock?: BedrockEmbeddingClient;
|
||||
lmstudio?: LmstudioEmbeddingClient;
|
||||
ollama?: OllamaEmbeddingClient;
|
||||
};
|
||||
|
||||
export type EmbeddingProviderOptions = {
|
||||
config: OpenClawConfig;
|
||||
agentDir?: string;
|
||||
provider: EmbeddingProviderRequest;
|
||||
remote?: {
|
||||
baseUrl?: string;
|
||||
apiKey?: SecretInput;
|
||||
headers?: Record<string, string>;
|
||||
};
|
||||
model: string;
|
||||
fallback: EmbeddingProviderFallback;
|
||||
local?: {
|
||||
modelPath?: string;
|
||||
modelCacheDir?: string;
|
||||
};
|
||||
/** Provider-specific output vector dimensions for supported embedding families. */
|
||||
outputDimensionality?: number;
|
||||
/** Gemini: override the default task type sent with embedding requests. */
|
||||
taskType?: GeminiTaskType;
|
||||
};
|
||||
|
||||
export const DEFAULT_LOCAL_MODEL =
|
||||
"hf:ggml-org/embeddinggemma-300m-qat-q8_0-GGUF/embeddinggemma-300m-qat-Q8_0.gguf";
|
||||
|
||||
function canAutoSelectLocal(options: EmbeddingProviderOptions): boolean {
|
||||
const modelPath = options.local?.modelPath?.trim();
|
||||
if (!modelPath) {
|
||||
return false;
|
||||
}
|
||||
if (/^(hf:|https?:)/i.test(modelPath)) {
|
||||
return false;
|
||||
}
|
||||
const resolved = resolveUserPath(modelPath);
|
||||
try {
|
||||
return fsSync.statSync(resolved).isFile();
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
function isMissingApiKeyError(err: unknown): boolean {
|
||||
const message = formatErrorMessage(err);
|
||||
return message.includes("No API key found for provider");
|
||||
}
|
||||
|
||||
export async function createLocalEmbeddingProvider(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<EmbeddingProvider> {
|
||||
const modelPath = options.local?.modelPath?.trim() || DEFAULT_LOCAL_MODEL;
|
||||
const modelCacheDir = options.local?.modelCacheDir?.trim();
|
||||
|
||||
// Lazy-load node-llama-cpp to keep startup light unless local is enabled.
|
||||
const { getLlama, resolveModelFile, LlamaLogLevel } = await importNodeLlamaCpp();
|
||||
|
||||
let llama: Llama | null = null;
|
||||
let embeddingModel: LlamaModel | null = null;
|
||||
let embeddingContext: LlamaEmbeddingContext | null = null;
|
||||
let initPromise: Promise<LlamaEmbeddingContext> | null = null;
|
||||
|
||||
const ensureContext = async (): Promise<LlamaEmbeddingContext> => {
|
||||
if (embeddingContext) {
|
||||
return embeddingContext;
|
||||
}
|
||||
if (initPromise) {
|
||||
return initPromise;
|
||||
}
|
||||
initPromise = (async () => {
|
||||
try {
|
||||
if (!llama) {
|
||||
llama = await getLlama({ logLevel: LlamaLogLevel.error });
|
||||
}
|
||||
if (!embeddingModel) {
|
||||
const resolved = await resolveModelFile(modelPath, modelCacheDir || undefined);
|
||||
embeddingModel = await llama.loadModel({ modelPath: resolved });
|
||||
}
|
||||
if (!embeddingContext) {
|
||||
embeddingContext = await embeddingModel.createEmbeddingContext();
|
||||
}
|
||||
return embeddingContext;
|
||||
} catch (err) {
|
||||
initPromise = null;
|
||||
throw err;
|
||||
}
|
||||
})();
|
||||
return initPromise;
|
||||
};
|
||||
|
||||
return {
|
||||
id: "local",
|
||||
model: modelPath,
|
||||
embedQuery: async (text) => {
|
||||
const ctx = await ensureContext();
|
||||
const embedding = await ctx.getEmbeddingFor(text);
|
||||
return sanitizeAndNormalizeEmbedding(Array.from(embedding.vector));
|
||||
},
|
||||
embedBatch: async (texts) => {
|
||||
const ctx = await ensureContext();
|
||||
const embeddings = await Promise.all(
|
||||
texts.map(async (text) => {
|
||||
const embedding = await ctx.getEmbeddingFor(text);
|
||||
return sanitizeAndNormalizeEmbedding(Array.from(embedding.vector));
|
||||
}),
|
||||
);
|
||||
return embeddings;
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
export async function createEmbeddingProvider(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<EmbeddingProviderResult> {
|
||||
const requestedProvider = options.provider;
|
||||
const fallback = options.fallback;
|
||||
|
||||
const createProvider = async (id: EmbeddingProviderId) => {
|
||||
if (id === "local") {
|
||||
const provider = await createLocalEmbeddingProvider(options);
|
||||
return { provider };
|
||||
}
|
||||
if (id === "lmstudio") {
|
||||
const { provider, client } = await createLmstudioEmbeddingProvider(options);
|
||||
return { provider, lmstudio: client };
|
||||
}
|
||||
if (id === "ollama") {
|
||||
const { provider, client } = await createOllamaEmbeddingProvider(options);
|
||||
return { provider, ollama: client };
|
||||
}
|
||||
if (id === "gemini") {
|
||||
const { provider, client } = await createGeminiEmbeddingProvider(options);
|
||||
return { provider, gemini: client };
|
||||
}
|
||||
if (id === "voyage") {
|
||||
const { provider, client } = await createVoyageEmbeddingProvider(options);
|
||||
return { provider, voyage: client };
|
||||
}
|
||||
if (id === "mistral") {
|
||||
const { provider, client } = await createMistralEmbeddingProvider(options);
|
||||
return { provider, mistral: client };
|
||||
}
|
||||
if (id === "bedrock") {
|
||||
const { provider, client } = await createBedrockEmbeddingProvider(options);
|
||||
return { provider, bedrock: client };
|
||||
}
|
||||
const { provider, client } = await createOpenAiEmbeddingProvider(options);
|
||||
return { provider, openAi: client };
|
||||
};
|
||||
|
||||
const formatPrimaryError = (err: unknown, provider: EmbeddingProviderId) =>
|
||||
provider === "local" ? formatLocalSetupError(err) : formatErrorMessage(err);
|
||||
|
||||
if (requestedProvider === "auto") {
|
||||
const missingKeyErrors: string[] = [];
|
||||
let localError: string | null = null;
|
||||
|
||||
if (canAutoSelectLocal(options)) {
|
||||
try {
|
||||
const local = await createProvider("local");
|
||||
return { ...local, requestedProvider };
|
||||
} catch (err) {
|
||||
localError = formatLocalSetupError(err);
|
||||
}
|
||||
}
|
||||
|
||||
for (const provider of REMOTE_EMBEDDING_PROVIDER_IDS) {
|
||||
try {
|
||||
const result = await createProvider(provider);
|
||||
return { ...result, requestedProvider };
|
||||
} catch (err) {
|
||||
const message = formatPrimaryError(err, provider);
|
||||
if (isMissingApiKeyError(err)) {
|
||||
missingKeyErrors.push(message);
|
||||
continue;
|
||||
}
|
||||
// Non-auth errors (e.g., network) are still fatal
|
||||
const wrapped = new Error(message) as Error & { cause?: unknown };
|
||||
wrapped.cause = err;
|
||||
throw wrapped;
|
||||
}
|
||||
}
|
||||
|
||||
// Try bedrock if AWS credentials are available
|
||||
if (await hasAwsCredentials()) {
|
||||
try {
|
||||
const result = await createProvider("bedrock");
|
||||
return { ...result, requestedProvider };
|
||||
} catch (err) {
|
||||
const message = formatPrimaryError(err, "bedrock");
|
||||
if (isMissingApiKeyError(err)) {
|
||||
missingKeyErrors.push(message);
|
||||
} else {
|
||||
const wrapped = new Error(message) as Error & { cause?: unknown };
|
||||
wrapped.cause = err;
|
||||
throw wrapped;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// All providers failed due to missing API keys - return null provider for FTS-only mode
|
||||
const details = [...missingKeyErrors, localError].filter(Boolean) as string[];
|
||||
const reason = details.length > 0 ? details.join("\n\n") : "No embeddings provider available.";
|
||||
return {
|
||||
provider: null,
|
||||
requestedProvider,
|
||||
providerUnavailableReason: reason,
|
||||
};
|
||||
}
|
||||
|
||||
try {
|
||||
const primary = await createProvider(requestedProvider);
|
||||
return { ...primary, requestedProvider };
|
||||
} catch (primaryErr) {
|
||||
const reason = formatPrimaryError(primaryErr, requestedProvider);
|
||||
if (fallback && fallback !== "none" && fallback !== requestedProvider) {
|
||||
try {
|
||||
const fallbackResult = await createProvider(fallback);
|
||||
return {
|
||||
...fallbackResult,
|
||||
requestedProvider,
|
||||
fallbackFrom: requestedProvider,
|
||||
fallbackReason: reason,
|
||||
};
|
||||
} catch (fallbackErr) {
|
||||
// Both primary and fallback failed - check if it's auth-related
|
||||
const fallbackReason = formatErrorMessage(fallbackErr);
|
||||
const combinedReason = `${reason}\n\nFallback to ${fallback} failed: ${fallbackReason}`;
|
||||
if (isMissingApiKeyError(primaryErr) && isMissingApiKeyError(fallbackErr)) {
|
||||
// Both failed due to missing API keys - return null for FTS-only mode
|
||||
return {
|
||||
provider: null,
|
||||
requestedProvider,
|
||||
fallbackFrom: requestedProvider,
|
||||
fallbackReason: reason,
|
||||
providerUnavailableReason: combinedReason,
|
||||
};
|
||||
}
|
||||
// Non-auth errors are still fatal
|
||||
const wrapped = new Error(combinedReason) as Error & {
|
||||
cause?: unknown;
|
||||
};
|
||||
wrapped.cause = fallbackErr;
|
||||
throw wrapped;
|
||||
}
|
||||
}
|
||||
// No fallback configured - check if we should degrade to FTS-only
|
||||
if (isMissingApiKeyError(primaryErr)) {
|
||||
return {
|
||||
provider: null,
|
||||
requestedProvider,
|
||||
providerUnavailableReason: reason,
|
||||
};
|
||||
}
|
||||
const wrapped = new Error(reason) as Error & { cause?: unknown };
|
||||
wrapped.cause = primaryErr;
|
||||
throw wrapped;
|
||||
}
|
||||
}
|
||||
|
||||
function isNodeLlamaCppMissing(err: unknown): boolean {
|
||||
if (!(err instanceof Error)) {
|
||||
return false;
|
||||
}
|
||||
const code = (err as Error & { code?: unknown }).code;
|
||||
if (code === "ERR_MODULE_NOT_FOUND") {
|
||||
return err.message.includes("node-llama-cpp");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
function formatLocalSetupError(err: unknown): string {
|
||||
const detail = formatErrorMessage(err);
|
||||
const missing = isNodeLlamaCppMissing(err);
|
||||
return [
|
||||
"Local embeddings unavailable.",
|
||||
missing
|
||||
? "Reason: optional dependency node-llama-cpp is missing (or failed to install)."
|
||||
: detail
|
||||
? `Reason: ${detail}`
|
||||
: undefined,
|
||||
missing && detail ? `Detail: ${detail}` : null,
|
||||
"To enable local embeddings:",
|
||||
"1) Use Node 24 (recommended for installs/updates; Node 22 LTS, currently 22.14+, remains supported)",
|
||||
missing
|
||||
? "2) Reinstall OpenClaw (this should install node-llama-cpp): npm i -g openclaw@latest"
|
||||
: null,
|
||||
"3) If you use pnpm: pnpm approve-builds (select node-llama-cpp), then pnpm rebuild node-llama-cpp",
|
||||
...REMOTE_EMBEDDING_PROVIDER_IDS.map(
|
||||
(provider) => `Or set agents.defaults.memorySearch.provider = "${provider}" (remote).`,
|
||||
),
|
||||
]
|
||||
.filter(Boolean)
|
||||
.join("\n");
|
||||
}
|
||||
export * from "../../../../src/memory-host-sdk/host/embeddings.js";
|
||||
|
||||
@@ -100,21 +100,3 @@ export function classifyMemoryMultimodalPath(
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
export function normalizeGeminiEmbeddingModelForMemory(model: string): string {
|
||||
const trimmed = model.trim();
|
||||
if (!trimmed) {
|
||||
return "";
|
||||
}
|
||||
return trimmed.replace(/^models\//, "").replace(/^(gemini|google)\//, "");
|
||||
}
|
||||
|
||||
export function supportsMemoryMultimodalEmbeddings(params: {
|
||||
provider: string;
|
||||
model: string;
|
||||
}): boolean {
|
||||
if (params.provider !== "gemini") {
|
||||
return false;
|
||||
}
|
||||
return normalizeGeminiEmbeddingModelForMemory(params.model) === "gemini-embedding-2-preview";
|
||||
}
|
||||
|
||||
12
pnpm-lock.yaml
generated
12
pnpm-lock.yaml
generated
@@ -362,6 +362,12 @@ importers:
|
||||
'@aws-sdk/client-bedrock':
|
||||
specifier: 3.1028.0
|
||||
version: 3.1028.0
|
||||
'@aws-sdk/client-bedrock-runtime':
|
||||
specifier: 3.1028.0
|
||||
version: 3.1028.0
|
||||
'@aws-sdk/credential-provider-node':
|
||||
specifier: 3.972.30
|
||||
version: 3.972.30
|
||||
devDependencies:
|
||||
'@openclaw/plugin-sdk':
|
||||
specifier: workspace:*
|
||||
@@ -1225,6 +1231,12 @@ importers:
|
||||
specifier: workspace:*
|
||||
version: link:../../packages/plugin-sdk
|
||||
|
||||
extensions/voyage:
|
||||
devDependencies:
|
||||
'@openclaw/plugin-sdk':
|
||||
specifier: workspace:*
|
||||
version: link:../../packages/plugin-sdk
|
||||
|
||||
extensions/vydra:
|
||||
devDependencies:
|
||||
'@openclaw/plugin-sdk':
|
||||
|
||||
@@ -6,7 +6,6 @@ import type { SecretInput } from "../config/types.secrets.js";
|
||||
import {
|
||||
isMemoryMultimodalEnabled,
|
||||
normalizeMemoryMultimodalSettings,
|
||||
supportsMemoryMultimodalEmbeddings,
|
||||
type MemoryMultimodalSettings,
|
||||
} from "../memory-host-sdk/multimodal.js";
|
||||
import { getMemoryEmbeddingProvider } from "../plugins/memory-embedding-provider-runtime.js";
|
||||
@@ -389,24 +388,9 @@ export function resolveMemorySearchConfig(
|
||||
const multimodalActive = isMemoryMultimodalEnabled(resolved.multimodal);
|
||||
const multimodalProvider =
|
||||
resolved.provider === "auto" ? undefined : getMemoryEmbeddingProvider(resolved.provider);
|
||||
const builtinMultimodalSupport =
|
||||
resolved.provider === "auto"
|
||||
? false
|
||||
: supportsMemoryMultimodalEmbeddings({
|
||||
provider: resolved.provider,
|
||||
model: resolved.model,
|
||||
});
|
||||
if (
|
||||
multimodalActive &&
|
||||
!(
|
||||
// Fall back to the built-in helper when the provider is not registered yet
|
||||
// or when a registered adapter does not implement multimodal capability checks.
|
||||
(
|
||||
multimodalProvider?.supportsMultimodalEmbeddings?.({
|
||||
model: resolved.model,
|
||||
}) ?? builtinMultimodalSupport
|
||||
)
|
||||
)
|
||||
!(multimodalProvider?.supportsMultimodalEmbeddings?.({ model: resolved.model }) ?? false)
|
||||
) {
|
||||
throw new Error(
|
||||
"agents.*.memorySearch.multimodal requires a provider adapter that supports multimodal embeddings for the configured model.",
|
||||
|
||||
@@ -16,50 +16,56 @@ export type {
|
||||
MemoryEmbeddingProviderRuntime,
|
||||
} from "../plugins/memory-embedding-providers.js";
|
||||
export { createLocalEmbeddingProvider, DEFAULT_LOCAL_MODEL } from "./host/embeddings.js";
|
||||
export { extractBatchErrorMessage, formatUnavailableBatchError } from "./host/batch-error-utils.js";
|
||||
export { postJsonWithRetry } from "./host/batch-http.js";
|
||||
export { applyEmbeddingBatchOutputLine } from "./host/batch-output.js";
|
||||
export {
|
||||
createGeminiEmbeddingProvider,
|
||||
DEFAULT_GEMINI_EMBEDDING_MODEL,
|
||||
buildGeminiEmbeddingRequest,
|
||||
} from "./host/embeddings-gemini.js";
|
||||
EMBEDDING_BATCH_ENDPOINT,
|
||||
type EmbeddingBatchStatus,
|
||||
type ProviderBatchOutputLine,
|
||||
} from "./host/batch-provider-common.js";
|
||||
export {
|
||||
createLmstudioEmbeddingProvider,
|
||||
DEFAULT_LMSTUDIO_EMBEDDING_MODEL,
|
||||
} from "./host/embeddings-lmstudio.js";
|
||||
export type { LmstudioEmbeddingClient } from "./host/embeddings-lmstudio.js";
|
||||
buildEmbeddingBatchGroupOptions,
|
||||
runEmbeddingBatchGroups,
|
||||
type EmbeddingBatchExecutionParams,
|
||||
} from "./host/batch-runner.js";
|
||||
export {
|
||||
createMistralEmbeddingProvider,
|
||||
DEFAULT_MISTRAL_EMBEDDING_MODEL,
|
||||
} from "./host/embeddings-mistral.js";
|
||||
resolveBatchCompletionFromStatus,
|
||||
resolveCompletedBatchResult,
|
||||
throwIfBatchTerminalFailure,
|
||||
type BatchCompletionResult,
|
||||
} from "./host/batch-status.js";
|
||||
export { uploadBatchJsonlFile } from "./host/batch-upload.js";
|
||||
export {
|
||||
createGitHubCopilotEmbeddingProvider,
|
||||
type GitHubCopilotEmbeddingClient,
|
||||
} from "./host/embeddings-github-copilot.js";
|
||||
export {
|
||||
createOllamaEmbeddingProvider,
|
||||
DEFAULT_OLLAMA_EMBEDDING_MODEL,
|
||||
} from "./host/embeddings-ollama.js";
|
||||
export type { OllamaEmbeddingClient } from "./host/embeddings-ollama.js";
|
||||
export {
|
||||
createOpenAiEmbeddingProvider,
|
||||
DEFAULT_OPENAI_EMBEDDING_MODEL,
|
||||
} from "./host/embeddings-openai.js";
|
||||
export {
|
||||
createVoyageEmbeddingProvider,
|
||||
DEFAULT_VOYAGE_EMBEDDING_MODEL,
|
||||
} from "./host/embeddings-voyage.js";
|
||||
export { runGeminiEmbeddingBatches, type GeminiBatchRequest } from "./host/batch-gemini.js";
|
||||
export {
|
||||
OPENAI_BATCH_ENDPOINT,
|
||||
runOpenAiEmbeddingBatches,
|
||||
type OpenAiBatchRequest,
|
||||
} from "./host/batch-openai.js";
|
||||
export { runVoyageEmbeddingBatches, type VoyageBatchRequest } from "./host/batch-voyage.js";
|
||||
buildBatchHeaders,
|
||||
normalizeBatchBaseUrl,
|
||||
type BatchHttpClientConfig,
|
||||
} from "./host/batch-utils.js";
|
||||
export { enforceEmbeddingMaxInputTokens } from "./host/embedding-chunk-limits.js";
|
||||
export {
|
||||
isMissingEmbeddingApiKeyError,
|
||||
mapBatchEmbeddingsByIndex,
|
||||
sanitizeEmbeddingCacheHeaders,
|
||||
} from "./host/embedding-provider-adapter-utils.js";
|
||||
export { sanitizeAndNormalizeEmbedding } from "./host/embedding-vectors.js";
|
||||
export { debugEmbeddingsLog } from "./host/embeddings-debug.js";
|
||||
export { normalizeEmbeddingModelWithPrefixes } from "./host/embeddings-model-normalize.js";
|
||||
export {
|
||||
resolveRemoteEmbeddingBearerClient,
|
||||
type RemoteEmbeddingProviderId,
|
||||
} from "./host/embeddings-remote-client.js";
|
||||
export {
|
||||
createRemoteEmbeddingProvider,
|
||||
resolveRemoteEmbeddingClient,
|
||||
type RemoteEmbeddingClient,
|
||||
} from "./host/embeddings-remote-provider.js";
|
||||
export { fetchRemoteEmbeddingVectors } from "./host/embeddings-remote-fetch.js";
|
||||
export {
|
||||
estimateStructuredEmbeddingInputBytes,
|
||||
estimateUtf8Bytes,
|
||||
} from "./host/embedding-input-limits.js";
|
||||
export { hasNonTextEmbeddingParts, type EmbeddingInput } from "./host/embedding-inputs.js";
|
||||
export { buildRemoteBaseUrlPolicy, withRemoteHttpResponse } from "./host/remote-http.js";
|
||||
export {
|
||||
buildCaseInsensitiveExtensionGlob,
|
||||
classifyMemoryMultimodalPath,
|
||||
|
||||
@@ -1,116 +0,0 @@
|
||||
import { afterEach, beforeAll, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import type { GeminiEmbeddingClient } from "./embeddings-gemini.js";
|
||||
|
||||
vi.mock("./remote-http.js", () => ({
|
||||
withRemoteHttpResponse: vi.fn(),
|
||||
}));
|
||||
|
||||
function magnitude(values: number[]) {
|
||||
return Math.sqrt(values.reduce((sum, value) => sum + value * value, 0));
|
||||
}
|
||||
|
||||
describe("runGeminiEmbeddingBatches", () => {
|
||||
let runGeminiEmbeddingBatches: typeof import("./batch-gemini.js").runGeminiEmbeddingBatches;
|
||||
let withRemoteHttpResponse: typeof import("./remote-http.js").withRemoteHttpResponse;
|
||||
let remoteHttpMock: ReturnType<typeof vi.mocked<typeof withRemoteHttpResponse>>;
|
||||
|
||||
beforeAll(async () => {
|
||||
({ runGeminiEmbeddingBatches } = await import("./batch-gemini.js"));
|
||||
({ withRemoteHttpResponse } = await import("./remote-http.js"));
|
||||
remoteHttpMock = vi.mocked(withRemoteHttpResponse);
|
||||
});
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.resetAllMocks();
|
||||
vi.unstubAllGlobals();
|
||||
});
|
||||
|
||||
const mockClient: GeminiEmbeddingClient = {
|
||||
baseUrl: "https://generativelanguage.googleapis.com/v1beta",
|
||||
headers: {},
|
||||
model: "gemini-embedding-2-preview",
|
||||
modelPath: "models/gemini-embedding-2-preview",
|
||||
apiKeys: ["test-key"],
|
||||
outputDimensionality: 1536,
|
||||
};
|
||||
|
||||
it("includes outputDimensionality in batch upload requests", async () => {
|
||||
remoteHttpMock.mockImplementationOnce(async (params) => {
|
||||
expect(params.url).toContain("/upload/v1beta/files?uploadType=multipart");
|
||||
const body = params.init?.body;
|
||||
if (!(body instanceof Blob)) {
|
||||
throw new Error("expected multipart blob body");
|
||||
}
|
||||
const text = await body.text();
|
||||
expect(text).toContain('"taskType":"RETRIEVAL_DOCUMENT"');
|
||||
expect(text).toContain('"outputDimensionality":1536');
|
||||
return await params.onResponse(
|
||||
new Response(JSON.stringify({ name: "files/file-123" }), {
|
||||
status: 200,
|
||||
headers: { "Content-Type": "application/json" },
|
||||
}),
|
||||
);
|
||||
});
|
||||
remoteHttpMock.mockImplementationOnce(async (params) => {
|
||||
expect(params.url).toMatch(/:asyncBatchEmbedContent$/u);
|
||||
return await params.onResponse(
|
||||
new Response(
|
||||
JSON.stringify({
|
||||
name: "batches/batch-1",
|
||||
state: "COMPLETED",
|
||||
outputConfig: { file: "files/output-1" },
|
||||
}),
|
||||
{
|
||||
status: 200,
|
||||
headers: { "Content-Type": "application/json" },
|
||||
},
|
||||
),
|
||||
);
|
||||
});
|
||||
remoteHttpMock.mockImplementationOnce(async (params) => {
|
||||
expect(params.url).toMatch(/\/files\/output-1:download$/u);
|
||||
return await params.onResponse(
|
||||
new Response(
|
||||
JSON.stringify({
|
||||
key: "req-1",
|
||||
response: { embedding: { values: [3, 4] } },
|
||||
}),
|
||||
{
|
||||
status: 200,
|
||||
headers: { "Content-Type": "application/jsonl" },
|
||||
},
|
||||
),
|
||||
);
|
||||
});
|
||||
|
||||
const results = await runGeminiEmbeddingBatches({
|
||||
gemini: mockClient,
|
||||
agentId: "main",
|
||||
requests: [
|
||||
{
|
||||
custom_id: "req-1",
|
||||
request: {
|
||||
content: { parts: [{ text: "hello world" }] },
|
||||
taskType: "RETRIEVAL_DOCUMENT",
|
||||
outputDimensionality: 1536,
|
||||
},
|
||||
},
|
||||
],
|
||||
wait: true,
|
||||
pollIntervalMs: 1,
|
||||
timeoutMs: 1000,
|
||||
concurrency: 1,
|
||||
});
|
||||
|
||||
const embedding = results.get("req-1");
|
||||
expect(embedding).toBeDefined();
|
||||
expect(embedding?.[0]).toBeCloseTo(0.6, 5);
|
||||
expect(embedding?.[1]).toBeCloseTo(0.8, 5);
|
||||
expect(magnitude(embedding ?? [])).toBeCloseTo(1, 5);
|
||||
expect(remoteHttpMock).toHaveBeenCalledTimes(3);
|
||||
});
|
||||
});
|
||||
@@ -1,368 +0,0 @@
|
||||
import {
|
||||
buildEmbeddingBatchGroupOptions,
|
||||
runEmbeddingBatchGroups,
|
||||
type EmbeddingBatchExecutionParams,
|
||||
} from "./batch-runner.js";
|
||||
import { buildBatchHeaders, normalizeBatchBaseUrl } from "./batch-utils.js";
|
||||
import { sanitizeAndNormalizeEmbedding } from "./embedding-vectors.js";
|
||||
import { debugEmbeddingsLog } from "./embeddings-debug.js";
|
||||
import type { GeminiEmbeddingClient, GeminiTextEmbeddingRequest } from "./embeddings-gemini.js";
|
||||
import { hashText } from "./internal.js";
|
||||
import { withRemoteHttpResponse } from "./remote-http.js";
|
||||
|
||||
export type GeminiBatchRequest = {
|
||||
custom_id: string;
|
||||
request: GeminiTextEmbeddingRequest;
|
||||
};
|
||||
|
||||
export type GeminiBatchStatus = {
|
||||
name?: string;
|
||||
state?: string;
|
||||
outputConfig?: { file?: string; fileId?: string };
|
||||
metadata?: {
|
||||
output?: {
|
||||
responsesFile?: string;
|
||||
};
|
||||
};
|
||||
error?: { message?: string };
|
||||
};
|
||||
|
||||
export type GeminiBatchOutputLine = {
|
||||
key?: string;
|
||||
custom_id?: string;
|
||||
request_id?: string;
|
||||
embedding?: { values?: number[] };
|
||||
response?: {
|
||||
embedding?: { values?: number[] };
|
||||
error?: { message?: string };
|
||||
};
|
||||
error?: { message?: string };
|
||||
};
|
||||
|
||||
const GEMINI_BATCH_MAX_REQUESTS = 50000;
|
||||
function getGeminiUploadUrl(baseUrl: string): string {
|
||||
if (baseUrl.includes("/v1beta")) {
|
||||
return baseUrl.replace(/\/v1beta\/?$/, "/upload/v1beta");
|
||||
}
|
||||
return `${baseUrl.replace(/\/$/, "")}/upload`;
|
||||
}
|
||||
|
||||
function buildGeminiUploadBody(params: { jsonl: string; displayName: string }): {
|
||||
body: Blob;
|
||||
contentType: string;
|
||||
} {
|
||||
const boundary = `openclaw-${hashText(params.displayName)}`;
|
||||
const jsonPart = JSON.stringify({
|
||||
file: {
|
||||
displayName: params.displayName,
|
||||
mimeType: "application/jsonl",
|
||||
},
|
||||
});
|
||||
const delimiter = `--${boundary}\r\n`;
|
||||
const closeDelimiter = `--${boundary}--\r\n`;
|
||||
const parts = [
|
||||
`${delimiter}Content-Type: application/json; charset=UTF-8\r\n\r\n${jsonPart}\r\n`,
|
||||
`${delimiter}Content-Type: application/jsonl; charset=UTF-8\r\n\r\n${params.jsonl}\r\n`,
|
||||
closeDelimiter,
|
||||
];
|
||||
const body = new Blob([parts.join("")], { type: "multipart/related" });
|
||||
return {
|
||||
body,
|
||||
contentType: `multipart/related; boundary=${boundary}`,
|
||||
};
|
||||
}
|
||||
|
||||
async function submitGeminiBatch(params: {
|
||||
gemini: GeminiEmbeddingClient;
|
||||
requests: GeminiBatchRequest[];
|
||||
agentId: string;
|
||||
}): Promise<GeminiBatchStatus> {
|
||||
const baseUrl = normalizeBatchBaseUrl(params.gemini);
|
||||
const jsonl = params.requests
|
||||
.map((request) =>
|
||||
JSON.stringify({
|
||||
key: request.custom_id,
|
||||
request: request.request,
|
||||
}),
|
||||
)
|
||||
.join("\n");
|
||||
const displayName = `memory-embeddings-${hashText(String(Date.now()))}`;
|
||||
const uploadPayload = buildGeminiUploadBody({ jsonl, displayName });
|
||||
|
||||
const uploadUrl = `${getGeminiUploadUrl(baseUrl)}/files?uploadType=multipart`;
|
||||
debugEmbeddingsLog("memory embeddings: gemini batch upload", {
|
||||
uploadUrl,
|
||||
baseUrl,
|
||||
requests: params.requests.length,
|
||||
});
|
||||
const filePayload = await withRemoteHttpResponse({
|
||||
url: uploadUrl,
|
||||
ssrfPolicy: params.gemini.ssrfPolicy,
|
||||
init: {
|
||||
method: "POST",
|
||||
headers: {
|
||||
...buildBatchHeaders(params.gemini, { json: false }),
|
||||
"Content-Type": uploadPayload.contentType,
|
||||
},
|
||||
body: uploadPayload.body,
|
||||
},
|
||||
onResponse: async (fileRes) => {
|
||||
if (!fileRes.ok) {
|
||||
const text = await fileRes.text();
|
||||
throw new Error(`gemini batch file upload failed: ${fileRes.status} ${text}`);
|
||||
}
|
||||
return (await fileRes.json()) as { name?: string; file?: { name?: string } };
|
||||
},
|
||||
});
|
||||
const fileId = filePayload.name ?? filePayload.file?.name;
|
||||
if (!fileId) {
|
||||
throw new Error("gemini batch file upload failed: missing file id");
|
||||
}
|
||||
|
||||
const batchBody = {
|
||||
batch: {
|
||||
displayName: `memory-embeddings-${params.agentId}`,
|
||||
inputConfig: {
|
||||
file_name: fileId,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const batchEndpoint = `${baseUrl}/${params.gemini.modelPath}:asyncBatchEmbedContent`;
|
||||
debugEmbeddingsLog("memory embeddings: gemini batch create", {
|
||||
batchEndpoint,
|
||||
fileId,
|
||||
});
|
||||
return await withRemoteHttpResponse({
|
||||
url: batchEndpoint,
|
||||
ssrfPolicy: params.gemini.ssrfPolicy,
|
||||
init: {
|
||||
method: "POST",
|
||||
headers: buildBatchHeaders(params.gemini, { json: true }),
|
||||
body: JSON.stringify(batchBody),
|
||||
},
|
||||
onResponse: async (batchRes) => {
|
||||
if (batchRes.ok) {
|
||||
return (await batchRes.json()) as GeminiBatchStatus;
|
||||
}
|
||||
const text = await batchRes.text();
|
||||
if (batchRes.status === 404) {
|
||||
throw new Error(
|
||||
"gemini batch create failed: 404 (asyncBatchEmbedContent not available for this model/baseUrl). Disable remote.batch.enabled or switch providers.",
|
||||
);
|
||||
}
|
||||
throw new Error(`gemini batch create failed: ${batchRes.status} ${text}`);
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
async function fetchGeminiBatchStatus(params: {
|
||||
gemini: GeminiEmbeddingClient;
|
||||
batchName: string;
|
||||
}): Promise<GeminiBatchStatus> {
|
||||
const baseUrl = normalizeBatchBaseUrl(params.gemini);
|
||||
const name = params.batchName.startsWith("batches/")
|
||||
? params.batchName
|
||||
: `batches/${params.batchName}`;
|
||||
const statusUrl = `${baseUrl}/${name}`;
|
||||
debugEmbeddingsLog("memory embeddings: gemini batch status", { statusUrl });
|
||||
return await withRemoteHttpResponse({
|
||||
url: statusUrl,
|
||||
ssrfPolicy: params.gemini.ssrfPolicy,
|
||||
init: {
|
||||
headers: buildBatchHeaders(params.gemini, { json: true }),
|
||||
},
|
||||
onResponse: async (res) => {
|
||||
if (!res.ok) {
|
||||
const text = await res.text();
|
||||
throw new Error(`gemini batch status failed: ${res.status} ${text}`);
|
||||
}
|
||||
return (await res.json()) as GeminiBatchStatus;
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
async function fetchGeminiFileContent(params: {
|
||||
gemini: GeminiEmbeddingClient;
|
||||
fileId: string;
|
||||
}): Promise<string> {
|
||||
const baseUrl = normalizeBatchBaseUrl(params.gemini);
|
||||
const file = params.fileId.startsWith("files/") ? params.fileId : `files/${params.fileId}`;
|
||||
const downloadUrl = `${baseUrl}/${file}:download`;
|
||||
debugEmbeddingsLog("memory embeddings: gemini batch download", { downloadUrl });
|
||||
return await withRemoteHttpResponse({
|
||||
url: downloadUrl,
|
||||
ssrfPolicy: params.gemini.ssrfPolicy,
|
||||
init: {
|
||||
headers: buildBatchHeaders(params.gemini, { json: true }),
|
||||
},
|
||||
onResponse: async (res) => {
|
||||
if (!res.ok) {
|
||||
const text = await res.text();
|
||||
throw new Error(`gemini batch file content failed: ${res.status} ${text}`);
|
||||
}
|
||||
return await res.text();
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
function parseGeminiBatchOutput(text: string): GeminiBatchOutputLine[] {
|
||||
if (!text.trim()) {
|
||||
return [];
|
||||
}
|
||||
return text
|
||||
.split("\n")
|
||||
.map((line) => line.trim())
|
||||
.filter(Boolean)
|
||||
.map((line) => JSON.parse(line) as GeminiBatchOutputLine);
|
||||
}
|
||||
|
||||
async function waitForGeminiBatch(params: {
|
||||
gemini: GeminiEmbeddingClient;
|
||||
batchName: string;
|
||||
wait: boolean;
|
||||
pollIntervalMs: number;
|
||||
timeoutMs: number;
|
||||
debug?: (message: string, data?: Record<string, unknown>) => void;
|
||||
initial?: GeminiBatchStatus;
|
||||
}): Promise<{ outputFileId: string }> {
|
||||
const start = Date.now();
|
||||
let current: GeminiBatchStatus | undefined = params.initial;
|
||||
while (true) {
|
||||
const status =
|
||||
current ??
|
||||
(await fetchGeminiBatchStatus({
|
||||
gemini: params.gemini,
|
||||
batchName: params.batchName,
|
||||
}));
|
||||
const state = status.state ?? "UNKNOWN";
|
||||
if (["SUCCEEDED", "COMPLETED", "DONE"].includes(state)) {
|
||||
const outputFileId =
|
||||
status.outputConfig?.file ??
|
||||
status.outputConfig?.fileId ??
|
||||
status.metadata?.output?.responsesFile;
|
||||
if (!outputFileId) {
|
||||
throw new Error(`gemini batch ${params.batchName} completed without output file`);
|
||||
}
|
||||
return { outputFileId };
|
||||
}
|
||||
if (["FAILED", "CANCELLED", "CANCELED", "EXPIRED"].includes(state)) {
|
||||
const message = status.error?.message ?? "unknown error";
|
||||
throw new Error(`gemini batch ${params.batchName} ${state}: ${message}`);
|
||||
}
|
||||
if (!params.wait) {
|
||||
throw new Error(`gemini batch ${params.batchName} still ${state}; wait disabled`);
|
||||
}
|
||||
if (Date.now() - start > params.timeoutMs) {
|
||||
throw new Error(`gemini batch ${params.batchName} timed out after ${params.timeoutMs}ms`);
|
||||
}
|
||||
params.debug?.(`gemini batch ${params.batchName} ${state}; waiting ${params.pollIntervalMs}ms`);
|
||||
await new Promise((resolve) => setTimeout(resolve, params.pollIntervalMs));
|
||||
current = undefined;
|
||||
}
|
||||
}
|
||||
|
||||
export async function runGeminiEmbeddingBatches(
|
||||
params: {
|
||||
gemini: GeminiEmbeddingClient;
|
||||
agentId: string;
|
||||
requests: GeminiBatchRequest[];
|
||||
} & EmbeddingBatchExecutionParams,
|
||||
): Promise<Map<string, number[]>> {
|
||||
return await runEmbeddingBatchGroups({
|
||||
...buildEmbeddingBatchGroupOptions(params, {
|
||||
maxRequests: GEMINI_BATCH_MAX_REQUESTS,
|
||||
debugLabel: "memory embeddings: gemini batch submit",
|
||||
}),
|
||||
runGroup: async ({ group, groupIndex, groups, byCustomId }) => {
|
||||
const batchInfo = await submitGeminiBatch({
|
||||
gemini: params.gemini,
|
||||
requests: group,
|
||||
agentId: params.agentId,
|
||||
});
|
||||
const batchName = batchInfo.name ?? "";
|
||||
if (!batchName) {
|
||||
throw new Error("gemini batch create failed: missing batch name");
|
||||
}
|
||||
|
||||
params.debug?.("memory embeddings: gemini batch created", {
|
||||
batchName,
|
||||
state: batchInfo.state,
|
||||
group: groupIndex + 1,
|
||||
groups,
|
||||
requests: group.length,
|
||||
});
|
||||
|
||||
if (
|
||||
!params.wait &&
|
||||
batchInfo.state &&
|
||||
!["SUCCEEDED", "COMPLETED", "DONE"].includes(batchInfo.state)
|
||||
) {
|
||||
throw new Error(
|
||||
`gemini batch ${batchName} submitted; enable remote.batch.wait to await completion`,
|
||||
);
|
||||
}
|
||||
|
||||
const completed =
|
||||
batchInfo.state && ["SUCCEEDED", "COMPLETED", "DONE"].includes(batchInfo.state)
|
||||
? {
|
||||
outputFileId:
|
||||
batchInfo.outputConfig?.file ??
|
||||
batchInfo.outputConfig?.fileId ??
|
||||
batchInfo.metadata?.output?.responsesFile ??
|
||||
"",
|
||||
}
|
||||
: await waitForGeminiBatch({
|
||||
gemini: params.gemini,
|
||||
batchName,
|
||||
wait: params.wait,
|
||||
pollIntervalMs: params.pollIntervalMs,
|
||||
timeoutMs: params.timeoutMs,
|
||||
debug: params.debug,
|
||||
initial: batchInfo,
|
||||
});
|
||||
if (!completed.outputFileId) {
|
||||
throw new Error(`gemini batch ${batchName} completed without output file`);
|
||||
}
|
||||
|
||||
const content = await fetchGeminiFileContent({
|
||||
gemini: params.gemini,
|
||||
fileId: completed.outputFileId,
|
||||
});
|
||||
const outputLines = parseGeminiBatchOutput(content);
|
||||
const errors: string[] = [];
|
||||
const remaining = new Set(group.map((request) => request.custom_id));
|
||||
|
||||
for (const line of outputLines) {
|
||||
const customId = line.key ?? line.custom_id ?? line.request_id;
|
||||
if (!customId) {
|
||||
continue;
|
||||
}
|
||||
remaining.delete(customId);
|
||||
if (line.error?.message) {
|
||||
errors.push(`${customId}: ${line.error.message}`);
|
||||
continue;
|
||||
}
|
||||
if (line.response?.error?.message) {
|
||||
errors.push(`${customId}: ${line.response.error.message}`);
|
||||
continue;
|
||||
}
|
||||
const embedding = sanitizeAndNormalizeEmbedding(
|
||||
line.embedding?.values ?? line.response?.embedding?.values ?? [],
|
||||
);
|
||||
if (embedding.length === 0) {
|
||||
errors.push(`${customId}: empty embedding`);
|
||||
continue;
|
||||
}
|
||||
byCustomId.set(customId, embedding);
|
||||
}
|
||||
|
||||
if (errors.length > 0) {
|
||||
throw new Error(`gemini batch ${batchName} failed: ${errors.join("; ")}`);
|
||||
}
|
||||
if (remaining.size > 0) {
|
||||
throw new Error(`gemini batch ${batchName} missing ${remaining.size} embedding responses`);
|
||||
}
|
||||
},
|
||||
});
|
||||
}
|
||||
@@ -1,108 +0,0 @@
|
||||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
|
||||
const mocks = vi.hoisted(() => ({
|
||||
uploadBatchJsonlFile: vi.fn(async () => "file_in"),
|
||||
postJsonWithRetry: vi.fn(async () => ({ id: "batch_1", status: "in_progress" })),
|
||||
resolveCompletedBatchResult: vi.fn(async () => ({ outputFileId: "file_out" })),
|
||||
withRemoteHttpResponse: vi.fn(
|
||||
async (params: { url: string; onResponse: (res: Response) => Promise<unknown> }) => {
|
||||
if (params.url.endsWith("/files/file_out/content")) {
|
||||
const content = [
|
||||
JSON.stringify({
|
||||
custom_id: "0",
|
||||
response: {
|
||||
status_code: 200,
|
||||
body: { data: [{ embedding: [1, 0, 0], index: 0 }] },
|
||||
},
|
||||
}),
|
||||
JSON.stringify({
|
||||
custom_id: "1",
|
||||
response: {
|
||||
status_code: 200,
|
||||
body: { data: [{ embedding: [2, 0, 0], index: 0 }] },
|
||||
},
|
||||
}),
|
||||
].join("\n");
|
||||
return await params.onResponse({
|
||||
ok: true,
|
||||
status: 200,
|
||||
text: async () => content,
|
||||
} as Response);
|
||||
}
|
||||
return await params.onResponse({
|
||||
ok: true,
|
||||
status: 200,
|
||||
json: async () => ({ id: "batch_1", status: "completed", output_file_id: "file_out" }),
|
||||
} as Response);
|
||||
},
|
||||
),
|
||||
}));
|
||||
|
||||
vi.mock("./batch-upload.js", () => ({
|
||||
uploadBatchJsonlFile: mocks.uploadBatchJsonlFile,
|
||||
}));
|
||||
|
||||
vi.mock("./batch-http.js", () => ({
|
||||
postJsonWithRetry: mocks.postJsonWithRetry,
|
||||
}));
|
||||
|
||||
vi.mock("./batch-status.js", () => ({
|
||||
resolveBatchCompletionFromStatus: vi.fn(),
|
||||
resolveCompletedBatchResult: mocks.resolveCompletedBatchResult,
|
||||
throwIfBatchTerminalFailure: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock("./remote-http.js", () => ({
|
||||
withRemoteHttpResponse: mocks.withRemoteHttpResponse,
|
||||
}));
|
||||
|
||||
describe("runOpenAiEmbeddingBatches", () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it("maps uploaded batch output rows back to embeddings", async () => {
|
||||
const { runOpenAiEmbeddingBatches, OPENAI_BATCH_ENDPOINT } = await import("./batch-openai.js");
|
||||
|
||||
const result = await runOpenAiEmbeddingBatches({
|
||||
openAi: {
|
||||
baseUrl: "https://api.openai.com/v1",
|
||||
headers: { Authorization: "Bearer test" },
|
||||
fetchImpl: fetch,
|
||||
model: "text-embedding-3-small",
|
||||
},
|
||||
agentId: "main",
|
||||
requests: [
|
||||
{
|
||||
custom_id: "0",
|
||||
method: "POST",
|
||||
url: OPENAI_BATCH_ENDPOINT,
|
||||
body: { model: "text-embedding-3-small", input: "hello" },
|
||||
},
|
||||
{
|
||||
custom_id: "1",
|
||||
method: "POST",
|
||||
url: OPENAI_BATCH_ENDPOINT,
|
||||
body: { model: "text-embedding-3-small", input: "world" },
|
||||
},
|
||||
],
|
||||
wait: true,
|
||||
pollIntervalMs: 1,
|
||||
timeoutMs: 1000,
|
||||
concurrency: 3,
|
||||
});
|
||||
|
||||
expect(mocks.uploadBatchJsonlFile).toHaveBeenCalled();
|
||||
expect(mocks.postJsonWithRetry).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
errorPrefix: "openai batch create failed",
|
||||
body: expect.objectContaining({
|
||||
endpoint: OPENAI_BATCH_ENDPOINT,
|
||||
metadata: { source: "openclaw-memory", agent: "main" },
|
||||
}),
|
||||
}),
|
||||
);
|
||||
expect(result.get("0")).toEqual([1, 0, 0]);
|
||||
expect(result.get("1")).toEqual([2, 0, 0]);
|
||||
});
|
||||
});
|
||||
@@ -1,176 +0,0 @@
|
||||
import { ReadableStream } from "node:stream/web";
|
||||
import { setTimeout as nativeSleep } from "node:timers/promises";
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
import {
|
||||
runVoyageEmbeddingBatches,
|
||||
type VoyageBatchOutputLine,
|
||||
type VoyageBatchRequest,
|
||||
} from "./batch-voyage.js";
|
||||
import type { VoyageEmbeddingClient } from "./embeddings-voyage.js";
|
||||
|
||||
const realNow = Date.now.bind(Date);
|
||||
|
||||
describe("runVoyageEmbeddingBatches", () => {
|
||||
const mockClient: VoyageEmbeddingClient = {
|
||||
baseUrl: "https://api.voyageai.com/v1",
|
||||
headers: { Authorization: "Bearer test-key" },
|
||||
model: "voyage-4-large",
|
||||
};
|
||||
|
||||
const mockRequests: VoyageBatchRequest[] = [
|
||||
{ custom_id: "req-1", body: { input: "text1" } },
|
||||
{ custom_id: "req-2", body: { input: "text2" } },
|
||||
];
|
||||
|
||||
it("successfully submits batch, waits, and streams results", async () => {
|
||||
const outputLines: VoyageBatchOutputLine[] = [
|
||||
{
|
||||
custom_id: "req-1",
|
||||
response: { status_code: 200, body: { data: [{ embedding: [0.1, 0.1] }] } },
|
||||
},
|
||||
{
|
||||
custom_id: "req-2",
|
||||
response: { status_code: 200, body: { data: [{ embedding: [0.2, 0.2] }] } },
|
||||
},
|
||||
];
|
||||
const withRemoteHttpResponse = vi.fn();
|
||||
const postJsonWithRetry = vi.fn();
|
||||
const uploadBatchJsonlFile = vi.fn();
|
||||
|
||||
// Create a stream that emits the NDJSON lines
|
||||
const stream = new ReadableStream({
|
||||
start(controller) {
|
||||
const text = outputLines.map((l) => JSON.stringify(l)).join("\n");
|
||||
controller.enqueue(new TextEncoder().encode(text));
|
||||
controller.close();
|
||||
},
|
||||
});
|
||||
uploadBatchJsonlFile.mockImplementationOnce(async (params) => {
|
||||
expect(params.errorPrefix).toBe("voyage batch file upload failed");
|
||||
expect(params.requests).toEqual(mockRequests);
|
||||
return "file-123";
|
||||
});
|
||||
postJsonWithRetry.mockImplementationOnce(async (params) => {
|
||||
expect(params.url).toContain("/batches");
|
||||
expect(params.body).toMatchObject({
|
||||
input_file_id: "file-123",
|
||||
completion_window: "12h",
|
||||
request_params: {
|
||||
model: "voyage-4-large",
|
||||
input_type: "document",
|
||||
},
|
||||
});
|
||||
return {
|
||||
id: "batch-abc",
|
||||
status: "pending",
|
||||
};
|
||||
});
|
||||
withRemoteHttpResponse.mockImplementationOnce(async (params) => {
|
||||
expect(params.url).toContain("/batches/batch-abc");
|
||||
return await params.onResponse(
|
||||
new Response(
|
||||
JSON.stringify({
|
||||
id: "batch-abc",
|
||||
status: "completed",
|
||||
output_file_id: "file-out-999",
|
||||
}),
|
||||
{
|
||||
status: 200,
|
||||
headers: { "Content-Type": "application/json" },
|
||||
},
|
||||
),
|
||||
);
|
||||
});
|
||||
withRemoteHttpResponse.mockImplementationOnce(async (params) => {
|
||||
expect(params.url).toContain("/files/file-out-999/content");
|
||||
return await params.onResponse(
|
||||
new Response(stream as unknown as BodyInit, {
|
||||
status: 200,
|
||||
headers: { "Content-Type": "application/x-ndjson" },
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
const results = await runVoyageEmbeddingBatches({
|
||||
client: mockClient,
|
||||
agentId: "agent-1",
|
||||
requests: mockRequests,
|
||||
wait: true,
|
||||
pollIntervalMs: 1, // fast poll
|
||||
timeoutMs: 1000,
|
||||
concurrency: 1,
|
||||
deps: {
|
||||
now: realNow,
|
||||
sleep: async (ms) => {
|
||||
await nativeSleep(ms);
|
||||
},
|
||||
postJsonWithRetry,
|
||||
uploadBatchJsonlFile,
|
||||
withRemoteHttpResponse,
|
||||
},
|
||||
});
|
||||
|
||||
expect(results.size).toBe(2);
|
||||
expect(results.get("req-1")).toEqual([0.1, 0.1]);
|
||||
expect(results.get("req-2")).toEqual([0.2, 0.2]);
|
||||
expect(uploadBatchJsonlFile).toHaveBeenCalledTimes(1);
|
||||
expect(postJsonWithRetry).toHaveBeenCalledTimes(1);
|
||||
expect(withRemoteHttpResponse).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it("handles empty lines and stream chunks correctly", async () => {
|
||||
const withRemoteHttpResponse = vi.fn();
|
||||
const postJsonWithRetry = vi.fn();
|
||||
const uploadBatchJsonlFile = vi.fn();
|
||||
const stream = new ReadableStream({
|
||||
start(controller) {
|
||||
const line1 = JSON.stringify({
|
||||
custom_id: "req-1",
|
||||
response: { body: { data: [{ embedding: [1] }] } },
|
||||
});
|
||||
const line2 = JSON.stringify({
|
||||
custom_id: "req-2",
|
||||
response: { body: { data: [{ embedding: [2] }] } },
|
||||
});
|
||||
|
||||
// Split across chunks
|
||||
controller.enqueue(new TextEncoder().encode(line1 + "\n"));
|
||||
controller.enqueue(new TextEncoder().encode("\n")); // empty line
|
||||
controller.enqueue(new TextEncoder().encode(line2)); // no newline at EOF
|
||||
controller.close();
|
||||
},
|
||||
});
|
||||
uploadBatchJsonlFile.mockResolvedValueOnce("f1");
|
||||
postJsonWithRetry.mockResolvedValueOnce({
|
||||
id: "b1",
|
||||
status: "completed",
|
||||
output_file_id: "out1",
|
||||
});
|
||||
withRemoteHttpResponse.mockImplementationOnce(async (params) => {
|
||||
expect(params.url).toContain("/files/out1/content");
|
||||
return await params.onResponse(new Response(stream as unknown as BodyInit, { status: 200 }));
|
||||
});
|
||||
|
||||
const results = await runVoyageEmbeddingBatches({
|
||||
client: mockClient,
|
||||
agentId: "a1",
|
||||
requests: mockRequests,
|
||||
wait: true,
|
||||
pollIntervalMs: 1,
|
||||
timeoutMs: 1000,
|
||||
concurrency: 1,
|
||||
deps: {
|
||||
now: realNow,
|
||||
sleep: async (ms) => {
|
||||
await nativeSleep(ms);
|
||||
},
|
||||
postJsonWithRetry,
|
||||
uploadBatchJsonlFile,
|
||||
withRemoteHttpResponse,
|
||||
},
|
||||
});
|
||||
|
||||
expect(results.get("req-1")).toEqual([1]);
|
||||
expect(results.get("req-2")).toEqual([2]);
|
||||
});
|
||||
});
|
||||
@@ -1,315 +0,0 @@
|
||||
import { createInterface } from "node:readline";
|
||||
import { Readable } from "node:stream";
|
||||
import {
|
||||
applyEmbeddingBatchOutputLine,
|
||||
buildBatchHeaders,
|
||||
buildEmbeddingBatchGroupOptions,
|
||||
EMBEDDING_BATCH_ENDPOINT,
|
||||
extractBatchErrorMessage,
|
||||
formatUnavailableBatchError,
|
||||
normalizeBatchBaseUrl,
|
||||
postJsonWithRetry,
|
||||
resolveBatchCompletionFromStatus,
|
||||
resolveCompletedBatchResult,
|
||||
runEmbeddingBatchGroups,
|
||||
throwIfBatchTerminalFailure,
|
||||
type EmbeddingBatchExecutionParams,
|
||||
type EmbeddingBatchStatus,
|
||||
type BatchCompletionResult,
|
||||
type ProviderBatchOutputLine,
|
||||
uploadBatchJsonlFile,
|
||||
withRemoteHttpResponse,
|
||||
} from "./batch-embedding-common.js";
|
||||
import type { VoyageEmbeddingClient } from "./embeddings-voyage.js";
|
||||
|
||||
/**
|
||||
* Voyage Batch API Input Line format.
|
||||
* See: https://docs.voyageai.com/docs/batch-inference
|
||||
*/
|
||||
export type VoyageBatchRequest = {
|
||||
custom_id: string;
|
||||
body: {
|
||||
input: string | string[];
|
||||
};
|
||||
};
|
||||
|
||||
export type VoyageBatchStatus = EmbeddingBatchStatus;
|
||||
export type VoyageBatchOutputLine = ProviderBatchOutputLine;
|
||||
|
||||
export const VOYAGE_BATCH_ENDPOINT = EMBEDDING_BATCH_ENDPOINT;
|
||||
const VOYAGE_BATCH_COMPLETION_WINDOW = "12h";
|
||||
const VOYAGE_BATCH_MAX_REQUESTS = 50000;
|
||||
|
||||
type VoyageBatchDeps = {
|
||||
now: () => number;
|
||||
sleep: (ms: number) => Promise<void>;
|
||||
postJsonWithRetry: typeof postJsonWithRetry;
|
||||
uploadBatchJsonlFile: typeof uploadBatchJsonlFile;
|
||||
withRemoteHttpResponse: typeof withRemoteHttpResponse;
|
||||
};
|
||||
|
||||
function resolveVoyageBatchDeps(overrides: Partial<VoyageBatchDeps> | undefined): VoyageBatchDeps {
|
||||
return {
|
||||
now: overrides?.now ?? Date.now,
|
||||
sleep:
|
||||
overrides?.sleep ??
|
||||
(async (ms: number) => await new Promise((resolve) => setTimeout(resolve, ms))),
|
||||
postJsonWithRetry: overrides?.postJsonWithRetry ?? postJsonWithRetry,
|
||||
uploadBatchJsonlFile: overrides?.uploadBatchJsonlFile ?? uploadBatchJsonlFile,
|
||||
withRemoteHttpResponse: overrides?.withRemoteHttpResponse ?? withRemoteHttpResponse,
|
||||
};
|
||||
}
|
||||
|
||||
async function assertVoyageResponseOk(res: Response, context: string): Promise<void> {
|
||||
if (!res.ok) {
|
||||
const text = await res.text();
|
||||
throw new Error(`${context}: ${res.status} ${text}`);
|
||||
}
|
||||
}
|
||||
|
||||
function buildVoyageBatchRequest<T>(params: {
|
||||
client: VoyageEmbeddingClient;
|
||||
path: string;
|
||||
onResponse: (res: Response) => Promise<T>;
|
||||
}) {
|
||||
const baseUrl = normalizeBatchBaseUrl(params.client);
|
||||
return {
|
||||
url: `${baseUrl}/${params.path}`,
|
||||
ssrfPolicy: params.client.ssrfPolicy,
|
||||
init: {
|
||||
headers: buildBatchHeaders(params.client, { json: true }),
|
||||
},
|
||||
onResponse: params.onResponse,
|
||||
};
|
||||
}
|
||||
|
||||
async function submitVoyageBatch(params: {
|
||||
client: VoyageEmbeddingClient;
|
||||
requests: VoyageBatchRequest[];
|
||||
agentId: string;
|
||||
deps: VoyageBatchDeps;
|
||||
}): Promise<VoyageBatchStatus> {
|
||||
const baseUrl = normalizeBatchBaseUrl(params.client);
|
||||
const inputFileId = await params.deps.uploadBatchJsonlFile({
|
||||
client: params.client,
|
||||
requests: params.requests,
|
||||
errorPrefix: "voyage batch file upload failed",
|
||||
});
|
||||
|
||||
// 2. Create batch job using Voyage Batches API
|
||||
return await params.deps.postJsonWithRetry<VoyageBatchStatus>({
|
||||
url: `${baseUrl}/batches`,
|
||||
headers: buildBatchHeaders(params.client, { json: true }),
|
||||
ssrfPolicy: params.client.ssrfPolicy,
|
||||
body: {
|
||||
input_file_id: inputFileId,
|
||||
endpoint: VOYAGE_BATCH_ENDPOINT,
|
||||
completion_window: VOYAGE_BATCH_COMPLETION_WINDOW,
|
||||
request_params: {
|
||||
model: params.client.model,
|
||||
input_type: "document",
|
||||
},
|
||||
metadata: {
|
||||
source: "clawdbot-memory",
|
||||
agent: params.agentId,
|
||||
},
|
||||
},
|
||||
errorPrefix: "voyage batch create failed",
|
||||
});
|
||||
}
|
||||
|
||||
async function fetchVoyageBatchStatus(params: {
|
||||
client: VoyageEmbeddingClient;
|
||||
batchId: string;
|
||||
deps: VoyageBatchDeps;
|
||||
}): Promise<VoyageBatchStatus> {
|
||||
return await params.deps.withRemoteHttpResponse(
|
||||
buildVoyageBatchRequest({
|
||||
client: params.client,
|
||||
path: `batches/${params.batchId}`,
|
||||
onResponse: async (res) => {
|
||||
await assertVoyageResponseOk(res, "voyage batch status failed");
|
||||
return (await res.json()) as VoyageBatchStatus;
|
||||
},
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
async function readVoyageBatchError(params: {
|
||||
client: VoyageEmbeddingClient;
|
||||
errorFileId: string;
|
||||
deps: VoyageBatchDeps;
|
||||
}): Promise<string | undefined> {
|
||||
try {
|
||||
return await params.deps.withRemoteHttpResponse(
|
||||
buildVoyageBatchRequest({
|
||||
client: params.client,
|
||||
path: `files/${params.errorFileId}/content`,
|
||||
onResponse: async (res) => {
|
||||
await assertVoyageResponseOk(res, "voyage batch error file content failed");
|
||||
const text = await res.text();
|
||||
if (!text.trim()) {
|
||||
return undefined;
|
||||
}
|
||||
const lines = text
|
||||
.split("\n")
|
||||
.map((line) => line.trim())
|
||||
.filter(Boolean)
|
||||
.map((line) => JSON.parse(line) as VoyageBatchOutputLine);
|
||||
return extractBatchErrorMessage(lines);
|
||||
},
|
||||
}),
|
||||
);
|
||||
} catch (err) {
|
||||
return formatUnavailableBatchError(err);
|
||||
}
|
||||
}
|
||||
|
||||
async function waitForVoyageBatch(params: {
|
||||
client: VoyageEmbeddingClient;
|
||||
batchId: string;
|
||||
wait: boolean;
|
||||
pollIntervalMs: number;
|
||||
timeoutMs: number;
|
||||
debug?: (message: string, data?: Record<string, unknown>) => void;
|
||||
initial?: VoyageBatchStatus;
|
||||
deps: VoyageBatchDeps;
|
||||
}): Promise<BatchCompletionResult> {
|
||||
const start = params.deps.now();
|
||||
let current: VoyageBatchStatus | undefined = params.initial;
|
||||
while (true) {
|
||||
const status =
|
||||
current ??
|
||||
(await fetchVoyageBatchStatus({
|
||||
client: params.client,
|
||||
batchId: params.batchId,
|
||||
deps: params.deps,
|
||||
}));
|
||||
const state = status.status ?? "unknown";
|
||||
if (state === "completed") {
|
||||
return resolveBatchCompletionFromStatus({
|
||||
provider: "voyage",
|
||||
batchId: params.batchId,
|
||||
status,
|
||||
});
|
||||
}
|
||||
await throwIfBatchTerminalFailure({
|
||||
provider: "voyage",
|
||||
status: { ...status, id: params.batchId },
|
||||
readError: async (errorFileId) =>
|
||||
await readVoyageBatchError({
|
||||
client: params.client,
|
||||
errorFileId,
|
||||
deps: params.deps,
|
||||
}),
|
||||
});
|
||||
if (!params.wait) {
|
||||
throw new Error(`voyage batch ${params.batchId} still ${state}; wait disabled`);
|
||||
}
|
||||
if (params.deps.now() - start > params.timeoutMs) {
|
||||
throw new Error(`voyage batch ${params.batchId} timed out after ${params.timeoutMs}ms`);
|
||||
}
|
||||
params.debug?.(`voyage batch ${params.batchId} ${state}; waiting ${params.pollIntervalMs}ms`);
|
||||
await params.deps.sleep(params.pollIntervalMs);
|
||||
current = undefined;
|
||||
}
|
||||
}
|
||||
|
||||
export async function runVoyageEmbeddingBatches(
|
||||
params: {
|
||||
client: VoyageEmbeddingClient;
|
||||
agentId: string;
|
||||
requests: VoyageBatchRequest[];
|
||||
deps?: Partial<VoyageBatchDeps>;
|
||||
} & EmbeddingBatchExecutionParams,
|
||||
): Promise<Map<string, number[]>> {
|
||||
const deps = resolveVoyageBatchDeps(params.deps);
|
||||
return await runEmbeddingBatchGroups({
|
||||
...buildEmbeddingBatchGroupOptions(params, {
|
||||
maxRequests: VOYAGE_BATCH_MAX_REQUESTS,
|
||||
debugLabel: "memory embeddings: voyage batch submit",
|
||||
}),
|
||||
runGroup: async ({ group, groupIndex, groups, byCustomId }) => {
|
||||
const batchInfo = await submitVoyageBatch({
|
||||
client: params.client,
|
||||
requests: group,
|
||||
agentId: params.agentId,
|
||||
deps,
|
||||
});
|
||||
if (!batchInfo.id) {
|
||||
throw new Error("voyage batch create failed: missing batch id");
|
||||
}
|
||||
const batchId = batchInfo.id;
|
||||
|
||||
params.debug?.("memory embeddings: voyage batch created", {
|
||||
batchId: batchInfo.id,
|
||||
status: batchInfo.status,
|
||||
group: groupIndex + 1,
|
||||
groups,
|
||||
requests: group.length,
|
||||
});
|
||||
|
||||
const completed = await resolveCompletedBatchResult({
|
||||
provider: "voyage",
|
||||
status: batchInfo,
|
||||
wait: params.wait,
|
||||
waitForBatch: async () =>
|
||||
await waitForVoyageBatch({
|
||||
client: params.client,
|
||||
batchId,
|
||||
wait: params.wait,
|
||||
pollIntervalMs: params.pollIntervalMs,
|
||||
timeoutMs: params.timeoutMs,
|
||||
debug: params.debug,
|
||||
initial: batchInfo,
|
||||
deps,
|
||||
}),
|
||||
});
|
||||
|
||||
const baseUrl = normalizeBatchBaseUrl(params.client);
|
||||
const errors: string[] = [];
|
||||
const remaining = new Set(group.map((request) => request.custom_id));
|
||||
|
||||
await deps.withRemoteHttpResponse({
|
||||
url: `${baseUrl}/files/${completed.outputFileId}/content`,
|
||||
ssrfPolicy: params.client.ssrfPolicy,
|
||||
init: {
|
||||
headers: buildBatchHeaders(params.client, { json: true }),
|
||||
},
|
||||
onResponse: async (contentRes) => {
|
||||
if (!contentRes.ok) {
|
||||
const text = await contentRes.text();
|
||||
throw new Error(`voyage batch file content failed: ${contentRes.status} ${text}`);
|
||||
}
|
||||
|
||||
if (!contentRes.body) {
|
||||
return;
|
||||
}
|
||||
const reader = createInterface({
|
||||
input: Readable.fromWeb(
|
||||
contentRes.body as unknown as import("stream/web").ReadableStream,
|
||||
),
|
||||
terminal: false,
|
||||
});
|
||||
|
||||
for await (const rawLine of reader) {
|
||||
if (!rawLine.trim()) {
|
||||
continue;
|
||||
}
|
||||
const line = JSON.parse(rawLine) as VoyageBatchOutputLine;
|
||||
applyEmbeddingBatchOutputLine({ line, remaining, errors, byCustomId });
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
if (errors.length > 0) {
|
||||
throw new Error(`voyage batch ${batchInfo.id} failed: ${errors.join("; ")}`);
|
||||
}
|
||||
if (remaining.size > 0) {
|
||||
throw new Error(
|
||||
`voyage batch ${batchInfo.id} missing ${remaining.size} embedding responses`,
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
}
|
||||
@@ -1,40 +1,14 @@
|
||||
import { normalizeLowercaseStringOrEmpty } from "../../shared/string-coerce.js";
|
||||
import type { EmbeddingProvider } from "./embeddings.js";
|
||||
|
||||
const DEFAULT_EMBEDDING_MAX_INPUT_TOKENS = 8192;
|
||||
const DEFAULT_LOCAL_EMBEDDING_MAX_INPUT_TOKENS = 2048;
|
||||
|
||||
const KNOWN_EMBEDDING_MAX_INPUT_TOKENS: Record<string, number> = {
|
||||
"openai:text-embedding-3-small": 8192,
|
||||
"openai:text-embedding-3-large": 8192,
|
||||
"openai:text-embedding-ada-002": 8191,
|
||||
"gemini:text-embedding-004": 2048,
|
||||
"gemini:gemini-embedding-001": 2048,
|
||||
"gemini:gemini-embedding-2-preview": 8192,
|
||||
"voyage:voyage-3": 32000,
|
||||
"voyage:voyage-3-lite": 16000,
|
||||
"voyage:voyage-code-3": 32000,
|
||||
};
|
||||
|
||||
export function resolveEmbeddingMaxInputTokens(provider: EmbeddingProvider): number {
|
||||
if (typeof provider.maxInputTokens === "number") {
|
||||
return provider.maxInputTokens;
|
||||
}
|
||||
|
||||
// Provider/model mapping is best-effort; different providers use different
|
||||
// limits and we prefer to be conservative when we don't know.
|
||||
const key = normalizeLowercaseStringOrEmpty(`${provider.id}:${provider.model}`);
|
||||
const known = KNOWN_EMBEDDING_MAX_INPUT_TOKENS[key];
|
||||
if (typeof known === "number") {
|
||||
return known;
|
||||
}
|
||||
|
||||
// Provider-specific conservative fallbacks. This prevents us from accidentally
|
||||
// using the OpenAI default for providers with much smaller limits.
|
||||
if (normalizeLowercaseStringOrEmpty(provider.id) === "gemini") {
|
||||
return 2048;
|
||||
}
|
||||
if (normalizeLowercaseStringOrEmpty(provider.id) === "local") {
|
||||
if (provider.id === "local") {
|
||||
return DEFAULT_LOCAL_EMBEDDING_MAX_INPUT_TOKENS;
|
||||
}
|
||||
|
||||
|
||||
29
src/memory-host-sdk/host/embedding-provider-adapter-utils.ts
Normal file
29
src/memory-host-sdk/host/embedding-provider-adapter-utils.ts
Normal file
@@ -0,0 +1,29 @@
|
||||
import { normalizeLowercaseStringOrEmpty } from "../../shared/string-coerce.js";
|
||||
|
||||
export function isMissingEmbeddingApiKeyError(err: unknown): boolean {
|
||||
return err instanceof Error && err.message.includes("No API key found for provider");
|
||||
}
|
||||
|
||||
export function sanitizeEmbeddingCacheHeaders(
|
||||
headers: Record<string, string>,
|
||||
excludedHeaderNames: string[],
|
||||
): Array<[string, string]> {
|
||||
const excluded = new Set(
|
||||
excludedHeaderNames.map((name) => normalizeLowercaseStringOrEmpty(name)),
|
||||
);
|
||||
return Object.entries(headers)
|
||||
.filter(([key]) => !excluded.has(normalizeLowercaseStringOrEmpty(key)))
|
||||
.toSorted(([a], [b]) => a.localeCompare(b))
|
||||
.map(([key, value]) => [key, value]);
|
||||
}
|
||||
|
||||
export function mapBatchEmbeddingsByIndex(
|
||||
byCustomId: Map<string, number[]>,
|
||||
count: number,
|
||||
): number[][] {
|
||||
const embeddings: number[][] = [];
|
||||
for (let index = 0; index < count; index += 1) {
|
||||
embeddings.push(byCustomId.get(String(index)) ?? []);
|
||||
}
|
||||
return embeddings;
|
||||
}
|
||||
@@ -1,377 +0,0 @@
|
||||
import { afterEach, beforeAll, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
|
||||
const { defaultProviderMock, resolveCredentialsMock, sendMock } = vi.hoisted(() => ({
|
||||
defaultProviderMock: vi.fn(),
|
||||
resolveCredentialsMock: vi.fn(),
|
||||
sendMock: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock("@aws-sdk/client-bedrock-runtime", () => {
|
||||
class MockClient {
|
||||
region: string;
|
||||
constructor(config: { region: string }) {
|
||||
this.region = config.region;
|
||||
}
|
||||
send = sendMock;
|
||||
}
|
||||
class MockCommand {
|
||||
input: unknown;
|
||||
constructor(input: unknown) {
|
||||
this.input = input;
|
||||
}
|
||||
}
|
||||
return { BedrockRuntimeClient: MockClient, InvokeModelCommand: MockCommand };
|
||||
});
|
||||
|
||||
vi.mock("@aws-sdk/credential-provider-node", () => ({
|
||||
defaultProvider: defaultProviderMock.mockImplementation(() => resolveCredentialsMock),
|
||||
}));
|
||||
|
||||
let createBedrockEmbeddingProvider: typeof import("./embeddings-bedrock.js").createBedrockEmbeddingProvider;
|
||||
let resolveBedrockEmbeddingClient: typeof import("./embeddings-bedrock.js").resolveBedrockEmbeddingClient;
|
||||
let normalizeBedrockEmbeddingModel: typeof import("./embeddings-bedrock.js").normalizeBedrockEmbeddingModel;
|
||||
let hasAwsCredentials: typeof import("./embeddings-bedrock.js").hasAwsCredentials;
|
||||
|
||||
beforeAll(async () => {
|
||||
({
|
||||
createBedrockEmbeddingProvider,
|
||||
resolveBedrockEmbeddingClient,
|
||||
normalizeBedrockEmbeddingModel,
|
||||
hasAwsCredentials,
|
||||
} = await import("./embeddings-bedrock.js"));
|
||||
});
|
||||
|
||||
beforeEach(() => {
|
||||
defaultProviderMock.mockImplementation(() => resolveCredentialsMock);
|
||||
});
|
||||
|
||||
const enc = (body: unknown) => ({ body: new TextEncoder().encode(JSON.stringify(body)) });
|
||||
const reqBody = (i = 0): Record<string, unknown> =>
|
||||
JSON.parse(sendMock.mock.calls[i][0].input.body);
|
||||
|
||||
describe("bedrock embedding provider", () => {
|
||||
const originalEnv = process.env;
|
||||
afterEach(() => {
|
||||
process.env = originalEnv;
|
||||
vi.restoreAllMocks();
|
||||
defaultProviderMock.mockClear();
|
||||
resolveCredentialsMock.mockReset();
|
||||
sendMock.mockReset();
|
||||
});
|
||||
|
||||
// --- Normalization ---
|
||||
|
||||
it("normalizes model names with prefixes", () => {
|
||||
expect(normalizeBedrockEmbeddingModel("bedrock/amazon.titan-embed-text-v2:0")).toBe(
|
||||
"amazon.titan-embed-text-v2:0",
|
||||
);
|
||||
expect(normalizeBedrockEmbeddingModel("amazon-bedrock/cohere.embed-english-v3")).toBe(
|
||||
"cohere.embed-english-v3",
|
||||
);
|
||||
expect(normalizeBedrockEmbeddingModel("")).toBe("amazon.titan-embed-text-v2:0");
|
||||
});
|
||||
|
||||
// --- Client resolution ---
|
||||
|
||||
it("resolves region from env", () => {
|
||||
process.env = { ...originalEnv, AWS_REGION: "eu-west-1" };
|
||||
const c = resolveBedrockEmbeddingClient({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "amazon.titan-embed-text-v2:0",
|
||||
fallback: "none",
|
||||
});
|
||||
expect(c.region).toBe("eu-west-1");
|
||||
expect(c.dimensions).toBe(1024);
|
||||
});
|
||||
|
||||
it("defaults to us-east-1", () => {
|
||||
process.env = { ...originalEnv };
|
||||
delete process.env.AWS_REGION;
|
||||
delete process.env.AWS_DEFAULT_REGION;
|
||||
expect(
|
||||
resolveBedrockEmbeddingClient({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "amazon.titan-embed-text-v2:0",
|
||||
fallback: "none",
|
||||
}).region,
|
||||
).toBe("us-east-1");
|
||||
});
|
||||
|
||||
it("extracts region from baseUrl", () => {
|
||||
process.env = { ...originalEnv };
|
||||
delete process.env.AWS_REGION;
|
||||
const c = resolveBedrockEmbeddingClient({
|
||||
config: {
|
||||
models: {
|
||||
providers: {
|
||||
"amazon-bedrock": { baseUrl: "https://bedrock-runtime.ap-southeast-2.amazonaws.com" },
|
||||
},
|
||||
},
|
||||
} as never,
|
||||
provider: "bedrock",
|
||||
model: "amazon.titan-embed-text-v2:0",
|
||||
fallback: "none",
|
||||
});
|
||||
expect(c.region).toBe("ap-southeast-2");
|
||||
});
|
||||
|
||||
it("validates dimensions", () => {
|
||||
expect(() =>
|
||||
resolveBedrockEmbeddingClient({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "amazon.titan-embed-text-v2:0",
|
||||
fallback: "none",
|
||||
outputDimensionality: 768,
|
||||
}),
|
||||
).toThrow("Invalid dimensions 768");
|
||||
});
|
||||
|
||||
it("accepts valid dimensions", () => {
|
||||
expect(
|
||||
resolveBedrockEmbeddingClient({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "amazon.titan-embed-text-v2:0",
|
||||
fallback: "none",
|
||||
outputDimensionality: 256,
|
||||
}).dimensions,
|
||||
).toBe(256);
|
||||
});
|
||||
|
||||
it("resolves throughput-suffixed variants", () => {
|
||||
expect(
|
||||
resolveBedrockEmbeddingClient({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "amazon.titan-embed-text-v1:2:8k",
|
||||
fallback: "none",
|
||||
}).dimensions,
|
||||
).toBe(1536);
|
||||
});
|
||||
|
||||
// --- Credential detection ---
|
||||
|
||||
it("detects access keys", async () => {
|
||||
await expect(
|
||||
hasAwsCredentials({
|
||||
AWS_ACCESS_KEY_ID: "A",
|
||||
AWS_SECRET_ACCESS_KEY: "s",
|
||||
} as NodeJS.ProcessEnv),
|
||||
).resolves.toBe(true);
|
||||
});
|
||||
it("detects profile", async () => {
|
||||
await expect(hasAwsCredentials({ AWS_PROFILE: "default" } as NodeJS.ProcessEnv)).resolves.toBe(
|
||||
true,
|
||||
);
|
||||
});
|
||||
it("detects ECS task role", async () => {
|
||||
await expect(
|
||||
hasAwsCredentials({ AWS_CONTAINER_CREDENTIALS_RELATIVE_URI: "/v2" } as NodeJS.ProcessEnv),
|
||||
).resolves.toBe(true);
|
||||
});
|
||||
it("detects EKS IRSA", async () => {
|
||||
await expect(
|
||||
hasAwsCredentials({
|
||||
AWS_WEB_IDENTITY_TOKEN_FILE: "/var/run/secrets/token",
|
||||
AWS_ROLE_ARN: "arn:aws:iam::123:role/x",
|
||||
} as NodeJS.ProcessEnv),
|
||||
).resolves.toBe(true);
|
||||
});
|
||||
it("detects credentials via the AWS SDK default provider chain", async () => {
|
||||
resolveCredentialsMock.mockResolvedValue({ accessKeyId: "AKIAEXAMPLE" });
|
||||
await expect(hasAwsCredentials({} as NodeJS.ProcessEnv)).resolves.toBe(true);
|
||||
expect(defaultProviderMock).toHaveBeenCalledWith({ timeout: 1000, maxRetries: 0 });
|
||||
});
|
||||
it("returns false with no creds", async () => {
|
||||
resolveCredentialsMock.mockRejectedValue(new Error("no aws credentials"));
|
||||
await expect(hasAwsCredentials({} as NodeJS.ProcessEnv)).resolves.toBe(false);
|
||||
});
|
||||
|
||||
// --- Titan V2 ---
|
||||
|
||||
it("embeds with Titan V2", async () => {
|
||||
sendMock.mockResolvedValue(enc({ embedding: [0.1, 0.2, 0.3] }));
|
||||
const { provider } = await createBedrockEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "amazon.titan-embed-text-v2:0",
|
||||
fallback: "none",
|
||||
});
|
||||
expect(await provider.embedQuery("test")).toHaveLength(3);
|
||||
expect(reqBody()).toMatchObject({ inputText: "test", normalize: true, dimensions: 1024 });
|
||||
});
|
||||
|
||||
it("returns empty for blank text", async () => {
|
||||
const { provider } = await createBedrockEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "amazon.titan-embed-text-v2:0",
|
||||
fallback: "none",
|
||||
});
|
||||
expect(await provider.embedQuery(" ")).toEqual([]);
|
||||
expect(sendMock).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("batches Titan V2 concurrently", async () => {
|
||||
sendMock
|
||||
.mockResolvedValueOnce(enc({ embedding: [0.1] }))
|
||||
.mockResolvedValueOnce(enc({ embedding: [0.2] }));
|
||||
const { provider } = await createBedrockEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "amazon.titan-embed-text-v2:0",
|
||||
fallback: "none",
|
||||
});
|
||||
expect(await provider.embedBatch(["a", "b"])).toHaveLength(2);
|
||||
expect(sendMock).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
// --- Titan V1 ---
|
||||
|
||||
it("sends only inputText for Titan V1", async () => {
|
||||
sendMock.mockResolvedValue(enc({ embedding: [0.5] }));
|
||||
const { provider } = await createBedrockEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "amazon.titan-embed-text-v1",
|
||||
fallback: "none",
|
||||
});
|
||||
await provider.embedQuery("hi");
|
||||
expect(reqBody()).toEqual({ inputText: "hi" });
|
||||
});
|
||||
|
||||
it("handles Titan G1 text variant", async () => {
|
||||
sendMock.mockResolvedValue(enc({ embedding: [0.1] }));
|
||||
const { provider } = await createBedrockEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "amazon.titan-embed-g1-text-02",
|
||||
fallback: "none",
|
||||
});
|
||||
await provider.embedQuery("hi");
|
||||
expect(reqBody()).toEqual({ inputText: "hi" });
|
||||
});
|
||||
|
||||
// --- Cohere V3 ---
|
||||
|
||||
it("embeds Cohere V3 batch in single call", async () => {
|
||||
sendMock.mockResolvedValue(enc({ embeddings: [[0.1], [0.2]] }));
|
||||
const { provider } = await createBedrockEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "cohere.embed-english-v3",
|
||||
fallback: "none",
|
||||
});
|
||||
expect(await provider.embedBatch(["a", "b"])).toHaveLength(2);
|
||||
expect(sendMock).toHaveBeenCalledTimes(1);
|
||||
expect(reqBody()).toMatchObject({ texts: ["a", "b"], input_type: "search_document" });
|
||||
});
|
||||
|
||||
it("uses search_query for Cohere embedQuery", async () => {
|
||||
sendMock.mockResolvedValue(enc({ embeddings: [[0.1]] }));
|
||||
const { provider } = await createBedrockEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "cohere.embed-english-v3",
|
||||
fallback: "none",
|
||||
});
|
||||
await provider.embedQuery("q");
|
||||
expect(reqBody().input_type).toBe("search_query");
|
||||
});
|
||||
|
||||
// --- Cohere V4 ---
|
||||
|
||||
it("embeds Cohere V4 with embedding_types + output_dimension", async () => {
|
||||
sendMock.mockResolvedValue(enc({ embeddings: { float: [[0.1], [0.2]] } }));
|
||||
const { provider } = await createBedrockEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "cohere.embed-v4:0",
|
||||
fallback: "none",
|
||||
});
|
||||
expect(await provider.embedBatch(["a", "b"])).toHaveLength(2);
|
||||
expect(reqBody()).toMatchObject({ embedding_types: ["float"], output_dimension: 1536 });
|
||||
});
|
||||
|
||||
it("validates Cohere V4 dimensions", () => {
|
||||
expect(() =>
|
||||
resolveBedrockEmbeddingClient({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "cohere.embed-v4:0",
|
||||
fallback: "none",
|
||||
outputDimensionality: 2048,
|
||||
}),
|
||||
).toThrow("Invalid dimensions 2048");
|
||||
});
|
||||
|
||||
// --- Nova ---
|
||||
|
||||
it("embeds Nova with SINGLE_EMBEDDING format", async () => {
|
||||
sendMock.mockResolvedValue(
|
||||
enc({ embeddings: [{ embeddingType: "TEXT", embedding: [0.1, 0.2] }] }),
|
||||
);
|
||||
const { provider } = await createBedrockEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "amazon.nova-2-multimodal-embeddings-v1:0",
|
||||
fallback: "none",
|
||||
});
|
||||
expect(await provider.embedQuery("hi")).toHaveLength(2);
|
||||
expect(reqBody().taskType).toBe("SINGLE_EMBEDDING");
|
||||
});
|
||||
|
||||
it("validates Nova dimensions", () => {
|
||||
expect(() =>
|
||||
resolveBedrockEmbeddingClient({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "amazon.nova-2-multimodal-embeddings-v1:0",
|
||||
fallback: "none",
|
||||
outputDimensionality: 512,
|
||||
}),
|
||||
).toThrow("Invalid dimensions 512");
|
||||
});
|
||||
|
||||
it("batches Nova concurrently", async () => {
|
||||
sendMock
|
||||
.mockResolvedValueOnce(enc({ embeddings: [{ embeddingType: "TEXT", embedding: [0.1] }] }))
|
||||
.mockResolvedValueOnce(enc({ embeddings: [{ embeddingType: "TEXT", embedding: [0.2] }] }));
|
||||
const { provider } = await createBedrockEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "amazon.nova-2-multimodal-embeddings-v1:0",
|
||||
fallback: "none",
|
||||
});
|
||||
expect(await provider.embedBatch(["a", "b"])).toHaveLength(2);
|
||||
expect(sendMock).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
// --- TwelveLabs ---
|
||||
|
||||
it("embeds TwelveLabs Marengo", async () => {
|
||||
sendMock.mockResolvedValue(enc({ data: [{ embedding: [0.1, 0.2] }] }));
|
||||
const { provider } = await createBedrockEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "twelvelabs.marengo-embed-3-0-v1:0",
|
||||
fallback: "none",
|
||||
});
|
||||
expect(await provider.embedQuery("hi")).toHaveLength(2);
|
||||
expect(reqBody()).toEqual({ inputType: "text", text: { inputText: "hi" } });
|
||||
});
|
||||
|
||||
it("embeds TwelveLabs object-style responses", async () => {
|
||||
sendMock.mockResolvedValue(enc({ data: { embedding: [0.3, 0.4] } }));
|
||||
const { provider } = await createBedrockEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "bedrock",
|
||||
model: "twelvelabs.marengo-embed-2-7-v1:0",
|
||||
fallback: "none",
|
||||
});
|
||||
expect(await provider.embedQuery("hi")).toEqual([0.6, 0.8]);
|
||||
});
|
||||
});
|
||||
@@ -1,115 +0,0 @@
|
||||
import type { EmbeddingInput } from "./embedding-inputs.js";
|
||||
import type { GeminiTaskType } from "./embeddings.types.js";
|
||||
|
||||
export const DEFAULT_GEMINI_EMBEDDING_MODEL = "gemini-embedding-001";
|
||||
|
||||
export const GEMINI_EMBEDDING_2_MODELS = new Set([
|
||||
"gemini-embedding-2-preview",
|
||||
// Add the GA model name here once released.
|
||||
]);
|
||||
|
||||
const GEMINI_EMBEDDING_2_DEFAULT_DIMENSIONS = 3072;
|
||||
const GEMINI_EMBEDDING_2_VALID_DIMENSIONS = [768, 1536, 3072] as const;
|
||||
|
||||
export type { GeminiTaskType } from "./embeddings.types.js";
|
||||
|
||||
export type GeminiTextPart = { text: string };
|
||||
export type GeminiInlinePart = {
|
||||
inlineData: { mimeType: string; data: string };
|
||||
};
|
||||
export type GeminiPart = GeminiTextPart | GeminiInlinePart;
|
||||
export type GeminiEmbeddingRequest = {
|
||||
content: { parts: GeminiPart[] };
|
||||
taskType: GeminiTaskType;
|
||||
outputDimensionality?: number;
|
||||
model?: string;
|
||||
};
|
||||
export type GeminiTextEmbeddingRequest = GeminiEmbeddingRequest;
|
||||
|
||||
/** Builds the text-only Gemini embedding request shape used across direct and batch APIs. */
|
||||
export function buildGeminiTextEmbeddingRequest(params: {
|
||||
text: string;
|
||||
taskType: GeminiTaskType;
|
||||
outputDimensionality?: number;
|
||||
modelPath?: string;
|
||||
}): GeminiTextEmbeddingRequest {
|
||||
return buildGeminiEmbeddingRequest({
|
||||
input: { text: params.text },
|
||||
taskType: params.taskType,
|
||||
outputDimensionality: params.outputDimensionality,
|
||||
modelPath: params.modelPath,
|
||||
});
|
||||
}
|
||||
|
||||
export function buildGeminiEmbeddingRequest(params: {
|
||||
input: EmbeddingInput;
|
||||
taskType: GeminiTaskType;
|
||||
outputDimensionality?: number;
|
||||
modelPath?: string;
|
||||
}): GeminiEmbeddingRequest {
|
||||
const request: GeminiEmbeddingRequest = {
|
||||
content: {
|
||||
parts: params.input.parts?.map((part) =>
|
||||
part.type === "text"
|
||||
? ({ text: part.text } satisfies GeminiTextPart)
|
||||
: ({
|
||||
inlineData: { mimeType: part.mimeType, data: part.data },
|
||||
} satisfies GeminiInlinePart),
|
||||
) ?? [{ text: params.input.text }],
|
||||
},
|
||||
taskType: params.taskType,
|
||||
};
|
||||
if (params.modelPath) {
|
||||
request.model = params.modelPath;
|
||||
}
|
||||
if (params.outputDimensionality != null) {
|
||||
request.outputDimensionality = params.outputDimensionality;
|
||||
}
|
||||
return request;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true if the given model name is a gemini-embedding-2 variant that
|
||||
* supports `outputDimensionality` and extended task types.
|
||||
*/
|
||||
export function isGeminiEmbedding2Model(model: string): boolean {
|
||||
return GEMINI_EMBEDDING_2_MODELS.has(model);
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate and return the `outputDimensionality` for gemini-embedding-2 models.
|
||||
* Returns `undefined` for older models (they don't support the param).
|
||||
*/
|
||||
export function resolveGeminiOutputDimensionality(
|
||||
model: string,
|
||||
requested?: number,
|
||||
): number | undefined {
|
||||
if (!isGeminiEmbedding2Model(model)) {
|
||||
return undefined;
|
||||
}
|
||||
if (requested == null) {
|
||||
return GEMINI_EMBEDDING_2_DEFAULT_DIMENSIONS;
|
||||
}
|
||||
const valid: readonly number[] = GEMINI_EMBEDDING_2_VALID_DIMENSIONS;
|
||||
if (!valid.includes(requested)) {
|
||||
throw new Error(
|
||||
`Invalid outputDimensionality ${requested} for ${model}. Valid values: ${valid.join(", ")}`,
|
||||
);
|
||||
}
|
||||
return requested;
|
||||
}
|
||||
|
||||
export function normalizeGeminiModel(model: string): string {
|
||||
const trimmed = model.trim();
|
||||
if (!trimmed) {
|
||||
return DEFAULT_GEMINI_EMBEDDING_MODEL;
|
||||
}
|
||||
const withoutPrefix = trimmed.replace(/^models\//, "");
|
||||
if (withoutPrefix.startsWith("gemini/")) {
|
||||
return withoutPrefix.slice("gemini/".length);
|
||||
}
|
||||
if (withoutPrefix.startsWith("google/")) {
|
||||
return withoutPrefix.slice("google/".length);
|
||||
}
|
||||
return withoutPrefix;
|
||||
}
|
||||
@@ -1,178 +0,0 @@
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
|
||||
const resolveCopilotApiTokenMock = vi.hoisted(() => vi.fn());
|
||||
const fetchWithSsrFGuardMock = vi.hoisted(() => vi.fn());
|
||||
|
||||
vi.mock("../../agents/github-copilot-token.js", () => ({
|
||||
DEFAULT_COPILOT_API_BASE_URL: "https://api.githubcopilot.test",
|
||||
resolveCopilotApiToken: resolveCopilotApiTokenMock,
|
||||
}));
|
||||
|
||||
vi.mock("../../infra/net/fetch-guard.js", () => ({
|
||||
fetchWithSsrFGuard: fetchWithSsrFGuardMock,
|
||||
}));
|
||||
|
||||
import { createGitHubCopilotEmbeddingProvider } from "./embeddings-github-copilot.js";
|
||||
|
||||
function mockFetchResponse(spec: { ok: boolean; status?: number; json?: unknown; text?: string }) {
|
||||
fetchWithSsrFGuardMock.mockImplementationOnce(async () => ({
|
||||
response: {
|
||||
ok: spec.ok,
|
||||
status: spec.status ?? (spec.ok ? 200 : 500),
|
||||
json: async () => spec.json,
|
||||
text: async () => spec.text ?? "",
|
||||
},
|
||||
release: vi.fn(async () => {}),
|
||||
}));
|
||||
}
|
||||
|
||||
describe("createGitHubCopilotEmbeddingProvider", () => {
|
||||
beforeEach(() => {
|
||||
resolveCopilotApiTokenMock.mockResolvedValue({
|
||||
token: "copilot-token-a",
|
||||
expiresAt: Date.now() + 3_600_000,
|
||||
source: "test",
|
||||
baseUrl: "https://api.githubcopilot.test",
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
resolveCopilotApiTokenMock.mockReset();
|
||||
fetchWithSsrFGuardMock.mockReset();
|
||||
});
|
||||
|
||||
it("normalizes embeddings returned for queries", async () => {
|
||||
mockFetchResponse({
|
||||
ok: true,
|
||||
json: {
|
||||
data: [{ index: 0, embedding: [3, 4] }],
|
||||
},
|
||||
});
|
||||
|
||||
const { provider } = await createGitHubCopilotEmbeddingProvider({
|
||||
githubToken: "gh_test",
|
||||
model: "text-embedding-3-small",
|
||||
});
|
||||
|
||||
await expect(provider.embedQuery("hello")).resolves.toEqual([0.6, 0.8]);
|
||||
expect(fetchWithSsrFGuardMock).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
url: "https://api.githubcopilot.test/embeddings",
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it("preserves input order by explicit response index", async () => {
|
||||
mockFetchResponse({
|
||||
ok: true,
|
||||
json: {
|
||||
data: [
|
||||
{ index: 1, embedding: [0, 2] },
|
||||
{ index: 0, embedding: [1, 0] },
|
||||
],
|
||||
},
|
||||
});
|
||||
|
||||
const { provider } = await createGitHubCopilotEmbeddingProvider({
|
||||
githubToken: "gh_test",
|
||||
model: "text-embedding-3-small",
|
||||
});
|
||||
|
||||
await expect(provider.embedBatch(["first", "second"])).resolves.toEqual([
|
||||
[1, 0],
|
||||
[0, 1],
|
||||
]);
|
||||
});
|
||||
|
||||
it("uses a fresh Copilot token for later requests", async () => {
|
||||
resolveCopilotApiTokenMock
|
||||
.mockResolvedValueOnce({
|
||||
token: "copilot-token-create",
|
||||
expiresAt: Date.now() + 3_600_000,
|
||||
source: "test",
|
||||
baseUrl: "https://api.githubcopilot.test",
|
||||
})
|
||||
.mockResolvedValueOnce({
|
||||
token: "copilot-token-first",
|
||||
expiresAt: Date.now() + 3_600_000,
|
||||
source: "test",
|
||||
baseUrl: "https://api.githubcopilot.test",
|
||||
})
|
||||
.mockResolvedValueOnce({
|
||||
token: "copilot-token-second",
|
||||
expiresAt: Date.now() + 3_600_000,
|
||||
source: "test",
|
||||
baseUrl: "https://api.githubcopilot.test",
|
||||
});
|
||||
mockFetchResponse({
|
||||
ok: true,
|
||||
json: { data: [{ index: 0, embedding: [1, 0] }] },
|
||||
});
|
||||
mockFetchResponse({
|
||||
ok: true,
|
||||
json: { data: [{ index: 0, embedding: [0, 1] }] },
|
||||
});
|
||||
|
||||
const { provider } = await createGitHubCopilotEmbeddingProvider({
|
||||
githubToken: "gh_test",
|
||||
model: "text-embedding-3-small",
|
||||
});
|
||||
|
||||
await provider.embedQuery("first");
|
||||
await provider.embedQuery("second");
|
||||
|
||||
const firstHeaders = fetchWithSsrFGuardMock.mock.calls[0]?.[0]?.init?.headers as Record<
|
||||
string,
|
||||
string
|
||||
>;
|
||||
const secondHeaders = fetchWithSsrFGuardMock.mock.calls[1]?.[0]?.init?.headers as Record<
|
||||
string,
|
||||
string
|
||||
>;
|
||||
expect(firstHeaders.Authorization).toBe("Bearer copilot-token-first");
|
||||
expect(secondHeaders.Authorization).toBe("Bearer copilot-token-second");
|
||||
});
|
||||
|
||||
it("honors custom baseUrl and header overrides", async () => {
|
||||
mockFetchResponse({
|
||||
ok: true,
|
||||
json: { data: [{ index: 0, embedding: [1, 0] }] },
|
||||
});
|
||||
|
||||
const { provider } = await createGitHubCopilotEmbeddingProvider({
|
||||
githubToken: "gh_test",
|
||||
model: "text-embedding-3-small",
|
||||
baseUrl: "https://proxy.example/v1",
|
||||
headers: { "X-Proxy-Token": "proxy" },
|
||||
});
|
||||
|
||||
await provider.embedQuery("hello");
|
||||
|
||||
const call = fetchWithSsrFGuardMock.mock.calls[0]?.[0] as {
|
||||
init: { headers: Record<string, string> };
|
||||
url: string;
|
||||
};
|
||||
expect(call.url).toBe("https://proxy.example/v1/embeddings");
|
||||
expect(call.init.headers["X-Proxy-Token"]).toBe("proxy");
|
||||
expect(call.init.headers.Authorization).toBe("Bearer copilot-token-a");
|
||||
});
|
||||
|
||||
it("fails fast on sparse or malformed embedding payloads", async () => {
|
||||
mockFetchResponse({
|
||||
ok: true,
|
||||
json: {
|
||||
data: [{ index: 1, embedding: [1, 0] }],
|
||||
},
|
||||
});
|
||||
|
||||
const { provider } = await createGitHubCopilotEmbeddingProvider({
|
||||
githubToken: "gh_test",
|
||||
model: "text-embedding-3-small",
|
||||
});
|
||||
|
||||
await expect(provider.embedBatch(["first", "second"])).rejects.toThrow(
|
||||
"GitHub Copilot embeddings response missing vectors for some inputs",
|
||||
);
|
||||
});
|
||||
});
|
||||
@@ -1,151 +0,0 @@
|
||||
import {
|
||||
DEFAULT_COPILOT_API_BASE_URL,
|
||||
resolveCopilotApiToken,
|
||||
} from "../../agents/github-copilot-token.js";
|
||||
import { sanitizeAndNormalizeEmbedding } from "./embedding-vectors.js";
|
||||
import type { EmbeddingProvider } from "./embeddings.types.js";
|
||||
import { buildRemoteBaseUrlPolicy, withRemoteHttpResponse } from "./remote-http.js";
|
||||
|
||||
export type GitHubCopilotEmbeddingClient = {
|
||||
githubToken: string;
|
||||
model: string;
|
||||
baseUrl?: string;
|
||||
headers?: Record<string, string>;
|
||||
env?: NodeJS.ProcessEnv;
|
||||
fetchImpl?: typeof fetch;
|
||||
};
|
||||
|
||||
const COPILOT_EMBEDDING_PROVIDER_ID = "github-copilot";
|
||||
|
||||
const COPILOT_HEADERS_STATIC: Record<string, string> = {
|
||||
"Content-Type": "application/json",
|
||||
"Editor-Version": "vscode/1.96.2",
|
||||
"User-Agent": "GitHubCopilotChat/0.26.7",
|
||||
};
|
||||
|
||||
function resolveConfiguredBaseUrl(
|
||||
configuredBaseUrl: string | undefined,
|
||||
tokenBaseUrl: string | undefined,
|
||||
): string {
|
||||
const trimmed = configuredBaseUrl?.trim();
|
||||
if (trimmed) {
|
||||
return trimmed;
|
||||
}
|
||||
return tokenBaseUrl || DEFAULT_COPILOT_API_BASE_URL;
|
||||
}
|
||||
|
||||
async function resolveGitHubCopilotEmbeddingSession(client: GitHubCopilotEmbeddingClient): Promise<{
|
||||
baseUrl: string;
|
||||
headers: Record<string, string>;
|
||||
}> {
|
||||
const token = await resolveCopilotApiToken({
|
||||
githubToken: client.githubToken,
|
||||
env: client.env,
|
||||
fetchImpl: client.fetchImpl,
|
||||
});
|
||||
const baseUrl = resolveConfiguredBaseUrl(client.baseUrl, token.baseUrl);
|
||||
return {
|
||||
baseUrl,
|
||||
headers: {
|
||||
...COPILOT_HEADERS_STATIC,
|
||||
...client.headers,
|
||||
Authorization: `Bearer ${token.token}`,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
function parseGitHubCopilotEmbeddingPayload(payload: unknown, expectedCount: number): number[][] {
|
||||
if (!payload || typeof payload !== "object") {
|
||||
throw new Error("GitHub Copilot embeddings response missing data[]");
|
||||
}
|
||||
const data = (payload as { data?: unknown }).data;
|
||||
if (!Array.isArray(data)) {
|
||||
throw new Error("GitHub Copilot embeddings response missing data[]");
|
||||
}
|
||||
|
||||
const vectors = Array.from<number[] | undefined>({ length: expectedCount });
|
||||
for (const entry of data) {
|
||||
if (!entry || typeof entry !== "object") {
|
||||
throw new Error("GitHub Copilot embeddings response contains an invalid entry");
|
||||
}
|
||||
const indexValue = (entry as { index?: unknown }).index;
|
||||
const embedding = (entry as { embedding?: unknown }).embedding;
|
||||
const index = typeof indexValue === "number" ? indexValue : Number.NaN;
|
||||
if (!Number.isInteger(index)) {
|
||||
throw new Error("GitHub Copilot embeddings response contains an invalid index");
|
||||
}
|
||||
if (index < 0 || index >= expectedCount) {
|
||||
throw new Error("GitHub Copilot embeddings response contains an out-of-range index");
|
||||
}
|
||||
if (vectors[index] !== undefined) {
|
||||
throw new Error("GitHub Copilot embeddings response contains duplicate indexes");
|
||||
}
|
||||
if (!Array.isArray(embedding) || !embedding.every((value) => typeof value === "number")) {
|
||||
throw new Error("GitHub Copilot embeddings response contains an invalid embedding");
|
||||
}
|
||||
vectors[index] = sanitizeAndNormalizeEmbedding(embedding);
|
||||
}
|
||||
|
||||
for (let index = 0; index < expectedCount; index += 1) {
|
||||
if (vectors[index] === undefined) {
|
||||
throw new Error("GitHub Copilot embeddings response missing vectors for some inputs");
|
||||
}
|
||||
}
|
||||
return vectors as number[][];
|
||||
}
|
||||
|
||||
export async function createGitHubCopilotEmbeddingProvider(
|
||||
client: GitHubCopilotEmbeddingClient,
|
||||
): Promise<{ provider: EmbeddingProvider; client: GitHubCopilotEmbeddingClient }> {
|
||||
const initialSession = await resolveGitHubCopilotEmbeddingSession(client);
|
||||
|
||||
const embed = async (input: string[]): Promise<number[][]> => {
|
||||
if (input.length === 0) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const session = await resolveGitHubCopilotEmbeddingSession(client);
|
||||
const url = `${session.baseUrl.replace(/\/$/, "")}/embeddings`;
|
||||
return await withRemoteHttpResponse({
|
||||
url,
|
||||
fetchImpl: client.fetchImpl,
|
||||
ssrfPolicy: buildRemoteBaseUrlPolicy(session.baseUrl),
|
||||
init: {
|
||||
method: "POST",
|
||||
headers: session.headers,
|
||||
body: JSON.stringify({ model: client.model, input }),
|
||||
},
|
||||
onResponse: async (response) => {
|
||||
if (!response.ok) {
|
||||
throw new Error(
|
||||
`GitHub Copilot embeddings HTTP ${response.status}: ${await response.text()}`,
|
||||
);
|
||||
}
|
||||
|
||||
let payload: unknown;
|
||||
try {
|
||||
payload = await response.json();
|
||||
} catch {
|
||||
throw new Error("GitHub Copilot embeddings returned invalid JSON");
|
||||
}
|
||||
return parseGitHubCopilotEmbeddingPayload(payload, input.length);
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
return {
|
||||
provider: {
|
||||
id: COPILOT_EMBEDDING_PROVIDER_ID,
|
||||
model: client.model,
|
||||
embedQuery: async (text) => {
|
||||
const [vector] = await embed([text]);
|
||||
return vector ?? [];
|
||||
},
|
||||
embedBatch: embed,
|
||||
},
|
||||
client: {
|
||||
...client,
|
||||
baseUrl: initialSession.baseUrl,
|
||||
},
|
||||
};
|
||||
}
|
||||
@@ -1,387 +0,0 @@
|
||||
import { afterEach, beforeAll, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import type { OpenClawConfig } from "../../config/config.js";
|
||||
|
||||
const ensureLmstudioModelLoadedMock = vi.hoisted(() => vi.fn());
|
||||
const resolveLmstudioRuntimeApiKeyMock = vi.hoisted(() => vi.fn());
|
||||
|
||||
vi.mock("../../plugin-sdk/lmstudio-runtime.js", () => ({
|
||||
buildLmstudioAuthHeaders: ({
|
||||
apiKey,
|
||||
json,
|
||||
headers,
|
||||
}: {
|
||||
apiKey?: string;
|
||||
json?: boolean;
|
||||
headers?: Record<string, string>;
|
||||
}) => ({
|
||||
...(json ? { "Content-Type": "application/json" } : {}),
|
||||
...(apiKey ? { Authorization: `Bearer ${apiKey}` } : {}),
|
||||
...headers,
|
||||
}),
|
||||
ensureLmstudioModelLoaded: (...args: unknown[]) => ensureLmstudioModelLoadedMock(...args),
|
||||
LMSTUDIO_DEFAULT_EMBEDDING_MODEL: "text-embedding-nomic-embed-text-v1.5",
|
||||
LMSTUDIO_PROVIDER_ID: "lmstudio",
|
||||
resolveLmstudioInferenceBase: (baseUrl?: string) => {
|
||||
const normalized = (baseUrl || "http://localhost:1234").replace(/\/+$/u, "");
|
||||
if (normalized.endsWith("/api/v1")) {
|
||||
return normalized.slice(0, -"/api/v1".length) + "/v1";
|
||||
}
|
||||
if (normalized.endsWith("/v1")) {
|
||||
return normalized;
|
||||
}
|
||||
return `${normalized}/v1`;
|
||||
},
|
||||
resolveLmstudioProviderHeaders: ({ headers }: { headers?: Record<string, string> }) =>
|
||||
headers ?? {},
|
||||
resolveLmstudioRuntimeApiKey: (...args: unknown[]) => resolveLmstudioRuntimeApiKeyMock(...args),
|
||||
}));
|
||||
|
||||
let createLmstudioEmbeddingProvider: typeof import("./embeddings-lmstudio.js").createLmstudioEmbeddingProvider;
|
||||
|
||||
describe("embeddings-lmstudio", () => {
|
||||
const originalFetch = globalThis.fetch;
|
||||
const jsonResponse = (embedding: number[]) =>
|
||||
new Response(
|
||||
JSON.stringify({
|
||||
data: [{ embedding }],
|
||||
}),
|
||||
{
|
||||
status: 200,
|
||||
headers: { "content-type": "application/json" },
|
||||
},
|
||||
);
|
||||
|
||||
function mockEmbeddingFetch(embedding: number[]) {
|
||||
const fetchMock = vi.fn<typeof fetch>();
|
||||
fetchMock.mockResolvedValue(jsonResponse(embedding));
|
||||
globalThis.fetch = fetchMock as unknown as typeof fetch;
|
||||
return fetchMock;
|
||||
}
|
||||
|
||||
beforeAll(async () => {
|
||||
({ createLmstudioEmbeddingProvider } = await import("./embeddings-lmstudio.js"));
|
||||
});
|
||||
|
||||
beforeEach(() => {
|
||||
ensureLmstudioModelLoadedMock.mockReset();
|
||||
resolveLmstudioRuntimeApiKeyMock.mockReset();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
globalThis.fetch = originalFetch;
|
||||
});
|
||||
|
||||
it("embeds against inference base and warms model with resolved key", async () => {
|
||||
ensureLmstudioModelLoadedMock.mockResolvedValue(undefined);
|
||||
resolveLmstudioRuntimeApiKeyMock.mockResolvedValue("profile-lmstudio-key");
|
||||
|
||||
const fetchMock = mockEmbeddingFetch([0.1, 0.2]);
|
||||
|
||||
const { provider } = await createLmstudioEmbeddingProvider({
|
||||
config: {
|
||||
models: {
|
||||
providers: {
|
||||
lmstudio: {
|
||||
baseUrl: "http://localhost:1234/api/v1/",
|
||||
headers: { "X-Provider": "provider" },
|
||||
models: [],
|
||||
},
|
||||
},
|
||||
},
|
||||
} as OpenClawConfig,
|
||||
provider: "lmstudio",
|
||||
model: "lmstudio/text-embedding-nomic-embed-text-v1.5",
|
||||
fallback: "none",
|
||||
});
|
||||
|
||||
await provider.embedQuery("hello");
|
||||
|
||||
expect(fetchMock).toHaveBeenCalledWith(
|
||||
"http://localhost:1234/v1/embeddings",
|
||||
expect.objectContaining({
|
||||
method: "POST",
|
||||
headers: expect.objectContaining({
|
||||
"Content-Type": "application/json",
|
||||
Authorization: "Bearer profile-lmstudio-key",
|
||||
"X-Provider": "provider",
|
||||
}),
|
||||
}),
|
||||
);
|
||||
expect(ensureLmstudioModelLoadedMock).toHaveBeenCalledWith({
|
||||
baseUrl: "http://localhost:1234/v1",
|
||||
apiKey: "profile-lmstudio-key",
|
||||
headers: {
|
||||
"X-Provider": "provider",
|
||||
},
|
||||
ssrfPolicy: { allowedHostnames: ["localhost"] },
|
||||
modelKey: "text-embedding-nomic-embed-text-v1.5",
|
||||
timeoutMs: 120_000,
|
||||
});
|
||||
});
|
||||
|
||||
it("uses memorySearch remote overrides for primary lmstudio", async () => {
|
||||
ensureLmstudioModelLoadedMock.mockResolvedValue(undefined);
|
||||
resolveLmstudioRuntimeApiKeyMock.mockResolvedValue("profile-key");
|
||||
|
||||
const fetchMock = mockEmbeddingFetch([1, 2, 3]);
|
||||
|
||||
const { provider } = await createLmstudioEmbeddingProvider({
|
||||
config: {
|
||||
models: {
|
||||
providers: {
|
||||
lmstudio: {
|
||||
baseUrl: "http://localhost:1234",
|
||||
headers: {
|
||||
"X-Provider": "provider",
|
||||
"X-Config-Only": "from-provider",
|
||||
},
|
||||
models: [],
|
||||
},
|
||||
},
|
||||
},
|
||||
} as OpenClawConfig,
|
||||
provider: "lmstudio",
|
||||
model: "",
|
||||
fallback: "none",
|
||||
remote: {
|
||||
baseUrl: "http://localhost:9999",
|
||||
apiKey: "remote-lmstudio-key",
|
||||
headers: {
|
||||
"X-Provider": "remote",
|
||||
"X-Remote-Only": "from-remote",
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
await provider.embedBatch(["one", "two"]);
|
||||
|
||||
expect(fetchMock).toHaveBeenCalledWith(
|
||||
"http://localhost:9999/v1/embeddings",
|
||||
expect.objectContaining({
|
||||
headers: expect.objectContaining({
|
||||
Authorization: "Bearer remote-lmstudio-key",
|
||||
"X-Provider": "remote",
|
||||
"X-Config-Only": "from-provider",
|
||||
"X-Remote-Only": "from-remote",
|
||||
}),
|
||||
}),
|
||||
);
|
||||
expect(resolveLmstudioRuntimeApiKeyMock).not.toHaveBeenCalled();
|
||||
expect(ensureLmstudioModelLoadedMock).toHaveBeenCalledWith({
|
||||
baseUrl: "http://localhost:9999/v1",
|
||||
apiKey: "remote-lmstudio-key",
|
||||
headers: {
|
||||
"X-Provider": "remote",
|
||||
"X-Config-Only": "from-provider",
|
||||
"X-Remote-Only": "from-remote",
|
||||
},
|
||||
ssrfPolicy: { allowedHostnames: ["localhost"] },
|
||||
modelKey: "text-embedding-nomic-embed-text-v1.5",
|
||||
timeoutMs: 120_000,
|
||||
});
|
||||
});
|
||||
|
||||
it("preserves remote Authorization header auth for primary lmstudio", async () => {
|
||||
ensureLmstudioModelLoadedMock.mockResolvedValue(undefined);
|
||||
resolveLmstudioRuntimeApiKeyMock.mockResolvedValue("stale-profile-key");
|
||||
|
||||
const fetchMock = mockEmbeddingFetch([1, 2, 3]);
|
||||
|
||||
const { provider } = await createLmstudioEmbeddingProvider({
|
||||
config: {
|
||||
models: {
|
||||
providers: {
|
||||
lmstudio: {
|
||||
baseUrl: "http://localhost:1234",
|
||||
headers: {
|
||||
"X-Provider": "provider",
|
||||
},
|
||||
models: [],
|
||||
},
|
||||
},
|
||||
},
|
||||
} as OpenClawConfig,
|
||||
provider: "lmstudio",
|
||||
model: "",
|
||||
fallback: "none",
|
||||
remote: {
|
||||
baseUrl: "http://localhost:9999",
|
||||
headers: {
|
||||
Authorization: "Bearer remote-proxy-token",
|
||||
"X-Remote-Only": "from-remote",
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
await provider.embedBatch(["one", "two"]);
|
||||
|
||||
expect(fetchMock).toHaveBeenCalledWith(
|
||||
"http://localhost:9999/v1/embeddings",
|
||||
expect.objectContaining({
|
||||
headers: expect.objectContaining({
|
||||
Authorization: "Bearer remote-proxy-token",
|
||||
"X-Provider": "provider",
|
||||
"X-Remote-Only": "from-remote",
|
||||
}),
|
||||
}),
|
||||
);
|
||||
expect(resolveLmstudioRuntimeApiKeyMock).not.toHaveBeenCalled();
|
||||
expect(ensureLmstudioModelLoadedMock).toHaveBeenCalledWith({
|
||||
baseUrl: "http://localhost:9999/v1",
|
||||
apiKey: undefined,
|
||||
headers: {
|
||||
"X-Provider": "provider",
|
||||
Authorization: "Bearer remote-proxy-token",
|
||||
"X-Remote-Only": "from-remote",
|
||||
},
|
||||
ssrfPolicy: { allowedHostnames: ["localhost"] },
|
||||
modelKey: "text-embedding-nomic-embed-text-v1.5",
|
||||
timeoutMs: 120_000,
|
||||
});
|
||||
});
|
||||
|
||||
it("ignores memorySearch remote overrides for fallback lmstudio activation", async () => {
|
||||
ensureLmstudioModelLoadedMock.mockResolvedValue(undefined);
|
||||
resolveLmstudioRuntimeApiKeyMock.mockResolvedValue("profile-key");
|
||||
|
||||
const fetchMock = mockEmbeddingFetch([1, 2, 3]);
|
||||
|
||||
const { provider } = await createLmstudioEmbeddingProvider({
|
||||
config: {
|
||||
models: {
|
||||
providers: {
|
||||
lmstudio: {
|
||||
baseUrl: "http://localhost:1234",
|
||||
headers: {
|
||||
"X-Provider": "provider",
|
||||
"X-Config-Only": "from-provider",
|
||||
},
|
||||
models: [],
|
||||
},
|
||||
},
|
||||
},
|
||||
} as OpenClawConfig,
|
||||
provider: "openai",
|
||||
model: "",
|
||||
fallback: "lmstudio",
|
||||
remote: {
|
||||
baseUrl: "http://localhost:9999",
|
||||
apiKey: "remote-lmstudio-key",
|
||||
headers: {
|
||||
"X-Provider": "remote",
|
||||
"X-Remote-Only": "from-remote",
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
await provider.embedBatch(["one", "two"]);
|
||||
|
||||
expect(fetchMock).toHaveBeenCalledWith(
|
||||
"http://localhost:1234/v1/embeddings",
|
||||
expect.objectContaining({
|
||||
headers: expect.objectContaining({
|
||||
Authorization: "Bearer profile-key",
|
||||
"X-Provider": "provider",
|
||||
"X-Config-Only": "from-provider",
|
||||
}),
|
||||
}),
|
||||
);
|
||||
const callHeaders = fetchMock.mock.calls[0]?.[1]?.headers as Record<string, string>;
|
||||
expect(callHeaders["X-Remote-Only"]).toBeUndefined();
|
||||
expect(resolveLmstudioRuntimeApiKeyMock).toHaveBeenCalled();
|
||||
expect(ensureLmstudioModelLoadedMock).toHaveBeenCalledWith({
|
||||
baseUrl: "http://localhost:1234/v1",
|
||||
apiKey: "profile-key",
|
||||
headers: {
|
||||
"X-Provider": "provider",
|
||||
"X-Config-Only": "from-provider",
|
||||
},
|
||||
ssrfPolicy: { allowedHostnames: ["localhost"] },
|
||||
modelKey: "text-embedding-nomic-embed-text-v1.5",
|
||||
timeoutMs: 120_000,
|
||||
});
|
||||
});
|
||||
|
||||
it("skips remote SecretRef resolution for fallback lmstudio activation", async () => {
|
||||
ensureLmstudioModelLoadedMock.mockResolvedValue(undefined);
|
||||
resolveLmstudioRuntimeApiKeyMock.mockResolvedValue("profile-key");
|
||||
|
||||
const fetchMock = mockEmbeddingFetch([1, 2, 3]);
|
||||
|
||||
const { provider } = await createLmstudioEmbeddingProvider({
|
||||
config: {
|
||||
models: {
|
||||
providers: {
|
||||
lmstudio: {
|
||||
baseUrl: "http://localhost:1234",
|
||||
headers: {
|
||||
"X-Provider": "provider",
|
||||
},
|
||||
models: [],
|
||||
},
|
||||
},
|
||||
},
|
||||
} as OpenClawConfig,
|
||||
provider: "openai",
|
||||
model: "",
|
||||
fallback: "lmstudio",
|
||||
remote: {
|
||||
baseUrl: "http://localhost:9999",
|
||||
apiKey: { source: "env", provider: "default", id: "OPENAI_API_KEY" },
|
||||
headers: {
|
||||
"X-Remote-Only": "from-remote",
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
await provider.embedQuery("hello");
|
||||
|
||||
expect(fetchMock).toHaveBeenCalledWith(
|
||||
"http://localhost:1234/v1/embeddings",
|
||||
expect.objectContaining({
|
||||
headers: expect.objectContaining({
|
||||
Authorization: "Bearer profile-key",
|
||||
"X-Provider": "provider",
|
||||
}),
|
||||
}),
|
||||
);
|
||||
const callHeaders = fetchMock.mock.calls[0]?.[1]?.headers as Record<string, string>;
|
||||
expect(callHeaders["X-Remote-Only"]).toBeUndefined();
|
||||
expect(resolveLmstudioRuntimeApiKeyMock).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("uses env-template-backed provider api keys in embedding requests", async () => {
|
||||
ensureLmstudioModelLoadedMock.mockResolvedValue(undefined);
|
||||
resolveLmstudioRuntimeApiKeyMock.mockResolvedValue("template-lmstudio-key");
|
||||
|
||||
const fetchMock = mockEmbeddingFetch([0.3, 0.4]);
|
||||
|
||||
const { provider } = await createLmstudioEmbeddingProvider({
|
||||
config: {
|
||||
models: {
|
||||
providers: {
|
||||
lmstudio: {
|
||||
baseUrl: "http://localhost:1234/v1",
|
||||
apiKey: "${LM_API_TOKEN}",
|
||||
models: [],
|
||||
},
|
||||
},
|
||||
},
|
||||
} as OpenClawConfig,
|
||||
provider: "lmstudio",
|
||||
model: "text-embedding-nomic-embed-text-v1.5",
|
||||
fallback: "none",
|
||||
});
|
||||
|
||||
await provider.embedQuery("hello");
|
||||
|
||||
expect(fetchMock).toHaveBeenCalledWith(
|
||||
"http://localhost:1234/v1/embeddings",
|
||||
expect.objectContaining({
|
||||
headers: expect.objectContaining({
|
||||
Authorization: "Bearer template-lmstudio-key",
|
||||
}),
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
@@ -1,19 +0,0 @@
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { DEFAULT_MISTRAL_EMBEDDING_MODEL, normalizeMistralModel } from "./embeddings-mistral.js";
|
||||
|
||||
describe("normalizeMistralModel", () => {
|
||||
it("returns the default model for empty values", () => {
|
||||
expect(normalizeMistralModel("")).toBe(DEFAULT_MISTRAL_EMBEDDING_MODEL);
|
||||
expect(normalizeMistralModel(" ")).toBe(DEFAULT_MISTRAL_EMBEDDING_MODEL);
|
||||
});
|
||||
|
||||
it("strips the mistral/ prefix", () => {
|
||||
expect(normalizeMistralModel("mistral/mistral-embed")).toBe("mistral-embed");
|
||||
expect(normalizeMistralModel(" mistral/custom-embed ")).toBe("custom-embed");
|
||||
});
|
||||
|
||||
it("keeps explicit non-prefixed models", () => {
|
||||
expect(normalizeMistralModel("mistral-embed")).toBe("mistral-embed");
|
||||
expect(normalizeMistralModel("custom-embed-v2")).toBe("custom-embed-v2");
|
||||
});
|
||||
});
|
||||
@@ -1,43 +0,0 @@
|
||||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
|
||||
const { createOllamaEmbeddingProviderMock } = vi.hoisted(() => ({
|
||||
createOllamaEmbeddingProviderMock: vi.fn(async (options: unknown) => ({
|
||||
provider: { source: "mock-provider", options },
|
||||
client: { source: "mock-client" },
|
||||
})),
|
||||
}));
|
||||
|
||||
vi.mock("../../plugin-sdk/ollama-runtime.js", () => ({
|
||||
DEFAULT_OLLAMA_EMBEDDING_MODEL: "nomic-embed-text",
|
||||
createOllamaEmbeddingProvider: createOllamaEmbeddingProviderMock,
|
||||
}));
|
||||
|
||||
describe("embeddings-ollama facade", () => {
|
||||
beforeEach(() => {
|
||||
createOllamaEmbeddingProviderMock.mockClear();
|
||||
});
|
||||
|
||||
it("re-exports the default Ollama embedding model", async () => {
|
||||
const mod = await import("./embeddings-ollama.js");
|
||||
expect(mod.DEFAULT_OLLAMA_EMBEDDING_MODEL).toBe("nomic-embed-text");
|
||||
});
|
||||
|
||||
it("delegates provider creation to the plugin-sdk runtime facade", async () => {
|
||||
const mod = await import("./embeddings-ollama.js");
|
||||
const options = {
|
||||
provider: "ollama",
|
||||
model: "nomic-embed-text",
|
||||
fallback: "none",
|
||||
config: {},
|
||||
};
|
||||
|
||||
const result = await mod.createOllamaEmbeddingProvider(options as never);
|
||||
|
||||
expect(createOllamaEmbeddingProviderMock).toHaveBeenCalledTimes(1);
|
||||
expect(createOllamaEmbeddingProviderMock).toHaveBeenCalledWith(options);
|
||||
expect(result).toEqual({
|
||||
provider: { source: "mock-provider", options },
|
||||
client: { source: "mock-client" },
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,5 +0,0 @@
|
||||
export type { OllamaEmbeddingClient } from "../../plugin-sdk/ollama-runtime.js";
|
||||
export {
|
||||
createOllamaEmbeddingProvider,
|
||||
DEFAULT_OLLAMA_EMBEDDING_MODEL,
|
||||
} from "../../plugin-sdk/ollama-runtime.js";
|
||||
@@ -1,25 +0,0 @@
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { DEFAULT_OPENAI_EMBEDDING_MODEL, normalizeOpenAiModel } from "./embeddings-openai.js";
|
||||
|
||||
describe("normalizeOpenAiModel", () => {
|
||||
it("returns the default model when input is blank", () => {
|
||||
expect(normalizeOpenAiModel("")).toBe(DEFAULT_OPENAI_EMBEDDING_MODEL);
|
||||
expect(normalizeOpenAiModel(" ")).toBe(DEFAULT_OPENAI_EMBEDDING_MODEL);
|
||||
});
|
||||
|
||||
it("strips the openai/ prefix", () => {
|
||||
expect(normalizeOpenAiModel("openai/text-embedding-3-small")).toBe("text-embedding-3-small");
|
||||
expect(normalizeOpenAiModel("openai/text-embedding-ada-002")).toBe("text-embedding-ada-002");
|
||||
});
|
||||
|
||||
it("preserves explicit third-party provider prefixes", () => {
|
||||
expect(normalizeOpenAiModel("spark/text-embedding-3-small")).toBe(
|
||||
"spark/text-embedding-3-small",
|
||||
);
|
||||
expect(normalizeOpenAiModel("litellm/azure/ada-002")).toBe("litellm/azure/ada-002");
|
||||
});
|
||||
|
||||
it("preserves unprefixed model ids", () => {
|
||||
expect(normalizeOpenAiModel("text-embedding-3-large")).toBe("text-embedding-3-large");
|
||||
});
|
||||
});
|
||||
@@ -1,88 +0,0 @@
|
||||
import { vi } from "vitest";
|
||||
import { type FetchMock, withFetchPreconnect } from "../../test-utils/fetch-mock.js";
|
||||
|
||||
vi.mock("../../infra/net/fetch-guard.js", () => ({
|
||||
fetchWithSsrFGuard: async (params: {
|
||||
url: string;
|
||||
init?: RequestInit;
|
||||
fetchImpl?: typeof fetch;
|
||||
}) => {
|
||||
const fetchImpl = params.fetchImpl ?? globalThis.fetch;
|
||||
if (!fetchImpl) {
|
||||
throw new Error("fetch is not available");
|
||||
}
|
||||
const response = await fetchImpl(params.url, params.init);
|
||||
return {
|
||||
response,
|
||||
finalUrl: params.url,
|
||||
release: async () => {},
|
||||
};
|
||||
},
|
||||
}));
|
||||
|
||||
type FetchPayloadFactory = (input: RequestInfo | URL, init?: RequestInit) => unknown;
|
||||
type JsonResponseFetchMock = ReturnType<typeof vi.fn<FetchMock>> & {
|
||||
preconnect: (
|
||||
url: string | URL,
|
||||
options?: { dns?: boolean; tcp?: boolean; http?: boolean; https?: boolean },
|
||||
) => void;
|
||||
__openclawAcceptsDispatcher: true;
|
||||
};
|
||||
|
||||
export type JsonFetchMock = ReturnType<typeof createJsonResponseFetchMock>;
|
||||
|
||||
export function createJsonResponseFetchMock(payload: FetchPayloadFactory): JsonResponseFetchMock;
|
||||
export function createJsonResponseFetchMock(payload: unknown): JsonResponseFetchMock;
|
||||
export function createJsonResponseFetchMock(payload: unknown) {
|
||||
const fetchMock = vi.fn<FetchMock>(async (input: RequestInfo | URL, init?: RequestInit) => {
|
||||
const body =
|
||||
typeof payload === "function" ? (payload as FetchPayloadFactory)(input, init) : payload;
|
||||
const serialized = JSON.stringify(body);
|
||||
return {
|
||||
ok: true,
|
||||
status: 200,
|
||||
json: async () => body,
|
||||
text: async () => serialized,
|
||||
} as Response;
|
||||
});
|
||||
return withFetchPreconnect(fetchMock) as JsonResponseFetchMock;
|
||||
}
|
||||
|
||||
export function createEmbeddingDataFetchMock(embeddingValues = [0.1, 0.2, 0.3]) {
|
||||
return createJsonResponseFetchMock({ data: [{ embedding: embeddingValues }] });
|
||||
}
|
||||
|
||||
export function createGeminiFetchMock(embeddingValues = [1, 2, 3]) {
|
||||
return createJsonResponseFetchMock({ embedding: { values: embeddingValues } });
|
||||
}
|
||||
|
||||
export function createGeminiBatchFetchMock(count: number, embeddingValues = [1, 2, 3]) {
|
||||
return createJsonResponseFetchMock({
|
||||
embeddings: Array.from({ length: count }, () => ({ values: embeddingValues })),
|
||||
});
|
||||
}
|
||||
|
||||
export function installFetchMock(fetchMock: typeof globalThis.fetch) {
|
||||
vi.stubGlobal("fetch", fetchMock);
|
||||
}
|
||||
|
||||
export function readFirstFetchRequest(fetchMock: { mock: { calls: unknown[][] } }) {
|
||||
const [url, init] = fetchMock.mock.calls[0] ?? [];
|
||||
return { url, init: init as RequestInit | undefined };
|
||||
}
|
||||
|
||||
export function parseFetchBody(fetchMock: { mock: { calls: unknown[][] } }, callIndex = 0) {
|
||||
const init = fetchMock.mock.calls[callIndex]?.[1] as RequestInit | undefined;
|
||||
return JSON.parse((init?.body as string) ?? "{}") as Record<string, unknown>;
|
||||
}
|
||||
|
||||
export function mockResolvedProviderKey(
|
||||
resolveApiKeyForProvider: typeof import("../../agents/model-auth.js").resolveApiKeyForProvider,
|
||||
apiKey = "test-key",
|
||||
) {
|
||||
vi.mocked(resolveApiKeyForProvider).mockResolvedValue({
|
||||
apiKey,
|
||||
mode: "api-key",
|
||||
source: "test",
|
||||
});
|
||||
}
|
||||
@@ -5,7 +5,7 @@ import type { EmbeddingProviderOptions } from "./embeddings.types.js";
|
||||
import { buildRemoteBaseUrlPolicy } from "./remote-http.js";
|
||||
import { resolveMemorySecretInputString } from "./secret-input.js";
|
||||
|
||||
export type RemoteEmbeddingProviderId = "openai" | "voyage" | "mistral";
|
||||
export type RemoteEmbeddingProviderId = string;
|
||||
|
||||
export async function resolveRemoteEmbeddingBearerClient(params: {
|
||||
provider: RemoteEmbeddingProviderId;
|
||||
|
||||
@@ -1,141 +0,0 @@
|
||||
import { afterEach, beforeAll, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import * as authModule from "../../agents/model-auth.js";
|
||||
import {
|
||||
createEmbeddingDataFetchMock,
|
||||
createJsonResponseFetchMock,
|
||||
installFetchMock,
|
||||
mockResolvedProviderKey,
|
||||
type JsonFetchMock,
|
||||
} from "./embeddings-provider.test-support.js";
|
||||
import { mockPublicPinnedHostname } from "./test-helpers/ssrf.js";
|
||||
|
||||
const { resolveApiKeyForProviderMock } = vi.hoisted(() => ({
|
||||
resolveApiKeyForProviderMock: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock("../../agents/model-auth.js", () => {
|
||||
return {
|
||||
resolveApiKeyForProvider: resolveApiKeyForProviderMock,
|
||||
requireApiKey: (auth: { apiKey?: string; mode?: string }, provider: string) => {
|
||||
if (auth.apiKey) {
|
||||
return auth.apiKey;
|
||||
}
|
||||
throw new Error(`No API key resolved for provider "${provider}" (auth mode: ${auth.mode}).`);
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
let createVoyageEmbeddingProvider: typeof import("./embeddings-voyage.js").createVoyageEmbeddingProvider;
|
||||
let normalizeVoyageModel: typeof import("./embeddings-voyage.js").normalizeVoyageModel;
|
||||
|
||||
beforeAll(async () => {
|
||||
({ createVoyageEmbeddingProvider, normalizeVoyageModel } =
|
||||
await import("./embeddings-voyage.js"));
|
||||
});
|
||||
|
||||
beforeEach(() => {
|
||||
vi.useRealTimers();
|
||||
vi.doUnmock("undici");
|
||||
});
|
||||
|
||||
async function createDefaultVoyageProvider(model: string, fetchMock: JsonFetchMock) {
|
||||
installFetchMock(fetchMock as unknown as typeof globalThis.fetch);
|
||||
mockPublicPinnedHostname();
|
||||
mockResolvedProviderKey(authModule.resolveApiKeyForProvider, "voyage-key-123");
|
||||
return createVoyageEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "voyage",
|
||||
model,
|
||||
fallback: "none",
|
||||
});
|
||||
}
|
||||
|
||||
describe("voyage embedding provider", () => {
|
||||
afterEach(() => {
|
||||
vi.doUnmock("undici");
|
||||
vi.resetAllMocks();
|
||||
vi.unstubAllGlobals();
|
||||
});
|
||||
|
||||
it("configures client with correct defaults and headers", async () => {
|
||||
const fetchMock = createEmbeddingDataFetchMock();
|
||||
const result = await createDefaultVoyageProvider("voyage-4-large", fetchMock);
|
||||
|
||||
await result.provider.embedQuery("test query");
|
||||
|
||||
expect(authModule.resolveApiKeyForProvider).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ provider: "voyage" }),
|
||||
);
|
||||
|
||||
const call = fetchMock.mock.calls[0];
|
||||
expect(call).toBeDefined();
|
||||
const [url, init] = call as [RequestInfo | URL, RequestInit | undefined];
|
||||
expect(url).toBe("https://api.voyageai.com/v1/embeddings");
|
||||
|
||||
const headers = (init?.headers ?? {}) as Record<string, string>;
|
||||
expect(headers.Authorization).toBe("Bearer voyage-key-123");
|
||||
expect(headers["Content-Type"]).toBe("application/json");
|
||||
|
||||
const body = JSON.parse(init?.body as string);
|
||||
expect(body).toEqual({
|
||||
model: "voyage-4-large",
|
||||
input: ["test query"],
|
||||
input_type: "query",
|
||||
});
|
||||
});
|
||||
|
||||
it("respects remote overrides for baseUrl and apiKey", async () => {
|
||||
const fetchMock = createEmbeddingDataFetchMock();
|
||||
installFetchMock(fetchMock as unknown as typeof globalThis.fetch);
|
||||
mockPublicPinnedHostname();
|
||||
|
||||
const result = await createVoyageEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "voyage",
|
||||
model: "voyage-4-lite",
|
||||
fallback: "none",
|
||||
remote: {
|
||||
baseUrl: "https://example.com",
|
||||
apiKey: "remote-override-key",
|
||||
headers: { "X-Custom": "123" },
|
||||
},
|
||||
});
|
||||
|
||||
await result.provider.embedQuery("test");
|
||||
|
||||
const call = fetchMock.mock.calls[0];
|
||||
expect(call).toBeDefined();
|
||||
const [url, init] = call as [RequestInfo | URL, RequestInit | undefined];
|
||||
expect(url).toBe("https://example.com/embeddings");
|
||||
|
||||
const headers = (init?.headers ?? {}) as Record<string, string>;
|
||||
expect(headers.Authorization).toBe("Bearer remote-override-key");
|
||||
expect(headers["X-Custom"]).toBe("123");
|
||||
});
|
||||
|
||||
it("passes input_type=document for embedBatch", async () => {
|
||||
const fetchMock = createJsonResponseFetchMock({
|
||||
data: [{ embedding: [0.1, 0.2] }, { embedding: [0.3, 0.4] }],
|
||||
});
|
||||
const result = await createDefaultVoyageProvider("voyage-4-large", fetchMock);
|
||||
|
||||
await result.provider.embedBatch(["doc1", "doc2"]);
|
||||
|
||||
const call = fetchMock.mock.calls[0];
|
||||
expect(call).toBeDefined();
|
||||
const [, init] = call as [RequestInfo | URL, RequestInit | undefined];
|
||||
const body = JSON.parse(init?.body as string);
|
||||
expect(body).toEqual({
|
||||
model: "voyage-4-large",
|
||||
input: ["doc1", "doc2"],
|
||||
input_type: "document",
|
||||
});
|
||||
});
|
||||
|
||||
it("normalizes model names", async () => {
|
||||
expect(normalizeVoyageModel("voyage/voyage-large-2")).toBe("voyage-large-2");
|
||||
expect(normalizeVoyageModel("voyage-4-large")).toBe("voyage-4-large");
|
||||
expect(normalizeVoyageModel(" voyage-lite ")).toBe("voyage-lite");
|
||||
expect(normalizeVoyageModel("")).toBe("voyage-4-large"); // Default
|
||||
});
|
||||
});
|
||||
@@ -1,768 +1,67 @@
|
||||
import { setTimeout as sleep } from "node:timers/promises";
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import * as authModule from "../../agents/model-auth.js";
|
||||
import {
|
||||
createEmbeddingDataFetchMock,
|
||||
createGeminiFetchMock,
|
||||
installFetchMock,
|
||||
mockResolvedProviderKey as mockResolvedProviderKeyBase,
|
||||
readFirstFetchRequest,
|
||||
} from "./embeddings-provider.test-support.js";
|
||||
import { createEmbeddingProvider, DEFAULT_LOCAL_MODEL } from "./embeddings.js";
|
||||
import { createLocalEmbeddingProvider, DEFAULT_LOCAL_MODEL } from "./embeddings.js";
|
||||
import * as nodeLlamaModule from "./node-llama.js";
|
||||
import { mockPublicPinnedHostname } from "./test-helpers/ssrf.js";
|
||||
|
||||
const {
|
||||
bedrockSendMock,
|
||||
createOllamaEmbeddingProviderMock,
|
||||
defaultProviderMock,
|
||||
resolveApiKeyForProviderMock,
|
||||
resolveCredentialsMock,
|
||||
} = vi.hoisted(() => ({
|
||||
bedrockSendMock: vi.fn(),
|
||||
createOllamaEmbeddingProviderMock: vi.fn(async () => {
|
||||
throw new Error("Unexpected ollama provider in embeddings.test.ts");
|
||||
}),
|
||||
defaultProviderMock: vi.fn(),
|
||||
resolveCredentialsMock: vi.fn(),
|
||||
resolveApiKeyForProviderMock: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock("../../agents/model-auth.js", () => {
|
||||
return {
|
||||
resolveApiKeyForProvider: resolveApiKeyForProviderMock,
|
||||
requireApiKey: (auth: { apiKey?: string; mode?: string }, provider: string) => {
|
||||
if (auth.apiKey) {
|
||||
return auth.apiKey;
|
||||
}
|
||||
throw new Error(`No API key resolved for provider "${provider}" (auth mode: ${auth.mode}).`);
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
vi.mock("./embeddings-ollama.js", () => ({
|
||||
createOllamaEmbeddingProvider: createOllamaEmbeddingProviderMock,
|
||||
}));
|
||||
|
||||
vi.mock("@aws-sdk/client-bedrock-runtime", () => {
|
||||
class MockClient {
|
||||
send = bedrockSendMock;
|
||||
}
|
||||
class MockCommand {
|
||||
input: unknown;
|
||||
constructor(input: unknown) {
|
||||
this.input = input;
|
||||
}
|
||||
}
|
||||
return { BedrockRuntimeClient: MockClient, InvokeModelCommand: MockCommand };
|
||||
});
|
||||
|
||||
vi.mock("@aws-sdk/credential-provider-node", () => ({
|
||||
defaultProvider: defaultProviderMock.mockImplementation(() => resolveCredentialsMock),
|
||||
}));
|
||||
|
||||
const createFetchMock = () => createEmbeddingDataFetchMock([1, 2, 3]);
|
||||
|
||||
type ResolvedProviderAuth = Awaited<ReturnType<typeof authModule.resolveApiKeyForProvider>>;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.spyOn(authModule, "resolveApiKeyForProvider");
|
||||
vi.spyOn(nodeLlamaModule, "importNodeLlamaCpp");
|
||||
defaultProviderMock.mockImplementation(() => resolveCredentialsMock);
|
||||
});
|
||||
|
||||
beforeEach(() => {
|
||||
vi.useRealTimers();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.resetAllMocks();
|
||||
vi.unstubAllGlobals();
|
||||
});
|
||||
|
||||
function requireProvider(result: Awaited<ReturnType<typeof createEmbeddingProvider>>) {
|
||||
if (!result.provider) {
|
||||
throw new Error("Expected embedding provider");
|
||||
}
|
||||
return result.provider;
|
||||
function mockLocalEmbeddingRuntime(vector = new Float32Array([2.35, 3.45, 0.63, 4.3])) {
|
||||
const getEmbeddingFor = vi.fn().mockResolvedValue({ vector });
|
||||
const createEmbeddingContext = vi.fn().mockResolvedValue({ getEmbeddingFor });
|
||||
const loadModel = vi.fn().mockResolvedValue({ createEmbeddingContext });
|
||||
const resolveModelFile = vi.fn(async (modelPath: string) => `/resolved/${modelPath}`);
|
||||
|
||||
vi.mocked(nodeLlamaModule.importNodeLlamaCpp).mockResolvedValue({
|
||||
getLlama: async () => ({ loadModel }),
|
||||
resolveModelFile,
|
||||
LlamaLogLevel: { error: 0 },
|
||||
} as never);
|
||||
|
||||
return { createEmbeddingContext, getEmbeddingFor, loadModel, resolveModelFile };
|
||||
}
|
||||
|
||||
function mockResolvedProviderKey(apiKey = "provider-key") {
|
||||
mockResolvedProviderKeyBase(authModule.resolveApiKeyForProvider, apiKey);
|
||||
}
|
||||
describe("local embedding provider", () => {
|
||||
it("normalizes local embeddings and resolves the default local model", async () => {
|
||||
const runtime = mockLocalEmbeddingRuntime();
|
||||
|
||||
function mockMissingLocalEmbeddingDependency() {
|
||||
vi.mocked(nodeLlamaModule.importNodeLlamaCpp).mockRejectedValue(
|
||||
Object.assign(new Error("Cannot find package 'node-llama-cpp'"), {
|
||||
code: "ERR_MODULE_NOT_FOUND",
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
function createLocalProvider(options?: { fallback?: "none" | "openai" }) {
|
||||
return createEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "local",
|
||||
model: "text-embedding-3-small",
|
||||
fallback: options?.fallback ?? "none",
|
||||
});
|
||||
}
|
||||
|
||||
function expectAutoSelectedProvider(
|
||||
result: Awaited<ReturnType<typeof createEmbeddingProvider>>,
|
||||
expectedId: "openai" | "gemini" | "mistral" | "bedrock",
|
||||
) {
|
||||
expect(result.requestedProvider).toBe("auto");
|
||||
const provider = requireProvider(result);
|
||||
expect(provider.id).toBe(expectedId);
|
||||
return provider;
|
||||
}
|
||||
|
||||
function createAutoProvider(model = "") {
|
||||
return createEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "auto",
|
||||
model,
|
||||
fallback: "none",
|
||||
});
|
||||
}
|
||||
|
||||
describe("embedding provider remote overrides", () => {
|
||||
it("uses remote baseUrl/apiKey and merges headers", async () => {
|
||||
const fetchMock = createFetchMock();
|
||||
installFetchMock(fetchMock as unknown as typeof globalThis.fetch);
|
||||
mockPublicPinnedHostname();
|
||||
mockResolvedProviderKey("provider-key");
|
||||
|
||||
const cfg = {
|
||||
models: {
|
||||
providers: {
|
||||
openai: {
|
||||
baseUrl: "https://api.openai.com/v1",
|
||||
headers: {
|
||||
"X-Provider": "p",
|
||||
"X-Shared": "provider",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const result = await createEmbeddingProvider({
|
||||
config: cfg as never,
|
||||
provider: "openai",
|
||||
remote: {
|
||||
baseUrl: "https://example.com/v1",
|
||||
apiKey: " remote-key ",
|
||||
headers: {
|
||||
"X-Shared": "remote",
|
||||
"X-Remote": "r",
|
||||
},
|
||||
},
|
||||
model: "text-embedding-3-small",
|
||||
fallback: "openai",
|
||||
});
|
||||
|
||||
const provider = requireProvider(result);
|
||||
await provider.embedQuery("hello");
|
||||
|
||||
expect(authModule.resolveApiKeyForProvider).not.toHaveBeenCalled();
|
||||
const url = fetchMock.mock.calls[0]?.[0];
|
||||
const init = fetchMock.mock.calls[0]?.[1];
|
||||
expect(url).toBe("https://example.com/v1/embeddings");
|
||||
const headers = (init?.headers ?? {}) as Record<string, string>;
|
||||
expect(headers.Authorization).toBe("Bearer remote-key");
|
||||
expect(headers["Content-Type"]).toBe("application/json");
|
||||
expect(headers["X-Provider"]).toBe("p");
|
||||
expect(headers["X-Shared"]).toBe("remote");
|
||||
expect(headers["X-Remote"]).toBe("r");
|
||||
});
|
||||
|
||||
it("falls back to resolved api key when remote apiKey is blank", async () => {
|
||||
const fetchMock = createFetchMock();
|
||||
installFetchMock(fetchMock as unknown as typeof globalThis.fetch);
|
||||
mockPublicPinnedHostname();
|
||||
mockResolvedProviderKey("provider-key");
|
||||
|
||||
const cfg = {
|
||||
models: {
|
||||
providers: {
|
||||
openai: {
|
||||
baseUrl: "https://api.openai.com/v1",
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const result = await createEmbeddingProvider({
|
||||
config: cfg as never,
|
||||
provider: "openai",
|
||||
remote: {
|
||||
baseUrl: "https://example.com/v1",
|
||||
apiKey: " ",
|
||||
},
|
||||
model: "text-embedding-3-small",
|
||||
fallback: "openai",
|
||||
});
|
||||
|
||||
const provider = requireProvider(result);
|
||||
await provider.embedQuery("hello");
|
||||
|
||||
expect(authModule.resolveApiKeyForProvider).toHaveBeenCalledTimes(1);
|
||||
const init = fetchMock.mock.calls[0]?.[1];
|
||||
const headers = (init?.headers as Record<string, string>) ?? {};
|
||||
expect(headers.Authorization).toBe("Bearer provider-key");
|
||||
});
|
||||
|
||||
it("builds Gemini embeddings requests with api key header", async () => {
|
||||
const fetchMock = createGeminiFetchMock();
|
||||
installFetchMock(fetchMock as unknown as typeof globalThis.fetch);
|
||||
mockPublicPinnedHostname();
|
||||
mockResolvedProviderKey("provider-key");
|
||||
|
||||
const cfg = {
|
||||
models: {
|
||||
providers: {
|
||||
google: {
|
||||
baseUrl: "https://generativelanguage.googleapis.com/v1beta",
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const result = await createEmbeddingProvider({
|
||||
config: cfg as never,
|
||||
provider: "gemini",
|
||||
remote: {
|
||||
apiKey: "gemini-key",
|
||||
},
|
||||
model: "text-embedding-004",
|
||||
fallback: "openai",
|
||||
});
|
||||
|
||||
const provider = requireProvider(result);
|
||||
await provider.embedQuery("hello");
|
||||
|
||||
const { url, init } = readFirstFetchRequest(fetchMock);
|
||||
expect(url).toBe(
|
||||
"https://generativelanguage.googleapis.com/v1beta/models/text-embedding-004:embedContent",
|
||||
);
|
||||
const headers = (init?.headers ?? {}) as Record<string, string>;
|
||||
expect(headers["x-goog-api-key"]).toBe("gemini-key");
|
||||
expect(headers["Content-Type"]).toBe("application/json");
|
||||
});
|
||||
|
||||
it("fails fast when Gemini remote apiKey is an unresolved SecretRef", async () => {
|
||||
vi.stubEnv("GEMINI_API_KEY", "");
|
||||
|
||||
await expect(
|
||||
createEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "gemini",
|
||||
remote: {
|
||||
apiKey: { source: "env", provider: "default", id: "GEMINI_API_KEY" },
|
||||
},
|
||||
model: "text-embedding-004",
|
||||
fallback: "openai",
|
||||
}),
|
||||
).rejects.toThrow(/agents\.\*\.memorySearch\.remote\.apiKey:/i);
|
||||
});
|
||||
|
||||
it("uses GEMINI_API_KEY env indirection for Gemini remote apiKey", async () => {
|
||||
vi.stubEnv("GEMINI_API_KEY", "env-gemini-key");
|
||||
|
||||
const result = await createEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "gemini",
|
||||
remote: {
|
||||
apiKey: "GEMINI_API_KEY", // pragma: allowlist secret
|
||||
},
|
||||
model: "text-embedding-004",
|
||||
fallback: "openai",
|
||||
});
|
||||
|
||||
const provider = requireProvider(result);
|
||||
expect(provider.id).toBe("gemini");
|
||||
expect(result.gemini?.apiKeys).toEqual(["env-gemini-key"]);
|
||||
});
|
||||
|
||||
it("builds Mistral embeddings requests with bearer auth", async () => {
|
||||
const fetchMock = createFetchMock();
|
||||
installFetchMock(fetchMock as unknown as typeof globalThis.fetch);
|
||||
mockPublicPinnedHostname();
|
||||
mockResolvedProviderKey("provider-key");
|
||||
|
||||
const cfg = {
|
||||
models: {
|
||||
providers: {
|
||||
mistral: {
|
||||
baseUrl: "https://api.mistral.ai/v1",
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const result = await createEmbeddingProvider({
|
||||
config: cfg as never,
|
||||
provider: "mistral",
|
||||
remote: {
|
||||
apiKey: "mistral-key", // pragma: allowlist secret
|
||||
},
|
||||
model: "mistral/mistral-embed",
|
||||
fallback: "none",
|
||||
});
|
||||
|
||||
const provider = requireProvider(result);
|
||||
await provider.embedQuery("hello");
|
||||
|
||||
const { url, init } = readFirstFetchRequest(fetchMock);
|
||||
expect(url).toBe("https://api.mistral.ai/v1/embeddings");
|
||||
const headers = (init?.headers ?? {}) as Record<string, string>;
|
||||
expect(headers.Authorization).toBe("Bearer mistral-key");
|
||||
const payload = JSON.parse((init?.body as string | undefined) ?? "{}") as { model?: string };
|
||||
expect(payload.model).toBe("mistral-embed");
|
||||
});
|
||||
});
|
||||
|
||||
describe("embedding provider auto selection", () => {
|
||||
it("keeps explicit model when openai is selected", async () => {
|
||||
const fetchMock = vi.fn(async (_input?: unknown, _init?: unknown) => ({
|
||||
ok: true,
|
||||
status: 200,
|
||||
json: async () => ({ data: [{ embedding: [1, 2, 3] }] }),
|
||||
}));
|
||||
installFetchMock(fetchMock as unknown as typeof globalThis.fetch);
|
||||
mockPublicPinnedHostname();
|
||||
vi.mocked(authModule.resolveApiKeyForProvider).mockImplementation(async ({ provider }) => {
|
||||
if (provider === "openai") {
|
||||
return { apiKey: "openai-key", source: "env: OPENAI_API_KEY", mode: "api-key" };
|
||||
}
|
||||
throw new Error(`Unexpected provider ${provider}`);
|
||||
});
|
||||
|
||||
const result = await createEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "auto",
|
||||
model: "text-embedding-3-small",
|
||||
fallback: "none",
|
||||
});
|
||||
|
||||
expect(result.requestedProvider).toBe("auto");
|
||||
const provider = requireProvider(result);
|
||||
expect(provider.id).toBe("openai");
|
||||
await provider.embedQuery("hello");
|
||||
const url = fetchMock.mock.calls[0]?.[0];
|
||||
const init = fetchMock.mock.calls[0]?.[1] as RequestInit | undefined;
|
||||
expect(url).toBe("https://api.openai.com/v1/embeddings");
|
||||
const payload = JSON.parse(init?.body as string) as { model?: string };
|
||||
expect(payload.model).toBe("text-embedding-3-small");
|
||||
});
|
||||
|
||||
it("selects the first available remote provider in auto mode", async () => {
|
||||
const cases: Array<{
|
||||
name: string;
|
||||
expectedProvider: "openai" | "gemini" | "mistral";
|
||||
resolveApiKey: (provider: string) => ResolvedProviderAuth;
|
||||
}> = [
|
||||
{
|
||||
name: "openai first",
|
||||
expectedProvider: "openai" as const,
|
||||
resolveApiKey(provider: string): ResolvedProviderAuth {
|
||||
if (provider === "openai") {
|
||||
return { apiKey: "openai-key", source: "env: OPENAI_API_KEY", mode: "api-key" };
|
||||
}
|
||||
throw new Error(`No API key found for provider "${provider}".`);
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "gemini fallback",
|
||||
expectedProvider: "gemini" as const,
|
||||
resolveApiKey(provider: string): ResolvedProviderAuth {
|
||||
if (provider === "openai") {
|
||||
throw new Error('No API key found for provider "openai".');
|
||||
}
|
||||
if (provider === "google") {
|
||||
return {
|
||||
apiKey: "gemini-key",
|
||||
source: "env: GEMINI_API_KEY",
|
||||
mode: "api-key" as const,
|
||||
};
|
||||
}
|
||||
throw new Error(`Unexpected provider ${provider}`);
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "mistral after earlier misses",
|
||||
expectedProvider: "mistral" as const,
|
||||
resolveApiKey(provider: string): ResolvedProviderAuth {
|
||||
if (provider === "mistral") {
|
||||
return {
|
||||
apiKey: "mistral-key",
|
||||
source: "env: MISTRAL_API_KEY",
|
||||
mode: "api-key" as const,
|
||||
};
|
||||
}
|
||||
throw new Error(`No API key found for provider "${provider}".`);
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
for (const testCase of cases) {
|
||||
vi.mocked(authModule.resolveApiKeyForProvider).mockReset();
|
||||
vi.mocked(authModule.resolveApiKeyForProvider).mockImplementation(async ({ provider }) =>
|
||||
testCase.resolveApiKey(provider),
|
||||
);
|
||||
|
||||
const result = await createAutoProvider();
|
||||
expectAutoSelectedProvider(result, testCase.expectedProvider);
|
||||
}
|
||||
});
|
||||
|
||||
it("selects Bedrock in auto mode when the AWS credential chain resolves", async () => {
|
||||
bedrockSendMock.mockResolvedValue({
|
||||
body: new TextEncoder().encode(JSON.stringify({ embedding: [1, 2, 3] })),
|
||||
});
|
||||
resolveCredentialsMock.mockResolvedValue({ accessKeyId: "AKIAEXAMPLE" });
|
||||
vi.mocked(authModule.resolveApiKeyForProvider).mockImplementation(async ({ provider }) => {
|
||||
throw new Error(`No API key found for provider "${provider}".`);
|
||||
});
|
||||
|
||||
const result = await createAutoProvider();
|
||||
const provider = expectAutoSelectedProvider(result, "bedrock");
|
||||
await provider.embedQuery("hello");
|
||||
|
||||
expect(bedrockSendMock).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it("rethrows non-auth Bedrock setup errors in auto mode", async () => {
|
||||
resolveCredentialsMock.mockResolvedValue({ accessKeyId: "AKIAEXAMPLE" });
|
||||
vi.mocked(authModule.resolveApiKeyForProvider).mockImplementation(async ({ provider }) => {
|
||||
throw new Error(`No API key found for provider "${provider}".`);
|
||||
});
|
||||
|
||||
await expect(
|
||||
createEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "auto",
|
||||
model: "",
|
||||
fallback: "none",
|
||||
outputDimensionality: 768,
|
||||
}),
|
||||
).rejects.toThrow("Invalid dimensions 768");
|
||||
});
|
||||
});
|
||||
|
||||
describe("embedding provider local fallback", () => {
|
||||
it("falls back to openai when node-llama-cpp is missing", async () => {
|
||||
mockMissingLocalEmbeddingDependency();
|
||||
|
||||
const fetchMock = createFetchMock();
|
||||
installFetchMock(fetchMock as unknown as typeof globalThis.fetch);
|
||||
|
||||
mockResolvedProviderKey("provider-key");
|
||||
|
||||
const result = await createLocalProvider({ fallback: "openai" });
|
||||
|
||||
const provider = requireProvider(result);
|
||||
expect(provider.id).toBe("openai");
|
||||
expect(result.fallbackFrom).toBe("local");
|
||||
expect(result.fallbackReason).toContain("node-llama-cpp");
|
||||
});
|
||||
|
||||
it("throws a helpful error when local is requested and fallback is none", async () => {
|
||||
mockMissingLocalEmbeddingDependency();
|
||||
await expect(createLocalProvider()).rejects.toThrow(/optional dependency node-llama-cpp/i);
|
||||
});
|
||||
|
||||
it("mentions every remote provider in local setup guidance", async () => {
|
||||
mockMissingLocalEmbeddingDependency();
|
||||
await expect(createLocalProvider()).rejects.toThrow(/provider = "gemini"/i);
|
||||
await expect(createLocalProvider()).rejects.toThrow(/provider = "mistral"/i);
|
||||
});
|
||||
});
|
||||
|
||||
describe("local embedding normalization", () => {
|
||||
async function createLocalProviderForTest() {
|
||||
return createEmbeddingProvider({
|
||||
const provider = await createLocalEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "local",
|
||||
model: "",
|
||||
fallback: "none",
|
||||
});
|
||||
}
|
||||
|
||||
function mockSingleLocalEmbeddingVector(
|
||||
vector: number[],
|
||||
resolveModelFile: (modelPath: string, modelDirectory?: string) => Promise<string> = async () =>
|
||||
"/fake/model.gguf",
|
||||
): void {
|
||||
vi.mocked(nodeLlamaModule.importNodeLlamaCpp).mockResolvedValue({
|
||||
getLlama: async () => ({
|
||||
loadModel: vi.fn().mockResolvedValue({
|
||||
createEmbeddingContext: vi.fn().mockResolvedValue({
|
||||
getEmbeddingFor: vi.fn().mockResolvedValue({
|
||||
vector: new Float32Array(vector),
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
resolveModelFile,
|
||||
LlamaLogLevel: { error: 0 },
|
||||
} as never);
|
||||
}
|
||||
|
||||
it("normalizes local embeddings to magnitude ~1.0", async () => {
|
||||
const unnormalizedVector = [2.35, 3.45, 0.63, 4.3, 1.2, 5.1, 2.8, 3.9];
|
||||
const resolveModelFileMock = vi.fn(async () => "/fake/model.gguf");
|
||||
|
||||
mockSingleLocalEmbeddingVector(unnormalizedVector, resolveModelFileMock);
|
||||
|
||||
const result = await createLocalProviderForTest();
|
||||
|
||||
const provider = requireProvider(result);
|
||||
const embedding = await provider.embedQuery("test query");
|
||||
const magnitude = Math.sqrt(embedding.reduce((sum, value) => sum + value * value, 0));
|
||||
|
||||
const magnitude = Math.sqrt(embedding.reduce((sum, x) => sum + x * x, 0));
|
||||
|
||||
expect(magnitude).toBeCloseTo(1.0, 5);
|
||||
expect(resolveModelFileMock).toHaveBeenCalledWith(DEFAULT_LOCAL_MODEL, undefined);
|
||||
expect(magnitude).toBeCloseTo(1, 5);
|
||||
expect(runtime.resolveModelFile).toHaveBeenCalledWith(DEFAULT_LOCAL_MODEL, undefined);
|
||||
expect(runtime.getEmbeddingFor).toHaveBeenCalledWith("test query");
|
||||
});
|
||||
|
||||
it("handles zero vector without division by zero", async () => {
|
||||
const zeroVector = [0, 0, 0, 0];
|
||||
it("trims explicit local model paths and cache directories", async () => {
|
||||
const runtime = mockLocalEmbeddingRuntime(new Float32Array([1, 0]));
|
||||
|
||||
mockSingleLocalEmbeddingVector(zeroVector);
|
||||
|
||||
const result = await createLocalProviderForTest();
|
||||
|
||||
const provider = requireProvider(result);
|
||||
const embedding = await provider.embedQuery("test");
|
||||
|
||||
expect(embedding).toEqual([0, 0, 0, 0]);
|
||||
expect(embedding.every((value) => Number.isFinite(value))).toBe(true);
|
||||
});
|
||||
|
||||
it("sanitizes non-finite values before normalization", async () => {
|
||||
const nonFiniteVector = [1, Number.NaN, Number.POSITIVE_INFINITY, Number.NEGATIVE_INFINITY];
|
||||
|
||||
mockSingleLocalEmbeddingVector(nonFiniteVector);
|
||||
|
||||
const result = await createLocalProviderForTest();
|
||||
|
||||
const provider = requireProvider(result);
|
||||
const embedding = await provider.embedQuery("test");
|
||||
|
||||
expect(embedding).toEqual([1, 0, 0, 0]);
|
||||
expect(embedding.every((value) => Number.isFinite(value))).toBe(true);
|
||||
});
|
||||
|
||||
it("normalizes batch embeddings to magnitude ~1.0", async () => {
|
||||
const unnormalizedVectors = [
|
||||
[2.35, 3.45, 0.63, 4.3],
|
||||
[10.0, 0.0, 0.0, 0.0],
|
||||
[1.0, 1.0, 1.0, 1.0],
|
||||
];
|
||||
|
||||
vi.mocked(nodeLlamaModule.importNodeLlamaCpp).mockResolvedValue({
|
||||
getLlama: async () => ({
|
||||
loadModel: vi.fn().mockResolvedValue({
|
||||
createEmbeddingContext: vi.fn().mockResolvedValue({
|
||||
getEmbeddingFor: vi
|
||||
.fn()
|
||||
.mockResolvedValueOnce({ vector: new Float32Array(unnormalizedVectors[0]) })
|
||||
.mockResolvedValueOnce({ vector: new Float32Array(unnormalizedVectors[1]) })
|
||||
.mockResolvedValueOnce({ vector: new Float32Array(unnormalizedVectors[2]) }),
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
resolveModelFile: async () => "/fake/model.gguf",
|
||||
LlamaLogLevel: { error: 0 },
|
||||
} as never);
|
||||
|
||||
const result = await createLocalProviderForTest();
|
||||
|
||||
const provider = requireProvider(result);
|
||||
const embeddings = await provider.embedBatch(["text1", "text2", "text3"]);
|
||||
|
||||
for (const embedding of embeddings) {
|
||||
const magnitude = Math.sqrt(embedding.reduce((sum, x) => sum + x * x, 0));
|
||||
expect(magnitude).toBeCloseTo(1.0, 5);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
describe("local embedding ensureContext concurrency", () => {
|
||||
async function setupLocalProviderWithMockedInit(params?: {
|
||||
initializationDelayMs?: number;
|
||||
failFirstGetLlama?: boolean;
|
||||
}) {
|
||||
const getLlamaSpy = vi.fn();
|
||||
const loadModelSpy = vi.fn();
|
||||
const createContextSpy = vi.fn();
|
||||
let shouldFail = params?.failFirstGetLlama ?? false;
|
||||
|
||||
vi.spyOn(nodeLlamaModule, "importNodeLlamaCpp").mockResolvedValue({
|
||||
getLlama: async (...args: unknown[]) => {
|
||||
getLlamaSpy(...args);
|
||||
if (shouldFail) {
|
||||
shouldFail = false;
|
||||
throw new Error("transient init failure");
|
||||
}
|
||||
if (params?.initializationDelayMs) {
|
||||
await sleep(params.initializationDelayMs);
|
||||
}
|
||||
return {
|
||||
loadModel: async (...modelArgs: unknown[]) => {
|
||||
loadModelSpy(...modelArgs);
|
||||
if (params?.initializationDelayMs) {
|
||||
await sleep(params.initializationDelayMs);
|
||||
}
|
||||
return {
|
||||
createEmbeddingContext: async () => {
|
||||
createContextSpy();
|
||||
return {
|
||||
getEmbeddingFor: vi.fn().mockResolvedValue({
|
||||
vector: new Float32Array([1, 0, 0, 0]),
|
||||
}),
|
||||
};
|
||||
},
|
||||
};
|
||||
},
|
||||
};
|
||||
},
|
||||
resolveModelFile: async () => "/fake/model.gguf",
|
||||
LlamaLogLevel: { error: 0 },
|
||||
} as never);
|
||||
|
||||
const result = await createEmbeddingProvider({
|
||||
const provider = await createLocalEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "local",
|
||||
model: "",
|
||||
fallback: "none",
|
||||
local: {
|
||||
modelPath: " /models/embed.gguf ",
|
||||
modelCacheDir: " /cache/models ",
|
||||
},
|
||||
});
|
||||
|
||||
return {
|
||||
provider: requireProvider(result),
|
||||
getLlamaSpy,
|
||||
loadModelSpy,
|
||||
createContextSpy,
|
||||
};
|
||||
}
|
||||
await provider.embedBatch(["a", "b"]);
|
||||
|
||||
it("loads the model only once when embedBatch is called concurrently", async () => {
|
||||
const { provider, getLlamaSpy, loadModelSpy, createContextSpy } =
|
||||
await setupLocalProviderWithMockedInit({
|
||||
initializationDelayMs: 5,
|
||||
});
|
||||
|
||||
const results = await Promise.all([
|
||||
provider.embedBatch(["text1"]),
|
||||
provider.embedBatch(["text2"]),
|
||||
provider.embedBatch(["text3"]),
|
||||
provider.embedBatch(["text4"]),
|
||||
]);
|
||||
|
||||
expect(results).toHaveLength(4);
|
||||
for (const embeddings of results) {
|
||||
expect(embeddings).toHaveLength(1);
|
||||
expect(embeddings[0]).toHaveLength(4);
|
||||
}
|
||||
|
||||
expect(getLlamaSpy).toHaveBeenCalledTimes(1);
|
||||
expect(loadModelSpy).toHaveBeenCalledTimes(1);
|
||||
expect(createContextSpy).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it("retries initialization after a transient ensureContext failure", async () => {
|
||||
const { provider, getLlamaSpy, loadModelSpy, createContextSpy } =
|
||||
await setupLocalProviderWithMockedInit({
|
||||
failFirstGetLlama: true,
|
||||
});
|
||||
|
||||
await expect(provider.embedBatch(["first"])).rejects.toThrow("transient init failure");
|
||||
|
||||
const recovered = await provider.embedBatch(["second"]);
|
||||
expect(recovered).toHaveLength(1);
|
||||
expect(recovered[0]).toHaveLength(4);
|
||||
|
||||
expect(getLlamaSpy).toHaveBeenCalledTimes(2);
|
||||
expect(loadModelSpy).toHaveBeenCalledTimes(1);
|
||||
expect(createContextSpy).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it("shares initialization when embedQuery and embedBatch start concurrently", async () => {
|
||||
const { provider, getLlamaSpy, loadModelSpy, createContextSpy } =
|
||||
await setupLocalProviderWithMockedInit({
|
||||
initializationDelayMs: 5,
|
||||
});
|
||||
|
||||
const [queryA, batch, queryB] = await Promise.all([
|
||||
provider.embedQuery("query-a"),
|
||||
provider.embedBatch(["batch-a", "batch-b"]),
|
||||
provider.embedQuery("query-b"),
|
||||
]);
|
||||
|
||||
expect(queryA).toHaveLength(4);
|
||||
expect(batch).toHaveLength(2);
|
||||
expect(queryB).toHaveLength(4);
|
||||
expect(batch[0]).toHaveLength(4);
|
||||
expect(batch[1]).toHaveLength(4);
|
||||
|
||||
expect(getLlamaSpy).toHaveBeenCalledTimes(1);
|
||||
expect(loadModelSpy).toHaveBeenCalledTimes(1);
|
||||
expect(createContextSpy).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe("FTS-only fallback when no provider available", () => {
|
||||
it("returns null provider when all requested auth paths fail", async () => {
|
||||
vi.mocked(authModule.resolveApiKeyForProvider).mockRejectedValue(
|
||||
new Error("No API key found for provider"),
|
||||
);
|
||||
|
||||
for (const testCase of [
|
||||
{
|
||||
name: "auto mode",
|
||||
options: {
|
||||
config: {} as never,
|
||||
provider: "auto" as const,
|
||||
model: "",
|
||||
fallback: "none" as const,
|
||||
},
|
||||
requestedProvider: "auto",
|
||||
fallbackFrom: undefined,
|
||||
reasonIncludes: "No API key",
|
||||
},
|
||||
{
|
||||
name: "explicit provider only",
|
||||
options: {
|
||||
config: {} as never,
|
||||
provider: "openai" as const,
|
||||
model: "text-embedding-3-small",
|
||||
fallback: "none" as const,
|
||||
},
|
||||
requestedProvider: "openai",
|
||||
fallbackFrom: undefined,
|
||||
reasonIncludes: "No API key",
|
||||
},
|
||||
{
|
||||
name: "primary and fallback",
|
||||
options: {
|
||||
config: {} as never,
|
||||
provider: "openai" as const,
|
||||
model: "text-embedding-3-small",
|
||||
fallback: "gemini" as const,
|
||||
},
|
||||
requestedProvider: "openai",
|
||||
fallbackFrom: "openai",
|
||||
reasonIncludes: "Fallback to gemini failed",
|
||||
},
|
||||
]) {
|
||||
const result = await createEmbeddingProvider(testCase.options);
|
||||
expect(result.provider, testCase.name).toBeNull();
|
||||
expect(result.requestedProvider, testCase.name).toBe(testCase.requestedProvider);
|
||||
expect(result.fallbackFrom, testCase.name).toBe(testCase.fallbackFrom);
|
||||
expect(result.providerUnavailableReason, testCase.name).toContain(testCase.reasonIncludes);
|
||||
}
|
||||
expect(provider.model).toBe("/models/embed.gguf");
|
||||
expect(runtime.resolveModelFile).toHaveBeenCalledWith("/models/embed.gguf", "/cache/models");
|
||||
expect(runtime.getEmbeddingFor).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,43 +1,9 @@
|
||||
import fsSync from "node:fs";
|
||||
import type { Llama, LlamaEmbeddingContext, LlamaModel } from "node-llama-cpp";
|
||||
import { formatErrorMessage } from "../../infra/errors.js";
|
||||
import { normalizeOptionalString } from "../../shared/string-coerce.js";
|
||||
import { resolveUserPath } from "../../utils.js";
|
||||
import { sanitizeAndNormalizeEmbedding } from "./embedding-vectors.js";
|
||||
import {
|
||||
createBedrockEmbeddingProvider,
|
||||
hasAwsCredentials,
|
||||
type BedrockEmbeddingClient,
|
||||
} from "./embeddings-bedrock.js";
|
||||
import { createGeminiEmbeddingProvider, type GeminiEmbeddingClient } from "./embeddings-gemini.js";
|
||||
import {
|
||||
createLmstudioEmbeddingProvider,
|
||||
type LmstudioEmbeddingClient,
|
||||
} from "./embeddings-lmstudio.js";
|
||||
import {
|
||||
createMistralEmbeddingProvider,
|
||||
type MistralEmbeddingClient,
|
||||
} from "./embeddings-mistral.js";
|
||||
import { createOllamaEmbeddingProvider, type OllamaEmbeddingClient } from "./embeddings-ollama.js";
|
||||
import { createOpenAiEmbeddingProvider, type OpenAiEmbeddingClient } from "./embeddings-openai.js";
|
||||
import { createVoyageEmbeddingProvider, type VoyageEmbeddingClient } from "./embeddings-voyage.js";
|
||||
import type {
|
||||
EmbeddingProvider,
|
||||
EmbeddingProviderFallback,
|
||||
EmbeddingProviderId,
|
||||
EmbeddingProviderOptions,
|
||||
EmbeddingProviderRequest,
|
||||
GeminiTaskType,
|
||||
} from "./embeddings.types.js";
|
||||
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.types.js";
|
||||
import { importNodeLlamaCpp } from "./node-llama.js";
|
||||
|
||||
export type { GeminiEmbeddingClient } from "./embeddings-gemini.js";
|
||||
export type { LmstudioEmbeddingClient } from "./embeddings-lmstudio.js";
|
||||
export type { MistralEmbeddingClient } from "./embeddings-mistral.js";
|
||||
export type { OpenAiEmbeddingClient } from "./embeddings-openai.js";
|
||||
export type { VoyageEmbeddingClient } from "./embeddings-voyage.js";
|
||||
export type { OllamaEmbeddingClient } from "./embeddings-ollama.js";
|
||||
export type { BedrockEmbeddingClient } from "./embeddings-bedrock.js";
|
||||
export type {
|
||||
EmbeddingProvider,
|
||||
EmbeddingProviderFallback,
|
||||
@@ -47,51 +13,9 @@ export type {
|
||||
GeminiTaskType,
|
||||
} from "./embeddings.types.js";
|
||||
|
||||
// Remote providers considered for auto-selection when provider === "auto".
|
||||
// LM Studio and Ollama are intentionally excluded here so that "auto" mode does not
|
||||
// implicitly assume either instance is available.
|
||||
// Bedrock is handled separately when AWS credentials are detected.
|
||||
const REMOTE_EMBEDDING_PROVIDER_IDS = ["openai", "gemini", "voyage", "mistral"] as const;
|
||||
|
||||
export type EmbeddingProviderResult = {
|
||||
provider: EmbeddingProvider | null;
|
||||
requestedProvider: EmbeddingProviderRequest;
|
||||
fallbackFrom?: EmbeddingProviderId;
|
||||
fallbackReason?: string;
|
||||
providerUnavailableReason?: string;
|
||||
openAi?: OpenAiEmbeddingClient;
|
||||
gemini?: GeminiEmbeddingClient;
|
||||
voyage?: VoyageEmbeddingClient;
|
||||
mistral?: MistralEmbeddingClient;
|
||||
ollama?: OllamaEmbeddingClient;
|
||||
bedrock?: BedrockEmbeddingClient;
|
||||
lmstudio?: LmstudioEmbeddingClient;
|
||||
};
|
||||
|
||||
export const DEFAULT_LOCAL_MODEL =
|
||||
"hf:ggml-org/embeddinggemma-300m-qat-q8_0-GGUF/embeddinggemma-300m-qat-Q8_0.gguf";
|
||||
|
||||
function canAutoSelectLocal(options: EmbeddingProviderOptions): boolean {
|
||||
const modelPath = options.local?.modelPath?.trim();
|
||||
if (!modelPath) {
|
||||
return false;
|
||||
}
|
||||
if (/^(hf:|https?:)/i.test(modelPath)) {
|
||||
return false;
|
||||
}
|
||||
const resolved = resolveUserPath(modelPath);
|
||||
try {
|
||||
return fsSync.statSync(resolved).isFile();
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
function isMissingApiKeyError(err: unknown): boolean {
|
||||
const message = formatErrorMessage(err);
|
||||
return message.includes("No API key found for provider");
|
||||
}
|
||||
|
||||
export async function createLocalEmbeddingProvider(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<EmbeddingProvider> {
|
||||
@@ -154,186 +78,3 @@ export async function createLocalEmbeddingProvider(
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
export async function createEmbeddingProvider(
|
||||
options: EmbeddingProviderOptions,
|
||||
): Promise<EmbeddingProviderResult> {
|
||||
const requestedProvider = options.provider;
|
||||
const fallback = options.fallback;
|
||||
|
||||
const createProvider = async (id: EmbeddingProviderId) => {
|
||||
if (id === "local") {
|
||||
const provider = await createLocalEmbeddingProvider(options);
|
||||
return { provider };
|
||||
}
|
||||
if (id === "lmstudio") {
|
||||
const { provider, client } = await createLmstudioEmbeddingProvider(options);
|
||||
return { provider, lmstudio: client };
|
||||
}
|
||||
if (id === "ollama") {
|
||||
const { provider, client } = await createOllamaEmbeddingProvider(options);
|
||||
return { provider, ollama: client };
|
||||
}
|
||||
if (id === "gemini") {
|
||||
const { provider, client } = await createGeminiEmbeddingProvider(options);
|
||||
return { provider, gemini: client };
|
||||
}
|
||||
if (id === "voyage") {
|
||||
const { provider, client } = await createVoyageEmbeddingProvider(options);
|
||||
return { provider, voyage: client };
|
||||
}
|
||||
if (id === "mistral") {
|
||||
const { provider, client } = await createMistralEmbeddingProvider(options);
|
||||
return { provider, mistral: client };
|
||||
}
|
||||
if (id === "bedrock") {
|
||||
const { provider, client } = await createBedrockEmbeddingProvider(options);
|
||||
return { provider, bedrock: client };
|
||||
}
|
||||
const { provider, client } = await createOpenAiEmbeddingProvider(options);
|
||||
return { provider, openAi: client };
|
||||
};
|
||||
|
||||
const formatPrimaryError = (err: unknown, provider: EmbeddingProviderId) =>
|
||||
provider === "local" ? formatLocalSetupError(err) : formatErrorMessage(err);
|
||||
|
||||
if (requestedProvider === "auto") {
|
||||
const missingKeyErrors: string[] = [];
|
||||
let localError: string | null = null;
|
||||
|
||||
if (canAutoSelectLocal(options)) {
|
||||
try {
|
||||
const local = await createProvider("local");
|
||||
return { ...local, requestedProvider };
|
||||
} catch (err) {
|
||||
localError = formatLocalSetupError(err);
|
||||
}
|
||||
}
|
||||
|
||||
for (const provider of REMOTE_EMBEDDING_PROVIDER_IDS) {
|
||||
try {
|
||||
const result = await createProvider(provider);
|
||||
return { ...result, requestedProvider };
|
||||
} catch (err) {
|
||||
const message = formatPrimaryError(err, provider);
|
||||
if (isMissingApiKeyError(err)) {
|
||||
missingKeyErrors.push(message);
|
||||
continue;
|
||||
}
|
||||
// Non-auth errors (e.g., network) are still fatal
|
||||
const wrapped = new Error(message) as Error & { cause?: unknown };
|
||||
wrapped.cause = err;
|
||||
throw wrapped;
|
||||
}
|
||||
}
|
||||
|
||||
// Try bedrock if AWS credentials are available
|
||||
if (await hasAwsCredentials()) {
|
||||
try {
|
||||
const result = await createProvider("bedrock");
|
||||
return { ...result, requestedProvider };
|
||||
} catch (err) {
|
||||
const message = formatPrimaryError(err, "bedrock");
|
||||
if (isMissingApiKeyError(err)) {
|
||||
missingKeyErrors.push(message);
|
||||
} else {
|
||||
const wrapped = new Error(message) as Error & { cause?: unknown };
|
||||
wrapped.cause = err;
|
||||
throw wrapped;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// All providers failed due to missing API keys - return null provider for FTS-only mode
|
||||
const details = [...missingKeyErrors, localError].filter(Boolean) as string[];
|
||||
const reason = details.length > 0 ? details.join("\n\n") : "No embeddings provider available.";
|
||||
return {
|
||||
provider: null,
|
||||
requestedProvider,
|
||||
providerUnavailableReason: reason,
|
||||
};
|
||||
}
|
||||
|
||||
try {
|
||||
const primary = await createProvider(requestedProvider);
|
||||
return { ...primary, requestedProvider };
|
||||
} catch (primaryErr) {
|
||||
const reason = formatPrimaryError(primaryErr, requestedProvider);
|
||||
if (fallback && fallback !== "none" && fallback !== requestedProvider) {
|
||||
try {
|
||||
const fallbackResult = await createProvider(fallback);
|
||||
return {
|
||||
...fallbackResult,
|
||||
requestedProvider,
|
||||
fallbackFrom: requestedProvider,
|
||||
fallbackReason: reason,
|
||||
};
|
||||
} catch (fallbackErr) {
|
||||
// Both primary and fallback failed - check if it's auth-related
|
||||
const fallbackReason = formatErrorMessage(fallbackErr);
|
||||
const combinedReason = `${reason}\n\nFallback to ${fallback} failed: ${fallbackReason}`;
|
||||
if (isMissingApiKeyError(primaryErr) && isMissingApiKeyError(fallbackErr)) {
|
||||
// Both failed due to missing API keys - return null for FTS-only mode
|
||||
return {
|
||||
provider: null,
|
||||
requestedProvider,
|
||||
fallbackFrom: requestedProvider,
|
||||
fallbackReason: reason,
|
||||
providerUnavailableReason: combinedReason,
|
||||
};
|
||||
}
|
||||
// Non-auth errors are still fatal
|
||||
const wrapped = new Error(combinedReason) as Error & { cause?: unknown };
|
||||
wrapped.cause = fallbackErr;
|
||||
throw wrapped;
|
||||
}
|
||||
}
|
||||
// No fallback configured - check if we should degrade to FTS-only
|
||||
if (isMissingApiKeyError(primaryErr)) {
|
||||
return {
|
||||
provider: null,
|
||||
requestedProvider,
|
||||
providerUnavailableReason: reason,
|
||||
};
|
||||
}
|
||||
const wrapped = new Error(reason) as Error & { cause?: unknown };
|
||||
wrapped.cause = primaryErr;
|
||||
throw wrapped;
|
||||
}
|
||||
}
|
||||
|
||||
function isNodeLlamaCppMissing(err: unknown): boolean {
|
||||
if (!(err instanceof Error)) {
|
||||
return false;
|
||||
}
|
||||
const code = (err as Error & { code?: unknown }).code;
|
||||
if (code === "ERR_MODULE_NOT_FOUND") {
|
||||
return err.message.includes("node-llama-cpp");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
function formatLocalSetupError(err: unknown): string {
|
||||
const detail = formatErrorMessage(err);
|
||||
const missing = isNodeLlamaCppMissing(err);
|
||||
return [
|
||||
"Local embeddings unavailable.",
|
||||
missing
|
||||
? "Reason: optional dependency node-llama-cpp is missing (or failed to install)."
|
||||
: detail
|
||||
? `Reason: ${detail}`
|
||||
: undefined,
|
||||
missing && detail ? `Detail: ${detail}` : null,
|
||||
"To enable local embeddings:",
|
||||
"1) Use Node 24 (recommended for installs/updates; Node 22 LTS, currently 22.14+, remains supported)",
|
||||
missing
|
||||
? "2) Reinstall OpenClaw (this should install node-llama-cpp): npm i -g openclaw@latest"
|
||||
: null,
|
||||
"3) If you use pnpm: pnpm approve-builds (select node-llama-cpp), then pnpm rebuild node-llama-cpp",
|
||||
...REMOTE_EMBEDDING_PROVIDER_IDS.map(
|
||||
(provider) => `Or set agents.defaults.memorySearch.provider = "${provider}" (remote).`,
|
||||
),
|
||||
]
|
||||
.filter(Boolean)
|
||||
.join("\n");
|
||||
}
|
||||
|
||||
@@ -11,18 +11,9 @@ export type EmbeddingProvider = {
|
||||
embedBatchInputs?: (inputs: EmbeddingInput[]) => Promise<number[][]>;
|
||||
};
|
||||
|
||||
export type EmbeddingProviderId =
|
||||
| "openai"
|
||||
| "local"
|
||||
| "gemini"
|
||||
| "voyage"
|
||||
| "mistral"
|
||||
| "lmstudio"
|
||||
| "ollama"
|
||||
| "bedrock";
|
||||
|
||||
export type EmbeddingProviderRequest = EmbeddingProviderId | "auto";
|
||||
export type EmbeddingProviderFallback = EmbeddingProviderId | "none";
|
||||
export type EmbeddingProviderId = string;
|
||||
export type EmbeddingProviderRequest = string;
|
||||
export type EmbeddingProviderFallback = string;
|
||||
|
||||
export type GeminiTaskType =
|
||||
| "RETRIEVAL_QUERY"
|
||||
@@ -36,14 +27,14 @@ export type GeminiTaskType =
|
||||
export type EmbeddingProviderOptions = {
|
||||
config: OpenClawConfig;
|
||||
agentDir?: string;
|
||||
provider: EmbeddingProviderRequest;
|
||||
provider?: EmbeddingProviderRequest;
|
||||
remote?: {
|
||||
baseUrl?: string;
|
||||
apiKey?: SecretInput;
|
||||
headers?: Record<string, string>;
|
||||
};
|
||||
model: string;
|
||||
fallback: EmbeddingProviderFallback;
|
||||
fallback?: EmbeddingProviderFallback;
|
||||
local?: {
|
||||
modelPath?: string;
|
||||
modelCacheDir?: string;
|
||||
|
||||
@@ -106,21 +106,3 @@ export function classifyMemoryMultimodalPath(
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
export function normalizeGeminiEmbeddingModelForMemory(model: string): string {
|
||||
const trimmed = model.trim();
|
||||
if (!trimmed) {
|
||||
return "";
|
||||
}
|
||||
return trimmed.replace(/^models\//, "").replace(/^(gemini|google)\//, "");
|
||||
}
|
||||
|
||||
export function supportsMemoryMultimodalEmbeddings(params: {
|
||||
provider: string;
|
||||
model: string;
|
||||
}): boolean {
|
||||
if (params.provider !== "gemini") {
|
||||
return false;
|
||||
}
|
||||
return normalizeGeminiEmbeddingModelForMemory(params.model) === "gemini-embedding-2-preview";
|
||||
}
|
||||
|
||||
@@ -85,11 +85,7 @@ export interface MemorySearchManager {
|
||||
onDebug?: (debug: MemorySearchRuntimeDebug) => void;
|
||||
},
|
||||
): Promise<MemorySearchResult[]>;
|
||||
readFile(params: {
|
||||
relPath: string;
|
||||
from?: number;
|
||||
lines?: number;
|
||||
}): Promise<MemoryReadResult>;
|
||||
readFile(params: { relPath: string; from?: number; lines?: number }): Promise<MemoryReadResult>;
|
||||
status(): MemoryProviderStatus;
|
||||
sync?(params?: {
|
||||
reason?: string;
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
export {
|
||||
isMemoryMultimodalEnabled,
|
||||
normalizeMemoryMultimodalSettings,
|
||||
supportsMemoryMultimodalEmbeddings,
|
||||
type MemoryMultimodalSettings,
|
||||
} from "./host/multimodal.js";
|
||||
|
||||
@@ -5,6 +5,10 @@ import path from "node:path";
|
||||
import { fileURLToPath, pathToFileURL } from "node:url";
|
||||
|
||||
export { resolveEnvApiKey } from "../agents/model-auth-env.js";
|
||||
export {
|
||||
collectProviderApiKeysForExecution,
|
||||
executeWithApiKeyRotation,
|
||||
} from "../agents/api-key-rotation.js";
|
||||
export { NON_ENV_SECRETREF_MARKER } from "../agents/model-auth-markers.js";
|
||||
export {
|
||||
requireApiKey,
|
||||
|
||||
@@ -84,13 +84,29 @@ export function resolvePluginCapabilityProviders<K extends CapabilityProviderReg
|
||||
}): CapabilityProviderForKey<K>[] {
|
||||
const activeRegistry = resolveRuntimePluginRegistry();
|
||||
const activeProviders = activeRegistry?.[params.key] ?? [];
|
||||
if (activeProviders.length > 0) {
|
||||
if (activeProviders.length > 0 && params.key !== "memoryEmbeddingProviders") {
|
||||
return activeProviders.map((entry) => entry.provider) as CapabilityProviderForKey<K>[];
|
||||
}
|
||||
const compatConfig = resolveCapabilityProviderConfig({ key: params.key, cfg: params.cfg });
|
||||
const loadOptions = compatConfig === undefined ? undefined : { config: compatConfig };
|
||||
const registry = resolveRuntimePluginRegistry(loadOptions);
|
||||
return (registry?.[params.key] ?? []).map(
|
||||
(entry) => entry.provider,
|
||||
) as CapabilityProviderForKey<K>[];
|
||||
if (params.key !== "memoryEmbeddingProviders") {
|
||||
return (registry?.[params.key] ?? []).map(
|
||||
(entry) => entry.provider,
|
||||
) as CapabilityProviderForKey<K>[];
|
||||
}
|
||||
const merged = new Map<string, CapabilityProviderForKey<K>>();
|
||||
for (const entry of activeProviders) {
|
||||
const provider = entry.provider as CapabilityProviderForKey<K> & { id?: string };
|
||||
if (provider.id) {
|
||||
merged.set(provider.id, provider);
|
||||
}
|
||||
}
|
||||
for (const entry of registry?.[params.key] ?? []) {
|
||||
const provider = entry.provider as CapabilityProviderForKey<K> & { id?: string };
|
||||
if (provider.id && !merged.has(provider.id)) {
|
||||
merged.set(provider.id, provider);
|
||||
}
|
||||
}
|
||||
return [...merged.values()];
|
||||
}
|
||||
|
||||
@@ -28,7 +28,14 @@ const packageManifestContractTests: PackageManifestContractParams[] = [
|
||||
},
|
||||
{ pluginId: "irc", minHostVersionBaseline: "2026.3.22" },
|
||||
{ pluginId: "line", minHostVersionBaseline: "2026.3.22" },
|
||||
{ pluginId: "amazon-bedrock", mirroredRootRuntimeDeps: ["@aws-sdk/client-bedrock"] },
|
||||
{
|
||||
pluginId: "amazon-bedrock",
|
||||
mirroredRootRuntimeDeps: [
|
||||
"@aws-sdk/client-bedrock",
|
||||
"@aws-sdk/client-bedrock-runtime",
|
||||
"@aws-sdk/credential-provider-node",
|
||||
],
|
||||
},
|
||||
{
|
||||
pluginId: "amazon-bedrock-mantle",
|
||||
mirroredRootRuntimeDeps: ["@aws/bedrock-token-generator"],
|
||||
|
||||
@@ -36,7 +36,7 @@ afterEach(() => {
|
||||
});
|
||||
|
||||
describe("memory embedding provider runtime resolution", () => {
|
||||
it("prefers registered adapters over capability fallback adapters", () => {
|
||||
it("merges registered and declared capability fallback adapters", () => {
|
||||
registerMemoryEmbeddingProvider({
|
||||
id: "registered",
|
||||
create: async () => ({ provider: null }),
|
||||
@@ -45,9 +45,10 @@ describe("memory embedding provider runtime resolution", () => {
|
||||
|
||||
expect(runtimeModule.listMemoryEmbeddingProviders().map((adapter) => adapter.id)).toEqual([
|
||||
"registered",
|
||||
"capability",
|
||||
]);
|
||||
expect(runtimeModule.getMemoryEmbeddingProvider("registered")?.id).toBe("registered");
|
||||
expect(mocks.resolvePluginCapabilityProviders).not.toHaveBeenCalled();
|
||||
expect(mocks.resolvePluginCapabilityProviders).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it("falls back to declared capability adapters when the registry is cold", () => {
|
||||
@@ -60,14 +61,22 @@ describe("memory embedding provider runtime resolution", () => {
|
||||
expect(mocks.resolvePluginCapabilityProviders).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it("does not consult capability fallback once runtime adapters are registered", () => {
|
||||
registerMemoryEmbeddingProvider({
|
||||
it("prefers registered adapters over declared capability fallback adapters with the same id", () => {
|
||||
const registered = {
|
||||
id: "openai",
|
||||
create: async () => ({ provider: null }),
|
||||
} satisfies MemoryEmbeddingProviderAdapter;
|
||||
registerMemoryEmbeddingProvider({
|
||||
...registered,
|
||||
});
|
||||
mocks.resolvePluginCapabilityProviders.mockReturnValue([createCapabilityAdapter("ollama")]);
|
||||
mocks.resolvePluginCapabilityProviders.mockReturnValue([createCapabilityAdapter("openai")]);
|
||||
|
||||
expect(runtimeModule.getMemoryEmbeddingProvider("ollama")).toBeUndefined();
|
||||
expect(mocks.resolvePluginCapabilityProviders).not.toHaveBeenCalled();
|
||||
expect(runtimeModule.getMemoryEmbeddingProvider("openai")).toEqual(
|
||||
expect.objectContaining({ id: "openai" }),
|
||||
);
|
||||
expect(runtimeModule.listMemoryEmbeddingProviders().map((adapter) => adapter.id)).toEqual([
|
||||
"openai",
|
||||
]);
|
||||
expect(mocks.resolvePluginCapabilityProviders).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -15,13 +15,16 @@ export function listMemoryEmbeddingProviders(
|
||||
cfg?: OpenClawConfig,
|
||||
): MemoryEmbeddingProviderAdapter[] {
|
||||
const registered = listRegisteredMemoryEmbeddingProviderAdapters();
|
||||
if (registered.length > 0) {
|
||||
return registered;
|
||||
}
|
||||
return resolvePluginCapabilityProviders({
|
||||
const merged = new Map(registered.map((adapter) => [adapter.id, adapter]));
|
||||
for (const adapter of resolvePluginCapabilityProviders({
|
||||
key: "memoryEmbeddingProviders",
|
||||
cfg,
|
||||
});
|
||||
})) {
|
||||
if (!merged.has(adapter.id)) {
|
||||
merged.set(adapter.id, adapter);
|
||||
}
|
||||
}
|
||||
return [...merged.values()];
|
||||
}
|
||||
|
||||
export function getMemoryEmbeddingProvider(
|
||||
@@ -32,8 +35,5 @@ export function getMemoryEmbeddingProvider(
|
||||
if (registered) {
|
||||
return registered.adapter;
|
||||
}
|
||||
if (listRegisteredMemoryEmbeddingProviders().length > 0) {
|
||||
return undefined;
|
||||
}
|
||||
return listMemoryEmbeddingProviders(cfg).find((adapter) => adapter.id === id);
|
||||
}
|
||||
|
||||
@@ -35,6 +35,8 @@ export type MemoryEmbeddingProvider = {
|
||||
export type MemoryEmbeddingProviderCreateOptions = {
|
||||
config: OpenClawConfig;
|
||||
agentDir?: string;
|
||||
provider?: string;
|
||||
fallback?: string;
|
||||
remote?: {
|
||||
baseUrl?: string;
|
||||
apiKey?: SecretInput;
|
||||
@@ -46,6 +48,14 @@ export type MemoryEmbeddingProviderCreateOptions = {
|
||||
modelCacheDir?: string;
|
||||
};
|
||||
outputDimensionality?: number;
|
||||
taskType?:
|
||||
| "RETRIEVAL_QUERY"
|
||||
| "RETRIEVAL_DOCUMENT"
|
||||
| "SEMANTIC_SIMILARITY"
|
||||
| "CLASSIFICATION"
|
||||
| "CLUSTERING"
|
||||
| "QUESTION_ANSWERING"
|
||||
| "FACT_VERIFICATION";
|
||||
};
|
||||
|
||||
export type MemoryEmbeddingProviderCreateResult = {
|
||||
@@ -57,6 +67,7 @@ export type MemoryEmbeddingProviderAdapter = {
|
||||
id: string;
|
||||
defaultModel?: string;
|
||||
transport?: "local" | "remote";
|
||||
authProviderId?: string;
|
||||
autoSelectPriority?: number;
|
||||
allowExplicitWhenConfiguredAuto?: boolean;
|
||||
supportsMultimodalEmbeddings?: (params: { model: string }) => boolean;
|
||||
|
||||
Reference in New Issue
Block a user