test: trim embeddings provider import cost

This commit is contained in:
Shakker
2026-04-02 10:09:44 +01:00
committed by Peter Steinberger
parent 0af1d0ddb2
commit e9e7033ea1

View File

@@ -1,8 +1,17 @@
import { setTimeout as sleep } from "node:timers/promises";
import { afterEach, beforeAll, beforeEach, describe, expect, it, vi } from "vitest";
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
import * as authModule from "../../../../src/agents/model-auth.js";
import { DEFAULT_GEMINI_EMBEDDING_MODEL } from "./embeddings-gemini.js";
import { createEmbeddingProvider, DEFAULT_LOCAL_MODEL } from "./embeddings.js";
import * as nodeLlamaModule from "./node-llama.js";
import { mockPublicPinnedHostname } from "./test-helpers/ssrf.js";
const { createOllamaEmbeddingProviderMock } = vi.hoisted(() => ({
createOllamaEmbeddingProviderMock: vi.fn(async () => {
throw new Error("Unexpected ollama provider in embeddings.test.ts");
}),
}));
vi.mock("../../../../src/infra/net/fetch-guard.js", () => ({
fetchWithSsrFGuard: async (params: {
url: string;
@@ -22,6 +31,10 @@ vi.mock("../../../../src/infra/net/fetch-guard.js", () => ({
},
}));
vi.mock("./embeddings-ollama.js", () => ({
createOllamaEmbeddingProvider: createOllamaEmbeddingProviderMock,
}));
const createFetchMock = () =>
vi.fn(async (_input?: unknown, _init?: unknown) => ({
ok: true,
@@ -45,22 +58,11 @@ function readFirstFetchRequest(fetchMock: { mock: { calls: unknown[][] } }) {
return { url, init: init as RequestInit | undefined };
}
type EmbeddingsModule = typeof import("./embeddings.js");
type AuthModule = typeof import("../../../../src/agents/model-auth.js");
type ResolvedProviderAuth = Awaited<ReturnType<AuthModule["resolveApiKeyForProvider"]>>;
type ResolvedProviderAuth = Awaited<ReturnType<typeof authModule.resolveApiKeyForProvider>>;
let authModule: AuthModule;
let nodeLlamaModule: typeof import("./node-llama.js");
let createEmbeddingProvider: EmbeddingsModule["createEmbeddingProvider"];
let DEFAULT_LOCAL_MODEL: EmbeddingsModule["DEFAULT_LOCAL_MODEL"];
beforeAll(async () => {
vi.resetModules();
authModule = await import("../../../../src/agents/model-auth.js");
nodeLlamaModule = await import("./node-llama.js");
beforeEach(() => {
vi.spyOn(authModule, "resolveApiKeyForProvider");
vi.spyOn(nodeLlamaModule, "importNodeLlamaCpp");
({ createEmbeddingProvider, DEFAULT_LOCAL_MODEL } = await import("./embeddings.js"));
});
beforeEach(() => {
@@ -72,7 +74,7 @@ afterEach(() => {
vi.unstubAllGlobals();
});
function requireProvider(result: Awaited<ReturnType<EmbeddingsModule["createEmbeddingProvider"]>>) {
function requireProvider(result: Awaited<ReturnType<typeof createEmbeddingProvider>>) {
if (!result.provider) {
throw new Error("Expected embedding provider");
}
@@ -105,7 +107,7 @@ function createLocalProvider(options?: { fallback?: "none" | "openai" }) {
}
function expectAutoSelectedProvider(
result: Awaited<ReturnType<EmbeddingsModule["createEmbeddingProvider"]>>,
result: Awaited<ReturnType<typeof createEmbeddingProvider>>,
expectedId: "openai" | "gemini" | "mistral",
) {
expect(result.requestedProvider).toBe("auto");
@@ -574,11 +576,6 @@ describe("local embedding normalization", () => {
});
describe("local embedding ensureContext concurrency", () => {
beforeEach(() => {
vi.resetModules();
vi.doUnmock("./node-llama.js");
});
async function setupLocalProviderWithMockedInit(params?: {
initializationDelayMs?: number;
failFirstGetLlama?: boolean;
@@ -588,7 +585,6 @@ describe("local embedding ensureContext concurrency", () => {
const createContextSpy = vi.fn();
let shouldFail = params?.failFirstGetLlama ?? false;
const nodeLlamaModule = await import("./node-llama.js");
vi.spyOn(nodeLlamaModule, "importNodeLlamaCpp").mockResolvedValue({
getLlama: async (...args: unknown[]) => {
getLlamaSpy(...args);
@@ -622,7 +618,6 @@ describe("local embedding ensureContext concurrency", () => {
LlamaLogLevel: { error: 0 },
} as never);
const { createEmbeddingProvider } = await import("./embeddings.js");
const result = await createEmbeddingProvider({
config: {} as never,
provider: "local",
@@ -704,12 +699,6 @@ describe("local embedding ensureContext concurrency", () => {
});
describe("FTS-only fallback when no provider available", () => {
beforeEach(async () => {
authModule = await import("../../../../src/agents/model-auth.js");
({ createEmbeddingProvider, DEFAULT_LOCAL_MODEL } = await import("./embeddings.js"));
vi.spyOn(authModule, "resolveApiKeyForProvider");
});
it("returns null provider when all requested auth paths fail", async () => {
vi.mocked(authModule.resolveApiKeyForProvider).mockRejectedValue(
new Error("No API key found for provider"),