diff --git a/src/memory-host-sdk/host/embeddings-gemini.test.ts b/src/memory-host-sdk/host/embeddings-gemini.test.ts index a1f4ef028ef..9ae42863988 100644 --- a/src/memory-host-sdk/host/embeddings-gemini.test.ts +++ b/src/memory-host-sdk/host/embeddings-gemini.test.ts @@ -1,61 +1,21 @@ import { afterEach, beforeAll, beforeEach, describe, expect, it, vi } from "vitest"; import * as authModule from "../../agents/model-auth.js"; +import { + createGeminiBatchFetchMock, + createGeminiFetchMock, + installFetchMock, + mockResolvedProviderKey, + parseFetchBody, + readFirstFetchRequest, + type JsonFetchMock, +} from "./embeddings-provider.test-support.js"; import { mockPublicPinnedHostname } from "./test-helpers/ssrf.js"; -vi.mock("../../infra/net/fetch-guard.js", () => ({ - fetchWithSsrFGuard: async (params: { - url: string; - init?: RequestInit; - fetchImpl?: typeof fetch; - }) => { - const fetchImpl = params.fetchImpl ?? globalThis.fetch; - if (!fetchImpl) { - throw new Error("fetch is not available"); - } - const response = await fetchImpl(params.url, params.init); - return { - response, - finalUrl: params.url, - release: async () => {}, - }; - }, -})); - vi.mock("../../agents/model-auth.js", async () => { const { createModelAuthMockModule } = await import("../../test-utils/model-auth-mock.js"); return createModelAuthMockModule(); }); -const createGeminiFetchMock = (embeddingValues = [1, 2, 3]) => - vi.fn(async (_input?: unknown, _init?: unknown) => ({ - ok: true, - status: 200, - json: async () => ({ embedding: { values: embeddingValues } }), - })); - -const createGeminiBatchFetchMock = (count: number, embeddingValues = [1, 2, 3]) => - vi.fn(async (_input?: unknown, _init?: unknown) => ({ - ok: true, - status: 200, - json: async () => ({ - embeddings: Array.from({ length: count }, () => ({ values: embeddingValues })), - }), - })); - -function installFetchMock(fetchMock: typeof globalThis.fetch) { - vi.stubGlobal("fetch", fetchMock); -} - -function readFirstFetchRequest(fetchMock: { mock: { calls: unknown[][] } }) { - const [url, init] = fetchMock.mock.calls[0] ?? []; - return { url, init: init as RequestInit | undefined }; -} - -function parseFetchBody(fetchMock: { mock: { calls: unknown[][] } }, callIndex = 0) { - const init = fetchMock.mock.calls[callIndex]?.[1] as RequestInit | undefined; - return JSON.parse((init?.body as string) ?? "{}") as Record; -} - function magnitude(values: number[]) { return Math.sqrt(values.reduce((sum, value) => sum + value * value, 0)); } @@ -92,25 +52,13 @@ afterEach(() => { vi.unstubAllGlobals(); }); -function mockResolvedProviderKey(apiKey = "test-key") { - vi.mocked(authModule.resolveApiKeyForProvider).mockResolvedValue({ - apiKey, - mode: "api-key", - source: "test", - }); -} - -type GeminiFetchMock = - | ReturnType - | ReturnType; - async function createProviderWithFetch( - fetchMock: GeminiFetchMock, + fetchMock: JsonFetchMock, options: Partial[0]> & { model: string }, ) { installFetchMock(fetchMock as unknown as typeof globalThis.fetch); mockPublicPinnedHostname(); - mockResolvedProviderKey(); + mockResolvedProviderKey(authModule.resolveApiKeyForProvider); const { provider } = await createGeminiEmbeddingProvider({ config: {} as never, provider: "gemini", @@ -429,7 +377,7 @@ describe("gemini-embedding-2-preview provider", () => { }); it("throws for invalid outputDimensionality", async () => { - mockResolvedProviderKey(); + mockResolvedProviderKey(authModule.resolveApiKeyForProvider); await expect( createGeminiEmbeddingProvider({ @@ -493,7 +441,7 @@ describe("gemini model normalization", () => { const fetchMock = createGeminiFetchMock(); installFetchMock(fetchMock as unknown as typeof globalThis.fetch); mockPublicPinnedHostname(); - mockResolvedProviderKey(); + mockResolvedProviderKey(authModule.resolveApiKeyForProvider); const { provider } = await createGeminiEmbeddingProvider({ config: {} as never, @@ -512,7 +460,7 @@ describe("gemini model normalization", () => { const fetchMock = createGeminiFetchMock(); installFetchMock(fetchMock as unknown as typeof globalThis.fetch); mockPublicPinnedHostname(); - mockResolvedProviderKey(); + mockResolvedProviderKey(authModule.resolveApiKeyForProvider); const { provider } = await createGeminiEmbeddingProvider({ config: {} as never, @@ -531,7 +479,7 @@ describe("gemini model normalization", () => { const fetchMock = createGeminiFetchMock(); installFetchMock(fetchMock as unknown as typeof globalThis.fetch); mockPublicPinnedHostname(); - mockResolvedProviderKey(); + mockResolvedProviderKey(authModule.resolveApiKeyForProvider); const { provider } = await createGeminiEmbeddingProvider({ config: {} as never, @@ -549,7 +497,7 @@ describe("gemini model normalization", () => { it("defaults to gemini-embedding-001 when model is empty", async () => { const fetchMock = createGeminiFetchMock(); installFetchMock(fetchMock as unknown as typeof globalThis.fetch); - mockResolvedProviderKey(); + mockResolvedProviderKey(authModule.resolveApiKeyForProvider); const { provider, client } = await createGeminiEmbeddingProvider({ config: {} as never, @@ -563,7 +511,7 @@ describe("gemini model normalization", () => { }); it("returns empty array for blank query text", async () => { - mockResolvedProviderKey(); + mockResolvedProviderKey(authModule.resolveApiKeyForProvider); const { provider } = await createGeminiEmbeddingProvider({ config: {} as never, @@ -577,7 +525,7 @@ describe("gemini model normalization", () => { }); it("returns empty array for empty batch", async () => { - mockResolvedProviderKey(); + mockResolvedProviderKey(authModule.resolveApiKeyForProvider); const { provider } = await createGeminiEmbeddingProvider({ config: {} as never, diff --git a/src/memory-host-sdk/host/embeddings-provider.test-support.ts b/src/memory-host-sdk/host/embeddings-provider.test-support.ts new file mode 100644 index 00000000000..e5809f750c1 --- /dev/null +++ b/src/memory-host-sdk/host/embeddings-provider.test-support.ts @@ -0,0 +1,75 @@ +import { vi } from "vitest"; +import { type FetchMock, withFetchPreconnect } from "../../test-utils/fetch-mock.js"; + +vi.mock("../../infra/net/fetch-guard.js", () => ({ + fetchWithSsrFGuard: async (params: { + url: string; + init?: RequestInit; + fetchImpl?: typeof fetch; + }) => { + const fetchImpl = params.fetchImpl ?? globalThis.fetch; + if (!fetchImpl) { + throw new Error("fetch is not available"); + } + const response = await fetchImpl(params.url, params.init); + return { + response, + finalUrl: params.url, + release: async () => {}, + }; + }, +})); + +type FetchPayloadFactory = (input: RequestInfo | URL, init?: RequestInit) => unknown; + +export type JsonFetchMock = ReturnType; + +export function createJsonResponseFetchMock(payload: unknown | FetchPayloadFactory) { + const fetchMock = vi.fn(async (input: RequestInfo | URL, init?: RequestInit) => { + const body = typeof payload === "function" ? payload(input, init) : payload; + return new Response(JSON.stringify(body), { + status: 200, + headers: { "Content-Type": "application/json" }, + }); + }); + return withFetchPreconnect(fetchMock); +} + +export function createEmbeddingDataFetchMock(embeddingValues = [0.1, 0.2, 0.3]) { + return createJsonResponseFetchMock({ data: [{ embedding: embeddingValues }] }); +} + +export function createGeminiFetchMock(embeddingValues = [1, 2, 3]) { + return createJsonResponseFetchMock({ embedding: { values: embeddingValues } }); +} + +export function createGeminiBatchFetchMock(count: number, embeddingValues = [1, 2, 3]) { + return createJsonResponseFetchMock({ + embeddings: Array.from({ length: count }, () => ({ values: embeddingValues })), + }); +} + +export function installFetchMock(fetchMock: typeof globalThis.fetch) { + vi.stubGlobal("fetch", fetchMock); +} + +export function readFirstFetchRequest(fetchMock: { mock: { calls: unknown[][] } }) { + const [url, init] = fetchMock.mock.calls[0] ?? []; + return { url, init: init as RequestInit | undefined }; +} + +export function parseFetchBody(fetchMock: { mock: { calls: unknown[][] } }, callIndex = 0) { + const init = fetchMock.mock.calls[callIndex]?.[1] as RequestInit | undefined; + return JSON.parse((init?.body as string) ?? "{}") as Record; +} + +export function mockResolvedProviderKey( + resolveApiKeyForProvider: typeof import("../../agents/model-auth.js").resolveApiKeyForProvider, + apiKey = "test-key", +) { + vi.mocked(resolveApiKeyForProvider).mockResolvedValue({ + apiKey, + mode: "api-key", + source: "test", + }); +} diff --git a/src/memory-host-sdk/host/embeddings-voyage.test.ts b/src/memory-host-sdk/host/embeddings-voyage.test.ts index c5834796299..6ca0d99fed2 100644 --- a/src/memory-host-sdk/host/embeddings-voyage.test.ts +++ b/src/memory-host-sdk/host/embeddings-voyage.test.ts @@ -1,47 +1,19 @@ import { afterEach, beforeAll, beforeEach, describe, expect, it, vi } from "vitest"; import * as authModule from "../../agents/model-auth.js"; -import { type FetchMock, withFetchPreconnect } from "../../test-utils/fetch-mock.js"; +import { + createEmbeddingDataFetchMock, + createJsonResponseFetchMock, + installFetchMock, + mockResolvedProviderKey, + type JsonFetchMock, +} from "./embeddings-provider.test-support.js"; import { mockPublicPinnedHostname } from "./test-helpers/ssrf.js"; -vi.mock("../../infra/net/fetch-guard.js", () => ({ - fetchWithSsrFGuard: async (params: { - url: string; - init?: RequestInit; - fetchImpl?: typeof fetch; - }) => { - const fetchImpl = params.fetchImpl ?? globalThis.fetch; - if (!fetchImpl) { - throw new Error("fetch is not available"); - } - const response = await fetchImpl(params.url, params.init); - return { - response, - finalUrl: params.url, - release: async () => {}, - }; - }, -})); - vi.mock("../../agents/model-auth.js", async () => { const { createModelAuthMockModule } = await import("../../test-utils/model-auth-mock.js"); return createModelAuthMockModule(); }); -const createFetchMock = () => { - const fetchMock = vi.fn( - async (_input: RequestInfo | URL, _init?: RequestInit) => - new Response(JSON.stringify({ data: [{ embedding: [0.1, 0.2, 0.3] }] }), { - status: 200, - headers: { "Content-Type": "application/json" }, - }), - ); - return withFetchPreconnect(fetchMock); -}; - -function installFetchMock(fetchMock: typeof globalThis.fetch) { - vi.stubGlobal("fetch", fetchMock); -} - let createVoyageEmbeddingProvider: typeof import("./embeddings-voyage.js").createVoyageEmbeddingProvider; let normalizeVoyageModel: typeof import("./embeddings-voyage.js").normalizeVoyageModel; @@ -55,21 +27,10 @@ beforeEach(() => { vi.doUnmock("undici"); }); -function mockVoyageApiKey() { - vi.mocked(authModule.resolveApiKeyForProvider).mockResolvedValue({ - apiKey: "voyage-key-123", - mode: "api-key", - source: "test", - }); -} - -async function createDefaultVoyageProvider( - model: string, - fetchMock: ReturnType, -) { +async function createDefaultVoyageProvider(model: string, fetchMock: JsonFetchMock) { installFetchMock(fetchMock as unknown as typeof globalThis.fetch); mockPublicPinnedHostname(); - mockVoyageApiKey(); + mockResolvedProviderKey(authModule.resolveApiKeyForProvider, "voyage-key-123"); return createVoyageEmbeddingProvider({ config: {} as never, provider: "voyage", @@ -86,7 +47,7 @@ describe("voyage embedding provider", () => { }); it("configures client with correct defaults and headers", async () => { - const fetchMock = createFetchMock(); + const fetchMock = createEmbeddingDataFetchMock(); const result = await createDefaultVoyageProvider("voyage-4-large", fetchMock); await result.provider.embedQuery("test query"); @@ -113,7 +74,7 @@ describe("voyage embedding provider", () => { }); it("respects remote overrides for baseUrl and apiKey", async () => { - const fetchMock = createFetchMock(); + const fetchMock = createEmbeddingDataFetchMock(); installFetchMock(fetchMock as unknown as typeof globalThis.fetch); mockPublicPinnedHostname(); @@ -142,17 +103,9 @@ describe("voyage embedding provider", () => { }); it("passes input_type=document for embedBatch", async () => { - const fetchMock = withFetchPreconnect( - vi.fn( - async (_input: RequestInfo | URL, _init?: RequestInit) => - new Response( - JSON.stringify({ - data: [{ embedding: [0.1, 0.2] }, { embedding: [0.3, 0.4] }], - }), - { status: 200, headers: { "Content-Type": "application/json" } }, - ), - ), - ); + const fetchMock = createJsonResponseFetchMock({ + data: [{ embedding: [0.1, 0.2] }, { embedding: [0.3, 0.4] }], + }); const result = await createDefaultVoyageProvider("voyage-4-large", fetchMock); await result.provider.embedBatch(["doc1", "doc2"]); diff --git a/src/memory-host-sdk/host/embeddings.test.ts b/src/memory-host-sdk/host/embeddings.test.ts index 04cf29403c6..6d573e28ed1 100644 --- a/src/memory-host-sdk/host/embeddings.test.ts +++ b/src/memory-host-sdk/host/embeddings.test.ts @@ -2,6 +2,13 @@ import { setTimeout as sleep } from "node:timers/promises"; import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; import * as authModule from "../../agents/model-auth.js"; import { DEFAULT_GEMINI_EMBEDDING_MODEL } from "./embeddings-gemini.js"; +import { + createEmbeddingDataFetchMock, + createGeminiFetchMock, + installFetchMock, + mockResolvedProviderKey as mockResolvedProviderKeyBase, + readFirstFetchRequest, +} from "./embeddings-provider.test-support.js"; import { createEmbeddingProvider, DEFAULT_LOCAL_MODEL } from "./embeddings.js"; import * as nodeLlamaModule from "./node-llama.js"; import { mockPublicPinnedHostname } from "./test-helpers/ssrf.js"; @@ -20,25 +27,6 @@ const { resolveCredentialsMock: vi.fn(), })); -vi.mock("../../infra/net/fetch-guard.js", () => ({ - fetchWithSsrFGuard: async (params: { - url: string; - init?: RequestInit; - fetchImpl?: typeof fetch; - }) => { - const fetchImpl = params.fetchImpl ?? globalThis.fetch; - if (!fetchImpl) { - throw new Error("fetch is not available"); - } - const response = await fetchImpl(params.url, params.init); - return { - response, - finalUrl: params.url, - release: async () => {}, - }; - }, -})); - vi.mock("./embeddings-ollama.js", () => ({ createOllamaEmbeddingProvider: createOllamaEmbeddingProviderMock, })); @@ -60,28 +48,7 @@ vi.mock("@aws-sdk/credential-provider-node", () => ({ defaultProvider: defaultProviderMock.mockImplementation(() => resolveCredentialsMock), })); -const createFetchMock = () => - vi.fn(async (_input?: unknown, _init?: unknown) => ({ - ok: true, - status: 200, - json: async () => ({ data: [{ embedding: [1, 2, 3] }] }), - })); - -const createGeminiFetchMock = () => - vi.fn(async (_input?: unknown, _init?: unknown) => ({ - ok: true, - status: 200, - json: async () => ({ embedding: { values: [1, 2, 3] } }), - })); - -function installFetchMock(fetchMock: typeof globalThis.fetch) { - vi.stubGlobal("fetch", fetchMock); -} - -function readFirstFetchRequest(fetchMock: { mock: { calls: unknown[][] } }) { - const [url, init] = fetchMock.mock.calls[0] ?? []; - return { url, init: init as RequestInit | undefined }; -} +const createFetchMock = () => createEmbeddingDataFetchMock([1, 2, 3]); type ResolvedProviderAuth = Awaited>; @@ -108,11 +75,7 @@ function requireProvider(result: Awaited { expect(authModule.resolveApiKeyForProvider).not.toHaveBeenCalled(); const url = fetchMock.mock.calls[0]?.[0]; - const init = fetchMock.mock.calls[0]?.[1] as RequestInit | undefined; + const init = fetchMock.mock.calls[0]?.[1]; expect(url).toBe("https://example.com/v1/embeddings"); const headers = (init?.headers ?? {}) as Record; expect(headers.Authorization).toBe("Bearer remote-key"); @@ -233,7 +196,7 @@ describe("embedding provider remote overrides", () => { await provider.embedQuery("hello"); expect(authModule.resolveApiKeyForProvider).toHaveBeenCalledTimes(1); - const init = fetchMock.mock.calls[0]?.[1] as RequestInit | undefined; + const init = fetchMock.mock.calls[0]?.[1]; const headers = (init?.headers as Record) ?? {}; expect(headers.Authorization).toBe("Bearer provider-key"); });