From 5201c422513db414dea6044021801b1bde639c8c Mon Sep 17 00:00:00 2001 From: Peter Steinberger Date: Wed, 29 Apr 2026 17:31:50 +0100 Subject: [PATCH] refactor(discord): split messaging runtime actions --- .../src/actions/runtime.messaging.messages.ts | 205 ++++++ .../actions/runtime.messaging.reactions.ts | 67 ++ .../src/actions/runtime.messaging.runtime.ts | 69 ++ .../src/actions/runtime.messaging.send.ts | 230 +++++++ .../src/actions/runtime.messaging.shared.ts | 97 +++ .../discord/src/actions/runtime.messaging.ts | 602 +----------------- extensions/discord/src/internal/rest.test.ts | 88 +-- .../internal/test-builders.test-support.ts | 34 + .../discord/src/proxy-request-client.test.ts | 34 +- extensions/discord/src/voice/manager.ts | 5 +- 10 files changed, 788 insertions(+), 643 deletions(-) create mode 100644 extensions/discord/src/actions/runtime.messaging.messages.ts create mode 100644 extensions/discord/src/actions/runtime.messaging.reactions.ts create mode 100644 extensions/discord/src/actions/runtime.messaging.runtime.ts create mode 100644 extensions/discord/src/actions/runtime.messaging.send.ts create mode 100644 extensions/discord/src/actions/runtime.messaging.shared.ts diff --git a/extensions/discord/src/actions/runtime.messaging.messages.ts b/extensions/discord/src/actions/runtime.messaging.messages.ts new file mode 100644 index 00000000000..7069bb6a17d --- /dev/null +++ b/extensions/discord/src/actions/runtime.messaging.messages.ts @@ -0,0 +1,205 @@ +import { + jsonResult, + readNumberParam, + readStringArrayParam, + readStringParam, +} from "../runtime-api.js"; +import { discordMessagingActionRuntime } from "./runtime.messaging.runtime.js"; +import type { DiscordMessagingActionContext } from "./runtime.messaging.shared.js"; + +function parseDiscordMessageLink(link: string) { + const normalized = link.trim(); + const match = normalized.match( + /^(?:https?:\/\/)?(?:ptb\.|canary\.)?discord(?:app)?\.com\/channels\/(\d+)\/(\d+)\/(\d+)(?:\/?|\?.*)$/i, + ); + if (!match) { + throw new Error( + "Invalid Discord message link. Expected https://discord.com/channels///.", + ); + } + return { + guildId: match[1], + channelId: match[2], + messageId: match[3], + }; +} + +export async function handleDiscordMessageManagementAction(ctx: DiscordMessagingActionContext) { + switch (ctx.action) { + case "permissions": { + if (!ctx.isActionEnabled("permissions")) { + throw new Error("Discord permissions are disabled."); + } + const channelId = ctx.resolveChannelId(); + const permissions = await discordMessagingActionRuntime.fetchChannelPermissionsDiscord( + channelId, + ctx.withOpts(), + ); + return jsonResult({ ok: true, permissions }); + } + case "fetchMessage": { + if (!ctx.isActionEnabled("messages")) { + throw new Error("Discord message reads are disabled."); + } + const messageLink = readStringParam(ctx.params, "messageLink"); + let guildId = readStringParam(ctx.params, "guildId"); + let channelId = readStringParam(ctx.params, "channelId"); + let messageId = readStringParam(ctx.params, "messageId"); + if (messageLink) { + const parsed = parseDiscordMessageLink(messageLink); + guildId = parsed.guildId; + channelId = parsed.channelId; + messageId = parsed.messageId; + } + if (!guildId || !channelId || !messageId) { + throw new Error( + "Discord message fetch requires guildId, channelId, and messageId (or a valid messageLink).", + ); + } + const message = await discordMessagingActionRuntime.fetchMessageDiscord( + channelId, + messageId, + ctx.withOpts(), + ); + return jsonResult({ + ok: true, + message: ctx.normalizeMessage(message), + guildId, + channelId, + messageId, + }); + } + case "readMessages": { + if (!ctx.isActionEnabled("messages")) { + throw new Error("Discord message reads are disabled."); + } + const channelId = ctx.resolveChannelId(); + const query = { + limit: readNumberParam(ctx.params, "limit"), + before: readStringParam(ctx.params, "before"), + after: readStringParam(ctx.params, "after"), + around: readStringParam(ctx.params, "around"), + }; + const messages = await discordMessagingActionRuntime.readMessagesDiscord( + channelId, + query, + ctx.withOpts(), + ); + return jsonResult({ + ok: true, + messages: messages.map((message) => ctx.normalizeMessage(message)), + }); + } + case "editMessage": { + if (!ctx.isActionEnabled("messages")) { + throw new Error("Discord message edits are disabled."); + } + const channelId = ctx.resolveChannelId(); + const messageId = readStringParam(ctx.params, "messageId", { + required: true, + }); + const content = readStringParam(ctx.params, "content", { + required: true, + }); + const message = await discordMessagingActionRuntime.editMessageDiscord( + channelId, + messageId, + { content }, + ctx.withOpts(), + ); + return jsonResult({ ok: true, message }); + } + case "deleteMessage": { + if (!ctx.isActionEnabled("messages")) { + throw new Error("Discord message deletes are disabled."); + } + const channelId = ctx.resolveChannelId(); + const messageId = readStringParam(ctx.params, "messageId", { + required: true, + }); + await discordMessagingActionRuntime.deleteMessageDiscord( + channelId, + messageId, + ctx.withOpts(), + ); + return jsonResult({ ok: true }); + } + case "pinMessage": { + if (!ctx.isActionEnabled("pins")) { + throw new Error("Discord pins are disabled."); + } + const channelId = ctx.resolveChannelId(); + const messageId = readStringParam(ctx.params, "messageId", { + required: true, + }); + await discordMessagingActionRuntime.pinMessageDiscord(channelId, messageId, ctx.withOpts()); + return jsonResult({ ok: true }); + } + case "unpinMessage": { + if (!ctx.isActionEnabled("pins")) { + throw new Error("Discord pins are disabled."); + } + const channelId = ctx.resolveChannelId(); + const messageId = readStringParam(ctx.params, "messageId", { + required: true, + }); + await discordMessagingActionRuntime.unpinMessageDiscord(channelId, messageId, ctx.withOpts()); + return jsonResult({ ok: true }); + } + case "listPins": { + if (!ctx.isActionEnabled("pins")) { + throw new Error("Discord pins are disabled."); + } + const channelId = ctx.resolveChannelId(); + const pins = await discordMessagingActionRuntime.listPinsDiscord(channelId, ctx.withOpts()); + return jsonResult({ ok: true, pins: pins.map((pin) => ctx.normalizeMessage(pin)) }); + } + case "searchMessages": { + if (!ctx.isActionEnabled("search")) { + throw new Error("Discord search is disabled."); + } + const guildId = readStringParam(ctx.params, "guildId", { + required: true, + }); + const content = readStringParam(ctx.params, "content", { + required: true, + }); + const channelId = readStringParam(ctx.params, "channelId"); + const channelIds = readStringArrayParam(ctx.params, "channelIds"); + const authorId = readStringParam(ctx.params, "authorId"); + const authorIds = readStringArrayParam(ctx.params, "authorIds"); + const limit = readNumberParam(ctx.params, "limit"); + const channelIdList = [...(channelIds ?? []), ...(channelId ? [channelId] : [])]; + const authorIdList = [...(authorIds ?? []), ...(authorId ? [authorId] : [])]; + const results = await discordMessagingActionRuntime.searchMessagesDiscord( + { + guildId, + content, + channelIds: channelIdList.length ? channelIdList : undefined, + authorIds: authorIdList.length ? authorIdList : undefined, + limit, + }, + ctx.withOpts(), + ); + if (!results || typeof results !== "object") { + return jsonResult({ ok: true, results }); + } + const resultsRecord = results as Record; + const messages = resultsRecord.messages; + const normalizedMessages = Array.isArray(messages) + ? messages.map((group) => + Array.isArray(group) ? group.map((msg) => ctx.normalizeMessage(msg)) : group, + ) + : messages; + return jsonResult({ + ok: true, + results: { + ...resultsRecord, + messages: normalizedMessages, + }, + }); + } + default: + return undefined; + } +} diff --git a/extensions/discord/src/actions/runtime.messaging.reactions.ts b/extensions/discord/src/actions/runtime.messaging.reactions.ts new file mode 100644 index 00000000000..c64a2873eca --- /dev/null +++ b/extensions/discord/src/actions/runtime.messaging.reactions.ts @@ -0,0 +1,67 @@ +import { + jsonResult, + readNumberParam, + readReactionParams, + readStringParam, +} from "../runtime-api.js"; +import { discordMessagingActionRuntime } from "./runtime.messaging.runtime.js"; +import type { DiscordMessagingActionContext } from "./runtime.messaging.shared.js"; + +export async function handleDiscordReactionMessagingAction(ctx: DiscordMessagingActionContext) { + switch (ctx.action) { + case "react": { + if (!ctx.isActionEnabled("reactions")) { + throw new Error("Discord reactions are disabled."); + } + const channelId = await ctx.resolveReactionChannelId(); + const messageId = readStringParam(ctx.params, "messageId", { + required: true, + }); + const { emoji, remove, isEmpty } = readReactionParams(ctx.params, { + removeErrorMessage: "Emoji is required to remove a Discord reaction.", + }); + if (remove) { + await discordMessagingActionRuntime.removeReactionDiscord( + channelId, + messageId, + emoji, + ctx.withReactionRuntimeOptions(), + ); + return jsonResult({ ok: true, removed: emoji }); + } + if (isEmpty) { + const removed = await discordMessagingActionRuntime.removeOwnReactionsDiscord( + channelId, + messageId, + ctx.withReactionRuntimeOptions(), + ); + return jsonResult({ ok: true, removed: removed.removed }); + } + await discordMessagingActionRuntime.reactMessageDiscord( + channelId, + messageId, + emoji, + ctx.withReactionRuntimeOptions(), + ); + return jsonResult({ ok: true, added: emoji }); + } + case "reactions": { + if (!ctx.isActionEnabled("reactions")) { + throw new Error("Discord reactions are disabled."); + } + const channelId = await ctx.resolveReactionChannelId(); + const messageId = readStringParam(ctx.params, "messageId", { + required: true, + }); + const limit = readNumberParam(ctx.params, "limit"); + const reactions = await discordMessagingActionRuntime.fetchReactionsDiscord( + channelId, + messageId, + ctx.withReactionRuntimeOptions({ limit }), + ); + return jsonResult({ ok: true, reactions }); + } + default: + return undefined; + } +} diff --git a/extensions/discord/src/actions/runtime.messaging.runtime.ts b/extensions/discord/src/actions/runtime.messaging.runtime.ts new file mode 100644 index 00000000000..592d4e61f71 --- /dev/null +++ b/extensions/discord/src/actions/runtime.messaging.runtime.ts @@ -0,0 +1,69 @@ +import { readDiscordComponentSpec } from "../components.js"; +import type { OpenClawConfig } from "../runtime-api.js"; +import { sendDiscordComponentMessage } from "../send.components.js"; +import { + createThreadDiscord, + deleteMessageDiscord, + editMessageDiscord, + fetchChannelPermissionsDiscord, + fetchMessageDiscord, + fetchReactionsDiscord, + listPinsDiscord, + listThreadsDiscord, + pinMessageDiscord, + reactMessageDiscord, + readMessagesDiscord, + removeOwnReactionsDiscord, + removeReactionDiscord, + searchMessagesDiscord, + sendMessageDiscord, + sendPollDiscord, + sendStickerDiscord, + sendVoiceMessageDiscord, + unpinMessageDiscord, +} from "../send.js"; +import { resolveDiscordTargetChannelId } from "../send.shared.js"; +import { resolveDiscordChannelId } from "../targets.js"; + +export const discordMessagingActionRuntime = { + createThreadDiscord, + deleteMessageDiscord, + editMessageDiscord, + fetchChannelPermissionsDiscord, + fetchMessageDiscord, + fetchReactionsDiscord, + listPinsDiscord, + listThreadsDiscord, + pinMessageDiscord, + reactMessageDiscord, + readDiscordComponentSpec, + readMessagesDiscord, + removeOwnReactionsDiscord, + removeReactionDiscord, + resolveDiscordReactionTargetChannelId, + resolveDiscordChannelId, + searchMessagesDiscord, + sendDiscordComponentMessage, + sendMessageDiscord, + sendPollDiscord, + sendStickerDiscord, + sendVoiceMessageDiscord, + unpinMessageDiscord, +}; + +export async function resolveDiscordReactionTargetChannelId(params: { + target: string; + cfg: OpenClawConfig; + accountId?: string; +}): Promise { + try { + return resolveDiscordChannelId(params.target); + } catch { + return ( + await resolveDiscordTargetChannelId(params.target, { + cfg: params.cfg, + accountId: params.accountId, + }) + ).channelId; + } +} diff --git a/extensions/discord/src/actions/runtime.messaging.send.ts b/extensions/discord/src/actions/runtime.messaging.send.ts new file mode 100644 index 00000000000..7f479f6fb57 --- /dev/null +++ b/extensions/discord/src/actions/runtime.messaging.send.ts @@ -0,0 +1,230 @@ +import { + assertMediaNotDataUrl, + jsonResult, + readBooleanParam, + readNumberParam, + readStringArrayParam, + readStringParam, + resolvePollMaxSelections, +} from "../runtime-api.js"; +import type { DiscordSendComponents, DiscordSendEmbeds } from "../send.shared.js"; +import { discordMessagingActionRuntime } from "./runtime.messaging.runtime.js"; +import type { DiscordMessagingActionContext } from "./runtime.messaging.shared.js"; + +function hasDiscordComponentObjectKeys(value: unknown): value is Record { + return Boolean( + value && + typeof value === "object" && + !Array.isArray(value) && + Object.keys(value as Record).length > 0, + ); +} + +export async function handleDiscordMessageSendAction(ctx: DiscordMessagingActionContext) { + switch (ctx.action) { + case "sticker": { + if (!ctx.isActionEnabled("stickers")) { + throw new Error("Discord stickers are disabled."); + } + const to = readStringParam(ctx.params, "to", { required: true }); + const content = readStringParam(ctx.params, "content"); + const stickerIds = readStringArrayParam(ctx.params, "stickerIds", { + required: true, + label: "stickerIds", + }); + await discordMessagingActionRuntime.sendStickerDiscord( + to, + stickerIds, + ctx.withOpts({ content }), + ); + return jsonResult({ ok: true }); + } + case "poll": { + if (!ctx.isActionEnabled("polls")) { + throw new Error("Discord polls are disabled."); + } + const to = readStringParam(ctx.params, "to", { required: true }); + const content = readStringParam(ctx.params, "content"); + const question = readStringParam(ctx.params, "question", { + required: true, + }); + const answers = readStringArrayParam(ctx.params, "answers", { + required: true, + label: "answers", + }); + const allowMultiselect = readBooleanParam(ctx.params, "allowMultiselect"); + const durationHours = readNumberParam(ctx.params, "durationHours"); + const maxSelections = resolvePollMaxSelections(answers.length, allowMultiselect); + await discordMessagingActionRuntime.sendPollDiscord( + to, + { question, options: answers, maxSelections, durationHours }, + ctx.withOpts({ content }), + ); + return jsonResult({ ok: true }); + } + case "sendMessage": { + if (!ctx.isActionEnabled("messages")) { + throw new Error("Discord message sends are disabled."); + } + const to = readStringParam(ctx.params, "to", { required: true }); + const asVoice = ctx.params.asVoice === true; + const silent = ctx.params.silent === true; + const rawComponents = ctx.params.components; + const componentSpec = hasDiscordComponentObjectKeys(rawComponents) + ? discordMessagingActionRuntime.readDiscordComponentSpec(rawComponents) + : null; + const components: DiscordSendComponents | undefined = + Array.isArray(rawComponents) || typeof rawComponents === "function" + ? (rawComponents as DiscordSendComponents) + : undefined; + const content = readStringParam(ctx.params, "content", { + required: !asVoice && !componentSpec && !components, + allowEmpty: true, + }); + const mediaUrl = + readStringParam(ctx.params, "mediaUrl", { trim: false }) ?? + readStringParam(ctx.params, "path", { trim: false }) ?? + readStringParam(ctx.params, "filePath", { trim: false }); + const filename = readStringParam(ctx.params, "filename"); + const replyTo = readStringParam(ctx.params, "replyTo"); + const rawEmbeds = ctx.params.embeds; + const embeds: DiscordSendEmbeds | undefined = Array.isArray(rawEmbeds) + ? (rawEmbeds as DiscordSendEmbeds) + : undefined; + const sessionKey = readStringParam(ctx.params, "__sessionKey"); + const agentId = readStringParam(ctx.params, "__agentId"); + + if (componentSpec) { + if (asVoice) { + throw new Error("Discord components cannot be sent as voice messages."); + } + if (embeds?.length) { + throw new Error("Discord components cannot include embeds."); + } + const normalizedContent = content?.trim() ? content : undefined; + const payload = componentSpec.text + ? componentSpec + : { ...componentSpec, text: normalizedContent }; + const result = await discordMessagingActionRuntime.sendDiscordComponentMessage( + to, + payload, + { + ...ctx.withOpts(), + silent, + replyTo: replyTo ?? undefined, + sessionKey: sessionKey ?? undefined, + agentId: agentId ?? undefined, + mediaUrl: mediaUrl ?? undefined, + filename: filename ?? undefined, + }, + ); + return jsonResult({ ok: true, result, components: true }); + } + + if (asVoice) { + if (!mediaUrl) { + throw new Error( + "Voice messages require a media file reference (mediaUrl, path, or filePath).", + ); + } + if (content && content.trim()) { + throw new Error( + "Voice messages cannot include text content (Discord limitation). Remove the content parameter.", + ); + } + assertMediaNotDataUrl(mediaUrl); + const result = await discordMessagingActionRuntime.sendVoiceMessageDiscord(to, mediaUrl, { + ...ctx.withOpts(), + replyTo, + silent, + }); + return jsonResult({ ok: true, result, voiceMessage: true }); + } + + const result = await discordMessagingActionRuntime.sendMessageDiscord(to, content ?? "", { + ...ctx.withOpts(), + mediaUrl, + filename: filename ?? undefined, + mediaLocalRoots: ctx.options?.mediaLocalRoots, + mediaReadFile: ctx.options?.mediaReadFile, + replyTo, + components, + embeds, + silent, + }); + return jsonResult({ ok: true, result }); + } + case "threadCreate": { + if (!ctx.isActionEnabled("threads")) { + throw new Error("Discord threads are disabled."); + } + const channelId = ctx.resolveChannelId(); + const name = readStringParam(ctx.params, "name", { required: true }); + const messageId = readStringParam(ctx.params, "messageId"); + const content = readStringParam(ctx.params, "content"); + const autoArchiveMinutes = readNumberParam(ctx.params, "autoArchiveMinutes"); + const appliedTags = readStringArrayParam(ctx.params, "appliedTags"); + const payload = { + name, + messageId, + autoArchiveMinutes, + content, + appliedTags: appliedTags ?? undefined, + }; + const thread = await discordMessagingActionRuntime.createThreadDiscord( + channelId, + payload, + ctx.withOpts(), + ); + return jsonResult({ ok: true, thread }); + } + case "threadList": { + if (!ctx.isActionEnabled("threads")) { + throw new Error("Discord threads are disabled."); + } + const guildId = readStringParam(ctx.params, "guildId", { + required: true, + }); + const channelId = readStringParam(ctx.params, "channelId"); + const includeArchived = readBooleanParam(ctx.params, "includeArchived"); + const before = readStringParam(ctx.params, "before"); + const limit = readNumberParam(ctx.params, "limit"); + const threads = await discordMessagingActionRuntime.listThreadsDiscord( + { + guildId, + channelId, + includeArchived, + before, + limit, + }, + ctx.withOpts(), + ); + return jsonResult({ ok: true, threads }); + } + case "threadReply": { + if (!ctx.isActionEnabled("threads")) { + throw new Error("Discord threads are disabled."); + } + const channelId = ctx.resolveChannelId(); + const content = readStringParam(ctx.params, "content", { + required: true, + }); + const mediaUrl = readStringParam(ctx.params, "mediaUrl"); + const replyTo = readStringParam(ctx.params, "replyTo"); + const result = await discordMessagingActionRuntime.sendMessageDiscord( + `channel:${channelId}`, + content, + { + ...ctx.withOpts(), + mediaUrl, + mediaLocalRoots: ctx.options?.mediaLocalRoots, + mediaReadFile: ctx.options?.mediaReadFile, + replyTo, + }, + ); + return jsonResult({ ok: true, result }); + } + default: + return undefined; + } +} diff --git a/extensions/discord/src/actions/runtime.messaging.shared.ts b/extensions/discord/src/actions/runtime.messaging.shared.ts new file mode 100644 index 00000000000..a11cf1666d2 --- /dev/null +++ b/extensions/discord/src/actions/runtime.messaging.shared.ts @@ -0,0 +1,97 @@ +import type { AgentToolResult } from "@mariozechner/pi-agent-core"; +import { resolveDefaultDiscordAccountId } from "../accounts.js"; +import { createDiscordRuntimeAccountContext } from "../client.js"; +import { + type ActionGate, + readStringParam, + type DiscordActionConfig, + type OpenClawConfig, + withNormalizedTimestamp, +} from "../runtime-api.js"; +import type { DiscordReactOpts } from "../send.types.js"; +import { discordMessagingActionRuntime } from "./runtime.messaging.runtime.js"; +import { createDiscordActionOptions } from "./runtime.shared.js"; + +export type DiscordMessagingActionOptions = { + mediaLocalRoots?: readonly string[]; + mediaReadFile?: (filePath: string) => Promise; +}; + +export type DiscordMessagingActionContext = { + action: string; + params: Record; + isActionEnabled: ActionGate; + cfg: OpenClawConfig; + options?: DiscordMessagingActionOptions; + accountId?: string; + resolveChannelId: () => string; + resolveReactionChannelId: () => Promise; + withOpts: (extra?: Record) => { cfg: OpenClawConfig; accountId?: string }; + withReactionRuntimeOptions: = Record>( + extra?: T, + ) => DiscordReactOpts & T; + normalizeMessage: (message: unknown) => unknown; +}; + +export type DiscordMessagingActionHandler = ( + ctx: DiscordMessagingActionContext, +) => Promise | undefined>; + +export function createDiscordMessagingActionContext(params: { + action: string; + input: Record; + isActionEnabled: ActionGate; + cfg: OpenClawConfig; + options?: DiscordMessagingActionOptions; +}): DiscordMessagingActionContext { + const accountId = readStringParam(params.input, "accountId"); + const cfgOptions = { cfg: params.cfg }; + const withOpts = (extra?: Record) => + createDiscordActionOptions({ cfg: params.cfg, accountId, extra }); + const resolvedReactionAccountId = accountId ?? resolveDefaultDiscordAccountId(params.cfg); + const reactionRuntimeOptions = resolvedReactionAccountId + ? createDiscordRuntimeAccountContext({ + cfg: params.cfg, + accountId: resolvedReactionAccountId, + }) + : cfgOptions; + return { + action: params.action, + params: params.input, + isActionEnabled: params.isActionEnabled, + cfg: params.cfg, + options: params.options, + accountId, + resolveChannelId: () => + discordMessagingActionRuntime.resolveDiscordChannelId( + readStringParam(params.input, "channelId", { + required: true, + }), + ), + resolveReactionChannelId: async () => { + const target = + readStringParam(params.input, "channelId") ?? + readStringParam(params.input, "to", { required: true }); + return await discordMessagingActionRuntime.resolveDiscordReactionTargetChannelId({ + target, + cfg: params.cfg, + accountId: resolvedReactionAccountId, + }); + }, + withOpts, + withReactionRuntimeOptions: (extra) => + ({ + ...(reactionRuntimeOptions ?? cfgOptions), + ...extra, + }) as DiscordReactOpts & NonNullable, + normalizeMessage: (message: unknown) => { + if (!message || typeof message !== "object") { + return message; + } + return withNormalizedTimestamp( + message as Record, + (message as { timestamp?: unknown }).timestamp, + ); + }, + }; +} diff --git a/extensions/discord/src/actions/runtime.messaging.ts b/extensions/discord/src/actions/runtime.messaging.ts index 2bcf9f0cefd..9a3550595e9 100644 --- a/extensions/discord/src/actions/runtime.messaging.ts +++ b/extensions/discord/src/actions/runtime.messaging.ts @@ -1,594 +1,40 @@ import type { AgentToolResult } from "@mariozechner/pi-agent-core"; -import { resolveDefaultDiscordAccountId } from "../accounts.js"; -import { createDiscordRuntimeAccountContext } from "../client.js"; -import { readDiscordComponentSpec } from "../components.js"; +import type { ActionGate, DiscordActionConfig, OpenClawConfig } from "../runtime-api.js"; +import { handleDiscordMessageManagementAction } from "./runtime.messaging.messages.js"; +import { handleDiscordReactionMessagingAction } from "./runtime.messaging.reactions.js"; +import { handleDiscordMessageSendAction } from "./runtime.messaging.send.js"; import { - assertMediaNotDataUrl, - type ActionGate, - jsonResult, - readNumberParam, - readReactionParams, - readStringArrayParam, - readStringParam, - resolvePollMaxSelections, - type DiscordActionConfig, - type OpenClawConfig, - withNormalizedTimestamp, - readBooleanParam, -} from "../runtime-api.js"; -import { sendDiscordComponentMessage } from "../send.components.js"; -import { - createThreadDiscord, - deleteMessageDiscord, - editMessageDiscord, - fetchChannelPermissionsDiscord, - fetchMessageDiscord, - fetchReactionsDiscord, - listPinsDiscord, - listThreadsDiscord, - pinMessageDiscord, - reactMessageDiscord, - readMessagesDiscord, - removeOwnReactionsDiscord, - removeReactionDiscord, - searchMessagesDiscord, - sendMessageDiscord, - sendPollDiscord, - sendStickerDiscord, - sendVoiceMessageDiscord, - unpinMessageDiscord, -} from "../send.js"; -import { - resolveDiscordTargetChannelId, - type DiscordSendComponents, - type DiscordSendEmbeds, -} from "../send.shared.js"; -import { resolveDiscordChannelId } from "../targets.js"; -import { createDiscordActionOptions } from "./runtime.shared.js"; - -export const discordMessagingActionRuntime = { - createThreadDiscord, - deleteMessageDiscord, - editMessageDiscord, - fetchChannelPermissionsDiscord, - fetchMessageDiscord, - fetchReactionsDiscord, - listPinsDiscord, - listThreadsDiscord, - pinMessageDiscord, - reactMessageDiscord, - readDiscordComponentSpec, - readMessagesDiscord, - removeOwnReactionsDiscord, - removeReactionDiscord, + createDiscordMessagingActionContext, + type DiscordMessagingActionOptions, +} from "./runtime.messaging.shared.js"; +export { + discordMessagingActionRuntime, resolveDiscordReactionTargetChannelId, - resolveDiscordChannelId, - searchMessagesDiscord, - sendDiscordComponentMessage, - sendMessageDiscord, - sendPollDiscord, - sendStickerDiscord, - sendVoiceMessageDiscord, - unpinMessageDiscord, -}; - -export async function resolveDiscordReactionTargetChannelId(params: { - target: string; - cfg: OpenClawConfig; - accountId?: string; -}): Promise { - try { - return resolveDiscordChannelId(params.target); - } catch { - return ( - await resolveDiscordTargetChannelId(params.target, { - cfg: params.cfg, - accountId: params.accountId, - }) - ).channelId; - } -} - -function hasDiscordComponentObjectKeys(value: unknown): value is Record { - return Boolean( - value && - typeof value === "object" && - !Array.isArray(value) && - Object.keys(value as Record).length > 0, - ); -} - -function parseDiscordMessageLink(link: string) { - const normalized = link.trim(); - const match = normalized.match( - /^(?:https?:\/\/)?(?:ptb\.|canary\.)?discord(?:app)?\.com\/channels\/(\d+)\/(\d+)\/(\d+)(?:\/?|\?.*)$/i, - ); - if (!match) { - throw new Error( - "Invalid Discord message link. Expected https://discord.com/channels///.", - ); - } - return { - guildId: match[1], - channelId: match[2], - messageId: match[3], - }; -} +} from "./runtime.messaging.runtime.js"; export async function handleDiscordMessagingAction( action: string, params: Record, isActionEnabled: ActionGate, cfg: OpenClawConfig, - options?: { - mediaLocalRoots?: readonly string[]; - mediaReadFile?: (filePath: string) => Promise; - }, + options?: DiscordMessagingActionOptions, ): Promise> { - const resolveChannelId = () => - discordMessagingActionRuntime.resolveDiscordChannelId( - readStringParam(params, "channelId", { - required: true, - }), - ); - const accountId = readStringParam(params, "accountId"); if (!cfg) { throw new Error("Discord messaging actions require a resolved runtime config."); } - const cfgOptions = { cfg }; - const withOpts = (extra?: Record) => - createDiscordActionOptions({ cfg, accountId, extra }); - const resolvedReactionAccountId = accountId ?? resolveDefaultDiscordAccountId(cfg); - const resolveReactionChannelId = async () => { - const target = - readStringParam(params, "channelId") ?? readStringParam(params, "to", { required: true }); - return await discordMessagingActionRuntime.resolveDiscordReactionTargetChannelId({ - target, - cfg, - accountId: resolvedReactionAccountId, - }); - }; - const reactionRuntimeOptions = resolvedReactionAccountId - ? createDiscordRuntimeAccountContext({ - cfg, - accountId: resolvedReactionAccountId, - }) - : cfgOptions; - const withReactionRuntimeOptions = (extra?: Record) => ({ - ...(reactionRuntimeOptions ?? cfgOptions), - ...extra, + const ctx = createDiscordMessagingActionContext({ + action, + input: params, + isActionEnabled, + cfg, + options, }); - const normalizeMessage = (message: unknown) => { - if (!message || typeof message !== "object") { - return message; - } - return withNormalizedTimestamp( - message as Record, - (message as { timestamp?: unknown }).timestamp, - ); - }; - switch (action) { - case "react": { - if (!isActionEnabled("reactions")) { - throw new Error("Discord reactions are disabled."); - } - const channelId = await resolveReactionChannelId(); - const messageId = readStringParam(params, "messageId", { - required: true, - }); - const { emoji, remove, isEmpty } = readReactionParams(params, { - removeErrorMessage: "Emoji is required to remove a Discord reaction.", - }); - if (remove) { - await discordMessagingActionRuntime.removeReactionDiscord( - channelId, - messageId, - emoji, - withReactionRuntimeOptions(), - ); - return jsonResult({ ok: true, removed: emoji }); - } - if (isEmpty) { - const removed = await discordMessagingActionRuntime.removeOwnReactionsDiscord( - channelId, - messageId, - withReactionRuntimeOptions(), - ); - return jsonResult({ ok: true, removed: removed.removed }); - } - await discordMessagingActionRuntime.reactMessageDiscord( - channelId, - messageId, - emoji, - withReactionRuntimeOptions(), - ); - return jsonResult({ ok: true, added: emoji }); - } - case "reactions": { - if (!isActionEnabled("reactions")) { - throw new Error("Discord reactions are disabled."); - } - const channelId = await resolveReactionChannelId(); - const messageId = readStringParam(params, "messageId", { - required: true, - }); - const limit = readNumberParam(params, "limit"); - const reactions = await discordMessagingActionRuntime.fetchReactionsDiscord( - channelId, - messageId, - withReactionRuntimeOptions({ limit }), - ); - return jsonResult({ ok: true, reactions }); - } - case "sticker": { - if (!isActionEnabled("stickers")) { - throw new Error("Discord stickers are disabled."); - } - const to = readStringParam(params, "to", { required: true }); - const content = readStringParam(params, "content"); - const stickerIds = readStringArrayParam(params, "stickerIds", { - required: true, - label: "stickerIds", - }); - await discordMessagingActionRuntime.sendStickerDiscord(to, stickerIds, withOpts({ content })); - return jsonResult({ ok: true }); - } - case "poll": { - if (!isActionEnabled("polls")) { - throw new Error("Discord polls are disabled."); - } - const to = readStringParam(params, "to", { required: true }); - const content = readStringParam(params, "content"); - const question = readStringParam(params, "question", { - required: true, - }); - const answers = readStringArrayParam(params, "answers", { - required: true, - label: "answers", - }); - const allowMultiselect = readBooleanParam(params, "allowMultiselect"); - const durationHours = readNumberParam(params, "durationHours"); - const maxSelections = resolvePollMaxSelections(answers.length, allowMultiselect); - await discordMessagingActionRuntime.sendPollDiscord( - to, - { question, options: answers, maxSelections, durationHours }, - withOpts({ content }), - ); - return jsonResult({ ok: true }); - } - case "permissions": { - if (!isActionEnabled("permissions")) { - throw new Error("Discord permissions are disabled."); - } - const channelId = resolveChannelId(); - const permissions = await discordMessagingActionRuntime.fetchChannelPermissionsDiscord( - channelId, - withOpts(), - ); - return jsonResult({ ok: true, permissions }); - } - case "fetchMessage": { - if (!isActionEnabled("messages")) { - throw new Error("Discord message reads are disabled."); - } - const messageLink = readStringParam(params, "messageLink"); - let guildId = readStringParam(params, "guildId"); - let channelId = readStringParam(params, "channelId"); - let messageId = readStringParam(params, "messageId"); - if (messageLink) { - const parsed = parseDiscordMessageLink(messageLink); - guildId = parsed.guildId; - channelId = parsed.channelId; - messageId = parsed.messageId; - } - if (!guildId || !channelId || !messageId) { - throw new Error( - "Discord message fetch requires guildId, channelId, and messageId (or a valid messageLink).", - ); - } - const message = await discordMessagingActionRuntime.fetchMessageDiscord( - channelId, - messageId, - withOpts(), - ); - return jsonResult({ - ok: true, - message: normalizeMessage(message), - guildId, - channelId, - messageId, - }); - } - case "readMessages": { - if (!isActionEnabled("messages")) { - throw new Error("Discord message reads are disabled."); - } - const channelId = resolveChannelId(); - const query = { - limit: readNumberParam(params, "limit"), - before: readStringParam(params, "before"), - after: readStringParam(params, "after"), - around: readStringParam(params, "around"), - }; - const messages = await discordMessagingActionRuntime.readMessagesDiscord( - channelId, - query, - withOpts(), - ); - return jsonResult({ - ok: true, - messages: messages.map((message) => normalizeMessage(message)), - }); - } - case "sendMessage": { - if (!isActionEnabled("messages")) { - throw new Error("Discord message sends are disabled."); - } - const to = readStringParam(params, "to", { required: true }); - const asVoice = params.asVoice === true; - const silent = params.silent === true; - const rawComponents = params.components; - const componentSpec = hasDiscordComponentObjectKeys(rawComponents) - ? discordMessagingActionRuntime.readDiscordComponentSpec(rawComponents) - : null; - const components: DiscordSendComponents | undefined = - Array.isArray(rawComponents) || typeof rawComponents === "function" - ? (rawComponents as DiscordSendComponents) - : undefined; - const content = readStringParam(params, "content", { - required: !asVoice && !componentSpec && !components, - allowEmpty: true, - }); - const mediaUrl = - readStringParam(params, "mediaUrl", { trim: false }) ?? - readStringParam(params, "path", { trim: false }) ?? - readStringParam(params, "filePath", { trim: false }); - const filename = readStringParam(params, "filename"); - const replyTo = readStringParam(params, "replyTo"); - const rawEmbeds = params.embeds; - const embeds: DiscordSendEmbeds | undefined = Array.isArray(rawEmbeds) - ? (rawEmbeds as DiscordSendEmbeds) - : undefined; - const sessionKey = readStringParam(params, "__sessionKey"); - const agentId = readStringParam(params, "__agentId"); - - if (componentSpec) { - if (asVoice) { - throw new Error("Discord components cannot be sent as voice messages."); - } - if (embeds?.length) { - throw new Error("Discord components cannot include embeds."); - } - const normalizedContent = content?.trim() ? content : undefined; - const payload = componentSpec.text - ? componentSpec - : { ...componentSpec, text: normalizedContent }; - const result = await discordMessagingActionRuntime.sendDiscordComponentMessage( - to, - payload, - { - ...withOpts(), - silent, - replyTo: replyTo ?? undefined, - sessionKey: sessionKey ?? undefined, - agentId: agentId ?? undefined, - mediaUrl: mediaUrl ?? undefined, - filename: filename ?? undefined, - }, - ); - return jsonResult({ ok: true, result, components: true }); - } - - // Handle voice message sending - if (asVoice) { - if (!mediaUrl) { - throw new Error( - "Voice messages require a media file reference (mediaUrl, path, or filePath).", - ); - } - if (content && content.trim()) { - throw new Error( - "Voice messages cannot include text content (Discord limitation). Remove the content parameter.", - ); - } - assertMediaNotDataUrl(mediaUrl); - const result = await discordMessagingActionRuntime.sendVoiceMessageDiscord(to, mediaUrl, { - ...withOpts(), - replyTo, - silent, - }); - return jsonResult({ ok: true, result, voiceMessage: true }); - } - - const result = await discordMessagingActionRuntime.sendMessageDiscord(to, content ?? "", { - ...withOpts(), - mediaUrl, - filename: filename ?? undefined, - mediaLocalRoots: options?.mediaLocalRoots, - mediaReadFile: options?.mediaReadFile, - replyTo, - components, - embeds, - silent, - }); - return jsonResult({ ok: true, result }); - } - case "editMessage": { - if (!isActionEnabled("messages")) { - throw new Error("Discord message edits are disabled."); - } - const channelId = resolveChannelId(); - const messageId = readStringParam(params, "messageId", { - required: true, - }); - const content = readStringParam(params, "content", { - required: true, - }); - const message = await discordMessagingActionRuntime.editMessageDiscord( - channelId, - messageId, - { content }, - withOpts(), - ); - return jsonResult({ ok: true, message }); - } - case "deleteMessage": { - if (!isActionEnabled("messages")) { - throw new Error("Discord message deletes are disabled."); - } - const channelId = resolveChannelId(); - const messageId = readStringParam(params, "messageId", { - required: true, - }); - await discordMessagingActionRuntime.deleteMessageDiscord(channelId, messageId, withOpts()); - return jsonResult({ ok: true }); - } - case "threadCreate": { - if (!isActionEnabled("threads")) { - throw new Error("Discord threads are disabled."); - } - const channelId = resolveChannelId(); - const name = readStringParam(params, "name", { required: true }); - const messageId = readStringParam(params, "messageId"); - const content = readStringParam(params, "content"); - const autoArchiveMinutes = readNumberParam(params, "autoArchiveMinutes"); - const appliedTags = readStringArrayParam(params, "appliedTags"); - const payload = { - name, - messageId, - autoArchiveMinutes, - content, - appliedTags: appliedTags ?? undefined, - }; - const thread = await discordMessagingActionRuntime.createThreadDiscord( - channelId, - payload, - withOpts(), - ); - return jsonResult({ ok: true, thread }); - } - case "threadList": { - if (!isActionEnabled("threads")) { - throw new Error("Discord threads are disabled."); - } - const guildId = readStringParam(params, "guildId", { - required: true, - }); - const channelId = readStringParam(params, "channelId"); - const includeArchived = readBooleanParam(params, "includeArchived"); - const before = readStringParam(params, "before"); - const limit = readNumberParam(params, "limit"); - const threads = await discordMessagingActionRuntime.listThreadsDiscord( - { - guildId, - channelId, - includeArchived, - before, - limit, - }, - withOpts(), - ); - return jsonResult({ ok: true, threads }); - } - case "threadReply": { - if (!isActionEnabled("threads")) { - throw new Error("Discord threads are disabled."); - } - const channelId = resolveChannelId(); - const content = readStringParam(params, "content", { - required: true, - }); - const mediaUrl = readStringParam(params, "mediaUrl"); - const replyTo = readStringParam(params, "replyTo"); - const result = await discordMessagingActionRuntime.sendMessageDiscord( - `channel:${channelId}`, - content, - { - ...withOpts(), - mediaUrl, - mediaLocalRoots: options?.mediaLocalRoots, - mediaReadFile: options?.mediaReadFile, - replyTo, - }, - ); - return jsonResult({ ok: true, result }); - } - case "pinMessage": { - if (!isActionEnabled("pins")) { - throw new Error("Discord pins are disabled."); - } - const channelId = resolveChannelId(); - const messageId = readStringParam(params, "messageId", { - required: true, - }); - await discordMessagingActionRuntime.pinMessageDiscord(channelId, messageId, withOpts()); - return jsonResult({ ok: true }); - } - case "unpinMessage": { - if (!isActionEnabled("pins")) { - throw new Error("Discord pins are disabled."); - } - const channelId = resolveChannelId(); - const messageId = readStringParam(params, "messageId", { - required: true, - }); - await discordMessagingActionRuntime.unpinMessageDiscord(channelId, messageId, withOpts()); - return jsonResult({ ok: true }); - } - case "listPins": { - if (!isActionEnabled("pins")) { - throw new Error("Discord pins are disabled."); - } - const channelId = resolveChannelId(); - const pins = await discordMessagingActionRuntime.listPinsDiscord(channelId, withOpts()); - return jsonResult({ ok: true, pins: pins.map((pin) => normalizeMessage(pin)) }); - } - case "searchMessages": { - if (!isActionEnabled("search")) { - throw new Error("Discord search is disabled."); - } - const guildId = readStringParam(params, "guildId", { - required: true, - }); - const content = readStringParam(params, "content", { - required: true, - }); - const channelId = readStringParam(params, "channelId"); - const channelIds = readStringArrayParam(params, "channelIds"); - const authorId = readStringParam(params, "authorId"); - const authorIds = readStringArrayParam(params, "authorIds"); - const limit = readNumberParam(params, "limit"); - const channelIdList = [...(channelIds ?? []), ...(channelId ? [channelId] : [])]; - const authorIdList = [...(authorIds ?? []), ...(authorId ? [authorId] : [])]; - const results = await discordMessagingActionRuntime.searchMessagesDiscord( - { - guildId, - content, - channelIds: channelIdList.length ? channelIdList : undefined, - authorIds: authorIdList.length ? authorIdList : undefined, - limit, - }, - withOpts(), - ); - if (!results || typeof results !== "object") { - return jsonResult({ ok: true, results }); - } - const resultsRecord = results as Record; - const messages = resultsRecord.messages; - const normalizedMessages = Array.isArray(messages) - ? messages.map((group) => - Array.isArray(group) ? group.map((msg) => normalizeMessage(msg)) : group, - ) - : messages; - return jsonResult({ - ok: true, - results: { - ...resultsRecord, - messages: normalizedMessages, - }, - }); - } - default: + return ( + (await handleDiscordReactionMessagingAction(ctx)) ?? + (await handleDiscordMessageSendAction(ctx)) ?? + (await handleDiscordMessageManagementAction(ctx)) ?? + (() => { throw new Error(`Unknown action: ${action}`); - } + })() + ); } diff --git a/extensions/discord/src/internal/rest.test.ts b/extensions/discord/src/internal/rest.test.ts index f7373ad73fd..f2cd5503e5d 100644 --- a/extensions/discord/src/internal/rest.test.ts +++ b/extensions/discord/src/internal/rest.test.ts @@ -1,14 +1,7 @@ import { afterEach, describe, expect, it, vi } from "vitest"; import { serializeRequestBody } from "./rest-body.js"; import { RequestClient } from "./rest.js"; - -function createDeferred() { - let resolve: ((value: T) => void) | undefined; - const promise = new Promise((res) => { - resolve = res; - }); - return { promise, resolve: resolve! }; -} +import { createDeferred, createJsonResponse } from "./test-builders.test-support.js"; describe("RequestClient", () => { afterEach(() => { @@ -19,7 +12,7 @@ describe("RequestClient", () => { const firstResponse = createDeferred(); const queuedResponses = [ firstResponse.promise, - Promise.resolve(new Response(JSON.stringify({ ok: true }), { status: 200 })), + Promise.resolve(createJsonResponse({ ok: true })), ]; const fetchSpy = vi.fn(async () => { const response = queuedResponses.shift(); @@ -39,7 +32,7 @@ describe("RequestClient", () => { expect(client.queueSize).toBe(2); await expect(client.get("/users/@me")).rejects.toThrow(/queue is full/); - firstResponse.resolve(new Response(JSON.stringify({ id: "u1" }), { status: 200 })); + firstResponse.resolve(createJsonResponse({ id: "u1" })); await expect(first).resolves.toEqual({ id: "u1" }); await expect(second).resolves.toEqual({ ok: true }); @@ -65,16 +58,20 @@ describe("RequestClient", () => { await vi.waitFor(() => expect(fetchSpy).toHaveBeenCalledTimes(2)); channelResponse.resolve( - new Response(JSON.stringify({ id: "channel" }), { - status: 200, - headers: { "X-RateLimit-Bucket": "channel-messages", "X-RateLimit-Remaining": "1" }, - }), + createJsonResponse( + { id: "channel" }, + { + headers: { "X-RateLimit-Bucket": "channel-messages", "X-RateLimit-Remaining": "1" }, + }, + ), ); guildResponse.resolve( - new Response(JSON.stringify({ id: "guild" }), { - status: 200, - headers: { "X-RateLimit-Bucket": "guild-roles", "X-RateLimit-Remaining": "1" }, - }), + createJsonResponse( + { id: "guild" }, + { + headers: { "X-RateLimit-Bucket": "guild-roles", "X-RateLimit-Remaining": "1" }, + }, + ), ); await expect(Promise.all([channel, guild])).resolves.toEqual([ @@ -86,10 +83,12 @@ describe("RequestClient", () => { it("prunes idle route buckets and mappings after Discord bucket remapping", async () => { const client = new RequestClient("test-token", { fetch: async () => - new Response(JSON.stringify({ id: "first" }), { - status: 200, - headers: { "X-RateLimit-Bucket": "channel-messages" }, - }), + createJsonResponse( + { id: "first" }, + { + headers: { "X-RateLimit-Bucket": "channel-messages" }, + }, + ), }); await expect(client.get("/channels/c1/messages")).resolves.toEqual({ id: "first" }); @@ -105,25 +104,29 @@ describe("RequestClient", () => { vi.setSystemTime(0); const responses = [ Promise.resolve( - new Response(JSON.stringify({ id: "first" }), { - status: 200, - headers: { - "X-RateLimit-Bucket": "channel-messages", - "X-RateLimit-Limit": "1", - "X-RateLimit-Remaining": "0", - "X-RateLimit-Reset-After": "0.1", + createJsonResponse( + { id: "first" }, + { + headers: { + "X-RateLimit-Bucket": "channel-messages", + "X-RateLimit-Limit": "1", + "X-RateLimit-Remaining": "0", + "X-RateLimit-Reset-After": "0.1", + }, }, - }), + ), ), Promise.resolve( - new Response(JSON.stringify({ id: "second" }), { - status: 200, - headers: { - "X-RateLimit-Bucket": "channel-messages", - "X-RateLimit-Limit": "1", - "X-RateLimit-Remaining": "1", + createJsonResponse( + { id: "second" }, + { + headers: { + "X-RateLimit-Bucket": "channel-messages", + "X-RateLimit-Limit": "1", + "X-RateLimit-Remaining": "1", + }, }, - }), + ), ), ]; const fetchSpy = vi.fn(async () => { @@ -176,10 +179,13 @@ describe("RequestClient", () => { const client = new RequestClient("test-token", { queueRequests: false, fetch: async () => - new Response(JSON.stringify({ message: "Forbidden", code: 50013 }), { - status: 403, - headers: { "X-RateLimit-Bucket": "permissions" }, - }), + createJsonResponse( + { message: "Forbidden", code: 50013 }, + { + status: 403, + headers: { "X-RateLimit-Bucket": "permissions" }, + }, + ), }); await expect(client.get("/channels/c1/messages")).rejects.toMatchObject({ status: 403 }); diff --git a/extensions/discord/src/internal/test-builders.test-support.ts b/extensions/discord/src/internal/test-builders.test-support.ts index 38f3bca8314..cac6e66d66f 100644 --- a/extensions/discord/src/internal/test-builders.test-support.ts +++ b/extensions/discord/src/internal/test-builders.test-support.ts @@ -24,6 +24,40 @@ export type FakeRestClient = RequestClient & { enqueueResponse: (value: unknown) => void; }; +export function createDeferred() { + let resolve: ((value: T) => void) | undefined; + const promise = new Promise((res) => { + resolve = res; + }); + return { promise, resolve: resolve! }; +} + +export function createJsonResponse(body: unknown, init?: ResponseInit): Response { + return new Response(JSON.stringify(body), { + status: 200, + ...init, + }); +} + +export function createAbortableFetchMock() { + let receivedSignal: AbortSignal | undefined; + const fetch = vi.fn( + (_input: string | URL | Request, init?: RequestInit) => + new Promise((_resolve, reject) => { + receivedSignal = init?.signal ?? undefined; + init?.signal?.addEventListener("abort", () => { + reject(new DOMException("The operation was aborted.", "AbortError")); + }); + }), + ); + return { + fetch, + get receivedSignal() { + return receivedSignal; + }, + }; +} + export function createInternalTestClient(commands: BaseCommand[] = []): Client { return new Client( { diff --git a/extensions/discord/src/proxy-request-client.test.ts b/extensions/discord/src/proxy-request-client.test.ts index c6da2d665bf..c96ab991249 100644 --- a/extensions/discord/src/proxy-request-client.test.ts +++ b/extensions/discord/src/proxy-request-client.test.ts @@ -1,4 +1,8 @@ import { describe, expect, it, vi } from "vitest"; +import { + createAbortableFetchMock, + createJsonResponse, +} from "./internal/test-builders.test-support.js"; import { createDiscordRequestClient, DISCORD_REST_TIMEOUT_MS } from "./proxy-request-client.js"; describe("createDiscordRequestClient", () => { @@ -6,7 +10,7 @@ describe("createDiscordRequestClient", () => { const fetchSpy = vi.fn(async (_input: string | URL | Request, init?: RequestInit) => { expect(init?.signal).toBeDefined(); expect(init!.signal!.aborted).toBe(false); - return new Response(JSON.stringify([]), { status: 200 }); + return createJsonResponse([]); }); const client = createDiscordRequestClient("Bot test-token", { @@ -19,14 +23,7 @@ describe("createDiscordRequestClient", () => { }); it("lets the REST client abort hanging proxied requests after its timeout", async () => { - const fetchSpy = vi.fn( - (_input: string | URL | Request, init?: RequestInit) => - new Promise((_resolve, reject) => { - init?.signal?.addEventListener("abort", () => { - reject(new DOMException("The operation was aborted.", "AbortError")); - }); - }), - ); + const { fetch: fetchSpy } = createAbortableFetchMock(); const client = createDiscordRequestClient("Bot test-token", { fetch: fetchSpy as never, @@ -38,30 +35,21 @@ describe("createDiscordRequestClient", () => { }, 1_000); it("lets abortAllRequests cancel active proxied fetches", async () => { - let receivedSignal: AbortSignal | undefined; - const fetchSpy = vi.fn( - (_input: string | URL | Request, init?: RequestInit) => - new Promise((_resolve, reject) => { - receivedSignal = init?.signal ?? undefined; - init?.signal?.addEventListener("abort", () => { - reject(new DOMException("The operation was aborted.", "AbortError")); - }); - }), - ); + const abortable = createAbortableFetchMock(); const client = createDiscordRequestClient("Bot test-token", { - fetch: fetchSpy as never, + fetch: abortable.fetch as never, queueRequests: false, timeout: 5_000, }); const request = client.get("/channels/123/messages"); - await vi.waitFor(() => expect(fetchSpy).toHaveBeenCalledTimes(1)); + await vi.waitFor(() => expect(abortable.fetch).toHaveBeenCalledTimes(1)); client.abortAllRequests(); await expect(request).rejects.toThrow(); - expect(receivedSignal?.aborted).toBe(true); + expect(abortable.receivedSignal?.aborted).toBe(true); }); it("provides the REST client's timeout signal even without a caller signal", async () => { @@ -69,7 +57,7 @@ describe("createDiscordRequestClient", () => { const fetchSpy = vi.fn(async (_input: string | URL | Request, init?: RequestInit) => { receivedSignal = init?.signal ?? undefined; - return new Response(JSON.stringify({}), { status: 200 }); + return createJsonResponse({}); }); const client = createDiscordRequestClient("Bot test-token", { diff --git a/extensions/discord/src/voice/manager.ts b/extensions/discord/src/voice/manager.ts index c44cf0f1654..0ada40e2e58 100644 --- a/extensions/discord/src/voice/manager.ts +++ b/extensions/discord/src/voice/manager.ts @@ -98,7 +98,10 @@ export class DiscordVoiceManager { this.botUserId = params.botUserId; this.voiceEnabled = params.discordConfig.voice?.enabled !== false; this.ownerAllowFrom = - resolveDiscordAccountAllowFrom({ cfg: params.cfg, accountId: params.accountId }) ?? []; + resolveDiscordAccountAllowFrom({ cfg: params.cfg, accountId: params.accountId }) ?? + params.discordConfig.allowFrom ?? + params.discordConfig.dm?.allowFrom ?? + []; this.speakerContext = new DiscordVoiceSpeakerContextResolver({ client: params.client, ownerAllowFrom: this.ownerAllowFrom,