diff --git a/extensions/xai/index.test.ts b/extensions/xai/index.test.ts index 1dea1d530f0..8ef1cc4d225 100644 --- a/extensions/xai/index.test.ts +++ b/extensions/xai/index.test.ts @@ -1,8 +1,11 @@ -import type { StreamFn } from "@mariozechner/pi-agent-core"; -import type { Context, Model } from "@mariozechner/pi-ai"; import { describe, expect, it } from "vitest"; import { registerSingleProviderPlugin } from "../../test/helpers/plugins/plugin-registration.js"; import plugin from "./index.js"; +import { + createXaiPayloadCaptureStream, + expectXaiFastToolStreamShaping, + runXaiGrok4ResponseStream, +} from "./test-helpers.js"; function createProviderModel(overrides: { id: string; @@ -59,54 +62,17 @@ describe("xai provider plugin", () => { it("wires provider stream shaping for fast mode and tool-stream defaults", async () => { const provider = await registerSingleProviderPlugin(plugin); - let capturedModelId = ""; - let capturedPayload: Record | undefined; - const baseStreamFn: StreamFn = (model, _context, options) => { - capturedModelId = model.id; - const payload: Record = { - reasoning: { effort: "high" }, - tools: [ - { - type: "function", - function: { - name: "write", - parameters: { type: "object", properties: {} }, - strict: true, - }, - }, - ], - }; - options?.onPayload?.(payload as never, model as never); - capturedPayload = payload; - return { - result: async () => ({}) as never, - async *[Symbol.asyncIterator]() {}, - } as unknown as ReturnType; - }; + const capture = createXaiPayloadCaptureStream(); const wrapped = provider.wrapStreamFn?.({ provider: "xai", modelId: "grok-4", extraParams: { fastMode: true }, - streamFn: baseStreamFn, + streamFn: capture.streamFn, } as never); - void wrapped?.( - { - api: "openai-responses", - provider: "xai", - id: "grok-4", - } as Model<"openai-responses">, - { messages: [] } as Context, - {}, - ); - - expect(capturedModelId).toBe("grok-4-fast"); - expect(capturedPayload).toMatchObject({ tool_stream: true }); - expect(capturedPayload).not.toHaveProperty("reasoning"); - expect( - (capturedPayload?.tools as Array<{ function?: Record }>)[0]?.function, - ).not.toHaveProperty("strict"); + runXaiGrok4ResponseStream(wrapped); + expectXaiFastToolStreamShaping(capture); }); it("defaults tool_stream extra params but preserves explicit values", async () => { diff --git a/extensions/xai/stream.test.ts b/extensions/xai/stream.test.ts index 8d19965ece8..520f800b229 100644 --- a/extensions/xai/stream.test.ts +++ b/extensions/xai/stream.test.ts @@ -6,14 +6,11 @@ import { createXaiToolPayloadCompatibilityWrapper, wrapXaiProviderStream, } from "./stream.js"; - -type ToolPayload = { - function?: Record; -}; -type XaiTestPayload = Record & { - tools?: Array<{ type?: string; function?: Record }>; - input?: unknown[]; -}; +import { + createXaiPayloadCaptureStream, + expectXaiFastToolStreamShaping, + runXaiGrok4ResponseStream, +} from "./test-helpers.js"; type XaiStreamApi = Extract; function captureWrappedModelId(params: { @@ -94,51 +91,15 @@ describe("xai stream wrappers", () => { }); it("composes the xai provider stream chain from extra params", () => { - let capturedModelId = ""; - let capturedPayload: XaiTestPayload | undefined; - const baseStreamFn: StreamFn = (model, _context, options) => { - capturedModelId = model.id; - const payload: XaiTestPayload = { - reasoning: { effort: "high" }, - tools: [ - { - type: "function", - function: { - name: "write", - parameters: { type: "object", properties: {} }, - strict: true, - }, - }, - ], - }; - options?.onPayload?.(payload as never, model as never); - capturedPayload = payload; - return { - result: async () => ({}) as never, - async *[Symbol.asyncIterator]() {}, - } as unknown as ReturnType; - }; + const capture = createXaiPayloadCaptureStream(); const wrapped = wrapXaiProviderStream({ - streamFn: baseStreamFn, + streamFn: capture.streamFn, extraParams: { fastMode: true }, } as never); - void wrapped?.( - { - api: "openai-responses", - provider: "xai", - id: "grok-4", - } as Model<"openai-responses">, - { messages: [] } as Context, - {}, - ); - - expect(capturedModelId).toBe("grok-4-fast"); - expect(capturedPayload).toMatchObject({ tool_stream: true }); - expect(capturedPayload).not.toHaveProperty("reasoning"); - const payloadTools = capturedPayload?.tools as ToolPayload[] | undefined; - expect(payloadTools?.[0]?.function).not.toHaveProperty("strict"); + runXaiGrok4ResponseStream(wrapped); + expectXaiFastToolStreamShaping(capture); }); it("strips unsupported strict and reasoning controls from tool payloads", () => { diff --git a/extensions/xai/test-helpers.ts b/extensions/xai/test-helpers.ts new file mode 100644 index 00000000000..77848a1a8ac --- /dev/null +++ b/extensions/xai/test-helpers.ts @@ -0,0 +1,73 @@ +import type { StreamFn } from "@mariozechner/pi-agent-core"; +import type { Context, Model } from "@mariozechner/pi-ai"; +import { expect } from "vitest"; + +export type XaiToolPayloadFunction = { + function?: Record; +}; + +export type XaiTestPayload = Record & { + tools?: Array<{ type?: string; function?: Record }>; + input?: unknown[]; +}; + +export function createXaiToolStreamPayload(): XaiTestPayload { + return { + reasoning: { effort: "high" }, + tools: [ + { + type: "function", + function: { + name: "write", + parameters: { type: "object", properties: {} }, + strict: true, + }, + }, + ], + }; +} + +export function createXaiPayloadCaptureStream() { + let capturedModelId = ""; + let capturedPayload: XaiTestPayload | undefined; + + const streamFn: StreamFn = (model, _context, options) => { + capturedModelId = model.id; + const payload = createXaiToolStreamPayload(); + options?.onPayload?.(payload as never, model as never); + capturedPayload = payload; + return { + result: async () => ({}) as never, + async *[Symbol.asyncIterator]() {}, + } as unknown as ReturnType; + }; + + return { + streamFn, + getCapturedModelId: () => capturedModelId, + getCapturedPayload: () => capturedPayload, + }; +} + +export function runXaiGrok4ResponseStream(streamFn: StreamFn | undefined) { + void streamFn?.( + { + api: "openai-responses", + provider: "xai", + id: "grok-4", + } as Model<"openai-responses">, + { messages: [] } as Context, + {}, + ); +} + +export function expectXaiFastToolStreamShaping( + capture: ReturnType, +) { + const capturedPayload = capture.getCapturedPayload(); + expect(capture.getCapturedModelId()).toBe("grok-4-fast"); + expect(capturedPayload).toMatchObject({ tool_stream: true }); + expect(capturedPayload).not.toHaveProperty("reasoning"); + const payloadTools = capturedPayload?.tools as XaiToolPayloadFunction[] | undefined; + expect(payloadTools?.[0]?.function).not.toHaveProperty("strict"); +}