refactor(providers): share transport stream helpers

This commit is contained in:
Vincent Koc
2026-04-04 03:48:26 +09:00
parent fcec417d7d
commit d8458a1481
5 changed files with 234 additions and 103 deletions

View File

@@ -2,7 +2,6 @@ import Anthropic from "@anthropic-ai/sdk";
import type { StreamFn } from "@mariozechner/pi-agent-core";
import {
calculateCost,
createAssistantMessageEventStream,
getEnvApiKey,
parseStreamingJson,
type AnthropicOptions,
@@ -12,9 +11,16 @@ import {
type ThinkingLevel,
} from "@mariozechner/pi-ai";
import { buildCopilotDynamicHeaders, hasCopilotVisionInput } from "./copilot-dynamic-headers.js";
import { sanitizeTransportPayloadText } from "./openai-transport-stream.js";
import { buildGuardedModelFetch } from "./provider-transport-fetch.js";
import { transformTransportMessages } from "./transport-message-transform.js";
import {
createEmptyTransportUsage,
createWritableTransportEventStream,
failTransportStream,
finalizeTransportStream,
mergeTransportHeaders,
sanitizeTransportPayloadText,
} from "./transport-stream-shared.js";
const CLAUDE_CODE_VERSION = "2.1.75";
const CLAUDE_CODE_TOOLS = [
@@ -86,10 +92,6 @@ type MutableAssistantOutput = {
errorMessage?: string;
};
function sanitizeAnthropicText(text: string): string {
return sanitizeTransportPayloadText(text);
}
function supportsAdaptiveThinking(modelId: string): boolean {
return (
modelId.includes("opus-4-6") ||
@@ -143,18 +145,6 @@ function adjustMaxTokensForThinking(params: {
return { maxTokens, thinkingBudget };
}
function mergeHeaders(
...headerSources: Array<Record<string, string> | undefined>
): Record<string, string> | undefined {
const merged: Record<string, string> = {};
for (const headers of headerSources) {
if (headers) {
Object.assign(merged, headers);
}
}
return Object.keys(merged).length > 0 ? merged : undefined;
}
function isAnthropicOAuthToken(apiKey: string): boolean {
return apiKey.includes("sk-ant-oat");
}
@@ -197,7 +187,7 @@ function convertContentBlocks(
) {
const hasImages = content.some((item) => item.type === "image");
if (!hasImages) {
return sanitizeAnthropicText(
return sanitizeTransportPayloadText(
content.map((item) => ("text" in item ? item.text : "")).join("\n"),
);
}
@@ -205,7 +195,7 @@ function convertContentBlocks(
if (block.type === "text") {
return {
type: "text",
text: sanitizeAnthropicText(block.text),
text: sanitizeTransportPayloadText(block.text),
};
}
return {
@@ -245,7 +235,7 @@ function convertAnthropicMessages(
if (msg.content.trim().length > 0) {
params.push({
role: "user",
content: sanitizeAnthropicText(msg.content),
content: sanitizeTransportPayloadText(msg.content),
});
}
continue;
@@ -260,7 +250,7 @@ function convertAnthropicMessages(
item.type === "text"
? {
type: "text",
text: sanitizeAnthropicText(item.text),
text: sanitizeTransportPayloadText(item.text),
}
: {
type: "image",
@@ -293,7 +283,7 @@ function convertAnthropicMessages(
if (block.text.trim().length > 0) {
blocks.push({
type: "text",
text: sanitizeAnthropicText(block.text),
text: sanitizeTransportPayloadText(block.text),
});
}
continue;
@@ -312,12 +302,12 @@ function convertAnthropicMessages(
if (!block.thinkingSignature || block.thinkingSignature.trim().length === 0) {
blocks.push({
type: "text",
text: sanitizeAnthropicText(block.thinking),
text: sanitizeTransportPayloadText(block.thinking),
});
} else {
blocks.push({
type: "thinking",
thinking: sanitizeAnthropicText(block.thinking),
thinking: sanitizeTransportPayloadText(block.thinking),
signature: block.thinkingSignature,
});
}
@@ -454,7 +444,7 @@ function createAnthropicTransportClient(params: {
authToken: apiKey,
baseURL: model.baseUrl,
dangerouslyAllowBrowser: true,
defaultHeaders: mergeHeaders(
defaultHeaders: mergeTransportHeaders(
{
accept: "application/json",
"anthropic-dangerous-direct-browser-access": "true",
@@ -483,7 +473,7 @@ function createAnthropicTransportClient(params: {
authToken: apiKey,
baseURL: model.baseUrl,
dangerouslyAllowBrowser: true,
defaultHeaders: mergeHeaders(
defaultHeaders: mergeTransportHeaders(
{
accept: "application/json",
"anthropic-dangerous-direct-browser-access": "true",
@@ -504,7 +494,7 @@ function createAnthropicTransportClient(params: {
apiKey,
baseURL: model.baseUrl,
dangerouslyAllowBrowser: true,
defaultHeaders: mergeHeaders(
defaultHeaders: mergeTransportHeaders(
{
accept: "application/json",
"anthropic-dangerous-direct-browser-access": "true",
@@ -544,7 +534,7 @@ function buildAnthropicParams(
? [
{
type: "text",
text: sanitizeAnthropicText(context.systemPrompt),
text: sanitizeTransportPayloadText(context.systemPrompt),
...(cacheControl ? { cache_control: cacheControl } : {}),
},
]
@@ -554,7 +544,7 @@ function buildAnthropicParams(
params.system = [
{
type: "text",
text: sanitizeAnthropicText(context.systemPrompt),
text: sanitizeTransportPayloadText(context.systemPrompt),
...(cacheControl ? { cache_control: cacheControl } : {}),
},
];
@@ -639,7 +629,7 @@ export function createAnthropicMessagesTransportStreamFn(): StreamFn {
return (rawModel, context, rawOptions) => {
const model = rawModel as AnthropicTransportModel;
const options = rawOptions as AnthropicTransportOptions | undefined;
const stream = createAssistantMessageEventStream();
const { eventStream, stream } = createWritableTransportEventStream();
void (async () => {
const output: MutableAssistantOutput = {
role: "assistant",
@@ -647,14 +637,7 @@ export function createAnthropicMessagesTransportStreamFn(): StreamFn {
api: "anthropic-messages",
provider: model.provider,
model: model.id,
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
usage: createEmptyTransportUsage(),
stopReason: "stop",
timestamp: Date.now(),
};
@@ -895,24 +878,21 @@ export function createAnthropicMessagesTransportStreamFn(): StreamFn {
calculateCost(model, output.usage);
}
}
if (transportOptions.signal?.aborted) {
throw new Error("Request was aborted");
}
if (output.stopReason === "aborted" || output.stopReason === "error") {
throw new Error("An unknown error occurred");
}
stream.push({ type: "done", reason: output.stopReason as never, message: output as never });
stream.end();
finalizeTransportStream({ stream, output, signal: transportOptions.signal });
} catch (error) {
for (const block of output.content) {
delete block.index;
}
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
output.errorMessage = error instanceof Error ? error.message : JSON.stringify(error);
stream.push({ type: "error", reason: output.stopReason as never, error: output as never });
stream.end();
failTransportStream({
stream,
output,
signal: options?.signal,
error,
cleanup: () => {
for (const block of output.content) {
delete block.index;
}
},
});
}
})();
return stream as ReturnType<StreamFn>;
return eventStream as ReturnType<StreamFn>;
};
}

View File

@@ -1,7 +1,6 @@
import type { StreamFn } from "@mariozechner/pi-agent-core";
import {
calculateCost,
createAssistantMessageEventStream,
getEnvApiKey,
type Context,
type Model,
@@ -10,9 +9,17 @@ import {
} from "@mariozechner/pi-ai";
import { parseGeminiAuth } from "../infra/gemini-auth.js";
import { normalizeGoogleApiBaseUrl } from "../infra/google-api-base-url.js";
import { sanitizeTransportPayloadText } from "./openai-transport-stream.js";
import { buildGuardedModelFetch } from "./provider-transport-fetch.js";
import { transformTransportMessages } from "./transport-message-transform.js";
import {
createEmptyTransportUsage,
createWritableTransportEventStream,
failTransportStream,
finalizeTransportStream,
mergeTransportHeaders,
sanitizeTransportPayloadText,
type WritableTransportStream,
} from "./transport-stream-shared.js";
type GoogleTransportModel = Model<"google-generative-ai"> & {
headers?: Record<string, string>;
@@ -101,10 +108,6 @@ type GoogleSseChunk = {
let toolCallCounter = 0;
function sanitizeGoogleTransportText(text: string): string {
return sanitizeTransportPayloadText(text);
}
function isGemini3ProModel(modelId: string): boolean {
return /gemini-3(?:\.\d+)?-pro/.test(modelId.toLowerCase());
}
@@ -277,14 +280,14 @@ function convertGoogleMessages(model: GoogleTransportModel, context: Context) {
if (typeof msg.content === "string") {
contents.push({
role: "user",
parts: [{ text: sanitizeGoogleTransportText(msg.content) }],
parts: [{ text: sanitizeTransportPayloadText(msg.content) }],
});
continue;
}
const parts = msg.content
.map((item) =>
item.type === "text"
? { text: sanitizeGoogleTransportText(item.text) }
? { text: sanitizeTransportPayloadText(item.text) }
: {
inlineData: {
mimeType: item.mimeType,
@@ -308,7 +311,7 @@ function convertGoogleMessages(model: GoogleTransportModel, context: Context) {
continue;
}
parts.push({
text: sanitizeGoogleTransportText(block.text),
text: sanitizeTransportPayloadText(block.text),
...(isSameProviderAndModel && block.textSignature
? { thoughtSignature: block.textSignature }
: {}),
@@ -322,11 +325,11 @@ function convertGoogleMessages(model: GoogleTransportModel, context: Context) {
if (isSameProviderAndModel) {
parts.push({
thought: true,
text: sanitizeGoogleTransportText(block.thinking),
text: sanitizeTransportPayloadText(block.thinking),
...(block.thinkingSignature ? { thoughtSignature: block.thinkingSignature } : {}),
});
} else {
parts.push({ text: sanitizeGoogleTransportText(block.thinking) });
parts.push({ text: sanitizeTransportPayloadText(block.thinking) });
}
continue;
}
@@ -364,7 +367,7 @@ function convertGoogleMessages(model: GoogleTransportModel, context: Context) {
)
: [];
const responseValue = textResult
? sanitizeGoogleTransportText(textResult)
? sanitizeTransportPayloadText(textResult)
: imageContent.length > 0
? "(see attached image)"
: "";
@@ -442,7 +445,7 @@ export function buildGoogleGenerativeAiParams(
}
if (context.systemPrompt) {
params.systemInstruction = {
parts: [{ text: sanitizeGoogleTransportText(context.systemPrompt) }],
parts: [{ text: sanitizeTransportPayloadText(context.systemPrompt) }],
};
}
if (context.tools?.length) {
@@ -463,12 +466,18 @@ function buildGoogleHeaders(
optionHeaders: Record<string, string> | undefined,
): Record<string, string> {
const authHeaders = apiKey ? parseGeminiAuth(apiKey).headers : undefined;
return {
accept: "text/event-stream",
...authHeaders,
...model.headers,
...optionHeaders,
};
return (
mergeTransportHeaders(
{
accept: "text/event-stream",
},
authHeaders,
model.headers,
optionHeaders,
) ?? {
accept: "text/event-stream",
}
);
}
async function* parseGoogleSseChunks(
@@ -539,7 +548,7 @@ function updateUsage(
}
function pushTextBlockEnd(
stream: { push(event: unknown): void },
stream: WritableTransportStream,
output: MutableAssistantOutput,
blockIndex: number,
) {
@@ -570,8 +579,7 @@ export function createGoogleGenerativeAiTransportStreamFn(): StreamFn {
return (rawModel, context, rawOptions) => {
const model = rawModel as GoogleTransportModel;
const options = rawOptions as GoogleTransportOptions | undefined;
const eventStream = createAssistantMessageEventStream();
const stream = eventStream as unknown as { push(event: unknown): void; end(): void };
const { eventStream, stream } = createWritableTransportEventStream();
void (async () => {
const output: MutableAssistantOutput = {
role: "assistant",
@@ -579,14 +587,7 @@ export function createGoogleGenerativeAiTransportStreamFn(): StreamFn {
api: "google-generative-ai",
provider: model.provider,
model: model.id,
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
usage: createEmptyTransportUsage(),
stopReason: "stop",
timestamp: Date.now(),
};
@@ -724,19 +725,9 @@ export function createGoogleGenerativeAiTransportStreamFn(): StreamFn {
if (currentBlockIndex >= 0) {
pushTextBlockEnd(stream, output, currentBlockIndex);
}
if (options?.signal?.aborted) {
throw new Error("Request was aborted");
}
if (output.stopReason === "aborted" || output.stopReason === "error") {
throw new Error("An unknown error occurred");
}
stream.push({ type: "done", reason: output.stopReason as never, message: output as never });
stream.end();
finalizeTransportStream({ stream, output, signal: options?.signal });
} catch (error) {
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
output.errorMessage = error instanceof Error ? error.message : JSON.stringify(error);
stream.push({ type: "error", reason: output.stopReason as never, error: output as never });
stream.end();
failTransportStream({ stream, output, signal: options?.signal, error });
}
})();
return eventStream as unknown as ReturnType<StreamFn>;

View File

@@ -23,6 +23,7 @@ import { resolveOpenAICompletionsCompatDefaultsFromCapabilities } from "./openai
import { resolveProviderRequestCapabilities } from "./provider-attribution.js";
import { buildGuardedModelFetch } from "./provider-transport-fetch.js";
import { transformTransportMessages } from "./transport-message-transform.js";
import { sanitizeTransportPayloadText } from "./transport-stream-shared.js";
const DEFAULT_AZURE_OPENAI_API_VERSION = "2024-12-01-preview";
@@ -81,12 +82,7 @@ type MutableAssistantOutput = {
errorMessage?: string;
};
export function sanitizeTransportPayloadText(text: string): string {
return text.replace(
/[\uD800-\uDBFF](?![\uDC00-\uDFFF])|(?<![\uD800-\uDBFF])[\uDC00-\uDFFF]/g,
"",
);
}
export { sanitizeTransportPayloadText } from "./transport-stream-shared.js";
function stringifyUnknown(value: unknown, fallback = ""): string {
if (typeof value === "string") {

View File

@@ -0,0 +1,75 @@
import { describe, expect, it, vi } from "vitest";
import {
failTransportStream,
finalizeTransportStream,
mergeTransportHeaders,
sanitizeTransportPayloadText,
} from "./transport-stream-shared.js";
describe("transport stream shared helpers", () => {
it("sanitizes unpaired surrogate code units", () => {
const high = String.fromCharCode(0xd83d);
const low = String.fromCharCode(0xdc00);
expect(sanitizeTransportPayloadText(`left${high}right`)).toBe("leftright");
expect(sanitizeTransportPayloadText(`left${low}right`)).toBe("leftright");
expect(sanitizeTransportPayloadText("emoji 🙈 ok")).toBe("emoji 🙈 ok");
});
it("merges transport headers in source order", () => {
expect(
mergeTransportHeaders(
{ accept: "text/event-stream", "x-base": "one" },
{ authorization: "Bearer token" },
{ "x-base": "two" },
),
).toEqual({
accept: "text/event-stream",
authorization: "Bearer token",
"x-base": "two",
});
expect(mergeTransportHeaders(undefined, undefined)).toBeUndefined();
});
it("finalizes successful transport streams", () => {
const push = vi.fn();
const end = vi.fn();
const output = { stopReason: "stop" };
finalizeTransportStream({
stream: { push, end },
output,
});
expect(push).toHaveBeenCalledWith({
type: "done",
reason: "stop",
message: output,
});
expect(end).toHaveBeenCalledTimes(1);
});
it("marks transport stream failures and runs cleanup", () => {
const push = vi.fn();
const end = vi.fn();
const cleanup = vi.fn();
const output: { stopReason: string; errorMessage?: string } = { stopReason: "stop" };
failTransportStream({
stream: { push, end },
output,
error: new Error("boom"),
cleanup,
});
expect(cleanup).toHaveBeenCalledTimes(1);
expect(output.stopReason).toBe("error");
expect(output.errorMessage).toBe("boom");
expect(push).toHaveBeenCalledWith({
type: "error",
reason: "error",
error: output,
});
expect(end).toHaveBeenCalledTimes(1);
});
});

View File

@@ -0,0 +1,89 @@
import { createAssistantMessageEventStream } from "@mariozechner/pi-ai";
export type TransportUsage = {
input: number;
output: number;
cacheRead: number;
cacheWrite: number;
totalTokens: number;
cost: { input: number; output: number; cacheRead: number; cacheWrite: number; total: number };
};
export type WritableTransportStream = {
push(event: unknown): void;
end(): void;
};
type TransportOutputShape = {
stopReason: string;
errorMessage?: string;
};
export function sanitizeTransportPayloadText(text: string): string {
return text.replace(
/[\uD800-\uDBFF](?![\uDC00-\uDFFF])|(?<![\uD800-\uDBFF])[\uDC00-\uDFFF]/g,
"",
);
}
export function mergeTransportHeaders(
...headerSources: Array<Record<string, string> | undefined>
): Record<string, string> | undefined {
const merged: Record<string, string> = {};
for (const headers of headerSources) {
if (headers) {
Object.assign(merged, headers);
}
}
return Object.keys(merged).length > 0 ? merged : undefined;
}
export function createEmptyTransportUsage(): TransportUsage {
return {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
};
}
export function createWritableTransportEventStream() {
const eventStream = createAssistantMessageEventStream();
return {
eventStream,
stream: eventStream as unknown as WritableTransportStream,
};
}
export function finalizeTransportStream(params: {
stream: WritableTransportStream;
output: TransportOutputShape;
signal?: AbortSignal;
}): void {
const { stream, output, signal } = params;
if (signal?.aborted) {
throw new Error("Request was aborted");
}
if (output.stopReason === "aborted" || output.stopReason === "error") {
throw new Error("An unknown error occurred");
}
stream.push({ type: "done", reason: output.stopReason as never, message: output as never });
stream.end();
}
export function failTransportStream(params: {
stream: WritableTransportStream;
output: TransportOutputShape;
signal?: AbortSignal;
error: unknown;
cleanup?: () => void;
}): void {
const { stream, output, signal, error, cleanup } = params;
cleanup?.();
output.stopReason = signal?.aborted ? "aborted" : "error";
output.errorMessage = error instanceof Error ? error.message : JSON.stringify(error);
stream.push({ type: "error", reason: output.stopReason as never, error: output as never });
stream.end();
}