diff --git a/src/agents/models-config.providers.discovery.ts b/src/agents/models-config.providers.discovery.ts index e8775014041..a6d99afa89f 100644 --- a/src/agents/models-config.providers.discovery.ts +++ b/src/agents/models-config.providers.discovery.ts @@ -31,33 +31,20 @@ const log = createSubsystemLogger("agents/model-providers"); const OLLAMA_SHOW_CONCURRENCY = 8; const OLLAMA_SHOW_MAX_MODELS = 200; -const SGLANG_BASE_URL = "http://127.0.0.1:30000/v1"; -const SGLANG_DEFAULT_CONTEXT_WINDOW = 128000; -const SGLANG_DEFAULT_MAX_TOKENS = 8192; -const SGLANG_DEFAULT_COST = { +const OPENAI_COMPAT_LOCAL_DEFAULT_CONTEXT_WINDOW = 128000; +const OPENAI_COMPAT_LOCAL_DEFAULT_MAX_TOKENS = 8192; +const OPENAI_COMPAT_LOCAL_DEFAULT_COST = { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, }; +const SGLANG_BASE_URL = "http://127.0.0.1:30000/v1"; + const VLLM_BASE_URL = "http://127.0.0.1:8000/v1"; -const VLLM_DEFAULT_CONTEXT_WINDOW = 128000; -const VLLM_DEFAULT_MAX_TOKENS = 8192; -const VLLM_DEFAULT_COST = { - input: 0, - output: 0, - cacheRead: 0, - cacheWrite: 0, -}; -type VllmModelsResponse = { - data?: Array<{ - id?: string; - }>; -}; - -type SglangModelsResponse = { +type OpenAICompatModelsResponse = { data?: Array<{ id?: string; }>; @@ -112,31 +99,34 @@ async function discoverOllamaModels( } } -async function discoverVllmModels( - baseUrl: string, - apiKey?: string, -): Promise { +async function discoverOpenAICompatibleLocalModels(params: { + baseUrl: string; + apiKey?: string; + label: string; + contextWindow?: number; + maxTokens?: number; +}): Promise { if (process.env.VITEST || process.env.NODE_ENV === "test") { return []; } - const trimmedBaseUrl = baseUrl.trim().replace(/\/+$/, ""); + const trimmedBaseUrl = params.baseUrl.trim().replace(/\/+$/, ""); const url = `${trimmedBaseUrl}/models`; try { - const trimmedApiKey = apiKey?.trim(); + const trimmedApiKey = params.apiKey?.trim(); const response = await fetch(url, { headers: trimmedApiKey ? { Authorization: `Bearer ${trimmedApiKey}` } : undefined, signal: AbortSignal.timeout(5000), }); if (!response.ok) { - log.warn(`Failed to discover vLLM models: ${response.status}`); + log.warn(`Failed to discover ${params.label} models: ${response.status}`); return []; } - const data = (await response.json()) as VllmModelsResponse; + const data = (await response.json()) as OpenAICompatModelsResponse; const models = data.data ?? []; if (models.length === 0) { - log.warn("No vLLM models found on local instance"); + log.warn(`No ${params.label} models found on local instance`); return []; } @@ -150,62 +140,13 @@ async function discoverVllmModels( name: modelId, reasoning: isReasoningModelHeuristic(modelId), input: ["text"], - cost: VLLM_DEFAULT_COST, - contextWindow: VLLM_DEFAULT_CONTEXT_WINDOW, - maxTokens: VLLM_DEFAULT_MAX_TOKENS, + cost: OPENAI_COMPAT_LOCAL_DEFAULT_COST, + contextWindow: params.contextWindow ?? OPENAI_COMPAT_LOCAL_DEFAULT_CONTEXT_WINDOW, + maxTokens: params.maxTokens ?? OPENAI_COMPAT_LOCAL_DEFAULT_MAX_TOKENS, } satisfies ModelDefinitionConfig; }); } catch (error) { - log.warn(`Failed to discover vLLM models: ${String(error)}`); - return []; - } -} - -async function discoverSglangModels( - baseUrl: string, - apiKey?: string, -): Promise { - if (process.env.VITEST || process.env.NODE_ENV === "test") { - return []; - } - - const trimmedBaseUrl = baseUrl.trim().replace(/\/+$/, ""); - const url = `${trimmedBaseUrl}/models`; - - try { - const trimmedApiKey = apiKey?.trim(); - const response = await fetch(url, { - headers: trimmedApiKey ? { Authorization: `Bearer ${trimmedApiKey}` } : undefined, - signal: AbortSignal.timeout(5000), - }); - if (!response.ok) { - log.warn(`Failed to discover SGLang models: ${response.status}`); - return []; - } - const data = (await response.json()) as SglangModelsResponse; - const models = data.data ?? []; - if (models.length === 0) { - log.warn("No SGLang models found on local instance"); - return []; - } - - return models - .map((model) => ({ id: typeof model.id === "string" ? model.id.trim() : "" })) - .filter((model) => Boolean(model.id)) - .map((model) => { - const modelId = model.id; - return { - id: modelId, - name: modelId, - reasoning: isReasoningModelHeuristic(modelId), - input: ["text"], - cost: SGLANG_DEFAULT_COST, - contextWindow: SGLANG_DEFAULT_CONTEXT_WINDOW, - maxTokens: SGLANG_DEFAULT_MAX_TOKENS, - } satisfies ModelDefinitionConfig; - }); - } catch (error) { - log.warn(`Failed to discover SGLang models: ${String(error)}`); + log.warn(`Failed to discover ${params.label} models: ${String(error)}`); return []; } } @@ -257,7 +198,11 @@ export async function buildVllmProvider(params?: { apiKey?: string; }): Promise { const baseUrl = (params?.baseUrl?.trim() || VLLM_BASE_URL).replace(/\/+$/, ""); - const models = await discoverVllmModels(baseUrl, params?.apiKey); + const models = await discoverOpenAICompatibleLocalModels({ + baseUrl, + apiKey: params?.apiKey, + label: "vLLM", + }); return { baseUrl, api: "openai-completions", @@ -270,7 +215,11 @@ export async function buildSglangProvider(params?: { apiKey?: string; }): Promise { const baseUrl = (params?.baseUrl?.trim() || SGLANG_BASE_URL).replace(/\/+$/, ""); - const models = await discoverSglangModels(baseUrl, params?.apiKey); + const models = await discoverOpenAICompatibleLocalModels({ + baseUrl, + apiKey: params?.apiKey, + label: "SGLang", + }); return { baseUrl, api: "openai-completions",