From 3ebfd38e12a4c16655e70678d8bb935ce945d2a0 Mon Sep 17 00:00:00 2001 From: Gustavo Madeira Santana Date: Thu, 12 Mar 2026 04:36:03 +0000 Subject: [PATCH] Matrix: harden DM and verification routing --- .../matrix/src/matrix/monitor/direct.test.ts | 40 ++++++---- .../matrix/src/matrix/monitor/direct.ts | 79 ++++++------------- .../matrix/src/matrix/monitor/events.test.ts | 57 ++++++++++++- .../src/matrix/monitor/verification-events.ts | 37 ++++++++- .../matrix/src/matrix/send/targets.test.ts | 34 +++++++- extensions/matrix/src/matrix/send/targets.ts | 40 +++++++--- 6 files changed, 199 insertions(+), 88 deletions(-) diff --git a/extensions/matrix/src/matrix/monitor/direct.test.ts b/extensions/matrix/src/matrix/monitor/direct.test.ts index c688a9ee1a7..dfd272cdbc5 100644 --- a/extensions/matrix/src/matrix/monitor/direct.test.ts +++ b/extensions/matrix/src/matrix/monitor/direct.test.ts @@ -7,8 +7,6 @@ function createMockClient(params: { senderDirect?: boolean; selfDirect?: boolean; members?: string[]; - roomName?: string | null; - roomNameError?: unknown; }) { const members = params.members ?? ["@alice:example.org", "@bot:example.org"]; return { @@ -21,12 +19,6 @@ function createMockClient(params: { getRoomStateEvent: vi .fn() .mockImplementation(async (_roomId: string, eventType: string, stateKey: string) => { - if (eventType === "m.room.name") { - if (params.roomNameError) { - throw params.roomNameError; - } - return params.roomName == null ? {} : { name: params.roomName }; - } if (stateKey === "@alice:example.org") { return { is_direct: params.senderDirect === true }; } @@ -61,8 +53,13 @@ describe("createDirectRoomTracker", () => { expect(client.getJoinedRoomMembers).toHaveBeenCalledWith("!room:example.org"); }); - it("does not classify named 2-member rooms as DMs from member count alone", async () => { - const tracker = createDirectRoomTracker(createMockClient({ isDm: false, roomName: "Project" })); + it("does not classify rooms with extra members as DMs", async () => { + const tracker = createDirectRoomTracker( + createMockClient({ + isDm: false, + members: ["@alice:example.org", "@bot:example.org", "@observer:example.org"], + }), + ); await expect( tracker.isDirectMessage({ roomId: "!room:example.org", @@ -71,11 +68,11 @@ describe("createDirectRoomTracker", () => { ).resolves.toBe(false); }); - it("treats missing room names as DM fallback for 2-member rooms", async () => { + it("does not classify 2-member rooms whose sender is not a joined member as DMs", async () => { const tracker = createDirectRoomTracker( createMockClient({ isDm: false, - roomNameError: { errcode: "M_NOT_FOUND" }, + members: ["@mallory:example.org", "@bot:example.org"], }), ); await expect( @@ -83,10 +80,10 @@ describe("createDirectRoomTracker", () => { roomId: "!room:example.org", senderId: "@alice:example.org", }), - ).resolves.toBe(true); + ).resolves.toBe(false); }); - it("uses is_direct member flags when present", async () => { + it("still recognizes exact 2-member rooms when member state also claims is_direct", async () => { const tracker = createDirectRoomTracker(createMockClient({ senderDirect: true })); await expect( tracker.isDirectMessage({ @@ -95,4 +92,19 @@ describe("createDirectRoomTracker", () => { }), ).resolves.toBe(true); }); + + it("ignores member-state is_direct when the room is not a strict DM", async () => { + const tracker = createDirectRoomTracker( + createMockClient({ + senderDirect: true, + members: ["@alice:example.org", "@bot:example.org", "@observer:example.org"], + }), + ); + await expect( + tracker.isDirectMessage({ + roomId: "!room:example.org", + senderId: "@alice:example.org", + }), + ).resolves.toBe(false); + }); }); diff --git a/extensions/matrix/src/matrix/monitor/direct.ts b/extensions/matrix/src/matrix/monitor/direct.ts index a74b2b61975..2a2ae2a9769 100644 --- a/extensions/matrix/src/matrix/monitor/direct.ts +++ b/extensions/matrix/src/matrix/monitor/direct.ts @@ -12,19 +12,11 @@ type DirectRoomTrackerOptions = { const DM_CACHE_TTL_MS = 30_000; -function isMatrixNotFoundError(err: unknown): boolean { - if (typeof err !== "object" || err === null) { - return false; - } - const value = err as { errcode?: string; statusCode?: number }; - return value.errcode === "M_NOT_FOUND" || value.statusCode === 404; -} - export function createDirectRoomTracker(client: MatrixClient, opts: DirectRoomTrackerOptions = {}) { const log = opts.log ?? (() => {}); let lastDmUpdateMs = 0; let cachedSelfUserId: string | null = null; - const memberCountCache = new Map(); + const joinedMembersCache = new Map(); const ensureSelfUserId = async (): Promise => { if (cachedSelfUserId) { @@ -51,36 +43,26 @@ export function createDirectRoomTracker(client: MatrixClient, opts: DirectRoomTr } }; - const resolveMemberCount = async (roomId: string): Promise => { - const cached = memberCountCache.get(roomId); + const resolveJoinedMembers = async (roomId: string): Promise => { + const cached = joinedMembersCache.get(roomId); const now = Date.now(); if (cached && now - cached.ts < DM_CACHE_TTL_MS) { - return cached.count; + return cached.members; } try { const members = await client.getJoinedRoomMembers(roomId); - const count = members.length; - memberCountCache.set(roomId, { count, ts: now }); - return count; + const normalized = members + .filter((entry): entry is string => typeof entry === "string") + .map((entry) => entry.trim()) + .filter(Boolean); + joinedMembersCache.set(roomId, { members: normalized, ts: now }); + return normalized; } catch (err) { - log(`matrix: dm member count failed room=${roomId} (${String(err)})`); + log(`matrix: dm member lookup failed room=${roomId} (${String(err)})`); return null; } }; - const hasDirectFlag = async (roomId: string, userId?: string): Promise => { - const target = userId?.trim(); - if (!target) { - return false; - } - try { - const state = await client.getRoomStateEvent(roomId, "m.room.member", target); - return state?.is_direct === true; - } catch { - return false; - } - }; - return { isDirectMessage: async (params: DirectMessageCheck): Promise => { const { roomId, senderId } = params; @@ -92,35 +74,22 @@ export function createDirectRoomTracker(client: MatrixClient, opts: DirectRoomTr } const selfUserId = params.selfUserId ?? (await ensureSelfUserId()); - const directViaState = - (await hasDirectFlag(roomId, senderId)) || (await hasDirectFlag(roomId, selfUserId ?? "")); - if (directViaState) { - log(`matrix: dm detected via member state room=${roomId}`); + const joinedMembers = await resolveJoinedMembers(roomId); + const normalizedSenderId = senderId?.trim(); + if ( + selfUserId && + normalizedSenderId && + joinedMembers?.length === 2 && + joinedMembers.includes(selfUserId) && + joinedMembers.includes(normalizedSenderId) + ) { + log(`matrix: dm detected via exact 2-member room room=${roomId}`); return true; } - const memberCount = await resolveMemberCount(roomId); - if (memberCount === 2) { - try { - const nameState = (await client.getRoomStateEvent(roomId, "m.room.name", "")) as { - name?: string | null; - } | null; - if (!nameState?.name?.trim()) { - log(`matrix: dm detected via fallback (2 members, no room name) room=${roomId}`); - return true; - } - } catch (err: unknown) { - if (isMatrixNotFoundError(err)) { - log(`matrix: dm detected via fallback (2 members, no room name) room=${roomId}`); - return true; - } - log( - `matrix: dm fallback skipped (room name check failed: ${String(err)}) room=${roomId}`, - ); - } - } - - log(`matrix: dm check room=${roomId} result=group members=${memberCount ?? "unknown"}`); + log( + `matrix: dm check room=${roomId} result=group members=${joinedMembers?.length ?? "unknown"}`, + ); return false; }, }; diff --git a/extensions/matrix/src/matrix/monitor/events.test.ts b/extensions/matrix/src/matrix/monitor/events.test.ts index 5e669a6c11e..b144721f8d2 100644 --- a/extensions/matrix/src/matrix/monitor/events.test.ts +++ b/extensions/matrix/src/matrix/monitor/events.test.ts @@ -19,6 +19,8 @@ function createHarness(params?: { accountId?: string; authEncryption?: boolean; cryptoAvailable?: boolean; + selfUserId?: string; + joinedMembersByRoom?: Record; verifications?: Array<{ id: string; transactionId?: string; @@ -43,6 +45,10 @@ function createHarness(params?: { return client; }), sendMessage, + getUserId: vi.fn(async () => params?.selfUserId ?? "@bot:example.org"), + getJoinedRoomMembers: vi.fn( + async (roomId: string) => params?.joinedMembersByRoom?.[roomId] ?? [], + ), ...(params?.cryptoAvailable === false ? {} : { @@ -157,6 +163,9 @@ describe("registerMatrixMonitorEvents verification routing", () => { it("posts SAS emoji/decimal details when verification summaries expose them", async () => { const { sendMessage, roomEventListener, listVerifications } = createHarness({ + joinedMembersByRoom: { + "!dm:example.org": ["@alice:example.org", "@bot:example.org"], + }, verifications: [ { id: "verification-1", @@ -175,7 +184,7 @@ describe("registerMatrixMonitorEvents verification routing", () => { ], }); - roomEventListener("!room:example.org", { + roomEventListener("!dm:example.org", { event_id: "$start2", sender: "@alice:example.org", type: "m.key.verification.start", @@ -194,6 +203,52 @@ describe("registerMatrixMonitorEvents verification routing", () => { }); }); + it("does not leak SAS details into unrelated non-DM rooms when flow ids do not match", async () => { + const { sendMessage, roomEventListener } = createHarness({ + joinedMembersByRoom: { + "!group:example.org": ["@alice:example.org", "@bot:example.org", "@ops:example.org"], + }, + verifications: [ + { + id: "verification-2", + transactionId: "$different-flow-id", + otherUserId: "@alice:example.org", + updatedAt: new Date("2026-02-25T21:42:54.000Z").toISOString(), + sas: { + decimal: [6158, 1986, 3513], + emoji: [ + ["🎁", "Gift"], + ["🌍", "Globe"], + ["🐴", "Horse"], + ], + }, + }, + ], + }); + + roomEventListener("!group:example.org", { + event_id: "$start-group", + sender: "@alice:example.org", + type: "m.key.verification.start", + origin_server_ts: Date.now(), + content: { + "m.relates_to": { event_id: "$req-group" }, + }, + }); + + await vi.waitFor(() => { + expect(sendMessage).toHaveBeenCalledTimes(1); + }); + expect(getSentNoticeBody(sendMessage, 0)).toContain( + "Matrix verification started with @alice:example.org.", + ); + expect( + (sendMessage.mock.calls as unknown[][]).some((call) => + String((call[1] as { body?: string } | undefined)?.body ?? "").includes("SAS emoji:"), + ), + ).toBe(false); + }); + it("does not emit duplicate SAS notices for the same verification payload", async () => { const { sendMessage, roomEventListener, listVerifications } = createHarness({ verifications: [ diff --git a/extensions/matrix/src/matrix/monitor/verification-events.ts b/extensions/matrix/src/matrix/monitor/verification-events.ts index 4eace5f11d1..776cec17fe5 100644 --- a/extensions/matrix/src/matrix/monitor/verification-events.ts +++ b/extensions/matrix/src/matrix/monitor/verification-events.ts @@ -161,9 +161,30 @@ function resolveSummaryRecency(summary: MatrixVerificationSummaryLike): number { return Number.isFinite(ts) ? ts : 0; } +async function isStrictDirectVerificationRoom(params: { + client: MatrixClient; + roomId: string; + senderId: string; +}): Promise { + const selfUserId = trimMaybeString(await params.client.getUserId().catch(() => null)); + if (!selfUserId) { + return false; + } + const joinedMembers = await params.client.getJoinedRoomMembers(params.roomId).catch(() => null); + if (!Array.isArray(joinedMembers) || joinedMembers.length !== 2) { + return false; + } + const normalizedMembers = joinedMembers + .filter((entry): entry is string => typeof entry === "string") + .map((entry) => entry.trim()) + .filter(Boolean); + return normalizedMembers.includes(selfUserId) && normalizedMembers.includes(params.senderId); +} + async function resolveVerificationSummaryForSignal( client: MatrixClient, params: { + roomId: string; event: MatrixRawEvent; senderId: string; flowId: string | null; @@ -187,7 +208,20 @@ async function resolveVerificationSummaryForSignal( return byTransactionId; } - // Fallback for flows where transaction IDs do not match room event IDs consistently. + // Only fall back by user inside the active DM with that user. Otherwise a + // spoofed verification event in an unrelated room can leak the current SAS + // prompt into that room. + if ( + !(await isStrictDirectVerificationRoom({ + client, + roomId: params.roomId, + senderId: params.senderId, + })) + ) { + return null; + } + + // Fallback for DM flows where transaction IDs do not match room event IDs consistently. const byUser = list .filter((entry) => entry.otherUserId === params.senderId && entry.completed !== true) .sort((a, b) => resolveSummaryRecency(b) - resolveSummaryRecency(a))[0]; @@ -257,6 +291,7 @@ export function createMatrixVerificationEventRouter(params: { const stageNotice = formatVerificationStageNotice({ stage: signal.stage, senderId, event }); const summary = await resolveVerificationSummaryForSignal(params.client, { + roomId, event, senderId, flowId, diff --git a/extensions/matrix/src/matrix/send/targets.test.ts b/extensions/matrix/src/matrix/send/targets.test.ts index d53f15ed0d9..280bf2a6005 100644 --- a/extensions/matrix/src/matrix/send/targets.test.ts +++ b/extensions/matrix/src/matrix/send/targets.test.ts @@ -17,8 +17,9 @@ describe("resolveMatrixRoomId", () => { getAccountData: vi.fn().mockResolvedValue({ [userId]: ["!room:example.org"], }), + getUserId: vi.fn().mockResolvedValue("@bot:example.org"), getJoinedRooms: vi.fn(), - getJoinedRoomMembers: vi.fn(), + getJoinedRoomMembers: vi.fn().mockResolvedValue(["@bot:example.org", userId]), setAccountData: vi.fn(), } as unknown as MatrixClient; @@ -37,6 +38,7 @@ describe("resolveMatrixRoomId", () => { const setAccountData = vi.fn().mockResolvedValue(undefined); const client = { getAccountData: vi.fn().mockRejectedValue(new Error("nope")), + getUserId: vi.fn().mockResolvedValue("@bot:example.org"), getJoinedRooms: vi.fn().mockResolvedValue([roomId]), getJoinedRoomMembers: vi.fn().mockResolvedValue(["@bot:example.org", userId]), setAccountData, @@ -61,6 +63,7 @@ describe("resolveMatrixRoomId", () => { .mockResolvedValueOnce(["@bot:example.org", userId]); const client = { getAccountData: vi.fn().mockRejectedValue(new Error("nope")), + getUserId: vi.fn().mockResolvedValue("@bot:example.org"), getJoinedRooms: vi.fn().mockResolvedValue(["!bad:example.org", roomId]), getJoinedRoomMembers, setAccountData, @@ -77,6 +80,7 @@ describe("resolveMatrixRoomId", () => { const roomId = "!group:example.org"; const client = { getAccountData: vi.fn().mockRejectedValue(new Error("nope")), + getUserId: vi.fn().mockResolvedValue("@bot:example.org"), getJoinedRooms: vi.fn().mockResolvedValue([roomId]), getJoinedRoomMembers: vi .fn() @@ -98,8 +102,9 @@ describe("resolveMatrixRoomId", () => { getAccountData: vi.fn().mockResolvedValue({ [userId]: [roomId], }), + getUserId: vi.fn().mockResolvedValue("@bot:example.org"), getJoinedRooms: vi.fn(), - getJoinedRoomMembers: vi.fn(), + getJoinedRoomMembers: vi.fn().mockResolvedValue(["@bot:example.org", userId]), setAccountData: vi.fn(), resolveRoom: vi.fn(), } as unknown as MatrixClient; @@ -117,8 +122,9 @@ describe("resolveMatrixRoomId", () => { getAccountData: vi.fn().mockResolvedValue({ [userId]: ["!room-a:example.org"], }), + getUserId: vi.fn().mockResolvedValue("@bot-a:example.org"), getJoinedRooms: vi.fn(), - getJoinedRoomMembers: vi.fn(), + getJoinedRoomMembers: vi.fn().mockResolvedValue(["@bot-a:example.org", userId]), setAccountData: vi.fn(), resolveRoom: vi.fn(), } as unknown as MatrixClient; @@ -126,8 +132,9 @@ describe("resolveMatrixRoomId", () => { getAccountData: vi.fn().mockResolvedValue({ [userId]: ["!room-b:example.org"], }), + getUserId: vi.fn().mockResolvedValue("@bot-b:example.org"), getJoinedRooms: vi.fn(), - getJoinedRoomMembers: vi.fn(), + getJoinedRoomMembers: vi.fn().mockResolvedValue(["@bot-b:example.org", userId]), setAccountData: vi.fn(), resolveRoom: vi.fn(), } as unknown as MatrixClient; @@ -140,6 +147,25 @@ describe("resolveMatrixRoomId", () => { // oxlint-disable-next-line typescript/unbound-method expect(clientB.getAccountData).toHaveBeenCalledTimes(1); }); + + it("ignores m.direct entries that point at shared rooms", async () => { + const userId = "@shared:example.org"; + const client = { + getAccountData: vi.fn().mockResolvedValue({ + [userId]: ["!shared-room:example.org", "!dm-room:example.org"], + }), + getUserId: vi.fn().mockResolvedValue("@bot:example.org"), + getJoinedRooms: vi.fn(), + getJoinedRoomMembers: vi + .fn() + .mockResolvedValueOnce(["@bot:example.org", userId, "@extra:example.org"]) + .mockResolvedValueOnce(["@bot:example.org", userId]), + setAccountData: vi.fn(), + resolveRoom: vi.fn(), + } as unknown as MatrixClient; + + await expect(resolveMatrixRoomId(client, userId)).resolves.toBe("!dm-room:example.org"); + }); }); describe("normalizeThreadId", () => { diff --git a/extensions/matrix/src/matrix/send/targets.ts b/extensions/matrix/src/matrix/send/targets.ts index 66498f993e3..d7abb2c5006 100644 --- a/extensions/matrix/src/matrix/send/targets.ts +++ b/extensions/matrix/src/matrix/send/targets.ts @@ -43,6 +43,26 @@ function setDirectRoomCached(client: MatrixClient, key: string, value: string): } } +async function isStrictDirectRoom( + client: MatrixClient, + roomId: string, + remoteUserId: string, + selfUserId: string | null, +): Promise { + if (!selfUserId) { + return false; + } + let members: string[]; + try { + members = await client.getJoinedRoomMembers(roomId); + } catch { + return false; + } + return ( + members.length === 2 && members.includes(remoteUserId.trim()) && members.includes(selfUserId) + ); +} + async function persistDirectRoom( client: MatrixClient, userId: string, @@ -83,6 +103,7 @@ async function resolveDirectRoomId(client: MatrixClient, userId: string): Promis if (cached) { return cached; } + const selfUserId = (await client.getUserId().catch(() => null))?.trim() || null; // 1) Fast path: use account data (m.direct) for *this* logged-in user (the bot). try { @@ -91,9 +112,11 @@ async function resolveDirectRoomId(client: MatrixClient, userId: string): Promis string[] | undefined >; const list = Array.isArray(directContent?.[trimmed]) ? directContent[trimmed] : []; - if (list && list.length > 0) { - setDirectRoomCached(client, trimmed, list[0]); - return list[0]; + for (const roomId of list) { + if (await isStrictDirectRoom(client, roomId, trimmed, selfUserId)) { + setDirectRoomCached(client, trimmed, roomId); + return roomId; + } } } catch { // Ignore and fall back. @@ -104,16 +127,7 @@ async function resolveDirectRoomId(client: MatrixClient, userId: string): Promis try { const rooms = await client.getJoinedRooms(); for (const roomId of rooms) { - let members: string[]; - try { - members = await client.getJoinedRoomMembers(roomId); - } catch { - continue; - } - if (!members.includes(trimmed)) { - continue; - } - if (members.length === 2) { + if (await isStrictDirectRoom(client, roomId, trimmed, selfUserId)) { setDirectRoomCached(client, trimmed, roomId); await persistDirectRoom(client, trimmed, roomId); return roomId;