diff --git a/src/agents/live-target-matcher.test.ts b/src/agents/live-target-matcher.test.ts new file mode 100644 index 00000000000..03564a8727b --- /dev/null +++ b/src/agents/live-target-matcher.test.ts @@ -0,0 +1,34 @@ +import { describe, expect, it } from "vitest"; +import { createLiveTargetMatcher } from "./live-target-matcher.js"; + +describe("createLiveTargetMatcher", () => { + it("matches Anthropic-owned models for the claude-cli provider filter", () => { + const matcher = createLiveTargetMatcher({ + providerFilter: new Set(["claude-cli"]), + modelFilter: null, + }); + + expect(matcher.matchesProvider("anthropic")).toBe(true); + expect(matcher.matchesProvider("openai")).toBe(false); + }); + + it("matches Anthropic model refs for claude-cli explicit model filters", () => { + const matcher = createLiveTargetMatcher({ + providerFilter: null, + modelFilter: new Set(["claude-cli/claude-sonnet-4-6"]), + }); + + expect(matcher.matchesModel("anthropic", "claude-sonnet-4-6")).toBe(true); + expect(matcher.matchesModel("anthropic", "claude-opus-4-6")).toBe(false); + }); + + it("keeps direct provider/model matches working", () => { + const matcher = createLiveTargetMatcher({ + providerFilter: new Set(["openrouter"]), + modelFilter: new Set(["openrouter/openai/gpt-5.4"]), + }); + + expect(matcher.matchesProvider("openrouter")).toBe(true); + expect(matcher.matchesModel("openrouter", "openai/gpt-5.4")).toBe(true); + }); +}); diff --git a/src/agents/live-target-matcher.ts b/src/agents/live-target-matcher.ts new file mode 100644 index 00000000000..104b5af2eaf --- /dev/null +++ b/src/agents/live-target-matcher.ts @@ -0,0 +1,156 @@ +import type { OpenClawConfig } from "../config/config.js"; +import { resolveOwningPluginIdsForProvider } from "../plugins/providers.js"; +import { normalizeProviderId } from "./provider-id.js"; + +type ModelTarget = { + raw: string; + provider?: string; + modelId: string; +}; + +function normalizeCsvSet(values: Set | null): Set | null { + if (!values) { + return null; + } + const normalized = new Set(); + for (const value of values) { + const trimmed = value.trim(); + if (!trimmed) { + continue; + } + normalized.add(trimmed); + } + return normalized.size > 0 ? normalized : null; +} + +function parseModelTarget(raw: string): ModelTarget | null { + const trimmed = raw.trim(); + if (!trimmed) { + return null; + } + const slash = trimmed.indexOf("/"); + if (slash === -1) { + return { + raw: trimmed, + modelId: trimmed.toLowerCase(), + }; + } + const provider = normalizeProviderId(trimmed.slice(0, slash)); + const modelId = trimmed + .slice(slash + 1) + .trim() + .toLowerCase(); + if (!provider || !modelId) { + return null; + } + return { + raw: trimmed, + provider, + modelId, + }; +} + +function hasSharedOwner( + left: string, + right: string, + params: { + config?: OpenClawConfig; + workspaceDir?: string; + env?: NodeJS.ProcessEnv; + ownerCache: Map; + }, +): boolean { + const resolveOwners = (provider: string): readonly string[] => { + const normalized = normalizeProviderId(provider); + const cached = params.ownerCache.get(normalized); + if (cached) { + return cached; + } + const owners = + resolveOwningPluginIdsForProvider({ + provider: normalized, + config: params.config, + workspaceDir: params.workspaceDir, + env: params.env, + }) ?? []; + params.ownerCache.set(normalized, owners); + return owners; + }; + + const leftOwners = resolveOwners(left); + const rightOwners = resolveOwners(right); + return leftOwners.some((owner) => rightOwners.includes(owner)); +} + +export function createLiveTargetMatcher(params: { + providerFilter: Set | null; + modelFilter: Set | null; + config?: OpenClawConfig; + workspaceDir?: string; + env?: NodeJS.ProcessEnv; +}) { + const providerFilter = normalizeCsvSet(params.providerFilter); + const modelTargets = [...(normalizeCsvSet(params.modelFilter) ?? [])] + .map((value) => parseModelTarget(value)) + .filter((value): value is ModelTarget => value !== null); + const ownerCache = new Map(); + + return { + matchesProvider(provider: string): boolean { + if (!providerFilter) { + return true; + } + const normalizedProvider = normalizeProviderId(provider); + for (const requested of providerFilter) { + const normalizedRequested = normalizeProviderId(requested); + if (normalizedRequested === normalizedProvider) { + return true; + } + if ( + hasSharedOwner(normalizedRequested, normalizedProvider, { + config: params.config, + workspaceDir: params.workspaceDir, + env: params.env, + ownerCache, + }) + ) { + return true; + } + } + return false; + }, + matchesModel(provider: string, modelId: string): boolean { + if (modelTargets.length === 0) { + return true; + } + const normalizedProvider = normalizeProviderId(provider); + const normalizedModelId = modelId.trim().toLowerCase(); + const directRef = `${normalizedProvider}/${normalizedModelId}`; + for (const target of modelTargets) { + if (target.raw.toLowerCase() === directRef) { + return true; + } + if (target.modelId !== normalizedModelId) { + continue; + } + if (!target.provider) { + return true; + } + if (target.provider === normalizedProvider) { + return true; + } + if ( + hasSharedOwner(target.provider, normalizedProvider, { + config: params.config, + workspaceDir: params.workspaceDir, + env: params.env, + ownerCache, + }) + ) { + return true; + } + } + return false; + }, + }; +} diff --git a/src/agents/models.profiles.live.test.ts b/src/agents/models.profiles.live.test.ts index 25d508f4bfa..77c453a2b32 100644 --- a/src/agents/models.profiles.live.test.ts +++ b/src/agents/models.profiles.live.test.ts @@ -9,6 +9,7 @@ import { isAnthropicRateLimitError, } from "./live-auth-keys.js"; import { isHighSignalLiveModelRef, selectHighSignalLiveItems } from "./live-model-filter.js"; +import { createLiveTargetMatcher } from "./live-target-matcher.js"; import { isLiveProfileKeyModeEnabled, isLiveTestEnabled } from "./live-test-helpers.js"; import { getApiKeyForModel, requireApiKey } from "./model-auth.js"; import { shouldSuppressBuiltInModel } from "./model-suppression.js"; @@ -418,6 +419,12 @@ describeLive("live models (profile keys)", () => { const providers = parseProviderFilter(process.env.OPENCLAW_LIVE_PROVIDERS); const perModelTimeoutMs = toInt(process.env.OPENCLAW_LIVE_MODEL_TIMEOUT_MS, 30_000); const maxModels = toInt(process.env.OPENCLAW_LIVE_MAX_MODELS, 0); + const targetMatcher = createLiveTargetMatcher({ + providerFilter: providers, + modelFilter: filter, + config: cfg, + env: process.env, + }); const failures: Array<{ model: string; error: string }> = []; const skipped: Array<{ model: string; reason: string }> = []; @@ -430,11 +437,11 @@ describeLive("live models (profile keys)", () => { if (shouldSuppressBuiltInModel({ provider: model.provider, id: model.id })) { continue; } - if (providers && !providers.has(model.provider)) { + if (!targetMatcher.matchesProvider(model.provider)) { continue; } const id = `${model.provider}/${model.id}`; - if (filter && !filter.has(id)) { + if (!targetMatcher.matchesModel(model.provider, model.id)) { continue; } if (!filter && useModern) { diff --git a/src/gateway/gateway-models.profiles.live.test.ts b/src/gateway/gateway-models.profiles.live.test.ts index 12a4434598e..01e57d61b24 100644 --- a/src/gateway/gateway-models.profiles.live.test.ts +++ b/src/gateway/gateway-models.profiles.live.test.ts @@ -23,6 +23,7 @@ import { isHighSignalLiveModelRef, selectHighSignalLiveItems, } from "../agents/live-model-filter.js"; +import { createLiveTargetMatcher } from "../agents/live-target-matcher.js"; import { isLiveProfileKeyModeEnabled, isLiveTestEnabled } from "../agents/live-test-helpers.js"; import { getApiKeyForModel } from "../agents/model-auth.js"; import { shouldSuppressBuiltInModel } from "../agents/model-suppression.js"; @@ -823,39 +824,87 @@ async function getFreeGatewayPort(): Promise { throw new Error("failed to acquire a free gateway port block"); } +function sleep(ms: number): Promise { + return new Promise((resolve) => setTimeout(resolve, ms)); +} + async function connectClient(params: { url: string; token: string }) { + const startedAt = Date.now(); + let attempt = 0; + let lastError: Error | null = null; + + while (Date.now() - startedAt < GATEWAY_LIVE_PROBE_TIMEOUT_MS) { + attempt += 1; + const remainingMs = GATEWAY_LIVE_PROBE_TIMEOUT_MS - (Date.now() - startedAt); + if (remainingMs <= 0) { + break; + } + try { + return await connectClientOnce({ + ...params, + timeoutMs: Math.min(remainingMs, 10_000), + }); + } catch (error) { + lastError = error instanceof Error ? error : new Error(String(error)); + if (!isRetryableGatewayConnectError(lastError) || remainingMs <= 2_000) { + throw lastError; + } + await sleep(Math.min(500 * attempt, 2_000)); + } + } + + throw lastError ?? new Error("gateway connect timeout"); +} + +async function connectClientOnce(params: { url: string; token: string; timeoutMs: number }) { return await new Promise((resolve, reject) => { let settled = false; - const stop = (err?: Error, client?: GatewayClient) => { + let client: GatewayClient | undefined; + const stop = (err?: Error, connectedClient?: GatewayClient) => { if (settled) { return; } settled = true; clearTimeout(timer); if (err) { + if (client) { + void client.stopAndWait({ timeoutMs: 1_000 }).catch(() => {}); + } reject(err); } else { - resolve(client as GatewayClient); + resolve(connectedClient as GatewayClient); } }; - const client = new GatewayClient({ + client = new GatewayClient({ url: params.url, token: params.token, clientName: GATEWAY_CLIENT_NAMES.TEST, clientDisplayName: "vitest-live", clientVersion: "dev", mode: GATEWAY_CLIENT_MODES.TEST, + requestTimeoutMs: params.timeoutMs, + connectChallengeTimeoutMs: params.timeoutMs, onHelloOk: () => stop(undefined, client), onConnectError: (err) => stop(err), onClose: (code, reason) => stop(new Error(`gateway closed during connect (${code}): ${reason}`)), }); - const timer = setTimeout(() => stop(new Error("gateway connect timeout")), 10_000); + const timer = setTimeout(() => stop(new Error("gateway connect timeout")), params.timeoutMs); timer.unref(); client.start(); }); } +function isRetryableGatewayConnectError(error: Error): boolean { + const message = error.message.toLowerCase(); + return ( + message.includes("gateway closed during connect (1000)") || + message.includes("gateway connect timeout") || + message.includes("gateway connect challenge timeout") || + message.includes("gateway request timeout for connect") + ); +} + function extractTranscriptMessageText(message: unknown): string { if (!message || typeof message !== "object") { return ""; @@ -1841,8 +1890,14 @@ describeLive("gateway live (dev agent, profile keys)", () => { const useExplicit = Boolean(rawModels) && !useModern; const filter = useExplicit ? parseFilter(rawModels) : null; const maxModels = GATEWAY_LIVE_MAX_MODELS; + const targetMatcher = createLiveTargetMatcher({ + providerFilter: PROVIDERS, + modelFilter: filter, + config: cfg, + env: process.env, + }); const wanted = filter - ? all.filter((m) => filter.has(`${m.provider}/${m.id}`)) + ? all.filter((m) => targetMatcher.matchesModel(m.provider, m.id)) : all.filter((m) => isHighSignalLiveModelRef({ provider: m.provider, id: m.id })); const candidates: Array> = []; @@ -1851,7 +1906,7 @@ describeLive("gateway live (dev agent, profile keys)", () => { if (shouldSuppressBuiltInModel({ provider: model.provider, id: model.id })) { continue; } - if (PROVIDERS && !PROVIDERS.has(model.provider)) { + if (!targetMatcher.matchesProvider(model.provider)) { continue; } const modelRef = `${model.provider}/${model.id}`;