fix: count unknown-tool retries only when streamed

This commit is contained in:
Bob
2026-04-13 22:09:38 +02:00
committed by Onur Solmaz
parent 891e42beec
commit 176a6d9fa1
2 changed files with 281 additions and 23 deletions

View File

@@ -722,6 +722,225 @@ describe("wrapStreamFnTrimToolCallNames", () => {
]);
});
it("counts the final unknown-tool retry when streamed messages omit the tool name", async () => {
const baseFn = vi.fn(() =>
createFakeStream({
events: [
{
type: "toolcall_delta",
message: { role: "assistant", content: [{ type: "toolCall", name: "" }] },
},
],
resultMessage: {
role: "assistant",
content: [{ type: "toolCall", name: " exec ", arguments: { command: "echo retry" } }],
},
}),
);
const wrappedFn = wrapStreamFnTrimToolCallNames(baseFn as never, new Set(["read"]), {
unknownToolThreshold: 1,
});
const firstStream = await Promise.resolve(wrappedFn({} as never, {} as never, {} as never));
await firstStream.result();
const secondStream = await Promise.resolve(wrappedFn({} as never, {} as never, {} as never));
for await (const _item of secondStream) {
// drain
}
const secondResult = (await secondStream.result()) as {
role: string;
content: Array<{ type: string; text?: string; name?: string }>;
};
expect(secondResult.role).toBe("assistant");
expect(secondResult.content).toEqual([
expect.objectContaining({
type: "text",
text: expect.stringContaining('"exec"'),
}),
]);
});
it("resets a provisional streamed unknown-tool retry when later chunks resolve to an allowed tool", async () => {
const baseFn = vi
.fn()
.mockImplementationOnce(() =>
createFakeStream({
events: [
{
type: "toolcall_delta",
message: { role: "assistant", content: [{ type: "toolCall", name: " ex " }] },
},
{
type: "toolcall_delta",
message: { role: "assistant", content: [{ type: "toolCall", name: " exec " }] },
},
],
resultMessage: {
role: "assistant",
content: [{ type: "toolCall", name: " exec ", arguments: { command: "echo ok" } }],
},
}),
)
.mockImplementationOnce(() =>
createFakeStream({
events: [],
resultMessage: {
role: "assistant",
content: [{ type: "toolCall", name: " ex ", arguments: { command: "echo retry" } }],
},
}),
);
const wrappedFn = wrapStreamFnTrimToolCallNames(baseFn as never, new Set(["exec"]), {
unknownToolThreshold: 1,
});
const firstStream = await Promise.resolve(wrappedFn({} as never, {} as never, {} as never));
for await (const _item of firstStream) {
// drain
}
await firstStream.result();
const secondStream = await Promise.resolve(wrappedFn({} as never, {} as never, {} as never));
const secondResult = (await secondStream.result()) as {
role: string;
content: Array<{ type: string; text?: string; name?: string }>;
};
expect(secondResult.role).toBe("assistant");
expect(secondResult.content).toEqual([
expect.objectContaining({
type: "toolCall",
name: "ex",
}),
]);
});
it("keeps processing later streamed messages after one streamed unknown-tool retry was counted", async () => {
const baseFn = vi
.fn()
.mockImplementationOnce(() =>
createFakeStream({
events: [
{
type: "toolcall_delta",
message: { role: "assistant", content: [{ type: "toolCall", name: " re " }] },
},
{
type: "toolcall_delta",
message: { role: "assistant", content: [{ type: "toolCall", name: " read " }] },
},
],
resultMessage: {
role: "assistant",
content: [{ type: "text", text: "resolved to allowed tool" }],
},
}),
)
.mockImplementationOnce(() =>
createFakeStream({
events: [],
resultMessage: {
role: "assistant",
content: [{ type: "toolCall", name: " re ", arguments: { command: "echo retry" } }],
},
}),
);
const wrappedFn = wrapStreamFnTrimToolCallNames(baseFn as never, new Set(["read"]), {
unknownToolThreshold: 1,
});
const firstStream = await Promise.resolve(wrappedFn({} as never, {} as never, {} as never));
for await (const _item of firstStream) {
// drain
}
await firstStream.result();
const secondStream = await Promise.resolve(wrappedFn({} as never, {} as never, {} as never));
const secondResult = (await secondStream.result()) as {
role: string;
content: Array<{ type: string; text?: string; name?: string }>;
};
expect(secondResult.role).toBe("assistant");
expect(secondResult.content).toEqual([
expect.objectContaining({
type: "toolCall",
name: "re",
}),
]);
});
it("resets a stale unknown-tool streak when a streamed message mixes allowed and unknown tools", async () => {
const baseFn = vi
.fn()
.mockImplementationOnce(() =>
createFakeStream({
events: [],
resultMessage: {
role: "assistant",
content: [{ type: "toolCall", name: " ex ", arguments: { command: "echo first" } }],
},
}),
)
.mockImplementationOnce(() =>
createFakeStream({
events: [
{
type: "toolcall_delta",
message: {
role: "assistant",
content: [
{ type: "toolCall", name: " exec ", arguments: { command: "echo allowed" } },
{ type: "toolCall", name: " ex ", arguments: { command: "echo provisional" } },
],
},
},
],
resultMessage: {
role: "assistant",
content: [{ type: "toolCall", name: " exec ", arguments: { command: "echo ok" } }],
},
}),
)
.mockImplementationOnce(() =>
createFakeStream({
events: [],
resultMessage: {
role: "assistant",
content: [{ type: "toolCall", name: " ex ", arguments: { command: "echo retry" } }],
},
}),
);
const wrappedFn = wrapStreamFnTrimToolCallNames(baseFn as never, new Set(["exec"]), {
unknownToolThreshold: 1,
});
const firstStream = await Promise.resolve(wrappedFn({} as never, {} as never, {} as never));
await firstStream.result();
const secondStream = await Promise.resolve(wrappedFn({} as never, {} as never, {} as never));
for await (const _item of secondStream) {
// drain
}
await secondStream.result();
const thirdStream = await Promise.resolve(wrappedFn({} as never, {} as never, {} as never));
const thirdResult = (await thirdStream.result()) as {
role: string;
content: Array<{ type: string; text?: string; name?: string }>;
};
expect(thirdResult.role).toBe("assistant");
expect(thirdResult.content).toEqual([
expect.objectContaining({
type: "toolCall",
name: "ex",
}),
]);
});
it("infers tool names from malformed toolCallId variants when allowlist is present", async () => {
const partialToolCall = { type: "toolCall", id: "functions.read:0", name: "" };
const finalToolCallA = { type: "toolCall", id: "functionsread3", name: "" };

View File

@@ -636,20 +636,26 @@ function trimWhitespaceFromToolCallNamesInMessage(
normalizeToolCallIdsInMessage(message);
}
function collectUnknownToolNameFromMessage(
function classifyToolCallMessage(
message: unknown,
allowedToolNames?: Set<string>,
): string | undefined {
):
| { kind: "none" }
| { kind: "allowed" }
| { kind: "incomplete" }
| { kind: "unknown"; toolName: string } {
if (!message || typeof message !== "object" || !allowedToolNames || allowedToolNames.size === 0) {
return undefined;
return { kind: "none" };
}
const content = (message as { content?: unknown }).content;
if (!Array.isArray(content)) {
return undefined;
return { kind: "none" };
}
let unknownToolName: string | undefined;
let sawToolCall = false;
let sawAllowedToolCall = false;
let sawIncompleteToolCall = false;
for (const block of content) {
if (!block || typeof block !== "object") {
continue;
@@ -661,10 +667,12 @@ function collectUnknownToolNameFromMessage(
sawToolCall = true;
const rawName = typeof typedBlock.name === "string" ? typedBlock.name.trim() : "";
if (!rawName) {
return undefined;
sawIncompleteToolCall = true;
continue;
}
if (resolveExactAllowedToolName(rawName, allowedToolNames)) {
return undefined;
sawAllowedToolCall = true;
continue;
}
const normalizedUnknownToolName = normalizeToolName(rawName);
if (!unknownToolName) {
@@ -672,11 +680,20 @@ function collectUnknownToolNameFromMessage(
continue;
}
if (unknownToolName !== normalizedUnknownToolName) {
return undefined;
sawIncompleteToolCall = true;
}
}
return sawToolCall ? unknownToolName : undefined;
if (!sawToolCall) {
return { kind: "none" };
}
if (sawAllowedToolCall) {
return { kind: "allowed" };
}
if (sawIncompleteToolCall) {
return { kind: "incomplete" };
}
return unknownToolName ? { kind: "unknown", toolName: unknownToolName } : { kind: "incomplete" };
}
function rewriteUnknownToolLoopMessage(message: unknown, toolName: string): void {
@@ -694,27 +711,41 @@ function rewriteUnknownToolLoopMessage(message: unknown, toolName: string): void
function guardUnknownToolLoopInMessage(
message: unknown,
state: UnknownToolLoopGuardState,
params: { allowedToolNames?: Set<string>; threshold?: number; countAttempt: boolean },
): void {
params: {
allowedToolNames?: Set<string>;
threshold?: number;
countAttempt: boolean;
resetOnAllowedTool?: boolean;
resetOnMissingUnknownTool?: boolean;
},
): boolean {
const threshold = params.threshold;
if (threshold === undefined || threshold <= 0) {
return;
return false;
}
const unknownToolName = collectUnknownToolNameFromMessage(message, params.allowedToolNames);
if (!unknownToolName) {
if (params.countAttempt) {
const toolCallState = classifyToolCallMessage(message, params.allowedToolNames);
if (toolCallState.kind === "allowed") {
if (params.resetOnAllowedTool === true) {
state.lastUnknownToolName = undefined;
state.count = 0;
}
return;
return false;
}
if (toolCallState.kind !== "unknown") {
if (params.countAttempt && params.resetOnMissingUnknownTool !== false) {
state.lastUnknownToolName = undefined;
state.count = 0;
}
return false;
}
const unknownToolName = toolCallState.toolName;
if (!params.countAttempt) {
if (state.lastUnknownToolName === unknownToolName && state.count > threshold) {
rewriteUnknownToolLoopMessage(message, unknownToolName);
}
return;
return false;
}
if (message && typeof message === "object") {
@@ -722,7 +753,7 @@ function guardUnknownToolLoopInMessage(
if (state.lastUnknownToolName === unknownToolName && state.count > threshold) {
rewriteUnknownToolLoopMessage(message, unknownToolName);
}
return;
return true;
}
state.countedMessages.add(message);
}
@@ -737,6 +768,7 @@ function guardUnknownToolLoopInMessage(
if (state.count > threshold) {
rewriteUnknownToolLoopMessage(message, unknownToolName);
}
return true;
}
function wrapStreamTrimToolCallNames(
@@ -757,6 +789,7 @@ function wrapStreamTrimToolCallNames(
allowedToolNames,
threshold: options?.unknownToolThreshold,
countAttempt: !streamAttemptAlreadyCounted,
resetOnAllowedTool: true,
});
return message;
};
@@ -776,12 +809,18 @@ function wrapStreamTrimToolCallNames(
trimWhitespaceFromToolCallNamesInMessage(event.partial, allowedToolNames);
trimWhitespaceFromToolCallNamesInMessage(event.message, allowedToolNames);
if (event.message && typeof event.message === "object") {
guardUnknownToolLoopInMessage(event.message, unknownToolGuardState, {
allowedToolNames,
threshold: options?.unknownToolThreshold,
countAttempt: true,
});
streamAttemptAlreadyCounted = true;
const countedStreamAttempt = guardUnknownToolLoopInMessage(
event.message,
unknownToolGuardState,
{
allowedToolNames,
threshold: options?.unknownToolThreshold,
countAttempt: !streamAttemptAlreadyCounted,
resetOnAllowedTool: true,
resetOnMissingUnknownTool: false,
},
);
streamAttemptAlreadyCounted ||= countedStreamAttempt;
}
guardUnknownToolLoopInMessage(event.partial, unknownToolGuardState, {
allowedToolNames,