fix(android-voice): cancel in-flight speech when speaker muted

This commit is contained in:
Ayaan Zaidi
2026-02-28 19:50:15 +05:30
committed by Ayaan Zaidi
parent 727ae469cf
commit 930e94024a

View File

@@ -27,6 +27,7 @@ import java.net.HttpURLConnection
import java.net.URL
import java.util.UUID
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Job
@@ -148,6 +149,7 @@ class TalkModeManager(
private var chatSubscribedSessionKey: String? = null
private var configLoaded = false
@Volatile private var playbackEnabled = true
@Volatile private var playbackGeneration = 0L
private var player: MediaPlayer? = null
private var streamingSource: StreamingMediaDataSource? = null
@@ -197,8 +199,10 @@ class TalkModeManager(
}
fun setPlaybackEnabled(enabled: Boolean) {
if (playbackEnabled == enabled) return
playbackEnabled = enabled
if (!enabled) {
playbackGeneration += 1
stopSpeaking()
}
}
@@ -209,9 +213,10 @@ class TalkModeManager(
suspend fun speakAssistantReply(text: String) {
if (!playbackEnabled) return
val playbackToken = playbackGeneration
ensureConfigLoaded()
if (!playbackEnabled) return
playAssistant(text)
ensurePlaybackActive(playbackToken)
playAssistant(text, playbackToken)
}
private fun start() {
@@ -389,8 +394,14 @@ class TalkModeManager(
return
}
Log.d(tag, "assistant text ok chars=${assistant.length}")
playAssistant(assistant)
val playbackToken = playbackGeneration
ensurePlaybackActive(playbackToken)
playAssistant(assistant, playbackToken)
} catch (err: Throwable) {
if (err is CancellationException) {
Log.d(tag, "finalize speech cancelled")
return
}
_statusText.value = "Talk failed: ${err.message ?: err::class.simpleName}"
Log.w(tag, "finalize failed: ${err.message ?: err::class.simpleName}")
}
@@ -507,7 +518,7 @@ class TalkModeManager(
return null
}
private suspend fun playAssistant(text: String) {
private suspend fun playAssistant(text: String, playbackToken: Long) {
val parsed = TalkDirectiveParser.parse(text)
if (parsed.unknownKeys.isNotEmpty()) {
Log.w(tag, "Unknown talk directive keys: ${parsed.unknownKeys}")
@@ -535,6 +546,7 @@ class TalkModeManager(
modelOverrideActive = true
}
}
ensurePlaybackActive(playbackToken)
val apiKey =
apiKey?.trim()?.takeIf { it.isNotEmpty() }
@@ -561,9 +573,10 @@ class TalkModeManager(
if (apiKey.isNullOrEmpty()) {
Log.w(tag, "missing ELEVENLABS_API_KEY; falling back to system voice")
}
ensurePlaybackActive(playbackToken)
_usingFallbackTts.value = true
_statusText.value = "Speaking (System)…"
speakWithSystemTts(cleaned)
speakWithSystemTts(cleaned, playbackToken)
} else {
_usingFallbackTts.value = false
val ttsStarted = SystemClock.elapsedRealtime()
@@ -584,43 +597,71 @@ class TalkModeManager(
language = TalkModeRuntime.validatedLanguage(directive?.language),
latencyTier = TalkModeRuntime.validatedLatencyTier(directive?.latencyTier),
)
streamAndPlay(voiceId = voiceId!!, apiKey = apiKey!!, request = request)
streamAndPlay(voiceId = voiceId!!, apiKey = apiKey!!, request = request, playbackToken = playbackToken)
Log.d(tag, "elevenlabs stream ok durMs=${SystemClock.elapsedRealtime() - ttsStarted}")
}
} catch (err: Throwable) {
if (isPlaybackCancelled(err, playbackToken)) {
Log.d(tag, "assistant speech cancelled")
return
}
Log.w(tag, "speak failed: ${err.message ?: err::class.simpleName}; falling back to system voice")
try {
ensurePlaybackActive(playbackToken)
_usingFallbackTts.value = true
_statusText.value = "Speaking (System)…"
speakWithSystemTts(cleaned)
speakWithSystemTts(cleaned, playbackToken)
} catch (fallbackErr: Throwable) {
if (isPlaybackCancelled(fallbackErr, playbackToken)) {
Log.d(tag, "assistant fallback speech cancelled")
return
}
_statusText.value = "Speak failed: ${fallbackErr.message ?: fallbackErr::class.simpleName}"
Log.w(tag, "system voice failed: ${fallbackErr.message ?: fallbackErr::class.simpleName}")
}
} finally {
_isSpeaking.value = false
}
_isSpeaking.value = false
}
private suspend fun streamAndPlay(voiceId: String, apiKey: String, request: ElevenLabsRequest) {
private suspend fun streamAndPlay(
voiceId: String,
apiKey: String,
request: ElevenLabsRequest,
playbackToken: Long,
) {
ensurePlaybackActive(playbackToken)
stopSpeaking(resetInterrupt = false)
ensurePlaybackActive(playbackToken)
pcmStopRequested = false
val pcmSampleRate = TalkModeRuntime.parsePcmSampleRate(request.outputFormat)
if (pcmSampleRate != null) {
try {
streamAndPlayPcm(voiceId = voiceId, apiKey = apiKey, request = request, sampleRate = pcmSampleRate)
streamAndPlayPcm(
voiceId = voiceId,
apiKey = apiKey,
request = request,
sampleRate = pcmSampleRate,
playbackToken = playbackToken,
)
return
} catch (err: Throwable) {
if (pcmStopRequested) return
if (isPlaybackCancelled(err, playbackToken) || pcmStopRequested) return
Log.w(tag, "pcm playback failed; falling back to mp3: ${err.message ?: err::class.simpleName}")
}
}
streamAndPlayMp3(voiceId = voiceId, apiKey = apiKey, request = request)
ensurePlaybackActive(playbackToken)
streamAndPlayMp3(voiceId = voiceId, apiKey = apiKey, request = request, playbackToken = playbackToken)
}
private suspend fun streamAndPlayMp3(voiceId: String, apiKey: String, request: ElevenLabsRequest) {
private suspend fun streamAndPlayMp3(
voiceId: String,
apiKey: String,
request: ElevenLabsRequest,
playbackToken: Long,
) {
val dataSource = StreamingMediaDataSource()
streamingSource = dataSource
@@ -657,7 +698,7 @@ class TalkModeManager(
val fetchJob =
scope.launch(Dispatchers.IO) {
try {
streamTts(voiceId = voiceId, apiKey = apiKey, request = request, sink = dataSource)
streamTts(voiceId = voiceId, apiKey = apiKey, request = request, sink = dataSource, playbackToken = playbackToken)
fetchError.complete(null)
} catch (err: Throwable) {
dataSource.fail()
@@ -667,8 +708,11 @@ class TalkModeManager(
Log.d(tag, "play start")
try {
ensurePlaybackActive(playbackToken)
prepared.await()
ensurePlaybackActive(playbackToken)
finished.await()
ensurePlaybackActive(playbackToken)
fetchError.await()?.let { throw it }
} finally {
fetchJob.cancel()
@@ -682,7 +726,9 @@ class TalkModeManager(
apiKey: String,
request: ElevenLabsRequest,
sampleRate: Int,
playbackToken: Long,
) {
ensurePlaybackActive(playbackToken)
val minBuffer =
AudioTrack.getMinBufferSize(
sampleRate,
@@ -718,20 +764,22 @@ class TalkModeManager(
Log.d(tag, "pcm play start sampleRate=$sampleRate bufferSize=$bufferSize")
try {
streamPcm(voiceId = voiceId, apiKey = apiKey, request = request, track = track)
streamPcm(voiceId = voiceId, apiKey = apiKey, request = request, track = track, playbackToken = playbackToken)
} finally {
cleanupPcmTrack()
}
Log.d(tag, "pcm play done")
}
private suspend fun speakWithSystemTts(text: String) {
private suspend fun speakWithSystemTts(text: String, playbackToken: Long) {
val trimmed = text.trim()
if (trimmed.isEmpty()) return
ensurePlaybackActive(playbackToken)
val ok = ensureSystemTts()
if (!ok) {
throw IllegalStateException("system TTS unavailable")
}
ensurePlaybackActive(playbackToken)
val tts = systemTts ?: throw IllegalStateException("system TTS unavailable")
val utteranceId = "talk-${UUID.randomUUID()}"
@@ -741,6 +789,7 @@ class TalkModeManager(
systemTtsPendingId = utteranceId
withContext(Dispatchers.Main) {
ensurePlaybackActive(playbackToken)
val params = Bundle()
tts.speak(trimmed, TextToSpeech.QUEUE_FLUSH, params, utteranceId)
}
@@ -751,6 +800,7 @@ class TalkModeManager(
} catch (err: Throwable) {
throw err
}
ensurePlaybackActive(playbackToken)
}
}
@@ -870,6 +920,17 @@ class TalkModeManager(
return true
}
private fun ensurePlaybackActive(playbackToken: Long) {
if (!playbackEnabled || playbackToken != playbackGeneration) {
throw CancellationException("assistant speech cancelled")
}
}
private fun isPlaybackCancelled(err: Throwable?, playbackToken: Long): Boolean {
if (err is CancellationException) return true
return !playbackEnabled || playbackToken != playbackGeneration
}
private suspend fun ensureConfigLoaded() {
if (!configLoaded) {
reloadConfig()
@@ -950,8 +1011,10 @@ class TalkModeManager(
apiKey: String,
request: ElevenLabsRequest,
sink: StreamingMediaDataSource,
playbackToken: Long,
) {
withContext(Dispatchers.IO) {
ensurePlaybackActive(playbackToken)
val conn = openTtsConnection(voiceId = voiceId, apiKey = apiKey, request = request)
try {
val payload = buildRequestPayload(request)
@@ -967,8 +1030,10 @@ class TalkModeManager(
val buffer = ByteArray(8 * 1024)
conn.inputStream.use { input ->
while (true) {
ensurePlaybackActive(playbackToken)
val read = input.read(buffer)
if (read <= 0) break
ensurePlaybackActive(playbackToken)
sink.append(buffer.copyOf(read))
}
}
@@ -984,8 +1049,10 @@ class TalkModeManager(
apiKey: String,
request: ElevenLabsRequest,
track: AudioTrack,
playbackToken: Long,
) {
withContext(Dispatchers.IO) {
ensurePlaybackActive(playbackToken)
val conn = openTtsConnection(voiceId = voiceId, apiKey = apiKey, request = request)
try {
val payload = buildRequestPayload(request)
@@ -1000,21 +1067,21 @@ class TalkModeManager(
val buffer = ByteArray(8 * 1024)
conn.inputStream.use { input ->
while (true) {
if (pcmStopRequested) return@withContext
if (pcmStopRequested || isPlaybackCancelled(null, playbackToken)) return@withContext
val read = input.read(buffer)
if (read <= 0) break
var offset = 0
while (offset < read) {
if (pcmStopRequested) return@withContext
if (pcmStopRequested || isPlaybackCancelled(null, playbackToken)) return@withContext
val wrote =
try {
track.write(buffer, offset, read - offset)
} catch (err: Throwable) {
if (pcmStopRequested) return@withContext
if (pcmStopRequested || isPlaybackCancelled(err, playbackToken)) return@withContext
throw err
}
if (wrote <= 0) {
if (pcmStopRequested) return@withContext
if (pcmStopRequested || isPlaybackCancelled(null, playbackToken)) return@withContext
throw IllegalStateException("AudioTrack write failed: $wrote")
}
offset += wrote