diff --git a/src/memory/batch-gemini.test.ts b/src/memory/batch-gemini.test.ts index 67d90a5a78b..57bc71291b9 100644 --- a/src/memory/batch-gemini.test.ts +++ b/src/memory/batch-gemini.test.ts @@ -75,9 +75,11 @@ describe("runGeminiEmbeddingBatches", () => { requests: [ { custom_id: "req-1", - content: { parts: [{ text: "hello world" }] }, - taskType: "RETRIEVAL_DOCUMENT", - outputDimensionality: 1536, + request: { + content: { parts: [{ text: "hello world" }] }, + taskType: "RETRIEVAL_DOCUMENT", + outputDimensionality: 1536, + }, }, ], wait: true, diff --git a/src/memory/batch-gemini.ts b/src/memory/batch-gemini.ts index 111570a998c..3afb5121ff7 100644 --- a/src/memory/batch-gemini.ts +++ b/src/memory/batch-gemini.ts @@ -5,15 +5,13 @@ import { } from "./batch-runner.js"; import { buildBatchHeaders, normalizeBatchBaseUrl } from "./batch-utils.js"; import { debugEmbeddingsLog } from "./embeddings-debug.js"; -import type { GeminiEmbeddingClient } from "./embeddings-gemini.js"; +import type { GeminiEmbeddingClient, GeminiTextEmbeddingRequest } from "./embeddings-gemini.js"; import { hashText } from "./internal.js"; import { withRemoteHttpResponse } from "./remote-http.js"; export type GeminiBatchRequest = { custom_id: string; - content: { parts: Array<{ text: string }> }; - taskType: "RETRIEVAL_DOCUMENT" | "RETRIEVAL_QUERY"; - outputDimensionality?: number; + request: GeminiTextEmbeddingRequest; }; export type GeminiBatchStatus = { @@ -83,13 +81,7 @@ async function submitGeminiBatch(params: { .map((request) => JSON.stringify({ key: request.custom_id, - request: { - content: request.content, - taskType: request.taskType, - ...(typeof request.outputDimensionality === "number" - ? { outputDimensionality: request.outputDimensionality } - : {}), - }, + request: request.request, }), ) .join("\n"); diff --git a/src/memory/embeddings-gemini.test.ts b/src/memory/embeddings-gemini.test.ts index 451ebd795da..36cb6bfd111 100644 --- a/src/memory/embeddings-gemini.test.ts +++ b/src/memory/embeddings-gemini.test.ts @@ -3,6 +3,7 @@ import * as authModule from "../agents/model-auth.js"; import { buildFileDataPart, buildGeminiParts, + buildGeminiTextEmbeddingRequest, buildInlineDataPart, createGeminiEmbeddingProvider, DEFAULT_GEMINI_EMBEDDING_MODEL, @@ -90,6 +91,24 @@ describe("buildFileDataPart", () => { }); }); +describe("buildGeminiTextEmbeddingRequest", () => { + it("builds a text embedding request with optional model and dimensions", () => { + expect( + buildGeminiTextEmbeddingRequest({ + text: "hello", + taskType: "RETRIEVAL_DOCUMENT", + modelPath: "models/gemini-embedding-2-preview", + outputDimensionality: 1536, + }), + ).toEqual({ + model: "models/gemini-embedding-2-preview", + content: { parts: [{ text: "hello" }] }, + taskType: "RETRIEVAL_DOCUMENT", + outputDimensionality: 1536, + }); + }); +}); + // ---------- Model detection ---------- describe("isGeminiEmbedding2Model", () => { @@ -255,6 +274,28 @@ describe("gemini-embedding-2-preview provider", () => { expect(body.outputDimensionality).toBe(768); }); + it("uses custom outputDimensionality for each embedBatch request", async () => { + const fetchMock = createGeminiBatchFetchMock(2); + vi.stubGlobal("fetch", fetchMock); + mockResolvedProviderKey(); + + const { provider } = await createGeminiEmbeddingProvider({ + config: {} as never, + provider: "gemini", + model: "gemini-embedding-2-preview", + fallback: "none", + outputDimensionality: 768, + }); + + await provider.embedBatch(["text1", "text2"]); + + const body = parseFetchBody(fetchMock); + expect(body.requests).toEqual([ + expect.objectContaining({ outputDimensionality: 768 }), + expect.objectContaining({ outputDimensionality: 768 }), + ]); + }); + it("throws for invalid outputDimensionality", async () => { mockResolvedProviderKey(); diff --git a/src/memory/embeddings-gemini.ts b/src/memory/embeddings-gemini.ts index 3aa13d798f2..f8c3d3f4a06 100644 --- a/src/memory/embeddings-gemini.ts +++ b/src/memory/embeddings-gemini.ts @@ -53,6 +53,12 @@ export type GeminiFilePart = { fileData: { mimeType: string; fileUri: string }; }; export type GeminiPart = GeminiTextPart | GeminiInlinePart | GeminiFilePart; +export type GeminiTextEmbeddingRequest = { + content: { parts: GeminiTextPart[] }; + taskType: GeminiTaskType; + outputDimensionality?: number; + model?: string; +}; /** Convert a string or pre-built parts array into `GeminiPart[]`. */ export function buildGeminiParts(input: string | GeminiPart[]): GeminiPart[] { @@ -72,6 +78,26 @@ export function buildFileDataPart(mimeType: string, fileUri: string): GeminiFile return { fileData: { mimeType, fileUri } }; } +/** Builds the text-only Gemini embedding request shape used across direct and batch APIs. */ +export function buildGeminiTextEmbeddingRequest(params: { + text: string; + taskType: GeminiTaskType; + outputDimensionality?: number; + modelPath?: string; +}): GeminiTextEmbeddingRequest { + const request: GeminiTextEmbeddingRequest = { + content: { parts: [{ text: params.text }] }, + taskType: params.taskType, + }; + if (params.modelPath) { + request.model = params.modelPath; + } + if (params.outputDimensionality != null) { + request.outputDimensionality = params.outputDimensionality; + } + return request; +} + /** * Returns true if the given model name is a gemini-embedding-2 variant that * supports `outputDimensionality` and extended task types. @@ -186,13 +212,11 @@ export async function createGeminiEmbeddingProvider( if (!text.trim()) { return []; } - const body: Record = { - content: { parts: [{ text }] }, + const body = buildGeminiTextEmbeddingRequest({ + text, taskType: options.taskType ?? "RETRIEVAL_QUERY", - }; - if (isV2 && outputDimensionality != null) { - body.outputDimensionality = outputDimensionality; - } + outputDimensionality: isV2 ? outputDimensionality : undefined, + }); const payload = await executeWithApiKeyRotation({ provider: "google", apiKeys: client.apiKeys, @@ -205,18 +229,15 @@ export async function createGeminiEmbeddingProvider( if (texts.length === 0) { return []; } - const requests = texts.map((text) => { - const req: Record = { - model: client.modelPath, - content: { parts: [{ text }] }, + const requests = texts.map((text) => + buildGeminiTextEmbeddingRequest({ + text, + modelPath: client.modelPath, taskType: options.taskType ?? "RETRIEVAL_DOCUMENT", - }; - if (isV2 && outputDimensionality != null) { - req.outputDimensionality = outputDimensionality; - } - return req; - }); - const batchBody: Record = { requests }; + outputDimensionality: isV2 ? outputDimensionality : undefined, + }), + ); + const batchBody = { requests }; const payload = await executeWithApiKeyRotation({ provider: "google", apiKeys: client.apiKeys, diff --git a/src/memory/manager-embedding-ops.ts b/src/memory/manager-embedding-ops.ts index 97a26dcc315..bcc653fda7a 100644 --- a/src/memory/manager-embedding-ops.ts +++ b/src/memory/manager-embedding-ops.ts @@ -9,6 +9,7 @@ import { import { type VoyageBatchRequest, runVoyageEmbeddingBatches } from "./batch-voyage.js"; import { enforceEmbeddingMaxInputTokens } from "./embedding-chunk-limits.js"; import { estimateUtf8Bytes } from "./embedding-input-limits.js"; +import { buildGeminiTextEmbeddingRequest } from "./embeddings-gemini.js"; import { chunkMarkdown, hashText, @@ -482,9 +483,11 @@ export abstract class MemoryManagerEmbeddingOps extends MemoryManagerSyncOps { provider: "gemini", enabled: Boolean(gemini), buildRequest: (chunk) => ({ - content: { parts: [{ text: chunk.text }] }, - taskType: "RETRIEVAL_DOCUMENT", - outputDimensionality: this.gemini?.outputDimensionality, + request: buildGeminiTextEmbeddingRequest({ + text: chunk.text, + taskType: "RETRIEVAL_DOCUMENT", + outputDimensionality: this.gemini?.outputDimensionality, + }), }), runBatch: async (runnerOptions) => await runGeminiEmbeddingBatches({