refactor: dedupe embedding provider test fixtures

This commit is contained in:
Peter Steinberger
2026-04-08 15:25:03 +01:00
parent 27560b7b68
commit b358db1775
4 changed files with 118 additions and 179 deletions

View File

@@ -1,61 +1,21 @@
import { afterEach, beforeAll, beforeEach, describe, expect, it, vi } from "vitest";
import * as authModule from "../../agents/model-auth.js";
import {
createGeminiBatchFetchMock,
createGeminiFetchMock,
installFetchMock,
mockResolvedProviderKey,
parseFetchBody,
readFirstFetchRequest,
type JsonFetchMock,
} from "./embeddings-provider.test-support.js";
import { mockPublicPinnedHostname } from "./test-helpers/ssrf.js";
vi.mock("../../infra/net/fetch-guard.js", () => ({
fetchWithSsrFGuard: async (params: {
url: string;
init?: RequestInit;
fetchImpl?: typeof fetch;
}) => {
const fetchImpl = params.fetchImpl ?? globalThis.fetch;
if (!fetchImpl) {
throw new Error("fetch is not available");
}
const response = await fetchImpl(params.url, params.init);
return {
response,
finalUrl: params.url,
release: async () => {},
};
},
}));
vi.mock("../../agents/model-auth.js", async () => {
const { createModelAuthMockModule } = await import("../../test-utils/model-auth-mock.js");
return createModelAuthMockModule();
});
const createGeminiFetchMock = (embeddingValues = [1, 2, 3]) =>
vi.fn(async (_input?: unknown, _init?: unknown) => ({
ok: true,
status: 200,
json: async () => ({ embedding: { values: embeddingValues } }),
}));
const createGeminiBatchFetchMock = (count: number, embeddingValues = [1, 2, 3]) =>
vi.fn(async (_input?: unknown, _init?: unknown) => ({
ok: true,
status: 200,
json: async () => ({
embeddings: Array.from({ length: count }, () => ({ values: embeddingValues })),
}),
}));
function installFetchMock(fetchMock: typeof globalThis.fetch) {
vi.stubGlobal("fetch", fetchMock);
}
function readFirstFetchRequest(fetchMock: { mock: { calls: unknown[][] } }) {
const [url, init] = fetchMock.mock.calls[0] ?? [];
return { url, init: init as RequestInit | undefined };
}
function parseFetchBody(fetchMock: { mock: { calls: unknown[][] } }, callIndex = 0) {
const init = fetchMock.mock.calls[callIndex]?.[1] as RequestInit | undefined;
return JSON.parse((init?.body as string) ?? "{}") as Record<string, unknown>;
}
function magnitude(values: number[]) {
return Math.sqrt(values.reduce((sum, value) => sum + value * value, 0));
}
@@ -92,25 +52,13 @@ afterEach(() => {
vi.unstubAllGlobals();
});
function mockResolvedProviderKey(apiKey = "test-key") {
vi.mocked(authModule.resolveApiKeyForProvider).mockResolvedValue({
apiKey,
mode: "api-key",
source: "test",
});
}
type GeminiFetchMock =
| ReturnType<typeof createGeminiFetchMock>
| ReturnType<typeof createGeminiBatchFetchMock>;
async function createProviderWithFetch(
fetchMock: GeminiFetchMock,
fetchMock: JsonFetchMock,
options: Partial<Parameters<typeof createGeminiEmbeddingProvider>[0]> & { model: string },
) {
installFetchMock(fetchMock as unknown as typeof globalThis.fetch);
mockPublicPinnedHostname();
mockResolvedProviderKey();
mockResolvedProviderKey(authModule.resolveApiKeyForProvider);
const { provider } = await createGeminiEmbeddingProvider({
config: {} as never,
provider: "gemini",
@@ -429,7 +377,7 @@ describe("gemini-embedding-2-preview provider", () => {
});
it("throws for invalid outputDimensionality", async () => {
mockResolvedProviderKey();
mockResolvedProviderKey(authModule.resolveApiKeyForProvider);
await expect(
createGeminiEmbeddingProvider({
@@ -493,7 +441,7 @@ describe("gemini model normalization", () => {
const fetchMock = createGeminiFetchMock();
installFetchMock(fetchMock as unknown as typeof globalThis.fetch);
mockPublicPinnedHostname();
mockResolvedProviderKey();
mockResolvedProviderKey(authModule.resolveApiKeyForProvider);
const { provider } = await createGeminiEmbeddingProvider({
config: {} as never,
@@ -512,7 +460,7 @@ describe("gemini model normalization", () => {
const fetchMock = createGeminiFetchMock();
installFetchMock(fetchMock as unknown as typeof globalThis.fetch);
mockPublicPinnedHostname();
mockResolvedProviderKey();
mockResolvedProviderKey(authModule.resolveApiKeyForProvider);
const { provider } = await createGeminiEmbeddingProvider({
config: {} as never,
@@ -531,7 +479,7 @@ describe("gemini model normalization", () => {
const fetchMock = createGeminiFetchMock();
installFetchMock(fetchMock as unknown as typeof globalThis.fetch);
mockPublicPinnedHostname();
mockResolvedProviderKey();
mockResolvedProviderKey(authModule.resolveApiKeyForProvider);
const { provider } = await createGeminiEmbeddingProvider({
config: {} as never,
@@ -549,7 +497,7 @@ describe("gemini model normalization", () => {
it("defaults to gemini-embedding-001 when model is empty", async () => {
const fetchMock = createGeminiFetchMock();
installFetchMock(fetchMock as unknown as typeof globalThis.fetch);
mockResolvedProviderKey();
mockResolvedProviderKey(authModule.resolveApiKeyForProvider);
const { provider, client } = await createGeminiEmbeddingProvider({
config: {} as never,
@@ -563,7 +511,7 @@ describe("gemini model normalization", () => {
});
it("returns empty array for blank query text", async () => {
mockResolvedProviderKey();
mockResolvedProviderKey(authModule.resolveApiKeyForProvider);
const { provider } = await createGeminiEmbeddingProvider({
config: {} as never,
@@ -577,7 +525,7 @@ describe("gemini model normalization", () => {
});
it("returns empty array for empty batch", async () => {
mockResolvedProviderKey();
mockResolvedProviderKey(authModule.resolveApiKeyForProvider);
const { provider } = await createGeminiEmbeddingProvider({
config: {} as never,

View File

@@ -0,0 +1,75 @@
import { vi } from "vitest";
import { type FetchMock, withFetchPreconnect } from "../../test-utils/fetch-mock.js";
vi.mock("../../infra/net/fetch-guard.js", () => ({
fetchWithSsrFGuard: async (params: {
url: string;
init?: RequestInit;
fetchImpl?: typeof fetch;
}) => {
const fetchImpl = params.fetchImpl ?? globalThis.fetch;
if (!fetchImpl) {
throw new Error("fetch is not available");
}
const response = await fetchImpl(params.url, params.init);
return {
response,
finalUrl: params.url,
release: async () => {},
};
},
}));
type FetchPayloadFactory = (input: RequestInfo | URL, init?: RequestInit) => unknown;
export type JsonFetchMock = ReturnType<typeof createJsonResponseFetchMock>;
export function createJsonResponseFetchMock(payload: unknown | FetchPayloadFactory) {
const fetchMock = vi.fn<FetchMock>(async (input: RequestInfo | URL, init?: RequestInit) => {
const body = typeof payload === "function" ? payload(input, init) : payload;
return new Response(JSON.stringify(body), {
status: 200,
headers: { "Content-Type": "application/json" },
});
});
return withFetchPreconnect(fetchMock);
}
export function createEmbeddingDataFetchMock(embeddingValues = [0.1, 0.2, 0.3]) {
return createJsonResponseFetchMock({ data: [{ embedding: embeddingValues }] });
}
export function createGeminiFetchMock(embeddingValues = [1, 2, 3]) {
return createJsonResponseFetchMock({ embedding: { values: embeddingValues } });
}
export function createGeminiBatchFetchMock(count: number, embeddingValues = [1, 2, 3]) {
return createJsonResponseFetchMock({
embeddings: Array.from({ length: count }, () => ({ values: embeddingValues })),
});
}
export function installFetchMock(fetchMock: typeof globalThis.fetch) {
vi.stubGlobal("fetch", fetchMock);
}
export function readFirstFetchRequest(fetchMock: { mock: { calls: unknown[][] } }) {
const [url, init] = fetchMock.mock.calls[0] ?? [];
return { url, init: init as RequestInit | undefined };
}
export function parseFetchBody(fetchMock: { mock: { calls: unknown[][] } }, callIndex = 0) {
const init = fetchMock.mock.calls[callIndex]?.[1] as RequestInit | undefined;
return JSON.parse((init?.body as string) ?? "{}") as Record<string, unknown>;
}
export function mockResolvedProviderKey(
resolveApiKeyForProvider: typeof import("../../agents/model-auth.js").resolveApiKeyForProvider,
apiKey = "test-key",
) {
vi.mocked(resolveApiKeyForProvider).mockResolvedValue({
apiKey,
mode: "api-key",
source: "test",
});
}

View File

@@ -1,47 +1,19 @@
import { afterEach, beforeAll, beforeEach, describe, expect, it, vi } from "vitest";
import * as authModule from "../../agents/model-auth.js";
import { type FetchMock, withFetchPreconnect } from "../../test-utils/fetch-mock.js";
import {
createEmbeddingDataFetchMock,
createJsonResponseFetchMock,
installFetchMock,
mockResolvedProviderKey,
type JsonFetchMock,
} from "./embeddings-provider.test-support.js";
import { mockPublicPinnedHostname } from "./test-helpers/ssrf.js";
vi.mock("../../infra/net/fetch-guard.js", () => ({
fetchWithSsrFGuard: async (params: {
url: string;
init?: RequestInit;
fetchImpl?: typeof fetch;
}) => {
const fetchImpl = params.fetchImpl ?? globalThis.fetch;
if (!fetchImpl) {
throw new Error("fetch is not available");
}
const response = await fetchImpl(params.url, params.init);
return {
response,
finalUrl: params.url,
release: async () => {},
};
},
}));
vi.mock("../../agents/model-auth.js", async () => {
const { createModelAuthMockModule } = await import("../../test-utils/model-auth-mock.js");
return createModelAuthMockModule();
});
const createFetchMock = () => {
const fetchMock = vi.fn<FetchMock>(
async (_input: RequestInfo | URL, _init?: RequestInit) =>
new Response(JSON.stringify({ data: [{ embedding: [0.1, 0.2, 0.3] }] }), {
status: 200,
headers: { "Content-Type": "application/json" },
}),
);
return withFetchPreconnect(fetchMock);
};
function installFetchMock(fetchMock: typeof globalThis.fetch) {
vi.stubGlobal("fetch", fetchMock);
}
let createVoyageEmbeddingProvider: typeof import("./embeddings-voyage.js").createVoyageEmbeddingProvider;
let normalizeVoyageModel: typeof import("./embeddings-voyage.js").normalizeVoyageModel;
@@ -55,21 +27,10 @@ beforeEach(() => {
vi.doUnmock("undici");
});
function mockVoyageApiKey() {
vi.mocked(authModule.resolveApiKeyForProvider).mockResolvedValue({
apiKey: "voyage-key-123",
mode: "api-key",
source: "test",
});
}
async function createDefaultVoyageProvider(
model: string,
fetchMock: ReturnType<typeof createFetchMock>,
) {
async function createDefaultVoyageProvider(model: string, fetchMock: JsonFetchMock) {
installFetchMock(fetchMock as unknown as typeof globalThis.fetch);
mockPublicPinnedHostname();
mockVoyageApiKey();
mockResolvedProviderKey(authModule.resolveApiKeyForProvider, "voyage-key-123");
return createVoyageEmbeddingProvider({
config: {} as never,
provider: "voyage",
@@ -86,7 +47,7 @@ describe("voyage embedding provider", () => {
});
it("configures client with correct defaults and headers", async () => {
const fetchMock = createFetchMock();
const fetchMock = createEmbeddingDataFetchMock();
const result = await createDefaultVoyageProvider("voyage-4-large", fetchMock);
await result.provider.embedQuery("test query");
@@ -113,7 +74,7 @@ describe("voyage embedding provider", () => {
});
it("respects remote overrides for baseUrl and apiKey", async () => {
const fetchMock = createFetchMock();
const fetchMock = createEmbeddingDataFetchMock();
installFetchMock(fetchMock as unknown as typeof globalThis.fetch);
mockPublicPinnedHostname();
@@ -142,17 +103,9 @@ describe("voyage embedding provider", () => {
});
it("passes input_type=document for embedBatch", async () => {
const fetchMock = withFetchPreconnect(
vi.fn<FetchMock>(
async (_input: RequestInfo | URL, _init?: RequestInit) =>
new Response(
JSON.stringify({
data: [{ embedding: [0.1, 0.2] }, { embedding: [0.3, 0.4] }],
}),
{ status: 200, headers: { "Content-Type": "application/json" } },
),
),
);
const fetchMock = createJsonResponseFetchMock({
data: [{ embedding: [0.1, 0.2] }, { embedding: [0.3, 0.4] }],
});
const result = await createDefaultVoyageProvider("voyage-4-large", fetchMock);
await result.provider.embedBatch(["doc1", "doc2"]);

View File

@@ -2,6 +2,13 @@ import { setTimeout as sleep } from "node:timers/promises";
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
import * as authModule from "../../agents/model-auth.js";
import { DEFAULT_GEMINI_EMBEDDING_MODEL } from "./embeddings-gemini.js";
import {
createEmbeddingDataFetchMock,
createGeminiFetchMock,
installFetchMock,
mockResolvedProviderKey as mockResolvedProviderKeyBase,
readFirstFetchRequest,
} from "./embeddings-provider.test-support.js";
import { createEmbeddingProvider, DEFAULT_LOCAL_MODEL } from "./embeddings.js";
import * as nodeLlamaModule from "./node-llama.js";
import { mockPublicPinnedHostname } from "./test-helpers/ssrf.js";
@@ -20,25 +27,6 @@ const {
resolveCredentialsMock: vi.fn(),
}));
vi.mock("../../infra/net/fetch-guard.js", () => ({
fetchWithSsrFGuard: async (params: {
url: string;
init?: RequestInit;
fetchImpl?: typeof fetch;
}) => {
const fetchImpl = params.fetchImpl ?? globalThis.fetch;
if (!fetchImpl) {
throw new Error("fetch is not available");
}
const response = await fetchImpl(params.url, params.init);
return {
response,
finalUrl: params.url,
release: async () => {},
};
},
}));
vi.mock("./embeddings-ollama.js", () => ({
createOllamaEmbeddingProvider: createOllamaEmbeddingProviderMock,
}));
@@ -60,28 +48,7 @@ vi.mock("@aws-sdk/credential-provider-node", () => ({
defaultProvider: defaultProviderMock.mockImplementation(() => resolveCredentialsMock),
}));
const createFetchMock = () =>
vi.fn(async (_input?: unknown, _init?: unknown) => ({
ok: true,
status: 200,
json: async () => ({ data: [{ embedding: [1, 2, 3] }] }),
}));
const createGeminiFetchMock = () =>
vi.fn(async (_input?: unknown, _init?: unknown) => ({
ok: true,
status: 200,
json: async () => ({ embedding: { values: [1, 2, 3] } }),
}));
function installFetchMock(fetchMock: typeof globalThis.fetch) {
vi.stubGlobal("fetch", fetchMock);
}
function readFirstFetchRequest(fetchMock: { mock: { calls: unknown[][] } }) {
const [url, init] = fetchMock.mock.calls[0] ?? [];
return { url, init: init as RequestInit | undefined };
}
const createFetchMock = () => createEmbeddingDataFetchMock([1, 2, 3]);
type ResolvedProviderAuth = Awaited<ReturnType<typeof authModule.resolveApiKeyForProvider>>;
@@ -108,11 +75,7 @@ function requireProvider(result: Awaited<ReturnType<typeof createEmbeddingProvid
}
function mockResolvedProviderKey(apiKey = "provider-key") {
vi.mocked(authModule.resolveApiKeyForProvider).mockResolvedValue({
apiKey,
mode: "api-key",
source: "test",
});
mockResolvedProviderKeyBase(authModule.resolveApiKeyForProvider, apiKey);
}
function mockMissingLocalEmbeddingDependency() {
@@ -192,7 +155,7 @@ describe("embedding provider remote overrides", () => {
expect(authModule.resolveApiKeyForProvider).not.toHaveBeenCalled();
const url = fetchMock.mock.calls[0]?.[0];
const init = fetchMock.mock.calls[0]?.[1] as RequestInit | undefined;
const init = fetchMock.mock.calls[0]?.[1];
expect(url).toBe("https://example.com/v1/embeddings");
const headers = (init?.headers ?? {}) as Record<string, string>;
expect(headers.Authorization).toBe("Bearer remote-key");
@@ -233,7 +196,7 @@ describe("embedding provider remote overrides", () => {
await provider.embedQuery("hello");
expect(authModule.resolveApiKeyForProvider).toHaveBeenCalledTimes(1);
const init = fetchMock.mock.calls[0]?.[1] as RequestInit | undefined;
const init = fetchMock.mock.calls[0]?.[1];
const headers = (init?.headers as Record<string, string>) ?? {};
expect(headers.Authorization).toBe("Bearer provider-key");
});