mirror of
https://github.com/openclaw/openclaw.git
synced 2026-05-24 06:39:49 +00:00
fix(android): restart gateway session on reconnect
This commit is contained in:
@@ -149,7 +149,10 @@ class GatewaySession(
|
||||
val tls: GatewayTlsParams?,
|
||||
)
|
||||
|
||||
private var desired: DesiredConnection? = null
|
||||
private val lifecycleLock = Any()
|
||||
|
||||
@Volatile private var desired: DesiredConnection? = null
|
||||
|
||||
private var job: Job? = null
|
||||
|
||||
@Volatile private var currentConnection: Connection? = null
|
||||
@@ -168,26 +171,39 @@ class GatewaySession(
|
||||
options: GatewayConnectOptions,
|
||||
tls: GatewayTlsParams? = null,
|
||||
) {
|
||||
desired = DesiredConnection(endpoint, token, bootstrapToken, password, options, tls)
|
||||
pendingDeviceTokenRetry = false
|
||||
deviceTokenRetryBudgetUsed = false
|
||||
reconnectPausedForAuthFailure = false
|
||||
if (job == null) {
|
||||
job = scope.launch(Dispatchers.IO) { runLoop() }
|
||||
val connectionToClose: Connection?
|
||||
synchronized(lifecycleLock) {
|
||||
desired = DesiredConnection(endpoint, token, bootstrapToken, password, options, tls)
|
||||
pendingDeviceTokenRetry = false
|
||||
deviceTokenRetryBudgetUsed = false
|
||||
reconnectPausedForAuthFailure = false
|
||||
connectionToClose = currentConnection
|
||||
if (job?.isActive != true) {
|
||||
job = scope.launch(Dispatchers.IO) { runLoop() }
|
||||
}
|
||||
}
|
||||
connectionToClose?.closeQuietly()
|
||||
}
|
||||
|
||||
fun disconnect() {
|
||||
desired = null
|
||||
pendingDeviceTokenRetry = false
|
||||
deviceTokenRetryBudgetUsed = false
|
||||
reconnectPausedForAuthFailure = false
|
||||
currentConnection?.closeQuietly()
|
||||
scope.launch(Dispatchers.IO) {
|
||||
job?.cancelAndJoin()
|
||||
val jobToCancel: Job?
|
||||
val connectionToClose: Connection?
|
||||
synchronized(lifecycleLock) {
|
||||
desired = null
|
||||
pendingDeviceTokenRetry = false
|
||||
deviceTokenRetryBudgetUsed = false
|
||||
reconnectPausedForAuthFailure = false
|
||||
connectionToClose = currentConnection
|
||||
jobToCancel = job
|
||||
job = null
|
||||
pluginSurfaceUrls = emptyMap()
|
||||
mainSessionKey = null
|
||||
}
|
||||
connectionToClose?.closeQuietly()
|
||||
scope.launch(Dispatchers.IO) {
|
||||
jobToCancel?.cancelAndJoin()
|
||||
if (desired == null) {
|
||||
pluginSurfaceUrls = emptyMap()
|
||||
mainSessionKey = null
|
||||
}
|
||||
onDisconnected("Offline")
|
||||
}
|
||||
}
|
||||
@@ -963,9 +979,11 @@ class GatewaySession(
|
||||
conn.connect()
|
||||
conn.awaitClose()
|
||||
} finally {
|
||||
currentConnection = null
|
||||
pluginSurfaceUrls = emptyMap()
|
||||
mainSessionKey = null
|
||||
if (currentConnection === conn) {
|
||||
currentConnection = null
|
||||
pluginSurfaceUrls = emptyMap()
|
||||
mainSessionKey = null
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,10 +1,129 @@
|
||||
package ai.openclaw.app.gateway
|
||||
|
||||
import kotlinx.coroutines.CompletableDeferred
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.Job
|
||||
import kotlinx.coroutines.SupervisorJob
|
||||
import kotlinx.coroutines.cancelAndJoin
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import kotlinx.coroutines.withTimeout
|
||||
import kotlinx.serialization.json.Json
|
||||
import kotlinx.serialization.json.jsonObject
|
||||
import kotlinx.serialization.json.jsonPrimitive
|
||||
import okhttp3.Response
|
||||
import okhttp3.WebSocket
|
||||
import okhttp3.WebSocketListener
|
||||
import okhttp3.mockwebserver.Dispatcher
|
||||
import okhttp3.mockwebserver.MockResponse
|
||||
import okhttp3.mockwebserver.MockWebServer
|
||||
import okhttp3.mockwebserver.RecordedRequest
|
||||
import org.junit.Assert.assertEquals
|
||||
import org.junit.Assert.assertFalse
|
||||
import org.junit.Assert.assertTrue
|
||||
import org.junit.Test
|
||||
import org.junit.runner.RunWith
|
||||
import org.robolectric.RobolectricTestRunner
|
||||
import org.robolectric.RuntimeEnvironment
|
||||
import org.robolectric.annotation.Config
|
||||
import java.util.concurrent.ConcurrentLinkedQueue
|
||||
|
||||
private const val LIFECYCLE_TEST_TIMEOUT_MS = 8_000L
|
||||
private const val LIFECYCLE_CONNECT_CHALLENGE_FRAME =
|
||||
"""{"type":"event","event":"connect.challenge","payload":{"nonce":"android-test-nonce"}}"""
|
||||
|
||||
private class ReconnectDeviceAuthStore : DeviceAuthTokenStore {
|
||||
override fun loadEntry(
|
||||
deviceId: String,
|
||||
role: String,
|
||||
): DeviceAuthEntry? = null
|
||||
|
||||
override fun saveToken(
|
||||
deviceId: String,
|
||||
role: String,
|
||||
token: String,
|
||||
scopes: List<String>,
|
||||
) = Unit
|
||||
|
||||
override fun clearToken(
|
||||
deviceId: String,
|
||||
role: String,
|
||||
) = Unit
|
||||
}
|
||||
|
||||
private data class ReconnectHarness(
|
||||
val session: GatewaySession,
|
||||
val sessionJob: Job,
|
||||
)
|
||||
|
||||
private data class ReconnectServer(
|
||||
val server: MockWebServer,
|
||||
val sockets: ConcurrentLinkedQueue<WebSocket>,
|
||||
) {
|
||||
val port: Int
|
||||
get() = server.port
|
||||
|
||||
val requestCount: Int
|
||||
get() = server.requestCount
|
||||
|
||||
fun shutdown() {
|
||||
sockets.forEach { runCatching { it.cancel() } }
|
||||
runCatching { server.shutdown() }
|
||||
.onFailure { err ->
|
||||
if (err.message != "Gave up waiting for queue to shut down") throw err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@RunWith(RobolectricTestRunner::class)
|
||||
@Config(sdk = [34])
|
||||
class GatewaySessionReconnectTest {
|
||||
@Test
|
||||
fun connectToNewGatewayClosesActiveConnectionAndStartsReplacement() =
|
||||
runBlocking {
|
||||
val json = Json { ignoreUnknownKeys = true }
|
||||
val firstConnect = CompletableDeferred<Unit>()
|
||||
val firstClosed = CompletableDeferred<Unit>()
|
||||
val secondConnect = CompletableDeferred<Unit>()
|
||||
val secondClosed = CompletableDeferred<Unit>()
|
||||
val firstServer =
|
||||
startGatewayServer(
|
||||
json = json,
|
||||
onClosed = { firstClosed.complete(Unit) },
|
||||
) { webSocket, id, method ->
|
||||
if (method == "connect") {
|
||||
firstConnect.complete(Unit)
|
||||
webSocket.send(connectResponseFrame(id))
|
||||
}
|
||||
}
|
||||
val secondServer =
|
||||
startGatewayServer(
|
||||
json = json,
|
||||
onClosed = { secondClosed.complete(Unit) },
|
||||
) { webSocket, id, method ->
|
||||
if (method == "connect") {
|
||||
secondConnect.complete(Unit)
|
||||
webSocket.send(connectResponseFrame(id))
|
||||
}
|
||||
}
|
||||
val harness = createReconnectHarness()
|
||||
|
||||
try {
|
||||
connectNodeSession(harness.session, firstServer.port)
|
||||
withTimeout(LIFECYCLE_TEST_TIMEOUT_MS) { firstConnect.await() }
|
||||
|
||||
connectNodeSession(harness.session, secondServer.port)
|
||||
|
||||
withTimeout(LIFECYCLE_TEST_TIMEOUT_MS) { firstClosed.await() }
|
||||
withTimeout(LIFECYCLE_TEST_TIMEOUT_MS) { secondConnect.await() }
|
||||
assertEquals(1, secondServer.requestCount)
|
||||
harness.session.disconnect()
|
||||
withTimeout(LIFECYCLE_TEST_TIMEOUT_MS) { secondClosed.await() }
|
||||
} finally {
|
||||
shutdownReconnectHarness(harness, firstServer, secondServer)
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun bootstrapNodePairingRequiredKeepsReconnectActive() {
|
||||
val error =
|
||||
@@ -113,4 +232,125 @@ class GatewaySessionReconnectTest {
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
private fun createReconnectHarness(): ReconnectHarness {
|
||||
val app = RuntimeEnvironment.getApplication()
|
||||
val sessionJob = SupervisorJob()
|
||||
val session =
|
||||
GatewaySession(
|
||||
scope = CoroutineScope(sessionJob + Dispatchers.Default),
|
||||
identityStore = DeviceIdentityStore(app),
|
||||
deviceAuthStore = ReconnectDeviceAuthStore(),
|
||||
onConnected = { _, _, _ -> },
|
||||
onDisconnected = { _ -> },
|
||||
onEvent = { _, _ -> },
|
||||
onInvoke = { GatewaySession.InvokeResult.ok("""{"handled":true}""") },
|
||||
)
|
||||
return ReconnectHarness(session = session, sessionJob = sessionJob)
|
||||
}
|
||||
|
||||
private suspend fun connectNodeSession(
|
||||
session: GatewaySession,
|
||||
port: Int,
|
||||
) {
|
||||
session.connect(
|
||||
endpoint =
|
||||
GatewayEndpoint(
|
||||
stableId = "manual|127.0.0.1|$port",
|
||||
name = "test",
|
||||
host = "127.0.0.1",
|
||||
port = port,
|
||||
tlsEnabled = false,
|
||||
),
|
||||
token = "test-token",
|
||||
bootstrapToken = null,
|
||||
password = null,
|
||||
options =
|
||||
GatewayConnectOptions(
|
||||
role = "node",
|
||||
scopes = listOf("node:invoke"),
|
||||
caps = emptyList(),
|
||||
commands = emptyList(),
|
||||
permissions = emptyMap(),
|
||||
client =
|
||||
GatewayClientInfo(
|
||||
id = "openclaw-android-test",
|
||||
displayName = "Android Test",
|
||||
version = "1.0.0-test",
|
||||
platform = "android",
|
||||
mode = "node",
|
||||
instanceId = "android-test-instance",
|
||||
deviceFamily = "android",
|
||||
modelIdentifier = "test",
|
||||
),
|
||||
),
|
||||
tls = null,
|
||||
)
|
||||
}
|
||||
|
||||
private suspend fun shutdownReconnectHarness(
|
||||
harness: ReconnectHarness,
|
||||
vararg servers: ReconnectServer,
|
||||
) {
|
||||
harness.session.disconnect()
|
||||
harness.sessionJob.cancelAndJoin()
|
||||
servers.forEach { it.shutdown() }
|
||||
}
|
||||
|
||||
private fun connectResponseFrame(id: String): String = """{"type":"res","id":"$id","ok":true,"payload":{"snapshot":{"sessionDefaults":{"mainSessionKey":"main"}}}}"""
|
||||
|
||||
private fun startGatewayServer(
|
||||
json: Json,
|
||||
onClosed: () -> Unit = {},
|
||||
onRequestFrame: (webSocket: WebSocket, id: String, method: String) -> Unit,
|
||||
): ReconnectServer {
|
||||
val sockets = ConcurrentLinkedQueue<WebSocket>()
|
||||
val server =
|
||||
MockWebServer().apply {
|
||||
dispatcher =
|
||||
object : Dispatcher() {
|
||||
override fun dispatch(request: RecordedRequest): MockResponse =
|
||||
MockResponse().withWebSocketUpgrade(
|
||||
object : WebSocketListener() {
|
||||
override fun onOpen(
|
||||
webSocket: WebSocket,
|
||||
response: Response,
|
||||
) {
|
||||
sockets += webSocket
|
||||
webSocket.send(LIFECYCLE_CONNECT_CHALLENGE_FRAME)
|
||||
}
|
||||
|
||||
override fun onMessage(
|
||||
webSocket: WebSocket,
|
||||
text: String,
|
||||
) {
|
||||
val frame = json.parseToJsonElement(text).jsonObject
|
||||
if (frame["type"]?.jsonPrimitive?.content != "req") return
|
||||
val id = frame["id"]?.jsonPrimitive?.content ?: return
|
||||
val method = frame["method"]?.jsonPrimitive?.content ?: return
|
||||
onRequestFrame(webSocket, id, method)
|
||||
}
|
||||
|
||||
override fun onClosing(
|
||||
webSocket: WebSocket,
|
||||
code: Int,
|
||||
reason: String,
|
||||
) {
|
||||
onClosed()
|
||||
}
|
||||
|
||||
override fun onClosed(
|
||||
webSocket: WebSocket,
|
||||
code: Int,
|
||||
reason: String,
|
||||
) {
|
||||
onClosed()
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
start()
|
||||
}
|
||||
return ReconnectServer(server = server, sockets = sockets)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user