[JWTmw] add CheckUser to be used when using the "check" flag. Add "user_id" query param to check access of a specific user to a route

This commit is contained in:
Maxime Alves LIRMM 2021-05-28 22:18:58 +02:00
parent 0e5a8ede9d
commit e4e04c6ac1
1 changed files with 62 additions and 10 deletions

View File

@ -20,15 +20,20 @@ from starlette.authentication import (
AuthenticationBackend, AuthenticationError, BaseUser, AuthCredentials, AuthenticationBackend, AuthenticationError, BaseUser, AuthCredentials,
UnauthenticatedUser) UnauthenticatedUser)
from starlette.requests import HTTPConnection from starlette.requests import HTTPConnection
from starlette.exceptions import HTTPException
logger = logging.getLogger('halfapi') logger = logging.getLogger('halfapi')
"""
@OLD : old way to check production setting, we can simply check app's "debug" attribute
try: try:
from ..conf import PRODUCTION from ..conf import PRODUCTION
except ImportError: except ImportError:
logger.warning('Could not import PRODUCTION variable from conf module,'\ logger.warning('Could not import PRODUCTION variable from conf module,'\
' using HALFAPI_PROD environment variable') ' using HALFAPI_PROD environment variable')
PRODUCTION = bool(environ.get('HALFAPI_PROD', False)) PRODUCTION = bool(environ.get('HALFAPI_PROD', False))
"""
try: try:
from ..conf import SECRET from ..conf import SECRET
@ -76,6 +81,28 @@ class JWTUser(BaseUser):
return self.__id return self.__id
class CheckUser(BaseUser):
""" CheckUser class
Is used to call checks with give user_id, to know if it passes the ACLs for
the given route.
It should never be able to run a route function.
"""
def __init__(self, user_id: UUID) -> None:
self.__id = user_id
@property
def is_authenticated(self) -> bool:
return True
@property
def display_name(self) -> str:
return 'check_user'
@property
def id(self) -> str:
return self.__id return self.__id
@ -90,15 +117,25 @@ class JWTAuthenticationBackend(AuthenticationBackend):
self.algorithm = algorithm self.algorithm = algorithm
self.prefix = prefix self.prefix = prefix
@property
def id(self) -> str:
return self.__id
async def authenticate( async def authenticate(
self, conn: HTTPConnection self, conn: HTTPConnection
) -> typing.Optional[typing.Tuple["AuthCredentials", "BaseUser"]]: ) -> typing.Optional[typing.Tuple['AuthCredentials', 'BaseUser']]:
if "Authorization" not in conn.headers:
return None
token = conn.headers["Authorization"] token = conn.headers.get('Authorization')
is_check_call = 'check' in conn.query_params
is_fake_user_id = is_check_call and 'user_id' in conn.query_params
PRODUCTION = conn.scope['app'].debug == False
if not token and not is_check_call:
return AuthCredentials(), UnauthenticatedUser()
try: try:
if token and not is_fake_user_id:
payload = jwt.decode(token, payload = jwt.decode(token,
key=self.secret_key, key=self.secret_key,
algorithms=[self.algorithm], algorithms=[self.algorithm],
@ -106,6 +143,21 @@ class JWTAuthenticationBackend(AuthenticationBackend):
'verify_signature': bool(PRODUCTION) 'verify_signature': bool(PRODUCTION)
}) })
if is_check_call:
if is_fake_user_id:
try:
fake_user_id = UUID(conn.query_params['user_id'])
return AuthCredentials(), CheckUser(fake_user_id)
except ValueError as exc:
raise HTTPException(400, 'user_id parameter not an uuid')
if token:
return AuthCredentials(), CheckUser(payload['user_id'])
else:
return AuthCredentials(), UnauthenticatedUser()
if PRODUCTION and 'debug' in payload.keys() and payload['debug']: if PRODUCTION and 'debug' in payload.keys() and payload['debug']:
raise AuthenticationError( raise AuthenticationError(
'Trying to connect using *DEBUG* token in *PRODUCTION* mode') 'Trying to connect using *DEBUG* token in *PRODUCTION* mode')