[JWTmw] add CheckUser to be used when using the "check" flag. Add "user_id" query param to check access of a specific user to a route
This commit is contained in:
parent
0e5a8ede9d
commit
e4e04c6ac1
|
@ -20,15 +20,20 @@ from starlette.authentication import (
|
||||||
AuthenticationBackend, AuthenticationError, BaseUser, AuthCredentials,
|
AuthenticationBackend, AuthenticationError, BaseUser, AuthCredentials,
|
||||||
UnauthenticatedUser)
|
UnauthenticatedUser)
|
||||||
from starlette.requests import HTTPConnection
|
from starlette.requests import HTTPConnection
|
||||||
|
from starlette.exceptions import HTTPException
|
||||||
|
|
||||||
logger = logging.getLogger('halfapi')
|
logger = logging.getLogger('halfapi')
|
||||||
|
|
||||||
|
"""
|
||||||
|
@OLD : old way to check production setting, we can simply check app's "debug" attribute
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from ..conf import PRODUCTION
|
from ..conf import PRODUCTION
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.warning('Could not import PRODUCTION variable from conf module,'\
|
logger.warning('Could not import PRODUCTION variable from conf module,'\
|
||||||
' using HALFAPI_PROD environment variable')
|
' using HALFAPI_PROD environment variable')
|
||||||
PRODUCTION = bool(environ.get('HALFAPI_PROD', False))
|
PRODUCTION = bool(environ.get('HALFAPI_PROD', False))
|
||||||
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from ..conf import SECRET
|
from ..conf import SECRET
|
||||||
|
@ -76,6 +81,28 @@ class JWTUser(BaseUser):
|
||||||
return self.__id
|
return self.__id
|
||||||
|
|
||||||
|
|
||||||
|
class CheckUser(BaseUser):
|
||||||
|
""" CheckUser class
|
||||||
|
|
||||||
|
Is used to call checks with give user_id, to know if it passes the ACLs for
|
||||||
|
the given route.
|
||||||
|
|
||||||
|
It should never be able to run a route function.
|
||||||
|
"""
|
||||||
|
def __init__(self, user_id: UUID) -> None:
|
||||||
|
self.__id = user_id
|
||||||
|
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_authenticated(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def display_name(self) -> str:
|
||||||
|
return 'check_user'
|
||||||
|
|
||||||
|
@property
|
||||||
|
def id(self) -> str:
|
||||||
return self.__id
|
return self.__id
|
||||||
|
|
||||||
|
|
||||||
|
@ -90,15 +117,25 @@ class JWTAuthenticationBackend(AuthenticationBackend):
|
||||||
self.algorithm = algorithm
|
self.algorithm = algorithm
|
||||||
self.prefix = prefix
|
self.prefix = prefix
|
||||||
|
|
||||||
|
@property
|
||||||
|
def id(self) -> str:
|
||||||
|
return self.__id
|
||||||
|
|
||||||
async def authenticate(
|
async def authenticate(
|
||||||
self, conn: HTTPConnection
|
self, conn: HTTPConnection
|
||||||
) -> typing.Optional[typing.Tuple["AuthCredentials", "BaseUser"]]:
|
) -> typing.Optional[typing.Tuple['AuthCredentials', 'BaseUser']]:
|
||||||
|
|
||||||
if "Authorization" not in conn.headers:
|
|
||||||
return None
|
|
||||||
|
|
||||||
token = conn.headers["Authorization"]
|
token = conn.headers.get('Authorization')
|
||||||
|
is_check_call = 'check' in conn.query_params
|
||||||
|
is_fake_user_id = is_check_call and 'user_id' in conn.query_params
|
||||||
|
PRODUCTION = conn.scope['app'].debug == False
|
||||||
|
|
||||||
|
if not token and not is_check_call:
|
||||||
|
return AuthCredentials(), UnauthenticatedUser()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
if token and not is_fake_user_id:
|
||||||
payload = jwt.decode(token,
|
payload = jwt.decode(token,
|
||||||
key=self.secret_key,
|
key=self.secret_key,
|
||||||
algorithms=[self.algorithm],
|
algorithms=[self.algorithm],
|
||||||
|
@ -106,6 +143,21 @@ class JWTAuthenticationBackend(AuthenticationBackend):
|
||||||
'verify_signature': bool(PRODUCTION)
|
'verify_signature': bool(PRODUCTION)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
if is_check_call:
|
||||||
|
if is_fake_user_id:
|
||||||
|
try:
|
||||||
|
fake_user_id = UUID(conn.query_params['user_id'])
|
||||||
|
|
||||||
|
return AuthCredentials(), CheckUser(fake_user_id)
|
||||||
|
except ValueError as exc:
|
||||||
|
raise HTTPException(400, 'user_id parameter not an uuid')
|
||||||
|
|
||||||
|
if token:
|
||||||
|
return AuthCredentials(), CheckUser(payload['user_id'])
|
||||||
|
else:
|
||||||
|
return AuthCredentials(), UnauthenticatedUser()
|
||||||
|
|
||||||
|
|
||||||
if PRODUCTION and 'debug' in payload.keys() and payload['debug']:
|
if PRODUCTION and 'debug' in payload.keys() and payload['debug']:
|
||||||
raise AuthenticationError(
|
raise AuthenticationError(
|
||||||
'Trying to connect using *DEBUG* token in *PRODUCTION* mode')
|
'Trying to connect using *DEBUG* token in *PRODUCTION* mode')
|
||||||
|
|
Loading…
Reference in New Issue