fix(discord): satisfy internal boundary checks

This commit is contained in:
Peter Steinberger
2026-04-29 14:09:44 +01:00
parent 47b3530af3
commit da6135d34c
15 changed files with 113 additions and 60 deletions

View File

@@ -1,8 +1,7 @@
import { GatewayDispatchEvents } from "discord-api-types/v10";
import { getChannel, getGuild, getGuildMember, getUser } from "./api.js";
import type { Client } from "./client.js";
import type { RequestClient } from "./rest.js";
import { Guild, GuildMember, User, channelFactory } from "./structures.js";
import { Guild, GuildMember, User, channelFactory, type StructureClient } from "./structures.js";
type CacheEntry<T> = {
expiresAt: number;
@@ -16,7 +15,7 @@ export class DiscordEntityCache {
constructor(
private readonly params: {
client: Client;
client: StructureClient;
rest: RequestClient | (() => RequestClient);
ttlMs?: number;
},

View File

@@ -1,6 +1,6 @@
import { InteractionType, type APIInteraction } from "discord-api-types/v10";
import type { Client } from "./client.js";
import {
type BaseCommand,
deferCommandInteractionIfNeeded,
resolveFocusedCommandOptionAutocompleteHandler,
} from "./commands.js";
@@ -14,7 +14,30 @@ import {
type RawInteraction,
} from "./interactions.js";
export async function dispatchInteraction(client: Client, rawData: APIInteraction): Promise<void> {
type DispatchComponent = {
defer: boolean | ((interaction: BaseComponentInteraction) => boolean);
ephemeral: boolean | ((interaction: BaseComponentInteraction) => boolean);
run(interaction: BaseComponentInteraction, data: Record<string, unknown>): unknown;
customIdParser(id: string): { data: Record<string, unknown> };
};
type DispatchModal = {
run(interaction: ModalInteraction, data: Record<string, unknown>): unknown;
customIdParser(id: string): { data: Record<string, unknown> };
};
type DispatchClient = Parameters<typeof createInteraction>[0] & {
commands: BaseCommand[];
componentHandler: {
resolve(customId: string, options?: { componentType?: number }): DispatchComponent | undefined;
};
modalHandler: { resolve(customId: string): DispatchModal | undefined };
};
export async function dispatchInteraction(
client: DispatchClient,
rawData: APIInteraction,
): Promise<void> {
const interaction = createInteraction(client, rawData as RawInteraction);
if (rawData.type === InteractionType.ApplicationCommandAutocomplete) {
const command = client.commands.find((entry) => entry.name === readInteractionName(rawData));

View File

@@ -4,8 +4,11 @@ import {
type APIChannel,
type APIInteractionDataResolvedChannel,
} from "discord-api-types/v10";
import type { Client } from "./client.js";
import { channelFactory } from "./structures.js";
import { channelFactory, type DiscordChannel, type StructureClient } from "./structures.js";
type OptionsClient = StructureClient & {
fetchChannel(id: string): Promise<DiscordChannel>;
};
function readFocusedOption(
options: APIApplicationCommandInteractionDataOption[] | undefined,
@@ -50,7 +53,7 @@ function readChildOptions(
export class OptionsHandler {
constructor(
private rawOptions: APIApplicationCommandInteractionDataOption[] | undefined,
private client: Client,
private client: OptionsClient,
private resolvedChannels: Record<string, APIInteractionDataResolvedChannel> | undefined,
) {}

View File

@@ -18,8 +18,6 @@ import {
editWebhookMessage,
getWebhookMessage,
} from "./api.js";
import type { Client } from "./client.js";
import { type ComponentData, type Modal } from "./components.js";
import { OptionsHandler } from "./interaction-options.js";
import {
InteractionResponseController,
@@ -29,11 +27,29 @@ import {
import { extractModalFields, ModalFields } from "./modal-fields.js";
import { serializePayload, type MessagePayload } from "./payload.js";
import { assertDiscordInteractionPayload } from "./schemas.js";
import { channelFactory, Guild, Message, User, type DiscordChannel } from "./structures.js";
import {
channelFactory,
Guild,
Message,
User,
type DiscordChannel,
type StructureClient,
} from "./structures.js";
export { OptionsHandler } from "./interaction-options.js";
export { ModalFields } from "./modal-fields.js";
type InteractionClient = StructureClient & {
options: { clientId: string };
fetchChannel(id: string): Promise<DiscordChannel>;
};
type Modal = {
serialize: () => unknown;
};
type ComponentData = Record<string, unknown>;
export type RawInteraction = APIInteraction & {
token: string;
member?: { user?: APIUser; roles?: string[] };
@@ -71,7 +87,7 @@ function toModalSubmitRawInteraction(rawData: RawInteraction): ModalSubmitRawInt
return rawData as ModalSubmitRawInteraction;
}
function readInteractionUser(rawData: RawInteraction, client: Client): User | null {
function readInteractionUser(rawData: RawInteraction, client: InteractionClient): User | null {
const directUser = "user" in rawData ? rawData.user : undefined;
if (directUser && typeof directUser === "object" && "id" in directUser) {
return new User(client, directUser);
@@ -98,7 +114,7 @@ export class BaseInteraction {
private readonly response = new InteractionResponseController();
constructor(
public client: Client,
public client: InteractionClient,
public rawData: RawInteraction,
) {
this.id = rawData.id;
@@ -214,7 +230,10 @@ export class BaseInteraction {
export class CommandInteraction extends BaseInteraction {
readonly options: OptionsHandler;
constructor(client: Client, rawData: APIApplicationCommandInteraction & RawInteraction) {
constructor(
client: InteractionClient,
rawData: APIApplicationCommandInteraction & RawInteraction,
) {
super(client, rawData);
this.options = new OptionsHandler(
rawData.data.options,
@@ -235,7 +254,7 @@ export class AutocompleteInteraction extends CommandInteraction {
export class BaseComponentInteraction extends BaseInteraction {
readonly values: string[];
constructor(client: Client, rawData: APIMessageComponentInteraction & RawInteraction) {
constructor(client: InteractionClient, rawData: APIMessageComponentInteraction & RawInteraction) {
super(client, rawData);
this.message =
rawData.message && typeof rawData.message === "object"
@@ -264,7 +283,7 @@ export class ChannelSelectMenuInteraction extends BaseComponentInteraction {}
export class ModalInteraction extends BaseInteraction {
readonly fields: ModalFields;
constructor(client: Client, rawData: APIModalSubmitInteraction & RawInteraction) {
constructor(client: InteractionClient, rawData: APIModalSubmitInteraction & RawInteraction) {
super(client, rawData);
this.fields = new ModalFields(
extractModalFields(rawData.data.components ?? []),
@@ -277,7 +296,7 @@ export class ModalInteraction extends BaseInteraction {
}
}
export function createInteraction(client: Client, rawData: RawInteraction) {
export function createInteraction(client: InteractionClient, rawData: RawInteraction) {
assertDiscordInteractionPayload(rawData);
if (rawData.type === InteractionType.ApplicationCommandAutocomplete) {
return new AutocompleteInteraction(client, toCommandRawInteraction(rawData));

View File

@@ -1,6 +1,6 @@
import { Routes } from "discord-api-types/v10";
import { isLiveTestEnabled } from "openclaw/plugin-sdk/test-env";
import { describe, expect, it } from "vitest";
import { isLiveTestEnabled } from "../../../../src/agents/live-test-helpers.js";
import { parseApplicationIdFromToken } from "../probe.js";
import { RequestClient } from "./rest.js";

View File

@@ -1,6 +1,5 @@
import { type APIRole, type APIUser } from "discord-api-types/v10";
import type { Client } from "./client.js";
import { Role, User } from "./structures.js";
import { Role, User, type StructureClient } from "./structures.js";
type ModalResolvedData = {
roles?: Record<string, { id: string; name?: string }>;
@@ -49,7 +48,7 @@ export class ModalFields {
constructor(
private values: Record<string, string | string[]>,
private resolved?: ModalResolvedData,
private client?: Client,
private client?: StructureClient,
) {}
private value(id: string, required: boolean): string | string[] | undefined {

View File

@@ -1,14 +1,4 @@
import { MessageFlags, type APIEmbed } from "discord-api-types/v10";
import type {
BaseMessageInteractiveComponent,
Container,
File,
MediaGallery,
Row,
Section,
Separator,
TextDisplay,
} from "./components.js";
import { Embed } from "./embeds.js";
export type MessagePayloadFile = {
@@ -32,14 +22,10 @@ export type MessagePayloadObject = {
stickers?: [string, string, string] | [string, string] | [string];
};
export type MessagePayload = string | MessagePayloadObject;
export type TopLevelComponents =
| Row<BaseMessageInteractiveComponent>
| Container
| File
| MediaGallery
| Section
| Separator
| TextDisplay;
export type TopLevelComponents = {
isV2?: boolean;
serialize: () => unknown;
};
function clean<T extends Record<string, unknown>>(value: T): T {
return Object.fromEntries(Object.entries(value).filter(([, entry]) => entry !== undefined)) as T;

View File

@@ -1,4 +1,9 @@
import type { RequestData } from "./rest.js";
type RequestData = {
body?: unknown;
multipartStyle?: "message" | "form";
rawBody?: boolean;
headers?: Record<string, string>;
};
export function serializeRequestBody(
data: RequestData | undefined,

View File

@@ -1,4 +1,4 @@
import type { QueuedRequest } from "./rest.js";
type QueryValue = string | number | boolean;
export function createRouteKey(method: string, path: string): string {
return `${method.toUpperCase()} ${path.split("?")[0] ?? path}`;
@@ -38,7 +38,7 @@ export function readResetAt(response: Response): number | undefined {
return reset !== undefined ? reset * 1000 : undefined;
}
export function appendQuery(path: string, query?: QueuedRequest["query"]): string {
export function appendQuery(path: string, query?: Record<string, QueryValue>): string {
if (!query || Object.keys(query).length === 0) {
return path;
}

View File

@@ -17,20 +17,24 @@ import {
pinChannelMessage,
unpinChannelMessage,
} from "./api.js";
import type { Client } from "./client.js";
import { serializePayload, type MessagePayload } from "./payload.js";
import type { RequestClient } from "./rest.js";
type RawOrId<T> = T | string | { id: string; channelId?: string };
export type StructureClient = {
rest: RequestClient;
fetchUser(id: string): Promise<User>;
};
export class Base {
constructor(protected client: Client) {}
constructor(protected client: StructureClient) {}
}
export class User<IsPartial extends boolean = false> extends Base {
protected _rawData: APIUser | null;
readonly id: string;
constructor(client: Client, rawDataOrId: IsPartial extends true ? string : APIUser) {
constructor(client: StructureClient, rawDataOrId: IsPartial extends true ? string : APIUser) {
super(client);
this._rawData = typeof rawDataOrId === "string" ? null : rawDataOrId;
this.id = typeof rawDataOrId === "string" ? rawDataOrId : rawDataOrId.id;
@@ -84,7 +88,7 @@ export class User<IsPartial extends boolean = false> extends Base {
export class Role<IsPartial extends boolean = false> extends Base {
protected _rawData: APIRole | null;
readonly id: string;
constructor(client: Client, rawDataOrId: IsPartial extends true ? string : APIRole) {
constructor(client: StructureClient, rawDataOrId: IsPartial extends true ? string : APIRole) {
super(client);
this._rawData = typeof rawDataOrId === "string" ? null : rawDataOrId;
this.id = typeof rawDataOrId === "string" ? rawDataOrId : rawDataOrId.id;
@@ -97,7 +101,7 @@ export class Role<IsPartial extends boolean = false> extends Base {
export class Guild<IsPartial extends boolean = false> extends Base {
protected _rawData: APIGuild | null;
readonly id: string;
constructor(client: Client, rawDataOrId: IsPartial extends true ? string : APIGuild) {
constructor(client: StructureClient, rawDataOrId: IsPartial extends true ? string : APIGuild) {
super(client);
this._rawData = typeof rawDataOrId === "string" ? null : rawDataOrId;
this.id = typeof rawDataOrId === "string" ? rawDataOrId : rawDataOrId.id;
@@ -109,7 +113,7 @@ export class Guild<IsPartial extends boolean = false> extends Base {
export class GuildMember extends Base {
constructor(
client: Client,
client: StructureClient,
public rawData: APIGuildMember,
) {
super(client);
@@ -130,7 +134,7 @@ export class Message<IsPartial extends boolean = false> extends Base {
readonly id: string;
readonly channelId: string;
constructor(client: Client, rawDataOrIds: RawOrId<APIMessage>) {
constructor(client: StructureClient, rawDataOrIds: RawOrId<APIMessage>) {
super(client);
this._rawData =
typeof rawDataOrIds === "string" || !("author" in rawDataOrIds) ? null : rawDataOrIds;
@@ -257,7 +261,7 @@ export type DiscordChannel = APIChannel & {
};
export function channelFactory(
_client: Client,
_client: StructureClient,
channelData: APIChannel,
_partial?: boolean,
): DiscordChannel {

View File

@@ -1,6 +1,6 @@
import { logVerbose } from "openclaw/plugin-sdk/runtime-env";
import { normalizeOptionalStringifiedId } from "openclaw/plugin-sdk/text-runtime";
import type { ChannelType, Client, Message } from "../internal/discord.js";
import type { ChannelType, Message } from "../internal/discord.js";
import { resolveDiscordChannelInfoSafe } from "./channel-access.js";
export type DiscordChannelInfo = {
@@ -10,6 +10,9 @@ export type DiscordChannelInfo = {
parentId?: string;
ownerId?: string;
};
export type DiscordChannelInfoClient = {
fetchChannel(channelId: string): Promise<unknown>;
};
type DiscordMessageWithChannelId = Message & {
channel_id?: unknown;
@@ -45,7 +48,7 @@ export function resolveDiscordMessageChannelId(params: {
}
export async function resolveDiscordChannelInfo(
client: Client,
client: DiscordChannelInfoClient,
channelId: string,
): Promise<DiscordChannelInfo | null> {
const cached = DISCORD_CHANNEL_INFO_CACHE.get(channelId);
@@ -65,8 +68,13 @@ export async function resolveDiscordChannelInfo(
return null;
}
const channelInfo = resolveDiscordChannelInfoSafe(channel);
const rawChannel = channel as { type?: ChannelType };
const type = (channelInfo.type as ChannelType | undefined) ?? rawChannel.type;
if (type === undefined) {
return null;
}
const payload: DiscordChannelInfo = {
type: (channelInfo.type as ChannelType | undefined) ?? channel.type,
type,
name: channelInfo.name,
topic: channelInfo.topic,
parentId: channelInfo.parentId,

View File

@@ -13,6 +13,7 @@ import type { Message } from "../internal/discord.js";
export {
__resetDiscordChannelInfoCacheForTest,
resolveDiscordChannelInfo,
type DiscordChannelInfoClient,
resolveDiscordMessageChannelId,
type DiscordChannelInfo,
} from "./message-channel-info.js";

View File

@@ -1,4 +1,5 @@
import { ChannelType, type Client } from "../internal/discord.js";
import { ChannelType } from "../internal/discord.js";
import type { DiscordChannelInfoClient } from "./message-utils.js";
import { resolveDiscordThreadLikeChannelContext } from "./thread-channel-context.js";
type DiscordInteractionChannel = {
@@ -21,7 +22,7 @@ export type DiscordNativeInteractionChannelContext = {
export async function resolveDiscordNativeInteractionChannelContext(params: {
channel: DiscordInteractionChannel | null | undefined;
client: Client;
client: DiscordChannelInfoClient;
hasGuild: boolean;
channelIdFallback: string;
}): Promise<DiscordNativeInteractionChannelContext> {

View File

@@ -1,11 +1,15 @@
import { ChannelType, type Client } from "../internal/discord.js";
import { ChannelType } from "../internal/discord.js";
import { normalizeDiscordSlug } from "./allow-list.js";
import {
resolveDiscordChannelIdSafe,
resolveDiscordChannelInfoSafe,
resolveDiscordChannelParentIdSafe,
} from "./channel-access.js";
import { resolveDiscordChannelInfo, type DiscordChannelInfo } from "./message-utils.js";
import {
resolveDiscordChannelInfo,
type DiscordChannelInfo,
type DiscordChannelInfoClient,
} from "./message-utils.js";
import { resolveDiscordThreadParentInfo } from "./threading.js";
export type DiscordThreadLikeChannelContext = {
@@ -44,7 +48,7 @@ function buildFetchedChannelInfo(channel: unknown): DiscordChannelInfo | null {
}
export async function resolveDiscordThreadLikeChannelContext(params: {
client: Client;
client: DiscordChannelInfoClient;
channel: unknown;
channelIdFallback?: string;
channelInfo?: DiscordChannelInfo | null;
@@ -97,7 +101,7 @@ export async function resolveDiscordThreadLikeChannelContext(params: {
}
export async function resolveFetchedDiscordThreadLikeChannelContext(params: {
client: Client;
client: DiscordChannelInfoClient;
channel: unknown;
channelIdFallback?: string;
}): Promise<DiscordThreadLikeChannelContext> {

View File

@@ -26,6 +26,7 @@ import {
} from "./channel-access.js";
import {
resolveDiscordChannelInfo,
type DiscordChannelInfoClient,
resolveDiscordEmbedText,
resolveDiscordForwardedMessagesTextFromSnapshots,
resolveDiscordMessageChannelId,
@@ -203,7 +204,7 @@ export function resolveDiscordThreadChannel(params: {
}
export async function resolveDiscordThreadParentInfo(params: {
client: Client;
client: DiscordChannelInfoClient;
threadChannel: DiscordThreadChannel;
channelInfo: import("./message-utils.js").DiscordChannelInfo | null;
}): Promise<DiscordThreadParentInfo> {