From d55c7ea997c095337d47070b8e66d66ca0a32722 Mon Sep 17 00:00:00 2001 From: Vincent Koc Date: Tue, 28 Apr 2026 03:45:54 -0700 Subject: [PATCH] fix(plugins): bound prompt memory recall latency --- extensions/memory-lancedb/index.test.ts | 96 +++++++++++++++++++ extensions/memory-lancedb/index.ts | 50 +++++++++- .../hooks.model-override-wiring.test.ts | 38 ++++++++ src/plugins/hooks.ts | 39 ++++++-- 4 files changed, 213 insertions(+), 10 deletions(-) diff --git a/extensions/memory-lancedb/index.test.ts b/extensions/memory-lancedb/index.test.ts index 45747570dfd..e1c0c584208 100644 --- a/extensions/memory-lancedb/index.test.ts +++ b/extensions/memory-lancedb/index.test.ts @@ -602,6 +602,102 @@ describe("memory plugin e2e", () => { } }); + test("bounds auto-recall latency during prompt build", async () => { + vi.useFakeTimers(); + const post = vi.fn(() => new Promise(() => undefined)); + const ensureGlobalUndiciEnvProxyDispatcher = vi.fn(); + const loadLanceDbModule = vi.fn(async () => ({ + connect: vi.fn(async () => ({ + tableNames: vi.fn(async () => ["memories"]), + openTable: vi.fn(async () => ({ + vectorSearch: vi.fn(() => ({ limit: vi.fn(() => ({ toArray: vi.fn(async () => []) })) })), + countRows: vi.fn(async () => 0), + add: vi.fn(async () => undefined), + delete: vi.fn(async () => undefined), + })), + })), + })); + + vi.resetModules(); + vi.doMock("openclaw/plugin-sdk/runtime-env", () => ({ + ensureGlobalUndiciEnvProxyDispatcher, + })); + vi.doMock("openai", () => ({ + default: class MockOpenAI { + post = post; + }, + })); + vi.doMock("./lancedb-runtime.js", () => ({ + loadLanceDbModule, + })); + + try { + const { default: dynamicMemoryPlugin } = await import("./index.js"); + const on = vi.fn(); + const logger = { + info: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + debug: vi.fn(), + }; + const mockApi = { + id: "memory-lancedb", + name: "Memory (LanceDB)", + source: "test", + config: {}, + pluginConfig: { + embedding: { + apiKey: OPENAI_API_KEY, + model: "text-embedding-3-small", + }, + dbPath: getDbPath(), + autoCapture: false, + autoRecall: true, + }, + runtime: {}, + logger, + registerTool: vi.fn(), + registerCli: vi.fn(), + registerService: vi.fn(), + on, + resolvePath: (p: string) => p, + }; + + dynamicMemoryPlugin.register(mockApi as any); + + const beforePromptBuild = on.mock.calls.find( + ([hookName]) => hookName === "before_prompt_build", + )?.[1]; + expect(beforePromptBuild).toBeTypeOf("function"); + + const resultPromise = beforePromptBuild?.( + { prompt: "what editor should i use?", messages: [] }, + {}, + ); + await vi.advanceTimersByTimeAsync(15_000); + + await expect(resultPromise).resolves.toBeUndefined(); + expect(ensureGlobalUndiciEnvProxyDispatcher).toHaveBeenCalledOnce(); + expect(post).toHaveBeenCalledWith( + "/embeddings", + expect.objectContaining({ + maxRetries: 0, + timeout: 15_000, + }), + ); + expect(loadLanceDbModule).not.toHaveBeenCalled(); + expect(logger.warn).toHaveBeenCalledWith( + "memory-lancedb: auto-recall timed out after 15000ms; skipping memory injection to avoid stalling agent startup", + ); + } finally { + vi.doUnmock("openclaw/plugin-sdk/runtime-env"); + vi.doUnmock("openai"); + vi.doUnmock("./lancedb-runtime.js"); + vi.resetModules(); + vi.useRealTimers(); + } + }); + test("uses live runtime config to enable auto-recall after startup disable", async () => { const embeddingsCreate = vi.fn(async () => ({ data: [{ embedding: [0.1, 0.2, 0.3] }], diff --git a/extensions/memory-lancedb/index.ts b/extensions/memory-lancedb/index.ts index 1f9bcfc9e93..c5fa7051361 100644 --- a/extensions/memory-lancedb/index.ts +++ b/extensions/memory-lancedb/index.ts @@ -149,6 +149,7 @@ function resolveAutoCaptureStartIndex( // ============================================================================ const TABLE_NAME = "memories"; +const DEFAULT_AUTO_RECALL_TIMEOUT_MS = 15_000; class MemoryDB { private db: LanceDB.Connection | null = null; @@ -262,7 +263,7 @@ class MemoryDB { // ============================================================================ type Embeddings = { - embed(text: string): Promise; + embed(text: string, options?: { timeoutMs?: number }): Promise; }; class OpenAiCompatibleEmbeddings implements Embeddings { @@ -277,7 +278,7 @@ class OpenAiCompatibleEmbeddings implements Embeddings { this.client = new OpenAI({ apiKey, baseURL: baseUrl }); } - async embed(text: string): Promise { + async embed(text: string, options?: { timeoutMs?: number }): Promise { const params: OpenAI.EmbeddingCreateParams = { model: this.model, input: text, @@ -292,6 +293,7 @@ class OpenAiCompatibleEmbeddings implements Embeddings { // transport and normalize the response ourselves. const response = await this.client.post("/embeddings", { body: params, + ...(options?.timeoutMs ? { timeout: options.timeoutMs, maxRetries: 0 } : {}), }); return normalizeEmbeddingVector(response.data?.[0]?.embedding); } @@ -353,6 +355,32 @@ class ProviderAdapterEmbeddings implements Embeddings { } } +async function runWithTimeout(params: { + timeoutMs: number; + task: () => Promise; +}): Promise<{ status: "ok"; value: T } | { status: "timeout" }> { + let timeout: ReturnType | undefined; + const TIMEOUT = Symbol("timeout"); + const timeoutPromise = new Promise((resolve) => { + timeout = setTimeout(() => resolve(TIMEOUT), params.timeoutMs); + timeout.unref?.(); + }); + const taskPromise = params.task(); + taskPromise.catch(() => undefined); + + try { + const result = await Promise.race([taskPromise, timeoutPromise]); + if (result === TIMEOUT) { + return { status: "timeout" }; + } + return { status: "ok", value: result }; + } finally { + if (timeout) { + clearTimeout(timeout); + } + } +} + function createEmbeddings(api: OpenClawPluginApi, cfg: MemoryConfig): Embeddings { const { provider, model, dimensions, apiKey, baseUrl } = cfg.embedding; if (provider === "openai" && apiKey) { @@ -818,8 +846,22 @@ export default definePluginEntry({ event.prompt, currentCfg.recallMaxChars, ); - const vector = await embeddings.embed(recallQuery); - const results = await db.search(vector, 3, 0.3); + const recall = await runWithTimeout({ + timeoutMs: DEFAULT_AUTO_RECALL_TIMEOUT_MS, + task: async () => { + const vector = await embeddings.embed(recallQuery, { + timeoutMs: DEFAULT_AUTO_RECALL_TIMEOUT_MS, + }); + return await db.search(vector, 3, 0.3); + }, + }); + if (recall.status === "timeout") { + api.logger.warn?.( + `memory-lancedb: auto-recall timed out after ${DEFAULT_AUTO_RECALL_TIMEOUT_MS}ms; skipping memory injection to avoid stalling agent startup`, + ); + return undefined; + } + const results = recall.value; if (results.length === 0) { return undefined; diff --git a/src/plugins/hooks.model-override-wiring.test.ts b/src/plugins/hooks.model-override-wiring.test.ts index 193c15e4041..7ebcf8c8fda 100644 --- a/src/plugins/hooks.model-override-wiring.test.ts +++ b/src/plugins/hooks.model-override-wiring.test.ts @@ -226,6 +226,44 @@ describe("model override pipeline wiring", () => { expectedPrependContext, }); }); + + it("skips timed-out handlers and continues", async () => { + vi.useFakeTimers(); + try { + addBeforePromptBuildHook( + registry, + "slow-plugin", + () => new Promise(() => undefined), + 10, + ); + addBeforePromptBuildHook(registry, "fast-plugin", () => ({ prependContext: "fast" }), 1); + const logger = { + error: vi.fn(), + warn: vi.fn(), + info: vi.fn(), + debug: vi.fn(), + }; + const runner = createHookRunner(registry, { + logger, + modifyingHookTimeoutMsByHook: { before_prompt_build: 5 }, + }); + + const resultPromise = runner.runBeforePromptBuild( + { prompt: "test", messages: [] }, + stubCtx, + ); + await vi.advanceTimersByTimeAsync(5); + + await expect(resultPromise).resolves.toEqual({ prependContext: "fast" }); + expect(logger.error).toHaveBeenCalledWith( + expect.stringContaining( + "[hooks] before_prompt_build handler from slow-plugin failed: timed out after 5ms", + ), + ); + } finally { + vi.useRealTimers(); + } + }); }); describe("graceful degradation + hook detection", () => { diff --git a/src/plugins/hooks.ts b/src/plugins/hooks.ts index 3f78d7193c7..1429070289b 100644 --- a/src/plugins/hooks.ts +++ b/src/plugins/hooks.ts @@ -160,11 +160,19 @@ export type HookRunnerOptions = { * the runner continues, but the plugin's underlying work is not cancelled. */ voidHookTimeoutMsByHook?: Partial>; + /** + * Optional timeout for modifying hooks. A timed-out hook is logged and skipped, + * but the plugin's underlying work is not cancelled. + */ + modifyingHookTimeoutMsByHook?: Partial>; }; const DEFAULT_VOID_HOOK_TIMEOUT_MS_BY_HOOK: Partial> = { agent_end: 30_000, }; +const DEFAULT_MODIFYING_HOOK_TIMEOUT_MS_BY_HOOK: Partial> = { + before_prompt_build: 15_000, +}; type ModifyingHookPolicy = { mergeResults?: ( @@ -236,6 +244,10 @@ export function createHookRunner( ...DEFAULT_VOID_HOOK_TIMEOUT_MS_BY_HOOK, ...options.voidHookTimeoutMsByHook, }; + const modifyingHookTimeoutMsByHook = { + ...DEFAULT_MODIFYING_HOOK_TIMEOUT_MS_BY_HOOK, + ...options.modifyingHookTimeoutMsByHook, + }; const shouldCatchHookErrors = (hookName: PluginHookName): boolean => catchErrors && (failurePolicyByHook[hookName] ?? "fail-open") === "fail-open"; @@ -385,13 +397,27 @@ export function createHookRunner( return Math.floor(timeoutMs); }; - const withVoidHookTimeout = async (promise: Promise, timeoutMs: number): Promise => { + const getModifyingHookTimeoutMs = (hookName: PluginHookName): number | undefined => { + const timeoutMs = modifyingHookTimeoutMsByHook[hookName]; + if (typeof timeoutMs !== "number" || !Number.isFinite(timeoutMs) || timeoutMs <= 0) { + return undefined; + } + return Math.floor(timeoutMs); + }; + + const withHookTimeout = async ( + promise: Promise, + timeoutMs: number, + options: { unref?: boolean } = {}, + ): Promise => { let timer: ReturnType | undefined; const timeout = new Promise((_, reject) => { timer = setTimeout(() => { reject(new Error(`timed out after ${timeoutMs}ms`)); }, timeoutMs); - timer.unref?.(); + if (options.unref) { + timer.unref?.(); + } }); try { @@ -435,7 +461,7 @@ export function createHookRunner( ); const timeoutMs = getVoidHookTimeoutMs(hookName); if (timeoutMs) { - await withVoidHookTimeout(promise, timeoutMs); + await withHookTimeout(promise, timeoutMs, { unref: true }); } else { await promise; } @@ -468,9 +494,10 @@ export function createHookRunner( for (const hook of hooks) { try { - const handlerResult = await ( - hook.handler as (event: unknown, ctx: unknown) => Promise - )(event, ctx); + const handler = hook.handler as (event: unknown, ctx: unknown) => Promise; + const promise = Promise.resolve(handler(event, ctx)); + const timeoutMs = getModifyingHookTimeoutMs(hookName); + const handlerResult = timeoutMs ? await withHookTimeout(promise, timeoutMs) : await promise; if (handlerResult !== undefined && handlerResult !== null) { if (policy.mergeResults) {