test(openai): retry stalled websocket reasoning turn

This commit is contained in:
Vincent Koc
2026-05-06 22:05:45 -07:00
parent 0597e8a065
commit 9910cdb7a9

View File

@@ -144,6 +144,37 @@ class MissingDoneEventError extends Error {
}
}
class WebSocketLiveAttemptTimeoutError extends Error {
constructor(label: string, timeoutMs: number) {
super(`${label} timed out after ${timeoutMs}ms`);
this.name = "WebSocketLiveAttemptTimeoutError";
}
}
async function withWebSocketLiveAttemptTimeout<T>(
label: string,
timeoutMs: number,
run: () => Promise<T>,
): Promise<T> {
let timer: ReturnType<typeof setTimeout> | undefined;
try {
return await Promise.race([
run(),
new Promise<never>((_, reject) => {
timer = setTimeout(
() => reject(new WebSocketLiveAttemptTimeoutError(label, timeoutMs)),
timeoutMs,
);
timer.unref?.();
}),
]);
} finally {
if (timer) {
clearTimeout(timer);
}
}
}
function isTransientWebSocketLiveError(error: unknown): boolean {
if (error instanceof MissingDoneEventError) {
return true;
@@ -351,78 +382,87 @@ describe("OpenAI WebSocket e2e", () => {
async () => {
let lastError: unknown;
for (let attempt = 0; attempt < 2; attempt += 1) {
const sid = freshSession(`tool-reasoning-${attempt}`);
try {
const sid = freshSession(`tool-reasoning-${attempt}`);
const completedResponses: ResponseObject[] = [];
openAIWsStreamModule.__testing.setDepsForTest({
createManager: (options) => {
const manager = new openAIWsConnectionModule.OpenAIWebSocketManager(options);
manager.onMessage((event) => {
if (event.type === "response.completed") {
completedResponses.push(event.response);
}
await withWebSocketLiveAttemptTimeout(
`OpenAI WebSocket reasoning metadata attempt ${attempt + 1}`,
75_000,
async () => {
const completedResponses: ResponseObject[] = [];
openAIWsStreamModule.__testing.setDepsForTest({
createManager: (options) => {
const manager = new openAIWsConnectionModule.OpenAIWebSocketManager(options);
manager.onMessage((event) => {
if (event.type === "response.completed") {
completedResponses.push(event.response);
}
});
return manager;
},
});
return manager;
const streamFn = openAIWsStreamModule.createOpenAIWebSocketStreamFn(API_KEY!, sid);
const firstContext = makeToolContext(
"Think carefully, call the tool `noop` with {} first, then after the tool result reply with exactly TOOL_OK.",
);
const firstDone = expectDone(
await collectEvents(
streamFn(model, firstContext, {
transport: "websocket",
toolChoice: "required",
reasoningEffort: "high",
reasoningSummary: "detailed",
maxTokens: 256,
} as unknown as StreamFnParams[2]),
),
);
const firstResponse = completedResponses[0];
expect(firstResponse).toBeDefined();
const rawReasoningItems = (firstResponse?.output ?? []).filter(
(
item,
): item is Extract<OutputItem, { type: "reasoning" | `reasoning.${string}` }> =>
item.type === "reasoning" || item.type.startsWith("reasoning."),
);
const replayableReasoningItems = rawReasoningItems.filter(
(item) => typeof item.id === "string" && item.id.startsWith("rs_"),
);
const thinkingBlocks = extractThinkingBlocks(firstDone);
expect(thinkingBlocks).toHaveLength(replayableReasoningItems.length);
expect(thinkingBlocks.map((block) => block.thinking)).toEqual(
replayableReasoningItems.map((item) => extractReasoningText(item)),
);
expect(
thinkingBlocks.map((block) => parseReasoningSignature(block.thinkingSignature)),
).toEqual(replayableReasoningItems.map((item) => toExpectedReasoningSignature(item)));
const rawToolCall = firstResponse?.output.find(
(item): item is Extract<OutputItem, { type: "function_call" }> =>
item.type === "function_call",
);
expect(rawToolCall).toBeDefined();
const toolCall = extractToolCall(firstDone);
expect(toolCall?.name).toBe(rawToolCall?.name);
expect(toolCall?.id).toBe(
rawToolCall ? `${rawToolCall.call_id}|${rawToolCall.id}` : undefined,
);
const secondDone = await runWebsocketToolFollowupTurn({
streamFn,
context: firstContext,
firstDone,
toolCallId: toolCall!.id,
output: "TOOL_OK",
});
expect(assistantText(secondDone)).toMatch(/TOOL_OK/);
},
});
const streamFn = openAIWsStreamModule.createOpenAIWebSocketStreamFn(API_KEY!, sid);
const firstContext = makeToolContext(
"Think carefully, call the tool `noop` with {} first, then after the tool result reply with exactly TOOL_OK.",
);
const firstDone = expectDone(
await collectEvents(
streamFn(model, firstContext, {
transport: "websocket",
toolChoice: "required",
reasoningEffort: "high",
reasoningSummary: "detailed",
maxTokens: 256,
} as unknown as StreamFnParams[2]),
),
);
const firstResponse = completedResponses[0];
expect(firstResponse).toBeDefined();
const rawReasoningItems = (firstResponse?.output ?? []).filter(
(item): item is Extract<OutputItem, { type: "reasoning" | `reasoning.${string}` }> =>
item.type === "reasoning" || item.type.startsWith("reasoning."),
);
const replayableReasoningItems = rawReasoningItems.filter(
(item) => typeof item.id === "string" && item.id.startsWith("rs_"),
);
const thinkingBlocks = extractThinkingBlocks(firstDone);
expect(thinkingBlocks).toHaveLength(replayableReasoningItems.length);
expect(thinkingBlocks.map((block) => block.thinking)).toEqual(
replayableReasoningItems.map((item) => extractReasoningText(item)),
);
expect(
thinkingBlocks.map((block) => parseReasoningSignature(block.thinkingSignature)),
).toEqual(replayableReasoningItems.map((item) => toExpectedReasoningSignature(item)));
const rawToolCall = firstResponse?.output.find(
(item): item is Extract<OutputItem, { type: "function_call" }> =>
item.type === "function_call",
);
expect(rawToolCall).toBeDefined();
const toolCall = extractToolCall(firstDone);
expect(toolCall?.name).toBe(rawToolCall?.name);
expect(toolCall?.id).toBe(
rawToolCall ? `${rawToolCall.call_id}|${rawToolCall.id}` : undefined,
);
const secondDone = await runWebsocketToolFollowupTurn({
streamFn,
context: firstContext,
firstDone,
toolCallId: toolCall!.id,
output: "TOOL_OK",
});
expect(assistantText(secondDone)).toMatch(/TOOL_OK/);
return;
} catch (error) {
lastError = error;
openAIWsStreamModule.releaseWsSession(sid);
openAIWsStreamModule.__testing.setDepsForTest();
if (!isTransientWebSocketLiveError(error) || attempt === 1) {
throw error;