Files
openclaw/src/agents/tools/media-tool-shared.ts

560 lines
16 KiB
TypeScript

import { type Api, type Model } from "@mariozechner/pi-ai";
import type { AgentModelConfig } from "../../config/types.agents-shared.js";
import type { OpenClawConfig } from "../../config/types.openclaw.js";
import type { SsrFPolicy } from "../../infra/net/ssrf.js";
import { getDefaultLocalRoots } from "../../media/web-media.js";
import { readSnakeCaseParamRaw } from "../../param-key.js";
import { loadCapabilityManifestSnapshot } from "../../plugins/capability-provider-runtime.js";
import { listAvailableManifestContractValues } from "../../plugins/manifest-contract-eligibility.js";
import {
normalizeOptionalLowercaseString,
normalizeOptionalString,
} from "../../shared/string-coerce.js";
import type { AuthProfileStore } from "../auth-profiles/types.js";
import { normalizeModelRef } from "../model-selection.js";
import { normalizeProviderId } from "../provider-id.js";
import {
ToolInputError,
readNumberParam,
readStringArrayParam,
readStringParam,
} from "./common.js";
import type { ImageModelConfig } from "./image-tool.helpers.js";
import {
getCurrentCapabilityMetadataSnapshot,
hasSnapshotCapabilityAvailability,
} from "./manifest-capability-availability.js";
import {
buildToolModelConfigFromCandidates,
coerceToolModelConfig,
hasAuthForProvider,
hasToolModelConfig,
resolveDefaultModelRef,
type ToolModelConfig,
} from "./model-config.helpers.js";
import { getApiKeyForModel, normalizeWorkspaceDir, requireApiKey } from "./tool-runtime.helpers.js";
type TextToolAttempt = {
provider: string;
model: string;
error: string;
};
type TextToolResult = {
text: string;
provider: string;
model: string;
attempts: TextToolAttempt[];
};
type GenerationModelRef = {
provider: string;
model: string;
};
type ParseGenerationModelRef = (raw: string | undefined) => GenerationModelRef | null;
type MediaReferenceDetailEntry = {
rewrittenFrom?: string;
};
type TaskRunDetailHandle = {
taskId: string;
runId: string;
};
export function applyImageModelConfigDefaults(
cfg: OpenClawConfig | undefined,
imageModelConfig: ImageModelConfig,
): OpenClawConfig | undefined {
return applyAgentDefaultModelConfig(cfg, "imageModel", imageModelConfig);
}
export function applyImageGenerationModelConfigDefaults(
cfg: OpenClawConfig | undefined,
imageGenerationModelConfig: ToolModelConfig,
): OpenClawConfig | undefined {
return applyAgentDefaultModelConfig(cfg, "imageGenerationModel", imageGenerationModelConfig);
}
export function applyVideoGenerationModelConfigDefaults(
cfg: OpenClawConfig | undefined,
videoGenerationModelConfig: ToolModelConfig,
): OpenClawConfig | undefined {
return applyAgentDefaultModelConfig(cfg, "videoGenerationModel", videoGenerationModelConfig);
}
export function applyMusicGenerationModelConfigDefaults(
cfg: OpenClawConfig | undefined,
musicGenerationModelConfig: ToolModelConfig,
): OpenClawConfig | undefined {
return applyAgentDefaultModelConfig(cfg, "musicGenerationModel", musicGenerationModelConfig);
}
export function readGenerationTimeoutMs(args: Record<string, unknown>): number | undefined {
const timeoutMs = readNumberParam(args, "timeoutMs", {
integer: true,
strict: true,
});
if (timeoutMs === undefined) {
return undefined;
}
if (timeoutMs <= 0) {
throw new ToolInputError("timeoutMs must be a positive integer in milliseconds.");
}
return timeoutMs;
}
export function resolveRemoteMediaSsrfPolicy(
cfg: OpenClawConfig | undefined,
): SsrFPolicy | undefined {
return cfg?.tools?.web?.fetch?.ssrfPolicy;
}
function applyAgentDefaultModelConfig(
cfg: OpenClawConfig | undefined,
key: "imageModel" | "imageGenerationModel" | "videoGenerationModel" | "musicGenerationModel",
modelConfig: ToolModelConfig,
): OpenClawConfig | undefined {
if (!cfg) {
return undefined;
}
return {
...cfg,
agents: {
...cfg.agents,
defaults: {
...cfg.agents?.defaults,
[key]: modelConfig,
},
},
};
}
type CapabilityProvider = {
id: string;
aliases?: string[];
defaultModel?: string;
isConfigured?: (ctx: { cfg?: OpenClawConfig; agentDir?: string }) => boolean;
};
type GenerationCapabilityProviderKey =
| "imageGenerationProviders"
| "videoGenerationProviders"
| "musicGenerationProviders";
export function findCapabilityProviderById<T extends CapabilityProvider>(params: {
providers: T[];
providerId?: string;
}): T | undefined {
const selectedProvider = normalizeProviderId(params.providerId ?? "");
return params.providers.find(
(provider) =>
normalizeProviderId(provider.id) === selectedProvider ||
(provider.aliases ?? []).some((alias) => normalizeProviderId(alias) === selectedProvider),
);
}
export function isCapabilityProviderConfigured<T extends CapabilityProvider>(params: {
providers: T[];
provider?: T;
providerId?: string;
cfg?: OpenClawConfig;
agentDir?: string;
authStore?: AuthProfileStore;
}): boolean {
const provider =
params.provider ??
findCapabilityProviderById({
providers: params.providers,
providerId: params.providerId,
});
if (!provider) {
return params.providerId
? hasAuthForProvider({
provider: params.providerId,
agentDir: params.agentDir,
authStore: params.authStore,
})
: false;
}
if (provider.isConfigured) {
return provider.isConfigured({
cfg: params.cfg,
agentDir: params.agentDir,
});
}
return hasAuthForProvider({
provider: provider.id,
agentDir: params.agentDir,
authStore: params.authStore,
});
}
export function resolveSelectedCapabilityProvider<T extends CapabilityProvider>(params: {
providers: T[];
modelConfig: ToolModelConfig;
modelOverride?: string;
parseModelRef: ParseGenerationModelRef;
}): T | undefined {
const selectedRef =
params.parseModelRef(params.modelOverride) ?? params.parseModelRef(params.modelConfig.primary);
if (!selectedRef) {
return undefined;
}
return findCapabilityProviderById({
providers: params.providers,
providerId: selectedRef.provider,
});
}
export function resolveCapabilityModelCandidatesForTool(params: {
cfg?: OpenClawConfig;
agentDir?: string;
authStore?: AuthProfileStore;
providers: CapabilityProvider[];
}): string[] {
const providerDefaults = new Map<string, { ref: string; aliases: string[] }>();
for (const provider of params.providers) {
const providerId = provider.id.trim();
const modelId = provider.defaultModel?.trim();
if (
!providerId ||
!modelId ||
providerDefaults.has(providerId) ||
!isCapabilityProviderConfigured({
providers: params.providers,
provider,
cfg: params.cfg,
agentDir: params.agentDir,
authStore: params.authStore,
})
) {
continue;
}
const aliases = (provider.aliases ?? []).flatMap((alias) => {
const normalized = normalizeProviderId(alias);
return normalized ? [normalized] : [];
});
providerDefaults.set(providerId, { ref: `${providerId}/${modelId}`, aliases });
}
const primaryProvider = resolveDefaultModelRef(params.cfg).provider;
const normalizedPrimaryProvider = normalizeProviderId(primaryProvider);
const providerIds = [...providerDefaults.keys()].toSorted();
const matchesPrimaryProvider = (providerId: string): boolean => {
const entry = providerDefaults.get(providerId);
return (
normalizeProviderId(providerId) === normalizedPrimaryProvider ||
(entry?.aliases ?? []).includes(normalizedPrimaryProvider)
);
};
const orderedProviders = [
...providerIds.filter(matchesPrimaryProvider),
...providerIds.filter((providerId) => !matchesPrimaryProvider(providerId)),
];
const orderedRefs: string[] = [];
const seen = new Set<string>();
for (const providerId of orderedProviders) {
const entry = providerDefaults.get(providerId);
if (!entry || seen.has(entry.ref)) {
continue;
}
seen.add(entry.ref);
orderedRefs.push(entry.ref);
}
return orderedRefs;
}
export function resolveCapabilityModelConfigForTool(params: {
cfg?: OpenClawConfig;
agentDir?: string;
authStore?: AuthProfileStore;
modelConfig?: AgentModelConfig;
providers: CapabilityProvider[];
}): ToolModelConfig | null {
const explicit = coerceToolModelConfig(params.modelConfig);
if (hasToolModelConfig(explicit)) {
return explicit;
}
return buildToolModelConfigFromCandidates({
explicit,
agentDir: params.agentDir,
authStore: params.authStore,
candidates: resolveCapabilityModelCandidatesForTool({
cfg: params.cfg,
agentDir: params.agentDir,
authStore: params.authStore,
providers: params.providers,
}),
isProviderConfigured: (providerId) =>
isCapabilityProviderConfigured({
providers: params.providers,
providerId,
cfg: params.cfg,
agentDir: params.agentDir,
authStore: params.authStore,
}),
});
}
export function hasGenerationToolAvailability(params: {
cfg?: OpenClawConfig;
agentDir?: string;
workspaceDir?: string;
authStore?: AuthProfileStore;
modelConfig?: AgentModelConfig;
providers?: CapabilityProvider[] | (() => CapabilityProvider[]);
providerKey: GenerationCapabilityProviderKey;
}): boolean {
if (hasToolModelConfig(coerceToolModelConfig(params.modelConfig))) {
return true;
}
const providers = typeof params.providers === "function" ? params.providers() : params.providers;
if (providers) {
return providers.some((provider) =>
isCapabilityProviderConfigured({
providers,
provider,
cfg: params.cfg,
agentDir: params.agentDir,
authStore: params.authStore,
}),
);
}
const snapshot =
getCurrentCapabilityMetadataSnapshot({
config: params.cfg,
workspaceDir: params.workspaceDir,
}) ??
loadCapabilityManifestSnapshot({
cfg: params.cfg,
workspaceDir: params.workspaceDir,
});
if (
hasSnapshotCapabilityAvailability({
snapshot,
key: params.providerKey,
config: params.cfg,
authStore: params.authStore,
})
) {
return true;
}
return listAvailableManifestContractValues({
snapshot,
contract: params.providerKey,
config: params.cfg,
}).some((providerId) =>
hasAuthForProvider({
provider: providerId,
agentDir: params.agentDir,
authStore: params.authStore,
}),
);
}
function formatQuotedList(values: readonly string[]): string {
if (values.length === 1) {
return `"${values[0]}"`;
}
if (values.length === 2) {
return `"${values[0]}" or "${values[1]}"`;
}
return `${values
.slice(0, -1)
.map((value) => `"${value}"`)
.join(", ")}, or "${values[values.length - 1]}"`;
}
export function resolveGenerateAction<TAction extends string>(params: {
args: Record<string, unknown>;
allowed: readonly TAction[];
defaultAction: TAction;
}): TAction {
const raw = readStringParam(params.args, "action");
if (!raw) {
return params.defaultAction;
}
const normalized = normalizeOptionalLowercaseString(raw);
if (normalized && (params.allowed as readonly string[]).includes(normalized)) {
return normalized as TAction;
}
throw new ToolInputError(`action must be ${formatQuotedList(params.allowed)}`);
}
export function readBooleanToolParam(
params: Record<string, unknown>,
key: string,
): boolean | undefined {
const raw = readSnakeCaseParamRaw(params, key);
if (typeof raw === "boolean") {
return raw;
}
if (typeof raw === "string") {
const normalized = normalizeOptionalLowercaseString(raw);
if (normalized === "true") {
return true;
}
if (normalized === "false") {
return false;
}
}
return undefined;
}
export function normalizeMediaReferenceInputs(params: {
args: Record<string, unknown>;
singularKey: string;
pluralKey: string;
maxCount: number;
label: string;
}): string[] {
const single = readStringParam(params.args, params.singularKey);
const multiple = readStringArrayParam(params.args, params.pluralKey);
const combined = [...(single ? [single] : []), ...(multiple ?? [])];
const deduped: string[] = [];
const seen = new Set<string>();
for (const candidate of combined) {
const trimmed = candidate.trim();
const dedupe = trimmed.startsWith("@") ? trimmed.slice(1).trim() : trimmed;
if (!dedupe || seen.has(dedupe)) {
continue;
}
seen.add(dedupe);
deduped.push(trimmed);
}
if (deduped.length > params.maxCount) {
throw new ToolInputError(
`Too many ${params.label}: ${deduped.length} provided, maximum is ${params.maxCount}.`,
);
}
return deduped;
}
export function buildMediaReferenceDetails<T extends MediaReferenceDetailEntry>(params: {
entries: readonly T[];
singleKey: string;
pluralKey: string;
getResolvedInput: (entry: T) => string | undefined;
singleRewriteKey?: string;
}): Record<string, unknown> {
if (params.entries.length === 1) {
const entry = params.entries[0];
if (!entry) {
return {};
}
const rewriteKey = params.singleRewriteKey ?? "rewrittenFrom";
return {
[params.singleKey]: params.getResolvedInput(entry),
...(entry.rewrittenFrom ? { [rewriteKey]: entry.rewrittenFrom } : {}),
};
}
if (params.entries.length > 1) {
return {
[params.pluralKey]: params.entries.map((entry) => ({
[params.singleKey]: params.getResolvedInput(entry),
...(entry.rewrittenFrom ? { rewrittenFrom: entry.rewrittenFrom } : {}),
})),
};
}
return {};
}
export function buildTaskRunDetails(
handle: TaskRunDetailHandle | null | undefined,
): Record<string, unknown> {
return handle
? {
task: {
taskId: handle.taskId,
runId: handle.runId,
},
}
: {};
}
export function resolveMediaToolLocalRoots(
workspaceDirRaw: string | undefined,
options?: { workspaceOnly?: boolean },
_mediaSources?: readonly string[],
): string[] {
const workspaceDir = normalizeWorkspaceDir(workspaceDirRaw);
if (options?.workspaceOnly) {
return workspaceDir ? [workspaceDir] : [];
}
const roots = getDefaultLocalRoots();
return workspaceDir ? Array.from(new Set([...roots, workspaceDir])) : [...roots];
}
export function resolvePromptAndModelOverride(
args: Record<string, unknown>,
defaultPrompt: string,
): {
prompt: string;
modelOverride?: string;
} {
const prompt = normalizeOptionalString(args.prompt) ?? defaultPrompt;
const modelOverride = normalizeOptionalString(args.model);
return { prompt, modelOverride };
}
export function buildTextToolResult(
result: TextToolResult,
extraDetails: Record<string, unknown>,
): {
content: Array<{ type: "text"; text: string }>;
details: Record<string, unknown>;
} {
return {
content: [{ type: "text", text: result.text }],
details: {
model: `${result.provider}/${result.model}`,
...extraDetails,
attempts: result.attempts,
},
};
}
export function resolveModelFromRegistry(params: {
modelRegistry: { find: (provider: string, modelId: string) => unknown };
provider: string;
modelId: string;
}): Model<Api> {
const resolvedRef = normalizeModelRef(params.provider, params.modelId, {
allowPluginNormalization: false,
});
let model = params.modelRegistry.find(
resolvedRef.provider,
resolvedRef.model,
) as Model<Api> | null;
if (!model && !resolvedRef.model.includes("/")) {
model = params.modelRegistry.find(
resolvedRef.provider,
`${resolvedRef.provider}/${resolvedRef.model}`,
) as Model<Api> | null;
}
if (!model) {
throw new Error(`Unknown model: ${resolvedRef.provider}/${resolvedRef.model}`);
}
return model;
}
export async function resolveModelRuntimeApiKey(params: {
model: Model<Api>;
cfg: OpenClawConfig | undefined;
agentDir: string;
authStorage: {
setRuntimeApiKey: (provider: string, apiKey: string) => void;
};
}): Promise<string> {
const apiKeyInfo = await getApiKeyForModel({
model: params.model,
cfg: params.cfg,
agentDir: params.agentDir,
});
const apiKey = requireApiKey(apiKeyInfo, params.model.provider);
params.authStorage.setRuntimeApiKey(params.model.provider, apiKey);
return apiKey;
}