nettoyage / commentaires / renommage de variables - JWTUser.id deviens JWTUser.identity

This commit is contained in:
Maxime Alves LIRMM@home 2021-04-24 08:56:18 +02:00
parent a2fb70f84b
commit 10b1960f4e
10 changed files with 174 additions and 140 deletions

View File

@ -3,11 +3,11 @@
Base ACL module that contains generic functions for domains ACL Base ACL module that contains generic functions for domains ACL
""" """
import logging import logging
from functools import wraps from functools import wraps
from starlette.authentication import UnauthenticatedUser
from json import JSONDecodeError from json import JSONDecodeError
from starlette.authentication import UnauthenticatedUser
from starlette.exceptions import HTTPException
logger = logging.getLogger('uvicorn.asgi') logger = logging.getLogger('uvicorn.asgi')
@ -30,11 +30,20 @@ def connected(fct=public):
return caller return caller
def args_check(fct): def args_check(fct):
""" Decorator that puts required and optional arguments in scope
For GET requests it uses the query_params
For POST requests it uses the body as JSON
If "check" is present in the query params, nothing is done.
If some required arguments are missing, a 400 status code is sent.
"""
@wraps(fct) @wraps(fct)
async def caller(req, *args, **kwargs): async def caller(req, *args, **kwargs):
if 'check' in req.query_params: if 'check' in req.query_params:
""" Check query param should not read the "args" # Check query param should not read the "args"
"""
return await fct(req, *args, **kwargs) return await fct(req, *args, **kwargs)
if req.method == 'GET': if req.method == 'GET':
@ -47,7 +56,7 @@ def args_check(fct):
data_ = {} data_ = {}
def plural(array: list) -> str: def plural(array: list) -> str:
return len(array) > 1 and 's' or '' return 's' if len(array) > 1 else ''
def comma_list(array: list) -> str: def comma_list(array: list) -> str:
return ', '.join(array) return ', '.join(array)

View File

@ -5,7 +5,6 @@ lib/domain.py The domain-scoped utility functions
import importlib import importlib
import logging import logging
import time
from types import ModuleType from types import ModuleType
from typing import Generator, Dict, List from typing import Generator, Dict, List
@ -94,7 +93,7 @@ def gen_routes(route_params: Dict, path: List, m_router: ModuleType) -> Generato
if params is None: if params is None:
continue continue
if len(params) == 0: if len(params) == 0:
logger.error(f'No ACL for route [{verb}] "/".join(path)') logger.error('No ACL for route [{%s}] %s', verb, "/".join(path))
try: try:
fct_name = get_fct_name(verb, path[-1]) fct_name = get_fct_name(verb, path[-1])
@ -125,7 +124,7 @@ def gen_router_routes(m_router: ModuleType, path: List[str]) -> Generator:
""" """
if not hasattr(m_router, 'ROUTES'): if not hasattr(m_router, 'ROUTES'):
logger.error(f'Missing *ROUTES* constant in *{m_router.__name__}*') logger.error('Missing *ROUTES* constant in *%s*', m_router.__name__)
raise Exception(f'No ROUTES constant for {m_router.__name__}') raise Exception(f'No ROUTES constant for {m_router.__name__}')

View File

@ -13,7 +13,6 @@ from starlette.types import Scope, Send, Receive
from .routes import api_routes from .routes import api_routes
from .domain import d_domains from .domain import d_domains
from ..conf import config_dict
logger = logging.getLogger('uvicorn.asgi') logger = logging.getLogger('uvicorn.asgi')
@ -62,7 +61,7 @@ class DomainMiddleware(BaseHTTPMiddleware):
scope_['config'] = dict(config_section) scope_['config'] = dict(config_section)
except configparser.NoSectionError: except configparser.NoSectionError:
logger.debug( logger.debug(
f'No specific configuration for domain **{current_domain}**') 'No specific configuration for domain **%s**', current_domain)
scope_['config'] = {} scope_['config'] = {}

View File

@ -1,44 +1,26 @@
__LICENSE__ = """ """
BSD 3-Clause License JWT Middleware module
Copyright (c) 2018, Amit Ripshtos Classes:
All rights reserved. - JWTUser : goes in request.user
- JWTAuthenticationBackend
- JWTWebSocketAuthenticationBackend
Redistribution and use in source and binary forms, with or without Raises:
modification, are permitted provided that the following conditions are met: Exception: If configuration has no SECRET or HALFAPI_SECRET is not set
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
""" """
from os import environ from os import environ
import typing
import logging
from uuid import UUID
import jwt import jwt
from uuid import UUID
from starlette.authentication import ( from starlette.authentication import (
AuthenticationBackend, AuthenticationError, BaseUser, AuthCredentials, AuthenticationBackend, AuthenticationError, BaseUser, AuthCredentials,
UnauthenticatedUser) UnauthenticatedUser)
from starlette.requests import HTTPConnection
import logging
logger = logging.getLogger('halfapi') logger = logging.getLogger('halfapi')
try: try:
@ -50,18 +32,22 @@ except ImportError:
try: try:
from ..conf import SECRET from ..conf import SECRET
except ImportError: except ImportError as exc:
logger.warning('Could not import SECRET variable from conf module,'\ logger.warning('Could not import SECRET variable from conf module,'\
' using HALFAPI_SECRET environment variable') ' using HALFAPI_SECRET environment variable')
SECRET = environ.get('HALFAPI_SECRET', False) SECRET = environ.get('HALFAPI_SECRET', False)
if not SECRET: if not SECRET:
raise Exception('Missing HALFAPI_SECRET variable') raise Exception('Missing HALFAPI_SECRET variable') from exc
class JWTUser(BaseUser): class JWTUser(BaseUser):
def __init__(self, id: UUID, token: str, payload: dict) -> None: """ JWTUser class
self.__id = id
Is used to store authentication informations
"""
def __init__(self, user_id: UUID, token: str, payload: dict) -> None:
self.__id = user_id
self.token = token self.token = token
self.payload = payload self.payload = payload
@ -81,10 +67,16 @@ class JWTUser(BaseUser):
return True return True
@property @property
def id(self) -> str: def display_name(self) -> str:
return ' '.join(
(self.payload.get('name'), self.payload.get('firstname')))
@property
def identity(self) -> str:
return self.__id return self.__id
class JWTAuthenticationBackend(AuthenticationBackend): class JWTAuthenticationBackend(AuthenticationBackend):
def __init__(self, secret_key: str = SECRET, def __init__(self, secret_key: str = SECRET,
algorithm: str = 'HS256', prefix: str = 'JWT'): algorithm: str = 'HS256', prefix: str = 'JWT'):
@ -95,11 +87,14 @@ class JWTAuthenticationBackend(AuthenticationBackend):
self.algorithm = algorithm self.algorithm = algorithm
self.prefix = prefix self.prefix = prefix
async def authenticate(self, request): async def authenticate(
if "Authorization" not in request.headers: self, conn: HTTPConnection
) -> typing.Optional[typing.Tuple["AuthCredentials", "BaseUser"]]:
if "Authorization" not in conn.headers:
return None return None
token = request.headers["Authorization"] token = conn.headers["Authorization"]
try: try:
payload = jwt.decode(token, payload = jwt.decode(token,
key=self.secret_key, key=self.secret_key,
@ -113,32 +108,36 @@ class JWTAuthenticationBackend(AuthenticationBackend):
'Trying to connect using *DEBUG* token in *PRODUCTION* mode') 'Trying to connect using *DEBUG* token in *PRODUCTION* mode')
except jwt.InvalidTokenError as exc: except jwt.InvalidTokenError as exc:
raise AuthenticationError(str(exc)) raise AuthenticationError(str(exc)) from exc
except Exception as exc: except Exception as exc:
logger.error('Authentication error : %s', exc) logger.error('Authentication error : %s', exc)
raise exc raise exc
return AuthCredentials(["authenticated"]), JWTUser( return AuthCredentials(["authenticated"]), JWTUser(
id=payload['user_id'], token=token, payload=payload) user_id=payload['user_id'], token=token, payload=payload)
class JWTWebSocketAuthenticationBackend(AuthenticationBackend): class JWTWebSocketAuthenticationBackend(AuthenticationBackend):
def __init__(self, secret_key: str, algorithm: str = 'HS256', query_param_name: str = 'jwt', def __init__(self, secret_key: str, algorithm: str = 'HS256', query_param_name: str = 'jwt',
id: UUID = None, audience = None): user_id: UUID = None, audience = None):
self.secret_key = secret_key self.secret_key = secret_key
self.algorithm = algorithm self.algorithm = algorithm
self.query_param_name = query_param_name self.query_param_name = query_param_name
self.id = id self.__id = user_id
self.audience = audience self.audience = audience
async def authenticate(self, request): async def authenticate(
if self.query_param_name not in request.query_params: self, conn: HTTPConnection
) -> typing.Optional[typing.Tuple["AuthCredentials", "BaseUser"]]:
if self.query_param_name not in conn.query_params:
return AuthCredentials(), UnauthenticatedUser() return AuthCredentials(), UnauthenticatedUser()
token = request.query_params[self.query_param_name] token = conn.query_params[self.query_param_name]
try: try:
payload = jwt.decode( payload = jwt.decode(
@ -155,12 +154,12 @@ class JWTWebSocketAuthenticationBackend(AuthenticationBackend):
'Trying to connect using *DEBUG* token in *PRODUCTION* mode') 'Trying to connect using *DEBUG* token in *PRODUCTION* mode')
except jwt.InvalidTokenError as exc: except jwt.InvalidTokenError as exc:
raise AuthenticationError(str(exc)) raise AuthenticationError(str(exc)) from exc
return ( return (
AuthCredentials(["authenticated"]), AuthCredentials(["authenticated"]),
JWTUser( JWTUser(
id=payload['id'], user_id=payload['id'],
token=token, token=token,
payload=payload) payload=payload)
) )

View File

@ -1,12 +1,16 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from starlette.exceptions import HTTPException
""" """
This is the *query* library that contains all the useful functions to treat our This is the *query* library that contains all the useful functions to treat our
queries queries
Fonction:
- parse_query
""" """
def parse_query(q: str = ""): from starlette.exceptions import HTTPException
def parse_query(q_string: str = ""):
""" """
Returns the fitting Response object according to query parameters. Returns the fitting Response object according to query parameters.
@ -15,7 +19,7 @@ def parse_query(q: str = ""):
It returns a callable function that returns the desired Response object. It returns a callable function that returns the desired Response object.
Parameters: Parameters:
q (str): The query string "q" parameter, in the format q_string (str): The query string "q" parameter, in the format
key0:value0|...|keyN:valueN key0:value0|...|keyN:valueN
Returns: Returns:
@ -61,16 +65,16 @@ def parse_query(q: str = ""):
""" """
params = {} params = {}
if len(q) > 0: if len(q_string) > 0:
try: try:
split_ = lambda x : x.split(':') split_ = lambda x : x.split(':')
params = dict(map(split_, q.split('|'))) params = dict(map(split_, q_string.split('|')))
except ValueError: except ValueError as exc:
raise HTTPException(400) raise HTTPException(400) from exc
split_ = lambda x : x.split(':') split_ = lambda x : x.split(':')
params = dict(map(split_, q.split('|'))) params = dict(map(split_, q_string.split('|')))
def select(obj, fields = []): def select(obj, fields):
if 'limit' in params and int(params['limit']) > 0: if 'limit' in params and int(params['limit']) > 0:
obj.limit(int(params['limit'])) obj.limit(int(params['limit']))

View File

@ -1,7 +1,21 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# builtins # builtins
""" Response module
Contains some base response classes
Classes :
- HJSONResponse
- InternalServerErrorResponse
- NotFoundResponse
- NotImplementedResponse
- ORJSONResponse
- PlainTextResponse
- UnauthorizedResponse
"""
import decimal import decimal
import typing as typ import typing
import orjson import orjson
# asgi framework # asgi framework
@ -47,17 +61,21 @@ class UnauthorizedResponse(Response):
class ORJSONResponse(JSONResponse): class ORJSONResponse(JSONResponse):
""" The response that encodes data into JSON
"""
def __init__(self, content, default=None, **kwargs): def __init__(self, content, default=None, **kwargs):
self.default = default if default is not None else ORJSONResponse.default_cast self.default = default if default is not None else ORJSONResponse.default_cast
super().__init__(content, **kwargs) super().__init__(content, **kwargs)
def render(self, content: typ.Any) -> bytes: def render(self, content: typing.Any) -> bytes:
return orjson.dumps(content, return orjson.dumps(content,
option=orjson.OPT_NON_STR_KEYS, option=orjson.OPT_NON_STR_KEYS,
default=self.default) default=self.default)
@staticmethod @staticmethod
def default_cast(x): def default_cast(typ):
""" Cast the data in JSON-serializable type
"""
str_types = { str_types = {
decimal.Decimal decimal.Decimal
} }
@ -65,14 +83,16 @@ class ORJSONResponse(JSONResponse):
set set
} }
if type(x) in str_types: if type(typ) in str_types:
return str(x) return str(typ)
if type(x) in list_types: if type(typ) in list_types:
return list(x) return list(typ)
raise TypeError(f'Type {type(x)} is not handled by ORJSONResponse') raise TypeError(f'Type {type(typ)} is not handled by ORJSONResponse')
class HJSONResponse(ORJSONResponse): class HJSONResponse(ORJSONResponse):
def render(self, content: typ.Generator): """ The response that encodes generator data into JSON
"""
def render(self, content: typing.Generator):
return super().render(list(content)) return super().render(list(content))

View File

@ -1,5 +1,18 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import time """
Routes module
Fonctions :
- route_acl_decorator
- gen_starlette_routes
- api_routes
- api_acls
- debug_routes
Exception :
- DomainNotFoundError
"""
from datetime import datetime from datetime import datetime
from functools import wraps from functools import wraps
import logging import logging
@ -17,7 +30,8 @@ from halfapi.lib.domain import gen_domain_routes, VERBS
logger = logging.getLogger('uvicorn.asgi') logger = logging.getLogger('uvicorn.asgi')
class DomainNotFoundError(Exception): class DomainNotFoundError(Exception):
pass """ Exception when a domain is not importable
"""
def route_acl_decorator(fct: Callable, params: List[Dict]): def route_acl_decorator(fct: Callable, params: List[Dict]):
""" """
@ -46,10 +60,11 @@ def route_acl_decorator(fct: Callable, params: List[Dict]):
if not passed: if not passed:
logger.debug( logger.debug(
f'ACL FAIL for current route ({fct} - {param.get("acl")})') 'ACL FAIL for current route (%s - %s)', fct, param.get('acl'))
continue continue
logger.debug(f'ACL OK for current route ({fct} - {param.get("acl")})') logger.debug(
'ACL OK for current route (%s - %s)', fct, param.get('acl'))
req.scope['acl_pass'] = param['acl'].__name__ req.scope['acl_pass'] = param['acl'].__name__
if 'args' in param: if 'args' in param:
@ -138,6 +153,8 @@ def api_routes(m_dom: ModuleType) -> Generator:
def api_acls(request): def api_acls(request):
""" Returns the list of possible ACLs
"""
res = {} res = {}
doc = 'doc' in request.query_params doc = 'doc' in request.query_params
for domain, d_domain_acl in request.scope['acl'].items(): for domain, d_domain_acl in request.scope['acl'].items():
@ -152,6 +169,8 @@ def api_acls(request):
def debug_routes(): def debug_routes():
""" Halfapi debug routes definition
"""
async def debug_log(request: Request, *args, **kwargs): async def debug_log(request: Request, *args, **kwargs):
logger.debug('debuglog# %s', {datetime.now().isoformat()}) logger.debug('debuglog# %s', {datetime.now().isoformat()})
logger.info('debuglog# %s', {datetime.now().isoformat()}) logger.info('debuglog# %s', {datetime.now().isoformat()})

View File

@ -1,41 +1,28 @@
from types import ModuleType """ Schemas module
Functions :
- get_api_routes
- schema_json
- schema_dict_dom
- get_acls
Constant :
SCHEMAS (starlette.schemas.SchemaGenerator)
"""
from typing import Dict from typing import Dict
from ..conf import DOMAINSDICT
from .routes import gen_starlette_routes
from .responses import *
from .jwt_middleware import UnauthenticatedUser, JWTUser
from starlette.schemas import SchemaGenerator from starlette.schemas import SchemaGenerator
from starlette.routing import Router
from .routes import gen_starlette_routes, api_acls
from .responses import ORJSONResponse
SCHEMAS = SchemaGenerator( SCHEMAS = SchemaGenerator(
{"openapi": "3.0.0", "info": {"title": "HalfAPI", "version": "1.0"}} {"openapi": "3.0.0", "info": {"title": "HalfAPI", "version": "1.0"}}
) )
"""
example: > {
"dummy_domain": {
"/abc/alphabet/organigramme": {
"fqtn": null,
"params": [
{}
],
"verb": "GET"
},
"/act/personne/": {
"fqtn": "acteur.personne",
"params": [
{}
"verb": "GET"
}
}
}
"""
async def get_api_routes(request, *args, **kwargs): async def get_api_routes(request, *args, **kwargs):
""" """
responses:
200:
description: Returns the current API routes description (HalfAPI 0.2.1) description: Returns the current API routes description (HalfAPI 0.2.1)
as a JSON object as a JSON object
""" """
@ -44,8 +31,6 @@ async def get_api_routes(request, *args, **kwargs):
async def schema_json(request, *args, **kwargs): async def schema_json(request, *args, **kwargs):
""" """
responses:
200:
description: Returns the current API routes description (OpenAPI v3) description: Returns the current API routes description (OpenAPI v3)
as a JSON object as a JSON object
""" """
@ -59,8 +44,7 @@ def schema_dict_dom(d_domains) -> Dict:
Parameters: Parameters:
m_domain (ModuleType): The module to scan for routes d_domains (Dict[str, moduleType]): The module to scan for routes
Returns: Returns:
@ -73,12 +57,7 @@ def schema_dict_dom(d_domains) -> Dict:
async def get_acls(request, *args, **kwargs): async def get_acls(request, *args, **kwargs):
""" """
responses:
200:
description: A dictionnary of the domains and their acls, with the description: A dictionnary of the domains and their acls, with the
result of the acls functions result of the acls functions
""" """
from .routes import api_acls
return ORJSONResponse(api_acls(request)) return ORJSONResponse(api_acls(request))

View File

@ -1,3 +1,10 @@
"""
Timing module
Helpers to gathers stats on requests timing
class HTimingClient
"""
import logging import logging
from timing_asgi import TimingClient from timing_asgi import TimingClient
@ -5,12 +12,11 @@ from timing_asgi import TimingClient
logger = logging.getLogger('uvicorn.asgi') logger = logging.getLogger('uvicorn.asgi')
class HTimingClient(TimingClient): class HTimingClient(TimingClient):
""" Used to redefine TimingClient.timing
"""
def timing(self, metric_name, timing, tags): def timing(self, metric_name, timing, tags):
tags_d = { tags_d = dict(map(lambda elt: elt.split(':'), tags))
key: val
for key, val in map(
lambda elt: elt.split(':'), tags)
}
logger.debug('[TIME:%s][%s] %s %s - %sms', logger.debug('[TIME:%s][%s] %s %s - %sms',
tags_d['time'], metric_name, tags_d['time'], metric_name,
tags_d['http_method'], tags_d['http_status'], tags_d['http_method'], tags_d['http_status'],

View File

@ -145,7 +145,7 @@ def test_JWTUser():
token = '{}' token = '{}'
payload = {} payload = {}
user = JWTUser(uid, token, payload) user = JWTUser(uid, token, payload)
assert user.id == uid assert user.identity == uid
assert user.token == token assert user.token == token
assert user.payload == payload assert user.payload == payload
assert user.is_authenticated == True assert user.is_authenticated == True