fix(hooks): harden before_tool_call hook runner to fail-closed on error [AI] (#59822)

* fix: address issue

* fix: address PR review feedback

* docs: add changelog entry for PR merge

* docs: normalize changelog entry placement

---------

Co-authored-by: Devin Robison <drobison@nvidia.com>
This commit is contained in:
pgondhi987
2026-04-04 04:14:35 +05:30
committed by GitHub
parent 1322aa2ba2
commit e19dce0aed
7 changed files with 82 additions and 15 deletions

View File

@@ -40,6 +40,9 @@ export function initializeGlobalHookRunner(registry: PluginRegistry): void {
error: (msg) => log.error(msg),
},
catchErrors: true,
failurePolicyByHook: {
before_tool_call: "fail-closed",
},
});
const hookCount = registry.hooks.length;

View File

@@ -163,6 +163,32 @@ describe("before_tool_call terminal block semantics", () => {
expect(result?.block).toBe(true);
expect(low).not.toHaveBeenCalled();
});
it("throws for before_tool_call when configured as fail-closed", async () => {
addStaticTestHooks(registry, {
hookName: "before_tool_call",
hooks: [
{
pluginId: "failing",
result: {},
priority: 100,
handler: () => {
throw new Error("boom");
},
},
],
});
const runner = createHookRunner(registry, {
catchErrors: true,
failurePolicyByHook: {
before_tool_call: "fail-closed",
},
});
await expect(runner.runBeforeToolCall(toolEvent, toolCtx)).rejects.toThrow(
"before_tool_call handler from failing failed: Error: boom",
);
});
});
describe("message_sending terminal cancel semantics", () => {

View File

@@ -124,10 +124,17 @@ export type HookRunnerLogger = {
error: (message: string) => void;
};
export type HookFailurePolicy = "fail-open" | "fail-closed";
export type HookRunnerOptions = {
logger?: HookRunnerLogger;
/** If true, errors in hooks will be caught and logged instead of thrown */
catchErrors?: boolean;
/**
* Optional per-hook failure policy.
* Defaults to fail-open unless explicitly overridden for a hook name.
*/
failurePolicyByHook?: Partial<Record<PluginHookName, HookFailurePolicy>>;
};
type ModifyingHookPolicy<K extends PluginHookName, TResult> = {
@@ -186,6 +193,10 @@ function getHooksForNameAndPlugin<K extends PluginHookName>(
export function createHookRunner(registry: PluginRegistry, options: HookRunnerOptions = {}) {
const logger = options.logger;
const catchErrors = options.catchErrors ?? true;
const failurePolicyByHook = options.failurePolicyByHook ?? {};
const shouldCatchHookErrors = (hookName: PluginHookName): boolean =>
catchErrors && (failurePolicyByHook[hookName] ?? "fail-open") === "fail-open";
const firstDefined = <T>(prev: T | undefined, next: T | undefined): T | undefined => prev ?? next;
const lastDefined = <T>(prev: T | undefined, next: T | undefined): T | undefined => next ?? prev;
@@ -255,7 +266,7 @@ export function createHookRunner(registry: PluginRegistry, options: HookRunnerOp
const msg = `[hooks] ${params.hookName} handler from ${params.pluginId} failed: ${String(
params.error,
)}`;
if (catchErrors) {
if (shouldCatchHookErrors(params.hookName)) {
logger?.error(msg);
return;
}
@@ -797,7 +808,7 @@ export function createHookRunner(registry: PluginRegistry, options: HookRunnerOp
const msg =
`[hooks] tool_result_persist handler from ${hook.pluginId} returned a Promise; ` +
`this hook is synchronous and the result was ignored.`;
if (catchErrors) {
if (shouldCatchHookErrors("tool_result_persist")) {
logger?.warn?.(msg);
continue;
}
@@ -810,7 +821,7 @@ export function createHookRunner(registry: PluginRegistry, options: HookRunnerOp
}
} catch (err) {
const msg = `[hooks] tool_result_persist handler from ${hook.pluginId} failed: ${String(err)}`;
if (catchErrors) {
if (shouldCatchHookErrors("tool_result_persist")) {
logger?.error(msg);
} else {
throw new Error(msg, { cause: err });
@@ -862,7 +873,7 @@ export function createHookRunner(registry: PluginRegistry, options: HookRunnerOp
const msg =
`[hooks] before_message_write handler from ${hook.pluginId} returned a Promise; ` +
`this hook is synchronous and the result was ignored.`;
if (catchErrors) {
if (shouldCatchHookErrors("before_message_write")) {
logger?.warn?.(msg);
continue;
}
@@ -882,7 +893,7 @@ export function createHookRunner(registry: PluginRegistry, options: HookRunnerOp
}
} catch (err) {
const msg = `[hooks] before_message_write handler from ${hook.pluginId} failed: ${String(err)}`;
if (catchErrors) {
if (shouldCatchHookErrors("before_message_write")) {
logger?.error(msg);
} else {
throw new Error(msg, { cause: err });