diff --git a/extensions/discord/src/internal/client.test.ts b/extensions/discord/src/internal/client.test.ts index 38627082615..b6cc0629aec 100644 --- a/extensions/discord/src/internal/client.test.ts +++ b/extensions/discord/src/internal/client.test.ts @@ -150,6 +150,60 @@ describe("Client.deployCommands", () => { expect(deleteRequest).not.toHaveBeenCalled(); }); + it("does not patch live-only command metadata or reordered unordered arrays", async () => { + const client = createInternalTestClient([ + createTestCommand({ + name: "one", + options: [ + { + type: 3, + name: "value", + description: "Value", + required: false, + autocomplete: false, + channel_types: [1, 0], + }, + ], + }), + ]); + const get = vi.fn(async () => [ + { + id: "cmd1", + application_id: "app1", + type: ApplicationCommandType.ChatInput, + name: "one", + name_localized: "one", + description: "one command", + description_localized: "one command", + options: [ + { + type: 3, + name: "value", + description: "Value", + description_localized: "Value", + channel_types: [0, 1], + }, + ], + default_member_permissions: null, + dm_permission: true, + integration_types: [1, 0], + contexts: [2, 1, 0], + guild_id: undefined, + version: "1", + }, + ]); + const patch = vi.fn(async () => undefined); + const post = vi.fn(async () => undefined); + const deleteRequest = vi.fn(async () => undefined); + attachRestMock(client, { get, patch, post, delete: deleteRequest }); + + await client.deployCommands({ mode: "reconcile" }); + + expect(patch).not.toHaveBeenCalled(); + expect(post).not.toHaveBeenCalled(); + expect(deleteRequest).not.toHaveBeenCalled(); + }); + it("skips command deploy when the serialized command set is unchanged", async () => { const client = createInternalTestClient([createTestCommand({ name: "one" })]); const get = vi.fn(async () => []); diff --git a/extensions/discord/src/internal/client.ts b/extensions/discord/src/internal/client.ts index 531b3f6f0c0..7db1f0a10b0 100644 --- a/extensions/discord/src/internal/client.ts +++ b/extensions/discord/src/internal/client.ts @@ -6,7 +6,7 @@ import { DiscordEntityCache } from "./entity-cache.js"; import { DiscordEventQueue, type DiscordEventQueueOptions } from "./event-queue.js"; import { dispatchInteraction } from "./interaction-dispatch.js"; import { RequestClient, type RequestClientOptions } from "./rest.js"; -import type { Guild, GuildMember, User } from "./structures.js"; +import type { Guild, GuildMember, Message, User } from "./structures.js"; export interface Route { method: "GET" | "POST" | "PUT" | "PATCH" | "DELETE"; @@ -49,10 +49,18 @@ export interface ClientOptions { restCacheTtlMs?: number; } +type OneOffComponentResult = + | { success: true; customId: string; message: Message; values?: string[] } + | { success: false; message: Message; reason: "timed out" }; + export class ComponentRegistry< T extends { customId: string; customIdParser?: typeof parseCustomId; type?: number }, > { private entries = new Map(); + private oneOffComponents = new Map< + string, + { message: Message; resolve(result: OneOffComponentResult): void; timer: NodeJS.Timeout } + >(); private wildcardEntries: T[] = []; register(entry: T): void { @@ -90,12 +98,66 @@ export class ComponentRegistry< return true; }); } + + waitForMessageComponent(message: Message, timeoutMs: number): Promise { + const key = createOneOffComponentKey(message.id, message.channelId); + return new Promise((resolve) => { + const existing = this.oneOffComponents.get(key); + if (existing) { + clearTimeout(existing.timer); + existing.resolve({ success: false, message, reason: "timed out" }); + } + const timer = setTimeout( + () => { + this.oneOffComponents.delete(key); + resolve({ success: false, message, reason: "timed out" }); + }, + Math.max(0, timeoutMs), + ); + timer.unref?.(); + this.oneOffComponents.set(key, { + message, + timer, + resolve, + }); + }); + } + + resolveOneOffComponent(params: { + channelId?: string; + customId: string; + messageId?: string; + values?: string[]; + }): boolean { + if (!params.messageId || !params.channelId) { + return false; + } + const entry = this.oneOffComponents.get( + createOneOffComponentKey(params.messageId, params.channelId), + ); + if (!entry) { + return false; + } + clearTimeout(entry.timer); + this.oneOffComponents.delete(createOneOffComponentKey(params.messageId, params.channelId)); + entry.resolve({ + success: true, + customId: params.customId, + message: entry.message, + values: params.values, + }); + return true; + } } function parseRegistryKey(customId: string, parser: typeof parseCustomId = parseCustomId): string { return parser(customId).key; } +function createOneOffComponentKey(messageId: string, channelId: string): string { + return `${messageId}:${channelId}`; +} + export class Client { routes: Route[] = []; plugins: Array<{ id: string; plugin: Plugin }> = []; diff --git a/extensions/discord/src/internal/command-deploy.ts b/extensions/discord/src/internal/command-deploy.ts index 71dc532dec9..11a815637e4 100644 --- a/extensions/discord/src/internal/command-deploy.ts +++ b/extensions/discord/src/internal/command-deploy.ts @@ -157,12 +157,15 @@ function comparableCommand(value: unknown): unknown { return value; } const omit = new Set([ - "id", "application_id", + "description_localized", + "dm_permission", "guild_id", + "id", + "name_localized", + "nsfw", "version", "default_permission", - "nsfw", ]); return stableComparableObject( Object.fromEntries( @@ -171,18 +174,52 @@ function comparableCommand(value: unknown): unknown { ); } -function stableComparableObject(value: unknown): unknown { +const unorderedCommandArrayFields = new Set(["channel_types", "contexts", "integration_types"]); +const subcommandOptionOnlyFields = new Set([ + "contexts", + "default_member_permissions", + "description_localized", + "description_localizations", + "integration_types", + "name_localized", + "name_localizations", +]); + +function stableComparableObject(value: unknown, path: string[] = []): unknown { if (Array.isArray(value)) { - return value.map((entry) => stableComparableObject(entry)); + const normalized = value.map((entry) => stableComparableObject(entry, path)); + const key = path.at(-1); + if ( + key && + unorderedCommandArrayFields.has(key) && + normalized.every( + (entry) => + typeof entry === "string" || typeof entry === "number" || typeof entry === "boolean", + ) + ) { + return normalized.toSorted((left, right) => String(left).localeCompare(String(right))); + } + return normalized; } if (!value || typeof value !== "object") { return value; } return Object.fromEntries( Object.entries(value as Record) - .filter(([, entry]) => entry !== undefined) + .filter(([key, entry]) => { + if (entry === undefined) { + return false; + } + if (path.includes("options") && subcommandOptionOnlyFields.has(key)) { + return false; + } + if ((key === "required" || key === "autocomplete") && entry === false) { + return false; + } + return true; + }) .toSorted(([a], [b]) => a.localeCompare(b)) - .map(([key, entry]) => [key, stableComparableObject(entry)]), + .map(([key, entry]) => [key, stableComparableObject(entry, [...path, key])]), ); } diff --git a/extensions/discord/src/internal/gateway.test.ts b/extensions/discord/src/internal/gateway.test.ts index 8001a8c34a9..ba37791f6d8 100644 --- a/extensions/discord/src/internal/gateway.test.ts +++ b/extensions/discord/src/internal/gateway.test.ts @@ -2,6 +2,7 @@ import { EventEmitter } from "node:events"; import { GatewayCloseCodes, GatewayDispatchEvents, + GatewayIntentBits, GatewayOpcodes, InteractionType, PresenceUpdateStatus, @@ -270,6 +271,29 @@ describe("GatewayPlugin", () => { ); }); + it("rejects gateway payloads that exceed Discord's size limit", () => { + const gateway = new GatewayPlugin({ autoInteractions: false }); + const send = attachOpenSocket(gateway); + + expect(() => + gateway.send({ + op: GatewayOpcodes.PresenceUpdate, + d: { + since: null, + activities: [ + { + name: "x".repeat(4_100), + type: 0, + }, + ], + status: PresenceUpdateStatus.Online, + afk: false, + }, + } as GatewaySendPayload), + ).toThrow(/4096-byte limit/); + expect(send).not.toHaveBeenCalled(); + }); + it("ignores stale socket close events after reconnecting", () => { const gateway = new TestGatewayPlugin({ autoInteractions: false, @@ -330,6 +354,7 @@ describe("GatewayPlugin", () => { it("clears resume state after invalid session false", async () => { vi.useFakeTimers(); + vi.spyOn(Math, "random").mockReturnValue(0); const gateway = new TestGatewayPlugin({ autoInteractions: false, url: "wss://gateway.example.test", @@ -354,6 +379,29 @@ describe("GatewayPlugin", () => { expect(sessionState.sequence).toBeNull(); }); + it("delays invalid-session reconnects by Discord's randomized cooldown floor", async () => { + vi.useFakeTimers(); + vi.spyOn(Math, "random").mockReturnValue(0.75); + const gateway = new TestGatewayPlugin({ + autoInteractions: false, + url: "wss://gateway.example.test", + }); + + gateway.connect(false); + gateway.sockets[0]?.emit("open"); + ( + gateway as unknown as { + handlePayload(payload: { op: number; d: unknown }, resume: boolean): void; + } + ).handlePayload({ op: GatewayOpcodes.InvalidSession, d: true }, true); + + await vi.advanceTimersByTimeAsync(3_999); + expect(gateway.connectCalls).toEqual([false]); + + await vi.advanceTimersByTimeAsync(1); + expect(gateway.connectCalls).toEqual([false, true]); + }); + it("includes close code details when reconnect attempts are exhausted", async () => { vi.useFakeTimers(); const gateway = new TestGatewayPlugin({ @@ -508,4 +556,48 @@ describe("GatewayPlugin", () => { expect.stringContaining(`"op":${GatewayOpcodes.Identify}`), ); }); + + it("validates requestGuildMembers before sending", () => { + const withoutMembersIntent = new GatewayPlugin({ autoInteractions: false }); + attachOpenSocket(withoutMembersIntent); + + expect(() => + withoutMembersIntent.requestGuildMembers({ guild_id: "guild1", query: "", limit: 0 }), + ).toThrow(/GUILD_MEMBERS intent/); + + const withoutPresenceIntent = new GatewayPlugin({ + autoInteractions: false, + intents: GatewayIntentBits.GuildMembers, + }); + attachOpenSocket(withoutPresenceIntent); + + expect(() => + withoutPresenceIntent.requestGuildMembers({ + guild_id: "guild1", + query: "", + limit: 0, + presences: true, + }), + ).toThrow(/GUILD_PRESENCES intent/); + + const valid = new GatewayPlugin({ + autoInteractions: false, + intents: GatewayIntentBits.GuildMembers | GatewayIntentBits.GuildPresences, + }); + const send = attachOpenSocket(valid); + + expect(() => + valid.requestGuildMembers({ + guild_id: "guild1", + limit: 1, + }), + ).toThrow(/query or user_ids/); + + valid.requestGuildMembers({ guild_id: "guild1", query: "", limit: 0, presences: true }); + expect(send).toHaveBeenCalledTimes(1); + expect(JSON.parse(send.mock.calls[0]?.[0] as string)).toEqual({ + op: GatewayOpcodes.RequestGuildMembers, + d: { guild_id: "guild1", query: "", limit: 0, presences: true }, + }); + }); }); diff --git a/extensions/discord/src/internal/gateway.ts b/extensions/discord/src/internal/gateway.ts index 8c6c5cf616c..764ddea219f 100644 --- a/extensions/discord/src/internal/gateway.ts +++ b/extensions/discord/src/internal/gateway.ts @@ -46,6 +46,9 @@ type GatewayPluginOptions = { const READY_STATE_OPEN = 1; const DEFAULT_GATEWAY_URL = "wss://gateway.discord.gg/"; +const DISCORD_GATEWAY_PAYLOAD_LIMIT_BYTES = 4096; +const INVALID_SESSION_MIN_DELAY_MS = 1_000; +const INVALID_SESSION_JITTER_MS = 4_000; function ensureGatewayParams(url: string): string { const parsed = new URL(url); @@ -274,7 +277,11 @@ export class GatewayPlugin extends Plugin { if (!payload.d) { this.resetSessionState(); } - this.scheduleReconnect(payload.d); + this.scheduleReconnect( + payload.d, + undefined, + INVALID_SESSION_MIN_DELAY_MS + Math.floor(Math.random() * INVALID_SESSION_JITTER_MS), + ); break; case GatewayOpcodes.Reconnect: this.scheduleReconnect(true); @@ -347,6 +354,15 @@ export class GatewayPlugin extends Plugin { throw new Error("Discord gateway socket is not open"); } const serialized = JSON.stringify(payload); + const payloadSize = + typeof Buffer !== "undefined" + ? Buffer.byteLength(serialized, "utf8") + : new TextEncoder().encode(serialized).byteLength; + if (payloadSize > DISCORD_GATEWAY_PAYLOAD_LIMIT_BYTES) { + throw new Error( + `Discord gateway payload exceeds ${DISCORD_GATEWAY_PAYLOAD_LIMIT_BYTES}-byte limit`, + ); + } this.outboundLimiter.send(serialized, { critical: skipRateLimit }); } @@ -386,7 +402,7 @@ export class GatewayPlugin extends Plugin { this.sequence = null; } - private scheduleReconnect(resume: boolean, closeCode?: number): void { + private scheduleReconnect(resume: boolean, closeCode?: number, minDelayMs = 0): void { if (!this.shouldReconnect) { return; } @@ -408,7 +424,10 @@ export class GatewayPlugin extends Plugin { ); return; } - const delay = Math.min(30_000, 1_000 * 2 ** Math.min(this.reconnectAttempts, 5)); + const delay = Math.max( + minDelayMs, + Math.min(30_000, 1_000 * 2 ** Math.min(this.reconnectAttempts, 5)), + ); this.reconnectTimer.schedule(delay, () => { this.connect(resume); }); @@ -423,6 +442,15 @@ export class GatewayPlugin extends Plugin { } requestGuildMembers(data: RequestGuildMembersData): void { + if (!this.hasIntent(GatewayIntentBits.GuildMembers)) { + throw new Error("GUILD_MEMBERS intent is required for requestGuildMembers"); + } + if (data.presences && !this.hasIntent(GatewayIntentBits.GuildPresences)) { + throw new Error("GUILD_PRESENCES intent is required when requesting presences"); + } + if (!data.query && data.query !== "" && !data.user_ids) { + throw new Error("Either query or user_ids is required for requestGuildMembers"); + } this.send({ op: GatewayOpcodes.RequestGuildMembers, d: data } as GatewaySendPayload); } diff --git a/extensions/discord/src/internal/interaction-dispatch.ts b/extensions/discord/src/internal/interaction-dispatch.ts index f4b96700aff..18f9e0d4ed4 100644 --- a/extensions/discord/src/internal/interaction-dispatch.ts +++ b/extensions/discord/src/internal/interaction-dispatch.ts @@ -30,6 +30,12 @@ type DispatchClient = Parameters[0] & { commands: BaseCommand[]; componentHandler: { resolve(customId: string, options?: { componentType?: number }): DispatchComponent | undefined; + resolveOneOffComponent(params: { + channelId?: string; + customId: string; + messageId?: string; + values?: string[]; + }): boolean; }; modalHandler: { resolve(customId: string): DispatchModal | undefined }; }; @@ -75,11 +81,22 @@ export async function dispatchInteraction( if (!customId) { return; } + const componentInteraction = interaction as BaseComponentInteraction; + if ( + client.componentHandler.resolveOneOffComponent({ + channelId: readMessageChannelId(rawData), + customId, + messageId: readMessageId(rawData), + values: readComponentValues(rawData), + }) + ) { + await componentInteraction.acknowledge(); + return; + } const component = client.componentHandler.resolve(customId, { componentType: (rawData as { data?: { component_type?: number } }).data?.component_type, }); if (component) { - const componentInteraction = interaction as BaseComponentInteraction; await deferComponentInteractionIfNeeded(component, componentInteraction); await component.run(componentInteraction, parseComponentInteractionData(component, customId)); } @@ -128,3 +145,18 @@ function readInteractionName(rawData: APIInteraction): string | undefined { function readCustomId(rawData: APIInteraction): string | undefined { return (rawData as { data?: { custom_id?: string } }).data?.custom_id; } + +function readComponentValues(rawData: APIInteraction): string[] | undefined { + const values = (rawData as { data?: { values?: unknown } }).data?.values; + return Array.isArray(values) ? values.map(String) : undefined; +} + +function readMessageId(rawData: APIInteraction): string | undefined { + const messageId = (rawData as { message?: { id?: unknown } }).message?.id; + return typeof messageId === "string" ? messageId : undefined; +} + +function readMessageChannelId(rawData: APIInteraction): string | undefined { + const channelId = (rawData as { message?: { channel_id?: unknown } }).message?.channel_id; + return typeof channelId === "string" ? channelId : undefined; +} diff --git a/extensions/discord/src/internal/interactions.test.ts b/extensions/discord/src/internal/interactions.test.ts index 50800e88866..eb3bfbf085b 100644 --- a/extensions/discord/src/internal/interactions.test.ts +++ b/extensions/discord/src/internal/interactions.test.ts @@ -179,6 +179,78 @@ describe("BaseInteraction", () => { expect(interaction.user?.globalName).toBe("Alice Cooper"); expect(interaction.user?.discriminator).toBe("1234"); }); + + it("waits for a one-off component reply without invoking registered handlers", async () => { + const get = vi.fn(async () => ({ + id: "message1", + channel_id: "channel1", + author: { + id: "bot1", + username: "bot", + discriminator: "0000", + global_name: null, + avatar: null, + }, + content: "pick", + timestamp: "2026-05-01T00:00:00.000Z", + })); + const post = vi.fn(async () => undefined); + const client = createInternalTestClient(); + attachRestMock(client, { get, post }); + const interaction = new BaseInteraction( + client, + createInternalInteractionPayload({ id: "interaction1", token: "token1" }), + ); + + const wait = interaction.replyAndWaitForComponent({ content: "pick" }, 1_000); + await vi.waitFor(() => + expect(get).toHaveBeenCalledWith("/webhooks/app1/token1/messages/%40original"), + ); + + await client.handleInteraction( + createInternalComponentInteractionPayload({ + id: "component-interaction1", + token: "component-token1", + data: { custom_id: "button1" }, + message: { + id: "message1", + channel_id: "channel1", + author: { + id: "bot1", + username: "bot", + discriminator: "0000", + global_name: null, + avatar: null, + }, + content: "pick", + timestamp: "2026-05-01T00:00:00.000Z", + edited_timestamp: null, + tts: false, + mention_everyone: false, + mentions: [], + mention_roles: [], + attachments: [], + embeds: [], + pinned: false, + type: 0, + }, + }), + ); + + await expect(wait).resolves.toEqual({ + success: true, + customId: "button1", + message: expect.objectContaining({ id: "message1", channelId: "channel1" }), + values: undefined, + }); + expect(post).toHaveBeenNthCalledWith( + 2, + "/interactions/component-interaction1/component-token1/callback", + { + body: { type: InteractionResponseType.DeferredMessageUpdate }, + }, + ); + }); }); describe("ModalInteraction", () => { diff --git a/extensions/discord/src/internal/interactions.ts b/extensions/discord/src/internal/interactions.ts index 7b2e9a57100..22c74080f33 100644 --- a/extensions/discord/src/internal/interactions.ts +++ b/extensions/discord/src/internal/interactions.ts @@ -7,6 +7,7 @@ import { type APIChannel, type APIInteraction, type APIInteractionDataResolvedChannel, + type APIMessage, type APIMessageComponentInteraction, type APIModalSubmitInteraction, type APIUser, @@ -41,6 +42,15 @@ export { ModalFields } from "./modal-fields.js"; type InteractionClient = StructureClient & { options: { clientId: string }; + componentHandler: { + waitForMessageComponent( + message: Message, + timeoutMs: number, + ): Promise< + | { success: true; customId: string; message: Message; values?: string[] } + | { success: false; message: Message; reason: "timed out" } + >; + }; fetchChannel(id: string): Promise; }; @@ -216,6 +226,16 @@ export class BaseInteraction { ); } + async replyAndWaitForComponent(payload: MessagePayload, timeoutMs = 300_000) { + const result = await this.reply(payload); + const rawMessage = isRawMessage(result) ? result : await this.fetchReply(); + if (!isRawMessage(rawMessage)) { + throw new Error("Discord interaction reply did not return a message"); + } + const message = new Message(this.client, rawMessage as APIMessage); + return await this.client.componentHandler.waitForMessageComponent(message, timeoutMs); + } + async followUp(payload: MessagePayload): Promise { const body = serializePayload(payload); return await createWebhookMessage( @@ -272,6 +292,18 @@ export class BaseComponentInteraction extends BaseInteraction { async showModal(modal: Modal): Promise { return await this.callback(InteractionResponseType.Modal, modal.serialize()); } + + async editAndWaitForComponent( + payload: MessagePayload, + message: Message | null = this.message, + timeoutMs = 300_000, + ) { + if (!message) { + return null; + } + const editedMessage = await message.edit(payload); + return await this.client.componentHandler.waitForMessageComponent(editedMessage, timeoutMs); + } } export class ButtonInteraction extends BaseComponentInteraction {} @@ -335,3 +367,12 @@ export function parseComponentInteractionData( ): ComponentData { return component.customIdParser(customId).data; } + +function isRawMessage(value: unknown): value is { id: string; channel_id: string } { + return ( + Boolean(value) && + typeof value === "object" && + typeof (value as { id?: unknown }).id === "string" && + typeof (value as { channel_id?: unknown }).channel_id === "string" + ); +} diff --git a/extensions/discord/src/internal/rest-scheduler.ts b/extensions/discord/src/internal/rest-scheduler.ts index 19a9c4e2ad0..dc7eb3a12d4 100644 --- a/extensions/discord/src/internal/rest-scheduler.ts +++ b/extensions/discord/src/internal/rest-scheduler.ts @@ -1,12 +1,15 @@ import { RateLimitError, readRetryAfter } from "./rest-errors.js"; import { createBucketKey, createRouteKey, readHeaderNumber, readResetAt } from "./rest-routes.js"; +export type RequestPriority = "critical" | "standard" | "background"; export type RequestQuery = Record; type ScheduledRequest = { method: string; path: string; data?: TData; + enqueuedAt: number; generation: number; + priority: RequestPriority; query?: RequestQuery; routeKey: string; retryCount: number; @@ -14,25 +17,47 @@ type ScheduledRequest = { reject: (reason?: unknown) => void; }; +type LaneQueues = Record>>; + type BucketState = { active: number; bucket?: string; invalidRequests: number; limit?: number; - pending: Array>; + pending: LaneQueues; rateLimitHits: number; remaining?: number; resetAt: number; routeKeys: Set; }; -type RestSchedulerOptions = { - maxConcurrency: number; - maxRateLimitRetries: number; +export type RestSchedulerLaneOptions = { maxQueueSize: number; + staleAfterMs?: number; + weight: number; +}; + +export type RestSchedulerOptions = { + lanes: Record; + maxConcurrency: number; + maxQueueSize: number; + maxRateLimitRetries: number; }; const INVALID_REQUEST_WINDOW_MS = 10 * 60_000; +const requestPriorities = ["critical", "standard", "background"] as const; + +function createLaneQueues(): LaneQueues { + return { + critical: [], + standard: [], + background: [], + }; +} + +function countPending(bucket: BucketState): number { + return requestPriorities.reduce((count, lane) => count + bucket.pending[lane].length, 0); +} export class RestScheduler { private activeWorkers = 0; @@ -40,6 +65,18 @@ export class RestScheduler { private drainTimer: NodeJS.Timeout | undefined; private globalRateLimitUntil = 0; private invalidRequestTimestamps: Array<{ at: number; status: number }> = []; + private laneCursor = 0; + private laneDropped: Record = { + critical: 0, + standard: 0, + background: 0, + }; + private laneSchedule: RequestPriority[]; + private queuedByLane: Record = { + critical: 0, + standard: 0, + background: 0, + }; private queueGeneration = 0; private queuedRequests = 0; private routeBuckets = new Map(); @@ -47,23 +84,35 @@ export class RestScheduler { constructor( private readonly options: RestSchedulerOptions, private readonly executor: (request: ScheduledRequest) => Promise, - ) {} + ) { + this.laneSchedule = this.buildLaneSchedule(options.lanes); + } enqueue(params: { method: string; path: string; data?: TData; + priority: RequestPriority; query?: RequestQuery; }): Promise { if (this.queuedRequests >= this.options.maxQueueSize) { throw new Error("Discord request queue is full"); } + const laneOptions = this.options.lanes[params.priority]; + if (this.queuedByLane[params.priority] >= laneOptions.maxQueueSize) { + this.laneDropped[params.priority] += 1; + throw new Error( + `Discord ${params.priority} request queue is full (${this.queuedByLane[params.priority]} / ${laneOptions.maxQueueSize})`, + ); + } const routeKey = createRouteKey(params.method, params.path); const bucket = this.getBucket(this.routeBuckets.get(routeKey) ?? routeKey); return new Promise((resolve, reject) => { this.queuedRequests += 1; - bucket.pending.push({ + this.queuedByLane[params.priority] += 1; + bucket.pending[params.priority].push({ ...params, + enqueuedAt: Date.now(), generation: this.queueGeneration, routeKey, retryCount: 0, @@ -108,7 +157,10 @@ export class RestScheduler { active: bucket.active, bucket: bucket.bucket, invalidRequests: bucket.invalidRequests, - pending: bucket.pending.length, + pending: countPending(bucket), + pendingByLane: Object.fromEntries( + requestPriorities.map((lane) => [lane, bucket.pending[lane].length]), + ), rateLimitHits: bucket.rateLimitHits, remaining: bucket.remaining, resetAt: bucket.resetAt, @@ -123,6 +175,11 @@ export class RestScheduler { {}, ), queueSize: this.queueSize, + queueSizeByLane: { ...this.queuedByLane }, + droppedByLane: { ...this.laneDropped }, + oldestQueuedByLane: Object.fromEntries( + requestPriorities.map((lane) => [lane, this.getOldestQueuedAge(lane)]), + ), activeWorkers: this.activeWorkers, maxConcurrentWorkers: this.maxConcurrentWorkers, }; @@ -144,7 +201,7 @@ export class RestScheduler { const bucket: BucketState = { active: 0, invalidRequests: 0, - pending: [], + pending: createLaneQueues(), rateLimitHits: 0, resetAt: 0, routeKeys: new Set([key]), @@ -180,7 +237,7 @@ export class RestScheduler { bucket: BucketState, now = Date.now(), ): void { - if (bucket.active > 0 || bucket.pending.length > 0 || this.isBucketRateLimited(bucket, now)) { + if (bucket.active > 0 || countPending(bucket) > 0 || this.isBucketRateLimited(bucket, now)) { return; } for (const routeKey of Array.from(bucket.routeKeys)) { @@ -201,8 +258,10 @@ export class RestScheduler { this.routeBuckets.set(routeKey, bucketKey); const routeBucket = this.buckets.get(routeKey); if (routeBucket && routeBucket !== target) { - target.pending.push(...routeBucket.pending); - routeBucket.pending = []; + for (const lane of requestPriorities) { + target.pending[lane].push(...routeBucket.pending[lane]); + routeBucket.pending[lane] = []; + } if (routeBucket.active === 0) { this.buckets.delete(routeKey); } @@ -302,42 +361,16 @@ export class RestScheduler { } private drainQueues(): void { - const now = Date.now(); - if (this.globalRateLimitUntil > now) { - this.scheduleDrain(this.globalRateLimitUntil - now); - return; - } let nextDelayMs = Number.POSITIVE_INFINITY; - for (const [key, bucket] of this.buckets) { - if (this.activeWorkers >= this.maxConcurrentWorkers) { + while (this.activeWorkers < this.maxConcurrentWorkers) { + const next = this.takeNextQueuedRequest(); + if (!next.queued) { + if (next.waitMs !== undefined) { + nextDelayMs = Math.min(nextDelayMs, next.waitMs); + } break; } - if (bucket.pending.length === 0) { - if (bucket.active !== 0) { - continue; - } - if (this.isBucketRateLimited(bucket, now)) { - nextDelayMs = Math.min(nextDelayMs, bucket.resetAt - now); - continue; - } - this.pruneIdleRouteMappings(key, bucket, now); - if (this.shouldPruneIdleBucket(key)) { - this.buckets.delete(key); - } - continue; - } - if (bucket.active > 0) { - continue; - } - const waitMs = this.getBucketWaitMs(bucket, now); - if (waitMs > 0) { - nextDelayMs = Math.min(nextDelayMs, waitMs); - continue; - } - const queued = bucket.pending.shift(); - if (!queued) { - continue; - } + const { bucket, queued } = next; if (bucket.remaining !== undefined && bucket.remaining > 0) { bucket.remaining -= 1; } @@ -350,6 +383,87 @@ export class RestScheduler { } } + private takeNextQueuedRequest(): + | { bucket: BucketState; queued: ScheduledRequest; waitMs?: never } + | { bucket?: never; queued?: never; waitMs?: number } { + const now = Date.now(); + if (this.globalRateLimitUntil > now) { + return { waitMs: this.globalRateLimitUntil - now }; + } + this.pruneIdleBuckets(now); + let nextDelayMs: number | undefined; + const buckets = Array.from(this.buckets.values()).filter((bucket) => countPending(bucket) > 0); + if (buckets.length === 0) { + return {}; + } + for (let laneOffset = 0; laneOffset < this.laneSchedule.length; laneOffset += 1) { + const lane = this.laneSchedule[(this.laneCursor + laneOffset) % this.laneSchedule.length]; + if (!lane || this.queuedByLane[lane] <= 0) { + continue; + } + for (const bucket of buckets) { + const queue = bucket.pending[lane]; + this.dropStaleHeadRequests(queue, lane, now); + if (queue.length === 0) { + continue; + } + if (bucket.active > 0) { + nextDelayMs = Math.min(nextDelayMs ?? 5, 5); + continue; + } + const waitMs = this.getBucketWaitMs(bucket, now); + if (waitMs > 0) { + nextDelayMs = Math.min(nextDelayMs ?? waitMs, waitMs); + continue; + } + const queued = queue.shift(); + if (!queued) { + continue; + } + this.queuedByLane[lane] = Math.max(0, this.queuedByLane[lane] - 1); + this.laneCursor = (this.laneCursor + laneOffset + 1) % this.laneSchedule.length; + return { bucket, queued }; + } + } + return { waitMs: nextDelayMs }; + } + + private dropStaleHeadRequests( + queue: Array>, + lane: RequestPriority, + now: number, + ): void { + const staleAfterMs = this.options.lanes[lane].staleAfterMs; + if (!staleAfterMs || staleAfterMs <= 0) { + return; + } + while (queue.length > 0 && now - (queue[0]?.enqueuedAt ?? now) > staleAfterMs) { + const stale = queue.shift(); + if (!stale) { + continue; + } + this.queuedRequests = Math.max(0, this.queuedRequests - 1); + this.queuedByLane[lane] = Math.max(0, this.queuedByLane[lane] - 1); + this.laneDropped[lane] += 1; + stale.reject(new Error(`Dropped stale ${lane} request after ${now - stale.enqueuedAt}ms`)); + } + } + + private pruneIdleBuckets(now = Date.now()): void { + for (const [key, bucket] of this.buckets) { + if (bucket.active !== 0 || countPending(bucket) > 0) { + continue; + } + if (this.isBucketRateLimited(bucket, now)) { + continue; + } + this.pruneIdleRouteMappings(key, bucket, now); + if (this.shouldPruneIdleBucket(key)) { + this.buckets.delete(key); + } + } + } + private async runQueuedRequest( queued: ScheduledRequest, bucket: BucketState, @@ -369,7 +483,7 @@ export class RestScheduler { if (!requeued) { this.queuedRequests = Math.max(0, this.queuedRequests - 1); } - if (bucket.active === 0 && bucket.pending.length === 0) { + if (bucket.active === 0 && countPending(bucket) === 0) { for (const routeKey of bucket.routeKeys) { if (this.routeBuckets.get(routeKey) === routeKey) { this.routeBuckets.delete(routeKey); @@ -388,21 +502,50 @@ export class RestScheduler { return false; } const bucketKey = this.routeBuckets.get(queued.routeKey) ?? queued.routeKey; - this.getBucket(bucketKey).pending.push({ + this.getBucket(bucketKey).pending[queued.priority].push({ ...queued, + enqueuedAt: Date.now(), retryCount: queued.retryCount + 1, }); + this.queuedByLane[queued.priority] += 1; return true; } private rejectPending(error: Error | DOMException): void { for (const bucket of this.buckets.values()) { - for (const queued of bucket.pending.splice(0)) { - queued.reject(error); - this.queuedRequests = Math.max(0, this.queuedRequests - 1); + for (const lane of requestPriorities) { + for (const queued of bucket.pending[lane].splice(0)) { + queued.reject(error); + this.queuedRequests = Math.max(0, this.queuedRequests - 1); + this.queuedByLane[lane] = Math.max(0, this.queuedByLane[lane] - 1); + } } } } + + private buildLaneSchedule(lanes: Record) { + const schedule: RequestPriority[] = []; + for (const lane of requestPriorities) { + const weight = Math.max(1, Math.floor(lanes[lane].weight)); + for (let i = 0; i < weight; i += 1) { + schedule.push(lane); + } + } + return schedule.length > 0 ? schedule : [...requestPriorities]; + } + + private getOldestQueuedAge(lane: RequestPriority): number { + const now = Date.now(); + let oldest = 0; + for (const bucket of this.buckets.values()) { + const queued = bucket.pending[lane][0]; + if (!queued) { + continue; + } + oldest = Math.max(oldest, now - queued.enqueuedAt); + } + return oldest; + } } function isGlobalRateLimit(parsed: unknown): boolean { diff --git a/extensions/discord/src/internal/rest.test.ts b/extensions/discord/src/internal/rest.test.ts index aada8d0da95..a8c8b32ec5b 100644 --- a/extensions/discord/src/internal/rest.test.ts +++ b/extensions/discord/src/internal/rest.test.ts @@ -41,6 +41,76 @@ describe("RequestClient", () => { expect(client.queueSize).toBe(0); }); + it("dispatches critical interaction callbacks before older background requests", async () => { + const firstResponse = createDeferred(); + const responses = new Map>([ + ["/guilds/g1/roles", firstResponse.promise], + ["/interactions/123/token/callback", Promise.resolve(createJsonResponse({ ok: "critical" }))], + ["/guilds/g2/roles", Promise.resolve(createJsonResponse({ ok: "background" }))], + ]); + const fetchSpy = vi.fn(async (input: string | URL | Request) => { + const url = + typeof input === "string" ? input : input instanceof URL ? input.toString() : input.url; + const path = new URL(url).pathname.replace(/^\/api\/v\d+/, ""); + const response = responses.get(path); + if (!response) { + throw new Error(`unexpected request ${path}`); + } + return await response; + }); + const client = new RequestClient("test-token", { + fetch: fetchSpy, + scheduler: { maxConcurrency: 1 }, + }); + + const first = client.get("/guilds/g1/roles"); + const background = client.get("/guilds/g2/roles"); + const critical = client.post("/interactions/123/token/callback", { body: { type: 5 } }); + + await vi.waitFor(() => expect(fetchSpy).toHaveBeenCalledTimes(1)); + firstResponse.resolve(createJsonResponse({ ok: "first" })); + + await expect(first).resolves.toEqual({ ok: "first" }); + await expect(critical).resolves.toEqual({ ok: "critical" }); + await expect(background).resolves.toEqual({ ok: "background" }); + expect(fetchSpy.mock.calls.map(([input]) => new URL(readRequestUrl(input)).pathname)).toEqual([ + "/api/v10/guilds/g1/roles", + "/api/v10/interactions/123/token/callback", + "/api/v10/guilds/g2/roles", + ]); + }); + + it("drops stale background requests instead of replaying obsolete reads", async () => { + vi.useFakeTimers(); + vi.setSystemTime(0); + const firstResponse = createDeferred(); + const fetchSpy = vi.fn(async () => await firstResponse.promise); + const client = new RequestClient("test-token", { + fetch: fetchSpy, + scheduler: { + maxConcurrency: 1, + lanes: { background: { staleAfterMs: 50 } }, + }, + }); + + const first = client.get("/guilds/g1/roles"); + const stale = client.get("/guilds/g2/roles"); + await vi.waitFor(() => expect(fetchSpy).toHaveBeenCalledTimes(1)); + + await vi.advanceTimersByTimeAsync(51); + firstResponse.resolve(createJsonResponse({ ok: "first" })); + + await expect(first).resolves.toEqual({ ok: "first" }); + await expect(stale).rejects.toThrow(/Dropped stale background request/); + expect(fetchSpy).toHaveBeenCalledTimes(1); + expect(client.getSchedulerMetrics()).toEqual( + expect.objectContaining({ + droppedByLane: expect.objectContaining({ background: 1 }), + queueSize: 0, + }), + ); + }); + it("runs independent route buckets concurrently", async () => { const channelResponse = createDeferred(); const guildResponse = createDeferred(); @@ -508,3 +578,7 @@ describe("RequestClient", () => { expect(form.get("payload_json")).toBeNull(); }); }); + +function readRequestUrl(input: string | URL | Request): string { + return typeof input === "string" ? input : input instanceof URL ? input.toString() : input.url; +} diff --git a/extensions/discord/src/internal/rest.ts b/extensions/discord/src/internal/rest.ts index b15788d1617..66eeffe6d6b 100644 --- a/extensions/discord/src/internal/rest.ts +++ b/extensions/discord/src/internal/rest.ts @@ -9,14 +9,21 @@ import { readRetryAfter, } from "./rest-errors.js"; import { appendQuery, createRouteKey } from "./rest-routes.js"; -import { RestScheduler, type RequestQuery } from "./rest-scheduler.js"; +import { + RestScheduler, + type RequestPriority as RestRequestPriority, + type RequestQuery, +} from "./rest-scheduler.js"; import { isDiscordRateLimitBody } from "./schemas.js"; export { DiscordError, RateLimitError } from "./rest-errors.js"; export type RuntimeProfile = "serverless" | "persistent"; -export type RequestPriority = "critical" | "standard" | "background"; +export type RequestPriority = RestRequestPriority; export type RequestSchedulerOptions = { + lanes?: Partial< + Record + >; maxConcurrency?: number; maxRateLimitRetries?: number; }; @@ -63,6 +70,11 @@ const defaultOptions = { }; const DEFAULT_MAX_CONCURRENT_WORKERS = 4; +const defaultLaneOptions: Record = { + critical: { weight: 6 }, + standard: { staleAfterMs: 60_000, weight: 3 }, + background: { staleAfterMs: 20_000, weight: 1 }, +}; function coerceResponseBody(raw: string): unknown { if (!raw) { @@ -134,9 +146,13 @@ export class RequestClient { this.options = { ...defaultOptions, ...options }; this.scheduler = new RestScheduler( { + lanes: normalizeSchedulerLanes( + this.options.maxQueueSize ?? defaultOptions.maxQueueSize, + this.options.scheduler?.lanes, + ), maxConcurrency: this.options.scheduler?.maxConcurrency ?? DEFAULT_MAX_CONCURRENT_WORKERS, - maxRateLimitRetries: this.options.scheduler?.maxRateLimitRetries ?? 3, maxQueueSize: this.options.maxQueueSize ?? defaultOptions.maxQueueSize, + maxRateLimitRetries: this.options.scheduler?.maxRateLimitRetries ?? 3, }, async (request) => await this.executeRequest( @@ -177,7 +193,12 @@ export class RequestClient { if (!this.options.queueRequests) { return await this.executeRequest(method, path, params, routeKey); } - return await this.scheduler.enqueue({ method, path, ...params }); + return await this.scheduler.enqueue({ + method, + path, + priority: getRequestPriority(method, path), + ...params, + }); } protected async executeRequest( @@ -258,3 +279,53 @@ export class RequestClient { this.requestControllers.clear(); } } + +function normalizeSchedulerLanes( + maxQueueSize: number, + lanes?: RequestSchedulerOptions["lanes"], +): Record { + const fallbackMaxQueueSize = Math.max(1, Math.floor(maxQueueSize)); + return { + critical: normalizeSchedulerLane("critical", fallbackMaxQueueSize, lanes?.critical), + standard: normalizeSchedulerLane("standard", fallbackMaxQueueSize, lanes?.standard), + background: normalizeSchedulerLane("background", fallbackMaxQueueSize, lanes?.background), + }; +} + +function normalizeSchedulerLane( + lane: RestRequestPriority, + maxQueueSize: number, + options?: { maxQueueSize?: number; staleAfterMs?: number; weight?: number }, +): { maxQueueSize: number; staleAfterMs?: number; weight: number } { + const defaults = defaultLaneOptions[lane]; + return { + maxQueueSize: + options?.maxQueueSize !== undefined + ? Math.max(1, Math.floor(options.maxQueueSize)) + : maxQueueSize, + staleAfterMs: + options?.staleAfterMs !== undefined + ? Math.max(0, Math.floor(options.staleAfterMs)) + : defaults.staleAfterMs, + weight: + options?.weight !== undefined ? Math.max(1, Math.floor(options.weight)) : defaults.weight, + }; +} + +function getRequestPriority(method: string, path: string): RestRequestPriority { + const normalizedMethod = method.toUpperCase(); + const normalizedPath = path.toLowerCase(); + if (/^\/interactions\/\d+\/[^/]+\/callback$/.test(normalizedPath)) { + return "critical"; + } + if ( + normalizedPath.startsWith("/webhooks/") && + (normalizedMethod === "POST" || normalizedMethod === "PATCH" || normalizedMethod === "DELETE") + ) { + return "standard"; + } + if (normalizedMethod !== "GET" && /\/channels\/\d+\/messages/.test(normalizedPath)) { + return "standard"; + } + return "background"; +}