diff --git a/extensions/alibaba/video-generation-provider.test.ts b/extensions/alibaba/video-generation-provider.test.ts index 56d1bc36002..8f77c1fd03e 100644 --- a/extensions/alibaba/video-generation-provider.test.ts +++ b/extensions/alibaba/video-generation-provider.test.ts @@ -1,71 +1,27 @@ -import { afterEach, describe, expect, it, vi } from "vitest"; -import { buildAlibabaVideoGenerationProvider } from "./video-generation-provider.js"; +import { beforeAll, describe, expect, it } from "vitest"; +import { + expectDashscopeVideoTaskPoll, + expectSuccessfulDashscopeVideoResult, + mockSuccessfulDashscopeVideoTask, +} from "../../test/helpers/media-generation/dashscope-video-provider.js"; +import { + getProviderHttpMocks, + installProviderHttpMockCleanup, +} from "../../test/helpers/media-generation/provider-http-mocks.js"; -const { - resolveApiKeyForProviderMock, - postJsonRequestMock, - fetchWithTimeoutMock, - assertOkOrThrowHttpErrorMock, - resolveProviderHttpRequestConfigMock, -} = vi.hoisted(() => ({ - resolveApiKeyForProviderMock: vi.fn(async () => ({ apiKey: "alibaba-key" })), - postJsonRequestMock: vi.fn(), - fetchWithTimeoutMock: vi.fn(), - assertOkOrThrowHttpErrorMock: vi.fn(async () => {}), - resolveProviderHttpRequestConfigMock: vi.fn((params) => ({ - baseUrl: params.baseUrl ?? params.defaultBaseUrl, - allowPrivateNetwork: false, - headers: new Headers(params.defaultHeaders), - dispatcherPolicy: undefined, - })), -})); +const { postJsonRequestMock, fetchWithTimeoutMock } = getProviderHttpMocks(); -vi.mock("openclaw/plugin-sdk/provider-auth-runtime", () => ({ - resolveApiKeyForProvider: resolveApiKeyForProviderMock, -})); +let buildAlibabaVideoGenerationProvider: typeof import("./video-generation-provider.js").buildAlibabaVideoGenerationProvider; -vi.mock("openclaw/plugin-sdk/provider-http", () => ({ - assertOkOrThrowHttpError: assertOkOrThrowHttpErrorMock, - fetchWithTimeout: fetchWithTimeoutMock, - postJsonRequest: postJsonRequestMock, - resolveProviderHttpRequestConfig: resolveProviderHttpRequestConfigMock, -})); +beforeAll(async () => { + ({ buildAlibabaVideoGenerationProvider } = await import("./video-generation-provider.js")); +}); + +installProviderHttpMockCleanup(); describe("alibaba video generation provider", () => { - afterEach(() => { - resolveApiKeyForProviderMock.mockClear(); - postJsonRequestMock.mockReset(); - fetchWithTimeoutMock.mockReset(); - assertOkOrThrowHttpErrorMock.mockClear(); - resolveProviderHttpRequestConfigMock.mockClear(); - }); - it("submits async Wan generation, polls task status, and downloads the resulting video", async () => { - postJsonRequestMock.mockResolvedValue({ - response: { - json: async () => ({ - request_id: "req-1", - output: { - task_id: "task-1", - }, - }), - }, - release: vi.fn(async () => {}), - }); - fetchWithTimeoutMock - .mockResolvedValueOnce({ - json: async () => ({ - output: { - task_status: "SUCCEEDED", - results: [{ video_url: "https://example.com/out.mp4" }], - }, - }), - headers: new Headers(), - }) - .mockResolvedValueOnce({ - arrayBuffer: async () => Buffer.from("mp4-bytes"), - headers: new Headers({ "content-type": "video/mp4" }), - }); + mockSuccessfulDashscopeVideoTask({ postJsonRequestMock, fetchWithTimeoutMock }); const provider = buildAlibabaVideoGenerationProvider(); const result = await provider.generateVideo({ @@ -96,22 +52,8 @@ describe("alibaba video generation provider", () => { }), }), ); - expect(fetchWithTimeoutMock).toHaveBeenNthCalledWith( - 1, - "https://dashscope-intl.aliyuncs.com/api/v1/tasks/task-1", - expect.objectContaining({ method: "GET" }), - 120000, - fetch, - ); - expect(result.videos).toHaveLength(1); - expect(result.videos[0]?.mimeType).toBe("video/mp4"); - expect(result.metadata).toEqual( - expect.objectContaining({ - requestId: "req-1", - taskId: "task-1", - taskStatus: "SUCCEEDED", - }), - ); + expectDashscopeVideoTaskPoll(fetchWithTimeoutMock); + expectSuccessfulDashscopeVideoResult(result); }); it("fails fast when reference inputs are local buffers instead of remote URLs", async () => { diff --git a/extensions/alibaba/video-generation-provider.ts b/extensions/alibaba/video-generation-provider.ts index 4f56a753729..c543297152f 100644 --- a/extensions/alibaba/video-generation-provider.ts +++ b/extensions/alibaba/video-generation-provider.ts @@ -1,29 +1,21 @@ import { isProviderApiKeyConfigured } from "openclaw/plugin-sdk/provider-auth"; import { resolveApiKeyForProvider } from "openclaw/plugin-sdk/provider-auth-runtime"; +import { resolveProviderHttpRequestConfig } from "openclaw/plugin-sdk/provider-http"; import { - assertOkOrThrowHttpError, - postJsonRequest, - resolveProviderHttpRequestConfig, -} from "openclaw/plugin-sdk/provider-http"; -import { - DEFAULT_VIDEO_GENERATION_DURATION_SECONDS, + DASHSCOPE_WAN_VIDEO_CAPABILITIES, + DASHSCOPE_WAN_VIDEO_MODELS, + DEFAULT_DASHSCOPE_WAN_VIDEO_MODEL, DEFAULT_VIDEO_GENERATION_TIMEOUT_MS, - DEFAULT_VIDEO_RESOLUTION_TO_SIZE, - buildDashscopeVideoGenerationInput, - buildDashscopeVideoGenerationParameters, - downloadDashscopeGeneratedVideos, - extractDashscopeVideoUrls, - pollDashscopeVideoTaskUntilComplete, + runDashscopeVideoGenerationTask, } from "openclaw/plugin-sdk/video-generation"; import type { - DashscopeVideoGenerationResponse, VideoGenerationProvider, VideoGenerationRequest, VideoGenerationResult, } from "openclaw/plugin-sdk/video-generation"; const DEFAULT_ALIBABA_VIDEO_BASE_URL = "https://dashscope-intl.aliyuncs.com"; -const DEFAULT_ALIBABA_VIDEO_MODEL = "wan2.6-t2v"; +const DEFAULT_ALIBABA_VIDEO_MODEL = DEFAULT_DASHSCOPE_WAN_VIDEO_MODEL; function resolveAlibabaVideoBaseUrl(req: VideoGenerationRequest): string { return req.cfg?.models?.providers?.alibaba?.baseUrl?.trim() || DEFAULT_ALIBABA_VIDEO_BASE_URL; @@ -38,45 +30,13 @@ export function buildAlibabaVideoGenerationProvider(): VideoGenerationProvider { id: "alibaba", label: "Alibaba Model Studio", defaultModel: DEFAULT_ALIBABA_VIDEO_MODEL, - models: ["wan2.6-t2v", "wan2.6-i2v", "wan2.6-r2v", "wan2.6-r2v-flash", "wan2.7-r2v"], + models: [...DASHSCOPE_WAN_VIDEO_MODELS], isConfigured: ({ agentDir }) => isProviderApiKeyConfigured({ provider: "alibaba", agentDir, }), - capabilities: { - generate: { - maxVideos: 1, - maxDurationSeconds: 10, - supportsSize: true, - supportsAspectRatio: true, - supportsResolution: true, - supportsAudio: true, - supportsWatermark: true, - }, - imageToVideo: { - enabled: true, - maxVideos: 1, - maxInputImages: 1, - maxDurationSeconds: 10, - supportsSize: true, - supportsAspectRatio: true, - supportsResolution: true, - supportsAudio: true, - supportsWatermark: true, - }, - videoToVideo: { - enabled: true, - maxVideos: 1, - maxInputVideos: 4, - maxDurationSeconds: 10, - supportsSize: true, - supportsAspectRatio: true, - supportsResolution: true, - supportsAudio: true, - supportsWatermark: true, - }, - }, + capabilities: DASHSCOPE_WAN_VIDEO_CAPABILITIES, async generateVideo(req): Promise { const fetchFn = fetch; const auth = await resolveApiKeyForProvider({ @@ -105,68 +65,19 @@ export function buildAlibabaVideoGenerationProvider(): VideoGenerationProvider { }); const model = req.model?.trim() || DEFAULT_ALIBABA_VIDEO_MODEL; - const { response, release } = await postJsonRequest({ + return await runDashscopeVideoGenerationTask({ + providerLabel: "Alibaba Wan", + model, + req, url: `${resolveDashscopeAigcApiBaseUrl(baseUrl)}/api/v1/services/aigc/video-generation/video-synthesis`, headers, - body: { - model, - input: buildDashscopeVideoGenerationInput({ - providerLabel: "Alibaba Wan", - req, - }), - parameters: buildDashscopeVideoGenerationParameters( - { - ...req, - durationSeconds: req.durationSeconds ?? DEFAULT_VIDEO_GENERATION_DURATION_SECONDS, - }, - DEFAULT_VIDEO_RESOLUTION_TO_SIZE, - ), - }, + baseUrl: resolveDashscopeAigcApiBaseUrl(baseUrl), timeoutMs: req.timeoutMs, fetchFn, allowPrivateNetwork, dispatcherPolicy, + defaultTimeoutMs: DEFAULT_VIDEO_GENERATION_TIMEOUT_MS, }); - - try { - await assertOkOrThrowHttpError(response, "Alibaba Wan video generation failed"); - const submitted = (await response.json()) as DashscopeVideoGenerationResponse; - const taskId = submitted.output?.task_id?.trim(); - if (!taskId) { - throw new Error("Alibaba Wan video generation response missing task_id"); - } - const completed = await pollDashscopeVideoTaskUntilComplete({ - providerLabel: "Alibaba Wan", - taskId, - headers, - timeoutMs: req.timeoutMs, - fetchFn, - baseUrl: resolveDashscopeAigcApiBaseUrl(baseUrl), - defaultTimeoutMs: DEFAULT_VIDEO_GENERATION_TIMEOUT_MS, - }); - const urls = extractDashscopeVideoUrls(completed); - if (urls.length === 0) { - throw new Error("Alibaba Wan video generation completed without output video URLs"); - } - const videos = await downloadDashscopeGeneratedVideos({ - providerLabel: "Alibaba Wan", - urls, - timeoutMs: req.timeoutMs, - fetchFn, - defaultTimeoutMs: DEFAULT_VIDEO_GENERATION_TIMEOUT_MS, - }); - return { - videos, - model, - metadata: { - requestId: submitted.request_id, - taskId, - taskStatus: completed.output?.task_status, - }, - }; - } finally { - await release(); - } }, }; } diff --git a/extensions/bluebubbles/src/accounts-normalization.ts b/extensions/bluebubbles/src/accounts-normalization.ts new file mode 100644 index 00000000000..bd57b1c824b --- /dev/null +++ b/extensions/bluebubbles/src/accounts-normalization.ts @@ -0,0 +1,107 @@ +import { isBlockedHostnameOrIp } from "openclaw/plugin-sdk/ssrf-runtime"; +import { normalizeBlueBubblesServerUrl } from "./types.js"; + +function asRecord(value: unknown): Record | null { + return value && typeof value === "object" && !Array.isArray(value) + ? (value as Record) + : null; +} + +export function normalizeBlueBubblesPrivateNetworkAliases( + config: T, +): T { + const record = asRecord(config); + if (!record) { + return config; + } + const network = asRecord(record.network); + const canonicalValue = + typeof network?.dangerouslyAllowPrivateNetwork === "boolean" + ? network.dangerouslyAllowPrivateNetwork + : typeof network?.allowPrivateNetwork === "boolean" + ? network.allowPrivateNetwork + : typeof record.dangerouslyAllowPrivateNetwork === "boolean" + ? record.dangerouslyAllowPrivateNetwork + : typeof record.allowPrivateNetwork === "boolean" + ? record.allowPrivateNetwork + : undefined; + + if (canonicalValue === undefined) { + return config; + } + + const { + allowPrivateNetwork: _legacyFlatAllow, + dangerouslyAllowPrivateNetwork: _legacyFlatDanger, + ...rest + } = record; + const { + allowPrivateNetwork: _legacyNetworkAllow, + dangerouslyAllowPrivateNetwork: _legacyNetworkDanger, + ...restNetwork + } = network ?? {}; + + return { + ...rest, + network: { + ...restNetwork, + dangerouslyAllowPrivateNetwork: canonicalValue, + }, + } as T; +} + +export function normalizeBlueBubblesAccountsMap( + accounts: Record | undefined, +): Record | undefined { + if (!accounts) { + return undefined; + } + return Object.fromEntries( + Object.entries(accounts).map(([accountKey, accountConfig]) => [ + accountKey, + normalizeBlueBubblesPrivateNetworkAliases(accountConfig), + ]), + ); +} + +export function resolveBlueBubblesPrivateNetworkConfigValue( + config: object | null | undefined, +): boolean | undefined { + const record = asRecord(config); + if (!record) { + return undefined; + } + const network = asRecord(record.network); + if (typeof network?.dangerouslyAllowPrivateNetwork === "boolean") { + return network.dangerouslyAllowPrivateNetwork; + } + if (typeof network?.allowPrivateNetwork === "boolean") { + return network.allowPrivateNetwork; + } + if (typeof record.dangerouslyAllowPrivateNetwork === "boolean") { + return record.dangerouslyAllowPrivateNetwork; + } + if (typeof record.allowPrivateNetwork === "boolean") { + return record.allowPrivateNetwork; + } + return undefined; +} + +export function resolveBlueBubblesEffectiveAllowPrivateNetworkFromConfig(params: { + baseUrl?: string; + config?: object | null; +}): boolean { + const configuredValue = resolveBlueBubblesPrivateNetworkConfigValue(params.config); + if (configuredValue !== undefined) { + return configuredValue; + } + if (!params.baseUrl) { + return false; + } + try { + const hostname = new URL(normalizeBlueBubblesServerUrl(params.baseUrl)).hostname.trim(); + return Boolean(hostname) && isBlockedHostnameOrIp(hostname); + } catch { + return false; + } +} diff --git a/extensions/bluebubbles/src/accounts.ts b/extensions/bluebubbles/src/accounts.ts index 89374a838cc..59a692ec077 100644 --- a/extensions/bluebubbles/src/accounts.ts +++ b/extensions/bluebubbles/src/accounts.ts @@ -5,8 +5,13 @@ import { } from "openclaw/plugin-sdk/account-resolution"; import { resolveChannelStreamingChunkMode } from "openclaw/plugin-sdk/channel-streaming"; import type { OpenClawConfig } from "openclaw/plugin-sdk/config-runtime"; -import { isBlockedHostnameOrIp } from "openclaw/plugin-sdk/ssrf-runtime"; import { normalizeOptionalString } from "openclaw/plugin-sdk/text-runtime"; +import { + normalizeBlueBubblesAccountsMap, + normalizeBlueBubblesPrivateNetworkAliases, + resolveBlueBubblesEffectiveAllowPrivateNetworkFromConfig, + resolveBlueBubblesPrivateNetworkConfigValue as resolveBlueBubblesPrivateNetworkConfigValueFromRecord, +} from "./accounts-normalization.js"; import { hasConfiguredSecretInput, normalizeSecretInputString } from "./secret-input.js"; import { normalizeBlueBubblesServerUrl, type BlueBubblesAccountConfig } from "./types.js"; @@ -25,76 +30,13 @@ const { } = createAccountListHelpers("bluebubbles"); export { listBlueBubblesAccountIds, resolveDefaultBlueBubblesAccountId }; -function asRecord(value: unknown): Record | null { - return value && typeof value === "object" && !Array.isArray(value) - ? (value as Record) - : null; -} - -function normalizeBlueBubblesPrivateNetworkAliases( - config: Record | undefined, -): Record | undefined { - const record = asRecord(config); - if (!record) { - return config; - } - const network = asRecord(record.network); - const canonicalValue = - typeof network?.dangerouslyAllowPrivateNetwork === "boolean" - ? network.dangerouslyAllowPrivateNetwork - : typeof network?.allowPrivateNetwork === "boolean" - ? network.allowPrivateNetwork - : typeof record.dangerouslyAllowPrivateNetwork === "boolean" - ? record.dangerouslyAllowPrivateNetwork - : typeof record.allowPrivateNetwork === "boolean" - ? record.allowPrivateNetwork - : undefined; - - if (canonicalValue === undefined) { - return config; - } - - const { - allowPrivateNetwork: _legacyFlatAllow, - dangerouslyAllowPrivateNetwork: _legacyFlatDanger, - ...rest - } = record; - const { - allowPrivateNetwork: _legacyNetworkAllow, - dangerouslyAllowPrivateNetwork: _legacyNetworkDanger, - ...restNetwork - } = network ?? {}; - - return { - ...rest, - network: { - ...restNetwork, - dangerouslyAllowPrivateNetwork: canonicalValue, - }, - }; -} - -function normalizeBlueBubblesAccountsMap( - accounts: Record> | undefined, -): Record> | undefined { - if (!accounts) { - return undefined; - } - return Object.fromEntries( - Object.entries(accounts).map(([accountKey, accountConfig]) => [ - accountKey, - normalizeBlueBubblesPrivateNetworkAliases(accountConfig) as Partial, - ]), - ); -} - function mergeBlueBubblesAccountConfig( cfg: OpenClawConfig, accountId: string, ): BlueBubblesAccountConfig { const channelConfig = normalizeBlueBubblesPrivateNetworkAliases( cfg.channels?.bluebubbles as BlueBubblesAccountConfig | undefined, - ) as BlueBubblesAccountConfig | undefined; + ); const accounts = normalizeBlueBubblesAccountsMap( cfg.channels?.bluebubbles?.accounts as | Record> @@ -141,43 +83,14 @@ export function resolveBlueBubblesAccount(params: { export function resolveBlueBubblesPrivateNetworkConfigValue( config: BlueBubblesAccountConfig | null | undefined, ): boolean | undefined { - const record = asRecord(config); - if (!record) { - return undefined; - } - const network = asRecord(record.network); - if (typeof network?.dangerouslyAllowPrivateNetwork === "boolean") { - return network.dangerouslyAllowPrivateNetwork; - } - if (typeof network?.allowPrivateNetwork === "boolean") { - return network.allowPrivateNetwork; - } - if (typeof record.dangerouslyAllowPrivateNetwork === "boolean") { - return record.dangerouslyAllowPrivateNetwork; - } - if (typeof record.allowPrivateNetwork === "boolean") { - return record.allowPrivateNetwork; - } - return undefined; + return resolveBlueBubblesPrivateNetworkConfigValueFromRecord(config); } export function resolveBlueBubblesEffectiveAllowPrivateNetwork(params: { baseUrl?: string; config?: BlueBubblesAccountConfig | null; }): boolean { - const configuredValue = resolveBlueBubblesPrivateNetworkConfigValue(params.config); - if (configuredValue !== undefined) { - return configuredValue; - } - if (!params.baseUrl) { - return false; - } - try { - const hostname = new URL(normalizeBlueBubblesServerUrl(params.baseUrl)).hostname.trim(); - return Boolean(hostname) && isBlockedHostnameOrIp(hostname); - } catch { - return false; - } + return resolveBlueBubblesEffectiveAllowPrivateNetworkFromConfig(params); } export function listEnabledBlueBubblesAccounts(cfg: OpenClawConfig): ResolvedBlueBubblesAccount[] { diff --git a/extensions/bluebubbles/src/test-harness.ts b/extensions/bluebubbles/src/test-harness.ts index 4b67bdc53e1..57732ee980b 100644 --- a/extensions/bluebubbles/src/test-harness.ts +++ b/extensions/bluebubbles/src/test-harness.ts @@ -1,7 +1,12 @@ -import { isBlockedHostnameOrIp } from "openclaw/plugin-sdk/ssrf-runtime"; import type { Mock } from "vitest"; import { afterEach, beforeEach, vi } from "vitest"; -import { _setFetchGuardForTesting, normalizeBlueBubblesServerUrl } from "./types.js"; +import { + normalizeBlueBubblesAccountsMap, + normalizeBlueBubblesPrivateNetworkAliases, + resolveBlueBubblesEffectiveAllowPrivateNetworkFromConfig, + resolveBlueBubblesPrivateNetworkConfigValue as resolveBlueBubblesPrivateNetworkConfigValueFromConfig, +} from "./accounts-normalization.js"; +import { _setFetchGuardForTesting } from "./types.js"; export const BLUE_BUBBLES_PRIVATE_API_STATUS = { enabled: true, @@ -28,69 +33,6 @@ export function mockBlueBubblesPrivateApiStatusOnce( mock.mockReturnValueOnce(value); } -function asRecord(value: unknown): Record | null { - return value && typeof value === "object" && !Array.isArray(value) - ? (value as Record) - : null; -} - -function normalizeBlueBubblesPrivateNetworkAliases( - config: Record | undefined, -): Record | undefined { - const record = asRecord(config); - if (!record) { - return config; - } - const network = asRecord(record.network); - const canonicalValue = - typeof network?.dangerouslyAllowPrivateNetwork === "boolean" - ? network.dangerouslyAllowPrivateNetwork - : typeof network?.allowPrivateNetwork === "boolean" - ? network.allowPrivateNetwork - : typeof record.dangerouslyAllowPrivateNetwork === "boolean" - ? record.dangerouslyAllowPrivateNetwork - : typeof record.allowPrivateNetwork === "boolean" - ? record.allowPrivateNetwork - : undefined; - - if (canonicalValue === undefined) { - return config; - } - - const { - allowPrivateNetwork: _legacyFlatAllow, - dangerouslyAllowPrivateNetwork: _legacyFlatDanger, - ...rest - } = record; - const { - allowPrivateNetwork: _legacyNetworkAllow, - dangerouslyAllowPrivateNetwork: _legacyNetworkDanger, - ...restNetwork - } = network ?? {}; - - return { - ...rest, - network: { - ...restNetwork, - dangerouslyAllowPrivateNetwork: canonicalValue, - }, - }; -} - -function normalizeBlueBubblesAccountsMap( - accounts: Record | undefined> | undefined, -): Record | undefined> | undefined { - if (!accounts) { - return undefined; - } - return Object.fromEntries( - Object.entries(accounts).map(([accountKey, accountConfig]) => [ - accountKey, - normalizeBlueBubblesPrivateNetworkAliases(accountConfig), - ]), - ); -} - export function resolveBlueBubblesAccountFromConfig(params: { cfg?: { channels?: { bluebubbles?: Record } }; accountId?: string; @@ -127,48 +69,6 @@ export function resolveBlueBubblesAccountFromConfig(params: { }; } -function resolveBlueBubblesPrivateNetworkConfigValueFromConfig( - config: Record | undefined, -): boolean | undefined { - const record = asRecord(config); - if (!record) { - return undefined; - } - const network = asRecord(record.network); - if (typeof network?.dangerouslyAllowPrivateNetwork === "boolean") { - return network.dangerouslyAllowPrivateNetwork; - } - if (typeof network?.allowPrivateNetwork === "boolean") { - return network.allowPrivateNetwork; - } - if (typeof record.dangerouslyAllowPrivateNetwork === "boolean") { - return record.dangerouslyAllowPrivateNetwork; - } - if (typeof record.allowPrivateNetwork === "boolean") { - return record.allowPrivateNetwork; - } - return undefined; -} - -function resolveBlueBubblesEffectiveAllowPrivateNetworkFromConfig(params: { - baseUrl?: string; - config?: Record; -}) { - const configuredValue = resolveBlueBubblesPrivateNetworkConfigValueFromConfig(params.config); - if (configuredValue !== undefined) { - return configuredValue; - } - if (!params.baseUrl) { - return false; - } - try { - const hostname = new URL(normalizeBlueBubblesServerUrl(params.baseUrl)).hostname.trim(); - return Boolean(hostname) && isBlockedHostnameOrIp(hostname); - } catch { - return false; - } -} - export function createBlueBubblesAccountsMockModule() { return { resolveBlueBubblesAccount: vi.fn(resolveBlueBubblesAccountFromConfig), diff --git a/extensions/browser/src/browser/server-context.remote-profile-tab-ops.fallback.test.ts b/extensions/browser/src/browser/server-context.remote-profile-tab-ops.fallback.test.ts index 7fa6e8b9392..7f30849c820 100644 --- a/extensions/browser/src/browser/server-context.remote-profile-tab-ops.fallback.test.ts +++ b/extensions/browser/src/browser/server-context.remote-profile-tab-ops.fallback.test.ts @@ -1,59 +1,12 @@ -import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; - -type RemoteProfileTestDeps = { - chromeModule: typeof import("./chrome.js"); - InvalidBrowserNavigationUrlError: typeof import("./navigation-guard.js").InvalidBrowserNavigationUrlError; - pwAiModule: typeof import("./pw-ai-module.js"); - closePlaywrightBrowserConnection: typeof import("./pw-session.js").closePlaywrightBrowserConnection; - createBrowserRouteContext: typeof import("./server-context.js").createBrowserRouteContext; - createJsonListFetchMock: typeof import("./server-context.remote-tab-ops.harness.js").createJsonListFetchMock; - createRemoteRouteHarness: typeof import("./server-context.remote-tab-ops.harness.js").createRemoteRouteHarness; - createSequentialPageLister: typeof import("./server-context.remote-tab-ops.harness.js").createSequentialPageLister; - makeState: typeof import("./server-context.remote-tab-ops.harness.js").makeState; - originalFetch: typeof import("./server-context.remote-tab-ops.harness.js").originalFetch; -}; - -async function loadRemoteProfileTestDeps(): Promise { - vi.resetModules(); - await import("./server-context.chrome-test-harness.js"); - const chromeModule = await import("./chrome.js"); - const { InvalidBrowserNavigationUrlError } = await import("./navigation-guard.js"); - const pwAiModule = await import("./pw-ai-module.js"); - const { closePlaywrightBrowserConnection } = await import("./pw-session.js"); - const { createBrowserRouteContext } = await import("./server-context.js"); - const { - createJsonListFetchMock, - createRemoteRouteHarness, - createSequentialPageLister, - makeState, - originalFetch, - } = await import("./server-context.remote-tab-ops.harness.js"); - return { - chromeModule, - InvalidBrowserNavigationUrlError, - pwAiModule, - closePlaywrightBrowserConnection, - createBrowserRouteContext, - createJsonListFetchMock, - createRemoteRouteHarness, - createSequentialPageLister, - makeState, - originalFetch, - }; -} +import { describe, expect, it, vi } from "vitest"; +import { + installRemoteProfileTestLifecycle, + loadRemoteProfileTestDeps, + type RemoteProfileTestDeps, +} from "./server-context.remote-profile-tab-ops.test-helpers.js"; const deps: RemoteProfileTestDeps = await loadRemoteProfileTestDeps(); - -beforeEach(() => { - vi.clearAllMocks(); - globalThis.fetch = deps.originalFetch; -}); - -afterEach(async () => { - await deps.closePlaywrightBrowserConnection().catch(() => {}); - globalThis.fetch = deps.originalFetch; - vi.restoreAllMocks(); -}); +installRemoteProfileTestLifecycle(deps); describe("browser remote profile fallback and attachOnly behavior", () => { it("uses profile-level attachOnly when global attachOnly is false", async () => { diff --git a/extensions/browser/src/browser/server-context.remote-profile-tab-ops.playwright.test.ts b/extensions/browser/src/browser/server-context.remote-profile-tab-ops.playwright.test.ts index 0ed9f68e4ea..d6fba63115c 100644 --- a/extensions/browser/src/browser/server-context.remote-profile-tab-ops.playwright.test.ts +++ b/extensions/browser/src/browser/server-context.remote-profile-tab-ops.playwright.test.ts @@ -1,59 +1,12 @@ -import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; - -type RemoteProfileTestDeps = { - chromeModule: typeof import("./chrome.js"); - InvalidBrowserNavigationUrlError: typeof import("./navigation-guard.js").InvalidBrowserNavigationUrlError; - pwAiModule: typeof import("./pw-ai-module.js"); - closePlaywrightBrowserConnection: typeof import("./pw-session.js").closePlaywrightBrowserConnection; - createBrowserRouteContext: typeof import("./server-context.js").createBrowserRouteContext; - createJsonListFetchMock: typeof import("./server-context.remote-tab-ops.harness.js").createJsonListFetchMock; - createRemoteRouteHarness: typeof import("./server-context.remote-tab-ops.harness.js").createRemoteRouteHarness; - createSequentialPageLister: typeof import("./server-context.remote-tab-ops.harness.js").createSequentialPageLister; - makeState: typeof import("./server-context.remote-tab-ops.harness.js").makeState; - originalFetch: typeof import("./server-context.remote-tab-ops.harness.js").originalFetch; -}; - -async function loadRemoteProfileTestDeps(): Promise { - vi.resetModules(); - await import("./server-context.chrome-test-harness.js"); - const chromeModule = await import("./chrome.js"); - const { InvalidBrowserNavigationUrlError } = await import("./navigation-guard.js"); - const pwAiModule = await import("./pw-ai-module.js"); - const { closePlaywrightBrowserConnection } = await import("./pw-session.js"); - const { createBrowserRouteContext } = await import("./server-context.js"); - const { - createJsonListFetchMock, - createRemoteRouteHarness, - createSequentialPageLister, - makeState, - originalFetch, - } = await import("./server-context.remote-tab-ops.harness.js"); - return { - chromeModule, - InvalidBrowserNavigationUrlError, - pwAiModule, - closePlaywrightBrowserConnection, - createBrowserRouteContext, - createJsonListFetchMock, - createRemoteRouteHarness, - createSequentialPageLister, - makeState, - originalFetch, - }; -} +import { describe, expect, it, vi } from "vitest"; +import { + installRemoteProfileTestLifecycle, + loadRemoteProfileTestDeps, + type RemoteProfileTestDeps, +} from "./server-context.remote-profile-tab-ops.test-helpers.js"; const deps: RemoteProfileTestDeps = await loadRemoteProfileTestDeps(); - -beforeEach(() => { - vi.clearAllMocks(); - globalThis.fetch = deps.originalFetch; -}); - -afterEach(async () => { - await deps.closePlaywrightBrowserConnection().catch(() => {}); - globalThis.fetch = deps.originalFetch; - vi.restoreAllMocks(); -}); +installRemoteProfileTestLifecycle(deps); describe("browser remote profile tab ops via Playwright", () => { it("uses Playwright tab operations when available", async () => { diff --git a/extensions/browser/src/browser/server-context.remote-profile-tab-ops.test-helpers.ts b/extensions/browser/src/browser/server-context.remote-profile-tab-ops.test-helpers.ts new file mode 100644 index 00000000000..df558dc5432 --- /dev/null +++ b/extensions/browser/src/browser/server-context.remote-profile-tab-ops.test-helpers.ts @@ -0,0 +1,56 @@ +import { afterEach, beforeEach, vi } from "vitest"; + +export type RemoteProfileTestDeps = { + chromeModule: typeof import("./chrome.js"); + InvalidBrowserNavigationUrlError: typeof import("./navigation-guard.js").InvalidBrowserNavigationUrlError; + pwAiModule: typeof import("./pw-ai-module.js"); + closePlaywrightBrowserConnection: typeof import("./pw-session.js").closePlaywrightBrowserConnection; + createBrowserRouteContext: typeof import("./server-context.js").createBrowserRouteContext; + createJsonListFetchMock: typeof import("./server-context.remote-tab-ops.harness.js").createJsonListFetchMock; + createRemoteRouteHarness: typeof import("./server-context.remote-tab-ops.harness.js").createRemoteRouteHarness; + createSequentialPageLister: typeof import("./server-context.remote-tab-ops.harness.js").createSequentialPageLister; + makeState: typeof import("./server-context.remote-tab-ops.harness.js").makeState; + originalFetch: typeof import("./server-context.remote-tab-ops.harness.js").originalFetch; +}; + +export async function loadRemoteProfileTestDeps(): Promise { + vi.resetModules(); + await import("./server-context.chrome-test-harness.js"); + const chromeModule = await import("./chrome.js"); + const { InvalidBrowserNavigationUrlError } = await import("./navigation-guard.js"); + const pwAiModule = await import("./pw-ai-module.js"); + const { closePlaywrightBrowserConnection } = await import("./pw-session.js"); + const { createBrowserRouteContext } = await import("./server-context.js"); + const { + createJsonListFetchMock, + createRemoteRouteHarness, + createSequentialPageLister, + makeState, + originalFetch, + } = await import("./server-context.remote-tab-ops.harness.js"); + return { + chromeModule, + InvalidBrowserNavigationUrlError, + pwAiModule, + closePlaywrightBrowserConnection, + createBrowserRouteContext, + createJsonListFetchMock, + createRemoteRouteHarness, + createSequentialPageLister, + makeState, + originalFetch, + }; +} + +export function installRemoteProfileTestLifecycle(deps: RemoteProfileTestDeps): void { + beforeEach(() => { + vi.clearAllMocks(); + globalThis.fetch = deps.originalFetch; + }); + + afterEach(async () => { + await deps.closePlaywrightBrowserConnection().catch(() => {}); + globalThis.fetch = deps.originalFetch; + vi.restoreAllMocks(); + }); +} diff --git a/extensions/byteplus/video-generation-provider.test.ts b/extensions/byteplus/video-generation-provider.test.ts index f39e9c6f272..975668374e3 100644 --- a/extensions/byteplus/video-generation-provider.test.ts +++ b/extensions/byteplus/video-generation-provider.test.ts @@ -1,45 +1,20 @@ -import { afterEach, describe, expect, it, vi } from "vitest"; -import { buildBytePlusVideoGenerationProvider } from "./video-generation-provider.js"; +import { beforeAll, describe, expect, it, vi } from "vitest"; +import { + getProviderHttpMocks, + installProviderHttpMockCleanup, +} from "../../test/helpers/media-generation/provider-http-mocks.js"; -const { - resolveApiKeyForProviderMock, - postJsonRequestMock, - fetchWithTimeoutMock, - assertOkOrThrowHttpErrorMock, - resolveProviderHttpRequestConfigMock, -} = vi.hoisted(() => ({ - resolveApiKeyForProviderMock: vi.fn(async () => ({ apiKey: "byteplus-key" })), - postJsonRequestMock: vi.fn(), - fetchWithTimeoutMock: vi.fn(), - assertOkOrThrowHttpErrorMock: vi.fn(async () => {}), - resolveProviderHttpRequestConfigMock: vi.fn((params) => ({ - baseUrl: params.baseUrl ?? params.defaultBaseUrl, - allowPrivateNetwork: false, - headers: new Headers(params.defaultHeaders), - dispatcherPolicy: undefined, - })), -})); +const { postJsonRequestMock, fetchWithTimeoutMock } = getProviderHttpMocks(); -vi.mock("openclaw/plugin-sdk/provider-auth-runtime", () => ({ - resolveApiKeyForProvider: resolveApiKeyForProviderMock, -})); +let buildBytePlusVideoGenerationProvider: typeof import("./video-generation-provider.js").buildBytePlusVideoGenerationProvider; -vi.mock("openclaw/plugin-sdk/provider-http", () => ({ - assertOkOrThrowHttpError: assertOkOrThrowHttpErrorMock, - fetchWithTimeout: fetchWithTimeoutMock, - postJsonRequest: postJsonRequestMock, - resolveProviderHttpRequestConfig: resolveProviderHttpRequestConfigMock, -})); +beforeAll(async () => { + ({ buildBytePlusVideoGenerationProvider } = await import("./video-generation-provider.js")); +}); + +installProviderHttpMockCleanup(); describe("byteplus video generation provider", () => { - afterEach(() => { - resolveApiKeyForProviderMock.mockClear(); - postJsonRequestMock.mockReset(); - fetchWithTimeoutMock.mockReset(); - assertOkOrThrowHttpErrorMock.mockClear(); - resolveProviderHttpRequestConfigMock.mockClear(); - }); - it("creates a content-generation task, polls, and downloads the video", async () => { postJsonRequestMock.mockResolvedValue({ response: { diff --git a/extensions/googlechat/src/channel.test.ts b/extensions/googlechat/src/channel.test.ts index 7bf58715e6a..01a5511485f 100644 --- a/extensions/googlechat/src/channel.test.ts +++ b/extensions/googlechat/src/channel.test.ts @@ -62,6 +62,31 @@ function resolveGoogleChatAccountImpl(params: { cfg: OpenClawConfig; accountId?: }; } +function mockGoogleChatOutboundSpaceResolution() { + resolveGoogleChatOutboundSpaceMock.mockImplementation(async ({ target }: { target: string }) => { + const normalized = normalizeGoogleChatTarget(target); + if (!normalized) { + throw new Error("Missing Google Chat target."); + } + return normalized.toLowerCase().startsWith("users/") + ? `spaces/DM-${normalized.slice("users/".length)}` + : normalized.replace(/\/messages\/.+$/, ""); + }); +} + +function mockGoogleChatMediaLoaders() { + loadOutboundMediaFromUrlMock.mockImplementation(async (mediaUrl: string) => ({ + buffer: Buffer.from("default-bytes"), + fileName: mediaUrl.split("/").pop() || "attachment", + contentType: "application/octet-stream", + })); + fetchRemoteMediaMock.mockImplementation(async () => ({ + buffer: Buffer.from("remote-bytes"), + fileName: "remote.png", + contentType: "image/png", + })); +} + vi.mock("./channel.runtime.js", () => { return { googleChatChannelRuntime: { @@ -136,48 +161,14 @@ vi.mock("./channel.deps.runtime.js", () => { }); resolveGoogleChatAccountMock.mockImplementation(resolveGoogleChatAccountImpl); -resolveGoogleChatOutboundSpaceMock.mockImplementation(async ({ target }: { target: string }) => { - const normalized = normalizeGoogleChatTarget(target); - if (!normalized) { - throw new Error("Missing Google Chat target."); - } - return normalized.toLowerCase().startsWith("users/") - ? `spaces/DM-${normalized.slice("users/".length)}` - : normalized.replace(/\/messages\/.+$/, ""); -}); -loadOutboundMediaFromUrlMock.mockImplementation(async (mediaUrl: string) => ({ - buffer: Buffer.from("default-bytes"), - fileName: mediaUrl.split("/").pop() || "attachment", - contentType: "application/octet-stream", -})); -fetchRemoteMediaMock.mockImplementation(async () => ({ - buffer: Buffer.from("remote-bytes"), - fileName: "remote.png", - contentType: "image/png", -})); +mockGoogleChatOutboundSpaceResolution(); +mockGoogleChatMediaLoaders(); afterEach(() => { vi.clearAllMocks(); resolveGoogleChatAccountMock.mockImplementation(resolveGoogleChatAccountImpl); - resolveGoogleChatOutboundSpaceMock.mockImplementation(async ({ target }: { target: string }) => { - const normalized = normalizeGoogleChatTarget(target); - if (!normalized) { - throw new Error("Missing Google Chat target."); - } - return normalized.toLowerCase().startsWith("users/") - ? `spaces/DM-${normalized.slice("users/".length)}` - : normalized.replace(/\/messages\/.+$/, ""); - }); - loadOutboundMediaFromUrlMock.mockImplementation(async (mediaUrl: string) => ({ - buffer: Buffer.from("default-bytes"), - fileName: mediaUrl.split("/").pop() || "attachment", - contentType: "application/octet-stream", - })); - fetchRemoteMediaMock.mockImplementation(async () => ({ - buffer: Buffer.from("remote-bytes"), - fileName: "remote.png", - contentType: "image/png", - })); + mockGoogleChatOutboundSpaceResolution(); + mockGoogleChatMediaLoaders(); }); function createGoogleChatCfg(): OpenClawConfig { diff --git a/extensions/lobster/src/lobster-taskflow.test.ts b/extensions/lobster/src/lobster-taskflow.test.ts index cfedb384bef..5a6ba1533de 100644 --- a/extensions/lobster/src/lobster-taskflow.test.ts +++ b/extensions/lobster/src/lobster-taskflow.test.ts @@ -3,7 +3,7 @@ import type { LobsterRunner } from "./lobster-runner.js"; import { resumeManagedLobsterFlow, runManagedLobsterFlow } from "./lobster-taskflow.js"; import { createFakeTaskFlow } from "./taskflow-test-helpers.js"; -function _expectManagedFlowFailure( +function expectManagedFlowFailure( result: Awaited>, ) { expect(result.ok).toBe(false); @@ -18,6 +18,45 @@ function createRunner(result: Awaited>): Lobste }; } +function createRunFlowParams( + taskFlow: ReturnType, + runner: LobsterRunner, +): Parameters[0] { + return { + taskFlow, + runner, + runnerParams: { + action: "run", + pipeline: "noop", + cwd: process.cwd(), + timeoutMs: 1000, + maxStdoutBytes: 4096, + }, + controllerId: "tests/lobster", + goal: "Run Lobster workflow", + }; +} + +function createResumeFlowParams( + taskFlow: ReturnType, + runner: LobsterRunner, +): Parameters[0] { + return { + taskFlow, + runner, + flowId: "flow-1", + expectedRevision: 4, + runnerParams: { + action: "resume", + token: "resume-1", + approve: true, + cwd: process.cwd(), + timeoutMs: 1000, + maxStdoutBytes: 4096, + }, + }; +} + describe("runManagedLobsterFlow", () => { it("creates a flow and finishes it when Lobster succeeds", async () => { const taskFlow = createFakeTaskFlow(); @@ -28,19 +67,7 @@ describe("runManagedLobsterFlow", () => { requiresApproval: null, }); - const result = await runManagedLobsterFlow({ - taskFlow, - runner, - runnerParams: { - action: "run", - pipeline: "noop", - cwd: process.cwd(), - timeoutMs: 1000, - maxStdoutBytes: 4096, - }, - controllerId: "tests/lobster", - goal: "Run Lobster workflow", - }); + const result = await runManagedLobsterFlow(createRunFlowParams(taskFlow, runner)); expect(result.ok).toBe(true); expect(taskFlow.createManaged).toHaveBeenCalledWith({ @@ -69,19 +96,7 @@ describe("runManagedLobsterFlow", () => { }, }); - const result = await runManagedLobsterFlow({ - taskFlow, - runner, - runnerParams: { - action: "run", - pipeline: "noop", - cwd: process.cwd(), - timeoutMs: 1000, - maxStdoutBytes: 4096, - }, - controllerId: "tests/lobster", - goal: "Run Lobster workflow", - }); + const result = await runManagedLobsterFlow(createRunFlowParams(taskFlow, runner)); expect(result.ok).toBe(true); expect(taskFlow.setWaiting).toHaveBeenCalledWith({ @@ -107,24 +122,9 @@ describe("runManagedLobsterFlow", () => { }, }); - const result = await runManagedLobsterFlow({ - taskFlow, - runner, - runnerParams: { - action: "run", - pipeline: "noop", - cwd: process.cwd(), - timeoutMs: 1000, - maxStdoutBytes: 4096, - }, - controllerId: "tests/lobster", - goal: "Run Lobster workflow", - }); - - expect(result.ok).toBe(false); - if (result.ok) { - throw new Error("expected managed Lobster flow to fail"); - } + const result = expectManagedFlowFailure( + await runManagedLobsterFlow(createRunFlowParams(taskFlow, runner)), + ); expect(result.error.message).toBe("boom"); expect(taskFlow.fail).toHaveBeenCalledWith({ flowId: "flow-1", @@ -138,24 +138,9 @@ describe("runManagedLobsterFlow", () => { run: vi.fn().mockRejectedValue(new Error("crashed")), }; - const result = await runManagedLobsterFlow({ - taskFlow, - runner, - runnerParams: { - action: "run", - pipeline: "noop", - cwd: process.cwd(), - timeoutMs: 1000, - maxStdoutBytes: 4096, - }, - controllerId: "tests/lobster", - goal: "Run Lobster workflow", - }); - - expect(result.ok).toBe(false); - if (result.ok) { - throw new Error("expected managed Lobster flow to fail"); - } + const result = expectManagedFlowFailure( + await runManagedLobsterFlow(createRunFlowParams(taskFlow, runner)), + ); expect(result.error.message).toBe("crashed"); expect(taskFlow.fail).toHaveBeenCalledWith({ flowId: "flow-1", @@ -174,20 +159,7 @@ describe("resumeManagedLobsterFlow", () => { requiresApproval: null, }); - const result = await resumeManagedLobsterFlow({ - taskFlow, - runner, - flowId: "flow-1", - expectedRevision: 4, - runnerParams: { - action: "resume", - token: "resume-1", - approve: true, - cwd: process.cwd(), - timeoutMs: 1000, - maxStdoutBytes: 4096, - }, - }); + const result = await resumeManagedLobsterFlow(createResumeFlowParams(taskFlow, runner)); expect(result.ok).toBe(true); expect(taskFlow.resume).toHaveBeenCalledWith({ @@ -216,25 +188,9 @@ describe("resumeManagedLobsterFlow", () => { requiresApproval: null, }); - const result = await resumeManagedLobsterFlow({ - taskFlow, - runner, - flowId: "flow-1", - expectedRevision: 4, - runnerParams: { - action: "resume", - token: "resume-1", - approve: true, - cwd: process.cwd(), - timeoutMs: 1000, - maxStdoutBytes: 4096, - }, - }); - - expect(result.ok).toBe(false); - if (result.ok) { - throw new Error("expected resumed Lobster flow to fail"); - } + const result = expectManagedFlowFailure( + await resumeManagedLobsterFlow(createResumeFlowParams(taskFlow, runner)), + ); expect(result.error.message).toMatch(/revision_conflict/); expect(runner.run).not.toHaveBeenCalled(); }); @@ -253,20 +209,7 @@ describe("resumeManagedLobsterFlow", () => { }, }); - const result = await resumeManagedLobsterFlow({ - taskFlow, - runner, - flowId: "flow-1", - expectedRevision: 4, - runnerParams: { - action: "resume", - token: "resume-1", - approve: true, - cwd: process.cwd(), - timeoutMs: 1000, - maxStdoutBytes: 4096, - }, - }); + const result = await resumeManagedLobsterFlow(createResumeFlowParams(taskFlow, runner)); expect(result.ok).toBe(true); expect(taskFlow.setWaiting).toHaveBeenCalledWith({ diff --git a/extensions/matrix/src/matrix/accounts.test.ts b/extensions/matrix/src/matrix/accounts.test.ts index c112cea94f7..e28d2229aba 100644 --- a/extensions/matrix/src/matrix/accounts.test.ts +++ b/extensions/matrix/src/matrix/accounts.test.ts @@ -33,6 +33,124 @@ const envKeys = [ getMatrixScopedEnvVarNames("team-ops").accessToken, ]; +type MatrixRoomScopeKey = "groups" | "rooms"; + +function createMatrixAccountConfig(accessToken: string) { + return { + homeserver: "https://matrix.example.org", + accessToken, + }; +} + +function createMatrixScopedEntriesConfig(scopeKey: MatrixRoomScopeKey): CoreConfig { + return { + channels: { + matrix: { + [scopeKey]: { + "!default-room:example.org": { + enabled: true, + account: "default", + }, + "!axis-room:example.org": { + enabled: true, + account: "axis", + }, + "!unassigned-room:example.org": { + enabled: true, + }, + }, + accounts: { + default: createMatrixAccountConfig("default-token"), + axis: createMatrixAccountConfig("axis-token"), + }, + }, + }, + } as unknown as CoreConfig; +} + +function createMatrixTopLevelDefaultScopedEntriesConfig(scopeKey: MatrixRoomScopeKey): CoreConfig { + return { + channels: { + matrix: { + ...createMatrixAccountConfig("default-token"), + [scopeKey]: { + "!default-room:example.org": { + enabled: true, + account: "default", + }, + "!ops-room:example.org": { + enabled: true, + account: "ops", + }, + "!shared-room:example.org": { + enabled: true, + }, + }, + accounts: { + ops: createMatrixAccountConfig("ops-token"), + }, + }, + }, + } as unknown as CoreConfig; +} + +function expectMatrixScopedEntries( + cfg: CoreConfig, + scopeKey: MatrixRoomScopeKey, + accountId: string, + expected: Record, +): void { + expect(resolveMatrixAccount({ cfg, accountId }).config[scopeKey]).toEqual(expected); +} + +function expectMultiAccountMatrixScopedEntries( + cfg: CoreConfig, + scopeKey: MatrixRoomScopeKey, +): void { + expectMatrixScopedEntries(cfg, scopeKey, "default", { + "!default-room:example.org": { + enabled: true, + account: "default", + }, + "!unassigned-room:example.org": { + enabled: true, + }, + }); + expectMatrixScopedEntries(cfg, scopeKey, "axis", { + "!axis-room:example.org": { + enabled: true, + account: "axis", + }, + "!unassigned-room:example.org": { + enabled: true, + }, + }); +} + +function expectTopLevelDefaultMatrixScopedEntries( + cfg: CoreConfig, + scopeKey: MatrixRoomScopeKey, +): void { + expectMatrixScopedEntries(cfg, scopeKey, "default", { + "!default-room:example.org": { + enabled: true, + account: "default", + }, + "!shared-room:example.org": { + enabled: true, + }, + }); + expectMatrixScopedEntries(cfg, scopeKey, "ops", { + "!ops-room:example.org": { + enabled: true, + account: "ops", + }, + "!shared-room:example.org": { + enabled: true, + }, + }); +} + describe("resolveMatrixAccount", () => { let prevEnv: Record = {}; @@ -471,203 +589,25 @@ describe("resolveMatrixAccount", () => { }); it("filters channel-level groups by room account in multi-account setups", () => { - const cfg = { - channels: { - matrix: { - groups: { - "!default-room:example.org": { - enabled: true, - account: "default", - }, - "!axis-room:example.org": { - enabled: true, - account: "axis", - }, - "!unassigned-room:example.org": { - enabled: true, - }, - }, - accounts: { - default: { - homeserver: "https://matrix.example.org", - accessToken: "default-token", - }, - axis: { - homeserver: "https://matrix.example.org", - accessToken: "axis-token", - }, - }, - }, - }, - } as unknown as CoreConfig; - - expect(resolveMatrixAccount({ cfg, accountId: "default" }).config.groups).toEqual({ - "!default-room:example.org": { - enabled: true, - account: "default", - }, - "!unassigned-room:example.org": { - enabled: true, - }, - }); - expect(resolveMatrixAccount({ cfg, accountId: "axis" }).config.groups).toEqual({ - "!axis-room:example.org": { - enabled: true, - account: "axis", - }, - "!unassigned-room:example.org": { - enabled: true, - }, - }); + expectMultiAccountMatrixScopedEntries(createMatrixScopedEntriesConfig("groups"), "groups"); }); it("filters channel-level groups when the default account is configured at the top level", () => { - const cfg = { - channels: { - matrix: { - homeserver: "https://matrix.example.org", - accessToken: "default-token", - groups: { - "!default-room:example.org": { - enabled: true, - account: "default", - }, - "!ops-room:example.org": { - enabled: true, - account: "ops", - }, - "!shared-room:example.org": { - enabled: true, - }, - }, - accounts: { - ops: { - homeserver: "https://matrix.example.org", - accessToken: "ops-token", - }, - }, - }, - }, - } as unknown as CoreConfig; - - expect(resolveMatrixAccount({ cfg, accountId: "default" }).config.groups).toEqual({ - "!default-room:example.org": { - enabled: true, - account: "default", - }, - "!shared-room:example.org": { - enabled: true, - }, - }); - expect(resolveMatrixAccount({ cfg, accountId: "ops" }).config.groups).toEqual({ - "!ops-room:example.org": { - enabled: true, - account: "ops", - }, - "!shared-room:example.org": { - enabled: true, - }, - }); + expectTopLevelDefaultMatrixScopedEntries( + createMatrixTopLevelDefaultScopedEntriesConfig("groups"), + "groups", + ); }); it("filters legacy channel-level rooms by room account in multi-account setups", () => { - const cfg = { - channels: { - matrix: { - rooms: { - "!default-room:example.org": { - enabled: true, - account: "default", - }, - "!axis-room:example.org": { - enabled: true, - account: "axis", - }, - "!unassigned-room:example.org": { - enabled: true, - }, - }, - accounts: { - default: { - homeserver: "https://matrix.example.org", - accessToken: "default-token", - }, - axis: { - homeserver: "https://matrix.example.org", - accessToken: "axis-token", - }, - }, - }, - }, - } as unknown as CoreConfig; - - expect(resolveMatrixAccount({ cfg, accountId: "default" }).config.rooms).toEqual({ - "!default-room:example.org": { - enabled: true, - account: "default", - }, - "!unassigned-room:example.org": { - enabled: true, - }, - }); - expect(resolveMatrixAccount({ cfg, accountId: "axis" }).config.rooms).toEqual({ - "!axis-room:example.org": { - enabled: true, - account: "axis", - }, - "!unassigned-room:example.org": { - enabled: true, - }, - }); + expectMultiAccountMatrixScopedEntries(createMatrixScopedEntriesConfig("rooms"), "rooms"); }); it("filters legacy channel-level rooms when the default account is configured at the top level", () => { - const cfg = { - channels: { - matrix: { - homeserver: "https://matrix.example.org", - accessToken: "default-token", - rooms: { - "!default-room:example.org": { - enabled: true, - account: "default", - }, - "!ops-room:example.org": { - enabled: true, - account: "ops", - }, - "!shared-room:example.org": { - enabled: true, - }, - }, - accounts: { - ops: { - homeserver: "https://matrix.example.org", - accessToken: "ops-token", - }, - }, - }, - }, - } as unknown as CoreConfig; - - expect(resolveMatrixAccount({ cfg, accountId: "default" }).config.rooms).toEqual({ - "!default-room:example.org": { - enabled: true, - account: "default", - }, - "!shared-room:example.org": { - enabled: true, - }, - }); - expect(resolveMatrixAccount({ cfg, accountId: "ops" }).config.rooms).toEqual({ - "!ops-room:example.org": { - enabled: true, - account: "ops", - }, - "!shared-room:example.org": { - enabled: true, - }, - }); + expectTopLevelDefaultMatrixScopedEntries( + createMatrixTopLevelDefaultScopedEntriesConfig("rooms"), + "rooms", + ); }); it("honors injected env when scoping room entries in multi-account setups", () => { diff --git a/extensions/matrix/src/matrix/monitor/handler.body-for-agent.test.ts b/extensions/matrix/src/matrix/monitor/handler.body-for-agent.test.ts index f89cbb313f9..57642745f8e 100644 --- a/extensions/matrix/src/matrix/monitor/handler.body-for-agent.test.ts +++ b/extensions/matrix/src/matrix/monitor/handler.body-for-agent.test.ts @@ -8,6 +8,61 @@ import { import type { MatrixRawEvent } from "./types.js"; describe("createMatrixRoomMessageHandler inbound body formatting", () => { + type MatrixHandlerHarness = ReturnType; + type FinalizedReplyContext = { + ReplyToBody?: string; + ReplyToSender?: string; + ThreadStarterBody?: string; + }; + + function createQuotedReplyVisibilityHarness(contextVisibility: "allowlist" | "allowlist_quote") { + return createMatrixHandlerTestHarness({ + client: { + getEvent: async () => + createMatrixTextMessageEvent({ + eventId: "$quoted", + sender: "@mallory:example.org", + body: "Quoted payload", + }), + }, + isDirectMessage: false, + cfg: { + channels: { + matrix: { + contextVisibility, + }, + }, + }, + groupPolicy: "allowlist", + groupAllowFrom: ["@alice:example.org"], + roomsConfig: { "*": {} }, + replyToMode: "all", + getMemberDisplayName: async (_roomId, userId) => + userId === "@alice:example.org" ? "Alice" : "Mallory", + }); + } + + async function sendQuotedReply(handler: MatrixHandlerHarness["handler"]) { + await handler( + "!room:example.org", + createMatrixTextMessageEvent({ + eventId: "$reply1", + sender: "@alice:example.org", + body: "@room follow up", + relatesTo: { + "m.in_reply_to": { event_id: "$quoted" }, + }, + mentions: { room: true }, + }), + ); + } + + function latestFinalizedReplyContext( + finalizeInboundContext: MatrixHandlerHarness["finalizeInboundContext"], + ) { + return vi.mocked(finalizeInboundContext).mock.calls.at(-1)?.[0] as FinalizedReplyContext; + } + beforeEach(() => { installMatrixMonitorTestRuntime({ matchesMentionPatterns: () => false, @@ -319,95 +374,22 @@ describe("createMatrixRoomMessageHandler inbound body formatting", () => { }); it("drops quoted reply context fetched from non-allowlisted room senders", async () => { - const { handler, finalizeInboundContext } = createMatrixHandlerTestHarness({ - client: { - getEvent: async () => - createMatrixTextMessageEvent({ - eventId: "$quoted", - sender: "@mallory:example.org", - body: "Quoted payload", - }), - }, - isDirectMessage: false, - cfg: { - channels: { - matrix: { - contextVisibility: "allowlist", - }, - }, - }, - groupPolicy: "allowlist", - groupAllowFrom: ["@alice:example.org"], - roomsConfig: { "*": {} }, - replyToMode: "all", - getMemberDisplayName: async (_roomId, userId) => - userId === "@alice:example.org" ? "Alice" : "Mallory", - }); + const { handler, finalizeInboundContext } = createQuotedReplyVisibilityHarness("allowlist"); - await handler( - "!room:example.org", - createMatrixTextMessageEvent({ - eventId: "$reply1", - sender: "@alice:example.org", - body: "@room follow up", - relatesTo: { - "m.in_reply_to": { event_id: "$quoted" }, - }, - mentions: { room: true }, - }), - ); + await sendQuotedReply(handler); - const finalized = vi.mocked(finalizeInboundContext).mock.calls.at(-1)?.[0] as { - ReplyToBody?: string; - ReplyToSender?: string; - }; + const finalized = latestFinalizedReplyContext(finalizeInboundContext); expect(finalized.ReplyToBody).toBeUndefined(); expect(finalized.ReplyToSender).toBeUndefined(); }); it("keeps quoted reply context in allowlist_quote mode", async () => { - const { handler, finalizeInboundContext } = createMatrixHandlerTestHarness({ - client: { - getEvent: async () => - createMatrixTextMessageEvent({ - eventId: "$quoted", - sender: "@mallory:example.org", - body: "Quoted payload", - }), - }, - isDirectMessage: false, - cfg: { - channels: { - matrix: { - contextVisibility: "allowlist_quote", - }, - }, - }, - groupPolicy: "allowlist", - groupAllowFrom: ["@alice:example.org"], - roomsConfig: { "*": {} }, - replyToMode: "all", - getMemberDisplayName: async (_roomId, userId) => - userId === "@alice:example.org" ? "Alice" : "Mallory", - }); + const { handler, finalizeInboundContext } = + createQuotedReplyVisibilityHarness("allowlist_quote"); - await handler( - "!room:example.org", - createMatrixTextMessageEvent({ - eventId: "$reply1", - sender: "@alice:example.org", - body: "@room follow up", - relatesTo: { - "m.in_reply_to": { event_id: "$quoted" }, - }, - mentions: { room: true }, - }), - ); + await sendQuotedReply(handler); - const finalized = vi.mocked(finalizeInboundContext).mock.calls.at(-1)?.[0] as { - ReplyToBody?: string; - ReplyToSender?: string; - }; + const finalized = latestFinalizedReplyContext(finalizeInboundContext); expect(finalized.ReplyToBody).toBe("Quoted payload"); expect(finalized.ReplyToSender).toBe("Mallory"); }); diff --git a/extensions/matrix/src/matrix/monitor/handler.group-history.test.ts b/extensions/matrix/src/matrix/monitor/handler.group-history.test.ts index 12a4095de14..56da3656294 100644 --- a/extensions/matrix/src/matrix/monitor/handler.group-history.test.ts +++ b/extensions/matrix/src/matrix/monitor/handler.group-history.test.ts @@ -68,6 +68,46 @@ function deferred() { return { promise, resolve }; } +function createFinalDeliveryFailureHandler(finalizeInboundContext: (ctx: unknown) => unknown) { + let capturedOnError: + | ((err: unknown, info: { kind: "tool" | "block" | "final" }) => void) + | undefined; + + return createMatrixHandlerTestHarness({ + historyLimit: 20, + groupPolicy: "open", + isDirectMessage: false, + finalizeInboundContext, + dispatchReplyFromConfig: async () => ({ + queuedFinal: true, + counts: { final: 1, block: 0, tool: 0 }, + }), + createReplyDispatcherWithTyping: (params?: { + onError?: (err: unknown, info: { kind: "tool" | "block" | "final" }) => void; + }) => { + capturedOnError = params?.onError; + return { + dispatcher: {}, + replyOptions: {}, + markDispatchIdle: () => {}, + markRunComplete: () => {}, + }; + }, + withReplyDispatcher: async (params: { + dispatcher: { markComplete?: () => void; waitForIdle?: () => Promise }; + run: () => Promise; + onSettled?: () => void | Promise; + }) => { + const result = await params.run(); + capturedOnError?.(new Error("simulated delivery failure"), { kind: "final" }); + params.dispatcher.markComplete?.(); + await params.dispatcher.waitForIdle?.(); + await params.onSettled?.(); + return result; + }, + }); +} + describe("matrix group chat history — scenario 1: basic accumulation", () => { it("pending messages appear in InboundHistory; trigger itself does not", async () => { const finalizeInboundContext = vi.fn((ctx: unknown) => ctx); @@ -447,45 +487,8 @@ describe("matrix group chat history — scenario 2: race condition safety", () = }); it("watermark does not advance when final reply delivery fails (retry sees same history)", async () => { - // Capture the onError callback so we can fire a simulated final delivery failure - let capturedOnError: - | ((err: unknown, info: { kind: "tool" | "block" | "final" }) => void) - | undefined; - const finalizeInboundContext = vi.fn((ctx: unknown) => ctx); - const { handler } = createMatrixHandlerTestHarness({ - historyLimit: 20, - groupPolicy: "open", - isDirectMessage: false, - finalizeInboundContext, - dispatchReplyFromConfig: async () => ({ - queuedFinal: true, - counts: { final: 1, block: 0, tool: 0 }, - }), - createReplyDispatcherWithTyping: (params?: { - onError?: (err: unknown, info: { kind: "tool" | "block" | "final" }) => void; - }) => { - capturedOnError = params?.onError; - return { - dispatcher: {}, - replyOptions: {}, - markDispatchIdle: () => {}, - markRunComplete: () => {}, - }; - }, - withReplyDispatcher: async (params: { - dispatcher: { markComplete?: () => void; waitForIdle?: () => Promise }; - run: () => Promise; - onSettled?: () => void | Promise; - }) => { - const result = await params.run(); - capturedOnError?.(new Error("simulated delivery failure"), { kind: "final" }); - params.dispatcher.markComplete?.(); - await params.dispatcher.waitForIdle?.(); - await params.onSettled?.(); - return result; - }, - }); + const { handler } = createFinalDeliveryFailureHandler(finalizeInboundContext); await handler( DEFAULT_ROOM, @@ -519,44 +522,8 @@ describe("matrix group chat history — scenario 2: race condition safety", () = }); it("retrying the same failed trigger reuses the original history window", async () => { - let capturedOnError: - | ((err: unknown, info: { kind: "tool" | "block" | "final" }) => void) - | undefined; - const finalizeInboundContext = vi.fn((ctx: unknown) => ctx); - const { handler } = createMatrixHandlerTestHarness({ - historyLimit: 20, - groupPolicy: "open", - isDirectMessage: false, - finalizeInboundContext, - dispatchReplyFromConfig: async () => ({ - queuedFinal: true, - counts: { final: 1, block: 0, tool: 0 }, - }), - createReplyDispatcherWithTyping: (params?: { - onError?: (err: unknown, info: { kind: "tool" | "block" | "final" }) => void; - }) => { - capturedOnError = params?.onError; - return { - dispatcher: {}, - replyOptions: {}, - markDispatchIdle: () => {}, - markRunComplete: () => {}, - }; - }, - withReplyDispatcher: async (params: { - dispatcher: { markComplete?: () => void; waitForIdle?: () => Promise }; - run: () => Promise; - onSettled?: () => void | Promise; - }) => { - const result = await params.run(); - capturedOnError?.(new Error("simulated delivery failure"), { kind: "final" }); - params.dispatcher.markComplete?.(); - await params.dispatcher.waitForIdle?.(); - await params.onSettled?.(); - return result; - }, - }); + const { handler } = createFinalDeliveryFailureHandler(finalizeInboundContext); await handler( DEFAULT_ROOM, diff --git a/extensions/matrix/src/matrix/monitor/reaction-events.test.ts b/extensions/matrix/src/matrix/monitor/reaction-events.test.ts index 2c453b48ec7..27b9f046968 100644 --- a/extensions/matrix/src/matrix/monitor/reaction-events.test.ts +++ b/extensions/matrix/src/matrix/monitor/reaction-events.test.ts @@ -8,6 +8,10 @@ import type { CoreConfig } from "../../types.js"; import { handleInboundMatrixReaction } from "./reaction-events.js"; const resolveMatrixApproval = vi.fn(); +type MatrixReactionParams = Parameters[0]; +type MatrixReactionClient = MatrixReactionParams["client"]; +type MatrixReactionCore = MatrixReactionParams["core"]; +type MatrixReactionEvent = MatrixReactionParams["event"]; vi.mock("../../exec-approval-resolver.js", () => ({ isApprovalNotFoundError: (err: unknown) => @@ -56,49 +60,87 @@ function buildCore() { } as unknown as Parameters[0]["core"]; } +function createReactionClient( + getEvent: ReturnType = vi.fn(), +): MatrixReactionClient & { getEvent: ReturnType } { + return { getEvent } as unknown as MatrixReactionClient & { + getEvent: ReturnType; + }; +} + +function createReactionEvent( + params: { + eventId?: string; + targetEventId?: string; + reactionKey?: string; + } = {}, +): MatrixReactionEvent { + return { + event_id: params.eventId ?? "$reaction-1", + sender: "@owner:example.org", + type: "m.reaction", + origin_server_ts: 123, + content: { + "m.relates_to": { + rel_type: "m.annotation", + event_id: params.targetEventId ?? "$approval-msg", + key: params.reactionKey ?? "✅", + }, + }, + } as MatrixReactionEvent; +} + +async function handleReaction(params: { + client: MatrixReactionClient; + core: MatrixReactionCore; + cfg?: CoreConfig; + targetEventId?: string; + reactionKey?: string; +}): Promise { + await handleInboundMatrixReaction({ + client: params.client, + core: params.core, + cfg: params.cfg ?? buildConfig(), + accountId: "default", + roomId: "!ops:example.org", + event: createReactionEvent({ + targetEventId: params.targetEventId, + reactionKey: params.reactionKey, + }), + senderId: "@owner:example.org", + senderLabel: "Owner", + selfUserId: "@bot:example.org", + isDirectMessage: false, + logVerboseMessage: vi.fn(), + }); +} + describe("matrix approval reactions", () => { it("resolves approval reactions instead of enqueueing a generic reaction event", async () => { const core = buildCore(); + const cfg = buildConfig(); registerMatrixApprovalReactionTarget({ roomId: "!ops:example.org", eventId: "$approval-msg", approvalId: "req-123", allowedDecisions: ["allow-once", "allow-always", "deny"], }); - const client = { - getEvent: vi.fn().mockResolvedValue({ + const client = createReactionClient( + vi.fn().mockResolvedValue({ event_id: "$approval-msg", sender: "@bot:example.org", content: { body: "approval prompt" }, }), - } as unknown as Parameters[0]["client"]; + ); - await handleInboundMatrixReaction({ + await handleReaction({ client, core, - cfg: buildConfig(), - accountId: "default", - roomId: "!ops:example.org", - event: { - event_id: "$reaction-1", - origin_server_ts: 123, - content: { - "m.relates_to": { - rel_type: "m.annotation", - event_id: "$approval-msg", - key: "✅", - }, - }, - } as never, - senderId: "@owner:example.org", - senderLabel: "Owner", - selfUserId: "@bot:example.org", - isDirectMessage: false, - logVerboseMessage: vi.fn(), + cfg, }); expect(resolveMatrixApproval).toHaveBeenCalledWith({ - cfg: buildConfig(), + cfg, approvalId: "req-123", decision: "allow-once", senderId: "@owner:example.org", @@ -108,38 +150,21 @@ describe("matrix approval reactions", () => { it("keeps ordinary reactions on bot messages as generic reaction events", async () => { const core = buildCore(); - const client = { - getEvent: vi.fn().mockResolvedValue({ + const client = createReactionClient( + vi.fn().mockResolvedValue({ event_id: "$msg-1", sender: "@bot:example.org", content: { body: "normal bot message", }, }), - } as unknown as Parameters[0]["client"]; + ); - await handleInboundMatrixReaction({ + await handleReaction({ client, core, - cfg: buildConfig(), - accountId: "default", - roomId: "!ops:example.org", - event: { - event_id: "$reaction-1", - origin_server_ts: 123, - content: { - "m.relates_to": { - rel_type: "m.annotation", - event_id: "$msg-1", - key: "👍", - }, - }, - } as never, - senderId: "@owner:example.org", - senderLabel: "Owner", - selfUserId: "@bot:example.org", - isDirectMessage: false, - logVerboseMessage: vi.fn(), + targetEventId: "$msg-1", + reactionKey: "👍", }); expect(resolveMatrixApproval).not.toHaveBeenCalled(); @@ -165,36 +190,19 @@ describe("matrix approval reactions", () => { approvalId: "req-123", allowedDecisions: ["deny"], }); - const client = { - getEvent: vi.fn().mockResolvedValue({ + const client = createReactionClient( + vi.fn().mockResolvedValue({ event_id: "$approval-msg", sender: "@bot:example.org", content: { body: "approval prompt" }, }), - } as unknown as Parameters[0]["client"]; + ); - await handleInboundMatrixReaction({ + await handleReaction({ client, core, cfg, - accountId: "default", - roomId: "!ops:example.org", - event: { - event_id: "$reaction-1", - origin_server_ts: 123, - content: { - "m.relates_to": { - rel_type: "m.annotation", - event_id: "$approval-msg", - key: "❌", - }, - }, - } as never, - senderId: "@owner:example.org", - senderLabel: "Owner", - selfUserId: "@bot:example.org", - isDirectMessage: false, - logVerboseMessage: vi.fn(), + reactionKey: "❌", }); expect(resolveMatrixApproval).toHaveBeenCalledWith({ @@ -214,32 +222,11 @@ describe("matrix approval reactions", () => { approvalId: "req-123", allowedDecisions: ["allow-once"], }); - const client = { - getEvent: vi.fn().mockRejectedValue(new Error("boom")), - } as unknown as Parameters[0]["client"]; + const client = createReactionClient(vi.fn().mockRejectedValue(new Error("boom"))); - await handleInboundMatrixReaction({ + await handleReaction({ client, core, - cfg: buildConfig(), - accountId: "default", - roomId: "!ops:example.org", - event: { - event_id: "$reaction-1", - origin_server_ts: 123, - content: { - "m.relates_to": { - rel_type: "m.annotation", - event_id: "$approval-msg", - key: "✅", - }, - }, - } as never, - senderId: "@owner:example.org", - senderLabel: "Owner", - selfUserId: "@bot:example.org", - isDirectMessage: false, - logVerboseMessage: vi.fn(), }); expect(client.getEvent).not.toHaveBeenCalled(); @@ -266,32 +253,13 @@ describe("matrix approval reactions", () => { approvalId: "plugin:req-123", allowedDecisions: ["allow-once", "deny"], }); - const client = { - getEvent: vi.fn(), - } as unknown as Parameters[0]["client"]; + const client = createReactionClient(); - await handleInboundMatrixReaction({ + await handleReaction({ client, core, cfg, - accountId: "default", - roomId: "!ops:example.org", - event: { - event_id: "$reaction-1", - origin_server_ts: 123, - content: { - "m.relates_to": { - rel_type: "m.annotation", - event_id: "$plugin-approval-msg", - key: "✅", - }, - }, - } as never, - senderId: "@owner:example.org", - senderLabel: "Owner", - selfUserId: "@bot:example.org", - isDirectMessage: false, - logVerboseMessage: vi.fn(), + targetEventId: "$plugin-approval-msg", }); expect(client.getEvent).not.toHaveBeenCalled(); @@ -315,32 +283,12 @@ describe("matrix approval reactions", () => { approvalId: "req-123", allowedDecisions: ["deny"], }); - const client = { - getEvent: vi.fn(), - } as unknown as Parameters[0]["client"]; + const client = createReactionClient(); - await handleInboundMatrixReaction({ + await handleReaction({ client, core, - cfg: buildConfig(), - accountId: "default", - roomId: "!ops:example.org", - event: { - event_id: "$reaction-1", - origin_server_ts: 123, - content: { - "m.relates_to": { - rel_type: "m.annotation", - event_id: "$approval-msg", - key: "❌", - }, - }, - } as never, - senderId: "@owner:example.org", - senderLabel: "Owner", - selfUserId: "@bot:example.org", - isDirectMessage: false, - logVerboseMessage: vi.fn(), + reactionKey: "❌", }); expect(client.getEvent).not.toHaveBeenCalled(); @@ -361,32 +309,14 @@ describe("matrix approval reactions", () => { throw new Error("matrix config missing"); } matrixCfg.reactionNotifications = "off"; - const client = { - getEvent: vi.fn(), - } as unknown as Parameters[0]["client"]; + const client = createReactionClient(); - await handleInboundMatrixReaction({ + await handleReaction({ client, core, cfg, - accountId: "default", - roomId: "!ops:example.org", - event: { - event_id: "$reaction-1", - origin_server_ts: 123, - content: { - "m.relates_to": { - rel_type: "m.annotation", - event_id: "$msg-1", - key: "👍", - }, - }, - } as never, - senderId: "@owner:example.org", - senderLabel: "Owner", - selfUserId: "@bot:example.org", - isDirectMessage: false, - logVerboseMessage: vi.fn(), + targetEventId: "$msg-1", + reactionKey: "👍", }); expect(client.getEvent).not.toHaveBeenCalled(); diff --git a/extensions/matrix/src/matrix/monitor/room-info.test.ts b/extensions/matrix/src/matrix/monitor/room-info.test.ts index bbeb34469ea..c982de51775 100644 --- a/extensions/matrix/src/matrix/monitor/room-info.test.ts +++ b/extensions/matrix/src/matrix/monitor/room-info.test.ts @@ -2,32 +2,51 @@ import { describe, expect, it, vi } from "vitest"; import type { MatrixClient } from "../sdk.js"; import { createMatrixRoomInfoResolver } from "./room-info.js"; -function createClientStub() { +type RoomStateHandler = ( + roomId: string, + eventType: string, + stateKey: string, +) => Promise>; + +type RoomInfoClientStub = MatrixClient & { + getRoomStateEvent: ReturnType; +}; + +function createRoomStateClient(handler: RoomStateHandler): RoomInfoClientStub { return { - getRoomStateEvent: vi.fn( - async ( - roomId: string, - eventType: string, - stateKey: string, - ): Promise> => { - if (eventType === "m.room.name") { - return { name: `Room ${roomId}` }; - } - if (eventType === "m.room.canonical_alias") { - return { - alias: `#alias-${roomId}:example.org`, - alt_aliases: [`#alt-${roomId}:example.org`], - }; - } - if (eventType === "m.room.member") { - return { displayname: `Display ${roomId}:${stateKey}` }; - } - return {}; - }, - ), - } as unknown as MatrixClient & { - getRoomStateEvent: ReturnType; - }; + getRoomStateEvent: vi.fn(handler), + } as unknown as RoomInfoClientStub; +} + +function createClientStub() { + return createRoomStateClient(async (roomId, eventType, stateKey) => { + if (eventType === "m.room.name") { + return { name: `Room ${roomId}` }; + } + if (eventType === "m.room.canonical_alias") { + return { + alias: `#alias-${roomId}:example.org`, + alt_aliases: [`#alt-${roomId}:example.org`], + }; + } + if (eventType === "m.room.member") { + return { displayname: `Display ${roomId}:${stateKey}` }; + } + return {}; + }); +} + +function createMissingMetadataError() { + const err = new Error("M_NOT_FOUND"); + Object.assign(err, { + statusCode: 404, + body: { errcode: "M_NOT_FOUND" }, + }); + return err; +} + +function getRoomStateCallCount(client: RoomInfoClientStub, eventType: string) { + return client.getRoomStateEvent.mock.calls.filter(([, type]) => type === eventType).length; } describe("createMatrixRoomInfoResolver", () => { @@ -59,18 +78,7 @@ describe("createMatrixRoomInfoResolver", () => { }); it("caches fallback user IDs when member display names are missing", async () => { - const client = { - getRoomStateEvent: vi.fn( - async (_roomId: string, eventType: string): Promise> => { - if (eventType === "m.room.member") { - return {}; - } - return {}; - }, - ), - } as unknown as MatrixClient & { - getRoomStateEvent: ReturnType; - }; + const client = createRoomStateClient(async () => ({})); const resolver = createMatrixRoomInfoResolver(client); await expect( @@ -84,16 +92,12 @@ describe("createMatrixRoomInfoResolver", () => { }); it("marks unresolved room metadata when room info lookups fail", async () => { - const client = { - getRoomStateEvent: vi.fn(async (_roomId: string, eventType: string) => { - if (eventType === "m.room.member") { - return {}; - } - throw new Error("room info unavailable"); - }), - } as unknown as MatrixClient & { - getRoomStateEvent: ReturnType; - }; + const client = createRoomStateClient(async (_roomId, eventType) => { + if (eventType === "m.room.member") { + return {}; + } + throw new Error("room info unavailable"); + }); const resolver = createMatrixRoomInfoResolver(client); await expect( @@ -106,21 +110,12 @@ describe("createMatrixRoomInfoResolver", () => { }); it("treats missing room metadata as resolved-empty state", async () => { - const client = { - getRoomStateEvent: vi.fn(async (_roomId: string, eventType: string) => { - if (eventType === "m.room.name" || eventType === "m.room.canonical_alias") { - const err = new Error("M_NOT_FOUND"); - Object.assign(err, { - statusCode: 404, - body: { errcode: "M_NOT_FOUND" }, - }); - throw err; - } - return {}; - }), - } as unknown as MatrixClient & { - getRoomStateEvent: ReturnType; - }; + const client = createRoomStateClient(async (_roomId, eventType) => { + if (eventType === "m.room.name" || eventType === "m.room.canonical_alias") { + throw createMissingMetadataError(); + } + return {}; + }); const resolver = createMatrixRoomInfoResolver(client); await expect( @@ -133,34 +128,24 @@ describe("createMatrixRoomInfoResolver", () => { }); it("retries room metadata after a transient lookup failure", async () => { - const client = { - getRoomStateEvent: vi.fn(async (_roomId: string, eventType: string) => { - if (eventType === "m.room.name") { - if ( - client.getRoomStateEvent.mock.calls.filter(([, type]) => type === eventType).length === - 1 - ) { - throw new Error("name lookup unavailable"); - } - return { name: "Recovered Room" }; + const client = createRoomStateClient(async (_roomId, eventType) => { + if (eventType === "m.room.name") { + if (getRoomStateCallCount(client, eventType) === 1) { + throw new Error("name lookup unavailable"); } - if (eventType === "m.room.canonical_alias") { - if ( - client.getRoomStateEvent.mock.calls.filter(([, type]) => type === eventType).length === - 1 - ) { - throw new Error("alias lookup unavailable"); - } - return { - alias: "#recovered:example.org", - alt_aliases: ["#alt-recovered:example.org"], - }; + return { name: "Recovered Room" }; + } + if (eventType === "m.room.canonical_alias") { + if (getRoomStateCallCount(client, eventType) === 1) { + throw new Error("alias lookup unavailable"); } - return {}; - }), - } as unknown as MatrixClient & { - getRoomStateEvent: ReturnType; - }; + return { + alias: "#recovered:example.org", + alt_aliases: ["#alt-recovered:example.org"], + }; + } + return {}; + }); const resolver = createMatrixRoomInfoResolver(client); await expect( @@ -182,13 +167,9 @@ describe("createMatrixRoomInfoResolver", () => { }); it("caches fallback user IDs when member display-name lookups fail", async () => { - const client = { - getRoomStateEvent: vi.fn(async (): Promise> => { - throw new Error("member lookup failed"); - }), - } as unknown as MatrixClient & { - getRoomStateEvent: ReturnType; - }; + const client = createRoomStateClient(async () => { + throw new Error("member lookup failed"); + }); const resolver = createMatrixRoomInfoResolver(client); await expect( diff --git a/extensions/matrix/src/matrix/monitor/route.test.ts b/extensions/matrix/src/matrix/monitor/route.test.ts index 1eefd21f960..cfe0e5a0be4 100644 --- a/extensions/matrix/src/matrix/monitor/route.test.ts +++ b/extensions/matrix/src/matrix/monitor/route.test.ts @@ -17,6 +17,33 @@ const baseCfg = { }, } satisfies OpenClawConfig; +type RouteBinding = NonNullable[number]; +type RoutePeer = NonNullable; + +function matrixBinding( + agentId: string, + peer?: RoutePeer, + type?: RouteBinding["type"], +): RouteBinding { + return { + ...(type ? { type } : {}), + agentId, + match: { + channel: "matrix", + accountId: "ops", + ...(peer ? { peer } : {}), + }, + } as RouteBinding; +} + +function senderPeer(id = "@alice:example.org"): RoutePeer { + return { kind: "direct", id }; +} + +function dmRoomPeer(id = "!dm:example.org"): RoutePeer { + return { kind: "channel", id }; +} + function resolveDmRoute( cfg: OpenClawConfig, opts: { @@ -46,22 +73,8 @@ describe("resolveMatrixInboundRoute", () => { const cfg = { ...baseCfg, bindings: [ - { - agentId: "room-agent", - match: { - channel: "matrix", - accountId: "ops", - peer: { kind: "channel", id: "!dm:example.org" }, - }, - }, - { - agentId: "sender-agent", - match: { - channel: "matrix", - accountId: "ops", - peer: { kind: "direct", id: "@alice:example.org" }, - }, - }, + matrixBinding("room-agent", dmRoomPeer()), + matrixBinding("sender-agent", senderPeer()), ], } satisfies OpenClawConfig; @@ -76,23 +89,7 @@ describe("resolveMatrixInboundRoute", () => { it("uses the DM room as a parent-peer fallback before account-level bindings", () => { const cfg = { ...baseCfg, - bindings: [ - { - agentId: "acp-agent", - match: { - channel: "matrix", - accountId: "ops", - }, - }, - { - agentId: "room-agent", - match: { - channel: "matrix", - accountId: "ops", - peer: { kind: "channel", id: "!dm:example.org" }, - }, - }, - ], + bindings: [matrixBinding("acp-agent"), matrixBinding("room-agent", dmRoomPeer())], } satisfies OpenClawConfig; const { route, configuredBinding } = resolveDmRoute(cfg); @@ -106,16 +103,7 @@ describe("resolveMatrixInboundRoute", () => { it("can isolate Matrix DMs per room without changing agent selection", () => { const cfg = { ...baseCfg, - bindings: [ - { - agentId: "sender-agent", - match: { - channel: "matrix", - accountId: "ops", - peer: { kind: "direct", id: "@alice:example.org" }, - }, - }, - ], + bindings: [matrixBinding("sender-agent", senderPeer())], } satisfies OpenClawConfig; const { route, configuredBinding } = resolveDmRoute(cfg, { @@ -134,23 +122,8 @@ describe("resolveMatrixInboundRoute", () => { const cfg = { ...baseCfg, bindings: [ - { - agentId: "room-agent", - match: { - channel: "matrix", - accountId: "ops", - peer: { kind: "channel", id: "!dm:example.org" }, - }, - }, - { - type: "acp", - agentId: "acp-agent", - match: { - channel: "matrix", - accountId: "ops", - peer: { kind: "channel", id: "!dm:example.org" }, - }, - }, + matrixBinding("room-agent", dmRoomPeer()), + matrixBinding("acp-agent", dmRoomPeer(), "acp"), ], } satisfies OpenClawConfig; @@ -167,23 +140,8 @@ describe("resolveMatrixInboundRoute", () => { const cfg = { ...baseCfg, bindings: [ - { - agentId: "room-agent", - match: { - channel: "matrix", - accountId: "ops", - peer: { kind: "channel", id: "!dm:example.org" }, - }, - }, - { - type: "acp", - agentId: "acp-agent", - match: { - channel: "matrix", - accountId: "ops", - peer: { kind: "channel", id: "!dm:example.org" }, - }, - }, + matrixBinding("room-agent", dmRoomPeer()), + matrixBinding("acp-agent", dmRoomPeer(), "acp"), ], } satisfies OpenClawConfig; @@ -227,22 +185,8 @@ describe("resolveMatrixInboundRoute", () => { const cfg = { ...baseCfg, bindings: [ - { - agentId: "sender-agent", - match: { - channel: "matrix", - accountId: "ops", - peer: { kind: "direct", id: "@alice:example.org" }, - }, - }, - { - agentId: "room-agent", - match: { - channel: "matrix", - accountId: "ops", - peer: { kind: "channel", id: "!dm:example.org" }, - }, - }, + matrixBinding("sender-agent", senderPeer()), + matrixBinding("room-agent", dmRoomPeer()), ], } satisfies OpenClawConfig; diff --git a/extensions/matrix/src/matrix/thread-bindings.test.ts b/extensions/matrix/src/matrix/thread-bindings.test.ts index 2c5d4c5ab7d..c8dc6c9d077 100644 --- a/extensions/matrix/src/matrix/thread-bindings.test.ts +++ b/extensions/matrix/src/matrix/thread-bindings.test.ts @@ -11,6 +11,7 @@ import { resolveMatrixStoragePaths, writeStorageMeta, } from "./client/storage.js"; +import type { MatrixAuth, MatrixStoragePaths } from "./client/types.js"; import { createMatrixThreadBindingManager, resetMatrixThreadBindingsForTests, @@ -45,6 +46,11 @@ describe("matrix thread bindings", () => { const idleTimeoutMs = 24 * 60 * 60 * 1000; const matrixClient = {} as never; + function resetThreadBindingAdapters() { + __testing.resetSessionBindingAdaptersForTests(); + resetMatrixThreadBindingsForTests(); + } + function currentThreadConversation(params?: { conversationId?: string; parentConversationId?: string; @@ -57,17 +63,32 @@ describe("matrix thread bindings", () => { }; } - async function createStaticThreadBindingManager() { + function createBindingManager( + params: { + auth?: MatrixAuth; + stateDir?: string; + idleTimeoutMs?: number; + maxAgeMs?: number; + enableSweeper?: boolean; + logVerboseMessage?: (message: string) => void; + } = {}, + ) { return createMatrixThreadBindingManager({ accountId, - auth, + auth: params.auth ?? auth, client: matrixClient, - idleTimeoutMs, - maxAgeMs: 0, - enableSweeper: false, + ...(params.stateDir ? { stateDir: params.stateDir } : {}), + idleTimeoutMs: params.idleTimeoutMs ?? idleTimeoutMs, + maxAgeMs: params.maxAgeMs ?? 0, + enableSweeper: params.enableSweeper ?? false, + ...(params.logVerboseMessage ? { logVerboseMessage: params.logVerboseMessage } : {}), }); } + async function createStaticThreadBindingManager() { + return createBindingManager(); + } + async function bindCurrentThread(params?: { targetSessionKey?: string; conversationId?: string; @@ -95,6 +116,16 @@ describe("matrix thread bindings", () => { }); } + function writeAuthStorageMeta(authForMeta: MatrixAuth, storagePaths: MatrixStoragePaths) { + writeStorageMeta({ + storagePaths, + homeserver: authForMeta.homeserver, + userId: authForMeta.userId, + accountId: authForMeta.accountId, + deviceId: authForMeta.deviceId ?? null, + }); + } + async function readPersistedLastActivityAt(bindingsPath: string) { const raw = await fs.readFile(bindingsPath, "utf-8"); const parsed = JSON.parse(raw) as { @@ -103,10 +134,32 @@ describe("matrix thread bindings", () => { return parsed.bindings?.[0]?.lastActivityAt; } + async function expectPersistedThreadBinding( + bindingsPath: string, + expected: { + conversationId: string; + targetSessionKey: string; + parentConversationId?: string; + }, + ) { + await vi.waitFor(async () => { + const persistedRaw = await fs.readFile(bindingsPath, "utf-8"); + expect(JSON.parse(persistedRaw)).toMatchObject({ + version: 1, + bindings: [ + expect.objectContaining({ + conversationId: expected.conversationId, + parentConversationId: expected.parentConversationId ?? "!room:example", + targetSessionKey: expected.targetSessionKey, + }), + ], + }); + }); + } + beforeEach(() => { stateDir = fsSync.mkdtempSync(path.join(os.tmpdir(), "matrix-thread-bindings-")); - __testing.resetSessionBindingAdaptersForTests(); - resetMatrixThreadBindingsForTests(); + resetThreadBindingAdapters(); sendMessageMatrixMock.mockClear(); renameMock.mockReset(); renameMock.mockImplementation(actualRename); @@ -384,50 +437,19 @@ describe("matrix thread bindings", () => { accessToken: "token-new", }; - const initialManager = await createMatrixThreadBindingManager({ - accountId: "ops", - auth: initialAuth, - client: {} as never, - idleTimeoutMs: 24 * 60 * 60 * 1000, - maxAgeMs: 0, - enableSweeper: false, - }); + const initialManager = await createBindingManager({ auth: initialAuth }); - await getSessionBindingService().bind({ - targetSessionKey: "agent:ops:subagent:child", - targetKind: "subagent", - conversation: { - channel: "matrix", - accountId: "ops", - conversationId: "$thread", - parentConversationId: "!room:example", - }, - placement: "current", - }); + await bindCurrentThread(); const initialStoragePaths = resolveMatrixStoragePaths({ ...initialAuth, env: process.env, }); - writeStorageMeta({ - storagePaths: initialStoragePaths, - homeserver: initialAuth.homeserver, - userId: initialAuth.userId, - accountId: initialAuth.accountId, - deviceId: null, - }); + writeAuthStorageMeta(initialAuth, initialStoragePaths); initialManager.stop(); - resetMatrixThreadBindingsForTests(); - __testing.resetSessionBindingAdaptersForTests(); + resetThreadBindingAdapters(); - await createMatrixThreadBindingManager({ - accountId: "ops", - auth: rotatedAuth, - client: {} as never, - idleTimeoutMs: 24 * 60 * 60 * 1000, - maxAgeMs: 0, - enableSweeper: false, - }); + await createBindingManager({ auth: rotatedAuth }); expect( getSessionBindingService().resolveByConversation({ @@ -461,64 +483,24 @@ describe("matrix thread bindings", () => { deviceId: "DEVICE123", }; - const initialManager = await createMatrixThreadBindingManager({ - accountId: "ops", - auth: initialAuth, - client: {} as never, - idleTimeoutMs: 24 * 60 * 60 * 1000, - maxAgeMs: 0, - enableSweeper: false, - }); + const initialManager = await createBindingManager({ auth: initialAuth }); - await getSessionBindingService().bind({ - targetSessionKey: "agent:ops:subagent:child", - targetKind: "subagent", - conversation: { - channel: "matrix", - accountId: "ops", - conversationId: "$thread", - parentConversationId: "!room:example", - }, - placement: "current", - }); + await bindCurrentThread(); const initialStoragePaths = resolveMatrixStoragePaths({ ...initialAuth, env: process.env, }); - writeStorageMeta({ - storagePaths: initialStoragePaths, - homeserver: initialAuth.homeserver, - userId: initialAuth.userId, - accountId: initialAuth.accountId, - deviceId: initialAuth.deviceId, - }); + writeAuthStorageMeta(initialAuth, initialStoragePaths); const initialBindingsPath = path.join(initialStoragePaths.rootDir, "thread-bindings.json"); - await vi.waitFor(async () => { - const persistedRaw = await fs.readFile(initialBindingsPath, "utf-8"); - expect(JSON.parse(persistedRaw)).toMatchObject({ - version: 1, - bindings: [ - expect.objectContaining({ - conversationId: "$thread", - parentConversationId: "!room:example", - targetSessionKey: "agent:ops:subagent:child", - }), - ], - }); + await expectPersistedThreadBinding(initialBindingsPath, { + conversationId: "$thread", + targetSessionKey: "agent:ops:subagent:child", }); initialManager.stop(); - resetMatrixThreadBindingsForTests(); - __testing.resetSessionBindingAdaptersForTests(); + resetThreadBindingAdapters(); - await createMatrixThreadBindingManager({ - accountId: "ops", - auth: rotatedAuth, - client: {} as never, - idleTimeoutMs: 24 * 60 * 60 * 1000, - maxAgeMs: 0, - enableSweeper: false, - }); + await createBindingManager({ auth: rotatedAuth }); expect( getSessionBindingService().resolveByConversation({ @@ -547,36 +529,14 @@ describe("matrix thread bindings", () => { path.join(os.tmpdir(), "matrix-thread-bindings-replacement-"), ); - const initialManager = await createMatrixThreadBindingManager({ - accountId: "ops", - auth, - client: {} as never, + const initialManager = await createBindingManager({ stateDir: initialStateDir, - idleTimeoutMs: 24 * 60 * 60 * 1000, - maxAgeMs: 0, - enableSweeper: false, }); - await getSessionBindingService().bind({ - targetSessionKey: "agent:ops:subagent:child", - targetKind: "subagent", - conversation: { - channel: "matrix", - accountId: "ops", - conversationId: "$thread", - parentConversationId: "!room:example", - }, - placement: "current", - }); + await bindCurrentThread(); - const replacementManager = await createMatrixThreadBindingManager({ - accountId: "ops", - auth, - client: {} as never, + const replacementManager = await createBindingManager({ stateDir: replacementStateDir, - idleTimeoutMs: 24 * 60 * 60 * 1000, - maxAgeMs: 0, - enableSweeper: false, }); expect(replacementManager).not.toBe(initialManager); @@ -590,46 +550,18 @@ describe("matrix thread bindings", () => { }), ).toBeNull(); - await getSessionBindingService().bind({ + await bindCurrentThread({ targetSessionKey: "agent:ops:subagent:replacement", - targetKind: "subagent", - conversation: { - channel: "matrix", - accountId: "ops", - conversationId: "$thread-2", - parentConversationId: "!room:example", - }, - placement: "current", + conversationId: "$thread-2", }); - await vi.waitFor(async () => { - const replacementRaw = await fs.readFile( - resolveBindingsFilePath(replacementStateDir), - "utf-8", - ); - expect(JSON.parse(replacementRaw)).toMatchObject({ - version: 1, - bindings: [ - expect.objectContaining({ - conversationId: "$thread-2", - parentConversationId: "!room:example", - targetSessionKey: "agent:ops:subagent:replacement", - }), - ], - }); + await expectPersistedThreadBinding(resolveBindingsFilePath(replacementStateDir), { + conversationId: "$thread-2", + targetSessionKey: "agent:ops:subagent:replacement", }); - await vi.waitFor(async () => { - const initialRaw = await fs.readFile(resolveBindingsFilePath(initialStateDir), "utf-8"); - expect(JSON.parse(initialRaw)).toMatchObject({ - version: 1, - bindings: [ - expect.objectContaining({ - conversationId: "$thread", - parentConversationId: "!room:example", - targetSessionKey: "agent:ops:subagent:child", - }), - ], - }); + await expectPersistedThreadBinding(resolveBindingsFilePath(initialStateDir), { + conversationId: "$thread", + targetSessionKey: "agent:ops:subagent:child", }); }); diff --git a/extensions/matrix/src/session-route.test.ts b/extensions/matrix/src/session-route.test.ts index efd649ea97a..abf2907d225 100644 --- a/extensions/matrix/src/session-route.test.ts +++ b/extensions/matrix/src/session-route.test.ts @@ -6,6 +6,25 @@ import type { OpenClawConfig } from "./runtime-api.js"; import { resolveMatrixOutboundSessionRoute } from "./session-route.js"; const tempDirs = new Set(); +const currentDmSessionKey = "agent:main:matrix:channel:!dm:example.org"; +type MatrixChannelConfig = NonNullable["matrix"]>; + +const perRoomDmMatrixConfig = { + dm: { + sessionScope: "per-room", + }, +} satisfies MatrixChannelConfig; + +const defaultAccountPerRoomDmMatrixConfig = { + defaultAccount: "ops", + accounts: { + ops: { + dm: { + sessionScope: "per-room", + }, + }, + }, +} satisfies MatrixChannelConfig; function createTempStore(entries: Record): string { const tempDir = fs.mkdtempSync(path.join(os.tmpdir(), "matrix-session-route-")); @@ -15,6 +34,98 @@ function createTempStore(entries: Record): string { return storePath; } +function createMatrixRouteConfig( + entries: Record, + matrix: MatrixChannelConfig = perRoomDmMatrixConfig, +): OpenClawConfig { + return { + session: { + store: createTempStore(entries), + }, + channels: { + matrix, + }, + } satisfies OpenClawConfig; +} + +function createStoredDirectDmSession( + params: { + from?: string; + to?: string; + accountId?: string | null; + nativeChannelId?: string; + nativeDirectUserId?: string; + lastTo?: string; + lastAccountId?: string; + } = {}, +): Record { + const accountId = params.accountId === null ? undefined : (params.accountId ?? "ops"); + const to = params.to ?? "room:!dm:example.org"; + const accountMetadata = accountId ? { accountId } : {}; + const nativeMetadata = { + ...(params.nativeChannelId ? { nativeChannelId: params.nativeChannelId } : {}), + ...(params.nativeDirectUserId ? { nativeDirectUserId: params.nativeDirectUserId } : {}), + }; + return { + sessionId: "sess-1", + updatedAt: Date.now(), + chatType: "direct", + origin: { + chatType: "direct", + from: params.from ?? "matrix:@alice:example.org", + to, + ...nativeMetadata, + ...accountMetadata, + }, + deliveryContext: { + channel: "matrix", + to, + ...accountMetadata, + }, + ...(params.lastTo ? { lastTo: params.lastTo } : {}), + ...(params.lastAccountId ? { lastAccountId: params.lastAccountId } : {}), + }; +} + +function createStoredChannelSession(): Record { + return { + sessionId: "sess-1", + updatedAt: Date.now(), + chatType: "channel", + origin: { + chatType: "channel", + from: "matrix:channel:!ops:example.org", + to: "room:!ops:example.org", + nativeChannelId: "!ops:example.org", + nativeDirectUserId: "@alice:example.org", + accountId: "ops", + }, + deliveryContext: { + channel: "matrix", + to: "room:!ops:example.org", + accountId: "ops", + }, + lastTo: "room:!ops:example.org", + lastAccountId: "ops", + }; +} + +function resolveUserRoute(params: { cfg: OpenClawConfig; accountId?: string; target?: string }) { + const target = params.target ?? "@alice:example.org"; + return resolveMatrixOutboundSessionRoute({ + cfg: params.cfg, + agentId: "main", + ...(params.accountId ? { accountId: params.accountId } : {}), + currentSessionKey: currentDmSessionKey, + target, + resolvedTarget: { + to: target, + kind: "user", + source: "normalized", + }, + }); +} + afterEach(() => { for (const tempDir of tempDirs) { fs.rmSync(tempDir, { recursive: true, force: true }); @@ -24,53 +135,18 @@ afterEach(() => { describe("resolveMatrixOutboundSessionRoute", () => { it("reuses the current DM room session for same-user sends when Matrix DMs are per-room", () => { - const storePath = createTempStore({ - "agent:main:matrix:channel:!dm:example.org": { - sessionId: "sess-1", - updatedAt: Date.now(), - chatType: "direct", - origin: { - chatType: "direct", - from: "matrix:@alice:example.org", - to: "room:!dm:example.org", - accountId: "ops", - }, - deliveryContext: { - channel: "matrix", - to: "room:!dm:example.org", - accountId: "ops", - }, - }, + const cfg = createMatrixRouteConfig({ + [currentDmSessionKey]: createStoredDirectDmSession(), }); - const cfg = { - session: { - store: storePath, - }, - channels: { - matrix: { - dm: { - sessionScope: "per-room", - }, - }, - }, - } satisfies OpenClawConfig; - const route = resolveMatrixOutboundSessionRoute({ + const route = resolveUserRoute({ cfg, - agentId: "main", accountId: "ops", - currentSessionKey: "agent:main:matrix:channel:!dm:example.org", - target: "@alice:example.org", - resolvedTarget: { - to: "@alice:example.org", - kind: "user", - source: "normalized", - }, }); expect(route).toMatchObject({ - sessionKey: "agent:main:matrix:channel:!dm:example.org", - baseSessionKey: "agent:main:matrix:channel:!dm:example.org", + sessionKey: currentDmSessionKey, + baseSessionKey: currentDmSessionKey, peer: { kind: "channel", id: "!dm:example.org" }, chatType: "direct", from: "matrix:@alice:example.org", @@ -79,48 +155,13 @@ describe("resolveMatrixOutboundSessionRoute", () => { }); it("falls back to user-scoped routing when the current session is for another DM peer", () => { - const storePath = createTempStore({ - "agent:main:matrix:channel:!dm:example.org": { - sessionId: "sess-1", - updatedAt: Date.now(), - chatType: "direct", - origin: { - chatType: "direct", - from: "matrix:@bob:example.org", - to: "room:!dm:example.org", - accountId: "ops", - }, - deliveryContext: { - channel: "matrix", - to: "room:!dm:example.org", - accountId: "ops", - }, - }, + const cfg = createMatrixRouteConfig({ + [currentDmSessionKey]: createStoredDirectDmSession({ from: "matrix:@bob:example.org" }), }); - const cfg = { - session: { - store: storePath, - }, - channels: { - matrix: { - dm: { - sessionScope: "per-room", - }, - }, - }, - } satisfies OpenClawConfig; - const route = resolveMatrixOutboundSessionRoute({ + const route = resolveUserRoute({ cfg, - agentId: "main", accountId: "ops", - currentSessionKey: "agent:main:matrix:channel:!dm:example.org", - target: "@alice:example.org", - resolvedTarget: { - to: "@alice:example.org", - kind: "user", - source: "normalized", - }, }); expect(route).toMatchObject({ @@ -134,48 +175,13 @@ describe("resolveMatrixOutboundSessionRoute", () => { }); it("falls back to user-scoped routing when the current session belongs to another Matrix account", () => { - const storePath = createTempStore({ - "agent:main:matrix:channel:!dm:example.org": { - sessionId: "sess-1", - updatedAt: Date.now(), - chatType: "direct", - origin: { - chatType: "direct", - from: "matrix:@alice:example.org", - to: "room:!dm:example.org", - accountId: "ops", - }, - deliveryContext: { - channel: "matrix", - to: "room:!dm:example.org", - accountId: "ops", - }, - }, + const cfg = createMatrixRouteConfig({ + [currentDmSessionKey]: createStoredDirectDmSession(), }); - const cfg = { - session: { - store: storePath, - }, - channels: { - matrix: { - dm: { - sessionScope: "per-room", - }, - }, - }, - } satisfies OpenClawConfig; - const route = resolveMatrixOutboundSessionRoute({ + const route = resolveUserRoute({ cfg, - agentId: "main", accountId: "support", - currentSessionKey: "agent:main:matrix:channel:!dm:example.org", - target: "@alice:example.org", - resolvedTarget: { - to: "@alice:example.org", - kind: "user", - source: "normalized", - }, }); expect(route).toMatchObject({ @@ -189,57 +195,25 @@ describe("resolveMatrixOutboundSessionRoute", () => { }); it("reuses the canonical DM room after user-target outbound metadata overwrites latest to fields", () => { - const storePath = createTempStore({ - "agent:main:matrix:channel:!dm:example.org": { - sessionId: "sess-1", - updatedAt: Date.now(), - chatType: "direct", - origin: { - chatType: "direct", - from: "matrix:@bob:example.org", - to: "room:@bob:example.org", - nativeChannelId: "!dm:example.org", - nativeDirectUserId: "@alice:example.org", - accountId: "ops", - }, - deliveryContext: { - channel: "matrix", - to: "room:@bob:example.org", - accountId: "ops", - }, + const cfg = createMatrixRouteConfig({ + [currentDmSessionKey]: createStoredDirectDmSession({ + from: "matrix:@bob:example.org", + to: "room:@bob:example.org", + nativeChannelId: "!dm:example.org", + nativeDirectUserId: "@alice:example.org", lastTo: "room:@bob:example.org", lastAccountId: "ops", - }, + }), }); - const cfg = { - session: { - store: storePath, - }, - channels: { - matrix: { - dm: { - sessionScope: "per-room", - }, - }, - }, - } satisfies OpenClawConfig; - const route = resolveMatrixOutboundSessionRoute({ + const route = resolveUserRoute({ cfg, - agentId: "main", accountId: "ops", - currentSessionKey: "agent:main:matrix:channel:!dm:example.org", - target: "@alice:example.org", - resolvedTarget: { - to: "@alice:example.org", - kind: "user", - source: "normalized", - }, }); expect(route).toMatchObject({ - sessionKey: "agent:main:matrix:channel:!dm:example.org", - baseSessionKey: "agent:main:matrix:channel:!dm:example.org", + sessionKey: currentDmSessionKey, + baseSessionKey: currentDmSessionKey, peer: { kind: "channel", id: "!dm:example.org" }, chatType: "direct", from: "matrix:@alice:example.org", @@ -248,52 +222,21 @@ describe("resolveMatrixOutboundSessionRoute", () => { }); it("does not reuse the canonical DM room for a different Matrix user after latest metadata drift", () => { - const storePath = createTempStore({ - "agent:main:matrix:channel:!dm:example.org": { - sessionId: "sess-1", - updatedAt: Date.now(), - chatType: "direct", - origin: { - chatType: "direct", - from: "matrix:@bob:example.org", - to: "room:@bob:example.org", - nativeChannelId: "!dm:example.org", - nativeDirectUserId: "@alice:example.org", - accountId: "ops", - }, - deliveryContext: { - channel: "matrix", - to: "room:@bob:example.org", - accountId: "ops", - }, + const cfg = createMatrixRouteConfig({ + [currentDmSessionKey]: createStoredDirectDmSession({ + from: "matrix:@bob:example.org", + to: "room:@bob:example.org", + nativeChannelId: "!dm:example.org", + nativeDirectUserId: "@alice:example.org", lastTo: "room:@bob:example.org", lastAccountId: "ops", - }, + }), }); - const cfg = { - session: { - store: storePath, - }, - channels: { - matrix: { - dm: { - sessionScope: "per-room", - }, - }, - }, - } satisfies OpenClawConfig; - const route = resolveMatrixOutboundSessionRoute({ + const route = resolveUserRoute({ cfg, - agentId: "main", accountId: "ops", - currentSessionKey: "agent:main:matrix:channel:!dm:example.org", target: "@bob:example.org", - resolvedTarget: { - to: "@bob:example.org", - kind: "user", - source: "normalized", - }, }); expect(route).toMatchObject({ @@ -307,52 +250,13 @@ describe("resolveMatrixOutboundSessionRoute", () => { }); it("does not reuse a room after the session metadata was overwritten by a non-DM Matrix send", () => { - const storePath = createTempStore({ - "agent:main:matrix:channel:!dm:example.org": { - sessionId: "sess-1", - updatedAt: Date.now(), - chatType: "channel", - origin: { - chatType: "channel", - from: "matrix:channel:!ops:example.org", - to: "room:!ops:example.org", - nativeChannelId: "!ops:example.org", - nativeDirectUserId: "@alice:example.org", - accountId: "ops", - }, - deliveryContext: { - channel: "matrix", - to: "room:!ops:example.org", - accountId: "ops", - }, - lastTo: "room:!ops:example.org", - lastAccountId: "ops", - }, + const cfg = createMatrixRouteConfig({ + [currentDmSessionKey]: createStoredChannelSession(), }); - const cfg = { - session: { - store: storePath, - }, - channels: { - matrix: { - dm: { - sessionScope: "per-room", - }, - }, - }, - } satisfies OpenClawConfig; - const route = resolveMatrixOutboundSessionRoute({ + const route = resolveUserRoute({ cfg, - agentId: "main", accountId: "ops", - currentSessionKey: "agent:main:matrix:channel:!dm:example.org", - target: "@alice:example.org", - resolvedTarget: { - to: "@alice:example.org", - kind: "user", - source: "normalized", - }, }); expect(route).toMatchObject({ @@ -366,57 +270,20 @@ describe("resolveMatrixOutboundSessionRoute", () => { }); it("uses the effective default Matrix account when accountId is omitted", () => { - const storePath = createTempStore({ - "agent:main:matrix:channel:!dm:example.org": { - sessionId: "sess-1", - updatedAt: Date.now(), - chatType: "direct", - origin: { - chatType: "direct", - from: "matrix:@alice:example.org", - to: "room:!dm:example.org", - accountId: "ops", - }, - deliveryContext: { - channel: "matrix", - to: "room:!dm:example.org", - accountId: "ops", - }, + const cfg = createMatrixRouteConfig( + { + [currentDmSessionKey]: createStoredDirectDmSession(), }, - }); - const cfg = { - session: { - store: storePath, - }, - channels: { - matrix: { - defaultAccount: "ops", - accounts: { - ops: { - dm: { - sessionScope: "per-room", - }, - }, - }, - }, - }, - } satisfies OpenClawConfig; + defaultAccountPerRoomDmMatrixConfig, + ); - const route = resolveMatrixOutboundSessionRoute({ + const route = resolveUserRoute({ cfg, - agentId: "main", - currentSessionKey: "agent:main:matrix:channel:!dm:example.org", - target: "@alice:example.org", - resolvedTarget: { - to: "@alice:example.org", - kind: "user", - source: "normalized", - }, }); expect(route).toMatchObject({ - sessionKey: "agent:main:matrix:channel:!dm:example.org", - baseSessionKey: "agent:main:matrix:channel:!dm:example.org", + sessionKey: currentDmSessionKey, + baseSessionKey: currentDmSessionKey, peer: { kind: "channel", id: "!dm:example.org" }, chatType: "direct", from: "matrix:@alice:example.org", @@ -425,55 +292,20 @@ describe("resolveMatrixOutboundSessionRoute", () => { }); it("reuses the current DM room when stored account metadata is missing", () => { - const storePath = createTempStore({ - "agent:main:matrix:channel:!dm:example.org": { - sessionId: "sess-1", - updatedAt: Date.now(), - chatType: "direct", - origin: { - chatType: "direct", - from: "matrix:@alice:example.org", - to: "room:!dm:example.org", - }, - deliveryContext: { - channel: "matrix", - to: "room:!dm:example.org", - }, + const cfg = createMatrixRouteConfig( + { + [currentDmSessionKey]: createStoredDirectDmSession({ accountId: null }), }, - }); - const cfg = { - session: { - store: storePath, - }, - channels: { - matrix: { - defaultAccount: "ops", - accounts: { - ops: { - dm: { - sessionScope: "per-room", - }, - }, - }, - }, - }, - } satisfies OpenClawConfig; + defaultAccountPerRoomDmMatrixConfig, + ); - const route = resolveMatrixOutboundSessionRoute({ + const route = resolveUserRoute({ cfg, - agentId: "main", - currentSessionKey: "agent:main:matrix:channel:!dm:example.org", - target: "@alice:example.org", - resolvedTarget: { - to: "@alice:example.org", - kind: "user", - source: "normalized", - }, }); expect(route).toMatchObject({ - sessionKey: "agent:main:matrix:channel:!dm:example.org", - baseSessionKey: "agent:main:matrix:channel:!dm:example.org", + sessionKey: currentDmSessionKey, + baseSessionKey: currentDmSessionKey, peer: { kind: "channel", id: "!dm:example.org" }, chatType: "direct", from: "matrix:@alice:example.org", diff --git a/extensions/matrix/src/setup-core.test.ts b/extensions/matrix/src/setup-core.test.ts index e9c63fb0a7a..bbe11dbcf19 100644 --- a/extensions/matrix/src/setup-core.test.ts +++ b/extensions/matrix/src/setup-core.test.ts @@ -2,6 +2,39 @@ import { describe, expect, it } from "vitest"; import { matrixSetupAdapter } from "./setup-core.js"; import type { CoreConfig } from "./types.js"; +function applyOpsAccountConfig(cfg: CoreConfig): CoreConfig { + return matrixSetupAdapter.applyAccountConfig({ + cfg, + accountId: "ops", + input: { + name: "Ops", + homeserver: "https://matrix.example.org", + accessToken: "ops-token", + }, + }) as CoreConfig; +} + +function expectPromotedDefaultAccount(next: CoreConfig): void { + expect(next.channels?.matrix?.accounts?.Default).toMatchObject({ + enabled: true, + deviceName: "Legacy raw key", + homeserver: "https://matrix.example.org", + userId: "@default:example.org", + accessToken: "default-token", + avatarUrl: "mxc://example.org/default-avatar", + }); + expect(next.channels?.matrix?.accounts?.default).toBeUndefined(); +} + +function expectOpsAccount(next: CoreConfig): void { + expect(next.channels?.matrix?.accounts?.ops).toMatchObject({ + name: "Ops", + enabled: true, + homeserver: "https://matrix.example.org", + accessToken: "ops-token", + }); +} + describe("matrixSetupAdapter", () => { it("moves legacy default config before writing a named account", () => { const cfg = { @@ -63,31 +96,10 @@ describe("matrixSetupAdapter", () => { }, } as CoreConfig; - const next = matrixSetupAdapter.applyAccountConfig({ - cfg, - accountId: "ops", - input: { - name: "Ops", - homeserver: "https://matrix.example.org", - accessToken: "ops-token", - }, - }) as CoreConfig; + const next = applyOpsAccountConfig(cfg); - expect(next.channels?.matrix?.accounts?.Default).toMatchObject({ - enabled: true, - deviceName: "Legacy raw key", - homeserver: "https://matrix.example.org", - userId: "@default:example.org", - accessToken: "default-token", - avatarUrl: "mxc://example.org/default-avatar", - }); - expect(next.channels?.matrix?.accounts?.default).toBeUndefined(); - expect(next.channels?.matrix?.accounts?.ops).toMatchObject({ - name: "Ops", - enabled: true, - homeserver: "https://matrix.example.org", - accessToken: "ops-token", - }); + expectPromotedDefaultAccount(next); + expectOpsAccount(next); }); it("reuses an existing raw default-like key during promotion when defaultAccount is unset", () => { @@ -112,35 +124,14 @@ describe("matrixSetupAdapter", () => { }, } as CoreConfig; - const next = matrixSetupAdapter.applyAccountConfig({ - cfg, - accountId: "ops", - input: { - name: "Ops", - homeserver: "https://matrix.example.org", - accessToken: "ops-token", - }, - }) as CoreConfig; + const next = applyOpsAccountConfig(cfg); - expect(next.channels?.matrix?.accounts?.Default).toMatchObject({ - enabled: true, - deviceName: "Legacy raw key", - homeserver: "https://matrix.example.org", - userId: "@default:example.org", - accessToken: "default-token", - avatarUrl: "mxc://example.org/default-avatar", - }); - expect(next.channels?.matrix?.accounts?.default).toBeUndefined(); + expectPromotedDefaultAccount(next); expect(next.channels?.matrix?.accounts?.support).toMatchObject({ homeserver: "https://matrix.example.org", accessToken: "support-token", }); - expect(next.channels?.matrix?.accounts?.ops).toMatchObject({ - name: "Ops", - enabled: true, - homeserver: "https://matrix.example.org", - accessToken: "ops-token", - }); + expectOpsAccount(next); }); it("clears stored auth fields when switching an account to env-backed auth", () => { diff --git a/extensions/memory-wiki/src/status.test.ts b/extensions/memory-wiki/src/status.test.ts index ab192a21ede..f0e58f48516 100644 --- a/extensions/memory-wiki/src/status.test.ts +++ b/extensions/memory-wiki/src/status.test.ts @@ -14,6 +14,30 @@ import { createMemoryWikiTestHarness } from "./test-helpers.js"; const { createVault } = createMemoryWikiTestHarness(); +async function resolveBridgeMissingArtifactsStatus() { + const config = resolveMemoryWikiConfig( + { + vaultMode: "bridge", + bridge: { + enabled: true, + readMemoryArtifacts: true, + }, + }, + { homedir: "/Users/tester" }, + ); + + return resolveMemoryWikiStatus(config, { + appConfig: { + agents: { + list: [{ id: "main", default: true, workspace: "/tmp/workspace" }], + }, + } as OpenClawConfig, + listPublicArtifacts: async () => [], + pathExists: async () => true, + resolveCommand: async () => null, + }); +} + describe("resolveMemoryWikiStatus", () => { it("reports missing vault and missing requested obsidian cli", async () => { const config = resolveMemoryWikiConfig( @@ -61,27 +85,7 @@ describe("resolveMemoryWikiStatus", () => { }); it("warns when bridge mode has no exported memory artifacts", async () => { - const config = resolveMemoryWikiConfig( - { - vaultMode: "bridge", - bridge: { - enabled: true, - readMemoryArtifacts: true, - }, - }, - { homedir: "/Users/tester" }, - ); - - const status = await resolveMemoryWikiStatus(config, { - appConfig: { - agents: { - list: [{ id: "main", default: true, workspace: "/tmp/workspace" }], - }, - } as OpenClawConfig, - listPublicArtifacts: async () => [], - pathExists: async () => true, - resolveCommand: async () => null, - }); + const status = await resolveBridgeMissingArtifactsStatus(); expect(status.bridgePublicArtifactCount).toBe(0); expect(status.warnings.map((warning) => warning.code)).toContain("bridge-artifacts-missing"); @@ -235,27 +239,7 @@ describe("memory wiki doctor", () => { }); it("suggests bridge fixes when no public artifacts are exported", async () => { - const config = resolveMemoryWikiConfig( - { - vaultMode: "bridge", - bridge: { - enabled: true, - readMemoryArtifacts: true, - }, - }, - { homedir: "/Users/tester" }, - ); - - const status = await resolveMemoryWikiStatus(config, { - appConfig: { - agents: { - list: [{ id: "main", default: true, workspace: "/tmp/workspace" }], - }, - } as OpenClawConfig, - listPublicArtifacts: async () => [], - pathExists: async () => true, - resolveCommand: async () => null, - }); + const status = await resolveBridgeMissingArtifactsStatus(); const report = buildMemoryWikiDoctorReport(status); expect(report.fixes.map((fix) => fix.code)).toContain("bridge-artifacts-missing"); diff --git a/extensions/minimax/music-generation-provider.test.ts b/extensions/minimax/music-generation-provider.test.ts index e87da12236f..9a080cdee3f 100644 --- a/extensions/minimax/music-generation-provider.test.ts +++ b/extensions/minimax/music-generation-provider.test.ts @@ -1,60 +1,42 @@ -import { afterEach, describe, expect, it, vi } from "vitest"; -import { buildMinimaxMusicGenerationProvider } from "./music-generation-provider.js"; +import { beforeAll, describe, expect, it, vi } from "vitest"; +import { + getMinimaxProviderHttpMocks, + installMinimaxProviderHttpMockCleanup, + loadMinimaxMusicGenerationProviderModule, +} from "./provider-http.test-helpers.js"; -const { - resolveApiKeyForProviderMock, - postJsonRequestMock, - fetchWithTimeoutMock, - assertOkOrThrowHttpErrorMock, - resolveProviderHttpRequestConfigMock, -} = vi.hoisted(() => ({ - resolveApiKeyForProviderMock: vi.fn(async () => ({ apiKey: "minimax-key" })), - postJsonRequestMock: vi.fn(), - fetchWithTimeoutMock: vi.fn(), - assertOkOrThrowHttpErrorMock: vi.fn(async () => {}), - resolveProviderHttpRequestConfigMock: vi.fn((params) => ({ - baseUrl: params.baseUrl ?? params.defaultBaseUrl, - allowPrivateNetwork: false, - headers: new Headers(params.defaultHeaders), - dispatcherPolicy: undefined, - })), -})); +const { postJsonRequestMock, fetchWithTimeoutMock } = getMinimaxProviderHttpMocks(); -vi.mock("openclaw/plugin-sdk/provider-auth-runtime", () => ({ - resolveApiKeyForProvider: resolveApiKeyForProviderMock, -})); +let buildMinimaxMusicGenerationProvider: Awaited< + ReturnType +>["buildMinimaxMusicGenerationProvider"]; -vi.mock("openclaw/plugin-sdk/provider-http", () => ({ - assertOkOrThrowHttpError: assertOkOrThrowHttpErrorMock, - fetchWithTimeout: fetchWithTimeoutMock, - postJsonRequest: postJsonRequestMock, - resolveProviderHttpRequestConfig: resolveProviderHttpRequestConfigMock, -})); +beforeAll(async () => { + ({ buildMinimaxMusicGenerationProvider } = await loadMinimaxMusicGenerationProviderModule()); +}); + +installMinimaxProviderHttpMockCleanup(); + +function mockMusicGenerationResponse(json: Record): void { + postJsonRequestMock.mockResolvedValue({ + response: { + json: async () => json, + }, + release: vi.fn(async () => {}), + }); + fetchWithTimeoutMock.mockResolvedValue({ + headers: new Headers({ "content-type": "audio/mpeg" }), + arrayBuffer: async () => Buffer.from("mp3-bytes"), + }); +} describe("minimax music generation provider", () => { - afterEach(() => { - resolveApiKeyForProviderMock.mockClear(); - postJsonRequestMock.mockReset(); - fetchWithTimeoutMock.mockReset(); - assertOkOrThrowHttpErrorMock.mockClear(); - resolveProviderHttpRequestConfigMock.mockClear(); - }); - it("creates music and downloads the generated track", async () => { - postJsonRequestMock.mockResolvedValue({ - response: { - json: async () => ({ - task_id: "task-123", - audio_url: "https://example.com/out.mp3", - lyrics: "our city wakes", - base_resp: { status_code: 0 }, - }), - }, - release: vi.fn(async () => {}), - }); - fetchWithTimeoutMock.mockResolvedValue({ - headers: new Headers({ "content-type": "audio/mpeg" }), - arrayBuffer: async () => Buffer.from("mp3-bytes"), + mockMusicGenerationResponse({ + task_id: "task-123", + audio_url: "https://example.com/out.mp3", + lyrics: "our city wakes", + base_resp: { status_code: 0 }, }); const provider = buildMinimaxMusicGenerationProvider(); @@ -98,20 +80,11 @@ describe("minimax music generation provider", () => { }); it("downloads tracks when url output is returned in data.audio", async () => { - postJsonRequestMock.mockResolvedValue({ - response: { - json: async () => ({ - data: { - audio: "https://example.com/url-audio.mp3", - }, - base_resp: { status_code: 0 }, - }), + mockMusicGenerationResponse({ + data: { + audio: "https://example.com/url-audio.mp3", }, - release: vi.fn(async () => {}), - }); - fetchWithTimeoutMock.mockResolvedValue({ - headers: new Headers({ "content-type": "audio/mpeg" }), - arrayBuffer: async () => Buffer.from("mp3-bytes"), + base_resp: { status_code: 0 }, }); const provider = buildMinimaxMusicGenerationProvider(); @@ -148,19 +121,10 @@ describe("minimax music generation provider", () => { }); it("uses lyrics optimizer when lyrics are omitted", async () => { - postJsonRequestMock.mockResolvedValue({ - response: { - json: async () => ({ - task_id: "task-456", - audio_url: "https://example.com/out.mp3", - base_resp: { status_code: 0 }, - }), - }, - release: vi.fn(async () => {}), - }); - fetchWithTimeoutMock.mockResolvedValue({ - headers: new Headers({ "content-type": "audio/mpeg" }), - arrayBuffer: async () => Buffer.from("mp3-bytes"), + mockMusicGenerationResponse({ + task_id: "task-456", + audio_url: "https://example.com/out.mp3", + base_resp: { status_code: 0 }, }); const provider = buildMinimaxMusicGenerationProvider(); diff --git a/extensions/minimax/provider-http.test-helpers.ts b/extensions/minimax/provider-http.test-helpers.ts new file mode 100644 index 00000000000..c2c90ba9659 --- /dev/null +++ b/extensions/minimax/provider-http.test-helpers.ts @@ -0,0 +1,15 @@ +import { + getProviderHttpMocks, + installProviderHttpMockCleanup, +} from "../../test/helpers/media-generation/provider-http-mocks.js"; + +export const getMinimaxProviderHttpMocks = getProviderHttpMocks; +export const installMinimaxProviderHttpMockCleanup = installProviderHttpMockCleanup; + +export function loadMinimaxMusicGenerationProviderModule() { + return import("./music-generation-provider.js"); +} + +export function loadMinimaxVideoGenerationProviderModule() { + return import("./video-generation-provider.js"); +} diff --git a/extensions/minimax/video-generation-provider.test.ts b/extensions/minimax/video-generation-provider.test.ts index 27e04d10582..dc1f7af5907 100644 --- a/extensions/minimax/video-generation-provider.test.ts +++ b/extensions/minimax/video-generation-provider.test.ts @@ -1,45 +1,23 @@ -import { afterEach, describe, expect, it, vi } from "vitest"; -import { buildMinimaxVideoGenerationProvider } from "./video-generation-provider.js"; +import { beforeAll, describe, expect, it, vi } from "vitest"; +import { + getMinimaxProviderHttpMocks, + installMinimaxProviderHttpMockCleanup, + loadMinimaxVideoGenerationProviderModule, +} from "./provider-http.test-helpers.js"; -const { - resolveApiKeyForProviderMock, - postJsonRequestMock, - fetchWithTimeoutMock, - assertOkOrThrowHttpErrorMock, - resolveProviderHttpRequestConfigMock, -} = vi.hoisted(() => ({ - resolveApiKeyForProviderMock: vi.fn(async () => ({ apiKey: "minimax-key" })), - postJsonRequestMock: vi.fn(), - fetchWithTimeoutMock: vi.fn(), - assertOkOrThrowHttpErrorMock: vi.fn(async () => {}), - resolveProviderHttpRequestConfigMock: vi.fn((params) => ({ - baseUrl: params.baseUrl ?? params.defaultBaseUrl, - allowPrivateNetwork: false, - headers: new Headers(params.defaultHeaders), - dispatcherPolicy: undefined, - })), -})); +const { postJsonRequestMock, fetchWithTimeoutMock } = getMinimaxProviderHttpMocks(); -vi.mock("openclaw/plugin-sdk/provider-auth-runtime", () => ({ - resolveApiKeyForProvider: resolveApiKeyForProviderMock, -})); +let buildMinimaxVideoGenerationProvider: Awaited< + ReturnType +>["buildMinimaxVideoGenerationProvider"]; -vi.mock("openclaw/plugin-sdk/provider-http", () => ({ - assertOkOrThrowHttpError: assertOkOrThrowHttpErrorMock, - fetchWithTimeout: fetchWithTimeoutMock, - postJsonRequest: postJsonRequestMock, - resolveProviderHttpRequestConfig: resolveProviderHttpRequestConfigMock, -})); +beforeAll(async () => { + ({ buildMinimaxVideoGenerationProvider } = await loadMinimaxVideoGenerationProviderModule()); +}); + +installMinimaxProviderHttpMockCleanup(); describe("minimax video generation provider", () => { - afterEach(() => { - resolveApiKeyForProviderMock.mockClear(); - postJsonRequestMock.mockReset(); - fetchWithTimeoutMock.mockReset(); - assertOkOrThrowHttpErrorMock.mockClear(); - resolveProviderHttpRequestConfigMock.mockClear(); - }); - it("creates a task, polls status, and downloads the generated video", async () => { postJsonRequestMock.mockResolvedValue({ response: { diff --git a/extensions/msteams/src/attachments/graph.test.ts b/extensions/msteams/src/attachments/graph.test.ts index a15dcdd2683..b11486b6e01 100644 --- a/extensions/msteams/src/attachments/graph.test.ts +++ b/extensions/msteams/src/attachments/graph.test.ts @@ -62,6 +62,47 @@ function mockBinaryResponse(data: Uint8Array, status = 200) { return new Response(Buffer.from(data) as BodyInit, { status }); } +type GuardedFetchParams = { url: string; init?: RequestInit }; + +function guardedFetchResult(params: GuardedFetchParams, response: Response) { + return { + response, + release: async () => {}, + finalUrl: params.url, + }; +} + +function mockGraphMediaFetch(options: { + messageId: string; + messageResponse?: unknown; + hostedContents?: unknown[]; + valueResponses?: Record; + fetchCalls?: string[]; +}) { + vi.mocked(fetchWithSsrFGuard).mockImplementation(async (params: GuardedFetchParams) => { + options.fetchCalls?.push(params.url); + const url = params.url; + if (url.endsWith(`/messages/${options.messageId}`) && !url.includes("hostedContents")) { + return guardedFetchResult( + params, + mockFetchResponse(options.messageResponse ?? { body: {}, attachments: [] }), + ); + } + if (url.endsWith("/hostedContents")) { + return guardedFetchResult(params, mockFetchResponse({ value: options.hostedContents ?? [] })); + } + for (const [fragment, response] of Object.entries(options.valueResponses ?? {})) { + if (url.includes(fragment)) { + return guardedFetchResult(params, response); + } + } + if (url.endsWith("/attachments")) { + return guardedFetchResult(params, mockFetchResponse({ value: [] })); + } + return guardedFetchResult(params, mockFetchResponse({}, 404)); + }); +} + describe("downloadMSTeamsGraphMedia hosted content $value fallback", () => { beforeEach(() => { vi.clearAllMocks(); @@ -72,49 +113,13 @@ describe("downloadMSTeamsGraphMedia hosted content $value fallback", () => { const fetchCalls: string[] = []; - vi.mocked(fetchWithSsrFGuard).mockImplementation(async (params: { url: string }) => { - fetchCalls.push(params.url); - const url = params.url; - - // Main message fetch - if (url.endsWith("/messages/msg-1") && !url.includes("hostedContents")) { - return { - response: mockFetchResponse({ body: {}, attachments: [] }), - release: async () => {}, - finalUrl: params.url, - }; - } - // hostedContents collection - if (url.endsWith("/hostedContents")) { - return { - response: mockFetchResponse({ - value: [{ id: "hosted-123", contentType: "image/png", contentBytes: null }], - }), - release: async () => {}, - finalUrl: params.url, - }; - } - // $value endpoint (the fallback being tested) - if (url.includes("/hostedContents/hosted-123/$value")) { - return { - response: mockBinaryResponse(imageBytes), - release: async () => {}, - finalUrl: params.url, - }; - } - // attachments collection - if (url.endsWith("/attachments")) { - return { - response: mockFetchResponse({ value: [] }), - release: async () => {}, - finalUrl: params.url, - }; - } - return { - response: mockFetchResponse({}, 404), - release: async () => {}, - finalUrl: params.url, - }; + mockGraphMediaFetch({ + messageId: "msg-1", + hostedContents: [{ id: "hosted-123", contentType: "image/png", contentBytes: null }], + valueResponses: { + "/hostedContents/hosted-123/$value": mockBinaryResponse(imageBytes), + }, + fetchCalls, }); const result = await downloadMSTeamsGraphMedia({ @@ -131,36 +136,9 @@ describe("downloadMSTeamsGraphMedia hosted content $value fallback", () => { }); it("skips hosted content when contentBytes is null and id is missing", async () => { - vi.mocked(fetchWithSsrFGuard).mockImplementation(async (params: { url: string }) => { - const url = params.url; - if (url.endsWith("/messages/msg-2") && !url.includes("hostedContents")) { - return { - response: mockFetchResponse({ body: {}, attachments: [] }), - release: async () => {}, - finalUrl: params.url, - }; - } - if (url.endsWith("/hostedContents")) { - return { - response: mockFetchResponse({ - value: [{ contentType: "image/png", contentBytes: null }], - }), - release: async () => {}, - finalUrl: params.url, - }; - } - if (url.endsWith("/attachments")) { - return { - response: mockFetchResponse({ value: [] }), - release: async () => {}, - finalUrl: params.url, - }; - } - return { - response: mockFetchResponse({}, 404), - release: async () => {}, - finalUrl: params.url, - }; + mockGraphMediaFetch({ + messageId: "msg-2", + hostedContents: [{ contentType: "image/png", contentBytes: null }], }); const result = await downloadMSTeamsGraphMedia({ @@ -176,49 +154,19 @@ describe("downloadMSTeamsGraphMedia hosted content $value fallback", () => { it("skips $value content when Content-Length exceeds maxBytes", async () => { const fetchCalls: string[] = []; - vi.mocked(fetchWithSsrFGuard).mockImplementation(async (params: { url: string }) => { - fetchCalls.push(params.url); - const url = params.url; - if (url.endsWith("/messages/msg-cl") && !url.includes("hostedContents")) { - return { - response: mockFetchResponse({ body: {}, attachments: [] }), - release: async () => {}, - finalUrl: params.url, - }; - } - if (url.endsWith("/hostedContents")) { - return { - response: mockFetchResponse({ - value: [{ id: "hosted-big", contentType: "image/png", contentBytes: null }], - }), - release: async () => {}, - finalUrl: params.url, - }; - } - if (url.includes("/hostedContents/hosted-big/$value")) { - // Return a response whose Content-Length exceeds maxBytes - const data = new Uint8Array([0x89, 0x50, 0x4e, 0x47]); - return { - response: new Response(Buffer.from(data) as BodyInit, { + mockGraphMediaFetch({ + messageId: "msg-cl", + hostedContents: [{ id: "hosted-big", contentType: "image/png", contentBytes: null }], + valueResponses: { + "/hostedContents/hosted-big/$value": new Response( + Buffer.from(new Uint8Array([0x89, 0x50, 0x4e, 0x47])) as BodyInit, + { status: 200, headers: { "content-length": "999999999" }, - }), - release: async () => {}, - finalUrl: params.url, - }; - } - if (url.endsWith("/attachments")) { - return { - response: mockFetchResponse({ value: [] }), - release: async () => {}, - finalUrl: params.url, - }; - } - return { - response: mockFetchResponse({}, 404), - release: async () => {}, - finalUrl: params.url, - }; + }, + ), + }, + fetchCalls, }); const result = await downloadMSTeamsGraphMedia({ @@ -237,37 +185,10 @@ describe("downloadMSTeamsGraphMedia hosted content $value fallback", () => { const fetchCalls: string[] = []; const base64Png = Buffer.from([0x89, 0x50, 0x4e, 0x47]).toString("base64"); - vi.mocked(fetchWithSsrFGuard).mockImplementation(async (params: { url: string }) => { - fetchCalls.push(params.url); - const url = params.url; - if (url.endsWith("/messages/msg-3") && !url.includes("hostedContents")) { - return { - response: mockFetchResponse({ body: {}, attachments: [] }), - release: async () => {}, - finalUrl: params.url, - }; - } - if (url.endsWith("/hostedContents")) { - return { - response: mockFetchResponse({ - value: [{ id: "hosted-456", contentType: "image/png", contentBytes: base64Png }], - }), - release: async () => {}, - finalUrl: params.url, - }; - } - if (url.endsWith("/attachments")) { - return { - response: mockFetchResponse({ value: [] }), - release: async () => {}, - finalUrl: params.url, - }; - } - return { - response: mockFetchResponse({}, 404), - release: async () => {}, - finalUrl: params.url, - }; + mockGraphMediaFetch({ + messageId: "msg-3", + hostedContents: [{ id: "hosted-456", contentType: "image/png", contentBytes: base64Png }], + fetchCalls, }); const result = await downloadMSTeamsGraphMedia({ @@ -283,37 +204,7 @@ describe("downloadMSTeamsGraphMedia hosted content $value fallback", () => { }); it("adds the OpenClaw User-Agent to guarded Graph attachment fetches", async () => { - vi.mocked(fetchWithSsrFGuard).mockImplementation( - async (params: { url: string; init?: RequestInit }) => { - const url = params.url; - if (url.endsWith("/messages/msg-ua") && !url.includes("hostedContents")) { - return { - response: mockFetchResponse({ body: {}, attachments: [] }), - release: async () => {}, - finalUrl: params.url, - }; - } - if (url.endsWith("/hostedContents")) { - return { - response: mockFetchResponse({ value: [] }), - release: async () => {}, - finalUrl: params.url, - }; - } - if (url.endsWith("/attachments")) { - return { - response: mockFetchResponse({ value: [] }), - release: async () => {}, - finalUrl: params.url, - }; - } - return { - response: mockFetchResponse({}, 404), - release: async () => {}, - finalUrl: params.url, - }; - }, - ); + mockGraphMediaFetch({ messageId: "msg-ua" }); await downloadMSTeamsGraphMedia({ messageUrl: "https://graph.microsoft.com/v1.0/chats/c/messages/msg-ua", @@ -333,43 +224,18 @@ describe("downloadMSTeamsGraphMedia hosted content $value fallback", () => { }); it("adds the OpenClaw User-Agent to Graph shares downloads for reference attachments", async () => { - vi.mocked(fetchWithSsrFGuard).mockImplementation(async (params: { url: string }) => { - const url = params.url; - if (url.endsWith("/messages/msg-share") && !url.includes("hostedContents")) { - return { - response: mockFetchResponse({ - body: {}, - attachments: [ - { - contentType: "reference", - contentUrl: "https://tenant.sharepoint.com/file.docx", - name: "file.docx", - }, - ], - }), - release: async () => {}, - finalUrl: params.url, - }; - } - if (url.endsWith("/hostedContents")) { - return { - response: mockFetchResponse({ value: [] }), - release: async () => {}, - finalUrl: params.url, - }; - } - if (url.endsWith("/attachments")) { - return { - response: mockFetchResponse({ value: [] }), - release: async () => {}, - finalUrl: params.url, - }; - } - return { - response: mockFetchResponse({}, 404), - release: async () => {}, - finalUrl: params.url, - }; + mockGraphMediaFetch({ + messageId: "msg-share", + messageResponse: { + body: {}, + attachments: [ + { + contentType: "reference", + contentUrl: "https://tenant.sharepoint.com/file.docx", + name: "file.docx", + }, + ], + }, }); vi.mocked(safeFetchWithPolicy).mockResolvedValue(new Response(null, { status: 200 })); vi.mocked(downloadAndStoreMSTeamsRemoteMedia).mockImplementation(async (params) => { diff --git a/extensions/msteams/src/graph-messages.actions.test.ts b/extensions/msteams/src/graph-messages.actions.test.ts index 08d0671ea36..147116b52d9 100644 --- a/extensions/msteams/src/graph-messages.actions.test.ts +++ b/extensions/msteams/src/graph-messages.actions.test.ts @@ -1,48 +1,28 @@ -import { beforeEach, describe, expect, it, vi } from "vitest"; +import { beforeAll, describe, expect, it } from "vitest"; import type { OpenClawConfig } from "../runtime-api.js"; import { - pinMessageMSTeams, - reactMessageMSTeams, - unpinMessageMSTeams, - unreactMessageMSTeams, -} from "./graph-messages.js"; + CHANNEL_TO, + CHAT_ID, + TOKEN, + type GraphMessagesTestModule, + getGraphMessagesMockState, + installGraphMessagesMockDefaults, + loadGraphMessagesTestModule, +} from "./graph-messages.test-helpers.js"; -const mockState = vi.hoisted(() => ({ - resolveGraphToken: vi.fn(), - fetchGraphJson: vi.fn(), - postGraphJson: vi.fn(), - postGraphBetaJson: vi.fn(), - deleteGraphRequest: vi.fn(), - findPreferredDmByUserId: vi.fn(), -})); +const mockState = getGraphMessagesMockState(); +installGraphMessagesMockDefaults(); +let pinMessageMSTeams: GraphMessagesTestModule["pinMessageMSTeams"]; +let reactMessageMSTeams: GraphMessagesTestModule["reactMessageMSTeams"]; +let unpinMessageMSTeams: GraphMessagesTestModule["unpinMessageMSTeams"]; +let unreactMessageMSTeams: GraphMessagesTestModule["unreactMessageMSTeams"]; -vi.mock("./graph.js", () => { - return { - resolveGraphToken: mockState.resolveGraphToken, - fetchGraphJson: mockState.fetchGraphJson, - postGraphJson: mockState.postGraphJson, - postGraphBetaJson: mockState.postGraphBetaJson, - deleteGraphRequest: mockState.deleteGraphRequest, - escapeOData: vi.fn((value: string) => value.replaceAll("'", "''")), - }; +beforeAll(async () => { + ({ pinMessageMSTeams, reactMessageMSTeams, unpinMessageMSTeams, unreactMessageMSTeams } = + await loadGraphMessagesTestModule()); }); -vi.mock("./conversation-store-fs.js", () => ({ - createMSTeamsConversationStoreFs: () => ({ - findPreferredDmByUserId: mockState.findPreferredDmByUserId, - }), -})); - -const TOKEN = "test-graph-token"; -const CHAT_ID = "19:abc@thread.tacv2"; -const CHANNEL_TO = "team-id-1/channel-id-1"; - describe("pinMessageMSTeams", () => { - beforeEach(() => { - vi.clearAllMocks(); - mockState.resolveGraphToken.mockResolvedValue(TOKEN); - }); - it("pins a message in a chat", async () => { mockState.postGraphJson.mockResolvedValue({ id: "pinned-1" }); @@ -79,11 +59,6 @@ describe("pinMessageMSTeams", () => { }); describe("unpinMessageMSTeams", () => { - beforeEach(() => { - vi.clearAllMocks(); - mockState.resolveGraphToken.mockResolvedValue(TOKEN); - }); - it("unpins a message from a chat", async () => { mockState.deleteGraphRequest.mockResolvedValue(undefined); @@ -118,11 +93,6 @@ describe("unpinMessageMSTeams", () => { }); describe("reactMessageMSTeams", () => { - beforeEach(() => { - vi.clearAllMocks(); - mockState.resolveGraphToken.mockResolvedValue(TOKEN); - }); - it("sets a like reaction on a chat message", async () => { mockState.postGraphBetaJson.mockResolvedValue(undefined); @@ -211,11 +181,6 @@ describe("reactMessageMSTeams", () => { }); describe("unreactMessageMSTeams", () => { - beforeEach(() => { - vi.clearAllMocks(); - mockState.resolveGraphToken.mockResolvedValue(TOKEN); - }); - it("removes a reaction from a chat message", async () => { mockState.postGraphBetaJson.mockResolvedValue(undefined); diff --git a/extensions/msteams/src/graph-messages.read.test.ts b/extensions/msteams/src/graph-messages.read.test.ts index dfbdee723fd..f10b281b7b5 100644 --- a/extensions/msteams/src/graph-messages.read.test.ts +++ b/extensions/msteams/src/graph-messages.read.test.ts @@ -1,43 +1,27 @@ -import { beforeEach, describe, expect, it, vi } from "vitest"; +import { beforeAll, describe, expect, it } from "vitest"; import type { OpenClawConfig } from "../runtime-api.js"; -import { getMessageMSTeams, listPinsMSTeams, listReactionsMSTeams } from "./graph-messages.js"; +import { + CHANNEL_TO, + CHAT_ID, + TOKEN, + type GraphMessagesTestModule, + getGraphMessagesMockState, + installGraphMessagesMockDefaults, + loadGraphMessagesTestModule, +} from "./graph-messages.test-helpers.js"; -const mockState = vi.hoisted(() => ({ - resolveGraphToken: vi.fn(), - fetchGraphJson: vi.fn(), - postGraphJson: vi.fn(), - postGraphBetaJson: vi.fn(), - deleteGraphRequest: vi.fn(), - findPreferredDmByUserId: vi.fn(), -})); +const mockState = getGraphMessagesMockState(); +installGraphMessagesMockDefaults(); +let getMessageMSTeams: GraphMessagesTestModule["getMessageMSTeams"]; +let listPinsMSTeams: GraphMessagesTestModule["listPinsMSTeams"]; +let listReactionsMSTeams: GraphMessagesTestModule["listReactionsMSTeams"]; -vi.mock("./graph.js", () => { - return { - resolveGraphToken: mockState.resolveGraphToken, - fetchGraphJson: mockState.fetchGraphJson, - postGraphJson: mockState.postGraphJson, - postGraphBetaJson: mockState.postGraphBetaJson, - deleteGraphRequest: mockState.deleteGraphRequest, - escapeOData: vi.fn((value: string) => value.replaceAll("'", "''")), - }; +beforeAll(async () => { + ({ getMessageMSTeams, listPinsMSTeams, listReactionsMSTeams } = + await loadGraphMessagesTestModule()); }); -vi.mock("./conversation-store-fs.js", () => ({ - createMSTeamsConversationStoreFs: () => ({ - findPreferredDmByUserId: mockState.findPreferredDmByUserId, - }), -})); - -const TOKEN = "test-graph-token"; -const CHAT_ID = "19:abc@thread.tacv2"; -const CHANNEL_TO = "team-id-1/channel-id-1"; - describe("getMessageMSTeams", () => { - beforeEach(() => { - vi.clearAllMocks(); - mockState.resolveGraphToken.mockResolvedValue(TOKEN); - }); - it("resolves user: target using graphChatId from store", async () => { mockState.findPreferredDmByUserId.mockResolvedValue({ conversationId: "a:bot-framework-dm-id", @@ -186,11 +170,6 @@ describe("getMessageMSTeams", () => { }); describe("listPinsMSTeams", () => { - beforeEach(() => { - vi.clearAllMocks(); - mockState.resolveGraphToken.mockResolvedValue(TOKEN); - }); - it("lists pinned messages in a chat", async () => { mockState.fetchGraphJson.mockResolvedValue({ value: [ @@ -233,11 +212,6 @@ describe("listPinsMSTeams", () => { }); describe("listReactionsMSTeams", () => { - beforeEach(() => { - vi.clearAllMocks(); - mockState.resolveGraphToken.mockResolvedValue(TOKEN); - }); - it("lists reactions grouped by type with user details", async () => { mockState.fetchGraphJson.mockResolvedValue({ id: "msg-1", diff --git a/extensions/msteams/src/graph-messages.search.test.ts b/extensions/msteams/src/graph-messages.search.test.ts index ce9d34bb126..e4056dc813c 100644 --- a/extensions/msteams/src/graph-messages.search.test.ts +++ b/extensions/msteams/src/graph-messages.search.test.ts @@ -1,43 +1,23 @@ -import { beforeEach, describe, expect, it, vi } from "vitest"; +import { beforeAll, describe, expect, it } from "vitest"; import type { OpenClawConfig } from "../runtime-api.js"; -import { searchMessagesMSTeams } from "./graph-messages.js"; +import { + CHANNEL_TO, + CHAT_ID, + type GraphMessagesTestModule, + getGraphMessagesMockState, + installGraphMessagesMockDefaults, + loadGraphMessagesTestModule, +} from "./graph-messages.test-helpers.js"; -const mockState = vi.hoisted(() => ({ - resolveGraphToken: vi.fn(), - fetchGraphJson: vi.fn(), - postGraphJson: vi.fn(), - postGraphBetaJson: vi.fn(), - deleteGraphRequest: vi.fn(), - findPreferredDmByUserId: vi.fn(), -})); +const mockState = getGraphMessagesMockState(); +installGraphMessagesMockDefaults(); +let searchMessagesMSTeams: GraphMessagesTestModule["searchMessagesMSTeams"]; -vi.mock("./graph.js", () => { - return { - resolveGraphToken: mockState.resolveGraphToken, - fetchGraphJson: mockState.fetchGraphJson, - postGraphJson: mockState.postGraphJson, - postGraphBetaJson: mockState.postGraphBetaJson, - deleteGraphRequest: mockState.deleteGraphRequest, - escapeOData: vi.fn((value: string) => value.replaceAll("'", "''")), - }; +beforeAll(async () => { + ({ searchMessagesMSTeams } = await loadGraphMessagesTestModule()); }); -vi.mock("./conversation-store-fs.js", () => ({ - createMSTeamsConversationStoreFs: () => ({ - findPreferredDmByUserId: mockState.findPreferredDmByUserId, - }), -})); - -const TOKEN = "test-graph-token"; -const CHAT_ID = "19:abc@thread.tacv2"; -const CHANNEL_TO = "team-id-1/channel-id-1"; - describe("searchMessagesMSTeams", () => { - beforeEach(() => { - vi.clearAllMocks(); - mockState.resolveGraphToken.mockResolvedValue(TOKEN); - }); - it("searches chat messages with query string", async () => { mockState.fetchGraphJson.mockResolvedValue({ value: [ diff --git a/extensions/msteams/src/graph-messages.test-helpers.ts b/extensions/msteams/src/graph-messages.test-helpers.ts new file mode 100644 index 00000000000..6fa0e3ede43 --- /dev/null +++ b/extensions/msteams/src/graph-messages.test-helpers.ts @@ -0,0 +1,48 @@ +import { beforeEach, vi } from "vitest"; + +const graphMessagesMockState = vi.hoisted(() => ({ + resolveGraphToken: vi.fn(), + fetchGraphJson: vi.fn(), + postGraphJson: vi.fn(), + postGraphBetaJson: vi.fn(), + deleteGraphRequest: vi.fn(), + findPreferredDmByUserId: vi.fn(), +})); + +vi.mock("./graph.js", () => { + return { + resolveGraphToken: graphMessagesMockState.resolveGraphToken, + fetchGraphJson: graphMessagesMockState.fetchGraphJson, + postGraphJson: graphMessagesMockState.postGraphJson, + postGraphBetaJson: graphMessagesMockState.postGraphBetaJson, + deleteGraphRequest: graphMessagesMockState.deleteGraphRequest, + escapeOData: vi.fn((value: string) => value.replaceAll("'", "''")), + }; +}); + +vi.mock("./conversation-store-fs.js", () => ({ + createMSTeamsConversationStoreFs: () => ({ + findPreferredDmByUserId: graphMessagesMockState.findPreferredDmByUserId, + }), +})); + +export const TOKEN = "test-graph-token"; +export const CHAT_ID = "19:abc@thread.tacv2"; +export const CHANNEL_TO = "team-id-1/channel-id-1"; + +export function getGraphMessagesMockState(): typeof graphMessagesMockState { + return graphMessagesMockState; +} + +export type GraphMessagesTestModule = typeof import("./graph-messages.js"); + +export function loadGraphMessagesTestModule(): Promise { + return import("./graph-messages.js"); +} + +export function installGraphMessagesMockDefaults(): void { + beforeEach(() => { + vi.clearAllMocks(); + graphMessagesMockState.resolveGraphToken.mockResolvedValue(TOKEN); + }); +} diff --git a/extensions/msteams/src/messenger.test.ts b/extensions/msteams/src/messenger.test.ts index 36cae456842..fa938626e0a 100644 --- a/extensions/msteams/src/messenger.test.ts +++ b/extensions/msteams/src/messenger.test.ts @@ -196,6 +196,48 @@ describe("msteams messenger", () => { serviceUrl: "https://service.example.com", }; + async function sendAndCaptureRevokeFallbackReference( + conversation: StoredConversationReference["conversation"], + ) { + const proactiveSent: string[] = []; + let capturedReference: unknown; + const conversationRef: StoredConversationReference = { + activityId: "activity456", + user: { id: "user123", name: "User" }, + agent: { id: "bot123", name: "Bot" }, + conversation, + channelId: "msteams", + serviceUrl: "https://service.example.com", + }; + const adapter: MSTeamsAdapter = { + continueConversation: async (_appId, reference, logic) => { + capturedReference = reference; + await logic({ + sendActivity: createRecordedSendActivity(proactiveSent), + updateActivity: noopUpdateActivity, + deleteActivity: noopDeleteActivity, + }); + }, + process: async () => {}, + updateActivity: noopUpdateActivity, + deleteActivity: noopDeleteActivity, + }; + + await sendMSTeamsMessages({ + replyStyle: "thread", + adapter, + appId: "app123", + conversationRef, + context: createRevokedThreadContext(), + messages: [{ text: "hello" }], + }); + + return { + proactiveSent, + reference: capturedReference as { conversation?: { id?: string }; activityId?: string }, + }; + } + it("sends thread messages via the provided context", async () => { const sent: string[] = []; const ctx = { @@ -409,97 +451,27 @@ describe("msteams messenger", () => { }); it("reconstructs threaded conversation ID for channel revoke fallback", async () => { - const proactiveSent: string[] = []; - let capturedReference: unknown; - - const channelRef: StoredConversationReference = { - activityId: "activity456", - user: { id: "user123", name: "User" }, - agent: { id: "bot123", name: "Bot" }, - conversation: { - id: "19:abc@thread.tacv2;messageid=deadbeef", - conversationType: "channel", - }, - channelId: "msteams", - serviceUrl: "https://service.example.com", - }; - - const ctx = createRevokedThreadContext(); - const adapter: MSTeamsAdapter = { - continueConversation: async (_appId, reference, logic) => { - capturedReference = reference; - await logic({ - sendActivity: createRecordedSendActivity(proactiveSent), - updateActivity: noopUpdateActivity, - deleteActivity: noopDeleteActivity, - }); - }, - process: async () => {}, - updateActivity: noopUpdateActivity, - deleteActivity: noopDeleteActivity, - }; - - await sendMSTeamsMessages({ - replyStyle: "thread", - adapter, - appId: "app123", - conversationRef: channelRef, - context: ctx, - messages: [{ text: "hello" }], + const { proactiveSent, reference } = await sendAndCaptureRevokeFallbackReference({ + id: "19:abc@thread.tacv2;messageid=deadbeef", + conversationType: "channel", }); expect(proactiveSent).toEqual(["hello"]); - const ref = capturedReference as { conversation?: { id?: string }; activityId?: string }; // Conversation ID should include the thread suffix for channel messages - expect(ref.conversation?.id).toBe("19:abc@thread.tacv2;messageid=activity456"); - expect(ref.activityId).toBeUndefined(); + expect(reference.conversation?.id).toBe("19:abc@thread.tacv2;messageid=activity456"); + expect(reference.activityId).toBeUndefined(); }); it("does not add thread suffix for group chat revoke fallback", async () => { - const proactiveSent: string[] = []; - let capturedReference: unknown; - - const groupRef: StoredConversationReference = { - activityId: "activity789", - user: { id: "user123", name: "User" }, - agent: { id: "bot123", name: "Bot" }, - conversation: { - id: "19:group123@thread.v2", - conversationType: "groupChat", - }, - channelId: "msteams", - serviceUrl: "https://service.example.com", - }; - - const ctx = createRevokedThreadContext(); - const adapter: MSTeamsAdapter = { - continueConversation: async (_appId, reference, logic) => { - capturedReference = reference; - await logic({ - sendActivity: createRecordedSendActivity(proactiveSent), - updateActivity: noopUpdateActivity, - deleteActivity: noopDeleteActivity, - }); - }, - process: async () => {}, - updateActivity: noopUpdateActivity, - deleteActivity: noopDeleteActivity, - }; - - await sendMSTeamsMessages({ - replyStyle: "thread", - adapter, - appId: "app123", - conversationRef: groupRef, - context: ctx, - messages: [{ text: "hello" }], + const { proactiveSent, reference } = await sendAndCaptureRevokeFallbackReference({ + id: "19:group123@thread.v2", + conversationType: "groupChat", }); expect(proactiveSent).toEqual(["hello"]); - const ref = capturedReference as { conversation?: { id?: string }; activityId?: string }; // Group chat should NOT have thread suffix — flat conversation - expect(ref.conversation?.id).toBe("19:group123@thread.v2"); - expect(ref.activityId).toBeUndefined(); + expect(reference.conversation?.id).toBe("19:group123@thread.v2"); + expect(reference.activityId).toBeUndefined(); }); it("retries top-level sends on transient (5xx)", async () => { diff --git a/extensions/nextcloud-talk/src/send.cfg-threading.test.ts b/extensions/nextcloud-talk/src/send.cfg-threading.test.ts index 34f87f5402b..9f648a3436c 100644 --- a/extensions/nextcloud-talk/src/send.cfg-threading.test.ts +++ b/extensions/nextcloud-talk/src/send.cfg-threading.test.ts @@ -33,6 +33,21 @@ vi.mock("./send.runtime.js", () => { const { sendMessageNextcloudTalk, sendReactionNextcloudTalk } = await import("./send.js"); +function expectProvidedMessageCfgThreading(cfg: unknown): void { + expectProvidedCfgSkipsRuntimeLoad({ + loadConfig: hoisted.loadConfig, + resolveAccount: hoisted.resolveNextcloudTalkAccount, + cfg, + accountId: "work", + }); + expect(hoisted.resolveMarkdownTableMode).toHaveBeenCalledWith({ + cfg, + channel: "nextcloud-talk", + accountId: "default", + }); + expect(hoisted.convertMarkdownTables).toHaveBeenCalledWith("hello", "preserve"); +} + describe("nextcloud-talk send cfg threading", () => { const fetchMock = vi.fn(); const defaultAccount = { @@ -41,6 +56,17 @@ describe("nextcloud-talk send cfg threading", () => { secret: "secret-value", }; + function mockNextcloudMessageResponse(messageId: number, timestamp: number): void { + fetchMock.mockResolvedValueOnce( + new Response( + JSON.stringify({ + ocs: { data: { id: messageId, timestamp } }, + }), + { status: 200, headers: { "content-type": "application/json" } }, + ), + ); + } + beforeEach(() => { vi.stubGlobal("fetch", fetchMock); // Route the SSRF guard mock through the global fetch mock. @@ -66,32 +92,14 @@ describe("nextcloud-talk send cfg threading", () => { it("uses provided cfg for sendMessage and skips runtime loadConfig", async () => { const cfg = { source: "provided" } as const; - fetchMock.mockResolvedValueOnce( - new Response( - JSON.stringify({ - ocs: { data: { id: 12345, timestamp: 1_706_000_000 } }, - }), - { status: 200, headers: { "content-type": "application/json" } }, - ), - ); + mockNextcloudMessageResponse(12345, 1_706_000_000); const result = await sendMessageNextcloudTalk("room:abc123", "hello", { cfg, accountId: "work", }); - expectProvidedCfgSkipsRuntimeLoad({ - loadConfig: hoisted.loadConfig, - resolveAccount: hoisted.resolveNextcloudTalkAccount, - cfg, - accountId: "work", - }); - expect(hoisted.resolveMarkdownTableMode).toHaveBeenCalledWith({ - cfg, - channel: "nextcloud-talk", - accountId: "default", - }); - expect(hoisted.convertMarkdownTables).toHaveBeenCalledWith("hello", "preserve"); + expectProvidedMessageCfgThreading(cfg); expect(hoisted.record).toHaveBeenCalledWith({ channel: "nextcloud-talk", accountId: "default", @@ -110,32 +118,14 @@ describe("nextcloud-talk send cfg threading", () => { hoisted.record.mockImplementation(() => { throw new Error("Nextcloud Talk runtime not initialized"); }); - fetchMock.mockResolvedValueOnce( - new Response( - JSON.stringify({ - ocs: { data: { id: 12346, timestamp: 1_706_000_001 } }, - }), - { status: 200, headers: { "content-type": "application/json" } }, - ), - ); + mockNextcloudMessageResponse(12346, 1_706_000_001); const result = await sendMessageNextcloudTalk("room:abc123", "hello", { cfg, accountId: "work", }); - expectProvidedCfgSkipsRuntimeLoad({ - loadConfig: hoisted.loadConfig, - resolveAccount: hoisted.resolveNextcloudTalkAccount, - cfg, - accountId: "work", - }); - expect(hoisted.resolveMarkdownTableMode).toHaveBeenCalledWith({ - cfg, - channel: "nextcloud-talk", - accountId: "default", - }); - expect(hoisted.convertMarkdownTables).toHaveBeenCalledWith("hello", "preserve"); + expectProvidedMessageCfgThreading(cfg); expect(result).toEqual({ messageId: "12346", roomToken: "abc123", diff --git a/extensions/nostr/src/channel.outbound.test.ts b/extensions/nostr/src/channel.outbound.test.ts index 8fdc07f5fb4..6dee73580d2 100644 --- a/extensions/nostr/src/channel.outbound.test.ts +++ b/extensions/nostr/src/channel.outbound.test.ts @@ -28,6 +28,40 @@ function createCfg() { }; } +function installOutboundRuntime(convertMarkdownTables = vi.fn((text: string) => text)) { + const resolveMarkdownTableMode = vi.fn(() => "off"); + setNostrRuntime({ + channel: { + text: { + resolveMarkdownTableMode, + convertMarkdownTables, + }, + }, + reply: {}, + } as unknown as PluginRuntime); + return { resolveMarkdownTableMode, convertMarkdownTables }; +} + +async function startOutboundAccount(accountId?: string) { + const sendDm = vi.fn(async () => {}); + const bus = { + sendDm, + close: vi.fn(), + getMetrics: vi.fn(() => ({ counters: {} })), + publishProfile: vi.fn(), + getProfileState: vi.fn(async () => null), + }; + mocks.startNostrBus.mockResolvedValueOnce(bus as unknown); + + const cleanup = (await startNostrGatewayAccount( + createStartAccountContext({ + account: buildResolvedNostrAccount(accountId ? { accountId } : undefined), + }), + )) as { stop: () => void }; + + return { cleanup, sendDm }; +} + describe("nostr outbound cfg threading", () => { afterEach(() => { mocks.normalizePubkey.mockClear(); @@ -35,33 +69,10 @@ describe("nostr outbound cfg threading", () => { }); it("uses resolved cfg when converting markdown tables before send", async () => { - const resolveMarkdownTableMode = vi.fn(() => "off"); - const convertMarkdownTables = vi.fn((text: string) => `converted:${text}`); - setNostrRuntime({ - channel: { - text: { - resolveMarkdownTableMode, - convertMarkdownTables, - }, - }, - reply: {}, - } as unknown as PluginRuntime); - - const sendDm = vi.fn(async () => {}); - const bus = { - sendDm, - close: vi.fn(), - getMetrics: vi.fn(() => ({ counters: {} })), - publishProfile: vi.fn(), - getProfileState: vi.fn(async () => null), - }; - mocks.startNostrBus.mockResolvedValueOnce(bus as unknown); - - const cleanup = (await startNostrGatewayAccount( - createStartAccountContext({ - account: buildResolvedNostrAccount(), - }), - )) as { stop: () => void }; + const { resolveMarkdownTableMode, convertMarkdownTables } = installOutboundRuntime( + vi.fn((text: string) => `converted:${text}`), + ); + const { cleanup, sendDm } = await startOutboundAccount(); const cfg = createCfg(); await nostrOutboundAdapter.sendText({ @@ -84,33 +95,8 @@ describe("nostr outbound cfg threading", () => { }); it("uses the configured defaultAccount when accountId is omitted", async () => { - const resolveMarkdownTableMode = vi.fn(() => "off"); - const convertMarkdownTables = vi.fn((text: string) => text); - setNostrRuntime({ - channel: { - text: { - resolveMarkdownTableMode, - convertMarkdownTables, - }, - }, - reply: {}, - } as unknown as PluginRuntime); - - const sendDm = vi.fn(async () => {}); - const bus = { - sendDm, - close: vi.fn(), - getMetrics: vi.fn(() => ({ counters: {} })), - publishProfile: vi.fn(), - getProfileState: vi.fn(async () => null), - }; - mocks.startNostrBus.mockResolvedValueOnce(bus as unknown); - - const cleanup = (await startNostrGatewayAccount( - createStartAccountContext({ - account: buildResolvedNostrAccount({ accountId: "work" }), - }), - )) as { stop: () => void }; + const { resolveMarkdownTableMode } = installOutboundRuntime(); + const { cleanup, sendDm } = await startOutboundAccount("work"); const cfg = { channels: { diff --git a/extensions/ollama/src/web-search-provider.test.ts b/extensions/ollama/src/web-search-provider.test.ts index 4f74db3e40b..86aa5b467de 100644 --- a/extensions/ollama/src/web-search-provider.test.ts +++ b/extensions/ollama/src/web-search-provider.test.ts @@ -15,6 +15,42 @@ vi.mock("openclaw/plugin-sdk/ssrf-runtime", () => ({ fetchWithSsrFGuard: fetchWithSsrFGuardMock, })); +type OllamaProviderConfigOverride = Partial<{ + api: "ollama"; + apiKey: string; + baseUrl: string; + models: NonNullable< + NonNullable["providers"]>[string] + >["models"]; +}>; + +function createOllamaConfig(provider: OllamaProviderConfigOverride = {}): OpenClawConfig { + return { + models: { + providers: { + ollama: { + baseUrl: "http://ollama.local:11434/v1", + api: "ollama", + models: [], + ...provider, + }, + }, + }, + }; +} + +function createSetupNotes() { + const notes: Array<{ title?: string; message: string }> = []; + return { + notes, + prompter: { + note: async (message: string, title?: string) => { + notes.push({ title, message }); + }, + }, + }; +} + describe("ollama web search provider", () => { beforeEach(() => { fetchWithSsrFGuardMock.mockReset(); @@ -77,17 +113,7 @@ describe("ollama web search provider", () => { const provider = createOllamaWebSearchProvider(); const tool = provider.createTool({ - config: { - models: { - providers: { - ollama: { - baseUrl: "http://ollama.local:11434/v1", - api: "ollama", - models: [], - }, - }, - }, - }, + config: createOllamaConfig(), } as never); if (!tool) { throw new Error("Expected tool definition"); @@ -137,26 +163,12 @@ describe("ollama web search provider", () => { it("warns when Ollama is not reachable during setup without cancelling", async () => { fetchWithSsrFGuardMock.mockRejectedValueOnce(new Error("connect failed")); - const notes: Array<{ title?: string; message: string }> = []; - const config: OpenClawConfig = { - models: { - providers: { - ollama: { - baseUrl: "http://ollama.local:11434/v1", - api: "ollama", - models: [], - }, - }, - }, - }; + const config = createOllamaConfig(); + const { notes, prompter } = createSetupNotes(); const next = await testing.warnOllamaWebSearchPrereqs({ config, - prompter: { - note: async (message: string, title?: string) => { - notes.push({ title, message }); - }, - }, + prompter, }); expect(next).toBe(config); @@ -172,18 +184,12 @@ describe("ollama web search provider", () => { const original = process.env.OLLAMA_API_KEY; try { process.env.OLLAMA_API_KEY = "real-secret-from-env"; - const key = testing.resolveOllamaWebSearchApiKey({ - models: { - providers: { - ollama: { - apiKey: "OLLAMA_API_KEY", - baseUrl: "http://localhost:11434", - api: "ollama", - models: [], - }, - }, - }, - }); + const key = testing.resolveOllamaWebSearchApiKey( + createOllamaConfig({ + apiKey: "OLLAMA_API_KEY", + baseUrl: "http://localhost:11434", + }), + ); expect(key).toBe("real-secret-from-env"); } finally { if (original === undefined) { @@ -214,26 +220,12 @@ describe("ollama web search provider", () => { release: vi.fn(async () => {}), }); - const notes: Array<{ title?: string; message: string }> = []; - const config: OpenClawConfig = { - models: { - providers: { - ollama: { - baseUrl: "http://ollama.local:11434/v1", - api: "ollama", - models: [], - }, - }, - }, - }; + const config = createOllamaConfig(); + const { notes, prompter } = createSetupNotes(); const next = await testing.warnOllamaWebSearchPrereqs({ config, - prompter: { - note: async (message: string, title?: string) => { - notes.push({ title, message }); - }, - }, + prompter, }); expect(next).toBe(config); diff --git a/extensions/openai/video-generation-provider.test.ts b/extensions/openai/video-generation-provider.test.ts index b8582d233b8..79aee1e6ce6 100644 --- a/extensions/openai/video-generation-provider.test.ts +++ b/extensions/openai/video-generation-provider.test.ts @@ -1,45 +1,21 @@ -import { afterEach, describe, expect, it, vi } from "vitest"; -import { buildOpenAIVideoGenerationProvider } from "./video-generation-provider.js"; +import { beforeAll, describe, expect, it, vi } from "vitest"; +import { + getProviderHttpMocks, + installProviderHttpMockCleanup, +} from "../../test/helpers/media-generation/provider-http-mocks.js"; -const { - resolveApiKeyForProviderMock, - postJsonRequestMock, - fetchWithTimeoutMock, - assertOkOrThrowHttpErrorMock, - resolveProviderHttpRequestConfigMock, -} = vi.hoisted(() => ({ - resolveApiKeyForProviderMock: vi.fn(async () => ({ apiKey: "openai-key" })), - postJsonRequestMock: vi.fn(), - fetchWithTimeoutMock: vi.fn(), - assertOkOrThrowHttpErrorMock: vi.fn(async () => {}), - resolveProviderHttpRequestConfigMock: vi.fn((params) => ({ - baseUrl: params.baseUrl ?? params.defaultBaseUrl, - allowPrivateNetwork: false, - headers: new Headers(params.defaultHeaders), - dispatcherPolicy: undefined, - })), -})); +const { postJsonRequestMock, fetchWithTimeoutMock, resolveProviderHttpRequestConfigMock } = + getProviderHttpMocks(); -vi.mock("openclaw/plugin-sdk/provider-auth-runtime", () => ({ - resolveApiKeyForProvider: resolveApiKeyForProviderMock, -})); +let buildOpenAIVideoGenerationProvider: typeof import("./video-generation-provider.js").buildOpenAIVideoGenerationProvider; -vi.mock("openclaw/plugin-sdk/provider-http", () => ({ - assertOkOrThrowHttpError: assertOkOrThrowHttpErrorMock, - fetchWithTimeout: fetchWithTimeoutMock, - postJsonRequest: postJsonRequestMock, - resolveProviderHttpRequestConfig: resolveProviderHttpRequestConfigMock, -})); +beforeAll(async () => { + ({ buildOpenAIVideoGenerationProvider } = await import("./video-generation-provider.js")); +}); + +installProviderHttpMockCleanup(); describe("openai video generation provider", () => { - afterEach(() => { - resolveApiKeyForProviderMock.mockClear(); - postJsonRequestMock.mockReset(); - fetchWithTimeoutMock.mockReset(); - assertOkOrThrowHttpErrorMock.mockClear(); - resolveProviderHttpRequestConfigMock.mockClear(); - }); - it("uses JSON for text-only Sora requests", async () => { postJsonRequestMock.mockResolvedValue({ response: { diff --git a/extensions/qa-lab/src/lab-server.ts b/extensions/qa-lab/src/lab-server.ts index bba69e0a64f..3a41a3f7816 100644 --- a/extensions/qa-lab/src/lab-server.ts +++ b/extensions/qa-lab/src/lab-server.ts @@ -535,6 +535,49 @@ export async function startQaLabServer(params?: { })(); return runnerModelCatalogPromise; }; + + async function runSelfCheck(): Promise { + latestScenarioRun = withQaLabRunCounts({ + kind: "self-check", + status: "running", + startedAt: new Date().toISOString(), + scenarios: [ + { + id: "qa-self-check", + name: "Synthetic Slack-class roundtrip", + status: "running", + }, + ], + }); + const result = await runQaSelfCheckAgainstState({ + state, + cfg: gateway?.cfg ?? createQaLabConfig(listenUrl), + outputPath: params?.outputPath, + repoRoot, + }); + latestScenarioRun = withQaLabRunCounts({ + kind: "self-check", + status: "completed", + startedAt: latestScenarioRun.startedAt, + finishedAt: new Date().toISOString(), + scenarios: [ + { + id: "qa-self-check", + name: result.scenarioResult.name, + status: result.scenarioResult.status, + details: result.scenarioResult.details, + steps: result.scenarioResult.steps, + }, + ], + }); + latestReport = { + outputPath: result.outputPath, + markdown: result.report, + generatedAt: new Date().toISOString(), + }; + return result; + } + const server = createServer(async (req, res) => { const url = new URL(req.url ?? "/", "http://127.0.0.1"); @@ -644,44 +687,7 @@ export async function startQaLabServer(params?: { writeError(res, 409, "QA suite run already in progress"); return; } - latestScenarioRun = withQaLabRunCounts({ - kind: "self-check", - status: "running", - startedAt: new Date().toISOString(), - scenarios: [ - { - id: "qa-self-check", - name: "Synthetic Slack-class roundtrip", - status: "running", - }, - ], - }); - const result = await runQaSelfCheckAgainstState({ - state, - cfg: gateway?.cfg ?? createQaLabConfig(listenUrl), - outputPath: params?.outputPath, - repoRoot, - }); - latestScenarioRun = withQaLabRunCounts({ - kind: "self-check", - status: "completed", - startedAt: latestScenarioRun.startedAt, - finishedAt: new Date().toISOString(), - scenarios: [ - { - id: "qa-self-check", - name: result.scenarioResult.name, - status: result.scenarioResult.status, - details: result.scenarioResult.details, - steps: result.scenarioResult.steps, - }, - ], - }); - latestReport = { - outputPath: result.outputPath, - markdown: result.report, - generatedAt: new Date().toISOString(), - }; + const result = await runSelfCheck(); writeJson(res, 200, serializeSelfCheck(result)); return; } @@ -846,47 +852,7 @@ export async function startQaLabServer(params?: { setLatestReport(next: QaLabLatestReport | null) { latestReport = next; }, - async runSelfCheck() { - latestScenarioRun = withQaLabRunCounts({ - kind: "self-check", - status: "running", - startedAt: new Date().toISOString(), - scenarios: [ - { - id: "qa-self-check", - name: "Synthetic Slack-class roundtrip", - status: "running", - }, - ], - }); - const result = await runQaSelfCheckAgainstState({ - state, - cfg: gateway?.cfg ?? createQaLabConfig(listenUrl), - outputPath: params?.outputPath, - repoRoot, - }); - latestScenarioRun = withQaLabRunCounts({ - kind: "self-check", - status: "completed", - startedAt: latestScenarioRun.startedAt, - finishedAt: new Date().toISOString(), - scenarios: [ - { - id: "qa-self-check", - name: result.scenarioResult.name, - status: result.scenarioResult.status, - details: result.scenarioResult.details, - steps: result.scenarioResult.steps, - }, - ], - }); - latestReport = { - outputPath: result.outputPath, - markdown: result.report, - generatedAt: new Date().toISOString(), - }; - return result; - }, + runSelfCheck, async stop() { await gateway?.stop(); await new Promise((resolve, reject) => diff --git a/extensions/qqbot/src/utils/platform.test.ts b/extensions/qqbot/src/utils/platform.test.ts index 6791941fda1..246683d6fa5 100644 --- a/extensions/qqbot/src/utils/platform.test.ts +++ b/extensions/qqbot/src/utils/platform.test.ts @@ -11,6 +11,32 @@ import { describe("qqbot local media path remapping", () => { const createdPaths: string[] = []; + function createOpenClawTestRoot() { + const actualHome = getHomeDir(); + const openclawDir = path.join(actualHome, ".openclaw"); + fs.mkdirSync(openclawDir, { recursive: true }); + const testRoot = fs.mkdtempSync(path.join(openclawDir, "qqbot-platform-test-")); + createdPaths.push(testRoot); + return { actualHome, testRootName: path.basename(testRoot) }; + } + + function createQqbotMediaFile(fileName: string) { + const { actualHome, testRootName } = createOpenClawTestRoot(); + const mediaFile = path.join( + actualHome, + ".openclaw", + "media", + "qqbot", + "downloads", + testRootName, + fileName, + ); + fs.mkdirSync(path.dirname(mediaFile), { recursive: true }); + fs.writeFileSync(mediaFile, "image", "utf8"); + createdPaths.push(path.dirname(mediaFile)); + return { actualHome, testRootName, mediaFile }; + } + afterEach(() => { vi.restoreAllMocks(); for (const target of createdPaths.splice(0)) { @@ -19,24 +45,7 @@ describe("qqbot local media path remapping", () => { }); it("remaps missing workspace media paths to the real media directory", () => { - const actualHome = getHomeDir(); - const openclawDir = path.join(actualHome, ".openclaw"); - fs.mkdirSync(openclawDir, { recursive: true }); - const testRoot = fs.mkdtempSync(path.join(openclawDir, "qqbot-platform-test-")); - createdPaths.push(testRoot); - - const mediaFile = path.join( - actualHome, - ".openclaw", - "media", - "qqbot", - "downloads", - path.basename(testRoot), - "example.png", - ); - fs.mkdirSync(path.dirname(mediaFile), { recursive: true }); - fs.writeFileSync(mediaFile, "image", "utf8"); - createdPaths.push(path.dirname(mediaFile)); + const { actualHome, testRootName, mediaFile } = createQqbotMediaFile("example.png"); const missingWorkspacePath = path.join( actualHome, @@ -44,7 +53,7 @@ describe("qqbot local media path remapping", () => { "workspace", "qqbot", "downloads", - path.basename(testRoot), + testRootName, "example.png", ); @@ -52,24 +61,7 @@ describe("qqbot local media path remapping", () => { }); it("leaves existing media paths unchanged", () => { - const actualHome = getHomeDir(); - const openclawDir = path.join(actualHome, ".openclaw"); - fs.mkdirSync(openclawDir, { recursive: true }); - const testRoot = fs.mkdtempSync(path.join(openclawDir, "qqbot-platform-test-")); - createdPaths.push(testRoot); - - const mediaFile = path.join( - actualHome, - ".openclaw", - "media", - "qqbot", - "downloads", - path.basename(testRoot), - "existing.png", - ); - fs.mkdirSync(path.dirname(mediaFile), { recursive: true }); - fs.writeFileSync(mediaFile, "image", "utf8"); - createdPaths.push(path.dirname(mediaFile)); + const { mediaFile } = createQqbotMediaFile("existing.png"); expect(resolveQQBotLocalMediaPath(mediaFile)).toBe(mediaFile); }); @@ -99,41 +91,20 @@ describe("qqbot local media path remapping", () => { }); it("allows structured payload files inside the QQ Bot media directory", () => { - const actualHome = getHomeDir(); - const openclawDir = path.join(actualHome, ".openclaw"); - fs.mkdirSync(openclawDir, { recursive: true }); - const testRoot = fs.mkdtempSync(path.join(openclawDir, "qqbot-platform-test-")); - createdPaths.push(testRoot); - - const mediaFile = path.join( - actualHome, - ".openclaw", - "media", - "qqbot", - "downloads", - path.basename(testRoot), - "allowed.png", - ); - fs.mkdirSync(path.dirname(mediaFile), { recursive: true }); - fs.writeFileSync(mediaFile, "image", "utf8"); - createdPaths.push(path.dirname(mediaFile)); + const { mediaFile } = createQqbotMediaFile("allowed.png"); expect(resolveQQBotPayloadLocalFilePath(mediaFile)).toBe(mediaFile); }); it("blocks structured payload files inside the QQ Bot data directory", () => { - const actualHome = getHomeDir(); - const openclawDir = path.join(actualHome, ".openclaw"); - fs.mkdirSync(openclawDir, { recursive: true }); - const testRoot = fs.mkdtempSync(path.join(openclawDir, "qqbot-platform-test-")); - createdPaths.push(testRoot); + const { actualHome, testRootName } = createOpenClawTestRoot(); const dataFile = path.join( actualHome, ".openclaw", "qqbot", "sessions", - path.basename(testRoot), + testRootName, "session.json", ); fs.mkdirSync(path.dirname(dataFile), { recursive: true }); @@ -144,24 +115,7 @@ describe("qqbot local media path remapping", () => { }); it("allows legacy workspace paths when they remap into QQ Bot media storage", () => { - const actualHome = getHomeDir(); - const openclawDir = path.join(actualHome, ".openclaw"); - fs.mkdirSync(openclawDir, { recursive: true }); - const testRoot = fs.mkdtempSync(path.join(openclawDir, "qqbot-platform-test-")); - createdPaths.push(testRoot); - - const mediaFile = path.join( - actualHome, - ".openclaw", - "media", - "qqbot", - "downloads", - path.basename(testRoot), - "legacy.png", - ); - fs.mkdirSync(path.dirname(mediaFile), { recursive: true }); - fs.writeFileSync(mediaFile, "image", "utf8"); - createdPaths.push(path.dirname(mediaFile)); + const { actualHome, testRootName, mediaFile } = createQqbotMediaFile("legacy.png"); const missingWorkspacePath = path.join( actualHome, @@ -169,7 +123,7 @@ describe("qqbot local media path remapping", () => { "workspace", "qqbot", "downloads", - path.basename(testRoot), + testRootName, "legacy.png", ); diff --git a/extensions/qwen/video-generation-provider.test.ts b/extensions/qwen/video-generation-provider.test.ts index 1b62b93e4fe..fb90b1d59f4 100644 --- a/extensions/qwen/video-generation-provider.test.ts +++ b/extensions/qwen/video-generation-provider.test.ts @@ -1,71 +1,27 @@ -import { afterEach, describe, expect, it, vi } from "vitest"; -import { buildQwenVideoGenerationProvider } from "./video-generation-provider.js"; +import { beforeAll, describe, expect, it } from "vitest"; +import { + expectDashscopeVideoTaskPoll, + expectSuccessfulDashscopeVideoResult, + mockSuccessfulDashscopeVideoTask, +} from "../../test/helpers/media-generation/dashscope-video-provider.js"; +import { + getProviderHttpMocks, + installProviderHttpMockCleanup, +} from "../../test/helpers/media-generation/provider-http-mocks.js"; -const { - resolveApiKeyForProviderMock, - postJsonRequestMock, - fetchWithTimeoutMock, - assertOkOrThrowHttpErrorMock, - resolveProviderHttpRequestConfigMock, -} = vi.hoisted(() => ({ - resolveApiKeyForProviderMock: vi.fn(async () => ({ apiKey: "qwen-key" })), - postJsonRequestMock: vi.fn(), - fetchWithTimeoutMock: vi.fn(), - assertOkOrThrowHttpErrorMock: vi.fn(async () => {}), - resolveProviderHttpRequestConfigMock: vi.fn((params) => ({ - baseUrl: params.baseUrl ?? params.defaultBaseUrl, - allowPrivateNetwork: false, - headers: new Headers(params.defaultHeaders), - dispatcherPolicy: undefined, - })), -})); +const { postJsonRequestMock, fetchWithTimeoutMock } = getProviderHttpMocks(); -vi.mock("openclaw/plugin-sdk/provider-auth-runtime", () => ({ - resolveApiKeyForProvider: resolveApiKeyForProviderMock, -})); +let buildQwenVideoGenerationProvider: typeof import("./video-generation-provider.js").buildQwenVideoGenerationProvider; -vi.mock("openclaw/plugin-sdk/provider-http", () => ({ - assertOkOrThrowHttpError: assertOkOrThrowHttpErrorMock, - fetchWithTimeout: fetchWithTimeoutMock, - postJsonRequest: postJsonRequestMock, - resolveProviderHttpRequestConfig: resolveProviderHttpRequestConfigMock, -})); +beforeAll(async () => { + ({ buildQwenVideoGenerationProvider } = await import("./video-generation-provider.js")); +}); + +installProviderHttpMockCleanup(); describe("qwen video generation provider", () => { - afterEach(() => { - resolveApiKeyForProviderMock.mockClear(); - postJsonRequestMock.mockReset(); - fetchWithTimeoutMock.mockReset(); - assertOkOrThrowHttpErrorMock.mockClear(); - resolveProviderHttpRequestConfigMock.mockClear(); - }); - it("submits async Wan generation, polls task status, and downloads the resulting video", async () => { - postJsonRequestMock.mockResolvedValue({ - response: { - json: async () => ({ - request_id: "req-1", - output: { - task_id: "task-1", - }, - }), - }, - release: vi.fn(async () => {}), - }); - fetchWithTimeoutMock - .mockResolvedValueOnce({ - json: async () => ({ - output: { - task_status: "SUCCEEDED", - results: [{ video_url: "https://example.com/out.mp4" }], - }, - }), - headers: new Headers(), - }) - .mockResolvedValueOnce({ - arrayBuffer: async () => Buffer.from("mp4-bytes"), - headers: new Headers({ "content-type": "video/mp4" }), - }); + mockSuccessfulDashscopeVideoTask({ postJsonRequestMock, fetchWithTimeoutMock }); const provider = buildQwenVideoGenerationProvider(); const result = await provider.generateVideo({ @@ -90,22 +46,8 @@ describe("qwen video generation provider", () => { }), }), ); - expect(fetchWithTimeoutMock).toHaveBeenNthCalledWith( - 1, - "https://dashscope-intl.aliyuncs.com/api/v1/tasks/task-1", - expect.objectContaining({ method: "GET" }), - 120000, - fetch, - ); - expect(result.videos).toHaveLength(1); - expect(result.videos[0]?.mimeType).toBe("video/mp4"); - expect(result.metadata).toEqual( - expect.objectContaining({ - requestId: "req-1", - taskId: "task-1", - taskStatus: "SUCCEEDED", - }), - ); + expectDashscopeVideoTaskPoll(fetchWithTimeoutMock); + expectSuccessfulDashscopeVideoResult(result); }); it("fails fast when reference inputs are local buffers instead of remote URLs", async () => { @@ -126,31 +68,13 @@ describe("qwen video generation provider", () => { }); it("preserves dedicated coding endpoints for dedicated API keys", async () => { - postJsonRequestMock.mockResolvedValue({ - response: { - json: async () => ({ - request_id: "req-2", - output: { - task_id: "task-2", - }, - }), + mockSuccessfulDashscopeVideoTask( + { + postJsonRequestMock, + fetchWithTimeoutMock, }, - release: vi.fn(async () => {}), - }); - fetchWithTimeoutMock - .mockResolvedValueOnce({ - json: async () => ({ - output: { - task_status: "SUCCEEDED", - results: [{ video_url: "https://example.com/out.mp4" }], - }, - }), - headers: new Headers(), - }) - .mockResolvedValueOnce({ - arrayBuffer: async () => Buffer.from("mp4-bytes"), - headers: new Headers({ "content-type": "video/mp4" }), - }); + { requestId: "req-2", taskId: "task-2" }, + ); const provider = buildQwenVideoGenerationProvider(); await provider.generateVideo({ @@ -174,12 +98,9 @@ describe("qwen video generation provider", () => { url: "https://coding-intl.dashscope.aliyuncs.com/api/v1/services/aigc/video-generation/video-synthesis", }), ); - expect(fetchWithTimeoutMock).toHaveBeenNthCalledWith( - 1, - "https://coding-intl.dashscope.aliyuncs.com/api/v1/tasks/task-2", - expect.objectContaining({ method: "GET" }), - 120000, - fetch, - ); + expectDashscopeVideoTaskPoll(fetchWithTimeoutMock, { + baseUrl: "https://coding-intl.dashscope.aliyuncs.com", + taskId: "task-2", + }); }); }); diff --git a/extensions/qwen/video-generation-provider.ts b/extensions/qwen/video-generation-provider.ts index 00770440472..6d7b3d96be2 100644 --- a/extensions/qwen/video-generation-provider.ts +++ b/extensions/qwen/video-generation-provider.ts @@ -1,22 +1,14 @@ import { isProviderApiKeyConfigured } from "openclaw/plugin-sdk/provider-auth"; import { resolveApiKeyForProvider } from "openclaw/plugin-sdk/provider-auth-runtime"; +import { resolveProviderHttpRequestConfig } from "openclaw/plugin-sdk/provider-http"; import { - assertOkOrThrowHttpError, - postJsonRequest, - resolveProviderHttpRequestConfig, -} from "openclaw/plugin-sdk/provider-http"; -import { - DEFAULT_VIDEO_GENERATION_DURATION_SECONDS, + DASHSCOPE_WAN_VIDEO_CAPABILITIES, + DASHSCOPE_WAN_VIDEO_MODELS, + DEFAULT_DASHSCOPE_WAN_VIDEO_MODEL, DEFAULT_VIDEO_GENERATION_TIMEOUT_MS, - DEFAULT_VIDEO_RESOLUTION_TO_SIZE, - buildDashscopeVideoGenerationInput, - buildDashscopeVideoGenerationParameters, - downloadDashscopeGeneratedVideos, - extractDashscopeVideoUrls, - pollDashscopeVideoTaskUntilComplete, + runDashscopeVideoGenerationTask, } from "openclaw/plugin-sdk/video-generation"; import type { - DashscopeVideoGenerationResponse, VideoGenerationProvider, VideoGenerationRequest, VideoGenerationResult, @@ -24,7 +16,7 @@ import type { import { QWEN_STANDARD_CN_BASE_URL, QWEN_STANDARD_GLOBAL_BASE_URL } from "./models.js"; const DEFAULT_QWEN_VIDEO_BASE_URL = "https://dashscope-intl.aliyuncs.com"; -const DEFAULT_QWEN_VIDEO_MODEL = "wan2.6-t2v"; +const DEFAULT_QWEN_VIDEO_MODEL = DEFAULT_DASHSCOPE_WAN_VIDEO_MODEL; function resolveQwenVideoBaseUrl(req: VideoGenerationRequest): string { const direct = req.cfg?.models?.providers?.qwen?.baseUrl?.trim(); @@ -66,45 +58,13 @@ export function buildQwenVideoGenerationProvider(): VideoGenerationProvider { id: "qwen", label: "Qwen Cloud", defaultModel: DEFAULT_QWEN_VIDEO_MODEL, - models: ["wan2.6-t2v", "wan2.6-i2v", "wan2.6-r2v", "wan2.6-r2v-flash", "wan2.7-r2v"], + models: [...DASHSCOPE_WAN_VIDEO_MODELS], isConfigured: ({ agentDir }) => isProviderApiKeyConfigured({ provider: "qwen", agentDir, }), - capabilities: { - generate: { - maxVideos: 1, - maxDurationSeconds: 10, - supportsSize: true, - supportsAspectRatio: true, - supportsResolution: true, - supportsAudio: true, - supportsWatermark: true, - }, - imageToVideo: { - enabled: true, - maxVideos: 1, - maxInputImages: 1, - maxDurationSeconds: 10, - supportsSize: true, - supportsAspectRatio: true, - supportsResolution: true, - supportsAudio: true, - supportsWatermark: true, - }, - videoToVideo: { - enabled: true, - maxVideos: 1, - maxInputVideos: 4, - maxDurationSeconds: 10, - supportsSize: true, - supportsAspectRatio: true, - supportsResolution: true, - supportsAudio: true, - supportsWatermark: true, - }, - }, + capabilities: DASHSCOPE_WAN_VIDEO_CAPABILITIES, async generateVideo(req): Promise { const fetchFn = fetch; const auth = await resolveApiKeyForProvider({ @@ -133,68 +93,19 @@ export function buildQwenVideoGenerationProvider(): VideoGenerationProvider { }); const model = req.model?.trim() || DEFAULT_QWEN_VIDEO_MODEL; - const { response, release } = await postJsonRequest({ + return await runDashscopeVideoGenerationTask({ + providerLabel: "Qwen", + model, + req, url: `${resolveDashscopeAigcApiBaseUrl(baseUrl)}/api/v1/services/aigc/video-generation/video-synthesis`, headers, - body: { - model, - input: buildDashscopeVideoGenerationInput({ - providerLabel: "Qwen", - req, - }), - parameters: buildDashscopeVideoGenerationParameters( - { - ...req, - durationSeconds: req.durationSeconds ?? DEFAULT_VIDEO_GENERATION_DURATION_SECONDS, - }, - DEFAULT_VIDEO_RESOLUTION_TO_SIZE, - ), - }, + baseUrl: resolveDashscopeAigcApiBaseUrl(baseUrl), timeoutMs: req.timeoutMs, fetchFn, allowPrivateNetwork, dispatcherPolicy, + defaultTimeoutMs: DEFAULT_VIDEO_GENERATION_TIMEOUT_MS, }); - - try { - await assertOkOrThrowHttpError(response, "Qwen video generation failed"); - const submitted = (await response.json()) as DashscopeVideoGenerationResponse; - const taskId = submitted.output?.task_id?.trim(); - if (!taskId) { - throw new Error("Qwen video generation response missing task_id"); - } - const completed = await pollDashscopeVideoTaskUntilComplete({ - providerLabel: "Qwen", - taskId, - headers, - timeoutMs: req.timeoutMs, - fetchFn, - baseUrl: resolveDashscopeAigcApiBaseUrl(baseUrl), - defaultTimeoutMs: DEFAULT_VIDEO_GENERATION_TIMEOUT_MS, - }); - const urls = extractDashscopeVideoUrls(completed); - if (urls.length === 0) { - throw new Error("Qwen video generation completed without output video URLs"); - } - const videos = await downloadDashscopeGeneratedVideos({ - providerLabel: "Qwen", - urls, - timeoutMs: req.timeoutMs, - fetchFn, - defaultTimeoutMs: DEFAULT_VIDEO_GENERATION_TIMEOUT_MS, - }); - return { - videos, - model, - metadata: { - requestId: submitted.request_id, - taskId, - taskStatus: completed.output?.task_status, - }, - }; - } finally { - await release(); - } }, }; } diff --git a/extensions/runway/video-generation-provider.test.ts b/extensions/runway/video-generation-provider.test.ts index 4c4c7e66958..2504f1a83c4 100644 --- a/extensions/runway/video-generation-provider.test.ts +++ b/extensions/runway/video-generation-provider.test.ts @@ -1,45 +1,20 @@ -import { afterEach, describe, expect, it, vi } from "vitest"; -import { buildRunwayVideoGenerationProvider } from "./video-generation-provider.js"; +import { beforeAll, describe, expect, it, vi } from "vitest"; +import { + getProviderHttpMocks, + installProviderHttpMockCleanup, +} from "../../test/helpers/media-generation/provider-http-mocks.js"; -const { - resolveApiKeyForProviderMock, - postJsonRequestMock, - fetchWithTimeoutMock, - assertOkOrThrowHttpErrorMock, - resolveProviderHttpRequestConfigMock, -} = vi.hoisted(() => ({ - resolveApiKeyForProviderMock: vi.fn(async () => ({ apiKey: "runway-key" })), - postJsonRequestMock: vi.fn(), - fetchWithTimeoutMock: vi.fn(), - assertOkOrThrowHttpErrorMock: vi.fn(async () => {}), - resolveProviderHttpRequestConfigMock: vi.fn((params) => ({ - baseUrl: params.baseUrl ?? params.defaultBaseUrl, - allowPrivateNetwork: false, - headers: new Headers(params.defaultHeaders), - dispatcherPolicy: undefined, - })), -})); +const { postJsonRequestMock, fetchWithTimeoutMock } = getProviderHttpMocks(); -vi.mock("openclaw/plugin-sdk/provider-auth-runtime", () => ({ - resolveApiKeyForProvider: resolveApiKeyForProviderMock, -})); +let buildRunwayVideoGenerationProvider: typeof import("./video-generation-provider.js").buildRunwayVideoGenerationProvider; -vi.mock("openclaw/plugin-sdk/provider-http", () => ({ - assertOkOrThrowHttpError: assertOkOrThrowHttpErrorMock, - fetchWithTimeout: fetchWithTimeoutMock, - postJsonRequest: postJsonRequestMock, - resolveProviderHttpRequestConfig: resolveProviderHttpRequestConfigMock, -})); +beforeAll(async () => { + ({ buildRunwayVideoGenerationProvider } = await import("./video-generation-provider.js")); +}); + +installProviderHttpMockCleanup(); describe("runway video generation provider", () => { - afterEach(() => { - resolveApiKeyForProviderMock.mockClear(); - postJsonRequestMock.mockReset(); - fetchWithTimeoutMock.mockReset(); - assertOkOrThrowHttpErrorMock.mockClear(); - resolveProviderHttpRequestConfigMock.mockClear(); - }); - it("submits a text-to-video task, polls it, and downloads the output", async () => { postJsonRequestMock.mockResolvedValue({ response: { diff --git a/extensions/slack/src/monitor.tool-result.test.ts b/extensions/slack/src/monitor.tool-result.test.ts index 3d0da66d6f9..68fbc6ba774 100644 --- a/extensions/slack/src/monitor.tool-result.test.ts +++ b/extensions/slack/src/monitor.tool-result.test.ts @@ -222,6 +222,57 @@ describe("monitorSlackProvider tool results", () => { ); } + function setMentionGatedAckConfig(statusReactionsEnabled: boolean) { + slackTestState.config = { + messages: { + responsePrefix: "PFX", + ackReaction: "👀", + ackReactionScope: "group-mentions", + removeAckAfterReply: true, + statusReactions: statusReactionsEnabled + ? { enabled: true, timing: { debounceMs: 0, doneHoldMs: 0, errorHoldMs: 0 } } + : { enabled: false }, + }, + channels: { + slack: { + dm: { enabled: true, policy: "open", allowFrom: ["*"] }, + groupPolicy: "open", + }, + }, + }; + } + + function mockGeneralChannelInfo() { + const client = getSlackClient(); + if (!client) { + throw new Error("Slack client not registered"); + } + const conversations = client.conversations as { + info: ReturnType; + }; + conversations.info.mockResolvedValueOnce({ + channel: { name: "general", is_channel: true }, + }); + } + + async function runMentionGatedChannelMessageAndFlush() { + await runSlackMessageOnce(monitorSlackProvider, { + event: makeSlackMessageEvent({ + text: "<@bot-user> hello", + ts: "456", + channel_type: "channel", + }), + }); + await new Promise((resolve) => setTimeout(resolve, 0)); + await flush(); + } + + function expectReactionNames(names: string[]) { + expect(reactMock.mock.calls.map(([args]) => String((args as { name: string }).name))).toEqual( + names, + ); + } + async function runDefaultMessageAndExpectSentText(expectedText: string) { replyMock.mockResolvedValue({ text: expectedText.replace(/^PFX /, "") }); await runSlackMessageOnce(monitorSlackProvider, { @@ -557,41 +608,9 @@ describe("monitorSlackProvider tool results", () => { it("keeps ack reaction when no reply is delivered and status reactions are disabled", async () => { replyMock.mockResolvedValue(undefined); - slackTestState.config = { - messages: { - responsePrefix: "PFX", - ackReaction: "👀", - ackReactionScope: "group-mentions", - removeAckAfterReply: true, - statusReactions: { enabled: false }, - }, - channels: { - slack: { - dm: { enabled: true, policy: "open", allowFrom: ["*"] }, - groupPolicy: "open", - }, - }, - }; - const client = getSlackClient(); - if (!client) { - throw new Error("Slack client not registered"); - } - const conversations = client.conversations as { - info: ReturnType; - }; - conversations.info.mockResolvedValueOnce({ - channel: { name: "general", is_channel: true }, - }); - - await runSlackMessageOnce(monitorSlackProvider, { - event: makeSlackMessageEvent({ - text: "<@bot-user> hello", - ts: "456", - channel_type: "channel", - }), - }); - await new Promise((resolve) => setTimeout(resolve, 0)); - await flush(); + setMentionGatedAckConfig(false); + mockGeneralChannelInfo(); + await runMentionGatedChannelMessageAndFlush(); expect(sendMock).not.toHaveBeenCalled(); expect(reactMock).toHaveBeenCalledTimes(1); @@ -604,44 +623,9 @@ describe("monitorSlackProvider tool results", () => { it("keeps ack reaction when no reply is delivered and status reactions are enabled", async () => { replyMock.mockResolvedValue(undefined); - slackTestState.config = { - messages: { - responsePrefix: "PFX", - ackReaction: "👀", - ackReactionScope: "group-mentions", - removeAckAfterReply: true, - statusReactions: { - enabled: true, - timing: { debounceMs: 0, doneHoldMs: 0, errorHoldMs: 0 }, - }, - }, - channels: { - slack: { - dm: { enabled: true, policy: "open", allowFrom: ["*"] }, - groupPolicy: "open", - }, - }, - }; - const client = getSlackClient(); - if (!client) { - throw new Error("Slack client not registered"); - } - const conversations = client.conversations as { - info: ReturnType; - }; - conversations.info.mockResolvedValueOnce({ - channel: { name: "general", is_channel: true }, - }); - - await runSlackMessageOnce(monitorSlackProvider, { - event: makeSlackMessageEvent({ - text: "<@bot-user> hello", - ts: "456", - channel_type: "channel", - }), - }); - await new Promise((resolve) => setTimeout(resolve, 0)); - await flush(); + setMentionGatedAckConfig(true); + mockGeneralChannelInfo(); + await runMentionGatedChannelMessageAndFlush(); expect(sendMock).not.toHaveBeenCalled(); expect(reactMock).toHaveBeenCalledTimes(1); @@ -654,53 +638,12 @@ describe("monitorSlackProvider tool results", () => { it("restores ack reaction when dispatch fails before any reply is delivered", async () => { replyMock.mockRejectedValue(new Error("boom")); - slackTestState.config = { - messages: { - responsePrefix: "PFX", - ackReaction: "👀", - ackReactionScope: "group-mentions", - removeAckAfterReply: true, - statusReactions: { - enabled: true, - timing: { debounceMs: 0, doneHoldMs: 0, errorHoldMs: 0 }, - }, - }, - channels: { - slack: { - dm: { enabled: true, policy: "open", allowFrom: ["*"] }, - groupPolicy: "open", - }, - }, - }; - const client = getSlackClient(); - if (!client) { - throw new Error("Slack client not registered"); - } - const conversations = client.conversations as { - info: ReturnType; - }; - conversations.info.mockResolvedValueOnce({ - channel: { name: "general", is_channel: true }, - }); - - await runSlackMessageOnce(monitorSlackProvider, { - event: makeSlackMessageEvent({ - text: "<@bot-user> hello", - ts: "456", - channel_type: "channel", - }), - }); - await new Promise((resolve) => setTimeout(resolve, 0)); - await flush(); + setMentionGatedAckConfig(true); + mockGeneralChannelInfo(); + await runMentionGatedChannelMessageAndFlush(); expect(sendMock).not.toHaveBeenCalled(); - expect(reactMock.mock.calls.map(([args]) => String((args as { name: string }).name))).toEqual([ - "eyes", - "scream", - "eyes", - "eyes", - "scream", - ]); + expectReactionNames(["eyes", "scream", "eyes", "eyes", "scream"]); }); it("replies with pairing code when dmPolicy is pairing and no allowFrom is set", async () => { diff --git a/extensions/slack/src/monitor/message-handler/prepare.thread-context-allowlist.test.ts b/extensions/slack/src/monitor/message-handler/prepare.thread-context-allowlist.test.ts index d3a9ce52fd7..9eb8b7d8d8f 100644 --- a/extensions/slack/src/monitor/message-handler/prepare.thread-context-allowlist.test.ts +++ b/extensions/slack/src/monitor/message-handler/prepare.thread-context-allowlist.test.ts @@ -23,6 +23,86 @@ function makeTmpStorePath() { return path.join(dir, "sessions.json"); } +type ThreadContextCaseParams = { + channel: string; + channelType: SlackMessageEvent["channel_type"]; + user: string; + userName: string; + starterText: string; + followUpText: string; + startTs: string; + replyTs: string; + followUpTs: string; + currentTs: string; + channelsConfig?: Parameters[0]["channelsConfig"]; + resolveChannelName?: (channelId: string) => Promise<{ + name?: string; + type?: SlackMessageEvent["channel_type"]; + topic?: string; + purpose?: string; + }>; +}; + +async function prepareThreadContextCase(params: ThreadContextCaseParams) { + const replies = vi + .fn() + .mockResolvedValueOnce({ + messages: [{ text: params.starterText, user: params.user, ts: params.startTs }], + }) + .mockResolvedValueOnce({ + messages: [ + { text: params.starterText, user: params.user, ts: params.startTs }, + { text: "assistant reply", bot_id: "B1", ts: params.replyTs }, + { text: params.followUpText, user: params.user, ts: params.followUpTs }, + { text: "current message", user: params.user, ts: params.currentTs }, + ], + response_metadata: { next_cursor: "" }, + }); + const ctx = createInboundSlackTestContext({ + cfg: { + session: { store: makeTmpStorePath() }, + channels: { + slack: { + enabled: true, + replyToMode: "all", + groupPolicy: "open", + contextVisibility: "allowlist", + }, + }, + } as OpenClawConfig, + appClient: { conversations: { replies } } as unknown as App["client"], + defaultRequireMention: false, + replyToMode: "all", + channelsConfig: params.channelsConfig, + }); + ctx.allowFrom = ["u-owner"]; + ctx.resolveUserName = async (id: string) => ({ + name: id === params.user ? params.userName : "Owner", + }); + if (params.resolveChannelName) { + ctx.resolveChannelName = params.resolveChannelName; + } + + const prepared = await prepareSlackMessage({ + ctx, + account: createSlackTestAccount({ + replyToMode: "all", + thread: { initialHistoryLimit: 20 }, + }), + message: { + channel: params.channel, + channel_type: params.channelType, + user: params.user, + text: "current message", + ts: params.currentTs, + thread_ts: params.startTs, + } as SlackMessageEvent, + opts: { source: "message" }, + }); + + return { prepared, replies }; +} + describe("prepareSlackMessage thread context allowlists", () => { afterAll(() => { if (fixtureRoot) { @@ -32,64 +112,24 @@ describe("prepareSlackMessage thread context allowlists", () => { }); it("uses room users allowlist for thread context filtering", async () => { - const replies = vi - .fn() - .mockResolvedValueOnce({ - messages: [{ text: "starter from room user", user: "U1", ts: "100.000" }], - }) - .mockResolvedValueOnce({ - messages: [ - { text: "starter from room user", user: "U1", ts: "100.000" }, - { text: "assistant reply", bot_id: "B1", ts: "100.500" }, - { text: "allowed follow-up", user: "U1", ts: "100.800" }, - { text: "current message", user: "U1", ts: "101.000" }, - ], - response_metadata: { next_cursor: "" }, - }); - const storePath = makeTmpStorePath(); - const ctx = createInboundSlackTestContext({ - cfg: { - session: { store: storePath }, - channels: { - slack: { - enabled: true, - replyToMode: "all", - groupPolicy: "open", - contextVisibility: "allowlist", - }, - }, - } as OpenClawConfig, - appClient: { conversations: { replies } } as unknown as App["client"], - defaultRequireMention: false, - replyToMode: "all", + const { prepared, replies } = await prepareThreadContextCase({ + channel: "C123", + channelType: "channel", + user: "U1", + userName: "Alice", + starterText: "starter from room user", + followUpText: "allowed follow-up", + startTs: "100.000", + replyTs: "100.500", + followUpTs: "100.800", + currentTs: "101.000", channelsConfig: { C123: { users: ["U1"], requireMention: false, }, }, - }); - ctx.allowFrom = ["u-owner"]; - ctx.resolveUserName = async (id: string) => ({ - name: id === "U1" ? "Alice" : "Owner", - }); - ctx.resolveChannelName = async () => ({ name: "general", type: "channel" }); - - const prepared = await prepareSlackMessage({ - ctx, - account: createSlackTestAccount({ - replyToMode: "all", - thread: { initialHistoryLimit: 20 }, - }), - message: { - channel: "C123", - channel_type: "channel", - user: "U1", - text: "current message", - ts: "101.000", - thread_ts: "100.000", - } as SlackMessageEvent, - opts: { source: "message" }, + resolveChannelName: async () => ({ name: "general", type: "channel" }), }); expect(prepared).toBeTruthy(); @@ -102,63 +142,23 @@ describe("prepareSlackMessage thread context allowlists", () => { }); it("does not apply the owner allowlist to open-room thread context", async () => { - const replies = vi - .fn() - .mockResolvedValueOnce({ - messages: [{ text: "starter from open room", user: "U2", ts: "200.000" }], - }) - .mockResolvedValueOnce({ - messages: [ - { text: "starter from open room", user: "U2", ts: "200.000" }, - { text: "assistant reply", bot_id: "B1", ts: "200.500" }, - { text: "open-room follow-up", user: "U2", ts: "200.800" }, - { text: "current message", user: "U2", ts: "201.000" }, - ], - response_metadata: { next_cursor: "" }, - }); - const storePath = makeTmpStorePath(); - const ctx = createInboundSlackTestContext({ - cfg: { - session: { store: storePath }, - channels: { - slack: { - enabled: true, - replyToMode: "all", - groupPolicy: "open", - contextVisibility: "allowlist", - }, - }, - } as OpenClawConfig, - appClient: { conversations: { replies } } as unknown as App["client"], - defaultRequireMention: false, - replyToMode: "all", + const { prepared, replies } = await prepareThreadContextCase({ + channel: "C124", + channelType: "channel", + user: "U2", + userName: "Bob", + starterText: "starter from open room", + followUpText: "open-room follow-up", + startTs: "200.000", + replyTs: "200.500", + followUpTs: "200.800", + currentTs: "201.000", channelsConfig: { C124: { requireMention: false, }, }, - }); - ctx.allowFrom = ["u-owner"]; - ctx.resolveUserName = async (id: string) => ({ - name: id === "U2" ? "Bob" : "Owner", - }); - ctx.resolveChannelName = async () => ({ name: "general", type: "channel" }); - - const prepared = await prepareSlackMessage({ - ctx, - account: createSlackTestAccount({ - replyToMode: "all", - thread: { initialHistoryLimit: 20 }, - }), - message: { - channel: "C124", - channel_type: "channel", - user: "U2", - text: "current message", - ts: "201.000", - thread_ts: "200.000", - } as SlackMessageEvent, - opts: { source: "message" }, + resolveChannelName: async () => ({ name: "general", type: "channel" }), }); expect(prepared).toBeTruthy(); @@ -171,57 +171,17 @@ describe("prepareSlackMessage thread context allowlists", () => { }); it("does not apply the owner allowlist to open DMs when dmPolicy is open", async () => { - const replies = vi - .fn() - .mockResolvedValueOnce({ - messages: [{ text: "starter from open dm", user: "U3", ts: "300.000" }], - }) - .mockResolvedValueOnce({ - messages: [ - { text: "starter from open dm", user: "U3", ts: "300.000" }, - { text: "assistant reply", bot_id: "B1", ts: "300.500" }, - { text: "dm follow-up", user: "U3", ts: "300.800" }, - { text: "current message", user: "U3", ts: "301.000" }, - ], - response_metadata: { next_cursor: "" }, - }); - const storePath = makeTmpStorePath(); - const ctx = createInboundSlackTestContext({ - cfg: { - session: { store: storePath }, - channels: { - slack: { - enabled: true, - replyToMode: "all", - groupPolicy: "open", - contextVisibility: "allowlist", - }, - }, - } as OpenClawConfig, - appClient: { conversations: { replies } } as unknown as App["client"], - defaultRequireMention: false, - replyToMode: "all", - }); - ctx.allowFrom = ["u-owner"]; - ctx.resolveUserName = async (id: string) => ({ - name: id === "U3" ? "Dana" : "Owner", - }); - - const prepared = await prepareSlackMessage({ - ctx, - account: createSlackTestAccount({ - replyToMode: "all", - thread: { initialHistoryLimit: 20 }, - }), - message: { - channel: "D300", - channel_type: "im", - user: "U3", - text: "current message", - ts: "301.000", - thread_ts: "300.000", - } as SlackMessageEvent, - opts: { source: "message" }, + const { prepared, replies } = await prepareThreadContextCase({ + channel: "D300", + channelType: "im", + user: "U3", + userName: "Dana", + starterText: "starter from open dm", + followUpText: "dm follow-up", + startTs: "300.000", + replyTs: "300.500", + followUpTs: "300.800", + currentTs: "301.000", }); expect(prepared).toBeTruthy(); @@ -234,57 +194,17 @@ describe("prepareSlackMessage thread context allowlists", () => { }); it("does not apply the owner allowlist to MPIM thread context", async () => { - const replies = vi - .fn() - .mockResolvedValueOnce({ - messages: [{ text: "starter from mpim", user: "U4", ts: "400.000" }], - }) - .mockResolvedValueOnce({ - messages: [ - { text: "starter from mpim", user: "U4", ts: "400.000" }, - { text: "assistant reply", bot_id: "B1", ts: "400.500" }, - { text: "mpim follow-up", user: "U4", ts: "400.800" }, - { text: "current message", user: "U4", ts: "401.000" }, - ], - response_metadata: { next_cursor: "" }, - }); - const storePath = makeTmpStorePath(); - const ctx = createInboundSlackTestContext({ - cfg: { - session: { store: storePath }, - channels: { - slack: { - enabled: true, - replyToMode: "all", - groupPolicy: "open", - contextVisibility: "allowlist", - }, - }, - } as OpenClawConfig, - appClient: { conversations: { replies } } as unknown as App["client"], - defaultRequireMention: false, - replyToMode: "all", - }); - ctx.allowFrom = ["u-owner"]; - ctx.resolveUserName = async (id: string) => ({ - name: id === "U4" ? "Evan" : "Owner", - }); - - const prepared = await prepareSlackMessage({ - ctx, - account: createSlackTestAccount({ - replyToMode: "all", - thread: { initialHistoryLimit: 20 }, - }), - message: { - channel: "G400", - channel_type: "mpim", - user: "U4", - text: "current message", - ts: "401.000", - thread_ts: "400.000", - } as SlackMessageEvent, - opts: { source: "message" }, + const { prepared, replies } = await prepareThreadContextCase({ + channel: "G400", + channelType: "mpim", + user: "U4", + userName: "Evan", + starterText: "starter from mpim", + followUpText: "mpim follow-up", + startTs: "400.000", + replyTs: "400.500", + followUpTs: "400.800", + currentTs: "401.000", }); expect(prepared).toBeTruthy(); diff --git a/extensions/telegram/src/bot-native-commands.menu-test-support.ts b/extensions/telegram/src/bot-native-commands.menu-test-support.ts index 83e7cbebd5e..dd1d34b83be 100644 --- a/extensions/telegram/src/bot-native-commands.menu-test-support.ts +++ b/extensions/telegram/src/bot-native-commands.menu-test-support.ts @@ -21,6 +21,9 @@ type CreateCommandBotResult = { deleteMessage: ReturnType; setMyCommands: ReturnType; }; +type CreateCommandBotParams = { + api?: Record; +}; const skillCommandMocks = vi.hoisted(() => ({ listSkillCommandsForAgents: vi.fn( @@ -67,7 +70,7 @@ export function resetNativeCommandMenuMocks() { emitTelegramMessageSentHooks.mockClear(); } -export function createCommandBot(): CreateCommandBotResult { +export function createCommandBot(params: CreateCommandBotParams = {}): CreateCommandBotResult { const commandHandlers = new Map Promise>(); const sendMessage = vi.fn().mockResolvedValue({ message_id: 999 }); const deleteMessage = vi.fn().mockResolvedValue(true); @@ -77,6 +80,7 @@ export function createCommandBot(): CreateCommandBotResult { setMyCommands, sendMessage, deleteMessage, + ...params.api, }, command: vi.fn((name: string, cb: (ctx: unknown) => Promise) => { commandHandlers.set(name, cb); diff --git a/extensions/telegram/src/bot-native-commands.test.ts b/extensions/telegram/src/bot-native-commands.test.ts index 878eee42bc5..b4b7e38cf80 100644 --- a/extensions/telegram/src/bot-native-commands.test.ts +++ b/extensions/telegram/src/bot-native-commands.test.ts @@ -19,6 +19,54 @@ let registerTelegramNativeCommands: typeof import("./bot-native-commands.js").re let parseTelegramNativeCommandCallbackData: typeof import("./bot-native-commands.js").parseTelegramNativeCommandCallbackData; let resolveTelegramNativeCommandDisableBlockStreaming: typeof import("./bot-native-commands.js").resolveTelegramNativeCommandDisableBlockStreaming; +type CommandBotHarness = ReturnType; +type CommandHandler = (ctx: unknown) => Promise; +type PlugCommandHarnessParams = { + botHarness?: CommandBotHarness; + cfg?: OpenClawConfig; + command?: Record; + args?: string; + result?: Record; + registerOverrides?: Partial[0]>; +}; + +function primePlugCommand(params: PlugCommandHarnessParams = {}) { + pluginCommandMocks.getPluginCommandSpecs.mockReturnValue([ + { + name: "plug", + description: "Plugin command", + }, + ] as never); + pluginCommandMocks.matchPluginCommand.mockReturnValue({ + command: { + key: "plug", + requireAuth: false, + ...params.command, + }, + args: params.args, + } as never); + pluginCommandMocks.executePluginCommand.mockResolvedValue( + (params.result ?? { text: "ok" }) as never, + ); +} + +function registerPlugCommand(params: PlugCommandHarnessParams = {}) { + const botHarness = params.botHarness ?? createCommandBot(); + primePlugCommand(params); + registerTelegramNativeCommands({ + ...createNativeCommandTestParams(params.cfg ?? {}, { + bot: botHarness.bot, + ...params.registerOverrides, + }), + }); + const handler = botHarness.commandHandlers.get("plug"); + expect(handler).toBeTruthy(); + return { + ...botHarness, + handler: handler as CommandHandler, + }; +} + describe("registerTelegramNativeCommands", () => { beforeAll(async () => { ({ @@ -230,8 +278,6 @@ describe("registerTelegramNativeCommands", () => { }); it("passes agent-scoped media roots for plugin command replies with media", async () => { - const commandHandlers = new Map Promise>(); - const sendMessage = vi.fn().mockResolvedValue(undefined); const cfg: OpenClawConfig = { agents: { list: [{ id: "main", default: true }, { id: "work" }], @@ -239,38 +285,15 @@ describe("registerTelegramNativeCommands", () => { bindings: [{ agentId: "work", match: { channel: "telegram", accountId: "default" } }], }; - pluginCommandMocks.getPluginCommandSpecs.mockReturnValue([ - { - name: "plug", - description: "Plugin command", + const { handler, sendMessage } = registerPlugCommand({ + cfg, + result: { + text: "with media", + mediaUrl: "/tmp/workspace-work/render.png", }, - ] as never); - pluginCommandMocks.matchPluginCommand.mockReturnValue({ - command: { key: "plug", requireAuth: false }, - args: undefined, - } as never); - pluginCommandMocks.executePluginCommand.mockResolvedValue({ - text: "with media", - mediaUrl: "/tmp/workspace-work/render.png", - } as never); - - registerTelegramNativeCommands({ - ...createNativeCommandTestParams(cfg, { - bot: { - api: { - setMyCommands: vi.fn().mockResolvedValue(undefined), - sendMessage, - }, - command: vi.fn((name: string, cb: (ctx: unknown) => Promise) => { - commandHandlers.set(name, cb); - }), - } as unknown as Parameters[0]["bot"], - }), }); - const handler = commandHandlers.get("plug"); - expect(handler).toBeTruthy(); - await handler?.(createPrivateCommandContext()); + await handler(createPrivateCommandContext()); const firstDeliverRepliesCall = deliverReplies.mock.calls.at(0) as [unknown] | undefined; expect(firstDeliverRepliesCall?.[0]).toEqual( @@ -305,36 +328,20 @@ describe("registerTelegramNativeCommands", () => { }); it("uses plugin command metadata to send and edit a Telegram progress placeholder", async () => { - const { bot, commandHandlers, sendMessage, deleteMessage } = createCommandBot(); - - pluginCommandMocks.getPluginCommandSpecs.mockReturnValue([ - { - name: "plug", - description: "Plugin command", - }, - ] as never); - pluginCommandMocks.matchPluginCommand.mockReturnValue({ + const { handler, sendMessage, deleteMessage } = registerPlugCommand({ + args: "now", command: { - key: "plug", - requireAuth: false, nativeProgressMessages: { telegram: "Running this command now...\n\nI'll edit this message with the final result when it's ready.", }, }, - args: "now", - } as never); - pluginCommandMocks.executePluginCommand.mockResolvedValue({ - text: "Command completed successfully", - } as never); - - registerTelegramNativeCommands({ - ...createNativeCommandTestParams({}, { bot }), + result: { + text: "Command completed successfully", + }, }); - const handler = commandHandlers.get("plug"); - expect(handler).toBeTruthy(); - await handler?.( + await handler( createPrivateCommandContext({ match: "now", }), @@ -366,38 +373,22 @@ describe("registerTelegramNativeCommands", () => { }); it("preserves Telegram buttons when editing a metadata-driven progress placeholder", async () => { - const { bot, commandHandlers, sendMessage, deleteMessage } = createCommandBot(); - - pluginCommandMocks.getPluginCommandSpecs.mockReturnValue([ - { - name: "plug", - description: "Plugin command", - }, - ] as never); - pluginCommandMocks.matchPluginCommand.mockReturnValue({ + const { handler, sendMessage, deleteMessage } = registerPlugCommand({ + args: "now", command: { - key: "plug", - requireAuth: false, nativeProgressMessages: { telegram: "Working on it..." }, }, - args: "now", - } as never); - pluginCommandMocks.executePluginCommand.mockResolvedValue({ - text: "Choose an option", - channelData: { - telegram: { - buttons: [[{ text: "Approve", callback_data: "approve" }]], + result: { + text: "Choose an option", + channelData: { + telegram: { + buttons: [[{ text: "Approve", callback_data: "approve" }]], + }, }, }, - } as never); - - registerTelegramNativeCommands({ - ...createNativeCommandTestParams({}, { bot }), }); - const handler = commandHandlers.get("plug"); - expect(handler).toBeTruthy(); - await handler?.(createPrivateCommandContext({ match: "now" })); + await handler(createPrivateCommandContext({ match: "now" })); expect(sendMessage).toHaveBeenCalledWith(100, "Working on it...", undefined); expect(editMessageTelegram).toHaveBeenCalledWith( @@ -413,34 +404,18 @@ describe("registerTelegramNativeCommands", () => { }); it("falls back to a normal reply when a metadata-driven progress result is not editable", async () => { - const { bot, commandHandlers, sendMessage, deleteMessage } = createCommandBot(); - - pluginCommandMocks.getPluginCommandSpecs.mockReturnValue([ - { - name: "plug", - description: "Plugin command", - }, - ] as never); - pluginCommandMocks.matchPluginCommand.mockReturnValue({ + const { handler, sendMessage, deleteMessage } = registerPlugCommand({ + args: "now", command: { - key: "plug", - requireAuth: false, nativeProgressMessages: { telegram: "Working on it..." }, }, - args: "now", - } as never); - pluginCommandMocks.executePluginCommand.mockResolvedValue({ - text: "rich output", - mediaUrl: "/tmp/render.png", - } as never); - - registerTelegramNativeCommands({ - ...createNativeCommandTestParams({}, { bot }), + result: { + text: "rich output", + mediaUrl: "/tmp/render.png", + }, }); - const handler = commandHandlers.get("plug"); - expect(handler).toBeTruthy(); - await handler?.( + await handler( createPrivateCommandContext({ match: "now", }), @@ -457,34 +432,18 @@ describe("registerTelegramNativeCommands", () => { }); it("cleans up the progress placeholder before falling back after an edit failure", async () => { - const { bot, commandHandlers, sendMessage, deleteMessage } = createCommandBot(); - - pluginCommandMocks.getPluginCommandSpecs.mockReturnValue([ - { - name: "plug", - description: "Plugin command", - }, - ] as never); - pluginCommandMocks.matchPluginCommand.mockReturnValue({ + const { handler, sendMessage, deleteMessage } = registerPlugCommand({ + args: "now", command: { - key: "plug", - requireAuth: false, nativeProgressMessages: { telegram: "Working on it..." }, }, - args: "now", - } as never); - pluginCommandMocks.executePluginCommand.mockResolvedValue({ - text: "Command completed successfully", - } as never); + result: { + text: "Command completed successfully", + }, + }); editMessageTelegram.mockRejectedValueOnce(new Error("message to edit not found")); - registerTelegramNativeCommands({ - ...createNativeCommandTestParams({}, { bot }), - }); - - const handler = commandHandlers.get("plug"); - expect(handler).toBeTruthy(); - await handler?.(createPrivateCommandContext({ match: "now" })); + await handler(createPrivateCommandContext({ match: "now" })); expect(sendMessage).toHaveBeenCalledWith(100, "Working on it...", undefined); expect(editMessageTelegram).toHaveBeenCalledTimes(1); @@ -497,53 +456,35 @@ describe("registerTelegramNativeCommands", () => { }); it("cleans up the progress placeholder when Telegram suppresses a local exec approval reply", async () => { - const { bot, commandHandlers, sendMessage, deleteMessage } = createCommandBot(); - - pluginCommandMocks.getPluginCommandSpecs.mockReturnValue([ - { - name: "plug", - description: "Plugin command", - }, - ] as never); - pluginCommandMocks.matchPluginCommand.mockReturnValue({ + const { handler, sendMessage, deleteMessage } = registerPlugCommand({ + args: "now", command: { - key: "plug", - requireAuth: false, nativeProgressMessages: { telegram: "Working on it..." }, }, - args: "now", - } as never); - pluginCommandMocks.executePluginCommand.mockResolvedValue({ - text: "Approval required.\n\n```txt\n/approve 7f423fdc allow-once\n```", - channelData: { - execApproval: { - approvalId: "7f423fdc-1111-2222-3333-444444444444", - approvalSlug: "7f423fdc", - allowedDecisions: ["allow-once", "allow-always", "deny"], + result: { + text: "Approval required.\n\n```txt\n/approve 7f423fdc allow-once\n```", + channelData: { + execApproval: { + approvalId: "7f423fdc-1111-2222-3333-444444444444", + approvalSlug: "7f423fdc", + allowedDecisions: ["allow-once", "allow-always", "deny"], + }, }, }, - } as never); - - registerTelegramNativeCommands({ - ...createNativeCommandTestParams( - { - channels: { - telegram: { - execApprovals: { - enabled: true, - approvers: ["12345"], - target: "dm", - }, + cfg: { + channels: { + telegram: { + execApprovals: { + enabled: true, + approvers: ["12345"], + target: "dm", }, }, }, - { bot }, - ), + }, }); - const handler = commandHandlers.get("plug"); - expect(handler).toBeTruthy(); - await handler?.(createPrivateCommandContext({ match: "now" })); + await handler(createPrivateCommandContext({ match: "now" })); expect(sendMessage).toHaveBeenCalledWith(100, "Working on it...", undefined); expect(deleteMessage).toHaveBeenCalledWith(100, 999); @@ -552,48 +493,24 @@ describe("registerTelegramNativeCommands", () => { }); it("sends plugin command error replies silently when silentErrorReplies is enabled", async () => { - const commandHandlers = new Map Promise>(); - const cfg: OpenClawConfig = { - channels: { - telegram: { - silentErrorReplies: true, + const { handler } = registerPlugCommand({ + cfg: { + channels: { + telegram: { + silentErrorReplies: true, + }, }, }, - }; - - pluginCommandMocks.getPluginCommandSpecs.mockReturnValue([ - { - name: "plug", - description: "Plugin command", + result: { + text: "plugin failed", + isError: true, + }, + registerOverrides: { + telegramCfg: { silentErrorReplies: true } as TelegramAccountConfig, }, - ] as never); - pluginCommandMocks.matchPluginCommand.mockReturnValue({ - command: { key: "plug", requireAuth: false }, - args: undefined, - } as never); - pluginCommandMocks.executePluginCommand.mockResolvedValue({ - text: "plugin failed", - isError: true, - } as never); - - registerTelegramNativeCommands({ - ...createNativeCommandTestParams(cfg, { - bot: { - api: { - setMyCommands: vi.fn().mockResolvedValue(undefined), - sendMessage: vi.fn().mockResolvedValue(undefined), - }, - command: vi.fn((name: string, cb: (ctx: unknown) => Promise) => { - commandHandlers.set(name, cb); - }), - } as unknown as Parameters[0]["bot"], - }), - telegramCfg: { silentErrorReplies: true } as TelegramAccountConfig, }); - const handler = commandHandlers.get("plug"); - expect(handler).toBeTruthy(); - await handler?.(createPrivateCommandContext()); + await handler(createPrivateCommandContext()); const firstDeliverRepliesCall = deliverReplies.mock.calls.at(0) as [unknown] | undefined; expect(firstDeliverRepliesCall?.[0]).toEqual( @@ -605,40 +522,9 @@ describe("registerTelegramNativeCommands", () => { }); it("forwards topic-scoped binding context to Telegram plugin commands", async () => { - const commandHandlers = new Map Promise>(); + const { handler } = registerPlugCommand(); - pluginCommandMocks.getPluginCommandSpecs.mockReturnValue([ - { - name: "plug", - description: "Plugin command", - }, - ] as never); - pluginCommandMocks.matchPluginCommand.mockReturnValue({ - command: { key: "plug", requireAuth: false }, - args: undefined, - } as never); - pluginCommandMocks.executePluginCommand.mockResolvedValue({ text: "ok" } as never); - - registerTelegramNativeCommands({ - ...createNativeCommandTestParams( - {}, - { - bot: { - api: { - setMyCommands: vi.fn().mockResolvedValue(undefined), - sendMessage: vi.fn().mockResolvedValue(undefined), - }, - command: vi.fn((name: string, cb: (ctx: unknown) => Promise) => { - commandHandlers.set(name, cb); - }), - } as unknown as Parameters[0]["bot"], - }, - ), - }); - - const handler = commandHandlers.get("plug"); - expect(handler).toBeTruthy(); - await handler?.({ + await handler({ match: "", message: { message_id: 2, @@ -666,42 +552,12 @@ describe("registerTelegramNativeCommands", () => { }); it("treats Telegram forum #General commands as topic 1 when Telegram omits topic metadata", async () => { - const commandHandlers = new Map Promise>(); const getChat = vi.fn(async () => ({ id: -1001234567890, type: "supergroup", is_forum: true })); - - pluginCommandMocks.getPluginCommandSpecs.mockReturnValue([ - { - name: "plug", - description: "Plugin command", - }, - ] as never); - pluginCommandMocks.matchPluginCommand.mockReturnValue({ - command: { key: "plug", requireAuth: false }, - args: undefined, - } as never); - pluginCommandMocks.executePluginCommand.mockResolvedValue({ text: "ok" } as never); - - registerTelegramNativeCommands({ - ...createNativeCommandTestParams( - {}, - { - bot: { - api: { - setMyCommands: vi.fn().mockResolvedValue(undefined), - sendMessage: vi.fn().mockResolvedValue(undefined), - getChat, - }, - command: vi.fn((name: string, cb: (ctx: unknown) => Promise) => { - commandHandlers.set(name, cb); - }), - } as unknown as Parameters[0]["bot"], - }, - ), + const { handler } = registerPlugCommand({ + botHarness: createCommandBot({ api: { getChat } }), }); - const handler = commandHandlers.get("plug"); - expect(handler).toBeTruthy(); - await handler?.({ + await handler({ match: "", message: { message_id: 2, @@ -727,40 +583,9 @@ describe("registerTelegramNativeCommands", () => { }); it("forwards direct-message binding context to Telegram plugin commands", async () => { - const commandHandlers = new Map Promise>(); + const { handler } = registerPlugCommand(); - pluginCommandMocks.getPluginCommandSpecs.mockReturnValue([ - { - name: "plug", - description: "Plugin command", - }, - ] as never); - pluginCommandMocks.matchPluginCommand.mockReturnValue({ - command: { key: "plug", requireAuth: false }, - args: undefined, - } as never); - pluginCommandMocks.executePluginCommand.mockResolvedValue({ text: "ok" } as never); - - registerTelegramNativeCommands({ - ...createNativeCommandTestParams( - {}, - { - bot: { - api: { - setMyCommands: vi.fn().mockResolvedValue(undefined), - sendMessage: vi.fn().mockResolvedValue(undefined), - }, - command: vi.fn((name: string, cb: (ctx: unknown) => Promise) => { - commandHandlers.set(name, cb); - }), - } as unknown as Parameters[0]["bot"], - }, - ), - }); - - const handler = commandHandlers.get("plug"); - expect(handler).toBeTruthy(); - await handler?.(createPrivateCommandContext({ chatId: 100, userId: 200 })); + await handler(createPrivateCommandContext({ chatId: 100, userId: 200 })); expect(pluginCommandMocks.executePluginCommand).toHaveBeenCalledWith( expect.objectContaining({ diff --git a/extensions/telegram/src/button-types.test-helpers.ts b/extensions/telegram/src/button-types.test-helpers.ts new file mode 100644 index 00000000000..163561a705e --- /dev/null +++ b/extensions/telegram/src/button-types.test-helpers.ts @@ -0,0 +1,71 @@ +import { describe, expect, it } from "vitest"; +import { buildTelegramInteractiveButtons, resolveTelegramInlineButtons } from "./button-types.js"; + +export function describeTelegramInteractiveButtonBehavior(): void { + describe("buildTelegramInteractiveButtons", () => { + it("maps shared buttons and selects into Telegram inline rows", () => { + expect( + buildTelegramInteractiveButtons({ + blocks: [ + { + type: "buttons", + buttons: [ + { label: "Approve", value: "approve", style: "success" }, + { label: "Reject", value: "reject", style: "danger" }, + { label: "Later", value: "later" }, + { label: "Archive", value: "archive" }, + ], + }, + { + type: "select", + options: [{ label: "Alpha", value: "alpha" }], + }, + ], + }), + ).toEqual([ + [ + { text: "Approve", callback_data: "approve", style: "success" }, + { text: "Reject", callback_data: "reject", style: "danger" }, + { text: "Later", callback_data: "later", style: undefined }, + ], + [{ text: "Archive", callback_data: "archive", style: undefined }], + [{ text: "Alpha", callback_data: "alpha", style: undefined }], + ]); + }); + }); + + describe("resolveTelegramInlineButtons", () => { + it("prefers explicit buttons over shared interactive blocks", () => { + const explicit = [[{ text: "Keep", callback_data: "keep" }]] as const; + + expect( + resolveTelegramInlineButtons({ + buttons: explicit, + interactive: { + blocks: [ + { + type: "buttons", + buttons: [{ label: "Override", value: "override" }], + }, + ], + }, + }), + ).toBe(explicit); + }); + + it("derives buttons from raw interactive payloads", () => { + expect( + resolveTelegramInlineButtons({ + interactive: { + blocks: [ + { + type: "buttons", + buttons: [{ label: "Retry", value: "retry", style: "primary" }], + }, + ], + }, + }), + ).toEqual([[{ text: "Retry", callback_data: "retry", style: "primary" }]]); + }); + }); +} diff --git a/extensions/telegram/src/button-types.test.ts b/extensions/telegram/src/button-types.test.ts index 9652cbc277f..8fe96fe4abc 100644 --- a/extensions/telegram/src/button-types.test.ts +++ b/extensions/telegram/src/button-types.test.ts @@ -1,37 +1,10 @@ import { describe, expect, it } from "vitest"; -import { buildTelegramInteractiveButtons, resolveTelegramInlineButtons } from "./button-types.js"; +import { buildTelegramInteractiveButtons } from "./button-types.js"; +import { describeTelegramInteractiveButtonBehavior } from "./button-types.test-helpers.js"; -describe("buildTelegramInteractiveButtons", () => { - it("maps shared buttons and selects into Telegram inline rows", () => { - expect( - buildTelegramInteractiveButtons({ - blocks: [ - { - type: "buttons", - buttons: [ - { label: "Approve", value: "approve", style: "success" }, - { label: "Reject", value: "reject", style: "danger" }, - { label: "Later", value: "later" }, - { label: "Archive", value: "archive" }, - ], - }, - { - type: "select", - options: [{ label: "Alpha", value: "alpha" }], - }, - ], - }), - ).toEqual([ - [ - { text: "Approve", callback_data: "approve", style: "success" }, - { text: "Reject", callback_data: "reject", style: "danger" }, - { text: "Later", callback_data: "later", style: undefined }, - ], - [{ text: "Archive", callback_data: "archive", style: undefined }], - [{ text: "Alpha", callback_data: "alpha", style: undefined }], - ]); - }); +describeTelegramInteractiveButtonBehavior(); +describe("buildTelegramInteractiveButtons callback limits", () => { it("drops buttons whose callback payload exceeds Telegram limits", () => { expect( buildTelegramInteractiveButtons({ @@ -48,38 +21,3 @@ describe("buildTelegramInteractiveButtons", () => { ).toEqual([[{ text: "Keep", callback_data: "ok", style: undefined }]]); }); }); - -describe("resolveTelegramInlineButtons", () => { - it("prefers explicit buttons over shared interactive blocks", () => { - const explicit = [[{ text: "Keep", callback_data: "keep" }]] as const; - - expect( - resolveTelegramInlineButtons({ - buttons: explicit, - interactive: { - blocks: [ - { - type: "buttons", - buttons: [{ label: "Override", value: "override" }], - }, - ], - }, - }), - ).toBe(explicit); - }); - - it("derives buttons from raw interactive payloads", () => { - expect( - resolveTelegramInlineButtons({ - interactive: { - blocks: [ - { - type: "buttons", - buttons: [{ label: "Retry", value: "retry", style: "primary" }], - }, - ], - }, - }), - ).toEqual([[{ text: "Retry", callback_data: "retry", style: "primary" }]]); - }); -}); diff --git a/extensions/telegram/src/doctor-contract.ts b/extensions/telegram/src/doctor-contract.ts index dd4153b520b..b4aaa48e1c5 100644 --- a/extensions/telegram/src/doctor-contract.ts +++ b/extensions/telegram/src/doctor-contract.ts @@ -97,6 +97,7 @@ export function normalizeCompatibilityConfig({ entry: updated, pathPrefix: "channels.telegram", changes, + includePreviewChunk: true, resolvedMode: resolveTelegramPreviewStreamMode(updated), includePreviewChunk: true, }); @@ -116,6 +117,7 @@ export function normalizeCompatibilityConfig({ entry: account, pathPrefix: `channels.telegram.accounts.${accountId}`, changes, + includePreviewChunk: true, resolvedMode: resolveTelegramPreviewStreamMode(account), includePreviewChunk: true, }); diff --git a/extensions/telegram/src/inline-buttons.test.ts b/extensions/telegram/src/inline-buttons.test.ts index fbe3196930c..2dd842a7671 100644 --- a/extensions/telegram/src/inline-buttons.test.ts +++ b/extensions/telegram/src/inline-buttons.test.ts @@ -1,5 +1,6 @@ import { describe, expect, it } from "vitest"; -import { buildTelegramInteractiveButtons, resolveTelegramInlineButtons } from "./button-types.js"; +import { buildTelegramInteractiveButtons } from "./button-types.js"; +import { describeTelegramInteractiveButtonBehavior } from "./button-types.test-helpers.js"; import { resolveTelegramTargetChatType } from "./inline-buttons.js"; describe("resolveTelegramTargetChatType", () => { @@ -37,37 +38,9 @@ describe("resolveTelegramTargetChatType", () => { }); }); -describe("buildTelegramInteractiveButtons", () => { - it("maps shared buttons and selects into Telegram inline rows", () => { - expect( - buildTelegramInteractiveButtons({ - blocks: [ - { - type: "buttons", - buttons: [ - { label: "Approve", value: "approve", style: "success" }, - { label: "Reject", value: "reject", style: "danger" }, - { label: "Later", value: "later" }, - { label: "Archive", value: "archive" }, - ], - }, - { - type: "select", - options: [{ label: "Alpha", value: "alpha" }], - }, - ], - }), - ).toEqual([ - [ - { text: "Approve", callback_data: "approve", style: "success" }, - { text: "Reject", callback_data: "reject", style: "danger" }, - { text: "Later", callback_data: "later", style: undefined }, - ], - [{ text: "Archive", callback_data: "archive", style: undefined }], - [{ text: "Alpha", callback_data: "alpha", style: undefined }], - ]); - }); +describeTelegramInteractiveButtonBehavior(); +describe("buildTelegramInteractiveButtons callback rewrites", () => { it("drops shared buttons whose callback data exceeds Telegram's limit", () => { expect( buildTelegramInteractiveButtons({ @@ -112,38 +85,3 @@ describe("buildTelegramInteractiveButtons", () => { ]); }); }); - -describe("resolveTelegramInlineButtons", () => { - it("prefers explicit buttons over shared interactive blocks", () => { - const explicit = [[{ text: "Keep", callback_data: "keep" }]] as const; - - expect( - resolveTelegramInlineButtons({ - buttons: explicit, - interactive: { - blocks: [ - { - type: "buttons", - buttons: [{ label: "Override", value: "override" }], - }, - ], - }, - }), - ).toBe(explicit); - }); - - it("derives buttons from raw interactive payloads", () => { - expect( - resolveTelegramInlineButtons({ - interactive: { - blocks: [ - { - type: "buttons", - buttons: [{ label: "Retry", value: "retry", style: "primary" }], - }, - ], - }, - }), - ).toEqual([[{ text: "Retry", callback_data: "retry", style: "primary" }]]); - }); -}); diff --git a/extensions/telegram/src/polling-session.test.ts b/extensions/telegram/src/polling-session.test.ts index c28b8967d48..662450dcc06 100644 --- a/extensions/telegram/src/polling-session.test.ts +++ b/extensions/telegram/src/polling-session.test.ts @@ -30,6 +30,13 @@ vi.mock("openclaw/plugin-sdk/runtime-env", () => ({ let TelegramPollingSession: typeof import("./polling-session.js").TelegramPollingSession; +type TelegramApiMiddleware = ( + prev: (...args: unknown[]) => Promise, + method: string, + payload: unknown, +) => Promise; +type AsyncVoidFn = () => Promise; + function makeBot() { return { api: { @@ -41,7 +48,10 @@ function makeBot() { }; } -function installPollingStallWatchdogHarness() { +function installPollingStallWatchdogHarness( + dateNowSequence: readonly number[] = [0, 0], + fallbackDateNow = 120_001, +) { let watchdog: (() => void) | undefined; const setIntervalSpy = vi.spyOn(globalThis, "setInterval").mockImplementation((fn) => { watchdog = fn as () => void; @@ -53,11 +63,11 @@ function installPollingStallWatchdogHarness() { return 1 as unknown as ReturnType; }); const clearTimeoutSpy = vi.spyOn(globalThis, "clearTimeout").mockImplementation(() => {}); - const dateNowSpy = vi - .spyOn(Date, "now") - .mockImplementationOnce(() => 0) // lastGetUpdatesAt init - .mockImplementationOnce(() => 0) // lastApiActivityAt init - .mockImplementation(() => 120_001); + const dateNowSpy = vi.spyOn(Date, "now"); + for (const value of dateNowSequence) { + dateNowSpy.mockImplementationOnce(() => value); + } + dateNowSpy.mockImplementation(() => fallbackDateNow); return { async waitForWatchdog() { @@ -114,6 +124,15 @@ function createPollingSessionWithTransportRestart(params: { abortSignal: AbortSignal; telegramTransport: ReturnType; createTelegramTransport: () => ReturnType; +}) { + return createPollingSession(params); +} + +function createPollingSession(params: { + abortSignal: AbortSignal; + log?: (message: string) => void; + telegramTransport?: ReturnType; + createTelegramTransport?: () => ReturnType; }) { return new TelegramPollingSession({ token: "tok", @@ -125,12 +144,47 @@ function createPollingSessionWithTransportRestart(params: { runnerOptions: {}, getLastUpdateId: () => null, persistUpdateId: async () => undefined, - log: () => undefined, + log: params.log ?? (() => undefined), telegramTransport: params.telegramTransport, - createTelegramTransport: params.createTelegramTransport, + ...(params.createTelegramTransport + ? { createTelegramTransport: params.createTelegramTransport } + : {}), }); } +function mockBotCapturingApiMiddleware(botStop: AsyncVoidFn) { + let apiMiddleware: TelegramApiMiddleware | undefined; + createTelegramBotMock.mockReturnValueOnce({ + api: { + deleteWebhook: vi.fn(async () => true), + getUpdates: vi.fn(async () => []), + config: { + use: vi.fn((fn: TelegramApiMiddleware) => { + apiMiddleware = fn; + }), + }, + }, + stop: botStop, + }); + return () => apiMiddleware; +} + +function mockLongRunningPollingCycle(runnerStop: AsyncVoidFn) { + let firstTaskResolve: (() => void) | undefined; + runMock.mockReturnValue({ + task: () => + new Promise((resolve) => { + firstTaskResolve = resolve; + }), + stop: async () => { + await runnerStop(); + firstTaskResolve?.(); + }, + isRunning: () => true, + }); + return () => firstTaskResolve?.(); +} + describe("TelegramPollingSession", () => { beforeAll(async () => { ({ TelegramPollingSession } = await import("./polling-session.js")); @@ -367,91 +421,28 @@ describe("TelegramPollingSession", () => { const abort = new AbortController(); const botStop = vi.fn(async () => undefined); const runnerStop = vi.fn(async () => undefined); - - // Capture the API middleware so we can simulate sendMessage calls - let apiMiddleware: - | (( - prev: (...args: unknown[]) => Promise, - method: string, - payload: unknown, - ) => Promise) - | undefined; - - const bot = { - api: { - deleteWebhook: vi.fn(async () => true), - getUpdates: vi.fn(async () => []), - config: { - use: vi.fn((fn: typeof apiMiddleware) => { - apiMiddleware = fn; - }), - }, - }, - stop: botStop, - }; - createTelegramBotMock.mockReturnValue(bot); - - let firstTaskResolve: (() => void) | undefined; - const firstTask = new Promise((resolve) => { - firstTaskResolve = resolve; - }); - runMock.mockImplementation(() => ({ - task: () => firstTask, - stop: async () => { - await runnerStop(); - firstTaskResolve?.(); - }, - isRunning: () => true, - })); + const getApiMiddleware = mockBotCapturingApiMiddleware(botStop); + const resolveFirstTask = mockLongRunningPollingCycle(runnerStop); // t=0: lastGetUpdatesAt and lastApiActivityAt initialized // t=120_001: watchdog fires (getUpdates stale for 120s) // But right before watchdog, a sendMessage succeeded at t=120_000 - const setIntervalSpy = vi.spyOn(globalThis, "setInterval").mockImplementation((fn) => { - watchdog = fn as () => void; - return 1 as unknown as ReturnType; - }); - const clearIntervalSpy = vi.spyOn(globalThis, "clearInterval").mockImplementation(() => {}); - const setTimeoutSpy = vi.spyOn(globalThis, "setTimeout").mockImplementation((fn) => { - void Promise.resolve().then(() => (fn as () => void)()); - return 1 as unknown as ReturnType; - }); - const clearTimeoutSpy = vi.spyOn(globalThis, "clearTimeout").mockImplementation(() => {}); - const dateNowSpy = vi - .spyOn(Date, "now") - .mockImplementationOnce(() => 0) // lastGetUpdatesAt init - .mockImplementationOnce(() => 0) // lastApiActivityAt init - // All subsequent calls (sendMessage completion + watchdog check) return - // the same value, giving apiIdle = 0 — well below the stall threshold. - .mockImplementation(() => 120_001); + // All subsequent Date.now calls return the same value, giving apiIdle = 0. + const watchdogHarness = installPollingStallWatchdogHarness(); - let watchdog: (() => void) | undefined; const log = vi.fn(); - const session = new TelegramPollingSession({ - token: "tok", - config: {}, - accountId: "default", - runtime: undefined, - proxyFetch: undefined, + const session = createPollingSession({ abortSignal: abort.signal, - runnerOptions: {}, - getLastUpdateId: () => null, - persistUpdateId: async () => undefined, log, - telegramTransport: undefined, }); try { const runPromise = session.runUntilAbort(); - - // Wait for watchdog to be captured - for (let attempt = 0; attempt < 20 && !watchdog; attempt += 1) { - await Promise.resolve(); - } - expect(watchdog).toBeTypeOf("function"); + const watchdog = await watchdogHarness.waitForWatchdog(); // Simulate a sendMessage call through the middleware before watchdog fires. // This updates lastApiActivityAt, proving the network is alive. + const apiMiddleware = getApiMiddleware(); if (apiMiddleware) { const fakePrev = vi.fn(async () => ({ ok: true })); await apiMiddleware(fakePrev, "sendMessage", { chat_id: 123, text: "hello" }); @@ -467,14 +458,10 @@ describe("TelegramPollingSession", () => { // Clean up: abort to end the session abort.abort(); - firstTaskResolve?.(); + resolveFirstTask(); await runPromise; } finally { - setIntervalSpy.mockRestore(); - clearIntervalSpy.mockRestore(); - setTimeoutSpy.mockRestore(); - clearTimeoutSpy.mockRestore(); - dateNowSpy.mockRestore(); + watchdogHarness.restore(); } }); @@ -482,85 +469,26 @@ describe("TelegramPollingSession", () => { const abort = new AbortController(); const botStop = vi.fn(async () => undefined); const runnerStop = vi.fn(async () => undefined); + const getApiMiddleware = mockBotCapturingApiMiddleware(botStop); + const resolveFirstTask = mockLongRunningPollingCycle(runnerStop); - let apiMiddleware: - | (( - prev: (...args: unknown[]) => Promise, - method: string, - payload: unknown, - ) => Promise) - | undefined; - createTelegramBotMock.mockReturnValueOnce({ - api: { - deleteWebhook: vi.fn(async () => true), - getUpdates: vi.fn(async () => []), - config: { - use: vi.fn((fn: typeof apiMiddleware) => { - apiMiddleware = fn; - }), - }, - }, - stop: botStop, - }); + const watchdogHarness = installPollingStallWatchdogHarness([0, 0, 60_000]); - let firstTaskResolve: (() => void) | undefined; - runMock.mockReturnValue({ - task: () => - new Promise((resolve) => { - firstTaskResolve = resolve; - }), - stop: async () => { - await runnerStop(); - firstTaskResolve?.(); - }, - isRunning: () => true, - }); - - // t=0: lastGetUpdatesAt and lastApiActivityAt initialized - const setIntervalSpy = vi.spyOn(globalThis, "setInterval").mockImplementation((fn) => { - watchdog = fn as () => void; - return 1 as unknown as ReturnType; - }); - const clearIntervalSpy = vi.spyOn(globalThis, "clearInterval").mockImplementation(() => {}); - const setTimeoutSpy = vi.spyOn(globalThis, "setTimeout").mockImplementation((fn) => { - void Promise.resolve().then(() => (fn as () => void)()); - return 1 as unknown as ReturnType; - }); - const clearTimeoutSpy = vi.spyOn(globalThis, "clearTimeout").mockImplementation(() => {}); - const dateNowSpy = vi - .spyOn(Date, "now") - .mockImplementationOnce(() => 0) // lastGetUpdatesAt init - .mockImplementationOnce(() => 0) // lastApiActivityAt init - .mockImplementationOnce(() => 60_000) // sendMessage start - .mockImplementation(() => 120_001); - - let watchdog: (() => void) | undefined; const log = vi.fn(); - const session = new TelegramPollingSession({ - token: "tok", - config: {}, - accountId: "default", - runtime: undefined, - proxyFetch: undefined, + const session = createPollingSession({ abortSignal: abort.signal, - runnerOptions: {}, - getLastUpdateId: () => null, - persistUpdateId: async () => undefined, log, - telegramTransport: undefined, }); try { const runPromise = session.runUntilAbort(); - for (let attempt = 0; attempt < 20 && !watchdog; attempt += 1) { - await Promise.resolve(); - } - expect(watchdog).toBeTypeOf("function"); + const watchdog = await watchdogHarness.waitForWatchdog(); // Start an in-flight sendMessage that has NOT yet resolved. // This simulates a slow delivery where the API call is still pending. let resolveSendMessage: ((v: unknown) => void) | undefined; + const apiMiddleware = getApiMiddleware(); if (apiMiddleware) { const slowPrev = vi.fn( () => @@ -586,14 +514,10 @@ describe("TelegramPollingSession", () => { } abort.abort(); - firstTaskResolve?.(); + resolveFirstTask(); await runPromise; } finally { - setIntervalSpy.mockRestore(); - clearIntervalSpy.mockRestore(); - setTimeoutSpy.mockRestore(); - clearTimeoutSpy.mockRestore(); - dateNowSpy.mockRestore(); + watchdogHarness.restore(); } }); @@ -601,82 +525,24 @@ describe("TelegramPollingSession", () => { const abort = new AbortController(); const botStop = vi.fn(async () => undefined); const runnerStop = vi.fn(async () => undefined); + const getApiMiddleware = mockBotCapturingApiMiddleware(botStop); + const resolveFirstTask = mockLongRunningPollingCycle(runnerStop); - let apiMiddleware: - | (( - prev: (...args: unknown[]) => Promise, - method: string, - payload: unknown, - ) => Promise) - | undefined; - createTelegramBotMock.mockReturnValueOnce({ - api: { - deleteWebhook: vi.fn(async () => true), - getUpdates: vi.fn(async () => []), - config: { - use: vi.fn((fn: typeof apiMiddleware) => { - apiMiddleware = fn; - }), - }, - }, - stop: botStop, - }); + const watchdogHarness = installPollingStallWatchdogHarness([0, 0, 1]); - let firstTaskResolve: (() => void) | undefined; - runMock.mockReturnValue({ - task: () => - new Promise((resolve) => { - firstTaskResolve = resolve; - }), - stop: async () => { - await runnerStop(); - firstTaskResolve?.(); - }, - isRunning: () => true, - }); - - const setIntervalSpy = vi.spyOn(globalThis, "setInterval").mockImplementation((fn) => { - watchdog = fn as () => void; - return 1 as unknown as ReturnType; - }); - const clearIntervalSpy = vi.spyOn(globalThis, "clearInterval").mockImplementation(() => {}); - const setTimeoutSpy = vi.spyOn(globalThis, "setTimeout").mockImplementation((fn) => { - void Promise.resolve().then(() => (fn as () => void)()); - return 1 as unknown as ReturnType; - }); - const clearTimeoutSpy = vi.spyOn(globalThis, "clearTimeout").mockImplementation(() => {}); - const dateNowSpy = vi - .spyOn(Date, "now") - .mockImplementationOnce(() => 0) // lastGetUpdatesAt init - .mockImplementationOnce(() => 0) // lastApiActivityAt init - .mockImplementationOnce(() => 1) // sendMessage start - .mockImplementation(() => 120_001); - - let watchdog: (() => void) | undefined; const log = vi.fn(); - const session = new TelegramPollingSession({ - token: "tok", - config: {}, - accountId: "default", - runtime: undefined, - proxyFetch: undefined, + const session = createPollingSession({ abortSignal: abort.signal, - runnerOptions: {}, - getLastUpdateId: () => null, - persistUpdateId: async () => undefined, log, - telegramTransport: undefined, }); try { const runPromise = session.runUntilAbort(); - for (let attempt = 0; attempt < 20 && !watchdog; attempt += 1) { - await Promise.resolve(); - } - expect(watchdog).toBeTypeOf("function"); + const watchdog = await watchdogHarness.waitForWatchdog(); let resolveSendMessage: ((v: unknown) => void) | undefined; + const apiMiddleware = getApiMiddleware(); if (apiMiddleware) { const slowPrev = vi.fn( () => @@ -699,14 +565,10 @@ describe("TelegramPollingSession", () => { } abort.abort(); - firstTaskResolve?.(); + resolveFirstTask(); await runPromise; } finally { - setIntervalSpy.mockRestore(); - clearIntervalSpy.mockRestore(); - setTimeoutSpy.mockRestore(); - clearTimeoutSpy.mockRestore(); - dateNowSpy.mockRestore(); + watchdogHarness.restore(); } }); @@ -714,84 +576,25 @@ describe("TelegramPollingSession", () => { const abort = new AbortController(); const botStop = vi.fn(async () => undefined); const runnerStop = vi.fn(async () => undefined); + const getApiMiddleware = mockBotCapturingApiMiddleware(botStop); + const resolveFirstTask = mockLongRunningPollingCycle(runnerStop); - let apiMiddleware: - | (( - prev: (...args: unknown[]) => Promise, - method: string, - payload: unknown, - ) => Promise) - | undefined; - createTelegramBotMock.mockReturnValueOnce({ - api: { - deleteWebhook: vi.fn(async () => true), - getUpdates: vi.fn(async () => []), - config: { - use: vi.fn((fn: typeof apiMiddleware) => { - apiMiddleware = fn; - }), - }, - }, - stop: botStop, - }); + const watchdogHarness = installPollingStallWatchdogHarness([0, 0, 1, 120_000]); - let firstTaskResolve: (() => void) | undefined; - runMock.mockReturnValue({ - task: () => - new Promise((resolve) => { - firstTaskResolve = resolve; - }), - stop: async () => { - await runnerStop(); - firstTaskResolve?.(); - }, - isRunning: () => true, - }); - - const setIntervalSpy = vi.spyOn(globalThis, "setInterval").mockImplementation((fn) => { - watchdog = fn as () => void; - return 1 as unknown as ReturnType; - }); - const clearIntervalSpy = vi.spyOn(globalThis, "clearInterval").mockImplementation(() => {}); - const setTimeoutSpy = vi.spyOn(globalThis, "setTimeout").mockImplementation((fn) => { - void Promise.resolve().then(() => (fn as () => void)()); - return 1 as unknown as ReturnType; - }); - const clearTimeoutSpy = vi.spyOn(globalThis, "clearTimeout").mockImplementation(() => {}); - const dateNowSpy = vi - .spyOn(Date, "now") - .mockImplementationOnce(() => 0) // lastGetUpdatesAt init - .mockImplementationOnce(() => 0) // lastApiActivityAt init - .mockImplementationOnce(() => 1) // first sendMessage start - .mockImplementationOnce(() => 120_000) // second sendMessage start - .mockImplementation(() => 120_001); - - let watchdog: (() => void) | undefined; const log = vi.fn(); - const session = new TelegramPollingSession({ - token: "tok", - config: {}, - accountId: "default", - runtime: undefined, - proxyFetch: undefined, + const session = createPollingSession({ abortSignal: abort.signal, - runnerOptions: {}, - getLastUpdateId: () => null, - persistUpdateId: async () => undefined, log, - telegramTransport: undefined, }); try { const runPromise = session.runUntilAbort(); - for (let attempt = 0; attempt < 20 && !watchdog; attempt += 1) { - await Promise.resolve(); - } - expect(watchdog).toBeTypeOf("function"); + const watchdog = await watchdogHarness.waitForWatchdog(); let resolveFirstSend: ((v: unknown) => void) | undefined; let resolveSecondSend: ((v: unknown) => void) | undefined; + const apiMiddleware = getApiMiddleware(); if (apiMiddleware) { const firstSendPromise = apiMiddleware( vi.fn( @@ -829,14 +632,10 @@ describe("TelegramPollingSession", () => { } abort.abort(); - firstTaskResolve?.(); + resolveFirstTask(); await runPromise; } finally { - setIntervalSpy.mockRestore(); - clearIntervalSpy.mockRestore(); - setTimeoutSpy.mockRestore(); - clearTimeoutSpy.mockRestore(); - dateNowSpy.mockRestore(); + watchdogHarness.restore(); } }); diff --git a/extensions/together/video-generation-provider.test.ts b/extensions/together/video-generation-provider.test.ts index 2a5fdf226ea..f4adeef241f 100644 --- a/extensions/together/video-generation-provider.test.ts +++ b/extensions/together/video-generation-provider.test.ts @@ -1,45 +1,20 @@ -import { afterEach, describe, expect, it, vi } from "vitest"; -import { buildTogetherVideoGenerationProvider } from "./video-generation-provider.js"; +import { beforeAll, describe, expect, it, vi } from "vitest"; +import { + getProviderHttpMocks, + installProviderHttpMockCleanup, +} from "../../test/helpers/media-generation/provider-http-mocks.js"; -const { - resolveApiKeyForProviderMock, - postJsonRequestMock, - fetchWithTimeoutMock, - assertOkOrThrowHttpErrorMock, - resolveProviderHttpRequestConfigMock, -} = vi.hoisted(() => ({ - resolveApiKeyForProviderMock: vi.fn(async () => ({ apiKey: "together-key" })), - postJsonRequestMock: vi.fn(), - fetchWithTimeoutMock: vi.fn(), - assertOkOrThrowHttpErrorMock: vi.fn(async () => {}), - resolveProviderHttpRequestConfigMock: vi.fn((params) => ({ - baseUrl: params.baseUrl ?? params.defaultBaseUrl, - allowPrivateNetwork: false, - headers: new Headers(params.defaultHeaders), - dispatcherPolicy: undefined, - })), -})); +const { postJsonRequestMock, fetchWithTimeoutMock } = getProviderHttpMocks(); -vi.mock("openclaw/plugin-sdk/provider-auth-runtime", () => ({ - resolveApiKeyForProvider: resolveApiKeyForProviderMock, -})); +let buildTogetherVideoGenerationProvider: typeof import("./video-generation-provider.js").buildTogetherVideoGenerationProvider; -vi.mock("openclaw/plugin-sdk/provider-http", () => ({ - assertOkOrThrowHttpError: assertOkOrThrowHttpErrorMock, - fetchWithTimeout: fetchWithTimeoutMock, - postJsonRequest: postJsonRequestMock, - resolveProviderHttpRequestConfig: resolveProviderHttpRequestConfigMock, -})); +beforeAll(async () => { + ({ buildTogetherVideoGenerationProvider } = await import("./video-generation-provider.js")); +}); + +installProviderHttpMockCleanup(); describe("together video generation provider", () => { - afterEach(() => { - resolveApiKeyForProviderMock.mockClear(); - postJsonRequestMock.mockReset(); - fetchWithTimeoutMock.mockReset(); - assertOkOrThrowHttpErrorMock.mockClear(); - resolveProviderHttpRequestConfigMock.mockClear(); - }); - it("creates a video, polls completion, and downloads the output", async () => { postJsonRequestMock.mockResolvedValue({ response: { diff --git a/extensions/whatsapp/src/auto-reply/monitor/inbound-dispatch.test.ts b/extensions/whatsapp/src/auto-reply/monitor/inbound-dispatch.test.ts index 342a94dbf79..85b717dff5f 100644 --- a/extensions/whatsapp/src/auto-reply/monitor/inbound-dispatch.test.ts +++ b/extensions/whatsapp/src/auto-reply/monitor/inbound-dispatch.test.ts @@ -99,6 +99,39 @@ function getCapturedDeliver() { )?.dispatcherOptions?.deliver; } +type BufferedReplyParams = Parameters[0]; + +function makeReplyLogger(): BufferedReplyParams["replyLogger"] { + return { + info: () => {}, + warn: () => {}, + error: () => {}, + debug: () => {}, + } as never; +} + +async function dispatchBufferedReply(overrides: Partial = {}) { + const params: BufferedReplyParams = { + cfg: { channels: { whatsapp: { blockStreaming: true } } } as never, + connectionId: "conn", + context: { Body: "hi" }, + conversationId: "+1000", + deliverReply: async () => {}, + groupHistories: new Map(), + groupHistoryKey: "+1000", + maxMediaBytes: 1, + msg: makeMsg(), + rememberSentText: () => {}, + replyLogger: makeReplyLogger(), + replyPipeline: {} as never, + replyResolver: (async () => undefined) as never, + route: makeRoute(), + shouldClearGroupHistory: false, + }; + + return dispatchWhatsAppBufferedReply({ ...params, ...overrides }); +} + describe("whatsapp inbound dispatch", () => { beforeEach(() => { capturedDispatchParams = undefined; @@ -197,29 +230,16 @@ describe("whatsapp inbound dispatch", () => { ["whatsapp:default:group:123@g.us", [{ sender: "Alice (+111)", body: "first" }]], ]); - await dispatchWhatsAppBufferedReply({ - cfg: { channels: { whatsapp: { blockStreaming: true } } } as never, - connectionId: "conn", + await dispatchBufferedReply({ context: { Body: "second" }, conversationId: "123@g.us", - deliverReply: async () => {}, groupHistories, groupHistoryKey: "whatsapp:default:group:123@g.us", - maxMediaBytes: 1, msg: makeMsg({ from: "123@g.us", chatType: "group", senderE164: "+222", }), - rememberSentText: () => {}, - replyLogger: { - info: () => {}, - warn: () => {}, - error: () => {}, - debug: () => {}, - } as never, - replyPipeline: {}, - replyResolver: (async () => undefined) as never, route: makeRoute({ sessionKey: "agent:main:whatsapp:group:123@g.us" }), shouldClearGroupHistory: true, }); @@ -231,27 +251,9 @@ describe("whatsapp inbound dispatch", () => { const deliverReply = vi.fn(async () => undefined); const rememberSentText = vi.fn(); - await dispatchWhatsAppBufferedReply({ - cfg: { channels: { whatsapp: { blockStreaming: true } } } as never, - connectionId: "conn", - context: { Body: "hi" }, - conversationId: "+1000", + await dispatchBufferedReply({ deliverReply, - groupHistories: new Map(), - groupHistoryKey: "+1000", - maxMediaBytes: 1, - msg: makeMsg(), rememberSentText, - replyLogger: { - info: () => {}, - warn: () => {}, - error: () => {}, - debug: () => {}, - } as never, - replyPipeline: {}, - replyResolver: (async () => undefined) as never, - route: makeRoute(), - shouldClearGroupHistory: false, }); const deliver = getCapturedDeliver(); @@ -271,27 +273,9 @@ describe("whatsapp inbound dispatch", () => { const deliverReply = vi.fn(async () => undefined); const rememberSentText = vi.fn(); - await dispatchWhatsAppBufferedReply({ - cfg: { channels: { whatsapp: { blockStreaming: true } } } as never, - connectionId: "conn", - context: { Body: "hi" }, - conversationId: "+1000", + await dispatchBufferedReply({ deliverReply, - groupHistories: new Map(), - groupHistoryKey: "+1000", - maxMediaBytes: 1, - msg: makeMsg(), rememberSentText, - replyLogger: { - info: () => {}, - warn: () => {}, - error: () => {}, - debug: () => {}, - } as never, - replyPipeline: {}, - replyResolver: (async () => undefined) as never, - route: makeRoute(), - shouldClearGroupHistory: false, }); const deliver = getCapturedDeliver(); @@ -307,28 +291,7 @@ describe("whatsapp inbound dispatch", () => { }); it("maps WhatsApp blockStreaming=true to disableBlockStreaming=false", async () => { - await dispatchWhatsAppBufferedReply({ - cfg: { channels: { whatsapp: { blockStreaming: true } } } as never, - connectionId: "conn", - context: { Body: "hi" }, - conversationId: "+1000", - deliverReply: async () => {}, - groupHistories: new Map(), - groupHistoryKey: "+1000", - maxMediaBytes: 1, - msg: makeMsg(), - rememberSentText: () => {}, - replyLogger: { - info: () => {}, - warn: () => {}, - error: () => {}, - debug: () => {}, - } as never, - replyPipeline: {}, - replyResolver: (async () => undefined) as never, - route: makeRoute(), - shouldClearGroupHistory: false, - }); + await dispatchBufferedReply(); expect( ( @@ -340,27 +303,8 @@ describe("whatsapp inbound dispatch", () => { }); it("maps WhatsApp blockStreaming=false to disableBlockStreaming=true", async () => { - await dispatchWhatsAppBufferedReply({ + await dispatchBufferedReply({ cfg: { channels: { whatsapp: { blockStreaming: false } } } as never, - connectionId: "conn", - context: { Body: "hi" }, - conversationId: "+1000", - deliverReply: async () => {}, - groupHistories: new Map(), - groupHistoryKey: "+1000", - maxMediaBytes: 1, - msg: makeMsg(), - rememberSentText: () => {}, - replyLogger: { - info: () => {}, - warn: () => {}, - error: () => {}, - debug: () => {}, - } as never, - replyPipeline: {}, - replyResolver: (async () => undefined) as never, - route: makeRoute(), - shouldClearGroupHistory: false, }); expect( @@ -373,27 +317,8 @@ describe("whatsapp inbound dispatch", () => { }); it("leaves disableBlockStreaming undefined when WhatsApp blockStreaming is unset", async () => { - await dispatchWhatsAppBufferedReply({ + await dispatchBufferedReply({ cfg: { channels: { whatsapp: {} } } as never, - connectionId: "conn", - context: { Body: "hi" }, - conversationId: "+1000", - deliverReply: async () => {}, - groupHistories: new Map(), - groupHistoryKey: "+1000", - maxMediaBytes: 1, - msg: makeMsg(), - rememberSentText: () => {}, - replyLogger: { - info: () => {}, - warn: () => {}, - error: () => {}, - debug: () => {}, - } as never, - replyPipeline: {}, - replyResolver: (async () => undefined) as never, - route: makeRoute(), - shouldClearGroupHistory: false, }); expect( @@ -425,27 +350,9 @@ describe("whatsapp inbound dispatch", () => { ); await expect( - dispatchWhatsAppBufferedReply({ - cfg: { channels: { whatsapp: { blockStreaming: true } } } as never, - connectionId: "conn", - context: { Body: "hi" }, - conversationId: "+1000", + dispatchBufferedReply({ deliverReply, - groupHistories: new Map(), - groupHistoryKey: "+1000", - maxMediaBytes: 1, - msg: makeMsg(), rememberSentText, - replyLogger: { - info: () => {}, - warn: () => {}, - error: () => {}, - debug: () => {}, - } as never, - replyPipeline: {}, - replyResolver: (async () => undefined) as never, - route: makeRoute(), - shouldClearGroupHistory: false, }), ).resolves.toBe(true); @@ -456,27 +363,8 @@ describe("whatsapp inbound dispatch", () => { it("passes sendComposing through as the reply typing callback", async () => { const sendComposing = vi.fn(async () => undefined); - await dispatchWhatsAppBufferedReply({ - cfg: { channels: { whatsapp: { blockStreaming: true } } } as never, - connectionId: "conn", - context: { Body: "hi" }, - conversationId: "+1000", - deliverReply: async () => {}, - groupHistories: new Map(), - groupHistoryKey: "+1000", - maxMediaBytes: 1, + await dispatchBufferedReply({ msg: makeMsg({ sendComposing }), - rememberSentText: () => {}, - replyLogger: { - info: () => {}, - warn: () => {}, - error: () => {}, - debug: () => {}, - } as never, - replyPipeline: {}, - replyResolver: (async () => undefined) as never, - route: makeRoute(), - shouldClearGroupHistory: false, }); expect( diff --git a/extensions/whatsapp/src/channel.setup.test.ts b/extensions/whatsapp/src/channel.setup.test.ts index e4ea8fc448b..9c6fc60d991 100644 --- a/extensions/whatsapp/src/channel.setup.test.ts +++ b/extensions/whatsapp/src/channel.setup.test.ts @@ -2,11 +2,25 @@ import { DEFAULT_ACCOUNT_ID } from "openclaw/plugin-sdk/routing"; import type { RuntimeEnv } from "openclaw/plugin-sdk/runtime-env"; import { beforeEach, describe, expect, it, vi } from "vitest"; import { createQueuedWizardPrompter } from "../../../test/helpers/plugins/setup-wizard.js"; -import { checkWhatsAppHeartbeatReady } from "./heartbeat.js"; import { whatsappApprovalAuth } from "./approval-auth.js"; import { whatsappPlugin } from "./channel.js"; +import { checkWhatsAppHeartbeatReady } from "./heartbeat.js"; import type { OpenClawConfig } from "./runtime-api.js"; import { finalizeWhatsAppSetup } from "./setup-finalize.js"; +import { + createWhatsAppAllowlistModeInput, + createWhatsAppLinkingHarness, + createWhatsAppOwnerAllowlistHarness, + createWhatsAppPersonalPhoneHarness, + createWhatsAppRootAllowFromConfig, + expectNoWhatsAppLoginFollowup, + expectWhatsAppAllowlistModeSetup, + expectWhatsAppLoginFollowup, + expectWhatsAppOpenPolicySetup, + expectWhatsAppOwnerAllowlistSetup, + expectWhatsAppPersonalPhoneSetup, + expectWhatsAppSeparatePhoneDisabledSetup, +} from "./setup-test-helpers.js"; const hoisted = vi.hoisted(() => ({ loginWeb: vi.fn(async () => {}), @@ -129,10 +143,7 @@ describe("whatsapp setup wizard", () => { }); it("applies owner allowlist when forceAllowFrom is enabled", async () => { - const harness = createQueuedWizardPrompter({ - confirmValues: [false], - textValues: ["+1 (555) 555-0123"], - }); + const harness = createWhatsAppOwnerAllowlistHarness(createQueuedWizardPrompter); const result = await runConfigureWithHarness({ harness, @@ -141,14 +152,7 @@ describe("whatsapp setup wizard", () => { expect(result.accountId).toBe(DEFAULT_ACCOUNT_ID); expect(hoisted.loginWeb).not.toHaveBeenCalled(); - expect(result.cfg.channels?.whatsapp?.selfChatMode).toBe(true); - expect(result.cfg.channels?.whatsapp?.dmPolicy).toBe("allowlist"); - expect(result.cfg.channels?.whatsapp?.allowFrom).toEqual(["+15555550123"]); - expect(harness.text).toHaveBeenCalledWith( - expect.objectContaining({ - message: "Your personal WhatsApp number (the phone you will message from)", - }), - ); + expectWhatsAppOwnerAllowlistSetup(result.cfg, harness); }); it("supports disabled DM policy for separate-phone setup", async () => { @@ -156,38 +160,24 @@ describe("whatsapp setup wizard", () => { selectValues: ["separate", "disabled"], }); - expect(result.cfg.channels?.whatsapp?.selfChatMode).toBe(false); - expect(result.cfg.channels?.whatsapp?.dmPolicy).toBe("disabled"); - expect(result.cfg.channels?.whatsapp?.allowFrom).toBeUndefined(); - expect(harness.text).not.toHaveBeenCalled(); + expectWhatsAppSeparatePhoneDisabledSetup(result.cfg, harness); }); it("normalizes allowFrom entries when list mode is selected", async () => { - const { result } = await runSeparatePhoneFlow({ - selectValues: ["separate", "allowlist", "list"], - textValues: ["+1 (555) 555-0123, +15555550123, *"], - }); + const { result } = await runSeparatePhoneFlow(createWhatsAppAllowlistModeInput()); - expect(result.cfg.channels?.whatsapp?.selfChatMode).toBe(false); - expect(result.cfg.channels?.whatsapp?.dmPolicy).toBe("allowlist"); - expect(result.cfg.channels?.whatsapp?.allowFrom).toEqual(["+15555550123", "*"]); + expectWhatsAppAllowlistModeSetup(result.cfg); }); it("enables allowlist self-chat mode for personal-phone setup", async () => { hoisted.pathExists.mockResolvedValue(true); - const harness = createQueuedWizardPrompter({ - confirmValues: [false], - selectValues: ["personal"], - textValues: ["+1 (555) 111-2222"], - }); + const harness = createWhatsAppPersonalPhoneHarness(createQueuedWizardPrompter); const result = await runConfigureWithHarness({ harness, }); - expect(result.cfg.channels?.whatsapp?.selfChatMode).toBe(true); - expect(result.cfg.channels?.whatsapp?.dmPolicy).toBe("allowlist"); - expect(result.cfg.channels?.whatsapp?.allowFrom).toEqual(["+15551112222"]); + expectWhatsAppPersonalPhoneSetup(result.cfg); }); it("forces wildcard allowFrom for open policy without allowFrom follow-up prompts", async () => { @@ -198,28 +188,15 @@ describe("whatsapp setup wizard", () => { const result = await runConfigureWithHarness({ harness, - cfg: { - channels: { - whatsapp: { - allowFrom: ["+15555550123"], - }, - }, - }, + cfg: createWhatsAppRootAllowFromConfig() as OpenClawConfig, }); - expect(result.cfg.channels?.whatsapp?.selfChatMode).toBe(false); - expect(result.cfg.channels?.whatsapp?.dmPolicy).toBe("open"); - expect(result.cfg.channels?.whatsapp?.allowFrom).toEqual(["*", "+15555550123"]); - expect(harness.select).toHaveBeenCalledTimes(2); - expect(harness.text).not.toHaveBeenCalled(); + expectWhatsAppOpenPolicySetup(result.cfg, harness); }); it("runs WhatsApp login when not linked and user confirms linking", async () => { hoisted.pathExists.mockResolvedValue(false); - const harness = createQueuedWizardPrompter({ - confirmValues: [true], - selectValues: ["separate", "disabled"], - }); + const harness = createWhatsAppLinkingHarness(createQueuedWizardPrompter); const runtime = createRuntime(); await runConfigureWithHarness({ @@ -241,10 +218,7 @@ describe("whatsapp setup wizard", () => { }); expect(hoisted.loginWeb).not.toHaveBeenCalled(); - expect(harness.note).not.toHaveBeenCalledWith( - expect.stringContaining("openclaw channels login"), - "WhatsApp", - ); + expectNoWhatsAppLoginFollowup(harness); }); it("shows follow-up login command note when not linked and linking is skipped", async () => { @@ -257,10 +231,7 @@ describe("whatsapp setup wizard", () => { harness, }); - expect(harness.note).toHaveBeenCalledWith( - expect.stringContaining("openclaw channels login"), - "WhatsApp", - ); + expectWhatsAppLoginFollowup(harness); }); it("heartbeat readiness uses configured defaultAccount for active listener checks", async () => { diff --git a/extensions/whatsapp/src/setup-surface.test.ts b/extensions/whatsapp/src/setup-surface.test.ts index 354f2a08f4a..d70c422ad5e 100644 --- a/extensions/whatsapp/src/setup-surface.test.ts +++ b/extensions/whatsapp/src/setup-surface.test.ts @@ -7,6 +7,23 @@ import { runSetupWizardFinalize, } from "../../../test/helpers/plugins/setup-wizard.js"; import { whatsappSetupWizard } from "./setup-surface.js"; +import { + createWhatsAppAllowlistModeInput, + createWhatsAppLinkingHarness, + createWhatsAppOwnerAllowlistHarness, + createWhatsAppPersonalPhoneHarness, + createWhatsAppRootAllowFromConfig, + createWhatsAppWorkAccountConfig, + expectNoWhatsAppLoginFollowup, + expectWhatsAppAllowlistModeSetup, + expectWhatsAppLoginFollowup, + expectWhatsAppOpenPolicySetup, + expectWhatsAppOwnerAllowlistSetup, + expectWhatsAppPersonalPhoneSetup, + expectWhatsAppSeparatePhoneDisabledSetup, + expectWhatsAppWorkAccountAccessNote, + expectWhatsAppWorkAccountOpenAccess, +} from "./setup-test-helpers.js"; const hoisted = vi.hoisted(() => ({ detectWhatsAppLinked: vi.fn<(cfg: OpenClawConfig, accountId: string) => Promise>( @@ -124,10 +141,7 @@ describe("whatsapp setup wizard", () => { }); it("applies owner allowlist when forceAllowFrom is enabled", async () => { - const harness = createQueuedWizardPrompter({ - confirmValues: [false], - textValues: ["+1 (555) 555-0123"], - }); + const harness = createWhatsAppOwnerAllowlistHarness(createQueuedWizardPrompter); const result = expectFinalizeResult( await runFinalizeWithHarness({ @@ -137,14 +151,7 @@ describe("whatsapp setup wizard", () => { ); expect(hoisted.loginWeb).not.toHaveBeenCalled(); - expect(result.cfg.channels?.whatsapp?.selfChatMode).toBe(true); - expect(result.cfg.channels?.whatsapp?.dmPolicy).toBe("allowlist"); - expect(result.cfg.channels?.whatsapp?.allowFrom).toEqual(["+15555550123"]); - expect(harness.text).toHaveBeenCalledWith( - expect.objectContaining({ - message: "Your personal WhatsApp number (the phone you will message from)", - }), - ); + expectWhatsAppOwnerAllowlistSetup(result.cfg, harness); }); it("supports disabled DM policy for separate-phone setup", async () => { @@ -152,10 +159,7 @@ describe("whatsapp setup wizard", () => { selectValues: ["separate", "disabled"], }); - expect(result.cfg.channels?.whatsapp?.selfChatMode).toBe(false); - expect(result.cfg.channels?.whatsapp?.dmPolicy).toBe("disabled"); - expect(result.cfg.channels?.whatsapp?.allowFrom).toBeUndefined(); - expect(harness.text).not.toHaveBeenCalled(); + expectWhatsAppSeparatePhoneDisabledSetup(result.cfg, harness); }); it("writes named-account DM policy and allowFrom instead of the channel root", async () => { @@ -168,32 +172,12 @@ describe("whatsapp setup wizard", () => { await runFinalizeWithHarness({ harness, accountId: "work", - cfg: { - channels: { - whatsapp: { - dmPolicy: "disabled", - allowFrom: ["+15555550123"], - accounts: { - work: { - authDir: "/tmp/work", - }, - }, - }, - }, - }, + cfg: createWhatsAppWorkAccountConfig() as OpenClawConfig, }), ); - expect(named.cfg.channels?.whatsapp?.dmPolicy).toBe("disabled"); - expect(named.cfg.channels?.whatsapp?.allowFrom).toEqual(["+15555550123"]); - expect(named.cfg.channels?.whatsapp?.accounts?.work?.dmPolicy).toBe("open"); - expect(named.cfg.channels?.whatsapp?.accounts?.work?.allowFrom).toEqual(["*", "+15555550123"]); - expect(harness.note).toHaveBeenCalledWith( - expect.stringContaining( - "`channels.whatsapp.accounts.work.dmPolicy` + `channels.whatsapp.accounts.work.allowFrom`", - ), - "WhatsApp DM access", - ); + expectWhatsAppWorkAccountOpenAccess(named.cfg); + expectWhatsAppWorkAccountAccessNote(harness); }); it("labels the selected named account in setup status even when not linked", async () => { @@ -261,53 +245,23 @@ describe("whatsapp setup wizard", () => { await runFinalizeWithHarness({ harness, accountId: "", - cfg: { - channels: { - whatsapp: { - defaultAccount: "work", - dmPolicy: "disabled", - allowFrom: ["+15555550123"], - accounts: { - work: { - authDir: "/tmp/work", - }, - }, - }, - }, - }, + cfg: createWhatsAppWorkAccountConfig({ defaultAccount: "work" }) as OpenClawConfig, }), ); - expect(result.cfg.channels?.whatsapp?.dmPolicy).toBe("disabled"); - expect(result.cfg.channels?.whatsapp?.allowFrom).toEqual(["+15555550123"]); - expect(result.cfg.channels?.whatsapp?.accounts?.work?.dmPolicy).toBe("open"); - expect(result.cfg.channels?.whatsapp?.accounts?.work?.allowFrom).toEqual(["*", "+15555550123"]); - expect(harness.note).toHaveBeenCalledWith( - expect.stringContaining( - "`channels.whatsapp.accounts.work.dmPolicy` + `channels.whatsapp.accounts.work.allowFrom`", - ), - "WhatsApp DM access", - ); + expectWhatsAppWorkAccountOpenAccess(result.cfg); + expectWhatsAppWorkAccountAccessNote(harness); }); it("normalizes allowFrom entries when list mode is selected", async () => { - const { result } = await runSeparatePhoneFlow({ - selectValues: ["separate", "allowlist", "list"], - textValues: ["+1 (555) 555-0123, +15555550123, *"], - }); + const { result } = await runSeparatePhoneFlow(createWhatsAppAllowlistModeInput()); - expect(result.cfg.channels?.whatsapp?.selfChatMode).toBe(false); - expect(result.cfg.channels?.whatsapp?.dmPolicy).toBe("allowlist"); - expect(result.cfg.channels?.whatsapp?.allowFrom).toEqual(["+15555550123", "*"]); + expectWhatsAppAllowlistModeSetup(result.cfg); }); it("enables allowlist self-chat mode for personal-phone setup", async () => { hoisted.pathExists.mockResolvedValue(true); - const harness = createQueuedWizardPrompter({ - confirmValues: [false], - selectValues: ["personal"], - textValues: ["+1 (555) 111-2222"], - }); + const harness = createWhatsAppPersonalPhoneHarness(createQueuedWizardPrompter); const result = expectFinalizeResult( await runFinalizeWithHarness({ @@ -315,9 +269,7 @@ describe("whatsapp setup wizard", () => { }), ); - expect(result.cfg.channels?.whatsapp?.selfChatMode).toBe(true); - expect(result.cfg.channels?.whatsapp?.dmPolicy).toBe("allowlist"); - expect(result.cfg.channels?.whatsapp?.allowFrom).toEqual(["+15551112222"]); + expectWhatsAppPersonalPhoneSetup(result.cfg); }); it("forces wildcard allowFrom for open policy without allowFrom follow-up prompts", async () => { @@ -329,29 +281,16 @@ describe("whatsapp setup wizard", () => { const result = expectFinalizeResult( await runFinalizeWithHarness({ harness, - cfg: { - channels: { - whatsapp: { - allowFrom: ["+15555550123"], - }, - }, - }, + cfg: createWhatsAppRootAllowFromConfig() as OpenClawConfig, }), ); - expect(result.cfg.channels?.whatsapp?.selfChatMode).toBe(false); - expect(result.cfg.channels?.whatsapp?.dmPolicy).toBe("open"); - expect(result.cfg.channels?.whatsapp?.allowFrom).toEqual(["*", "+15555550123"]); - expect(harness.select).toHaveBeenCalledTimes(2); - expect(harness.text).not.toHaveBeenCalled(); + expectWhatsAppOpenPolicySetup(result.cfg, harness); }); it("runs WhatsApp login when not linked and user confirms linking", async () => { hoisted.pathExists.mockResolvedValue(false); - const harness = createQueuedWizardPrompter({ - confirmValues: [true], - selectValues: ["separate", "disabled"], - }); + const harness = createWhatsAppLinkingHarness(createQueuedWizardPrompter); const runtime = createRuntime(); await runFinalizeWithHarness({ @@ -373,10 +312,7 @@ describe("whatsapp setup wizard", () => { }); expect(hoisted.loginWeb).not.toHaveBeenCalled(); - expect(harness.note).not.toHaveBeenCalledWith( - expect.stringContaining("openclaw channels login"), - "WhatsApp", - ); + expectNoWhatsAppLoginFollowup(harness); }); it("shows follow-up login command note when not linked and linking is skipped", async () => { @@ -389,9 +325,6 @@ describe("whatsapp setup wizard", () => { harness, }); - expect(harness.note).toHaveBeenCalledWith( - expect.stringContaining("openclaw channels login"), - "WhatsApp", - ); + expectWhatsAppLoginFollowup(harness); }); }); diff --git a/extensions/whatsapp/src/setup-test-helpers.ts b/extensions/whatsapp/src/setup-test-helpers.ts new file mode 100644 index 00000000000..317d9c81a5e --- /dev/null +++ b/extensions/whatsapp/src/setup-test-helpers.ts @@ -0,0 +1,207 @@ +import { expect } from "vitest"; + +type WhatsAppSetupConfig = { + channels?: { + whatsapp?: { + selfChatMode?: boolean; + dmPolicy?: string; + allowFrom?: string[]; + accounts?: Record; + }; + }; +}; + +type WizardPromptHarness = { + text: { (...args: unknown[]): unknown }; + select: { (...args: unknown[]): unknown }; + note: { (...args: unknown[]): unknown }; +}; + +type QueuedWizardPrompterFactory = (params: { + confirmValues?: boolean[]; + selectValues?: string[]; + textValues?: string[]; +}) => T; + +export const WHATSAPP_OWNER_NUMBER_INPUT = "+1 (555) 555-0123"; +export const WHATSAPP_OWNER_NUMBER = "+15555550123"; +export const WHATSAPP_PERSONAL_NUMBER_INPUT = "+1 (555) 111-2222"; +export const WHATSAPP_PERSONAL_NUMBER = "+15551112222"; +export const WHATSAPP_ACCESS_NOTE_TITLE = "WhatsApp DM access"; +export const WHATSAPP_LOGIN_NOTE_TITLE = "WhatsApp"; + +export function createWhatsAppRootAllowFromConfig(): WhatsAppSetupConfig { + return { + channels: { + whatsapp: { + allowFrom: [WHATSAPP_OWNER_NUMBER], + }, + }, + }; +} + +export function createWhatsAppOwnerAllowlistHarness( + createPrompter: QueuedWizardPrompterFactory, +): T { + return createPrompter({ + confirmValues: [false], + textValues: [WHATSAPP_OWNER_NUMBER_INPUT], + }); +} + +export function createWhatsAppPersonalPhoneHarness( + createPrompter: QueuedWizardPrompterFactory, +): T { + return createPrompter({ + confirmValues: [false], + selectValues: ["personal"], + textValues: [WHATSAPP_PERSONAL_NUMBER_INPUT], + }); +} + +export function createWhatsAppLinkingHarness( + createPrompter: QueuedWizardPrompterFactory, +): T { + return createPrompter({ + confirmValues: [true], + selectValues: ["separate", "disabled"], + }); +} + +export function createWhatsAppWorkAccountConfig( + params: { + defaultAccount?: string; + } = {}, +): WhatsAppSetupConfig { + return { + channels: { + whatsapp: { + ...(params.defaultAccount ? { defaultAccount: params.defaultAccount } : {}), + dmPolicy: "disabled", + allowFrom: [WHATSAPP_OWNER_NUMBER], + accounts: { + work: { + authDir: "/tmp/work", + }, + }, + }, + }, + }; +} + +export function createWhatsAppAllowlistModeInput(): { + selectValues: string[]; + textValues: string[]; +} { + return { + selectValues: ["separate", "allowlist", "list"], + textValues: [`${WHATSAPP_OWNER_NUMBER_INPUT}, ${WHATSAPP_OWNER_NUMBER}, *`], + }; +} + +export function expectWhatsAppDmAccess( + cfg: WhatsAppSetupConfig, + expected: { + selfChatMode: boolean; + dmPolicy: string; + allowFrom?: string[]; + }, +): void { + expect(cfg.channels?.whatsapp?.selfChatMode).toBe(expected.selfChatMode); + expect(cfg.channels?.whatsapp?.dmPolicy).toBe(expected.dmPolicy); + if ("allowFrom" in expected) { + expect(cfg.channels?.whatsapp?.allowFrom).toEqual(expected.allowFrom); + } else { + expect(cfg.channels?.whatsapp?.allowFrom).toBeUndefined(); + } +} + +export function expectWhatsAppWorkAccountOpenAccess(cfg: WhatsAppSetupConfig): void { + expect(cfg.channels?.whatsapp?.dmPolicy).toBe("disabled"); + expect(cfg.channels?.whatsapp?.allowFrom).toEqual([WHATSAPP_OWNER_NUMBER]); + expect(cfg.channels?.whatsapp?.accounts?.work?.dmPolicy).toBe("open"); + expect(cfg.channels?.whatsapp?.accounts?.work?.allowFrom).toEqual(["*", WHATSAPP_OWNER_NUMBER]); +} + +export function expectWhatsAppOwnerNumberPrompt(harness: WizardPromptHarness): void { + expect(harness.text).toHaveBeenCalledWith( + expect.objectContaining({ + message: "Your personal WhatsApp number (the phone you will message from)", + }), + ); +} + +export function expectWhatsAppOwnerAllowlistSetup( + cfg: WhatsAppSetupConfig, + harness: WizardPromptHarness, +): void { + expectWhatsAppDmAccess(cfg, { + selfChatMode: true, + dmPolicy: "allowlist", + allowFrom: [WHATSAPP_OWNER_NUMBER], + }); + expectWhatsAppOwnerNumberPrompt(harness); +} + +export function expectWhatsAppSeparatePhoneDisabledSetup( + cfg: WhatsAppSetupConfig, + harness: WizardPromptHarness, +): void { + expectWhatsAppDmAccess(cfg, { + selfChatMode: false, + dmPolicy: "disabled", + }); + expect(harness.text).not.toHaveBeenCalled(); +} + +export function expectWhatsAppAllowlistModeSetup(cfg: WhatsAppSetupConfig): void { + expectWhatsAppDmAccess(cfg, { + selfChatMode: false, + dmPolicy: "allowlist", + allowFrom: [WHATSAPP_OWNER_NUMBER, "*"], + }); +} + +export function expectWhatsAppPersonalPhoneSetup(cfg: WhatsAppSetupConfig): void { + expectWhatsAppDmAccess(cfg, { + selfChatMode: true, + dmPolicy: "allowlist", + allowFrom: [WHATSAPP_PERSONAL_NUMBER], + }); +} + +export function expectWhatsAppOpenPolicySetup( + cfg: WhatsAppSetupConfig, + harness: WizardPromptHarness, +): void { + expectWhatsAppDmAccess(cfg, { + selfChatMode: false, + dmPolicy: "open", + allowFrom: ["*", WHATSAPP_OWNER_NUMBER], + }); + expect(harness.select).toHaveBeenCalledTimes(2); + expect(harness.text).not.toHaveBeenCalled(); +} + +export function expectNoWhatsAppLoginFollowup(harness: WizardPromptHarness): void { + expect(harness.note).not.toHaveBeenCalledWith( + expect.stringContaining("openclaw channels login"), + WHATSAPP_LOGIN_NOTE_TITLE, + ); +} + +export function expectWhatsAppLoginFollowup(harness: WizardPromptHarness): void { + expect(harness.note).toHaveBeenCalledWith( + expect.stringContaining("openclaw channels login"), + WHATSAPP_LOGIN_NOTE_TITLE, + ); +} + +export function expectWhatsAppWorkAccountAccessNote(harness: WizardPromptHarness): void { + expect(harness.note).toHaveBeenCalledWith( + expect.stringContaining( + "`channels.whatsapp.accounts.work.dmPolicy` + `channels.whatsapp.accounts.work.allowFrom`", + ), + WHATSAPP_ACCESS_NOTE_TITLE, + ); +} diff --git a/extensions/whatsapp/src/test-helpers.ts b/extensions/whatsapp/src/test-helpers.ts index eecde36946e..53457e80a95 100644 --- a/extensions/whatsapp/src/test-helpers.ts +++ b/extensions/whatsapp/src/test-helpers.ts @@ -82,6 +82,47 @@ function loadSessionStoreMock(storePath: string) { } } +type BufferedDispatchReplyParams = { + ctx: Record; + replyResolver: (ctx: Record) => Promise | undefined>; + dispatcherOptions: { + deliver: ( + payload: Record, + info: { kind: "tool" | "block" | "final" }, + ) => Promise; + onReplyStart?: (() => Promise) | (() => void); + }; +}; + +function createBufferedDispatchReplyMock() { + return vi.fn(async (params: BufferedDispatchReplyParams) => { + await params.dispatcherOptions.onReplyStart?.(); + const payload = await params.replyResolver(params.ctx); + if (!payload || typeof payload !== "object") { + return { + queuedFinal: false, + counts: { tool: 0, block: 0, final: 0 }, + }; + } + const text = typeof payload.text === "string" ? payload.text.trim() : ""; + const hasMedia = + typeof payload.mediaUrl === "string" || + typeof payload.mediaPath === "string" || + typeof payload.fileUrl === "string"; + if (!text && !hasMedia) { + return { + queuedFinal: false, + counts: { tool: 0, block: 0, final: 0 }, + }; + } + await params.dispatcherOptions.deliver(payload, { kind: "final" }); + return { + queuedFinal: true, + counts: { tool: 0, block: 0, final: 1 }, + }; + }); +} + function resolveChannelContextVisibilityModeMock(params: { cfg: { channels?: Record< @@ -225,44 +266,7 @@ vi.mock("./auto-reply/monitor/inbound-dispatch.runtime.js", () => ({ onModelSelected: undefined, responsePrefix: undefined, }), - dispatchReplyWithBufferedBlockDispatcher: vi.fn( - async (params: { - ctx: Record; - replyResolver: (ctx: Record) => Promise | undefined>; - dispatcherOptions: { - deliver: ( - payload: Record, - info: { kind: "tool" | "block" | "final" }, - ) => Promise; - onReplyStart?: (() => Promise) | (() => void); - }; - }) => { - await params.dispatcherOptions.onReplyStart?.(); - const payload = await params.replyResolver(params.ctx); - if (!payload || typeof payload !== "object") { - return { - queuedFinal: false, - counts: { tool: 0, block: 0, final: 0 }, - }; - } - const text = typeof payload.text === "string" ? payload.text.trim() : ""; - const hasMedia = - typeof payload.mediaUrl === "string" || - typeof payload.mediaPath === "string" || - typeof payload.fileUrl === "string"; - if (!text && !hasMedia) { - return { - queuedFinal: false, - counts: { tool: 0, block: 0, final: 0 }, - }; - } - await params.dispatcherOptions.deliver(payload, { kind: "final" }); - return { - queuedFinal: true, - counts: { tool: 0, block: 0, final: 1 }, - }; - }, - ), + dispatchReplyWithBufferedBlockDispatcher: createBufferedDispatchReplyMock(), finalizeInboundContext: (ctx: T) => ctx, getAgentScopedMediaLocalRoots: () => [] as string[], jidToE164: (jid: string) => { @@ -304,44 +308,7 @@ vi.mock("./auto-reply/monitor/runtime-api.js", () => ({ onModelSelected: undefined, responsePrefix: undefined, }), - dispatchReplyWithBufferedBlockDispatcher: vi.fn( - async (params: { - ctx: Record; - replyResolver: (ctx: Record) => Promise | undefined>; - dispatcherOptions: { - deliver: ( - payload: Record, - info: { kind: "tool" | "block" | "final" }, - ) => Promise; - onReplyStart?: (() => Promise) | (() => void); - }; - }) => { - await params.dispatcherOptions.onReplyStart?.(); - const payload = await params.replyResolver(params.ctx); - if (!payload || typeof payload !== "object") { - return { - queuedFinal: false, - counts: { tool: 0, block: 0, final: 0 }, - }; - } - const text = typeof payload.text === "string" ? payload.text.trim() : ""; - const hasMedia = - typeof payload.mediaUrl === "string" || - typeof payload.mediaPath === "string" || - typeof payload.fileUrl === "string"; - if (!text && !hasMedia) { - return { - queuedFinal: false, - counts: { tool: 0, block: 0, final: 0 }, - }; - } - await params.dispatcherOptions.deliver(payload, { kind: "final" }); - return { - queuedFinal: true, - counts: { tool: 0, block: 0, final: 1 }, - }; - }, - ), + dispatchReplyWithBufferedBlockDispatcher: createBufferedDispatchReplyMock(), finalizeInboundContext: (ctx: T) => ctx, formatInboundEnvelope: (params: { body: string; senderLabel?: string }) => `${params.senderLabel ? `${params.senderLabel}: ` : ""}${params.body}`, diff --git a/extensions/xai/stream.test.ts b/extensions/xai/stream.test.ts index 4787442a2b1..cf14b6ed1a8 100644 --- a/extensions/xai/stream.test.ts +++ b/extensions/xai/stream.test.ts @@ -14,10 +14,12 @@ type XaiTestPayload = Record & { tools?: Array<{ type?: string; function?: Record }>; input?: unknown[]; }; +type XaiStreamApi = Extract; + function captureWrappedModelId(params: { modelId: string; fastMode: boolean; - api?: Extract; + api?: XaiStreamApi; }): string { let capturedModelId = ""; const baseStreamFn: StreamFn = (model) => { @@ -39,6 +41,33 @@ function captureWrappedModelId(params: { return capturedModelId; } +function runXaiToolPayloadWrapper(params: { + payload: Record; + api?: XaiStreamApi; + modelId?: string; + input?: string[]; +}) { + const baseStreamFn: StreamFn = (_model, _context, options) => { + options?.onPayload?.(params.payload, {} as Model); + return {} as ReturnType; + }; + const wrapped = createXaiToolPayloadCompatibilityWrapper(baseStreamFn); + const api = params.api ?? "openai-responses"; + + void wrapped( + { + api, + provider: "xai", + id: + params.modelId ?? + (api === "openai-completions" ? "grok-4-1-fast-reasoning" : "grok-4-fast"), + ...(params.input ? { input: params.input } : {}), + } as Model, + { messages: [] } as Context, + {}, + ); +} + describe("xai stream wrappers", () => { it("rewrites supported Grok models to fast variants when fast mode is enabled", () => { expect(captureWrappedModelId({ modelId: "grok-3", fastMode: true })).toBe("grok-3-fast"); @@ -128,21 +157,7 @@ describe("xai stream wrappers", () => { }, ], }; - const baseStreamFn: StreamFn = (_model, _context, options) => { - options?.onPayload?.(payload, {} as Model<"openai-completions">); - return {} as ReturnType; - }; - const wrapped = createXaiToolPayloadCompatibilityWrapper(baseStreamFn); - - void wrapped( - { - api: "openai-completions", - provider: "xai", - id: "grok-4-1-fast-reasoning", - } as Model<"openai-completions">, - { messages: [] } as Context, - {}, - ); + runXaiToolPayloadWrapper({ payload, api: "openai-completions" }); expect(payload).not.toHaveProperty("reasoning"); expect(payload).not.toHaveProperty("reasoningEffort"); @@ -156,21 +171,7 @@ describe("xai stream wrappers", () => { reasoningEffort: "high", reasoning_effort: "high", }; - const baseStreamFn: StreamFn = (_model, _context, options) => { - options?.onPayload?.(payload, {} as Model<"openai-responses">); - return {} as ReturnType; - }; - const wrapped = createXaiToolPayloadCompatibilityWrapper(baseStreamFn); - - void wrapped( - { - api: "openai-responses", - provider: "xai", - id: "grok-4-fast", - } as Model<"openai-responses">, - { messages: [] } as Context, - {}, - ); + runXaiToolPayloadWrapper({ payload }); expect(payload).not.toHaveProperty("reasoning"); expect(payload).not.toHaveProperty("reasoningEffort"); @@ -194,22 +195,7 @@ describe("xai stream wrappers", () => { }, ], }; - const baseStreamFn: StreamFn = (_model, _context, options) => { - options?.onPayload?.(payload, {} as Model<"openai-responses">); - return {} as ReturnType; - }; - const wrapped = createXaiToolPayloadCompatibilityWrapper(baseStreamFn); - - void wrapped( - { - api: "openai-responses", - provider: "xai", - id: "grok-4-fast", - input: ["text", "image"], - } as Model<"openai-responses">, - { messages: [] } as Context, - {}, - ); + runXaiToolPayloadWrapper({ payload, input: ["text", "image"] }); expect(payload.input).toEqual([ { @@ -252,22 +238,7 @@ describe("xai stream wrappers", () => { }, ], }; - const baseStreamFn: StreamFn = (_model, _context, options) => { - options?.onPayload?.(payload, {} as Model<"openai-responses">); - return {} as ReturnType; - }; - const wrapped = createXaiToolPayloadCompatibilityWrapper(baseStreamFn); - - void wrapped( - { - api: "openai-responses", - provider: "xai", - id: "grok-4-fast", - input: ["text", "image"], - } as Model<"openai-responses">, - { messages: [] } as Context, - {}, - ); + runXaiToolPayloadWrapper({ payload, input: ["text", "image"] }); expect(payload.input).toEqual([ { @@ -322,22 +293,7 @@ describe("xai stream wrappers", () => { }, ], }; - const baseStreamFn: StreamFn = (_model, _context, options) => { - options?.onPayload?.(payload, {} as Model<"openai-responses">); - return {} as ReturnType; - }; - const wrapped = createXaiToolPayloadCompatibilityWrapper(baseStreamFn); - - void wrapped( - { - api: "openai-responses", - provider: "xai", - id: "grok-4-fast", - input: ["text", "image"], - } as Model<"openai-responses">, - { messages: [] } as Context, - {}, - ); + runXaiToolPayloadWrapper({ payload, input: ["text", "image"] }); expect(payload.input).toEqual([ { @@ -386,22 +342,7 @@ describe("xai stream wrappers", () => { }, ], }; - const baseStreamFn: StreamFn = (_model, _context, options) => { - options?.onPayload?.(payload, {} as Model<"openai-responses">); - return {} as ReturnType; - }; - const wrapped = createXaiToolPayloadCompatibilityWrapper(baseStreamFn); - - void wrapped( - { - api: "openai-responses", - provider: "xai", - id: "grok-4-fast", - input: ["text"], - } as Model<"openai-responses">, - { messages: [] } as Context, - {}, - ); + runXaiToolPayloadWrapper({ payload, input: ["text"] }); expect(payload.input).toEqual([ { diff --git a/extensions/xai/video-generation-provider.test.ts b/extensions/xai/video-generation-provider.test.ts index 314ee80ac0f..15560612d15 100644 --- a/extensions/xai/video-generation-provider.test.ts +++ b/extensions/xai/video-generation-provider.test.ts @@ -1,45 +1,20 @@ -import { afterEach, describe, expect, it, vi } from "vitest"; -import { buildXaiVideoGenerationProvider } from "./video-generation-provider.js"; +import { beforeAll, describe, expect, it, vi } from "vitest"; +import { + getProviderHttpMocks, + installProviderHttpMockCleanup, +} from "../../test/helpers/media-generation/provider-http-mocks.js"; -const { - resolveApiKeyForProviderMock, - postJsonRequestMock, - fetchWithTimeoutMock, - assertOkOrThrowHttpErrorMock, - resolveProviderHttpRequestConfigMock, -} = vi.hoisted(() => ({ - resolveApiKeyForProviderMock: vi.fn(async () => ({ apiKey: "xai-key" })), - postJsonRequestMock: vi.fn(), - fetchWithTimeoutMock: vi.fn(), - assertOkOrThrowHttpErrorMock: vi.fn(async () => {}), - resolveProviderHttpRequestConfigMock: vi.fn((params) => ({ - baseUrl: params.baseUrl ?? params.defaultBaseUrl, - allowPrivateNetwork: false, - headers: new Headers(params.defaultHeaders), - dispatcherPolicy: undefined, - })), -})); +const { postJsonRequestMock, fetchWithTimeoutMock } = getProviderHttpMocks(); -vi.mock("openclaw/plugin-sdk/provider-auth-runtime", () => ({ - resolveApiKeyForProvider: resolveApiKeyForProviderMock, -})); +let buildXaiVideoGenerationProvider: typeof import("./video-generation-provider.js").buildXaiVideoGenerationProvider; -vi.mock("openclaw/plugin-sdk/provider-http", () => ({ - assertOkOrThrowHttpError: assertOkOrThrowHttpErrorMock, - fetchWithTimeout: fetchWithTimeoutMock, - postJsonRequest: postJsonRequestMock, - resolveProviderHttpRequestConfig: resolveProviderHttpRequestConfigMock, -})); +beforeAll(async () => { + ({ buildXaiVideoGenerationProvider } = await import("./video-generation-provider.js")); +}); + +installProviderHttpMockCleanup(); describe("xai video generation provider", () => { - afterEach(() => { - resolveApiKeyForProviderMock.mockClear(); - postJsonRequestMock.mockReset(); - fetchWithTimeoutMock.mockReset(); - assertOkOrThrowHttpErrorMock.mockClear(); - resolveProviderHttpRequestConfigMock.mockClear(); - }); - it("creates, polls, and downloads a generated video", async () => { postJsonRequestMock.mockResolvedValue({ response: { diff --git a/extensions/zalouser/src/channel.setup.test.ts b/extensions/zalouser/src/channel.setup.test.ts index 0f7d349311a..c30cf501274 100644 --- a/extensions/zalouser/src/channel.setup.test.ts +++ b/extensions/zalouser/src/channel.setup.test.ts @@ -1,52 +1,11 @@ import { mkdtemp, rm } from "node:fs/promises"; import os from "node:os"; import path from "node:path"; -import { createScopedDmSecurityResolver } from "openclaw/plugin-sdk/channel-config-helpers"; import { withEnvAsync } from "openclaw/plugin-sdk/testing"; import { describe, expect, it } from "vitest"; import { createPluginSetupWizardStatus } from "../../../test/helpers/plugins/setup-wizard.js"; -import type { OpenClawConfig } from "../runtime-api.js"; import "./zalo-js.test-mocks.js"; -import { - listZalouserAccountIds, - resolveDefaultZalouserAccountId, - resolveZalouserAccountSync, -} from "./accounts.js"; -import { zalouserSetupAdapter } from "./setup-core.js"; -import { zalouserSetupWizard } from "./setup-surface.js"; - -const zalouserSetupPlugin = { - id: "zalouser", - meta: { - id: "zalouser", - label: "ZaloUser", - selectionLabel: "ZaloUser", - docsPath: "/channels/zalouser", - blurb: "Unofficial Zalo personal account connector.", - }, - capabilities: { - chatTypes: ["direct", "group"] as Array<"direct" | "group">, - }, - config: { - listAccountIds: (cfg: unknown) => listZalouserAccountIds(cfg as never), - defaultAccountId: (cfg: unknown) => resolveDefaultZalouserAccountId(cfg as never), - resolveAccount: (cfg: OpenClawConfig, accountId?: string | null) => - resolveZalouserAccountSync({ cfg, accountId }), - }, - security: { - resolveDmPolicy: createScopedDmSecurityResolver({ - channelKey: "zalouser", - resolvePolicy: (account: ReturnType) => - account.config.dmPolicy, - resolveAllowFrom: (account: ReturnType) => - account.config.allowFrom, - policyPathSuffix: "dmPolicy", - normalizeEntry: (raw: string) => raw.trim().replace(/^(zalouser|zlu):/i, ""), - }), - }, - setup: zalouserSetupAdapter, - setupWizard: zalouserSetupWizard, -} as const; +import { zalouserSetupPlugin } from "./setup-test-helpers.js"; const zalouserSetupGetStatus = createPluginSetupWizardStatus(zalouserSetupPlugin); diff --git a/extensions/zalouser/src/setup-surface.test.ts b/extensions/zalouser/src/setup-surface.test.ts index 0b19625c898..0338a637817 100644 --- a/extensions/zalouser/src/setup-surface.test.ts +++ b/extensions/zalouser/src/setup-surface.test.ts @@ -1,4 +1,3 @@ -import { createScopedDmSecurityResolver } from "openclaw/plugin-sdk/channel-config-helpers"; import { describe, expect, it, vi } from "vitest"; import { createPluginSetupWizardConfigure, @@ -7,46 +6,8 @@ import { } from "../../../test/helpers/plugins/setup-wizard.js"; import type { OpenClawConfig } from "../runtime-api.js"; import "./zalo-js.test-mocks.js"; -import { - listZalouserAccountIds, - resolveDefaultZalouserAccountId, - resolveZalouserAccountSync, -} from "./accounts.js"; -import { zalouserSetupAdapter } from "./setup-core.js"; import { zalouserSetupWizard } from "./setup-surface.js"; - -const zalouserSetupPlugin = { - id: "zalouser", - meta: { - id: "zalouser", - label: "ZaloUser", - selectionLabel: "ZaloUser", - docsPath: "/channels/zalouser", - blurb: "Unofficial Zalo personal account connector.", - }, - capabilities: { - chatTypes: ["direct", "group"] as Array<"direct" | "group">, - }, - config: { - listAccountIds: (cfg: unknown) => listZalouserAccountIds(cfg as never), - defaultAccountId: (cfg: unknown) => resolveDefaultZalouserAccountId(cfg as never), - resolveAccount: (cfg: OpenClawConfig, accountId?: string | null) => - resolveZalouserAccountSync({ cfg, accountId }), - }, - security: { - resolveDmPolicy: createScopedDmSecurityResolver({ - channelKey: "zalouser", - resolvePolicy: (account: ReturnType) => - account.config.dmPolicy, - resolveAllowFrom: (account: ReturnType) => - account.config.allowFrom, - policyPathSuffix: "dmPolicy", - normalizeEntry: (raw: string) => raw.trim().replace(/^(zalouser|zlu):/i, ""), - }), - }, - setup: zalouserSetupAdapter, - setupWizard: zalouserSetupWizard, -} as const; +import { zalouserSetupPlugin } from "./setup-test-helpers.js"; const zalouserConfigure = createPluginSetupWizardConfigure(zalouserSetupPlugin); diff --git a/extensions/zalouser/src/setup-test-helpers.ts b/extensions/zalouser/src/setup-test-helpers.ts new file mode 100644 index 00000000000..73b42d24f02 --- /dev/null +++ b/extensions/zalouser/src/setup-test-helpers.ts @@ -0,0 +1,42 @@ +import { createScopedDmSecurityResolver } from "openclaw/plugin-sdk/channel-config-helpers"; +import type { OpenClawConfig } from "../runtime-api.js"; +import { + listZalouserAccountIds, + resolveDefaultZalouserAccountId, + resolveZalouserAccountSync, +} from "./accounts.js"; +import { zalouserSetupAdapter } from "./setup-core.js"; +import { zalouserSetupWizard } from "./setup-surface.js"; + +export const zalouserSetupPlugin = { + id: "zalouser", + meta: { + id: "zalouser", + label: "ZaloUser", + selectionLabel: "ZaloUser", + docsPath: "/channels/zalouser", + blurb: "Unofficial Zalo personal account connector.", + }, + capabilities: { + chatTypes: ["direct", "group"] as Array<"direct" | "group">, + }, + config: { + listAccountIds: (cfg: unknown) => listZalouserAccountIds(cfg as never), + defaultAccountId: (cfg: unknown) => resolveDefaultZalouserAccountId(cfg as never), + resolveAccount: (cfg: OpenClawConfig, accountId?: string | null) => + resolveZalouserAccountSync({ cfg, accountId }), + }, + security: { + resolveDmPolicy: createScopedDmSecurityResolver({ + channelKey: "zalouser", + resolvePolicy: (account: ReturnType) => + account.config.dmPolicy, + resolveAllowFrom: (account: ReturnType) => + account.config.allowFrom, + policyPathSuffix: "dmPolicy", + normalizeEntry: (raw: string) => raw.trim().replace(/^(zalouser|zlu):/i, ""), + }), + }, + setup: zalouserSetupAdapter, + setupWizard: zalouserSetupWizard, +} as const; diff --git a/src/acp/translator.stop-reason.test.ts b/src/acp/translator.stop-reason.test.ts index ea3e08c9e6f..fd09c7372a3 100644 --- a/src/acp/translator.stop-reason.test.ts +++ b/src/acp/translator.stop-reason.test.ts @@ -12,10 +12,56 @@ type PendingPromptHarness = { runId: string; }; -async function createPendingPromptHarness(): Promise { - const sessionId = "session-1"; - const sessionKey = "agent:main:main"; +const DEFAULT_SESSION_ID = "session-1"; +const DEFAULT_SESSION_KEY = "agent:main:main"; +const DEFAULT_PROMPT_TEXT = "hello"; +function createSessionAgentHarness( + request: GatewayClient["request"], + options: { sessionId?: string; sessionKey?: string; cwd?: string } = {}, +) { + const sessionId = options.sessionId ?? DEFAULT_SESSION_ID; + const sessionKey = options.sessionKey ?? DEFAULT_SESSION_KEY; + const sessionStore = createInMemorySessionStore(); + sessionStore.createSession({ + sessionId, + sessionKey, + cwd: options.cwd ?? "/tmp", + }); + const agent = new AcpGatewayAgent(createAcpConnection(), createAcpGateway(request), { + sessionStore, + }); + + return { + agent, + sessionId, + sessionKey, + sessionStore, + }; +} + +function promptAgent( + agent: AcpGatewayAgent, + sessionId = DEFAULT_SESSION_ID, + text = DEFAULT_PROMPT_TEXT, +) { + return agent.prompt({ + sessionId, + prompt: [{ type: "text", text }], + _meta: {}, + } as unknown as PromptRequest); +} + +function observeSettlement(promise: ReturnType) { + const settleSpy = vi.fn(); + void promise.then( + (value) => settleSpy({ kind: "resolve", value }), + (error) => settleSpy({ kind: "reject", error }), + ); + return settleSpy; +} + +async function createPendingPromptHarness(): Promise { let runId: string | undefined; const request = vi.fn(async (method: string, params?: Record) => { if (method === "chat.send") { @@ -25,23 +71,8 @@ async function createPendingPromptHarness(): Promise { return {}; }) as GatewayClient["request"]; - const sessionStore = createInMemorySessionStore(); - sessionStore.createSession({ - sessionId, - sessionKey, - cwd: "/tmp", - }); - - const agent = new AcpGatewayAgent( - createAcpConnection(), - createAcpGateway(request as unknown as GatewayClient["request"]), - { sessionStore }, - ); - const promptPromise = agent.prompt({ - sessionId, - prompt: [{ type: "text", text: "hello" }], - _meta: {}, - } as unknown as PromptRequest); + const { agent, sessionId } = createSessionAgentHarness(request); + const promptPromise = promptAgent(agent, sessionId); await vi.waitFor(() => { expect(runId).toBeDefined(); @@ -111,11 +142,7 @@ describe("acp translator stop reason mapping", () => { it("keeps in-flight prompts pending across transient gateway disconnects", async () => { const { agent, promptPromise, runId } = await createPendingPromptHarness(); - const settleSpy = vi.fn(); - void promptPromise.then( - (value) => settleSpy({ kind: "resolve", value }), - (error) => settleSpy({ kind: "reject", error }), - ); + const settleSpy = observeSettlement(promptPromise); agent.handleGatewayDisconnect("1006: connection lost"); await Promise.resolve(); @@ -153,31 +180,15 @@ describe("acp translator stop reason mapping", () => { it("keeps pre-ack send disconnects inside the reconnect grace window", async () => { vi.useFakeTimers(); try { - const sessionStore = createInMemorySessionStore(); - sessionStore.createSession({ - sessionId: "session-1", - sessionKey: "agent:main:main", - cwd: "/tmp", - }); const request = vi.fn(async (method: string) => { if (method === "chat.send") { throw new Error("gateway closed (1006): connection lost"); } return {}; }) as GatewayClient["request"]; - const agent = new AcpGatewayAgent(createAcpConnection(), createAcpGateway(request), { - sessionStore, - }); - const promptPromise = agent.prompt({ - sessionId: "session-1", - prompt: [{ type: "text", text: "hello" }], - _meta: {}, - } as unknown as PromptRequest); - const settleSpy = vi.fn(); - void promptPromise.then( - (value) => settleSpy({ kind: "resolve", value }), - (error) => settleSpy({ kind: "reject", error }), - ); + const { agent, sessionId } = createSessionAgentHarness(request); + const promptPromise = promptAgent(agent, sessionId); + const settleSpy = observeSettlement(promptPromise); await Promise.resolve(); expect(settleSpy).not.toHaveBeenCalled(); @@ -194,8 +205,6 @@ describe("acp translator stop reason mapping", () => { }); it("reconciles a missed final event on reconnect via agent.wait", async () => { - const sessionId = "session-1"; - const sessionKey = "agent:main:main"; let runId: string | undefined; const request = vi.fn(async (method: string, params?: Record) => { if (method === "chat.send") { @@ -207,20 +216,8 @@ describe("acp translator stop reason mapping", () => { } return {}; }) as GatewayClient["request"]; - const sessionStore = createInMemorySessionStore(); - sessionStore.createSession({ - sessionId, - sessionKey, - cwd: "/tmp", - }); - const agent = new AcpGatewayAgent(createAcpConnection(), createAcpGateway(request), { - sessionStore, - }); - const promptPromise = agent.prompt({ - sessionId, - prompt: [{ type: "text", text: "hello" }], - _meta: {}, - } as unknown as PromptRequest); + const { agent, sessionId } = createSessionAgentHarness(request); + const promptPromise = promptAgent(agent, sessionId); await vi.waitFor(() => { expect(runId).toBeDefined(); @@ -243,8 +240,6 @@ describe("acp translator stop reason mapping", () => { it("rechecks accepted prompts at the disconnect deadline after reconnect timeout", async () => { vi.useFakeTimers(); try { - const sessionId = "session-1"; - const sessionKey = "agent:main:main"; let waitCount = 0; const request = vi.fn(async (method: string, params?: Record) => { if (method === "chat.send") { @@ -260,25 +255,9 @@ describe("acp translator stop reason mapping", () => { } return {}; }) as GatewayClient["request"]; - const sessionStore = createInMemorySessionStore(); - sessionStore.createSession({ - sessionId, - sessionKey, - cwd: "/tmp", - }); - const agent = new AcpGatewayAgent(createAcpConnection(), createAcpGateway(request), { - sessionStore, - }); - const promptPromise = agent.prompt({ - sessionId, - prompt: [{ type: "text", text: "hello" }], - _meta: {}, - } as unknown as PromptRequest); - const settleSpy = vi.fn(); - void promptPromise.then( - (value) => settleSpy({ kind: "resolve", value }), - (error) => settleSpy({ kind: "reject", error }), - ); + const { agent, sessionId } = createSessionAgentHarness(request); + const promptPromise = promptAgent(agent, sessionId); + const settleSpy = observeSettlement(promptPromise); await Promise.resolve(); agent.handleGatewayDisconnect("1006: connection lost"); @@ -298,8 +277,6 @@ describe("acp translator stop reason mapping", () => { it("keeps accepted prompts pending when the deadline recheck still reports timeout", async () => { vi.useFakeTimers(); try { - const sessionId = "session-1"; - const sessionKey = "agent:main:main"; const request = vi.fn(async (method: string) => { if (method === "chat.send") { return {}; @@ -309,20 +286,8 @@ describe("acp translator stop reason mapping", () => { } return {}; }) as GatewayClient["request"]; - const sessionStore = createInMemorySessionStore(); - sessionStore.createSession({ - sessionId, - sessionKey, - cwd: "/tmp", - }); - const agent = new AcpGatewayAgent(createAcpConnection(), createAcpGateway(request), { - sessionStore, - }); - const promptPromise = agent.prompt({ - sessionId, - prompt: [{ type: "text", text: "hello" }], - _meta: {}, - } as unknown as PromptRequest); + const { agent, sessionId } = createSessionAgentHarness(request); + const promptPromise = promptAgent(agent, sessionId); await Promise.resolve(); agent.handleGatewayDisconnect("1006: connection lost"); @@ -341,8 +306,6 @@ describe("acp translator stop reason mapping", () => { it("does not clear a newer disconnect deadline while reconnect reconciliation is still running", async () => { vi.useFakeTimers(); try { - const sessionId = "session-1"; - const sessionKey = "agent:main:main"; let resolveAgentWait: ((value: { status: "timeout" }) => void) | undefined; let agentWaitCount = 0; const request = vi.fn(async (method: string) => { @@ -360,25 +323,9 @@ describe("acp translator stop reason mapping", () => { } return {}; }) as GatewayClient["request"]; - const sessionStore = createInMemorySessionStore(); - sessionStore.createSession({ - sessionId, - sessionKey, - cwd: "/tmp", - }); - const agent = new AcpGatewayAgent(createAcpConnection(), createAcpGateway(request), { - sessionStore, - }); - const promptPromise = agent.prompt({ - sessionId, - prompt: [{ type: "text", text: "hello" }], - _meta: {}, - } as unknown as PromptRequest); - const settleSpy = vi.fn(); - void promptPromise.then( - (value) => settleSpy({ kind: "resolve", value }), - (error) => settleSpy({ kind: "reject", error }), - ); + const { agent, sessionId } = createSessionAgentHarness(request); + const promptPromise = promptAgent(agent, sessionId); + const settleSpy = observeSettlement(promptPromise); await Promise.resolve(); agent.handleGatewayDisconnect("1006: first disconnect"); @@ -405,8 +352,6 @@ describe("acp translator stop reason mapping", () => { it("rejects pre-ack prompts when reconnect timeout still finds no run", async () => { vi.useFakeTimers(); try { - const sessionId = "session-1"; - const sessionKey = "agent:main:main"; const request = vi.fn(async (method: string) => { if (method === "chat.send") { throw new Error("gateway closed (1006): connection lost"); @@ -416,20 +361,8 @@ describe("acp translator stop reason mapping", () => { } return {}; }) as GatewayClient["request"]; - const sessionStore = createInMemorySessionStore(); - sessionStore.createSession({ - sessionId, - sessionKey, - cwd: "/tmp", - }); - const agent = new AcpGatewayAgent(createAcpConnection(), createAcpGateway(request), { - sessionStore, - }); - const promptPromise = agent.prompt({ - sessionId, - prompt: [{ type: "text", text: "hello" }], - _meta: {}, - } as unknown as PromptRequest); + const { agent, sessionId } = createSessionAgentHarness(request); + const promptPromise = promptAgent(agent, sessionId); void promptPromise.catch(() => {}); await Promise.resolve(); @@ -449,8 +382,6 @@ describe("acp translator stop reason mapping", () => { }); it("rejects a superseded pre-ack prompt when a newer prompt has replaced the session entry", async () => { - const sessionId = "session-1"; - const sessionKey = "agent:main:main"; let promptCount = 0; const request = vi.fn(async (method: string) => { if (method !== "chat.send") { @@ -462,28 +393,12 @@ describe("acp translator stop reason mapping", () => { } return {}; }) as GatewayClient["request"]; - const sessionStore = createInMemorySessionStore(); - sessionStore.createSession({ - sessionId, - sessionKey, - cwd: "/tmp", - }); - const agent = new AcpGatewayAgent(createAcpConnection(), createAcpGateway(request), { - sessionStore, - }); + const { agent, sessionId } = createSessionAgentHarness(request); - const firstPrompt = agent.prompt({ - sessionId, - prompt: [{ type: "text", text: "first" }], - _meta: {}, - } as unknown as PromptRequest); + const firstPrompt = promptAgent(agent, sessionId, "first"); await Promise.resolve(); - const secondPrompt = agent.prompt({ - sessionId, - prompt: [{ type: "text", text: "second" }], - _meta: {}, - } as unknown as PromptRequest); + const secondPrompt = promptAgent(agent, sessionId, "second"); await expect(firstPrompt).rejects.toThrow("gateway closed (1006): connection lost"); await expect(Promise.race([secondPrompt, Promise.resolve("pending")])).resolves.toBe("pending"); @@ -492,8 +407,6 @@ describe("acp translator stop reason mapping", () => { it("rejects stale pre-ack prompts when a superseded send resolves late", async () => { vi.useFakeTimers(); try { - const sessionId = "session-1"; - const sessionKey = "agent:main:main"; let firstSendResolve: (() => void) | undefined; let sendCount = 0; const request = vi.fn(async (method: string) => { @@ -511,30 +424,14 @@ describe("acp translator stop reason mapping", () => { } return {}; }) as GatewayClient["request"]; - const sessionStore = createInMemorySessionStore(); - sessionStore.createSession({ - sessionId, - sessionKey, - cwd: "/tmp", - }); - const agent = new AcpGatewayAgent(createAcpConnection(), createAcpGateway(request), { - sessionStore, - }); + const { agent, sessionId } = createSessionAgentHarness(request); - const firstPrompt = agent.prompt({ - sessionId, - prompt: [{ type: "text", text: "first" }], - _meta: {}, - } as unknown as PromptRequest); + const firstPrompt = promptAgent(agent, sessionId, "first"); void firstPrompt.catch(() => {}); await Promise.resolve(); expect(firstSendResolve).toBeDefined(); - const secondPrompt = agent.prompt({ - sessionId, - prompt: [{ type: "text", text: "second" }], - _meta: {}, - } as unknown as PromptRequest); + const secondPrompt = promptAgent(agent, sessionId, "second"); void secondPrompt.catch(() => {}); await Promise.resolve(); expect(sendCount).toBe(2); @@ -598,11 +495,7 @@ describe("acp translator stop reason mapping", () => { prompt: [{ type: "text", text: "pre-ack" }], _meta: {}, } as unknown as PromptRequest); - const acceptedSettleSpy = vi.fn(); - void acceptedPrompt.then( - (value) => acceptedSettleSpy({ kind: "resolve", value }), - (error) => acceptedSettleSpy({ kind: "reject", error }), - ); + observeSettlement(acceptedPrompt); void preAckPrompt.catch(() => {}); await Promise.resolve(); @@ -624,8 +517,6 @@ describe("acp translator stop reason mapping", () => { }); it("reconciles prompts started while the gateway is disconnected", async () => { - const sessionId = "session-1"; - const sessionKey = "agent:main:main"; const request = vi.fn(async (method: string) => { if (method === "chat.send") { throw new Error("gateway closed (1006): connection lost"); @@ -635,27 +526,11 @@ describe("acp translator stop reason mapping", () => { } return {}; }) as GatewayClient["request"]; - const sessionStore = createInMemorySessionStore(); - sessionStore.createSession({ - sessionId, - sessionKey, - cwd: "/tmp", - }); - const agent = new AcpGatewayAgent(createAcpConnection(), createAcpGateway(request), { - sessionStore, - }); + const { agent, sessionId } = createSessionAgentHarness(request); agent.handleGatewayDisconnect("1006: connection lost"); - const promptPromise = agent.prompt({ - sessionId, - prompt: [{ type: "text", text: "hello" }], - _meta: {}, - } as unknown as PromptRequest); - const settleSpy = vi.fn(); - void promptPromise.then( - (value) => settleSpy({ kind: "resolve", value }), - (error) => settleSpy({ kind: "reject", error }), - ); + const promptPromise = promptAgent(agent, sessionId); + const settleSpy = observeSettlement(promptPromise); await Promise.resolve(); agent.handleGatewayReconnect(); @@ -670,8 +545,6 @@ describe("acp translator stop reason mapping", () => { it("does not let a stale disconnect deadline reject a newer prompt on the same session", async () => { vi.useFakeTimers(); try { - const sessionId = "session-1"; - const sessionKey = "agent:main:main"; let sendCount = 0; const requestMock = vi.fn(async (method: string, params?: Record) => { if (method === "chat.send") { @@ -687,21 +560,9 @@ describe("acp translator stop reason mapping", () => { return {}; }); const request = requestMock as GatewayClient["request"]; - const sessionStore = createInMemorySessionStore(); - sessionStore.createSession({ - sessionId, - sessionKey, - cwd: "/tmp", - }); - const agent = new AcpGatewayAgent(createAcpConnection(), createAcpGateway(request), { - sessionStore, - }); + const { agent, sessionId } = createSessionAgentHarness(request); - const firstPrompt = agent.prompt({ - sessionId, - prompt: [{ type: "text", text: "first" }], - _meta: {}, - } as unknown as PromptRequest); + const firstPrompt = promptAgent(agent, sessionId, "first"); void firstPrompt.catch(() => {}); await Promise.resolve(); const firstRunId = requestMock.mock.calls[0]?.[1]?.idempotencyKey as string; @@ -710,11 +571,7 @@ describe("acp translator stop reason mapping", () => { agent.handleGatewayReconnect(); await Promise.resolve(); - const secondPrompt = agent.prompt({ - sessionId, - prompt: [{ type: "text", text: "second" }], - _meta: {}, - } as unknown as PromptRequest); + const secondPrompt = promptAgent(agent, sessionId, "second"); await vi.advanceTimersByTimeAsync(5_000); await expect(Promise.race([secondPrompt, Promise.resolve("pending")])).resolves.toBe( diff --git a/src/agents/bundle-mcp-shared.test-harness.ts b/src/agents/bundle-mcp-shared.test-harness.ts new file mode 100644 index 00000000000..de277c10f19 --- /dev/null +++ b/src/agents/bundle-mcp-shared.test-harness.ts @@ -0,0 +1,98 @@ +import fs from "node:fs/promises"; +import { createRequire } from "node:module"; +import path from "node:path"; + +const require = createRequire(import.meta.url); +const SDK_SERVER_MCP_PATH = require.resolve("@modelcontextprotocol/sdk/server/mcp.js"); +const SDK_SERVER_STDIO_PATH = require.resolve("@modelcontextprotocol/sdk/server/stdio.js"); + +export async function writeExecutable(filePath: string, content: string): Promise { + await fs.mkdir(path.dirname(filePath), { recursive: true }); + await fs.writeFile(filePath, content, { encoding: "utf-8", mode: 0o755 }); +} + +export async function writeBundleProbeMcpServer( + filePath: string, + params: { + startupCounterPath?: string; + startupDelayMs?: number; + pidPath?: string; + exitMarkerPath?: string; + } = {}, +): Promise { + await writeExecutable( + filePath, + `#!/usr/bin/env node +import fs from "node:fs"; +import fsp from "node:fs/promises"; +import { setTimeout as delay } from "node:timers/promises"; +import { McpServer } from ${JSON.stringify(SDK_SERVER_MCP_PATH)}; +import { StdioServerTransport } from ${JSON.stringify(SDK_SERVER_STDIO_PATH)}; + +const startupCounterPath = ${JSON.stringify(params.startupCounterPath ?? "")}; +if (startupCounterPath) { + let current = 0; + try { + current = Number.parseInt((await fsp.readFile(startupCounterPath, "utf8")).trim(), 10) || 0; + } catch {} + await fsp.writeFile(startupCounterPath, String(current + 1), "utf8"); +} +const pidPath = ${JSON.stringify(params.pidPath ?? "")}; +if (pidPath) { + await fsp.writeFile(pidPath, String(process.pid), "utf8"); +} +const exitMarkerPath = ${JSON.stringify(params.exitMarkerPath ?? "")}; +if (exitMarkerPath) { + process.once("exit", () => { + try { + fs.writeFileSync(exitMarkerPath, "exited", "utf8"); + } catch {} + }); +} +const startupDelayMs = ${JSON.stringify(params.startupDelayMs ?? 0)}; +if (startupDelayMs > 0) { + await delay(startupDelayMs); +} + +const server = new McpServer({ name: "bundle-probe", version: "1.0.0" }); +server.tool("bundle_probe", "Bundle MCP probe", async () => { + return { + content: [{ type: "text", text: process.env.BUNDLE_PROBE_TEXT ?? "missing-probe-text" }], + }; +}); + +await server.connect(new StdioServerTransport()); +`, + ); +} + +export async function writeClaudeBundle(params: { + pluginRoot: string; + serverScriptPath: string; +}): Promise { + await fs.mkdir(path.join(params.pluginRoot, ".claude-plugin"), { recursive: true }); + await fs.writeFile( + path.join(params.pluginRoot, ".claude-plugin", "plugin.json"), + `${JSON.stringify({ name: "bundle-probe" }, null, 2)}\n`, + "utf-8", + ); + await fs.writeFile( + path.join(params.pluginRoot, ".mcp.json"), + `${JSON.stringify( + { + mcpServers: { + bundleProbe: { + command: "node", + args: [path.relative(params.pluginRoot, params.serverScriptPath)], + env: { + BUNDLE_PROBE_TEXT: "FROM-BUNDLE", + }, + }, + }, + }, + null, + 2, + )}\n`, + "utf-8", + ); +} diff --git a/src/agents/bundle-mcp.test-harness.ts b/src/agents/bundle-mcp.test-harness.ts index 4fa84028d56..13e851d477c 100644 --- a/src/agents/bundle-mcp.test-harness.ts +++ b/src/agents/bundle-mcp.test-harness.ts @@ -1,67 +1,15 @@ -import fs from "node:fs/promises"; import { createRequire } from "node:module"; -import path from "node:path"; +import { + writeBundleProbeMcpServer, + writeClaudeBundle, + writeExecutable, +} from "./bundle-mcp-shared.test-harness.js"; const require = createRequire(import.meta.url); -const SDK_SERVER_MCP_PATH = require.resolve("@modelcontextprotocol/sdk/server/mcp.js"); -const SDK_SERVER_STDIO_PATH = require.resolve("@modelcontextprotocol/sdk/server/stdio.js"); const SDK_CLIENT_INDEX_PATH = require.resolve("@modelcontextprotocol/sdk/client/index.js"); const SDK_CLIENT_STDIO_PATH = require.resolve("@modelcontextprotocol/sdk/client/stdio.js"); -export async function writeExecutable(filePath: string, content: string): Promise { - await fs.mkdir(path.dirname(filePath), { recursive: true }); - await fs.writeFile(filePath, content, { encoding: "utf-8", mode: 0o755 }); -} - -export async function writeBundleProbeMcpServer(filePath: string): Promise { - await writeExecutable( - filePath, - `#!/usr/bin/env node -import { McpServer } from ${JSON.stringify(SDK_SERVER_MCP_PATH)}; -import { StdioServerTransport } from ${JSON.stringify(SDK_SERVER_STDIO_PATH)}; - -const server = new McpServer({ name: "bundle-probe", version: "1.0.0" }); -server.tool("bundle_probe", "Bundle MCP probe", async () => { - return { - content: [{ type: "text", text: process.env.BUNDLE_PROBE_TEXT ?? "missing-probe-text" }], - }; -}); - -await server.connect(new StdioServerTransport()); -`, - ); -} - -export async function writeClaudeBundle(params: { - pluginRoot: string; - serverScriptPath: string; -}): Promise { - await fs.mkdir(path.join(params.pluginRoot, ".claude-plugin"), { recursive: true }); - await fs.writeFile( - path.join(params.pluginRoot, ".claude-plugin", "plugin.json"), - `${JSON.stringify({ name: "bundle-probe" }, null, 2)}\n`, - "utf-8", - ); - await fs.writeFile( - path.join(params.pluginRoot, ".mcp.json"), - `${JSON.stringify( - { - mcpServers: { - bundleProbe: { - command: "node", - args: [path.relative(params.pluginRoot, params.serverScriptPath)], - env: { - BUNDLE_PROBE_TEXT: "FROM-BUNDLE", - }, - }, - }, - }, - null, - 2, - )}\n`, - "utf-8", - ); -} +export { writeBundleProbeMcpServer, writeClaudeBundle, writeExecutable }; export async function writeFakeClaudeCli(filePath: string): Promise { await writeExecutable( diff --git a/src/agents/memory-search.ts b/src/agents/memory-search.ts index 8c85a343ef3..db07f6e6a83 100644 --- a/src/agents/memory-search.ts +++ b/src/agents/memory-search.ts @@ -225,30 +225,7 @@ function mergeConfig( tokens: overrides?.chunking?.tokens ?? defaults?.chunking?.tokens ?? DEFAULT_CHUNK_TOKENS, overlap: overrides?.chunking?.overlap ?? defaults?.chunking?.overlap ?? DEFAULT_CHUNK_OVERLAP, }; - const sync = { - onSessionStart: overrides?.sync?.onSessionStart ?? defaults?.sync?.onSessionStart ?? true, - onSearch: overrides?.sync?.onSearch ?? defaults?.sync?.onSearch ?? true, - watch: overrides?.sync?.watch ?? defaults?.sync?.watch ?? true, - watchDebounceMs: - overrides?.sync?.watchDebounceMs ?? - defaults?.sync?.watchDebounceMs ?? - DEFAULT_WATCH_DEBOUNCE_MS, - intervalMinutes: overrides?.sync?.intervalMinutes ?? defaults?.sync?.intervalMinutes ?? 0, - sessions: { - deltaBytes: - overrides?.sync?.sessions?.deltaBytes ?? - defaults?.sync?.sessions?.deltaBytes ?? - DEFAULT_SESSION_DELTA_BYTES, - deltaMessages: - overrides?.sync?.sessions?.deltaMessages ?? - defaults?.sync?.sessions?.deltaMessages ?? - DEFAULT_SESSION_DELTA_MESSAGES, - postCompactionForce: - overrides?.sync?.sessions?.postCompactionForce ?? - defaults?.sync?.sessions?.postCompactionForce ?? - true, - }, - }; + const sync = resolveSyncConfig(defaults, overrides); const query = { maxResults: overrides?.query?.maxResults ?? defaults?.query?.maxResults ?? DEFAULT_MAX_RESULTS, minScore: overrides?.query?.minScore ?? defaults?.query?.minScore ?? DEFAULT_MIN_SCORE, diff --git a/src/agents/models.profiles.live.test.ts b/src/agents/models.profiles.live.test.ts index 31d79d728c9..1b624ebe317 100644 --- a/src/agents/models.profiles.live.test.ts +++ b/src/agents/models.profiles.live.test.ts @@ -2,6 +2,7 @@ import { type Api, completeSimple, type Model } from "@mariozechner/pi-ai"; import { Type } from "@sinclair/typebox"; import { describe, expect, it } from "vitest"; import { loadConfig } from "../config/config.js"; +import { parseLiveCsvFilter } from "../media-generation/live-test-helpers.js"; import { resolveOpenClawAgentDir } from "./agent-paths.js"; import { collectAnthropicApiKeys, @@ -30,15 +31,7 @@ const LIVE_SETUP_TIMEOUT_MS = Math.max( const describeLive = LIVE ? describe : describe.skip; function parseCsvFilter(raw?: string): Set | null { - const trimmed = raw?.trim(); - if (!trimmed || trimmed === "all") { - return null; - } - const ids = trimmed - .split(",") - .map((s) => s.trim()) - .filter(Boolean); - return ids.length ? new Set(ids) : null; + return parseLiveCsvFilter(raw, { lowercase: false }); } function parseProviderFilter(raw?: string): Set | null { diff --git a/src/agents/pi-bundle-mcp-test-harness.ts b/src/agents/pi-bundle-mcp-test-harness.ts index a5b5de9274b..2d4ad3a7d97 100644 --- a/src/agents/pi-bundle-mcp-test-harness.ts +++ b/src/agents/pi-bundle-mcp-test-harness.ts @@ -3,12 +3,16 @@ import http from "node:http"; import { createRequire } from "node:module"; import os from "node:os"; import path from "node:path"; +import { + writeBundleProbeMcpServer, + writeClaudeBundle, + writeExecutable, +} from "./bundle-mcp-shared.test-harness.js"; import { __testing } from "./pi-bundle-mcp-tools.js"; const require = createRequire(import.meta.url); const SDK_SERVER_MCP_PATH = require.resolve("@modelcontextprotocol/sdk/server/mcp.js"); const SDK_SERVER_SSE_PATH = require.resolve("@modelcontextprotocol/sdk/server/sse.js"); -const SDK_SERVER_STDIO_PATH = require.resolve("@modelcontextprotocol/sdk/server/stdio.js"); const tempDirs: string[] = []; @@ -25,65 +29,7 @@ export async function makeTempDir(prefix: string): Promise { return dir; } -export async function writeExecutable(filePath: string, content: string): Promise { - await fs.mkdir(path.dirname(filePath), { recursive: true }); - await fs.writeFile(filePath, content, { encoding: "utf-8", mode: 0o755 }); -} - -export async function writeBundleProbeMcpServer( - filePath: string, - params: { - startupCounterPath?: string; - startupDelayMs?: number; - pidPath?: string; - exitMarkerPath?: string; - } = {}, -): Promise { - await writeExecutable( - filePath, - `#!/usr/bin/env node -import fs from "node:fs"; -import fsp from "node:fs/promises"; -import { setTimeout as delay } from "node:timers/promises"; -import { McpServer } from ${JSON.stringify(SDK_SERVER_MCP_PATH)}; -import { StdioServerTransport } from ${JSON.stringify(SDK_SERVER_STDIO_PATH)}; - -const startupCounterPath = ${JSON.stringify(params.startupCounterPath ?? "")}; -if (startupCounterPath) { - let current = 0; - try { - current = Number.parseInt((await fsp.readFile(startupCounterPath, "utf8")).trim(), 10) || 0; - } catch {} - await fsp.writeFile(startupCounterPath, String(current + 1), "utf8"); -} -const pidPath = ${JSON.stringify(params.pidPath ?? "")}; -if (pidPath) { - await fsp.writeFile(pidPath, String(process.pid), "utf8"); -} -const exitMarkerPath = ${JSON.stringify(params.exitMarkerPath ?? "")}; -if (exitMarkerPath) { - process.once("exit", () => { - try { - fs.writeFileSync(exitMarkerPath, "exited", "utf8"); - } catch {} - }); -} -const startupDelayMs = ${JSON.stringify(params.startupDelayMs ?? 0)}; -if (startupDelayMs > 0) { - await delay(startupDelayMs); -} - -const server = new McpServer({ name: "bundle-probe", version: "1.0.0" }); -server.tool("bundle_probe", "Bundle MCP probe", async () => { - return { - content: [{ type: "text", text: process.env.BUNDLE_PROBE_TEXT ?? "missing-probe-text" }], - }; -}); - -await server.connect(new StdioServerTransport()); -`, - ); -} +export { writeBundleProbeMcpServer, writeClaudeBundle, writeExecutable }; export async function waitForFileText(filePath: string, timeoutMs = 5_000): Promise { const start = Date.now(); @@ -97,37 +43,6 @@ export async function waitForFileText(filePath: string, timeoutMs = 5_000): Prom throw new Error(`Timed out waiting for ${filePath}`); } -export async function writeClaudeBundle(params: { - pluginRoot: string; - serverScriptPath: string; -}): Promise { - await fs.mkdir(path.join(params.pluginRoot, ".claude-plugin"), { recursive: true }); - await fs.writeFile( - path.join(params.pluginRoot, ".claude-plugin", "plugin.json"), - `${JSON.stringify({ name: "bundle-probe" }, null, 2)}\n`, - "utf-8", - ); - await fs.writeFile( - path.join(params.pluginRoot, ".mcp.json"), - `${JSON.stringify( - { - mcpServers: { - bundleProbe: { - command: "node", - args: [path.relative(params.pluginRoot, params.serverScriptPath)], - env: { - BUNDLE_PROBE_TEXT: "FROM-BUNDLE", - }, - }, - }, - }, - null, - 2, - )}\n`, - "utf-8", - ); -} - export async function startSseProbeServer( probeText = "FROM-SSE", ): Promise<{ port: number; close: () => Promise }> { diff --git a/src/agents/pi-embedded-runner/google-prompt-cache.test.ts b/src/agents/pi-embedded-runner/google-prompt-cache.test.ts index aaaaafd07f0..0fd0644404a 100644 --- a/src/agents/pi-embedded-runner/google-prompt-cache.test.ts +++ b/src/agents/pi-embedded-runner/google-prompt-cache.test.ts @@ -51,50 +51,76 @@ function makeGoogleModel(id = "gemini-3.1-pro-preview") { } satisfies Model<"google-generative-ai">; } +function createCacheFetchMock(params: { name: string; expireTime: string }) { + return vi.fn().mockResolvedValue( + new Response(JSON.stringify(params), { + status: 200, + headers: { "content-type": "application/json" }, + }), + ); +} + +function createCapturingStreamFn(result = "stream") { + let capturedPayload: Record | undefined; + const streamFn = vi.fn( + ( + model: Parameters[0], + _context: Parameters[1], + options: Parameters[2], + ) => { + const payload: Record = {}; + void options?.onPayload?.(payload, model); + capturedPayload = payload; + return result as never; + }, + ); + return { + streamFn, + getCapturedPayload: () => capturedPayload, + }; +} + +function preparePromptCacheStream(params: { + fetchMock: ReturnType; + now: number; + sessionManager: ReturnType; + streamFn: StreamFn; +}) { + return prepareGooglePromptCacheStreamFn( + { + apiKey: "gemini-api-key", + extraParams: { cacheRetention: "long" }, + model: makeGoogleModel(), + modelId: "gemini-3.1-pro-preview", + provider: "google", + sessionManager: params.sessionManager, + streamFn: params.streamFn, + systemPrompt: "Follow policy.", + }, + { + buildGuardedFetch: () => params.fetchMock as typeof fetch, + now: () => params.now, + }, + ); +} + describe("google prompt cache", () => { it("creates cached content from the system prompt and strips that prompt from live requests", async () => { const now = 1_000_000; const entries: SessionCustomEntry[] = []; const sessionManager = makeSessionManager(entries); - const fetchMock = vi.fn().mockResolvedValue( - new Response( - JSON.stringify({ - name: "cachedContents/system-cache-1", - expireTime: new Date(now + 3_600_000).toISOString(), - }), - { status: 200, headers: { "content-type": "application/json" } }, - ), - ); - let capturedPayload: Record | undefined; - const innerStreamFn = vi.fn( - ( - model: Parameters[0], - _context: Parameters[1], - options: Parameters[2], - ) => { - const payload: Record = {}; - void options?.onPayload?.(payload, model); - capturedPayload = payload; - return "stream" as never; - }, - ); + const fetchMock = createCacheFetchMock({ + name: "cachedContents/system-cache-1", + expireTime: new Date(now + 3_600_000).toISOString(), + }); + const { streamFn: innerStreamFn, getCapturedPayload } = createCapturingStreamFn(); - const wrapped = await prepareGooglePromptCacheStreamFn( - { - apiKey: "gemini-api-key", - extraParams: { cacheRetention: "long" }, - model: makeGoogleModel(), - modelId: "gemini-3.1-pro-preview", - provider: "google", - sessionManager, - streamFn: innerStreamFn, - systemPrompt: "Follow policy.", - }, - { - buildGuardedFetch: () => fetchMock as typeof fetch, - now: () => now, - }, - ); + const wrapped = await preparePromptCacheStream({ + fetchMock, + now, + sessionManager, + streamFn: innerStreamFn, + }); expect(wrapped).toBeTypeOf("function"); void wrapped?.( @@ -143,7 +169,7 @@ describe("google prompt cache", () => { }), expect.objectContaining({ temperature: 0.2 }), ); - expect(capturedPayload).toMatchObject({ + expect(getCapturedPayload()).toMatchObject({ cachedContent: "cachedContents/system-cache-1", }); expect(entries).toHaveLength(1); @@ -155,63 +181,26 @@ describe("google prompt cache", () => { const now = 2_000_000; const entries: SessionCustomEntry[] = []; const sessionManager = makeSessionManager(entries); - const fetchMock = vi.fn().mockResolvedValue( - new Response( - JSON.stringify({ - name: "cachedContents/system-cache-2", - expireTime: new Date(now + 3_600_000).toISOString(), - }), - { status: 200, headers: { "content-type": "application/json" } }, - ), - ); + const fetchMock = createCacheFetchMock({ + name: "cachedContents/system-cache-2", + expireTime: new Date(now + 3_600_000).toISOString(), + }); - await prepareGooglePromptCacheStreamFn( - { - apiKey: "gemini-api-key", - extraParams: { cacheRetention: "long" }, - model: makeGoogleModel(), - modelId: "gemini-3.1-pro-preview", - provider: "google", - sessionManager, - streamFn: vi.fn(() => "first" as never), - systemPrompt: "Follow policy.", - }, - { - buildGuardedFetch: () => fetchMock as typeof fetch, - now: () => now, - }, - ); + await preparePromptCacheStream({ + fetchMock, + now, + sessionManager, + streamFn: vi.fn(() => "first" as never), + }); fetchMock.mockClear(); - let capturedPayload: Record | undefined; - const innerStreamFn = vi.fn( - ( - model: Parameters[0], - _context: Parameters[1], - options: Parameters[2], - ) => { - const payload: Record = {}; - void options?.onPayload?.(payload, model); - capturedPayload = payload; - return "second" as never; - }, - ); - const wrapped = await prepareGooglePromptCacheStreamFn( - { - apiKey: "gemini-api-key", - extraParams: { cacheRetention: "long" }, - model: makeGoogleModel(), - modelId: "gemini-3.1-pro-preview", - provider: "google", - sessionManager, - streamFn: innerStreamFn, - systemPrompt: "Follow policy.", - }, - { - buildGuardedFetch: () => fetchMock as typeof fetch, - now: () => now + 30_000, - }, - ); + const { streamFn: innerStreamFn, getCapturedPayload } = createCapturingStreamFn("second"); + const wrapped = await preparePromptCacheStream({ + fetchMock, + now: now + 30_000, + sessionManager, + streamFn: innerStreamFn, + }); void wrapped?.( makeGoogleModel(), @@ -225,7 +214,7 @@ describe("google prompt cache", () => { expect.objectContaining({ systemPrompt: undefined }), expect.any(Object), ); - expect(capturedPayload).toMatchObject({ + expect(getCapturedPayload()).toMatchObject({ cachedContent: "cachedContents/system-cache-2", }); }); @@ -255,45 +244,18 @@ describe("google prompt cache", () => { }, }, ]); - const fetchMock = vi.fn().mockResolvedValue( - new Response( - JSON.stringify({ - name: "cachedContents/system-cache-3", - expireTime: new Date(now + 3_600_000).toISOString(), - }), - { status: 200, headers: { "content-type": "application/json" } }, - ), - ); - let capturedPayload: Record | undefined; - const innerStreamFn = vi.fn( - ( - model: Parameters[0], - _context: Parameters[1], - options: Parameters[2], - ) => { - const payload: Record = {}; - void options?.onPayload?.(payload, model); - capturedPayload = payload; - return "stream" as never; - }, - ); + const fetchMock = createCacheFetchMock({ + name: "cachedContents/system-cache-3", + expireTime: new Date(now + 3_600_000).toISOString(), + }); + const { streamFn: innerStreamFn, getCapturedPayload } = createCapturingStreamFn(); - const wrapped = await prepareGooglePromptCacheStreamFn( - { - apiKey: "gemini-api-key", - extraParams: { cacheRetention: "long" }, - model: makeGoogleModel(), - modelId: "gemini-3.1-pro-preview", - provider: "google", - sessionManager, - streamFn: innerStreamFn, - systemPrompt: "Follow policy.", - }, - { - buildGuardedFetch: () => fetchMock as typeof fetch, - now: () => now, - }, - ); + const wrapped = await preparePromptCacheStream({ + fetchMock, + now, + sessionManager, + streamFn: innerStreamFn, + }); void wrapped?.( makeGoogleModel(), @@ -311,7 +273,7 @@ describe("google prompt cache", () => { expect.objectContaining({ systemPrompt: undefined }), expect.any(Object), ); - expect(capturedPayload).toMatchObject({ + expect(getCapturedPayload()).toMatchObject({ cachedContent: "cachedContents/system-cache-3", }); }); diff --git a/src/agents/pi-embedded-subscribe.handlers.messages.test.ts b/src/agents/pi-embedded-subscribe.handlers.messages.test.ts index 17e6ff6bd45..137b41b232a 100644 --- a/src/agents/pi-embedded-subscribe.handlers.messages.test.ts +++ b/src/agents/pi-embedded-subscribe.handlers.messages.test.ts @@ -10,6 +10,109 @@ import { resolveSilentReplyFallbackText, } from "./pi-embedded-subscribe.handlers.messages.js"; import type { EmbeddedPiSubscribeContext } from "./pi-embedded-subscribe.handlers.types.js"; +import { + createOpenAiResponsesPartial, + createOpenAiResponsesTextBlock, + createOpenAiResponsesTextEvent as createTextUpdateEvent, +} from "./pi-embedded-subscribe.openai-responses.test-helpers.js"; + +function createMessageUpdateContext( + params: { + onAgentEvent?: ReturnType; + onPartialReply?: ReturnType; + flushBlockReplyBuffer?: ReturnType; + debug?: ReturnType; + shouldEmitPartialReplies?: boolean; + } = {}, +) { + return { + params: { + runId: "run-1", + session: { id: "session-1" }, + ...(params.onAgentEvent ? { onAgentEvent: params.onAgentEvent } : {}), + ...(params.onPartialReply ? { onPartialReply: params.onPartialReply } : {}), + }, + state: { + deterministicApprovalPromptPending: false, + deterministicApprovalPromptSent: false, + reasoningStreamOpen: false, + streamReasoning: false, + deltaBuffer: "", + blockBuffer: "", + partialBlockState: { + thinking: false, + final: false, + inlineCode: createInlineCodeState(), + }, + lastStreamedAssistant: undefined, + lastStreamedAssistantCleaned: undefined, + emittedAssistantUpdate: false, + shouldEmitPartialReplies: params.shouldEmitPartialReplies ?? true, + blockReplyBreak: "text_end", + assistantMessageIndex: 0, + }, + log: { debug: params.debug ?? vi.fn() }, + noteLastAssistant: vi.fn(), + stripBlockTags: (text: string) => text, + consumePartialReplyDirectives: vi.fn(() => null), + emitReasoningStream: vi.fn(), + flushBlockReplyBuffer: params.flushBlockReplyBuffer ?? vi.fn(), + } as unknown as EmbeddedPiSubscribeContext; +} + +function createMessageEndContext( + params: { + onAgentEvent?: ReturnType; + onBlockReply?: ReturnType; + emitBlockReply?: ReturnType; + finalizeAssistantTexts?: ReturnType; + consumeReplyDirectives?: ReturnType; + state?: Record; + } = {}, +) { + return { + params: { + runId: "run-1", + session: { id: "session-1" }, + ...(params.onAgentEvent ? { onAgentEvent: params.onAgentEvent } : {}), + ...(params.onBlockReply ? { onBlockReply: params.onBlockReply } : { onBlockReply: vi.fn() }), + }, + state: { + assistantTexts: [], + assistantTextBaseline: 0, + emittedAssistantUpdate: false, + deterministicApprovalPromptPending: false, + deterministicApprovalPromptSent: false, + messagingToolSentTexts: [], + messagingToolSentTextsNormalized: [], + includeReasoning: false, + streamReasoning: false, + blockReplyBreak: "message_end", + deltaBuffer: "Need send.", + blockBuffer: "Need send.", + blockState: { + thinking: false, + final: false, + inlineCode: createInlineCodeState(), + }, + lastStreamedAssistant: undefined, + lastStreamedAssistantCleaned: undefined, + lastReasoningSent: undefined, + reasoningStreamOpen: false, + ...params.state, + }, + noteLastAssistant: vi.fn(), + recordAssistantUsage: vi.fn(), + log: { debug: vi.fn(), warn: vi.fn() }, + stripBlockTags: (text: string) => text, + finalizeAssistantTexts: params.finalizeAssistantTexts ?? vi.fn(), + emitBlockReply: params.emitBlockReply ?? vi.fn(), + consumeReplyDirectives: params.consumeReplyDirectives ?? vi.fn(() => ({ text: "Need send." })), + emitReasoningStream: vi.fn(), + flushBlockReplyBuffer: vi.fn(), + blockChunker: null, + } as unknown as EmbeddedPiSubscribeContext; +} describe("resolveSilentReplyFallbackText", () => { it("replaces NO_REPLY with latest messaging tool text when available", () => { @@ -145,48 +248,20 @@ describe("handleMessageUpdate", () => { const onAgentEvent = vi.fn(); const onPartialReply = vi.fn(); const flushBlockReplyBuffer = vi.fn(); - const ctx = { - params: { - runId: "run-1", - session: { id: "session-1" }, - onAgentEvent, - onPartialReply, - }, - state: { - deterministicApprovalPromptPending: false, - deterministicApprovalPromptSent: false, - reasoningStreamOpen: false, - streamReasoning: false, - deltaBuffer: "", - blockBuffer: "", - partialBlockState: { - thinking: false, - final: false, - inlineCode: createInlineCodeState(), - }, - lastStreamedAssistantCleaned: undefined, - emittedAssistantUpdate: false, - shouldEmitPartialReplies: true, - blockReplyBreak: "text_end", - assistantMessageIndex: 0, - }, - log: { debug: vi.fn() }, - noteLastAssistant: vi.fn(), - stripBlockTags: (text: string) => text, - consumePartialReplyDirectives: vi.fn(() => null), + const ctx = createMessageUpdateContext({ + onAgentEvent, + onPartialReply, flushBlockReplyBuffer, - } as unknown as EmbeddedPiSubscribeContext; + }); - handleMessageUpdate(ctx, { - type: "message_update", - message: { role: "assistant", phase: "commentary", content: [] }, - assistantMessageEvent: { type: "text_delta", delta: "Need send." }, - } as never); - handleMessageUpdate(ctx, { - type: "message_update", - message: { role: "assistant", phase: "commentary", content: [] }, - assistantMessageEvent: { type: "text_end", content: "Need send." }, - } as never); + handleMessageUpdate( + ctx, + createTextUpdateEvent({ type: "text_delta", text: "Need send.", messagePhase: "commentary" }), + ); + handleMessageUpdate( + ctx, + createTextUpdateEvent({ type: "text_end", text: "Need send.", messagePhase: "commentary" }), + ); await Promise.resolve(); @@ -199,53 +274,33 @@ describe("handleMessageUpdate", () => { const onAgentEvent = vi.fn(); const onPartialReply = vi.fn(); const flushBlockReplyBuffer = vi.fn(); - const commentaryBlock = { - type: "text", + const commentaryBlock = createOpenAiResponsesTextBlock({ text: "Need send.", - textSignature: JSON.stringify({ v: 1, id: "msg_sig", phase: "commentary" }), - }; - const ctx = { - params: { - runId: "run-1", - session: { id: "session-1" }, - onAgentEvent, - onPartialReply, - }, - state: { - deterministicApprovalPromptPending: false, - deterministicApprovalPromptSent: false, - reasoningStreamOpen: false, - streamReasoning: false, - deltaBuffer: "", - blockBuffer: "", - partialBlockState: { - thinking: false, - final: false, - inlineCode: createInlineCodeState(), - }, - lastStreamedAssistantCleaned: undefined, - emittedAssistantUpdate: false, - shouldEmitPartialReplies: true, - blockReplyBreak: "text_end", - assistantMessageIndex: 0, - }, - log: { debug: vi.fn() }, - noteLastAssistant: vi.fn(), - stripBlockTags: (text: string) => text, - consumePartialReplyDirectives: vi.fn(() => null), + id: "msg_sig", + phase: "commentary", + }); + const ctx = createMessageUpdateContext({ + onAgentEvent, + onPartialReply, flushBlockReplyBuffer, - } as unknown as EmbeddedPiSubscribeContext; + }); - handleMessageUpdate(ctx, { - type: "message_update", - message: { role: "assistant", content: [commentaryBlock] }, - assistantMessageEvent: { type: "text_delta", delta: "Need send." }, - } as never); - handleMessageUpdate(ctx, { - type: "message_update", - message: { role: "assistant", content: [commentaryBlock] }, - assistantMessageEvent: { type: "text_end", content: "Need send." }, - } as never); + handleMessageUpdate( + ctx, + createTextUpdateEvent({ + type: "text_delta", + text: "Need send.", + content: [commentaryBlock], + }), + ); + handleMessageUpdate( + ctx, + createTextUpdateEvent({ + type: "text_end", + text: "Need send.", + content: [commentaryBlock], + }), + ); await Promise.resolve(); @@ -258,93 +313,42 @@ describe("handleMessageUpdate", () => { it("suppresses commentary partials even when they contain visible text", () => { const onAgentEvent = vi.fn(); - const ctx = { - params: { - runId: "run-1", - session: { id: "session-1" }, - onAgentEvent, - }, - state: { - deterministicApprovalPromptPending: false, - deterministicApprovalPromptSent: false, - reasoningStreamOpen: false, - streamReasoning: false, - deltaBuffer: "", - blockBuffer: "", - partialBlockState: { - thinking: false, - final: false, - inlineCode: createInlineCodeState(), - }, - lastStreamedAssistant: undefined, - lastStreamedAssistantCleaned: undefined, - emittedAssistantUpdate: false, - shouldEmitPartialReplies: false, - blockReplyBreak: "text_end", - }, - log: { debug: vi.fn() }, - noteLastAssistant: vi.fn(), - stripBlockTags: (text: string) => text, - consumePartialReplyDirectives: vi.fn(() => null), - emitReasoningStream: vi.fn(), - flushBlockReplyBuffer: vi.fn(), - } as unknown as EmbeddedPiSubscribeContext; + const ctx = createMessageUpdateContext({ + onAgentEvent, + shouldEmitPartialReplies: false, + }); - handleMessageUpdate(ctx, { - type: "message_update", - message: { role: "assistant", content: [] }, - assistantMessageEvent: { + handleMessageUpdate( + ctx, + createTextUpdateEvent({ type: "text_delta", - delta: "Working...", - partial: { - role: "assistant", - content: [ - { - type: "text", - text: "Working...", - textSignature: JSON.stringify({ v: 1, id: "item_commentary", phase: "commentary" }), - }, - ], - phase: "commentary", - stopReason: "stop", - api: "openai-responses", - provider: "openai", - model: "gpt-5.2", - usage: {}, - timestamp: 0, - }, - }, - } as never); + text: "Working...", + partial: createOpenAiResponsesPartial({ + text: "Working...", + id: "item_commentary", + signaturePhase: "commentary", + partialPhase: "commentary", + }), + }), + ); expect(onAgentEvent).not.toHaveBeenCalled(); expect(ctx.state.deltaBuffer).toBe(""); expect(ctx.state.blockBuffer).toBe(""); - handleMessageUpdate(ctx, { - type: "message_update", - message: { role: "assistant", content: [] }, - assistantMessageEvent: { + handleMessageUpdate( + ctx, + createTextUpdateEvent({ type: "text_delta", - delta: "Done.", - partial: { - role: "assistant", - content: [ - { - type: "text", - text: "Done.", - textSignature: JSON.stringify({ v: 1, id: "item_final", phase: "final_answer" }), - }, - ], - phase: "final_answer", - stopReason: "stop", - api: "openai-responses", - provider: "openai", - model: "gpt-5.2", - usage: {}, - timestamp: 0, - }, - }, - } as never); + text: "Done.", + partial: createOpenAiResponsesPartial({ + text: "Done.", + id: "item_final", + signaturePhase: "final_answer", + partialPhase: "final_answer", + }), + }), + ); expect(onAgentEvent).toHaveBeenCalledTimes(1); expect(onAgentEvent.mock.calls[0]?.[0]).toMatchObject({ @@ -358,42 +362,15 @@ describe("handleMessageUpdate", () => { it("contains synchronous text_end flush failures", async () => { const debug = vi.fn(); - const ctx = { - params: { - runId: "run-1", - session: { id: "session-1" }, - }, - state: { - deterministicApprovalPromptPending: false, - deterministicApprovalPromptSent: false, - reasoningStreamOpen: false, - streamReasoning: false, - deltaBuffer: "", - blockBuffer: "", - partialBlockState: { - thinking: false, - final: false, - inlineCode: createInlineCodeState(), - }, - lastStreamedAssistantCleaned: undefined, - emittedAssistantUpdate: false, - shouldEmitPartialReplies: false, - blockReplyBreak: "text_end", - }, - log: { debug }, - noteLastAssistant: vi.fn(), - stripBlockTags: (text: string) => text, - consumePartialReplyDirectives: vi.fn(() => null), + const ctx = createMessageUpdateContext({ + debug, + shouldEmitPartialReplies: false, flushBlockReplyBuffer: vi.fn(() => { throw new Error("boom"); }), - } as unknown as EmbeddedPiSubscribeContext; + }); - handleMessageUpdate(ctx, { - type: "message_update", - message: { role: "assistant", content: [] }, - assistantMessageEvent: { type: "text_end" }, - } as never); + handleMessageUpdate(ctx, createTextUpdateEvent({ type: "text_end", text: "" })); await vi.waitFor(() => { expect(debug).toHaveBeenCalledWith("text_end block reply flush failed: Error: boom"); @@ -406,44 +383,11 @@ describe("handleMessageEnd", () => { const onAgentEvent = vi.fn(); const emitBlockReply = vi.fn(); const finalizeAssistantTexts = vi.fn(); - const ctx = { - params: { - runId: "run-1", - session: { id: "session-1" }, - onAgentEvent, - onBlockReply: vi.fn(), - }, - state: { - assistantTexts: [], - assistantTextBaseline: 0, - emittedAssistantUpdate: false, - deterministicApprovalPromptPending: false, - deterministicApprovalPromptSent: false, - reasoningStreamOpen: false, - includeReasoning: false, - streamReasoning: false, - blockReplyBreak: "message_end", - deltaBuffer: "Need send.", - blockBuffer: "Need send.", - blockState: { - thinking: false, - final: false, - inlineCode: createInlineCodeState(), - }, - lastStreamedAssistant: undefined, - lastStreamedAssistantCleaned: undefined, - }, - noteLastAssistant: vi.fn(), - recordAssistantUsage: vi.fn(), - log: { debug: vi.fn(), warn: vi.fn() }, - stripBlockTags: (text: string) => text, + const ctx = createMessageEndContext({ + onAgentEvent, finalizeAssistantTexts, emitBlockReply, - consumeReplyDirectives: vi.fn(() => ({ text: "Need send." })), - emitReasoningStream: vi.fn(), - flushBlockReplyBuffer: vi.fn(), - blockChunker: null, - } as unknown as EmbeddedPiSubscribeContext; + }); void handleMessageEnd(ctx, { type: "message_end", @@ -464,55 +408,22 @@ describe("handleMessageEnd", () => { const onAgentEvent = vi.fn(); const emitBlockReply = vi.fn(); const finalizeAssistantTexts = vi.fn(); - const ctx = { - params: { - runId: "run-1", - session: { id: "session-1" }, - onAgentEvent, - onBlockReply: vi.fn(), - }, - state: { - assistantTexts: [], - assistantTextBaseline: 0, - emittedAssistantUpdate: false, - deterministicApprovalPromptPending: false, - deterministicApprovalPromptSent: false, - reasoningStreamOpen: false, - includeReasoning: false, - streamReasoning: false, - blockReplyBreak: "message_end", - deltaBuffer: "Need send.", - blockBuffer: "Need send.", - blockState: { - thinking: false, - final: false, - inlineCode: createInlineCodeState(), - }, - lastStreamedAssistant: undefined, - lastStreamedAssistantCleaned: undefined, - }, - noteLastAssistant: vi.fn(), - recordAssistantUsage: vi.fn(), - log: { debug: vi.fn(), warn: vi.fn() }, - stripBlockTags: (text: string) => text, + const ctx = createMessageEndContext({ + onAgentEvent, finalizeAssistantTexts, emitBlockReply, - consumeReplyDirectives: vi.fn(() => ({ text: "Need send." })), - emitReasoningStream: vi.fn(), - flushBlockReplyBuffer: vi.fn(), - blockChunker: null, - } as unknown as EmbeddedPiSubscribeContext; + }); void handleMessageEnd(ctx, { type: "message_end", message: { role: "assistant", content: [ - { - type: "text", + createOpenAiResponsesTextBlock({ text: "Need send.", - textSignature: JSON.stringify({ v: 1, id: "msg_sig", phase: "commentary" }), - }, + id: "msg_sig", + phase: "commentary", + }), ], usage: { input: 1, output: 1, total: 2 }, }, @@ -530,47 +441,20 @@ describe("handleMessageEnd", () => { // input. The non-empty call shouldn't happen for text_end channels (that's // the safety send we're guarding against). const consumeReplyDirectives = vi.fn((text: string) => (text ? { text } : null)); - const ctx = { - params: { - runId: "run-1", - session: { id: "session-1" }, - onBlockReply, - }, + const ctx = createMessageEndContext({ + onBlockReply, + emitBlockReply, + consumeReplyDirectives, state: { - deterministicApprovalPromptPending: false, - deterministicApprovalPromptSent: false, - messagingToolSentTexts: [], - messagingToolSentTextsNormalized: [], - includeReasoning: false, - streamReasoning: false, emittedAssistantUpdate: true, lastStreamedAssistantCleaned: "Hello world", - assistantTexts: [], - assistantTextBaseline: 0, blockReplyBreak: "text_end", // Simulate text_end already delivered this text through emitBlockChunk lastBlockReplyText: "Hello world", - lastReasoningSent: undefined, - reasoningStreamOpen: false, deltaBuffer: "", blockBuffer: "", - blockState: { - thinking: false, - final: false, - inlineCode: createInlineCodeState(), - }, }, - log: { debug: vi.fn() }, - noteLastAssistant: vi.fn(), - recordAssistantUsage: vi.fn(), - stripBlockTags: (text: string) => text, - finalizeAssistantTexts: vi.fn(), - emitBlockReply, - consumeReplyDirectives, - emitReasoningStream: vi.fn(), - flushBlockReplyBuffer: vi.fn(), - blockChunker: null, - } as unknown as EmbeddedPiSubscribeContext; + }); void handleMessageEnd(ctx, { type: "message_end", @@ -592,47 +476,20 @@ describe("handleMessageEnd", () => { const emitBlockReply = vi.fn(); // Same pattern: directive accumulator returns null for empty final flush const consumeReplyDirectives = vi.fn((text: string) => (text ? { text } : null)); - const ctx = { - params: { - runId: "run-1", - session: { id: "session-1" }, - onBlockReply, - }, + const ctx = createMessageEndContext({ + onBlockReply, + emitBlockReply, + consumeReplyDirectives, state: { - deterministicApprovalPromptPending: false, - deterministicApprovalPromptSent: false, - messagingToolSentTexts: [], - messagingToolSentTextsNormalized: [], - includeReasoning: false, - streamReasoning: false, emittedAssistantUpdate: true, lastStreamedAssistantCleaned: "Hello world", - assistantTexts: [], - assistantTextBaseline: 0, blockReplyBreak: "text_end", // text_end delivered via emitBlockChunk which uses different stripping lastBlockReplyText: "Hello world.", - lastReasoningSent: undefined, - reasoningStreamOpen: false, deltaBuffer: "", blockBuffer: "", - blockState: { - thinking: false, - final: false, - inlineCode: createInlineCodeState(), - }, }, - log: { debug: vi.fn() }, - noteLastAssistant: vi.fn(), - recordAssistantUsage: vi.fn(), - stripBlockTags: (text: string) => text, - finalizeAssistantTexts: vi.fn(), - emitBlockReply, - consumeReplyDirectives, - emitReasoningStream: vi.fn(), - flushBlockReplyBuffer: vi.fn(), - blockChunker: null, - } as unknown as EmbeddedPiSubscribeContext; + }); void handleMessageEnd(ctx, { type: "message_end", @@ -652,58 +509,32 @@ describe("handleMessageEnd", () => { it("emits a replacement final assistant event when final_answer appears only at message_end", () => { const onAgentEvent = vi.fn(); - const ctx = { - params: { - runId: "run-1", - session: { id: "session-1" }, - onAgentEvent, - }, + const ctx = createMessageEndContext({ + onAgentEvent, state: { - deterministicApprovalPromptPending: false, - deterministicApprovalPromptSent: false, - messagingToolSentTexts: [], - messagingToolSentTextsNormalized: [], - includeReasoning: false, - streamReasoning: false, emittedAssistantUpdate: true, lastStreamedAssistantCleaned: "Working...", - assistantTexts: [], - assistantTextBaseline: 0, blockReplyBreak: "text_end", - lastReasoningSent: undefined, - reasoningStreamOpen: false, deltaBuffer: "", blockBuffer: "", - blockState: { - thinking: false, - final: false, - inlineCode: createInlineCodeState(), - }, }, - log: { debug: vi.fn() }, - noteLastAssistant: vi.fn(), - recordAssistantUsage: vi.fn(), - stripBlockTags: (text: string) => text, - finalizeAssistantTexts: vi.fn(), - emitReasoningStream: vi.fn(), - blockChunker: null, - } as unknown as EmbeddedPiSubscribeContext; + }); void handleMessageEnd(ctx, { type: "message_end", message: { role: "assistant", content: [ - { - type: "text", + createOpenAiResponsesTextBlock({ text: "Working...", - textSignature: JSON.stringify({ v: 1, id: "item_commentary", phase: "commentary" }), - }, - { - type: "text", + id: "item_commentary", + phase: "commentary", + }), + createOpenAiResponsesTextBlock({ text: "Done.", - textSignature: JSON.stringify({ v: 1, id: "item_final", phase: "final_answer" }), - }, + id: "item_final", + phase: "final_answer", + }), ], stopReason: "stop", api: "openai-responses", diff --git a/src/agents/pi-embedded-subscribe.openai-responses.test-helpers.ts b/src/agents/pi-embedded-subscribe.openai-responses.test-helpers.ts new file mode 100644 index 00000000000..eea7aa876e7 --- /dev/null +++ b/src/agents/pi-embedded-subscribe.openai-responses.test-helpers.ts @@ -0,0 +1,81 @@ +export type OpenAiResponsesTextEventPhase = "commentary" | "final_answer"; + +export function createOpenAiResponsesTextBlock(params: { + text: string; + id: string; + phase?: OpenAiResponsesTextEventPhase; +}) { + return { + type: "text", + text: params.text, + textSignature: JSON.stringify({ + v: 1, + id: params.id, + ...(params.phase ? { phase: params.phase } : {}), + }), + }; +} + +export function createOpenAiResponsesPartial(params: { + text: string; + id: string; + signaturePhase?: OpenAiResponsesTextEventPhase; + partialPhase?: OpenAiResponsesTextEventPhase; +}) { + return { + role: "assistant", + content: [ + createOpenAiResponsesTextBlock({ + text: params.text, + id: params.id, + phase: params.signaturePhase, + }), + ], + ...(params.partialPhase ? { phase: params.partialPhase } : {}), + stopReason: "stop", + api: "openai-responses", + provider: "openai", + model: "gpt-5.2", + usage: {}, + timestamp: 0, + }; +} + +export function createOpenAiResponsesTextEvent(params: { + type: "text_delta" | "text_end"; + text: string; + delta?: string; + id?: string; + signaturePhase?: OpenAiResponsesTextEventPhase; + partialPhase?: OpenAiResponsesTextEventPhase; + messagePhase?: OpenAiResponsesTextEventPhase; + content?: unknown[]; + partial?: ReturnType; +}) { + const partial = + params.partial ?? + (params.id + ? createOpenAiResponsesPartial({ + text: params.text, + id: params.id, + signaturePhase: params.signaturePhase, + partialPhase: params.partialPhase, + }) + : undefined); + + return { + type: "message_update", + message: { + role: "assistant", + ...(params.messagePhase ? { phase: params.messagePhase } : {}), + content: params.content ?? [], + }, + assistantMessageEvent: { + type: params.type, + ...(params.type === "text_delta" + ? { delta: params.delta ?? params.text } + : { content: params.text }), + ...(partial ? { partial } : {}), + }, + } as never; +} diff --git a/src/agents/pi-embedded-subscribe.subscribe-embedded-pi-session.emits-block-replies-text-end-does-not.test.ts b/src/agents/pi-embedded-subscribe.subscribe-embedded-pi-session.emits-block-replies-text-end-does-not.test.ts index f2a8fc643b7..a254d4e6a76 100644 --- a/src/agents/pi-embedded-subscribe.subscribe-embedded-pi-session.emits-block-replies-text-end-does-not.test.ts +++ b/src/agents/pi-embedded-subscribe.subscribe-embedded-pi-session.emits-block-replies-text-end-does-not.test.ts @@ -5,6 +5,74 @@ import { emitAssistantTextDelta, emitAssistantTextEnd, } from "./pi-embedded-subscribe.e2e-harness.js"; +import { + createOpenAiResponsesTextBlock, + createOpenAiResponsesTextEvent, + type OpenAiResponsesTextEventPhase, +} from "./pi-embedded-subscribe.openai-responses.test-helpers.js"; + +type TextEndBlockReplyHarness = ReturnType; + +function emitOpenAiResponsesTextEvent(params: { + emit: TextEndBlockReplyHarness["emit"]; + type: "text_delta" | "text_end"; + text: string; + delta?: string; + id: string; + signaturePhase?: OpenAiResponsesTextEventPhase; + partialPhase?: OpenAiResponsesTextEventPhase; +}) { + const { emit, ...eventParams } = params; + emit(createOpenAiResponsesTextEvent(eventParams)); +} + +function emitOpenAiResponsesTextDeltaAndEnd(params: { + emit: TextEndBlockReplyHarness["emit"]; + text: string; + delta?: string; + id: string; + phase?: OpenAiResponsesTextEventPhase; +}) { + const { phase, ...eventParams } = params; + emitOpenAiResponsesTextEvent({ + ...eventParams, + type: "text_delta", + signaturePhase: phase, + partialPhase: phase, + }); + emitOpenAiResponsesTextEvent({ + ...eventParams, + type: "text_end", + delta: undefined, + signaturePhase: phase, + partialPhase: phase, + }); +} + +function emitOpenAiResponsesFinalMessageEnd(params: { + emit: TextEndBlockReplyHarness["emit"]; + commentaryText: string; + finalText: string; +}) { + params.emit({ + type: "message_end", + message: { + role: "assistant", + content: [ + createOpenAiResponsesTextBlock({ + text: params.commentaryText, + id: "item_commentary", + phase: "commentary", + }), + createOpenAiResponsesTextBlock({ + text: params.finalText, + id: "item_final", + phase: "final_answer", + }), + ], + } as AssistantMessage, + }); +} describe("subscribeEmbeddedPiSession", () => { it("emits block replies on text_end and does not duplicate on message_end", async () => { @@ -65,53 +133,17 @@ describe("subscribeEmbeddedPiSession", () => { const { emit, subscription } = createTextEndBlockReplyHarness({ onBlockReply }); emit({ type: "message_start", message: { role: "assistant" } }); - emit({ - type: "message_update", - message: { role: "assistant", content: [] }, - assistantMessageEvent: { - type: "text_delta", - delta: "Legacy answer", - partial: { - role: "assistant", - content: [ - { - type: "text", - text: "Legacy answer", - textSignature: JSON.stringify({ v: 1, id: "item_legacy" }), - }, - ], - stopReason: "stop", - api: "openai-responses", - provider: "openai", - model: "gpt-5.2", - usage: {}, - timestamp: 0, - }, - }, + emitOpenAiResponsesTextEvent({ + emit, + type: "text_delta", + text: "Legacy answer", + id: "item_legacy", }); - emit({ - type: "message_update", - message: { role: "assistant", content: [] }, - assistantMessageEvent: { - type: "text_end", - content: "Legacy answer", - partial: { - role: "assistant", - content: [ - { - type: "text", - text: "Legacy answer", - textSignature: JSON.stringify({ v: 1, id: "item_legacy" }), - }, - ], - stopReason: "stop", - api: "openai-responses", - provider: "openai", - model: "gpt-5.2", - usage: {}, - timestamp: 0, - }, - }, + emitOpenAiResponsesTextEvent({ + emit, + type: "text_end", + text: "Legacy answer", + id: "item_legacy", }); await Promise.resolve(); @@ -136,131 +168,26 @@ describe("subscribeEmbeddedPiSession", () => { const { emit, subscription } = createTextEndBlockReplyHarness({ onBlockReply }); emit({ type: "message_start", message: { role: "assistant" } }); - emit({ - type: "message_update", - message: { role: "assistant", content: [] }, - assistantMessageEvent: { - type: "text_delta", - delta: "Working...", - partial: { - role: "assistant", - content: [ - { - type: "text", - text: "Working...", - textSignature: JSON.stringify({ v: 1, id: "item_commentary", phase: "commentary" }), - }, - ], - phase: "commentary", - stopReason: "stop", - api: "openai-responses", - provider: "openai", - model: "gpt-5.2", - usage: {}, - timestamp: 0, - }, - }, - }); - emit({ - type: "message_update", - message: { role: "assistant", content: [] }, - assistantMessageEvent: { - type: "text_end", - content: "Working...", - partial: { - role: "assistant", - content: [ - { - type: "text", - text: "Working...", - textSignature: JSON.stringify({ v: 1, id: "item_commentary", phase: "commentary" }), - }, - ], - phase: "commentary", - stopReason: "stop", - api: "openai-responses", - provider: "openai", - model: "gpt-5.2", - usage: {}, - timestamp: 0, - }, - }, + emitOpenAiResponsesTextDeltaAndEnd({ + emit, + text: "Working...", + id: "item_commentary", + phase: "commentary", }); await Promise.resolve(); expect(onBlockReply).not.toHaveBeenCalled(); expect(subscription.assistantTexts).toEqual([]); - emit({ - type: "message_update", - message: { role: "assistant", content: [] }, - assistantMessageEvent: { - type: "text_delta", - delta: "Done.", - partial: { - role: "assistant", - content: [ - { - type: "text", - text: "Done.", - textSignature: JSON.stringify({ v: 1, id: "item_final", phase: "final_answer" }), - }, - ], - phase: "final_answer", - stopReason: "stop", - api: "openai-responses", - provider: "openai", - model: "gpt-5.2", - usage: {}, - timestamp: 0, - }, - }, - }); - emit({ - type: "message_update", - message: { role: "assistant", content: [] }, - assistantMessageEvent: { - type: "text_end", - content: "Done.", - partial: { - role: "assistant", - content: [ - { - type: "text", - text: "Done.", - textSignature: JSON.stringify({ v: 1, id: "item_final", phase: "final_answer" }), - }, - ], - phase: "final_answer", - stopReason: "stop", - api: "openai-responses", - provider: "openai", - model: "gpt-5.2", - usage: {}, - timestamp: 0, - }, - }, + emitOpenAiResponsesTextDeltaAndEnd({ + emit, + text: "Done.", + id: "item_final", + phase: "final_answer", }); await Promise.resolve(); - emit({ - type: "message_end", - message: { - role: "assistant", - content: [ - { - type: "text", - text: "Working...", - textSignature: JSON.stringify({ v: 1, id: "item_commentary", phase: "commentary" }), - }, - { - type: "text", - text: "Done.", - textSignature: JSON.stringify({ v: 1, id: "item_final", phase: "final_answer" }), - }, - ], - } as AssistantMessage, - }); + emitOpenAiResponsesFinalMessageEnd({ emit, commentaryText: "Working...", finalText: "Done." }); expect(onBlockReply).toHaveBeenCalledTimes(1); expect(onBlockReply.mock.calls[0]?.[0]?.text).toBe("Done."); @@ -272,109 +199,22 @@ describe("subscribeEmbeddedPiSession", () => { const { emit, subscription } = createTextEndBlockReplyHarness({ onBlockReply }); emit({ type: "message_start", message: { role: "assistant" } }); - emit({ - type: "message_update", - message: { role: "assistant", content: [] }, - assistantMessageEvent: { - type: "text_delta", - delta: "Hello", - partial: { - role: "assistant", - content: [ - { - type: "text", - text: "Hello", - textSignature: JSON.stringify({ v: 1, id: "item_commentary", phase: "commentary" }), - }, - ], - phase: "commentary", - stopReason: "stop", - api: "openai-responses", - provider: "openai", - model: "gpt-5.2", - usage: {}, - timestamp: 0, - }, - }, - }); - emit({ - type: "message_update", - message: { role: "assistant", content: [] }, - assistantMessageEvent: { - type: "text_end", - content: "Hello", - partial: { - role: "assistant", - content: [ - { - type: "text", - text: "Hello", - textSignature: JSON.stringify({ v: 1, id: "item_commentary", phase: "commentary" }), - }, - ], - phase: "commentary", - stopReason: "stop", - api: "openai-responses", - provider: "openai", - model: "gpt-5.2", - usage: {}, - timestamp: 0, - }, - }, + emitOpenAiResponsesTextDeltaAndEnd({ + emit, + text: "Hello", + id: "item_commentary", + phase: "commentary", }); await Promise.resolve(); expect(onBlockReply).not.toHaveBeenCalled(); - emit({ - type: "message_update", - message: { role: "assistant", content: [] }, - assistantMessageEvent: { - type: "text_delta", - delta: " world", - partial: { - role: "assistant", - content: [ - { - type: "text", - text: "Hello world", - textSignature: JSON.stringify({ v: 1, id: "item_final", phase: "final_answer" }), - }, - ], - phase: "final_answer", - stopReason: "stop", - api: "openai-responses", - provider: "openai", - model: "gpt-5.2", - usage: {}, - timestamp: 0, - }, - }, - }); - emit({ - type: "message_update", - message: { role: "assistant", content: [] }, - assistantMessageEvent: { - type: "text_end", - content: "Hello world", - partial: { - role: "assistant", - content: [ - { - type: "text", - text: "Hello world", - textSignature: JSON.stringify({ v: 1, id: "item_final", phase: "final_answer" }), - }, - ], - phase: "final_answer", - stopReason: "stop", - api: "openai-responses", - provider: "openai", - model: "gpt-5.2", - usage: {}, - timestamp: 0, - }, - }, + emitOpenAiResponsesTextDeltaAndEnd({ + emit, + text: "Hello world", + delta: " world", + id: "item_final", + phase: "final_answer", }); await Promise.resolve(); @@ -388,53 +228,19 @@ describe("subscribeEmbeddedPiSession", () => { const { emit, subscription } = createTextEndBlockReplyHarness({ onBlockReply }); emit({ type: "message_start", message: { role: "assistant" } }); - emit({ - type: "message_update", - message: { role: "assistant", content: [] }, - assistantMessageEvent: { - type: "text_delta", - delta: "Done.", - partial: { - role: "assistant", - content: [ - { - type: "text", - text: "Done.", - textSignature: JSON.stringify({ v: 1, id: "item_final", phase: "final_answer" }), - }, - ], - stopReason: "stop", - api: "openai-responses", - provider: "openai", - model: "gpt-5.2", - usage: {}, - timestamp: 0, - }, - }, + emitOpenAiResponsesTextEvent({ + emit, + type: "text_delta", + text: "Done.", + id: "item_final", + signaturePhase: "final_answer", }); - emit({ - type: "message_update", - message: { role: "assistant", content: [] }, - assistantMessageEvent: { - type: "text_end", - content: "Done.", - partial: { - role: "assistant", - content: [ - { - type: "text", - text: "Done.", - textSignature: JSON.stringify({ v: 1, id: "item_final", phase: "final_answer" }), - }, - ], - stopReason: "stop", - api: "openai-responses", - provider: "openai", - model: "gpt-5.2", - usage: {}, - timestamp: 0, - }, - }, + emitOpenAiResponsesTextEvent({ + emit, + type: "text_end", + text: "Done.", + id: "item_final", + signaturePhase: "final_answer", }); await Promise.resolve(); @@ -448,76 +254,15 @@ describe("subscribeEmbeddedPiSession", () => { const { emit, subscription } = createTextEndBlockReplyHarness({ onBlockReply }); emit({ type: "message_start", message: { role: "assistant" } }); - emit({ - type: "message_update", - message: { role: "assistant", content: [] }, - assistantMessageEvent: { - type: "text_delta", - delta: "Working...", - partial: { - role: "assistant", - content: [ - { - type: "text", - text: "Working...", - textSignature: JSON.stringify({ v: 1, id: "item_commentary", phase: "commentary" }), - }, - ], - phase: "commentary", - stopReason: "stop", - api: "openai-responses", - provider: "openai", - model: "gpt-5.2", - usage: {}, - timestamp: 0, - }, - }, - }); - emit({ - type: "message_update", - message: { role: "assistant", content: [] }, - assistantMessageEvent: { - type: "text_end", - content: "Working...", - partial: { - role: "assistant", - content: [ - { - type: "text", - text: "Working...", - textSignature: JSON.stringify({ v: 1, id: "item_commentary", phase: "commentary" }), - }, - ], - phase: "commentary", - stopReason: "stop", - api: "openai-responses", - provider: "openai", - model: "gpt-5.2", - usage: {}, - timestamp: 0, - }, - }, + emitOpenAiResponsesTextDeltaAndEnd({ + emit, + text: "Working...", + id: "item_commentary", + phase: "commentary", }); await Promise.resolve(); - emit({ - type: "message_end", - message: { - role: "assistant", - content: [ - { - type: "text", - text: "Working...", - textSignature: JSON.stringify({ v: 1, id: "item_commentary", phase: "commentary" }), - }, - { - type: "text", - text: "Done.", - textSignature: JSON.stringify({ v: 1, id: "item_final", phase: "final_answer" }), - }, - ], - } as AssistantMessage, - }); + emitOpenAiResponsesFinalMessageEnd({ emit, commentaryText: "Working...", finalText: "Done." }); expect(onBlockReply).toHaveBeenCalledTimes(1); expect(onBlockReply.mock.calls[0]?.[0]?.text).toBe("Done."); diff --git a/src/agents/tools/cron-tool.ts b/src/agents/tools/cron-tool.ts index ed79e5b6122..9bea59a8375 100644 --- a/src/agents/tools/cron-tool.ts +++ b/src/agents/tools/cron-tool.ts @@ -182,26 +182,7 @@ const CronPatchObjectSchema = Type.Optional( Type.Object( { name: Type.Optional(Type.String({ description: "Job name" })), - schedule: Type.Optional( - Type.Object( - { - kind: optionalStringEnum(CRON_SCHEDULE_KINDS, { description: "Schedule type" }), - at: Type.Optional(Type.String({ description: "ISO-8601 timestamp (kind=at)" })), - everyMs: Type.Optional( - Type.Number({ description: "Interval in milliseconds (kind=every)" }), - ), - anchorMs: Type.Optional( - Type.Number({ description: "Optional start anchor in milliseconds (kind=every)" }), - ), - expr: Type.Optional(Type.String({ description: "Cron expression (kind=cron)" })), - tz: Type.Optional(Type.String({ description: "IANA timezone (kind=cron)" })), - staggerMs: Type.Optional( - Type.Number({ description: "Random jitter in ms (kind=cron)" }), - ), - }, - { additionalProperties: true }, - ), - ), + schedule: CronScheduleSchema, sessionTarget: Type.Optional(Type.String({ description: "Session target" })), wakeMode: optionalStringEnum(CRON_WAKE_MODES), payload: Type.Optional( @@ -209,29 +190,7 @@ const CronPatchObjectSchema = Type.Optional( toolsAllow: nullableStringArraySchema("Allowed tool ids, or null to clear"), }), ), - delivery: Type.Optional( - Type.Object( - { - mode: optionalStringEnum(CRON_DELIVERY_MODES, { description: "Delivery mode" }), - channel: Type.Optional(Type.String({ description: "Delivery channel" })), - to: Type.Optional(Type.String({ description: "Delivery target" })), - bestEffort: Type.Optional(Type.Boolean()), - accountId: Type.Optional(Type.String({ description: "Account target for delivery" })), - failureDestination: Type.Optional( - Type.Object( - { - channel: Type.Optional(Type.String()), - to: Type.Optional(Type.String()), - accountId: Type.Optional(Type.String()), - mode: optionalStringEnum(["announce", "webhook"] as const), - }, - { additionalProperties: true }, - ), - ), - }, - { additionalProperties: true }, - ), - ), + delivery: CronDeliverySchema, description: Type.Optional(Type.String()), enabled: Type.Optional(Type.Boolean()), deleteAfterRun: Type.Optional(Type.Boolean()), diff --git a/src/agents/tools/nodes-tool.test.ts b/src/agents/tools/nodes-tool.test.ts index 1e0b4473444..91afb2bf5bf 100644 --- a/src/agents/tools/nodes-tool.test.ts +++ b/src/agents/tools/nodes-tool.test.ts @@ -64,6 +64,45 @@ vi.mock("../../cli/nodes-screen.js", () => ({ let createNodesTool: typeof import("./nodes-tool.js").createNodesTool; +function mockNodePairApproveFlow(pendingRequest: { + requiredApproveScopes?: string[]; + commands?: string[]; +}): void { + gatewayMocks.callGatewayTool.mockImplementation(async (method, _opts, params, extra) => { + if (method === "node.pair.list") { + return { + pending: [ + { + requestId: "req-1", + ...pendingRequest, + }, + ], + }; + } + if (method === "node.pair.approve") { + return { ok: true, method, params, extra }; + } + throw new Error(`unexpected method: ${String(method)}`); + }); +} + +function expectNodePairApproveScopes(scopes: string[]): void { + expect(gatewayMocks.callGatewayTool).toHaveBeenNthCalledWith( + 1, + "node.pair.list", + {}, + {}, + { scopes: ["operator.pairing"] }, + ); + expect(gatewayMocks.callGatewayTool).toHaveBeenNthCalledWith( + 2, + "node.pair.approve", + {}, + { requestId: "req-1" }, + { scopes }, + ); +} + describe("createNodesTool screen_record duration guardrails", () => { beforeAll(async () => { ({ createNodesTool } = await import("./nodes-tool.js")); @@ -212,21 +251,8 @@ describe("createNodesTool screen_record duration guardrails", () => { }); it("uses operator.pairing plus operator.admin to approve exec-capable node pair requests", async () => { - gatewayMocks.callGatewayTool.mockImplementation(async (method, _opts, params, extra) => { - if (method === "node.pair.list") { - return { - pending: [ - { - requestId: "req-1", - requiredApproveScopes: ["operator.pairing", "operator.admin"], - }, - ], - }; - } - if (method === "node.pair.approve") { - return { ok: true, method, params, extra }; - } - throw new Error(`unexpected method: ${String(method)}`); + mockNodePairApproveFlow({ + requiredApproveScopes: ["operator.pairing", "operator.admin"], }); const tool = createNodesTool(); @@ -235,38 +261,12 @@ describe("createNodesTool screen_record duration guardrails", () => { requestId: "req-1", }); - expect(gatewayMocks.callGatewayTool).toHaveBeenNthCalledWith( - 1, - "node.pair.list", - {}, - {}, - { scopes: ["operator.pairing"] }, - ); - expect(gatewayMocks.callGatewayTool).toHaveBeenNthCalledWith( - 2, - "node.pair.approve", - {}, - { requestId: "req-1" }, - { scopes: ["operator.pairing", "operator.admin"] }, - ); + expectNodePairApproveScopes(["operator.pairing", "operator.admin"]); }); it("uses operator.pairing plus operator.write to approve non-exec node pair requests", async () => { - gatewayMocks.callGatewayTool.mockImplementation(async (method, _opts, params, extra) => { - if (method === "node.pair.list") { - return { - pending: [ - { - requestId: "req-1", - requiredApproveScopes: ["operator.pairing", "operator.write"], - }, - ], - }; - } - if (method === "node.pair.approve") { - return { ok: true, method, params, extra }; - } - throw new Error(`unexpected method: ${String(method)}`); + mockNodePairApproveFlow({ + requiredApproveScopes: ["operator.pairing", "operator.write"], }); const tool = createNodesTool(); @@ -275,38 +275,12 @@ describe("createNodesTool screen_record duration guardrails", () => { requestId: "req-1", }); - expect(gatewayMocks.callGatewayTool).toHaveBeenNthCalledWith( - 1, - "node.pair.list", - {}, - {}, - { scopes: ["operator.pairing"] }, - ); - expect(gatewayMocks.callGatewayTool).toHaveBeenNthCalledWith( - 2, - "node.pair.approve", - {}, - { requestId: "req-1" }, - { scopes: ["operator.pairing", "operator.write"] }, - ); + expectNodePairApproveScopes(["operator.pairing", "operator.write"]); }); it("uses operator.pairing for commandless node pair requests", async () => { - gatewayMocks.callGatewayTool.mockImplementation(async (method, _opts, params, extra) => { - if (method === "node.pair.list") { - return { - pending: [ - { - requestId: "req-1", - requiredApproveScopes: ["operator.pairing"], - }, - ], - }; - } - if (method === "node.pair.approve") { - return { ok: true, method, params, extra }; - } - throw new Error(`unexpected method: ${String(method)}`); + mockNodePairApproveFlow({ + requiredApproveScopes: ["operator.pairing"], }); const tool = createNodesTool(); @@ -315,38 +289,12 @@ describe("createNodesTool screen_record duration guardrails", () => { requestId: "req-1", }); - expect(gatewayMocks.callGatewayTool).toHaveBeenNthCalledWith( - 1, - "node.pair.list", - {}, - {}, - { scopes: ["operator.pairing"] }, - ); - expect(gatewayMocks.callGatewayTool).toHaveBeenNthCalledWith( - 2, - "node.pair.approve", - {}, - { requestId: "req-1" }, - { scopes: ["operator.pairing"] }, - ); + expectNodePairApproveScopes(["operator.pairing"]); }); it("falls back to command inspection when the gateway does not advertise required scopes", async () => { - gatewayMocks.callGatewayTool.mockImplementation(async (method, _opts, params, extra) => { - if (method === "node.pair.list") { - return { - pending: [ - { - requestId: "req-1", - commands: ["canvas.snapshot"], - }, - ], - }; - } - if (method === "node.pair.approve") { - return { ok: true, method, params, extra }; - } - throw new Error(`unexpected method: ${String(method)}`); + mockNodePairApproveFlow({ + commands: ["canvas.snapshot"], }); const tool = createNodesTool(); @@ -355,20 +303,7 @@ describe("createNodesTool screen_record duration guardrails", () => { requestId: "req-1", }); - expect(gatewayMocks.callGatewayTool).toHaveBeenNthCalledWith( - 1, - "node.pair.list", - {}, - {}, - { scopes: ["operator.pairing"] }, - ); - expect(gatewayMocks.callGatewayTool).toHaveBeenNthCalledWith( - 2, - "node.pair.approve", - {}, - { requestId: "req-1" }, - { scopes: ["operator.pairing", "operator.write"] }, - ); + expectNodePairApproveScopes(["operator.pairing", "operator.write"]); }); it("blocks invokeCommand system.run so exec stays the only shell path", async () => { diff --git a/src/auto-reply/reply/commands-subagents-send.test.ts b/src/auto-reply/reply/commands-subagents-send.test.ts index 0a14d989b6c..762d52c1997 100644 --- a/src/auto-reply/reply/commands-subagents-send.test.ts +++ b/src/auto-reply/reply/commands-subagents-send.test.ts @@ -1,6 +1,5 @@ import { beforeEach, describe, expect, it, vi } from "vitest"; -import type { SubagentRunRecord } from "../../agents/subagent-registry.types.js"; -import type { OpenClawConfig } from "../../config/config.js"; +import { buildSubagentsSendContext } from "./commands-subagents.test-helpers.js"; import { handleSubagentsSendAction } from "./commands-subagents/action-send.js"; const sendControlledSubagentMessageMock = vi.hoisted(() => vi.fn()); @@ -11,45 +10,11 @@ vi.mock("./commands-subagents-control.runtime.js", () => ({ steerControlledSubagentRun: steerControlledSubagentRunMock, })); -function buildRun(): SubagentRunRecord { - return { - runId: "run-1", - childSessionKey: "agent:main:subagent:abc", - requesterSessionKey: "agent:main:main", - requesterDisplayKey: "main", - task: "do thing", - cleanup: "keep", - createdAt: 1000, - startedAt: 1000, - }; -} - -function buildContext(params?: { - cfg?: OpenClawConfig; - requesterKey?: string; - runs?: SubagentRunRecord[]; - restTokens?: string[]; -}) { - return { - params: { - cfg: - params?.cfg ?? - ({ - commands: { text: true }, - channels: { whatsapp: { allowFrom: ["*"] } }, - } as OpenClawConfig), - ctx: {}, - command: { - channel: "whatsapp", - to: "test-bot", - }, - }, +const buildContext = () => + buildSubagentsSendContext({ handledPrefix: "/subagents", - requesterKey: params?.requesterKey ?? "agent:main:main", - runs: params?.runs ?? [buildRun()], - restTokens: params?.restTokens ?? ["1", "continue", "with", "follow-up", "details"], - } as Parameters[0]; -} + restTokens: ["1", "continue", "with", "follow-up", "details"], + }); describe("subagents send action", () => { beforeEach(() => { diff --git a/src/auto-reply/reply/commands-subagents-steer.test.ts b/src/auto-reply/reply/commands-subagents-steer.test.ts index cf245a8f364..23f74a4085a 100644 --- a/src/auto-reply/reply/commands-subagents-steer.test.ts +++ b/src/auto-reply/reply/commands-subagents-steer.test.ts @@ -1,6 +1,5 @@ import { beforeEach, describe, expect, it, vi } from "vitest"; -import type { SubagentRunRecord } from "../../agents/subagent-registry.types.js"; -import type { OpenClawConfig } from "../../config/config.js"; +import { buildSubagentsSendContext } from "./commands-subagents.test-helpers.js"; import { handleSubagentsSendAction } from "./commands-subagents/action-send.js"; const sendControlledSubagentMessageMock = vi.hoisted(() => vi.fn()); @@ -11,45 +10,11 @@ vi.mock("./commands-subagents-control.runtime.js", () => ({ steerControlledSubagentRun: steerControlledSubagentRunMock, })); -function buildRun(): SubagentRunRecord { - return { - runId: "run-1", - childSessionKey: "agent:main:subagent:abc", - requesterSessionKey: "agent:main:main", - requesterDisplayKey: "main", - task: "do thing", - cleanup: "keep", - createdAt: 1000, - startedAt: 1000, - }; -} - -function buildContext(params?: { - cfg?: OpenClawConfig; - requesterKey?: string; - runs?: SubagentRunRecord[]; - restTokens?: string[]; -}) { - return { - params: { - cfg: - params?.cfg ?? - ({ - commands: { text: true }, - channels: { whatsapp: { allowFrom: ["*"] } }, - } as OpenClawConfig), - ctx: {}, - command: { - channel: "whatsapp", - to: "test-bot", - }, - }, +const buildContext = () => + buildSubagentsSendContext({ handledPrefix: "/steer", - requesterKey: params?.requesterKey ?? "agent:main:main", - runs: params?.runs ?? [buildRun()], - restTokens: params?.restTokens ?? ["1", "check", "timer.ts", "instead"], - } as Parameters[0]; -} + restTokens: ["1", "check", "timer.ts", "instead"], + }); describe("subagents steer action", () => { beforeEach(() => { diff --git a/src/auto-reply/reply/commands-subagents.test-helpers.ts b/src/auto-reply/reply/commands-subagents.test-helpers.ts new file mode 100644 index 00000000000..2b693d6fade --- /dev/null +++ b/src/auto-reply/reply/commands-subagents.test-helpers.ts @@ -0,0 +1,44 @@ +import type { SubagentRunRecord } from "../../agents/subagent-registry.types.js"; +import type { OpenClawConfig } from "../../config/config.js"; +import type { handleSubagentsSendAction } from "./commands-subagents/action-send.js"; + +export function buildSubagentRun(): SubagentRunRecord { + return { + runId: "run-1", + childSessionKey: "agent:main:subagent:abc", + requesterSessionKey: "agent:main:main", + requesterDisplayKey: "main", + task: "do thing", + cleanup: "keep", + createdAt: 1000, + startedAt: 1000, + }; +} + +export function buildSubagentsSendContext(params?: { + cfg?: OpenClawConfig; + handledPrefix?: string; + requesterKey?: string; + runs?: SubagentRunRecord[]; + restTokens?: string[]; +}) { + return { + params: { + cfg: + params?.cfg ?? + ({ + commands: { text: true }, + channels: { whatsapp: { allowFrom: ["*"] } }, + } as OpenClawConfig), + ctx: {}, + command: { + channel: "whatsapp", + to: "test-bot", + }, + }, + handledPrefix: params?.handledPrefix ?? "/subagents", + requesterKey: params?.requesterKey ?? "agent:main:main", + runs: params?.runs ?? [buildSubagentRun()], + restTokens: params?.restTokens ?? [], + } as Parameters[0]; +} diff --git a/src/auto-reply/reply/dispatch-from-config.acp-abort.test.ts b/src/auto-reply/reply/dispatch-from-config.acp-abort.test.ts index 5612c7084ae..a9011ef54db 100644 --- a/src/auto-reply/reply/dispatch-from-config.acp-abort.test.ts +++ b/src/auto-reply/reply/dispatch-from-config.acp-abort.test.ts @@ -1,6 +1,5 @@ import { beforeAll, beforeEach, describe, expect, it, vi } from "vitest"; import type { OpenClawConfig } from "../../config/config.js"; -import type { SessionBindingRecord } from "../../infra/outbound/session-binding-service.js"; import type { AcpRuntime, AcpRuntimeEnsureInput, @@ -8,288 +7,27 @@ import type { AcpRuntimeHandle, AcpRuntimeTurnInput, } from "../../plugin-sdk/acp-runtime.js"; -import type { - PluginHookBeforeDispatchResult, - PluginHookReplyDispatchResult, - PluginTargetedInboundClaimOutcome, -} from "../../plugins/hooks.js"; -import { setActivePluginRegistry } from "../../plugins/runtime.js"; -import { - createChannelTestPluginBase, - createTestRegistry, -} from "../../test-utils/channel-plugins.js"; import { createInternalHookEventPayload } from "../../test-utils/internal-hook-event-payload.js"; -import type { ReplyPayload } from "../types.js"; -import type { ReplyDispatcher } from "./reply-dispatcher.js"; +import { + acpManagerRuntimeMocks, + acpMocks, + agentEventMocks, + createDispatcher, + diagnosticMocks, + hookMocks, + internalHookMocks, + mocks, + noAbortResult, + resetPluginTtsAndThreadMocks, + sessionBindingMocks, + sessionStoreMocks, + setDiscordTestRegistry, +} from "./dispatch-from-config.shared.test-harness.js"; import { buildTestCtx } from "./test-ctx.js"; -type AbortResult = { handled: boolean; aborted: boolean; stoppedSubagents?: number }; - -const mocks = vi.hoisted(() => ({ - routeReply: vi.fn(async (_params: unknown) => ({ ok: true, messageId: "mock" })), - tryFastAbortFromMessage: vi.fn<() => Promise>(async () => ({ - handled: false, - aborted: false, - })), -})); -const diagnosticMocks = vi.hoisted(() => ({ - logMessageQueued: vi.fn(), - logMessageProcessed: vi.fn(), - logSessionStateChange: vi.fn(), -})); -const hookMocks = vi.hoisted(() => ({ - registry: { - plugins: [] as Array<{ id: string; status: "loaded" | "disabled" | "error" }>, - }, - runner: { - hasHooks: vi.fn<(hookName?: string) => boolean>(() => false), - runInboundClaim: vi.fn(async () => undefined), - runInboundClaimForPlugin: vi.fn(async () => undefined), - runInboundClaimForPluginOutcome: vi.fn<() => Promise>( - async () => ({ status: "no_handler" as const }), - ), - runMessageReceived: vi.fn(async () => {}), - runBeforeDispatch: vi.fn< - (_event: unknown, _ctx: unknown) => Promise - >(async () => undefined), - runReplyDispatch: vi.fn< - (_event: unknown, _ctx: unknown) => Promise - >(async () => undefined), - }, -})); -const internalHookMocks = vi.hoisted(() => ({ - createInternalHookEvent: vi.fn(), - triggerInternalHook: vi.fn(async () => {}), -})); -const acpMocks = vi.hoisted(() => ({ - listAcpSessionEntries: vi.fn(async () => []), - readAcpSessionEntry: vi.fn<(params: { sessionKey: string; cfg?: OpenClawConfig }) => unknown>( - () => null, - ), - getAcpRuntimeBackend: vi.fn<() => unknown>(() => null), - upsertAcpSessionMeta: vi.fn< - (params: { - sessionKey: string; - cfg?: OpenClawConfig; - mutate: ( - current: Record | undefined, - entry: { acp?: Record } | undefined, - ) => Record | null | undefined; - }) => Promise - >(async () => null), - requireAcpRuntimeBackend: vi.fn<() => unknown>(), -})); -const sessionBindingMocks = vi.hoisted(() => ({ - listBySession: vi.fn<(targetSessionKey: string) => SessionBindingRecord[]>(() => []), - resolveByConversation: vi.fn< - (ref: { - channel: string; - accountId: string; - conversationId: string; - parentConversationId?: string; - }) => SessionBindingRecord | null - >(() => null), - touch: vi.fn(), -})); -const pluginConversationBindingMocks = vi.hoisted(() => ({ - shownFallbackNoticeBindingIds: new Set(), -})); -const sessionStoreMocks = vi.hoisted(() => ({ - currentEntry: undefined as Record | undefined, - loadSessionStore: vi.fn(() => ({})), - resolveStorePath: vi.fn(() => "/tmp/mock-sessions.json"), - resolveSessionStoreEntry: vi.fn(() => ({ existing: sessionStoreMocks.currentEntry })), -})); -const acpManagerRuntimeMocks = vi.hoisted(() => ({ - getAcpSessionManager: vi.fn(), -})); -const agentEventMocks = vi.hoisted(() => ({ - emitAgentEvent: vi.fn(), - onAgentEvent: vi.fn<(listener: unknown) => () => void>(() => () => {}), -})); -const ttsMocks = vi.hoisted(() => ({ - maybeApplyTtsToPayload: vi.fn(async (paramsUnknown: unknown) => { - const params = paramsUnknown as { payload: ReplyPayload }; - return params.payload; - }), - normalizeTtsAutoMode: vi.fn((value: unknown) => (typeof value === "string" ? value : undefined)), - resolveTtsConfig: vi.fn((_cfg: OpenClawConfig) => ({ mode: "final" })), -})); -const threadInfoMocks = vi.hoisted(() => ({ - parseSessionThreadInfo: vi.fn< - (sessionKey: string | undefined) => { - baseSessionKey: string | undefined; - threadId: string | undefined; - } - >(), -})); - -function parseGenericThreadSessionInfo(sessionKey: string | undefined) { - const trimmed = sessionKey?.trim(); - if (!trimmed) { - return { baseSessionKey: undefined, threadId: undefined }; - } - const threadMarker = ":thread:"; - const topicMarker = ":topic:"; - const marker = trimmed.includes(threadMarker) - ? threadMarker - : trimmed.includes(topicMarker) - ? topicMarker - : undefined; - if (!marker) { - return { baseSessionKey: trimmed, threadId: undefined }; - } - const index = trimmed.lastIndexOf(marker); - if (index < 0) { - return { baseSessionKey: trimmed, threadId: undefined }; - } - const baseSessionKey = trimmed.slice(0, index).trim() || undefined; - const threadId = trimmed.slice(index + marker.length).trim() || undefined; - return { baseSessionKey, threadId }; -} - -vi.mock("./route-reply.runtime.js", () => ({ - isRoutableChannel: () => true, - routeReply: mocks.routeReply, -})); -vi.mock("./route-reply.js", () => ({ - isRoutableChannel: () => true, - routeReply: mocks.routeReply, -})); -vi.mock("./abort.runtime.js", () => ({ - tryFastAbortFromMessage: mocks.tryFastAbortFromMessage, - formatAbortReplyText: () => "⚙️ Agent was aborted.", -})); -vi.mock("../../logging/diagnostic.js", () => ({ - logMessageQueued: diagnosticMocks.logMessageQueued, - logMessageProcessed: diagnosticMocks.logMessageProcessed, - logSessionStateChange: diagnosticMocks.logSessionStateChange, -})); -vi.mock("../../config/sessions/thread-info.js", () => ({ - parseSessionThreadInfo: (sessionKey: string | undefined) => - threadInfoMocks.parseSessionThreadInfo(sessionKey), -})); -vi.mock("./dispatch-from-config.runtime.js", () => ({ - createInternalHookEvent: internalHookMocks.createInternalHookEvent, - loadSessionStore: sessionStoreMocks.loadSessionStore, - resolveSessionStoreEntry: sessionStoreMocks.resolveSessionStoreEntry, - resolveStorePath: sessionStoreMocks.resolveStorePath, - triggerInternalHook: internalHookMocks.triggerInternalHook, -})); -vi.mock("../../plugins/hook-runner-global.js", () => ({ - getGlobalHookRunner: () => hookMocks.runner, - getGlobalPluginRegistry: () => hookMocks.registry, -})); -vi.mock("../../acp/runtime/session-meta.js", () => ({ - listAcpSessionEntries: acpMocks.listAcpSessionEntries, - readAcpSessionEntry: acpMocks.readAcpSessionEntry, - upsertAcpSessionMeta: acpMocks.upsertAcpSessionMeta, -})); -vi.mock("../../acp/runtime/registry.js", () => ({ - getAcpRuntimeBackend: acpMocks.getAcpRuntimeBackend, - requireAcpRuntimeBackend: acpMocks.requireAcpRuntimeBackend, -})); -vi.mock("../../infra/outbound/session-binding-service.js", () => ({ - getSessionBindingService: () => ({ - bind: vi.fn(async () => { - throw new Error("bind not mocked"); - }), - getCapabilities: vi.fn(() => ({ - adapterAvailable: true, - bindSupported: true, - unbindSupported: true, - placements: ["current", "child"] as const, - })), - listBySession: (targetSessionKey: string) => - sessionBindingMocks.listBySession(targetSessionKey), - resolveByConversation: sessionBindingMocks.resolveByConversation, - touch: sessionBindingMocks.touch, - unbind: vi.fn(async () => []), - }), -})); -vi.mock("../../infra/agent-events.js", () => ({ - emitAgentEvent: (params: unknown) => agentEventMocks.emitAgentEvent(params), - onAgentEvent: (listener: unknown) => agentEventMocks.onAgentEvent(listener), -})); -vi.mock("../../plugins/conversation-binding.js", () => ({ - buildPluginBindingDeclinedText: () => "Plugin binding request was declined.", - buildPluginBindingErrorText: () => "Plugin binding request failed.", - buildPluginBindingUnavailableText: (binding: { pluginName?: string; pluginId: string }) => - `${binding.pluginName ?? binding.pluginId} is not currently loaded.`, - hasShownPluginBindingFallbackNotice: (bindingId: string) => - pluginConversationBindingMocks.shownFallbackNoticeBindingIds.has(bindingId), - isPluginOwnedSessionBindingRecord: ( - record: SessionBindingRecord | null | undefined, - ): record is SessionBindingRecord => - record?.metadata != null && - typeof record.metadata === "object" && - (record.metadata as { pluginBindingOwner?: string }).pluginBindingOwner === "plugin", - markPluginBindingFallbackNoticeShown: (bindingId: string) => { - pluginConversationBindingMocks.shownFallbackNoticeBindingIds.add(bindingId); - }, - toPluginConversationBinding: (record: SessionBindingRecord) => ({ - bindingId: record.bindingId, - pluginId: "unknown-plugin", - pluginName: undefined, - pluginRoot: "", - channel: record.conversation.channel, - accountId: record.conversation.accountId, - conversationId: record.conversation.conversationId, - parentConversationId: record.conversation.parentConversationId, - }), -})); -vi.mock("./dispatch-acp-manager.runtime.js", () => ({ - getAcpSessionManager: () => acpManagerRuntimeMocks.getAcpSessionManager(), - getSessionBindingService: () => ({ - listBySession: (targetSessionKey: string) => - sessionBindingMocks.listBySession(targetSessionKey), - unbind: vi.fn(async () => []), - }), -})); -vi.mock("../../tts/tts.js", () => ({ - maybeApplyTtsToPayload: (params: unknown) => ttsMocks.maybeApplyTtsToPayload(params), - normalizeTtsAutoMode: (value: unknown) => ttsMocks.normalizeTtsAutoMode(value), - resolveTtsConfig: (cfg: OpenClawConfig) => ttsMocks.resolveTtsConfig(cfg), -})); -vi.mock("../../tts/tts.runtime.js", () => ({ - maybeApplyTtsToPayload: (params: unknown) => ttsMocks.maybeApplyTtsToPayload(params), -})); -vi.mock("../../tts/status-config.js", () => ({ - resolveStatusTtsSnapshot: () => ({ - autoMode: "always", - provider: "auto", - maxLength: 1500, - summarize: true, - }), -})); -vi.mock("./dispatch-acp-tts.runtime.js", () => ({ - maybeApplyTtsToPayload: (params: unknown) => ttsMocks.maybeApplyTtsToPayload(params), -})); -vi.mock("./dispatch-acp-session.runtime.js", () => ({ - readAcpSessionEntry: (params: { sessionKey: string; cfg?: OpenClawConfig }) => - acpMocks.readAcpSessionEntry(params), -})); -vi.mock("../../tts/tts-config.js", () => ({ - normalizeTtsAutoMode: (value: unknown) => ttsMocks.normalizeTtsAutoMode(value), - resolveConfiguredTtsMode: (cfg: OpenClawConfig) => ttsMocks.resolveTtsConfig(cfg).mode, -})); - -const noAbortResult = { handled: false, aborted: false } as const; let dispatchReplyFromConfig: typeof import("./dispatch-from-config.js").dispatchReplyFromConfig; let tryDispatchAcpReplyHook: typeof import("../../plugin-sdk/acp-runtime.js").tryDispatchAcpReplyHook; -function createDispatcher(): ReplyDispatcher { - return { - sendToolResult: vi.fn(() => true), - sendBlockReply: vi.fn(() => true), - sendFinalReply: vi.fn(() => true), - waitForIdle: vi.fn(async () => {}), - getQueuedCounts: vi.fn(() => ({ tool: 0, block: 0, final: 0 })), - getFailedCounts: vi.fn(() => ({ tool: 0, block: 0, final: 0 })), - markComplete: vi.fn(), - }; -} - function shouldUseAcpReplyDispatchHook(eventUnknown: unknown): boolean { const event = eventUnknown as { sessionKey?: string; @@ -393,19 +131,7 @@ describe("dispatchReplyFromConfig ACP abort", () => { }); beforeEach(() => { - const discordTestPlugin = { - ...createChannelTestPluginBase({ - id: "discord", - capabilities: { chatTypes: ["direct"], nativeCommands: true }, - }), - outbound: { - deliveryMode: "direct", - shouldSuppressLocalPayloadPrompt: () => false, - }, - }; - setActivePluginRegistry( - createTestRegistry([{ pluginId: "discord", source: "test", plugin: discordTestPlugin }]), - ); + setDiscordTestRegistry(); acpManagerRuntimeMocks.getAcpSessionManager.mockReset(); acpManagerRuntimeMocks.getAcpSessionManager.mockReturnValue(createMockAcpSessionManager()); hookMocks.runner.hasHooks.mockReset(); @@ -445,20 +171,7 @@ describe("dispatchReplyFromConfig ACP abort", () => { sessionBindingMocks.listBySession.mockReset().mockReturnValue([]); sessionBindingMocks.resolveByConversation.mockReset().mockReturnValue(null); sessionBindingMocks.touch.mockReset(); - pluginConversationBindingMocks.shownFallbackNoticeBindingIds.clear(); - ttsMocks.maybeApplyTtsToPayload - .mockReset() - .mockImplementation(async (paramsUnknown: unknown) => { - const params = paramsUnknown as { payload: ReplyPayload }; - return params.payload; - }); - ttsMocks.normalizeTtsAutoMode - .mockReset() - .mockImplementation((value: unknown) => (typeof value === "string" ? value : undefined)); - ttsMocks.resolveTtsConfig.mockReset().mockReturnValue({ mode: "final" }); - threadInfoMocks.parseSessionThreadInfo - .mockReset() - .mockImplementation(parseGenericThreadSessionInfo); + resetPluginTtsAndThreadMocks(); diagnosticMocks.logMessageQueued.mockReset(); diagnosticMocks.logMessageProcessed.mockReset(); diagnosticMocks.logSessionStateChange.mockReset(); diff --git a/src/auto-reply/reply/dispatch-from-config.reply-dispatch.test.ts b/src/auto-reply/reply/dispatch-from-config.reply-dispatch.test.ts index d4e2e55396e..758facb58dc 100644 --- a/src/auto-reply/reply/dispatch-from-config.reply-dispatch.test.ts +++ b/src/auto-reply/reply/dispatch-from-config.reply-dispatch.test.ts @@ -1,304 +1,26 @@ import { beforeAll, beforeEach, describe, expect, it, vi } from "vitest"; -import type { OpenClawConfig } from "../../config/config.js"; -import type { SessionBindingRecord } from "../../infra/outbound/session-binding-service.js"; -import type { - PluginHookBeforeDispatchResult, - PluginHookReplyDispatchResult, - PluginTargetedInboundClaimOutcome, -} from "../../plugins/hooks.js"; -import { setActivePluginRegistry } from "../../plugins/runtime.js"; -import { - createChannelTestPluginBase, - createTestRegistry, -} from "../../test-utils/channel-plugins.js"; +import type { PluginHookReplyDispatchResult } from "../../plugins/hooks.js"; import { createInternalHookEventPayload } from "../../test-utils/internal-hook-event-payload.js"; -import type { ReplyPayload } from "../types.js"; -import type { ReplyDispatcher } from "./reply-dispatcher.js"; -import { buildTestCtx } from "./test-ctx.js"; +import { + acpManagerRuntimeMocks, + acpMocks, + agentEventMocks, + createDispatcher, + createHookCtx, + diagnosticMocks, + emptyConfig, + hookMocks, + internalHookMocks, + mocks, + resetPluginTtsAndThreadMocks, + sessionBindingMocks, + sessionStoreMocks, + setDiscordTestRegistry, +} from "./dispatch-from-config.shared.test-harness.js"; -type AbortResult = { handled: boolean; aborted: boolean; stoppedSubagents?: number }; - -const mocks = vi.hoisted(() => ({ - routeReply: vi.fn(async () => ({ ok: true, messageId: "mock" })), - tryFastAbortFromMessage: vi.fn<() => Promise>(async () => ({ - handled: false, - aborted: false, - })), -})); -const diagnosticMocks = vi.hoisted(() => ({ - logMessageQueued: vi.fn(), - logMessageProcessed: vi.fn(), - logSessionStateChange: vi.fn(), -})); -const hookMocks = vi.hoisted(() => ({ - registry: { - plugins: [] as Array<{ id: string; status: "loaded" | "disabled" | "error" }>, - }, - runner: { - hasHooks: vi.fn<(hookName?: string) => boolean>(() => false), - runInboundClaim: vi.fn(async () => undefined), - runInboundClaimForPlugin: vi.fn(async () => undefined), - runInboundClaimForPluginOutcome: vi.fn<() => Promise>( - async () => ({ status: "no_handler" as const }), - ), - runMessageReceived: vi.fn(async () => {}), - runBeforeDispatch: vi.fn< - (_event: unknown, _ctx: unknown) => Promise - >(async () => undefined), - runReplyDispatch: vi.fn< - (_event: unknown, _ctx: unknown) => Promise - >(async () => undefined), - }, -})); -const internalHookMocks = vi.hoisted(() => ({ - createInternalHookEvent: vi.fn(), - triggerInternalHook: vi.fn(async () => {}), -})); -const acpMocks = vi.hoisted(() => ({ - listAcpSessionEntries: vi.fn(async () => []), - readAcpSessionEntry: vi.fn<(params: { sessionKey: string; cfg?: OpenClawConfig }) => unknown>( - () => null, - ), - upsertAcpSessionMeta: vi.fn(async () => null), - requireAcpRuntimeBackend: vi.fn<() => unknown>(), -})); -const sessionBindingMocks = vi.hoisted(() => ({ - listBySession: vi.fn<(targetSessionKey: string) => SessionBindingRecord[]>(() => []), - resolveByConversation: vi.fn< - (ref: { - channel: string; - accountId: string; - conversationId: string; - parentConversationId?: string; - }) => SessionBindingRecord | null - >(() => null), - touch: vi.fn(), -})); -const pluginConversationBindingMocks = vi.hoisted(() => ({ - shownFallbackNoticeBindingIds: new Set(), -})); -const sessionStoreMocks = vi.hoisted(() => ({ - currentEntry: undefined as Record | undefined, - loadSessionStore: vi.fn(() => ({})), - resolveStorePath: vi.fn(() => "/tmp/mock-sessions.json"), - resolveSessionStoreEntry: vi.fn(() => ({ existing: sessionStoreMocks.currentEntry })), -})); -const acpManagerRuntimeMocks = vi.hoisted(() => ({ - getAcpSessionManager: vi.fn(() => ({ - resolveSession: () => ({ kind: "none" as const }), - getObservabilitySnapshot: () => ({ - runtimeCache: { activeSessions: 0, idleTtlMs: 0, evictedTotal: 0 }, - turns: { - active: 0, - queueDepth: 0, - completed: 0, - failed: 0, - averageLatencyMs: 0, - maxLatencyMs: 0, - }, - errorsByCode: {}, - }), - runTurn: vi.fn(), - })), -})); -const agentEventMocks = vi.hoisted(() => ({ - emitAgentEvent: vi.fn(), - onAgentEvent: vi.fn<(listener: unknown) => () => void>(() => () => {}), -})); -const ttsMocks = vi.hoisted(() => ({ - maybeApplyTtsToPayload: vi.fn(async (paramsUnknown: unknown) => { - const params = paramsUnknown as { payload: ReplyPayload }; - return params.payload; - }), - normalizeTtsAutoMode: vi.fn((value: unknown) => (typeof value === "string" ? value : undefined)), - resolveTtsConfig: vi.fn((_cfg: OpenClawConfig) => ({ mode: "final" })), -})); -const threadInfoMocks = vi.hoisted(() => ({ - parseSessionThreadInfo: vi.fn< - (sessionKey: string | undefined) => { - baseSessionKey: string | undefined; - threadId: string | undefined; - } - >(), -})); - -function parseGenericThreadSessionInfo(sessionKey: string | undefined) { - const trimmed = sessionKey?.trim(); - if (!trimmed) { - return { baseSessionKey: undefined, threadId: undefined }; - } - const threadMarker = ":thread:"; - const topicMarker = ":topic:"; - const marker = trimmed.includes(threadMarker) - ? threadMarker - : trimmed.includes(topicMarker) - ? topicMarker - : undefined; - if (!marker) { - return { baseSessionKey: trimmed, threadId: undefined }; - } - const index = trimmed.lastIndexOf(marker); - if (index < 0) { - return { baseSessionKey: trimmed, threadId: undefined }; - } - const baseSessionKey = trimmed.slice(0, index).trim() || undefined; - const threadId = trimmed.slice(index + marker.length).trim() || undefined; - return { baseSessionKey, threadId }; -} - -vi.mock("./route-reply.runtime.js", () => ({ - isRoutableChannel: () => true, - routeReply: mocks.routeReply, -})); -vi.mock("./route-reply.js", () => ({ - isRoutableChannel: () => true, - routeReply: mocks.routeReply, -})); -vi.mock("./abort.runtime.js", () => ({ - tryFastAbortFromMessage: mocks.tryFastAbortFromMessage, - formatAbortReplyText: () => "⚙️ Agent was aborted.", -})); -vi.mock("../../logging/diagnostic.js", () => ({ - logMessageQueued: diagnosticMocks.logMessageQueued, - logMessageProcessed: diagnosticMocks.logMessageProcessed, - logSessionStateChange: diagnosticMocks.logSessionStateChange, -})); -vi.mock("../../config/sessions/thread-info.js", () => ({ - parseSessionThreadInfo: (sessionKey: string | undefined) => - threadInfoMocks.parseSessionThreadInfo(sessionKey), -})); -vi.mock("./dispatch-from-config.runtime.js", () => ({ - createInternalHookEvent: internalHookMocks.createInternalHookEvent, - loadSessionStore: sessionStoreMocks.loadSessionStore, - resolveSessionStoreEntry: sessionStoreMocks.resolveSessionStoreEntry, - resolveStorePath: sessionStoreMocks.resolveStorePath, - triggerInternalHook: internalHookMocks.triggerInternalHook, -})); -vi.mock("../../plugins/hook-runner-global.js", () => ({ - getGlobalHookRunner: () => hookMocks.runner, - getGlobalPluginRegistry: () => hookMocks.registry, -})); -vi.mock("../../acp/runtime/session-meta.js", () => ({ - listAcpSessionEntries: acpMocks.listAcpSessionEntries, - readAcpSessionEntry: acpMocks.readAcpSessionEntry, - upsertAcpSessionMeta: acpMocks.upsertAcpSessionMeta, -})); -vi.mock("../../acp/runtime/registry.js", () => ({ - requireAcpRuntimeBackend: acpMocks.requireAcpRuntimeBackend, -})); -vi.mock("../../infra/outbound/session-binding-service.js", () => ({ - getSessionBindingService: () => ({ - bind: vi.fn(async () => { - throw new Error("bind not mocked"); - }), - getCapabilities: vi.fn(() => ({ - adapterAvailable: true, - bindSupported: true, - unbindSupported: true, - placements: ["current", "child"] as const, - })), - listBySession: (targetSessionKey: string) => - sessionBindingMocks.listBySession(targetSessionKey), - resolveByConversation: sessionBindingMocks.resolveByConversation, - touch: sessionBindingMocks.touch, - unbind: vi.fn(async () => []), - }), -})); -vi.mock("../../infra/agent-events.js", () => ({ - emitAgentEvent: (params: unknown) => agentEventMocks.emitAgentEvent(params), - onAgentEvent: (listener: unknown) => agentEventMocks.onAgentEvent(listener), -})); -vi.mock("../../plugins/conversation-binding.js", () => ({ - buildPluginBindingDeclinedText: () => "Plugin binding request was declined.", - buildPluginBindingErrorText: () => "Plugin binding request failed.", - buildPluginBindingUnavailableText: (binding: { pluginName?: string; pluginId: string }) => - `${binding.pluginName ?? binding.pluginId} is not currently loaded.`, - hasShownPluginBindingFallbackNotice: (bindingId: string) => - pluginConversationBindingMocks.shownFallbackNoticeBindingIds.has(bindingId), - isPluginOwnedSessionBindingRecord: ( - record: SessionBindingRecord | null | undefined, - ): record is SessionBindingRecord => - record?.metadata != null && - typeof record.metadata === "object" && - (record.metadata as { pluginBindingOwner?: string }).pluginBindingOwner === "plugin", - markPluginBindingFallbackNoticeShown: (bindingId: string) => { - pluginConversationBindingMocks.shownFallbackNoticeBindingIds.add(bindingId); - }, - toPluginConversationBinding: (record: SessionBindingRecord) => ({ - bindingId: record.bindingId, - pluginId: "unknown-plugin", - pluginName: undefined, - pluginRoot: "", - channel: record.conversation.channel, - accountId: record.conversation.accountId, - conversationId: record.conversation.conversationId, - parentConversationId: record.conversation.parentConversationId, - }), -})); -vi.mock("./dispatch-acp-manager.runtime.js", () => ({ - getAcpSessionManager: () => acpManagerRuntimeMocks.getAcpSessionManager(), - getSessionBindingService: () => ({ - listBySession: (targetSessionKey: string) => - sessionBindingMocks.listBySession(targetSessionKey), - unbind: vi.fn(async () => []), - }), -})); -vi.mock("../../tts/tts.js", () => ({ - maybeApplyTtsToPayload: (params: unknown) => ttsMocks.maybeApplyTtsToPayload(params), - normalizeTtsAutoMode: (value: unknown) => ttsMocks.normalizeTtsAutoMode(value), - resolveTtsConfig: (cfg: OpenClawConfig) => ttsMocks.resolveTtsConfig(cfg), -})); -vi.mock("../../tts/tts.runtime.js", () => ({ - maybeApplyTtsToPayload: (params: unknown) => ttsMocks.maybeApplyTtsToPayload(params), -})); -vi.mock("../../tts/status-config.js", () => ({ - resolveStatusTtsSnapshot: () => ({ - autoMode: "always", - provider: "auto", - maxLength: 1500, - summarize: true, - }), -})); -vi.mock("./dispatch-acp-tts.runtime.js", () => ({ - maybeApplyTtsToPayload: (params: unknown) => ttsMocks.maybeApplyTtsToPayload(params), -})); -vi.mock("./dispatch-acp-session.runtime.js", () => ({ - readAcpSessionEntry: (params: { sessionKey: string; cfg?: OpenClawConfig }) => - acpMocks.readAcpSessionEntry(params), -})); -vi.mock("../../tts/tts-config.js", () => ({ - normalizeTtsAutoMode: (value: unknown) => ttsMocks.normalizeTtsAutoMode(value), - resolveConfiguredTtsMode: (cfg: OpenClawConfig) => ttsMocks.resolveTtsConfig(cfg).mode, -})); - -const emptyConfig = {} as OpenClawConfig; let dispatchReplyFromConfig: typeof import("./dispatch-from-config.js").dispatchReplyFromConfig; let resetInboundDedupe: typeof import("./inbound-dedupe.js").resetInboundDedupe; -function createDispatcher(): ReplyDispatcher { - return { - sendToolResult: vi.fn(() => true), - sendBlockReply: vi.fn(() => true), - sendFinalReply: vi.fn(() => true), - waitForIdle: vi.fn(async () => {}), - getQueuedCounts: vi.fn(() => ({ tool: 0, block: 0, final: 0 })), - getFailedCounts: vi.fn(() => ({ tool: 0, block: 0, final: 0 })), - markComplete: vi.fn(), - }; -} - -function createHookCtx() { - return buildTestCtx({ - Body: "hello", - BodyForAgent: "hello", - BodyForCommands: "hello", - From: "user1", - Surface: "telegram", - ChatType: "private", - SessionKey: "agent:test:session", - }); -} - describe("dispatchReplyFromConfig reply_dispatch hook", () => { beforeAll(async () => { ({ dispatchReplyFromConfig } = await import("./dispatch-from-config.js")); @@ -306,19 +28,7 @@ describe("dispatchReplyFromConfig reply_dispatch hook", () => { }); beforeEach(() => { - const discordTestPlugin = { - ...createChannelTestPluginBase({ - id: "discord", - capabilities: { chatTypes: ["direct"], nativeCommands: true }, - }), - outbound: { - deliveryMode: "direct", - shouldSuppressLocalPayloadPrompt: () => false, - }, - }; - setActivePluginRegistry( - createTestRegistry([{ pluginId: "discord", source: "test", plugin: discordTestPlugin }]), - ); + setDiscordTestRegistry(); resetInboundDedupe(); mocks.routeReply.mockReset().mockResolvedValue({ ok: true, messageId: "mock" }); mocks.tryFastAbortFromMessage.mockReset().mockResolvedValue({ @@ -347,7 +57,6 @@ describe("dispatchReplyFromConfig reply_dispatch hook", () => { sessionBindingMocks.listBySession.mockReset().mockReturnValue([]); sessionBindingMocks.resolveByConversation.mockReset().mockReturnValue(null); sessionBindingMocks.touch.mockReset(); - pluginConversationBindingMocks.shownFallbackNoticeBindingIds.clear(); sessionStoreMocks.currentEntry = undefined; sessionStoreMocks.loadSessionStore.mockReset().mockReturnValue({}); sessionStoreMocks.resolveStorePath.mockReset().mockReturnValue("/tmp/mock-sessions.json"); @@ -374,19 +83,7 @@ describe("dispatchReplyFromConfig reply_dispatch hook", () => { diagnosticMocks.logMessageQueued.mockReset(); diagnosticMocks.logMessageProcessed.mockReset(); diagnosticMocks.logSessionStateChange.mockReset(); - ttsMocks.maybeApplyTtsToPayload - .mockReset() - .mockImplementation(async (paramsUnknown: unknown) => { - const params = paramsUnknown as { payload: ReplyPayload }; - return params.payload; - }); - ttsMocks.normalizeTtsAutoMode - .mockReset() - .mockImplementation((value: unknown) => (typeof value === "string" ? value : undefined)); - ttsMocks.resolveTtsConfig.mockReset().mockReturnValue({ mode: "final" }); - threadInfoMocks.parseSessionThreadInfo - .mockReset() - .mockImplementation(parseGenericThreadSessionInfo); + resetPluginTtsAndThreadMocks(); }); it("returns handled dispatch results from plugins", async () => { diff --git a/src/auto-reply/reply/dispatch-from-config.shared.test-harness.ts b/src/auto-reply/reply/dispatch-from-config.shared.test-harness.ts new file mode 100644 index 00000000000..612e65a7cb9 --- /dev/null +++ b/src/auto-reply/reply/dispatch-from-config.shared.test-harness.ts @@ -0,0 +1,342 @@ +import { vi } from "vitest"; +import type { OpenClawConfig } from "../../config/config.js"; +import type { SessionBindingRecord } from "../../infra/outbound/session-binding-service.js"; +import type { + PluginHookBeforeDispatchResult, + PluginHookReplyDispatchResult, + PluginTargetedInboundClaimOutcome, +} from "../../plugins/hooks.js"; +import { setActivePluginRegistry } from "../../plugins/runtime.js"; +import { + createChannelTestPluginBase, + createTestRegistry, +} from "../../test-utils/channel-plugins.js"; +import type { ReplyPayload } from "../types.js"; +import type { ReplyDispatcher } from "./reply-dispatcher.js"; +import { buildTestCtx } from "./test-ctx.js"; + +type AbortResult = { handled: boolean; aborted: boolean; stoppedSubagents?: number }; + +const mocks = vi.hoisted(() => ({ + routeReply: vi.fn(async (_params: unknown) => ({ ok: true, messageId: "mock" })), + tryFastAbortFromMessage: vi.fn<() => Promise>(async () => ({ + handled: false, + aborted: false, + })), +})); +const diagnosticMocks = vi.hoisted(() => ({ + logMessageQueued: vi.fn(), + logMessageProcessed: vi.fn(), + logSessionStateChange: vi.fn(), +})); +const hookMocks = vi.hoisted(() => ({ + registry: { + plugins: [] as Array<{ id: string; status: "loaded" | "disabled" | "error" }>, + }, + runner: { + hasHooks: vi.fn<(hookName?: string) => boolean>(() => false), + runInboundClaim: vi.fn(async () => undefined), + runInboundClaimForPlugin: vi.fn(async () => undefined), + runInboundClaimForPluginOutcome: vi.fn<() => Promise>( + async () => ({ status: "no_handler" as const }), + ), + runMessageReceived: vi.fn(async () => {}), + runBeforeDispatch: vi.fn< + (_event: unknown, _ctx: unknown) => Promise + >(async () => undefined), + runReplyDispatch: vi.fn< + (_event: unknown, _ctx: unknown) => Promise + >(async () => undefined), + }, +})); +const internalHookMocks = vi.hoisted(() => ({ + createInternalHookEvent: vi.fn(), + triggerInternalHook: vi.fn(async () => {}), +})); +const acpMocks = vi.hoisted(() => ({ + listAcpSessionEntries: vi.fn(async () => []), + readAcpSessionEntry: vi.fn<(params: { sessionKey: string; cfg?: OpenClawConfig }) => unknown>( + () => null, + ), + getAcpRuntimeBackend: vi.fn<() => unknown>(() => null), + upsertAcpSessionMeta: vi.fn< + (params: { + sessionKey: string; + cfg?: OpenClawConfig; + mutate: ( + current: Record | undefined, + entry: { acp?: Record } | undefined, + ) => Record | null | undefined; + }) => Promise + >(async () => null), + requireAcpRuntimeBackend: vi.fn<() => unknown>(), +})); +const sessionBindingMocks = vi.hoisted(() => ({ + listBySession: vi.fn<(targetSessionKey: string) => SessionBindingRecord[]>(() => []), + resolveByConversation: vi.fn< + (ref: { + channel: string; + accountId: string; + conversationId: string; + parentConversationId?: string; + }) => SessionBindingRecord | null + >(() => null), + touch: vi.fn(), +})); +const pluginConversationBindingMocks = vi.hoisted(() => ({ + shownFallbackNoticeBindingIds: new Set(), +})); +const sessionStoreMocks = vi.hoisted(() => ({ + currentEntry: undefined as Record | undefined, + loadSessionStore: vi.fn(() => ({})), + resolveStorePath: vi.fn(() => "/tmp/mock-sessions.json"), + resolveSessionStoreEntry: vi.fn(() => ({ existing: sessionStoreMocks.currentEntry })), +})); +const acpManagerRuntimeMocks = vi.hoisted(() => ({ + getAcpSessionManager: vi.fn(), +})); +const agentEventMocks = vi.hoisted(() => ({ + emitAgentEvent: vi.fn(), + onAgentEvent: vi.fn<(listener: unknown) => () => void>(() => () => {}), +})); +const ttsMocks = vi.hoisted(() => ({ + maybeApplyTtsToPayload: vi.fn(async (paramsUnknown: unknown) => { + const params = paramsUnknown as { payload: ReplyPayload }; + return params.payload; + }), + normalizeTtsAutoMode: vi.fn((value: unknown) => (typeof value === "string" ? value : undefined)), + resolveTtsConfig: vi.fn((_cfg: OpenClawConfig) => ({ mode: "final" })), +})); +const threadInfoMocks = vi.hoisted(() => ({ + parseSessionThreadInfo: vi.fn< + (sessionKey: string | undefined) => { + baseSessionKey: string | undefined; + threadId: string | undefined; + } + >(), +})); + +export { + acpManagerRuntimeMocks, + acpMocks, + agentEventMocks, + diagnosticMocks, + hookMocks, + internalHookMocks, + mocks, + pluginConversationBindingMocks, + sessionBindingMocks, + sessionStoreMocks, + threadInfoMocks, + ttsMocks, +}; + +export function parseGenericThreadSessionInfo(sessionKey: string | undefined) { + const trimmed = sessionKey?.trim(); + if (!trimmed) { + return { baseSessionKey: undefined, threadId: undefined }; + } + const threadMarker = ":thread:"; + const topicMarker = ":topic:"; + const marker = trimmed.includes(threadMarker) + ? threadMarker + : trimmed.includes(topicMarker) + ? topicMarker + : undefined; + if (!marker) { + return { baseSessionKey: trimmed, threadId: undefined }; + } + const index = trimmed.lastIndexOf(marker); + if (index < 0) { + return { baseSessionKey: trimmed, threadId: undefined }; + } + const baseSessionKey = trimmed.slice(0, index).trim() || undefined; + const threadId = trimmed.slice(index + marker.length).trim() || undefined; + return { baseSessionKey, threadId }; +} + +vi.mock("./route-reply.runtime.js", () => ({ + isRoutableChannel: () => true, + routeReply: mocks.routeReply, +})); +vi.mock("./route-reply.js", () => ({ + isRoutableChannel: () => true, + routeReply: mocks.routeReply, +})); +vi.mock("./abort.runtime.js", () => ({ + tryFastAbortFromMessage: mocks.tryFastAbortFromMessage, + formatAbortReplyText: () => "⚙️ Agent was aborted.", +})); +vi.mock("../../logging/diagnostic.js", () => ({ + logMessageQueued: diagnosticMocks.logMessageQueued, + logMessageProcessed: diagnosticMocks.logMessageProcessed, + logSessionStateChange: diagnosticMocks.logSessionStateChange, +})); +vi.mock("../../config/sessions/thread-info.js", () => ({ + parseSessionThreadInfo: (sessionKey: string | undefined) => + threadInfoMocks.parseSessionThreadInfo(sessionKey), +})); +vi.mock("./dispatch-from-config.runtime.js", () => ({ + createInternalHookEvent: internalHookMocks.createInternalHookEvent, + loadSessionStore: sessionStoreMocks.loadSessionStore, + resolveSessionStoreEntry: sessionStoreMocks.resolveSessionStoreEntry, + resolveStorePath: sessionStoreMocks.resolveStorePath, + triggerInternalHook: internalHookMocks.triggerInternalHook, +})); +vi.mock("../../plugins/hook-runner-global.js", () => ({ + getGlobalHookRunner: () => hookMocks.runner, + getGlobalPluginRegistry: () => hookMocks.registry, +})); +vi.mock("../../acp/runtime/session-meta.js", () => ({ + listAcpSessionEntries: acpMocks.listAcpSessionEntries, + readAcpSessionEntry: acpMocks.readAcpSessionEntry, + upsertAcpSessionMeta: acpMocks.upsertAcpSessionMeta, +})); +vi.mock("../../acp/runtime/registry.js", () => ({ + getAcpRuntimeBackend: acpMocks.getAcpRuntimeBackend, + requireAcpRuntimeBackend: acpMocks.requireAcpRuntimeBackend, +})); +vi.mock("../../infra/outbound/session-binding-service.js", () => ({ + getSessionBindingService: () => ({ + bind: vi.fn(async () => { + throw new Error("bind not mocked"); + }), + getCapabilities: vi.fn(() => ({ + adapterAvailable: true, + bindSupported: true, + unbindSupported: true, + placements: ["current", "child"] as const, + })), + listBySession: (targetSessionKey: string) => + sessionBindingMocks.listBySession(targetSessionKey), + resolveByConversation: sessionBindingMocks.resolveByConversation, + touch: sessionBindingMocks.touch, + unbind: vi.fn(async () => []), + }), +})); +vi.mock("../../infra/agent-events.js", () => ({ + emitAgentEvent: (params: unknown) => agentEventMocks.emitAgentEvent(params), + onAgentEvent: (listener: unknown) => agentEventMocks.onAgentEvent(listener), +})); +vi.mock("../../plugins/conversation-binding.js", () => ({ + buildPluginBindingDeclinedText: () => "Plugin binding request was declined.", + buildPluginBindingErrorText: () => "Plugin binding request failed.", + buildPluginBindingUnavailableText: (binding: { pluginName?: string; pluginId: string }) => + `${binding.pluginName ?? binding.pluginId} is not currently loaded.`, + hasShownPluginBindingFallbackNotice: (bindingId: string) => + pluginConversationBindingMocks.shownFallbackNoticeBindingIds.has(bindingId), + isPluginOwnedSessionBindingRecord: ( + record: SessionBindingRecord | null | undefined, + ): record is SessionBindingRecord => + record?.metadata != null && + typeof record.metadata === "object" && + (record.metadata as { pluginBindingOwner?: string }).pluginBindingOwner === "plugin", + markPluginBindingFallbackNoticeShown: (bindingId: string) => { + pluginConversationBindingMocks.shownFallbackNoticeBindingIds.add(bindingId); + }, + toPluginConversationBinding: (record: SessionBindingRecord) => ({ + bindingId: record.bindingId, + pluginId: "unknown-plugin", + pluginName: undefined, + pluginRoot: "", + channel: record.conversation.channel, + accountId: record.conversation.accountId, + conversationId: record.conversation.conversationId, + parentConversationId: record.conversation.parentConversationId, + }), +})); +vi.mock("./dispatch-acp-manager.runtime.js", () => ({ + getAcpSessionManager: () => acpManagerRuntimeMocks.getAcpSessionManager(), + getSessionBindingService: () => ({ + listBySession: (targetSessionKey: string) => + sessionBindingMocks.listBySession(targetSessionKey), + unbind: vi.fn(async () => []), + }), +})); +vi.mock("../../tts/tts.js", () => ({ + maybeApplyTtsToPayload: (params: unknown) => ttsMocks.maybeApplyTtsToPayload(params), + normalizeTtsAutoMode: (value: unknown) => ttsMocks.normalizeTtsAutoMode(value), + resolveTtsConfig: (cfg: OpenClawConfig) => ttsMocks.resolveTtsConfig(cfg), +})); +vi.mock("../../tts/tts.runtime.js", () => ({ + maybeApplyTtsToPayload: (params: unknown) => ttsMocks.maybeApplyTtsToPayload(params), +})); +vi.mock("../../tts/status-config.js", () => ({ + resolveStatusTtsSnapshot: () => ({ + autoMode: "always", + provider: "auto", + maxLength: 1500, + summarize: true, + }), +})); +vi.mock("./dispatch-acp-tts.runtime.js", () => ({ + maybeApplyTtsToPayload: (params: unknown) => ttsMocks.maybeApplyTtsToPayload(params), +})); +vi.mock("./dispatch-acp-session.runtime.js", () => ({ + readAcpSessionEntry: (params: { sessionKey: string; cfg?: OpenClawConfig }) => + acpMocks.readAcpSessionEntry(params), +})); +vi.mock("../../tts/tts-config.js", () => ({ + normalizeTtsAutoMode: (value: unknown) => ttsMocks.normalizeTtsAutoMode(value), + resolveConfiguredTtsMode: (cfg: OpenClawConfig) => ttsMocks.resolveTtsConfig(cfg).mode, +})); + +export const noAbortResult = { handled: false, aborted: false } as const; +export const emptyConfig = {} as OpenClawConfig; + +export function createDispatcher(): ReplyDispatcher { + const acceptReply = () => true; + const emptyCounts = () => ({ tool: 0, block: 0, final: 0 }); + return { + sendToolResult: vi.fn(acceptReply), + sendBlockReply: vi.fn(acceptReply), + sendFinalReply: vi.fn(acceptReply), + waitForIdle: vi.fn(async () => {}), + getQueuedCounts: vi.fn(emptyCounts), + getFailedCounts: vi.fn(emptyCounts), + markComplete: vi.fn(), + }; +} + +export function resetPluginTtsAndThreadMocks() { + pluginConversationBindingMocks.shownFallbackNoticeBindingIds.clear(); + ttsMocks.maybeApplyTtsToPayload.mockReset().mockImplementation(async (paramsUnknown: unknown) => { + const params = paramsUnknown as { payload: ReplyPayload }; + return params.payload; + }); + ttsMocks.normalizeTtsAutoMode + .mockReset() + .mockImplementation((value: unknown) => (typeof value === "string" ? value : undefined)); + ttsMocks.resolveTtsConfig.mockReset().mockReturnValue({ mode: "final" }); + threadInfoMocks.parseSessionThreadInfo + .mockReset() + .mockImplementation(parseGenericThreadSessionInfo); +} + +export function setDiscordTestRegistry() { + const discordTestPlugin = { + ...createChannelTestPluginBase({ + id: "discord", + capabilities: { chatTypes: ["direct"], nativeCommands: true }, + }), + outbound: { + deliveryMode: "direct", + shouldSuppressLocalPayloadPrompt: () => false, + }, + }; + setActivePluginRegistry( + createTestRegistry([{ pluginId: "discord", source: "test", plugin: discordTestPlugin }]), + ); +} + +export function createHookCtx() { + return buildTestCtx({ + Body: "hello", + BodyForAgent: "hello", + BodyForCommands: "hello", + From: "user1", + Surface: "telegram", + ChatType: "private", + SessionKey: "agent:test:session", + }); +} diff --git a/src/auto-reply/reply/queue.collect.test.ts b/src/auto-reply/reply/queue.collect.test.ts index fc2f0119b8d..dd865fba3a2 100644 --- a/src/auto-reply/reply/queue.collect.test.ts +++ b/src/auto-reply/reply/queue.collect.test.ts @@ -1,60 +1,13 @@ -import { afterAll, beforeAll, describe, expect, it } from "vitest"; -import type { OpenClawConfig } from "../../config/config.js"; -import { defaultRuntime } from "../../runtime.js"; +import { describe, expect, it } from "vitest"; import type { FollowupRun, QueueSettings } from "./queue.js"; import { enqueueFollowupRun, scheduleFollowupDrain } from "./queue.js"; +import { + createDeferred, + createQueueTestRun as createRun, + installQueueRuntimeErrorSilencer, +} from "./queue.test-helpers.js"; -function createDeferred() { - let resolve!: (value: T) => void; - let reject!: (reason?: unknown) => void; - const promise = new Promise((res, rej) => { - resolve = res; - reject = rej; - }); - return { promise, resolve, reject }; -} - -function createRun(params: { - prompt: string; - messageId?: string; - originatingChannel?: FollowupRun["originatingChannel"]; - originatingTo?: string; - originatingAccountId?: string; - originatingThreadId?: string | number; -}): FollowupRun { - return { - prompt: params.prompt, - messageId: params.messageId, - enqueuedAt: Date.now(), - originatingChannel: params.originatingChannel, - originatingTo: params.originatingTo, - originatingAccountId: params.originatingAccountId, - originatingThreadId: params.originatingThreadId, - run: { - agentId: "agent", - agentDir: "/tmp", - sessionId: "sess", - sessionFile: "/tmp/session.json", - workspaceDir: "/tmp", - config: {} as OpenClawConfig, - provider: "openai", - model: "gpt-test", - timeoutMs: 10_000, - blockReplyBreak: "text_end", - }, - }; -} - -let previousRuntimeError: typeof defaultRuntime.error; - -beforeAll(() => { - previousRuntimeError = defaultRuntime.error; - defaultRuntime.error = (() => {}) as typeof defaultRuntime.error; -}); - -afterAll(() => { - defaultRuntime.error = previousRuntimeError; -}); +installQueueRuntimeErrorSilencer(); describe("followup queue collect routing", () => { it("does not collect when destinations differ", async () => { diff --git a/src/auto-reply/reply/queue.dedupe.test.ts b/src/auto-reply/reply/queue.dedupe.test.ts index 5e5c7a8799f..5148bcd3c5e 100644 --- a/src/auto-reply/reply/queue.dedupe.test.ts +++ b/src/auto-reply/reply/queue.dedupe.test.ts @@ -1,65 +1,18 @@ -import { afterAll, beforeAll, beforeEach, describe, expect, it } from "vitest"; +import { beforeEach, describe, expect, it } from "vitest"; import { importFreshModule } from "../../../test/helpers/import-fresh.js"; -import type { OpenClawConfig } from "../../config/config.js"; -import { defaultRuntime } from "../../runtime.js"; import type { FollowupRun, QueueSettings } from "./queue.js"; import { enqueueFollowupRun, resetRecentQueuedMessageIdDedupe, scheduleFollowupDrain, } from "./queue.js"; +import { + createDeferred, + createQueueTestRun as createRun, + installQueueRuntimeErrorSilencer, +} from "./queue.test-helpers.js"; -function createDeferred() { - let resolve!: (value: T) => void; - let reject!: (reason?: unknown) => void; - const promise = new Promise((res, rej) => { - resolve = res; - reject = rej; - }); - return { promise, resolve, reject }; -} - -function createRun(params: { - prompt: string; - messageId?: string; - originatingChannel?: FollowupRun["originatingChannel"]; - originatingTo?: string; - originatingAccountId?: string; - originatingThreadId?: string | number; -}): FollowupRun { - return { - prompt: params.prompt, - messageId: params.messageId, - enqueuedAt: Date.now(), - originatingChannel: params.originatingChannel, - originatingTo: params.originatingTo, - originatingAccountId: params.originatingAccountId, - originatingThreadId: params.originatingThreadId, - run: { - agentId: "agent", - agentDir: "/tmp", - sessionId: "sess", - sessionFile: "/tmp/session.json", - workspaceDir: "/tmp", - config: {} as OpenClawConfig, - provider: "openai", - model: "gpt-test", - timeoutMs: 10_000, - blockReplyBreak: "text_end", - }, - }; -} - -let previousRuntimeError: typeof defaultRuntime.error; - -beforeAll(() => { - previousRuntimeError = defaultRuntime.error; - defaultRuntime.error = (() => {}) as typeof defaultRuntime.error; -}); - -afterAll(() => { - defaultRuntime.error = previousRuntimeError; -}); +installQueueRuntimeErrorSilencer(); describe("followup queue deduplication", () => { beforeEach(() => { diff --git a/src/auto-reply/reply/queue.drain-restart.test.ts b/src/auto-reply/reply/queue.drain-restart.test.ts index f9748f874e6..a06191c5d49 100644 --- a/src/auto-reply/reply/queue.drain-restart.test.ts +++ b/src/auto-reply/reply/queue.drain-restart.test.ts @@ -1,61 +1,14 @@ -import { afterAll, beforeAll, describe, expect, it, vi } from "vitest"; +import { describe, expect, it, vi } from "vitest"; import { importFreshModule } from "../../../test/helpers/import-fresh.js"; -import type { OpenClawConfig } from "../../config/config.js"; -import { defaultRuntime } from "../../runtime.js"; import type { FollowupRun, QueueSettings } from "./queue.js"; import { enqueueFollowupRun, scheduleFollowupDrain } from "./queue.js"; +import { + createDeferred, + createQueueTestRun as createRun, + installQueueRuntimeErrorSilencer, +} from "./queue.test-helpers.js"; -function createDeferred() { - let resolve!: (value: T) => void; - let reject!: (reason?: unknown) => void; - const promise = new Promise((res, rej) => { - resolve = res; - reject = rej; - }); - return { promise, resolve, reject }; -} - -function createRun(params: { - prompt: string; - messageId?: string; - originatingChannel?: FollowupRun["originatingChannel"]; - originatingTo?: string; - originatingAccountId?: string; - originatingThreadId?: string | number; -}): FollowupRun { - return { - prompt: params.prompt, - messageId: params.messageId, - enqueuedAt: Date.now(), - originatingChannel: params.originatingChannel, - originatingTo: params.originatingTo, - originatingAccountId: params.originatingAccountId, - originatingThreadId: params.originatingThreadId, - run: { - agentId: "agent", - agentDir: "/tmp", - sessionId: "sess", - sessionFile: "/tmp/session.json", - workspaceDir: "/tmp", - config: {} as OpenClawConfig, - provider: "openai", - model: "gpt-test", - timeoutMs: 10_000, - blockReplyBreak: "text_end", - }, - }; -} - -let previousRuntimeError: typeof defaultRuntime.error; - -beforeAll(() => { - previousRuntimeError = defaultRuntime.error; - defaultRuntime.error = (() => {}) as typeof defaultRuntime.error; -}); - -afterAll(() => { - defaultRuntime.error = previousRuntimeError; -}); +installQueueRuntimeErrorSilencer(); describe("followup queue drain restart after idle window", () => { it("does not retain stale callbacks when scheduleFollowupDrain runs with an empty queue", async () => { diff --git a/src/auto-reply/reply/queue.test-helpers.ts b/src/auto-reply/reply/queue.test-helpers.ts new file mode 100644 index 00000000000..8a24f21a772 --- /dev/null +++ b/src/auto-reply/reply/queue.test-helpers.ts @@ -0,0 +1,58 @@ +import { afterAll, beforeAll } from "vitest"; +import type { OpenClawConfig } from "../../config/config.js"; +import { defaultRuntime } from "../../runtime.js"; +import type { FollowupRun } from "./queue.js"; + +export function createDeferred() { + let resolve!: (value: T) => void; + let reject!: (reason?: unknown) => void; + const promise = new Promise((res, rej) => { + resolve = res; + reject = rej; + }); + return { promise, resolve, reject }; +} + +export function createQueueTestRun(params: { + prompt: string; + messageId?: string; + originatingChannel?: FollowupRun["originatingChannel"]; + originatingTo?: string; + originatingAccountId?: string; + originatingThreadId?: string | number; +}): FollowupRun { + return { + prompt: params.prompt, + messageId: params.messageId, + enqueuedAt: Date.now(), + originatingChannel: params.originatingChannel, + originatingTo: params.originatingTo, + originatingAccountId: params.originatingAccountId, + originatingThreadId: params.originatingThreadId, + run: { + agentId: "agent", + agentDir: "/tmp", + sessionId: "sess", + sessionFile: "/tmp/session.json", + workspaceDir: "/tmp", + config: {} as OpenClawConfig, + provider: "openai", + model: "gpt-test", + timeoutMs: 10_000, + blockReplyBreak: "text_end", + }, + }; +} + +export function installQueueRuntimeErrorSilencer(): void { + let previousRuntimeError: typeof defaultRuntime.error; + + beforeAll(() => { + previousRuntimeError = defaultRuntime.error; + defaultRuntime.error = (() => {}) as typeof defaultRuntime.error; + }); + + afterAll(() => { + defaultRuntime.error = previousRuntimeError; + }); +} diff --git a/src/cli/plugins-cli.install.test.ts b/src/cli/plugins-cli.install.test.ts index f3ddea5c9a1..6a9c552aad4 100644 --- a/src/cli/plugins-cli.install.test.ts +++ b/src/cli/plugins-cli.install.test.ts @@ -45,6 +45,14 @@ function createEnabledPluginConfig(pluginId: string): OpenClawConfig { } as OpenClawConfig; } +function createEmptyPluginConfig(): OpenClawConfig { + return { + plugins: { + entries: {}, + }, + } as OpenClawConfig; +} + function createClawHubInstalledConfig(params: { pluginId: string; install: Record; @@ -86,6 +94,134 @@ function createClawHubInstallResult(params: { }; } +function createNpmPluginInstallResult( + pluginId = "demo", +): Awaited> { + return { + ok: true, + pluginId, + targetDir: cliInstallPath(pluginId), + version: "1.2.3", + npmResolution: { + packageName: pluginId, + resolvedVersion: "1.2.3", + tarballUrl: `https://registry.npmjs.org/${pluginId}/-/${pluginId}-1.2.3.tgz`, + }, + }; +} + +function mockClawHubPackageNotFound(packageName: string) { + installPluginFromClawHub.mockResolvedValue({ + ok: false, + error: `ClawHub /api/v1/packages/${packageName} failed (404): Package not found`, + code: "package_not_found", + }); +} + +function primeNpmPluginFallback(pluginId = "demo") { + const cfg = createEmptyPluginConfig(); + const enabledCfg = createEnabledPluginConfig(pluginId); + + loadConfig.mockReturnValue(cfg); + mockClawHubPackageNotFound(pluginId); + installPluginFromNpmSpec.mockResolvedValue(createNpmPluginInstallResult(pluginId)); + enablePluginInConfig.mockReturnValue({ config: enabledCfg }); + recordPluginInstall.mockReturnValue(enabledCfg); + applyExclusiveSlotSelection.mockReturnValue({ + config: enabledCfg, + warnings: [], + }); + + return { cfg, enabledCfg }; +} + +function createPathHookPackInstalledConfig(tmpRoot: string): OpenClawConfig { + return { + hooks: { + internal: { + installs: { + "demo-hooks": { + source: "path", + sourcePath: tmpRoot, + installPath: tmpRoot, + }, + }, + }, + }, + } as OpenClawConfig; +} + +function createNpmHookPackInstalledConfig(): OpenClawConfig { + return { + hooks: { + internal: { + installs: { + "demo-hooks": { + source: "npm", + spec: "@acme/demo-hooks@1.2.3", + }, + }, + }, + }, + } as OpenClawConfig; +} + +function createHookPackInstallResult(targetDir: string): { + ok: true; + hookPackId: string; + hooks: string[]; + targetDir: string; + version: string; +} { + return { + ok: true, + hookPackId: "demo-hooks", + hooks: ["command-audit"], + targetDir, + version: "1.2.3", + }; +} + +function primeHookPackNpmFallback() { + const cfg = {} as OpenClawConfig; + const installedCfg = createNpmHookPackInstalledConfig(); + + loadConfig.mockReturnValue(cfg); + mockClawHubPackageNotFound("@acme/demo-hooks"); + installPluginFromNpmSpec.mockResolvedValue({ + ok: false, + error: "package.json missing openclaw.plugin.json", + }); + installHooksFromNpmSpec.mockResolvedValue({ + ...createHookPackInstallResult("/tmp/hooks/demo-hooks"), + npmResolution: { + name: "@acme/demo-hooks", + spec: "@acme/demo-hooks@1.2.3", + integrity: "sha256-demo", + }, + }); + recordHookInstall.mockReturnValue(installedCfg); + + return { cfg, installedCfg }; +} + +function primeHookPackPathFallback(params: { + tmpRoot: string; + pluginInstallError: string; +}): OpenClawConfig { + const installedCfg = createPathHookPackInstalledConfig(params.tmpRoot); + + loadConfig.mockReturnValue({} as OpenClawConfig); + installPluginFromPath.mockResolvedValueOnce({ + ok: false, + error: params.pluginInstallError, + }); + installHooksFromPath.mockResolvedValueOnce(createHookPackInstallResult(params.tmpRoot)); + recordHookInstall.mockReturnValue(installedCfg); + + return installedCfg; +} + describe("plugins cli install", () => { beforeEach(() => { resetPluginsCliTestState(); @@ -380,44 +516,7 @@ describe("plugins cli install", () => { }); it("falls back to npm when ClawHub does not have the package", async () => { - const cfg = { - plugins: { - entries: {}, - }, - } as OpenClawConfig; - const enabledCfg = { - plugins: { - entries: { - demo: { - enabled: true, - }, - }, - }, - } as OpenClawConfig; - - loadConfig.mockReturnValue(cfg); - installPluginFromClawHub.mockResolvedValue({ - ok: false, - error: "ClawHub /api/v1/packages/demo failed (404): Package not found", - code: "package_not_found", - }); - installPluginFromNpmSpec.mockResolvedValue({ - ok: true, - pluginId: "demo", - targetDir: cliInstallPath("demo"), - version: "1.2.3", - npmResolution: { - packageName: "demo", - resolvedVersion: "1.2.3", - tarballUrl: "https://registry.npmjs.org/demo/-/demo-1.2.3.tgz", - }, - }); - enablePluginInConfig.mockReturnValue({ config: enabledCfg }); - recordPluginInstall.mockReturnValue(enabledCfg); - applyExclusiveSlotSelection.mockReturnValue({ - config: enabledCfg, - warnings: [], - }); + primeNpmPluginFallback(); await runPluginsCommand(["plugins", "install", "demo"]); @@ -455,36 +554,7 @@ describe("plugins cli install", () => { }); it("passes dangerous force unsafe install to npm installs", async () => { - const cfg = { - plugins: { - entries: {}, - }, - } as OpenClawConfig; - const enabledCfg = createEnabledPluginConfig("demo"); - - loadConfig.mockReturnValue(cfg); - installPluginFromClawHub.mockResolvedValue({ - ok: false, - error: "ClawHub /api/v1/packages/demo failed (404): Package not found", - code: "package_not_found", - }); - installPluginFromNpmSpec.mockResolvedValue({ - ok: true, - pluginId: "demo", - targetDir: cliInstallPath("demo"), - version: "1.2.3", - npmResolution: { - packageName: "demo", - resolvedVersion: "1.2.3", - tarballUrl: "https://registry.npmjs.org/demo/-/demo-1.2.3.tgz", - }, - }); - enablePluginInConfig.mockReturnValue({ config: enabledCfg }); - recordPluginInstall.mockReturnValue(enabledCfg); - applyExclusiveSlotSelection.mockReturnValue({ - config: enabledCfg, - warnings: [], - }); + primeNpmPluginFallback(); await runPluginsCommand(["plugins", "install", "demo", "--dangerously-force-unsafe-install"]); @@ -542,35 +612,11 @@ describe("plugins cli install", () => { }); it("passes dangerous force unsafe install to linked hook-pack probe fallback", async () => { - const cfg = {} as OpenClawConfig; const tmpRoot = fs.mkdtempSync(path.join(os.tmpdir(), "openclaw-hook-link-")); - const installedCfg = { - hooks: { - internal: { - installs: { - "demo-hooks": { - source: "path", - sourcePath: tmpRoot, - installPath: tmpRoot, - }, - }, - }, - }, - } as OpenClawConfig; - - loadConfig.mockReturnValue(cfg); - installPluginFromPath.mockResolvedValueOnce({ - ok: false, - error: "plugin install probe failed", + primeHookPackPathFallback({ + tmpRoot, + pluginInstallError: "plugin install probe failed", }); - installHooksFromPath.mockResolvedValueOnce({ - ok: true, - hookPackId: "demo-hooks", - hooks: ["command-audit"], - targetDir: tmpRoot, - version: "1.2.3", - }); - recordHookInstall.mockReturnValue(installedCfg); try { await runPluginsCommand([ @@ -594,35 +640,11 @@ describe("plugins cli install", () => { }); it("passes dangerous force unsafe install to local hook-pack fallback installs", async () => { - const cfg = {} as OpenClawConfig; const tmpRoot = fs.mkdtempSync(path.join(os.tmpdir(), "openclaw-hook-install-")); - const installedCfg = { - hooks: { - internal: { - installs: { - "demo-hooks": { - source: "path", - sourcePath: tmpRoot, - installPath: tmpRoot, - }, - }, - }, - }, - } as OpenClawConfig; - - loadConfig.mockReturnValue(cfg); - installPluginFromPath.mockResolvedValueOnce({ - ok: false, - error: "plugin install failed", + primeHookPackPathFallback({ + tmpRoot, + pluginInstallError: "plugin install failed", }); - installHooksFromPath.mockResolvedValueOnce({ - ok: true, - hookPackId: "demo-hooks", - hooks: ["command-audit"], - targetDir: tmpRoot, - version: "1.2.3", - }); - recordHookInstall.mockReturnValue(installedCfg); try { await runPluginsCommand([ @@ -644,36 +666,7 @@ describe("plugins cli install", () => { ); }); it("passes force through as overwrite mode for npm installs", async () => { - const cfg = { - plugins: { - entries: {}, - }, - } as OpenClawConfig; - const enabledCfg = createEnabledPluginConfig("demo"); - - loadConfig.mockReturnValue(cfg); - installPluginFromClawHub.mockResolvedValue({ - ok: false, - error: "ClawHub /api/v1/packages/demo failed (404): Package not found", - code: "package_not_found", - }); - installPluginFromNpmSpec.mockResolvedValue({ - ok: true, - pluginId: "demo", - targetDir: cliInstallPath("demo"), - version: "1.2.3", - npmResolution: { - packageName: "demo", - resolvedVersion: "1.2.3", - tarballUrl: "https://registry.npmjs.org/demo/-/demo-1.2.3.tgz", - }, - }); - enablePluginInConfig.mockReturnValue({ config: enabledCfg }); - recordPluginInstall.mockReturnValue(enabledCfg); - applyExclusiveSlotSelection.mockReturnValue({ - config: enabledCfg, - warnings: [], - }); + primeNpmPluginFallback(); await runPluginsCommand(["plugins", "install", "demo", "--force"]); @@ -699,43 +692,7 @@ describe("plugins cli install", () => { }); it("falls back to installing hook packs from npm specs", async () => { - const cfg = {} as OpenClawConfig; - const installedCfg = { - hooks: { - internal: { - installs: { - "demo-hooks": { - source: "npm", - spec: "@acme/demo-hooks@1.2.3", - }, - }, - }, - }, - } as OpenClawConfig; - - loadConfig.mockReturnValue(cfg); - installPluginFromClawHub.mockResolvedValue({ - ok: false, - error: "ClawHub /api/v1/packages/@acme/demo-hooks failed (404): Package not found", - code: "package_not_found", - }); - installPluginFromNpmSpec.mockResolvedValue({ - ok: false, - error: "package.json missing openclaw.plugin.json", - }); - installHooksFromNpmSpec.mockResolvedValue({ - ok: true, - hookPackId: "demo-hooks", - hooks: ["command-audit"], - targetDir: "/tmp/hooks/demo-hooks", - version: "1.2.3", - npmResolution: { - name: "@acme/demo-hooks", - spec: "@acme/demo-hooks@1.2.3", - integrity: "sha256-demo", - }, - }); - recordHookInstall.mockReturnValue(installedCfg); + const { installedCfg } = primeHookPackNpmFallback(); await runPluginsCommand(["plugins", "install", "@acme/demo-hooks"]); @@ -756,43 +713,7 @@ describe("plugins cli install", () => { }); it("passes force through as overwrite mode for hook-pack npm fallback installs", async () => { - const cfg = {} as OpenClawConfig; - const installedCfg = { - hooks: { - internal: { - installs: { - "demo-hooks": { - source: "npm", - spec: "@acme/demo-hooks@1.2.3", - }, - }, - }, - }, - } as OpenClawConfig; - - loadConfig.mockReturnValue(cfg); - installPluginFromClawHub.mockResolvedValue({ - ok: false, - error: "ClawHub /api/v1/packages/@acme/demo-hooks failed (404): Package not found", - code: "package_not_found", - }); - installPluginFromNpmSpec.mockResolvedValue({ - ok: false, - error: "package.json missing openclaw.plugin.json", - }); - installHooksFromNpmSpec.mockResolvedValue({ - ok: true, - hookPackId: "demo-hooks", - hooks: ["command-audit"], - targetDir: "/tmp/hooks/demo-hooks", - version: "1.2.3", - npmResolution: { - name: "@acme/demo-hooks", - spec: "@acme/demo-hooks@1.2.3", - integrity: "sha256-demo", - }, - }); - recordHookInstall.mockReturnValue(installedCfg); + primeHookPackNpmFallback(); await runPluginsCommand(["plugins", "install", "@acme/demo-hooks", "--force"]); diff --git a/src/commands/onboard-channels.post-write.test.ts b/src/commands/onboard-channels.post-write.test.ts index 2bc7348eec0..67c65891c34 100644 --- a/src/commands/onboard-channels.post-write.test.ts +++ b/src/commands/onboard-channels.post-write.test.ts @@ -3,8 +3,7 @@ import type { OpenClawConfig } from "../config/config.js"; import { setActivePluginRegistry } from "../plugins/runtime.js"; import { createChannelTestPluginBase, createTestRegistry } from "../test-utils/channel-plugins.js"; import type { WizardPrompter } from "../wizard/prompts.js"; -import { getChannelSetupWizardAdapter } from "./channel-setup/registry.js"; -import type { ChannelSetupWizardAdapter } from "./channel-setup/types.js"; +import { patchChannelSetupWizardAdapter } from "./channel-test-helpers.js"; import { createChannelOnboardingPostWriteHookCollector, runCollectedChannelOnboardingPostWriteHooks, @@ -43,73 +42,6 @@ function setMinimalTelegramOnboardingRegistryForTests(): void { ); } -type ChannelSetupWizardAdapterPatch = Partial< - Pick< - ChannelSetupWizardAdapter, - | "afterConfigWritten" - | "configure" - | "configureInteractive" - | "configureWhenConfigured" - | "getStatus" - > ->; - -type PatchedSetupAdapterFields = { - afterConfigWritten?: ChannelSetupWizardAdapter["afterConfigWritten"]; - configure?: ChannelSetupWizardAdapter["configure"]; - configureInteractive?: ChannelSetupWizardAdapter["configureInteractive"]; - configureWhenConfigured?: ChannelSetupWizardAdapter["configureWhenConfigured"]; - getStatus?: ChannelSetupWizardAdapter["getStatus"]; -}; - -function patchChannelOnboardingAdapterForTest(patch: ChannelSetupWizardAdapterPatch): () => void { - const adapter = getChannelSetupWizardAdapter("telegram"); - if (!adapter) { - throw new Error("missing setup adapter for telegram"); - } - - const previous: PatchedSetupAdapterFields = {}; - - if (Object.prototype.hasOwnProperty.call(patch, "getStatus")) { - previous.getStatus = adapter.getStatus; - adapter.getStatus = patch.getStatus ?? adapter.getStatus; - } - if (Object.prototype.hasOwnProperty.call(patch, "afterConfigWritten")) { - previous.afterConfigWritten = adapter.afterConfigWritten; - adapter.afterConfigWritten = patch.afterConfigWritten; - } - if (Object.prototype.hasOwnProperty.call(patch, "configure")) { - previous.configure = adapter.configure; - adapter.configure = patch.configure ?? adapter.configure; - } - if (Object.prototype.hasOwnProperty.call(patch, "configureInteractive")) { - previous.configureInteractive = adapter.configureInteractive; - adapter.configureInteractive = patch.configureInteractive; - } - if (Object.prototype.hasOwnProperty.call(patch, "configureWhenConfigured")) { - previous.configureWhenConfigured = adapter.configureWhenConfigured; - adapter.configureWhenConfigured = patch.configureWhenConfigured; - } - - return () => { - if (Object.prototype.hasOwnProperty.call(patch, "getStatus")) { - adapter.getStatus = previous.getStatus!; - } - if (Object.prototype.hasOwnProperty.call(patch, "afterConfigWritten")) { - adapter.afterConfigWritten = previous.afterConfigWritten; - } - if (Object.prototype.hasOwnProperty.call(patch, "configure")) { - adapter.configure = previous.configure!; - } - if (Object.prototype.hasOwnProperty.call(patch, "configureInteractive")) { - adapter.configureInteractive = previous.configureInteractive; - } - if (Object.prototype.hasOwnProperty.call(patch, "configureWhenConfigured")) { - adapter.configureWhenConfigured = previous.configureWhenConfigured; - } - }; -} - function createPrompter(overrides: Partial): WizardPrompter { return createWizardPrompter( { @@ -159,7 +91,7 @@ describe("setupChannels post-write hooks", () => { } as OpenClawConfig, accountId: "acct-1", })); - const restore = patchChannelOnboardingAdapterForTest({ + const restore = patchChannelSetupWizardAdapter("telegram", { configureInteractive, afterConfigWritten, getStatus: vi.fn(async ({ cfg }: { cfg: OpenClawConfig }) => ({ diff --git a/src/commands/status-overview-surface.test.ts b/src/commands/status-overview-surface.test.ts index 66a8d064257..2324ee4c1a1 100644 --- a/src/commands/status-overview-surface.test.ts +++ b/src/commands/status-overview-surface.test.ts @@ -6,158 +6,85 @@ import { buildStatusOverviewSurfaceFromScan, } from "./status-overview-surface.ts"; +const baseCfg = { update: { channel: "stable" }, gateway: { bind: "loopback" } } as const; +const baseUpdate = { installKind: "git", git: { branch: "main", tag: "v1.2.3" } } as never; +const baseGatewaySnapshot = { + gatewayMode: "remote", + remoteUrlMissing: false, + gatewayConnection: { + url: "wss://gateway.example.com", + urlSource: "config", + message: "Gateway target: wss://gateway.example.com", + }, + gatewayReachable: true, + gatewayProbe: { connectLatencyMs: 42, error: null } as never, + gatewayProbeAuth: { token: "tok" }, + gatewayProbeAuthWarning: "warn-text", + gatewaySelf: { host: "gateway", version: "1.2.3" }, +} as const; +const baseScanFields = { + cfg: baseCfg, + update: baseUpdate, + tailscaleMode: "serve", + tailscaleDns: "box.tail.ts.net", + tailscaleHttpsUrl: "https://box.tail.ts.net", + ...baseGatewaySnapshot, +}; +const baseGatewayService = { + label: "LaunchAgent", + installed: true, + managedByOpenClaw: true, + loadedText: "loaded", + runtimeShort: "running", +}; +const baseNodeService = { + label: "node", + installed: true, + loadedText: "loaded", + runtime: { status: "running", pid: 42 }, +}; +const baseServices = { + gatewayService: baseGatewayService, + nodeService: baseNodeService, + nodeOnlyGateway: null, +}; +const baseOverviewSurface = { + ...baseScanFields, + ...baseServices, +}; + describe("status-overview-surface", () => { it("builds the shared overview surface from a status scan result", () => { expect( buildStatusOverviewSurfaceFromScan({ - scan: { - cfg: { update: { channel: "stable" }, gateway: { bind: "loopback" } }, - update: { installKind: "git", git: { branch: "main", tag: "v1.2.3" } } as never, - tailscaleMode: "serve", - tailscaleDns: "box.tail.ts.net", - tailscaleHttpsUrl: "https://box.tail.ts.net", - gatewayMode: "remote", - remoteUrlMissing: false, - gatewayConnection: { - url: "wss://gateway.example.com", - urlSource: "config", - message: "Gateway target: wss://gateway.example.com", - }, - gatewayReachable: true, - gatewayProbe: { connectLatencyMs: 42, error: null } as never, - gatewayProbeAuth: { token: "tok" }, - gatewayProbeAuthWarning: "warn-text", - gatewaySelf: { host: "gateway", version: "1.2.3" }, - }, - gatewayService: { - label: "LaunchAgent", - installed: true, - managedByOpenClaw: true, - loadedText: "loaded", - runtimeShort: "running", - }, - nodeService: { - label: "node", - installed: true, - loadedText: "loaded", - runtime: { status: "running", pid: 42 }, - }, - nodeOnlyGateway: null, + scan: baseScanFields, + ...baseServices, }), - ).toEqual({ - cfg: { update: { channel: "stable" }, gateway: { bind: "loopback" } }, - update: { installKind: "git", git: { branch: "main", tag: "v1.2.3" } }, - tailscaleMode: "serve", - tailscaleDns: "box.tail.ts.net", - tailscaleHttpsUrl: "https://box.tail.ts.net", - gatewayMode: "remote", - remoteUrlMissing: false, - gatewayConnection: { - url: "wss://gateway.example.com", - urlSource: "config", - message: "Gateway target: wss://gateway.example.com", - }, - gatewayReachable: true, - gatewayProbe: { connectLatencyMs: 42, error: null } as never, - gatewayProbeAuth: { token: "tok" }, - gatewayProbeAuthWarning: "warn-text", - gatewaySelf: { host: "gateway", version: "1.2.3" }, - gatewayService: { - label: "LaunchAgent", - installed: true, - managedByOpenClaw: true, - loadedText: "loaded", - runtimeShort: "running", - }, - nodeService: { - label: "node", - installed: true, - loadedText: "loaded", - runtime: { status: "running", pid: 42 }, - }, - nodeOnlyGateway: null, - }); + ).toEqual(baseOverviewSurface); }); it("builds the shared overview surface from scan overview data", () => { expect( buildStatusOverviewSurfaceFromOverview({ overview: { - cfg: { update: { channel: "stable" }, gateway: { bind: "loopback" } }, - update: { installKind: "git", git: { branch: "main", tag: "v1.2.3" } } as never, + cfg: baseCfg, + update: baseUpdate, tailscaleMode: "serve", tailscaleDns: "box.tail.ts.net", tailscaleHttpsUrl: "https://box.tail.ts.net", - gatewaySnapshot: { - gatewayMode: "remote", - remoteUrlMissing: false, - gatewayConnection: { - url: "wss://gateway.example.com", - urlSource: "config", - message: "Gateway target: wss://gateway.example.com", - }, - gatewayReachable: true, - gatewayProbe: { connectLatencyMs: 42, error: null } as never, - gatewayProbeAuth: { token: "tok" }, - gatewayProbeAuthWarning: "warn-text", - gatewaySelf: { host: "gateway", version: "1.2.3" }, - }, + gatewaySnapshot: baseGatewaySnapshot, } as never, - gatewayService: { - label: "LaunchAgent", - installed: true, - managedByOpenClaw: true, - loadedText: "loaded", - runtimeShort: "running", - }, - nodeService: { - label: "node", - installed: true, - loadedText: "loaded", - runtime: { status: "running", pid: 42 }, - }, - nodeOnlyGateway: null, + ...baseServices, }), - ).toEqual({ - cfg: { update: { channel: "stable" }, gateway: { bind: "loopback" } }, - update: { installKind: "git", git: { branch: "main", tag: "v1.2.3" } }, - tailscaleMode: "serve", - tailscaleDns: "box.tail.ts.net", - tailscaleHttpsUrl: "https://box.tail.ts.net", - gatewayMode: "remote", - remoteUrlMissing: false, - gatewayConnection: { - url: "wss://gateway.example.com", - urlSource: "config", - message: "Gateway target: wss://gateway.example.com", - }, - gatewayReachable: true, - gatewayProbe: { connectLatencyMs: 42, error: null } as never, - gatewayProbeAuth: { token: "tok" }, - gatewayProbeAuthWarning: "warn-text", - gatewaySelf: { host: "gateway", version: "1.2.3" }, - gatewayService: { - label: "LaunchAgent", - installed: true, - managedByOpenClaw: true, - loadedText: "loaded", - runtimeShort: "running", - }, - nodeService: { - label: "node", - installed: true, - loadedText: "loaded", - runtime: { status: "running", pid: 42 }, - }, - nodeOnlyGateway: null, - }); + ).toEqual(baseOverviewSurface); }); it("builds overview rows from the shared surface bundle", () => { expect( buildStatusOverviewRowsFromSurface({ surface: { - cfg: { update: { channel: "stable" }, gateway: { bind: "loopback" } }, + ...baseOverviewSurface, + cfg: baseCfg, update: { installKind: "git", git: { @@ -172,33 +99,11 @@ describe("status-overview-surface", () => { registry: { latestVersion: "2026.4.9" }, } as never, tailscaleMode: "off", - tailscaleDns: "box.tail.ts.net", tailscaleHttpsUrl: null, - gatewayMode: "remote", - remoteUrlMissing: false, gatewayConnection: { url: "wss://gateway.example.com", urlSource: "config", }, - gatewayReachable: true, - gatewayProbe: { connectLatencyMs: 42, error: null } as never, - gatewayProbeAuth: { token: "tok" }, - gatewayProbeAuthWarning: "warn-text", - gatewaySelf: { host: "gateway", version: "1.2.3" }, - gatewayService: { - label: "LaunchAgent", - installed: true, - managedByOpenClaw: true, - loadedText: "loaded", - runtimeShort: "running", - }, - nodeService: { - label: "node", - installed: true, - loadedText: "loaded", - runtime: { status: "running", pid: 42 }, - }, - nodeOnlyGateway: null, }, prefixRows: [{ Item: "OS", Value: "macOS · node 22" }], suffixRows: [{ Item: "Secrets", Value: "none" }], diff --git a/src/commands/status.scan-result.test.ts b/src/commands/status.scan-result.test.ts index 782c29362bd..67d19429456 100644 --- a/src/commands/status.scan-result.test.ts +++ b/src/commands/status.scan-result.test.ts @@ -4,111 +4,25 @@ import { buildColdStartStatusSummary } from "./status.scan.bootstrap-shared.ts"; describe("buildStatusScanResult", () => { it("builds the full shared scan result shape", () => { - expect( - buildStatusScanResult({ - cfg: { gateway: {} }, - sourceConfig: { gateway: {} }, - secretDiagnostics: ["diag"], - osSummary: { - platform: "linux", - arch: "x64", - release: "6.8.0", - label: "linux 6.8.0 (x64)", - }, - tailscaleMode: "serve", - tailscaleDns: "box.tail.ts.net", - tailscaleHttpsUrl: "https://box.tail.ts.net", - update: { - root: "/tmp/openclaw", - installKind: "package", - packageManager: "npm", - }, - gatewaySnapshot: { - gatewayConnection: { - url: "ws://127.0.0.1:18789", - urlSource: "config", - message: "Gateway target: ws://127.0.0.1:18789", - }, - remoteUrlMissing: false, - gatewayMode: "local", - gatewayProbeAuth: { token: "tok" }, - gatewayProbeAuthWarning: "warn", - gatewayProbe: { - ok: true, - url: "ws://127.0.0.1:18789", - connectLatencyMs: 42, - error: null, - close: null, - health: null, - status: null, - presence: null, - configSnapshot: null, - }, - gatewayReachable: true, - gatewaySelf: { host: "gateway" }, - }, - channelIssues: [ - { - channel: "discord", - accountId: "default", - kind: "runtime", - message: "warn", - }, - ], - agentStatus: { - defaultId: "main", - totalSessions: 0, - bootstrapPendingCount: 0, - agents: [ - { - id: "main", - workspaceDir: null, - bootstrapPending: false, - sessionsPath: "/tmp/main.json", - sessionsCount: 0, - lastUpdatedAt: null, - lastActiveAgeMs: null, - }, - ], - }, - channels: { rows: [], details: [] }, - summary: buildColdStartStatusSummary(), - memory: { agentId: "main", backend: "builtin", provider: "sqlite" }, - memoryPlugin: { enabled: true, slot: "memory-core" }, - pluginCompatibility: [ - { - pluginId: "legacy", - code: "legacy-before-agent-start", - severity: "warn", - message: "warn", - }, - ], - }), - ).toEqual({ - cfg: { gateway: {} }, - sourceConfig: { gateway: {} }, - secretDiagnostics: ["diag"], - osSummary: { - platform: "linux", - arch: "x64", - release: "6.8.0", - label: "linux 6.8.0 (x64)", - }, - tailscaleMode: "serve", - tailscaleDns: "box.tail.ts.net", - tailscaleHttpsUrl: "https://box.tail.ts.net", - update: { - root: "/tmp/openclaw", - installKind: "package", - packageManager: "npm", - }, + const osSummary = { + platform: "linux" as const, + arch: "x64", + release: "6.8.0", + label: "linux 6.8.0 (x64)", + }; + const update = { + root: "/tmp/openclaw", + installKind: "package" as const, + packageManager: "npm" as const, + }; + const gatewaySnapshot = { gatewayConnection: { url: "ws://127.0.0.1:18789", - urlSource: "config", + urlSource: "config" as const, message: "Gateway target: ws://127.0.0.1:18789", }, remoteUrlMissing: false, - gatewayMode: "local", + gatewayMode: "local" as const, gatewayProbeAuth: { token: "tok" }, gatewayProbeAuthWarning: "warn", gatewayProbe: { @@ -124,42 +38,87 @@ describe("buildStatusScanResult", () => { }, gatewayReachable: true, gatewaySelf: { host: "gateway" }, - channelIssues: [ - { - channel: "discord", - accountId: "default", - kind: "runtime", - message: "warn", - }, - ], - agentStatus: { - defaultId: "main", - totalSessions: 0, - bootstrapPendingCount: 0, - agents: [ - { - id: "main", - workspaceDir: null, - bootstrapPending: false, - sessionsPath: "/tmp/main.json", - sessionsCount: 0, - lastUpdatedAt: null, - lastActiveAgeMs: null, - }, - ], + }; + const channelIssues = [ + { + channel: "discord", + accountId: "default", + kind: "runtime" as const, + message: "warn", }, - channels: { rows: [], details: [] }, - summary: buildColdStartStatusSummary(), - memory: { agentId: "main", backend: "builtin", provider: "sqlite" }, - memoryPlugin: { enabled: true, slot: "memory-core" }, - pluginCompatibility: [ + ]; + const agentStatus = { + defaultId: "main", + totalSessions: 0, + bootstrapPendingCount: 0, + agents: [ { - pluginId: "legacy", - code: "legacy-before-agent-start", - severity: "warn", - message: "warn", + id: "main", + workspaceDir: null, + bootstrapPending: false, + sessionsPath: "/tmp/main.json", + sessionsCount: 0, + lastUpdatedAt: null, + lastActiveAgeMs: null, }, ], + }; + const channels = { rows: [], details: [] }; + const summary = buildColdStartStatusSummary(); + const memory = { agentId: "main", backend: "builtin" as const, provider: "sqlite" }; + const memoryPlugin = { enabled: true, slot: "memory-core" }; + const pluginCompatibility = [ + { + pluginId: "legacy", + code: "legacy-before-agent-start" as const, + severity: "warn" as const, + message: "warn", + }, + ]; + + expect( + buildStatusScanResult({ + cfg: { gateway: {} }, + sourceConfig: { gateway: {} }, + secretDiagnostics: ["diag"], + osSummary, + tailscaleMode: "serve", + tailscaleDns: "box.tail.ts.net", + tailscaleHttpsUrl: "https://box.tail.ts.net", + update, + gatewaySnapshot, + channelIssues, + agentStatus, + channels, + summary, + memory, + memoryPlugin, + pluginCompatibility, + }), + ).toEqual({ + cfg: { gateway: {} }, + sourceConfig: { gateway: {} }, + secretDiagnostics: ["diag"], + osSummary, + tailscaleMode: "serve", + tailscaleDns: "box.tail.ts.net", + tailscaleHttpsUrl: "https://box.tail.ts.net", + update, + gatewayConnection: gatewaySnapshot.gatewayConnection, + remoteUrlMissing: gatewaySnapshot.remoteUrlMissing, + gatewayMode: gatewaySnapshot.gatewayMode, + gatewayProbeAuth: gatewaySnapshot.gatewayProbeAuth, + gatewayProbeAuthWarning: gatewaySnapshot.gatewayProbeAuthWarning, + gatewayProbe: gatewaySnapshot.gatewayProbe, + gatewayReachable: gatewaySnapshot.gatewayReachable, + gatewaySelf: gatewaySnapshot.gatewaySelf, + channelIssues, + agentStatus, + channels, + summary, + memory, + memoryPlugin, + pluginCompatibility, }); }); }); diff --git a/src/config/io.audit.test.ts b/src/config/io.audit.test.ts index 24ae28224d0..3f6312afb69 100644 --- a/src/config/io.audit.test.ts +++ b/src/config/io.audit.test.ts @@ -10,6 +10,53 @@ import { resolveConfigAuditLogPath, } from "./io.audit.js"; +function createRenameAuditRecord(home: string) { + return finalizeConfigWriteAuditRecord({ + base: createConfigWriteAuditRecordBase({ + configPath: path.join(home, ".openclaw", "openclaw.json"), + env: {} as NodeJS.ProcessEnv, + existsBefore: true, + previousHash: "prev-hash", + nextHash: "next-hash", + previousBytes: 12, + nextBytes: 24, + previousMetadata: { + dev: "10", + ino: "11", + mode: 0o600, + nlink: 1, + uid: 501, + gid: 20, + }, + changedPathCount: 1, + hasMetaBefore: true, + hasMetaAfter: true, + gatewayModeBefore: "local", + gatewayModeAfter: "local", + suspicious: [], + now: "2026-04-07T08:00:00.000Z", + }), + result: "rename", + nextMetadata: { + dev: "12", + ino: "13", + mode: 0o600, + nlink: 1, + uid: 501, + gid: 20, + }, + }); +} + +function readAuditLog(home: string): unknown[] { + const auditPath = path.join(home, ".openclaw", "logs", "config-audit.jsonl"); + return fs + .readFileSync(auditPath, "utf-8") + .trim() + .split("\n") + .map((line) => JSON.parse(line)); +} + describe("config io audit helpers", () => { const suiteRootTracker = createSuiteTempRootTracker({ prefix: "openclaw-config-audit-" }); @@ -149,41 +196,7 @@ describe("config io audit helpers", () => { it("appends JSONL audit entries to the resolved audit path", async () => { const home = await suiteRootTracker.make("append"); - const record = finalizeConfigWriteAuditRecord({ - base: createConfigWriteAuditRecordBase({ - configPath: path.join(home, ".openclaw", "openclaw.json"), - env: {} as NodeJS.ProcessEnv, - existsBefore: true, - previousHash: "prev-hash", - nextHash: "next-hash", - previousBytes: 12, - nextBytes: 24, - previousMetadata: { - dev: "10", - ino: "11", - mode: 0o600, - nlink: 1, - uid: 501, - gid: 20, - }, - changedPathCount: 1, - hasMetaBefore: true, - hasMetaAfter: true, - gatewayModeBefore: "local", - gatewayModeAfter: "local", - suspicious: [], - now: "2026-04-07T08:00:00.000Z", - }), - result: "rename", - nextMetadata: { - dev: "12", - ino: "13", - mode: 0o600, - nlink: 1, - uid: 501, - gid: 20, - }, - }); + const record = createRenameAuditRecord(home); await appendConfigAuditRecord({ fs, @@ -192,10 +205,9 @@ describe("config io audit helpers", () => { record, }); - const auditPath = path.join(home, ".openclaw", "logs", "config-audit.jsonl"); - const lines = fs.readFileSync(auditPath, "utf-8").trim().split("\n"); - expect(lines).toHaveLength(1); - expect(JSON.parse(lines[0])).toMatchObject({ + const records = readAuditLog(home); + expect(records).toHaveLength(1); + expect(records[0]).toMatchObject({ event: "config.write", result: "rename", nextHash: "next-hash", @@ -204,41 +216,7 @@ describe("config io audit helpers", () => { it("also accepts flattened audit record params from legacy call sites", async () => { const home = await suiteRootTracker.make("append-flat"); - const record = finalizeConfigWriteAuditRecord({ - base: createConfigWriteAuditRecordBase({ - configPath: path.join(home, ".openclaw", "openclaw.json"), - env: {} as NodeJS.ProcessEnv, - existsBefore: true, - previousHash: "prev-hash", - nextHash: "next-hash", - previousBytes: 12, - nextBytes: 24, - previousMetadata: { - dev: "10", - ino: "11", - mode: 0o600, - nlink: 1, - uid: 501, - gid: 20, - }, - changedPathCount: 1, - hasMetaBefore: true, - hasMetaAfter: true, - gatewayModeBefore: "local", - gatewayModeAfter: "local", - suspicious: [], - now: "2026-04-07T08:00:00.000Z", - }), - result: "rename", - nextMetadata: { - dev: "12", - ino: "13", - mode: 0o600, - nlink: 1, - uid: 501, - gid: 20, - }, - }); + const record = createRenameAuditRecord(home); await appendConfigAuditRecord({ fs, @@ -247,10 +225,9 @@ describe("config io audit helpers", () => { ...record, }); - const auditPath = path.join(home, ".openclaw", "logs", "config-audit.jsonl"); - const lines = fs.readFileSync(auditPath, "utf-8").trim().split("\n"); - expect(lines).toHaveLength(1); - expect(JSON.parse(lines[0])).toMatchObject({ + const records = readAuditLog(home); + expect(records).toHaveLength(1); + expect(records[0]).toMatchObject({ event: "config.write", result: "rename", nextHash: "next-hash", diff --git a/src/config/io.observe-recovery.ts b/src/config/io.observe-recovery.ts index 4adbdd1aece..b9a352ac2ff 100644 --- a/src/config/io.observe-recovery.ts +++ b/src/config/io.observe-recovery.ts @@ -1,7 +1,11 @@ import crypto from "node:crypto"; import path from "node:path"; import { isRecord } from "../utils.js"; -import { appendConfigAuditRecord, appendConfigAuditRecordSync } from "./io.audit.js"; +import { + appendConfigAuditRecord, + appendConfigAuditRecordSync, + type ConfigObserveAuditRecord, +} from "./io.audit.js"; import { resolveStateDir } from "./paths.js"; export type ObserveRecoveryDeps = { @@ -99,6 +103,86 @@ type ConfigHealthState = { entries?: Record; }; +function createConfigObserveAuditRecord(params: { + ts: string; + configPath: string; + valid: boolean; + current: ConfigHealthFingerprint; + suspicious: string[]; + lastKnownGood: ConfigHealthFingerprint | undefined; + backup: ConfigHealthFingerprint | null | undefined; + clobberedPath: string | null; + restoredFromBackup: boolean; + restoredBackupPath: string | null; +}): ConfigObserveAuditRecord { + return { + ts: params.ts, + source: "config-io", + event: "config.observe", + phase: "read", + configPath: params.configPath, + pid: process.pid, + ppid: process.ppid, + cwd: process.cwd(), + argv: process.argv.slice(0, 8), + execArgv: process.execArgv.slice(0, 8), + exists: true, + valid: params.valid, + hash: params.current.hash, + bytes: params.current.bytes, + mtimeMs: params.current.mtimeMs, + ctimeMs: params.current.ctimeMs, + dev: params.current.dev, + ino: params.current.ino, + mode: params.current.mode, + nlink: params.current.nlink, + uid: params.current.uid, + gid: params.current.gid, + hasMeta: params.current.hasMeta, + gatewayMode: params.current.gatewayMode, + suspicious: params.suspicious, + lastKnownGoodHash: params.lastKnownGood?.hash ?? null, + lastKnownGoodBytes: params.lastKnownGood?.bytes ?? null, + lastKnownGoodMtimeMs: params.lastKnownGood?.mtimeMs ?? null, + lastKnownGoodCtimeMs: params.lastKnownGood?.ctimeMs ?? null, + lastKnownGoodDev: params.lastKnownGood?.dev ?? null, + lastKnownGoodIno: params.lastKnownGood?.ino ?? null, + lastKnownGoodMode: params.lastKnownGood?.mode ?? null, + lastKnownGoodNlink: params.lastKnownGood?.nlink ?? null, + lastKnownGoodUid: params.lastKnownGood?.uid ?? null, + lastKnownGoodGid: params.lastKnownGood?.gid ?? null, + lastKnownGoodGatewayMode: params.lastKnownGood?.gatewayMode ?? null, + backupHash: params.backup?.hash ?? null, + backupBytes: params.backup?.bytes ?? null, + backupMtimeMs: params.backup?.mtimeMs ?? null, + backupCtimeMs: params.backup?.ctimeMs ?? null, + backupDev: params.backup?.dev ?? null, + backupIno: params.backup?.ino ?? null, + backupMode: params.backup?.mode ?? null, + backupNlink: params.backup?.nlink ?? null, + backupUid: params.backup?.uid ?? null, + backupGid: params.backup?.gid ?? null, + backupGatewayMode: params.backup?.gatewayMode ?? null, + clobberedPath: params.clobberedPath, + restoredFromBackup: params.restoredFromBackup, + restoredBackupPath: params.restoredBackupPath, + }; +} + +type ConfigObserveAuditRecordParams = Parameters[0]; + +function createConfigObserveAuditAppendParams( + deps: ObserveRecoveryDeps, + params: ConfigObserveAuditRecordParams, +) { + return { + fs: deps.fs, + env: deps.env, + homedir: deps.homedir, + record: createConfigObserveAuditRecord(params), + }; +} + function hashConfigRaw(raw: string | null): string { return crypto .createHash("sha256") @@ -462,61 +546,20 @@ export async function maybeRecoverSuspiciousConfigRead(params: { params.deps.logger.warn( `Config auto-restored from backup: ${params.configPath} (${suspicious.join(", ")})`, ); - await appendConfigAuditRecord({ - fs: params.deps.fs, - env: params.deps.env, - homedir: params.deps.homedir, - ts: now, - source: "config-io", - event: "config.observe", - phase: "read", - configPath: params.configPath, - pid: process.pid, - ppid: process.ppid, - cwd: process.cwd(), - argv: process.argv.slice(0, 8), - execArgv: process.execArgv.slice(0, 8), - exists: true, - valid: true, - hash: current.hash, - bytes: current.bytes, - mtimeMs: current.mtimeMs, - ctimeMs: current.ctimeMs, - dev: current.dev, - ino: current.ino, - mode: current.mode, - nlink: current.nlink, - uid: current.uid, - gid: current.gid, - hasMeta: current.hasMeta, - gatewayMode: current.gatewayMode, - suspicious, - lastKnownGoodHash: entry.lastKnownGood?.hash ?? null, - lastKnownGoodBytes: entry.lastKnownGood?.bytes ?? null, - lastKnownGoodMtimeMs: entry.lastKnownGood?.mtimeMs ?? null, - lastKnownGoodCtimeMs: entry.lastKnownGood?.ctimeMs ?? null, - lastKnownGoodDev: entry.lastKnownGood?.dev ?? null, - lastKnownGoodIno: entry.lastKnownGood?.ino ?? null, - lastKnownGoodMode: entry.lastKnownGood?.mode ?? null, - lastKnownGoodNlink: entry.lastKnownGood?.nlink ?? null, - lastKnownGoodUid: entry.lastKnownGood?.uid ?? null, - lastKnownGoodGid: entry.lastKnownGood?.gid ?? null, - lastKnownGoodGatewayMode: entry.lastKnownGood?.gatewayMode ?? null, - backupHash: backup?.hash ?? null, - backupBytes: backup?.bytes ?? null, - backupMtimeMs: backup?.mtimeMs ?? null, - backupCtimeMs: backup?.ctimeMs ?? null, - backupDev: backup?.dev ?? null, - backupIno: backup?.ino ?? null, - backupMode: backup?.mode ?? null, - backupNlink: backup?.nlink ?? null, - backupUid: backup?.uid ?? null, - backupGid: backup?.gid ?? null, - backupGatewayMode: backup?.gatewayMode ?? null, - clobberedPath, - restoredFromBackup, - restoredBackupPath: backupPath, - }); + await appendConfigAuditRecord( + createConfigObserveAuditAppendParams(params.deps, { + ts: now, + configPath: params.configPath, + valid: true, + current, + suspicious, + lastKnownGood: entry.lastKnownGood, + backup, + clobberedPath, + restoredFromBackup, + restoredBackupPath: backupPath, + }), + ); healthState = setConfigHealthEntry(healthState, params.configPath, { ...entry, @@ -599,61 +642,20 @@ export function maybeRecoverSuspiciousConfigReadSync(params: { params.deps.logger.warn( `Config auto-restored from backup: ${params.configPath} (${suspicious.join(", ")})`, ); - appendConfigAuditRecordSync({ - fs: params.deps.fs, - env: params.deps.env, - homedir: params.deps.homedir, - ts: now, - source: "config-io", - event: "config.observe", - phase: "read", - configPath: params.configPath, - pid: process.pid, - ppid: process.ppid, - cwd: process.cwd(), - argv: process.argv.slice(0, 8), - execArgv: process.execArgv.slice(0, 8), - exists: true, - valid: true, - hash: current.hash, - bytes: current.bytes, - mtimeMs: current.mtimeMs, - ctimeMs: current.ctimeMs, - dev: current.dev, - ino: current.ino, - mode: current.mode, - nlink: current.nlink, - uid: current.uid, - gid: current.gid, - hasMeta: current.hasMeta, - gatewayMode: current.gatewayMode, - suspicious, - lastKnownGoodHash: entry.lastKnownGood?.hash ?? null, - lastKnownGoodBytes: entry.lastKnownGood?.bytes ?? null, - lastKnownGoodMtimeMs: entry.lastKnownGood?.mtimeMs ?? null, - lastKnownGoodCtimeMs: entry.lastKnownGood?.ctimeMs ?? null, - lastKnownGoodDev: entry.lastKnownGood?.dev ?? null, - lastKnownGoodIno: entry.lastKnownGood?.ino ?? null, - lastKnownGoodMode: entry.lastKnownGood?.mode ?? null, - lastKnownGoodNlink: entry.lastKnownGood?.nlink ?? null, - lastKnownGoodUid: entry.lastKnownGood?.uid ?? null, - lastKnownGoodGid: entry.lastKnownGood?.gid ?? null, - lastKnownGoodGatewayMode: entry.lastKnownGood?.gatewayMode ?? null, - backupHash: backup?.hash ?? null, - backupBytes: backup?.bytes ?? null, - backupMtimeMs: backup?.mtimeMs ?? null, - backupCtimeMs: backup?.ctimeMs ?? null, - backupDev: backup?.dev ?? null, - backupIno: backup?.ino ?? null, - backupMode: backup?.mode ?? null, - backupNlink: backup?.nlink ?? null, - backupUid: backup?.uid ?? null, - backupGid: backup?.gid ?? null, - backupGatewayMode: backup?.gatewayMode ?? null, - clobberedPath, - restoredFromBackup, - restoredBackupPath: backupPath, - }); + appendConfigAuditRecordSync( + createConfigObserveAuditAppendParams(params.deps, { + ts: now, + configPath: params.configPath, + valid: true, + current, + suspicious, + lastKnownGood: entry.lastKnownGood, + backup, + clobberedPath, + restoredFromBackup, + restoredBackupPath: backupPath, + }), + ); healthState = setConfigHealthEntry(healthState, params.configPath, { ...entry, @@ -742,61 +744,20 @@ export async function observeConfigSnapshot( }); deps.logger.warn(`Config observe anomaly: ${snapshot.path} (${suspicious.join(", ")})`); - await appendConfigAuditRecord({ - fs: deps.fs, - env: deps.env, - homedir: deps.homedir, - ts: now, - source: "config-io", - event: "config.observe", - phase: "read", - configPath: snapshot.path, - pid: process.pid, - ppid: process.ppid, - cwd: process.cwd(), - argv: process.argv.slice(0, 8), - execArgv: process.execArgv.slice(0, 8), - exists: true, - valid: snapshot.valid, - hash: current.hash, - bytes: current.bytes, - mtimeMs: current.mtimeMs, - ctimeMs: current.ctimeMs, - dev: current.dev, - ino: current.ino, - mode: current.mode, - nlink: current.nlink, - uid: current.uid, - gid: current.gid, - hasMeta: current.hasMeta, - gatewayMode: current.gatewayMode, - suspicious, - lastKnownGoodHash: entry.lastKnownGood?.hash ?? null, - lastKnownGoodBytes: entry.lastKnownGood?.bytes ?? null, - lastKnownGoodMtimeMs: entry.lastKnownGood?.mtimeMs ?? null, - lastKnownGoodCtimeMs: entry.lastKnownGood?.ctimeMs ?? null, - lastKnownGoodDev: entry.lastKnownGood?.dev ?? null, - lastKnownGoodIno: entry.lastKnownGood?.ino ?? null, - lastKnownGoodMode: entry.lastKnownGood?.mode ?? null, - lastKnownGoodNlink: entry.lastKnownGood?.nlink ?? null, - lastKnownGoodUid: entry.lastKnownGood?.uid ?? null, - lastKnownGoodGid: entry.lastKnownGood?.gid ?? null, - lastKnownGoodGatewayMode: entry.lastKnownGood?.gatewayMode ?? null, - backupHash: backup?.hash ?? null, - backupBytes: backup?.bytes ?? null, - backupMtimeMs: backup?.mtimeMs ?? null, - backupCtimeMs: backup?.ctimeMs ?? null, - backupDev: backup?.dev ?? null, - backupIno: backup?.ino ?? null, - backupMode: backup?.mode ?? null, - backupNlink: backup?.nlink ?? null, - backupUid: backup?.uid ?? null, - backupGid: backup?.gid ?? null, - backupGatewayMode: backup?.gatewayMode ?? null, - clobberedPath, - restoredFromBackup: false, - restoredBackupPath: null, - }); + await appendConfigAuditRecord( + createConfigObserveAuditAppendParams(deps, { + ts: now, + configPath: snapshot.path, + valid: snapshot.valid, + current, + suspicious, + lastKnownGood: entry.lastKnownGood, + backup, + clobberedPath, + restoredFromBackup: false, + restoredBackupPath: null, + }), + ); healthState = setConfigHealthEntry(healthState, snapshot.path, { ...entry, @@ -867,61 +828,20 @@ export function observeConfigSnapshotSync( }); deps.logger.warn(`Config observe anomaly: ${snapshot.path} (${suspicious.join(", ")})`); - appendConfigAuditRecordSync({ - fs: deps.fs, - env: deps.env, - homedir: deps.homedir, - ts: now, - source: "config-io", - event: "config.observe", - phase: "read", - configPath: snapshot.path, - pid: process.pid, - ppid: process.ppid, - cwd: process.cwd(), - argv: process.argv.slice(0, 8), - execArgv: process.execArgv.slice(0, 8), - exists: true, - valid: snapshot.valid, - hash: current.hash, - bytes: current.bytes, - mtimeMs: current.mtimeMs, - ctimeMs: current.ctimeMs, - dev: current.dev, - ino: current.ino, - mode: current.mode, - nlink: current.nlink, - uid: current.uid, - gid: current.gid, - hasMeta: current.hasMeta, - gatewayMode: current.gatewayMode, - suspicious, - lastKnownGoodHash: entry.lastKnownGood?.hash ?? null, - lastKnownGoodBytes: entry.lastKnownGood?.bytes ?? null, - lastKnownGoodMtimeMs: entry.lastKnownGood?.mtimeMs ?? null, - lastKnownGoodCtimeMs: entry.lastKnownGood?.ctimeMs ?? null, - lastKnownGoodDev: entry.lastKnownGood?.dev ?? null, - lastKnownGoodIno: entry.lastKnownGood?.ino ?? null, - lastKnownGoodMode: entry.lastKnownGood?.mode ?? null, - lastKnownGoodNlink: entry.lastKnownGood?.nlink ?? null, - lastKnownGoodUid: entry.lastKnownGood?.uid ?? null, - lastKnownGoodGid: entry.lastKnownGood?.gid ?? null, - lastKnownGoodGatewayMode: entry.lastKnownGood?.gatewayMode ?? null, - backupHash: backup?.hash ?? null, - backupBytes: backup?.bytes ?? null, - backupMtimeMs: backup?.mtimeMs ?? null, - backupCtimeMs: backup?.ctimeMs ?? null, - backupDev: backup?.dev ?? null, - backupIno: backup?.ino ?? null, - backupMode: backup?.mode ?? null, - backupNlink: backup?.nlink ?? null, - backupUid: backup?.uid ?? null, - backupGid: backup?.gid ?? null, - backupGatewayMode: backup?.gatewayMode ?? null, - clobberedPath, - restoredFromBackup: false, - restoredBackupPath: null, - }); + appendConfigAuditRecordSync( + createConfigObserveAuditAppendParams(deps, { + ts: now, + configPath: snapshot.path, + valid: snapshot.valid, + current, + suspicious, + lastKnownGood: entry.lastKnownGood, + backup, + clobberedPath, + restoredFromBackup: false, + restoredBackupPath: null, + }), + ); healthState = setConfigHealthEntry(healthState, snapshot.path, { ...entry, diff --git a/src/config/sessions/explicit-session-key-normalization.test.ts b/src/config/sessions/explicit-session-key-normalization.test.ts index 3cf32846337..055affb3fd5 100644 --- a/src/config/sessions/explicit-session-key-normalization.test.ts +++ b/src/config/sessions/explicit-session-key-normalization.test.ts @@ -1,69 +1,8 @@ -import { afterEach, beforeEach, describe, expect, it } from "vitest"; -import type { MsgContext } from "../../auto-reply/templating.js"; -import type { ChannelPlugin } from "../../channels/plugins/types.js"; -import { resetPluginRuntimeStateForTest, setActivePluginRegistry } from "../../plugins/runtime.js"; -import { - createChannelTestPluginBase, - createTestRegistry, -} from "../../test-utils/channel-plugins.js"; +import { describe, expect, it } from "vitest"; import { normalizeExplicitSessionKey } from "./explicit-session-key-normalization.js"; +import { installDiscordSessionKeyNormalizerFixture, makeCtx } from "./session-key.test-helpers.js"; -function makeCtx(overrides: Partial): MsgContext { - return { - Body: "", - From: "", - To: "", - ...overrides, - } as MsgContext; -} - -beforeEach(() => { - const discordPlugin: ChannelPlugin = { - ...createChannelTestPluginBase({ - id: "discord", - label: "Discord", - docsPath: "/channels/discord", - }), - messaging: { - normalizeExplicitSessionKey: ({ sessionKey, ctx }) => { - const normalizedChatType = ctx.ChatType?.trim().toLowerCase(); - let normalized = sessionKey.trim().toLowerCase(); - if (normalizedChatType !== "direct" && normalizedChatType !== "dm") { - return normalized; - } - normalized = normalized.replace(/^(discord:)dm:/, "$1direct:"); - normalized = normalized.replace(/^(agent:[^:]+:discord:)dm:/, "$1direct:"); - const match = normalized.match(/^((?:agent:[^:]+:)?)discord:channel:([^:]+)$/); - if (!match) { - return normalized; - } - const from = (ctx.From ?? "").trim().toLowerCase(); - const senderId = (ctx.SenderId ?? "").trim().toLowerCase(); - const fromDiscordId = - from.startsWith("discord:") && !from.includes(":channel:") && !from.includes(":group:") - ? from.slice("discord:".length) - : ""; - const directId = senderId || fromDiscordId; - return directId && directId === match[2] - ? `${match[1]}discord:direct:${match[2]}` - : normalized; - }, - }, - }; - setActivePluginRegistry( - createTestRegistry([ - { - pluginId: "discord", - plugin: discordPlugin, - source: "test", - }, - ]), - ); -}); - -afterEach(() => { - resetPluginRuntimeStateForTest(); -}); +installDiscordSessionKeyNormalizerFixture(); describe("normalizeExplicitSessionKey", () => { it("dispatches discord keys through the provider normalizer", () => { diff --git a/src/config/sessions/session-key.test-helpers.ts b/src/config/sessions/session-key.test-helpers.ts new file mode 100644 index 00000000000..cc586fb4e61 --- /dev/null +++ b/src/config/sessions/session-key.test-helpers.ts @@ -0,0 +1,67 @@ +import { afterEach, beforeEach } from "vitest"; +import type { MsgContext } from "../../auto-reply/templating.js"; +import type { ChannelPlugin } from "../../channels/plugins/types.js"; +import { resetPluginRuntimeStateForTest, setActivePluginRegistry } from "../../plugins/runtime.js"; +import { + createChannelTestPluginBase, + createTestRegistry, +} from "../../test-utils/channel-plugins.js"; + +export function makeCtx(overrides: Partial): MsgContext { + return { + Body: "", + From: "", + To: "", + ...overrides, + } as MsgContext; +} + +export function installDiscordSessionKeyNormalizerFixture(): void { + beforeEach(() => { + const discordPlugin: ChannelPlugin = { + ...createChannelTestPluginBase({ + id: "discord", + label: "Discord", + docsPath: "/channels/discord", + }), + messaging: { + normalizeExplicitSessionKey: ({ sessionKey, ctx }) => { + const normalizedChatType = ctx.ChatType?.trim().toLowerCase(); + let normalized = sessionKey.trim().toLowerCase(); + if (normalizedChatType !== "direct" && normalizedChatType !== "dm") { + return normalized; + } + normalized = normalized.replace(/^(discord:)dm:/, "$1direct:"); + normalized = normalized.replace(/^(agent:[^:]+:discord:)dm:/, "$1direct:"); + const match = normalized.match(/^((?:agent:[^:]+:)?)discord:channel:([^:]+)$/); + if (!match) { + return normalized; + } + const from = (ctx.From ?? "").trim().toLowerCase(); + const senderId = (ctx.SenderId ?? "").trim().toLowerCase(); + const fromDiscordId = + from.startsWith("discord:") && !from.includes(":channel:") && !from.includes(":group:") + ? from.slice("discord:".length) + : ""; + const directId = senderId || fromDiscordId; + return directId && directId === match[2] + ? `${match[1]}discord:direct:${match[2]}` + : normalized; + }, + }, + }; + setActivePluginRegistry( + createTestRegistry([ + { + pluginId: "discord", + plugin: discordPlugin, + source: "test", + }, + ]), + ); + }); + + afterEach(() => { + resetPluginRuntimeStateForTest(); + }); +} diff --git a/src/config/sessions/session-key.test.ts b/src/config/sessions/session-key.test.ts index 6f05552f6bc..941200523a9 100644 --- a/src/config/sessions/session-key.test.ts +++ b/src/config/sessions/session-key.test.ts @@ -1,69 +1,8 @@ -import { afterEach, beforeEach, describe, expect, it } from "vitest"; -import type { MsgContext } from "../../auto-reply/templating.js"; -import type { ChannelPlugin } from "../../channels/plugins/types.js"; -import { resetPluginRuntimeStateForTest, setActivePluginRegistry } from "../../plugins/runtime.js"; -import { - createChannelTestPluginBase, - createTestRegistry, -} from "../../test-utils/channel-plugins.js"; +import { describe, expect, it } from "vitest"; import { resolveSessionKey } from "./session-key.js"; +import { installDiscordSessionKeyNormalizerFixture, makeCtx } from "./session-key.test-helpers.js"; -function makeCtx(overrides: Partial): MsgContext { - return { - Body: "", - From: "", - To: "", - ...overrides, - } as MsgContext; -} - -beforeEach(() => { - const discordPlugin: ChannelPlugin = { - ...createChannelTestPluginBase({ - id: "discord", - label: "Discord", - docsPath: "/channels/discord", - }), - messaging: { - normalizeExplicitSessionKey: ({ sessionKey, ctx }) => { - const normalizedChatType = ctx.ChatType?.trim().toLowerCase(); - let normalized = sessionKey.trim().toLowerCase(); - if (normalizedChatType !== "direct" && normalizedChatType !== "dm") { - return normalized; - } - normalized = normalized.replace(/^(discord:)dm:/, "$1direct:"); - normalized = normalized.replace(/^(agent:[^:]+:discord:)dm:/, "$1direct:"); - const match = normalized.match(/^((?:agent:[^:]+:)?)discord:channel:([^:]+)$/); - if (!match) { - return normalized; - } - const from = (ctx.From ?? "").trim().toLowerCase(); - const senderId = (ctx.SenderId ?? "").trim().toLowerCase(); - const fromDiscordId = - from.startsWith("discord:") && !from.includes(":channel:") && !from.includes(":group:") - ? from.slice("discord:".length) - : ""; - const directId = senderId || fromDiscordId; - return directId && directId === match[2] - ? `${match[1]}discord:direct:${match[2]}` - : normalized; - }, - }, - }; - setActivePluginRegistry( - createTestRegistry([ - { - pluginId: "discord", - plugin: discordPlugin, - source: "test", - }, - ]), - ); -}); - -afterEach(() => { - resetPluginRuntimeStateForTest(); -}); +installDiscordSessionKeyNormalizerFixture(); describe("resolveSessionKey", () => { describe("Discord DM session key normalization", () => { diff --git a/src/cron/service/store.test.ts b/src/cron/service/store.test.ts index c9a7b815255..c34f20f3b35 100644 --- a/src/cron/service/store.test.ts +++ b/src/cron/service/store.test.ts @@ -10,49 +10,56 @@ const { logger, makeStorePath } = setupCronServiceSuite({ prefix: "cron-service-store-seam", }); +const STORE_TEST_NOW = Date.parse("2026-03-23T12:00:00.000Z"); + +async function writeSingleJobStore(storePath: string, job: Record) { + await fs.mkdir(path.dirname(storePath), { recursive: true }); + await fs.writeFile( + storePath, + JSON.stringify( + { + version: 1, + jobs: [job], + }, + null, + 2, + ), + "utf8", + ); +} + +function createStoreTestState(storePath: string) { + return createCronServiceState({ + storePath, + cronEnabled: true, + log: logger, + nowMs: () => STORE_TEST_NOW, + enqueueSystemEvent: vi.fn(), + requestHeartbeatNow: vi.fn(), + runIsolatedAgentJob: vi.fn(async () => ({ status: "ok" as const })), + }); +} + describe("cron service store seam coverage", () => { it("loads stored jobs, recomputes next runs, and does not rewrite the store on load", async () => { const { storePath } = await makeStorePath(); - const now = Date.parse("2026-03-23T12:00:00.000Z"); - await fs.mkdir(path.dirname(storePath), { recursive: true }); - await fs.writeFile( - storePath, - JSON.stringify( - { - version: 1, - jobs: [ - { - id: "modern-job", - name: "modern job", - enabled: true, - createdAtMs: now - 60_000, - updatedAtMs: now - 60_000, - schedule: { kind: "every", everyMs: 60_000 }, - sessionTarget: "isolated", - wakeMode: "now", - payload: { kind: "agentTurn", message: "ping" }, - delivery: { mode: "announce", channel: "telegram", to: "123" }, - state: {}, - }, - ], - }, - null, - 2, - ), - "utf8", - ); - - const state = createCronServiceState({ - storePath, - cronEnabled: true, - log: logger, - nowMs: () => now, - enqueueSystemEvent: vi.fn(), - requestHeartbeatNow: vi.fn(), - runIsolatedAgentJob: vi.fn(async () => ({ status: "ok" as const })), + await writeSingleJobStore(storePath, { + id: "modern-job", + name: "modern job", + enabled: true, + createdAtMs: STORE_TEST_NOW - 60_000, + updatedAtMs: STORE_TEST_NOW - 60_000, + schedule: { kind: "every", everyMs: 60_000 }, + sessionTarget: "isolated", + wakeMode: "now", + payload: { kind: "agentTurn", message: "ping" }, + delivery: { mode: "announce", channel: "telegram", to: "123" }, + state: {}, }); + const state = createStoreTestState(storePath); + await ensureLoaded(state); const job = state.store?.jobs[0]; @@ -67,7 +74,7 @@ describe("cron service store seam coverage", () => { channel: "telegram", to: "123", }); - expect(job?.state.nextRunAtMs).toBe(now); + expect(job?.state.nextRunAtMs).toBe(STORE_TEST_NOW); const persisted = JSON.parse(await fs.readFile(storePath, "utf8")) as { jobs: Array>; @@ -93,45 +100,22 @@ describe("cron service store seam coverage", () => { it("normalizes jobId-only jobs in memory so scheduler lookups resolve by stable id", async () => { const { storePath } = await makeStorePath(); - const now = Date.parse("2026-03-23T12:00:00.000Z"); - await fs.mkdir(path.dirname(storePath), { recursive: true }); - await fs.writeFile( - storePath, - JSON.stringify( - { - version: 1, - jobs: [ - { - jobId: "repro-stable-id", - name: "handed", - enabled: true, - createdAtMs: now - 60_000, - updatedAtMs: now - 60_000, - schedule: { kind: "every", everyMs: 60_000 }, - sessionTarget: "main", - wakeMode: "now", - payload: { kind: "systemEvent", text: "tick" }, - state: {}, - }, - ], - }, - null, - 2, - ), - "utf8", - ); - - const state = createCronServiceState({ - storePath, - cronEnabled: true, - log: logger, - nowMs: () => now, - enqueueSystemEvent: vi.fn(), - requestHeartbeatNow: vi.fn(), - runIsolatedAgentJob: vi.fn(async () => ({ status: "ok" as const })), + await writeSingleJobStore(storePath, { + jobId: "repro-stable-id", + name: "handed", + enabled: true, + createdAtMs: STORE_TEST_NOW - 60_000, + updatedAtMs: STORE_TEST_NOW - 60_000, + schedule: { kind: "every", everyMs: 60_000 }, + sessionTarget: "main", + wakeMode: "now", + payload: { kind: "systemEvent", text: "tick" }, + state: {}, }); + const state = createStoreTestState(storePath); + await ensureLoaded(state); expect(logger.warn).toHaveBeenCalledWith( diff --git a/src/gateway/server.talk-config.test.ts b/src/gateway/server.talk-config.test.ts index 742d0be2674..8d77762e161 100644 --- a/src/gateway/server.talk-config.test.ts +++ b/src/gateway/server.talk-config.test.ts @@ -6,11 +6,10 @@ import { publicKeyRawBase64UrlFromPem, signDevicePayload, } from "../infra/device-identity.js"; -import { createEmptyPluginRegistry } from "../plugins/registry-empty.js"; -import { getActivePluginRegistry, setActivePluginRegistry } from "../plugins/runtime.js"; import { withEnvAsync } from "../test-utils/env.js"; import { buildDeviceAuthPayload } from "./device-auth.js"; import { validateTalkConfigResult } from "./protocol/index.js"; +import { withSpeechProviders } from "./talk.test-helpers.js"; import { connectOk, createGatewaySuiteHarness, @@ -137,22 +136,6 @@ async function withTalkConfigConnection( } } -async function withSpeechProviders( - speechProviders: NonNullable["speechProviders"]>, - run: () => Promise, -): Promise { - const previousRegistry = getActivePluginRegistry() ?? createEmptyPluginRegistry(); - setActivePluginRegistry({ - ...createEmptyPluginRegistry(), - speechProviders, - }); - try { - return await run(); - } finally { - setActivePluginRegistry(previousRegistry); - } -} - function expectTalkConfig( talk: TalkConfig | undefined, expected: { diff --git a/src/gateway/server.talk-runtime.test.ts b/src/gateway/server.talk-runtime.test.ts index 14c988b5f54..5f62c6314ba 100644 --- a/src/gateway/server.talk-runtime.test.ts +++ b/src/gateway/server.talk-runtime.test.ts @@ -1,7 +1,9 @@ import { beforeEach, describe, expect, it, vi } from "vitest"; -import { createEmptyPluginRegistry } from "../plugins/registry-empty.js"; -import { getActivePluginRegistry, setActivePluginRegistry } from "../plugins/runtime.js"; -import { talkHandlers } from "./server-methods/talk.js"; +import { + invokeTalkSpeakDirect, + type TalkSpeakTestPayload, + withSpeechProviders, +} from "./talk.test-helpers.js"; const synthesizeSpeechMock = vi.hoisted(() => vi.fn(async () => ({ @@ -18,49 +20,43 @@ vi.mock("../tts/tts.js", () => ({ synthesizeSpeech: synthesizeSpeechMock, })); -type TalkSpeakPayload = { - audioBase64?: string; - provider?: string; - outputFormat?: string; -}; +type SpeechProvider = Parameters[0][number]["provider"]; const ALIAS_STUB_VOICE_ID = "VoiceAlias1234567890"; -async function invokeTalkSpeakDirect(params: Record) { - let response: - | { - ok: boolean; - payload?: unknown; - error?: { code?: string; message?: string; details?: unknown }; - } - | undefined; - await talkHandlers["talk.speak"]({ - req: { type: "req", id: "test", method: "talk.speak", params }, - params, - client: null, - isWebchatConnect: () => false, - respond: (ok, payload, error) => { - response = { ok, payload, error }; +async function writeAcmeTalkConfig() { + const { writeConfigFile } = await import("../config/config.js"); + await writeConfigFile({ + talk: { + provider: "acme", + providers: { + acme: { + voiceId: "plugin-voice", + }, + }, }, - context: {} as never, }); - return response; } -async function withSpeechProviders( - speechProviders: NonNullable["speechProviders"]>, - run: () => Promise, -): Promise { - const previousRegistry = getActivePluginRegistry() ?? createEmptyPluginRegistry(); - setActivePluginRegistry({ - ...createEmptyPluginRegistry(), - speechProviders, - }); - try { - return await run(); - } finally { - setActivePluginRegistry(previousRegistry); - } +async function withAcmeSpeechProvider( + synthesize: SpeechProvider["synthesize"], + run: () => Promise, +) { + await withSpeechProviders( + [ + { + pluginId: "acme-plugin", + source: "test", + provider: { + id: "acme", + label: "Acme Speech", + isConfigured: () => true, + synthesize, + }, + }, + ], + run, + ); } describe("gateway talk runtime", () => { @@ -138,43 +134,22 @@ describe("gateway talk runtime", () => { }); it("allows extension speech providers through talk.speak", async () => { - const { writeConfigFile } = await import("../config/config.js"); - await writeConfigFile({ - talk: { - provider: "acme", - providers: { - acme: { - voiceId: "plugin-voice", - }, - }, - }, - }); + await writeAcmeTalkConfig(); - await withSpeechProviders( - [ - { - pluginId: "acme-plugin", - source: "test", - provider: { - id: "acme", - label: "Acme Speech", - isConfigured: () => true, - synthesize: async () => ({ - audioBuffer: Buffer.from([7, 8, 9]), - outputFormat: "mp3", - fileExtension: ".mp3", - voiceCompatible: false, - }), - }, - }, - ], + await withAcmeSpeechProvider( + async () => ({ + audioBuffer: Buffer.from([7, 8, 9]), + outputFormat: "mp3", + fileExtension: ".mp3", + voiceCompatible: false, + }), async () => { const res = await invokeTalkSpeakDirect({ text: "Hello from talk mode.", }); expect(res?.ok, JSON.stringify(res?.error)).toBe(true); - expect((res?.payload as TalkSpeakPayload | undefined)?.provider).toBe("acme"); - expect((res?.payload as TalkSpeakPayload | undefined)?.audioBase64).toBe( + expect((res?.payload as TalkSpeakTestPayload | undefined)?.provider).toBe("acme"); + expect((res?.payload as TalkSpeakTestPayload | undefined)?.audioBase64).toBe( Buffer.from([7, 8, 9]).toString("base64"), ); }, @@ -241,9 +216,11 @@ describe("gateway talk runtime", () => { }); expect(res?.ok, JSON.stringify(res?.error)).toBe(true); - expect((res?.payload as TalkSpeakPayload | undefined)?.provider).toBe("elevenlabs"); - expect((res?.payload as TalkSpeakPayload | undefined)?.outputFormat).toBe("pcm_44100"); - expect((res?.payload as TalkSpeakPayload | undefined)?.audioBase64).toBe( + expect((res?.payload as TalkSpeakTestPayload | undefined)?.provider).toBe("elevenlabs"); + expect((res?.payload as TalkSpeakTestPayload | undefined)?.outputFormat).toBe( + "pcm_44100", + ); + expect((res?.payload as TalkSpeakTestPayload | undefined)?.audioBase64).toBe( Buffer.from([4, 5, 6]).toString("base64"), ); expect(synthesizeSpeechMock).toHaveBeenCalledWith( @@ -280,31 +257,10 @@ describe("gateway talk runtime", () => { }); it("returns synthesis_failed details when the provider rejects synthesis", async () => { - const { writeConfigFile } = await import("../config/config.js"); - await writeConfigFile({ - talk: { - provider: "acme", - providers: { - acme: { - voiceId: "plugin-voice", - }, - }, - }, - }); + await writeAcmeTalkConfig(); - await withSpeechProviders( - [ - { - pluginId: "acme-plugin", - source: "test", - provider: { - id: "acme", - label: "Acme Speech", - isConfigured: () => true, - synthesize: async () => ({}) as never, - }, - }, - ], + await withAcmeSpeechProvider( + async () => ({}) as never, async () => { synthesizeSpeechMock.mockResolvedValue({ success: false, @@ -321,31 +277,10 @@ describe("gateway talk runtime", () => { }); it("rejects empty audio results as invalid_audio_result", async () => { - const { writeConfigFile } = await import("../config/config.js"); - await writeConfigFile({ - talk: { - provider: "acme", - providers: { - acme: { - voiceId: "plugin-voice", - }, - }, - }, - }); + await writeAcmeTalkConfig(); - await withSpeechProviders( - [ - { - pluginId: "acme-plugin", - source: "test", - provider: { - id: "acme", - label: "Acme Speech", - isConfigured: () => true, - synthesize: async () => ({}) as never, - }, - }, - ], + await withAcmeSpeechProvider( + async () => ({}) as never, async () => { synthesizeSpeechMock.mockResolvedValue({ success: true, diff --git a/src/gateway/talk.test-helpers.ts b/src/gateway/talk.test-helpers.ts new file mode 100644 index 00000000000..3a42fcd4413 --- /dev/null +++ b/src/gateway/talk.test-helpers.ts @@ -0,0 +1,48 @@ +import { createEmptyPluginRegistry } from "../plugins/registry-empty.js"; +import { getActivePluginRegistry, setActivePluginRegistry } from "../plugins/runtime.js"; + +export type TalkSpeakTestPayload = { + audioBase64?: string; + provider?: string; + outputFormat?: string; + mimeType?: string; + fileExtension?: string; +}; + +export async function invokeTalkSpeakDirect(params: Record) { + const { talkHandlers } = await import("./server-methods/talk.js"); + let response: + | { + ok: boolean; + payload?: unknown; + error?: { code?: string; message?: string; details?: unknown }; + } + | undefined; + await talkHandlers["talk.speak"]({ + req: { type: "req", id: "test", method: "talk.speak", params }, + params, + client: null, + isWebchatConnect: () => false, + respond: (ok, payload, error) => { + response = { ok, payload, error }; + }, + context: {} as never, + }); + return response; +} + +export async function withSpeechProviders( + speechProviders: NonNullable["speechProviders"]>, + run: () => Promise, +): Promise { + const previousRegistry = getActivePluginRegistry() ?? createEmptyPluginRegistry(); + setActivePluginRegistry({ + ...createEmptyPluginRegistry(), + speechProviders, + }); + try { + return await run(); + } finally { + setActivePluginRegistry(previousRegistry); + } +} diff --git a/src/image-generation/live-test-helpers.ts b/src/image-generation/live-test-helpers.ts index 8d068309528..a457604129b 100644 --- a/src/image-generation/live-test-helpers.ts +++ b/src/image-generation/live-test-helpers.ts @@ -1,7 +1,15 @@ -import type { AuthProfileStore } from "../agents/auth-profiles.js"; import type { OpenClawConfig } from "../config/config.js"; +import { + parseLiveCsvFilter, + parseProviderModelMap, + redactLiveApiKey, + resolveConfiguredLiveProviderModels, + resolveLiveAuthStore, +} from "../media-generation/live-test-helpers.js"; import { normalizeOptionalLowercaseString } from "../shared/string-coerce.js"; +export { parseProviderModelMap, redactLiveApiKey }; + export const DEFAULT_LIVE_IMAGE_MODELS: Record = { fal: "fal/fal-ai/flux/dev", google: "google/gemini-3.1-flash-image-preview", @@ -22,87 +30,17 @@ export function parseCaseFilter(raw?: string): Set | null { return values.length > 0 ? new Set(values) : null; } -export function redactLiveApiKey(value: string | undefined): string { - const trimmed = value?.trim(); - if (!trimmed) { - return "none"; - } - if (trimmed.length <= 12) { - return trimmed; - } - return `${trimmed.slice(0, 8)}...${trimmed.slice(-4)}`; -} - export function parseCsvFilter(raw?: string): Set | null { - const trimmed = raw?.trim(); - if (!trimmed || trimmed === "all") { - return null; - } - const values = trimmed - .split(",") - .map((entry) => entry.trim()) - .filter(Boolean); - return values.length > 0 ? new Set(values) : null; -} - -export function parseProviderModelMap(raw?: string): Map { - const entries = new Map(); - for (const token of raw?.split(",") ?? []) { - const trimmed = token.trim(); - if (!trimmed) { - continue; - } - const slash = trimmed.indexOf("/"); - if (slash <= 0 || slash === trimmed.length - 1) { - continue; - } - const providerId = normalizeOptionalLowercaseString(trimmed.slice(0, slash)); - if (!providerId) { - continue; - } - entries.set(providerId, trimmed); - } - return entries; + return parseLiveCsvFilter(raw, { lowercase: false }); } export function resolveConfiguredLiveImageModels(cfg: OpenClawConfig): Map { - const resolved = new Map(); - const configured = cfg.agents?.defaults?.imageGenerationModel; - const add = (value: string | undefined) => { - const trimmed = value?.trim(); - if (!trimmed) { - return; - } - const slash = trimmed.indexOf("/"); - if (slash <= 0 || slash === trimmed.length - 1) { - return; - } - const providerId = normalizeOptionalLowercaseString(trimmed.slice(0, slash)); - if (!providerId) { - return; - } - resolved.set(providerId, trimmed); - }; - if (typeof configured === "string") { - add(configured); - return resolved; - } - add(configured?.primary); - for (const fallback of configured?.fallbacks ?? []) { - add(fallback); - } - return resolved; + return resolveConfiguredLiveProviderModels(cfg.agents?.defaults?.imageGenerationModel); } export function resolveLiveImageAuthStore(params: { requireProfileKeys: boolean; hasLiveKeys: boolean; -}): AuthProfileStore | undefined { - if (params.requireProfileKeys || !params.hasLiveKeys) { - return undefined; - } - return { - version: 1, - profiles: {}, - }; +}) { + return resolveLiveAuthStore(params); } diff --git a/src/image-generation/model-ref.ts b/src/image-generation/model-ref.ts index 74750005c83..3e198b8938a 100644 --- a/src/image-generation/model-ref.ts +++ b/src/image-generation/model-ref.ts @@ -1,23 +1,7 @@ -import { normalizeOptionalString } from "../shared/string-coerce.js"; +import { parseGenerationModelRef } from "../media-generation/model-ref.js"; export function parseImageGenerationModelRef( raw: string | undefined, ): { provider: string; model: string } | null { - const trimmed = normalizeOptionalString(raw); - if (!trimmed) { - return null; - } - const slashIndex = trimmed.indexOf("/"); - if (slashIndex <= 0 || slashIndex === trimmed.length - 1) { - return null; - } - const provider = normalizeOptionalString(trimmed.slice(0, slashIndex)); - const model = normalizeOptionalString(trimmed.slice(slashIndex + 1)); - if (!provider || !model) { - return null; - } - return { - provider, - model, - }; + return parseGenerationModelRef(raw); } diff --git a/src/image-generation/provider-registry.allowlist.test.ts b/src/image-generation/provider-registry.allowlist.test.ts index bd6e5538a6b..e4a34b4d5ce 100644 --- a/src/image-generation/provider-registry.allowlist.test.ts +++ b/src/image-generation/provider-registry.allowlist.test.ts @@ -1,35 +1,15 @@ -import { beforeAll, beforeEach, describe, expect, it, vi } from "vitest"; +import { beforeAll, describe, expect, it } from "vitest"; import type { OpenClawConfig } from "../config/config.js"; -import { createEmptyPluginRegistry } from "../plugins/registry.js"; - -const mocks = vi.hoisted(() => ({ - resolveRuntimePluginRegistry: vi.fn< - (params?: unknown) => ReturnType | undefined - >(() => undefined), - loadPluginManifestRegistry: vi.fn(() => ({ plugins: [], diagnostics: [] })), - withBundledPluginEnablementCompat: vi.fn(({ config }) => config), - withBundledPluginVitestCompat: vi.fn(({ config }) => config), -})); - -vi.mock("../plugins/loader.js", () => ({ - resolveRuntimePluginRegistry: mocks.resolveRuntimePluginRegistry, -})); - -vi.mock("../plugins/manifest-registry.js", () => ({ - loadPluginManifestRegistry: mocks.loadPluginManifestRegistry, -})); - -vi.mock("../plugins/bundled-compat.js", async (importOriginal) => { - const actual = await importOriginal(); - return { - ...actual, - withBundledPluginEnablementCompat: mocks.withBundledPluginEnablementCompat, - withBundledPluginVitestCompat: mocks.withBundledPluginVitestCompat, - }; -}); +import { + createEmptyProviderRegistryAllowlistFallbackRegistry, + getProviderRegistryAllowlistMocks, + installProviderRegistryAllowlistMockDefaults, +} from "../test-utils/provider-registry-allowlist.test-helpers.js"; let getImageGenerationProvider: typeof import("./provider-registry.js").getImageGenerationProvider; let listImageGenerationProviders: typeof import("./provider-registry.js").listImageGenerationProviders; +const mocks = getProviderRegistryAllowlistMocks(); +installProviderRegistryAllowlistMockDefaults(); describe("image-generation provider registry allowlist fallback", () => { beforeAll(async () => { @@ -37,17 +17,6 @@ describe("image-generation provider registry allowlist fallback", () => { await import("./provider-registry.js")); }); - beforeEach(() => { - mocks.resolveRuntimePluginRegistry.mockReset(); - mocks.resolveRuntimePluginRegistry.mockReturnValue(undefined); - mocks.loadPluginManifestRegistry.mockReset(); - mocks.loadPluginManifestRegistry.mockReturnValue({ plugins: [], diagnostics: [] }); - mocks.withBundledPluginEnablementCompat.mockReset(); - mocks.withBundledPluginEnablementCompat.mockImplementation(({ config }) => config); - mocks.withBundledPluginVitestCompat.mockReset(); - mocks.withBundledPluginVitestCompat.mockImplementation(({ config }) => config); - }); - it("adds bundled capability plugin ids to plugins.allow before fallback registry load", () => { const cfg = { plugins: { allow: ["custom-plugin"] } } as OpenClawConfig; const compatConfig = { @@ -69,7 +38,9 @@ describe("image-generation provider registry allowlist fallback", () => { }); mocks.withBundledPluginEnablementCompat.mockReturnValue(compatConfig); mocks.withBundledPluginVitestCompat.mockReturnValue(compatConfig); - mocks.resolveRuntimePluginRegistry.mockImplementation(() => createEmptyPluginRegistry()); + mocks.resolveRuntimePluginRegistry.mockImplementation(() => + createEmptyProviderRegistryAllowlistFallbackRegistry(), + ); expect(listImageGenerationProviders(cfg)).toEqual([]); expect(getImageGenerationProvider("openai", cfg)).toBeUndefined(); diff --git a/src/image-generation/runtime.test.ts b/src/image-generation/runtime.test.ts index fe465a3d822..a9b2d928a19 100644 --- a/src/image-generation/runtime.test.ts +++ b/src/image-generation/runtime.test.ts @@ -1,4 +1,5 @@ import { beforeEach, describe, expect, it, vi } from "vitest"; +import { resetGenerationRuntimeMocks } from "../../test/helpers/media-generation/runtime-test-mocks.js"; import type { OpenClawConfig } from "../config/config.js"; import { generateImage, listRuntimeImageGenerationProviders } from "./runtime.js"; import type { ImageGenerationProvider } from "./types.js"; @@ -68,23 +69,12 @@ vi.mock("./provider-registry.js", () => ({ describe("image-generation runtime", () => { beforeEach(() => { - mocks.createSubsystemLogger.mockClear(); - mocks.describeFailoverError.mockReset(); - mocks.getImageGenerationProvider.mockReset(); - mocks.getProviderEnvVars.mockReset(); - mocks.getProviderEnvVars.mockReturnValue([]); - mocks.resolveProviderAuthEnvVarCandidates.mockReset(); - mocks.resolveProviderAuthEnvVarCandidates.mockReturnValue({}); - mocks.isFailoverError.mockReset(); - mocks.isFailoverError.mockReturnValue(false); - mocks.listImageGenerationProviders.mockReset(); - mocks.listImageGenerationProviders.mockReturnValue([]); - mocks.parseImageGenerationModelRef.mockClear(); - mocks.resolveAgentModelFallbackValues.mockReset(); - mocks.resolveAgentModelFallbackValues.mockReturnValue([]); - mocks.resolveAgentModelPrimaryValue.mockReset(); - mocks.resolveAgentModelPrimaryValue.mockReturnValue(undefined); - mocks.debug.mockReset(); + resetGenerationRuntimeMocks({ + ...mocks, + getProvider: mocks.getImageGenerationProvider, + listProviders: mocks.listImageGenerationProviders, + parseModelRef: mocks.parseImageGenerationModelRef, + }); }); it("generates images through the active image-generation provider", async () => { diff --git a/src/image-generation/runtime.ts b/src/image-generation/runtime.ts index 454a182538b..9838a921910 100644 --- a/src/image-generation/runtime.ts +++ b/src/image-generation/runtime.ts @@ -5,8 +5,8 @@ import type { OpenClawConfig } from "../config/config.js"; import { formatErrorMessage } from "../infra/errors.js"; import { createSubsystemLogger } from "../logging/subsystem.js"; import { + buildMediaGenerationNormalizationMetadata, buildNoCapabilityModelConfiguredMessage, - deriveAspectRatioFromSize, resolveCapabilityModelCandidates, throwCapabilityGenerationFailure, } from "../media-generation/runtime-shared.js"; @@ -122,34 +122,10 @@ export async function generateImage( normalization: sanitized.normalization, metadata: { ...result.metadata, - ...(sanitized.normalization?.size?.requested !== undefined && - sanitized.normalization.size.applied !== undefined - ? { - requestedSize: sanitized.normalization.size.requested, - normalizedSize: sanitized.normalization.size.applied, - } - : {}), - ...(sanitized.normalization?.aspectRatio?.applied !== undefined - ? { - ...(sanitized.normalization.aspectRatio.requested !== undefined - ? { requestedAspectRatio: sanitized.normalization.aspectRatio.requested } - : {}), - normalizedAspectRatio: sanitized.normalization.aspectRatio.applied, - ...(sanitized.normalization.aspectRatio.derivedFrom === "size" && params.size - ? { - requestedSize: params.size, - aspectRatioDerivedFromSize: deriveAspectRatioFromSize(params.size), - } - : {}), - } - : {}), - ...(sanitized.normalization?.resolution?.requested !== undefined && - sanitized.normalization.resolution.applied !== undefined - ? { - requestedResolution: sanitized.normalization.resolution.requested, - normalizedResolution: sanitized.normalization.resolution.applied, - } - : {}), + ...buildMediaGenerationNormalizationMetadata({ + normalization: sanitized.normalization, + requestedSizeForDerivedAspectRatio: params.size, + }), }, ignoredOverrides: sanitized.ignoredOverrides, }; diff --git a/src/infra/approval-handler-bootstrap.test.ts b/src/infra/approval-handler-bootstrap.test.ts index 8e4652ba99b..ad65f5e6b88 100644 --- a/src/infra/approval-handler-bootstrap.test.ts +++ b/src/infra/approval-handler-bootstrap.test.ts @@ -28,6 +28,52 @@ describe("startChannelApprovalHandlerBootstrap", () => { await Promise.resolve(); }; + const createApprovalPlugin = () => + ({ + id: "slack", + meta: { label: "Slack" }, + approvalCapability: { + nativeRuntime: { + availability: { + isConfigured: vi.fn().mockReturnValue(true), + shouldHandle: vi.fn().mockReturnValue(true), + }, + presentation: { + buildPendingPayload: vi.fn(), + buildResolvedResult: vi.fn(), + buildExpiredResult: vi.fn(), + }, + transport: { + prepareTarget: vi.fn(), + deliverPending: vi.fn(), + }, + }, + }, + }) as never; + + const startTestBootstrap = (params: { + channelRuntime: ReturnType; + logger?: unknown; + }) => + startChannelApprovalHandlerBootstrap({ + plugin: createApprovalPlugin(), + cfg: {} as never, + accountId: "default", + channelRuntime: params.channelRuntime, + logger: params.logger as never, + }); + + const registerApprovalContext = ( + channelRuntime: ReturnType, + app: unknown = { ok: true }, + ) => + channelRuntime.runtimeContexts.register({ + channelId: "slack", + accountId: "default", + capability: "approval.native", + context: { app }, + }); + it("starts and stops the shared approval handler from runtime context registration", async () => { const channelRuntime = createRuntimeChannel(); const start = vi.fn().mockResolvedValue(undefined); @@ -37,39 +83,9 @@ describe("startChannelApprovalHandlerBootstrap", () => { stop, }); - const cleanup = await startChannelApprovalHandlerBootstrap({ - plugin: { - id: "slack", - meta: { label: "Slack" }, - approvalCapability: { - nativeRuntime: { - availability: { - isConfigured: vi.fn().mockReturnValue(true), - shouldHandle: vi.fn().mockReturnValue(true), - }, - presentation: { - buildPendingPayload: vi.fn(), - buildResolvedResult: vi.fn(), - buildExpiredResult: vi.fn(), - }, - transport: { - prepareTarget: vi.fn(), - deliverPending: vi.fn(), - }, - }, - }, - } as never, - cfg: {} as never, - accountId: "default", - channelRuntime, - }); + const cleanup = await startTestBootstrap({ channelRuntime }); - const lease = channelRuntime.runtimeContexts.register({ - channelId: "slack", - accountId: "default", - capability: "approval.native", - context: { app: { ok: true } }, - }); + const lease = registerApprovalContext(channelRuntime); await flushTransitions(); expect(createChannelApprovalHandlerFromCapability).toHaveBeenCalled(); @@ -92,39 +108,9 @@ describe("startChannelApprovalHandlerBootstrap", () => { stop, }); - const lease = channelRuntime.runtimeContexts.register({ - channelId: "slack", - accountId: "default", - capability: "approval.native", - context: { app: { ok: true } }, - }); + const lease = registerApprovalContext(channelRuntime); - const cleanup = await startChannelApprovalHandlerBootstrap({ - plugin: { - id: "slack", - meta: { label: "Slack" }, - approvalCapability: { - nativeRuntime: { - availability: { - isConfigured: vi.fn().mockReturnValue(true), - shouldHandle: vi.fn().mockReturnValue(true), - }, - presentation: { - buildPendingPayload: vi.fn(), - buildResolvedResult: vi.fn(), - buildExpiredResult: vi.fn(), - }, - transport: { - prepareTarget: vi.fn(), - deliverPending: vi.fn(), - }, - }, - }, - } as never, - cfg: {} as never, - accountId: "default", - channelRuntime, - }); + const cleanup = await startTestBootstrap({ channelRuntime }); expect(createChannelApprovalHandlerFromCapability).toHaveBeenCalledTimes(1); expect(start).toHaveBeenCalledTimes(1); @@ -147,39 +133,9 @@ describe("startChannelApprovalHandlerBootstrap", () => { }); createChannelApprovalHandlerFromCapability.mockReturnValue(runtimePromise); - const cleanup = await startChannelApprovalHandlerBootstrap({ - plugin: { - id: "slack", - meta: { label: "Slack" }, - approvalCapability: { - nativeRuntime: { - availability: { - isConfigured: vi.fn().mockReturnValue(true), - shouldHandle: vi.fn().mockReturnValue(true), - }, - presentation: { - buildPendingPayload: vi.fn(), - buildResolvedResult: vi.fn(), - buildExpiredResult: vi.fn(), - }, - transport: { - prepareTarget: vi.fn(), - deliverPending: vi.fn(), - }, - }, - }, - } as never, - cfg: {} as never, - accountId: "default", - channelRuntime, - }); + const cleanup = await startTestBootstrap({ channelRuntime }); - const lease = channelRuntime.runtimeContexts.register({ - channelId: "slack", - accountId: "default", - capability: "approval.native", - context: { app: { ok: true } }, - }); + const lease = registerApprovalContext(channelRuntime); await flushTransitions(); const start = vi.fn().mockResolvedValue(undefined); @@ -211,47 +167,12 @@ describe("startChannelApprovalHandlerBootstrap", () => { stop: stopSecond, }); - const cleanup = await startChannelApprovalHandlerBootstrap({ - plugin: { - id: "slack", - meta: { label: "Slack" }, - approvalCapability: { - nativeRuntime: { - availability: { - isConfigured: vi.fn().mockReturnValue(true), - shouldHandle: vi.fn().mockReturnValue(true), - }, - presentation: { - buildPendingPayload: vi.fn(), - buildResolvedResult: vi.fn(), - buildExpiredResult: vi.fn(), - }, - transport: { - prepareTarget: vi.fn(), - deliverPending: vi.fn(), - }, - }, - }, - } as never, - cfg: {} as never, - accountId: "default", - channelRuntime, - }); + const cleanup = await startTestBootstrap({ channelRuntime }); - const firstLease = channelRuntime.runtimeContexts.register({ - channelId: "slack", - accountId: "default", - capability: "approval.native", - context: { app: { ok: "first" } }, - }); + const firstLease = registerApprovalContext(channelRuntime, { ok: "first" }); await flushTransitions(); - const secondLease = channelRuntime.runtimeContexts.register({ - channelId: "slack", - accountId: "default", - capability: "approval.native", - context: { app: { ok: "second" } }, - }); + const secondLease = registerApprovalContext(channelRuntime, { ok: "second" }); await flushTransitions(); expect(createChannelApprovalHandlerFromCapability).toHaveBeenCalledTimes(2); @@ -287,40 +208,9 @@ describe("startChannelApprovalHandlerBootstrap", () => { .mockResolvedValueOnce({ start, stop }) .mockResolvedValueOnce({ start, stop }); - const cleanup = await startChannelApprovalHandlerBootstrap({ - plugin: { - id: "slack", - meta: { label: "Slack" }, - approvalCapability: { - nativeRuntime: { - availability: { - isConfigured: vi.fn().mockReturnValue(true), - shouldHandle: vi.fn().mockReturnValue(true), - }, - presentation: { - buildPendingPayload: vi.fn(), - buildResolvedResult: vi.fn(), - buildExpiredResult: vi.fn(), - }, - transport: { - prepareTarget: vi.fn(), - deliverPending: vi.fn(), - }, - }, - }, - } as never, - cfg: {} as never, - accountId: "default", - channelRuntime, - logger: logger as never, - }); + const cleanup = await startTestBootstrap({ channelRuntime, logger }); - channelRuntime.runtimeContexts.register({ - channelId: "slack", - accountId: "default", - capability: "approval.native", - context: { app: { ok: true } }, - }); + registerApprovalContext(channelRuntime); await flushTransitions(); expect(start).toHaveBeenCalledTimes(1); @@ -349,48 +239,13 @@ describe("startChannelApprovalHandlerBootstrap", () => { .mockResolvedValueOnce({ start: secondStart, stop: secondStop }) .mockResolvedValueOnce({ start: secondStart, stop: secondStop }); - const cleanup = await startChannelApprovalHandlerBootstrap({ - plugin: { - id: "slack", - meta: { label: "Slack" }, - approvalCapability: { - nativeRuntime: { - availability: { - isConfigured: vi.fn().mockReturnValue(true), - shouldHandle: vi.fn().mockReturnValue(true), - }, - presentation: { - buildPendingPayload: vi.fn(), - buildResolvedResult: vi.fn(), - buildExpiredResult: vi.fn(), - }, - transport: { - prepareTarget: vi.fn(), - deliverPending: vi.fn(), - }, - }, - }, - } as never, - cfg: {} as never, - accountId: "default", - channelRuntime, - }); + const cleanup = await startTestBootstrap({ channelRuntime }); - channelRuntime.runtimeContexts.register({ - channelId: "slack", - accountId: "default", - capability: "approval.native", - context: { app: { ok: "first" } }, - }); + registerApprovalContext(channelRuntime, { ok: "first" }); await flushTransitions(); expect(firstStart).toHaveBeenCalledTimes(1); - channelRuntime.runtimeContexts.register({ - channelId: "slack", - accountId: "default", - capability: "approval.native", - context: { app: { ok: "second" } }, - }); + registerApprovalContext(channelRuntime, { ok: "second" }); await flushTransitions(); expect(secondStart).toHaveBeenCalledTimes(1); diff --git a/src/infra/outbound/message-action-runner.context.test.ts b/src/infra/outbound/message-action-runner.context.test.ts index 25d3c702c42..2cf9ab00813 100644 --- a/src/infra/outbound/message-action-runner.context.test.ts +++ b/src/infra/outbound/message-action-runner.context.test.ts @@ -1,172 +1,21 @@ import { afterEach, beforeEach, describe, expect, it } from "vitest"; -import type { - ChannelDirectoryEntryKind, - ChannelMessageActionName, - ChannelMessagingAdapter, - ChannelOutboundAdapter, - ChannelPlugin, -} from "../../channels/plugins/types.js"; +import type { ChannelPlugin } from "../../channels/plugins/types.js"; import type { OpenClawConfig } from "../../config/config.js"; import { setActivePluginRegistry } from "../../plugins/runtime.js"; import { createChannelTestPluginBase, createTestRegistry, } from "../../test-utils/channel-plugins.js"; -import { runMessageAction } from "./message-action-runner.js"; - -const slackConfig = { - channels: { - slack: { - botToken: "xoxb-test", - appToken: "xapp-test", - }, - }, -} as OpenClawConfig; - -const whatsappConfig = { - channels: { - whatsapp: { - allowFrom: ["*"], - }, - }, -} as OpenClawConfig; - -const runDryAction = (params: { - cfg: OpenClawConfig; - action: ChannelMessageActionName; - actionParams: Record; - toolContext?: Record; - abortSignal?: AbortSignal; - sandboxRoot?: string; -}) => - runMessageAction({ - cfg: params.cfg, - action: params.action, - params: params.actionParams as never, - toolContext: params.toolContext as never, - dryRun: true, - abortSignal: params.abortSignal, - sandboxRoot: params.sandboxRoot, - }); - -const runDrySend = (params: { - cfg: OpenClawConfig; - actionParams: Record; - toolContext?: Record; - abortSignal?: AbortSignal; - sandboxRoot?: string; -}) => - runDryAction({ - ...params, - action: "send", - }); - -type ResolvedTestTarget = { to: string; kind: ChannelDirectoryEntryKind }; - -const directOutbound: ChannelOutboundAdapter = { deliveryMode: "direct" }; - -function normalizeSlackTarget(raw: string): string { - const trimmed = raw.trim(); - if (!trimmed) { - return trimmed; - } - if (trimmed.startsWith("#")) { - return trimmed.slice(1).trim(); - } - if (/^channel:/i.test(trimmed)) { - return trimmed.replace(/^channel:/i, "").trim(); - } - if (/^user:/i.test(trimmed)) { - return trimmed.replace(/^user:/i, "").trim(); - } - const mention = trimmed.match(/^<@([A-Z0-9]+)>$/i); - if (mention?.[1]) { - return mention[1]; - } - return trimmed; -} - -function createConfiguredTestPlugin(params: { - id: "slack" | "telegram" | "whatsapp"; - isConfigured: (cfg: OpenClawConfig) => boolean; - normalizeTarget: (raw: string) => string | undefined; - resolveTarget: (input: string) => ResolvedTestTarget | null; -}): ChannelPlugin { - const messaging: ChannelMessagingAdapter = { - normalizeTarget: params.normalizeTarget, - targetResolver: { - looksLikeId: (raw) => Boolean(params.resolveTarget(raw.trim())), - hint: "", - resolveTarget: async (resolverParams) => { - const resolved = params.resolveTarget(resolverParams.input); - return resolved ? { ...resolved, source: "normalized" } : null; - }, - }, - inferTargetChatType: (inferParams) => - params.resolveTarget(inferParams.to)?.kind === "user" ? "direct" : "group", - }; - return { - ...createChannelTestPluginBase({ - id: params.id, - config: { - listAccountIds: () => ["default"], - resolveAccount: () => ({ enabled: true }), - isConfigured: (_account, cfg) => params.isConfigured(cfg), - }, - }), - outbound: directOutbound, - messaging, - }; -} - -const slackTestPlugin = createConfiguredTestPlugin({ - id: "slack", - isConfigured: (cfg) => Boolean(cfg.channels?.slack?.botToken?.trim()), - normalizeTarget: (raw) => normalizeSlackTarget(raw) || undefined, - resolveTarget: (input) => { - const normalized = normalizeSlackTarget(input); - if (!normalized) { - return null; - } - if (/^[A-Z0-9]+$/i.test(normalized)) { - const kind = /^U/i.test(normalized) ? "user" : "group"; - return { to: normalized, kind }; - } - return null; - }, -}); - -const telegramTestPlugin = createConfiguredTestPlugin({ - id: "telegram", - isConfigured: (cfg) => Boolean(cfg.channels?.telegram?.botToken?.trim()), - normalizeTarget: (raw) => raw.trim() || undefined, - resolveTarget: (input) => { - const normalized = input.trim(); - if (!normalized) { - return null; - } - return { - to: normalized.replace(/^telegram:/i, ""), - kind: normalized.startsWith("@") ? "user" : "group", - }; - }, -}); - -const whatsappTestPlugin = createConfiguredTestPlugin({ - id: "whatsapp", - isConfigured: (cfg) => Boolean(cfg.channels?.whatsapp), - normalizeTarget: (raw) => raw.trim() || undefined, - resolveTarget: (input) => { - const normalized = input.trim(); - if (!normalized) { - return null; - } - return { - to: normalized, - kind: normalized.endsWith("@g.us") ? "group" : "user", - }; - }, -}); +import { + directOutbound, + runDryAction, + runDrySend, + slackConfig, + slackTestPlugin, + telegramTestPlugin, + whatsappConfig, + whatsappTestPlugin, +} from "./message-action-runner.test-helpers.js"; const imessageTestPlugin: ChannelPlugin = { ...createChannelTestPluginBase({ diff --git a/src/infra/outbound/message-action-runner.media.test.ts b/src/infra/outbound/message-action-runner.media.test.ts index 8ecfd863846..48786637d79 100644 --- a/src/infra/outbound/message-action-runner.media.test.ts +++ b/src/infra/outbound/message-action-runner.media.test.ts @@ -40,91 +40,11 @@ vi.mock("./outbound-session.js", () => ({ resolveOutboundSessionRoute: vi.fn(async () => null), })); -vi.mock("./message-action-threading.js", () => ({ - resolveAndApplyOutboundThreadId: vi.fn( - ( - actionParams: Record, - context: { - cfg: OpenClawConfig; - to: string; - accountId?: string | null; - toolContext?: Record; - resolveAutoThreadId?: (params: { - cfg: OpenClawConfig; - accountId?: string | null; - to: string; - toolContext?: Record; - replyToId?: string; - }) => string | undefined; - }, - ) => { - const explicit = - typeof actionParams.threadId === "string" ? actionParams.threadId : undefined; - const replyToId = typeof actionParams.replyTo === "string" ? actionParams.replyTo : undefined; - const resolved = - explicit ?? - context.resolveAutoThreadId?.({ - cfg: context.cfg, - accountId: context.accountId, - to: context.to, - toolContext: context.toolContext, - replyToId, - }); - if (resolved && !actionParams.threadId) { - actionParams.threadId = resolved; - } - return resolved ?? undefined; - }, - ), - prepareOutboundMirrorRoute: vi.fn( - async ({ - actionParams, - cfg, - to, - accountId, - toolContext, - agentId, - resolveAutoThreadId, - }: { - actionParams: Record; - cfg: OpenClawConfig; - to: string; - accountId?: string | null; - toolContext?: Record; - agentId?: string; - resolveAutoThreadId?: (params: { - cfg: OpenClawConfig; - accountId?: string | null; - to: string; - toolContext?: Record; - replyToId?: string; - }) => string | undefined; - }) => { - const explicit = - typeof actionParams.threadId === "string" ? actionParams.threadId : undefined; - const replyToId = typeof actionParams.replyTo === "string" ? actionParams.replyTo : undefined; - const resolvedThreadId = - explicit ?? - resolveAutoThreadId?.({ - cfg, - accountId, - to, - toolContext, - replyToId, - }); - if (resolvedThreadId && !actionParams.threadId) { - actionParams.threadId = resolvedThreadId; - } - if (agentId) { - actionParams.__agentId = agentId; - } - return { - resolvedThreadId, - outboundRoute: null, - }; - }, - ), -})); +vi.mock("./message-action-threading.js", async () => { + const { createOutboundThreadingMock } = + await import("./message-action-threading.test-helpers.js"); + return createOutboundThreadingMock(); +}); vi.mock("../../media/web-media.js", async () => { const actual = await vi.importActual( diff --git a/src/infra/outbound/message-action-runner.plugin-dispatch.test.ts b/src/infra/outbound/message-action-runner.plugin-dispatch.test.ts index 8fefa9de9db..fe725e3a4d0 100644 --- a/src/infra/outbound/message-action-runner.plugin-dispatch.test.ts +++ b/src/infra/outbound/message-action-runner.plugin-dispatch.test.ts @@ -47,91 +47,11 @@ vi.mock("../../channels/plugins/bootstrap-registry.js", () => ({ : undefined, })); -vi.mock("./message-action-threading.js", () => ({ - resolveAndApplyOutboundThreadId: vi.fn( - ( - actionParams: Record, - context: { - cfg: OpenClawConfig; - to: string; - accountId?: string | null; - toolContext?: Record; - resolveAutoThreadId?: (params: { - cfg: OpenClawConfig; - accountId?: string | null; - to: string; - toolContext?: Record; - replyToId?: string; - }) => string | undefined; - }, - ) => { - const explicit = - typeof actionParams.threadId === "string" ? actionParams.threadId : undefined; - const replyToId = typeof actionParams.replyTo === "string" ? actionParams.replyTo : undefined; - const resolved = - explicit ?? - context.resolveAutoThreadId?.({ - cfg: context.cfg, - accountId: context.accountId, - to: context.to, - toolContext: context.toolContext, - replyToId, - }); - if (resolved && !actionParams.threadId) { - actionParams.threadId = resolved; - } - return resolved ?? undefined; - }, - ), - prepareOutboundMirrorRoute: vi.fn( - async ({ - actionParams, - cfg, - to, - accountId, - toolContext, - agentId, - resolveAutoThreadId, - }: { - actionParams: Record; - cfg: OpenClawConfig; - to: string; - accountId?: string | null; - toolContext?: Record; - agentId?: string; - resolveAutoThreadId?: (params: { - cfg: OpenClawConfig; - accountId?: string | null; - to: string; - toolContext?: Record; - replyToId?: string; - }) => string | undefined; - }) => { - const explicit = - typeof actionParams.threadId === "string" ? actionParams.threadId : undefined; - const replyToId = typeof actionParams.replyTo === "string" ? actionParams.replyTo : undefined; - const resolvedThreadId = - explicit ?? - resolveAutoThreadId?.({ - cfg, - accountId, - to, - toolContext, - replyToId, - }); - if (resolvedThreadId && !actionParams.threadId) { - actionParams.threadId = resolvedThreadId; - } - if (agentId) { - actionParams.__agentId = agentId; - } - return { - resolvedThreadId, - outboundRoute: null, - }; - }, - ), -})); +vi.mock("./message-action-threading.js", async () => { + const { createOutboundThreadingMock } = + await import("./message-action-threading.test-helpers.js"); + return createOutboundThreadingMock(); +}); function createAlwaysConfiguredPluginConfig(account: Record = { enabled: true }) { return { diff --git a/src/infra/outbound/message-action-runner.poll.test.ts b/src/infra/outbound/message-action-runner.poll.test.ts index e726c0b98fa..2de12b44034 100644 --- a/src/infra/outbound/message-action-runner.poll.test.ts +++ b/src/infra/outbound/message-action-runner.poll.test.ts @@ -27,91 +27,11 @@ vi.mock("./outbound-session.js", () => ({ resolveOutboundSessionRoute: vi.fn(async () => null), })); -vi.mock("./message-action-threading.js", () => ({ - resolveAndApplyOutboundThreadId: vi.fn( - ( - actionParams: Record, - context: { - cfg: OpenClawConfig; - to: string; - accountId?: string | null; - toolContext?: Record; - resolveAutoThreadId?: (params: { - cfg: OpenClawConfig; - accountId?: string | null; - to: string; - toolContext?: Record; - replyToId?: string; - }) => string | undefined; - }, - ) => { - const explicit = - typeof actionParams.threadId === "string" ? actionParams.threadId : undefined; - const replyToId = typeof actionParams.replyTo === "string" ? actionParams.replyTo : undefined; - const resolved = - explicit ?? - context.resolveAutoThreadId?.({ - cfg: context.cfg, - accountId: context.accountId, - to: context.to, - toolContext: context.toolContext, - replyToId, - }); - if (resolved && !actionParams.threadId) { - actionParams.threadId = resolved; - } - return resolved ?? undefined; - }, - ), - prepareOutboundMirrorRoute: vi.fn( - async ({ - actionParams, - cfg, - to, - accountId, - toolContext, - agentId, - resolveAutoThreadId, - }: { - actionParams: Record; - cfg: OpenClawConfig; - to: string; - accountId?: string | null; - toolContext?: Record; - agentId?: string; - resolveAutoThreadId?: (params: { - cfg: OpenClawConfig; - accountId?: string | null; - to: string; - toolContext?: Record; - replyToId?: string; - }) => string | undefined; - }) => { - const explicit = - typeof actionParams.threadId === "string" ? actionParams.threadId : undefined; - const replyToId = typeof actionParams.replyTo === "string" ? actionParams.replyTo : undefined; - const resolvedThreadId = - explicit ?? - resolveAutoThreadId?.({ - cfg, - accountId, - to, - toolContext, - replyToId, - }); - if (resolvedThreadId && !actionParams.threadId) { - actionParams.threadId = resolvedThreadId; - } - if (agentId) { - actionParams.__agentId = agentId; - } - return { - resolvedThreadId, - outboundRoute: null, - }; - }, - ), -})); +vi.mock("./message-action-threading.js", async () => { + const { createOutboundThreadingMock } = + await import("./message-action-threading.test-helpers.js"); + return createOutboundThreadingMock(); +}); const telegramConfig = { channels: { telegram: { diff --git a/src/infra/outbound/message-action-runner.send-validation.test.ts b/src/infra/outbound/message-action-runner.send-validation.test.ts index b7cde5eff8d..e16dc2579c3 100644 --- a/src/infra/outbound/message-action-runner.send-validation.test.ts +++ b/src/infra/outbound/message-action-runner.send-validation.test.ts @@ -1,130 +1,13 @@ import { afterEach, beforeEach, describe, expect, it } from "vitest"; -import type { - ChannelDirectoryEntryKind, - ChannelMessagingAdapter, - ChannelOutboundAdapter, - ChannelPlugin, -} from "../../channels/plugins/types.js"; import type { OpenClawConfig } from "../../config/config.js"; import { setActivePluginRegistry } from "../../plugins/runtime.js"; +import { createTestRegistry } from "../../test-utils/channel-plugins.js"; import { - createChannelTestPluginBase, - createTestRegistry, -} from "../../test-utils/channel-plugins.js"; -import { runMessageAction } from "./message-action-runner.js"; - -const slackConfig = { - channels: { - slack: { - botToken: "xoxb-test", - appToken: "xapp-test", - }, - }, -} as OpenClawConfig; - -const runDrySend = (params: { - cfg: OpenClawConfig; - actionParams: Record; - toolContext?: Record; -}) => - runMessageAction({ - cfg: params.cfg, - action: "send", - params: params.actionParams as never, - toolContext: params.toolContext as never, - dryRun: true, - }); - -type ResolvedTestTarget = { to: string; kind: ChannelDirectoryEntryKind }; - -const directOutbound: ChannelOutboundAdapter = { deliveryMode: "direct" }; - -function normalizeSlackTarget(raw: string): string { - const trimmed = raw.trim(); - if (!trimmed) { - return trimmed; - } - if (trimmed.startsWith("#")) { - return trimmed.slice(1).trim(); - } - if (/^channel:/i.test(trimmed)) { - return trimmed.replace(/^channel:/i, "").trim(); - } - if (/^user:/i.test(trimmed)) { - return trimmed.replace(/^user:/i, "").trim(); - } - const mention = trimmed.match(/^<@([A-Z0-9]+)>$/i); - if (mention?.[1]) { - return mention[1]; - } - return trimmed; -} - -function createConfiguredTestPlugin(params: { - id: "slack" | "telegram"; - isConfigured: (cfg: OpenClawConfig) => boolean; - normalizeTarget: (raw: string) => string | undefined; - resolveTarget: (input: string) => ResolvedTestTarget | null; -}): ChannelPlugin { - const messaging: ChannelMessagingAdapter = { - normalizeTarget: params.normalizeTarget, - targetResolver: { - looksLikeId: (raw) => Boolean(params.resolveTarget(raw.trim())), - hint: "", - resolveTarget: async (resolverParams) => { - const resolved = params.resolveTarget(resolverParams.input); - return resolved ? { ...resolved, source: "normalized" } : null; - }, - }, - inferTargetChatType: (inferParams) => - params.resolveTarget(inferParams.to)?.kind === "user" ? "direct" : "group", - }; - return { - ...createChannelTestPluginBase({ - id: params.id, - config: { - listAccountIds: () => ["default"], - resolveAccount: () => ({ enabled: true }), - isConfigured: (_account, cfg) => params.isConfigured(cfg), - }, - }), - outbound: directOutbound, - messaging, - }; -} - -const slackTestPlugin = createConfiguredTestPlugin({ - id: "slack", - isConfigured: (cfg) => Boolean(cfg.channels?.slack?.botToken?.trim()), - normalizeTarget: (raw) => normalizeSlackTarget(raw) || undefined, - resolveTarget: (input) => { - const normalized = normalizeSlackTarget(input); - if (!normalized) { - return null; - } - if (/^[A-Z0-9]+$/i.test(normalized)) { - const kind = /^U/i.test(normalized) ? "user" : "group"; - return { to: normalized, kind }; - } - return null; - }, -}); - -const telegramTestPlugin = createConfiguredTestPlugin({ - id: "telegram", - isConfigured: (cfg) => Boolean(cfg.channels?.telegram?.botToken?.trim()), - normalizeTarget: (raw) => raw.trim() || undefined, - resolveTarget: (input) => { - const normalized = input.trim(); - if (!normalized) { - return null; - } - return { - to: normalized.replace(/^telegram:/i, ""), - kind: normalized.startsWith("@") ? "user" : "group", - }; - }, -}); + runDrySend, + slackConfig, + slackTestPlugin, + telegramTestPlugin, +} from "./message-action-runner.test-helpers.js"; describe("runMessageAction send validation", () => { beforeEach(() => { diff --git a/src/infra/outbound/message-action-runner.test-helpers.ts b/src/infra/outbound/message-action-runner.test-helpers.ts new file mode 100644 index 00000000000..38a14d9d708 --- /dev/null +++ b/src/infra/outbound/message-action-runner.test-helpers.ts @@ -0,0 +1,164 @@ +import type { + ChannelDirectoryEntryKind, + ChannelMessageActionName, + ChannelMessagingAdapter, + ChannelOutboundAdapter, + ChannelPlugin, +} from "../../channels/plugins/types.js"; +import type { OpenClawConfig } from "../../config/config.js"; +import { createChannelTestPluginBase } from "../../test-utils/channel-plugins.js"; +import { runMessageAction } from "./message-action-runner.js"; + +export const slackConfig = { + channels: { + slack: { + botToken: "xoxb-test", + appToken: "xapp-test", + }, + }, +} as OpenClawConfig; + +export const whatsappConfig = { + channels: { + whatsapp: { + allowFrom: ["*"], + }, + }, +} as OpenClawConfig; + +export const directOutbound: ChannelOutboundAdapter = { deliveryMode: "direct" }; + +export const runDryAction = (params: { + cfg: OpenClawConfig; + action: ChannelMessageActionName; + actionParams: Record; + toolContext?: Record; + abortSignal?: AbortSignal; + sandboxRoot?: string; +}) => + runMessageAction({ + cfg: params.cfg, + action: params.action, + params: params.actionParams as never, + toolContext: params.toolContext as never, + dryRun: true, + abortSignal: params.abortSignal, + sandboxRoot: params.sandboxRoot, + }); + +export const runDrySend = (params: { + cfg: OpenClawConfig; + actionParams: Record; + toolContext?: Record; + abortSignal?: AbortSignal; + sandboxRoot?: string; +}) => + runDryAction({ + ...params, + action: "send", + }); + +type ResolvedTestTarget = { to: string; kind: ChannelDirectoryEntryKind }; + +export function normalizeSlackTarget(raw: string): string { + const trimmed = raw.trim(); + if (!trimmed) { + return trimmed; + } + if (trimmed.startsWith("#")) { + return trimmed.slice(1).trim(); + } + if (/^channel:/i.test(trimmed)) { + return trimmed.replace(/^channel:/i, "").trim(); + } + if (/^user:/i.test(trimmed)) { + return trimmed.replace(/^user:/i, "").trim(); + } + const mention = trimmed.match(/^<@([A-Z0-9]+)>$/i); + if (mention?.[1]) { + return mention[1]; + } + return trimmed; +} + +export function createConfiguredTestPlugin(params: { + id: "slack" | "telegram" | "whatsapp"; + isConfigured: (cfg: OpenClawConfig) => boolean; + normalizeTarget: (raw: string) => string | undefined; + resolveTarget: (input: string) => ResolvedTestTarget | null; +}): ChannelPlugin { + const messaging: ChannelMessagingAdapter = { + normalizeTarget: params.normalizeTarget, + targetResolver: { + looksLikeId: (raw) => Boolean(params.resolveTarget(raw.trim())), + hint: "", + resolveTarget: async (resolverParams) => { + const resolved = params.resolveTarget(resolverParams.input); + return resolved ? { ...resolved, source: "normalized" } : null; + }, + }, + inferTargetChatType: (inferParams) => + params.resolveTarget(inferParams.to)?.kind === "user" ? "direct" : "group", + }; + return { + ...createChannelTestPluginBase({ + id: params.id, + config: { + listAccountIds: () => ["default"], + resolveAccount: () => ({ enabled: true }), + isConfigured: (_account, cfg) => params.isConfigured(cfg), + }, + }), + outbound: directOutbound, + messaging, + }; +} + +export const slackTestPlugin = createConfiguredTestPlugin({ + id: "slack", + isConfigured: (cfg) => Boolean(cfg.channels?.slack?.botToken?.trim()), + normalizeTarget: (raw) => normalizeSlackTarget(raw) || undefined, + resolveTarget: (input) => { + const normalized = normalizeSlackTarget(input); + if (!normalized) { + return null; + } + if (/^[A-Z0-9]+$/i.test(normalized)) { + const kind = /^U/i.test(normalized) ? "user" : "group"; + return { to: normalized, kind }; + } + return null; + }, +}); + +export const telegramTestPlugin = createConfiguredTestPlugin({ + id: "telegram", + isConfigured: (cfg) => Boolean(cfg.channels?.telegram?.botToken?.trim()), + normalizeTarget: (raw) => raw.trim() || undefined, + resolveTarget: (input) => { + const normalized = input.trim(); + if (!normalized) { + return null; + } + return { + to: normalized.replace(/^telegram:/i, ""), + kind: normalized.startsWith("@") ? "user" : "group", + }; + }, +}); + +export const whatsappTestPlugin = createConfiguredTestPlugin({ + id: "whatsapp", + isConfigured: (cfg) => Boolean(cfg.channels?.whatsapp), + normalizeTarget: (raw) => raw.trim() || undefined, + resolveTarget: (input) => { + const normalized = input.trim(); + if (!normalized) { + return null; + } + return { + to: normalized, + kind: normalized.endsWith("@g.us") ? "group" : "user", + }; + }, +}); diff --git a/src/infra/outbound/message-action-threading.test-helpers.ts b/src/infra/outbound/message-action-threading.test-helpers.ts new file mode 100644 index 00000000000..ad4ae802933 --- /dev/null +++ b/src/infra/outbound/message-action-threading.test-helpers.ts @@ -0,0 +1,79 @@ +import { vi } from "vitest"; +import type { OpenClawConfig } from "../../config/config.js"; + +type AutoThreadResolver = (params: { + cfg: OpenClawConfig; + accountId?: string | null; + to: string; + toolContext?: Record; + replyToId?: string; +}) => string | undefined; + +type OutboundThreadContext = { + cfg: OpenClawConfig; + to: string; + accountId?: string | null; + toolContext?: Record; + resolveAutoThreadId?: AutoThreadResolver; +}; + +function resolveOutboundThreadId( + actionParams: Record, + context: OutboundThreadContext, +): string | undefined { + const explicit = typeof actionParams.threadId === "string" ? actionParams.threadId : undefined; + const replyToId = typeof actionParams.replyTo === "string" ? actionParams.replyTo : undefined; + const resolved = + explicit ?? + context.resolveAutoThreadId?.({ + cfg: context.cfg, + accountId: context.accountId, + to: context.to, + toolContext: context.toolContext, + replyToId, + }); + if (resolved && !actionParams.threadId) { + actionParams.threadId = resolved; + } + return resolved ?? undefined; +} + +export function createOutboundThreadingMock() { + return { + resolveAndApplyOutboundThreadId: vi.fn(resolveOutboundThreadId), + prepareOutboundMirrorRoute: vi.fn( + async ({ + actionParams, + cfg, + to, + accountId, + toolContext, + agentId, + resolveAutoThreadId, + }: { + actionParams: Record; + cfg: OpenClawConfig; + to: string; + accountId?: string | null; + toolContext?: Record; + agentId?: string; + resolveAutoThreadId?: AutoThreadResolver; + }) => { + const resolvedThreadId = resolveOutboundThreadId(actionParams, { + cfg, + accountId, + to, + toolContext, + resolveAutoThreadId, + }); + if (agentId) { + actionParams.__agentId = agentId; + } + return { + resolvedThreadId, + outboundRoute: null, + }; + }, + ), + }; +} diff --git a/src/media-generation/live-test-helpers.ts b/src/media-generation/live-test-helpers.ts new file mode 100644 index 00000000000..9f063d4acce --- /dev/null +++ b/src/media-generation/live-test-helpers.ts @@ -0,0 +1,101 @@ +import type { AuthProfileStore } from "../agents/auth-profiles.js"; +import { normalizeOptionalLowercaseString } from "../shared/string-coerce.js"; + +type LiveProviderModelConfig = + | string + | { + primary?: string; + fallbacks?: readonly string[]; + } + | undefined; + +export function redactLiveApiKey(value: string | undefined): string { + const trimmed = value?.trim(); + if (!trimmed) { + return "none"; + } + if (trimmed.length <= 12) { + return trimmed; + } + return `${trimmed.slice(0, 8)}...${trimmed.slice(-4)}`; +} + +export function parseLiveCsvFilter( + raw?: string, + options: { lowercase?: boolean } = {}, +): Set | null { + const trimmed = raw?.trim(); + if (!trimmed || trimmed === "all") { + return null; + } + const values = trimmed + .split(",") + .map((entry) => + options.lowercase === false ? entry.trim() : normalizeOptionalLowercaseString(entry), + ) + .filter((entry): entry is string => Boolean(entry)); + return values.length > 0 ? new Set(values) : null; +} + +export function parseProviderModelMap(raw?: string): Map { + const entries = new Map(); + for (const token of raw?.split(",") ?? []) { + const trimmed = token.trim(); + if (!trimmed) { + continue; + } + const slash = trimmed.indexOf("/"); + if (slash <= 0 || slash === trimmed.length - 1) { + continue; + } + const providerId = normalizeOptionalLowercaseString(trimmed.slice(0, slash)); + if (!providerId) { + continue; + } + entries.set(providerId, trimmed); + } + return entries; +} + +export function resolveConfiguredLiveProviderModels( + configured: LiveProviderModelConfig, +): Map { + const resolved = new Map(); + const add = (value: string | undefined) => { + const trimmed = value?.trim(); + if (!trimmed) { + return; + } + const slash = trimmed.indexOf("/"); + if (slash <= 0 || slash === trimmed.length - 1) { + return; + } + const providerId = normalizeOptionalLowercaseString(trimmed.slice(0, slash)); + if (!providerId) { + return; + } + resolved.set(providerId, trimmed); + }; + if (typeof configured === "string") { + add(configured); + return resolved; + } + add(configured?.primary); + for (const fallback of configured?.fallbacks ?? []) { + add(fallback); + } + return resolved; +} + +export function resolveLiveAuthStore(params: { + requireProfileKeys: boolean; + hasLiveKeys: boolean; +}): AuthProfileStore | undefined { + if (params.requireProfileKeys || !params.hasLiveKeys) { + return undefined; + } + return { + version: 1, + profiles: {}, + }; +} diff --git a/src/media-generation/model-ref.ts b/src/media-generation/model-ref.ts new file mode 100644 index 00000000000..5d6382b2f33 --- /dev/null +++ b/src/media-generation/model-ref.ts @@ -0,0 +1,23 @@ +import { normalizeOptionalString } from "../shared/string-coerce.js"; + +export type ParsedGenerationModelRef = { + provider: string; + model: string; +}; + +export function parseGenerationModelRef(raw: string | undefined): ParsedGenerationModelRef | null { + const trimmed = normalizeOptionalString(raw); + if (!trimmed) { + return null; + } + const slashIndex = trimmed.indexOf("/"); + if (slashIndex <= 0 || slashIndex === trimmed.length - 1) { + return null; + } + const provider = normalizeOptionalString(trimmed.slice(0, slashIndex)); + const model = normalizeOptionalString(trimmed.slice(slashIndex + 1)); + if (!provider || !model) { + return null; + } + return { provider, model }; +} diff --git a/src/media-generation/runtime-shared.ts b/src/media-generation/runtime-shared.ts index 9a446306a4f..1dde1ddd31d 100644 --- a/src/media-generation/runtime-shared.ts +++ b/src/media-generation/runtime-shared.ts @@ -26,6 +26,13 @@ export type MediaNormalizationEntry = { supportedValues?: readonly TValue[]; }; +export type MediaGenerationNormalizationMetadataInput = { + size?: MediaNormalizationEntry; + aspectRatio?: MediaNormalizationEntry; + resolution?: MediaNormalizationEntry; + durationSeconds?: MediaNormalizationEntry; +}; + export function hasMediaNormalizationEntry( entry: MediaNormalizationEntry | undefined, ): entry is MediaNormalizationEntry { @@ -401,6 +408,55 @@ export function normalizeDurationToClosestMax( return Math.min(rounded, Math.max(1, Math.round(maxDurationSeconds))); } +export function buildMediaGenerationNormalizationMetadata(params: { + normalization?: MediaGenerationNormalizationMetadataInput; + requestedSizeForDerivedAspectRatio?: string; + includeSupportedDurationSeconds?: boolean; +}): Record { + const metadata: Record = {}; + const { normalization } = params; + if (normalization?.size?.requested !== undefined && normalization.size.applied !== undefined) { + metadata.requestedSize = normalization.size.requested; + metadata.normalizedSize = normalization.size.applied; + } + if (normalization?.aspectRatio?.applied !== undefined) { + if (normalization.aspectRatio.requested !== undefined) { + metadata.requestedAspectRatio = normalization.aspectRatio.requested; + } + metadata.normalizedAspectRatio = normalization.aspectRatio.applied; + if ( + normalization.aspectRatio.derivedFrom === "size" && + params.requestedSizeForDerivedAspectRatio + ) { + metadata.requestedSize = params.requestedSizeForDerivedAspectRatio; + metadata.aspectRatioDerivedFromSize = deriveAspectRatioFromSize( + params.requestedSizeForDerivedAspectRatio, + ); + } + } + if ( + normalization?.resolution?.requested !== undefined && + normalization.resolution.applied !== undefined + ) { + metadata.requestedResolution = normalization.resolution.requested; + metadata.normalizedResolution = normalization.resolution.applied; + } + if ( + normalization?.durationSeconds?.requested !== undefined && + normalization.durationSeconds.applied !== undefined + ) { + metadata.requestedDurationSeconds = normalization.durationSeconds.requested; + metadata.normalizedDurationSeconds = normalization.durationSeconds.applied; + if ( + params.includeSupportedDurationSeconds && + normalization.durationSeconds.supportedValues?.length + ) { + metadata.supportedDurationSeconds = normalization.durationSeconds.supportedValues; + } + } + return metadata; +} + export function throwCapabilityGenerationFailure(params: { capabilityLabel: string; attempts: FallbackAttempt[]; diff --git a/src/media-understanding/provider-registry.allowlist.test.ts b/src/media-understanding/provider-registry.allowlist.test.ts index bd5b3709050..92e264cc5dc 100644 --- a/src/media-understanding/provider-registry.allowlist.test.ts +++ b/src/media-understanding/provider-registry.allowlist.test.ts @@ -1,35 +1,15 @@ -import { beforeAll, beforeEach, describe, expect, it, vi } from "vitest"; +import { beforeAll, describe, expect, it } from "vitest"; import type { OpenClawConfig } from "../config/config.js"; -import { createEmptyPluginRegistry } from "../plugins/registry.js"; - -const mocks = vi.hoisted(() => ({ - resolveRuntimePluginRegistry: vi.fn< - (params?: unknown) => ReturnType | undefined - >(() => undefined), - loadPluginManifestRegistry: vi.fn(() => ({ plugins: [], diagnostics: [] })), - withBundledPluginEnablementCompat: vi.fn(({ config }) => config), - withBundledPluginVitestCompat: vi.fn(({ config }) => config), -})); - -vi.mock("../plugins/loader.js", () => ({ - resolveRuntimePluginRegistry: mocks.resolveRuntimePluginRegistry, -})); - -vi.mock("../plugins/manifest-registry.js", () => ({ - loadPluginManifestRegistry: mocks.loadPluginManifestRegistry, -})); - -vi.mock("../plugins/bundled-compat.js", async (importOriginal) => { - const actual = await importOriginal(); - return { - ...actual, - withBundledPluginEnablementCompat: mocks.withBundledPluginEnablementCompat, - withBundledPluginVitestCompat: mocks.withBundledPluginVitestCompat, - }; -}); +import { + createEmptyProviderRegistryAllowlistFallbackRegistry, + getProviderRegistryAllowlistMocks, + installProviderRegistryAllowlistMockDefaults, +} from "../test-utils/provider-registry-allowlist.test-helpers.js"; let buildMediaUnderstandingRegistry: typeof import("./provider-registry.js").buildMediaUnderstandingRegistry; let getMediaUnderstandingProvider: typeof import("./provider-registry.js").getMediaUnderstandingProvider; +const mocks = getProviderRegistryAllowlistMocks(); +installProviderRegistryAllowlistMockDefaults(); describe("media-understanding provider registry allowlist fallback", () => { beforeAll(async () => { @@ -37,17 +17,6 @@ describe("media-understanding provider registry allowlist fallback", () => { await import("./provider-registry.js")); }); - beforeEach(() => { - mocks.resolveRuntimePluginRegistry.mockReset(); - mocks.resolveRuntimePluginRegistry.mockReturnValue(undefined); - mocks.loadPluginManifestRegistry.mockReset(); - mocks.loadPluginManifestRegistry.mockReturnValue({ plugins: [], diagnostics: [] }); - mocks.withBundledPluginEnablementCompat.mockReset(); - mocks.withBundledPluginEnablementCompat.mockImplementation(({ config }) => config); - mocks.withBundledPluginVitestCompat.mockReset(); - mocks.withBundledPluginVitestCompat.mockImplementation(({ config }) => config); - }); - it("adds bundled capability plugin ids to plugins.allow before fallback registry load", () => { const cfg = { plugins: { allow: ["custom-plugin"] } } as OpenClawConfig; const compatConfig = { @@ -69,7 +38,9 @@ describe("media-understanding provider registry allowlist fallback", () => { }); mocks.withBundledPluginEnablementCompat.mockReturnValue(compatConfig); mocks.withBundledPluginVitestCompat.mockReturnValue(compatConfig); - mocks.resolveRuntimePluginRegistry.mockImplementation(() => createEmptyPluginRegistry()); + mocks.resolveRuntimePluginRegistry.mockImplementation(() => + createEmptyProviderRegistryAllowlistFallbackRegistry(), + ); const registry = buildMediaUnderstandingRegistry(undefined, cfg); diff --git a/src/media-understanding/runner.ts b/src/media-understanding/runner.ts index 5b8266effe0..7e7edb1fbe0 100644 --- a/src/media-understanding/runner.ts +++ b/src/media-understanding/runner.ts @@ -461,58 +461,6 @@ async function resolveKeyEntry(params: { return { type: "provider" as const, provider: providerId, model: resolvedModel }; }; - if (capability === "image") { - const activeProvider = params.activeModel?.provider?.trim(); - if (activeProvider) { - const activeEntry = await checkProvider(activeProvider, params.activeModel?.model); - if (activeEntry) { - return activeEntry; - } - } - for (const providerId of resolveConfiguredKeyProviderOrder({ - cfg, - providerRegistry, - capability, - fallbackProviders: resolveAutoMediaKeyProviders({ - cfg, - capability, - providerRegistry, - }), - })) { - const entry = await checkProvider(providerId); - if (entry) { - return entry; - } - } - return null; - } - - if (capability === "video") { - const activeProvider = params.activeModel?.provider?.trim(); - if (activeProvider) { - const activeEntry = await checkProvider(activeProvider, params.activeModel?.model); - if (activeEntry) { - return activeEntry; - } - } - for (const providerId of resolveConfiguredKeyProviderOrder({ - cfg, - providerRegistry, - capability, - fallbackProviders: resolveAutoMediaKeyProviders({ - cfg, - capability, - providerRegistry, - }), - })) { - const entry = await checkProvider(providerId, undefined); - if (entry) { - return entry; - } - } - return null; - } - const activeProvider = params.activeModel?.provider?.trim(); if (activeProvider) { const activeEntry = await checkProvider(activeProvider, params.activeModel?.model); diff --git a/src/music-generation/live-test-helpers.ts b/src/music-generation/live-test-helpers.ts index aa04bfe9560..4d7bb67bd78 100644 --- a/src/music-generation/live-test-helpers.ts +++ b/src/music-generation/live-test-helpers.ts @@ -1,93 +1,30 @@ -import type { AuthProfileStore } from "../agents/auth-profiles.js"; import type { OpenClawConfig } from "../config/config.js"; -import { normalizeOptionalLowercaseString } from "../shared/string-coerce.js"; +import { + parseLiveCsvFilter, + parseProviderModelMap, + redactLiveApiKey, + resolveConfiguredLiveProviderModels, + resolveLiveAuthStore, +} from "../media-generation/live-test-helpers.js"; + +export { parseProviderModelMap, redactLiveApiKey }; export const DEFAULT_LIVE_MUSIC_MODELS: Record = { google: "google/lyria-3-clip-preview", minimax: "minimax/music-2.5+", }; -export function redactLiveApiKey(value: string | undefined): string { - const trimmed = value?.trim(); - if (!trimmed) { - return "none"; - } - if (trimmed.length <= 12) { - return trimmed; - } - return `${trimmed.slice(0, 8)}...${trimmed.slice(-4)}`; -} - export function parseCsvFilter(raw?: string): Set | null { - const trimmed = raw?.trim(); - if (!trimmed || trimmed === "all") { - return null; - } - const values = trimmed - .split(",") - .map((entry) => normalizeOptionalLowercaseString(entry)) - .filter((entry): entry is string => Boolean(entry)); - return values.length > 0 ? new Set(values) : null; -} - -export function parseProviderModelMap(raw?: string): Map { - const entries = new Map(); - for (const token of raw?.split(",") ?? []) { - const trimmed = token.trim(); - if (!trimmed) { - continue; - } - const slash = trimmed.indexOf("/"); - if (slash <= 0 || slash === trimmed.length - 1) { - continue; - } - const providerId = normalizeOptionalLowercaseString(trimmed.slice(0, slash)); - if (!providerId) { - continue; - } - entries.set(providerId, trimmed); - } - return entries; + return parseLiveCsvFilter(raw); } export function resolveConfiguredLiveMusicModels(cfg: OpenClawConfig): Map { - const resolved = new Map(); - const configured = cfg.agents?.defaults?.musicGenerationModel; - const add = (value: string | undefined) => { - const trimmed = value?.trim(); - if (!trimmed) { - return; - } - const slash = trimmed.indexOf("/"); - if (slash <= 0 || slash === trimmed.length - 1) { - return; - } - const providerId = normalizeOptionalLowercaseString(trimmed.slice(0, slash)); - if (!providerId) { - return; - } - resolved.set(providerId, trimmed); - }; - if (typeof configured === "string") { - add(configured); - return resolved; - } - add(configured?.primary); - for (const fallback of configured?.fallbacks ?? []) { - add(fallback); - } - return resolved; + return resolveConfiguredLiveProviderModels(cfg.agents?.defaults?.musicGenerationModel); } export function resolveLiveMusicAuthStore(params: { requireProfileKeys: boolean; hasLiveKeys: boolean; -}): AuthProfileStore | undefined { - if (params.requireProfileKeys || !params.hasLiveKeys) { - return undefined; - } - return { - version: 1, - profiles: {}, - }; +}) { + return resolveLiveAuthStore(params); } diff --git a/src/music-generation/model-ref.ts b/src/music-generation/model-ref.ts index d58562570bc..d4075cab05b 100644 --- a/src/music-generation/model-ref.ts +++ b/src/music-generation/model-ref.ts @@ -1,16 +1,7 @@ +import { parseGenerationModelRef } from "../media-generation/model-ref.js"; + export function parseMusicGenerationModelRef( raw: string | undefined, ): { provider: string; model: string } | null { - const trimmed = raw?.trim(); - if (!trimmed) { - return null; - } - const slashIndex = trimmed.indexOf("/"); - if (slashIndex <= 0 || slashIndex === trimmed.length - 1) { - return null; - } - return { - provider: trimmed.slice(0, slashIndex).trim(), - model: trimmed.slice(slashIndex + 1).trim(), - }; + return parseGenerationModelRef(raw); } diff --git a/src/music-generation/runtime.test.ts b/src/music-generation/runtime.test.ts index 5412f571edf..3c7c65c0e45 100644 --- a/src/music-generation/runtime.test.ts +++ b/src/music-generation/runtime.test.ts @@ -1,4 +1,5 @@ import { beforeEach, describe, expect, it, vi } from "vitest"; +import { resetGenerationRuntimeMocks } from "../../test/helpers/media-generation/runtime-test-mocks.js"; import type { OpenClawConfig } from "../config/config.js"; import { generateMusic, listRuntimeMusicGenerationProviders } from "./runtime.js"; import type { MusicGenerationProvider } from "./types.js"; @@ -68,23 +69,12 @@ vi.mock("./provider-registry.js", () => ({ describe("music-generation runtime", () => { beforeEach(() => { - mocks.createSubsystemLogger.mockClear(); - mocks.describeFailoverError.mockReset(); - mocks.getMusicGenerationProvider.mockReset(); - mocks.getProviderEnvVars.mockReset(); - mocks.getProviderEnvVars.mockReturnValue([]); - mocks.resolveProviderAuthEnvVarCandidates.mockReset(); - mocks.resolveProviderAuthEnvVarCandidates.mockReturnValue({}); - mocks.isFailoverError.mockReset(); - mocks.isFailoverError.mockReturnValue(false); - mocks.listMusicGenerationProviders.mockReset(); - mocks.listMusicGenerationProviders.mockReturnValue([]); - mocks.parseMusicGenerationModelRef.mockClear(); - mocks.resolveAgentModelFallbackValues.mockReset(); - mocks.resolveAgentModelFallbackValues.mockReturnValue([]); - mocks.resolveAgentModelPrimaryValue.mockReset(); - mocks.resolveAgentModelPrimaryValue.mockReturnValue(undefined); - mocks.debug.mockReset(); + resetGenerationRuntimeMocks({ + ...mocks, + getProvider: mocks.getMusicGenerationProvider, + listProviders: mocks.listMusicGenerationProviders, + parseModelRef: mocks.parseMusicGenerationModelRef, + }); }); it("generates tracks through the active music-generation provider", async () => { diff --git a/src/music-generation/runtime.ts b/src/music-generation/runtime.ts index 17358247f1c..1b564d1b0e8 100644 --- a/src/music-generation/runtime.ts +++ b/src/music-generation/runtime.ts @@ -5,6 +5,7 @@ import type { OpenClawConfig } from "../config/config.js"; import { formatErrorMessage } from "../infra/errors.js"; import { createSubsystemLogger } from "../logging/subsystem.js"; import { + buildMediaGenerationNormalizationMetadata, buildNoCapabilityModelConfiguredMessage, resolveCapabilityModelCandidates, throwCapabilityGenerationFailure, @@ -124,13 +125,9 @@ export async function generateMusic( normalization: sanitized.normalization, metadata: { ...result.metadata, - ...(sanitized.normalization?.durationSeconds?.requested !== undefined && - sanitized.normalization.durationSeconds.applied !== undefined - ? { - requestedDurationSeconds: sanitized.normalization.durationSeconds.requested, - normalizedDurationSeconds: sanitized.normalization.durationSeconds.applied, - } - : {}), + ...buildMediaGenerationNormalizationMetadata({ + normalization: sanitized.normalization, + }), }, ignoredOverrides: sanitized.ignoredOverrides, }; diff --git a/src/plugin-sdk/video-generation.ts b/src/plugin-sdk/video-generation.ts index ccf273eec08..f1731ace696 100644 --- a/src/plugin-sdk/video-generation.ts +++ b/src/plugin-sdk/video-generation.ts @@ -139,6 +139,9 @@ type _VideoGenerationSdkCompat = [ ]; export { + DASHSCOPE_WAN_VIDEO_CAPABILITIES, + DASHSCOPE_WAN_VIDEO_MODELS, + DEFAULT_DASHSCOPE_WAN_VIDEO_MODEL, DEFAULT_VIDEO_GENERATION_DURATION_SECONDS, DEFAULT_VIDEO_GENERATION_TIMEOUT_MS, DEFAULT_VIDEO_RESOLUTION_TO_SIZE, @@ -148,6 +151,7 @@ export { extractDashscopeVideoUrls, pollDashscopeVideoTaskUntilComplete, resolveVideoGenerationReferenceUrls, + runDashscopeVideoGenerationTask, } from "../video-generation/dashscope-compatible.js"; export type { DashscopeVideoGenerationResponse } from "../video-generation/dashscope-compatible.js"; diff --git a/src/plugins/clawhub.ts b/src/plugins/clawhub.ts index 557b3f180e3..ea0f86c1e72 100644 --- a/src/plugins/clawhub.ts +++ b/src/plugins/clawhub.ts @@ -114,6 +114,11 @@ type JSZipObjectWithSize = JSZip.JSZipObject & { const CLAWHUB_GENERATED_ARCHIVE_METADATA_FILE = "_meta.json"; +type ClawHubArchiveEntryLimits = { + maxEntryBytes: number; + addArchiveBytes: (bytes: number) => boolean; +}; + export function formatClawHubSpecifier(params: { name: string; version?: string }): string { return `clawhub:${params.name}${params.version ? `@${params.version}` : ""}`; } @@ -309,10 +314,14 @@ function resolveClawHubArchiveVerification( }; } -async function readClawHubArchiveEntryBuffer( +async function readLimitedClawHubArchiveEntry( entry: JSZip.JSZipObject, - limits: { maxEntryBytes: number; addArchiveBytes: (bytes: number) => boolean }, -): Promise { + limits: ClawHubArchiveEntryLimits, + handlers: { + onChunk: (buffer: Buffer) => void; + onEnd: () => T; + }, +): Promise { const hintedSize = (entry as JSZipObjectWithSize)._data?.uncompressedSize; if ( typeof hintedSize === "number" && @@ -325,8 +334,7 @@ async function readClawHubArchiveEntryBuffer( ); } let entryBytes = 0; - const chunks: Buffer[] = []; - return await new Promise((resolve) => { + return await new Promise((resolve) => { let settled = false; const stream = entry.nodeStream("nodebuffer") as NodeJS.ReadableStream & { destroy?: (error?: Error) => void; @@ -360,14 +368,14 @@ async function readClawHubArchiveEntryBuffer( ); return; } - chunks.push(buffer); + handlers.onChunk(buffer); }); stream.once("end", () => { if (settled) { return; } settled = true; - resolve(Buffer.concat(chunks)); + resolve(handlers.onEnd()); }); stream.once("error", (error: unknown) => { if (settled) { @@ -384,78 +392,33 @@ async function readClawHubArchiveEntryBuffer( }); } +async function readClawHubArchiveEntryBuffer( + entry: JSZip.JSZipObject, + limits: ClawHubArchiveEntryLimits, +): Promise { + const chunks: Buffer[] = []; + return await readLimitedClawHubArchiveEntry(entry, limits, { + onChunk(buffer) { + chunks.push(buffer); + }, + onEnd() { + return Buffer.concat(chunks); + }, + }); +} + async function hashClawHubArchiveEntry( entry: JSZip.JSZipObject, - limits: { maxEntryBytes: number; addArchiveBytes: (bytes: number) => boolean }, + limits: ClawHubArchiveEntryLimits, ): Promise { - const hintedSize = (entry as JSZipObjectWithSize)._data?.uncompressedSize; - if ( - typeof hintedSize === "number" && - Number.isFinite(hintedSize) && - hintedSize > limits.maxEntryBytes - ) { - return buildClawHubInstallFailure( - `ClawHub archive fallback verification rejected "${entry.name}" because it exceeds the per-file size limit.`, - CLAWHUB_INSTALL_ERROR_CODE.ARCHIVE_INTEGRITY_MISMATCH, - ); - } - let entryBytes = 0; const digest = createHash("sha256"); - return await new Promise((resolve) => { - let settled = false; - const stream = entry.nodeStream("nodebuffer") as NodeJS.ReadableStream & { - destroy?: (error?: Error) => void; - }; - stream.on("data", (chunk: Buffer | Uint8Array | string) => { - if (settled) { - return; - } - const buffer = - typeof chunk === "string" ? Buffer.from(chunk) : Buffer.from(chunk as Uint8Array); - entryBytes += buffer.byteLength; - if (entryBytes > limits.maxEntryBytes) { - settled = true; - stream.destroy?.(); - resolve( - buildClawHubInstallFailure( - `ClawHub archive fallback verification rejected "${entry.name}" because it exceeds the per-file size limit.`, - CLAWHUB_INSTALL_ERROR_CODE.ARCHIVE_INTEGRITY_MISMATCH, - ), - ); - return; - } - if (!limits.addArchiveBytes(buffer.byteLength)) { - settled = true; - stream.destroy?.(); - resolve( - buildClawHubInstallFailure( - "ClawHub archive fallback verification exceeded the total extracted-size limit.", - CLAWHUB_INSTALL_ERROR_CODE.ARCHIVE_INTEGRITY_MISMATCH, - ), - ); - return; - } + return await readLimitedClawHubArchiveEntry(entry, limits, { + onChunk(buffer) { digest.update(buffer); - }); - stream.once("end", () => { - if (settled) { - return; - } - settled = true; - resolve(digest.digest("hex")); - }); - stream.once("error", (error: unknown) => { - if (settled) { - return; - } - settled = true; - resolve( - buildClawHubInstallFailure( - error instanceof Error ? error.message : String(error), - CLAWHUB_INSTALL_ERROR_CODE.ARCHIVE_INTEGRITY_MISMATCH, - ), - ); - }); + }, + onEnd() { + return digest.digest("hex"); + }, }); } diff --git a/src/plugins/doctor-contract-registry.ts b/src/plugins/doctor-contract-registry.ts index dc986e58341..8aec2d28e57 100644 --- a/src/plugins/doctor-contract-registry.ts +++ b/src/plugins/doctor-contract-registry.ts @@ -1,18 +1,13 @@ import fs from "node:fs"; import path from "node:path"; import { fileURLToPath } from "node:url"; -import { createJiti } from "jiti"; import type { LegacyConfigRule } from "../config/legacy.shared.js"; import type { OpenClawConfig } from "../config/types.js"; import { asNullableRecord } from "../shared/record-coerce.js"; import { discoverOpenClawPlugins } from "./discovery.js"; +import { getCachedPluginJitiLoader, type PluginJitiLoaderCache } from "./jiti-loader-cache.js"; import { loadPluginManifestRegistry } from "./manifest-registry.js"; import { resolvePluginCacheInputs } from "./roots.js"; -import { - buildPluginLoaderAliasMap, - buildPluginLoaderJitiOptions, - shouldPreferNativeJiti, -} from "./sdk-alias.js"; const CONTRACT_API_EXTENSIONS = [".js", ".mjs", ".cjs", ".ts", ".mts", ".cts"] as const; const CURRENT_MODULE_PATH = fileURLToPath(import.meta.url); @@ -40,26 +35,15 @@ type PluginDoctorContractEntry = { normalizeCompatibilityConfig?: PluginDoctorCompatibilityNormalizer; }; -const jitiLoaders = new Map>(); +const jitiLoaders: PluginJitiLoaderCache = new Map(); const doctorContractCache = new Map(); function getJiti(modulePath: string) { - const aliasMap = buildPluginLoaderAliasMap(modulePath, process.argv[1], import.meta.url); - const tryNative = shouldPreferNativeJiti(modulePath); - const cacheKey = JSON.stringify({ - tryNative, - aliasMap: Object.entries(aliasMap).toSorted(([left], [right]) => left.localeCompare(right)), + return getCachedPluginJitiLoader({ + cache: jitiLoaders, + modulePath, + importerUrl: import.meta.url, }); - const cached = jitiLoaders.get(cacheKey); - if (cached) { - return cached; - } - const loader = createJiti(modulePath, { - ...buildPluginLoaderJitiOptions(aliasMap), - tryNative, - }); - jitiLoaders.set(cacheKey, loader); - return loader; } function buildDoctorContractCacheKey(params: { diff --git a/src/plugins/jiti-loader-cache.ts b/src/plugins/jiti-loader-cache.ts new file mode 100644 index 00000000000..226d18afad5 --- /dev/null +++ b/src/plugins/jiti-loader-cache.ts @@ -0,0 +1,36 @@ +import { createJiti } from "jiti"; +import { + buildPluginLoaderAliasMap, + buildPluginLoaderJitiOptions, + shouldPreferNativeJiti, +} from "./sdk-alias.js"; + +export type PluginJitiLoaderCache = Map>; + +export function getCachedPluginJitiLoader(params: { + cache: PluginJitiLoaderCache; + modulePath: string; + importerUrl: string; + argvEntry?: string; +}): ReturnType { + const aliasMap = buildPluginLoaderAliasMap( + params.modulePath, + params.argvEntry ?? process.argv[1], + params.importerUrl, + ); + const tryNative = shouldPreferNativeJiti(params.modulePath); + const cacheKey = JSON.stringify({ + tryNative, + aliasMap: Object.entries(aliasMap).toSorted(([left], [right]) => left.localeCompare(right)), + }); + const cached = params.cache.get(cacheKey); + if (cached) { + return cached; + } + const loader = createJiti(params.modulePath, { + ...buildPluginLoaderJitiOptions(aliasMap), + tryNative, + }); + params.cache.set(cacheKey, loader); + return loader; +} diff --git a/src/plugins/runtime.test.ts b/src/plugins/runtime.test.ts index 7721412ab6d..9b796357f69 100644 --- a/src/plugins/runtime.test.ts +++ b/src/plugins/runtime.test.ts @@ -13,6 +13,7 @@ import { resolveActivePluginHttpRouteRegistry, setActivePluginRegistry, } from "./runtime.js"; +import { createPluginRecord } from "./status.test-helpers.js"; function createRegistryWithRoute(path: string) { const registry = createEmptyPluginRegistry(); @@ -185,68 +186,23 @@ describe("setActivePluginRegistry", () => { it("does not treat bundle-only loaded entries as imported runtime plugins", () => { const registry = createEmptyPluginRegistry(); - registry.plugins.push({ - id: "bundle-only", - name: "Bundle Only", - source: "/tmp/bundle", - origin: "bundled", - enabled: true, - status: "loaded", - format: "bundle", - toolNames: [], - hookNames: [], - channelIds: [], - cliBackendIds: [], - providerIds: [], - speechProviderIds: [], - realtimeTranscriptionProviderIds: [], - realtimeVoiceProviderIds: [], - mediaUnderstandingProviderIds: [], - imageGenerationProviderIds: [], - videoGenerationProviderIds: [], - musicGenerationProviderIds: [], - webFetchProviderIds: [], - webSearchProviderIds: [], - memoryEmbeddingProviderIds: [], - gatewayMethods: [], - cliCommands: [], - services: [], - commands: [], - httpRoutes: 0, - hookCount: 0, - configSchema: true, - }); - registry.plugins.push({ - id: "runtime-plugin", - name: "Runtime Plugin", - source: "/tmp/runtime", - origin: "workspace", - enabled: true, - status: "loaded", - format: "openclaw", - toolNames: [], - hookNames: [], - channelIds: [], - cliBackendIds: [], - providerIds: [], - speechProviderIds: [], - realtimeTranscriptionProviderIds: [], - realtimeVoiceProviderIds: [], - mediaUnderstandingProviderIds: [], - imageGenerationProviderIds: [], - videoGenerationProviderIds: [], - musicGenerationProviderIds: [], - webFetchProviderIds: [], - webSearchProviderIds: [], - memoryEmbeddingProviderIds: [], - gatewayMethods: [], - cliCommands: [], - services: [], - commands: [], - httpRoutes: 0, - hookCount: 0, - configSchema: true, - }); + registry.plugins.push( + createPluginRecord({ + id: "bundle-only", + name: "Bundle Only", + source: "/tmp/bundle", + origin: "bundled", + format: "bundle", + configSchema: true, + }), + createPluginRecord({ + id: "runtime-plugin", + name: "Runtime Plugin", + source: "/tmp/runtime", + format: "openclaw", + configSchema: true, + }), + ); setActivePluginRegistry(registry); @@ -259,43 +215,3 @@ describe("setActivePluginRegistry", () => { expect(listImportedRuntimePluginIds()).toEqual(["broken-plugin"]); }); }); - -describe("setActivePluginRegistry", () => { - beforeEach(() => { - setActivePluginRegistry(createEmptyPluginRegistry()); - }); - - it("does not carry forward httpRoutes when new registry has none", () => { - const oldRegistry = createEmptyPluginRegistry(); - const fakeRoute = makeRoute("/test"); - oldRegistry.httpRoutes.push(fakeRoute); - setActivePluginRegistry(oldRegistry); - expect(getActivePluginRegistry()?.httpRoutes).toHaveLength(1); - - const newRegistry = createEmptyPluginRegistry(); - expect(newRegistry.httpRoutes).toHaveLength(0); - setActivePluginRegistry(newRegistry); - expect(getActivePluginRegistry()?.httpRoutes).toHaveLength(0); - }); - - it("does not carry forward when new registry already has routes", () => { - const oldRegistry = createEmptyPluginRegistry(); - oldRegistry.httpRoutes.push(makeRoute("/old")); - setActivePluginRegistry(oldRegistry); - - const newRegistry = createEmptyPluginRegistry(); - const newRoute = makeRoute("/new"); - newRegistry.httpRoutes.push(newRoute); - setActivePluginRegistry(newRegistry); - expect(getActivePluginRegistry()?.httpRoutes).toHaveLength(1); - expect(getActivePluginRegistry()?.httpRoutes[0]).toEqual(newRoute); - }); - - it("does not carry forward when same registry is set again", () => { - const registry = createEmptyPluginRegistry(); - registry.httpRoutes.push(makeRoute("/test")); - setActivePluginRegistry(registry); - setActivePluginRegistry(registry); - expect(getActivePluginRegistry()?.httpRoutes).toHaveLength(1); - }); -}); diff --git a/src/plugins/setup-registry.ts b/src/plugins/setup-registry.ts index e13b050b72e..1da6e92a4e5 100644 --- a/src/plugins/setup-registry.ts +++ b/src/plugins/setup-registry.ts @@ -1,20 +1,15 @@ import fs from "node:fs"; import path from "node:path"; import { fileURLToPath } from "node:url"; -import { createJiti } from "jiti"; import { normalizeProviderId } from "../agents/provider-id.js"; import type { OpenClawConfig } from "../config/config.js"; import { buildPluginApi } from "./api-builder.js"; import { collectPluginConfigContractMatches } from "./config-contracts.js"; import { discoverOpenClawPlugins } from "./discovery.js"; +import { getCachedPluginJitiLoader, type PluginJitiLoaderCache } from "./jiti-loader-cache.js"; import { loadPluginManifestRegistry } from "./manifest-registry.js"; import { resolvePluginCacheInputs } from "./roots.js"; import type { PluginRuntime } from "./runtime/types.js"; -import { - buildPluginLoaderAliasMap, - buildPluginLoaderJitiOptions, - shouldPreferNativeJiti, -} from "./sdk-alias.js"; import type { CliBackendPlugin, OpenClawPluginModule, @@ -69,7 +64,7 @@ const NOOP_LOGGER: PluginLogger = { error() {}, }; -const jitiLoaders = new Map>(); +const jitiLoaders: PluginJitiLoaderCache = new Map(); const setupRegistryCache = new Map(); const setupProviderCache = new Map(); @@ -80,22 +75,11 @@ export function clearPluginSetupRegistryCache(): void { } function getJiti(modulePath: string) { - const aliasMap = buildPluginLoaderAliasMap(modulePath, process.argv[1], import.meta.url); - const tryNative = shouldPreferNativeJiti(modulePath); - const cacheKey = JSON.stringify({ - tryNative, - aliasMap: Object.entries(aliasMap).toSorted(([left], [right]) => left.localeCompare(right)), + return getCachedPluginJitiLoader({ + cache: jitiLoaders, + modulePath, + importerUrl: import.meta.url, }); - const cached = jitiLoaders.get(cacheKey); - if (cached) { - return cached; - } - const loader = createJiti(modulePath, { - ...buildPluginLoaderJitiOptions(aliasMap), - tryNative, - }); - jitiLoaders.set(cacheKey, loader); - return loader; } function buildSetupRegistryCacheKey(params: { diff --git a/src/plugins/status.test.ts b/src/plugins/status.test.ts index a7d43e952cf..ddd2c9968bd 100644 --- a/src/plugins/status.test.ts +++ b/src/plugins/status.test.ts @@ -211,6 +211,50 @@ function createAutoEnabledStatusConfig( return { rawConfig, autoEnabledConfig }; } +function expectAutoEnabledDemoCompatibilityNoticesPreserveRawConfig() { + const { rawConfig, autoEnabledConfig } = createAutoEnabledStatusConfig( + { + demo: { enabled: true }, + }, + { channels: { demo: { enabled: true } } }, + ); + const autoEnabledReasons = { + demo: ["demo configured"], + }; + applyPluginAutoEnableMock.mockReturnValue({ + config: autoEnabledConfig, + changes: [], + autoEnabledReasons, + }); + setSinglePluginLoadResult( + createPluginRecord({ + id: "demo", + name: "Demo", + description: "Auto-enabled plugin", + origin: "bundled", + hookCount: 1, + }), + { + typedHooks: [createTypedHook({ pluginId: "demo", hookName: "before_agent_start" })], + }, + ); + + expect(buildPluginCompatibilityNotices({ config: rawConfig })).toEqual([ + createCompatibilityNotice({ pluginId: "demo", code: "legacy-before-agent-start" }), + createCompatibilityNotice({ pluginId: "demo", code: "hook-only" }), + ]); + + expectAutoEnabledStatusLoad({ + rawConfig, + }); + expectPluginLoaderCall({ + config: autoEnabledConfig, + activationSourceConfig: rawConfig, + autoEnabledReasons, + loadModules: true, + }); +} + function expectNoCompatibilityWarnings() { expect(buildPluginCompatibilityNotices()).toEqual([]); expect(buildPluginCompatibilityWarnings()).toEqual([]); @@ -405,48 +449,7 @@ describe("plugin status reports", () => { }); it("preserves raw config activation context when compatibility notices build their own report", () => { - const { rawConfig, autoEnabledConfig } = createAutoEnabledStatusConfig( - { - demo: { enabled: true }, - }, - { channels: { demo: { enabled: true } } }, - ); - applyPluginAutoEnableMock.mockReturnValue({ - config: autoEnabledConfig, - changes: [], - autoEnabledReasons: { - demo: ["demo configured"], - }, - }); - setSinglePluginLoadResult( - createPluginRecord({ - id: "demo", - name: "Demo", - description: "Auto-enabled plugin", - origin: "bundled", - hookCount: 1, - }), - { - typedHooks: [createTypedHook({ pluginId: "demo", hookName: "before_agent_start" })], - }, - ); - - expect(buildPluginCompatibilityNotices({ config: rawConfig })).toEqual([ - createCompatibilityNotice({ pluginId: "demo", code: "legacy-before-agent-start" }), - createCompatibilityNotice({ pluginId: "demo", code: "hook-only" }), - ]); - - expectAutoEnabledStatusLoad({ - rawConfig, - }); - expectPluginLoaderCall({ - config: autoEnabledConfig, - activationSourceConfig: rawConfig, - autoEnabledReasons: { - demo: ["demo configured"], - }, - loadModules: true, - }); + expectAutoEnabledDemoCompatibilityNoticesPreserveRawConfig(); }); it("applies the full bundled provider compat chain before loading plugins", () => { @@ -468,48 +471,7 @@ describe("plugin status reports", () => { }); it("preserves raw config activation context for compatibility-derived reports", () => { - const { rawConfig, autoEnabledConfig } = createAutoEnabledStatusConfig( - { - demo: { enabled: true }, - }, - { channels: { demo: { enabled: true } } }, - ); - applyPluginAutoEnableMock.mockReturnValue({ - config: autoEnabledConfig, - changes: [], - autoEnabledReasons: { - demo: ["demo configured"], - }, - }); - setSinglePluginLoadResult( - createPluginRecord({ - id: "demo", - name: "Demo", - description: "Auto-enabled plugin", - origin: "bundled", - hookCount: 1, - }), - { - typedHooks: [createTypedHook({ pluginId: "demo", hookName: "before_agent_start" })], - }, - ); - - expect(buildPluginCompatibilityNotices({ config: rawConfig })).toEqual([ - createCompatibilityNotice({ pluginId: "demo", code: "legacy-before-agent-start" }), - createCompatibilityNotice({ pluginId: "demo", code: "hook-only" }), - ]); - - expectAutoEnabledStatusLoad({ - rawConfig, - }); - expectPluginLoaderCall({ - config: autoEnabledConfig, - activationSourceConfig: rawConfig, - autoEnabledReasons: { - demo: ["demo configured"], - }, - loadModules: true, - }); + expectAutoEnabledDemoCompatibilityNoticesPreserveRawConfig(); }); it("normalizes bundled plugin versions to the core base release", () => { diff --git a/src/plugins/web-search-providers.runtime.test.ts b/src/plugins/web-search-providers.runtime.test.ts index 02cd1f79ea4..f7b35a11e41 100644 --- a/src/plugins/web-search-providers.runtime.test.ts +++ b/src/plugins/web-search-providers.runtime.test.ts @@ -277,6 +277,51 @@ function createRuntimeWebSearchProvider(params: { }; } +function createBraveRuntimeWebSearchProvider() { + return createRuntimeWebSearchProvider({ + pluginId: "brave", + pluginName: "Brave", + id: "brave", + label: "Brave Search", + hint: "Brave runtime provider", + envVar: "BRAVE_API_KEY", + signupUrl: "https://example.com/brave", + credentialPath: "plugins.entries.brave.config.webSearch.apiKey", + }); +} + +function createActiveBraveRegistryFixture(params?: { + includeResolutionWorkspaceDir?: boolean; + activeWorkspaceDir?: string; +}) { + const env = createWebSearchEnv(); + const rawConfig = createBraveAllowConfig(); + const { config, activationSourceConfig, autoEnabledReasons } = + webSearchProvidersSharedModule.resolveBundledWebSearchResolutionConfig({ + config: rawConfig, + bundledAllowlistCompat: true, + ...(params?.includeResolutionWorkspaceDir + ? { workspaceDir: DEFAULT_WEB_SEARCH_WORKSPACE } + : {}), + env, + }); + const { cacheKey } = loaderModule.__testing.resolvePluginLoadCacheContext({ + config, + activationSourceConfig, + autoEnabledReasons, + workspaceDir: DEFAULT_WEB_SEARCH_WORKSPACE, + env, + onlyPluginIds: ["brave"], + cache: false, + activate: false, + }); + const registry = createEmptyPluginRegistry(); + registry.webSearchProviders.push(createBraveRuntimeWebSearchProvider()); + setActivePluginRegistry(registry, cacheKey, "default", params?.activeWorkspaceDir); + + return { env, rawConfig }; +} + function expectRuntimeProviderResolution( providers: ReturnType, expected: readonly string[], @@ -436,38 +481,7 @@ describe("resolvePluginWebSearchProviders", () => { }); it("reuses a compatible active registry for snapshot resolution when config is provided", () => { - const env = createWebSearchEnv(); - const rawConfig = createBraveAllowConfig(); - const { config, activationSourceConfig, autoEnabledReasons } = - webSearchProvidersSharedModule.resolveBundledWebSearchResolutionConfig({ - config: rawConfig, - bundledAllowlistCompat: true, - env, - }); - const { cacheKey } = loaderModule.__testing.resolvePluginLoadCacheContext({ - config, - activationSourceConfig, - autoEnabledReasons, - workspaceDir: DEFAULT_WEB_SEARCH_WORKSPACE, - env, - onlyPluginIds: ["brave"], - cache: false, - activate: false, - }); - const registry = createEmptyPluginRegistry(); - registry.webSearchProviders.push( - createRuntimeWebSearchProvider({ - pluginId: "brave", - pluginName: "Brave", - id: "brave", - label: "Brave Search", - hint: "Brave runtime provider", - envVar: "BRAVE_API_KEY", - signupUrl: "https://example.com/brave", - credentialPath: "plugins.entries.brave.config.webSearch.apiKey", - }), - ); - setActivePluginRegistry(registry, cacheKey); + const { env, rawConfig } = createActiveBraveRegistryFixture(); const providers = resolvePluginWebSearchProviders({ config: rawConfig, @@ -481,39 +495,10 @@ describe("resolvePluginWebSearchProviders", () => { }); it("inherits workspaceDir from the active registry for compatible web-search snapshot reuse", () => { - const env = createWebSearchEnv(); - const rawConfig = createBraveAllowConfig(); - const { config, activationSourceConfig, autoEnabledReasons } = - webSearchProvidersSharedModule.resolveBundledWebSearchResolutionConfig({ - config: rawConfig, - bundledAllowlistCompat: true, - workspaceDir: DEFAULT_WEB_SEARCH_WORKSPACE, - env, - }); - const { cacheKey } = loaderModule.__testing.resolvePluginLoadCacheContext({ - config, - activationSourceConfig, - autoEnabledReasons, - workspaceDir: DEFAULT_WEB_SEARCH_WORKSPACE, - env, - onlyPluginIds: ["brave"], - cache: false, - activate: false, + const { env, rawConfig } = createActiveBraveRegistryFixture({ + includeResolutionWorkspaceDir: true, + activeWorkspaceDir: DEFAULT_WEB_SEARCH_WORKSPACE, }); - const registry = createEmptyPluginRegistry(); - registry.webSearchProviders.push( - createRuntimeWebSearchProvider({ - pluginId: "brave", - pluginName: "Brave", - id: "brave", - label: "Brave Search", - hint: "Brave runtime provider", - envVar: "BRAVE_API_KEY", - signupUrl: "https://example.com/brave", - credentialPath: "plugins.entries.brave.config.webSearch.apiKey", - }), - ); - setActivePluginRegistry(registry, cacheKey, "default", DEFAULT_WEB_SEARCH_WORKSPACE); const providers = resolvePluginWebSearchProviders({ config: rawConfig, @@ -676,38 +661,7 @@ describe("resolvePluginWebSearchProviders", () => { { name: "reuses a compatible active registry for runtime resolution when config is provided", setupRegistry: () => { - const env = createWebSearchEnv(); - const rawConfig = createBraveAllowConfig(); - const { config, activationSourceConfig, autoEnabledReasons } = - webSearchProvidersSharedModule.resolveBundledWebSearchResolutionConfig({ - config: rawConfig, - bundledAllowlistCompat: true, - env, - }); - const { cacheKey } = loaderModule.__testing.resolvePluginLoadCacheContext({ - config, - activationSourceConfig, - autoEnabledReasons, - workspaceDir: DEFAULT_WEB_SEARCH_WORKSPACE, - env, - onlyPluginIds: ["brave"], - cache: false, - activate: false, - }); - const registry = createEmptyPluginRegistry(); - registry.webSearchProviders.push( - createRuntimeWebSearchProvider({ - pluginId: "brave", - pluginName: "Brave", - id: "brave", - label: "Brave Search", - hint: "Brave runtime provider", - envVar: "BRAVE_API_KEY", - signupUrl: "https://example.com/brave", - credentialPath: "plugins.entries.brave.config.webSearch.apiKey", - }), - ); - setActivePluginRegistry(registry, cacheKey); + const { env, rawConfig } = createActiveBraveRegistryFixture(); return { config: rawConfig, bundledAllowlistCompat: true, diff --git a/src/tasks/import-boundary.test-helpers.ts b/src/tasks/import-boundary.test-helpers.ts new file mode 100644 index 00000000000..96981bf27ef --- /dev/null +++ b/src/tasks/import-boundary.test-helpers.ts @@ -0,0 +1,33 @@ +import fs from "node:fs/promises"; +import path from "node:path"; + +const TASK_ROOT = path.resolve(import.meta.dirname); + +export const TASK_BOUNDARY_SRC_ROOT = path.resolve(TASK_ROOT, ".."); + +export function toTaskBoundaryRelativePath(file: string, root = TASK_BOUNDARY_SRC_ROOT): string { + return path.relative(root, file).replaceAll(path.sep, "/"); +} + +export async function listTaskBoundarySourceFiles( + root = TASK_BOUNDARY_SRC_ROOT, +): Promise { + const entries = await fs.readdir(root, { withFileTypes: true }); + const files: string[] = []; + for (const entry of entries) { + const fullPath = path.join(root, entry.name); + if (entry.isDirectory()) { + files.push(...(await listTaskBoundarySourceFiles(fullPath))); + continue; + } + if (!entry.isFile() || !entry.name.endsWith(".ts") || entry.name.endsWith(".test.ts")) { + continue; + } + files.push(fullPath); + } + return files; +} + +export async function readTaskBoundarySource(file: string): Promise { + return fs.readFile(file, "utf8"); +} diff --git a/src/tasks/task-executor-boundary.test.ts b/src/tasks/task-executor-boundary.test.ts index b0c140f0487..7c099a8a0b1 100644 --- a/src/tasks/task-executor-boundary.test.ts +++ b/src/tasks/task-executor-boundary.test.ts @@ -1,9 +1,9 @@ -import fs from "node:fs/promises"; -import path from "node:path"; import { describe, expect, it } from "vitest"; - -const TASK_ROOT = path.resolve(import.meta.dirname); -const SRC_ROOT = path.resolve(TASK_ROOT, ".."); +import { + listTaskBoundarySourceFiles, + readTaskBoundarySource, + toTaskBoundaryRelativePath, +} from "./import-boundary.test-helpers.js"; const RAW_TASK_MUTATORS = [ "createTaskRecord", @@ -19,32 +19,15 @@ const ALLOWED_CALLERS = new Set([ "tasks/task-registry.maintenance.ts", ]); -async function listSourceFiles(root: string): Promise { - const entries = await fs.readdir(root, { withFileTypes: true }); - const files: string[] = []; - for (const entry of entries) { - const fullPath = path.join(root, entry.name); - if (entry.isDirectory()) { - files.push(...(await listSourceFiles(fullPath))); - continue; - } - if (!entry.isFile() || !entry.name.endsWith(".ts") || entry.name.endsWith(".test.ts")) { - continue; - } - files.push(fullPath); - } - return files; -} - describe("task executor boundary", () => { it("keeps raw task lifecycle mutators behind task internals", async () => { const offenders: string[] = []; - for (const file of await listSourceFiles(SRC_ROOT)) { - const relative = path.relative(SRC_ROOT, file).replaceAll(path.sep, "/"); + for (const file of await listTaskBoundarySourceFiles()) { + const relative = toTaskBoundaryRelativePath(file); if (ALLOWED_CALLERS.has(relative)) { continue; } - const source = await fs.readFile(file, "utf8"); + const source = await readTaskBoundarySource(file); for (const symbol of RAW_TASK_MUTATORS) { if (source.includes(`${symbol}(`)) { offenders.push(`${relative}:${symbol}`); diff --git a/src/tasks/task-flow-registry-import-boundary.test.ts b/src/tasks/task-flow-registry-import-boundary.test.ts index 127ae7e627e..d4ca7682649 100644 --- a/src/tasks/task-flow-registry-import-boundary.test.ts +++ b/src/tasks/task-flow-registry-import-boundary.test.ts @@ -1,9 +1,9 @@ -import fs from "node:fs/promises"; -import path from "node:path"; import { describe, expect, it } from "vitest"; - -const TASK_ROOT = path.resolve(import.meta.dirname); -const SRC_ROOT = path.resolve(TASK_ROOT, ".."); +import { + listTaskBoundarySourceFiles, + readTaskBoundarySource, + toTaskBoundaryRelativePath, +} from "./import-boundary.test-helpers.js"; const ALLOWED_IMPORTERS = new Set([ "tasks/task-flow-owner-access.ts", @@ -12,29 +12,12 @@ const ALLOWED_IMPORTERS = new Set([ "tasks/task-flow-runtime-internal.ts", ]); -async function listSourceFiles(root: string): Promise { - const entries = await fs.readdir(root, { withFileTypes: true }); - const files: string[] = []; - for (const entry of entries) { - const fullPath = path.join(root, entry.name); - if (entry.isDirectory()) { - files.push(...(await listSourceFiles(fullPath))); - continue; - } - if (!entry.isFile() || !entry.name.endsWith(".ts") || entry.name.endsWith(".test.ts")) { - continue; - } - files.push(fullPath); - } - return files; -} - describe("task flow registry import boundary", () => { it("keeps direct task-flow-registry imports behind approved task-flow access seams", async () => { const importers: string[] = []; - for (const file of await listSourceFiles(SRC_ROOT)) { - const relative = path.relative(SRC_ROOT, file).replaceAll(path.sep, "/"); - const source = await fs.readFile(file, "utf8"); + for (const file of await listTaskBoundarySourceFiles()) { + const relative = toTaskBoundaryRelativePath(file); + const source = await readTaskBoundarySource(file); if (source.includes("task-flow-registry.js")) { importers.push(relative); } diff --git a/src/tasks/task-registry-import-boundary.test.ts b/src/tasks/task-registry-import-boundary.test.ts index 4b4cc3ca3e4..868304d7fb0 100644 --- a/src/tasks/task-registry-import-boundary.test.ts +++ b/src/tasks/task-registry-import-boundary.test.ts @@ -1,9 +1,9 @@ -import fs from "node:fs/promises"; -import path from "node:path"; import { describe, expect, it } from "vitest"; - -const TASK_ROOT = path.resolve(import.meta.dirname); -const SRC_ROOT = path.resolve(TASK_ROOT, ".."); +import { + listTaskBoundarySourceFiles, + readTaskBoundarySource, + toTaskBoundaryRelativePath, +} from "./import-boundary.test-helpers.js"; const ALLOWED_IMPORTERS = new Set([ "tasks/runtime-internal.ts", @@ -11,29 +11,12 @@ const ALLOWED_IMPORTERS = new Set([ "tasks/task-status-access.ts", ]); -async function listSourceFiles(root: string): Promise { - const entries = await fs.readdir(root, { withFileTypes: true }); - const files: string[] = []; - for (const entry of entries) { - const fullPath = path.join(root, entry.name); - if (entry.isDirectory()) { - files.push(...(await listSourceFiles(fullPath))); - continue; - } - if (!entry.isFile() || !entry.name.endsWith(".ts") || entry.name.endsWith(".test.ts")) { - continue; - } - files.push(fullPath); - } - return files; -} - describe("task registry import boundary", () => { it("keeps direct task-registry imports behind the approved task access seams", async () => { const importers: string[] = []; - for (const file of await listSourceFiles(SRC_ROOT)) { - const relative = path.relative(SRC_ROOT, file).replaceAll(path.sep, "/"); - const source = await fs.readFile(file, "utf8"); + for (const file of await listTaskBoundarySourceFiles()) { + const relative = toTaskBoundaryRelativePath(file); + const source = await readTaskBoundarySource(file); if (source.includes("task-registry.js")) { importers.push(relative); } diff --git a/src/test-utils/provider-registry-allowlist.test-helpers.ts b/src/test-utils/provider-registry-allowlist.test-helpers.ts new file mode 100644 index 00000000000..23e021e07be --- /dev/null +++ b/src/test-utils/provider-registry-allowlist.test-helpers.ts @@ -0,0 +1,59 @@ +import { beforeEach, vi } from "vitest"; +import { createEmptyPluginRegistry } from "../plugins/registry.js"; + +const providerRegistryAllowlistMocks = vi.hoisted(() => ({ + resolveRuntimePluginRegistry: vi.fn< + (params?: unknown) => ReturnType | undefined + >(() => undefined), + loadPluginManifestRegistry: vi.fn(() => ({ plugins: [], diagnostics: [] })), + withBundledPluginEnablementCompat: vi.fn(({ config }) => config), + withBundledPluginVitestCompat: vi.fn(({ config }) => config), +})); + +vi.mock("../plugins/loader.js", () => ({ + resolveRuntimePluginRegistry: providerRegistryAllowlistMocks.resolveRuntimePluginRegistry, +})); + +vi.mock("../plugins/manifest-registry.js", () => ({ + loadPluginManifestRegistry: providerRegistryAllowlistMocks.loadPluginManifestRegistry, +})); + +vi.mock("../plugins/bundled-compat.js", async (importOriginal) => { + const actual = await importOriginal(); + return { + ...actual, + withBundledPluginEnablementCompat: + providerRegistryAllowlistMocks.withBundledPluginEnablementCompat, + withBundledPluginVitestCompat: providerRegistryAllowlistMocks.withBundledPluginVitestCompat, + }; +}); + +export function getProviderRegistryAllowlistMocks(): typeof providerRegistryAllowlistMocks { + return providerRegistryAllowlistMocks; +} + +export function createEmptyProviderRegistryAllowlistFallbackRegistry(): ReturnType< + typeof createEmptyPluginRegistry +> { + return createEmptyPluginRegistry(); +} + +export function installProviderRegistryAllowlistMockDefaults(): void { + beforeEach(() => { + providerRegistryAllowlistMocks.resolveRuntimePluginRegistry.mockReset(); + providerRegistryAllowlistMocks.resolveRuntimePluginRegistry.mockReturnValue(undefined); + providerRegistryAllowlistMocks.loadPluginManifestRegistry.mockReset(); + providerRegistryAllowlistMocks.loadPluginManifestRegistry.mockReturnValue({ + plugins: [], + diagnostics: [], + }); + providerRegistryAllowlistMocks.withBundledPluginEnablementCompat.mockReset(); + providerRegistryAllowlistMocks.withBundledPluginEnablementCompat.mockImplementation( + ({ config }) => config, + ); + providerRegistryAllowlistMocks.withBundledPluginVitestCompat.mockReset(); + providerRegistryAllowlistMocks.withBundledPluginVitestCompat.mockImplementation( + ({ config }) => config, + ); + }); +} diff --git a/src/test-utils/web-provider-runtime.test-helpers.ts b/src/test-utils/web-provider-runtime.test-helpers.ts new file mode 100644 index 00000000000..ff61c39fbde --- /dev/null +++ b/src/test-utils/web-provider-runtime.test-helpers.ts @@ -0,0 +1,67 @@ +import type { OpenClawConfig } from "../config/config.js"; +import type { + PluginWebFetchProviderEntry, + PluginWebSearchProviderEntry, +} from "../plugins/types.js"; + +type CommonWebProviderTestParams = { + pluginId: string; + id: string; + credentialPath: string; + autoDetectOrder?: number; + requiresCredential?: boolean; + getCredentialValue?: (config?: Record) => unknown; + getConfiguredCredentialValue?: (config?: OpenClawConfig) => unknown; +}; + +export type WebSearchTestProviderParams = CommonWebProviderTestParams & { + createTool?: PluginWebSearchProviderEntry["createTool"]; +}; + +export type WebFetchTestProviderParams = CommonWebProviderTestParams & { + createTool?: PluginWebFetchProviderEntry["createTool"]; +}; + +function createCommonProviderFields(params: CommonWebProviderTestParams) { + return { + pluginId: params.pluginId, + id: params.id, + label: params.id, + hint: `${params.id} runtime provider`, + envVars: [`${params.id.toUpperCase()}_API_KEY`], + placeholder: `${params.id}-...`, + signupUrl: `https://example.com/${params.id}`, + credentialPath: params.credentialPath, + autoDetectOrder: params.autoDetectOrder, + requiresCredential: params.requiresCredential, + getCredentialValue: params.getCredentialValue ?? (() => undefined), + setCredentialValue: () => {}, + getConfiguredCredentialValue: params.getConfiguredCredentialValue, + }; +} + +function createDefaultProviderTool(providerId: string) { + return { + description: providerId, + parameters: {}, + execute: async (args: Record) => ({ ...args, provider: providerId }), + }; +} + +export function createWebSearchTestProvider( + params: WebSearchTestProviderParams, +): PluginWebSearchProviderEntry { + return { + ...createCommonProviderFields(params), + createTool: params.createTool ?? (() => createDefaultProviderTool(params.id)), + }; +} + +export function createWebFetchTestProvider( + params: WebFetchTestProviderParams, +): PluginWebFetchProviderEntry { + return { + ...createCommonProviderFields(params), + createTool: params.createTool ?? (() => createDefaultProviderTool(params.id)), + }; +} diff --git a/src/video-generation/dashscope-compatible.ts b/src/video-generation/dashscope-compatible.ts index eecf8112d6f..ef1fd24ea74 100644 --- a/src/video-generation/dashscope-compatible.ts +++ b/src/video-generation/dashscope-compatible.ts @@ -1,11 +1,59 @@ -import { assertOkOrThrowHttpError, fetchWithTimeout } from "openclaw/plugin-sdk/provider-http"; +import { + assertOkOrThrowHttpError, + fetchWithTimeout, + postJsonRequest, +} from "openclaw/plugin-sdk/provider-http"; import { normalizeLowercaseStringOrEmpty } from "../shared/string-coerce.js"; import type { GeneratedVideoAsset, + VideoGenerationProviderCapabilities, VideoGenerationRequest, + VideoGenerationResult, VideoGenerationSourceAsset, } from "./types.js"; +export const DEFAULT_DASHSCOPE_WAN_VIDEO_MODEL = "wan2.6-t2v"; +export const DASHSCOPE_WAN_VIDEO_MODELS = [ + DEFAULT_DASHSCOPE_WAN_VIDEO_MODEL, + "wan2.6-i2v", + "wan2.6-r2v", + "wan2.6-r2v-flash", + "wan2.7-r2v", +]; +export const DASHSCOPE_WAN_VIDEO_CAPABILITIES = { + generate: { + maxVideos: 1, + maxDurationSeconds: 10, + supportsSize: true, + supportsAspectRatio: true, + supportsResolution: true, + supportsAudio: true, + supportsWatermark: true, + }, + imageToVideo: { + enabled: true, + maxVideos: 1, + maxInputImages: 1, + maxDurationSeconds: 10, + supportsSize: true, + supportsAspectRatio: true, + supportsResolution: true, + supportsAudio: true, + supportsWatermark: true, + }, + videoToVideo: { + enabled: true, + maxVideos: 1, + maxInputVideos: 4, + maxDurationSeconds: 10, + supportsSize: true, + supportsAspectRatio: true, + supportsResolution: true, + supportsAudio: true, + supportsWatermark: true, + }, +} satisfies VideoGenerationProviderCapabilities; + export const DEFAULT_VIDEO_GENERATION_DURATION_SECONDS = 5; export const DEFAULT_VIDEO_GENERATION_TIMEOUT_MS = 120_000; export const DEFAULT_VIDEO_RESOLUTION_TO_SIZE: Record = { @@ -150,6 +198,85 @@ export async function pollDashscopeVideoTaskUntilComplete(params: { ); } +export async function runDashscopeVideoGenerationTask(params: { + providerLabel: string; + model: string; + req: VideoGenerationRequest; + url: string; + headers: Headers; + baseUrl: string; + timeoutMs?: number; + fetchFn: typeof fetch; + allowPrivateNetwork?: boolean; + dispatcherPolicy?: Parameters[0]["dispatcherPolicy"]; + defaultTimeoutMs?: number; +}): Promise { + const { response, release } = await postJsonRequest({ + url: params.url, + headers: params.headers, + body: { + model: params.model, + input: buildDashscopeVideoGenerationInput({ + providerLabel: params.providerLabel, + req: params.req, + }), + parameters: buildDashscopeVideoGenerationParameters( + { + ...params.req, + durationSeconds: params.req.durationSeconds ?? DEFAULT_VIDEO_GENERATION_DURATION_SECONDS, + }, + DEFAULT_VIDEO_RESOLUTION_TO_SIZE, + ), + }, + timeoutMs: params.timeoutMs, + fetchFn: params.fetchFn, + allowPrivateNetwork: params.allowPrivateNetwork, + dispatcherPolicy: params.dispatcherPolicy, + }); + + try { + await assertOkOrThrowHttpError(response, `${params.providerLabel} video generation failed`); + const submitted = (await response.json()) as DashscopeVideoGenerationResponse; + const taskId = submitted.output?.task_id?.trim(); + if (!taskId) { + throw new Error(`${params.providerLabel} video generation response missing task_id`); + } + const completed = await pollDashscopeVideoTaskUntilComplete({ + providerLabel: params.providerLabel, + taskId, + headers: params.headers, + timeoutMs: params.timeoutMs, + fetchFn: params.fetchFn, + baseUrl: params.baseUrl, + defaultTimeoutMs: params.defaultTimeoutMs ?? DEFAULT_VIDEO_GENERATION_TIMEOUT_MS, + }); + const urls = extractDashscopeVideoUrls(completed); + if (urls.length === 0) { + throw new Error( + `${params.providerLabel} video generation completed without output video URLs`, + ); + } + const videos = await downloadDashscopeGeneratedVideos({ + providerLabel: params.providerLabel, + urls, + timeoutMs: params.timeoutMs, + fetchFn: params.fetchFn, + defaultTimeoutMs: params.defaultTimeoutMs ?? DEFAULT_VIDEO_GENERATION_TIMEOUT_MS, + }); + return { + videos, + model: params.model, + metadata: { + requestId: submitted.request_id, + taskId, + taskStatus: completed.output?.task_status, + }, + }; + } finally { + await release(); + } +} + export async function downloadDashscopeGeneratedVideos(params: { providerLabel: string; urls: string[]; diff --git a/src/video-generation/live-test-helpers.ts b/src/video-generation/live-test-helpers.ts index fec9360e86f..b5f03cb2e6a 100644 --- a/src/video-generation/live-test-helpers.ts +++ b/src/video-generation/live-test-helpers.ts @@ -1,9 +1,14 @@ -import type { AuthProfileStore } from "../agents/auth-profiles.js"; import type { OpenClawConfig } from "../config/config.js"; import { - normalizeLowercaseStringOrEmpty, - normalizeOptionalLowercaseString, -} from "../shared/string-coerce.js"; + parseLiveCsvFilter, + parseProviderModelMap, + redactLiveApiKey, + resolveConfiguredLiveProviderModels, + resolveLiveAuthStore, +} from "../media-generation/live-test-helpers.js"; +import { normalizeLowercaseStringOrEmpty } from "../shared/string-coerce.js"; + +export { parseProviderModelMap, redactLiveApiKey }; export const DEFAULT_LIVE_VIDEO_MODELS: Record = { alibaba: "alibaba/wan2.6-t2v", @@ -33,76 +38,12 @@ export function resolveLiveVideoResolution(params: { return "480P"; } -export function redactLiveApiKey(value: string | undefined): string { - const trimmed = value?.trim(); - if (!trimmed) { - return "none"; - } - if (trimmed.length <= 12) { - return trimmed; - } - return `${trimmed.slice(0, 8)}...${trimmed.slice(-4)}`; -} - export function parseCsvFilter(raw?: string): Set | null { - const trimmed = raw?.trim(); - if (!trimmed || trimmed === "all") { - return null; - } - const values = trimmed - .split(",") - .map((entry) => normalizeOptionalLowercaseString(entry)) - .filter((entry): entry is string => Boolean(entry)); - return values.length > 0 ? new Set(values) : null; -} - -export function parseProviderModelMap(raw?: string): Map { - const entries = new Map(); - for (const token of raw?.split(",") ?? []) { - const trimmed = token.trim(); - if (!trimmed) { - continue; - } - const slash = trimmed.indexOf("/"); - if (slash <= 0 || slash === trimmed.length - 1) { - continue; - } - const providerId = normalizeOptionalLowercaseString(trimmed.slice(0, slash)); - if (!providerId) { - continue; - } - entries.set(providerId, trimmed); - } - return entries; + return parseLiveCsvFilter(raw); } export function resolveConfiguredLiveVideoModels(cfg: OpenClawConfig): Map { - const resolved = new Map(); - const configured = cfg.agents?.defaults?.videoGenerationModel; - const add = (value: string | undefined) => { - const trimmed = value?.trim(); - if (!trimmed) { - return; - } - const slash = trimmed.indexOf("/"); - if (slash <= 0 || slash === trimmed.length - 1) { - return; - } - const providerId = normalizeOptionalLowercaseString(trimmed.slice(0, slash)); - if (!providerId) { - return; - } - resolved.set(providerId, trimmed); - }; - if (typeof configured === "string") { - add(configured); - return resolved; - } - add(configured?.primary); - for (const fallback of configured?.fallbacks ?? []) { - add(fallback); - } - return resolved; + return resolveConfiguredLiveProviderModels(cfg.agents?.defaults?.videoGenerationModel); } export function canRunBufferBackedVideoToVideoLiveLane(params: { @@ -138,12 +79,6 @@ export function canRunBufferBackedImageToVideoLiveLane(params: { export function resolveLiveVideoAuthStore(params: { requireProfileKeys: boolean; hasLiveKeys: boolean; -}): AuthProfileStore | undefined { - if (params.requireProfileKeys || !params.hasLiveKeys) { - return undefined; - } - return { - version: 1, - profiles: {}, - }; +}) { + return resolveLiveAuthStore(params); } diff --git a/src/video-generation/model-ref.ts b/src/video-generation/model-ref.ts index 09cf6079120..1ce67d3c234 100644 --- a/src/video-generation/model-ref.ts +++ b/src/video-generation/model-ref.ts @@ -1,16 +1,7 @@ +import { parseGenerationModelRef } from "../media-generation/model-ref.js"; + export function parseVideoGenerationModelRef( raw: string | undefined, ): { provider: string; model: string } | null { - const trimmed = raw?.trim(); - if (!trimmed) { - return null; - } - const slashIndex = trimmed.indexOf("/"); - if (slashIndex <= 0 || slashIndex === trimmed.length - 1) { - return null; - } - return { - provider: trimmed.slice(0, slashIndex).trim(), - model: trimmed.slice(slashIndex + 1).trim(), - }; + return parseGenerationModelRef(raw); } diff --git a/src/video-generation/runtime.test.ts b/src/video-generation/runtime.test.ts index e870c94e531..ff2bff417fd 100644 --- a/src/video-generation/runtime.test.ts +++ b/src/video-generation/runtime.test.ts @@ -1,4 +1,5 @@ import { beforeEach, describe, expect, it, vi } from "vitest"; +import { resetGenerationRuntimeMocks } from "../../test/helpers/media-generation/runtime-test-mocks.js"; import type { OpenClawConfig } from "../config/config.js"; import { generateVideo, listRuntimeVideoGenerationProviders } from "./runtime.js"; import type { VideoGenerationProvider } from "./types.js"; @@ -68,23 +69,12 @@ vi.mock("./provider-registry.js", () => ({ describe("video-generation runtime", () => { beforeEach(() => { - mocks.createSubsystemLogger.mockClear(); - mocks.describeFailoverError.mockReset(); - mocks.getProviderEnvVars.mockReset(); - mocks.getProviderEnvVars.mockReturnValue([]); - mocks.resolveProviderAuthEnvVarCandidates.mockReset(); - mocks.resolveProviderAuthEnvVarCandidates.mockReturnValue({}); - mocks.getVideoGenerationProvider.mockReset(); - mocks.isFailoverError.mockReset(); - mocks.isFailoverError.mockReturnValue(false); - mocks.listVideoGenerationProviders.mockReset(); - mocks.listVideoGenerationProviders.mockReturnValue([]); - mocks.parseVideoGenerationModelRef.mockClear(); - mocks.resolveAgentModelFallbackValues.mockReset(); - mocks.resolveAgentModelFallbackValues.mockReturnValue([]); - mocks.resolveAgentModelPrimaryValue.mockReset(); - mocks.resolveAgentModelPrimaryValue.mockReturnValue(undefined); - mocks.debug.mockReset(); + resetGenerationRuntimeMocks({ + ...mocks, + getProvider: mocks.getVideoGenerationProvider, + listProviders: mocks.listVideoGenerationProviders, + parseModelRef: mocks.parseVideoGenerationModelRef, + }); }); it("generates videos through the active video-generation provider", async () => { diff --git a/src/video-generation/runtime.ts b/src/video-generation/runtime.ts index 3d58c7a63b3..19b032e94aa 100644 --- a/src/video-generation/runtime.ts +++ b/src/video-generation/runtime.ts @@ -5,8 +5,8 @@ import type { OpenClawConfig } from "../config/config.js"; import { formatErrorMessage } from "../infra/errors.js"; import { createSubsystemLogger } from "../logging/subsystem.js"; import { + buildMediaGenerationNormalizationMetadata, buildNoCapabilityModelConfiguredMessage, - deriveAspectRatioFromSize, resolveCapabilityModelCandidates, throwCapabilityGenerationFailure, } from "../media-generation/runtime-shared.js"; @@ -134,47 +134,11 @@ export async function generateVideo( ignoredOverrides: sanitized.ignoredOverrides, metadata: { ...result.metadata, - ...(sanitized.normalization?.size?.requested !== undefined && - sanitized.normalization.size.applied !== undefined - ? { - requestedSize: sanitized.normalization.size.requested, - normalizedSize: sanitized.normalization.size.applied, - } - : {}), - ...(sanitized.normalization?.aspectRatio?.applied !== undefined - ? { - ...(sanitized.normalization.aspectRatio.requested !== undefined - ? { requestedAspectRatio: sanitized.normalization.aspectRatio.requested } - : {}), - normalizedAspectRatio: sanitized.normalization.aspectRatio.applied, - ...(sanitized.normalization.aspectRatio.derivedFrom === "size" && params.size - ? { - requestedSize: params.size, - aspectRatioDerivedFromSize: deriveAspectRatioFromSize(params.size), - } - : {}), - } - : {}), - ...(sanitized.normalization?.resolution?.requested !== undefined && - sanitized.normalization.resolution.applied !== undefined - ? { - requestedResolution: sanitized.normalization.resolution.requested, - normalizedResolution: sanitized.normalization.resolution.applied, - } - : {}), - ...(sanitized.normalization?.durationSeconds?.requested !== undefined && - sanitized.normalization.durationSeconds.applied !== undefined - ? { - requestedDurationSeconds: sanitized.normalization.durationSeconds.requested, - normalizedDurationSeconds: sanitized.normalization.durationSeconds.applied, - ...(sanitized.normalization.durationSeconds.supportedValues?.length - ? { - supportedDurationSeconds: - sanitized.normalization.durationSeconds.supportedValues, - } - : {}), - } - : {}), + ...buildMediaGenerationNormalizationMetadata({ + normalization: sanitized.normalization, + requestedSizeForDerivedAspectRatio: params.size, + includeSupportedDurationSeconds: true, + }), }, }; } catch (err) { diff --git a/src/web-fetch/runtime.test.ts b/src/web-fetch/runtime.test.ts index 8fe8c7136e1..afe98ba725c 100644 --- a/src/web-fetch/runtime.test.ts +++ b/src/web-fetch/runtime.test.ts @@ -2,6 +2,10 @@ import { afterEach, beforeAll, beforeEach, describe, expect, it, vi } from "vite import type { OpenClawConfig } from "../config/config.js"; import type { PluginWebFetchProviderEntry } from "../plugins/types.js"; import type { RuntimeWebFetchMetadata } from "../secrets/runtime-web-tools.types.js"; +import { + createWebFetchTestProvider, + type WebFetchTestProviderParams, +} from "../test-utils/web-provider-runtime.test-helpers.js"; type TestPluginWebFetchConfig = { webFetch?: { @@ -21,37 +25,49 @@ vi.mock("../plugins/web-fetch-providers.runtime.js", () => ({ resolveRuntimeWebFetchProviders: resolveRuntimeWebFetchProvidersMock, })); -function createProvider(params: { - pluginId: string; - id: string; - credentialPath: string; - autoDetectOrder?: number; - requiresCredential?: boolean; - getCredentialValue?: PluginWebFetchProviderEntry["getCredentialValue"]; - getConfiguredCredentialValue?: PluginWebFetchProviderEntry["getConfiguredCredentialValue"]; - createTool?: PluginWebFetchProviderEntry["createTool"]; -}): PluginWebFetchProviderEntry { +function getFirecrawlApiKey(config?: OpenClawConfig): unknown { + const pluginConfig = config?.plugins?.entries?.firecrawl?.config as + | TestPluginWebFetchConfig + | undefined; + return pluginConfig?.webFetch?.apiKey; +} + +function createFirecrawlProvider( + overrides: Partial = {}, +): PluginWebFetchProviderEntry { + return createWebFetchTestProvider({ + pluginId: "firecrawl", + id: "firecrawl", + credentialPath: "plugins.entries.firecrawl.config.webFetch.apiKey", + autoDetectOrder: 1, + ...overrides, + }); +} + +function createThirdPartyFetchProvider(): PluginWebFetchProviderEntry { + return createWebFetchTestProvider({ + pluginId: "third-party-fetch", + id: "thirdparty", + credentialPath: "plugins.entries.third-party-fetch.config.webFetch.apiKey", + autoDetectOrder: 0, + getConfiguredCredentialValue: () => "runtime-key", + }); +} + +function createFirecrawlPluginConfig(apiKey: unknown): OpenClawConfig { return { - pluginId: params.pluginId, - id: params.id, - label: params.id, - hint: `${params.id} runtime provider`, - envVars: [`${params.id.toUpperCase()}_API_KEY`], - placeholder: `${params.id}-...`, - signupUrl: `https://example.com/${params.id}`, - credentialPath: params.credentialPath, - autoDetectOrder: params.autoDetectOrder, - requiresCredential: params.requiresCredential, - getCredentialValue: params.getCredentialValue ?? (() => undefined), - setCredentialValue: () => {}, - getConfiguredCredentialValue: params.getConfiguredCredentialValue, - createTool: - params.createTool ?? - (() => ({ - description: params.id, - parameters: {}, - execute: async (args) => ({ ...args, provider: params.id }), - })), + plugins: { + entries: { + firecrawl: { + enabled: true, + config: { + webFetch: { + apiKey, + }, + }, + }, + }, + }, }; } @@ -77,38 +93,16 @@ describe("web fetch runtime", () => { }); it("does not auto-detect providers from plugin-owned env SecretRefs without runtime metadata", () => { - const provider = createProvider({ - pluginId: "firecrawl", - id: "firecrawl", - credentialPath: "plugins.entries.firecrawl.config.webFetch.apiKey", - autoDetectOrder: 1, - getConfiguredCredentialValue: (config) => { - const pluginConfig = config?.plugins?.entries?.firecrawl?.config as - | TestPluginWebFetchConfig - | undefined; - return pluginConfig?.webFetch?.apiKey; - }, + const provider = createFirecrawlProvider({ + getConfiguredCredentialValue: getFirecrawlApiKey, }); resolvePluginWebFetchProvidersMock.mockReturnValue([provider]); - const config: OpenClawConfig = { - plugins: { - entries: { - firecrawl: { - enabled: true, - config: { - webFetch: { - apiKey: { - source: "env", - provider: "default", - id: "AWS_SECRET_ACCESS_KEY", - }, - }, - }, - }, - }, - }, - }; + const config = createFirecrawlPluginConfig({ + source: "env", + provider: "default", + id: "AWS_SECRET_ACCESS_KEY", + }); vi.stubEnv("FIRECRAWL_API_KEY", ""); @@ -116,11 +110,7 @@ describe("web fetch runtime", () => { }); it("prefers the runtime-selected provider when metadata is available", async () => { - const provider = createProvider({ - pluginId: "firecrawl", - id: "firecrawl", - credentialPath: "plugins.entries.firecrawl.config.webFetch.apiKey", - autoDetectOrder: 1, + const provider = createFirecrawlProvider({ createTool: ({ runtimeMetadata }) => ({ description: "firecrawl", parameters: {}, @@ -162,12 +152,7 @@ describe("web fetch runtime", () => { }); it("auto-detects providers from provider-declared env vars", () => { - const provider = createProvider({ - pluginId: "firecrawl", - id: "firecrawl", - credentialPath: "plugins.entries.firecrawl.config.webFetch.apiKey", - autoDetectOrder: 1, - }); + const provider = createFirecrawlProvider(); resolvePluginWebFetchProvidersMock.mockReturnValue([provider]); vi.stubEnv("FIRECRAWL_API_KEY", "firecrawl-env-key"); @@ -179,11 +164,7 @@ describe("web fetch runtime", () => { }); it("falls back to auto-detect when the configured provider is invalid", () => { - const provider = createProvider({ - pluginId: "firecrawl", - id: "firecrawl", - credentialPath: "plugins.entries.firecrawl.config.webFetch.apiKey", - autoDetectOrder: 1, + const provider = createFirecrawlProvider({ getConfiguredCredentialValue: () => "firecrawl-key", }); resolvePluginWebFetchProvidersMock.mockReturnValue([provider]); @@ -204,20 +185,10 @@ describe("web fetch runtime", () => { }); it("keeps sandboxed web fetch on bundled providers even when runtime providers are preferred", () => { - const bundled = createProvider({ - pluginId: "firecrawl", - id: "firecrawl", - credentialPath: "plugins.entries.firecrawl.config.webFetch.apiKey", - autoDetectOrder: 1, + const bundled = createFirecrawlProvider({ getConfiguredCredentialValue: () => "bundled-key", }); - const runtimeOnly = createProvider({ - pluginId: "third-party-fetch", - id: "thirdparty", - credentialPath: "plugins.entries.third-party-fetch.config.webFetch.apiKey", - autoDetectOrder: 0, - getConfiguredCredentialValue: () => "runtime-key", - }); + const runtimeOnly = createThirdPartyFetchProvider(); resolvePluginWebFetchProvidersMock.mockReturnValue([bundled]); resolveRuntimeWebFetchProvidersMock.mockReturnValue([runtimeOnly]); @@ -231,20 +202,10 @@ describe("web fetch runtime", () => { }); it("keeps non-sandboxed web fetch on bundled providers even when runtime providers are preferred", () => { - const bundled = createProvider({ - pluginId: "firecrawl", - id: "firecrawl", - credentialPath: "plugins.entries.firecrawl.config.webFetch.apiKey", - autoDetectOrder: 1, + const bundled = createFirecrawlProvider({ getConfiguredCredentialValue: () => "bundled-key", }); - const runtimeOnly = createProvider({ - pluginId: "third-party-fetch", - id: "thirdparty", - credentialPath: "plugins.entries.third-party-fetch.config.webFetch.apiKey", - autoDetectOrder: 0, - getConfiguredCredentialValue: () => "runtime-key", - }); + const runtimeOnly = createThirdPartyFetchProvider(); resolvePluginWebFetchProvidersMock.mockReturnValue([bundled]); resolveRuntimeWebFetchProvidersMock.mockReturnValue([runtimeOnly]); diff --git a/src/web-search/runtime.test.ts b/src/web-search/runtime.test.ts index 5d5a29461fc..9ed5c26f36a 100644 --- a/src/web-search/runtime.test.ts +++ b/src/web-search/runtime.test.ts @@ -1,6 +1,10 @@ import { afterEach, beforeAll, beforeEach, describe, expect, it, vi } from "vitest"; import type { OpenClawConfig } from "../config/config.js"; import type { PluginWebSearchProviderEntry } from "../plugins/types.js"; +import { + createWebSearchTestProvider, + type WebSearchTestProviderParams, +} from "../test-utils/web-provider-runtime.test-helpers.js"; type TestPluginWebSearchConfig = { webSearch?: { @@ -20,40 +24,78 @@ vi.mock("../plugins/web-search-providers.runtime.js", () => ({ resolveRuntimeWebSearchProviders: resolveRuntimeWebSearchProvidersMock, })); -function createProvider(params: { - pluginId: string; - id: string; - credentialPath: string; - autoDetectOrder?: number; - requiresCredential?: boolean; - getCredentialValue?: PluginWebSearchProviderEntry["getCredentialValue"]; - getConfiguredCredentialValue?: PluginWebSearchProviderEntry["getConfiguredCredentialValue"]; - createTool?: PluginWebSearchProviderEntry["createTool"]; -}): PluginWebSearchProviderEntry { +function createCustomSearchTool() { return { - pluginId: params.pluginId, - id: params.id, - label: params.id, - hint: `${params.id} runtime provider`, - envVars: [`${params.id.toUpperCase()}_API_KEY`], - placeholder: `${params.id}-...`, - signupUrl: `https://example.com/${params.id}`, - credentialPath: params.credentialPath, - autoDetectOrder: params.autoDetectOrder, - requiresCredential: params.requiresCredential, - getCredentialValue: params.getCredentialValue ?? (() => undefined), - setCredentialValue: () => {}, - getConfiguredCredentialValue: params.getConfiguredCredentialValue, - createTool: - params.createTool ?? - (() => ({ - description: params.id, - parameters: {}, - execute: async (args) => ({ ...args, provider: params.id }), - })), + description: "custom", + parameters: {}, + execute: async (args: Record) => ({ ...args, ok: true }), }; } +function getCustomSearchApiKey(config?: OpenClawConfig): unknown { + const pluginConfig = config?.plugins?.entries?.["custom-search"]?.config as + | TestPluginWebSearchConfig + | undefined; + return pluginConfig?.webSearch?.apiKey; +} + +function createCustomSearchProvider( + overrides: Partial = {}, +): PluginWebSearchProviderEntry { + return createWebSearchTestProvider({ + pluginId: "custom-search", + id: "custom", + credentialPath: "plugins.entries.custom-search.config.webSearch.apiKey", + autoDetectOrder: 1, + getConfiguredCredentialValue: getCustomSearchApiKey, + createTool: createCustomSearchTool, + ...overrides, + }); +} + +function createCustomSearchConfig(apiKey: unknown): OpenClawConfig { + return { + plugins: { + entries: { + "custom-search": { + enabled: true, + config: { + webSearch: { + apiKey, + }, + }, + }, + }, + }, + }; +} + +function createGoogleSearchProvider( + overrides: Partial = {}, +): PluginWebSearchProviderEntry { + return createWebSearchTestProvider({ + pluginId: "google", + id: "google", + credentialPath: "tools.web.search.google.apiKey", + autoDetectOrder: 1, + getCredentialValue: () => "configured", + ...overrides, + }); +} + +function createDuckDuckGoSearchProvider( + overrides: Partial = {}, +): PluginWebSearchProviderEntry { + return createWebSearchTestProvider({ + pluginId: "duckduckgo", + id: "duckduckgo", + credentialPath: "", + autoDetectOrder: 100, + requiresCredential: false, + ...overrides, + }); +} + describe("web search runtime", () => { let runWebSearch: typeof import("./runtime.js").runWebSearch; let activateSecretsRuntimeSnapshot: typeof import("../secrets/runtime.js").activateSecretsRuntimeSnapshot; @@ -78,17 +120,9 @@ describe("web search runtime", () => { it("executes searches through the active plugin registry", async () => { resolveRuntimeWebSearchProvidersMock.mockReturnValue([ - createProvider({ - pluginId: "custom-search", - id: "custom", + createCustomSearchProvider({ credentialPath: "tools.web.search.custom.apiKey", - autoDetectOrder: 1, getCredentialValue: () => "configured", - createTool: () => ({ - description: "custom", - parameters: {}, - execute: async (args) => ({ ...args, ok: true }), - }), }), ]); @@ -104,40 +138,11 @@ describe("web search runtime", () => { }); it("auto-detects a provider from canonical plugin-owned credentials", async () => { - const provider = createProvider({ - pluginId: "custom-search", - id: "custom", - credentialPath: "plugins.entries.custom-search.config.webSearch.apiKey", - autoDetectOrder: 1, - getConfiguredCredentialValue: (config) => { - const pluginConfig = config?.plugins?.entries?.["custom-search"]?.config as - | TestPluginWebSearchConfig - | undefined; - return pluginConfig?.webSearch?.apiKey; - }, - createTool: () => ({ - description: "custom", - parameters: {}, - execute: async (args) => ({ ...args, ok: true }), - }), - }); + const provider = createCustomSearchProvider(); resolveRuntimeWebSearchProvidersMock.mockReturnValue([provider]); resolvePluginWebSearchProvidersMock.mockReturnValue([provider]); - const config: OpenClawConfig = { - plugins: { - entries: { - "custom-search": { - enabled: true, - config: { - webSearch: { - apiKey: "custom-config-key", - }, - }, - }, - }, - }, - }; + const config = createCustomSearchConfig("custom-config-key"); await expect( runWebSearch({ @@ -151,44 +156,15 @@ describe("web search runtime", () => { }); it("treats non-env SecretRefs as configured credentials for provider auto-detect", async () => { - const provider = createProvider({ - pluginId: "custom-search", - id: "custom", - credentialPath: "plugins.entries.custom-search.config.webSearch.apiKey", - autoDetectOrder: 1, - getConfiguredCredentialValue: (config) => { - const pluginConfig = config?.plugins?.entries?.["custom-search"]?.config as - | TestPluginWebSearchConfig - | undefined; - return pluginConfig?.webSearch?.apiKey; - }, - createTool: () => ({ - description: "custom", - parameters: {}, - execute: async (args) => ({ ...args, ok: true }), - }), - }); + const provider = createCustomSearchProvider(); resolveRuntimeWebSearchProvidersMock.mockReturnValue([provider]); resolvePluginWebSearchProvidersMock.mockReturnValue([provider]); - const config: OpenClawConfig = { - plugins: { - entries: { - "custom-search": { - enabled: true, - config: { - webSearch: { - apiKey: { - source: "file", - provider: "vault", - id: "/providers/custom-search/apiKey", - }, - }, - }, - }, - }, - }, - }; + const config = createCustomSearchConfig({ + source: "file", + provider: "vault", + id: "/providers/custom-search/apiKey", + }); await expect( runWebSearch({ @@ -203,12 +179,7 @@ describe("web search runtime", () => { it("falls back to a keyless provider when no credentials are available", async () => { resolveRuntimeWebSearchProvidersMock.mockReturnValue([ - createProvider({ - pluginId: "duckduckgo", - id: "duckduckgo", - credentialPath: "", - autoDetectOrder: 100, - requiresCredential: false, + createDuckDuckGoSearchProvider({ getCredentialValue: () => "duckduckgo-no-key-needed", }), ]); @@ -226,7 +197,7 @@ describe("web search runtime", () => { it("prefers the active runtime-selected provider when callers omit runtime metadata", async () => { resolveRuntimeWebSearchProvidersMock.mockReturnValue([ - createProvider({ + createWebSearchTestProvider({ pluginId: "alpha-search", id: "alpha", credentialPath: "tools.web.search.alpha.apiKey", @@ -242,7 +213,7 @@ describe("web search runtime", () => { }), }), }), - createProvider({ + createWebSearchTestProvider({ pluginId: "beta-search", id: "beta", credentialPath: "tools.web.search.beta.apiKey", @@ -292,12 +263,7 @@ describe("web search runtime", () => { it("falls back to another provider when auto-selected search execution fails", async () => { resolveRuntimeWebSearchProvidersMock.mockReturnValue([ - createProvider({ - pluginId: "google", - id: "google", - credentialPath: "tools.web.search.google.apiKey", - autoDetectOrder: 1, - getCredentialValue: () => "configured", + createGoogleSearchProvider({ createTool: () => ({ description: "google", parameters: {}, @@ -306,18 +272,7 @@ describe("web search runtime", () => { }, }), }), - createProvider({ - pluginId: "duckduckgo", - id: "duckduckgo", - credentialPath: "", - autoDetectOrder: 100, - requiresCredential: false, - createTool: () => ({ - description: "duckduckgo", - parameters: {}, - execute: async (args) => ({ ...args, provider: "duckduckgo" }), - }), - }), + createDuckDuckGoSearchProvider(), ]); await expect( @@ -333,19 +288,8 @@ describe("web search runtime", () => { it("does not prebuild fallback provider tools before attempting the selected provider", async () => { resolveRuntimeWebSearchProvidersMock.mockReturnValue([ - createProvider({ - pluginId: "google", - id: "google", - credentialPath: "tools.web.search.google.apiKey", - autoDetectOrder: 1, - getCredentialValue: () => "configured", - createTool: () => ({ - description: "google", - parameters: {}, - execute: async (args) => ({ ...args, provider: "google" }), - }), - }), - createProvider({ + createGoogleSearchProvider(), + createWebSearchTestProvider({ pluginId: "broken-fallback", id: "broken-fallback", credentialPath: "", @@ -370,12 +314,7 @@ describe("web search runtime", () => { it("does not fall back when the provider came from explicit config selection", async () => { resolveRuntimeWebSearchProvidersMock.mockReturnValue([ - createProvider({ - pluginId: "google", - id: "google", - credentialPath: "tools.web.search.google.apiKey", - autoDetectOrder: 1, - getCredentialValue: () => "configured", + createGoogleSearchProvider({ createTool: () => ({ description: "google", parameters: {}, @@ -384,18 +323,7 @@ describe("web search runtime", () => { }, }), }), - createProvider({ - pluginId: "duckduckgo", - id: "duckduckgo", - credentialPath: "", - autoDetectOrder: 100, - requiresCredential: false, - createTool: () => ({ - description: "duckduckgo", - parameters: {}, - execute: async (args) => ({ ...args, provider: "duckduckgo" }), - }), - }), + createDuckDuckGoSearchProvider(), ]); await expect( @@ -416,12 +344,7 @@ describe("web search runtime", () => { it("does not fall back when the caller explicitly selects a provider", async () => { resolveRuntimeWebSearchProvidersMock.mockReturnValue([ - createProvider({ - pluginId: "google", - id: "google", - credentialPath: "tools.web.search.google.apiKey", - autoDetectOrder: 1, - getCredentialValue: () => "configured", + createGoogleSearchProvider({ createTool: () => ({ description: "google", parameters: {}, @@ -430,13 +353,7 @@ describe("web search runtime", () => { }, }), }), - createProvider({ - pluginId: "duckduckgo", - id: "duckduckgo", - credentialPath: "", - autoDetectOrder: 100, - requiresCredential: false, - }), + createDuckDuckGoSearchProvider(), ]); await expect( @@ -450,21 +367,10 @@ describe("web search runtime", () => { it("fails fast when an explicit provider cannot create a tool", async () => { resolveRuntimeWebSearchProvidersMock.mockReturnValue([ - createProvider({ - pluginId: "google", - id: "google", - credentialPath: "tools.web.search.google.apiKey", - autoDetectOrder: 1, - getCredentialValue: () => "configured", + createGoogleSearchProvider({ createTool: () => null, }), - createProvider({ - pluginId: "duckduckgo", - id: "duckduckgo", - credentialPath: "", - autoDetectOrder: 100, - requiresCredential: false, - }), + createDuckDuckGoSearchProvider(), ]); await expect( @@ -478,20 +384,8 @@ describe("web search runtime", () => { it("fails fast when the caller explicitly selects an unknown provider", async () => { resolveRuntimeWebSearchProvidersMock.mockReturnValue([ - createProvider({ - pluginId: "google", - id: "google", - credentialPath: "tools.web.search.google.apiKey", - autoDetectOrder: 1, - getCredentialValue: () => "configured", - }), - createProvider({ - pluginId: "duckduckgo", - id: "duckduckgo", - credentialPath: "", - autoDetectOrder: 100, - requiresCredential: false, - }), + createGoogleSearchProvider(), + createDuckDuckGoSearchProvider(), ]); await expect( @@ -505,23 +399,12 @@ describe("web search runtime", () => { it("still falls back when config names an unknown provider id", async () => { resolveRuntimeWebSearchProvidersMock.mockReturnValue([ - createProvider({ - pluginId: "google", - id: "google", - credentialPath: "tools.web.search.google.apiKey", - autoDetectOrder: 1, - getCredentialValue: () => "configured", + createGoogleSearchProvider({ createTool: () => { throw new Error("google aborted"); }, }), - createProvider({ - pluginId: "duckduckgo", - id: "duckduckgo", - credentialPath: "", - autoDetectOrder: 100, - requiresCredential: false, - }), + createDuckDuckGoSearchProvider(), ]); await expect( @@ -547,14 +430,8 @@ describe("web search runtime", () => { }); it("honors preferRuntimeProviders during execution", async () => { - const configuredProvider = createProvider({ - pluginId: "google", - id: "google", - credentialPath: "tools.web.search.google.apiKey", - autoDetectOrder: 1, - getCredentialValue: () => "configured", - }); - const runtimeProvider = createProvider({ + const configuredProvider = createGoogleSearchProvider(); + const runtimeProvider = createWebSearchTestProvider({ pluginId: "runtime-search", id: "runtime-search", credentialPath: "", @@ -592,20 +469,10 @@ describe("web search runtime", () => { it("returns a clear error when every fallback-capable provider is unavailable", async () => { resolveRuntimeWebSearchProvidersMock.mockReturnValue([ - createProvider({ - pluginId: "google", - id: "google", - credentialPath: "tools.web.search.google.apiKey", - autoDetectOrder: 1, - getCredentialValue: () => "configured", + createGoogleSearchProvider({ createTool: () => null, }), - createProvider({ - pluginId: "duckduckgo", - id: "duckduckgo", - credentialPath: "", - autoDetectOrder: 100, - requiresCredential: false, + createDuckDuckGoSearchProvider({ createTool: () => null, }), ]); diff --git a/test/git-hooks-pre-commit.test.ts b/test/git-hooks-pre-commit.test.ts index a2f637005d2..9bd5f2fd00f 100644 --- a/test/git-hooks-pre-commit.test.ts +++ b/test/git-hooks-pre-commit.test.ts @@ -26,6 +26,33 @@ function writeExecutable(dir: string, name: string, contents: string): void { }); } +function installPreCommitFixture(dir: string): string { + mkdirSync(path.join(dir, "git-hooks"), { recursive: true }); + mkdirSync(path.join(dir, "scripts", "pre-commit"), { recursive: true }); + symlinkSync( + path.join(process.cwd(), "git-hooks", "pre-commit"), + path.join(dir, "git-hooks", "pre-commit"), + ); + writeFileSync( + path.join(dir, "scripts", "pre-commit", "run-node-tool.sh"), + "#!/usr/bin/env bash\nexit 0\n", + { + encoding: "utf8", + mode: 0o755, + }, + ); + writeFileSync( + path.join(dir, "scripts", "pre-commit", "filter-staged-files.mjs"), + "process.exit(0);\n", + "utf8", + ); + + const fakeBinDir = path.join(dir, "bin"); + mkdirSync(fakeBinDir, { recursive: true }); + writeExecutable(fakeBinDir, "node", "#!/usr/bin/env bash\nexit 0\n"); + return fakeBinDir; +} + afterEach(() => { cleanupTempDirs(tempDirs); }); @@ -36,28 +63,7 @@ describe("git-hooks/pre-commit (integration)", () => { run(dir, "git", ["init", "-q", "--initial-branch=main"]); // Use the real hook script and lightweight helper stubs. - mkdirSync(path.join(dir, "git-hooks"), { recursive: true }); - mkdirSync(path.join(dir, "scripts", "pre-commit"), { recursive: true }); - symlinkSync( - path.join(process.cwd(), "git-hooks", "pre-commit"), - path.join(dir, "git-hooks", "pre-commit"), - ); - writeFileSync( - path.join(dir, "scripts", "pre-commit", "run-node-tool.sh"), - "#!/usr/bin/env bash\nexit 0\n", - { - encoding: "utf8", - mode: 0o755, - }, - ); - writeFileSync( - path.join(dir, "scripts", "pre-commit", "filter-staged-files.mjs"), - "process.exit(0);\n", - "utf8", - ); - const fakeBinDir = path.join(dir, "bin"); - mkdirSync(fakeBinDir, { recursive: true }); - writeExecutable(fakeBinDir, "node", "#!/usr/bin/env bash\nexit 0\n"); + const fakeBinDir = installPreCommitFixture(dir); // The hook ends with `pnpm check`, but this fixture is only exercising staged-file handling. // Stub pnpm too so Windows CI does not invoke a real package-manager command in the temp repo. writeExecutable(fakeBinDir, "pnpm", "#!/usr/bin/env bash\nexit 0\n"); @@ -82,31 +88,10 @@ describe("git-hooks/pre-commit (integration)", () => { const dir = makeTempRepoRoot(tempDirs, "openclaw-pre-commit-yolo-"); run(dir, "git", ["init", "-q", "--initial-branch=main"]); - mkdirSync(path.join(dir, "git-hooks"), { recursive: true }); - mkdirSync(path.join(dir, "scripts", "pre-commit"), { recursive: true }); - symlinkSync( - path.join(process.cwd(), "git-hooks", "pre-commit"), - path.join(dir, "git-hooks", "pre-commit"), - ); - writeFileSync( - path.join(dir, "scripts", "pre-commit", "run-node-tool.sh"), - "#!/usr/bin/env bash\nexit 0\n", - { - encoding: "utf8", - mode: 0o755, - }, - ); - writeFileSync( - path.join(dir, "scripts", "pre-commit", "filter-staged-files.mjs"), - "process.exit(0);\n", - "utf8", - ); + const fakeBinDir = installPreCommitFixture(dir); writeFileSync(path.join(dir, "package.json"), '{"name":"tmp"}\n', "utf8"); writeFileSync(path.join(dir, "pnpm-lock.yaml"), "lockfileVersion: '9.0'\n", "utf8"); - const fakeBinDir = path.join(dir, "bin"); - mkdirSync(fakeBinDir, { recursive: true }); - writeExecutable(fakeBinDir, "node", "#!/usr/bin/env bash\nexit 0\n"); writeExecutable( fakeBinDir, "pnpm", diff --git a/test/helpers/media-generation/dashscope-video-provider.ts b/test/helpers/media-generation/dashscope-video-provider.ts new file mode 100644 index 00000000000..efae883402c --- /dev/null +++ b/test/helpers/media-generation/dashscope-video-provider.ts @@ -0,0 +1,118 @@ +import type { VideoGenerationResult } from "openclaw/plugin-sdk/video-generation"; +import { expect, vi } from "vitest"; + +type ClearableMock = { + mockClear(): unknown; +}; + +type ResettableMock = { + mockReset(): unknown; +}; + +type ResolvableMock = { + mockResolvedValue(value: unknown): unknown; +}; + +type ChainableResolvedValueMock = ResettableMock & { + mockResolvedValueOnce(value: unknown): ChainableResolvedValueMock; +}; + +export type DashscopeVideoProviderMocks = { + resolveApiKeyForProviderMock: ClearableMock; + postJsonRequestMock: ResettableMock & ResolvableMock; + fetchWithTimeoutMock: ChainableResolvedValueMock; + assertOkOrThrowHttpErrorMock: ClearableMock; + resolveProviderHttpRequestConfigMock: ClearableMock; +}; + +export function resetDashscopeVideoProviderMocks(mocks: DashscopeVideoProviderMocks): void { + mocks.resolveApiKeyForProviderMock.mockClear(); + mocks.postJsonRequestMock.mockReset(); + mocks.fetchWithTimeoutMock.mockReset(); + mocks.assertOkOrThrowHttpErrorMock.mockClear(); + mocks.resolveProviderHttpRequestConfigMock.mockClear(); +} + +export function mockSuccessfulDashscopeVideoTask( + mocks: Pick, + params: { + requestId?: string; + taskId?: string; + taskStatus?: string; + videoUrl?: string; + } = {}, +): void { + const { + requestId = "req-1", + taskId = "task-1", + taskStatus = "SUCCEEDED", + videoUrl = "https://example.com/out.mp4", + } = params; + mocks.postJsonRequestMock.mockResolvedValue({ + response: { + json: async () => ({ + request_id: requestId, + output: { + task_id: taskId, + }, + }), + }, + release: vi.fn(async () => {}), + }); + mocks.fetchWithTimeoutMock + .mockResolvedValueOnce({ + json: async () => ({ + output: { + task_status: taskStatus, + results: [{ video_url: videoUrl }], + }, + }), + headers: new Headers(), + }) + .mockResolvedValueOnce({ + arrayBuffer: async () => Buffer.from("mp4-bytes"), + headers: new Headers({ "content-type": "video/mp4" }), + }); +} + +export function expectDashscopeVideoTaskPoll( + fetchWithTimeoutMock: ChainableResolvedValueMock, + params: { + baseUrl?: string; + taskId?: string; + timeoutMs?: number; + } = {}, +): void { + const { + baseUrl = "https://dashscope-intl.aliyuncs.com", + taskId = "task-1", + timeoutMs = 120_000, + } = params; + expect(fetchWithTimeoutMock).toHaveBeenNthCalledWith( + 1, + `${baseUrl}/api/v1/tasks/${taskId}`, + expect.objectContaining({ method: "GET" }), + timeoutMs, + fetch, + ); +} + +export function expectSuccessfulDashscopeVideoResult( + result: VideoGenerationResult, + params: { + requestId?: string; + taskId?: string; + taskStatus?: string; + } = {}, +): void { + const { requestId = "req-1", taskId = "task-1", taskStatus = "SUCCEEDED" } = params; + expect(result.videos).toHaveLength(1); + expect(result.videos[0]?.mimeType).toBe("video/mp4"); + expect(result.metadata).toEqual( + expect.objectContaining({ + requestId, + taskId, + taskStatus, + }), + ); +} diff --git a/test/helpers/media-generation/provider-http-mocks.ts b/test/helpers/media-generation/provider-http-mocks.ts new file mode 100644 index 00000000000..337469973ac --- /dev/null +++ b/test/helpers/media-generation/provider-http-mocks.ts @@ -0,0 +1,44 @@ +import type { resolveProviderHttpRequestConfig } from "openclaw/plugin-sdk/provider-http"; +import { afterEach, vi } from "vitest"; + +type ResolveProviderHttpRequestConfigParams = Parameters< + typeof resolveProviderHttpRequestConfig +>[0]; + +const providerHttpMocks = vi.hoisted(() => ({ + resolveApiKeyForProviderMock: vi.fn(async () => ({ apiKey: "provider-key" })), + postJsonRequestMock: vi.fn(), + fetchWithTimeoutMock: vi.fn(), + assertOkOrThrowHttpErrorMock: vi.fn(async () => {}), + resolveProviderHttpRequestConfigMock: vi.fn((params: ResolveProviderHttpRequestConfigParams) => ({ + baseUrl: params.baseUrl ?? params.defaultBaseUrl, + allowPrivateNetwork: false, + headers: new Headers(params.defaultHeaders), + dispatcherPolicy: undefined, + })), +})); + +vi.mock("openclaw/plugin-sdk/provider-auth-runtime", () => ({ + resolveApiKeyForProvider: providerHttpMocks.resolveApiKeyForProviderMock, +})); + +vi.mock("openclaw/plugin-sdk/provider-http", () => ({ + assertOkOrThrowHttpError: providerHttpMocks.assertOkOrThrowHttpErrorMock, + fetchWithTimeout: providerHttpMocks.fetchWithTimeoutMock, + postJsonRequest: providerHttpMocks.postJsonRequestMock, + resolveProviderHttpRequestConfig: providerHttpMocks.resolveProviderHttpRequestConfigMock, +})); + +export function getProviderHttpMocks() { + return providerHttpMocks; +} + +export function installProviderHttpMockCleanup(): void { + afterEach(() => { + providerHttpMocks.resolveApiKeyForProviderMock.mockClear(); + providerHttpMocks.postJsonRequestMock.mockReset(); + providerHttpMocks.fetchWithTimeoutMock.mockReset(); + providerHttpMocks.assertOkOrThrowHttpErrorMock.mockClear(); + providerHttpMocks.resolveProviderHttpRequestConfigMock.mockClear(); + }); +} diff --git a/test/helpers/media-generation/runtime-test-mocks.ts b/test/helpers/media-generation/runtime-test-mocks.ts new file mode 100644 index 00000000000..f17bd6de238 --- /dev/null +++ b/test/helpers/media-generation/runtime-test-mocks.ts @@ -0,0 +1,45 @@ +type ClearableMock = { + mockClear(): unknown; +}; + +type ResettableMock = { + mockReset(): unknown; +}; + +type ResettableReturnMock = ResettableMock & { + mockReturnValue(value: unknown): unknown; +}; + +export type GenerationRuntimeMocks = { + createSubsystemLogger: ClearableMock; + describeFailoverError: ResettableMock; + getProvider: ResettableReturnMock; + getProviderEnvVars: ResettableReturnMock; + resolveProviderAuthEnvVarCandidates: ResettableReturnMock; + isFailoverError: ResettableReturnMock; + listProviders: ResettableReturnMock; + parseModelRef: ClearableMock; + resolveAgentModelFallbackValues: ResettableReturnMock; + resolveAgentModelPrimaryValue: ResettableReturnMock; + debug: ResettableMock; +}; + +export function resetGenerationRuntimeMocks(mocks: GenerationRuntimeMocks): void { + mocks.createSubsystemLogger.mockClear(); + mocks.describeFailoverError.mockReset(); + mocks.getProvider.mockReset(); + mocks.getProviderEnvVars.mockReset(); + mocks.getProviderEnvVars.mockReturnValue([]); + mocks.resolveProviderAuthEnvVarCandidates.mockReset(); + mocks.resolveProviderAuthEnvVarCandidates.mockReturnValue({}); + mocks.isFailoverError.mockReset(); + mocks.isFailoverError.mockReturnValue(false); + mocks.listProviders.mockReset(); + mocks.listProviders.mockReturnValue([]); + mocks.parseModelRef.mockClear(); + mocks.resolveAgentModelFallbackValues.mockReset(); + mocks.resolveAgentModelFallbackValues.mockReturnValue([]); + mocks.resolveAgentModelPrimaryValue.mockReset(); + mocks.resolveAgentModelPrimaryValue.mockReturnValue(undefined); + mocks.debug.mockReset(); +} diff --git a/test/scripts/postinstall-bundled-plugins.test.ts b/test/scripts/postinstall-bundled-plugins.test.ts index 874a733f81a..ab61bec0a74 100644 --- a/test/scripts/postinstall-bundled-plugins.test.ts +++ b/test/scripts/postinstall-bundled-plugins.test.ts @@ -31,10 +31,14 @@ async function writePluginPackage( } describe("bundled plugin postinstall", () => { - function createBareNpmRunner(args: string[]) { + function createNpmInstallArgs(...packages: string[]) { + return ["install", "--omit=dev", "--no-save", "--package-lock=false", ...packages]; + } + + function createBareNpmRunner(packages: string[]) { return { command: "npm", - args, + args: createNpmInstallArgs(...packages), env: { HOME: "/tmp/home", PATH: "/tmp/node/bin", @@ -43,6 +47,24 @@ describe("bundled plugin postinstall", () => { }; } + function expectNpmInstallSpawn( + spawnSync: ReturnType, + packageRoot: string, + packages: string[], + ) { + expect(spawnSync).toHaveBeenCalledWith("npm", createNpmInstallArgs(...packages), { + cwd: packageRoot, + encoding: "utf8", + env: { + HOME: "/tmp/home", + PATH: "/tmp/node/bin", + }, + shell: false, + stdio: "pipe", + windowsVerbatimArguments: undefined, + }); + } + it("clears global npm config before nested installs", () => { expect( createNestedNpmInstallEnv({ @@ -70,13 +92,7 @@ describe("bundled plugin postinstall", () => { env: { HOME: "/tmp/home" }, extensionsDir, packageRoot, - npmRunner: createBareNpmRunner([ - "install", - "--omit=dev", - "--no-save", - "--package-lock=false", - "acpx@0.4.1", - ]), + npmRunner: createBareNpmRunner(["acpx@0.4.1"]), spawnSync, log: { log: vi.fn(), warn: vi.fn() }, }); @@ -103,32 +119,12 @@ describe("bundled plugin postinstall", () => { }, extensionsDir, packageRoot, - npmRunner: createBareNpmRunner([ - "install", - "--omit=dev", - "--no-save", - "--package-lock=false", - "acpx@0.4.1", - ]), + npmRunner: createBareNpmRunner(["acpx@0.4.1"]), spawnSync, log: { log: vi.fn(), warn: vi.fn() }, }); - expect(spawnSync).toHaveBeenCalledWith( - "npm", - ["install", "--omit=dev", "--no-save", "--package-lock=false", "acpx@0.4.1"], - { - cwd: packageRoot, - encoding: "utf8", - env: { - HOME: "/tmp/home", - PATH: "/tmp/node/bin", - }, - shell: false, - stdio: "pipe", - windowsVerbatimArguments: undefined, - }, - ); + expectNpmInstallSpawn(spawnSync, packageRoot, ["acpx@0.4.1"]); }); it("skips reinstall when the bundled sentinel package already exists", async () => { @@ -237,40 +233,12 @@ describe("bundled plugin postinstall", () => { }, extensionsDir, packageRoot, - npmRunner: createBareNpmRunner([ - "install", - "--omit=dev", - "--no-save", - "--package-lock=false", - "@slack/web-api@7.11.0", - "grammy@1.38.4", - ]), + npmRunner: createBareNpmRunner(["@slack/web-api@7.11.0", "grammy@1.38.4"]), spawnSync, log: { log: vi.fn(), warn: vi.fn() }, }); - expect(spawnSync).toHaveBeenCalledWith( - "npm", - [ - "install", - "--omit=dev", - "--no-save", - "--package-lock=false", - "@slack/web-api@7.11.0", - "grammy@1.38.4", - ], - { - cwd: packageRoot, - encoding: "utf8", - env: { - HOME: "/tmp/home", - PATH: "/tmp/node/bin", - }, - shell: false, - stdio: "pipe", - windowsVerbatimArguments: undefined, - }, - ); + expectNpmInstallSpawn(spawnSync, packageRoot, ["@slack/web-api@7.11.0", "grammy@1.38.4"]); }); it("installs only missing bundled plugin runtime deps", async () => { @@ -301,32 +269,12 @@ describe("bundled plugin postinstall", () => { }, extensionsDir, packageRoot, - npmRunner: createBareNpmRunner([ - "install", - "--omit=dev", - "--no-save", - "--package-lock=false", - "grammy@1.38.4", - ]), + npmRunner: createBareNpmRunner(["grammy@1.38.4"]), spawnSync, log: { log: vi.fn(), warn: vi.fn() }, }); - expect(spawnSync).toHaveBeenCalledWith( - "npm", - ["install", "--omit=dev", "--no-save", "--package-lock=false", "grammy@1.38.4"], - { - cwd: packageRoot, - encoding: "utf8", - env: { - HOME: "/tmp/home", - PATH: "/tmp/node/bin", - }, - shell: false, - stdio: "pipe", - windowsVerbatimArguments: undefined, - }, - ); + expectNpmInstallSpawn(spawnSync, packageRoot, ["grammy@1.38.4"]); }); it("installs bundled plugin deps when npm location is global", async () => { @@ -347,31 +295,11 @@ describe("bundled plugin postinstall", () => { }, extensionsDir, packageRoot, - npmRunner: createBareNpmRunner([ - "install", - "--omit=dev", - "--no-save", - "--package-lock=false", - "grammy@1.38.4", - ]), + npmRunner: createBareNpmRunner(["grammy@1.38.4"]), spawnSync, log: { log: vi.fn(), warn: vi.fn() }, }); - expect(spawnSync).toHaveBeenCalledWith( - "npm", - ["install", "--omit=dev", "--no-save", "--package-lock=false", "grammy@1.38.4"], - { - cwd: packageRoot, - encoding: "utf8", - env: { - HOME: "/tmp/home", - PATH: "/tmp/node/bin", - }, - shell: false, - stdio: "pipe", - windowsVerbatimArguments: undefined, - }, - ); + expectNpmInstallSpawn(spawnSync, packageRoot, ["grammy@1.38.4"]); }); });