From 10b1960f4e080db2a331d1a2a61c5e7df610c324 Mon Sep 17 00:00:00 2001 From: "Maxime Alves LIRMM@home" Date: Sat, 24 Apr 2021 08:56:18 +0200 Subject: [PATCH] nettoyage / commentaires / renommage de variables - JWTUser.id deviens JWTUser.identity --- halfapi/lib/acl.py | 21 +++++-- halfapi/lib/domain.py | 5 +- halfapi/lib/domain_middleware.py | 3 +- halfapi/lib/jwt_middleware.py | 97 ++++++++++++++++---------------- halfapi/lib/query.py | 24 ++++---- halfapi/lib/responses.py | 46 ++++++++++----- halfapi/lib/routes.py | 27 +++++++-- halfapi/lib/schemas.py | 73 +++++++++--------------- halfapi/lib/timing.py | 16 ++++-- tests/test_jwt_middleware.py | 2 +- 10 files changed, 174 insertions(+), 140 deletions(-) diff --git a/halfapi/lib/acl.py b/halfapi/lib/acl.py index 04e92b1..57556c8 100644 --- a/halfapi/lib/acl.py +++ b/halfapi/lib/acl.py @@ -3,11 +3,11 @@ Base ACL module that contains generic functions for domains ACL """ import logging - from functools import wraps -from starlette.authentication import UnauthenticatedUser - from json import JSONDecodeError +from starlette.authentication import UnauthenticatedUser +from starlette.exceptions import HTTPException + logger = logging.getLogger('uvicorn.asgi') @@ -30,11 +30,20 @@ def connected(fct=public): return caller def args_check(fct): + """ Decorator that puts required and optional arguments in scope + + For GET requests it uses the query_params + + For POST requests it uses the body as JSON + + If "check" is present in the query params, nothing is done. + + If some required arguments are missing, a 400 status code is sent. + """ @wraps(fct) async def caller(req, *args, **kwargs): if 'check' in req.query_params: - """ Check query param should not read the "args" - """ + # Check query param should not read the "args" return await fct(req, *args, **kwargs) if req.method == 'GET': @@ -47,7 +56,7 @@ def args_check(fct): data_ = {} def plural(array: list) -> str: - return len(array) > 1 and 's' or '' + return 's' if len(array) > 1 else '' def comma_list(array: list) -> str: return ', '.join(array) diff --git a/halfapi/lib/domain.py b/halfapi/lib/domain.py index e296965..ee0d917 100644 --- a/halfapi/lib/domain.py +++ b/halfapi/lib/domain.py @@ -5,7 +5,6 @@ lib/domain.py The domain-scoped utility functions import importlib import logging -import time from types import ModuleType from typing import Generator, Dict, List @@ -94,7 +93,7 @@ def gen_routes(route_params: Dict, path: List, m_router: ModuleType) -> Generato if params is None: continue if len(params) == 0: - logger.error(f'No ACL for route [{verb}] "/".join(path)') + logger.error('No ACL for route [{%s}] %s', verb, "/".join(path)) try: fct_name = get_fct_name(verb, path[-1]) @@ -125,7 +124,7 @@ def gen_router_routes(m_router: ModuleType, path: List[str]) -> Generator: """ if not hasattr(m_router, 'ROUTES'): - logger.error(f'Missing *ROUTES* constant in *{m_router.__name__}*') + logger.error('Missing *ROUTES* constant in *%s*', m_router.__name__) raise Exception(f'No ROUTES constant for {m_router.__name__}') diff --git a/halfapi/lib/domain_middleware.py b/halfapi/lib/domain_middleware.py index 541e940..5215fec 100644 --- a/halfapi/lib/domain_middleware.py +++ b/halfapi/lib/domain_middleware.py @@ -13,7 +13,6 @@ from starlette.types import Scope, Send, Receive from .routes import api_routes from .domain import d_domains -from ..conf import config_dict logger = logging.getLogger('uvicorn.asgi') @@ -62,7 +61,7 @@ class DomainMiddleware(BaseHTTPMiddleware): scope_['config'] = dict(config_section) except configparser.NoSectionError: logger.debug( - f'No specific configuration for domain **{current_domain}**') + 'No specific configuration for domain **%s**', current_domain) scope_['config'] = {} diff --git a/halfapi/lib/jwt_middleware.py b/halfapi/lib/jwt_middleware.py index 056a3f6..ddfe144 100644 --- a/halfapi/lib/jwt_middleware.py +++ b/halfapi/lib/jwt_middleware.py @@ -1,44 +1,26 @@ -__LICENSE__ = """ -BSD 3-Clause License +""" +JWT Middleware module -Copyright (c) 2018, Amit Ripshtos -All rights reserved. +Classes: + - JWTUser : goes in request.user + - JWTAuthenticationBackend + - JWTWebSocketAuthenticationBackend -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - -* Redistributions of source code must retain the above copyright notice, this - list of conditions and the following disclaimer. - -* Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - -* Neither the name of the copyright holder nor the names of its - contributors may be used to endorse or promote products derived from - this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +Raises: + Exception: If configuration has no SECRET or HALFAPI_SECRET is not set """ from os import environ +import typing +import logging +from uuid import UUID import jwt -from uuid import UUID from starlette.authentication import ( AuthenticationBackend, AuthenticationError, BaseUser, AuthCredentials, UnauthenticatedUser) +from starlette.requests import HTTPConnection -import logging logger = logging.getLogger('halfapi') try: @@ -50,18 +32,22 @@ except ImportError: try: from ..conf import SECRET -except ImportError: +except ImportError as exc: logger.warning('Could not import SECRET variable from conf module,'\ ' using HALFAPI_SECRET environment variable') SECRET = environ.get('HALFAPI_SECRET', False) if not SECRET: - raise Exception('Missing HALFAPI_SECRET variable') + raise Exception('Missing HALFAPI_SECRET variable') from exc class JWTUser(BaseUser): - def __init__(self, id: UUID, token: str, payload: dict) -> None: - self.__id = id + """ JWTUser class + + Is used to store authentication informations + """ + def __init__(self, user_id: UUID, token: str, payload: dict) -> None: + self.__id = user_id self.token = token self.payload = payload @@ -81,25 +67,34 @@ class JWTUser(BaseUser): return True @property - def id(self) -> str: + def display_name(self) -> str: + return ' '.join( + (self.payload.get('name'), self.payload.get('firstname'))) + + @property + def identity(self) -> str: return self.__id + class JWTAuthenticationBackend(AuthenticationBackend): def __init__(self, secret_key: str = SECRET, algorithm: str = 'HS256', prefix: str = 'JWT'): if secret_key is None: - raise Exception('Missing secret_key argument for JWTAuthenticationBackend') + raise Exception('Missing secret_key argument for JWTAuthenticationBackend') self.secret_key = secret_key self.algorithm = algorithm self.prefix = prefix - async def authenticate(self, request): - if "Authorization" not in request.headers: + async def authenticate( + self, conn: HTTPConnection + ) -> typing.Optional[typing.Tuple["AuthCredentials", "BaseUser"]]: + + if "Authorization" not in conn.headers: return None - token = request.headers["Authorization"] + token = conn.headers["Authorization"] try: payload = jwt.decode(token, key=self.secret_key, @@ -113,32 +108,36 @@ class JWTAuthenticationBackend(AuthenticationBackend): 'Trying to connect using *DEBUG* token in *PRODUCTION* mode') except jwt.InvalidTokenError as exc: - raise AuthenticationError(str(exc)) + raise AuthenticationError(str(exc)) from exc except Exception as exc: logger.error('Authentication error : %s', exc) raise exc return AuthCredentials(["authenticated"]), JWTUser( - id=payload['user_id'], token=token, payload=payload) + user_id=payload['user_id'], token=token, payload=payload) + class JWTWebSocketAuthenticationBackend(AuthenticationBackend): def __init__(self, secret_key: str, algorithm: str = 'HS256', query_param_name: str = 'jwt', - id: UUID = None, audience = None): + user_id: UUID = None, audience = None): self.secret_key = secret_key self.algorithm = algorithm self.query_param_name = query_param_name - self.id = id + self.__id = user_id self.audience = audience - async def authenticate(self, request): - if self.query_param_name not in request.query_params: + async def authenticate( + self, conn: HTTPConnection + ) -> typing.Optional[typing.Tuple["AuthCredentials", "BaseUser"]]: + + if self.query_param_name not in conn.query_params: return AuthCredentials(), UnauthenticatedUser() - token = request.query_params[self.query_param_name] + token = conn.query_params[self.query_param_name] try: payload = jwt.decode( @@ -155,12 +154,12 @@ class JWTWebSocketAuthenticationBackend(AuthenticationBackend): 'Trying to connect using *DEBUG* token in *PRODUCTION* mode') except jwt.InvalidTokenError as exc: - raise AuthenticationError(str(exc)) + raise AuthenticationError(str(exc)) from exc return ( - AuthCredentials(["authenticated"]), + AuthCredentials(["authenticated"]), JWTUser( - id=payload['id'], + user_id=payload['id'], token=token, payload=payload) ) diff --git a/halfapi/lib/query.py b/halfapi/lib/query.py index eaa1d1d..0e3ebad 100644 --- a/halfapi/lib/query.py +++ b/halfapi/lib/query.py @@ -1,12 +1,16 @@ #!/usr/bin/env python3 -from starlette.exceptions import HTTPException - """ This is the *query* library that contains all the useful functions to treat our queries + +Fonction: + - parse_query """ -def parse_query(q: str = ""): +from starlette.exceptions import HTTPException + + +def parse_query(q_string: str = ""): """ Returns the fitting Response object according to query parameters. @@ -15,7 +19,7 @@ def parse_query(q: str = ""): It returns a callable function that returns the desired Response object. Parameters: - q (str): The query string "q" parameter, in the format + q_string (str): The query string "q" parameter, in the format key0:value0|...|keyN:valueN Returns: @@ -61,16 +65,16 @@ def parse_query(q: str = ""): """ params = {} - if len(q) > 0: + if len(q_string) > 0: try: split_ = lambda x : x.split(':') - params = dict(map(split_, q.split('|'))) - except ValueError: - raise HTTPException(400) + params = dict(map(split_, q_string.split('|'))) + except ValueError as exc: + raise HTTPException(400) from exc split_ = lambda x : x.split(':') - params = dict(map(split_, q.split('|'))) + params = dict(map(split_, q_string.split('|'))) - def select(obj, fields = []): + def select(obj, fields): if 'limit' in params and int(params['limit']) > 0: obj.limit(int(params['limit'])) diff --git a/halfapi/lib/responses.py b/halfapi/lib/responses.py index 7fb0acc..7dbde9b 100644 --- a/halfapi/lib/responses.py +++ b/halfapi/lib/responses.py @@ -1,7 +1,21 @@ #!/usr/bin/env python3 # builtins +""" Response module + +Contains some base response classes + +Classes : + - HJSONResponse + - InternalServerErrorResponse + - NotFoundResponse + - NotImplementedResponse + - ORJSONResponse + - PlainTextResponse + - UnauthorizedResponse + +""" import decimal -import typing as typ +import typing import orjson # asgi framework @@ -19,45 +33,49 @@ __all__ = [ class InternalServerErrorResponse(Response): - """ The 500 Internal Server Error default Response + """ The 500 Internal Server Error default Response """ def __init__(self, *args, **kwargs): super().__init__(status_code=500) class NotFoundResponse(Response): - """ The 404 Not Found default Response + """ The 404 Not Found default Response """ def __init__(self, *args, **kwargs): super().__init__(status_code=404) class NotImplementedResponse(Response): - """ The 501 Not Implemented default Response + """ The 501 Not Implemented default Response """ def __init__(self, *args, **kwargs): super().__init__(status_code=501) class UnauthorizedResponse(Response): - """ The 401 Not Found default Response + """ The 401 Not Found default Response """ def __init__(self, *args, **kwargs): super().__init__(status_code = 401) class ORJSONResponse(JSONResponse): + """ The response that encodes data into JSON + """ def __init__(self, content, default=None, **kwargs): self.default = default if default is not None else ORJSONResponse.default_cast super().__init__(content, **kwargs) - def render(self, content: typ.Any) -> bytes: + def render(self, content: typing.Any) -> bytes: return orjson.dumps(content, option=orjson.OPT_NON_STR_KEYS, default=self.default) @staticmethod - def default_cast(x): + def default_cast(typ): + """ Cast the data in JSON-serializable type + """ str_types = { decimal.Decimal } @@ -65,14 +83,16 @@ class ORJSONResponse(JSONResponse): set } - if type(x) in str_types: - return str(x) - if type(x) in list_types: - return list(x) + if type(typ) in str_types: + return str(typ) + if type(typ) in list_types: + return list(typ) - raise TypeError(f'Type {type(x)} is not handled by ORJSONResponse') + raise TypeError(f'Type {type(typ)} is not handled by ORJSONResponse') class HJSONResponse(ORJSONResponse): - def render(self, content: typ.Generator): + """ The response that encodes generator data into JSON + """ + def render(self, content: typing.Generator): return super().render(list(content)) diff --git a/halfapi/lib/routes.py b/halfapi/lib/routes.py index ebf60fb..293949f 100644 --- a/halfapi/lib/routes.py +++ b/halfapi/lib/routes.py @@ -1,5 +1,18 @@ #!/usr/bin/env python3 -import time +""" +Routes module + +Fonctions : + - route_acl_decorator + - gen_starlette_routes + - api_routes + - api_acls + - debug_routes + +Exception : + - DomainNotFoundError + +""" from datetime import datetime from functools import wraps import logging @@ -17,7 +30,8 @@ from halfapi.lib.domain import gen_domain_routes, VERBS logger = logging.getLogger('uvicorn.asgi') class DomainNotFoundError(Exception): - pass + """ Exception when a domain is not importable + """ def route_acl_decorator(fct: Callable, params: List[Dict]): """ @@ -46,10 +60,11 @@ def route_acl_decorator(fct: Callable, params: List[Dict]): if not passed: logger.debug( - f'ACL FAIL for current route ({fct} - {param.get("acl")})') + 'ACL FAIL for current route (%s - %s)', fct, param.get('acl')) continue - logger.debug(f'ACL OK for current route ({fct} - {param.get("acl")})') + logger.debug( + 'ACL OK for current route (%s - %s)', fct, param.get('acl')) req.scope['acl_pass'] = param['acl'].__name__ if 'args' in param: @@ -138,6 +153,8 @@ def api_routes(m_dom: ModuleType) -> Generator: def api_acls(request): + """ Returns the list of possible ACLs + """ res = {} doc = 'doc' in request.query_params for domain, d_domain_acl in request.scope['acl'].items(): @@ -152,6 +169,8 @@ def api_acls(request): def debug_routes(): + """ Halfapi debug routes definition + """ async def debug_log(request: Request, *args, **kwargs): logger.debug('debuglog# %s', {datetime.now().isoformat()}) logger.info('debuglog# %s', {datetime.now().isoformat()}) diff --git a/halfapi/lib/schemas.py b/halfapi/lib/schemas.py index 4df6af4..b1f69a8 100644 --- a/halfapi/lib/schemas.py +++ b/halfapi/lib/schemas.py @@ -1,53 +1,38 @@ -from types import ModuleType +""" Schemas module + +Functions : + - get_api_routes + - schema_json + - schema_dict_dom + - get_acls + +Constant : + SCHEMAS (starlette.schemas.SchemaGenerator) +""" + from typing import Dict -from ..conf import DOMAINSDICT -from .routes import gen_starlette_routes -from .responses import * -from .jwt_middleware import UnauthenticatedUser, JWTUser + from starlette.schemas import SchemaGenerator -from starlette.routing import Router + +from .routes import gen_starlette_routes, api_acls +from .responses import ORJSONResponse + SCHEMAS = SchemaGenerator( {"openapi": "3.0.0", "info": {"title": "HalfAPI", "version": "1.0"}} ) -""" -example: > { -"dummy_domain": { - "/abc/alphabet/organigramme": { - "fqtn": null, - "params": [ - {} - ], - "verb": "GET" - }, - "/act/personne/": { - "fqtn": "acteur.personne", - "params": [ - {} - - "verb": "GET" - } - -} -} -""" - async def get_api_routes(request, *args, **kwargs): """ - responses: - 200: - description: Returns the current API routes description (HalfAPI 0.2.1) - as a JSON object + description: Returns the current API routes description (HalfAPI 0.2.1) + as a JSON object """ return ORJSONResponse(request.scope['api']) async def schema_json(request, *args, **kwargs): """ - responses: - 200: - description: Returns the current API routes description (OpenAPI v3) - as a JSON object + description: Returns the current API routes description (OpenAPI v3) + as a JSON object """ return ORJSONResponse( SCHEMAS.get_schema(routes=request.app.routes)) @@ -58,14 +43,13 @@ def schema_dict_dom(d_domains) -> Dict: Returns the API schema of the *m_domain* domain as a python dictionnary Parameters: - - m_domain (ModuleType): The module to scan for routes + d_domains (Dict[str, moduleType]): The module to scan for routes Returns: - + Dict: A dictionnary containing the description of the API using the - | OpenAPI standard + | OpenAPI standard """ return SCHEMAS.get_schema( routes=list(gen_starlette_routes(d_domains))) @@ -73,12 +57,7 @@ def schema_dict_dom(d_domains) -> Dict: async def get_acls(request, *args, **kwargs): """ - responses: - 200: - description: A dictionnary of the domains and their acls, with the - result of the acls functions + description: A dictionnary of the domains and their acls, with the + result of the acls functions """ - - from .routes import api_acls return ORJSONResponse(api_acls(request)) - diff --git a/halfapi/lib/timing.py b/halfapi/lib/timing.py index 5580d13..4e5be63 100644 --- a/halfapi/lib/timing.py +++ b/halfapi/lib/timing.py @@ -1,3 +1,10 @@ +""" +Timing module + +Helpers to gathers stats on requests timing + +class HTimingClient +""" import logging from timing_asgi import TimingClient @@ -5,12 +12,11 @@ from timing_asgi import TimingClient logger = logging.getLogger('uvicorn.asgi') class HTimingClient(TimingClient): + """ Used to redefine TimingClient.timing + """ def timing(self, metric_name, timing, tags): - tags_d = { - key: val - for key, val in map( - lambda elt: elt.split(':'), tags) - } + tags_d = dict(map(lambda elt: elt.split(':'), tags)) + logger.debug('[TIME:%s][%s] %s %s - %sms', tags_d['time'], metric_name, tags_d['http_method'], tags_d['http_status'], diff --git a/tests/test_jwt_middleware.py b/tests/test_jwt_middleware.py index ff8aaaa..a758f72 100644 --- a/tests/test_jwt_middleware.py +++ b/tests/test_jwt_middleware.py @@ -145,7 +145,7 @@ def test_JWTUser(): token = '{}' payload = {} user = JWTUser(uid, token, payload) - assert user.id == uid + assert user.identity == uid assert user.token == token assert user.payload == payload assert user.is_authenticated == True