diff --git a/halfapi/lib/jwt_middleware.py b/halfapi/lib/jwt_middleware.py index 9501fc1..31cb7d5 100644 --- a/halfapi/lib/jwt_middleware.py +++ b/halfapi/lib/jwt_middleware.py @@ -20,15 +20,20 @@ from starlette.authentication import ( AuthenticationBackend, AuthenticationError, BaseUser, AuthCredentials, UnauthenticatedUser) from starlette.requests import HTTPConnection +from starlette.exceptions import HTTPException logger = logging.getLogger('halfapi') +""" +@OLD : old way to check production setting, we can simply check app's "debug" attribute + try: from ..conf import PRODUCTION except ImportError: logger.warning('Could not import PRODUCTION variable from conf module,'\ ' using HALFAPI_PROD environment variable') PRODUCTION = bool(environ.get('HALFAPI_PROD', False)) +""" try: from ..conf import SECRET @@ -76,6 +81,28 @@ class JWTUser(BaseUser): return self.__id +class CheckUser(BaseUser): + """ CheckUser class + + Is used to call checks with give user_id, to know if it passes the ACLs for + the given route. + + It should never be able to run a route function. + """ + def __init__(self, user_id: UUID) -> None: + self.__id = user_id + + + @property + def is_authenticated(self) -> bool: + return True + + @property + def display_name(self) -> str: + return 'check_user' + + @property + def id(self) -> str: return self.__id @@ -90,21 +117,46 @@ class JWTAuthenticationBackend(AuthenticationBackend): self.algorithm = algorithm self.prefix = prefix + @property + def id(self) -> str: + return self.__id + async def authenticate( self, conn: HTTPConnection - ) -> typing.Optional[typing.Tuple["AuthCredentials", "BaseUser"]]: + ) -> typing.Optional[typing.Tuple['AuthCredentials', 'BaseUser']]: - if "Authorization" not in conn.headers: - return None - token = conn.headers["Authorization"] + token = conn.headers.get('Authorization') + is_check_call = 'check' in conn.query_params + is_fake_user_id = is_check_call and 'user_id' in conn.query_params + PRODUCTION = conn.scope['app'].debug == False + + if not token and not is_check_call: + return AuthCredentials(), UnauthenticatedUser() + try: - payload = jwt.decode(token, - key=self.secret_key, - algorithms=[self.algorithm], - options={ - 'verify_signature': bool(PRODUCTION) - }) + if token and not is_fake_user_id: + payload = jwt.decode(token, + key=self.secret_key, + algorithms=[self.algorithm], + options={ + 'verify_signature': bool(PRODUCTION) + }) + + if is_check_call: + if is_fake_user_id: + try: + fake_user_id = UUID(conn.query_params['user_id']) + + return AuthCredentials(), CheckUser(fake_user_id) + except ValueError as exc: + raise HTTPException(400, 'user_id parameter not an uuid') + + if token: + return AuthCredentials(), CheckUser(payload['user_id']) + else: + return AuthCredentials(), UnauthenticatedUser() + if PRODUCTION and 'debug' in payload.keys() and payload['debug']: raise AuthenticationError(