diff --git a/extensions/matrix/src/matrix/actions/client.ts b/extensions/matrix/src/matrix/actions/client.ts index cf11c494b8d..ae2a589cc88 100644 --- a/extensions/matrix/src/matrix/actions/client.ts +++ b/extensions/matrix/src/matrix/actions/client.ts @@ -58,3 +58,10 @@ export async function withResolvedActionClient( await stopActionClient(resolved, mode); } } + +export async function withStartedActionClient( + opts: MatrixActionClientOpts, + run: (client: MatrixActionClient["client"]) => Promise, +): Promise { + return await withResolvedActionClient({ ...opts, readiness: "started" }, run, "persist"); +} diff --git a/extensions/matrix/src/matrix/actions/devices.test.ts b/extensions/matrix/src/matrix/actions/devices.test.ts index 4a90d5e1925..727719a17d2 100644 --- a/extensions/matrix/src/matrix/actions/devices.test.ts +++ b/extensions/matrix/src/matrix/actions/devices.test.ts @@ -1,9 +1,9 @@ import { beforeEach, describe, expect, it, vi } from "vitest"; -const withResolvedActionClientMock = vi.fn(); +const withStartedActionClientMock = vi.fn(); vi.mock("./client.js", () => ({ - withResolvedActionClient: (...args: unknown[]) => withResolvedActionClientMock(...args), + withStartedActionClient: (...args: unknown[]) => withStartedActionClientMock(...args), })); let listMatrixOwnDevices: typeof import("./devices.js").listMatrixOwnDevices; @@ -17,7 +17,7 @@ describe("matrix device actions", () => { }); it("lists own devices on a started client", async () => { - withResolvedActionClientMock.mockImplementation(async (_opts, run) => { + withStartedActionClientMock.mockImplementation(async (_opts, run) => { return await run({ listOwnDevices: vi.fn(async () => [ { @@ -33,10 +33,9 @@ describe("matrix device actions", () => { const result = await listMatrixOwnDevices({ accountId: "poe" }); - expect(withResolvedActionClientMock).toHaveBeenCalledWith( - { accountId: "poe", readiness: "started" }, + expect(withStartedActionClientMock).toHaveBeenCalledWith( + { accountId: "poe" }, expect.any(Function), - "persist", ); expect(result).toEqual([ expect.objectContaining({ @@ -60,7 +59,7 @@ describe("matrix device actions", () => { }, ], })); - withResolvedActionClientMock.mockImplementation(async (_opts, run) => { + withStartedActionClientMock.mockImplementation(async (_opts, run) => { return await run({ listOwnDevices: vi.fn(async () => [ { diff --git a/extensions/matrix/src/matrix/actions/devices.ts b/extensions/matrix/src/matrix/actions/devices.ts index 452f4e72d7e..ab6769cbfb8 100644 --- a/extensions/matrix/src/matrix/actions/devices.ts +++ b/extensions/matrix/src/matrix/actions/devices.ts @@ -1,44 +1,34 @@ import { summarizeMatrixDeviceHealth } from "../device-health.js"; -import { withResolvedActionClient } from "./client.js"; +import { withStartedActionClient } from "./client.js"; import type { MatrixActionClientOpts } from "./types.js"; export async function listMatrixOwnDevices(opts: MatrixActionClientOpts = {}) { - return await withResolvedActionClient( - { ...opts, readiness: "started" }, - async (client) => await client.listOwnDevices(), - "persist", - ); + return await withStartedActionClient(opts, async (client) => await client.listOwnDevices()); } export async function pruneMatrixStaleGatewayDevices(opts: MatrixActionClientOpts = {}) { - return await withResolvedActionClient( - { ...opts, readiness: "started" }, - async (client) => { - const devices = await client.listOwnDevices(); - const health = summarizeMatrixDeviceHealth(devices); - const staleGatewayDeviceIds = health.staleOpenClawDevices.map((device) => device.deviceId); - const deleted = - staleGatewayDeviceIds.length > 0 - ? await client.deleteOwnDevices(staleGatewayDeviceIds) - : { - currentDeviceId: devices.find((device) => device.current)?.deviceId ?? null, - deletedDeviceIds: [] as string[], - remainingDevices: devices, - }; - return { - before: devices, - staleGatewayDeviceIds, - ...deleted, - }; - }, - "persist", - ); + return await withStartedActionClient(opts, async (client) => { + const devices = await client.listOwnDevices(); + const health = summarizeMatrixDeviceHealth(devices); + const staleGatewayDeviceIds = health.staleOpenClawDevices.map((device) => device.deviceId); + const deleted = + staleGatewayDeviceIds.length > 0 + ? await client.deleteOwnDevices(staleGatewayDeviceIds) + : { + currentDeviceId: devices.find((device) => device.current)?.deviceId ?? null, + deletedDeviceIds: [] as string[], + remainingDevices: devices, + }; + return { + before: devices, + staleGatewayDeviceIds, + ...deleted, + }; + }); } export async function getMatrixDeviceHealth(opts: MatrixActionClientOpts = {}) { - return await withResolvedActionClient( - { ...opts, readiness: "started" }, - async (client) => summarizeMatrixDeviceHealth(await client.listOwnDevices()), - "persist", + return await withStartedActionClient(opts, async (client) => + summarizeMatrixDeviceHealth(await client.listOwnDevices()), ); } diff --git a/extensions/matrix/src/matrix/actions/verification.test.ts b/extensions/matrix/src/matrix/actions/verification.test.ts index 2d1eb954cb1..668d9fe51fe 100644 --- a/extensions/matrix/src/matrix/actions/verification.test.ts +++ b/extensions/matrix/src/matrix/actions/verification.test.ts @@ -1,6 +1,6 @@ import { beforeEach, describe, expect, it, vi } from "vitest"; -const withResolvedActionClientMock = vi.fn(); +const withStartedActionClientMock = vi.fn(); const loadConfigMock = vi.fn(() => ({ channels: { matrix: {}, @@ -16,7 +16,7 @@ vi.mock("../../runtime.js", () => ({ })); vi.mock("./client.js", () => ({ - withResolvedActionClient: (...args: unknown[]) => withResolvedActionClientMock(...args), + withStartedActionClient: (...args: unknown[]) => withStartedActionClientMock(...args), })); let listMatrixVerifications: typeof import("./verification.js").listMatrixVerifications; @@ -45,7 +45,7 @@ describe("matrix verification actions", () => { }, }, }); - withResolvedActionClientMock.mockImplementation(async (_opts, run) => { + withStartedActionClientMock.mockImplementation(async (_opts, run) => { return await run({ crypto: null }); }); @@ -67,7 +67,7 @@ describe("matrix verification actions", () => { }, }, }); - withResolvedActionClientMock.mockImplementation(async (_opts, run) => { + withStartedActionClientMock.mockImplementation(async (_opts, run) => { return await run({ crypto: null }); }); diff --git a/extensions/matrix/src/matrix/actions/verification.ts b/extensions/matrix/src/matrix/actions/verification.ts index adad96e954d..f017d495930 100644 --- a/extensions/matrix/src/matrix/actions/verification.ts +++ b/extensions/matrix/src/matrix/actions/verification.ts @@ -1,7 +1,7 @@ import { getMatrixRuntime } from "../../runtime.js"; import type { CoreConfig } from "../../types.js"; import { formatMatrixEncryptionUnavailableError } from "../encryption-guidance.js"; -import { withResolvedActionClient } from "./client.js"; +import { withStartedActionClient } from "./client.js"; import type { MatrixActionClientOpts } from "./types.js"; function requireCrypto( @@ -24,14 +24,10 @@ function resolveVerificationId(input: string): string { } export async function listMatrixVerifications(opts: MatrixActionClientOpts = {}) { - return await withResolvedActionClient( - { ...opts, readiness: "started" }, - async (client) => { - const crypto = requireCrypto(client, opts); - return await crypto.listVerifications(); - }, - "persist", - ); + return await withStartedActionClient(opts, async (client) => { + const crypto = requireCrypto(client, opts); + return await crypto.listVerifications(); + }); } export async function requestMatrixVerification( @@ -42,79 +38,59 @@ export async function requestMatrixVerification( roomId?: string; } = {}, ) { - return await withResolvedActionClient( - { ...params, readiness: "started" }, - async (client) => { - const crypto = requireCrypto(client, params); - const ownUser = params.ownUser ?? (!params.userId && !params.deviceId && !params.roomId); - return await crypto.requestVerification({ - ownUser, - userId: params.userId?.trim() || undefined, - deviceId: params.deviceId?.trim() || undefined, - roomId: params.roomId?.trim() || undefined, - }); - }, - "persist", - ); + return await withStartedActionClient(params, async (client) => { + const crypto = requireCrypto(client, params); + const ownUser = params.ownUser ?? (!params.userId && !params.deviceId && !params.roomId); + return await crypto.requestVerification({ + ownUser, + userId: params.userId?.trim() || undefined, + deviceId: params.deviceId?.trim() || undefined, + roomId: params.roomId?.trim() || undefined, + }); + }); } export async function acceptMatrixVerification( requestId: string, opts: MatrixActionClientOpts = {}, ) { - return await withResolvedActionClient( - { ...opts, readiness: "started" }, - async (client) => { - const crypto = requireCrypto(client, opts); - return await crypto.acceptVerification(resolveVerificationId(requestId)); - }, - "persist", - ); + return await withStartedActionClient(opts, async (client) => { + const crypto = requireCrypto(client, opts); + return await crypto.acceptVerification(resolveVerificationId(requestId)); + }); } export async function cancelMatrixVerification( requestId: string, opts: MatrixActionClientOpts & { reason?: string; code?: string } = {}, ) { - return await withResolvedActionClient( - { ...opts, readiness: "started" }, - async (client) => { - const crypto = requireCrypto(client, opts); - return await crypto.cancelVerification(resolveVerificationId(requestId), { - reason: opts.reason?.trim() || undefined, - code: opts.code?.trim() || undefined, - }); - }, - "persist", - ); + return await withStartedActionClient(opts, async (client) => { + const crypto = requireCrypto(client, opts); + return await crypto.cancelVerification(resolveVerificationId(requestId), { + reason: opts.reason?.trim() || undefined, + code: opts.code?.trim() || undefined, + }); + }); } export async function startMatrixVerification( requestId: string, opts: MatrixActionClientOpts & { method?: "sas" } = {}, ) { - return await withResolvedActionClient( - { ...opts, readiness: "started" }, - async (client) => { - const crypto = requireCrypto(client, opts); - return await crypto.startVerification(resolveVerificationId(requestId), opts.method ?? "sas"); - }, - "persist", - ); + return await withStartedActionClient(opts, async (client) => { + const crypto = requireCrypto(client, opts); + return await crypto.startVerification(resolveVerificationId(requestId), opts.method ?? "sas"); + }); } export async function generateMatrixVerificationQr( requestId: string, opts: MatrixActionClientOpts = {}, ) { - return await withResolvedActionClient( - { ...opts, readiness: "started" }, - async (client) => { - const crypto = requireCrypto(client, opts); - return await crypto.generateVerificationQr(resolveVerificationId(requestId)); - }, - "persist", - ); + return await withStartedActionClient(opts, async (client) => { + const crypto = requireCrypto(client, opts); + return await crypto.generateVerificationQr(resolveVerificationId(requestId)); + }); } export async function scanMatrixVerificationQr( @@ -122,125 +98,96 @@ export async function scanMatrixVerificationQr( qrDataBase64: string, opts: MatrixActionClientOpts = {}, ) { - return await withResolvedActionClient( - { ...opts, readiness: "started" }, - async (client) => { - const crypto = requireCrypto(client, opts); - const payload = qrDataBase64.trim(); - if (!payload) { - throw new Error("Matrix QR data is required"); - } - return await crypto.scanVerificationQr(resolveVerificationId(requestId), payload); - }, - "persist", - ); + return await withStartedActionClient(opts, async (client) => { + const crypto = requireCrypto(client, opts); + const payload = qrDataBase64.trim(); + if (!payload) { + throw new Error("Matrix QR data is required"); + } + return await crypto.scanVerificationQr(resolveVerificationId(requestId), payload); + }); } export async function getMatrixVerificationSas( requestId: string, opts: MatrixActionClientOpts = {}, ) { - return await withResolvedActionClient( - { ...opts, readiness: "started" }, - async (client) => { - const crypto = requireCrypto(client, opts); - return await crypto.getVerificationSas(resolveVerificationId(requestId)); - }, - "persist", - ); + return await withStartedActionClient(opts, async (client) => { + const crypto = requireCrypto(client, opts); + return await crypto.getVerificationSas(resolveVerificationId(requestId)); + }); } export async function confirmMatrixVerificationSas( requestId: string, opts: MatrixActionClientOpts = {}, ) { - return await withResolvedActionClient( - { ...opts, readiness: "started" }, - async (client) => { - const crypto = requireCrypto(client, opts); - return await crypto.confirmVerificationSas(resolveVerificationId(requestId)); - }, - "persist", - ); + return await withStartedActionClient(opts, async (client) => { + const crypto = requireCrypto(client, opts); + return await crypto.confirmVerificationSas(resolveVerificationId(requestId)); + }); } export async function mismatchMatrixVerificationSas( requestId: string, opts: MatrixActionClientOpts = {}, ) { - return await withResolvedActionClient( - { ...opts, readiness: "started" }, - async (client) => { - const crypto = requireCrypto(client, opts); - return await crypto.mismatchVerificationSas(resolveVerificationId(requestId)); - }, - "persist", - ); + return await withStartedActionClient(opts, async (client) => { + const crypto = requireCrypto(client, opts); + return await crypto.mismatchVerificationSas(resolveVerificationId(requestId)); + }); } export async function confirmMatrixVerificationReciprocateQr( requestId: string, opts: MatrixActionClientOpts = {}, ) { - return await withResolvedActionClient( - { ...opts, readiness: "started" }, - async (client) => { - const crypto = requireCrypto(client, opts); - return await crypto.confirmVerificationReciprocateQr(resolveVerificationId(requestId)); - }, - "persist", - ); + return await withStartedActionClient(opts, async (client) => { + const crypto = requireCrypto(client, opts); + return await crypto.confirmVerificationReciprocateQr(resolveVerificationId(requestId)); + }); } export async function getMatrixEncryptionStatus( opts: MatrixActionClientOpts & { includeRecoveryKey?: boolean } = {}, ) { - return await withResolvedActionClient( - { ...opts, readiness: "started" }, - async (client) => { - const crypto = requireCrypto(client, opts); - const recoveryKey = await crypto.getRecoveryKey(); - return { - encryptionEnabled: true, - recoveryKeyStored: Boolean(recoveryKey), - recoveryKeyCreatedAt: recoveryKey?.createdAt ?? null, - ...(opts.includeRecoveryKey ? { recoveryKey: recoveryKey?.encodedPrivateKey ?? null } : {}), - pendingVerifications: (await crypto.listVerifications()).length, - }; - }, - "persist", - ); + return await withStartedActionClient(opts, async (client) => { + const crypto = requireCrypto(client, opts); + const recoveryKey = await crypto.getRecoveryKey(); + return { + encryptionEnabled: true, + recoveryKeyStored: Boolean(recoveryKey), + recoveryKeyCreatedAt: recoveryKey?.createdAt ?? null, + ...(opts.includeRecoveryKey ? { recoveryKey: recoveryKey?.encodedPrivateKey ?? null } : {}), + pendingVerifications: (await crypto.listVerifications()).length, + }; + }); } export async function getMatrixVerificationStatus( opts: MatrixActionClientOpts & { includeRecoveryKey?: boolean } = {}, ) { - return await withResolvedActionClient( - { ...opts, readiness: "started" }, - async (client) => { - const status = await client.getOwnDeviceVerificationStatus(); - const payload = { - ...status, - pendingVerifications: client.crypto ? (await client.crypto.listVerifications()).length : 0, - }; - if (!opts.includeRecoveryKey) { - return payload; - } - const recoveryKey = client.crypto ? await client.crypto.getRecoveryKey() : null; - return { - ...payload, - recoveryKey: recoveryKey?.encodedPrivateKey ?? null, - }; - }, - "persist", - ); + return await withStartedActionClient(opts, async (client) => { + const status = await client.getOwnDeviceVerificationStatus(); + const payload = { + ...status, + pendingVerifications: client.crypto ? (await client.crypto.listVerifications()).length : 0, + }; + if (!opts.includeRecoveryKey) { + return payload; + } + const recoveryKey = client.crypto ? await client.crypto.getRecoveryKey() : null; + return { + ...payload, + recoveryKey: recoveryKey?.encodedPrivateKey ?? null, + }; + }); } export async function getMatrixRoomKeyBackupStatus(opts: MatrixActionClientOpts = {}) { - return await withResolvedActionClient( - { ...opts, readiness: "started" }, + return await withStartedActionClient( + opts, async (client) => await client.getRoomKeyBackupStatus(), - "persist", ); } @@ -248,10 +195,9 @@ export async function verifyMatrixRecoveryKey( recoveryKey: string, opts: MatrixActionClientOpts = {}, ) { - return await withResolvedActionClient( - { ...opts, readiness: "started" }, + return await withStartedActionClient( + opts, async (client) => await client.verifyWithRecoveryKey(recoveryKey), - "persist", ); } @@ -260,13 +206,12 @@ export async function restoreMatrixRoomKeyBackup( recoveryKey?: string; } = {}, ) { - return await withResolvedActionClient( - { ...opts, readiness: "started" }, + return await withStartedActionClient( + opts, async (client) => await client.restoreRoomKeyBackup({ recoveryKey: opts.recoveryKey?.trim() || undefined, }), - "persist", ); } @@ -276,13 +221,12 @@ export async function bootstrapMatrixVerification( forceResetCrossSigning?: boolean; } = {}, ) { - return await withResolvedActionClient( - { ...opts, readiness: "started" }, + return await withStartedActionClient( + opts, async (client) => await client.bootstrapOwnDeviceVerification({ recoveryKey: opts.recoveryKey?.trim() || undefined, forceResetCrossSigning: opts.forceResetCrossSigning === true, }), - "persist", ); }