fix(tools): defer media generation provider discovery

This commit is contained in:
Ayaan Zaidi
2026-05-01 08:23:33 +05:30
parent e0fe02fb09
commit 60bdb96f2c
7 changed files with 193 additions and 32 deletions

View File

@@ -297,6 +297,7 @@ describe("createImageGenerateTool", () => {
});
it("infers the canonical OpenAI image model from provider readiness without explicit config", () => {
vi.stubEnv("OPENAI_API_KEY", "openai-test");
const isConfigured = vi.fn(({ agentDir }: { agentDir?: string }) => agentDir === "/tmp/agent");
vi.spyOn(imageGenerationRuntime, "listRuntimeImageGenerationProviders").mockReturnValue([
{

View File

@@ -39,6 +39,7 @@ import { decodeDataUrl } from "./image-tool.helpers.js";
import {
applyImageGenerationModelConfigDefaults,
buildMediaReferenceDetails,
hasGenerationToolAvailability,
isCapabilityProviderConfigured,
normalizeMediaReferenceInputs,
readGenerationTimeoutMs,
@@ -567,16 +568,16 @@ export function createImageGenerateTool(options?: {
fsPolicy?: ToolFsPolicy;
}): AnyAgentTool | null {
const cfg = options?.config ?? getRuntimeConfig();
const imageGenerationModelConfig = resolveImageGenerationModelConfigForTool({
cfg,
agentDir: options?.agentDir,
});
if (!imageGenerationModelConfig) {
if (
!hasGenerationToolAvailability({
cfg,
agentDir: options?.agentDir,
modelConfig: cfg.agents?.defaults?.imageGenerationModel,
providerKey: "imageGenerationProviders",
})
) {
return null;
}
const effectiveCfg =
applyImageGenerationModelConfigDefaults(cfg, imageGenerationModelConfig) ?? cfg;
const remoteMediaSsrfPolicy = resolveRemoteMediaSsrfPolicy(effectiveCfg);
const sandboxConfig =
options?.sandbox && options.sandbox.root.trim()
? {
@@ -596,7 +597,7 @@ export function createImageGenerateTool(options?: {
const params = args as Record<string, unknown>;
const action = resolveAction(params);
if (action === "list") {
const runtimeProviders = listRuntimeImageGenerationProviders({ config: effectiveCfg });
const runtimeProviders = listRuntimeImageGenerationProviders({ config: cfg });
const providers = runtimeProviders.map((provider) =>
Object.assign(
{ id: provider.id },
@@ -607,7 +608,7 @@ export function createImageGenerateTool(options?: {
configured: isCapabilityProviderConfigured({
providers: runtimeProviders,
provider,
cfg: effectiveCfg,
cfg,
agentDir: options?.agentDir,
}),
authEnvVars: getImageGenerationProviderAuthEnvVars(provider.id),
@@ -657,6 +658,16 @@ export function createImageGenerateTool(options?: {
};
}
const imageGenerationModelConfig = resolveImageGenerationModelConfigForTool({
cfg,
agentDir: options?.agentDir,
});
if (!imageGenerationModelConfig) {
throw new ToolInputError("No image-generation model configured.");
}
const effectiveCfg =
applyImageGenerationModelConfigDefaults(cfg, imageGenerationModelConfig) ?? cfg;
const remoteMediaSsrfPolicy = resolveRemoteMediaSsrfPolicy(effectiveCfg);
const prompt = readStringParam(params, "prompt", { required: true });
const imageInputs = normalizeReferenceImages(params);
const model = readStringParam(params, "model");

View File

@@ -4,6 +4,7 @@ import type { OpenClawConfig } from "../../config/types.openclaw.js";
import type { SsrFPolicy } from "../../infra/net/ssrf.js";
import { getDefaultLocalRoots } from "../../media/web-media.js";
import { readSnakeCaseParamRaw } from "../../param-key.js";
import { resolveBundledCapabilityProviderIds } from "../../plugins/capability-provider-runtime.js";
import {
normalizeOptionalLowercaseString,
normalizeOptionalString,
@@ -131,6 +132,11 @@ type CapabilityProvider = {
isConfigured?: (ctx: { cfg?: OpenClawConfig; agentDir?: string }) => boolean;
};
type GenerationCapabilityProviderKey =
| "imageGenerationProviders"
| "videoGenerationProviders"
| "musicGenerationProviders";
export function findCapabilityProviderById<T extends CapabilityProvider>(params: {
providers: T[];
providerId?: string;
@@ -271,6 +277,21 @@ export function resolveCapabilityModelConfigForTool(params: {
});
}
export function hasGenerationToolAvailability(params: {
cfg?: OpenClawConfig;
agentDir?: string;
modelConfig?: AgentModelConfig;
providerKey: GenerationCapabilityProviderKey;
}): boolean {
if (hasToolModelConfig(coerceToolModelConfig(params.modelConfig))) {
return true;
}
return resolveBundledCapabilityProviderIds({
key: params.providerKey,
cfg: params.cfg,
}).some((providerId) => hasAuthForProvider({ provider: providerId, agentDir: params.agentDir }));
}
function formatQuotedList(values: readonly string[]): string {
if (values.length === 1) {
return `"${values[0]}"`;

View File

@@ -33,6 +33,7 @@ import {
applyMusicGenerationModelConfigDefaults,
buildMediaReferenceDetails,
buildTaskRunDetails,
hasGenerationToolAvailability,
normalizeMediaReferenceInputs,
readBooleanToolParam,
readGenerationTimeoutMs,
@@ -495,11 +496,14 @@ export function createMusicGenerateTool(options?: {
scheduleBackgroundWork?: MusicGenerateBackgroundScheduler;
}): AnyAgentTool | null {
const cfg: OpenClawConfig = options?.config ?? getRuntimeConfig();
const musicGenerationModelConfig = resolveMusicGenerationModelConfigForTool({
cfg,
agentDir: options?.agentDir,
});
if (!musicGenerationModelConfig) {
if (
!hasGenerationToolAvailability({
cfg,
agentDir: options?.agentDir,
modelConfig: cfg.agents?.defaults?.musicGenerationModel,
providerKey: "musicGenerationProviders",
})
) {
return null;
}
@@ -523,17 +527,25 @@ export function createMusicGenerateTool(options?: {
execute: async (_toolCallId, rawArgs) => {
const args = rawArgs as Record<string, unknown>;
const action = resolveAction(args);
const effectiveCfg =
applyMusicGenerationModelConfigDefaults(cfg, musicGenerationModelConfig) ?? cfg;
if (action === "list") {
return createMusicGenerateListActionResult(effectiveCfg);
return createMusicGenerateListActionResult(cfg);
}
if (action === "status") {
return createMusicGenerateStatusActionResult(options?.agentSessionKey);
}
const musicGenerationModelConfig = resolveMusicGenerationModelConfigForTool({
cfg,
agentDir: options?.agentDir,
});
if (!musicGenerationModelConfig) {
throw new ToolInputError("No music-generation model configured.");
}
const effectiveCfg =
applyMusicGenerationModelConfigDefaults(cfg, musicGenerationModelConfig) ?? cfg;
const duplicateGuardResult = createMusicGenerateDuplicateGuardResult(
options?.agentSessionKey,
);

View File

@@ -36,6 +36,7 @@ import {
applyVideoGenerationModelConfigDefaults,
buildMediaReferenceDetails,
buildTaskRunDetails,
hasGenerationToolAvailability,
normalizeMediaReferenceInputs,
readBooleanToolParam,
readGenerationTimeoutMs,
@@ -802,11 +803,14 @@ export function createVideoGenerateTool(options?: {
scheduleBackgroundWork?: VideoGenerateBackgroundScheduler;
}): AnyAgentTool | null {
const cfg: OpenClawConfig = options?.config ?? getRuntimeConfig();
const videoGenerationModelConfig = resolveVideoGenerationModelConfigForTool({
cfg,
agentDir: options?.agentDir,
});
if (!videoGenerationModelConfig) {
if (
!hasGenerationToolAvailability({
cfg,
agentDir: options?.agentDir,
modelConfig: cfg.agents?.defaults?.videoGenerationModel,
providerKey: "videoGenerationProviders",
})
) {
return null;
}
@@ -830,18 +834,26 @@ export function createVideoGenerateTool(options?: {
execute: async (_toolCallId, rawArgs) => {
const args = rawArgs as Record<string, unknown>;
const action = resolveAction(args);
const effectiveCfg =
applyVideoGenerationModelConfigDefaults(cfg, videoGenerationModelConfig) ?? cfg;
const remoteMediaSsrfPolicy = resolveRemoteMediaSsrfPolicy(effectiveCfg);
if (action === "list") {
return createVideoGenerateListActionResult(effectiveCfg);
return createVideoGenerateListActionResult(cfg);
}
if (action === "status") {
return createVideoGenerateStatusActionResult(options?.agentSessionKey);
}
const videoGenerationModelConfig = resolveVideoGenerationModelConfigForTool({
cfg,
agentDir: options?.agentDir,
});
if (!videoGenerationModelConfig) {
throw new ToolInputError("No video-generation model configured.");
}
const effectiveCfg =
applyVideoGenerationModelConfigDefaults(cfg, videoGenerationModelConfig) ?? cfg;
const remoteMediaSsrfPolicy = resolveRemoteMediaSsrfPolicy(effectiveCfg);
const duplicateGuardResult = createVideoGenerateDuplicateGuardResult(
options?.agentSessionKey,
);

View File

@@ -536,6 +536,40 @@ describe("resolvePluginCapabilityProviders", () => {
});
});
it("reuses capability snapshot loads for the same config object", () => {
const { cfg, enablementCompat } = createCompatChainConfig();
const loaded = createEmptyPluginRegistry();
loaded.mediaUnderstandingProviders.push({
pluginId: "openai",
pluginName: "openai",
source: "test",
provider: {
id: "openai",
capabilities: ["image"],
},
} as never);
setBundledCapabilityFixture("mediaUnderstandingProviders");
mocks.withBundledPluginEnablementCompat.mockReturnValue(enablementCompat);
mocks.withBundledPluginVitestCompat.mockReturnValue(enablementCompat);
mocks.resolveRuntimePluginRegistry.mockImplementation((params?: unknown) =>
params === undefined ? undefined : loaded,
);
expectResolvedCapabilityProviderIds(
resolvePluginCapabilityProviders({ key: "mediaUnderstandingProviders", cfg }),
["openai"],
);
expectResolvedCapabilityProviderIds(
resolvePluginCapabilityProviders({ key: "mediaUnderstandingProviders", cfg }),
["openai"],
);
const snapshotLoads = mocks.resolveRuntimePluginRegistry.mock.calls.filter(
([options]) => options !== undefined,
);
expect(snapshotLoads).toHaveLength(1);
});
it("resolves manifest-derived capability plugin ids for equivalent config snapshots independently", () => {
const first = createCompatChainConfig();
const second = createCompatChainConfig();

View File

@@ -4,7 +4,11 @@ import {
withBundledPluginEnablementCompat,
withBundledPluginVitestCompat,
} from "./bundled-compat.js";
import { resolveRuntimePluginRegistry, type PluginLoadOptions } from "./loader.js";
import {
resolvePluginRegistryLoadCacheKey,
resolveRuntimePluginRegistry,
type PluginLoadOptions,
} from "./loader.js";
import { loadPluginManifestRegistryForPluginRegistry } from "./plugin-registry.js";
import type { PluginRegistry } from "./registry-types.js";
@@ -30,6 +34,12 @@ type CapabilityContractKey =
type CapabilityProviderForKey<K extends CapabilityProviderRegistryKey> =
PluginRegistry[K][number] extends { provider: infer T } ? T : never;
type CapabilityProviderEntries = PluginRegistry[CapabilityProviderRegistryKey];
const capabilityProviderSnapshotCache = new WeakMap<
OpenClawConfig,
Map<string, CapabilityProviderEntries>
>();
const CAPABILITY_CONTRACT_KEY: Record<CapabilityProviderRegistryKey, CapabilityContractKey> = {
memoryEmbeddingProviders: "memoryEmbeddingProviders",
@@ -64,6 +74,25 @@ function resolveBundledCapabilityCompatPluginIds(params: {
.toSorted((left, right) => left.localeCompare(right));
}
export function resolveBundledCapabilityProviderIds(params: {
key: CapabilityProviderRegistryKey;
cfg?: OpenClawConfig;
}): string[] {
const env = process.env;
const contractKey = CAPABILITY_CONTRACT_KEY[params.key];
return [
...new Set(
loadPluginManifestRegistryForPluginRegistry({
config: params.cfg,
env,
includeDisabled: true,
}).plugins.flatMap((plugin) =>
plugin.origin === "bundled" ? (plugin.contracts?.[contractKey] ?? []) : [],
),
),
].toSorted((left, right) => left.localeCompare(right));
}
function resolveCapabilityProviderConfig(params: {
key: CapabilityProviderRegistryKey;
cfg?: OpenClawConfig;
@@ -101,6 +130,30 @@ function createCapabilityProviderFallbackLoadOptions(params: {
return loadOptions;
}
function resolveCapabilityProviderSnapshotCache(
cfg: OpenClawConfig | undefined,
): Map<string, CapabilityProviderEntries> | undefined {
if (!cfg) {
return undefined;
}
let cache = capabilityProviderSnapshotCache.get(cfg);
if (!cache) {
cache = new Map();
capabilityProviderSnapshotCache.set(cfg, cache);
}
return cache;
}
function resolveCapabilityProviderSnapshotCacheKey(params: {
key: CapabilityProviderRegistryKey;
loadOptions: PluginLoadOptions;
}): string {
return JSON.stringify({
key: params.key,
load: resolvePluginRegistryLoadCacheKey(params.loadOptions),
});
}
function findProviderById<K extends CapabilityProviderRegistryKey>(
entries: PluginRegistry[K],
providerId: string,
@@ -246,8 +299,17 @@ export function resolvePluginCapabilityProvider<K extends CapabilityProviderRegi
pluginIds,
installBundledRuntimeDeps: params.installBundledRuntimeDeps,
});
const registry = resolveRuntimePluginRegistry(loadOptions);
return findProviderById(registry?.[params.key] ?? [], params.providerId);
const cache = resolveCapabilityProviderSnapshotCache(params.cfg);
const cacheKey = cache
? resolveCapabilityProviderSnapshotCacheKey({ key: params.key, loadOptions })
: "";
let loadedProviders = cache?.get(cacheKey) as PluginRegistry[K] | undefined;
if (!loadedProviders) {
const registry = resolveRuntimePluginRegistry(loadOptions);
loadedProviders = registry?.[params.key] ?? [];
cache?.set(cacheKey, loadedProviders as CapabilityProviderEntries);
}
return findProviderById(loadedProviders, params.providerId);
}
export function resolvePluginCapabilityProviders<K extends CapabilityProviderRegistryKey>(params: {
@@ -291,8 +353,16 @@ export function resolvePluginCapabilityProviders<K extends CapabilityProviderReg
pluginIds,
installBundledRuntimeDeps: params.installBundledRuntimeDeps,
});
const registry = resolveRuntimePluginRegistry(loadOptions);
const loadedProviders = registry?.[params.key] ?? [];
const cache = resolveCapabilityProviderSnapshotCache(params.cfg);
const cacheKey = cache
? resolveCapabilityProviderSnapshotCacheKey({ key: params.key, loadOptions })
: "";
let loadedProviders = cache?.get(cacheKey) as PluginRegistry[K] | undefined;
if (!loadedProviders) {
const registry = resolveRuntimePluginRegistry(loadOptions);
loadedProviders = registry?.[params.key] ?? [];
cache?.set(cacheKey, loadedProviders as CapabilityProviderEntries);
}
if (params.key !== "memoryEmbeddingProviders") {
const mergeLoadedProviders =
activeProviders.length > 0