Source code for swh.auth.starlette.backends

# Copyright (C) 2023-2024 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information

from datetime import datetime
import hashlib
from typing import Any, Dict, Optional, Tuple

from aiocache.base import BaseCache
from jwcrypto.common import JWException
from starlette.authentication import (
    AuthCredentials,
    AuthenticationBackend,
    AuthenticationError,
    SimpleUser,
)
from starlette.requests import HTTPConnection

from swh.auth.keycloak import (
    ExpiredSignatureError,
    KeycloakError,
    KeycloakOpenIDConnect,
    keycloak_error_message,
)


[docs] class BearerTokenAuthBackend(AuthenticationBackend): """ Starlette authentication backend using Keycloak OpenID Connect authorization An Keycloak server, realm and a cache to store access tokens must be provided """ def __init__( self, server_url: str, realm_name: str, client_id: str, cache: BaseCache ): """ Args: server_url: Keycloak URL realm_name: Keycloak realm name client_id: Keycloak client ID cache: An aiocache cache instance """ self.client_id = client_id self.oidc_client = KeycloakOpenIDConnect( server_url=server_url, realm_name=realm_name, client_id=client_id, ) self.cache = cache def _get_token_from_header(self, auth_header: str) -> str: try: auth_type, bearer_token = auth_header.split(" ", 1) except ValueError: raise AuthenticationError("Invalid auth header") if auth_type != "Bearer": raise AuthenticationError("Invalid or unsupported authorization type") return bearer_token def _get_token_cache_key(self, refresh_token) -> str: hasher = hashlib.sha1() hasher.update(refresh_token.encode("ascii")) return f"api_token_{hasher.hexdigest()}" def _get_new_access_token(self, refresh_token: str) -> Dict[str, Any]: try: access_token = self.oidc_client.refresh_token(refresh_token) except KeycloakError as e: raise AuthenticationError( "Invalid or expired user token", keycloak_error_message(e) ) return access_token def _decode_token(self, access_token: str) -> Optional[Dict[str, Any]]: if not access_token: return None try: decoded_token = self.oidc_client.decode_token(access_token) except (KeycloakError, UnicodeEncodeError, ExpiredSignatureError, ValueError): # token is eitehr too old or an invalid one decoded_token = None return decoded_token
[docs] async def authenticate( self, conn: HTTPConnection ) -> Optional[Tuple[AuthCredentials, SimpleUser]]: auth_header = conn.headers.get("Authorization") if auth_header is None: # anonymous user return None token = self._get_token_from_header(auth_header) try: # check if access token was provided in authorization header decoded_token = self._decode_token(token) if not decoded_token: raise AuthenticationError("Access token failed to be decoded") except JWException: # token is a refresh one so backend handles access token renewal # get the cache key cache_key = self._get_token_cache_key(token) # read access token from the cache access_token = await self.cache.get(cache_key) decoded_token = self._decode_token(access_token) if not access_token or not decoded_token: access_token = self._get_new_access_token(token)["access_token"] decoded_token = self._decode_token(access_token) if not decoded_token: raise AuthenticationError("Access token failed to be decoded") exp = datetime.fromtimestamp(decoded_token["exp"]) ttl = int(exp.timestamp() - datetime.now().timestamp()) await self.cache.set(cache_key, access_token, ttl=ttl) # set user scopes realm_access = decoded_token.get("realm_access", {}) user_scopes = realm_access.get("roles", []) resource_access = decoded_token.get("resource_access", {}) client_resource_access = resource_access.get(self.client_id, {}) user_scopes += client_resource_access.get("roles", []) return AuthCredentials(scopes=user_scopes), SimpleUser( decoded_token["preferred_username"] )