Matrix: harden DM and verification routing

This commit is contained in:
Gustavo Madeira Santana
2026-03-12 04:36:03 +00:00
parent 63fc3c780b
commit 3ebfd38e12
6 changed files with 199 additions and 88 deletions

View File

@@ -7,8 +7,6 @@ function createMockClient(params: {
senderDirect?: boolean;
selfDirect?: boolean;
members?: string[];
roomName?: string | null;
roomNameError?: unknown;
}) {
const members = params.members ?? ["@alice:example.org", "@bot:example.org"];
return {
@@ -21,12 +19,6 @@ function createMockClient(params: {
getRoomStateEvent: vi
.fn()
.mockImplementation(async (_roomId: string, eventType: string, stateKey: string) => {
if (eventType === "m.room.name") {
if (params.roomNameError) {
throw params.roomNameError;
}
return params.roomName == null ? {} : { name: params.roomName };
}
if (stateKey === "@alice:example.org") {
return { is_direct: params.senderDirect === true };
}
@@ -61,8 +53,13 @@ describe("createDirectRoomTracker", () => {
expect(client.getJoinedRoomMembers).toHaveBeenCalledWith("!room:example.org");
});
it("does not classify named 2-member rooms as DMs from member count alone", async () => {
const tracker = createDirectRoomTracker(createMockClient({ isDm: false, roomName: "Project" }));
it("does not classify rooms with extra members as DMs", async () => {
const tracker = createDirectRoomTracker(
createMockClient({
isDm: false,
members: ["@alice:example.org", "@bot:example.org", "@observer:example.org"],
}),
);
await expect(
tracker.isDirectMessage({
roomId: "!room:example.org",
@@ -71,11 +68,11 @@ describe("createDirectRoomTracker", () => {
).resolves.toBe(false);
});
it("treats missing room names as DM fallback for 2-member rooms", async () => {
it("does not classify 2-member rooms whose sender is not a joined member as DMs", async () => {
const tracker = createDirectRoomTracker(
createMockClient({
isDm: false,
roomNameError: { errcode: "M_NOT_FOUND" },
members: ["@mallory:example.org", "@bot:example.org"],
}),
);
await expect(
@@ -83,10 +80,10 @@ describe("createDirectRoomTracker", () => {
roomId: "!room:example.org",
senderId: "@alice:example.org",
}),
).resolves.toBe(true);
).resolves.toBe(false);
});
it("uses is_direct member flags when present", async () => {
it("still recognizes exact 2-member rooms when member state also claims is_direct", async () => {
const tracker = createDirectRoomTracker(createMockClient({ senderDirect: true }));
await expect(
tracker.isDirectMessage({
@@ -95,4 +92,19 @@ describe("createDirectRoomTracker", () => {
}),
).resolves.toBe(true);
});
it("ignores member-state is_direct when the room is not a strict DM", async () => {
const tracker = createDirectRoomTracker(
createMockClient({
senderDirect: true,
members: ["@alice:example.org", "@bot:example.org", "@observer:example.org"],
}),
);
await expect(
tracker.isDirectMessage({
roomId: "!room:example.org",
senderId: "@alice:example.org",
}),
).resolves.toBe(false);
});
});

View File

@@ -12,19 +12,11 @@ type DirectRoomTrackerOptions = {
const DM_CACHE_TTL_MS = 30_000;
function isMatrixNotFoundError(err: unknown): boolean {
if (typeof err !== "object" || err === null) {
return false;
}
const value = err as { errcode?: string; statusCode?: number };
return value.errcode === "M_NOT_FOUND" || value.statusCode === 404;
}
export function createDirectRoomTracker(client: MatrixClient, opts: DirectRoomTrackerOptions = {}) {
const log = opts.log ?? (() => {});
let lastDmUpdateMs = 0;
let cachedSelfUserId: string | null = null;
const memberCountCache = new Map<string, { count: number; ts: number }>();
const joinedMembersCache = new Map<string, { members: string[]; ts: number }>();
const ensureSelfUserId = async (): Promise<string | null> => {
if (cachedSelfUserId) {
@@ -51,36 +43,26 @@ export function createDirectRoomTracker(client: MatrixClient, opts: DirectRoomTr
}
};
const resolveMemberCount = async (roomId: string): Promise<number | null> => {
const cached = memberCountCache.get(roomId);
const resolveJoinedMembers = async (roomId: string): Promise<string[] | null> => {
const cached = joinedMembersCache.get(roomId);
const now = Date.now();
if (cached && now - cached.ts < DM_CACHE_TTL_MS) {
return cached.count;
return cached.members;
}
try {
const members = await client.getJoinedRoomMembers(roomId);
const count = members.length;
memberCountCache.set(roomId, { count, ts: now });
return count;
const normalized = members
.filter((entry): entry is string => typeof entry === "string")
.map((entry) => entry.trim())
.filter(Boolean);
joinedMembersCache.set(roomId, { members: normalized, ts: now });
return normalized;
} catch (err) {
log(`matrix: dm member count failed room=${roomId} (${String(err)})`);
log(`matrix: dm member lookup failed room=${roomId} (${String(err)})`);
return null;
}
};
const hasDirectFlag = async (roomId: string, userId?: string): Promise<boolean> => {
const target = userId?.trim();
if (!target) {
return false;
}
try {
const state = await client.getRoomStateEvent(roomId, "m.room.member", target);
return state?.is_direct === true;
} catch {
return false;
}
};
return {
isDirectMessage: async (params: DirectMessageCheck): Promise<boolean> => {
const { roomId, senderId } = params;
@@ -92,35 +74,22 @@ export function createDirectRoomTracker(client: MatrixClient, opts: DirectRoomTr
}
const selfUserId = params.selfUserId ?? (await ensureSelfUserId());
const directViaState =
(await hasDirectFlag(roomId, senderId)) || (await hasDirectFlag(roomId, selfUserId ?? ""));
if (directViaState) {
log(`matrix: dm detected via member state room=${roomId}`);
const joinedMembers = await resolveJoinedMembers(roomId);
const normalizedSenderId = senderId?.trim();
if (
selfUserId &&
normalizedSenderId &&
joinedMembers?.length === 2 &&
joinedMembers.includes(selfUserId) &&
joinedMembers.includes(normalizedSenderId)
) {
log(`matrix: dm detected via exact 2-member room room=${roomId}`);
return true;
}
const memberCount = await resolveMemberCount(roomId);
if (memberCount === 2) {
try {
const nameState = (await client.getRoomStateEvent(roomId, "m.room.name", "")) as {
name?: string | null;
} | null;
if (!nameState?.name?.trim()) {
log(`matrix: dm detected via fallback (2 members, no room name) room=${roomId}`);
return true;
}
} catch (err: unknown) {
if (isMatrixNotFoundError(err)) {
log(`matrix: dm detected via fallback (2 members, no room name) room=${roomId}`);
return true;
}
log(
`matrix: dm fallback skipped (room name check failed: ${String(err)}) room=${roomId}`,
);
}
}
log(`matrix: dm check room=${roomId} result=group members=${memberCount ?? "unknown"}`);
log(
`matrix: dm check room=${roomId} result=group members=${joinedMembers?.length ?? "unknown"}`,
);
return false;
},
};

View File

@@ -19,6 +19,8 @@ function createHarness(params?: {
accountId?: string;
authEncryption?: boolean;
cryptoAvailable?: boolean;
selfUserId?: string;
joinedMembersByRoom?: Record<string, string[]>;
verifications?: Array<{
id: string;
transactionId?: string;
@@ -43,6 +45,10 @@ function createHarness(params?: {
return client;
}),
sendMessage,
getUserId: vi.fn(async () => params?.selfUserId ?? "@bot:example.org"),
getJoinedRoomMembers: vi.fn(
async (roomId: string) => params?.joinedMembersByRoom?.[roomId] ?? [],
),
...(params?.cryptoAvailable === false
? {}
: {
@@ -157,6 +163,9 @@ describe("registerMatrixMonitorEvents verification routing", () => {
it("posts SAS emoji/decimal details when verification summaries expose them", async () => {
const { sendMessage, roomEventListener, listVerifications } = createHarness({
joinedMembersByRoom: {
"!dm:example.org": ["@alice:example.org", "@bot:example.org"],
},
verifications: [
{
id: "verification-1",
@@ -175,7 +184,7 @@ describe("registerMatrixMonitorEvents verification routing", () => {
],
});
roomEventListener("!room:example.org", {
roomEventListener("!dm:example.org", {
event_id: "$start2",
sender: "@alice:example.org",
type: "m.key.verification.start",
@@ -194,6 +203,52 @@ describe("registerMatrixMonitorEvents verification routing", () => {
});
});
it("does not leak SAS details into unrelated non-DM rooms when flow ids do not match", async () => {
const { sendMessage, roomEventListener } = createHarness({
joinedMembersByRoom: {
"!group:example.org": ["@alice:example.org", "@bot:example.org", "@ops:example.org"],
},
verifications: [
{
id: "verification-2",
transactionId: "$different-flow-id",
otherUserId: "@alice:example.org",
updatedAt: new Date("2026-02-25T21:42:54.000Z").toISOString(),
sas: {
decimal: [6158, 1986, 3513],
emoji: [
["🎁", "Gift"],
["🌍", "Globe"],
["🐴", "Horse"],
],
},
},
],
});
roomEventListener("!group:example.org", {
event_id: "$start-group",
sender: "@alice:example.org",
type: "m.key.verification.start",
origin_server_ts: Date.now(),
content: {
"m.relates_to": { event_id: "$req-group" },
},
});
await vi.waitFor(() => {
expect(sendMessage).toHaveBeenCalledTimes(1);
});
expect(getSentNoticeBody(sendMessage, 0)).toContain(
"Matrix verification started with @alice:example.org.",
);
expect(
(sendMessage.mock.calls as unknown[][]).some((call) =>
String((call[1] as { body?: string } | undefined)?.body ?? "").includes("SAS emoji:"),
),
).toBe(false);
});
it("does not emit duplicate SAS notices for the same verification payload", async () => {
const { sendMessage, roomEventListener, listVerifications } = createHarness({
verifications: [

View File

@@ -161,9 +161,30 @@ function resolveSummaryRecency(summary: MatrixVerificationSummaryLike): number {
return Number.isFinite(ts) ? ts : 0;
}
async function isStrictDirectVerificationRoom(params: {
client: MatrixClient;
roomId: string;
senderId: string;
}): Promise<boolean> {
const selfUserId = trimMaybeString(await params.client.getUserId().catch(() => null));
if (!selfUserId) {
return false;
}
const joinedMembers = await params.client.getJoinedRoomMembers(params.roomId).catch(() => null);
if (!Array.isArray(joinedMembers) || joinedMembers.length !== 2) {
return false;
}
const normalizedMembers = joinedMembers
.filter((entry): entry is string => typeof entry === "string")
.map((entry) => entry.trim())
.filter(Boolean);
return normalizedMembers.includes(selfUserId) && normalizedMembers.includes(params.senderId);
}
async function resolveVerificationSummaryForSignal(
client: MatrixClient,
params: {
roomId: string;
event: MatrixRawEvent;
senderId: string;
flowId: string | null;
@@ -187,7 +208,20 @@ async function resolveVerificationSummaryForSignal(
return byTransactionId;
}
// Fallback for flows where transaction IDs do not match room event IDs consistently.
// Only fall back by user inside the active DM with that user. Otherwise a
// spoofed verification event in an unrelated room can leak the current SAS
// prompt into that room.
if (
!(await isStrictDirectVerificationRoom({
client,
roomId: params.roomId,
senderId: params.senderId,
}))
) {
return null;
}
// Fallback for DM flows where transaction IDs do not match room event IDs consistently.
const byUser = list
.filter((entry) => entry.otherUserId === params.senderId && entry.completed !== true)
.sort((a, b) => resolveSummaryRecency(b) - resolveSummaryRecency(a))[0];
@@ -257,6 +291,7 @@ export function createMatrixVerificationEventRouter(params: {
const stageNotice = formatVerificationStageNotice({ stage: signal.stage, senderId, event });
const summary = await resolveVerificationSummaryForSignal(params.client, {
roomId,
event,
senderId,
flowId,

View File

@@ -17,8 +17,9 @@ describe("resolveMatrixRoomId", () => {
getAccountData: vi.fn().mockResolvedValue({
[userId]: ["!room:example.org"],
}),
getUserId: vi.fn().mockResolvedValue("@bot:example.org"),
getJoinedRooms: vi.fn(),
getJoinedRoomMembers: vi.fn(),
getJoinedRoomMembers: vi.fn().mockResolvedValue(["@bot:example.org", userId]),
setAccountData: vi.fn(),
} as unknown as MatrixClient;
@@ -37,6 +38,7 @@ describe("resolveMatrixRoomId", () => {
const setAccountData = vi.fn().mockResolvedValue(undefined);
const client = {
getAccountData: vi.fn().mockRejectedValue(new Error("nope")),
getUserId: vi.fn().mockResolvedValue("@bot:example.org"),
getJoinedRooms: vi.fn().mockResolvedValue([roomId]),
getJoinedRoomMembers: vi.fn().mockResolvedValue(["@bot:example.org", userId]),
setAccountData,
@@ -61,6 +63,7 @@ describe("resolveMatrixRoomId", () => {
.mockResolvedValueOnce(["@bot:example.org", userId]);
const client = {
getAccountData: vi.fn().mockRejectedValue(new Error("nope")),
getUserId: vi.fn().mockResolvedValue("@bot:example.org"),
getJoinedRooms: vi.fn().mockResolvedValue(["!bad:example.org", roomId]),
getJoinedRoomMembers,
setAccountData,
@@ -77,6 +80,7 @@ describe("resolveMatrixRoomId", () => {
const roomId = "!group:example.org";
const client = {
getAccountData: vi.fn().mockRejectedValue(new Error("nope")),
getUserId: vi.fn().mockResolvedValue("@bot:example.org"),
getJoinedRooms: vi.fn().mockResolvedValue([roomId]),
getJoinedRoomMembers: vi
.fn()
@@ -98,8 +102,9 @@ describe("resolveMatrixRoomId", () => {
getAccountData: vi.fn().mockResolvedValue({
[userId]: [roomId],
}),
getUserId: vi.fn().mockResolvedValue("@bot:example.org"),
getJoinedRooms: vi.fn(),
getJoinedRoomMembers: vi.fn(),
getJoinedRoomMembers: vi.fn().mockResolvedValue(["@bot:example.org", userId]),
setAccountData: vi.fn(),
resolveRoom: vi.fn(),
} as unknown as MatrixClient;
@@ -117,8 +122,9 @@ describe("resolveMatrixRoomId", () => {
getAccountData: vi.fn().mockResolvedValue({
[userId]: ["!room-a:example.org"],
}),
getUserId: vi.fn().mockResolvedValue("@bot-a:example.org"),
getJoinedRooms: vi.fn(),
getJoinedRoomMembers: vi.fn(),
getJoinedRoomMembers: vi.fn().mockResolvedValue(["@bot-a:example.org", userId]),
setAccountData: vi.fn(),
resolveRoom: vi.fn(),
} as unknown as MatrixClient;
@@ -126,8 +132,9 @@ describe("resolveMatrixRoomId", () => {
getAccountData: vi.fn().mockResolvedValue({
[userId]: ["!room-b:example.org"],
}),
getUserId: vi.fn().mockResolvedValue("@bot-b:example.org"),
getJoinedRooms: vi.fn(),
getJoinedRoomMembers: vi.fn(),
getJoinedRoomMembers: vi.fn().mockResolvedValue(["@bot-b:example.org", userId]),
setAccountData: vi.fn(),
resolveRoom: vi.fn(),
} as unknown as MatrixClient;
@@ -140,6 +147,25 @@ describe("resolveMatrixRoomId", () => {
// oxlint-disable-next-line typescript/unbound-method
expect(clientB.getAccountData).toHaveBeenCalledTimes(1);
});
it("ignores m.direct entries that point at shared rooms", async () => {
const userId = "@shared:example.org";
const client = {
getAccountData: vi.fn().mockResolvedValue({
[userId]: ["!shared-room:example.org", "!dm-room:example.org"],
}),
getUserId: vi.fn().mockResolvedValue("@bot:example.org"),
getJoinedRooms: vi.fn(),
getJoinedRoomMembers: vi
.fn()
.mockResolvedValueOnce(["@bot:example.org", userId, "@extra:example.org"])
.mockResolvedValueOnce(["@bot:example.org", userId]),
setAccountData: vi.fn(),
resolveRoom: vi.fn(),
} as unknown as MatrixClient;
await expect(resolveMatrixRoomId(client, userId)).resolves.toBe("!dm-room:example.org");
});
});
describe("normalizeThreadId", () => {

View File

@@ -43,6 +43,26 @@ function setDirectRoomCached(client: MatrixClient, key: string, value: string):
}
}
async function isStrictDirectRoom(
client: MatrixClient,
roomId: string,
remoteUserId: string,
selfUserId: string | null,
): Promise<boolean> {
if (!selfUserId) {
return false;
}
let members: string[];
try {
members = await client.getJoinedRoomMembers(roomId);
} catch {
return false;
}
return (
members.length === 2 && members.includes(remoteUserId.trim()) && members.includes(selfUserId)
);
}
async function persistDirectRoom(
client: MatrixClient,
userId: string,
@@ -83,6 +103,7 @@ async function resolveDirectRoomId(client: MatrixClient, userId: string): Promis
if (cached) {
return cached;
}
const selfUserId = (await client.getUserId().catch(() => null))?.trim() || null;
// 1) Fast path: use account data (m.direct) for *this* logged-in user (the bot).
try {
@@ -91,9 +112,11 @@ async function resolveDirectRoomId(client: MatrixClient, userId: string): Promis
string[] | undefined
>;
const list = Array.isArray(directContent?.[trimmed]) ? directContent[trimmed] : [];
if (list && list.length > 0) {
setDirectRoomCached(client, trimmed, list[0]);
return list[0];
for (const roomId of list) {
if (await isStrictDirectRoom(client, roomId, trimmed, selfUserId)) {
setDirectRoomCached(client, trimmed, roomId);
return roomId;
}
}
} catch {
// Ignore and fall back.
@@ -104,16 +127,7 @@ async function resolveDirectRoomId(client: MatrixClient, userId: string): Promis
try {
const rooms = await client.getJoinedRooms();
for (const roomId of rooms) {
let members: string[];
try {
members = await client.getJoinedRoomMembers(roomId);
} catch {
continue;
}
if (!members.includes(trimmed)) {
continue;
}
if (members.length === 2) {
if (await isStrictDirectRoom(client, roomId, trimmed, selfUserId)) {
setDirectRoomCached(client, trimmed, roomId);
await persistDirectRoom(client, trimmed, roomId);
return roomId;