From 5039a35a332a0db578abe4f8c2a288f55ede5dd5 Mon Sep 17 00:00:00 2001 From: Peter Steinberger Date: Sun, 3 May 2026 17:21:50 +0100 Subject: [PATCH] fix(discord): preserve tracked reaction targets --- .../monitor/message-handler.process.test.ts | 49 +++++++++++++++++++ .../src/monitor/message-handler.process.ts | 42 ++++++++++++++-- src/channels/status-reactions.test.ts | 18 +++++++ src/channels/status-reactions.ts | 16 ++++-- 4 files changed, 118 insertions(+), 7 deletions(-) diff --git a/extensions/discord/src/monitor/message-handler.process.test.ts b/extensions/discord/src/monitor/message-handler.process.test.ts index 96d1cc371c8..5c8ed7e67b5 100644 --- a/extensions/discord/src/monitor/message-handler.process.test.ts +++ b/extensions/discord/src/monitor/message-handler.process.test.ts @@ -66,6 +66,17 @@ vi.mock("../send.js", () => ({ }, })); +const discordTargetMocks = vi.hoisted(() => ({ + resolveDiscordTargetChannelId: vi.fn(async (target: string) => ({ + channelId: target === "user:u1" ? "dm-u1" : target, + })), +})); + +vi.mock("../send.shared.js", () => ({ + resolveDiscordTargetChannelId: (target: string, opts: unknown) => + discordTargetMocks.resolveDiscordTargetChannelId(target, opts), +})); + vi.mock("../send.messages.js", () => ({ editMessageDiscord: (channelId: string, messageId: string, payload: unknown, opts?: unknown) => deliveryMocks.editMessageDiscord(channelId, messageId, payload, opts), @@ -315,6 +326,7 @@ beforeEach(() => { vi.useRealTimers(); sendMocks.reactMessageDiscord.mockClear(); sendMocks.removeReactionDiscord.mockClear(); + discordTargetMocks.resolveDiscordTargetChannelId.mockClear(); editMessageDiscord.mockClear(); deliverDiscordReply.mockClear(); createDiscordDraftStream.mockClear(); @@ -637,6 +649,43 @@ describe("processDiscordMessage ack reactions", () => { expect(calls).toContainEqual(expect.arrayContaining(["c1", "m1", DEFAULT_EMOJIS.done])); }); + it("resolves tracked reaction to targets like the Discord reaction action", async () => { + vi.useFakeTimers(); + dispatchInboundMessage.mockImplementationOnce(async (params?: DispatchInboundParams) => { + await params?.replyOptions?.onToolStart?.({ + name: "message", + phase: "start", + args: { + action: "react", + to: "user:u1", + messageId: "m1", + emoji: "๐Ÿ“ˆ", + trackToolCalls: true, + }, + }); + await vi.advanceTimersByTimeAsync(DEFAULT_TIMING.debounceMs); + return createNoQueuedDispatchResult(); + }); + + const ctx = await createAutomaticSourceDeliveryContext({ + cfg: { messages: { ackReaction: "๐Ÿ‘€" } }, + }); + + await runProcessDiscordMessage(ctx); + await vi.runAllTimersAsync(); + + expect(discordTargetMocks.resolveDiscordTargetChannelId).toHaveBeenCalledWith( + "user:u1", + expect.objectContaining({ accountId: "default" }), + ); + const calls = sendMocks.reactMessageDiscord.mock.calls as unknown as Array< + [string, string, string] + >; + expect(calls).toContainEqual(expect.arrayContaining(["dm-u1", "m1", "๐Ÿ“ˆ"])); + expect(calls).toContainEqual(expect.arrayContaining(["dm-u1", "m1", "โœ‰๏ธ"])); + expect(calls).toContainEqual(expect.arrayContaining(["dm-u1", "m1", DEFAULT_EMOJIS.done])); + }); + it("shows stall emojis for long no-progress runs", async () => { vi.useFakeTimers(); let releaseDispatch!: () => void; diff --git a/extensions/discord/src/monitor/message-handler.process.ts b/extensions/discord/src/monitor/message-handler.process.ts index a08e6cda2dc..fe7482fbbb0 100644 --- a/extensions/discord/src/monitor/message-handler.process.ts +++ b/extensions/discord/src/monitor/message-handler.process.ts @@ -27,6 +27,8 @@ import { resolveDiscordMaxLinesPerMessage } from "../accounts.js"; import { createDiscordRestClient } from "../client.js"; import { removeReactionDiscord } from "../send.js"; import { editMessageDiscord } from "../send.messages.js"; +import { resolveDiscordTargetChannelId } from "../send.shared.js"; +import { resolveDiscordChannelId } from "../targets.js"; import { createDiscordAckReactionAdapter, createDiscordAckReactionContext, @@ -235,7 +237,29 @@ export async function processDiscordMessage( }); }, }); - const maybeBindStatusReactionsToToolReaction = (payload: ToolStartPayload) => { + const resolveTrackedReactionChannelId = async ( + args: Record, + ): Promise => { + const target = + readToolStringArg(args, "channelId") ?? + readToolStringArg(args, "channel_id") ?? + readToolStringArg(args, "to"); + if (!target) { + return messageChannelId; + } + try { + return resolveDiscordChannelId(target); + } catch { + return ( + await resolveDiscordTargetChannelId(target, { + cfg, + token, + accountId, + }) + ).channelId; + } + }; + const maybeBindStatusReactionsToToolReaction = async (payload: ToolStartPayload) => { if ( sourceRepliesAreToolOnly || cfg.messages?.statusReactions?.enabled === false || @@ -262,8 +286,18 @@ export async function processDiscordMessage( } const trackedMessageId = readToolStringArg(args, "messageId") ?? readToolStringArg(args, "message_id") ?? message.id; - const trackedChannelId = - readToolStringArg(args, "channelId") ?? readToolStringArg(args, "to") ?? messageChannelId; + let trackedChannelId: string; + try { + trackedChannelId = await resolveTrackedReactionChannelId(args); + } catch (err) { + logAckFailure({ + log: logVerbose, + channel: "discord", + target: `${readToolStringArg(args, "to") ?? readToolStringArg(args, "channelId") ?? messageChannelId}/${trackedMessageId}`, + error: err, + }); + return; + } statusReactionTarget = `${trackedChannelId}/${trackedMessageId}`; if (statusReactionsActive) { void statusReactions.clear(); @@ -619,7 +653,7 @@ export async function processDiscordMessage( if (isProcessAborted(abortSignal)) { return; } - maybeBindStatusReactionsToToolReaction(payload); + await maybeBindStatusReactionsToToolReaction(payload); await statusReactions.setTool(payload.name); draftPreview.pushToolProgress( payload.name ? `tool: ${payload.name}` : "tool running", diff --git a/src/channels/status-reactions.test.ts b/src/channels/status-reactions.test.ts index 9907f5bf972..884a1c2a9ed 100644 --- a/src/channels/status-reactions.test.ts +++ b/src/channels/status-reactions.test.ts @@ -110,6 +110,24 @@ describe("resolveToolEmoji", () => { expect(resolveToolEmoji(tool, DEFAULT_EMOJIS)).toBe(expected); }, ); + + it("preserves explicit status emoji overrides before exact tool display emojis", () => { + const emojis = { + ...DEFAULT_EMOJIS, + coding: "๐Ÿงช", + web: "๐Ÿ›ฐ๏ธ", + tool: "๐Ÿ”ง", + }; + const overrides = { + coding: "๐Ÿงช", + web: "๐Ÿ›ฐ๏ธ", + tool: "๐Ÿ”ง", + }; + + expect(resolveToolEmoji("exec", emojis, overrides)).toBe("๐Ÿงช"); + expect(resolveToolEmoji("web_search", emojis, overrides)).toBe("๐Ÿ›ฐ๏ธ"); + expect(resolveToolEmoji("message", emojis, overrides)).toBe("๐Ÿ”ง"); + }); }); describe("createStatusReactionController", () => { diff --git a/src/channels/status-reactions.ts b/src/channels/status-reactions.ts index 92bdeff7850..a67e8540c3e 100644 --- a/src/channels/status-reactions.ts +++ b/src/channels/status-reactions.ts @@ -105,18 +105,28 @@ export const WEB_TOOL_TOKENS: string[] = [ export function resolveToolEmoji( toolName: string | undefined, emojis: Required, + emojiOverrides?: StatusReactionEmojis, ): string { const normalized = normalizeOptionalLowercaseString(toolName) ?? ""; if (!normalized) { return emojis.tool; } + + const category = WEB_TOOL_TOKENS.some((token) => normalized.includes(token)) + ? "web" + : CODING_TOOL_TOKENS.some((token) => normalized.includes(token)) + ? "coding" + : "tool"; + if (emojiOverrides?.[category] !== undefined) { + return emojis[category]; + } if (Object.hasOwn(TOOL_DISPLAY_CONFIG.tools, normalized)) { return resolveToolDisplay({ name: toolName }).emoji; } - if (WEB_TOOL_TOKENS.some((token) => normalized.includes(token))) { + if (category === "web") { return emojis.web; } - if (CODING_TOOL_TOKENS.some((token) => normalized.includes(token))) { + if (category === "coding") { return emojis.coding; } return emojis.tool; @@ -322,7 +332,7 @@ export function createStatusReactionController(params: { } function setTool(toolName?: string): void { - const emoji = resolveToolEmoji(toolName, emojis); + const emoji = resolveToolEmoji(toolName, emojis, params.emojis); scheduleEmoji(emoji); }