From b7ad8fd6613d73d1a7df076dbe04ffa585aec929 Mon Sep 17 00:00:00 2001 From: Peter Steinberger Date: Sun, 8 Mar 2026 16:03:16 +0000 Subject: [PATCH] fix: fail closed talk provider selection --- .../ai/openclaw/app/voice/TalkModeManager.kt | 16 +++-- .../app/voice/TalkModeConfigParsingTest.kt | 44 ++++++++++++ .../OpenClawKit/TalkConfigParsing.swift | 17 +++-- .../TalkConfigParsingTests.swift | 30 ++++++++ src/config/config.talk-validation.test.ts | 64 +++++++++++++++++ src/config/talk.ts | 6 +- src/config/zod-schema.talk.test.ts | 32 +++++++++ src/config/zod-schema.ts | 72 ++++++++++++------- 8 files changed, 245 insertions(+), 36 deletions(-) diff --git a/apps/android/app/src/main/java/ai/openclaw/app/voice/TalkModeManager.kt b/apps/android/app/src/main/java/ai/openclaw/app/voice/TalkModeManager.kt index e63d012eb0a..98823e0b216 100644 --- a/apps/android/app/src/main/java/ai/openclaw/app/voice/TalkModeManager.kt +++ b/apps/android/app/src/main/java/ai/openclaw/app/voice/TalkModeManager.kt @@ -99,10 +99,18 @@ class TalkModeManager( val providerConfig = value.asObjectOrNull() ?: return@mapNotNull null providerId to providerConfig }?.toMap().orEmpty() - val providerId = - normalizeTalkProviderId(rawProvider) - ?: providers.keys.sorted().firstOrNull() - ?: defaultTalkProvider + val explicitProviderId = normalizeTalkProviderId(rawProvider) + if (explicitProviderId != null) { + if (providers.isNotEmpty() && providers[explicitProviderId] == null) { + return null + } + return TalkProviderConfigSelection( + provider = explicitProviderId, + config = providers[explicitProviderId] ?: buildJsonObject {}, + normalizedPayload = true, + ) + } + val providerId = providers.keys.singleOrNull() ?: return null return TalkProviderConfigSelection( provider = providerId, config = providers[providerId] ?: buildJsonObject {}, diff --git a/apps/android/app/src/test/java/ai/openclaw/app/voice/TalkModeConfigParsingTest.kt b/apps/android/app/src/test/java/ai/openclaw/app/voice/TalkModeConfigParsingTest.kt index 9188436a183..04fae5be2c7 100644 --- a/apps/android/app/src/test/java/ai/openclaw/app/voice/TalkModeConfigParsingTest.kt +++ b/apps/android/app/src/test/java/ai/openclaw/app/voice/TalkModeConfigParsingTest.kt @@ -68,6 +68,50 @@ class TalkModeConfigParsingTest { assertEquals("voice-normalized", selection?.config?.get("voiceId")?.jsonPrimitive?.content) } + @Test + fun rejectsNormalizedTalkProviderPayloadWhenProviderMissingFromProviders() { + val talk = + json.parseToJsonElement( + """ + { + "provider": "acme", + "providers": { + "elevenlabs": { + "voiceId": "voice-normalized" + } + } + } + """.trimIndent(), + ) + .jsonObject + + val selection = TalkModeManager.selectTalkProviderConfig(talk) + assertEquals(null, selection) + } + + @Test + fun rejectsNormalizedTalkProviderPayloadWhenProviderIsAmbiguous() { + val talk = + json.parseToJsonElement( + """ + { + "providers": { + "acme": { + "voiceId": "voice-acme" + }, + "elevenlabs": { + "voiceId": "voice-normalized" + } + } + } + """.trimIndent(), + ) + .jsonObject + + val selection = TalkModeManager.selectTalkProviderConfig(talk) + assertEquals(null, selection) + } + @Test fun fallsBackToLegacyTalkFieldsWhenNormalizedPayloadMissing() { val legacyApiKey = "legacy-key" // pragma: allowlist secret diff --git a/apps/shared/OpenClawKit/Sources/OpenClawKit/TalkConfigParsing.swift b/apps/shared/OpenClawKit/Sources/OpenClawKit/TalkConfigParsing.swift index 05c587b2e9d..0d5ade70e4c 100644 --- a/apps/shared/OpenClawKit/Sources/OpenClawKit/TalkConfigParsing.swift +++ b/apps/shared/OpenClawKit/Sources/OpenClawKit/TalkConfigParsing.swift @@ -31,10 +31,19 @@ public enum TalkConfigParsing { let hasNormalizedPayload = rawProvider != nil || rawProviders != nil if hasNormalizedPayload { let normalizedProviders = self.normalizedTalkProviders(rawProviders) - let providerID = - self.normalizedTalkProviderID(rawProvider) ?? - normalizedProviders.keys.min() ?? - defaultProvider + let explicitProviderID = self.normalizedTalkProviderID(rawProvider) + if let explicitProviderID { + if !normalizedProviders.isEmpty, normalizedProviders[explicitProviderID] == nil { + return nil + } + return TalkProviderConfigSelection( + provider: explicitProviderID, + config: normalizedProviders[explicitProviderID] ?? [:], + normalizedPayload: true) + } + guard normalizedProviders.count == 1, let providerID = normalizedProviders.keys.first else { + return nil + } return TalkProviderConfigSelection( provider: providerID, config: normalizedProviders[providerID] ?? [:], diff --git a/apps/shared/OpenClawKit/Tests/OpenClawKitTests/TalkConfigParsingTests.swift b/apps/shared/OpenClawKit/Tests/OpenClawKitTests/TalkConfigParsingTests.swift index 5edd2ff3368..f710a3497cf 100644 --- a/apps/shared/OpenClawKit/Tests/OpenClawKitTests/TalkConfigParsingTests.swift +++ b/apps/shared/OpenClawKit/Tests/OpenClawKitTests/TalkConfigParsingTests.swift @@ -66,6 +66,36 @@ struct TalkConfigParsingTests { #expect(selection == nil) } + @Test func rejectsNormalizedPayloadWhenProviderMissingFromProviders() { + let talk: [String: AnyCodable] = [ + "provider": AnyCodable("acme"), + "providers": AnyCodable([ + "elevenlabs": [ + "voiceId": "voice-normalized", + ], + ]), + ] + + let selection = TalkConfigParsing.selectProviderConfig(talk, defaultProvider: "elevenlabs") + #expect(selection == nil) + } + + @Test func rejectsNormalizedPayloadWhenMultipleProvidersAndNoProvider() { + let talk: [String: AnyCodable] = [ + "providers": AnyCodable([ + "acme": [ + "voiceId": "voice-acme", + ], + "elevenlabs": [ + "voiceId": "voice-eleven", + ], + ]), + ] + + let selection = TalkConfigParsing.selectProviderConfig(talk, defaultProvider: "elevenlabs") + #expect(selection == nil) + } + @Test func bridgesFoundationDictionary() { let raw: [String: Any] = [ "provider": "elevenlabs", diff --git a/src/config/config.talk-validation.test.ts b/src/config/config.talk-validation.test.ts index 8a0c93ecd3b..cb948d75c75 100644 --- a/src/config/config.talk-validation.test.ts +++ b/src/config/config.talk-validation.test.ts @@ -37,4 +37,68 @@ describe("talk config validation fail-closed behavior", () => { }, ); }); + + it("rejects talk.provider when it does not match talk.providers during config load", async () => { + await withTempHomeConfig( + { + agents: { list: [{ id: "main" }] }, + talk: { + provider: "acme", + providers: { + elevenlabs: { + voiceId: "voice-123", + }, + }, + }, + }, + async () => { + const consoleSpy = vi.spyOn(console, "error").mockImplementation(() => {}); + + let thrown: unknown; + try { + loadConfig(); + } catch (error) { + thrown = error; + } + + expect(thrown).toBeInstanceOf(Error); + expect((thrown as { code?: string } | undefined)?.code).toBe("INVALID_CONFIG"); + expect((thrown as Error).message).toMatch(/talk\.provider|talk\.providers|acme/i); + expect(consoleSpy).toHaveBeenCalled(); + }, + ); + }); + + it("rejects multi-provider talk config without talk.provider during config load", async () => { + await withTempHomeConfig( + { + agents: { list: [{ id: "main" }] }, + talk: { + providers: { + acme: { + voiceId: "voice-acme", + }, + elevenlabs: { + voiceId: "voice-eleven", + }, + }, + }, + }, + async () => { + const consoleSpy = vi.spyOn(console, "error").mockImplementation(() => {}); + + let thrown: unknown; + try { + loadConfig(); + } catch (error) { + thrown = error; + } + + expect(thrown).toBeInstanceOf(Error); + expect((thrown as { code?: string } | undefined)?.code).toBe("INVALID_CONFIG"); + expect((thrown as Error).message).toMatch(/talk\.provider|required/i); + expect(consoleSpy).toHaveBeenCalled(); + }, + ); + }); }); diff --git a/src/config/talk.ts b/src/config/talk.ts index 2d8f4b79c3d..32c4255a7a4 100644 --- a/src/config/talk.ts +++ b/src/config/talk.ts @@ -158,10 +158,14 @@ function legacyProviderConfigFromTalk( function activeProviderFromTalk(talk: TalkConfig): string | undefined { const provider = normalizeString(talk.provider); + const providers = talk.providers; if (provider) { + if (providers && !(provider in providers)) { + return undefined; + } return provider; } - const providerIds = talk.providers ? Object.keys(talk.providers) : []; + const providerIds = providers ? Object.keys(providers) : []; return providerIds.length === 1 ? providerIds[0] : undefined; } diff --git a/src/config/zod-schema.talk.test.ts b/src/config/zod-schema.talk.test.ts index 6f1f22ebc14..bbb7eb9f89f 100644 --- a/src/config/zod-schema.talk.test.ts +++ b/src/config/zod-schema.talk.test.ts @@ -25,4 +25,36 @@ describe("OpenClawSchema talk validation", () => { }), ).toThrow(/silenceTimeoutMs|number|integer/i); }); + + it("rejects talk.provider when it does not match talk.providers", () => { + expect(() => + OpenClawSchema.parse({ + talk: { + provider: "acme", + providers: { + elevenlabs: { + voiceId: "voice-123", + }, + }, + }, + }), + ).toThrow(/talk\.provider|talk\.providers|missing "acme"/i); + }); + + it("rejects multi-provider talk config without talk.provider", () => { + expect(() => + OpenClawSchema.parse({ + talk: { + providers: { + acme: { + voiceId: "voice-acme", + }, + elevenlabs: { + voiceId: "voice-eleven", + }, + }, + }, + }), + ).toThrow(/talk\.provider|required/i); + }); }); diff --git a/src/config/zod-schema.ts b/src/config/zod-schema.ts index 731909da72d..62b7f2f1513 100644 --- a/src/config/zod-schema.ts +++ b/src/config/zod-schema.ts @@ -159,6 +159,50 @@ const PluginEntrySchema = z }) .strict(); +const TalkProviderEntrySchema = z + .object({ + voiceId: z.string().optional(), + voiceAliases: z.record(z.string(), z.string()).optional(), + modelId: z.string().optional(), + outputFormat: z.string().optional(), + apiKey: SecretInputSchema.optional().register(sensitive), + }) + .catchall(z.unknown()); + +const TalkSchema = z + .object({ + provider: z.string().optional(), + providers: z.record(z.string(), TalkProviderEntrySchema).optional(), + voiceId: z.string().optional(), + voiceAliases: z.record(z.string(), z.string()).optional(), + modelId: z.string().optional(), + outputFormat: z.string().optional(), + apiKey: SecretInputSchema.optional().register(sensitive), + interruptOnSpeech: z.boolean().optional(), + silenceTimeoutMs: z.number().int().positive().optional(), + }) + .strict() + .superRefine((talk, ctx) => { + const provider = talk.provider?.trim().toLowerCase(); + const providers = talk.providers ? Object.keys(talk.providers) : []; + + if (provider && providers.length > 0 && !(provider in talk.providers!)) { + ctx.addIssue({ + code: z.ZodIssueCode.custom, + path: ["provider"], + message: `talk.provider must match a key in talk.providers (missing "${provider}")`, + }); + } + + if (!provider && providers.length > 1) { + ctx.addIssue({ + code: z.ZodIssueCode.custom, + path: ["provider"], + message: "talk.provider is required when talk.providers defines multiple providers", + }); + } + }); + export const OpenClawSchema = z .object({ $schema: z.string().optional(), @@ -572,33 +616,7 @@ export const OpenClawSchema = z }) .strict() .optional(), - talk: z - .object({ - provider: z.string().optional(), - providers: z - .record( - z.string(), - z - .object({ - voiceId: z.string().optional(), - voiceAliases: z.record(z.string(), z.string()).optional(), - modelId: z.string().optional(), - outputFormat: z.string().optional(), - apiKey: SecretInputSchema.optional().register(sensitive), - }) - .catchall(z.unknown()), - ) - .optional(), - voiceId: z.string().optional(), - voiceAliases: z.record(z.string(), z.string()).optional(), - modelId: z.string().optional(), - outputFormat: z.string().optional(), - apiKey: SecretInputSchema.optional().register(sensitive), - interruptOnSpeech: z.boolean().optional(), - silenceTimeoutMs: z.number().int().positive().optional(), - }) - .strict() - .optional(), + talk: TalkSchema.optional(), gateway: z .object({ port: z.number().int().positive().optional(),