fix(android): restart gateway session on reconnect

This commit is contained in:
Ayaan Zaidi
2026-05-18 17:41:36 +05:30
parent 5fb9c0c937
commit f1f92b8656
2 changed files with 277 additions and 19 deletions

View File

@@ -149,7 +149,10 @@ class GatewaySession(
val tls: GatewayTlsParams?,
)
private var desired: DesiredConnection? = null
private val lifecycleLock = Any()
@Volatile private var desired: DesiredConnection? = null
private var job: Job? = null
@Volatile private var currentConnection: Connection? = null
@@ -168,26 +171,39 @@ class GatewaySession(
options: GatewayConnectOptions,
tls: GatewayTlsParams? = null,
) {
desired = DesiredConnection(endpoint, token, bootstrapToken, password, options, tls)
pendingDeviceTokenRetry = false
deviceTokenRetryBudgetUsed = false
reconnectPausedForAuthFailure = false
if (job == null) {
job = scope.launch(Dispatchers.IO) { runLoop() }
val connectionToClose: Connection?
synchronized(lifecycleLock) {
desired = DesiredConnection(endpoint, token, bootstrapToken, password, options, tls)
pendingDeviceTokenRetry = false
deviceTokenRetryBudgetUsed = false
reconnectPausedForAuthFailure = false
connectionToClose = currentConnection
if (job?.isActive != true) {
job = scope.launch(Dispatchers.IO) { runLoop() }
}
}
connectionToClose?.closeQuietly()
}
fun disconnect() {
desired = null
pendingDeviceTokenRetry = false
deviceTokenRetryBudgetUsed = false
reconnectPausedForAuthFailure = false
currentConnection?.closeQuietly()
scope.launch(Dispatchers.IO) {
job?.cancelAndJoin()
val jobToCancel: Job?
val connectionToClose: Connection?
synchronized(lifecycleLock) {
desired = null
pendingDeviceTokenRetry = false
deviceTokenRetryBudgetUsed = false
reconnectPausedForAuthFailure = false
connectionToClose = currentConnection
jobToCancel = job
job = null
pluginSurfaceUrls = emptyMap()
mainSessionKey = null
}
connectionToClose?.closeQuietly()
scope.launch(Dispatchers.IO) {
jobToCancel?.cancelAndJoin()
if (desired == null) {
pluginSurfaceUrls = emptyMap()
mainSessionKey = null
}
onDisconnected("Offline")
}
}
@@ -963,9 +979,11 @@ class GatewaySession(
conn.connect()
conn.awaitClose()
} finally {
currentConnection = null
pluginSurfaceUrls = emptyMap()
mainSessionKey = null
if (currentConnection === conn) {
currentConnection = null
pluginSurfaceUrls = emptyMap()
mainSessionKey = null
}
}
}

View File

@@ -1,10 +1,129 @@
package ai.openclaw.app.gateway
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Job
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.cancelAndJoin
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.withTimeout
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.jsonObject
import kotlinx.serialization.json.jsonPrimitive
import okhttp3.Response
import okhttp3.WebSocket
import okhttp3.WebSocketListener
import okhttp3.mockwebserver.Dispatcher
import okhttp3.mockwebserver.MockResponse
import okhttp3.mockwebserver.MockWebServer
import okhttp3.mockwebserver.RecordedRequest
import org.junit.Assert.assertEquals
import org.junit.Assert.assertFalse
import org.junit.Assert.assertTrue
import org.junit.Test
import org.junit.runner.RunWith
import org.robolectric.RobolectricTestRunner
import org.robolectric.RuntimeEnvironment
import org.robolectric.annotation.Config
import java.util.concurrent.ConcurrentLinkedQueue
private const val LIFECYCLE_TEST_TIMEOUT_MS = 8_000L
private const val LIFECYCLE_CONNECT_CHALLENGE_FRAME =
"""{"type":"event","event":"connect.challenge","payload":{"nonce":"android-test-nonce"}}"""
private class ReconnectDeviceAuthStore : DeviceAuthTokenStore {
override fun loadEntry(
deviceId: String,
role: String,
): DeviceAuthEntry? = null
override fun saveToken(
deviceId: String,
role: String,
token: String,
scopes: List<String>,
) = Unit
override fun clearToken(
deviceId: String,
role: String,
) = Unit
}
private data class ReconnectHarness(
val session: GatewaySession,
val sessionJob: Job,
)
private data class ReconnectServer(
val server: MockWebServer,
val sockets: ConcurrentLinkedQueue<WebSocket>,
) {
val port: Int
get() = server.port
val requestCount: Int
get() = server.requestCount
fun shutdown() {
sockets.forEach { runCatching { it.cancel() } }
runCatching { server.shutdown() }
.onFailure { err ->
if (err.message != "Gave up waiting for queue to shut down") throw err
}
}
}
@RunWith(RobolectricTestRunner::class)
@Config(sdk = [34])
class GatewaySessionReconnectTest {
@Test
fun connectToNewGatewayClosesActiveConnectionAndStartsReplacement() =
runBlocking {
val json = Json { ignoreUnknownKeys = true }
val firstConnect = CompletableDeferred<Unit>()
val firstClosed = CompletableDeferred<Unit>()
val secondConnect = CompletableDeferred<Unit>()
val secondClosed = CompletableDeferred<Unit>()
val firstServer =
startGatewayServer(
json = json,
onClosed = { firstClosed.complete(Unit) },
) { webSocket, id, method ->
if (method == "connect") {
firstConnect.complete(Unit)
webSocket.send(connectResponseFrame(id))
}
}
val secondServer =
startGatewayServer(
json = json,
onClosed = { secondClosed.complete(Unit) },
) { webSocket, id, method ->
if (method == "connect") {
secondConnect.complete(Unit)
webSocket.send(connectResponseFrame(id))
}
}
val harness = createReconnectHarness()
try {
connectNodeSession(harness.session, firstServer.port)
withTimeout(LIFECYCLE_TEST_TIMEOUT_MS) { firstConnect.await() }
connectNodeSession(harness.session, secondServer.port)
withTimeout(LIFECYCLE_TEST_TIMEOUT_MS) { firstClosed.await() }
withTimeout(LIFECYCLE_TEST_TIMEOUT_MS) { secondConnect.await() }
assertEquals(1, secondServer.requestCount)
harness.session.disconnect()
withTimeout(LIFECYCLE_TEST_TIMEOUT_MS) { secondClosed.await() }
} finally {
shutdownReconnectHarness(harness, firstServer, secondServer)
}
}
@Test
fun bootstrapNodePairingRequiredKeepsReconnectActive() {
val error =
@@ -113,4 +232,125 @@ class GatewaySessionReconnectTest {
),
)
}
private fun createReconnectHarness(): ReconnectHarness {
val app = RuntimeEnvironment.getApplication()
val sessionJob = SupervisorJob()
val session =
GatewaySession(
scope = CoroutineScope(sessionJob + Dispatchers.Default),
identityStore = DeviceIdentityStore(app),
deviceAuthStore = ReconnectDeviceAuthStore(),
onConnected = { _, _, _ -> },
onDisconnected = { _ -> },
onEvent = { _, _ -> },
onInvoke = { GatewaySession.InvokeResult.ok("""{"handled":true}""") },
)
return ReconnectHarness(session = session, sessionJob = sessionJob)
}
private suspend fun connectNodeSession(
session: GatewaySession,
port: Int,
) {
session.connect(
endpoint =
GatewayEndpoint(
stableId = "manual|127.0.0.1|$port",
name = "test",
host = "127.0.0.1",
port = port,
tlsEnabled = false,
),
token = "test-token",
bootstrapToken = null,
password = null,
options =
GatewayConnectOptions(
role = "node",
scopes = listOf("node:invoke"),
caps = emptyList(),
commands = emptyList(),
permissions = emptyMap(),
client =
GatewayClientInfo(
id = "openclaw-android-test",
displayName = "Android Test",
version = "1.0.0-test",
platform = "android",
mode = "node",
instanceId = "android-test-instance",
deviceFamily = "android",
modelIdentifier = "test",
),
),
tls = null,
)
}
private suspend fun shutdownReconnectHarness(
harness: ReconnectHarness,
vararg servers: ReconnectServer,
) {
harness.session.disconnect()
harness.sessionJob.cancelAndJoin()
servers.forEach { it.shutdown() }
}
private fun connectResponseFrame(id: String): String = """{"type":"res","id":"$id","ok":true,"payload":{"snapshot":{"sessionDefaults":{"mainSessionKey":"main"}}}}"""
private fun startGatewayServer(
json: Json,
onClosed: () -> Unit = {},
onRequestFrame: (webSocket: WebSocket, id: String, method: String) -> Unit,
): ReconnectServer {
val sockets = ConcurrentLinkedQueue<WebSocket>()
val server =
MockWebServer().apply {
dispatcher =
object : Dispatcher() {
override fun dispatch(request: RecordedRequest): MockResponse =
MockResponse().withWebSocketUpgrade(
object : WebSocketListener() {
override fun onOpen(
webSocket: WebSocket,
response: Response,
) {
sockets += webSocket
webSocket.send(LIFECYCLE_CONNECT_CHALLENGE_FRAME)
}
override fun onMessage(
webSocket: WebSocket,
text: String,
) {
val frame = json.parseToJsonElement(text).jsonObject
if (frame["type"]?.jsonPrimitive?.content != "req") return
val id = frame["id"]?.jsonPrimitive?.content ?: return
val method = frame["method"]?.jsonPrimitive?.content ?: return
onRequestFrame(webSocket, id, method)
}
override fun onClosing(
webSocket: WebSocket,
code: Int,
reason: String,
) {
onClosed()
}
override fun onClosed(
webSocket: WebSocket,
code: Int,
reason: String,
) {
onClosed()
}
},
)
}
start()
}
return ReconnectServer(server = server, sockets = sockets)
}
}