diff --git a/extensions/google/index.test.ts b/extensions/google/index.test.ts index 2f0024a81e4..66d14fe1fd0 100644 --- a/extensions/google/index.test.ts +++ b/extensions/google/index.test.ts @@ -1,4 +1,3 @@ -import type { StreamFn } from "@mariozechner/pi-agent-core"; import type { Context, Model } from "@mariozechner/pi-ai"; import type { ProviderReplaySessionEntry, @@ -9,6 +8,7 @@ import { registerProviderPlugin, requireRegisteredProvider, } from "../../test/helpers/plugins/provider-registration.js"; +import { createCapturedThinkingConfigStream } from "../../test/helpers/plugins/stream-hooks.js"; import { registerGoogleGeminiCliProvider } from "./gemini-cli-provider.js"; import { registerGoogleProvider } from "./provider-registration.js"; @@ -146,24 +146,14 @@ describe("google provider plugin hooks", () => { }); const googleProvider = requireRegisteredProvider(providers, "google"); const cliProvider = requireRegisteredProvider(providers, "google-gemini-cli"); - let capturedPayload: Record | undefined; - - const baseStreamFn: StreamFn = (model, _context, options) => { - const payload = { config: { thinkingConfig: { thinkingBudget: -1 } } } as Record< - string, - unknown - >; - options?.onPayload?.(payload as never, model as never); - capturedPayload = payload; - return {} as never; - }; + const capturedStream = createCapturedThinkingConfigStream(); const runCase = (provider: typeof googleProvider, providerId: string) => { const wrapped = provider.wrapStreamFn?.({ provider: providerId, modelId: "gemini-3.1-pro-preview", thinkingLevel: "high", - streamFn: baseStreamFn, + streamFn: capturedStream.streamFn, } as never); void wrapped?.( @@ -176,6 +166,7 @@ describe("google provider plugin hooks", () => { {}, ); + const capturedPayload = capturedStream.getCapturedPayload(); expect(capturedPayload).toMatchObject({ config: { thinkingConfig: { thinkingLevel: "HIGH" } }, }); diff --git a/extensions/moonshot/index.test.ts b/extensions/moonshot/index.test.ts index 21aaebc8768..167eb5d82fd 100644 --- a/extensions/moonshot/index.test.ts +++ b/extensions/moonshot/index.test.ts @@ -1,7 +1,7 @@ -import type { StreamFn } from "@mariozechner/pi-agent-core"; import type { Context, Model } from "@mariozechner/pi-ai"; import { describe, expect, it } from "vitest"; import { registerSingleProviderPlugin } from "../../test/helpers/plugins/plugin-registration.js"; +import { createCapturedThinkingConfigStream } from "../../test/helpers/plugins/stream-hooks.js"; import plugin from "./index.js"; describe("moonshot provider plugin", () => { @@ -25,22 +25,13 @@ describe("moonshot provider plugin", () => { it("wires moonshot-thinking stream hooks", async () => { const provider = await registerSingleProviderPlugin(plugin); - let capturedPayload: Record | undefined; - const baseStreamFn: StreamFn = (model, _context, options) => { - const payload = { config: { thinkingConfig: { thinkingBudget: -1 } } } as Record< - string, - unknown - >; - options?.onPayload?.(payload as never, model as never); - capturedPayload = payload; - return {} as never; - }; + const capturedStream = createCapturedThinkingConfigStream(); const wrapped = provider.wrapStreamFn?.({ provider: "moonshot", modelId: "kimi-k2.5", thinkingLevel: "off", - streamFn: baseStreamFn, + streamFn: capturedStream.streamFn, } as never); void wrapped?.( @@ -53,7 +44,7 @@ describe("moonshot provider plugin", () => { {}, ); - expect(capturedPayload).toMatchObject({ + expect(capturedStream.getCapturedPayload()).toMatchObject({ config: { thinkingConfig: { thinkingBudget: -1 } }, thinking: { type: "disabled" }, }); diff --git a/test/helpers/plugins/stream-hooks.ts b/test/helpers/plugins/stream-hooks.ts new file mode 100644 index 00000000000..1b8b7fe7048 --- /dev/null +++ b/test/helpers/plugins/stream-hooks.ts @@ -0,0 +1,18 @@ +import type { StreamFn } from "@mariozechner/pi-agent-core"; + +export function createCapturedThinkingConfigStream() { + let capturedPayload: Record | undefined; + const streamFn: StreamFn = (model, _context, options) => { + const payload = { config: { thinkingConfig: { thinkingBudget: -1 } } } as Record< + string, + unknown + >; + options?.onPayload?.(payload as never, model as never); + capturedPayload = payload; + return {} as never; + }; + return { + streamFn, + getCapturedPayload: () => capturedPayload, + }; +}