mirror of
https://github.com/openclaw/openclaw.git
synced 2026-04-01 20:31:19 +00:00
test: stabilize trigger handling and hook e2e tests
This commit is contained in:
@@ -4,14 +4,12 @@ import path from "node:path";
|
||||
import type { AssistantMessage } from "@mariozechner/pi-ai";
|
||||
import { afterEach, beforeAll, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import type { OpenClawConfig } from "../config/config.js";
|
||||
import { registerLogTransport, resetLogger, setLoggerOverride } from "../logging/logger.js";
|
||||
import { redactIdentifier } from "../logging/redact-identifier.js";
|
||||
import type { AuthProfileFailureReason } from "./auth-profiles.js";
|
||||
import type { EmbeddedRunAttemptResult } from "./pi-embedded-runner/run/types.js";
|
||||
|
||||
const runEmbeddedAttemptMock = vi.fn<(params: unknown) => Promise<EmbeddedRunAttemptResult>>();
|
||||
const resolveCopilotApiTokenMock = vi.fn();
|
||||
const COPILOT_TOKEN_URL = "https://api.github.com/copilot_internal/v2/token";
|
||||
const { computeBackoffMock, sleepWithAbortMock } = vi.hoisted(() => ({
|
||||
computeBackoffMock: vi.fn(
|
||||
(
|
||||
@@ -22,63 +20,121 @@ const { computeBackoffMock, sleepWithAbortMock } = vi.hoisted(() => ({
|
||||
sleepWithAbortMock: vi.fn(async (_ms: number, _abortSignal?: AbortSignal) => undefined),
|
||||
}));
|
||||
|
||||
vi.mock("./pi-embedded-runner/run/attempt.js", () => ({
|
||||
runEmbeddedAttempt: (params: unknown) => runEmbeddedAttemptMock(params),
|
||||
}));
|
||||
|
||||
vi.mock("../infra/backoff.js", () => ({
|
||||
computeBackoff: (
|
||||
policy: { initialMs: number; maxMs: number; factor: number; jitter: number },
|
||||
attempt: number,
|
||||
) => computeBackoffMock(policy, attempt),
|
||||
sleepWithAbort: (ms: number, abortSignal?: AbortSignal) => sleepWithAbortMock(ms, abortSignal),
|
||||
}));
|
||||
|
||||
vi.mock("../../extensions/github-copilot/token.js", () => ({
|
||||
DEFAULT_COPILOT_API_BASE_URL: "https://api.individual.githubcopilot.com",
|
||||
resolveCopilotApiToken: (...args: unknown[]) => resolveCopilotApiTokenMock(...args),
|
||||
}));
|
||||
|
||||
vi.mock("./pi-embedded-runner/compact.js", () => ({
|
||||
compactEmbeddedPiSessionDirect: vi.fn(async () => {
|
||||
throw new Error("compact should not run in auth profile rotation tests");
|
||||
}),
|
||||
}));
|
||||
|
||||
vi.mock("./models-config.js", async (importOriginal) => {
|
||||
const mod = await importOriginal<typeof import("./models-config.js")>();
|
||||
return {
|
||||
...mod,
|
||||
ensureOpenClawModelsJson: vi.fn(async () => ({ wrote: false })),
|
||||
};
|
||||
});
|
||||
const installRunEmbeddedMocks = () => {
|
||||
vi.doMock("../plugins/hook-runner-global.js", () => ({
|
||||
getGlobalHookRunner: vi.fn(() => undefined),
|
||||
}));
|
||||
vi.doMock("../context-engine/index.js", () => ({
|
||||
ensureContextEnginesInitialized: vi.fn(),
|
||||
resolveContextEngine: vi.fn(async () => ({
|
||||
dispose: async () => undefined,
|
||||
})),
|
||||
}));
|
||||
vi.doMock("./runtime-plugins.js", () => ({
|
||||
ensureRuntimePluginsLoaded: vi.fn(),
|
||||
}));
|
||||
vi.doMock("./pi-embedded-runner/model.js", () => ({
|
||||
resolveModelAsync: async (provider: string, modelId: string) => ({
|
||||
model: {
|
||||
id: modelId,
|
||||
name: modelId,
|
||||
api: "openai-responses",
|
||||
provider,
|
||||
baseUrl:
|
||||
provider === "github-copilot" ? "https://api.copilot.example" : "https://example.com",
|
||||
reasoning: false,
|
||||
input: ["text"],
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
|
||||
contextWindow: 16_000,
|
||||
maxTokens: 2048,
|
||||
},
|
||||
error: undefined,
|
||||
authStorage: {
|
||||
setRuntimeApiKey: vi.fn(),
|
||||
},
|
||||
modelRegistry: {},
|
||||
}),
|
||||
}));
|
||||
vi.doMock("./pi-embedded-runner/run/attempt.js", () => ({
|
||||
runEmbeddedAttempt: (params: unknown) => runEmbeddedAttemptMock(params),
|
||||
}));
|
||||
vi.doMock("../plugins/provider-runtime.runtime.js", () => ({
|
||||
prepareProviderRuntimeAuth: async (params: {
|
||||
provider: string;
|
||||
context: { apiKey: string };
|
||||
}) => {
|
||||
if (params.provider !== "github-copilot") {
|
||||
return undefined;
|
||||
}
|
||||
const token = await resolveCopilotApiTokenMock(params.context.apiKey);
|
||||
return {
|
||||
apiKey: token.token,
|
||||
baseUrl: token.baseUrl,
|
||||
expiresAt: token.expiresAt,
|
||||
};
|
||||
},
|
||||
}));
|
||||
vi.doMock("../infra/backoff.js", () => ({
|
||||
computeBackoff: (
|
||||
policy: { initialMs: number; maxMs: number; factor: number; jitter: number },
|
||||
attempt: number,
|
||||
) => computeBackoffMock(policy, attempt),
|
||||
sleepWithAbort: (ms: number, abortSignal?: AbortSignal) => sleepWithAbortMock(ms, abortSignal),
|
||||
}));
|
||||
vi.doMock("./pi-embedded-runner/compact.js", () => ({
|
||||
compactEmbeddedPiSessionDirect: vi.fn(async () => {
|
||||
throw new Error("compact should not run in auth profile rotation tests");
|
||||
}),
|
||||
}));
|
||||
vi.doMock("./models-config.js", async (importOriginal) => {
|
||||
const mod = await importOriginal<typeof import("./models-config.js")>();
|
||||
return {
|
||||
...mod,
|
||||
ensureOpenClawModelsJson: vi.fn(async () => ({ wrote: false })),
|
||||
};
|
||||
});
|
||||
};
|
||||
|
||||
let runEmbeddedPiAgent: typeof import("./pi-embedded-runner/run.js").runEmbeddedPiAgent;
|
||||
let unregisterLogTransport: (() => void) | undefined;
|
||||
let registerLogTransportFn: typeof import("../logging/logger.js").registerLogTransport;
|
||||
let resetLoggerFn: typeof import("../logging/logger.js").resetLogger;
|
||||
let setLoggerOverrideFn: typeof import("../logging/logger.js").setLoggerOverride;
|
||||
const originalFetch = globalThis.fetch;
|
||||
|
||||
beforeAll(async () => {
|
||||
vi.resetModules();
|
||||
installRunEmbeddedMocks();
|
||||
({ runEmbeddedPiAgent } = await import("./pi-embedded-runner/run.js"));
|
||||
({
|
||||
registerLogTransport: registerLogTransportFn,
|
||||
resetLogger: resetLoggerFn,
|
||||
setLoggerOverride: setLoggerOverrideFn,
|
||||
} = await import("../logging/logger.js"));
|
||||
});
|
||||
|
||||
async function runEmbeddedPiAgentInline(
|
||||
params: Parameters<typeof runEmbeddedPiAgent>[0],
|
||||
): Promise<Awaited<ReturnType<typeof runEmbeddedPiAgent>>> {
|
||||
return await runEmbeddedPiAgent({
|
||||
...params,
|
||||
enqueue: async (task) => await task(),
|
||||
});
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.useRealTimers();
|
||||
runEmbeddedAttemptMock.mockClear();
|
||||
runEmbeddedAttemptMock.mockReset();
|
||||
runEmbeddedAttemptMock.mockImplementation(async () => {
|
||||
throw new Error("unexpected extra runEmbeddedAttempt call");
|
||||
});
|
||||
resolveCopilotApiTokenMock.mockReset();
|
||||
resolveCopilotApiTokenMock.mockImplementation(async () => {
|
||||
throw new Error("unexpected extra Copilot token refresh");
|
||||
});
|
||||
globalThis.fetch = vi.fn(async (input: string | URL | Request) => {
|
||||
const url = typeof input === "string" ? input : input instanceof URL ? input.href : input.url;
|
||||
if (url !== COPILOT_TOKEN_URL) {
|
||||
throw new Error(`Unexpected fetch in test: ${url}`);
|
||||
}
|
||||
const token = await resolveCopilotApiTokenMock();
|
||||
return {
|
||||
ok: true,
|
||||
status: 200,
|
||||
json: async () => ({
|
||||
token: token.token,
|
||||
expires_at: Math.floor(token.expiresAt / 1000),
|
||||
}),
|
||||
} as Response;
|
||||
throw new Error(`Unexpected fetch in test: ${url}`);
|
||||
}) as typeof fetch;
|
||||
computeBackoffMock.mockClear();
|
||||
sleepWithAbortMock.mockClear();
|
||||
@@ -88,8 +144,8 @@ afterEach(() => {
|
||||
globalThis.fetch = originalFetch;
|
||||
unregisterLogTransport?.();
|
||||
unregisterLogTransport = undefined;
|
||||
setLoggerOverride(null);
|
||||
resetLogger();
|
||||
setLoggerOverrideFn(null);
|
||||
resetLoggerFn();
|
||||
});
|
||||
|
||||
const baseUsage = {
|
||||
@@ -324,7 +380,7 @@ async function runAutoPinnedOpenAiTurn(params: {
|
||||
runId: string;
|
||||
authProfileId?: string;
|
||||
}) {
|
||||
await runEmbeddedPiAgent({
|
||||
await runEmbeddedPiAgentInline({
|
||||
sessionId: "session:test",
|
||||
sessionKey: params.sessionKey,
|
||||
sessionFile: path.join(params.workspaceDir, "session.jsonl"),
|
||||
@@ -368,7 +424,7 @@ async function runAutoPinnedRotationCase(params: {
|
||||
sessionKey: string;
|
||||
runId: string;
|
||||
}) {
|
||||
runEmbeddedAttemptMock.mockClear();
|
||||
runEmbeddedAttemptMock.mockReset();
|
||||
return withAgentWorkspace(async ({ agentDir, workspaceDir }) => {
|
||||
await writeAuthStore(agentDir);
|
||||
mockFailedThenSuccessfulAttempt(params.errorMessage);
|
||||
@@ -390,7 +446,7 @@ async function runAutoPinnedPromptErrorRotationCase(params: {
|
||||
sessionKey: string;
|
||||
runId: string;
|
||||
}) {
|
||||
runEmbeddedAttemptMock.mockClear();
|
||||
runEmbeddedAttemptMock.mockReset();
|
||||
return withAgentWorkspace(async ({ agentDir, workspaceDir }) => {
|
||||
await writeAuthStore(agentDir);
|
||||
mockPromptErrorThenSuccessfulAttempt(params.errorMessage);
|
||||
@@ -486,7 +542,7 @@ async function runTurnWithCooldownSeed(params: {
|
||||
});
|
||||
mockSingleSuccessfulAttempt();
|
||||
|
||||
await runEmbeddedPiAgent({
|
||||
await runEmbeddedPiAgentInline({
|
||||
sessionId: "session:test",
|
||||
sessionKey: params.sessionKey,
|
||||
sessionFile: path.join(workspaceDir, "session.jsonl"),
|
||||
@@ -518,7 +574,9 @@ describe("runEmbeddedPiAgent auth profile rotation", () => {
|
||||
resolveCopilotApiTokenMock
|
||||
.mockResolvedValueOnce({
|
||||
token: "copilot-initial",
|
||||
expiresAt: now + 2 * 60 * 1000,
|
||||
// Keep expiry beyond the runtime refresh margin so the test only
|
||||
// exercises auth-error refresh, not the background scheduler.
|
||||
expiresAt: now + 10 * 60 * 1000,
|
||||
source: "mock",
|
||||
baseUrl: "https://api.copilot.example",
|
||||
})
|
||||
@@ -549,7 +607,7 @@ describe("runEmbeddedPiAgent auth profile rotation", () => {
|
||||
}),
|
||||
);
|
||||
|
||||
await runEmbeddedPiAgent({
|
||||
await runEmbeddedPiAgentInline({
|
||||
sessionId: "session:test",
|
||||
sessionKey: "agent:test:copilot-auth-error",
|
||||
sessionFile: path.join(workspaceDir, "session.jsonl"),
|
||||
@@ -582,13 +640,14 @@ describe("runEmbeddedPiAgent auth profile rotation", () => {
|
||||
resolveCopilotApiTokenMock
|
||||
.mockResolvedValueOnce({
|
||||
token: "copilot-initial",
|
||||
expiresAt: now + 2 * 60 * 1000,
|
||||
// Avoid an immediate scheduled refresh racing the explicit auth retry.
|
||||
expiresAt: now + 10 * 60 * 1000,
|
||||
source: "mock",
|
||||
baseUrl: "https://api.copilot.example",
|
||||
})
|
||||
.mockResolvedValueOnce({
|
||||
token: "copilot-refresh-1",
|
||||
expiresAt: now + 4 * 60 * 1000,
|
||||
expiresAt: now + 10 * 60 * 1000,
|
||||
source: "mock",
|
||||
baseUrl: "https://api.copilot.example",
|
||||
})
|
||||
@@ -633,7 +692,7 @@ describe("runEmbeddedPiAgent auth profile rotation", () => {
|
||||
}),
|
||||
);
|
||||
|
||||
await runEmbeddedPiAgent({
|
||||
await runEmbeddedPiAgentInline({
|
||||
sessionId: "session:test",
|
||||
sessionKey: "agent:test:copilot-auth-repeat",
|
||||
sessionFile: path.join(workspaceDir, "session.jsonl"),
|
||||
@@ -647,7 +706,6 @@ describe("runEmbeddedPiAgent auth profile rotation", () => {
|
||||
timeoutMs: 5_000,
|
||||
runId: "run:copilot-auth-repeat",
|
||||
});
|
||||
|
||||
expect(runEmbeddedAttemptMock).toHaveBeenCalledTimes(4);
|
||||
expect(resolveCopilotApiTokenMock).toHaveBeenCalledTimes(3);
|
||||
} finally {
|
||||
@@ -682,7 +740,7 @@ describe("runEmbeddedPiAgent auth profile rotation", () => {
|
||||
}),
|
||||
);
|
||||
|
||||
const runPromise = runEmbeddedPiAgent({
|
||||
const runPromise = runEmbeddedPiAgentInline({
|
||||
sessionId: "session:test",
|
||||
sessionKey: "agent:test:copilot-shutdown",
|
||||
sessionFile: path.join(workspaceDir, "session.jsonl"),
|
||||
@@ -744,12 +802,12 @@ describe("runEmbeddedPiAgent auth profile rotation", () => {
|
||||
|
||||
it("logs structured failover decision metadata for overloaded assistant rotation", async () => {
|
||||
const records: Array<Record<string, unknown>> = [];
|
||||
setLoggerOverride({
|
||||
setLoggerOverrideFn({
|
||||
level: "trace",
|
||||
consoleLevel: "silent",
|
||||
file: path.join(os.tmpdir(), `openclaw-auth-rotation-${Date.now()}.log`),
|
||||
});
|
||||
unregisterLogTransport = registerLogTransport((record) => {
|
||||
unregisterLogTransport = registerLogTransportFn((record) => {
|
||||
records.push(record);
|
||||
});
|
||||
|
||||
@@ -858,7 +916,7 @@ describe("runEmbeddedPiAgent auth profile rotation", () => {
|
||||
}),
|
||||
);
|
||||
|
||||
const result = await runEmbeddedPiAgent({
|
||||
const result = await runEmbeddedPiAgentInline({
|
||||
sessionId: "session:test",
|
||||
sessionKey: "agent:test:compaction-timeout",
|
||||
sessionFile: path.join(workspaceDir, "session.jsonl"),
|
||||
@@ -887,7 +945,7 @@ describe("runEmbeddedPiAgent auth profile rotation", () => {
|
||||
|
||||
mockSingleErrorAttempt({ errorMessage: "rate limit" });
|
||||
|
||||
await runEmbeddedPiAgent({
|
||||
await runEmbeddedPiAgentInline({
|
||||
sessionId: "session:test",
|
||||
sessionKey: "agent:test:user",
|
||||
sessionFile: path.join(workspaceDir, "session.jsonl"),
|
||||
@@ -935,7 +993,7 @@ describe("runEmbeddedPiAgent auth profile rotation", () => {
|
||||
}),
|
||||
);
|
||||
|
||||
await runEmbeddedPiAgent({
|
||||
await runEmbeddedPiAgentInline({
|
||||
sessionId: "session:test",
|
||||
sessionKey: "agent:test:mismatch",
|
||||
sessionFile: path.join(workspaceDir, "session.jsonl"),
|
||||
@@ -977,7 +1035,7 @@ describe("runEmbeddedPiAgent auth profile rotation", () => {
|
||||
});
|
||||
|
||||
await expect(
|
||||
runEmbeddedPiAgent({
|
||||
runEmbeddedPiAgentInline({
|
||||
sessionId: "session:test",
|
||||
sessionKey: "agent:test:cooldown-failover",
|
||||
sessionFile: path.join(workspaceDir, "session.jsonl"),
|
||||
@@ -1021,7 +1079,7 @@ describe("runEmbeddedPiAgent auth profile rotation", () => {
|
||||
}),
|
||||
);
|
||||
|
||||
const result = await runEmbeddedPiAgent({
|
||||
const result = await runEmbeddedPiAgentInline({
|
||||
sessionId: "session:test",
|
||||
sessionKey: "agent:test:cooldown-probe",
|
||||
sessionFile: path.join(workspaceDir, "session.jsonl"),
|
||||
@@ -1069,7 +1127,7 @@ describe("runEmbeddedPiAgent auth profile rotation", () => {
|
||||
}),
|
||||
);
|
||||
|
||||
const result = await runEmbeddedPiAgent({
|
||||
const result = await runEmbeddedPiAgentInline({
|
||||
sessionId: "session:test",
|
||||
sessionKey: "agent:test:overloaded-cooldown-probe",
|
||||
sessionFile: path.join(workspaceDir, "session.jsonl"),
|
||||
@@ -1117,7 +1175,7 @@ describe("runEmbeddedPiAgent auth profile rotation", () => {
|
||||
}),
|
||||
);
|
||||
|
||||
const result = await runEmbeddedPiAgent({
|
||||
const result = await runEmbeddedPiAgentInline({
|
||||
sessionId: "session:test",
|
||||
sessionKey: "agent:test:billing-cooldown-probe-no-fallbacks",
|
||||
sessionFile: path.join(workspaceDir, "session.jsonl"),
|
||||
@@ -1148,7 +1206,7 @@ describe("runEmbeddedPiAgent auth profile rotation", () => {
|
||||
});
|
||||
|
||||
await expect(
|
||||
runEmbeddedPiAgent({
|
||||
runEmbeddedPiAgentInline({
|
||||
sessionId: "session:test",
|
||||
sessionKey: "agent:support:cooldown-failover",
|
||||
sessionFile: path.join(workspaceDir, "session.jsonl"),
|
||||
@@ -1193,7 +1251,7 @@ describe("runEmbeddedPiAgent auth profile rotation", () => {
|
||||
});
|
||||
|
||||
await expect(
|
||||
runEmbeddedPiAgent({
|
||||
runEmbeddedPiAgentInline({
|
||||
sessionId: "session:test",
|
||||
sessionKey: "agent:test:disabled-failover",
|
||||
sessionFile: path.join(workspaceDir, "session.jsonl"),
|
||||
@@ -1227,7 +1285,7 @@ describe("runEmbeddedPiAgent auth profile rotation", () => {
|
||||
await fs.writeFile(authPath, JSON.stringify({ version: 1, profiles: {}, usageStats: {} }));
|
||||
|
||||
await expect(
|
||||
runEmbeddedPiAgent({
|
||||
runEmbeddedPiAgentInline({
|
||||
sessionId: "session:test",
|
||||
sessionKey: "agent:test:auth-unavailable",
|
||||
sessionFile: path.join(workspaceDir, "session.jsonl"),
|
||||
@@ -1265,7 +1323,7 @@ describe("runEmbeddedPiAgent auth profile rotation", () => {
|
||||
|
||||
let thrown: unknown;
|
||||
try {
|
||||
await runEmbeddedPiAgent({
|
||||
await runEmbeddedPiAgentInline({
|
||||
sessionId: "session:test",
|
||||
sessionKey: "agent:test:billing-failover-active-model",
|
||||
sessionFile: path.join(workspaceDir, "session.jsonl"),
|
||||
|
||||
@@ -1,58 +1,71 @@
|
||||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { resetDiagnosticSessionStateForTest } from "../logging/diagnostic-session-state.js";
|
||||
import { getGlobalHookRunner } from "../plugins/hook-runner-global.js";
|
||||
import { toClientToolDefinitions, toToolDefinitions } from "./pi-tool-definition-adapter.js";
|
||||
import { wrapToolWithAbortSignal } from "./pi-tools.abort.js";
|
||||
import {
|
||||
__testing as beforeToolCallTesting,
|
||||
consumeAdjustedParamsForToolCall,
|
||||
wrapToolWithBeforeToolCallHook,
|
||||
} from "./pi-tools.before-tool-call.js";
|
||||
initializeGlobalHookRunner,
|
||||
resetGlobalHookRunner,
|
||||
} from "../plugins/hook-runner-global.js";
|
||||
import { createMockPluginRegistry } from "../plugins/hooks.test-helpers.js";
|
||||
|
||||
vi.mock("../plugins/hook-runner-global.js", async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof import("../plugins/hook-runner-global.js")>();
|
||||
return {
|
||||
...actual,
|
||||
getGlobalHookRunner: vi.fn(),
|
||||
};
|
||||
type ToolDefinitionAdapterModule = typeof import("./pi-tool-definition-adapter.js");
|
||||
type PiToolsAbortModule = typeof import("./pi-tools.abort.js");
|
||||
type BeforeToolCallModule = typeof import("./pi-tools.before-tool-call.js");
|
||||
|
||||
type ToClientToolDefinitions = ToolDefinitionAdapterModule["toClientToolDefinitions"];
|
||||
type ToToolDefinitions = ToolDefinitionAdapterModule["toToolDefinitions"];
|
||||
type WrapToolWithAbortSignal = PiToolsAbortModule["wrapToolWithAbortSignal"];
|
||||
type BeforeToolCallTesting = BeforeToolCallModule["__testing"];
|
||||
type ConsumeAdjustedParamsForToolCall = BeforeToolCallModule["consumeAdjustedParamsForToolCall"];
|
||||
type WrapToolWithBeforeToolCallHook = BeforeToolCallModule["wrapToolWithBeforeToolCallHook"];
|
||||
|
||||
let toClientToolDefinitions!: ToClientToolDefinitions;
|
||||
let toToolDefinitions!: ToToolDefinitions;
|
||||
let wrapToolWithAbortSignal!: WrapToolWithAbortSignal;
|
||||
let beforeToolCallTesting!: BeforeToolCallTesting;
|
||||
let consumeAdjustedParamsForToolCall!: ConsumeAdjustedParamsForToolCall;
|
||||
let wrapToolWithBeforeToolCallHook!: WrapToolWithBeforeToolCallHook;
|
||||
|
||||
beforeEach(async () => {
|
||||
if (!wrapToolWithBeforeToolCallHook) {
|
||||
({ toClientToolDefinitions, toToolDefinitions } =
|
||||
await import("./pi-tool-definition-adapter.js"));
|
||||
({ wrapToolWithAbortSignal } = await import("./pi-tools.abort.js"));
|
||||
({
|
||||
__testing: beforeToolCallTesting,
|
||||
consumeAdjustedParamsForToolCall,
|
||||
wrapToolWithBeforeToolCallHook,
|
||||
} = await import("./pi-tools.before-tool-call.js"));
|
||||
}
|
||||
});
|
||||
|
||||
const mockGetGlobalHookRunner = vi.mocked(getGlobalHookRunner);
|
||||
type BeforeToolCallHandlerMock = ReturnType<typeof vi.fn>;
|
||||
|
||||
type HookRunnerMock = {
|
||||
hasHooks: ReturnType<typeof vi.fn>;
|
||||
runBeforeToolCall: ReturnType<typeof vi.fn>;
|
||||
};
|
||||
|
||||
function installMockHookRunner(params?: {
|
||||
hasHooksReturn?: boolean;
|
||||
function installBeforeToolCallHook(params?: {
|
||||
enabled?: boolean;
|
||||
runBeforeToolCallImpl?: (...args: unknown[]) => unknown;
|
||||
}) {
|
||||
const hookRunner: HookRunnerMock = {
|
||||
hasHooks:
|
||||
params?.hasHooksReturn === undefined
|
||||
? vi.fn()
|
||||
: vi.fn(() => params.hasHooksReturn as boolean),
|
||||
runBeforeToolCall: params?.runBeforeToolCallImpl
|
||||
? vi.fn(params.runBeforeToolCallImpl)
|
||||
: vi.fn(),
|
||||
};
|
||||
// oxlint-disable-next-line typescript/no-explicit-any
|
||||
mockGetGlobalHookRunner.mockReturnValue(hookRunner as any);
|
||||
return hookRunner;
|
||||
}): BeforeToolCallHandlerMock {
|
||||
resetGlobalHookRunner();
|
||||
const handler = params?.runBeforeToolCallImpl
|
||||
? vi.fn(params.runBeforeToolCallImpl)
|
||||
: vi.fn(async () => undefined);
|
||||
if (params?.enabled === false) {
|
||||
return handler;
|
||||
}
|
||||
initializeGlobalHookRunner(createMockPluginRegistry([{ hookName: "before_tool_call", handler }]));
|
||||
return handler;
|
||||
}
|
||||
|
||||
describe("before_tool_call hook integration", () => {
|
||||
let hookRunner: HookRunnerMock;
|
||||
let beforeToolCallHook: BeforeToolCallHandlerMock;
|
||||
|
||||
beforeEach(() => {
|
||||
resetGlobalHookRunner();
|
||||
resetDiagnosticSessionStateForTest();
|
||||
beforeToolCallTesting.adjustedParamsByToolCallId.clear();
|
||||
hookRunner = installMockHookRunner();
|
||||
beforeToolCallHook = installBeforeToolCallHook();
|
||||
});
|
||||
|
||||
it("executes tool normally when no hook is registered", async () => {
|
||||
hookRunner.hasHooks.mockReturnValue(false);
|
||||
beforeToolCallHook = installBeforeToolCallHook({ enabled: false });
|
||||
const execute = vi.fn().mockResolvedValue({ content: [], details: { ok: true } });
|
||||
// oxlint-disable-next-line typescript/no-explicit-any
|
||||
const tool = wrapToolWithBeforeToolCallHook({ name: "Read", execute } as any, {
|
||||
@@ -63,7 +76,7 @@ describe("before_tool_call hook integration", () => {
|
||||
|
||||
await tool.execute("call-1", { path: "/tmp/file" }, undefined, extensionContext);
|
||||
|
||||
expect(hookRunner.runBeforeToolCall).not.toHaveBeenCalled();
|
||||
expect(beforeToolCallHook).not.toHaveBeenCalled();
|
||||
expect(execute).toHaveBeenCalledWith(
|
||||
"call-1",
|
||||
{ path: "/tmp/file" },
|
||||
@@ -73,8 +86,9 @@ describe("before_tool_call hook integration", () => {
|
||||
});
|
||||
|
||||
it("allows hook to modify parameters", async () => {
|
||||
hookRunner.hasHooks.mockReturnValue(true);
|
||||
hookRunner.runBeforeToolCall.mockResolvedValue({ params: { mode: "safe" } });
|
||||
beforeToolCallHook = installBeforeToolCallHook({
|
||||
runBeforeToolCallImpl: async () => ({ params: { mode: "safe" } }),
|
||||
});
|
||||
const execute = vi.fn().mockResolvedValue({ content: [], details: { ok: true } });
|
||||
// oxlint-disable-next-line typescript/no-explicit-any
|
||||
const tool = wrapToolWithBeforeToolCallHook({ name: "exec", execute } as any);
|
||||
@@ -91,10 +105,11 @@ describe("before_tool_call hook integration", () => {
|
||||
});
|
||||
|
||||
it("blocks tool execution when hook returns block=true", async () => {
|
||||
hookRunner.hasHooks.mockReturnValue(true);
|
||||
hookRunner.runBeforeToolCall.mockResolvedValue({
|
||||
block: true,
|
||||
blockReason: "blocked",
|
||||
beforeToolCallHook = installBeforeToolCallHook({
|
||||
runBeforeToolCallImpl: async () => ({
|
||||
block: true,
|
||||
blockReason: "blocked",
|
||||
}),
|
||||
});
|
||||
const execute = vi.fn().mockResolvedValue({ content: [], details: { ok: true } });
|
||||
// oxlint-disable-next-line typescript/no-explicit-any
|
||||
@@ -108,8 +123,11 @@ describe("before_tool_call hook integration", () => {
|
||||
});
|
||||
|
||||
it("continues execution when hook throws", async () => {
|
||||
hookRunner.hasHooks.mockReturnValue(true);
|
||||
hookRunner.runBeforeToolCall.mockRejectedValue(new Error("boom"));
|
||||
beforeToolCallHook = installBeforeToolCallHook({
|
||||
runBeforeToolCallImpl: async () => {
|
||||
throw new Error("boom");
|
||||
},
|
||||
});
|
||||
const execute = vi.fn().mockResolvedValue({ content: [], details: { ok: true } });
|
||||
// oxlint-disable-next-line typescript/no-explicit-any
|
||||
const tool = wrapToolWithBeforeToolCallHook({ name: "read", execute } as any);
|
||||
@@ -126,8 +144,9 @@ describe("before_tool_call hook integration", () => {
|
||||
});
|
||||
|
||||
it("normalizes non-object params for hook contract", async () => {
|
||||
hookRunner.hasHooks.mockReturnValue(true);
|
||||
hookRunner.runBeforeToolCall.mockResolvedValue(undefined);
|
||||
beforeToolCallHook = installBeforeToolCallHook({
|
||||
runBeforeToolCallImpl: async () => undefined,
|
||||
});
|
||||
const execute = vi.fn().mockResolvedValue({ content: [], details: { ok: true } });
|
||||
// oxlint-disable-next-line typescript/no-explicit-any
|
||||
const tool = wrapToolWithBeforeToolCallHook({ name: "ReAd", execute } as any, {
|
||||
@@ -140,7 +159,7 @@ describe("before_tool_call hook integration", () => {
|
||||
|
||||
await tool.execute("call-5", "not-an-object", undefined, extensionContext);
|
||||
|
||||
expect(hookRunner.runBeforeToolCall).toHaveBeenCalledWith(
|
||||
expect(beforeToolCallHook).toHaveBeenCalledWith(
|
||||
{
|
||||
toolName: "read",
|
||||
params: {},
|
||||
@@ -159,10 +178,12 @@ describe("before_tool_call hook integration", () => {
|
||||
});
|
||||
|
||||
it("keeps adjusted params isolated per run when toolCallId collides", async () => {
|
||||
hookRunner.hasHooks.mockReturnValue(true);
|
||||
hookRunner.runBeforeToolCall
|
||||
.mockResolvedValueOnce({ params: { marker: "A" } })
|
||||
.mockResolvedValueOnce({ params: { marker: "B" } });
|
||||
beforeToolCallHook = installBeforeToolCallHook({
|
||||
runBeforeToolCallImpl: vi
|
||||
.fn()
|
||||
.mockResolvedValueOnce({ params: { marker: "A" } })
|
||||
.mockResolvedValueOnce({ params: { marker: "B" } }),
|
||||
});
|
||||
const execute = vi.fn().mockResolvedValue({ content: [], details: { ok: true } });
|
||||
// oxlint-disable-next-line typescript/no-explicit-any
|
||||
const toolA = wrapToolWithBeforeToolCallHook({ name: "Read", execute } as any, {
|
||||
@@ -192,12 +213,12 @@ describe("before_tool_call hook integration", () => {
|
||||
});
|
||||
|
||||
describe("before_tool_call hook deduplication (#15502)", () => {
|
||||
let hookRunner: HookRunnerMock;
|
||||
let beforeToolCallHook: BeforeToolCallHandlerMock;
|
||||
|
||||
beforeEach(() => {
|
||||
resetGlobalHookRunner();
|
||||
resetDiagnosticSessionStateForTest();
|
||||
hookRunner = installMockHookRunner({
|
||||
hasHooksReturn: true,
|
||||
beforeToolCallHook = installBeforeToolCallHook({
|
||||
runBeforeToolCallImpl: async () => undefined,
|
||||
});
|
||||
});
|
||||
@@ -221,7 +242,7 @@ describe("before_tool_call hook deduplication (#15502)", () => {
|
||||
extensionContext,
|
||||
);
|
||||
|
||||
expect(hookRunner.runBeforeToolCall).toHaveBeenCalledTimes(1);
|
||||
expect(beforeToolCallHook).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it("fires hook exactly once when tool goes through wrap + abort + toToolDefinitions", async () => {
|
||||
@@ -246,21 +267,21 @@ describe("before_tool_call hook deduplication (#15502)", () => {
|
||||
extensionContext,
|
||||
);
|
||||
|
||||
expect(hookRunner.runBeforeToolCall).toHaveBeenCalledTimes(1);
|
||||
expect(beforeToolCallHook).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe("before_tool_call hook integration for client tools", () => {
|
||||
let hookRunner: HookRunnerMock;
|
||||
|
||||
beforeEach(() => {
|
||||
resetGlobalHookRunner();
|
||||
resetDiagnosticSessionStateForTest();
|
||||
hookRunner = installMockHookRunner();
|
||||
installBeforeToolCallHook();
|
||||
});
|
||||
|
||||
it("passes modified params to client tool callbacks", async () => {
|
||||
hookRunner.hasHooks.mockReturnValue(true);
|
||||
hookRunner.runBeforeToolCall.mockResolvedValue({ params: { extra: true } });
|
||||
installBeforeToolCallHook({
|
||||
runBeforeToolCallImpl: async () => ({ params: { extra: true } }),
|
||||
});
|
||||
const onClientToolCall = vi.fn();
|
||||
const [tool] = toClientToolDefinitions(
|
||||
[
|
||||
|
||||
Reference in New Issue
Block a user