[JWTmw] add CheckUser to be used when using the "check" flag. Add "user_id" query param to check access of a specific user to a route
This commit is contained in:
parent
0e5a8ede9d
commit
e4e04c6ac1
|
@ -20,15 +20,20 @@ from starlette.authentication import (
|
|||
AuthenticationBackend, AuthenticationError, BaseUser, AuthCredentials,
|
||||
UnauthenticatedUser)
|
||||
from starlette.requests import HTTPConnection
|
||||
from starlette.exceptions import HTTPException
|
||||
|
||||
logger = logging.getLogger('halfapi')
|
||||
|
||||
"""
|
||||
@OLD : old way to check production setting, we can simply check app's "debug" attribute
|
||||
|
||||
try:
|
||||
from ..conf import PRODUCTION
|
||||
except ImportError:
|
||||
logger.warning('Could not import PRODUCTION variable from conf module,'\
|
||||
' using HALFAPI_PROD environment variable')
|
||||
PRODUCTION = bool(environ.get('HALFAPI_PROD', False))
|
||||
"""
|
||||
|
||||
try:
|
||||
from ..conf import SECRET
|
||||
|
@ -76,6 +81,28 @@ class JWTUser(BaseUser):
|
|||
return self.__id
|
||||
|
||||
|
||||
class CheckUser(BaseUser):
|
||||
""" CheckUser class
|
||||
|
||||
Is used to call checks with give user_id, to know if it passes the ACLs for
|
||||
the given route.
|
||||
|
||||
It should never be able to run a route function.
|
||||
"""
|
||||
def __init__(self, user_id: UUID) -> None:
|
||||
self.__id = user_id
|
||||
|
||||
|
||||
@property
|
||||
def is_authenticated(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
return 'check_user'
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return self.__id
|
||||
|
||||
|
||||
|
@ -90,21 +117,46 @@ class JWTAuthenticationBackend(AuthenticationBackend):
|
|||
self.algorithm = algorithm
|
||||
self.prefix = prefix
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return self.__id
|
||||
|
||||
async def authenticate(
|
||||
self, conn: HTTPConnection
|
||||
) -> typing.Optional[typing.Tuple["AuthCredentials", "BaseUser"]]:
|
||||
) -> typing.Optional[typing.Tuple['AuthCredentials', 'BaseUser']]:
|
||||
|
||||
if "Authorization" not in conn.headers:
|
||||
return None
|
||||
|
||||
token = conn.headers["Authorization"]
|
||||
token = conn.headers.get('Authorization')
|
||||
is_check_call = 'check' in conn.query_params
|
||||
is_fake_user_id = is_check_call and 'user_id' in conn.query_params
|
||||
PRODUCTION = conn.scope['app'].debug == False
|
||||
|
||||
if not token and not is_check_call:
|
||||
return AuthCredentials(), UnauthenticatedUser()
|
||||
|
||||
try:
|
||||
payload = jwt.decode(token,
|
||||
key=self.secret_key,
|
||||
algorithms=[self.algorithm],
|
||||
options={
|
||||
'verify_signature': bool(PRODUCTION)
|
||||
})
|
||||
if token and not is_fake_user_id:
|
||||
payload = jwt.decode(token,
|
||||
key=self.secret_key,
|
||||
algorithms=[self.algorithm],
|
||||
options={
|
||||
'verify_signature': bool(PRODUCTION)
|
||||
})
|
||||
|
||||
if is_check_call:
|
||||
if is_fake_user_id:
|
||||
try:
|
||||
fake_user_id = UUID(conn.query_params['user_id'])
|
||||
|
||||
return AuthCredentials(), CheckUser(fake_user_id)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(400, 'user_id parameter not an uuid')
|
||||
|
||||
if token:
|
||||
return AuthCredentials(), CheckUser(payload['user_id'])
|
||||
else:
|
||||
return AuthCredentials(), UnauthenticatedUser()
|
||||
|
||||
|
||||
if PRODUCTION and 'debug' in payload.keys() and payload['debug']:
|
||||
raise AuthenticationError(
|
||||
|
|
Loading…
Reference in New Issue