From e9987ffc3aa0928486ce541952ce47713f6cad66 Mon Sep 17 00:00:00 2001 From: Peter Steinberger Date: Wed, 6 May 2026 05:30:22 +0100 Subject: [PATCH] fix: clamp xAI live gateway thinking --- extensions/xai/provider-policy-api.ts | 5 + .../gateway-models.profiles.live.test.ts | 254 ++++++++++++++++-- src/plugins/provider-runtime.test.ts | 26 ++ src/plugins/provider-runtime.ts | 4 + 4 files changed, 266 insertions(+), 23 deletions(-) create mode 100644 extensions/xai/provider-policy-api.ts diff --git a/extensions/xai/provider-policy-api.ts b/extensions/xai/provider-policy-api.ts new file mode 100644 index 00000000000..ed179c7ec95 --- /dev/null +++ b/extensions/xai/provider-policy-api.ts @@ -0,0 +1,5 @@ +import type { ProviderThinkingProfile } from "openclaw/plugin-sdk/plugin-entry"; + +export function resolveThinkingProfile(): ProviderThinkingProfile { + return { levels: [{ id: "off" }], defaultLevel: "off" }; +} diff --git a/src/gateway/gateway-models.profiles.live.test.ts b/src/gateway/gateway-models.profiles.live.test.ts index 7dfcedd4095..9e4bb969ffb 100644 --- a/src/gateway/gateway-models.profiles.live.test.ts +++ b/src/gateway/gateway-models.profiles.live.test.ts @@ -3,7 +3,12 @@ import fs from "node:fs/promises"; import { createServer } from "node:net"; import os from "node:os"; import path from "node:path"; -import type { Api, Model } from "@mariozechner/pi-ai"; +import { + clampThinkingLevel, + type Api, + type Model, + type ModelThinkingLevel, +} from "@mariozechner/pi-ai"; import { afterEach, describe, expect, it } from "vitest"; import { resolveAgentWorkspaceDir, resolveDefaultAgentDir } from "../agents/agent-scope.js"; import { ensureAuthProfileStore, saveAuthProfileStore } from "../agents/auth-profiles/store.js"; @@ -36,6 +41,7 @@ import { clearRuntimeConfigSnapshot, getRuntimeConfig } from "../config/io.js"; import type { ModelsConfig, ModelProviderConfig, OpenClawConfig } from "../config/types.js"; import { isTruthyEnvValue } from "../infra/env.js"; import { normalizeGoogleModelId } from "../plugin-sdk/google-model-id.js"; +import { resolveProviderThinkingProfile } from "../plugins/provider-runtime.js"; import { DEFAULT_AGENT_ID } from "../routing/session-key.js"; import { stripAssistantInternalScaffolding } from "../shared/text/assistant-visible-text.js"; import { GATEWAY_CLIENT_MODES, GATEWAY_CLIENT_NAMES } from "../utils/message-channel.js"; @@ -587,6 +593,95 @@ describe("resolveGatewayLiveMaxModels", () => { }); }); +function createGatewayLiveTestModel(provider: string, id: string): Model { + return { + provider, + id, + name: id, + api: "openai-responses", + input: ["text"], + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 }, + contextWindow: 1_000, + maxTokens: 100, + reasoning: false, + } as Model; +} + +describe("resolveExplicitLiveModelCandidates", () => { + it("uses targeted registry lookup for explicit provider/model filters", () => { + const model = createGatewayLiveTestModel("xai", "grok-4.3"); + const matcher = createLiveTargetMatcher({ + providerFilter: new Set(["xai"]), + modelFilter: new Set(["xai/grok-4.3"]), + env: {}, + }); + const candidates = resolveExplicitLiveModelCandidates({ + modelRegistry: { + find(provider, modelId) { + expect(provider).toBe("xai"); + expect(modelId).toBe("grok-4.3"); + return model; + }, + getAll() { + throw new Error("explicit model lookup should not enumerate registry"); + }, + }, + modelFilter: new Set(["xai/grok-4.3"]), + providerFilter: new Set(["xai"]), + targetMatcher: matcher, + }); + + expect(candidates).toEqual([model]); + }); + + it("falls back to enumeration for ambiguous model-only filters", () => { + const matcher = createLiveTargetMatcher({ + providerFilter: null, + modelFilter: new Set(["grok-4.3"]), + env: {}, + }); + + expect( + resolveExplicitLiveModelCandidates({ + modelRegistry: { + find() { + throw new Error("ambiguous model-only lookup should not use direct find"); + }, + getAll() { + return []; + }, + }, + modelFilter: new Set(["grok-4.3"]), + providerFilter: null, + targetMatcher: matcher, + }), + ).toBeNull(); + }); +}); + +describe("resolveGatewayLiveModelThinkingLevel", () => { + it("clamps requested thinking to levels supported by model metadata", () => { + expect( + resolveGatewayLiveModelThinkingLevel({ + cfg: {}, + model: { + ...createGatewayLiveTestModel("xai", "grok-4.3"), + reasoning: true, + thinkingLevelMap: { + off: null, + minimal: null, + low: null, + medium: null, + high: null, + xhigh: null, + }, + }, + requestedLevel: "low", + }), + ).toBe("off"); + }); +}); + function isGoogleModelNotFoundText(text: string): boolean { const trimmed = text.trim(); if (!trimmed) { @@ -1281,6 +1376,101 @@ type GatewayModelSuiteParams = { providerOverrides?: Record; }; +type LiveModelRegistry = { + find(provider: string, modelId: string): Model | null | undefined; + getAll(): Array>; +}; + +function parseExplicitLiveModelRef( + raw: string, + providerFilter: Set | null, +): { provider: string; modelId: string } | null { + const trimmed = raw.trim(); + if (!trimmed) { + return null; + } + const slash = trimmed.indexOf("/"); + if (slash !== -1) { + const provider = normalizeProviderId(trimmed.slice(0, slash)); + const modelId = trimmed.slice(slash + 1).trim(); + return provider && modelId ? { provider, modelId } : null; + } + if (!providerFilter || providerFilter.size !== 1) { + return null; + } + const [provider] = [...providerFilter]; + return provider ? { provider: normalizeProviderId(provider), modelId: trimmed } : null; +} + +function resolveExplicitLiveModelCandidates(params: { + modelRegistry: LiveModelRegistry; + modelFilter: Set | null; + providerFilter: Set | null; + targetMatcher: ReturnType; +}): Array> | null { + if (!params.modelFilter || params.modelFilter.size === 0) { + return null; + } + const candidates: Array> = []; + const seen = new Set(); + for (const raw of params.modelFilter) { + const ref = parseExplicitLiveModelRef(raw, params.providerFilter); + if (!ref) { + return null; + } + const model = params.modelRegistry.find(ref.provider, ref.modelId); + if (!model) { + return null; + } + if ( + !params.targetMatcher.matchesProvider(model.provider) || + !params.targetMatcher.matchesModel(model.provider, model.id) + ) { + return null; + } + const key = `${normalizeProviderId(model.provider)}/${model.id.toLowerCase()}`; + if (!seen.has(key)) { + seen.add(key); + candidates.push(model); + } + } + return candidates; +} + +function resolveGatewayLiveModelThinkingLevel(params: { + cfg: OpenClawConfig; + model: Model; + requestedLevel: string; +}): string { + const { model, requestedLevel } = params; + const normalized = requestedLevel.trim() as ModelThinkingLevel; + if (!["off", "minimal", "low", "medium", "high", "xhigh"].includes(normalized)) { + return requestedLevel; + } + const profile = resolveProviderThinkingProfile({ + provider: model.provider, + config: params.cfg, + context: { + provider: model.provider, + modelId: model.id, + reasoning: model.reasoning, + }, + }); + if (profile) { + const levelIds = profile.levels.map((level) => level.id); + if (levelIds.includes(normalized)) { + return normalized; + } + if (profile.defaultLevel) { + return profile.defaultLevel; + } + if (levelIds.length === 1) { + return levelIds[0] ?? requestedLevel; + } + } + return clampThinkingLevel(model, normalized); +} + function buildLiveGatewayConfig(params: { cfg: OpenClawConfig; candidates: Array>; @@ -1549,6 +1739,14 @@ async function runGatewayModelSuite(params: GatewayModelSuiteParams) { for (const [index, model] of params.candidates.entries()) { const modelKey = `${model.provider}/${model.id}`; const progressLabel = `[${params.label}] ${index + 1}/${total} ${modelKey}`; + const thinkingLevel = resolveGatewayLiveModelThinkingLevel({ + cfg: params.cfg, + model, + requestedLevel: params.thinkingLevel, + }); + if (thinkingLevel !== params.thinkingLevel) { + logProgress(`${progressLabel}: thinking ${params.thinkingLevel} -> ${thinkingLevel}`); + } // Use a separate session per model: live providers can finalize late after // skip/retry paths, and a reset on a reused key does not isolate those // delayed transcript writes from the next model probe. @@ -1589,7 +1787,7 @@ async function runGatewayModelSuite(params: GatewayModelSuiteParams) { modelKey, message: "Explain in 2-3 sentences how the JavaScript event loop handles microtasks vs macrotasks. Must mention both words: microtask and macrotask.", - thinkingLevel: params.thinkingLevel, + thinkingLevel, context: `${progressLabel}: prompt`, }); if (!text) { @@ -1601,7 +1799,7 @@ async function runGatewayModelSuite(params: GatewayModelSuiteParams) { modelKey, message: "Explain in 2-3 sentences how the JavaScript event loop handles microtasks vs macrotasks. Must mention both words: microtask and macrotask.", - thinkingLevel: params.thinkingLevel, + thinkingLevel, context: `${progressLabel}: prompt-retry`, }); } @@ -1650,7 +1848,7 @@ async function runGatewayModelSuite(params: GatewayModelSuiteParams) { modelKey, message: "Answer in exactly two short sentences. Include the exact lowercase words microtask and macrotask. No bullets.", - thinkingLevel: params.thinkingLevel, + thinkingLevel, context: `${progressLabel}: prompt-keyword-retry`, }); if (retryText) { @@ -1697,7 +1895,7 @@ async function runGatewayModelSuite(params: GatewayModelSuiteParams) { : "OpenClaw live tool probe (local, safe): " + `use the tool named \`read\` (or \`Read\`) with JSON arguments {"path":"${toolProbePath}"}. ` + "Then reply with the two nonce values you read (include both).", - thinkingLevel: params.thinkingLevel, + thinkingLevel, context: `${progressLabel}: tool-read`, }); if ( @@ -1768,7 +1966,7 @@ async function runGatewayModelSuite(params: GatewayModelSuiteParams) { `mkdir -p "${tempDir}" && printf '%s' '${nonceC}' > "${toolWritePath}". ` + `Then use the tool named \`read\` (or \`Read\`) with JSON arguments {"path":"${toolWritePath}"}. ` + "Finally reply including the nonce text you read back.", - thinkingLevel: params.thinkingLevel, + thinkingLevel, context: `${progressLabel}: tool-exec`, }); if ( @@ -1836,7 +2034,7 @@ async function runGatewayModelSuite(params: GatewayModelSuiteParams) { content: imageBase64, }, ], - thinkingLevel: params.thinkingLevel, + thinkingLevel, context: `${progressLabel}: image`, }); if ( @@ -1883,7 +2081,7 @@ async function runGatewayModelSuite(params: GatewayModelSuiteParams) { idempotencyKey: `idem-${runId2}-1`, modelKey, message: `Call the tool named \`read\` (or \`Read\`) on "${toolProbePath}". Do not write any other text.`, - thinkingLevel: params.thinkingLevel, + thinkingLevel, context: `${progressLabel}: tool-only-regression-first`, }); assertNoReasoningTags({ @@ -1899,7 +2097,7 @@ async function runGatewayModelSuite(params: GatewayModelSuiteParams) { idempotencyKey: `idem-${runId2}-2`, modelKey, message: `Now answer: what are the values of nonceA and nonceB in "${toolProbePath}"? Reply with exactly: ${nonceA} ${nonceB}.`, - thinkingLevel: params.thinkingLevel, + thinkingLevel, context: `${progressLabel}: tool-only-regression-second`, }); assertNoReasoningTags({ @@ -1919,7 +2117,7 @@ async function runGatewayModelSuite(params: GatewayModelSuiteParams) { sessionKey, modelKey, label: progressLabel, - thinkingLevel: params.thinkingLevel, + thinkingLevel, }); } return "done"; @@ -2171,7 +2369,6 @@ describeLive("gateway live (dev agent, profile keys)", () => { const agentDir = resolveDefaultAgentDir(cfg); const authStorage = discoverAuthStorage(agentDir); const modelRegistry = discoverModels(authStorage, agentDir); - const all = modelRegistry.getAll(); const rawModels = process.env.OPENCLAW_LIVE_GATEWAY_MODELS?.trim(); const useModern = !rawModels || rawModels === "modern" || rawModels === "all"; @@ -2184,18 +2381,29 @@ describeLive("gateway live (dev agent, profile keys)", () => { config: cfg, env: process.env, }); - const wanted = filter - ? all.filter((m) => targetMatcher.matchesModel(m.provider, m.id)) - : all.filter( - (m) => - !shouldExcludeProviderFromDefaultHighSignalLiveSweep({ - provider: m.provider, - useExplicitModels: useExplicit, - providerFilter: PROVIDERS, - config: cfg, - env: process.env, - }) && isHighSignalLiveModelRef({ provider: m.provider, id: m.id }), - ); + let wanted = useExplicit + ? resolveExplicitLiveModelCandidates({ + modelRegistry, + modelFilter: filter, + providerFilter: PROVIDERS, + targetMatcher, + }) + : null; + if (!wanted) { + const all = modelRegistry.getAll(); + wanted = filter + ? all.filter((m) => targetMatcher.matchesModel(m.provider, m.id)) + : all.filter( + (m) => + !shouldExcludeProviderFromDefaultHighSignalLiveSweep({ + provider: m.provider, + useExplicitModels: useExplicit, + providerFilter: PROVIDERS, + config: cfg, + env: process.env, + }) && isHighSignalLiveModelRef({ provider: m.provider, id: m.id }), + ); + } const candidates: Array> = []; const skipped: Array<{ model: string; error: string }> = []; diff --git a/src/plugins/provider-runtime.test.ts b/src/plugins/provider-runtime.test.ts index ef1250553fd..288a99ff075 100644 --- a/src/plugins/provider-runtime.test.ts +++ b/src/plugins/provider-runtime.test.ts @@ -70,6 +70,7 @@ let resolveProviderStreamFn: typeof import("./provider-runtime.js").resolveProvi let resolveProviderCacheTtlEligibility: typeof import("./provider-runtime.js").resolveProviderCacheTtlEligibility; let resolveProviderBinaryThinking: typeof import("./provider-runtime.js").resolveProviderBinaryThinking; let createProviderEmbeddingProvider: typeof import("./provider-runtime.js").createProviderEmbeddingProvider; +let resolveProviderThinkingProfile: typeof import("./provider-runtime.js").resolveProviderThinkingProfile; let resolveProviderDefaultThinkingLevel: typeof import("./provider-runtime.js").resolveProviderDefaultThinkingLevel; let resolveProviderModernModelRef: typeof import("./provider-runtime.js").resolveProviderModernModelRef; let resolveProviderReasoningOutputModeWithPlugin: typeof import("./provider-runtime.js").resolveProviderReasoningOutputModeWithPlugin; @@ -295,6 +296,7 @@ describe("provider-runtime", () => { resolveProviderCacheTtlEligibility, resolveProviderBinaryThinking, createProviderEmbeddingProvider, + resolveProviderThinkingProfile, resolveProviderDefaultThinkingLevel, resolveProviderModernModelRef, resolveProviderReasoningOutputModeWithPlugin, @@ -1154,6 +1156,30 @@ describe("provider-runtime", () => { expect(resolvePluginProvidersMock).not.toHaveBeenCalled(); }); + it("resolves thinking profiles from bundled policy surface before runtime plugins", () => { + const resolveThinkingProfile = vi.fn(() => ({ + levels: [{ id: "off" as const }], + defaultLevel: "off" as const, + })); + resolveBundledProviderPolicySurfaceMock.mockReturnValue({ + resolveThinkingProfile, + }); + + expect( + resolveProviderThinkingProfile({ + provider: "xai", + context: { + provider: "xai", + modelId: "grok-4.3", + reasoning: true, + }, + }), + ).toEqual({ levels: [{ id: "off" }], defaultLevel: "off" }); + + expect(resolveThinkingProfile).toHaveBeenCalledTimes(1); + expect(resolvePluginProvidersMock).not.toHaveBeenCalled(); + }); + it("resolves provider config defaults through owner plugins", () => { resolvePluginProvidersMock.mockReturnValue([ { diff --git a/src/plugins/provider-runtime.ts b/src/plugins/provider-runtime.ts index 345a0d05a84..09467e24e57 100644 --- a/src/plugins/provider-runtime.ts +++ b/src/plugins/provider-runtime.ts @@ -763,6 +763,10 @@ export function resolveProviderThinkingProfile(params: { env?: NodeJS.ProcessEnv; context: ProviderDefaultThinkingPolicyContext; }): ProviderThinkingProfile | null | undefined { + const bundledSurface = resolveBundledProviderPolicySurface(params.provider); + if (bundledSurface?.resolveThinkingProfile) { + return bundledSurface.resolveThinkingProfile(params.context) ?? undefined; + } return resolveProviderRuntimePlugin(params)?.resolveThinkingProfile?.(params.context); }