mirror of
https://github.com/openclaw/openclaw.git
synced 2026-04-08 15:51:06 +00:00
feat(providers): add google transport runtime
This commit is contained in:
249
src/agents/google-transport-stream.test.ts
Normal file
249
src/agents/google-transport-stream.test.ts
Normal file
@@ -0,0 +1,249 @@
|
||||
import type { Model } from "@mariozechner/pi-ai";
|
||||
import { beforeAll, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { attachModelProviderRequestTransport } from "./provider-request-config.js";
|
||||
|
||||
const { buildGuardedModelFetchMock, guardedFetchMock } = vi.hoisted(() => ({
|
||||
buildGuardedModelFetchMock: vi.fn(),
|
||||
guardedFetchMock: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock("./provider-transport-fetch.js", () => ({
|
||||
buildGuardedModelFetch: buildGuardedModelFetchMock,
|
||||
}));
|
||||
|
||||
let buildGoogleGenerativeAiParams: typeof import("./google-transport-stream.js").buildGoogleGenerativeAiParams;
|
||||
let createGoogleGenerativeAiTransportStreamFn: typeof import("./google-transport-stream.js").createGoogleGenerativeAiTransportStreamFn;
|
||||
|
||||
function buildSseResponse(events: unknown[]): Response {
|
||||
const sse = `${events.map((event) => `data: ${JSON.stringify(event)}\n\n`).join("")}data: [DONE]\n\n`;
|
||||
const encoder = new TextEncoder();
|
||||
const body = new ReadableStream<Uint8Array>({
|
||||
start(controller) {
|
||||
controller.enqueue(encoder.encode(sse));
|
||||
controller.close();
|
||||
},
|
||||
});
|
||||
return new Response(body, {
|
||||
status: 200,
|
||||
headers: { "content-type": "text/event-stream" },
|
||||
});
|
||||
}
|
||||
|
||||
describe("google transport stream", () => {
|
||||
beforeAll(async () => {
|
||||
({ buildGoogleGenerativeAiParams, createGoogleGenerativeAiTransportStreamFn } =
|
||||
await import("./google-transport-stream.js"));
|
||||
});
|
||||
|
||||
beforeEach(() => {
|
||||
buildGuardedModelFetchMock.mockReset();
|
||||
guardedFetchMock.mockReset();
|
||||
buildGuardedModelFetchMock.mockReturnValue(guardedFetchMock);
|
||||
});
|
||||
|
||||
it("uses the guarded fetch transport and parses Gemini SSE output", async () => {
|
||||
guardedFetchMock.mockResolvedValueOnce(
|
||||
buildSseResponse([
|
||||
{
|
||||
responseId: "resp_1",
|
||||
candidates: [
|
||||
{
|
||||
content: {
|
||||
parts: [
|
||||
{ thought: true, text: "draft", thoughtSignature: "sig_1" },
|
||||
{ text: "answer" },
|
||||
{ functionCall: { name: "lookup", args: { q: "hello" } } },
|
||||
],
|
||||
},
|
||||
finishReason: "STOP",
|
||||
},
|
||||
],
|
||||
usageMetadata: {
|
||||
promptTokenCount: 10,
|
||||
cachedContentTokenCount: 2,
|
||||
candidatesTokenCount: 5,
|
||||
thoughtsTokenCount: 3,
|
||||
totalTokenCount: 18,
|
||||
},
|
||||
},
|
||||
]),
|
||||
);
|
||||
|
||||
const model = attachModelProviderRequestTransport(
|
||||
{
|
||||
id: "gemini-3.1-pro-preview",
|
||||
name: "Gemini 3.1 Pro Preview",
|
||||
api: "google-generative-ai",
|
||||
provider: "google",
|
||||
baseUrl: "https://generativelanguage.googleapis.com",
|
||||
reasoning: true,
|
||||
input: ["text"],
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
|
||||
contextWindow: 128000,
|
||||
maxTokens: 8192,
|
||||
headers: { "X-Provider": "google" },
|
||||
} satisfies Model<"google-generative-ai">,
|
||||
{
|
||||
proxy: {
|
||||
mode: "explicit-proxy",
|
||||
url: "http://proxy.internal:8443",
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
const streamFn = createGoogleGenerativeAiTransportStreamFn();
|
||||
const stream = await Promise.resolve(
|
||||
streamFn(
|
||||
model,
|
||||
{
|
||||
systemPrompt: "Follow policy.",
|
||||
messages: [{ role: "user", content: "hello", timestamp: 0 }],
|
||||
tools: [
|
||||
{
|
||||
name: "lookup",
|
||||
description: "Look up a value",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: { q: { type: "string" } },
|
||||
required: ["q"],
|
||||
},
|
||||
},
|
||||
],
|
||||
} as unknown as Parameters<typeof streamFn>[1],
|
||||
{
|
||||
apiKey: "gemini-api-key",
|
||||
reasoning: "medium",
|
||||
toolChoice: "auto",
|
||||
} as Parameters<typeof streamFn>[2],
|
||||
),
|
||||
);
|
||||
const result = await stream.result();
|
||||
|
||||
expect(buildGuardedModelFetchMock).toHaveBeenCalledWith(model);
|
||||
expect(guardedFetchMock).toHaveBeenCalledWith(
|
||||
"https://generativelanguage.googleapis.com/v1beta/models/gemini-3.1-pro-preview:streamGenerateContent?alt=sse",
|
||||
expect.objectContaining({
|
||||
method: "POST",
|
||||
headers: expect.objectContaining({
|
||||
accept: "text/event-stream",
|
||||
"Content-Type": "application/json",
|
||||
"x-goog-api-key": "gemini-api-key",
|
||||
"X-Provider": "google",
|
||||
}),
|
||||
}),
|
||||
);
|
||||
|
||||
const init = guardedFetchMock.mock.calls[0]?.[1] as RequestInit;
|
||||
const requestBody = init.body;
|
||||
if (typeof requestBody !== "string") {
|
||||
throw new Error("Expected Google transport request body to be serialized JSON");
|
||||
}
|
||||
const payload = JSON.parse(requestBody) as Record<string, unknown>;
|
||||
expect(payload.systemInstruction).toEqual({
|
||||
parts: [{ text: "Follow policy." }],
|
||||
});
|
||||
expect(payload.generationConfig).toMatchObject({
|
||||
thinkingConfig: { includeThoughts: true, thinkingLevel: "HIGH" },
|
||||
});
|
||||
expect(payload.toolConfig).toMatchObject({
|
||||
functionCallingConfig: { mode: "AUTO" },
|
||||
});
|
||||
expect(result).toMatchObject({
|
||||
api: "google-generative-ai",
|
||||
provider: "google",
|
||||
responseId: "resp_1",
|
||||
stopReason: "toolUse",
|
||||
usage: {
|
||||
input: 8,
|
||||
output: 8,
|
||||
cacheRead: 2,
|
||||
totalTokens: 18,
|
||||
},
|
||||
content: [
|
||||
{ type: "thinking", thinking: "draft", thinkingSignature: "sig_1" },
|
||||
{ type: "text", text: "answer" },
|
||||
{ type: "toolCall", name: "lookup", arguments: { q: "hello" } },
|
||||
],
|
||||
});
|
||||
});
|
||||
|
||||
it("uses bearer auth when the Google api key is an OAuth JSON payload", async () => {
|
||||
guardedFetchMock.mockResolvedValueOnce(buildSseResponse([]));
|
||||
|
||||
const model = attachModelProviderRequestTransport(
|
||||
{
|
||||
id: "gemini-3-flash-preview",
|
||||
name: "Gemini 3 Flash Preview",
|
||||
api: "google-generative-ai",
|
||||
provider: "custom-google",
|
||||
baseUrl: "https://generativelanguage.googleapis.com/v1beta",
|
||||
reasoning: false,
|
||||
input: ["text"],
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
|
||||
contextWindow: 128000,
|
||||
maxTokens: 8192,
|
||||
} satisfies Model<"google-generative-ai">,
|
||||
{
|
||||
tls: {
|
||||
ca: "ca-pem",
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
const streamFn = createGoogleGenerativeAiTransportStreamFn();
|
||||
const stream = await Promise.resolve(
|
||||
streamFn(
|
||||
model,
|
||||
{
|
||||
messages: [{ role: "user", content: "hello", timestamp: 0 }],
|
||||
} as Parameters<typeof streamFn>[1],
|
||||
{
|
||||
apiKey: JSON.stringify({ token: "oauth-token", projectId: "demo" }),
|
||||
} as Parameters<typeof streamFn>[2],
|
||||
),
|
||||
);
|
||||
await stream.result();
|
||||
|
||||
expect(guardedFetchMock).toHaveBeenCalledWith(
|
||||
expect.any(String),
|
||||
expect.objectContaining({
|
||||
headers: expect.objectContaining({
|
||||
Authorization: "Bearer oauth-token",
|
||||
"Content-Type": "application/json",
|
||||
}),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it("builds direct Gemini payloads without negative fallback thinking budgets", () => {
|
||||
const model = {
|
||||
id: "custom-gemini-model",
|
||||
name: "Custom Gemini",
|
||||
api: "google-generative-ai",
|
||||
provider: "custom-google",
|
||||
baseUrl: "https://proxy.example.com/gemini/v1beta",
|
||||
reasoning: true,
|
||||
input: ["text"],
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
|
||||
contextWindow: 128000,
|
||||
maxTokens: 8192,
|
||||
} satisfies Model<"google-generative-ai">;
|
||||
|
||||
const params = buildGoogleGenerativeAiParams(
|
||||
model,
|
||||
{
|
||||
messages: [{ role: "user", content: "hello", timestamp: 0 }],
|
||||
} as never,
|
||||
{
|
||||
reasoning: "medium",
|
||||
},
|
||||
);
|
||||
|
||||
expect(params.generationConfig).toMatchObject({
|
||||
thinkingConfig: { includeThoughts: true },
|
||||
});
|
||||
expect(params.generationConfig).not.toMatchObject({
|
||||
thinkingConfig: { thinkingBudget: -1 },
|
||||
});
|
||||
});
|
||||
});
|
||||
744
src/agents/google-transport-stream.ts
Normal file
744
src/agents/google-transport-stream.ts
Normal file
@@ -0,0 +1,744 @@
|
||||
import type { StreamFn } from "@mariozechner/pi-agent-core";
|
||||
import {
|
||||
calculateCost,
|
||||
createAssistantMessageEventStream,
|
||||
getEnvApiKey,
|
||||
type Context,
|
||||
type Model,
|
||||
type SimpleStreamOptions,
|
||||
type ThinkingLevel,
|
||||
} from "@mariozechner/pi-ai";
|
||||
import { parseGeminiAuth } from "../infra/gemini-auth.js";
|
||||
import { normalizeGoogleApiBaseUrl } from "../infra/google-api-base-url.js";
|
||||
import { sanitizeTransportPayloadText } from "./openai-transport-stream.js";
|
||||
import { buildGuardedModelFetch } from "./provider-transport-fetch.js";
|
||||
import { transformTransportMessages } from "./transport-message-transform.js";
|
||||
|
||||
type GoogleTransportModel = Model<"google-generative-ai"> & {
|
||||
headers?: Record<string, string>;
|
||||
provider: string;
|
||||
};
|
||||
|
||||
type GoogleThinkingLevel = "MINIMAL" | "LOW" | "MEDIUM" | "HIGH";
|
||||
|
||||
type GoogleTransportOptions = SimpleStreamOptions & {
|
||||
toolChoice?:
|
||||
| "auto"
|
||||
| "none"
|
||||
| "any"
|
||||
| "required"
|
||||
| {
|
||||
type: "function";
|
||||
function: {
|
||||
name: string;
|
||||
};
|
||||
};
|
||||
thinking?: {
|
||||
enabled: boolean;
|
||||
budgetTokens?: number;
|
||||
level?: GoogleThinkingLevel;
|
||||
};
|
||||
};
|
||||
|
||||
type GoogleGenerateContentRequest = {
|
||||
contents: Array<Record<string, unknown>>;
|
||||
generationConfig?: Record<string, unknown>;
|
||||
systemInstruction?: Record<string, unknown>;
|
||||
tools?: Array<Record<string, unknown>>;
|
||||
toolConfig?: Record<string, unknown>;
|
||||
};
|
||||
|
||||
type GoogleTransportContentBlock =
|
||||
| { type: "text"; text: string; textSignature?: string }
|
||||
| { type: "thinking"; thinking: string; thinkingSignature?: string }
|
||||
| { type: "toolCall"; id: string; name: string; arguments: Record<string, unknown> };
|
||||
|
||||
type MutableAssistantOutput = {
|
||||
role: "assistant";
|
||||
content: Array<GoogleTransportContentBlock>;
|
||||
api: "google-generative-ai";
|
||||
provider: string;
|
||||
model: string;
|
||||
usage: {
|
||||
input: number;
|
||||
output: number;
|
||||
cacheRead: number;
|
||||
cacheWrite: number;
|
||||
totalTokens: number;
|
||||
cost: { input: number; output: number; cacheRead: number; cacheWrite: number; total: number };
|
||||
};
|
||||
stopReason: string;
|
||||
timestamp: number;
|
||||
responseId?: string;
|
||||
errorMessage?: string;
|
||||
};
|
||||
|
||||
type GoogleSseChunk = {
|
||||
responseId?: string;
|
||||
candidates?: Array<{
|
||||
content?: {
|
||||
parts?: Array<{
|
||||
text?: string;
|
||||
thought?: boolean;
|
||||
thoughtSignature?: string;
|
||||
functionCall?: {
|
||||
id?: string;
|
||||
name?: string;
|
||||
args?: Record<string, unknown>;
|
||||
};
|
||||
}>;
|
||||
};
|
||||
finishReason?: string;
|
||||
}>;
|
||||
usageMetadata?: {
|
||||
promptTokenCount?: number;
|
||||
cachedContentTokenCount?: number;
|
||||
candidatesTokenCount?: number;
|
||||
thoughtsTokenCount?: number;
|
||||
totalTokenCount?: number;
|
||||
};
|
||||
};
|
||||
|
||||
let toolCallCounter = 0;
|
||||
|
||||
function sanitizeGoogleTransportText(text: string): string {
|
||||
return sanitizeTransportPayloadText(text);
|
||||
}
|
||||
|
||||
function isGemini3ProModel(modelId: string): boolean {
|
||||
return /gemini-3(?:\.\d+)?-pro/.test(modelId.toLowerCase());
|
||||
}
|
||||
|
||||
function isGemini3FlashModel(modelId: string): boolean {
|
||||
return /gemini-3(?:\.\d+)?-flash/.test(modelId.toLowerCase());
|
||||
}
|
||||
|
||||
function requiresToolCallId(modelId: string): boolean {
|
||||
return modelId.startsWith("claude-") || modelId.startsWith("gpt-oss-");
|
||||
}
|
||||
|
||||
function supportsMultimodalFunctionResponse(modelId: string): boolean {
|
||||
const match = modelId.toLowerCase().match(/^gemini(?:-live)?-(\d+)/);
|
||||
if (!match) {
|
||||
return true;
|
||||
}
|
||||
return Number.parseInt(match[1] ?? "", 10) >= 3;
|
||||
}
|
||||
|
||||
function retainThoughtSignature(existing: string | undefined, incoming: string | undefined) {
|
||||
if (typeof incoming === "string" && incoming.length > 0) {
|
||||
return incoming;
|
||||
}
|
||||
return existing;
|
||||
}
|
||||
|
||||
function mapToolChoice(
|
||||
choice: GoogleTransportOptions["toolChoice"],
|
||||
): { mode: "AUTO" | "NONE" | "ANY"; allowedFunctionNames?: string[] } | undefined {
|
||||
if (!choice) {
|
||||
return undefined;
|
||||
}
|
||||
if (typeof choice === "object" && choice.type === "function") {
|
||||
return { mode: "ANY", allowedFunctionNames: [choice.function.name] };
|
||||
}
|
||||
switch (choice) {
|
||||
case "none":
|
||||
return { mode: "NONE" };
|
||||
case "any":
|
||||
case "required":
|
||||
return { mode: "ANY" };
|
||||
default:
|
||||
return { mode: "AUTO" };
|
||||
}
|
||||
}
|
||||
|
||||
function mapStopReasonString(reason: string): "stop" | "length" | "error" {
|
||||
switch (reason) {
|
||||
case "STOP":
|
||||
return "stop";
|
||||
case "MAX_TOKENS":
|
||||
return "length";
|
||||
default:
|
||||
return "error";
|
||||
}
|
||||
}
|
||||
|
||||
function normalizeToolCallId(id: string): string {
|
||||
return id.replace(/[^a-zA-Z0-9_-]/g, "_").slice(0, 64);
|
||||
}
|
||||
|
||||
function resolveGoogleModelPath(modelId: string): string {
|
||||
if (modelId.startsWith("models/") || modelId.startsWith("tunedModels/")) {
|
||||
return modelId;
|
||||
}
|
||||
return `models/${modelId}`;
|
||||
}
|
||||
|
||||
function buildGoogleRequestUrl(model: GoogleTransportModel): string {
|
||||
const baseUrl = normalizeGoogleApiBaseUrl(model.baseUrl);
|
||||
return `${baseUrl}/${resolveGoogleModelPath(model.id)}:streamGenerateContent?alt=sse`;
|
||||
}
|
||||
|
||||
function resolveThinkingLevel(level: ThinkingLevel, modelId: string): GoogleThinkingLevel {
|
||||
if (isGemini3ProModel(modelId)) {
|
||||
switch (level) {
|
||||
case "minimal":
|
||||
case "low":
|
||||
return "LOW";
|
||||
case "medium":
|
||||
case "high":
|
||||
case "xhigh":
|
||||
return "HIGH";
|
||||
}
|
||||
}
|
||||
switch (level) {
|
||||
case "minimal":
|
||||
return "MINIMAL";
|
||||
case "low":
|
||||
return "LOW";
|
||||
case "medium":
|
||||
return "MEDIUM";
|
||||
case "high":
|
||||
case "xhigh":
|
||||
return "HIGH";
|
||||
}
|
||||
}
|
||||
|
||||
function getDisabledThinkingConfig(modelId: string): Record<string, unknown> {
|
||||
if (isGemini3ProModel(modelId)) {
|
||||
return { thinkingLevel: "LOW" };
|
||||
}
|
||||
if (isGemini3FlashModel(modelId)) {
|
||||
return { thinkingLevel: "MINIMAL" };
|
||||
}
|
||||
return { thinkingBudget: 0 };
|
||||
}
|
||||
|
||||
function getGoogleThinkingBudget(
|
||||
modelId: string,
|
||||
effort: ThinkingLevel,
|
||||
customBudgets?: GoogleTransportOptions["thinkingBudgets"],
|
||||
): number | undefined {
|
||||
const normalizedEffort = effort === "xhigh" ? "high" : effort;
|
||||
if (customBudgets?.[normalizedEffort] !== undefined) {
|
||||
return customBudgets[normalizedEffort];
|
||||
}
|
||||
if (modelId.includes("2.5-pro")) {
|
||||
return { minimal: 128, low: 2048, medium: 8192, high: 32768 }[normalizedEffort];
|
||||
}
|
||||
if (modelId.includes("2.5-flash")) {
|
||||
return { minimal: 128, low: 2048, medium: 8192, high: 24576 }[normalizedEffort];
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
function resolveGoogleThinkingConfig(
|
||||
model: GoogleTransportModel,
|
||||
options: GoogleTransportOptions | undefined,
|
||||
): Record<string, unknown> | undefined {
|
||||
if (!model.reasoning) {
|
||||
return undefined;
|
||||
}
|
||||
if (options?.thinking) {
|
||||
if (!options.thinking.enabled) {
|
||||
return getDisabledThinkingConfig(model.id);
|
||||
}
|
||||
const config: Record<string, unknown> = { includeThoughts: true };
|
||||
if (options.thinking.level) {
|
||||
config.thinkingLevel = options.thinking.level;
|
||||
} else if (typeof options.thinking.budgetTokens === "number") {
|
||||
config.thinkingBudget = options.thinking.budgetTokens;
|
||||
}
|
||||
return config;
|
||||
}
|
||||
if (!options?.reasoning) {
|
||||
return getDisabledThinkingConfig(model.id);
|
||||
}
|
||||
if (isGemini3ProModel(model.id) || isGemini3FlashModel(model.id)) {
|
||||
return {
|
||||
includeThoughts: true,
|
||||
thinkingLevel: resolveThinkingLevel(options.reasoning, model.id),
|
||||
};
|
||||
}
|
||||
const budget = getGoogleThinkingBudget(model.id, options.reasoning, options.thinkingBudgets);
|
||||
return {
|
||||
includeThoughts: true,
|
||||
...(typeof budget === "number" ? { thinkingBudget: budget } : {}),
|
||||
};
|
||||
}
|
||||
|
||||
function convertGoogleMessages(model: GoogleTransportModel, context: Context) {
|
||||
const contents: Array<Record<string, unknown>> = [];
|
||||
const transformedMessages = transformTransportMessages(context.messages, model, (id) =>
|
||||
requiresToolCallId(model.id) ? normalizeToolCallId(id) : id,
|
||||
);
|
||||
for (const msg of transformedMessages) {
|
||||
if (msg.role === "user") {
|
||||
if (typeof msg.content === "string") {
|
||||
contents.push({
|
||||
role: "user",
|
||||
parts: [{ text: sanitizeGoogleTransportText(msg.content) }],
|
||||
});
|
||||
continue;
|
||||
}
|
||||
const parts = msg.content
|
||||
.map((item) =>
|
||||
item.type === "text"
|
||||
? { text: sanitizeGoogleTransportText(item.text) }
|
||||
: {
|
||||
inlineData: {
|
||||
mimeType: item.mimeType,
|
||||
data: item.data,
|
||||
},
|
||||
},
|
||||
)
|
||||
.filter((item) => model.input.includes("image") || !("inlineData" in item));
|
||||
if (parts.length > 0) {
|
||||
contents.push({ role: "user", parts });
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (msg.role === "assistant") {
|
||||
const isSameProviderAndModel = msg.provider === model.provider && msg.model === model.id;
|
||||
const parts: Array<Record<string, unknown>> = [];
|
||||
for (const block of msg.content) {
|
||||
if (block.type === "text") {
|
||||
if (!block.text.trim()) {
|
||||
continue;
|
||||
}
|
||||
parts.push({
|
||||
text: sanitizeGoogleTransportText(block.text),
|
||||
...(isSameProviderAndModel && block.textSignature
|
||||
? { thoughtSignature: block.textSignature }
|
||||
: {}),
|
||||
});
|
||||
continue;
|
||||
}
|
||||
if (block.type === "thinking") {
|
||||
if (!block.thinking.trim()) {
|
||||
continue;
|
||||
}
|
||||
if (isSameProviderAndModel) {
|
||||
parts.push({
|
||||
thought: true,
|
||||
text: sanitizeGoogleTransportText(block.thinking),
|
||||
...(block.thinkingSignature ? { thoughtSignature: block.thinkingSignature } : {}),
|
||||
});
|
||||
} else {
|
||||
parts.push({ text: sanitizeGoogleTransportText(block.thinking) });
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (block.type === "toolCall") {
|
||||
parts.push({
|
||||
functionCall: {
|
||||
name: block.name,
|
||||
args: block.arguments ?? {},
|
||||
...(requiresToolCallId(model.id) ? { id: block.id } : {}),
|
||||
},
|
||||
...(isSameProviderAndModel && block.thoughtSignature
|
||||
? { thoughtSignature: block.thoughtSignature }
|
||||
: {}),
|
||||
});
|
||||
}
|
||||
}
|
||||
if (parts.length > 0) {
|
||||
contents.push({ role: "model", parts });
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (msg.role === "toolResult") {
|
||||
const textResult = msg.content
|
||||
.filter(
|
||||
(item): item is Extract<(typeof msg.content)[number], { type: "text" }> =>
|
||||
item.type === "text",
|
||||
)
|
||||
.map((item) => item.text)
|
||||
.join("\n");
|
||||
const imageContent = model.input.includes("image")
|
||||
? msg.content.filter(
|
||||
(item): item is Extract<(typeof msg.content)[number], { type: "image" }> =>
|
||||
item.type === "image",
|
||||
)
|
||||
: [];
|
||||
const responseValue = textResult
|
||||
? sanitizeGoogleTransportText(textResult)
|
||||
: imageContent.length > 0
|
||||
? "(see attached image)"
|
||||
: "";
|
||||
const imageParts = imageContent.map((imageBlock) => ({
|
||||
inlineData: {
|
||||
mimeType: imageBlock.mimeType,
|
||||
data: imageBlock.data,
|
||||
},
|
||||
}));
|
||||
const functionResponse = {
|
||||
functionResponse: {
|
||||
name: msg.toolName,
|
||||
response: msg.isError ? { error: responseValue } : { output: responseValue },
|
||||
...(supportsMultimodalFunctionResponse(model.id) && imageParts.length > 0
|
||||
? { parts: imageParts }
|
||||
: {}),
|
||||
...(requiresToolCallId(model.id) ? { id: msg.toolCallId } : {}),
|
||||
},
|
||||
};
|
||||
const last = contents[contents.length - 1];
|
||||
if (
|
||||
last?.role === "user" &&
|
||||
Array.isArray(last.parts) &&
|
||||
last.parts.some((part) => "functionResponse" in part)
|
||||
) {
|
||||
(last.parts as Array<Record<string, unknown>>).push(functionResponse);
|
||||
} else {
|
||||
contents.push({ role: "user", parts: [functionResponse] });
|
||||
}
|
||||
if (imageParts.length > 0 && !supportsMultimodalFunctionResponse(model.id)) {
|
||||
contents.push({ role: "user", parts: [{ text: "Tool result image:" }, ...imageParts] });
|
||||
}
|
||||
}
|
||||
}
|
||||
return contents;
|
||||
}
|
||||
|
||||
function convertGoogleTools(tools: NonNullable<Context["tools"]>) {
|
||||
if (tools.length === 0) {
|
||||
return undefined;
|
||||
}
|
||||
return [
|
||||
{
|
||||
functionDeclarations: tools.map((tool) => ({
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
parametersJsonSchema: tool.parameters,
|
||||
})),
|
||||
},
|
||||
];
|
||||
}
|
||||
|
||||
export function buildGoogleGenerativeAiParams(
|
||||
model: GoogleTransportModel,
|
||||
context: Context,
|
||||
options?: GoogleTransportOptions,
|
||||
): GoogleGenerateContentRequest {
|
||||
const generationConfig: Record<string, unknown> = {};
|
||||
if (typeof options?.temperature === "number") {
|
||||
generationConfig.temperature = options.temperature;
|
||||
}
|
||||
if (typeof options?.maxTokens === "number") {
|
||||
generationConfig.maxOutputTokens = options.maxTokens;
|
||||
}
|
||||
const thinkingConfig = resolveGoogleThinkingConfig(model, options);
|
||||
if (thinkingConfig) {
|
||||
generationConfig.thinkingConfig = thinkingConfig;
|
||||
}
|
||||
|
||||
const params: GoogleGenerateContentRequest = {
|
||||
contents: convertGoogleMessages(model, context),
|
||||
};
|
||||
if (Object.keys(generationConfig).length > 0) {
|
||||
params.generationConfig = generationConfig;
|
||||
}
|
||||
if (context.systemPrompt) {
|
||||
params.systemInstruction = {
|
||||
parts: [{ text: sanitizeGoogleTransportText(context.systemPrompt) }],
|
||||
};
|
||||
}
|
||||
if (context.tools?.length) {
|
||||
params.tools = convertGoogleTools(context.tools);
|
||||
const toolChoice = mapToolChoice(options?.toolChoice);
|
||||
if (toolChoice) {
|
||||
params.toolConfig = {
|
||||
functionCallingConfig: toolChoice,
|
||||
};
|
||||
}
|
||||
}
|
||||
return params;
|
||||
}
|
||||
|
||||
function buildGoogleHeaders(
|
||||
model: GoogleTransportModel,
|
||||
apiKey: string | undefined,
|
||||
optionHeaders: Record<string, string> | undefined,
|
||||
): Record<string, string> {
|
||||
const authHeaders = apiKey ? parseGeminiAuth(apiKey).headers : undefined;
|
||||
return {
|
||||
accept: "text/event-stream",
|
||||
...authHeaders,
|
||||
...model.headers,
|
||||
...optionHeaders,
|
||||
};
|
||||
}
|
||||
|
||||
async function* parseGoogleSseChunks(
|
||||
response: Response,
|
||||
signal?: AbortSignal,
|
||||
): AsyncGenerator<GoogleSseChunk> {
|
||||
if (!response.body) {
|
||||
throw new Error("No response body");
|
||||
}
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
let buffer = "";
|
||||
const abortHandler = () => {
|
||||
void reader.cancel().catch(() => undefined);
|
||||
};
|
||||
signal?.addEventListener("abort", abortHandler);
|
||||
try {
|
||||
while (true) {
|
||||
if (signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
const { done, value } = await reader.read();
|
||||
if (done) {
|
||||
break;
|
||||
}
|
||||
buffer += decoder.decode(value, { stream: true }).replace(/\r/g, "");
|
||||
let boundary = buffer.indexOf("\n\n");
|
||||
while (boundary >= 0) {
|
||||
const rawEvent = buffer.slice(0, boundary);
|
||||
buffer = buffer.slice(boundary + 2);
|
||||
boundary = buffer.indexOf("\n\n");
|
||||
const data = rawEvent
|
||||
.split("\n")
|
||||
.filter((line) => line.startsWith("data:"))
|
||||
.map((line) => line.slice(5).trim())
|
||||
.join("\n");
|
||||
if (!data || data === "[DONE]") {
|
||||
continue;
|
||||
}
|
||||
yield JSON.parse(data) as GoogleSseChunk;
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
signal?.removeEventListener("abort", abortHandler);
|
||||
}
|
||||
}
|
||||
|
||||
function updateUsage(
|
||||
output: MutableAssistantOutput,
|
||||
model: GoogleTransportModel,
|
||||
chunk: GoogleSseChunk,
|
||||
) {
|
||||
const usage = chunk.usageMetadata;
|
||||
if (!usage) {
|
||||
return;
|
||||
}
|
||||
const promptTokens = usage.promptTokenCount || 0;
|
||||
const cacheRead = usage.cachedContentTokenCount || 0;
|
||||
output.usage = {
|
||||
input: Math.max(0, promptTokens - cacheRead),
|
||||
output: (usage.candidatesTokenCount || 0) + (usage.thoughtsTokenCount || 0),
|
||||
cacheRead,
|
||||
cacheWrite: 0,
|
||||
totalTokens: usage.totalTokenCount || 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
};
|
||||
calculateCost(model, output.usage);
|
||||
}
|
||||
|
||||
function pushTextBlockEnd(
|
||||
stream: { push(event: unknown): void },
|
||||
output: MutableAssistantOutput,
|
||||
blockIndex: number,
|
||||
) {
|
||||
const block = output.content[blockIndex];
|
||||
if (!block) {
|
||||
return;
|
||||
}
|
||||
if (block.type === "thinking") {
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
contentIndex: blockIndex,
|
||||
content: block.thinking,
|
||||
partial: output as never,
|
||||
});
|
||||
return;
|
||||
}
|
||||
if (block.type === "text") {
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
contentIndex: blockIndex,
|
||||
content: block.text,
|
||||
partial: output as never,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
export function createGoogleGenerativeAiTransportStreamFn(): StreamFn {
|
||||
return (rawModel, context, rawOptions) => {
|
||||
const model = rawModel as GoogleTransportModel;
|
||||
const options = rawOptions as GoogleTransportOptions | undefined;
|
||||
const eventStream = createAssistantMessageEventStream();
|
||||
const stream = eventStream as unknown as { push(event: unknown): void; end(): void };
|
||||
void (async () => {
|
||||
const output: MutableAssistantOutput = {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: "google-generative-ai",
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
try {
|
||||
const apiKey = options?.apiKey ?? getEnvApiKey(model.provider) ?? undefined;
|
||||
const fetch = buildGuardedModelFetch(model);
|
||||
let params = buildGoogleGenerativeAiParams(model, context, options);
|
||||
const nextParams = await options?.onPayload?.(params, model);
|
||||
if (nextParams !== undefined) {
|
||||
params = nextParams as GoogleGenerateContentRequest;
|
||||
}
|
||||
const response = await fetch(buildGoogleRequestUrl(model), {
|
||||
method: "POST",
|
||||
headers: buildGoogleHeaders(model, apiKey, options?.headers),
|
||||
body: JSON.stringify(params),
|
||||
signal: options?.signal,
|
||||
});
|
||||
if (!response.ok) {
|
||||
const message = await response.text().catch(() => "");
|
||||
throw new Error(`Google Generative AI API error (${response.status}): ${message}`);
|
||||
}
|
||||
stream.push({ type: "start", partial: output as never });
|
||||
let currentBlockIndex = -1;
|
||||
for await (const chunk of parseGoogleSseChunks(response, options?.signal)) {
|
||||
output.responseId ||= chunk.responseId;
|
||||
updateUsage(output, model, chunk);
|
||||
const candidate = chunk.candidates?.[0];
|
||||
if (candidate?.content?.parts) {
|
||||
for (const part of candidate.content.parts) {
|
||||
if (typeof part.text === "string") {
|
||||
const isThinking = part.thought === true;
|
||||
const currentBlock = output.content[currentBlockIndex];
|
||||
if (
|
||||
currentBlockIndex < 0 ||
|
||||
!currentBlock ||
|
||||
(isThinking && currentBlock.type !== "thinking") ||
|
||||
(!isThinking && currentBlock.type !== "text")
|
||||
) {
|
||||
if (currentBlockIndex >= 0) {
|
||||
pushTextBlockEnd(stream, output, currentBlockIndex);
|
||||
}
|
||||
if (isThinking) {
|
||||
output.content.push({ type: "thinking", thinking: "" });
|
||||
currentBlockIndex = output.content.length - 1;
|
||||
stream.push({
|
||||
type: "thinking_start",
|
||||
contentIndex: currentBlockIndex,
|
||||
partial: output as never,
|
||||
});
|
||||
} else {
|
||||
output.content.push({ type: "text", text: "" });
|
||||
currentBlockIndex = output.content.length - 1;
|
||||
stream.push({
|
||||
type: "text_start",
|
||||
contentIndex: currentBlockIndex,
|
||||
partial: output as never,
|
||||
});
|
||||
}
|
||||
}
|
||||
const activeBlock = output.content[currentBlockIndex];
|
||||
if (activeBlock?.type === "thinking") {
|
||||
activeBlock.thinking += part.text;
|
||||
activeBlock.thinkingSignature = retainThoughtSignature(
|
||||
activeBlock.thinkingSignature,
|
||||
part.thoughtSignature,
|
||||
);
|
||||
stream.push({
|
||||
type: "thinking_delta",
|
||||
contentIndex: currentBlockIndex,
|
||||
delta: part.text,
|
||||
partial: output as never,
|
||||
});
|
||||
} else if (activeBlock?.type === "text") {
|
||||
activeBlock.text += part.text;
|
||||
activeBlock.textSignature = retainThoughtSignature(
|
||||
activeBlock.textSignature,
|
||||
part.thoughtSignature,
|
||||
);
|
||||
stream.push({
|
||||
type: "text_delta",
|
||||
contentIndex: currentBlockIndex,
|
||||
delta: part.text,
|
||||
partial: output as never,
|
||||
});
|
||||
}
|
||||
}
|
||||
if (part.functionCall) {
|
||||
if (currentBlockIndex >= 0) {
|
||||
pushTextBlockEnd(stream, output, currentBlockIndex);
|
||||
currentBlockIndex = -1;
|
||||
}
|
||||
const providedId = part.functionCall.id;
|
||||
const isDuplicate = output.content.some(
|
||||
(block) => block.type === "toolCall" && block.id === providedId,
|
||||
);
|
||||
const toolCallId =
|
||||
providedId && !isDuplicate
|
||||
? providedId
|
||||
: `${part.functionCall.name || "tool"}_${Date.now()}_${++toolCallCounter}`;
|
||||
const toolCall: GoogleTransportContentBlock = {
|
||||
type: "toolCall",
|
||||
id: toolCallId,
|
||||
name: part.functionCall.name || "",
|
||||
arguments: part.functionCall.args ?? {},
|
||||
};
|
||||
output.content.push(toolCall);
|
||||
const blockIndex = output.content.length - 1;
|
||||
stream.push({
|
||||
type: "toolcall_start",
|
||||
contentIndex: blockIndex,
|
||||
partial: output as never,
|
||||
});
|
||||
stream.push({
|
||||
type: "toolcall_delta",
|
||||
contentIndex: blockIndex,
|
||||
delta: JSON.stringify(toolCall.arguments),
|
||||
partial: output as never,
|
||||
});
|
||||
stream.push({
|
||||
type: "toolcall_end",
|
||||
contentIndex: blockIndex,
|
||||
toolCall,
|
||||
partial: output as never,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
if (typeof candidate?.finishReason === "string") {
|
||||
output.stopReason = mapStopReasonString(candidate.finishReason);
|
||||
if (output.content.some((block) => block.type === "toolCall")) {
|
||||
output.stopReason = "toolUse";
|
||||
}
|
||||
}
|
||||
}
|
||||
if (currentBlockIndex >= 0) {
|
||||
pushTextBlockEnd(stream, output, currentBlockIndex);
|
||||
}
|
||||
if (options?.signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
if (output.stopReason === "aborted" || output.stopReason === "error") {
|
||||
throw new Error("An unknown error occurred");
|
||||
}
|
||||
stream.push({ type: "done", reason: output.stopReason as never, message: output as never });
|
||||
stream.end();
|
||||
} catch (error) {
|
||||
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
|
||||
output.errorMessage = error instanceof Error ? error.message : JSON.stringify(error);
|
||||
stream.push({ type: "error", reason: output.stopReason as never, error: output as never });
|
||||
stream.end();
|
||||
}
|
||||
})();
|
||||
return eventStream as unknown as ReturnType<StreamFn>;
|
||||
};
|
||||
}
|
||||
@@ -21,6 +21,7 @@ describe("openai transport stream", () => {
|
||||
expect(isTransportAwareApiSupported("openai-completions")).toBe(true);
|
||||
expect(isTransportAwareApiSupported("azure-openai-responses")).toBe(true);
|
||||
expect(isTransportAwareApiSupported("anthropic-messages")).toBe(true);
|
||||
expect(isTransportAwareApiSupported("google-generative-ai")).toBe(true);
|
||||
});
|
||||
|
||||
it("prepares a custom simple-completion api alias when transport overrides are attached", () => {
|
||||
@@ -89,6 +90,103 @@ describe("openai transport stream", () => {
|
||||
expect(buildTransportAwareSimpleStreamFn(model)).toBeTypeOf("function");
|
||||
});
|
||||
|
||||
it("prepares a Google simple-completion api alias when transport overrides are attached", () => {
|
||||
const model = attachModelProviderRequestTransport(
|
||||
{
|
||||
id: "gemini-3.1-pro-preview",
|
||||
name: "Gemini 3.1 Pro Preview",
|
||||
api: "google-generative-ai",
|
||||
provider: "google",
|
||||
baseUrl: "https://generativelanguage.googleapis.com/v1beta",
|
||||
reasoning: true,
|
||||
input: ["text"],
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
|
||||
contextWindow: 200000,
|
||||
maxTokens: 8192,
|
||||
} satisfies Model<"google-generative-ai">,
|
||||
{
|
||||
proxy: {
|
||||
mode: "explicit-proxy",
|
||||
url: "http://proxy.internal:8443",
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
const prepared = prepareTransportAwareSimpleModel(model);
|
||||
|
||||
expect(resolveTransportAwareSimpleApi(model.api)).toBe(
|
||||
"openclaw-google-generative-ai-transport",
|
||||
);
|
||||
expect(prepared).toMatchObject({
|
||||
api: "openclaw-google-generative-ai-transport",
|
||||
provider: "google",
|
||||
id: "gemini-3.1-pro-preview",
|
||||
});
|
||||
expect(buildTransportAwareSimpleStreamFn(model)).toBeTypeOf("function");
|
||||
});
|
||||
|
||||
it("keeps github-copilot OpenAI-family models on the shared transport seam", () => {
|
||||
const model = attachModelProviderRequestTransport(
|
||||
{
|
||||
id: "gpt-5.4",
|
||||
name: "GPT-5.4",
|
||||
api: "openai-responses",
|
||||
provider: "github-copilot",
|
||||
baseUrl: "https://api.githubcopilot.com/v1",
|
||||
reasoning: true,
|
||||
input: ["text", "image"],
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
|
||||
contextWindow: 200000,
|
||||
maxTokens: 8192,
|
||||
} satisfies Model<"openai-responses">,
|
||||
{
|
||||
proxy: {
|
||||
mode: "explicit-proxy",
|
||||
url: "http://proxy.internal:8443",
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
expect(resolveTransportAwareSimpleApi(model.api)).toBe("openclaw-openai-responses-transport");
|
||||
expect(prepareTransportAwareSimpleModel(model)).toMatchObject({
|
||||
api: "openclaw-openai-responses-transport",
|
||||
provider: "github-copilot",
|
||||
id: "gpt-5.4",
|
||||
});
|
||||
expect(buildTransportAwareSimpleStreamFn(model)).toBeTypeOf("function");
|
||||
});
|
||||
|
||||
it("keeps github-copilot Claude models on the shared Anthropic transport seam", () => {
|
||||
const model = attachModelProviderRequestTransport(
|
||||
{
|
||||
id: "claude-sonnet-4.6",
|
||||
name: "Claude Sonnet 4.6",
|
||||
api: "anthropic-messages",
|
||||
provider: "github-copilot",
|
||||
baseUrl: "https://api.githubcopilot.com/anthropic",
|
||||
reasoning: true,
|
||||
input: ["text", "image"],
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
|
||||
contextWindow: 200000,
|
||||
maxTokens: 8192,
|
||||
} satisfies Model<"anthropic-messages">,
|
||||
{
|
||||
proxy: {
|
||||
mode: "explicit-proxy",
|
||||
url: "http://proxy.internal:8443",
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
expect(resolveTransportAwareSimpleApi(model.api)).toBe("openclaw-anthropic-messages-transport");
|
||||
expect(prepareTransportAwareSimpleModel(model)).toMatchObject({
|
||||
api: "openclaw-anthropic-messages-transport",
|
||||
provider: "github-copilot",
|
||||
id: "claude-sonnet-4.6",
|
||||
});
|
||||
expect(buildTransportAwareSimpleStreamFn(model)).toBeTypeOf("function");
|
||||
});
|
||||
|
||||
it("removes unpaired surrogate code units but preserves valid surrogate pairs", () => {
|
||||
const high = String.fromCharCode(0xd83d);
|
||||
const low = String.fromCharCode(0xdc00);
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import type { StreamFn } from "@mariozechner/pi-agent-core";
|
||||
import type { Api, Model } from "@mariozechner/pi-ai";
|
||||
import { createAnthropicMessagesTransportStreamFn } from "./anthropic-transport-stream.js";
|
||||
import { createGoogleGenerativeAiTransportStreamFn } from "./google-transport-stream.js";
|
||||
import {
|
||||
createAzureOpenAIResponsesTransportStreamFn,
|
||||
createOpenAICompletionsTransportStreamFn,
|
||||
@@ -13,6 +14,7 @@ const SUPPORTED_TRANSPORT_APIS = new Set<Api>([
|
||||
"openai-completions",
|
||||
"azure-openai-responses",
|
||||
"anthropic-messages",
|
||||
"google-generative-ai",
|
||||
]);
|
||||
|
||||
const SIMPLE_TRANSPORT_API_ALIAS: Record<string, Api> = {
|
||||
@@ -20,6 +22,7 @@ const SIMPLE_TRANSPORT_API_ALIAS: Record<string, Api> = {
|
||||
"openai-completions": "openclaw-openai-completions-transport",
|
||||
"azure-openai-responses": "openclaw-azure-openai-responses-transport",
|
||||
"anthropic-messages": "openclaw-anthropic-messages-transport",
|
||||
"google-generative-ai": "openclaw-google-generative-ai-transport",
|
||||
};
|
||||
|
||||
function hasTransportOverrides(model: Model<Api>): boolean {
|
||||
@@ -53,6 +56,8 @@ export function createTransportAwareStreamFnForModel(model: Model<Api>): StreamF
|
||||
return createAzureOpenAIResponsesTransportStreamFn();
|
||||
case "anthropic-messages":
|
||||
return createAnthropicMessagesTransportStreamFn();
|
||||
case "google-generative-ai":
|
||||
return createGoogleGenerativeAiTransportStreamFn();
|
||||
default:
|
||||
return undefined;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user