fix(security): require explicit trust for first-time TLS pins

This commit is contained in:
Peter Steinberger
2026-02-14 17:47:13 +01:00
parent d714ac7797
commit 054366dea4
16 changed files with 549 additions and 76 deletions

View File

@@ -25,6 +25,7 @@ class MainViewModel(app: Application) : AndroidViewModel(app) {
val statusText: StateFlow<String> = runtime.statusText val statusText: StateFlow<String> = runtime.statusText
val serverName: StateFlow<String?> = runtime.serverName val serverName: StateFlow<String?> = runtime.serverName
val remoteAddress: StateFlow<String?> = runtime.remoteAddress val remoteAddress: StateFlow<String?> = runtime.remoteAddress
val pendingGatewayTrust: StateFlow<NodeRuntime.GatewayTrustPrompt?> = runtime.pendingGatewayTrust
val isForeground: StateFlow<Boolean> = runtime.isForeground val isForeground: StateFlow<Boolean> = runtime.isForeground
val seamColorArgb: StateFlow<Long> = runtime.seamColorArgb val seamColorArgb: StateFlow<Long> = runtime.seamColorArgb
val mainSessionKey: StateFlow<String> = runtime.mainSessionKey val mainSessionKey: StateFlow<String> = runtime.mainSessionKey
@@ -145,6 +146,14 @@ class MainViewModel(app: Application) : AndroidViewModel(app) {
runtime.disconnect() runtime.disconnect()
} }
fun acceptGatewayTrustPrompt() {
runtime.acceptGatewayTrustPrompt()
}
fun declineGatewayTrustPrompt() {
runtime.declineGatewayTrustPrompt()
}
fun handleCanvasA2UIActionFromWebView(payloadJson: String) { fun handleCanvasA2UIActionFromWebView(payloadJson: String) {
runtime.handleCanvasA2UIActionFromWebView(payloadJson) runtime.handleCanvasA2UIActionFromWebView(payloadJson)
} }

View File

@@ -15,6 +15,7 @@ import ai.openclaw.android.gateway.DeviceIdentityStore
import ai.openclaw.android.gateway.GatewayDiscovery import ai.openclaw.android.gateway.GatewayDiscovery
import ai.openclaw.android.gateway.GatewayEndpoint import ai.openclaw.android.gateway.GatewayEndpoint
import ai.openclaw.android.gateway.GatewaySession import ai.openclaw.android.gateway.GatewaySession
import ai.openclaw.android.gateway.probeGatewayTlsFingerprint
import ai.openclaw.android.node.* import ai.openclaw.android.node.*
import ai.openclaw.android.protocol.OpenClawCanvasA2UIAction import ai.openclaw.android.protocol.OpenClawCanvasA2UIAction
import ai.openclaw.android.voice.TalkModeManager import ai.openclaw.android.voice.TalkModeManager
@@ -166,12 +167,20 @@ class NodeRuntime(context: Context) {
private lateinit var gatewayEventHandler: GatewayEventHandler private lateinit var gatewayEventHandler: GatewayEventHandler
data class GatewayTrustPrompt(
val endpoint: GatewayEndpoint,
val fingerprintSha256: String,
)
private val _isConnected = MutableStateFlow(false) private val _isConnected = MutableStateFlow(false)
val isConnected: StateFlow<Boolean> = _isConnected.asStateFlow() val isConnected: StateFlow<Boolean> = _isConnected.asStateFlow()
private val _statusText = MutableStateFlow("Offline") private val _statusText = MutableStateFlow("Offline")
val statusText: StateFlow<String> = _statusText.asStateFlow() val statusText: StateFlow<String> = _statusText.asStateFlow()
private val _pendingGatewayTrust = MutableStateFlow<GatewayTrustPrompt?>(null)
val pendingGatewayTrust: StateFlow<GatewayTrustPrompt?> = _pendingGatewayTrust.asStateFlow()
private val _mainSessionKey = MutableStateFlow("main") private val _mainSessionKey = MutableStateFlow("main")
val mainSessionKey: StateFlow<String> = _mainSessionKey.asStateFlow() val mainSessionKey: StateFlow<String> = _mainSessionKey.asStateFlow()
@@ -419,6 +428,12 @@ class NodeRuntime(context: Context) {
val host = manualHost.value.trim() val host = manualHost.value.trim()
val port = manualPort.value val port = manualPort.value
if (host.isNotEmpty() && port in 1..65535) { if (host.isNotEmpty() && port in 1..65535) {
// Security: autoconnect only to previously trusted gateways (stored TLS pin).
if (!manualTls.value) return@collect
val stableId = GatewayEndpoint.manual(host = host, port = port).stableId
val storedFingerprint = prefs.loadGatewayTlsFingerprint(stableId)?.trim().orEmpty()
if (storedFingerprint.isEmpty()) return@collect
didAutoConnect = true didAutoConnect = true
connect(GatewayEndpoint.manual(host = host, port = port)) connect(GatewayEndpoint.manual(host = host, port = port))
} }
@@ -528,17 +543,42 @@ class NodeRuntime(context: Context) {
} }
fun connect(endpoint: GatewayEndpoint) { fun connect(endpoint: GatewayEndpoint) {
val tls = connectionManager.resolveTlsParams(endpoint)
if (tls?.required == true && tls.expectedFingerprint.isNullOrBlank()) {
// First-time TLS: capture fingerprint, ask user to verify out-of-band, then store and connect.
_statusText.value = "Verify gateway TLS fingerprint…"
scope.launch {
val fp = probeGatewayTlsFingerprint(endpoint.host, endpoint.port) ?: run {
_statusText.value = "Failed: can't read TLS fingerprint"
return@launch
}
_pendingGatewayTrust.value = GatewayTrustPrompt(endpoint = endpoint, fingerprintSha256 = fp)
}
return
}
connectedEndpoint = endpoint connectedEndpoint = endpoint
operatorStatusText = "Connecting…" operatorStatusText = "Connecting…"
nodeStatusText = "Connecting…" nodeStatusText = "Connecting…"
updateStatus() updateStatus()
val token = prefs.loadGatewayToken() val token = prefs.loadGatewayToken()
val password = prefs.loadGatewayPassword() val password = prefs.loadGatewayPassword()
val tls = connectionManager.resolveTlsParams(endpoint)
operatorSession.connect(endpoint, token, password, connectionManager.buildOperatorConnectOptions(), tls) operatorSession.connect(endpoint, token, password, connectionManager.buildOperatorConnectOptions(), tls)
nodeSession.connect(endpoint, token, password, connectionManager.buildNodeConnectOptions(), tls) nodeSession.connect(endpoint, token, password, connectionManager.buildNodeConnectOptions(), tls)
} }
fun acceptGatewayTrustPrompt() {
val prompt = _pendingGatewayTrust.value ?: return
_pendingGatewayTrust.value = null
prefs.saveGatewayTlsFingerprint(prompt.endpoint.stableId, prompt.fingerprintSha256)
connect(prompt.endpoint)
}
fun declineGatewayTrustPrompt() {
_pendingGatewayTrust.value = null
_statusText.value = "Offline"
}
private fun hasRecordAudioPermission(): Boolean { private fun hasRecordAudioPermission(): Boolean {
return ( return (
ContextCompat.checkSelfPermission(appContext, Manifest.permission.RECORD_AUDIO) == ContextCompat.checkSelfPermission(appContext, Manifest.permission.RECORD_AUDIO) ==
@@ -558,6 +598,7 @@ class NodeRuntime(context: Context) {
fun disconnect() { fun disconnect() {
connectedEndpoint = null connectedEndpoint = null
_pendingGatewayTrust.value = null
operatorSession.disconnect() operatorSession.disconnect()
nodeSession.disconnect() nodeSession.disconnect()
} }

View File

@@ -1,13 +1,20 @@
package ai.openclaw.android.gateway package ai.openclaw.android.gateway
import android.annotation.SuppressLint import android.annotation.SuppressLint
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext
import java.net.InetSocketAddress
import java.security.MessageDigest import java.security.MessageDigest
import java.security.SecureRandom import java.security.SecureRandom
import java.security.cert.CertificateException import java.security.cert.CertificateException
import java.security.cert.X509Certificate import java.security.cert.X509Certificate
import javax.net.ssl.HttpsURLConnection
import javax.net.ssl.HostnameVerifier import javax.net.ssl.HostnameVerifier
import javax.net.ssl.SSLContext import javax.net.ssl.SSLContext
import javax.net.ssl.SSLParameters
import javax.net.ssl.SSLSocketFactory import javax.net.ssl.SSLSocketFactory
import javax.net.ssl.SNIHostName
import javax.net.ssl.SSLSocket
import javax.net.ssl.TrustManagerFactory import javax.net.ssl.TrustManagerFactory
import javax.net.ssl.X509TrustManager import javax.net.ssl.X509TrustManager
@@ -59,13 +66,72 @@ fun buildGatewayTlsConfig(
val context = SSLContext.getInstance("TLS") val context = SSLContext.getInstance("TLS")
context.init(null, arrayOf(trustManager), SecureRandom()) context.init(null, arrayOf(trustManager), SecureRandom())
val verifier =
if (expected != null || params.allowTOFU) {
// When pinning, we intentionally ignore hostname mismatch (service discovery often yields IPs).
HostnameVerifier { _, _ -> true }
} else {
HttpsURLConnection.getDefaultHostnameVerifier()
}
return GatewayTlsConfig( return GatewayTlsConfig(
sslSocketFactory = context.socketFactory, sslSocketFactory = context.socketFactory,
trustManager = trustManager, trustManager = trustManager,
hostnameVerifier = HostnameVerifier { _, _ -> true }, hostnameVerifier = verifier,
) )
} }
suspend fun probeGatewayTlsFingerprint(
host: String,
port: Int,
timeoutMs: Int = 3_000,
): String? {
val trimmedHost = host.trim()
if (trimmedHost.isEmpty()) return null
if (port !in 1..65535) return null
return withContext(Dispatchers.IO) {
val trustAll =
@SuppressLint("CustomX509TrustManager")
object : X509TrustManager {
override fun checkClientTrusted(chain: Array<X509Certificate>, authType: String) {}
override fun checkServerTrusted(chain: Array<X509Certificate>, authType: String) {}
override fun getAcceptedIssuers(): Array<X509Certificate> = emptyArray()
}
val context = SSLContext.getInstance("TLS")
context.init(null, arrayOf(trustAll), SecureRandom())
val socket = (context.socketFactory.createSocket() as SSLSocket)
try {
socket.soTimeout = timeoutMs
socket.connect(InetSocketAddress(trimmedHost, port), timeoutMs)
// Best-effort SNI for hostnames (avoid crashing on IP literals).
try {
if (trimmedHost.any { it.isLetter() }) {
val params = SSLParameters()
params.serverNames = listOf(SNIHostName(trimmedHost))
socket.sslParameters = params
}
} catch (_: Throwable) {
// ignore
}
socket.startHandshake()
val cert = socket.session.peerCertificates.firstOrNull() as? X509Certificate ?: return@withContext null
sha256Hex(cert.encoded)
} catch (_: Throwable) {
null
} finally {
try {
socket.close()
} catch (_: Throwable) {
// ignore
}
}
}
}
private fun defaultTrustManager(): X509TrustManager { private fun defaultTrustManager(): X509TrustManager {
val factory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) val factory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
factory.init(null as java.security.KeyStore?) factory.init(null as java.security.KeyStore?)

View File

@@ -49,7 +49,7 @@ class ConnectionManager(
return GatewayTlsParams( return GatewayTlsParams(
required = true, required = true,
expectedFingerprint = null, expectedFingerprint = null,
allowTOFU = true, allowTOFU = false,
stableId = stableId, stableId = stableId,
) )
} }
@@ -70,7 +70,7 @@ class ConnectionManager(
return GatewayTlsParams( return GatewayTlsParams(
required = true, required = true,
expectedFingerprint = null, expectedFingerprint = null,
allowTOFU = true, allowTOFU = false,
stableId = stableId, stableId = stableId,
) )
} }

View File

@@ -34,6 +34,7 @@ import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.filled.ExpandLess import androidx.compose.material.icons.filled.ExpandLess
import androidx.compose.material.icons.filled.ExpandMore import androidx.compose.material.icons.filled.ExpandMore
import androidx.compose.material3.Button import androidx.compose.material3.Button
import androidx.compose.material3.AlertDialog
import androidx.compose.material3.HorizontalDivider import androidx.compose.material3.HorizontalDivider
import androidx.compose.material3.Icon import androidx.compose.material3.Icon
import androidx.compose.material3.ListItem import androidx.compose.material3.ListItem
@@ -42,6 +43,7 @@ import androidx.compose.material3.OutlinedTextField
import androidx.compose.material3.RadioButton import androidx.compose.material3.RadioButton
import androidx.compose.material3.Switch import androidx.compose.material3.Switch
import androidx.compose.material3.Text import androidx.compose.material3.Text
import androidx.compose.material3.TextButton
import androidx.compose.runtime.Composable import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.collectAsState import androidx.compose.runtime.collectAsState
@@ -89,6 +91,7 @@ fun SettingsSheet(viewModel: MainViewModel) {
val remoteAddress by viewModel.remoteAddress.collectAsState() val remoteAddress by viewModel.remoteAddress.collectAsState()
val gateways by viewModel.gateways.collectAsState() val gateways by viewModel.gateways.collectAsState()
val discoveryStatusText by viewModel.discoveryStatusText.collectAsState() val discoveryStatusText by viewModel.discoveryStatusText.collectAsState()
val pendingTrust by viewModel.pendingGatewayTrust.collectAsState()
val listState = rememberLazyListState() val listState = rememberLazyListState()
val (wakeWordsText, setWakeWordsText) = remember { mutableStateOf("") } val (wakeWordsText, setWakeWordsText) = remember { mutableStateOf("") }
@@ -112,6 +115,31 @@ fun SettingsSheet(viewModel: MainViewModel) {
} }
} }
if (pendingTrust != null) {
val prompt = pendingTrust!!
AlertDialog(
onDismissRequest = { viewModel.declineGatewayTrustPrompt() },
title = { Text("Trust this gateway?") },
text = {
Text(
"First-time TLS connection.\n\n" +
"Verify this SHA-256 fingerprint out-of-band before trusting:\n" +
prompt.fingerprintSha256,
)
},
confirmButton = {
TextButton(onClick = { viewModel.acceptGatewayTrustPrompt() }) {
Text("Trust and connect")
}
},
dismissButton = {
TextButton(onClick = { viewModel.declineGatewayTrustPrompt() }) {
Text("Cancel")
}
},
)
}
LaunchedEffect(wakeWords) { setWakeWordsText(wakeWords.joinToString(", ")) } LaunchedEffect(wakeWords) { setWakeWordsText(wakeWords.joinToString(", ")) }
val commitWakeWords = { val commitWakeWords = {
val parsed = WakeWords.parseIfChanged(wakeWordsText, wakeWords) val parsed = WakeWords.parseIfChanged(wakeWordsText, wakeWords)

View File

@@ -49,7 +49,7 @@ class ConnectionManagerTest {
) )
assertNull(params?.expectedFingerprint) assertNull(params?.expectedFingerprint)
assertEquals(true, params?.allowTOFU) assertEquals(false, params?.allowTOFU)
} }
@Test @Test
@@ -71,7 +71,6 @@ class ConnectionManagerTest {
manualTlsEnabled = true, manualTlsEnabled = true,
) )
assertNull(on?.expectedFingerprint) assertNull(on?.expectedFingerprint)
assertEquals(true, on?.allowTOFU) assertEquals(false, on?.allowTOFU)
} }
} }

View File

@@ -2,6 +2,7 @@ import AVFoundation
import Contacts import Contacts
import CoreLocation import CoreLocation
import CoreMotion import CoreMotion
import CryptoKit
import EventKit import EventKit
import Foundation import Foundation
import OpenClawKit import OpenClawKit
@@ -9,6 +10,7 @@ import Network
import Observation import Observation
import Photos import Photos
import ReplayKit import ReplayKit
import Security
import Speech import Speech
import SwiftUI import SwiftUI
import UIKit import UIKit
@@ -16,14 +18,27 @@ import UIKit
@MainActor @MainActor
@Observable @Observable
final class GatewayConnectionController { final class GatewayConnectionController {
struct TrustPrompt: Identifiable, Equatable {
let stableID: String
let gatewayName: String
let host: String
let port: Int
let fingerprintSha256: String
let isManual: Bool
var id: String { self.stableID }
}
private(set) var gateways: [GatewayDiscoveryModel.DiscoveredGateway] = [] private(set) var gateways: [GatewayDiscoveryModel.DiscoveredGateway] = []
private(set) var discoveryStatusText: String = "Idle" private(set) var discoveryStatusText: String = "Idle"
private(set) var discoveryDebugLog: [GatewayDiscoveryModel.DebugLogEntry] = [] private(set) var discoveryDebugLog: [GatewayDiscoveryModel.DebugLogEntry] = []
private(set) var pendingTrustPrompt: TrustPrompt?
private let discovery = GatewayDiscoveryModel() private let discovery = GatewayDiscoveryModel()
private weak var appModel: NodeAppModel? private weak var appModel: NodeAppModel?
private var didAutoConnect = false private var didAutoConnect = false
private var pendingServiceResolvers: [String: GatewayServiceResolver] = [:] private var pendingServiceResolvers: [String: GatewayServiceResolver] = [:]
private var pendingTrustConnect: (url: URL, stableID: String, isManual: Bool)?
init(appModel: NodeAppModel, startDiscovery: Bool = true) { init(appModel: NodeAppModel, startDiscovery: Bool = true) {
self.appModel = appModel self.appModel = appModel
@@ -58,12 +73,11 @@ final class GatewayConnectionController {
} }
func connect(_ gateway: GatewayDiscoveryModel.DiscoveredGateway) async { func connect(_ gateway: GatewayDiscoveryModel.DiscoveredGateway) async {
await self.connectDiscoveredGateway(gateway, allowTOFU: true) await self.connectDiscoveredGateway(gateway)
} }
private func connectDiscoveredGateway( private func connectDiscoveredGateway(
_ gateway: GatewayDiscoveryModel.DiscoveredGateway, _ gateway: GatewayDiscoveryModel.DiscoveredGateway) async
allowTOFU: Bool) async
{ {
let instanceId = UserDefaults.standard.string(forKey: "node.instanceId")? let instanceId = UserDefaults.standard.string(forKey: "node.instanceId")?
.trimmingCharacters(in: .whitespacesAndNewlines) ?? "" .trimmingCharacters(in: .whitespacesAndNewlines) ?? ""
@@ -73,21 +87,43 @@ final class GatewayConnectionController {
// Resolve the service endpoint (SRV/A/AAAA). TXT is unauthenticated; do not route via TXT. // Resolve the service endpoint (SRV/A/AAAA). TXT is unauthenticated; do not route via TXT.
guard let target = await self.resolveServiceEndpoint(gateway.endpoint) else { return } guard let target = await self.resolveServiceEndpoint(gateway.endpoint) else { return }
let tlsParams = self.resolveDiscoveredTLSParams(gateway: gateway, allowTOFU: allowTOFU) let stableID = gateway.stableID
// Discovery is a LAN operation; refuse unauthenticated plaintext connects.
let tlsRequired = true
let stored = GatewayTLSStore.loadFingerprint(stableID: stableID)
guard gateway.tlsEnabled || stored != nil else { return }
if tlsRequired, stored == nil {
guard let url = self.buildGatewayURL(host: target.host, port: target.port, useTLS: true)
else { return }
guard let fp = await self.probeTLSFingerprint(url: url) else { return }
self.pendingTrustConnect = (url: url, stableID: stableID, isManual: false)
self.pendingTrustPrompt = TrustPrompt(
stableID: stableID,
gatewayName: gateway.name,
host: target.host,
port: target.port,
fingerprintSha256: fp,
isManual: false)
self.appModel?.gatewayStatusText = "Verify gateway TLS fingerprint"
return
}
let tlsParams = stored.map { fp in
GatewayTLSParams(required: true, expectedFingerprint: fp, allowTOFU: false, storeKey: stableID)
}
guard let url = self.buildGatewayURL( guard let url = self.buildGatewayURL(
host: target.host, host: target.host,
port: target.port, port: target.port,
useTLS: tlsParams?.required == true) useTLS: tlsParams?.required == true)
else { return } else { return }
GatewaySettingsStore.saveLastGatewayConnection( GatewaySettingsStore.saveLastGatewayConnectionDiscovered(stableID: stableID, useTLS: true)
host: target.host,
port: target.port,
useTLS: tlsParams?.required == true,
stableID: gateway.stableID)
self.didAutoConnect = true self.didAutoConnect = true
self.startAutoConnect( self.startAutoConnect(
url: url, url: url,
gatewayStableID: gateway.stableID, gatewayStableID: stableID,
tls: tlsParams, tls: tlsParams,
token: token, token: token,
password: password) password: password)
@@ -102,19 +138,34 @@ final class GatewayConnectionController {
guard let resolvedPort = self.resolveManualPort(host: host, port: port, useTLS: resolvedUseTLS) guard let resolvedPort = self.resolveManualPort(host: host, port: port, useTLS: resolvedUseTLS)
else { return } else { return }
let stableID = self.manualStableID(host: host, port: resolvedPort) let stableID = self.manualStableID(host: host, port: resolvedPort)
let tlsParams = self.resolveManualTLSParams( let stored = GatewayTLSStore.loadFingerprint(stableID: stableID)
stableID: stableID, if resolvedUseTLS, stored == nil {
tlsEnabled: resolvedUseTLS, guard let url = self.buildGatewayURL(host: host, port: resolvedPort, useTLS: true) else { return }
allowTOFUReset: self.shouldForceTLS(host: host)) guard let fp = await self.probeTLSFingerprint(url: url) else { return }
self.pendingTrustConnect = (url: url, stableID: stableID, isManual: true)
self.pendingTrustPrompt = TrustPrompt(
stableID: stableID,
gatewayName: "\(host):\(resolvedPort)",
host: host,
port: resolvedPort,
fingerprintSha256: fp,
isManual: true)
self.appModel?.gatewayStatusText = "Verify gateway TLS fingerprint"
return
}
let tlsParams = stored.map { fp in
GatewayTLSParams(required: true, expectedFingerprint: fp, allowTOFU: false, storeKey: stableID)
}
guard let url = self.buildGatewayURL( guard let url = self.buildGatewayURL(
host: host, host: host,
port: resolvedPort, port: resolvedPort,
useTLS: tlsParams?.required == true) useTLS: tlsParams?.required == true)
else { return } else { return }
GatewaySettingsStore.saveLastGatewayConnection( GatewaySettingsStore.saveLastGatewayConnectionManual(
host: host, host: host,
port: resolvedPort, port: resolvedPort,
useTLS: tlsParams?.required == true, useTLS: resolvedUseTLS && tlsParams != nil,
stableID: stableID) stableID: stableID)
self.didAutoConnect = true self.didAutoConnect = true
self.startAutoConnect( self.startAutoConnect(
@@ -127,36 +178,63 @@ final class GatewayConnectionController {
func connectLastKnown() async { func connectLastKnown() async {
guard let last = GatewaySettingsStore.loadLastGatewayConnection() else { return } guard let last = GatewaySettingsStore.loadLastGatewayConnection() else { return }
switch last {
case let .manual(host, port, useTLS, _):
await self.connectManual(host: host, port: port, useTLS: useTLS)
case let .discovered(stableID, _):
guard let gateway = self.gateways.first(where: { $0.stableID == stableID }) else { return }
await self.connectDiscoveredGateway(gateway)
}
}
func clearPendingTrustPrompt() {
self.pendingTrustPrompt = nil
self.pendingTrustConnect = nil
}
func acceptPendingTrustPrompt() async {
guard let pending = self.pendingTrustConnect,
let prompt = self.pendingTrustPrompt,
pending.stableID == prompt.stableID
else { return }
GatewayTLSStore.saveFingerprint(prompt.fingerprintSha256, stableID: pending.stableID)
self.clearPendingTrustPrompt()
if pending.isManual {
GatewaySettingsStore.saveLastGatewayConnectionManual(
host: prompt.host,
port: prompt.port,
useTLS: true,
stableID: pending.stableID)
} else {
GatewaySettingsStore.saveLastGatewayConnectionDiscovered(stableID: pending.stableID, useTLS: true)
}
let instanceId = UserDefaults.standard.string(forKey: "node.instanceId")? let instanceId = UserDefaults.standard.string(forKey: "node.instanceId")?
.trimmingCharacters(in: .whitespacesAndNewlines) ?? "" .trimmingCharacters(in: .whitespacesAndNewlines) ?? ""
let token = GatewaySettingsStore.loadGatewayToken(instanceId: instanceId) let token = GatewaySettingsStore.loadGatewayToken(instanceId: instanceId)
let password = GatewaySettingsStore.loadGatewayPassword(instanceId: instanceId) let password = GatewaySettingsStore.loadGatewayPassword(instanceId: instanceId)
let resolvedUseTLS = last.useTLS let tlsParams = GatewayTLSParams(
let tlsParams = self.resolveManualTLSParams( required: true,
stableID: last.stableID, expectedFingerprint: prompt.fingerprintSha256,
tlsEnabled: resolvedUseTLS, allowTOFU: false,
allowTOFUReset: self.shouldForceTLS(host: last.host)) storeKey: pending.stableID)
guard let url = self.buildGatewayURL(
host: last.host,
port: last.port,
useTLS: tlsParams?.required == true)
else { return }
if resolvedUseTLS != last.useTLS {
GatewaySettingsStore.saveLastGatewayConnection(
host: last.host,
port: last.port,
useTLS: resolvedUseTLS,
stableID: last.stableID)
}
self.didAutoConnect = true self.didAutoConnect = true
self.startAutoConnect( self.startAutoConnect(
url: url, url: pending.url,
gatewayStableID: last.stableID, gatewayStableID: pending.stableID,
tls: tlsParams, tls: tlsParams,
token: token, token: token,
password: password) password: password)
} }
func declinePendingTrustPrompt() {
self.clearPendingTrustPrompt()
self.appModel?.gatewayStatusText = "Offline"
}
private func updateFromDiscovery() { private func updateFromDiscovery() {
let newGateways = self.discovery.gateways let newGateways = self.discovery.gateways
self.gateways = newGateways self.gateways = newGateways
@@ -233,25 +311,30 @@ final class GatewayConnectionController {
} }
if let lastKnown = GatewaySettingsStore.loadLastGatewayConnection() { if let lastKnown = GatewaySettingsStore.loadLastGatewayConnection() {
let resolvedUseTLS = lastKnown.useTLS || self.shouldForceTLS(host: lastKnown.host) if case let .manual(host, port, useTLS, stableID) = lastKnown {
let tlsParams = self.resolveManualTLSParams( let resolvedUseTLS = useTLS || self.shouldForceTLS(host: host)
stableID: lastKnown.stableID, let stored = GatewayTLSStore.loadFingerprint(stableID: stableID)
tlsEnabled: resolvedUseTLS, let tlsParams = stored.map { fp in
allowTOFUReset: self.shouldForceTLS(host: lastKnown.host)) GatewayTLSParams(required: true, expectedFingerprint: fp, allowTOFU: false, storeKey: stableID)
guard let url = self.buildGatewayURL( }
host: lastKnown.host, guard let url = self.buildGatewayURL(
port: lastKnown.port, host: host,
useTLS: tlsParams?.required == true) port: port,
else { return } useTLS: resolvedUseTLS && tlsParams != nil)
else { return }
self.didAutoConnect = true // Security: autoconnect only to previously trusted gateways (stored TLS pin).
self.startAutoConnect( guard tlsParams != nil else { return }
url: url,
gatewayStableID: lastKnown.stableID, self.didAutoConnect = true
tls: tlsParams, self.startAutoConnect(
token: token, url: url,
password: password) gatewayStableID: stableID,
return tls: tlsParams,
token: token,
password: password)
return
}
} }
let preferredStableID = defaults.string(forKey: "gateway.preferredStableID")? let preferredStableID = defaults.string(forKey: "gateway.preferredStableID")?
@@ -270,7 +353,7 @@ final class GatewayConnectionController {
self.didAutoConnect = true self.didAutoConnect = true
Task { [weak self] in Task { [weak self] in
guard let self else { return } guard let self else { return }
await self.connectDiscoveredGateway(target, allowTOFU: false) await self.connectDiscoveredGateway(target)
} }
return return
} }
@@ -282,7 +365,7 @@ final class GatewayConnectionController {
self.didAutoConnect = true self.didAutoConnect = true
Task { [weak self] in Task { [weak self] in
guard let self else { return } guard let self else { return }
await self.connectDiscoveredGateway(gateway, allowTOFU: false) await self.connectDiscoveredGateway(gateway)
} }
return return
} }
@@ -359,7 +442,7 @@ final class GatewayConnectionController {
return GatewayTLSParams( return GatewayTLSParams(
required: true, required: true,
expectedFingerprint: nil, expectedFingerprint: nil,
allowTOFU: allowTOFU, allowTOFU: false,
storeKey: stableID) storeKey: stableID)
} }
@@ -376,13 +459,22 @@ final class GatewayConnectionController {
return GatewayTLSParams( return GatewayTLSParams(
required: true, required: true,
expectedFingerprint: stored, expectedFingerprint: stored,
allowTOFU: stored == nil || allowTOFUReset, allowTOFU: false,
storeKey: stableID) storeKey: stableID)
} }
return nil return nil
} }
private func probeTLSFingerprint(url: URL) async -> String? {
await withCheckedContinuation { continuation in
let probe = GatewayTLSFingerprintProbe(url: url, timeoutSeconds: 3) { fp in
continuation.resume(returning: fp)
}
probe.start()
}
}
private func resolveServiceEndpoint(_ endpoint: NWEndpoint) async -> (host: String, port: Int)? { private func resolveServiceEndpoint(_ endpoint: NWEndpoint) async -> (host: String, port: Int)? {
guard case let .service(name, type, domain, _) = endpoint else { return nil } guard case let .service(name, type, domain, _) = endpoint else { return nil }
let key = "\(domain)|\(type)|\(name)" let key = "\(domain)|\(type)|\(name)"
@@ -692,3 +784,71 @@ extension GatewayConnectionController {
} }
} }
#endif #endif
private final class GatewayTLSFingerprintProbe: NSObject, URLSessionDelegate {
private let url: URL
private let timeoutSeconds: Double
private let onComplete: (String?) -> Void
private var didFinish = false
private var session: URLSession?
private var task: URLSessionWebSocketTask?
init(url: URL, timeoutSeconds: Double, onComplete: @escaping (String?) -> Void) {
self.url = url
self.timeoutSeconds = timeoutSeconds
self.onComplete = onComplete
}
func start() {
let config = URLSessionConfiguration.ephemeral
config.timeoutIntervalForRequest = self.timeoutSeconds
config.timeoutIntervalForResource = self.timeoutSeconds
let session = URLSession(configuration: config, delegate: self, delegateQueue: nil)
self.session = session
let task = session.webSocketTask(with: self.url)
self.task = task
task.resume()
DispatchQueue.global(qos: .utility).asyncAfter(deadline: .now() + self.timeoutSeconds) { [weak self] in
self?.finish(nil)
}
}
func urlSession(
_ session: URLSession,
didReceive challenge: URLAuthenticationChallenge,
completionHandler: @escaping (URLSession.AuthChallengeDisposition, URLCredential?) -> Void
) {
guard challenge.protectionSpace.authenticationMethod == NSURLAuthenticationMethodServerTrust,
let trust = challenge.protectionSpace.serverTrust
else {
completionHandler(.performDefaultHandling, nil)
return
}
let fp = GatewayTLSFingerprintProbe.certificateFingerprint(trust)
completionHandler(.cancelAuthenticationChallenge, nil)
self.finish(fp)
}
private func finish(_ fingerprint: String?) {
objc_sync_enter(self)
defer { objc_sync_exit(self) }
guard !self.didFinish else { return }
self.didFinish = true
self.task?.cancel(with: .goingAway, reason: nil)
self.session?.invalidateAndCancel()
self.onComplete(fingerprint)
}
private static func certificateFingerprint(_ trust: SecTrust) -> String? {
guard let chain = SecTrustCopyCertificateChain(trust) as? [SecCertificate],
let cert = chain.first
else {
return nil
}
let data = SecCertificateCopyData(cert) as Data
let digest = SHA256.hash(data: data)
return digest.map { String(format: "%02x", $0) }.joined()
}
}

View File

@@ -13,6 +13,7 @@ enum GatewaySettingsStore {
private static let manualPortDefaultsKey = "gateway.manual.port" private static let manualPortDefaultsKey = "gateway.manual.port"
private static let manualTlsDefaultsKey = "gateway.manual.tls" private static let manualTlsDefaultsKey = "gateway.manual.tls"
private static let discoveryDebugLogsDefaultsKey = "gateway.discovery.debugLogs" private static let discoveryDebugLogsDefaultsKey = "gateway.discovery.debugLogs"
private static let lastGatewayKindDefaultsKey = "gateway.last.kind"
private static let lastGatewayHostDefaultsKey = "gateway.last.host" private static let lastGatewayHostDefaultsKey = "gateway.last.host"
private static let lastGatewayPortDefaultsKey = "gateway.last.port" private static let lastGatewayPortDefaultsKey = "gateway.last.port"
private static let lastGatewayTlsDefaultsKey = "gateway.last.tls" private static let lastGatewayTlsDefaultsKey = "gateway.last.tls"
@@ -114,25 +115,73 @@ enum GatewaySettingsStore {
account: self.gatewayPasswordAccount(instanceId: instanceId)) account: self.gatewayPasswordAccount(instanceId: instanceId))
} }
static func saveLastGatewayConnection(host: String, port: Int, useTLS: Bool, stableID: String) { enum LastGatewayConnection: Equatable {
case manual(host: String, port: Int, useTLS: Bool, stableID: String)
case discovered(stableID: String, useTLS: Bool)
var stableID: String {
switch self {
case let .manual(_, _, _, stableID):
return stableID
case let .discovered(stableID, _):
return stableID
}
}
var useTLS: Bool {
switch self {
case let .manual(_, _, useTLS, _):
return useTLS
case let .discovered(_, useTLS):
return useTLS
}
}
}
private enum LastGatewayKind: String {
case manual
case discovered
}
static func saveLastGatewayConnectionManual(host: String, port: Int, useTLS: Bool, stableID: String) {
let defaults = UserDefaults.standard let defaults = UserDefaults.standard
defaults.set(LastGatewayKind.manual.rawValue, forKey: self.lastGatewayKindDefaultsKey)
defaults.set(host, forKey: self.lastGatewayHostDefaultsKey) defaults.set(host, forKey: self.lastGatewayHostDefaultsKey)
defaults.set(port, forKey: self.lastGatewayPortDefaultsKey) defaults.set(port, forKey: self.lastGatewayPortDefaultsKey)
defaults.set(useTLS, forKey: self.lastGatewayTlsDefaultsKey) defaults.set(useTLS, forKey: self.lastGatewayTlsDefaultsKey)
defaults.set(stableID, forKey: self.lastGatewayStableIDDefaultsKey) defaults.set(stableID, forKey: self.lastGatewayStableIDDefaultsKey)
} }
static func loadLastGatewayConnection() -> (host: String, port: Int, useTLS: Bool, stableID: String)? { static func saveLastGatewayConnectionDiscovered(stableID: String, useTLS: Bool) {
let defaults = UserDefaults.standard let defaults = UserDefaults.standard
defaults.set(LastGatewayKind.discovered.rawValue, forKey: self.lastGatewayKindDefaultsKey)
defaults.removeObject(forKey: self.lastGatewayHostDefaultsKey)
defaults.removeObject(forKey: self.lastGatewayPortDefaultsKey)
defaults.set(useTLS, forKey: self.lastGatewayTlsDefaultsKey)
defaults.set(stableID, forKey: self.lastGatewayStableIDDefaultsKey)
}
static func loadLastGatewayConnection() -> LastGatewayConnection? {
let defaults = UserDefaults.standard
let stableID = defaults.string(forKey: self.lastGatewayStableIDDefaultsKey)?
.trimmingCharacters(in: .whitespacesAndNewlines) ?? ""
guard !stableID.isEmpty else { return nil }
let useTLS = defaults.bool(forKey: self.lastGatewayTlsDefaultsKey)
let kindRaw = defaults.string(forKey: self.lastGatewayKindDefaultsKey)?
.trimmingCharacters(in: .whitespacesAndNewlines) ?? ""
let kind = LastGatewayKind(rawValue: kindRaw) ?? .manual
if kind == .discovered {
return .discovered(stableID: stableID, useTLS: useTLS)
}
let host = defaults.string(forKey: self.lastGatewayHostDefaultsKey)? let host = defaults.string(forKey: self.lastGatewayHostDefaultsKey)?
.trimmingCharacters(in: .whitespacesAndNewlines) ?? "" .trimmingCharacters(in: .whitespacesAndNewlines) ?? ""
let port = defaults.integer(forKey: self.lastGatewayPortDefaultsKey) let port = defaults.integer(forKey: self.lastGatewayPortDefaultsKey)
let useTLS = defaults.bool(forKey: self.lastGatewayTlsDefaultsKey)
let stableID = defaults.string(forKey: self.lastGatewayStableIDDefaultsKey)?
.trimmingCharacters(in: .whitespacesAndNewlines) ?? ""
guard !host.isEmpty, port > 0, port <= 65535, !stableID.isEmpty else { return nil } // Back-compat: older builds persisted manual-style host/port without a kind marker.
return (host: host, port: port, useTLS: useTLS, stableID: stableID) guard !host.isEmpty, port > 0, port <= 65535 else { return nil }
return .manual(host: host, port: port, useTLS: useTLS, stableID: stableID)
} }
static func loadGatewayClientIdOverride(stableID: String) -> String? { static func loadGatewayClientIdOverride(stableID: String) -> String? {

View File

@@ -0,0 +1,42 @@
import SwiftUI
struct GatewayTrustPromptAlert: ViewModifier {
@Environment(GatewayConnectionController.self) private var gatewayController: GatewayConnectionController
private var promptBinding: Binding<GatewayConnectionController.TrustPrompt?> {
Binding(
get: { self.gatewayController.pendingTrustPrompt },
set: { newValue in
if newValue == nil {
self.gatewayController.clearPendingTrustPrompt()
}
})
}
func body(content: Content) -> some View {
content.alert(item: self.promptBinding) { prompt in
Alert(
title: Text("Trust this gateway?"),
message: Text(
"""
First-time TLS connection.
Verify this SHA-256 fingerprint out-of-band before trusting:
\(prompt.fingerprintSha256)
"""),
primaryButton: .cancel(Text("Cancel")) {
self.gatewayController.declinePendingTrustPrompt()
},
secondaryButton: .default(Text("Trust and connect")) {
Task { await self.gatewayController.acceptPendingTrustPrompt() }
})
}
}
}
extension View {
func gatewayTrustPromptAlert() -> some View {
self.modifier(GatewayTrustPromptAlert())
}
}

View File

@@ -21,6 +21,7 @@ struct GatewayOnboardingView: View {
} }
.navigationTitle("Connect Gateway") .navigationTitle("Connect Gateway")
} }
.gatewayTrustPromptAlert()
} }
} }

View File

@@ -52,6 +52,7 @@ struct RootCanvas: View {
CameraFlashOverlay(nonce: self.appModel.cameraFlashNonce) CameraFlashOverlay(nonce: self.appModel.cameraFlashNonce)
} }
} }
.gatewayTrustPromptAlert()
.sheet(item: self.$presentedSheet) { sheet in .sheet(item: self.$presentedSheet) { sheet in
switch sheet { switch sheet {
case .settings: case .settings:

View File

@@ -376,6 +376,7 @@ struct SettingsTab: View {
} }
} }
} }
.gatewayTrustPromptAlert()
} }
@ViewBuilder @ViewBuilder
@@ -388,11 +389,13 @@ struct SettingsTab: View {
.font(.footnote) .font(.footnote)
.foregroundStyle(.secondary) .foregroundStyle(.secondary)
if let lastKnown = GatewaySettingsStore.loadLastGatewayConnection() { if let lastKnown = GatewaySettingsStore.loadLastGatewayConnection(),
case let .manual(host, port, _, _) = lastKnown
{
Button { Button {
Task { await self.connectLastKnown() } Task { await self.connectLastKnown() }
} label: { } label: {
self.lastKnownButtonLabel(host: lastKnown.host, port: lastKnown.port) self.lastKnownButtonLabel(host: host, port: port)
} }
.disabled(self.connectingGatewayID != nil) .disabled(self.connectingGatewayID != nil)
.buttonStyle(.borderedProminent) .buttonStyle(.borderedProminent)

View File

@@ -62,7 +62,7 @@ import Testing
let params = controller._test_resolveDiscoveredTLSParams(gateway: gateway, allowTOFU: true) let params = controller._test_resolveDiscoveredTLSParams(gateway: gateway, allowTOFU: true)
#expect(params?.expectedFingerprint == nil) #expect(params?.expectedFingerprint == nil)
#expect(params?.allowTOFU == true) #expect(params?.allowTOFU == false)
} }
@Test @MainActor func autoconnectRequiresStoredPinForDiscoveredGateways() async { @Test @MainActor func autoconnectRequiresStoredPinForDiscoveredGateways() async {
@@ -77,6 +77,7 @@ import Testing
defaults.removeObject(forKey: "gateway.last.port") defaults.removeObject(forKey: "gateway.last.port")
defaults.removeObject(forKey: "gateway.last.tls") defaults.removeObject(forKey: "gateway.last.tls")
defaults.removeObject(forKey: "gateway.last.stableID") defaults.removeObject(forKey: "gateway.last.stableID")
defaults.removeObject(forKey: "gateway.last.kind")
defaults.removeObject(forKey: "gateway.preferredStableID") defaults.removeObject(forKey: "gateway.preferredStableID")
defaults.set(stableID, forKey: "gateway.lastDiscoveredStableID") defaults.set(stableID, forKey: "gateway.lastDiscoveredStableID")
@@ -102,4 +103,3 @@ import Testing
#expect(controller._test_didAutoConnect() == false) #expect(controller._test_didAutoConnect() == false)
} }
} }

View File

@@ -124,4 +124,76 @@ private func restoreKeychain(_ snapshot: [KeychainEntry: String?]) {
#expect(defaults.string(forKey: "gateway.preferredStableID") == "preferred-from-keychain") #expect(defaults.string(forKey: "gateway.preferredStableID") == "preferred-from-keychain")
#expect(defaults.string(forKey: "gateway.lastDiscoveredStableID") == "last-from-keychain") #expect(defaults.string(forKey: "gateway.lastDiscoveredStableID") == "last-from-keychain")
} }
@Test func lastGateway_manualRoundTrip() {
let keys = [
"gateway.last.kind",
"gateway.last.host",
"gateway.last.port",
"gateway.last.tls",
"gateway.last.stableID",
]
let snapshot = snapshotDefaults(keys)
defer { restoreDefaults(snapshot) }
GatewaySettingsStore.saveLastGatewayConnectionManual(
host: "example.com",
port: 443,
useTLS: true,
stableID: "manual|example.com|443")
let loaded = GatewaySettingsStore.loadLastGatewayConnection()
#expect(loaded == .manual(host: "example.com", port: 443, useTLS: true, stableID: "manual|example.com|443"))
}
@Test func lastGateway_discoveredDoesNotPersistResolvedHostPort() {
let keys = [
"gateway.last.kind",
"gateway.last.host",
"gateway.last.port",
"gateway.last.tls",
"gateway.last.stableID",
]
let snapshot = snapshotDefaults(keys)
defer { restoreDefaults(snapshot) }
// Simulate a prior manual record that included host/port.
applyDefaults([
"gateway.last.host": "10.0.0.99",
"gateway.last.port": 18789,
"gateway.last.tls": true,
"gateway.last.stableID": "manual|10.0.0.99|18789",
"gateway.last.kind": "manual",
])
GatewaySettingsStore.saveLastGatewayConnectionDiscovered(stableID: "gw|abc", useTLS: true)
let defaults = UserDefaults.standard
#expect(defaults.object(forKey: "gateway.last.host") == nil)
#expect(defaults.object(forKey: "gateway.last.port") == nil)
#expect(GatewaySettingsStore.loadLastGatewayConnection() == .discovered(stableID: "gw|abc", useTLS: true))
}
@Test func lastGateway_backCompat_manualLoadsWhenKindMissing() {
let keys = [
"gateway.last.kind",
"gateway.last.host",
"gateway.last.port",
"gateway.last.tls",
"gateway.last.stableID",
]
let snapshot = snapshotDefaults(keys)
defer { restoreDefaults(snapshot) }
applyDefaults([
"gateway.last.kind": nil,
"gateway.last.host": "example.org",
"gateway.last.port": 18789,
"gateway.last.tls": false,
"gateway.last.stableID": "manual|example.org|18789",
])
let loaded = GatewaySettingsStore.loadLastGatewayConnection()
#expect(loaded == .manual(host: "example.org", port: 18789, useTLS: false, stableID: "manual|example.org|18789"))
}
} }

View File

@@ -105,6 +105,7 @@ Security notes:
- Bonjour/mDNS TXT records are **unauthenticated**. Clients must not treat TXT as authoritative routing. - Bonjour/mDNS TXT records are **unauthenticated**. Clients must not treat TXT as authoritative routing.
- Clients should route using the resolved service endpoint (SRV + A/AAAA). Treat `lanHost`, `tailnetDns`, `gatewayPort`, and `gatewayTlsSha256` as hints only. - Clients should route using the resolved service endpoint (SRV + A/AAAA). Treat `lanHost`, `tailnetDns`, `gatewayPort`, and `gatewayTlsSha256` as hints only.
- TLS pinning must never allow an advertised `gatewayTlsSha256` to override a previously stored pin. - TLS pinning must never allow an advertised `gatewayTlsSha256` to override a previously stored pin.
- iOS/Android nodes should treat discovery-based direct connects as **TLS-only** and require explicit user confirmation before trusting a first-time fingerprint.
## Debugging on macOS ## Debugging on macOS

View File

@@ -72,7 +72,8 @@ Security notes:
- Bonjour/mDNS TXT records are **unauthenticated**. Clients must treat TXT values as UX hints only. - Bonjour/mDNS TXT records are **unauthenticated**. Clients must treat TXT values as UX hints only.
- Routing (host/port) should prefer the **resolved service endpoint** (SRV + A/AAAA) over TXT-provided `lanHost`, `tailnetDns`, or `gatewayPort`. - Routing (host/port) should prefer the **resolved service endpoint** (SRV + A/AAAA) over TXT-provided `lanHost`, `tailnetDns`, or `gatewayPort`.
- TLS pinning must never allow an advertised `gatewayTlsSha256` to override a previously stored pin. For first-time connections, require explicit user intent (TOFU or other out-of-band verification). - TLS pinning must never allow an advertised `gatewayTlsSha256` to override a previously stored pin.
- iOS/Android nodes should treat discovery-based direct connects as **TLS-only** and require an explicit “trust this fingerprint” confirmation before storing a first-time pin (out-of-band verification).
Disable/override: Disable/override: