From e697fa5e75919e923b1b8a92f18970668f0c7aab Mon Sep 17 00:00:00 2001 From: Vincent Koc Date: Sat, 4 Apr 2026 03:33:32 +0900 Subject: [PATCH] feat(providers): add google transport runtime --- src/agents/google-transport-stream.test.ts | 249 +++++++ src/agents/google-transport-stream.ts | 744 +++++++++++++++++++++ src/agents/openai-transport-stream.test.ts | 98 +++ src/agents/provider-transport-stream.ts | 5 + 4 files changed, 1096 insertions(+) create mode 100644 src/agents/google-transport-stream.test.ts create mode 100644 src/agents/google-transport-stream.ts diff --git a/src/agents/google-transport-stream.test.ts b/src/agents/google-transport-stream.test.ts new file mode 100644 index 00000000000..14cb9869004 --- /dev/null +++ b/src/agents/google-transport-stream.test.ts @@ -0,0 +1,249 @@ +import type { Model } from "@mariozechner/pi-ai"; +import { beforeAll, beforeEach, describe, expect, it, vi } from "vitest"; +import { attachModelProviderRequestTransport } from "./provider-request-config.js"; + +const { buildGuardedModelFetchMock, guardedFetchMock } = vi.hoisted(() => ({ + buildGuardedModelFetchMock: vi.fn(), + guardedFetchMock: vi.fn(), +})); + +vi.mock("./provider-transport-fetch.js", () => ({ + buildGuardedModelFetch: buildGuardedModelFetchMock, +})); + +let buildGoogleGenerativeAiParams: typeof import("./google-transport-stream.js").buildGoogleGenerativeAiParams; +let createGoogleGenerativeAiTransportStreamFn: typeof import("./google-transport-stream.js").createGoogleGenerativeAiTransportStreamFn; + +function buildSseResponse(events: unknown[]): Response { + const sse = `${events.map((event) => `data: ${JSON.stringify(event)}\n\n`).join("")}data: [DONE]\n\n`; + const encoder = new TextEncoder(); + const body = new ReadableStream({ + start(controller) { + controller.enqueue(encoder.encode(sse)); + controller.close(); + }, + }); + return new Response(body, { + status: 200, + headers: { "content-type": "text/event-stream" }, + }); +} + +describe("google transport stream", () => { + beforeAll(async () => { + ({ buildGoogleGenerativeAiParams, createGoogleGenerativeAiTransportStreamFn } = + await import("./google-transport-stream.js")); + }); + + beforeEach(() => { + buildGuardedModelFetchMock.mockReset(); + guardedFetchMock.mockReset(); + buildGuardedModelFetchMock.mockReturnValue(guardedFetchMock); + }); + + it("uses the guarded fetch transport and parses Gemini SSE output", async () => { + guardedFetchMock.mockResolvedValueOnce( + buildSseResponse([ + { + responseId: "resp_1", + candidates: [ + { + content: { + parts: [ + { thought: true, text: "draft", thoughtSignature: "sig_1" }, + { text: "answer" }, + { functionCall: { name: "lookup", args: { q: "hello" } } }, + ], + }, + finishReason: "STOP", + }, + ], + usageMetadata: { + promptTokenCount: 10, + cachedContentTokenCount: 2, + candidatesTokenCount: 5, + thoughtsTokenCount: 3, + totalTokenCount: 18, + }, + }, + ]), + ); + + const model = attachModelProviderRequestTransport( + { + id: "gemini-3.1-pro-preview", + name: "Gemini 3.1 Pro Preview", + api: "google-generative-ai", + provider: "google", + baseUrl: "https://generativelanguage.googleapis.com", + reasoning: true, + input: ["text"], + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 }, + contextWindow: 128000, + maxTokens: 8192, + headers: { "X-Provider": "google" }, + } satisfies Model<"google-generative-ai">, + { + proxy: { + mode: "explicit-proxy", + url: "http://proxy.internal:8443", + }, + }, + ); + + const streamFn = createGoogleGenerativeAiTransportStreamFn(); + const stream = await Promise.resolve( + streamFn( + model, + { + systemPrompt: "Follow policy.", + messages: [{ role: "user", content: "hello", timestamp: 0 }], + tools: [ + { + name: "lookup", + description: "Look up a value", + parameters: { + type: "object", + properties: { q: { type: "string" } }, + required: ["q"], + }, + }, + ], + } as unknown as Parameters[1], + { + apiKey: "gemini-api-key", + reasoning: "medium", + toolChoice: "auto", + } as Parameters[2], + ), + ); + const result = await stream.result(); + + expect(buildGuardedModelFetchMock).toHaveBeenCalledWith(model); + expect(guardedFetchMock).toHaveBeenCalledWith( + "https://generativelanguage.googleapis.com/v1beta/models/gemini-3.1-pro-preview:streamGenerateContent?alt=sse", + expect.objectContaining({ + method: "POST", + headers: expect.objectContaining({ + accept: "text/event-stream", + "Content-Type": "application/json", + "x-goog-api-key": "gemini-api-key", + "X-Provider": "google", + }), + }), + ); + + const init = guardedFetchMock.mock.calls[0]?.[1] as RequestInit; + const requestBody = init.body; + if (typeof requestBody !== "string") { + throw new Error("Expected Google transport request body to be serialized JSON"); + } + const payload = JSON.parse(requestBody) as Record; + expect(payload.systemInstruction).toEqual({ + parts: [{ text: "Follow policy." }], + }); + expect(payload.generationConfig).toMatchObject({ + thinkingConfig: { includeThoughts: true, thinkingLevel: "HIGH" }, + }); + expect(payload.toolConfig).toMatchObject({ + functionCallingConfig: { mode: "AUTO" }, + }); + expect(result).toMatchObject({ + api: "google-generative-ai", + provider: "google", + responseId: "resp_1", + stopReason: "toolUse", + usage: { + input: 8, + output: 8, + cacheRead: 2, + totalTokens: 18, + }, + content: [ + { type: "thinking", thinking: "draft", thinkingSignature: "sig_1" }, + { type: "text", text: "answer" }, + { type: "toolCall", name: "lookup", arguments: { q: "hello" } }, + ], + }); + }); + + it("uses bearer auth when the Google api key is an OAuth JSON payload", async () => { + guardedFetchMock.mockResolvedValueOnce(buildSseResponse([])); + + const model = attachModelProviderRequestTransport( + { + id: "gemini-3-flash-preview", + name: "Gemini 3 Flash Preview", + api: "google-generative-ai", + provider: "custom-google", + baseUrl: "https://generativelanguage.googleapis.com/v1beta", + reasoning: false, + input: ["text"], + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 }, + contextWindow: 128000, + maxTokens: 8192, + } satisfies Model<"google-generative-ai">, + { + tls: { + ca: "ca-pem", + }, + }, + ); + + const streamFn = createGoogleGenerativeAiTransportStreamFn(); + const stream = await Promise.resolve( + streamFn( + model, + { + messages: [{ role: "user", content: "hello", timestamp: 0 }], + } as Parameters[1], + { + apiKey: JSON.stringify({ token: "oauth-token", projectId: "demo" }), + } as Parameters[2], + ), + ); + await stream.result(); + + expect(guardedFetchMock).toHaveBeenCalledWith( + expect.any(String), + expect.objectContaining({ + headers: expect.objectContaining({ + Authorization: "Bearer oauth-token", + "Content-Type": "application/json", + }), + }), + ); + }); + + it("builds direct Gemini payloads without negative fallback thinking budgets", () => { + const model = { + id: "custom-gemini-model", + name: "Custom Gemini", + api: "google-generative-ai", + provider: "custom-google", + baseUrl: "https://proxy.example.com/gemini/v1beta", + reasoning: true, + input: ["text"], + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 }, + contextWindow: 128000, + maxTokens: 8192, + } satisfies Model<"google-generative-ai">; + + const params = buildGoogleGenerativeAiParams( + model, + { + messages: [{ role: "user", content: "hello", timestamp: 0 }], + } as never, + { + reasoning: "medium", + }, + ); + + expect(params.generationConfig).toMatchObject({ + thinkingConfig: { includeThoughts: true }, + }); + expect(params.generationConfig).not.toMatchObject({ + thinkingConfig: { thinkingBudget: -1 }, + }); + }); +}); diff --git a/src/agents/google-transport-stream.ts b/src/agents/google-transport-stream.ts new file mode 100644 index 00000000000..951873ed6ad --- /dev/null +++ b/src/agents/google-transport-stream.ts @@ -0,0 +1,744 @@ +import type { StreamFn } from "@mariozechner/pi-agent-core"; +import { + calculateCost, + createAssistantMessageEventStream, + getEnvApiKey, + type Context, + type Model, + type SimpleStreamOptions, + type ThinkingLevel, +} from "@mariozechner/pi-ai"; +import { parseGeminiAuth } from "../infra/gemini-auth.js"; +import { normalizeGoogleApiBaseUrl } from "../infra/google-api-base-url.js"; +import { sanitizeTransportPayloadText } from "./openai-transport-stream.js"; +import { buildGuardedModelFetch } from "./provider-transport-fetch.js"; +import { transformTransportMessages } from "./transport-message-transform.js"; + +type GoogleTransportModel = Model<"google-generative-ai"> & { + headers?: Record; + provider: string; +}; + +type GoogleThinkingLevel = "MINIMAL" | "LOW" | "MEDIUM" | "HIGH"; + +type GoogleTransportOptions = SimpleStreamOptions & { + toolChoice?: + | "auto" + | "none" + | "any" + | "required" + | { + type: "function"; + function: { + name: string; + }; + }; + thinking?: { + enabled: boolean; + budgetTokens?: number; + level?: GoogleThinkingLevel; + }; +}; + +type GoogleGenerateContentRequest = { + contents: Array>; + generationConfig?: Record; + systemInstruction?: Record; + tools?: Array>; + toolConfig?: Record; +}; + +type GoogleTransportContentBlock = + | { type: "text"; text: string; textSignature?: string } + | { type: "thinking"; thinking: string; thinkingSignature?: string } + | { type: "toolCall"; id: string; name: string; arguments: Record }; + +type MutableAssistantOutput = { + role: "assistant"; + content: Array; + api: "google-generative-ai"; + provider: string; + model: string; + usage: { + input: number; + output: number; + cacheRead: number; + cacheWrite: number; + totalTokens: number; + cost: { input: number; output: number; cacheRead: number; cacheWrite: number; total: number }; + }; + stopReason: string; + timestamp: number; + responseId?: string; + errorMessage?: string; +}; + +type GoogleSseChunk = { + responseId?: string; + candidates?: Array<{ + content?: { + parts?: Array<{ + text?: string; + thought?: boolean; + thoughtSignature?: string; + functionCall?: { + id?: string; + name?: string; + args?: Record; + }; + }>; + }; + finishReason?: string; + }>; + usageMetadata?: { + promptTokenCount?: number; + cachedContentTokenCount?: number; + candidatesTokenCount?: number; + thoughtsTokenCount?: number; + totalTokenCount?: number; + }; +}; + +let toolCallCounter = 0; + +function sanitizeGoogleTransportText(text: string): string { + return sanitizeTransportPayloadText(text); +} + +function isGemini3ProModel(modelId: string): boolean { + return /gemini-3(?:\.\d+)?-pro/.test(modelId.toLowerCase()); +} + +function isGemini3FlashModel(modelId: string): boolean { + return /gemini-3(?:\.\d+)?-flash/.test(modelId.toLowerCase()); +} + +function requiresToolCallId(modelId: string): boolean { + return modelId.startsWith("claude-") || modelId.startsWith("gpt-oss-"); +} + +function supportsMultimodalFunctionResponse(modelId: string): boolean { + const match = modelId.toLowerCase().match(/^gemini(?:-live)?-(\d+)/); + if (!match) { + return true; + } + return Number.parseInt(match[1] ?? "", 10) >= 3; +} + +function retainThoughtSignature(existing: string | undefined, incoming: string | undefined) { + if (typeof incoming === "string" && incoming.length > 0) { + return incoming; + } + return existing; +} + +function mapToolChoice( + choice: GoogleTransportOptions["toolChoice"], +): { mode: "AUTO" | "NONE" | "ANY"; allowedFunctionNames?: string[] } | undefined { + if (!choice) { + return undefined; + } + if (typeof choice === "object" && choice.type === "function") { + return { mode: "ANY", allowedFunctionNames: [choice.function.name] }; + } + switch (choice) { + case "none": + return { mode: "NONE" }; + case "any": + case "required": + return { mode: "ANY" }; + default: + return { mode: "AUTO" }; + } +} + +function mapStopReasonString(reason: string): "stop" | "length" | "error" { + switch (reason) { + case "STOP": + return "stop"; + case "MAX_TOKENS": + return "length"; + default: + return "error"; + } +} + +function normalizeToolCallId(id: string): string { + return id.replace(/[^a-zA-Z0-9_-]/g, "_").slice(0, 64); +} + +function resolveGoogleModelPath(modelId: string): string { + if (modelId.startsWith("models/") || modelId.startsWith("tunedModels/")) { + return modelId; + } + return `models/${modelId}`; +} + +function buildGoogleRequestUrl(model: GoogleTransportModel): string { + const baseUrl = normalizeGoogleApiBaseUrl(model.baseUrl); + return `${baseUrl}/${resolveGoogleModelPath(model.id)}:streamGenerateContent?alt=sse`; +} + +function resolveThinkingLevel(level: ThinkingLevel, modelId: string): GoogleThinkingLevel { + if (isGemini3ProModel(modelId)) { + switch (level) { + case "minimal": + case "low": + return "LOW"; + case "medium": + case "high": + case "xhigh": + return "HIGH"; + } + } + switch (level) { + case "minimal": + return "MINIMAL"; + case "low": + return "LOW"; + case "medium": + return "MEDIUM"; + case "high": + case "xhigh": + return "HIGH"; + } +} + +function getDisabledThinkingConfig(modelId: string): Record { + if (isGemini3ProModel(modelId)) { + return { thinkingLevel: "LOW" }; + } + if (isGemini3FlashModel(modelId)) { + return { thinkingLevel: "MINIMAL" }; + } + return { thinkingBudget: 0 }; +} + +function getGoogleThinkingBudget( + modelId: string, + effort: ThinkingLevel, + customBudgets?: GoogleTransportOptions["thinkingBudgets"], +): number | undefined { + const normalizedEffort = effort === "xhigh" ? "high" : effort; + if (customBudgets?.[normalizedEffort] !== undefined) { + return customBudgets[normalizedEffort]; + } + if (modelId.includes("2.5-pro")) { + return { minimal: 128, low: 2048, medium: 8192, high: 32768 }[normalizedEffort]; + } + if (modelId.includes("2.5-flash")) { + return { minimal: 128, low: 2048, medium: 8192, high: 24576 }[normalizedEffort]; + } + return undefined; +} + +function resolveGoogleThinkingConfig( + model: GoogleTransportModel, + options: GoogleTransportOptions | undefined, +): Record | undefined { + if (!model.reasoning) { + return undefined; + } + if (options?.thinking) { + if (!options.thinking.enabled) { + return getDisabledThinkingConfig(model.id); + } + const config: Record = { includeThoughts: true }; + if (options.thinking.level) { + config.thinkingLevel = options.thinking.level; + } else if (typeof options.thinking.budgetTokens === "number") { + config.thinkingBudget = options.thinking.budgetTokens; + } + return config; + } + if (!options?.reasoning) { + return getDisabledThinkingConfig(model.id); + } + if (isGemini3ProModel(model.id) || isGemini3FlashModel(model.id)) { + return { + includeThoughts: true, + thinkingLevel: resolveThinkingLevel(options.reasoning, model.id), + }; + } + const budget = getGoogleThinkingBudget(model.id, options.reasoning, options.thinkingBudgets); + return { + includeThoughts: true, + ...(typeof budget === "number" ? { thinkingBudget: budget } : {}), + }; +} + +function convertGoogleMessages(model: GoogleTransportModel, context: Context) { + const contents: Array> = []; + const transformedMessages = transformTransportMessages(context.messages, model, (id) => + requiresToolCallId(model.id) ? normalizeToolCallId(id) : id, + ); + for (const msg of transformedMessages) { + if (msg.role === "user") { + if (typeof msg.content === "string") { + contents.push({ + role: "user", + parts: [{ text: sanitizeGoogleTransportText(msg.content) }], + }); + continue; + } + const parts = msg.content + .map((item) => + item.type === "text" + ? { text: sanitizeGoogleTransportText(item.text) } + : { + inlineData: { + mimeType: item.mimeType, + data: item.data, + }, + }, + ) + .filter((item) => model.input.includes("image") || !("inlineData" in item)); + if (parts.length > 0) { + contents.push({ role: "user", parts }); + } + continue; + } + + if (msg.role === "assistant") { + const isSameProviderAndModel = msg.provider === model.provider && msg.model === model.id; + const parts: Array> = []; + for (const block of msg.content) { + if (block.type === "text") { + if (!block.text.trim()) { + continue; + } + parts.push({ + text: sanitizeGoogleTransportText(block.text), + ...(isSameProviderAndModel && block.textSignature + ? { thoughtSignature: block.textSignature } + : {}), + }); + continue; + } + if (block.type === "thinking") { + if (!block.thinking.trim()) { + continue; + } + if (isSameProviderAndModel) { + parts.push({ + thought: true, + text: sanitizeGoogleTransportText(block.thinking), + ...(block.thinkingSignature ? { thoughtSignature: block.thinkingSignature } : {}), + }); + } else { + parts.push({ text: sanitizeGoogleTransportText(block.thinking) }); + } + continue; + } + if (block.type === "toolCall") { + parts.push({ + functionCall: { + name: block.name, + args: block.arguments ?? {}, + ...(requiresToolCallId(model.id) ? { id: block.id } : {}), + }, + ...(isSameProviderAndModel && block.thoughtSignature + ? { thoughtSignature: block.thoughtSignature } + : {}), + }); + } + } + if (parts.length > 0) { + contents.push({ role: "model", parts }); + } + continue; + } + + if (msg.role === "toolResult") { + const textResult = msg.content + .filter( + (item): item is Extract<(typeof msg.content)[number], { type: "text" }> => + item.type === "text", + ) + .map((item) => item.text) + .join("\n"); + const imageContent = model.input.includes("image") + ? msg.content.filter( + (item): item is Extract<(typeof msg.content)[number], { type: "image" }> => + item.type === "image", + ) + : []; + const responseValue = textResult + ? sanitizeGoogleTransportText(textResult) + : imageContent.length > 0 + ? "(see attached image)" + : ""; + const imageParts = imageContent.map((imageBlock) => ({ + inlineData: { + mimeType: imageBlock.mimeType, + data: imageBlock.data, + }, + })); + const functionResponse = { + functionResponse: { + name: msg.toolName, + response: msg.isError ? { error: responseValue } : { output: responseValue }, + ...(supportsMultimodalFunctionResponse(model.id) && imageParts.length > 0 + ? { parts: imageParts } + : {}), + ...(requiresToolCallId(model.id) ? { id: msg.toolCallId } : {}), + }, + }; + const last = contents[contents.length - 1]; + if ( + last?.role === "user" && + Array.isArray(last.parts) && + last.parts.some((part) => "functionResponse" in part) + ) { + (last.parts as Array>).push(functionResponse); + } else { + contents.push({ role: "user", parts: [functionResponse] }); + } + if (imageParts.length > 0 && !supportsMultimodalFunctionResponse(model.id)) { + contents.push({ role: "user", parts: [{ text: "Tool result image:" }, ...imageParts] }); + } + } + } + return contents; +} + +function convertGoogleTools(tools: NonNullable) { + if (tools.length === 0) { + return undefined; + } + return [ + { + functionDeclarations: tools.map((tool) => ({ + name: tool.name, + description: tool.description, + parametersJsonSchema: tool.parameters, + })), + }, + ]; +} + +export function buildGoogleGenerativeAiParams( + model: GoogleTransportModel, + context: Context, + options?: GoogleTransportOptions, +): GoogleGenerateContentRequest { + const generationConfig: Record = {}; + if (typeof options?.temperature === "number") { + generationConfig.temperature = options.temperature; + } + if (typeof options?.maxTokens === "number") { + generationConfig.maxOutputTokens = options.maxTokens; + } + const thinkingConfig = resolveGoogleThinkingConfig(model, options); + if (thinkingConfig) { + generationConfig.thinkingConfig = thinkingConfig; + } + + const params: GoogleGenerateContentRequest = { + contents: convertGoogleMessages(model, context), + }; + if (Object.keys(generationConfig).length > 0) { + params.generationConfig = generationConfig; + } + if (context.systemPrompt) { + params.systemInstruction = { + parts: [{ text: sanitizeGoogleTransportText(context.systemPrompt) }], + }; + } + if (context.tools?.length) { + params.tools = convertGoogleTools(context.tools); + const toolChoice = mapToolChoice(options?.toolChoice); + if (toolChoice) { + params.toolConfig = { + functionCallingConfig: toolChoice, + }; + } + } + return params; +} + +function buildGoogleHeaders( + model: GoogleTransportModel, + apiKey: string | undefined, + optionHeaders: Record | undefined, +): Record { + const authHeaders = apiKey ? parseGeminiAuth(apiKey).headers : undefined; + return { + accept: "text/event-stream", + ...authHeaders, + ...model.headers, + ...optionHeaders, + }; +} + +async function* parseGoogleSseChunks( + response: Response, + signal?: AbortSignal, +): AsyncGenerator { + if (!response.body) { + throw new Error("No response body"); + } + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + let buffer = ""; + const abortHandler = () => { + void reader.cancel().catch(() => undefined); + }; + signal?.addEventListener("abort", abortHandler); + try { + while (true) { + if (signal?.aborted) { + throw new Error("Request was aborted"); + } + const { done, value } = await reader.read(); + if (done) { + break; + } + buffer += decoder.decode(value, { stream: true }).replace(/\r/g, ""); + let boundary = buffer.indexOf("\n\n"); + while (boundary >= 0) { + const rawEvent = buffer.slice(0, boundary); + buffer = buffer.slice(boundary + 2); + boundary = buffer.indexOf("\n\n"); + const data = rawEvent + .split("\n") + .filter((line) => line.startsWith("data:")) + .map((line) => line.slice(5).trim()) + .join("\n"); + if (!data || data === "[DONE]") { + continue; + } + yield JSON.parse(data) as GoogleSseChunk; + } + } + } finally { + signal?.removeEventListener("abort", abortHandler); + } +} + +function updateUsage( + output: MutableAssistantOutput, + model: GoogleTransportModel, + chunk: GoogleSseChunk, +) { + const usage = chunk.usageMetadata; + if (!usage) { + return; + } + const promptTokens = usage.promptTokenCount || 0; + const cacheRead = usage.cachedContentTokenCount || 0; + output.usage = { + input: Math.max(0, promptTokens - cacheRead), + output: (usage.candidatesTokenCount || 0) + (usage.thoughtsTokenCount || 0), + cacheRead, + cacheWrite: 0, + totalTokens: usage.totalTokenCount || 0, + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, + }; + calculateCost(model, output.usage); +} + +function pushTextBlockEnd( + stream: { push(event: unknown): void }, + output: MutableAssistantOutput, + blockIndex: number, +) { + const block = output.content[blockIndex]; + if (!block) { + return; + } + if (block.type === "thinking") { + stream.push({ + type: "thinking_end", + contentIndex: blockIndex, + content: block.thinking, + partial: output as never, + }); + return; + } + if (block.type === "text") { + stream.push({ + type: "text_end", + contentIndex: blockIndex, + content: block.text, + partial: output as never, + }); + } +} + +export function createGoogleGenerativeAiTransportStreamFn(): StreamFn { + return (rawModel, context, rawOptions) => { + const model = rawModel as GoogleTransportModel; + const options = rawOptions as GoogleTransportOptions | undefined; + const eventStream = createAssistantMessageEventStream(); + const stream = eventStream as unknown as { push(event: unknown): void; end(): void }; + void (async () => { + const output: MutableAssistantOutput = { + role: "assistant", + content: [], + api: "google-generative-ai", + provider: model.provider, + model: model.id, + usage: { + input: 0, + output: 0, + cacheRead: 0, + cacheWrite: 0, + totalTokens: 0, + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, + }, + stopReason: "stop", + timestamp: Date.now(), + }; + try { + const apiKey = options?.apiKey ?? getEnvApiKey(model.provider) ?? undefined; + const fetch = buildGuardedModelFetch(model); + let params = buildGoogleGenerativeAiParams(model, context, options); + const nextParams = await options?.onPayload?.(params, model); + if (nextParams !== undefined) { + params = nextParams as GoogleGenerateContentRequest; + } + const response = await fetch(buildGoogleRequestUrl(model), { + method: "POST", + headers: buildGoogleHeaders(model, apiKey, options?.headers), + body: JSON.stringify(params), + signal: options?.signal, + }); + if (!response.ok) { + const message = await response.text().catch(() => ""); + throw new Error(`Google Generative AI API error (${response.status}): ${message}`); + } + stream.push({ type: "start", partial: output as never }); + let currentBlockIndex = -1; + for await (const chunk of parseGoogleSseChunks(response, options?.signal)) { + output.responseId ||= chunk.responseId; + updateUsage(output, model, chunk); + const candidate = chunk.candidates?.[0]; + if (candidate?.content?.parts) { + for (const part of candidate.content.parts) { + if (typeof part.text === "string") { + const isThinking = part.thought === true; + const currentBlock = output.content[currentBlockIndex]; + if ( + currentBlockIndex < 0 || + !currentBlock || + (isThinking && currentBlock.type !== "thinking") || + (!isThinking && currentBlock.type !== "text") + ) { + if (currentBlockIndex >= 0) { + pushTextBlockEnd(stream, output, currentBlockIndex); + } + if (isThinking) { + output.content.push({ type: "thinking", thinking: "" }); + currentBlockIndex = output.content.length - 1; + stream.push({ + type: "thinking_start", + contentIndex: currentBlockIndex, + partial: output as never, + }); + } else { + output.content.push({ type: "text", text: "" }); + currentBlockIndex = output.content.length - 1; + stream.push({ + type: "text_start", + contentIndex: currentBlockIndex, + partial: output as never, + }); + } + } + const activeBlock = output.content[currentBlockIndex]; + if (activeBlock?.type === "thinking") { + activeBlock.thinking += part.text; + activeBlock.thinkingSignature = retainThoughtSignature( + activeBlock.thinkingSignature, + part.thoughtSignature, + ); + stream.push({ + type: "thinking_delta", + contentIndex: currentBlockIndex, + delta: part.text, + partial: output as never, + }); + } else if (activeBlock?.type === "text") { + activeBlock.text += part.text; + activeBlock.textSignature = retainThoughtSignature( + activeBlock.textSignature, + part.thoughtSignature, + ); + stream.push({ + type: "text_delta", + contentIndex: currentBlockIndex, + delta: part.text, + partial: output as never, + }); + } + } + if (part.functionCall) { + if (currentBlockIndex >= 0) { + pushTextBlockEnd(stream, output, currentBlockIndex); + currentBlockIndex = -1; + } + const providedId = part.functionCall.id; + const isDuplicate = output.content.some( + (block) => block.type === "toolCall" && block.id === providedId, + ); + const toolCallId = + providedId && !isDuplicate + ? providedId + : `${part.functionCall.name || "tool"}_${Date.now()}_${++toolCallCounter}`; + const toolCall: GoogleTransportContentBlock = { + type: "toolCall", + id: toolCallId, + name: part.functionCall.name || "", + arguments: part.functionCall.args ?? {}, + }; + output.content.push(toolCall); + const blockIndex = output.content.length - 1; + stream.push({ + type: "toolcall_start", + contentIndex: blockIndex, + partial: output as never, + }); + stream.push({ + type: "toolcall_delta", + contentIndex: blockIndex, + delta: JSON.stringify(toolCall.arguments), + partial: output as never, + }); + stream.push({ + type: "toolcall_end", + contentIndex: blockIndex, + toolCall, + partial: output as never, + }); + } + } + } + if (typeof candidate?.finishReason === "string") { + output.stopReason = mapStopReasonString(candidate.finishReason); + if (output.content.some((block) => block.type === "toolCall")) { + output.stopReason = "toolUse"; + } + } + } + if (currentBlockIndex >= 0) { + pushTextBlockEnd(stream, output, currentBlockIndex); + } + if (options?.signal?.aborted) { + throw new Error("Request was aborted"); + } + if (output.stopReason === "aborted" || output.stopReason === "error") { + throw new Error("An unknown error occurred"); + } + stream.push({ type: "done", reason: output.stopReason as never, message: output as never }); + stream.end(); + } catch (error) { + output.stopReason = options?.signal?.aborted ? "aborted" : "error"; + output.errorMessage = error instanceof Error ? error.message : JSON.stringify(error); + stream.push({ type: "error", reason: output.stopReason as never, error: output as never }); + stream.end(); + } + })(); + return eventStream as unknown as ReturnType; + }; +} diff --git a/src/agents/openai-transport-stream.test.ts b/src/agents/openai-transport-stream.test.ts index 454cc1b53c8..96057a4a639 100644 --- a/src/agents/openai-transport-stream.test.ts +++ b/src/agents/openai-transport-stream.test.ts @@ -21,6 +21,7 @@ describe("openai transport stream", () => { expect(isTransportAwareApiSupported("openai-completions")).toBe(true); expect(isTransportAwareApiSupported("azure-openai-responses")).toBe(true); expect(isTransportAwareApiSupported("anthropic-messages")).toBe(true); + expect(isTransportAwareApiSupported("google-generative-ai")).toBe(true); }); it("prepares a custom simple-completion api alias when transport overrides are attached", () => { @@ -89,6 +90,103 @@ describe("openai transport stream", () => { expect(buildTransportAwareSimpleStreamFn(model)).toBeTypeOf("function"); }); + it("prepares a Google simple-completion api alias when transport overrides are attached", () => { + const model = attachModelProviderRequestTransport( + { + id: "gemini-3.1-pro-preview", + name: "Gemini 3.1 Pro Preview", + api: "google-generative-ai", + provider: "google", + baseUrl: "https://generativelanguage.googleapis.com/v1beta", + reasoning: true, + input: ["text"], + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 }, + contextWindow: 200000, + maxTokens: 8192, + } satisfies Model<"google-generative-ai">, + { + proxy: { + mode: "explicit-proxy", + url: "http://proxy.internal:8443", + }, + }, + ); + + const prepared = prepareTransportAwareSimpleModel(model); + + expect(resolveTransportAwareSimpleApi(model.api)).toBe( + "openclaw-google-generative-ai-transport", + ); + expect(prepared).toMatchObject({ + api: "openclaw-google-generative-ai-transport", + provider: "google", + id: "gemini-3.1-pro-preview", + }); + expect(buildTransportAwareSimpleStreamFn(model)).toBeTypeOf("function"); + }); + + it("keeps github-copilot OpenAI-family models on the shared transport seam", () => { + const model = attachModelProviderRequestTransport( + { + id: "gpt-5.4", + name: "GPT-5.4", + api: "openai-responses", + provider: "github-copilot", + baseUrl: "https://api.githubcopilot.com/v1", + reasoning: true, + input: ["text", "image"], + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 }, + contextWindow: 200000, + maxTokens: 8192, + } satisfies Model<"openai-responses">, + { + proxy: { + mode: "explicit-proxy", + url: "http://proxy.internal:8443", + }, + }, + ); + + expect(resolveTransportAwareSimpleApi(model.api)).toBe("openclaw-openai-responses-transport"); + expect(prepareTransportAwareSimpleModel(model)).toMatchObject({ + api: "openclaw-openai-responses-transport", + provider: "github-copilot", + id: "gpt-5.4", + }); + expect(buildTransportAwareSimpleStreamFn(model)).toBeTypeOf("function"); + }); + + it("keeps github-copilot Claude models on the shared Anthropic transport seam", () => { + const model = attachModelProviderRequestTransport( + { + id: "claude-sonnet-4.6", + name: "Claude Sonnet 4.6", + api: "anthropic-messages", + provider: "github-copilot", + baseUrl: "https://api.githubcopilot.com/anthropic", + reasoning: true, + input: ["text", "image"], + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 }, + contextWindow: 200000, + maxTokens: 8192, + } satisfies Model<"anthropic-messages">, + { + proxy: { + mode: "explicit-proxy", + url: "http://proxy.internal:8443", + }, + }, + ); + + expect(resolveTransportAwareSimpleApi(model.api)).toBe("openclaw-anthropic-messages-transport"); + expect(prepareTransportAwareSimpleModel(model)).toMatchObject({ + api: "openclaw-anthropic-messages-transport", + provider: "github-copilot", + id: "claude-sonnet-4.6", + }); + expect(buildTransportAwareSimpleStreamFn(model)).toBeTypeOf("function"); + }); + it("removes unpaired surrogate code units but preserves valid surrogate pairs", () => { const high = String.fromCharCode(0xd83d); const low = String.fromCharCode(0xdc00); diff --git a/src/agents/provider-transport-stream.ts b/src/agents/provider-transport-stream.ts index 3a1de3f7e2b..a30d46f560a 100644 --- a/src/agents/provider-transport-stream.ts +++ b/src/agents/provider-transport-stream.ts @@ -1,6 +1,7 @@ import type { StreamFn } from "@mariozechner/pi-agent-core"; import type { Api, Model } from "@mariozechner/pi-ai"; import { createAnthropicMessagesTransportStreamFn } from "./anthropic-transport-stream.js"; +import { createGoogleGenerativeAiTransportStreamFn } from "./google-transport-stream.js"; import { createAzureOpenAIResponsesTransportStreamFn, createOpenAICompletionsTransportStreamFn, @@ -13,6 +14,7 @@ const SUPPORTED_TRANSPORT_APIS = new Set([ "openai-completions", "azure-openai-responses", "anthropic-messages", + "google-generative-ai", ]); const SIMPLE_TRANSPORT_API_ALIAS: Record = { @@ -20,6 +22,7 @@ const SIMPLE_TRANSPORT_API_ALIAS: Record = { "openai-completions": "openclaw-openai-completions-transport", "azure-openai-responses": "openclaw-azure-openai-responses-transport", "anthropic-messages": "openclaw-anthropic-messages-transport", + "google-generative-ai": "openclaw-google-generative-ai-transport", }; function hasTransportOverrides(model: Model): boolean { @@ -53,6 +56,8 @@ export function createTransportAwareStreamFnForModel(model: Model): StreamF return createAzureOpenAIResponsesTransportStreamFn(); case "anthropic-messages": return createAnthropicMessagesTransportStreamFn(); + case "google-generative-ai": + return createGoogleGenerativeAiTransportStreamFn(); default: return undefined; }