diff --git a/src/agents/agent-tools.message-provider-policy.test.ts b/src/agents/agent-tools.message-provider-policy.test.ts index 265b910e52a..034bbceac34 100644 --- a/src/agents/agent-tools.message-provider-policy.test.ts +++ b/src/agents/agent-tools.message-provider-policy.test.ts @@ -4,21 +4,42 @@ * unsafe or redundant for the active channel. */ import { describe, expect, it } from "vitest"; -import { filterToolNamesByMessageProvider } from "./agent-tools.message-provider-policy.js"; +import { filterToolsByMessageProvider } from "./agent-tools.message-provider-policy.js"; -const DEFAULT_TOOL_NAMES = ["read", "write", "tts", "web_search"]; +const DEFAULT_TOOLS = [ + { name: "read" }, + { name: "write" }, + { name: "tts" }, + { name: "web_search" }, +]; + +function toolNames(tools: readonly { name: string }[]): Set { + return new Set(tools.map((tool) => tool.name)); +} describe("createOpenClawCodingTools message provider policy", () => { it.each(["voice", "VOICE", " Voice ", "discord-voice", "DISCORD-VOICE", " Discord-Voice "])( "does not expose tts tool for normalized voice provider: %s", (messageProvider) => { - const names = new Set(filterToolNamesByMessageProvider(DEFAULT_TOOL_NAMES, messageProvider)); + const names = toolNames(filterToolsByMessageProvider(DEFAULT_TOOLS, messageProvider)); expect(names.has("tts")).toBe(false); }, ); it("keeps tts tool for non-voice providers", () => { - const names = new Set(filterToolNamesByMessageProvider(DEFAULT_TOOL_NAMES, "guildchat")); + const names = toolNames(filterToolsByMessageProvider(DEFAULT_TOOLS, "guildchat")); expect(names.has("tts")).toBe(true); }); + + it("preserves duplicate tool entries while filtering", () => { + const tools = [ + { name: "read", id: 1 }, + { name: "tts", id: 2 }, + { name: "read", id: 3 }, + ]; + expect(filterToolsByMessageProvider(tools, "voice")).toStrictEqual([ + { name: "read", id: 1 }, + { name: "read", id: 3 }, + ]); + }); }); diff --git a/src/agents/agent-tools.message-provider-policy.ts b/src/agents/agent-tools.message-provider-policy.ts index 5be579475b9..c28db6960fe 100644 --- a/src/agents/agent-tools.message-provider-policy.ts +++ b/src/agents/agent-tools.message-provider-policy.ts @@ -14,49 +14,24 @@ const TOOL_ALLOW_BY_MESSAGE_PROVIDER: Readonly node: ["canvas", "image", "pdf", "tts", "web_fetch", "web_search"], }; -/** Filters tool names by the active message-provider allow/deny policy. */ -export function filterToolNamesByMessageProvider( - toolNames: readonly string[], - messageProvider?: string, -): string[] { - const normalizedProvider = normalizeOptionalLowercaseString(messageProvider); - if (!normalizedProvider) { - return [...toolNames]; - } - const allowedTools = TOOL_ALLOW_BY_MESSAGE_PROVIDER[normalizedProvider]; - if (allowedTools && allowedTools.length > 0) { - const allowedSet = new Set(allowedTools); - return toolNames.filter((toolName) => allowedSet.has(toolName)); - } - const deniedTools = TOOL_DENY_BY_MESSAGE_PROVIDER[normalizedProvider]; - if (!deniedTools || deniedTools.length === 0) { - return [...toolNames]; - } - const deniedSet = new Set(deniedTools); - return toolNames.filter((toolName) => !deniedSet.has(toolName)); -} - /** Applies message-provider filtering while preserving duplicate tool entries. */ export function filterToolsByMessageProvider( tools: readonly TTool[], messageProvider?: string, ): TTool[] { - const filteredToolNames = filterToolNamesByMessageProvider( - tools.map((tool) => tool.name), - messageProvider, - ); - const remainingCounts = new Map(); - for (const toolName of filteredToolNames) { - remainingCounts.set(toolName, (remainingCounts.get(toolName) ?? 0) + 1); + const normalizedProvider = normalizeOptionalLowercaseString(messageProvider); + if (!normalizedProvider) { + return [...tools]; } - return tools.filter((tool) => { - // Counted matching preserves the original order and duplicate instances - // after name-level policy filtering. - const remaining = remainingCounts.get(tool.name) ?? 0; - if (remaining <= 0) { - return false; - } - remainingCounts.set(tool.name, remaining - 1); - return true; - }); + const allowedTools = TOOL_ALLOW_BY_MESSAGE_PROVIDER[normalizedProvider]; + if (allowedTools && allowedTools.length > 0) { + const allowedSet = new Set(allowedTools); + return tools.filter((tool) => allowedSet.has(tool.name)); + } + const deniedTools = TOOL_DENY_BY_MESSAGE_PROVIDER[normalizedProvider]; + if (!deniedTools || deniedTools.length === 0) { + return [...tools]; + } + const deniedSet = new Set(deniedTools); + return tools.filter((tool) => !deniedSet.has(tool.name)); }