fix: bound static provider catalog listing

This commit is contained in:
Shakker
2026-04-22 03:21:03 +01:00
committed by Shakker
parent d6c7b468ea
commit f3da6e96b7
2 changed files with 181 additions and 8 deletions

View File

@@ -1,4 +1,5 @@
import { afterEach, describe, expect, it, vi } from "vitest";
import type { ProviderPlugin } from "../../plugins/types.js";
import {
loadProviderCatalogModelsForList,
resolveProviderCatalogPluginIdsForFilter,
@@ -23,6 +24,7 @@ const baseParams = {
describe("loadProviderCatalogModelsForList", () => {
afterEach(() => {
vi.useRealTimers();
vi.restoreAllMocks();
});
@@ -51,6 +53,48 @@ describe("loadProviderCatalogModelsForList", () => {
);
});
it("skips static catalogs that exceed the display budget", async () => {
vi.useFakeTimers();
const hungProvider = {
id: "hung",
label: "Hung",
auth: [],
staticCatalog: {
run: async () => new Promise<never>(() => {}),
},
} satisfies ProviderPlugin;
const healthyProvider = {
id: "healthy",
label: "Healthy",
auth: [],
staticCatalog: {
run: async () => ({
provider: {
baseUrl: "https://healthy.example/v1",
models: [{ id: "healthy-model", name: "Healthy Model" }],
},
}),
},
} satisfies ProviderPlugin;
const discovery = await import("../../plugins/provider-discovery.js");
vi.spyOn(discovery, "resolvePluginDiscoveryProviders").mockResolvedValue([
hungProvider,
healthyProvider,
]);
const rowsPromise = loadProviderCatalogModelsForList({
...baseParams,
});
await vi.advanceTimersByTimeAsync(2_000);
await expect(rowsPromise).resolves.toEqual([
expect.objectContaining({
provider: "healthy",
id: "healthy-model",
}),
]);
});
it("recognizes bundled provider hook aliases before the unknown-provider short-circuit", async () => {
await expect(
resolveProviderCatalogPluginIdsForFilter({
@@ -61,6 +105,49 @@ describe("loadProviderCatalogModelsForList", () => {
).resolves.toEqual(["openai"]);
});
it("recognizes trusted workspace provider aliases before the unknown-provider short-circuit", async () => {
const manifestRegistry = await import("../../plugins/manifest-registry.js");
const providers = await import("../../plugins/providers.js");
const discovery = await import("../../plugins/provider-discovery.js");
vi.spyOn(manifestRegistry, "loadPluginManifestRegistry").mockReturnValue({
plugins: [
{
id: "workspace-demo",
origin: "workspace",
providers: ["workspace-demo"],
cliBackends: [],
},
],
diagnostics: [],
} as never);
vi.spyOn(providers, "resolveDiscoveredProviderPluginIds").mockReturnValue(["workspace-demo"]);
vi.spyOn(discovery, "resolvePluginDiscoveryProviders").mockResolvedValue([
{
id: "workspace-demo",
pluginId: "workspace-demo",
label: "Workspace Demo",
aliases: ["workspace-demo-alias"],
auth: [],
staticCatalog: {
run: async () => ({
provider: {
baseUrl: "https://workspace.example/v1",
models: [],
},
}),
},
},
]);
await expect(
resolveProviderCatalogPluginIdsForFilter({
cfg: baseParams.cfg,
env: baseParams.env,
providerFilter: "workspace-demo-alias",
}),
).resolves.toEqual(["workspace-demo"]);
});
it("keeps unknown provider filters eligible for early empty results", async () => {
await expect(
resolveProviderCatalogPluginIdsForFilter({

View File

@@ -4,18 +4,71 @@ import type { ModelProviderConfig } from "../../config/types.models.js";
import type { OpenClawConfig } from "../../config/types.openclaw.js";
import { formatErrorMessage } from "../../infra/errors.js";
import { createSubsystemLogger } from "../../logging/subsystem.js";
import { loadPluginManifestRegistry } from "../../plugins/manifest-registry.js";
import {
groupPluginDiscoveryProvidersByOrder,
normalizePluginDiscoveryResult,
resolvePluginDiscoveryProviders,
runProviderStaticCatalog,
} from "../../plugins/provider-discovery.js";
import { resolveOwningPluginIdsForProvider } from "../../plugins/providers.js";
import {
resolveDiscoveredProviderPluginIds,
resolveOwningPluginIdsForProvider,
} from "../../plugins/providers.js";
import type { ProviderPlugin } from "../../plugins/types.js";
const DISCOVERY_ORDERS = ["simple", "profile", "paired", "late"] as const;
const SELF_HOSTED_DISCOVERY_PROVIDER_IDS = new Set(["lmstudio", "ollama", "sglang", "vllm"]);
const STATIC_CATALOG_TIMEOUT_MS = 2_000;
const log = createSubsystemLogger("models/list-provider-catalog");
function providerMatchesFilterAlias(provider: ProviderPlugin, providerFilter: string): boolean {
return [provider.id, ...(provider.aliases ?? []), ...(provider.hookAliases ?? [])].some(
(providerId) => normalizeProviderId(providerId) === providerFilter,
);
}
async function resolveWorkspacePluginIdsForProviderAlias(params: {
cfg: OpenClawConfig;
env?: NodeJS.ProcessEnv;
providerFilter: string;
}): Promise<string[] | undefined> {
const discoverablePluginIds = new Set(
resolveDiscoveredProviderPluginIds({
config: params.cfg,
env: params.env,
includeUntrustedWorkspacePlugins: false,
}),
);
const workspacePluginIds = loadPluginManifestRegistry({
config: params.cfg,
env: params.env,
})
.plugins.filter(
(plugin) => plugin.origin === "workspace" && discoverablePluginIds.has(plugin.id),
)
.map((plugin) => plugin.id);
if (workspacePluginIds.length === 0) {
return undefined;
}
const providers = await resolvePluginDiscoveryProviders({
config: params.cfg,
env: params.env,
onlyPluginIds: workspacePluginIds,
includeUntrustedWorkspacePlugins: false,
});
const pluginIds = [
...new Set(
providers
.filter((provider) => providerMatchesFilterAlias(provider, params.providerFilter))
.map((provider) => provider.pluginId)
.filter((pluginId): pluginId is string => typeof pluginId === "string" && pluginId !== ""),
),
].toSorted((left, right) => left.localeCompare(right));
return pluginIds.length > 0 ? pluginIds : undefined;
}
export async function resolveProviderCatalogPluginIdsForFilter(params: {
cfg: OpenClawConfig;
env?: NodeJS.ProcessEnv;
@@ -35,7 +88,15 @@ export async function resolveProviderCatalogPluginIdsForFilter(params: {
}
const { resolveProviderContractPluginIdsForProviderAlias } =
await import("../../plugins/contracts/registry.js");
return resolveProviderContractPluginIdsForProviderAlias(providerFilter);
const bundledAliasPluginIds = resolveProviderContractPluginIdsForProviderAlias(providerFilter);
if (bundledAliasPluginIds) {
return bundledAliasPluginIds;
}
return await resolveWorkspacePluginIdsForProviderAlias({
cfg: params.cfg,
env: params.env,
providerFilter,
});
}
function modelFromProviderCatalog(params: {
@@ -60,6 +121,29 @@ function modelFromProviderCatalog(params: {
} as Model<Api>;
}
async function withStaticCatalogTimeout<T>(
providerId: string,
run: () => T | Promise<T>,
): Promise<T> {
let timer: NodeJS.Timeout | undefined;
const timeout = new Promise<never>((_, reject) => {
timer = setTimeout(() => {
reject(
new Error(
`provider static catalog timed out for ${providerId} after ${STATIC_CATALOG_TIMEOUT_MS}ms`,
),
);
}, STATIC_CATALOG_TIMEOUT_MS);
});
try {
return await Promise.race([Promise.resolve().then(run), timeout]);
} finally {
if (timer) {
clearTimeout(timer);
}
}
}
export async function loadProviderCatalogModelsForList(params: {
cfg: OpenClawConfig;
agentDir: string;
@@ -95,12 +179,14 @@ export async function loadProviderCatalogModelsForList(params: {
}
let result: Awaited<ReturnType<typeof runProviderStaticCatalog>> | null;
try {
result = await runProviderStaticCatalog({
provider,
config: params.cfg,
agentDir: params.agentDir,
env,
});
result = await withStaticCatalogTimeout(provider.id, () =>
runProviderStaticCatalog({
provider,
config: params.cfg,
agentDir: params.agentDir,
env,
}),
);
} catch (error) {
log.warn(`provider static catalog failed for ${provider.id}: ${formatErrorMessage(error)}`);
result = null;