diff --git a/extensions/matrix/src/matrix/monitor/direct.test.ts b/extensions/matrix/src/matrix/monitor/direct.test.ts index 9e8fcb2a030..905771a4b85 100644 --- a/extensions/matrix/src/matrix/monitor/direct.test.ts +++ b/extensions/matrix/src/matrix/monitor/direct.test.ts @@ -8,14 +8,14 @@ function createMockClient(params: { selfDirect?: boolean; members?: string[]; }) { - const members = params.members ?? ["@alice:example.org", "@bot:example.org"]; + let members = params.members ?? ["@alice:example.org", "@bot:example.org"]; return { dms: { update: vi.fn().mockResolvedValue(undefined), isDm: vi.fn().mockReturnValue(params.isDm === true), }, getUserId: vi.fn().mockResolvedValue("@bot:example.org"), - getJoinedRoomMembers: vi.fn().mockResolvedValue(members), + getJoinedRoomMembers: vi.fn().mockImplementation(async () => members), getRoomStateEvent: vi .fn() .mockImplementation(async (_roomId: string, eventType: string, stateKey: string) => { @@ -27,6 +27,9 @@ function createMockClient(params: { } return {}; }), + __setMembers(next: string[]) { + members = next; + }, } as unknown as MatrixClient; } @@ -98,6 +101,33 @@ describe("createDirectRoomTracker", () => { ).resolves.toBe(false); }); + it("re-checks room membership after invalidation when a DM gains extra members", async () => { + const client = createMockClient({ isDm: true }); + const tracker = createDirectRoomTracker(client); + + await expect( + tracker.isDirectMessage({ + roomId: "!room:example.org", + senderId: "@alice:example.org", + }), + ).resolves.toBe(true); + + (client as MatrixClient & { __setMembers: (members: string[]) => void }).__setMembers([ + "@alice:example.org", + "@bot:example.org", + "@mallory:example.org", + ]); + + tracker.invalidateRoom("!room:example.org"); + + await expect( + tracker.isDirectMessage({ + roomId: "!room:example.org", + senderId: "@alice:example.org", + }), + ).resolves.toBe(false); + }); + it("still recognizes exact 2-member rooms when member state also claims is_direct", async () => { const tracker = createDirectRoomTracker(createMockClient({ senderDirect: true })); await expect( diff --git a/extensions/matrix/src/matrix/monitor/direct.ts b/extensions/matrix/src/matrix/monitor/direct.ts index 38518673655..e58580f04f0 100644 --- a/extensions/matrix/src/matrix/monitor/direct.ts +++ b/extensions/matrix/src/matrix/monitor/direct.ts @@ -64,6 +64,11 @@ export function createDirectRoomTracker(client: MatrixClient, opts: DirectRoomTr }; return { + invalidateRoom: (roomId: string): void => { + joinedMembersCache.delete(roomId); + lastDmUpdateMs = 0; + log(`matrix: invalidated dm cache room=${roomId}`); + }, isDirectMessage: async (params: DirectMessageCheck): Promise => { const { roomId, senderId } = params; await refreshDmCache(); diff --git a/extensions/matrix/src/matrix/monitor/events.test.ts b/extensions/matrix/src/matrix/monitor/events.test.ts index af352953d43..fb7e28cde4e 100644 --- a/extensions/matrix/src/matrix/monitor/events.test.ts +++ b/extensions/matrix/src/matrix/monitor/events.test.ts @@ -37,6 +37,7 @@ function createHarness(params?: { const onRoomMessage = vi.fn(async () => {}); const listVerifications = vi.fn(async () => params?.verifications ?? []); const sendMessage = vi.fn(async () => "$notice"); + const invalidateRoom = vi.fn(); const logger = { info: vi.fn(), warn: vi.fn(), error: vi.fn() }; const formatNativeDependencyHint = vi.fn(() => "install hint"); const client = { @@ -66,6 +67,9 @@ function createHarness(params?: { accountId: params?.accountId ?? "default", encryption: params?.authEncryption ?? true, } as MatrixAuth, + directTracker: { + invalidateRoom, + }, logVerboseMessage: vi.fn(), warnedEncryptedRooms: new Set(), warnedCryptoMissingRooms: new Set(), @@ -82,6 +86,7 @@ function createHarness(params?: { return { onRoomMessage, sendMessage, + invalidateRoom, roomEventListener, listVerifications, logger, @@ -117,6 +122,23 @@ describe("registerMatrixMonitorEvents verification routing", () => { expect(sendMessage).not.toHaveBeenCalled(); }); + it("invalidates direct-room membership cache on room member events", async () => { + const { invalidateRoom, roomEventListener } = createHarness(); + + roomEventListener("!room:example.org", { + event_id: "$member1", + sender: "@alice:example.org", + state_key: "@mallory:example.org", + type: EventType.RoomMember, + origin_server_ts: Date.now(), + content: { + membership: "join", + }, + }); + + expect(invalidateRoom).toHaveBeenCalledWith("!room:example.org"); + }); + it("posts verification request notices directly into the room", async () => { const { onRoomMessage, sendMessage, roomMessageListener } = createHarness(); if (!roomMessageListener) { diff --git a/extensions/matrix/src/matrix/monitor/events.ts b/extensions/matrix/src/matrix/monitor/events.ts index 4020b5f7dcb..f4718efc59f 100644 --- a/extensions/matrix/src/matrix/monitor/events.ts +++ b/extensions/matrix/src/matrix/monitor/events.ts @@ -11,6 +11,9 @@ export function registerMatrixMonitorEvents(params: { cfg: CoreConfig; client: MatrixClient; auth: MatrixAuth; + directTracker?: { + invalidateRoom: (roomId: string) => void; + }; logVerboseMessage: (message: string) => void; warnedEncryptedRooms: Set; warnedCryptoMissingRooms: Set; @@ -22,6 +25,7 @@ export function registerMatrixMonitorEvents(params: { cfg, client, auth, + directTracker, logVerboseMessage, warnedEncryptedRooms, warnedCryptoMissingRooms, @@ -68,6 +72,7 @@ export function registerMatrixMonitorEvents(params: { ); client.on("room.invite", (roomId: string, event: MatrixRawEvent) => { + directTracker?.invalidateRoom(roomId); const eventId = event?.event_id ?? "unknown"; const sender = event?.sender ?? "unknown"; const isDirect = (event?.content as { is_direct?: boolean } | undefined)?.is_direct === true; @@ -77,6 +82,7 @@ export function registerMatrixMonitorEvents(params: { }); client.on("room.join", (roomId: string, event: MatrixRawEvent) => { + directTracker?.invalidateRoom(roomId); const eventId = event?.event_id ?? "unknown"; logVerboseMessage(`matrix: join room=${roomId} id=${eventId}`); }); @@ -105,6 +111,7 @@ export function registerMatrixMonitorEvents(params: { return; } if (eventType === EventType.RoomMember) { + directTracker?.invalidateRoom(roomId); const membership = (event?.content as { membership?: string } | undefined)?.membership; const stateKey = (event as { state_key?: string }).state_key ?? ""; logVerboseMessage( diff --git a/extensions/matrix/src/matrix/monitor/index.ts b/extensions/matrix/src/matrix/monitor/index.ts index 96314fc6aea..49ea6f99376 100644 --- a/extensions/matrix/src/matrix/monitor/index.ts +++ b/extensions/matrix/src/matrix/monitor/index.ts @@ -212,6 +212,7 @@ export async function monitorMatrixProvider(opts: MonitorMatrixOpts = {}): Promi cfg, client, auth, + directTracker, logVerboseMessage, warnedEncryptedRooms, warnedCryptoMissingRooms,