refactor(providers): add stream family hooks

This commit is contained in:
Vincent Koc
2026-04-04 20:19:43 +09:00
parent 1037af01ad
commit bc648ac8e6
7 changed files with 162 additions and 38 deletions

View File

@@ -1,6 +1,9 @@
import type { StreamFn } from "@mariozechner/pi-agent-core";
import { describe, expect, it } from "vitest";
import { composeProviderStreamWrappers } from "./provider-stream.js";
import {
buildProviderStreamFamilyHooks,
composeProviderStreamWrappers,
} from "./provider-stream.js";
describe("composeProviderStreamWrappers", () => {
it("applies wrappers left to right", async () => {
@@ -33,3 +36,86 @@ describe("composeProviderStreamWrappers", () => {
expect(composeProviderStreamWrappers(baseStreamFn)).toBe(baseStreamFn);
});
});
describe("buildProviderStreamFamilyHooks", () => {
it("covers the stream family matrix", () => {
let capturedPayload: Record<string, unknown> | undefined;
let capturedModelId: string | undefined;
const baseStreamFn: StreamFn = (model, _context, options) => {
capturedModelId = String(model.id);
const payload = { config: { thinkingConfig: { thinkingBudget: -1 } } } as Record<
string,
unknown
>;
options?.onPayload?.(payload as never, model as never);
capturedPayload = payload;
return {} as never;
};
const googleHooks = buildProviderStreamFamilyHooks("google-thinking");
googleHooks.wrapStreamFn?.({
streamFn: baseStreamFn,
thinkingLevel: "high",
} as never)(
{ api: "google-generative-ai", id: "gemini-3.1-pro-preview" } as never,
{} as never,
{},
);
expect(capturedPayload).toMatchObject({
config: { thinkingConfig: { thinkingLevel: "HIGH" } },
});
const googleThinkingConfig = (
(capturedPayload as Record<string, unknown>).config as Record<string, unknown>
).thinkingConfig as Record<string, unknown>;
expect(googleThinkingConfig).not.toHaveProperty("thinkingBudget");
const minimaxHooks = buildProviderStreamFamilyHooks("minimax-fast-mode");
minimaxHooks.wrapStreamFn?.({
streamFn: baseStreamFn,
extraParams: { fastMode: true },
} as never)(
{
api: "anthropic-messages",
provider: "minimax",
id: "MiniMax-M2.7",
} as never,
{} as never,
{},
);
expect(capturedModelId).toBe("MiniMax-M2.7-highspeed");
const moonshotHooks = buildProviderStreamFamilyHooks("moonshot-thinking");
moonshotHooks.wrapStreamFn?.({
streamFn: baseStreamFn,
thinkingLevel: "off",
} as never)(
{ api: "openai-completions", id: "kimi-k2.5" } as never,
{} as never,
{},
);
expect(capturedPayload).toMatchObject({
config: { thinkingConfig: { thinkingBudget: -1 } },
thinking: { type: "disabled" },
});
const toolStreamHooks = buildProviderStreamFamilyHooks("tool-stream-default-on");
toolStreamHooks.wrapStreamFn?.({
streamFn: baseStreamFn,
extraParams: {},
} as never)({ id: "glm-4.7" } as never, {} as never, {});
expect(capturedPayload).toMatchObject({
config: { thinkingConfig: { thinkingBudget: -1 } },
tool_stream: true,
});
toolStreamHooks.wrapStreamFn?.({
streamFn: baseStreamFn,
extraParams: { tool_stream: false },
} as never)({ id: "glm-4.7" } as never, {} as never, {});
expect(capturedPayload).toMatchObject({
config: { thinkingConfig: { thinkingBudget: -1 } },
});
expect(capturedPayload).not.toHaveProperty("tool_stream");
});
});

View File

@@ -1,4 +1,22 @@
import type { StreamFn } from "@mariozechner/pi-agent-core";
import type { ProviderPlugin } from "../plugins/types.js";
import type { ProviderWrapStreamFnContext } from "./plugin-entry.js";
import {
createGoogleThinkingPayloadWrapper,
sanitizeGoogleThinkingPayload,
} from "../agents/pi-embedded-runner/google-stream-wrappers.js";
import { createMinimaxFastModeWrapper } from "../agents/pi-embedded-runner/minimax-stream-wrappers.js";
import {
createMoonshotThinkingWrapper,
resolveMoonshotThinkingType,
} from "../agents/pi-embedded-runner/moonshot-thinking-stream-wrappers.js";
import {
createKilocodeWrapper,
createOpenRouterSystemCacheWrapper,
createOpenRouterWrapper,
isProxyReasoningUnsupported,
} from "../agents/pi-embedded-runner/proxy-stream-wrappers.js";
import { createToolStreamWrapper, createZaiToolStreamWrapper } from "../agents/pi-embedded-runner/zai-stream-wrappers.js";
export type ProviderStreamWrapperFactory =
| ((streamFn: StreamFn | undefined) => StreamFn | undefined)
@@ -16,6 +34,46 @@ export function composeProviderStreamWrappers(
);
}
export type ProviderStreamFamily =
| "google-thinking"
| "moonshot-thinking"
| "minimax-fast-mode"
| "tool-stream-default-on";
type ProviderStreamFamilyHooks = Pick<ProviderPlugin, "wrapStreamFn">;
export function buildProviderStreamFamilyHooks(
family: ProviderStreamFamily,
): ProviderStreamFamilyHooks {
switch (family) {
case "google-thinking":
return {
wrapStreamFn: (ctx: ProviderWrapStreamFnContext) =>
createGoogleThinkingPayloadWrapper(ctx.streamFn, ctx.thinkingLevel),
};
case "moonshot-thinking":
return {
wrapStreamFn: (ctx: ProviderWrapStreamFnContext) => {
const thinkingType = resolveMoonshotThinkingType({
configuredThinking: ctx.extraParams?.thinking,
thinkingLevel: ctx.thinkingLevel,
});
return createMoonshotThinkingWrapper(ctx.streamFn, thinkingType);
},
};
case "minimax-fast-mode":
return {
wrapStreamFn: (ctx: ProviderWrapStreamFnContext) =>
createMinimaxFastModeWrapper(ctx.streamFn, ctx.extraParams?.fastMode === true),
};
case "tool-stream-default-on":
return {
wrapStreamFn: (ctx: ProviderWrapStreamFnContext) =>
createToolStreamWrapper(ctx.streamFn, ctx.extraParams?.tool_stream !== false),
};
}
}
// Public stream-wrapper helpers for provider plugins.
export {
@@ -38,18 +96,14 @@ export {
export {
createGoogleThinkingPayloadWrapper,
sanitizeGoogleThinkingPayload,
} from "../agents/pi-embedded-runner/google-stream-wrappers.js";
export { createMinimaxFastModeWrapper } from "../agents/pi-embedded-runner/minimax-stream-wrappers.js";
export {
createMinimaxFastModeWrapper,
createKilocodeWrapper,
createOpenRouterSystemCacheWrapper,
createOpenRouterWrapper,
isProxyReasoningUnsupported,
} from "../agents/pi-embedded-runner/proxy-stream-wrappers.js";
export {
createMoonshotThinkingWrapper,
resolveMoonshotThinkingType,
} from "../agents/pi-embedded-runner/moonshot-thinking-stream-wrappers.js";
};
export {
createOpenAIAttributionHeadersWrapper,
createCodexNativeWebSearchWrapper,
@@ -64,10 +118,7 @@ export {
resolveOpenAITextVerbosity,
} from "../agents/pi-embedded-runner/openai-stream-wrappers.js";
export { streamWithPayloadPatch } from "../agents/pi-embedded-runner/stream-payload-utils.js";
export {
createToolStreamWrapper,
createZaiToolStreamWrapper,
} from "../agents/pi-embedded-runner/zai-stream-wrappers.js";
export { createToolStreamWrapper, createZaiToolStreamWrapper };
export {
getOpenRouterModelCapabilities,
loadOpenRouterModelCapabilities,