nettoyage / commentaires / renommage de variables - JWTUser.id deviens JWTUser.identity

This commit is contained in:
Maxime Alves LIRMM@home 2021-04-24 08:56:18 +02:00
parent a2fb70f84b
commit 10b1960f4e
10 changed files with 174 additions and 140 deletions

View File

@ -3,11 +3,11 @@
Base ACL module that contains generic functions for domains ACL
"""
import logging
from functools import wraps
from starlette.authentication import UnauthenticatedUser
from json import JSONDecodeError
from starlette.authentication import UnauthenticatedUser
from starlette.exceptions import HTTPException
logger = logging.getLogger('uvicorn.asgi')
@ -30,11 +30,20 @@ def connected(fct=public):
return caller
def args_check(fct):
""" Decorator that puts required and optional arguments in scope
For GET requests it uses the query_params
For POST requests it uses the body as JSON
If "check" is present in the query params, nothing is done.
If some required arguments are missing, a 400 status code is sent.
"""
@wraps(fct)
async def caller(req, *args, **kwargs):
if 'check' in req.query_params:
""" Check query param should not read the "args"
"""
# Check query param should not read the "args"
return await fct(req, *args, **kwargs)
if req.method == 'GET':
@ -47,7 +56,7 @@ def args_check(fct):
data_ = {}
def plural(array: list) -> str:
return len(array) > 1 and 's' or ''
return 's' if len(array) > 1 else ''
def comma_list(array: list) -> str:
return ', '.join(array)

View File

@ -5,7 +5,6 @@ lib/domain.py The domain-scoped utility functions
import importlib
import logging
import time
from types import ModuleType
from typing import Generator, Dict, List
@ -94,7 +93,7 @@ def gen_routes(route_params: Dict, path: List, m_router: ModuleType) -> Generato
if params is None:
continue
if len(params) == 0:
logger.error(f'No ACL for route [{verb}] "/".join(path)')
logger.error('No ACL for route [{%s}] %s', verb, "/".join(path))
try:
fct_name = get_fct_name(verb, path[-1])
@ -125,7 +124,7 @@ def gen_router_routes(m_router: ModuleType, path: List[str]) -> Generator:
"""
if not hasattr(m_router, 'ROUTES'):
logger.error(f'Missing *ROUTES* constant in *{m_router.__name__}*')
logger.error('Missing *ROUTES* constant in *%s*', m_router.__name__)
raise Exception(f'No ROUTES constant for {m_router.__name__}')

View File

@ -13,7 +13,6 @@ from starlette.types import Scope, Send, Receive
from .routes import api_routes
from .domain import d_domains
from ..conf import config_dict
logger = logging.getLogger('uvicorn.asgi')
@ -62,7 +61,7 @@ class DomainMiddleware(BaseHTTPMiddleware):
scope_['config'] = dict(config_section)
except configparser.NoSectionError:
logger.debug(
f'No specific configuration for domain **{current_domain}**')
'No specific configuration for domain **%s**', current_domain)
scope_['config'] = {}

View File

@ -1,44 +1,26 @@
__LICENSE__ = """
BSD 3-Clause License
"""
JWT Middleware module
Copyright (c) 2018, Amit Ripshtos
All rights reserved.
Classes:
- JWTUser : goes in request.user
- JWTAuthenticationBackend
- JWTWebSocketAuthenticationBackend
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
Raises:
Exception: If configuration has no SECRET or HALFAPI_SECRET is not set
"""
from os import environ
import typing
import logging
from uuid import UUID
import jwt
from uuid import UUID
from starlette.authentication import (
AuthenticationBackend, AuthenticationError, BaseUser, AuthCredentials,
UnauthenticatedUser)
from starlette.requests import HTTPConnection
import logging
logger = logging.getLogger('halfapi')
try:
@ -50,18 +32,22 @@ except ImportError:
try:
from ..conf import SECRET
except ImportError:
except ImportError as exc:
logger.warning('Could not import SECRET variable from conf module,'\
' using HALFAPI_SECRET environment variable')
SECRET = environ.get('HALFAPI_SECRET', False)
if not SECRET:
raise Exception('Missing HALFAPI_SECRET variable')
raise Exception('Missing HALFAPI_SECRET variable') from exc
class JWTUser(BaseUser):
def __init__(self, id: UUID, token: str, payload: dict) -> None:
self.__id = id
""" JWTUser class
Is used to store authentication informations
"""
def __init__(self, user_id: UUID, token: str, payload: dict) -> None:
self.__id = user_id
self.token = token
self.payload = payload
@ -81,10 +67,16 @@ class JWTUser(BaseUser):
return True
@property
def id(self) -> str:
def display_name(self) -> str:
return ' '.join(
(self.payload.get('name'), self.payload.get('firstname')))
@property
def identity(self) -> str:
return self.__id
class JWTAuthenticationBackend(AuthenticationBackend):
def __init__(self, secret_key: str = SECRET,
algorithm: str = 'HS256', prefix: str = 'JWT'):
@ -95,11 +87,14 @@ class JWTAuthenticationBackend(AuthenticationBackend):
self.algorithm = algorithm
self.prefix = prefix
async def authenticate(self, request):
if "Authorization" not in request.headers:
async def authenticate(
self, conn: HTTPConnection
) -> typing.Optional[typing.Tuple["AuthCredentials", "BaseUser"]]:
if "Authorization" not in conn.headers:
return None
token = request.headers["Authorization"]
token = conn.headers["Authorization"]
try:
payload = jwt.decode(token,
key=self.secret_key,
@ -113,32 +108,36 @@ class JWTAuthenticationBackend(AuthenticationBackend):
'Trying to connect using *DEBUG* token in *PRODUCTION* mode')
except jwt.InvalidTokenError as exc:
raise AuthenticationError(str(exc))
raise AuthenticationError(str(exc)) from exc
except Exception as exc:
logger.error('Authentication error : %s', exc)
raise exc
return AuthCredentials(["authenticated"]), JWTUser(
id=payload['user_id'], token=token, payload=payload)
user_id=payload['user_id'], token=token, payload=payload)
class JWTWebSocketAuthenticationBackend(AuthenticationBackend):
def __init__(self, secret_key: str, algorithm: str = 'HS256', query_param_name: str = 'jwt',
id: UUID = None, audience = None):
user_id: UUID = None, audience = None):
self.secret_key = secret_key
self.algorithm = algorithm
self.query_param_name = query_param_name
self.id = id
self.__id = user_id
self.audience = audience
async def authenticate(self, request):
if self.query_param_name not in request.query_params:
async def authenticate(
self, conn: HTTPConnection
) -> typing.Optional[typing.Tuple["AuthCredentials", "BaseUser"]]:
if self.query_param_name not in conn.query_params:
return AuthCredentials(), UnauthenticatedUser()
token = request.query_params[self.query_param_name]
token = conn.query_params[self.query_param_name]
try:
payload = jwt.decode(
@ -155,12 +154,12 @@ class JWTWebSocketAuthenticationBackend(AuthenticationBackend):
'Trying to connect using *DEBUG* token in *PRODUCTION* mode')
except jwt.InvalidTokenError as exc:
raise AuthenticationError(str(exc))
raise AuthenticationError(str(exc)) from exc
return (
AuthCredentials(["authenticated"]),
JWTUser(
id=payload['id'],
user_id=payload['id'],
token=token,
payload=payload)
)

View File

@ -1,12 +1,16 @@
#!/usr/bin/env python3
from starlette.exceptions import HTTPException
"""
This is the *query* library that contains all the useful functions to treat our
queries
Fonction:
- parse_query
"""
def parse_query(q: str = ""):
from starlette.exceptions import HTTPException
def parse_query(q_string: str = ""):
"""
Returns the fitting Response object according to query parameters.
@ -15,7 +19,7 @@ def parse_query(q: str = ""):
It returns a callable function that returns the desired Response object.
Parameters:
q (str): The query string "q" parameter, in the format
q_string (str): The query string "q" parameter, in the format
key0:value0|...|keyN:valueN
Returns:
@ -61,16 +65,16 @@ def parse_query(q: str = ""):
"""
params = {}
if len(q) > 0:
if len(q_string) > 0:
try:
split_ = lambda x : x.split(':')
params = dict(map(split_, q.split('|')))
except ValueError:
raise HTTPException(400)
params = dict(map(split_, q_string.split('|')))
except ValueError as exc:
raise HTTPException(400) from exc
split_ = lambda x : x.split(':')
params = dict(map(split_, q.split('|')))
params = dict(map(split_, q_string.split('|')))
def select(obj, fields = []):
def select(obj, fields):
if 'limit' in params and int(params['limit']) > 0:
obj.limit(int(params['limit']))

View File

@ -1,7 +1,21 @@
#!/usr/bin/env python3
# builtins
""" Response module
Contains some base response classes
Classes :
- HJSONResponse
- InternalServerErrorResponse
- NotFoundResponse
- NotImplementedResponse
- ORJSONResponse
- PlainTextResponse
- UnauthorizedResponse
"""
import decimal
import typing as typ
import typing
import orjson
# asgi framework
@ -47,17 +61,21 @@ class UnauthorizedResponse(Response):
class ORJSONResponse(JSONResponse):
""" The response that encodes data into JSON
"""
def __init__(self, content, default=None, **kwargs):
self.default = default if default is not None else ORJSONResponse.default_cast
super().__init__(content, **kwargs)
def render(self, content: typ.Any) -> bytes:
def render(self, content: typing.Any) -> bytes:
return orjson.dumps(content,
option=orjson.OPT_NON_STR_KEYS,
default=self.default)
@staticmethod
def default_cast(x):
def default_cast(typ):
""" Cast the data in JSON-serializable type
"""
str_types = {
decimal.Decimal
}
@ -65,14 +83,16 @@ class ORJSONResponse(JSONResponse):
set
}
if type(x) in str_types:
return str(x)
if type(x) in list_types:
return list(x)
if type(typ) in str_types:
return str(typ)
if type(typ) in list_types:
return list(typ)
raise TypeError(f'Type {type(x)} is not handled by ORJSONResponse')
raise TypeError(f'Type {type(typ)} is not handled by ORJSONResponse')
class HJSONResponse(ORJSONResponse):
def render(self, content: typ.Generator):
""" The response that encodes generator data into JSON
"""
def render(self, content: typing.Generator):
return super().render(list(content))

View File

@ -1,5 +1,18 @@
#!/usr/bin/env python3
import time
"""
Routes module
Fonctions :
- route_acl_decorator
- gen_starlette_routes
- api_routes
- api_acls
- debug_routes
Exception :
- DomainNotFoundError
"""
from datetime import datetime
from functools import wraps
import logging
@ -17,7 +30,8 @@ from halfapi.lib.domain import gen_domain_routes, VERBS
logger = logging.getLogger('uvicorn.asgi')
class DomainNotFoundError(Exception):
pass
""" Exception when a domain is not importable
"""
def route_acl_decorator(fct: Callable, params: List[Dict]):
"""
@ -46,10 +60,11 @@ def route_acl_decorator(fct: Callable, params: List[Dict]):
if not passed:
logger.debug(
f'ACL FAIL for current route ({fct} - {param.get("acl")})')
'ACL FAIL for current route (%s - %s)', fct, param.get('acl'))
continue
logger.debug(f'ACL OK for current route ({fct} - {param.get("acl")})')
logger.debug(
'ACL OK for current route (%s - %s)', fct, param.get('acl'))
req.scope['acl_pass'] = param['acl'].__name__
if 'args' in param:
@ -138,6 +153,8 @@ def api_routes(m_dom: ModuleType) -> Generator:
def api_acls(request):
""" Returns the list of possible ACLs
"""
res = {}
doc = 'doc' in request.query_params
for domain, d_domain_acl in request.scope['acl'].items():
@ -152,6 +169,8 @@ def api_acls(request):
def debug_routes():
""" Halfapi debug routes definition
"""
async def debug_log(request: Request, *args, **kwargs):
logger.debug('debuglog# %s', {datetime.now().isoformat()})
logger.info('debuglog# %s', {datetime.now().isoformat()})

View File

@ -1,41 +1,28 @@
from types import ModuleType
""" Schemas module
Functions :
- get_api_routes
- schema_json
- schema_dict_dom
- get_acls
Constant :
SCHEMAS (starlette.schemas.SchemaGenerator)
"""
from typing import Dict
from ..conf import DOMAINSDICT
from .routes import gen_starlette_routes
from .responses import *
from .jwt_middleware import UnauthenticatedUser, JWTUser
from starlette.schemas import SchemaGenerator
from starlette.routing import Router
from .routes import gen_starlette_routes, api_acls
from .responses import ORJSONResponse
SCHEMAS = SchemaGenerator(
{"openapi": "3.0.0", "info": {"title": "HalfAPI", "version": "1.0"}}
)
"""
example: > {
"dummy_domain": {
"/abc/alphabet/organigramme": {
"fqtn": null,
"params": [
{}
],
"verb": "GET"
},
"/act/personne/": {
"fqtn": "acteur.personne",
"params": [
{}
"verb": "GET"
}
}
}
"""
async def get_api_routes(request, *args, **kwargs):
"""
responses:
200:
description: Returns the current API routes description (HalfAPI 0.2.1)
as a JSON object
"""
@ -44,8 +31,6 @@ async def get_api_routes(request, *args, **kwargs):
async def schema_json(request, *args, **kwargs):
"""
responses:
200:
description: Returns the current API routes description (OpenAPI v3)
as a JSON object
"""
@ -59,8 +44,7 @@ def schema_dict_dom(d_domains) -> Dict:
Parameters:
m_domain (ModuleType): The module to scan for routes
d_domains (Dict[str, moduleType]): The module to scan for routes
Returns:
@ -73,12 +57,7 @@ def schema_dict_dom(d_domains) -> Dict:
async def get_acls(request, *args, **kwargs):
"""
responses:
200:
description: A dictionnary of the domains and their acls, with the
result of the acls functions
"""
from .routes import api_acls
return ORJSONResponse(api_acls(request))

View File

@ -1,3 +1,10 @@
"""
Timing module
Helpers to gathers stats on requests timing
class HTimingClient
"""
import logging
from timing_asgi import TimingClient
@ -5,12 +12,11 @@ from timing_asgi import TimingClient
logger = logging.getLogger('uvicorn.asgi')
class HTimingClient(TimingClient):
""" Used to redefine TimingClient.timing
"""
def timing(self, metric_name, timing, tags):
tags_d = {
key: val
for key, val in map(
lambda elt: elt.split(':'), tags)
}
tags_d = dict(map(lambda elt: elt.split(':'), tags))
logger.debug('[TIME:%s][%s] %s %s - %sms',
tags_d['time'], metric_name,
tags_d['http_method'], tags_d['http_status'],

View File

@ -145,7 +145,7 @@ def test_JWTUser():
token = '{}'
payload = {}
user = JWTUser(uid, token, payload)
assert user.id == uid
assert user.identity == uid
assert user.token == token
assert user.payload == payload
assert user.is_authenticated == True