fix: fail closed talk provider selection

This commit is contained in:
Peter Steinberger
2026-03-08 16:03:16 +00:00
parent ca5e352c53
commit b7ad8fd661
8 changed files with 245 additions and 36 deletions

View File

@@ -99,10 +99,18 @@ class TalkModeManager(
val providerConfig = value.asObjectOrNull() ?: return@mapNotNull null
providerId to providerConfig
}?.toMap().orEmpty()
val providerId =
normalizeTalkProviderId(rawProvider)
?: providers.keys.sorted().firstOrNull()
?: defaultTalkProvider
val explicitProviderId = normalizeTalkProviderId(rawProvider)
if (explicitProviderId != null) {
if (providers.isNotEmpty() && providers[explicitProviderId] == null) {
return null
}
return TalkProviderConfigSelection(
provider = explicitProviderId,
config = providers[explicitProviderId] ?: buildJsonObject {},
normalizedPayload = true,
)
}
val providerId = providers.keys.singleOrNull() ?: return null
return TalkProviderConfigSelection(
provider = providerId,
config = providers[providerId] ?: buildJsonObject {},

View File

@@ -68,6 +68,50 @@ class TalkModeConfigParsingTest {
assertEquals("voice-normalized", selection?.config?.get("voiceId")?.jsonPrimitive?.content)
}
@Test
fun rejectsNormalizedTalkProviderPayloadWhenProviderMissingFromProviders() {
val talk =
json.parseToJsonElement(
"""
{
"provider": "acme",
"providers": {
"elevenlabs": {
"voiceId": "voice-normalized"
}
}
}
""".trimIndent(),
)
.jsonObject
val selection = TalkModeManager.selectTalkProviderConfig(talk)
assertEquals(null, selection)
}
@Test
fun rejectsNormalizedTalkProviderPayloadWhenProviderIsAmbiguous() {
val talk =
json.parseToJsonElement(
"""
{
"providers": {
"acme": {
"voiceId": "voice-acme"
},
"elevenlabs": {
"voiceId": "voice-normalized"
}
}
}
""".trimIndent(),
)
.jsonObject
val selection = TalkModeManager.selectTalkProviderConfig(talk)
assertEquals(null, selection)
}
@Test
fun fallsBackToLegacyTalkFieldsWhenNormalizedPayloadMissing() {
val legacyApiKey = "legacy-key" // pragma: allowlist secret

View File

@@ -31,10 +31,19 @@ public enum TalkConfigParsing {
let hasNormalizedPayload = rawProvider != nil || rawProviders != nil
if hasNormalizedPayload {
let normalizedProviders = self.normalizedTalkProviders(rawProviders)
let providerID =
self.normalizedTalkProviderID(rawProvider) ??
normalizedProviders.keys.min() ??
defaultProvider
let explicitProviderID = self.normalizedTalkProviderID(rawProvider)
if let explicitProviderID {
if !normalizedProviders.isEmpty, normalizedProviders[explicitProviderID] == nil {
return nil
}
return TalkProviderConfigSelection(
provider: explicitProviderID,
config: normalizedProviders[explicitProviderID] ?? [:],
normalizedPayload: true)
}
guard normalizedProviders.count == 1, let providerID = normalizedProviders.keys.first else {
return nil
}
return TalkProviderConfigSelection(
provider: providerID,
config: normalizedProviders[providerID] ?? [:],

View File

@@ -66,6 +66,36 @@ struct TalkConfigParsingTests {
#expect(selection == nil)
}
@Test func rejectsNormalizedPayloadWhenProviderMissingFromProviders() {
let talk: [String: AnyCodable] = [
"provider": AnyCodable("acme"),
"providers": AnyCodable([
"elevenlabs": [
"voiceId": "voice-normalized",
],
]),
]
let selection = TalkConfigParsing.selectProviderConfig(talk, defaultProvider: "elevenlabs")
#expect(selection == nil)
}
@Test func rejectsNormalizedPayloadWhenMultipleProvidersAndNoProvider() {
let talk: [String: AnyCodable] = [
"providers": AnyCodable([
"acme": [
"voiceId": "voice-acme",
],
"elevenlabs": [
"voiceId": "voice-eleven",
],
]),
]
let selection = TalkConfigParsing.selectProviderConfig(talk, defaultProvider: "elevenlabs")
#expect(selection == nil)
}
@Test func bridgesFoundationDictionary() {
let raw: [String: Any] = [
"provider": "elevenlabs",

View File

@@ -37,4 +37,68 @@ describe("talk config validation fail-closed behavior", () => {
},
);
});
it("rejects talk.provider when it does not match talk.providers during config load", async () => {
await withTempHomeConfig(
{
agents: { list: [{ id: "main" }] },
talk: {
provider: "acme",
providers: {
elevenlabs: {
voiceId: "voice-123",
},
},
},
},
async () => {
const consoleSpy = vi.spyOn(console, "error").mockImplementation(() => {});
let thrown: unknown;
try {
loadConfig();
} catch (error) {
thrown = error;
}
expect(thrown).toBeInstanceOf(Error);
expect((thrown as { code?: string } | undefined)?.code).toBe("INVALID_CONFIG");
expect((thrown as Error).message).toMatch(/talk\.provider|talk\.providers|acme/i);
expect(consoleSpy).toHaveBeenCalled();
},
);
});
it("rejects multi-provider talk config without talk.provider during config load", async () => {
await withTempHomeConfig(
{
agents: { list: [{ id: "main" }] },
talk: {
providers: {
acme: {
voiceId: "voice-acme",
},
elevenlabs: {
voiceId: "voice-eleven",
},
},
},
},
async () => {
const consoleSpy = vi.spyOn(console, "error").mockImplementation(() => {});
let thrown: unknown;
try {
loadConfig();
} catch (error) {
thrown = error;
}
expect(thrown).toBeInstanceOf(Error);
expect((thrown as { code?: string } | undefined)?.code).toBe("INVALID_CONFIG");
expect((thrown as Error).message).toMatch(/talk\.provider|required/i);
expect(consoleSpy).toHaveBeenCalled();
},
);
});
});

View File

@@ -158,10 +158,14 @@ function legacyProviderConfigFromTalk(
function activeProviderFromTalk(talk: TalkConfig): string | undefined {
const provider = normalizeString(talk.provider);
const providers = talk.providers;
if (provider) {
if (providers && !(provider in providers)) {
return undefined;
}
return provider;
}
const providerIds = talk.providers ? Object.keys(talk.providers) : [];
const providerIds = providers ? Object.keys(providers) : [];
return providerIds.length === 1 ? providerIds[0] : undefined;
}

View File

@@ -25,4 +25,36 @@ describe("OpenClawSchema talk validation", () => {
}),
).toThrow(/silenceTimeoutMs|number|integer/i);
});
it("rejects talk.provider when it does not match talk.providers", () => {
expect(() =>
OpenClawSchema.parse({
talk: {
provider: "acme",
providers: {
elevenlabs: {
voiceId: "voice-123",
},
},
},
}),
).toThrow(/talk\.provider|talk\.providers|missing "acme"/i);
});
it("rejects multi-provider talk config without talk.provider", () => {
expect(() =>
OpenClawSchema.parse({
talk: {
providers: {
acme: {
voiceId: "voice-acme",
},
elevenlabs: {
voiceId: "voice-eleven",
},
},
},
}),
).toThrow(/talk\.provider|required/i);
});
});

View File

@@ -159,6 +159,50 @@ const PluginEntrySchema = z
})
.strict();
const TalkProviderEntrySchema = z
.object({
voiceId: z.string().optional(),
voiceAliases: z.record(z.string(), z.string()).optional(),
modelId: z.string().optional(),
outputFormat: z.string().optional(),
apiKey: SecretInputSchema.optional().register(sensitive),
})
.catchall(z.unknown());
const TalkSchema = z
.object({
provider: z.string().optional(),
providers: z.record(z.string(), TalkProviderEntrySchema).optional(),
voiceId: z.string().optional(),
voiceAliases: z.record(z.string(), z.string()).optional(),
modelId: z.string().optional(),
outputFormat: z.string().optional(),
apiKey: SecretInputSchema.optional().register(sensitive),
interruptOnSpeech: z.boolean().optional(),
silenceTimeoutMs: z.number().int().positive().optional(),
})
.strict()
.superRefine((talk, ctx) => {
const provider = talk.provider?.trim().toLowerCase();
const providers = talk.providers ? Object.keys(talk.providers) : [];
if (provider && providers.length > 0 && !(provider in talk.providers!)) {
ctx.addIssue({
code: z.ZodIssueCode.custom,
path: ["provider"],
message: `talk.provider must match a key in talk.providers (missing "${provider}")`,
});
}
if (!provider && providers.length > 1) {
ctx.addIssue({
code: z.ZodIssueCode.custom,
path: ["provider"],
message: "talk.provider is required when talk.providers defines multiple providers",
});
}
});
export const OpenClawSchema = z
.object({
$schema: z.string().optional(),
@@ -572,33 +616,7 @@ export const OpenClawSchema = z
})
.strict()
.optional(),
talk: z
.object({
provider: z.string().optional(),
providers: z
.record(
z.string(),
z
.object({
voiceId: z.string().optional(),
voiceAliases: z.record(z.string(), z.string()).optional(),
modelId: z.string().optional(),
outputFormat: z.string().optional(),
apiKey: SecretInputSchema.optional().register(sensitive),
})
.catchall(z.unknown()),
)
.optional(),
voiceId: z.string().optional(),
voiceAliases: z.record(z.string(), z.string()).optional(),
modelId: z.string().optional(),
outputFormat: z.string().optional(),
apiKey: SecretInputSchema.optional().register(sensitive),
interruptOnSpeech: z.boolean().optional(),
silenceTimeoutMs: z.number().int().positive().optional(),
})
.strict()
.optional(),
talk: TalkSchema.optional(),
gateway: z
.object({
port: z.number().int().positive().optional(),