Skip to content

Commit

Permalink
More checks for task cancellation and tests (#44)
Browse files Browse the repository at this point in the history
### Motivation

In our fallback, buffered implementation, we did not use a task
cancellation handler so were not proactively cancelling the URLSession
task when the Swift concurrency task was cancelled. Additionally, while
we _did_ have a task cancellation handler in the streaming
implementation, so the URLSession task would be cancelled, we were not
actively checking for task cancellation as often as we could.

### Modifications

- Added more cooperative task cancellation.
- Added tests for both implementations that when the parent task for the
client request is cancelled that we get something sensible. Note that in
some cases, the request will succeed. In the cases where the request
fails, it will surface as a `ClientError` to the user where the
`underlyingError` is either `Swift.CancellationError` or `URLError` with
`code == .cancelled`.

### Result

More cooperative task and URLSession task cancellation and more thorough
tests.

### Test Plan

Added unit tests.
  • Loading branch information
simonjbeaumont authored Dec 11, 2023
1 parent 144464e commit aac0a82
Show file tree
Hide file tree
Showing 4 changed files with 282 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import Foundation
task = dataTask(with: urlRequest)
}
return try await withTaskCancellationHandler {
try Task.checkCancellation()
let delegate = BidirectionalStreamingURLSessionDelegate(
requestBody: requestBody,
requestStreamBufferSize: requestStreamBufferSize,
Expand All @@ -47,8 +48,10 @@ import Foundation
length: .init(from: response),
iterationBehavior: .single
)
try Task.checkCancellation()
return (try HTTPResponse(response), responseBody)
} onCancel: {
debug("Concurrency task cancelled, cancelling URLSession task.")
task.cancel()
}
}
Expand Down
50 changes: 35 additions & 15 deletions Sources/OpenAPIURLSession/URLSessionTransport.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import class Foundation.FileHandle
#if canImport(FoundationNetworking)
@preconcurrency import struct FoundationNetworking.URLRequest
import class FoundationNetworking.URLSession
import class FoundationNetworking.URLSessionTask
import class FoundationNetworking.URLResponse
import class FoundationNetworking.HTTPURLResponse
#endif
Expand Down Expand Up @@ -243,31 +244,50 @@ extension URLSession {
func bufferedRequest(for request: HTTPRequest, baseURL: URL, requestBody: HTTPBody?) async throws -> (
HTTPResponse, HTTPBody?
) {
try Task.checkCancellation()
var urlRequest = try URLRequest(request, baseURL: baseURL)
if let requestBody { urlRequest.httpBody = try await Data(collecting: requestBody, upTo: .max) }
try Task.checkCancellation()

/// Use `dataTask(with:completionHandler:)` here because `data(for:[delegate:]) async` is only available on
/// Darwin platforms newer than our minimum deployment target, and not at all on Linux.
let (response, maybeResponseBodyData): (URLResponse, Data?) = try await withCheckedThrowingContinuation {
continuation in
let task = self.dataTask(with: urlRequest) { [urlRequest] data, response, error in
if let error {
continuation.resume(throwing: error)
return
let taskBox: LockedValueBox<URLSessionTask?> = .init(nil)
return try await withTaskCancellationHandler {
let (response, maybeResponseBodyData): (URLResponse, Data?) = try await withCheckedThrowingContinuation {
continuation in
let task = self.dataTask(with: urlRequest) { [urlRequest] data, response, error in
if let error {
continuation.resume(throwing: error)
return
}
guard let response else {
continuation.resume(throwing: URLSessionTransportError.noResponse(url: urlRequest.url))
return
}
continuation.resume(with: .success((response, data)))
}
guard let response else {
continuation.resume(throwing: URLSessionTransportError.noResponse(url: urlRequest.url))
return
// Swift concurrency task cancelled here.
taskBox.withLockedValue { boxedTask in
guard task.state == .suspended else {
debug("URLSession task cannot be resumed, probably because it was cancelled by onCancel.")
return
}
task.resume()
boxedTask = task
}
continuation.resume(with: .success((response, data)))
}
task.resume()
}

let maybeResponseBody = maybeResponseBodyData.map { data in
HTTPBody(data, length: HTTPBody.Length(from: response), iterationBehavior: .multiple)
let maybeResponseBody = maybeResponseBodyData.map { data in
HTTPBody(data, length: HTTPBody.Length(from: response), iterationBehavior: .multiple)
}
return (try HTTPResponse(response), maybeResponseBody)
} onCancel: {
taskBox.withLockedValue { boxedTask in
debug("Concurrency task cancelled, cancelling URLSession task.")
boxedTask?.cancel()
boxedTask = nil
}
}
return (try HTTPResponse(response), maybeResponseBody)
}
}

Expand Down
240 changes: 240 additions & 0 deletions Tests/OpenAPIURLSessionTests/TaskCancellationTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftOpenAPIGenerator open source project
//
// Copyright (c) 2023 Apple Inc. and the SwiftOpenAPIGenerator project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftOpenAPIGenerator project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//
#if canImport(Darwin)

import Foundation
import HTTPTypes
import NIO
import OpenAPIRuntime
import XCTest
@testable import OpenAPIURLSession

enum CancellationPoint: CaseIterable {
case beforeSendingHead
case beforeSendingRequestBody
case partwayThroughSendingRequestBody
case beforeConsumingResponseBody
case partwayThroughConsumingResponseBody
case afterConsumingResponseBody
}

func testTaskCancelled(_ cancellationPoint: CancellationPoint, transport: URLSessionTransport) async throws {
let requestPath = "/hello/world"
let requestBodyElements = ["Hello,", "world!"]
let requestBodySequence = MockAsyncSequence(elementsToVend: requestBodyElements, gatingProduction: true)
let requestBody = HTTPBody(
requestBodySequence,
length: .known(Int64(requestBodyElements.joined().lengthOfBytes(using: .utf8))),
iterationBehavior: .single
)

let responseBodyMessage = "Hey!"

let taskShouldCancel = XCTestExpectation(description: "Concurrency task cancelled")
let taskCancelled = XCTestExpectation(description: "Concurrency task cancelled")

try await withThrowingTaskGroup(of: Void.self) { group in
let serverPort = try await AsyncTestHTTP1Server.start(connectionTaskGroup: &group) { connectionChannel in
try await connectionChannel.executeThenClose { inbound, outbound in
var requestPartIterator = inbound.makeAsyncIterator()
var accumulatedBody = ByteBuffer()
while let requestPart = try await requestPartIterator.next() {
switch requestPart {
case .head(let head):
XCTAssertEqual(head.uri, requestPath)
XCTAssertEqual(head.method, .POST)
case .body(let buffer): accumulatedBody.writeImmutableBuffer(buffer)
case .end:
switch cancellationPoint {
case .beforeConsumingResponseBody, .partwayThroughConsumingResponseBody,
.afterConsumingResponseBody:
XCTAssertEqual(
String(decoding: accumulatedBody.readableBytesView, as: UTF8.self),
requestBodyElements.joined()
)
case .beforeSendingHead, .beforeSendingRequestBody, .partwayThroughSendingRequestBody: break
}
try await outbound.write(.head(.init(version: .http1_1, status: .ok)))
try await outbound.write(.body(ByteBuffer(string: responseBodyMessage)))
try await outbound.write(.end(nil))
}
}
}
}
debug("Server running on 127.0.0.1:\(serverPort)")

let task = Task {
if case .beforeSendingHead = cancellationPoint {
taskShouldCancel.fulfill()
await fulfillment(of: [taskCancelled])
}
debug("Client starting request")
async let (asyncResponse, asyncResponseBody) = try await transport.send(
HTTPRequest(method: .post, scheme: nil, authority: nil, path: requestPath),
body: requestBody,
baseURL: URL(string: "http://127.0.0.1:\(serverPort)")!,
operationID: "unused"
)

if case .beforeSendingRequestBody = cancellationPoint {
taskShouldCancel.fulfill()
await fulfillment(of: [taskCancelled])
}

requestBodySequence.openGate(for: 1)

if case .partwayThroughSendingRequestBody = cancellationPoint {
taskShouldCancel.fulfill()
await fulfillment(of: [taskCancelled])
}

requestBodySequence.openGate()

let (response, maybeResponseBody) = try await (asyncResponse, asyncResponseBody)

debug("Client received response head: \(response)")
XCTAssertEqual(response.status, .ok)
let responseBody = try XCTUnwrap(maybeResponseBody)

if case .beforeConsumingResponseBody = cancellationPoint {
taskShouldCancel.fulfill()
await fulfillment(of: [taskCancelled])
}

var iterator = responseBody.makeAsyncIterator()

_ = try await iterator.next()

if case .partwayThroughConsumingResponseBody = cancellationPoint {
taskShouldCancel.fulfill()
await fulfillment(of: [taskCancelled])
}

while try await iterator.next() != nil {

}

if case .afterConsumingResponseBody = cancellationPoint {
taskShouldCancel.fulfill()
await fulfillment(of: [taskCancelled])
}

}

await fulfillment(of: [taskShouldCancel])
task.cancel()
taskCancelled.fulfill()

switch transport.configuration.implementation {
case .buffering:
switch cancellationPoint {
case .beforeSendingHead, .beforeSendingRequestBody, .partwayThroughSendingRequestBody:
await XCTAssertThrowsError(try await task.value) { error in XCTAssertTrue(error is CancellationError) }
case .beforeConsumingResponseBody, .partwayThroughConsumingResponseBody, .afterConsumingResponseBody:
try await task.value
}
case .streaming:
switch cancellationPoint {
case .beforeSendingHead:
await XCTAssertThrowsError(try await task.value) { error in XCTAssertTrue(error is CancellationError) }
case .beforeSendingRequestBody, .partwayThroughSendingRequestBody:
await XCTAssertThrowsError(try await task.value) { error in
guard let urlError = error as? URLError else {
XCTFail()
return
}
XCTAssertEqual(urlError.code, .cancelled)
}
case .beforeConsumingResponseBody, .partwayThroughConsumingResponseBody, .afterConsumingResponseBody:
try await task.value
}
}

group.cancelAll()
}

}

func fulfillment(
of expectations: [XCTestExpectation],
timeout seconds: TimeInterval = .infinity,
enforceOrder enforceOrderOfFulfillment: Bool = false,
file: StaticString = #file,
line: UInt = #line
) async {
guard
case .completed = await XCTWaiter.fulfillment(
of: expectations,
timeout: seconds,
enforceOrder: enforceOrderOfFulfillment
)
else {
XCTFail("Expectation was not fulfilled", file: file, line: line)
return
}
}

extension URLSessionTransportBufferedTests {
func testCancellation_beforeSendingHead() async throws {
try await testTaskCancelled(.beforeSendingHead, transport: transport)
}

func testCancellation_beforeSendingRequestBody() async throws {
try await testTaskCancelled(.beforeSendingRequestBody, transport: transport)
}

func testCancellation_partwayThroughSendingRequestBody() async throws {
try await testTaskCancelled(.partwayThroughSendingRequestBody, transport: transport)
}

func testCancellation_beforeConsumingResponseBody() async throws {
try await testTaskCancelled(.beforeConsumingResponseBody, transport: transport)
}

func testCancellation_partwayThroughConsumingResponseBody() async throws {
try await testTaskCancelled(.partwayThroughConsumingResponseBody, transport: transport)
}

func testCancellation_afterConsumingResponseBody() async throws {
try await testTaskCancelled(.afterConsumingResponseBody, transport: transport)
}
}

extension URLSessionTransportStreamingTests {
func testCancellation_beforeSendingHead() async throws {
try await testTaskCancelled(.beforeSendingHead, transport: transport)
}

func testCancellation_beforeSendingRequestBody() async throws {
try await testTaskCancelled(.beforeSendingRequestBody, transport: transport)
}

func testCancellation_partwayThroughSendingRequestBody() async throws {
try await testTaskCancelled(.partwayThroughSendingRequestBody, transport: transport)
}

func testCancellation_beforeConsumingResponseBody() async throws {
try await testTaskCancelled(.beforeConsumingResponseBody, transport: transport)
}

func testCancellation_partwayThroughConsumingResponseBody() async throws {
try await testTaskCancelled(.partwayThroughConsumingResponseBody, transport: transport)
}

func testCancellation_afterConsumingResponseBody() async throws {
try await testTaskCancelled(.afterConsumingResponseBody, transport: transport)
}
}

#endif // canImport(Darwin)
8 changes: 4 additions & 4 deletions Tests/OpenAPIURLSessionTests/URLSessionTransportTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class URLSessionTransportConverterTests: XCTestCase {

// swift-format-ignore: AllPublicDeclarationsHaveDocumentation
class URLSessionTransportBufferedTests: XCTestCase {
var transport: (any ClientTransport)!
var transport: URLSessionTransport!

static override func setUp() { OpenAPIURLSession.debugLoggingEnabled = false }

Expand All @@ -66,7 +66,7 @@ class URLSessionTransportBufferedTests: XCTestCase {

func testBasicGet() async throws { try await testHTTPBasicGet(transport: transport) }

func testBasicPost() async throws { try await testHTTPBasicGet(transport: transport) }
func testBasicPost() async throws { try await testHTTPBasicPost(transport: transport) }

#if canImport(Darwin) // Only passes on Darwin because Linux doesn't replay the request body on 307.
func testHTTPRedirect_multipleIterationBehavior_succeeds() async throws {
Expand All @@ -89,7 +89,7 @@ class URLSessionTransportBufferedTests: XCTestCase {

// swift-format-ignore: AllPublicDeclarationsHaveDocumentation
class URLSessionTransportStreamingTests: XCTestCase {
var transport: (any ClientTransport)!
var transport: URLSessionTransport!

static override func setUp() { OpenAPIURLSession.debugLoggingEnabled = false }

Expand All @@ -107,7 +107,7 @@ class URLSessionTransportStreamingTests: XCTestCase {

func testBasicGet() async throws { try await testHTTPBasicGet(transport: transport) }

func testBasicPost() async throws { try await testHTTPBasicGet(transport: transport) }
func testBasicPost() async throws { try await testHTTPBasicPost(transport: transport) }

#if canImport(Darwin) // Only passes on Darwin because Linux doesn't replay the request body on 307.
func testHTTPRedirect_multipleIterationBehavior_succeeds() async throws {
Expand Down

0 comments on commit aac0a82

Please sign in to comment.