test: retry transient openai websocket live stream

This commit is contained in:
Peter Steinberger
2026-04-27 18:43:43 +01:00
parent 2f909b0b21
commit b39d80835f

View File

@@ -128,8 +128,35 @@ async function collectEvents(stream: StreamReturn): Promise<AssistantMessageEven
function expectDone(events: AssistantMessageEvent[]): AssistantMessage {
const done = events.find((event) => event.type === "done")?.message;
expect(done).toBeDefined();
return done!;
if (!done) {
throw new MissingDoneEventError(events);
}
return done;
}
class MissingDoneEventError extends Error {
constructor(events: AssistantMessageEvent[]) {
super(
`OpenAI WebSocket stream ended without a done event; event types: ${events.map((event) => event.type).join(", ") || "<none>"}`,
);
this.name = "MissingDoneEventError";
}
}
function isTransientWebSocketLiveError(error: unknown): boolean {
if (error instanceof MissingDoneEventError) {
return true;
}
if (!(error instanceof Error)) {
return false;
}
const message = error.message.toLowerCase();
return (
message.includes("websocket closed") ||
message.includes("websocket stream ended") ||
message.includes("timeout") ||
message.includes("aborted")
);
}
function assistantText(message: AssistantMessage): string {
@@ -305,76 +332,89 @@ describe("OpenAI WebSocket e2e", () => {
testFn(
"surfaces replay-safe reasoning metadata on websocket tool turns",
async () => {
const sid = freshSession("tool-reasoning");
const completedResponses: ResponseObject[] = [];
openAIWsStreamModule.__testing.setDepsForTest({
createManager: (options) => {
const manager = new openAIWsConnectionModule.OpenAIWebSocketManager(options);
manager.onMessage((event) => {
if (event.type === "response.completed") {
completedResponses.push(event.response);
}
let lastError: unknown;
for (let attempt = 0; attempt < 2; attempt += 1) {
try {
const sid = freshSession(`tool-reasoning-${attempt}`);
const completedResponses: ResponseObject[] = [];
openAIWsStreamModule.__testing.setDepsForTest({
createManager: (options) => {
const manager = new openAIWsConnectionModule.OpenAIWebSocketManager(options);
manager.onMessage((event) => {
if (event.type === "response.completed") {
completedResponses.push(event.response);
}
});
return manager;
},
});
return manager;
},
});
const streamFn = openAIWsStreamModule.createOpenAIWebSocketStreamFn(API_KEY!, sid);
const firstContext = makeToolContext(
"Think carefully, call the tool `noop` with {} first, then after the tool result reply with exactly TOOL_OK.",
);
const firstDone = expectDone(
await collectEvents(
streamFn(model, firstContext, {
transport: "websocket",
toolChoice: "required",
reasoningEffort: "high",
reasoningSummary: "detailed",
maxTokens: 256,
} as unknown as StreamFnParams[2]),
),
);
const streamFn = openAIWsStreamModule.createOpenAIWebSocketStreamFn(API_KEY!, sid);
const firstContext = makeToolContext(
"Think carefully, call the tool `noop` with {} first, then after the tool result reply with exactly TOOL_OK.",
);
const firstDone = expectDone(
await collectEvents(
streamFn(model, firstContext, {
transport: "websocket",
toolChoice: "required",
reasoningEffort: "high",
reasoningSummary: "detailed",
maxTokens: 256,
} as unknown as StreamFnParams[2]),
),
);
const firstResponse = completedResponses[0];
expect(firstResponse).toBeDefined();
const firstResponse = completedResponses[0];
expect(firstResponse).toBeDefined();
const rawReasoningItems = (firstResponse?.output ?? []).filter(
(item): item is Extract<OutputItem, { type: "reasoning" | `reasoning.${string}` }> =>
item.type === "reasoning" || item.type.startsWith("reasoning."),
);
const replayableReasoningItems = rawReasoningItems.filter(
(item) => extractReasoningText(item).length > 0,
);
const thinkingBlocks = extractThinkingBlocks(firstDone);
expect(thinkingBlocks).toHaveLength(replayableReasoningItems.length);
expect(thinkingBlocks.map((block) => block.thinking)).toEqual(
replayableReasoningItems.map((item) => extractReasoningText(item)),
);
expect(
thinkingBlocks.map((block) => parseReasoningSignature(block.thinkingSignature)),
).toEqual(replayableReasoningItems.map((item) => toExpectedReasoningSignature(item)));
const rawReasoningItems = (firstResponse?.output ?? []).filter(
(item): item is Extract<OutputItem, { type: "reasoning" | `reasoning.${string}` }> =>
item.type === "reasoning" || item.type.startsWith("reasoning."),
);
const replayableReasoningItems = rawReasoningItems.filter(
(item) => extractReasoningText(item).length > 0,
);
const thinkingBlocks = extractThinkingBlocks(firstDone);
expect(thinkingBlocks).toHaveLength(replayableReasoningItems.length);
expect(thinkingBlocks.map((block) => block.thinking)).toEqual(
replayableReasoningItems.map((item) => extractReasoningText(item)),
);
expect(
thinkingBlocks.map((block) => parseReasoningSignature(block.thinkingSignature)),
).toEqual(replayableReasoningItems.map((item) => toExpectedReasoningSignature(item)));
const rawToolCall = firstResponse?.output.find(
(item): item is Extract<OutputItem, { type: "function_call" }> =>
item.type === "function_call",
);
expect(rawToolCall).toBeDefined();
const toolCall = extractToolCall(firstDone);
expect(toolCall?.name).toBe(rawToolCall?.name);
expect(toolCall?.id).toBe(
rawToolCall ? `${rawToolCall.call_id}|${rawToolCall.id}` : undefined,
);
const rawToolCall = firstResponse?.output.find(
(item): item is Extract<OutputItem, { type: "function_call" }> =>
item.type === "function_call",
);
expect(rawToolCall).toBeDefined();
const toolCall = extractToolCall(firstDone);
expect(toolCall?.name).toBe(rawToolCall?.name);
expect(toolCall?.id).toBe(
rawToolCall ? `${rawToolCall.call_id}|${rawToolCall.id}` : undefined,
);
const secondDone = await runWebsocketToolFollowupTurn({
streamFn,
context: firstContext,
firstDone,
toolCallId: toolCall!.id,
output: "TOOL_OK",
});
const secondDone = await runWebsocketToolFollowupTurn({
streamFn,
context: firstContext,
firstDone,
toolCallId: toolCall!.id,
output: "TOOL_OK",
});
expect(assistantText(secondDone)).toMatch(/TOOL_OK/);
expect(assistantText(secondDone)).toMatch(/TOOL_OK/);
return;
} catch (error) {
lastError = error;
openAIWsStreamModule.__testing.setDepsForTest();
if (!isTransientWebSocketLiveError(error) || attempt === 1) {
throw error;
}
}
}
throw lastError;
},
60_000,
120_000,
);
testFn(