refactor: add provider replay runtime hook surfaces (#59143)

Merged via squash.

Prepared head SHA: 56b41e87a5
Co-authored-by: jalehman <550978+jalehman@users.noreply.github.com>
Co-authored-by: jalehman <550978+jalehman@users.noreply.github.com>
Reviewed-by: @jalehman
This commit is contained in:
Josh Lehman
2026-04-01 13:45:41 -07:00
committed by GitHub
parent ca76e2fedc
commit 71346940ad
15 changed files with 771 additions and 102 deletions

View File

@@ -1,3 +1,4 @@
import type { AgentMessage } from "@mariozechner/pi-agent-core";
import { beforeEach, describe, expect, it, vi } from "vitest";
import {
expectAugmentedCodexCatalog,
@@ -5,7 +6,14 @@ import {
expectCodexMissingAuthHint,
expectedAugmentedOpenaiCodexCatalogEntries,
} from "./provider-runtime.test-support.js";
import type { ProviderPlugin, ProviderRuntimeModel } from "./types.js";
import type {
AnyAgentTool,
ProviderNormalizeToolSchemasContext,
ProviderPlugin,
ProviderRuntimeModel,
ProviderSanitizeReplayHistoryContext,
ProviderValidateReplayTurnsContext,
} from "./types.js";
type ResolvePluginProviders = typeof import("./providers.runtime.js").resolvePluginProviders;
type ResolveCatalogHookProviderPluginIds =
@@ -41,11 +49,15 @@ let resolveProviderBuiltInModelSuppression: typeof import("./provider-runtime.js
let createProviderEmbeddingProvider: typeof import("./provider-runtime.js").createProviderEmbeddingProvider;
let resolveProviderDefaultThinkingLevel: typeof import("./provider-runtime.js").resolveProviderDefaultThinkingLevel;
let resolveProviderModernModelRef: typeof import("./provider-runtime.js").resolveProviderModernModelRef;
let resolveProviderReasoningOutputModeWithPlugin: typeof import("./provider-runtime.js").resolveProviderReasoningOutputModeWithPlugin;
let resolveProviderReplayPolicyWithPlugin: typeof import("./provider-runtime.js").resolveProviderReplayPolicyWithPlugin;
let resolveProviderSyntheticAuthWithPlugin: typeof import("./provider-runtime.js").resolveProviderSyntheticAuthWithPlugin;
let sanitizeProviderReplayHistoryWithPlugin: typeof import("./provider-runtime.js").sanitizeProviderReplayHistoryWithPlugin;
let resolveProviderUsageSnapshotWithPlugin: typeof import("./provider-runtime.js").resolveProviderUsageSnapshotWithPlugin;
let resolveProviderCapabilitiesWithPlugin: typeof import("./provider-runtime.js").resolveProviderCapabilitiesWithPlugin;
let resolveProviderUsageAuthWithPlugin: typeof import("./provider-runtime.js").resolveProviderUsageAuthWithPlugin;
let resolveProviderXHighThinking: typeof import("./provider-runtime.js").resolveProviderXHighThinking;
let normalizeProviderToolSchemasWithPlugin: typeof import("./provider-runtime.js").normalizeProviderToolSchemasWithPlugin;
let normalizeProviderResolvedModelWithPlugin: typeof import("./provider-runtime.js").normalizeProviderResolvedModelWithPlugin;
let prepareProviderDynamicModel: typeof import("./provider-runtime.js").prepareProviderDynamicModel;
let prepareProviderRuntimeAuth: typeof import("./provider-runtime.js").prepareProviderRuntimeAuth;
@@ -53,6 +65,7 @@ let resetProviderRuntimeHookCacheForTest: typeof import("./provider-runtime.js")
let refreshProviderOAuthCredentialWithPlugin: typeof import("./provider-runtime.js").refreshProviderOAuthCredentialWithPlugin;
let resolveProviderRuntimePlugin: typeof import("./provider-runtime.js").resolveProviderRuntimePlugin;
let runProviderDynamicModel: typeof import("./provider-runtime.js").runProviderDynamicModel;
let validateProviderReplayTurnsWithPlugin: typeof import("./provider-runtime.js").validateProviderReplayTurnsWithPlugin;
let wrapProviderStreamFn: typeof import("./provider-runtime.js").wrapProviderStreamFn;
const MODEL: ProviderRuntimeModel = {
@@ -69,6 +82,31 @@ const MODEL: ProviderRuntimeModel = {
};
const DEMO_PROVIDER_ID = "demo";
const EMPTY_MODEL_REGISTRY = { find: () => null } as never;
const DEMO_REPLAY_MESSAGES: AgentMessage[] = [{ role: "user", content: "hello", timestamp: 1 }];
const DEMO_SANITIZED_MESSAGE: AgentMessage = {
role: "assistant",
content: [{ type: "text", text: "sanitized" }],
api: MODEL.api,
provider: MODEL.provider,
model: MODEL.id,
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "stop",
timestamp: 2,
};
const DEMO_TOOL = {
name: "demo-tool",
label: "Demo tool",
description: "Demo tool",
parameters: { type: "object", properties: {} },
execute: vi.fn(async () => ({ content: [], details: undefined })),
} as unknown as AnyAgentTool;
function createOpenAiCatalogProviderPlugin(
overrides: Partial<ProviderPlugin> = {},
@@ -228,11 +266,15 @@ describe("provider-runtime", () => {
createProviderEmbeddingProvider,
resolveProviderDefaultThinkingLevel,
resolveProviderModernModelRef,
resolveProviderReasoningOutputModeWithPlugin,
resolveProviderReplayPolicyWithPlugin,
resolveProviderSyntheticAuthWithPlugin,
sanitizeProviderReplayHistoryWithPlugin,
resolveProviderUsageSnapshotWithPlugin,
resolveProviderCapabilitiesWithPlugin,
resolveProviderUsageAuthWithPlugin,
resolveProviderXHighThinking,
normalizeProviderToolSchemasWithPlugin,
normalizeProviderResolvedModelWithPlugin,
prepareProviderDynamicModel,
prepareProviderRuntimeAuth,
@@ -240,6 +282,7 @@ describe("provider-runtime", () => {
refreshProviderOAuthCredentialWithPlugin,
resolveProviderRuntimePlugin,
runProviderDynamicModel,
validateProviderReplayTurnsWithPlugin,
wrapProviderStreamFn,
} = await import("./provider-runtime.js"));
resetProviderRuntimeHookCacheForTest();
@@ -428,6 +471,28 @@ describe("provider-runtime", () => {
embedBatch: async () => [[1, 0, 0]],
client: { token: "embed-token" },
}));
const buildReplayPolicy = vi.fn(() => ({
sanitizeMode: "full" as const,
toolCallIdMode: "strict9" as const,
allowSyntheticToolResults: true,
}));
const sanitizeReplayHistory = vi.fn(
async ({
messages,
}: Pick<ProviderSanitizeReplayHistoryContext, "messages">): Promise<AgentMessage[]> => [
...messages,
DEMO_SANITIZED_MESSAGE,
],
);
const validateReplayTurns = vi.fn(
async ({
messages,
}: Pick<ProviderValidateReplayTurnsContext, "messages">): Promise<AgentMessage[]> => messages,
);
const normalizeToolSchemas = vi.fn(
({ tools }: Pick<ProviderNormalizeToolSchemasContext, "tools">): AnyAgentTool[] => tools,
);
const resolveReasoningOutputMode = vi.fn(() => "tagged" as const);
const resolveSyntheticAuth = vi.fn(() => ({
apiKey: "demo-local",
source: "models.providers.demo (synthetic local key)",
@@ -478,6 +543,11 @@ describe("provider-runtime", () => {
capabilities: {
providerFamily: "openai",
},
buildReplayPolicy,
sanitizeReplayHistory,
validateReplayTurns,
normalizeToolSchemas,
resolveReasoningOutputMode,
prepareExtraParams: ({ extraParams }) => ({
...extraParams,
transport: "auto",
@@ -608,6 +678,28 @@ describe("provider-runtime", () => {
providerFamily: "openai",
});
expect(
resolveProviderReplayPolicyWithPlugin({
provider: DEMO_PROVIDER_ID,
context: createDemoResolvedModelContext({
modelApi: MODEL.api,
}),
}),
).toMatchObject({
sanitizeMode: "full",
toolCallIdMode: "strict9",
allowSyntheticToolResults: true,
});
expect(
resolveProviderReasoningOutputModeWithPlugin({
provider: DEMO_PROVIDER_ID,
context: createDemoResolvedModelContext({
modelApi: MODEL.api,
}),
}),
).toBe("tagged");
expect(
prepareProviderExtraParams({
provider: DEMO_PROVIDER_ID,
@@ -710,6 +802,34 @@ describe("provider-runtime", () => {
windows: [{ label: "Day", usedPercent: 25 }],
},
},
{
actual: () =>
sanitizeProviderReplayHistoryWithPlugin({
provider: DEMO_PROVIDER_ID,
context: createDemoResolvedModelContext({
modelApi: MODEL.api,
sessionId: "session-1",
messages: DEMO_REPLAY_MESSAGES,
}),
}),
expected: {
1: DEMO_SANITIZED_MESSAGE,
},
},
{
actual: () =>
validateProviderReplayTurnsWithPlugin({
provider: DEMO_PROVIDER_ID,
context: createDemoResolvedModelContext({
modelApi: MODEL.api,
sessionId: "session-1",
messages: DEMO_REPLAY_MESSAGES,
}),
}),
expected: {
0: DEMO_REPLAY_MESSAGES[0],
},
},
]);
expect(
@@ -721,6 +841,16 @@ describe("provider-runtime", () => {
}),
).toBeTypeOf("function");
expect(
normalizeProviderToolSchemasWithPlugin({
provider: DEMO_PROVIDER_ID,
context: createDemoResolvedModelContext({
modelApi: MODEL.api,
tools: [DEMO_TOOL],
}),
}),
).toEqual([DEMO_TOOL]);
expect(
normalizeProviderResolvedModelWithPlugin({
provider: DEMO_PROVIDER_ID,
@@ -855,7 +985,12 @@ describe("provider-runtime", () => {
await expectAugmentedCodexCatalog(augmentModelCatalogWithProviderPlugins);
expectCalledOnce(
buildReplayPolicy,
prepareDynamicModel,
sanitizeReplayHistory,
validateReplayTurns,
normalizeToolSchemas,
resolveReasoningOutputMode,
refreshOAuth,
resolveSyntheticAuth,
buildUnknownModelHint,