diff --git a/apps/android/app/src/main/java/ai/openclaw/app/gateway/GatewaySession.kt b/apps/android/app/src/main/java/ai/openclaw/app/gateway/GatewaySession.kt index 467084f1edb..eaee48981cf 100644 --- a/apps/android/app/src/main/java/ai/openclaw/app/gateway/GatewaySession.kt +++ b/apps/android/app/src/main/java/ai/openclaw/app/gateway/GatewaySession.kt @@ -149,7 +149,10 @@ class GatewaySession( val tls: GatewayTlsParams?, ) - private var desired: DesiredConnection? = null + private val lifecycleLock = Any() + + @Volatile private var desired: DesiredConnection? = null + private var job: Job? = null @Volatile private var currentConnection: Connection? = null @@ -168,26 +171,39 @@ class GatewaySession( options: GatewayConnectOptions, tls: GatewayTlsParams? = null, ) { - desired = DesiredConnection(endpoint, token, bootstrapToken, password, options, tls) - pendingDeviceTokenRetry = false - deviceTokenRetryBudgetUsed = false - reconnectPausedForAuthFailure = false - if (job == null) { - job = scope.launch(Dispatchers.IO) { runLoop() } + val connectionToClose: Connection? + synchronized(lifecycleLock) { + desired = DesiredConnection(endpoint, token, bootstrapToken, password, options, tls) + pendingDeviceTokenRetry = false + deviceTokenRetryBudgetUsed = false + reconnectPausedForAuthFailure = false + connectionToClose = currentConnection + if (job?.isActive != true) { + job = scope.launch(Dispatchers.IO) { runLoop() } + } } + connectionToClose?.closeQuietly() } fun disconnect() { - desired = null - pendingDeviceTokenRetry = false - deviceTokenRetryBudgetUsed = false - reconnectPausedForAuthFailure = false - currentConnection?.closeQuietly() - scope.launch(Dispatchers.IO) { - job?.cancelAndJoin() + val jobToCancel: Job? + val connectionToClose: Connection? + synchronized(lifecycleLock) { + desired = null + pendingDeviceTokenRetry = false + deviceTokenRetryBudgetUsed = false + reconnectPausedForAuthFailure = false + connectionToClose = currentConnection + jobToCancel = job job = null - pluginSurfaceUrls = emptyMap() - mainSessionKey = null + } + connectionToClose?.closeQuietly() + scope.launch(Dispatchers.IO) { + jobToCancel?.cancelAndJoin() + if (desired == null) { + pluginSurfaceUrls = emptyMap() + mainSessionKey = null + } onDisconnected("Offline") } } @@ -963,9 +979,11 @@ class GatewaySession( conn.connect() conn.awaitClose() } finally { - currentConnection = null - pluginSurfaceUrls = emptyMap() - mainSessionKey = null + if (currentConnection === conn) { + currentConnection = null + pluginSurfaceUrls = emptyMap() + mainSessionKey = null + } } } diff --git a/apps/android/app/src/test/java/ai/openclaw/app/gateway/GatewaySessionReconnectTest.kt b/apps/android/app/src/test/java/ai/openclaw/app/gateway/GatewaySessionReconnectTest.kt index 439ff0a410f..1d9c1dd567a 100644 --- a/apps/android/app/src/test/java/ai/openclaw/app/gateway/GatewaySessionReconnectTest.kt +++ b/apps/android/app/src/test/java/ai/openclaw/app/gateway/GatewaySessionReconnectTest.kt @@ -1,10 +1,129 @@ package ai.openclaw.app.gateway +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.Job +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.cancelAndJoin +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.withTimeout +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.jsonObject +import kotlinx.serialization.json.jsonPrimitive +import okhttp3.Response +import okhttp3.WebSocket +import okhttp3.WebSocketListener +import okhttp3.mockwebserver.Dispatcher +import okhttp3.mockwebserver.MockResponse +import okhttp3.mockwebserver.MockWebServer +import okhttp3.mockwebserver.RecordedRequest +import org.junit.Assert.assertEquals import org.junit.Assert.assertFalse import org.junit.Assert.assertTrue import org.junit.Test +import org.junit.runner.RunWith +import org.robolectric.RobolectricTestRunner +import org.robolectric.RuntimeEnvironment +import org.robolectric.annotation.Config +import java.util.concurrent.ConcurrentLinkedQueue +private const val LIFECYCLE_TEST_TIMEOUT_MS = 8_000L +private const val LIFECYCLE_CONNECT_CHALLENGE_FRAME = + """{"type":"event","event":"connect.challenge","payload":{"nonce":"android-test-nonce"}}""" + +private class ReconnectDeviceAuthStore : DeviceAuthTokenStore { + override fun loadEntry( + deviceId: String, + role: String, + ): DeviceAuthEntry? = null + + override fun saveToken( + deviceId: String, + role: String, + token: String, + scopes: List, + ) = Unit + + override fun clearToken( + deviceId: String, + role: String, + ) = Unit +} + +private data class ReconnectHarness( + val session: GatewaySession, + val sessionJob: Job, +) + +private data class ReconnectServer( + val server: MockWebServer, + val sockets: ConcurrentLinkedQueue, +) { + val port: Int + get() = server.port + + val requestCount: Int + get() = server.requestCount + + fun shutdown() { + sockets.forEach { runCatching { it.cancel() } } + runCatching { server.shutdown() } + .onFailure { err -> + if (err.message != "Gave up waiting for queue to shut down") throw err + } + } +} + +@RunWith(RobolectricTestRunner::class) +@Config(sdk = [34]) class GatewaySessionReconnectTest { + @Test + fun connectToNewGatewayClosesActiveConnectionAndStartsReplacement() = + runBlocking { + val json = Json { ignoreUnknownKeys = true } + val firstConnect = CompletableDeferred() + val firstClosed = CompletableDeferred() + val secondConnect = CompletableDeferred() + val secondClosed = CompletableDeferred() + val firstServer = + startGatewayServer( + json = json, + onClosed = { firstClosed.complete(Unit) }, + ) { webSocket, id, method -> + if (method == "connect") { + firstConnect.complete(Unit) + webSocket.send(connectResponseFrame(id)) + } + } + val secondServer = + startGatewayServer( + json = json, + onClosed = { secondClosed.complete(Unit) }, + ) { webSocket, id, method -> + if (method == "connect") { + secondConnect.complete(Unit) + webSocket.send(connectResponseFrame(id)) + } + } + val harness = createReconnectHarness() + + try { + connectNodeSession(harness.session, firstServer.port) + withTimeout(LIFECYCLE_TEST_TIMEOUT_MS) { firstConnect.await() } + + connectNodeSession(harness.session, secondServer.port) + + withTimeout(LIFECYCLE_TEST_TIMEOUT_MS) { firstClosed.await() } + withTimeout(LIFECYCLE_TEST_TIMEOUT_MS) { secondConnect.await() } + assertEquals(1, secondServer.requestCount) + harness.session.disconnect() + withTimeout(LIFECYCLE_TEST_TIMEOUT_MS) { secondClosed.await() } + } finally { + shutdownReconnectHarness(harness, firstServer, secondServer) + } + } + @Test fun bootstrapNodePairingRequiredKeepsReconnectActive() { val error = @@ -113,4 +232,125 @@ class GatewaySessionReconnectTest { ), ) } + + private fun createReconnectHarness(): ReconnectHarness { + val app = RuntimeEnvironment.getApplication() + val sessionJob = SupervisorJob() + val session = + GatewaySession( + scope = CoroutineScope(sessionJob + Dispatchers.Default), + identityStore = DeviceIdentityStore(app), + deviceAuthStore = ReconnectDeviceAuthStore(), + onConnected = { _, _, _ -> }, + onDisconnected = { _ -> }, + onEvent = { _, _ -> }, + onInvoke = { GatewaySession.InvokeResult.ok("""{"handled":true}""") }, + ) + return ReconnectHarness(session = session, sessionJob = sessionJob) + } + + private suspend fun connectNodeSession( + session: GatewaySession, + port: Int, + ) { + session.connect( + endpoint = + GatewayEndpoint( + stableId = "manual|127.0.0.1|$port", + name = "test", + host = "127.0.0.1", + port = port, + tlsEnabled = false, + ), + token = "test-token", + bootstrapToken = null, + password = null, + options = + GatewayConnectOptions( + role = "node", + scopes = listOf("node:invoke"), + caps = emptyList(), + commands = emptyList(), + permissions = emptyMap(), + client = + GatewayClientInfo( + id = "openclaw-android-test", + displayName = "Android Test", + version = "1.0.0-test", + platform = "android", + mode = "node", + instanceId = "android-test-instance", + deviceFamily = "android", + modelIdentifier = "test", + ), + ), + tls = null, + ) + } + + private suspend fun shutdownReconnectHarness( + harness: ReconnectHarness, + vararg servers: ReconnectServer, + ) { + harness.session.disconnect() + harness.sessionJob.cancelAndJoin() + servers.forEach { it.shutdown() } + } + + private fun connectResponseFrame(id: String): String = """{"type":"res","id":"$id","ok":true,"payload":{"snapshot":{"sessionDefaults":{"mainSessionKey":"main"}}}}""" + + private fun startGatewayServer( + json: Json, + onClosed: () -> Unit = {}, + onRequestFrame: (webSocket: WebSocket, id: String, method: String) -> Unit, + ): ReconnectServer { + val sockets = ConcurrentLinkedQueue() + val server = + MockWebServer().apply { + dispatcher = + object : Dispatcher() { + override fun dispatch(request: RecordedRequest): MockResponse = + MockResponse().withWebSocketUpgrade( + object : WebSocketListener() { + override fun onOpen( + webSocket: WebSocket, + response: Response, + ) { + sockets += webSocket + webSocket.send(LIFECYCLE_CONNECT_CHALLENGE_FRAME) + } + + override fun onMessage( + webSocket: WebSocket, + text: String, + ) { + val frame = json.parseToJsonElement(text).jsonObject + if (frame["type"]?.jsonPrimitive?.content != "req") return + val id = frame["id"]?.jsonPrimitive?.content ?: return + val method = frame["method"]?.jsonPrimitive?.content ?: return + onRequestFrame(webSocket, id, method) + } + + override fun onClosing( + webSocket: WebSocket, + code: Int, + reason: String, + ) { + onClosed() + } + + override fun onClosed( + webSocket: WebSocket, + code: Int, + reason: String, + ) { + onClosed() + } + }, + ) + } + start() + } + return ReconnectServer(server = server, sockets = sockets) + } }