Skip to content

Commit

Permalink
Improve Sendable conformance for HBWebSocket (#30)
Browse files Browse the repository at this point in the history
* Start of work to separate auto-ping from websocket

* Finish making HBWebSocket Sendable

* comment, value -> wrapped
  • Loading branch information
adam-fowler authored Oct 3, 2023
1 parent 90b7605 commit 28d5056
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 68 deletions.
3 changes: 3 additions & 0 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ let package = Package(
.library(name: "HummingbirdWSCore", targets: ["HummingbirdWSCore"]),
],
dependencies: [
.package(url: "https://github.com/apple/swift-atomics.git", from: "1.0.0"),
.package(url: "https://github.com/hummingbird-project/hummingbird-core.git", from: "1.1.0"),
.package(url: "https://github.com/hummingbird-project/hummingbird.git", from: "1.4.0"),
.package(url: "https://github.com/apple/swift-nio.git", from: "2.32.1"),
Expand All @@ -32,11 +33,13 @@ let package = Package(
]),
.target(name: "HummingbirdWebSocket", dependencies: [
.byName(name: "HummingbirdWSCore"),
.product(name: "Atomics", package: "swift-atomics"),
.product(name: "Hummingbird", package: "hummingbird"),
]),
.testTarget(name: "HummingbirdWebSocketTests", dependencies: [
.byName(name: "HummingbirdWebSocket"),
.byName(name: "HummingbirdWSClient"),
.product(name: "Atomics", package: "swift-atomics"),
]),
]
)
180 changes: 112 additions & 68 deletions Sources/HummingbirdWSCore/WebSocket.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,50 +12,132 @@
//
//===----------------------------------------------------------------------===//

import Atomics
import NIOCore
import NIOWebSocket

/// WebSocket object
public final class HBWebSocket {
public final class HBWebSocket: Sendable {
public enum SocketType: Sendable {
case client
case server
}

private class AutoPingTaskManager {
private var waitingOnPong: Bool
private var pingData: ByteBuffer
private var autoPingTask: Scheduled<Void>?

init(channel: Channel) {
self.waitingOnPong = false
self.pingData = channel.allocator.buffer(capacity: 16)
self.autoPingTask = nil
channel.closeFuture.whenComplete { _ in
self.shutdown()
}
}

func shutdown() {
self.autoPingTask?.cancel()
}

/// Send ping and setup task to check for pong and send new ping
func initiateAutoPing(interval: TimeAmount, ws: HBWebSocket) {
self.autoPingTask = ws.channel.eventLoop.scheduleTask(in: interval) {
if self.waitingOnPong {
// We never received a pong from our last ping, so the connection has timed out
let promise = ws.channel.eventLoop.makePromise(of: Void.self)
ws.close(code: .goingAway, promise: promise)
promise.futureResult.whenComplete { _ in
// Usually, closing a WebSocket is done by sending the close frame and waiting
// for the peer to respond with their close frame. We are in a timeout situation,
// so the other side likely will never send the close frame. We just close the
// channel ourselves.
ws.channel.close(mode: .all, promise: nil)
}

} else {
ws.sendPing().whenSuccess {
self.waitingOnPong = true
self.initiateAutoPing(interval: interval, ws: ws)
}
}
}
}

/// Respond to pong from client. Verify contents of pong and clear waitingOnPong flag
func receivedPong(frame: WebSocketFrame, ws: HBWebSocket) {
let frameData = frame.unmaskedData
guard frameData == self.pingData else {
ws.close(code: .goingAway, promise: nil)
return
}
self.waitingOnPong = false
ws.pongCallback.load(ordering: .relaxed).wrapped?(ws)
}

/// Send ping message
/// - Parameter promise: promise that is completed when ping message has been sent
func sendPing(ws: HBWebSocket, promise: EventLoopPromise<Void>?) {
if self.waitingOnPong {
promise?.succeed(())
return
}
// creating random payload
let random = (0..<16).map { _ in UInt8.random(in: 0...255) }
self.pingData.clear()
self.pingData.writeBytes(random)

ws.send(buffer: self.pingData, opcode: .ping, promise: promise)
}
}

public let channel: Channel
@inlinable public var eventLoop: EventLoop {
return self.channel.eventLoop
}

let type: SocketType
public typealias ReadCallback = @Sendable (WebSocketData, HBWebSocket) -> Void
public typealias CloseCallback = @Sendable (HBWebSocket) -> Void
public typealias PongCallback = @Sendable (HBWebSocket) -> Void

private var waitingOnPong: Bool = false
private var pingData: ByteBuffer
private var autoPingTask: Scheduled<Void>?
// wrapper class for type that conforms to AtomicReference
final class AtomicContainer<Value: Sendable>: AtomicReference, Sendable {
let wrapped: Value

init(_ value: Value) {
self.wrapped = value
}
}

let type: SocketType
private let autoPingManager: NIOLoopBound<AutoPingTaskManager>
private let isClosed: ManagedAtomic<Bool>
private let pongCallback: ManagedAtomic<AtomicContainer<PongCallback?>>
private let readCallback: ManagedAtomic<AtomicContainer<ReadCallback?>>

public init(channel: Channel, type: SocketType) {
self.channel = channel
self.isClosed = false
self.pongCallback = nil
self.readCallback = nil
self.pingData = channel.allocator.buffer(capacity: 16)
self.isClosed = .init(false)
self.pongCallback = .init(.init(nil))
self.readCallback = .init(.init(nil))
self.autoPingManager = .init(.init(channel: channel), eventLoop: channel.eventLoop)
self.type = type
}

/// Set callback to be called whenever WebSocket receives data
public func onRead(_ cb: @escaping ReadCallback) {
self.readCallback = cb
@preconcurrency public func onRead(_ cb: @escaping ReadCallback) {
self.readCallback.store(.init(cb), ordering: .relaxed)
}

/// Set callback to be called whenever WebSocket receives a pong
public func onPong(_ cb: @escaping PongCallback) {
self.pongCallback = cb
@preconcurrency public func onPong(_ cb: @escaping PongCallback) {
self.pongCallback.store(.init(cb), ordering: .relaxed)
}

/// Set callback to be called whenever WebSocket channel is closed
public func onClose(_ cb: @escaping CloseCallback) {
@preconcurrency public func onClose(_ cb: @escaping CloseCallback) {
self.channel.closeFuture.whenComplete { _ in
self.autoPingTask?.cancel()
cb(self)
}
}
Expand Down Expand Up @@ -97,11 +179,11 @@ public final class HBWebSocket {
/// - code: Close reason
/// - promise: promise that is completed when close has been sent
public func close(code: WebSocketErrorCode = .normalClosure, promise: EventLoopPromise<Void>?) {
guard self.isClosed == false else {
guard self.isClosed.load(ordering: .relaxed) == false else {
promise?.succeed(())
return
}
self.isClosed = true
self.isClosed.store(true, ordering: .relaxed)

var buffer = self.channel.allocator.buffer(capacity: 2)
buffer.write(webSocketErrorCode: code)
Expand All @@ -119,17 +201,12 @@ public final class HBWebSocket {
/// Send ping message
/// - Parameter promise: promise that is completed when ping message has been sent
public func sendPing(promise: EventLoopPromise<Void>?) {
self.channel.eventLoop.execute {
if self.waitingOnPong {
promise?.succeed(())
return
if self.channel.eventLoop.inEventLoop {
self.autoPingManager.value.sendPing(ws: self, promise: promise)
} else {
self.channel.eventLoop.execute {
self.autoPingManager.value.sendPing(ws: self, promise: promise)
}
// creating random payload
let random = (0..<16).map { _ in UInt8.random(in: 0...255) }
self.pingData.clear()
self.pingData.writeBytes(random)

self.send(buffer: self.pingData, opcode: .ping, promise: promise)
}
}

Expand All @@ -149,30 +226,17 @@ public final class HBWebSocket {
guard self.channel.isActive else {
return
}
self.autoPingTask = self.channel.eventLoop.scheduleTask(in: interval) {
if self.waitingOnPong {
// We never received a pong from our last ping, so the connection has timed out
let promise = self.channel.eventLoop.makePromise(of: Void.self)
self.close(code: .goingAway, promise: promise)
promise.futureResult.whenComplete { _ in
// Usually, closing a WebSocket is done by sending the close frame and waiting
// for the peer to respond with their close frame. We are in a timeout situation,
// so the other side likely will never send the close frame. We just close the
// channel ourselves.
self.channel.close(mode: .all, promise: nil)
}

} else {
self.sendPing().whenSuccess {
self.waitingOnPong = true
self.initiateAutoPing(interval: interval)
}
if self.channel.eventLoop.inEventLoop {
self.autoPingManager.value.initiateAutoPing(interval: interval, ws: self)
} else {
self.channel.eventLoop.execute {
self.autoPingManager.value.initiateAutoPing(interval: interval, ws: self)
}
}
}

func read(_ data: WebSocketData) {
self.readCallback?(data, self)
self.readCallback.load(ordering: .relaxed).wrapped?(data, self)
}

/// Send web socket frame to server
Expand All @@ -189,13 +253,7 @@ public final class HBWebSocket {

/// Respond to pong from client. Verify contents of pong and clear waitingOnPong flag
func receivedPong(frame: WebSocketFrame) {
let frameData = frame.unmaskedData
guard frameData == self.pingData else {
self.close(code: .goingAway, promise: nil)
return
}
self.waitingOnPong = false
self.pongCallback?(self)
self.autoPingManager.value.receivedPong(frame: frame, ws: self)
}

/// Respond to ping from client
Expand All @@ -209,7 +267,7 @@ public final class HBWebSocket {

func receivedClose(frame: WebSocketFrame) {
// Handle a received close frame. We're just going to close.
self.isClosed = true
self.isClosed.store(true, ordering: .relaxed)
self.channel.close(promise: nil)
}

Expand All @@ -229,14 +287,6 @@ public final class HBWebSocket {
let bytes: [UInt8] = (0...3).map { _ in UInt8.random(in: .min ... .max) }
return WebSocketMaskingKey(bytes)
}

public typealias ReadCallback = (WebSocketData, HBWebSocket) -> Void
public typealias CloseCallback = (HBWebSocket) -> Void
public typealias PongCallback = (HBWebSocket) -> Void

private var pongCallback: PongCallback?
private var readCallback: ReadCallback?
private var isClosed: Bool = false
}

@available(macOS 10.15, iOS 13.0, tvOS 13.0, watchOS 6.0, *)
Expand Down Expand Up @@ -275,9 +325,3 @@ extension HBWebSocket {
}
}
}

#if compiler(>=5.6)
// HBWebSocket can be set to Sendable because ping data which is mutable is
// managed internally and is only ever changed on the event loop
extension HBWebSocket: @unchecked Sendable {}
#endif // compiler(>=5.6)

0 comments on commit 28d5056

Please sign in to comment.