Agents: adapt pi-ai oauth and payload hooks

This commit is contained in:
Vincent Koc
2026-03-12 10:19:14 -04:00
parent f3be1c828c
commit 2f037f0930
17 changed files with 49 additions and 53 deletions

View File

@@ -29,7 +29,7 @@ describe("createAnthropicPayloadLogger", () => {
],
};
const streamFn: StreamFn = ((model, __, options) => {
options?.onPayload?.(payload);
options?.onPayload?.(payload, model);
return {} as never;
}) as StreamFn;

View File

@@ -145,7 +145,7 @@ export function createAnthropicPayloadLogger(params: {
payload: redactedPayload,
payloadDigest: digest(redactedPayload),
});
return options?.onPayload?.(payload);
return options?.onPayload?.(payload, model);
};
return streamFn(model, context, {
...options,

View File

@@ -17,7 +17,7 @@ const { getOAuthApiKeyMock } = vi.hoisted(() => ({
}),
}));
vi.mock("@mariozechner/pi-ai", () => ({
vi.mock("@mariozechner/pi-ai/oauth", () => ({
getOAuthApiKey: getOAuthApiKeyMock,
getOAuthProviders: () => [
{ id: "openai-codex", envApiKey: "OPENAI_API_KEY", oauthTokenEnv: "OPENAI_OAUTH_TOKEN" }, // pragma: allowlist secret

View File

@@ -1,9 +1,5 @@
import {
getOAuthApiKey,
getOAuthProviders,
type OAuthCredentials,
type OAuthProvider,
} from "@mariozechner/pi-ai";
import type { OAuthCredentials, OAuthProvider } from "@mariozechner/pi-ai";
import { getOAuthApiKey, getOAuthProviders } from "@mariozechner/pi-ai/oauth";
import { loadConfig, type OpenClawConfig } from "../../config/config.js";
import { coerceSecretRef } from "../../config/types.secrets.js";
import { withFileLock } from "../../infra/file-lock.js";

View File

@@ -797,7 +797,7 @@ export function createOpenAIWebSocketStreamFn(
...(prevResponseId ? { previous_response_id: prevResponseId } : {}),
...extraParams,
};
const nextPayload = options?.onPayload?.(payload);
const nextPayload = options?.onPayload?.(payload, model);
const requestPayload = (nextPayload ?? payload) as Parameters<
OpenAIWebSocketManager["send"]
>[0];

View File

@@ -208,7 +208,7 @@ describe("applyExtraParamsToAgent", () => {
}) {
const payload = params.payload ?? { store: false };
const baseStreamFn: StreamFn = (model, _context, options) => {
options?.onPayload?.(payload);
options?.onPayload?.(payload, model);
return {} as ReturnType<StreamFn>;
};
const agent = { streamFn: baseStreamFn };
@@ -233,7 +233,7 @@ describe("applyExtraParamsToAgent", () => {
}) {
const payload = params.payload ?? {};
const baseStreamFn: StreamFn = (model, _context, options) => {
options?.onPayload?.(payload);
options?.onPayload?.(payload, model);
return {} as ReturnType<StreamFn>;
};
const agent = { streamFn: baseStreamFn };
@@ -276,7 +276,7 @@ describe("applyExtraParamsToAgent", () => {
const payloads: Record<string, unknown>[] = [];
const baseStreamFn: StreamFn = (_model, _context, options) => {
const payload: Record<string, unknown> = { model: "deepseek/deepseek-r1" };
options?.onPayload?.(payload);
options?.onPayload?.(payload, model);
payloads.push(payload);
return {} as ReturnType<StreamFn>;
};
@@ -308,7 +308,7 @@ describe("applyExtraParamsToAgent", () => {
const payloads: Record<string, unknown>[] = [];
const baseStreamFn: StreamFn = (_model, _context, options) => {
const payload: Record<string, unknown> = {};
options?.onPayload?.(payload);
options?.onPayload?.(payload, model);
payloads.push(payload);
return {} as ReturnType<StreamFn>;
};
@@ -332,7 +332,7 @@ describe("applyExtraParamsToAgent", () => {
const payloads: Record<string, unknown>[] = [];
const baseStreamFn: StreamFn = (_model, _context, options) => {
const payload: Record<string, unknown> = { reasoning_effort: "high" };
options?.onPayload?.(payload);
options?.onPayload?.(payload, model);
payloads.push(payload);
return {} as ReturnType<StreamFn>;
};
@@ -357,7 +357,7 @@ describe("applyExtraParamsToAgent", () => {
const payloads: Record<string, unknown>[] = [];
const baseStreamFn: StreamFn = (_model, _context, options) => {
const payload: Record<string, unknown> = { reasoning: { max_tokens: 256 } };
options?.onPayload?.(payload);
options?.onPayload?.(payload, model);
payloads.push(payload);
return {} as ReturnType<StreamFn>;
};
@@ -381,7 +381,7 @@ describe("applyExtraParamsToAgent", () => {
const payloads: Record<string, unknown>[] = [];
const baseStreamFn: StreamFn = (_model, _context, options) => {
const payload: Record<string, unknown> = { reasoning_effort: "medium" };
options?.onPayload?.(payload);
options?.onPayload?.(payload, model);
payloads.push(payload);
return {} as ReturnType<StreamFn>;
};
@@ -588,7 +588,7 @@ describe("applyExtraParamsToAgent", () => {
const payloads: Record<string, unknown>[] = [];
const baseStreamFn: StreamFn = (_model, _context, options) => {
const payload: Record<string, unknown> = { thinking: "off" };
options?.onPayload?.(payload);
options?.onPayload?.(payload, model);
payloads.push(payload);
return {} as ReturnType<StreamFn>;
};
@@ -619,7 +619,7 @@ describe("applyExtraParamsToAgent", () => {
const payloads: Record<string, unknown>[] = [];
const baseStreamFn: StreamFn = (_model, _context, options) => {
const payload: Record<string, unknown> = { thinking: "off" };
options?.onPayload?.(payload);
options?.onPayload?.(payload, model);
payloads.push(payload);
return {} as ReturnType<StreamFn>;
};
@@ -650,7 +650,7 @@ describe("applyExtraParamsToAgent", () => {
const payloads: Record<string, unknown>[] = [];
const baseStreamFn: StreamFn = (_model, _context, options) => {
const payload: Record<string, unknown> = {};
options?.onPayload?.(payload);
options?.onPayload?.(payload, model);
payloads.push(payload);
return {} as ReturnType<StreamFn>;
};
@@ -674,7 +674,7 @@ describe("applyExtraParamsToAgent", () => {
const payloads: Record<string, unknown>[] = [];
const baseStreamFn: StreamFn = (_model, _context, options) => {
const payload: Record<string, unknown> = { tool_choice: "required" };
options?.onPayload?.(payload);
options?.onPayload?.(payload, model);
payloads.push(payload);
return {} as ReturnType<StreamFn>;
};
@@ -699,7 +699,7 @@ describe("applyExtraParamsToAgent", () => {
const payloads: Record<string, unknown>[] = [];
const baseStreamFn: StreamFn = (_model, _context, options) => {
const payload: Record<string, unknown> = {};
options?.onPayload?.(payload);
options?.onPayload?.(payload, model);
payloads.push(payload);
return {} as ReturnType<StreamFn>;
};
@@ -749,7 +749,7 @@ describe("applyExtraParamsToAgent", () => {
],
tool_choice: { type: "tool", name: "read" },
};
options?.onPayload?.(payload);
options?.onPayload?.(payload, model);
payloads.push(payload);
return {} as ReturnType<StreamFn>;
};
@@ -793,7 +793,7 @@ describe("applyExtraParamsToAgent", () => {
},
],
};
options?.onPayload?.(payload);
options?.onPayload?.(payload, model);
payloads.push(payload);
return {} as ReturnType<StreamFn>;
};
@@ -832,7 +832,7 @@ describe("applyExtraParamsToAgent", () => {
},
],
};
options?.onPayload?.(payload);
options?.onPayload?.(payload, model);
payloads.push(payload);
return {} as ReturnType<StreamFn>;
};
@@ -896,7 +896,7 @@ describe("applyExtraParamsToAgent", () => {
},
},
};
options?.onPayload?.(payload);
options?.onPayload?.(payload, model);
payloads.push(payload);
return {} as ReturnType<StreamFn>;
};
@@ -943,7 +943,7 @@ describe("applyExtraParamsToAgent", () => {
},
},
};
options?.onPayload?.(payload);
options?.onPayload?.(payload, model);
payloads.push(payload);
return {} as ReturnType<StreamFn>;
};

View File

@@ -298,7 +298,7 @@ export function createAnthropicToolPayloadCompatibilityWrapper(
);
}
}
return originalOnPayload?.(payload);
return originalOnPayload?.(payload, model);
},
});
};

View File

@@ -17,9 +17,9 @@ function applyAndCapture(params: {
}): CapturedCall {
const captured: CapturedCall = {};
const baseStreamFn: StreamFn = (_model, _context, options) => {
const baseStreamFn: StreamFn = (model, _context, options) => {
captured.headers = options?.headers;
options?.onPayload?.({});
options?.onPayload?.({}, model);
return createAssistantMessageEventStream();
};
const agent = { streamFn: baseStreamFn };
@@ -95,9 +95,9 @@ describe("extra-params: Kilocode kilo/auto reasoning", () => {
it("does not inject reasoning.effort for kilo/auto", () => {
let capturedPayload: Record<string, unknown> | undefined;
const baseStreamFn: StreamFn = (_model, _context, options) => {
const baseStreamFn: StreamFn = (model, _context, options) => {
const payload: Record<string, unknown> = { reasoning_effort: "high" };
options?.onPayload?.(payload);
options?.onPayload?.(payload, model);
capturedPayload = payload;
return createAssistantMessageEventStream();
};
@@ -123,9 +123,9 @@ describe("extra-params: Kilocode kilo/auto reasoning", () => {
it("injects reasoning.effort for non-auto kilocode models", () => {
let capturedPayload: Record<string, unknown> | undefined;
const baseStreamFn: StreamFn = (_model, _context, options) => {
const baseStreamFn: StreamFn = (model, _context, options) => {
const payload: Record<string, unknown> = {};
options?.onPayload?.(payload);
options?.onPayload?.(payload, model);
capturedPayload = payload;
return createAssistantMessageEventStream();
};
@@ -156,9 +156,9 @@ describe("extra-params: Kilocode kilo/auto reasoning", () => {
it("does not inject reasoning.effort for x-ai models", () => {
let capturedPayload: Record<string, unknown> | undefined;
const baseStreamFn: StreamFn = (_model, _context, options) => {
const baseStreamFn: StreamFn = (model, _context, options) => {
const payload: Record<string, unknown> = { reasoning_effort: "high" };
options?.onPayload?.(payload);
options?.onPayload?.(payload, model);
capturedPayload = payload;
return createAssistantMessageEventStream();
};

View File

@@ -12,8 +12,8 @@ type StreamPayload = {
};
function runOpenRouterPayload(payload: StreamPayload, modelId: string) {
const baseStreamFn: StreamFn = (_model, _context, options) => {
options?.onPayload?.(payload);
const baseStreamFn: StreamFn = (model, _context, options) => {
options?.onPayload?.(payload, model);
return createAssistantMessageEventStream();
};
const agent = { streamFn: baseStreamFn };

View File

@@ -230,7 +230,7 @@ function createGoogleThinkingPayloadWrapper(
thinkingLevel,
});
}
return onPayload?.(payload);
return onPayload?.(payload, model);
},
});
};
@@ -263,7 +263,7 @@ function createZaiToolStreamWrapper(
// Inject tool_stream: true for Z.AI API
(payload as Record<string, unknown>).tool_stream = true;
}
return originalOnPayload?.(payload);
return originalOnPayload?.(payload, model);
},
});
};
@@ -310,7 +310,7 @@ function createParallelToolCallsWrapper(
if (payload && typeof payload === "object") {
(payload as Record<string, unknown>).parallel_tool_calls = enabled;
}
return originalOnPayload?.(payload);
return originalOnPayload?.(payload, model);
},
});
};

View File

@@ -22,7 +22,7 @@ type ToolStreamCase = {
function runToolStreamCase(params: ToolStreamCase) {
const payload: Record<string, unknown> = { model: params.model.id, messages: [] };
const baseStreamFn: StreamFn = (model, _context, options) => {
options?.onPayload?.(payload);
options?.onPayload?.(payload, model);
return {} as ReturnType<StreamFn>;
};
const agent = { streamFn: baseStreamFn };

View File

@@ -60,7 +60,7 @@ export function createSiliconFlowThinkingWrapper(baseStreamFn: StreamFn | undefi
payloadObj.thinking = null;
}
}
return originalOnPayload?.(payload);
return originalOnPayload?.(payload, model);
},
});
};
@@ -106,7 +106,7 @@ export function createMoonshotThinkingWrapper(
payloadObj.tool_choice = "auto";
}
}
return originalOnPayload?.(payload);
return originalOnPayload?.(payload, model);
},
});
};

View File

@@ -197,7 +197,7 @@ export function createOpenAIResponsesContextManagementWrapper(
compactThreshold,
});
}
return originalOnPayload?.(payload);
return originalOnPayload?.(payload, model);
},
});
};
@@ -226,7 +226,7 @@ export function createOpenAIServiceTierWrapper(
payloadObj.service_tier = serviceTier;
}
}
return originalOnPayload?.(payload);
return originalOnPayload?.(payload, model);
},
});
};

View File

@@ -92,7 +92,7 @@ export function createOpenRouterSystemCacheWrapper(baseStreamFn: StreamFn | unde
}
}
}
return originalOnPayload?.(payload);
return originalOnPayload?.(payload, model);
},
});
};
@@ -113,7 +113,7 @@ export function createOpenRouterWrapper(
},
onPayload: (payload) => {
normalizeProxyReasoningPayload(payload, thinkingLevel);
return onPayload?.(payload);
return onPayload?.(payload, model);
},
});
};
@@ -138,7 +138,7 @@ export function createKilocodeWrapper(
},
onPayload: (payload) => {
normalizeProxyReasoningPayload(payload, thinkingLevel);
return onPayload?.(payload);
return onPayload?.(payload, model);
},
});
};

View File

@@ -233,14 +233,14 @@ export function wrapOllamaCompatNumCtx(baseFn: StreamFn | undefined, numCtx: num
...options,
onPayload: (payload: unknown) => {
if (!payload || typeof payload !== "object") {
return options?.onPayload?.(payload);
return options?.onPayload?.(payload, model);
}
const payloadRecord = payload as Record<string, unknown>;
if (!payloadRecord.options || typeof payloadRecord.options !== "object") {
payloadRecord.options = {};
}
(payloadRecord.options as Record<string, unknown>).num_ctx = numCtx;
return options?.onPayload?.(payload);
return options?.onPayload?.(payload, model);
},
});
}

View File

@@ -9,7 +9,7 @@ const mocks = vi.hoisted(() => ({
formatOpenAIOAuthTlsPreflightFix: vi.fn(),
}));
vi.mock("@mariozechner/pi-ai", () => ({
vi.mock("@mariozechner/pi-ai/oauth", () => ({
loginOpenAICodex: mocks.loginOpenAICodex,
}));

View File

@@ -1,5 +1,5 @@
import type { OAuthCredentials } from "@mariozechner/pi-ai";
import { loginOpenAICodex } from "@mariozechner/pi-ai";
import { loginOpenAICodex } from "@mariozechner/pi-ai/oauth";
import type { RuntimeEnv } from "../runtime.js";
import type { WizardPrompter } from "../wizard/prompts.js";
import { createVpsAwareOAuthHandlers } from "./oauth-flow.js";