Skip to content

Commit

Permalink
KTOR-7644 Make re-auth status codes configurable (#4420)
Browse files Browse the repository at this point in the history
Some services use 403 instead of 401. Changing them might be impossible. With this change Ktor can flexibly work with any broken service.

---------

Co-authored-by: Osip Fatkullin <[email protected]>
  • Loading branch information
wkornewald and osipxd authored Nov 19, 2024
1 parent 0ea62fc commit 6e0eb10
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 18 deletions.
3 changes: 2 additions & 1 deletion buildSrc/src/main/kotlin/test/server/tests/Auth.kt
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ internal fun Application.authTestServer() {
val token = call.request.headers["Authorization"]
if (token.isNullOrEmpty() || token.contains("invalid")) {
call.response.header(HttpHeaders.WWWAuthenticate, "Bearer realm=\"TestServer\"")
call.respond(HttpStatusCode.Unauthorized)
val status = call.request.queryParameters["status"]?.toIntOrNull() ?: 401
call.respond(HttpStatusCode.fromValue(status))
return@get
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
public final class io/ktor/client/plugins/auth/AuthConfig {
public fun <init> ()V
public final fun getProviders ()Ljava/util/List;
public final fun isUnauthorizedResponse ()Lkotlin/jvm/functions/Function2;
public final fun reAuthorizeOnResponse (Lkotlin/jvm/functions/Function2;)V
}

public final class io/ktor/client/plugins/auth/AuthKt {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,11 @@ final class io.ktor.client.plugins.auth/AuthConfig { // io.ktor.client.plugins.a

final val providers // io.ktor.client.plugins.auth/AuthConfig.providers|{}providers[0]
final fun <get-providers>(): kotlin.collections/MutableList<io.ktor.client.plugins.auth/AuthProvider> // io.ktor.client.plugins.auth/AuthConfig.providers.<get-providers>|<get-providers>(){}[0]

final var isUnauthorizedResponse // io.ktor.client.plugins.auth/AuthConfig.isUnauthorizedResponse|{}isUnauthorizedResponse[0]
final fun <get-isUnauthorizedResponse>(): kotlin.coroutines/SuspendFunction1<io.ktor.client.statement/HttpResponse, kotlin/Boolean> // io.ktor.client.plugins.auth/AuthConfig.isUnauthorizedResponse.<get-isUnauthorizedResponse>|<get-isUnauthorizedResponse>(){}[0]

final fun reAuthorizeOnResponse(kotlin.coroutines/SuspendFunction1<io.ktor.client.statement/HttpResponse, kotlin/Boolean>) // io.ktor.client.plugins.auth/AuthConfig.reAuthorizeOnResponse|reAuthorizeOnResponse(kotlin.coroutines.SuspendFunction1<io.ktor.client.statement.HttpResponse,kotlin.Boolean>){}[0]
}

final val io.ktor.client.plugins.auth/Auth // io.ktor.client.plugins.auth/Auth|{}Auth[0]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
/*
* Copyright 2014-2019 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
* Copyright 2014-2024 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
*/

package io.ktor.client.plugins.auth

import io.ktor.client.*
import io.ktor.client.call.*
import io.ktor.client.plugins.*
import io.ktor.client.plugins.api.*
import io.ktor.client.request.*
import io.ktor.client.statement.*
import io.ktor.http.*
import io.ktor.http.auth.*
import io.ktor.util.*
Expand All @@ -23,9 +23,36 @@ private class AtomicCounter {
val atomic = atomic(0)
}

/**
* Configuration used by [Auth] plugin.
*/
@KtorDsl
public class AuthConfig {
/**
* [AuthProvider] list to use.
*/
public val providers: MutableList<AuthProvider> = mutableListOf()

/**
* The currently set function to control whether a response is unauthorized and should trigger a refresh / re-auth.
*
* By default checks against HTTP status 401.
*
* You can set this value via [reAuthorizeOnResponse].
*/
@InternalAPI
public var isUnauthorizedResponse: suspend (HttpResponse) -> Boolean = { it.status == HttpStatusCode.Unauthorized }
private set

/**
* Sets a custom function to control whether a response is unauthorized and should trigger a refresh / re-auth.
*
* Use this to change the value of [isUnauthorizedResponse].
*/
public fun reAuthorizeOnResponse(block: suspend (HttpResponse) -> Boolean) {
@OptIn(InternalAPI::class)
isUnauthorizedResponse = block
}
}

/**
Expand All @@ -39,8 +66,9 @@ public val AuthCircuitBreaker: AttributeKey<Unit> = AttributeKey("auth-request")
*
* You can learn more from [Authentication and authorization](https://ktor.io/docs/auth.html).
*
* [providers] - list of auth providers to use.
* @see [AuthConfig] for configuration options.
*/
@OptIn(InternalAPI::class)
public val Auth: ClientPlugin<AuthConfig> = createClientPlugin("Auth", ::AuthConfig) {
val providers = pluginConfig.providers.toList()

Expand All @@ -50,7 +78,6 @@ public val Auth: ClientPlugin<AuthConfig> = createClientPlugin("Auth", ::AuthCon
val tokenVersionsAttributeKey =
AttributeKey<MutableMap<AuthProvider, Int>>("ProviderVersionAttributeKey")

@OptIn(InternalAPI::class)
fun findProvider(
call: HttpClientCall,
candidateProviders: Set<AuthProvider>
Expand All @@ -64,10 +91,10 @@ public val Auth: ClientPlugin<AuthConfig> = createClientPlugin("Auth", ::AuthCon
}

authHeaders.isEmpty() -> {
LOGGER.trace(
"401 response ${call.request.url} has no or empty \"WWW-Authenticate\" header. " +
LOGGER.trace {
"Unauthorized response ${call.request.url} has no or empty \"WWW-Authenticate\" header. " +
"Can not add or refresh token"
)
}
null
}

Expand All @@ -88,9 +115,9 @@ public val Auth: ClientPlugin<AuthConfig> = createClientPlugin("Auth", ::AuthCon
val requestTokenVersion = requestTokenVersions[provider]

if (requestTokenVersion != null && requestTokenVersion >= tokenVersion.atomic.value) {
LOGGER.trace("Refreshing token for ${call.request.url}")
LOGGER.trace { "Refreshing token for ${call.request.url}" }
if (!provider.refreshToken(call.response)) {
LOGGER.trace("Refreshing token failed for ${call.request.url}")
LOGGER.trace { "Refreshing token failed for ${call.request.url}" }
return false
} else {
requestTokenVersions[provider] = tokenVersion.atomic.incrementAndGet()
Expand All @@ -99,7 +126,6 @@ public val Auth: ClientPlugin<AuthConfig> = createClientPlugin("Auth", ::AuthCon
return true
}

@OptIn(InternalAPI::class)
suspend fun Send.Sender.executeWithNewToken(
call: HttpClientCall,
provider: AuthProvider,
Expand All @@ -111,13 +137,13 @@ public val Auth: ClientPlugin<AuthConfig> = createClientPlugin("Auth", ::AuthCon
provider.addRequestHeaders(request, authHeader)
request.attributes.put(AuthCircuitBreaker, Unit)

LOGGER.trace("Sending new request to ${call.request.url}")
LOGGER.trace { "Sending new request to ${call.request.url}" }
return proceed(request)
}

onRequest { request, _ ->
providers.filter { it.sendWithoutRequest(request) }.forEach { provider ->
LOGGER.trace("Adding auth headers for ${request.url} from provider $provider")
LOGGER.trace { "Adding auth headers for ${request.url} from provider $provider" }
val tokenVersion = tokenVersions.computeIfAbsent(provider) { AtomicCounter() }
val requestTokenVersions = request.attributes
.computeIfAbsent(tokenVersionsAttributeKey) { mutableMapOf() }
Expand All @@ -128,22 +154,22 @@ public val Auth: ClientPlugin<AuthConfig> = createClientPlugin("Auth", ::AuthCon

on(Send) { originalRequest ->
val origin = proceed(originalRequest)
if (origin.response.status != HttpStatusCode.Unauthorized) return@on origin
if (!pluginConfig.isUnauthorizedResponse(origin.response)) return@on origin
if (origin.request.attributes.contains(AuthCircuitBreaker)) return@on origin

var call = origin

val candidateProviders = HashSet(providers)

while (call.response.status == HttpStatusCode.Unauthorized) {
LOGGER.trace("Received 401 for ${call.request.url}")
while (pluginConfig.isUnauthorizedResponse(call.response)) {
LOGGER.trace { "Unauthorized response for ${call.request.url}" }

val (provider, authHeader) = findProvider(call, candidateProviders) ?: run {
LOGGER.trace("Can not find auth provider for ${call.request.url}")
LOGGER.trace { "Can not find auth provider for ${call.request.url}" }
return@on call
}

LOGGER.trace("Using provider $provider for ${call.request.url}")
LOGGER.trace { "Using provider $provider for ${call.request.url}" }

candidateProviders.remove(provider)
if (!refreshTokenIfNeeded(call, provider, originalRequest)) return@on call
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,27 @@ class AuthTest : ClientLoader() {
}
}

@Test
fun testForbiddenBearerAuthWithInvalidAccessAndValidRefreshTokens() = clientTests {
config {
install(Auth) {
reAuthorizeOnResponse { it.status == HttpStatusCode.Forbidden }
bearer {
refreshTokens { BearerTokens("valid", "refresh") }
loadTokens { BearerTokens("invalid", "refresh") }
}
}

expectSuccess = false
}

test { client ->
client.prepareGet("$TEST_SERVER/auth/bearer/test-refresh?status=403").execute {
assertEquals(HttpStatusCode.OK, it.status)
}
}
}

// The return of refreshTokenFun is null, cause it should not be called at all, if loadTokensFun returns valid tokens
@Test
fun testUnauthorizedBearerAuthWithValidAccessTokenAndInvalidRefreshToken() = clientTests {
Expand Down

0 comments on commit 6e0eb10

Please sign in to comment.