From 5f60479f18c505018bb3350efddf82b66ce0dbc4 Mon Sep 17 00:00:00 2001 From: Shakker Date: Wed, 6 May 2026 23:48:18 +0100 Subject: [PATCH] fix: scope async model runtime hooks --- .../model.skip-pi-discovery-hooks.test.ts | 19 +++++++ src/agents/pi-embedded-runner/model.ts | 53 +++++++++++++++++-- src/agents/pi-embedded-runner/run.ts | 5 +- 3 files changed, 73 insertions(+), 4 deletions(-) diff --git a/src/agents/pi-embedded-runner/model.skip-pi-discovery-hooks.test.ts b/src/agents/pi-embedded-runner/model.skip-pi-discovery-hooks.test.ts index 90c037c02e0..1b43374f901 100644 --- a/src/agents/pi-embedded-runner/model.skip-pi-discovery-hooks.test.ts +++ b/src/agents/pi-embedded-runner/model.skip-pi-discovery-hooks.test.ts @@ -59,6 +59,7 @@ describe("resolveModelAsync skipPiDiscovery runtime hooks", () => { it("uses only target-provider dynamic hooks", async () => { const result = await resolveModelAsync("ollama", "llama3.2:latest", "/tmp/agent", undefined, { skipPiDiscovery: true, + workspaceDir: "/tmp/workspace", }); expect(result.error).toBeUndefined(); @@ -70,8 +71,26 @@ describe("resolveModelAsync skipPiDiscovery runtime hooks", () => { expect(mocks.discoverAuthStorage).not.toHaveBeenCalled(); expect(mocks.discoverModels).not.toHaveBeenCalled(); expect(mocks.prepareProviderDynamicModel).toHaveBeenCalledTimes(1); + expect(mocks.prepareProviderDynamicModel).toHaveBeenCalledWith( + expect.objectContaining({ + workspaceDir: "/tmp/workspace", + context: expect.objectContaining({ workspaceDir: "/tmp/workspace" }), + }), + ); expect(mocks.runProviderDynamicModel).toHaveBeenCalledTimes(1); + expect(mocks.runProviderDynamicModel).toHaveBeenCalledWith( + expect.objectContaining({ + workspaceDir: "/tmp/workspace", + context: expect.objectContaining({ workspaceDir: "/tmp/workspace" }), + }), + ); expect(mocks.normalizeProviderResolvedModelWithPlugin).toHaveBeenCalledTimes(1); + expect(mocks.normalizeProviderResolvedModelWithPlugin).toHaveBeenCalledWith( + expect.objectContaining({ + workspaceDir: "/tmp/workspace", + context: expect.objectContaining({ workspaceDir: "/tmp/workspace" }), + }), + ); expect(mocks.applyProviderResolvedModelCompatWithPlugins).not.toHaveBeenCalled(); expect(mocks.applyProviderResolvedTransportWithPlugin).not.toHaveBeenCalled(); expect(mocks.normalizeProviderTransportWithPlugin).not.toHaveBeenCalled(); diff --git a/src/agents/pi-embedded-runner/model.ts b/src/agents/pi-embedded-runner/model.ts index dde23af3f62..08a31b40f45 100644 --- a/src/agents/pi-embedded-runner/model.ts +++ b/src/agents/pi-embedded-runner/model.ts @@ -155,13 +155,17 @@ function canonicalizeLegacyResolvedModel(params: { function applyResolvedTransportFallback(params: { provider: string; cfg?: OpenClawConfig; + workspaceDir?: string; runtimeHooks: ProviderRuntimeHooks; model: Model; }): Model | undefined { const normalized = params.runtimeHooks.normalizeProviderTransportWithPlugin({ provider: params.provider, config: params.cfg, + workspaceDir: params.workspaceDir, context: { + config: params.cfg, + workspaceDir: params.workspaceDir, provider: params.provider, api: params.model.api, baseUrl: params.model.baseUrl, @@ -187,6 +191,7 @@ function normalizeResolvedModel(params: { model: Model; cfg?: OpenClawConfig; agentDir?: string; + workspaceDir?: string; runtimeHooks?: ProviderRuntimeHooks; }): Model { const normalizeModelCost = (cost: unknown): Model["cost"] => { @@ -237,9 +242,11 @@ function normalizeResolvedModel(params: { const pluginNormalized = runtimeHooks.normalizeProviderResolvedModelWithPlugin({ provider: params.provider, config: params.cfg, + workspaceDir: params.workspaceDir, context: { config: params.cfg, agentDir: params.agentDir, + workspaceDir: params.workspaceDir, provider: params.provider, modelId: normalizedInputModel.id, model: normalizedInputModel, @@ -248,9 +255,11 @@ function normalizeResolvedModel(params: { const compatNormalized = runtimeHooks.applyProviderResolvedModelCompatWithPlugins?.({ provider: params.provider, config: params.cfg, + workspaceDir: params.workspaceDir, context: { config: params.cfg, agentDir: params.agentDir, + workspaceDir: params.workspaceDir, provider: params.provider, modelId: normalizedInputModel.id, model: (pluginNormalized ?? normalizedInputModel) as never, @@ -259,9 +268,11 @@ function normalizeResolvedModel(params: { const transportNormalized = runtimeHooks.applyProviderResolvedTransportWithPlugin?.({ provider: params.provider, config: params.cfg, + workspaceDir: params.workspaceDir, context: { config: params.cfg, agentDir: params.agentDir, + workspaceDir: params.workspaceDir, provider: params.provider, modelId: normalizedInputModel.id, model: (compatNormalized ?? pluginNormalized ?? normalizedInputModel) as never, @@ -272,6 +283,7 @@ function normalizeResolvedModel(params: { applyResolvedTransportFallback({ provider: params.provider, cfg: params.cfg, + workspaceDir: params.workspaceDir, runtimeHooks, model: compatNormalized ?? pluginNormalized ?? normalizedInputModel, }); @@ -290,6 +302,7 @@ function resolveProviderTransport(params: { api?: Api | null; baseUrl?: string; cfg?: OpenClawConfig; + workspaceDir?: string; runtimeHooks?: ProviderRuntimeHooks; }): { api?: Api; @@ -299,7 +312,10 @@ function resolveProviderTransport(params: { const normalized = runtimeHooks.normalizeProviderTransportWithPlugin({ provider: params.provider, config: params.cfg, + workspaceDir: params.workspaceDir, context: { + config: params.cfg, + workspaceDir: params.workspaceDir, provider: params.provider, api: params.api, baseUrl: params.baseUrl, @@ -499,6 +515,7 @@ function applyConfiguredProviderOverrides(params: { cfg?: OpenClawConfig; runtimeHooks?: ProviderRuntimeHooks; preferDiscoveredModelMetadata?: boolean; + workspaceDir?: string; }): ProviderRuntimeModel { const { discoveredModel, providerConfig, modelId } = params; const requestTimeoutMs = resolveProviderRequestTimeoutMs(providerConfig?.timeoutSeconds); @@ -582,6 +599,7 @@ function applyConfiguredProviderOverrides(params: { resolveConfiguredProviderDefaultApi(providerConfig), baseUrl: providerConfig.baseUrl ?? discoveredModel.baseUrl, cfg: params.cfg, + workspaceDir: params.workspaceDir, runtimeHooks: params.runtimeHooks, }); const resolvedContextWindow = @@ -635,9 +653,10 @@ function resolveExplicitModelWithRegistry(params: { modelRegistry: ModelRegistry; cfg?: OpenClawConfig; agentDir?: string; + workspaceDir?: string; runtimeHooks?: ProviderRuntimeHooks; }): { kind: "resolved"; model: Model } | { kind: "suppressed" } | undefined { - const { provider, modelId, modelRegistry, cfg, agentDir, runtimeHooks } = params; + const { provider, modelId, modelRegistry, cfg, agentDir, workspaceDir, runtimeHooks } = params; const providerConfig = resolveConfiguredProviderConfig(cfg, provider); const requestTimeoutMs = resolveProviderRequestTimeoutMs(providerConfig?.timeoutSeconds); const inlineMatch = findInlineModelMatch({ @@ -666,6 +685,7 @@ function resolveExplicitModelWithRegistry(params: { provider, cfg, agentDir, + workspaceDir, model: { ...inlineMatch, ...(resolvedParams ? { params: resolvedParams } : {}), @@ -701,8 +721,10 @@ function resolveExplicitModelWithRegistry(params: { modelId, cfg, runtimeHooks, + workspaceDir, }), runtimeHooks, + workspaceDir, }), }; } @@ -726,6 +748,7 @@ function resolveExplicitModelWithRegistry(params: { provider, cfg, agentDir, + workspaceDir, model: { ...fallbackInlineMatch, ...(resolvedParams ? { params: resolvedParams } : {}), @@ -766,6 +789,7 @@ function resolvePluginDynamicModelWithRegistry(params: { context: { config: cfg, agentDir, + workspaceDir, provider, modelId, modelRegistry, @@ -782,12 +806,14 @@ function resolvePluginDynamicModelWithRegistry(params: { modelId, cfg, runtimeHooks, + workspaceDir, preferDiscoveredModelMetadata, }); return normalizeResolvedModel({ provider, cfg, agentDir, + workspaceDir, model: overriddenDynamicModel, runtimeHooks, }); @@ -798,9 +824,10 @@ function resolveConfiguredFallbackModel(params: { modelId: string; cfg?: OpenClawConfig; agentDir?: string; + workspaceDir?: string; runtimeHooks?: ProviderRuntimeHooks; }): Model | undefined { - const { provider, modelId, cfg, agentDir, runtimeHooks } = params; + const { provider, modelId, cfg, agentDir, workspaceDir, runtimeHooks } = params; const providerConfig = resolveConfiguredProviderConfig(cfg, provider); const requestTimeoutMs = resolveProviderRequestTimeoutMs(providerConfig?.timeoutSeconds); const configuredModel = findConfiguredProviderModel(providerConfig, provider, modelId); @@ -825,6 +852,7 @@ function resolveConfiguredFallbackModel(params: { api: resolveConfiguredProviderDefaultApi(providerConfig) ?? "openai-responses", baseUrl: providerConfig?.baseUrl, cfg, + workspaceDir, runtimeHooks, }); const requestConfig = resolveProviderRequestConfig({ @@ -842,6 +870,7 @@ function resolveConfiguredFallbackModel(params: { provider, cfg, agentDir, + workspaceDir, model: attachModelProviderRequestTransport( { id: modelId, @@ -921,6 +950,7 @@ export function resolveModelWithRegistry(params: { modelRegistry: ModelRegistry; cfg?: OpenClawConfig; agentDir?: string; + workspaceDir?: string; runtimeHooks?: ProviderRuntimeHooks; }): Model | undefined { const normalizedRef = { @@ -933,7 +963,8 @@ export function resolveModelWithRegistry(params: { modelId: normalizedRef.model, }; const runtimeHooks = params.runtimeHooks ?? DEFAULT_PROVIDER_RUNTIME_HOOKS; - const workspaceDir = normalizedParams.cfg?.agents?.defaults?.workspace; + const workspaceDir = + normalizedParams.workspaceDir ?? normalizedParams.cfg?.agents?.defaults?.workspace; const explicitModel = resolveExplicitModelWithRegistry(normalizedParams); if (explicitModel?.kind === "suppressed") { return undefined; @@ -978,6 +1009,7 @@ export function resolveModel( modelRegistry?: ModelRegistry; runtimeHooks?: ProviderRuntimeHooks; skipProviderRuntimeHooks?: boolean; + workspaceDir?: string; }, ): { model?: Model; @@ -990,6 +1022,7 @@ export function resolveModel( model: normalizeStaticProviderModelId(normalizeProviderId(provider), modelId), }; const resolvedAgentDir = agentDir ?? resolveDefaultAgentDir(cfg ?? {}); + const workspaceDir = options?.workspaceDir ?? cfg?.agents?.defaults?.workspace; const authStorage = options?.authStorage ?? discoverAuthStorage(resolvedAgentDir); const modelRegistry = options?.modelRegistry ?? discoverModels(authStorage, resolvedAgentDir); const runtimeHooks = resolveRuntimeHooks(options); @@ -999,6 +1032,7 @@ export function resolveModel( modelRegistry, cfg, agentDir: resolvedAgentDir, + workspaceDir, runtimeHooks, }); if (model) { @@ -1011,6 +1045,7 @@ export function resolveModel( modelId: normalizedRef.model, cfg, agentDir: resolvedAgentDir, + workspaceDir, runtimeHooks, }), authStorage, @@ -1030,6 +1065,7 @@ export async function resolveModelAsync( runtimeHooks?: ProviderRuntimeHooks; skipProviderRuntimeHooks?: boolean; skipPiDiscovery?: boolean; + workspaceDir?: string; }, ): Promise<{ model?: Model; @@ -1042,6 +1078,7 @@ export async function resolveModelAsync( model: normalizeStaticProviderModelId(normalizeProviderId(provider), modelId), }; const resolvedAgentDir = agentDir ?? resolveDefaultAgentDir(cfg ?? {}); + const workspaceDir = options?.workspaceDir ?? cfg?.agents?.defaults?.workspace; const emptyDiscoveryStores = options?.skipPiDiscovery && (!options.authStorage || !options.modelRegistry) ? createEmptyPiDiscoveryStores() @@ -1061,6 +1098,7 @@ export async function resolveModelAsync( modelRegistry, cfg, agentDir: resolvedAgentDir, + workspaceDir, runtimeHooks, }); if (explicitModel?.kind === "suppressed") { @@ -1070,6 +1108,7 @@ export async function resolveModelAsync( modelId: normalizedRef.model, cfg, agentDir: resolvedAgentDir, + workspaceDir, runtimeHooks, }), authStorage, @@ -1081,9 +1120,11 @@ export async function resolveModelAsync( await runtimeHooks.prepareProviderDynamicModel({ provider: normalizedRef.provider, config: cfg, + workspaceDir, context: { config: cfg, agentDir: resolvedAgentDir, + workspaceDir, provider: normalizedRef.provider, modelId: normalizedRef.model, modelRegistry, @@ -1096,6 +1137,7 @@ export async function resolveModelAsync( modelRegistry, cfg, agentDir: resolvedAgentDir, + workspaceDir, runtimeHooks, }); }; @@ -1106,6 +1148,7 @@ export async function resolveModelAsync( modelId: normalizedRef.model, cfg, agentDir: resolvedAgentDir, + workspaceDir, runtimeHooks, }) ? explicitModel.model @@ -1126,6 +1169,7 @@ export async function resolveModelAsync( modelId: normalizedRef.model, cfg, agentDir: resolvedAgentDir, + workspaceDir, runtimeHooks, }), authStorage, @@ -1148,6 +1192,7 @@ function buildUnknownModelError(params: { modelId: string; cfg?: OpenClawConfig; agentDir?: string; + workspaceDir?: string; runtimeHooks?: ProviderRuntimeHooks; }): string { const suppressed = buildSuppressedBuiltInModelError({ @@ -1163,10 +1208,12 @@ function buildUnknownModelError(params: { const hint = runtimeHooks.buildProviderUnknownModelHintWithPlugin({ provider: params.provider, config: params.cfg, + workspaceDir: params.workspaceDir, env: process.env, context: { config: params.cfg, agentDir: params.agentDir, + workspaceDir: params.workspaceDir, env: process.env, provider: params.provider, modelId: params.modelId, diff --git a/src/agents/pi-embedded-runner/run.ts b/src/agents/pi-embedded-runner/run.ts index b4880591cb5..c6a573b3af3 100644 --- a/src/agents/pi-embedded-runner/run.ts +++ b/src/agents/pi-embedded-runner/run.ts @@ -514,6 +514,7 @@ export async function runEmbeddedPiAgent( // first generating PI models.json. This keeps one-shot model runs from // blocking on unrelated provider discovery. skipPiDiscovery: true, + workspaceDir: resolvedWorkspace, }, ); const modelResolution = @@ -523,7 +524,9 @@ export async function runEmbeddedPiAgent( await ensureOpenClawModelsJson(params.config, agentDir, { workspaceDir: resolvedWorkspace, }); - return await resolveModelAsync(provider, modelId, agentDir, params.config); + return await resolveModelAsync(provider, modelId, agentDir, params.config, { + workspaceDir: resolvedWorkspace, + }); })(); const { model, error, authStorage, modelRegistry } = modelResolution; if (!model) {