nettoyage / commentaires / renommage de variables - JWTUser.id deviens JWTUser.identity
This commit is contained in:
parent
a2fb70f84b
commit
10b1960f4e
|
@ -3,11 +3,11 @@
|
|||
Base ACL module that contains generic functions for domains ACL
|
||||
"""
|
||||
import logging
|
||||
|
||||
from functools import wraps
|
||||
from starlette.authentication import UnauthenticatedUser
|
||||
|
||||
from json import JSONDecodeError
|
||||
from starlette.authentication import UnauthenticatedUser
|
||||
from starlette.exceptions import HTTPException
|
||||
|
||||
|
||||
logger = logging.getLogger('uvicorn.asgi')
|
||||
|
||||
|
@ -30,11 +30,20 @@ def connected(fct=public):
|
|||
return caller
|
||||
|
||||
def args_check(fct):
|
||||
""" Decorator that puts required and optional arguments in scope
|
||||
|
||||
For GET requests it uses the query_params
|
||||
|
||||
For POST requests it uses the body as JSON
|
||||
|
||||
If "check" is present in the query params, nothing is done.
|
||||
|
||||
If some required arguments are missing, a 400 status code is sent.
|
||||
"""
|
||||
@wraps(fct)
|
||||
async def caller(req, *args, **kwargs):
|
||||
if 'check' in req.query_params:
|
||||
""" Check query param should not read the "args"
|
||||
"""
|
||||
# Check query param should not read the "args"
|
||||
return await fct(req, *args, **kwargs)
|
||||
|
||||
if req.method == 'GET':
|
||||
|
@ -47,7 +56,7 @@ def args_check(fct):
|
|||
data_ = {}
|
||||
|
||||
def plural(array: list) -> str:
|
||||
return len(array) > 1 and 's' or ''
|
||||
return 's' if len(array) > 1 else ''
|
||||
def comma_list(array: list) -> str:
|
||||
return ', '.join(array)
|
||||
|
||||
|
|
|
@ -5,7 +5,6 @@ lib/domain.py The domain-scoped utility functions
|
|||
|
||||
import importlib
|
||||
import logging
|
||||
import time
|
||||
from types import ModuleType
|
||||
from typing import Generator, Dict, List
|
||||
|
||||
|
@ -94,7 +93,7 @@ def gen_routes(route_params: Dict, path: List, m_router: ModuleType) -> Generato
|
|||
if params is None:
|
||||
continue
|
||||
if len(params) == 0:
|
||||
logger.error(f'No ACL for route [{verb}] "/".join(path)')
|
||||
logger.error('No ACL for route [{%s}] %s', verb, "/".join(path))
|
||||
|
||||
try:
|
||||
fct_name = get_fct_name(verb, path[-1])
|
||||
|
@ -125,7 +124,7 @@ def gen_router_routes(m_router: ModuleType, path: List[str]) -> Generator:
|
|||
"""
|
||||
|
||||
if not hasattr(m_router, 'ROUTES'):
|
||||
logger.error(f'Missing *ROUTES* constant in *{m_router.__name__}*')
|
||||
logger.error('Missing *ROUTES* constant in *%s*', m_router.__name__)
|
||||
raise Exception(f'No ROUTES constant for {m_router.__name__}')
|
||||
|
||||
|
||||
|
|
|
@ -13,7 +13,6 @@ from starlette.types import Scope, Send, Receive
|
|||
|
||||
from .routes import api_routes
|
||||
from .domain import d_domains
|
||||
from ..conf import config_dict
|
||||
|
||||
logger = logging.getLogger('uvicorn.asgi')
|
||||
|
||||
|
@ -62,7 +61,7 @@ class DomainMiddleware(BaseHTTPMiddleware):
|
|||
scope_['config'] = dict(config_section)
|
||||
except configparser.NoSectionError:
|
||||
logger.debug(
|
||||
f'No specific configuration for domain **{current_domain}**')
|
||||
'No specific configuration for domain **%s**', current_domain)
|
||||
scope_['config'] = {}
|
||||
|
||||
|
||||
|
|
|
@ -1,44 +1,26 @@
|
|||
__LICENSE__ = """
|
||||
BSD 3-Clause License
|
||||
"""
|
||||
JWT Middleware module
|
||||
|
||||
Copyright (c) 2018, Amit Ripshtos
|
||||
All rights reserved.
|
||||
Classes:
|
||||
- JWTUser : goes in request.user
|
||||
- JWTAuthenticationBackend
|
||||
- JWTWebSocketAuthenticationBackend
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
* Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
* Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
Raises:
|
||||
Exception: If configuration has no SECRET or HALFAPI_SECRET is not set
|
||||
"""
|
||||
|
||||
from os import environ
|
||||
import typing
|
||||
import logging
|
||||
from uuid import UUID
|
||||
|
||||
import jwt
|
||||
from uuid import UUID
|
||||
from starlette.authentication import (
|
||||
AuthenticationBackend, AuthenticationError, BaseUser, AuthCredentials,
|
||||
UnauthenticatedUser)
|
||||
from starlette.requests import HTTPConnection
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger('halfapi')
|
||||
|
||||
try:
|
||||
|
@ -50,18 +32,22 @@ except ImportError:
|
|||
|
||||
try:
|
||||
from ..conf import SECRET
|
||||
except ImportError:
|
||||
except ImportError as exc:
|
||||
logger.warning('Could not import SECRET variable from conf module,'\
|
||||
' using HALFAPI_SECRET environment variable')
|
||||
SECRET = environ.get('HALFAPI_SECRET', False)
|
||||
if not SECRET:
|
||||
raise Exception('Missing HALFAPI_SECRET variable')
|
||||
raise Exception('Missing HALFAPI_SECRET variable') from exc
|
||||
|
||||
|
||||
|
||||
class JWTUser(BaseUser):
|
||||
def __init__(self, id: UUID, token: str, payload: dict) -> None:
|
||||
self.__id = id
|
||||
""" JWTUser class
|
||||
|
||||
Is used to store authentication informations
|
||||
"""
|
||||
def __init__(self, user_id: UUID, token: str, payload: dict) -> None:
|
||||
self.__id = user_id
|
||||
self.token = token
|
||||
self.payload = payload
|
||||
|
||||
|
@ -81,10 +67,16 @@ class JWTUser(BaseUser):
|
|||
return True
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
def display_name(self) -> str:
|
||||
return ' '.join(
|
||||
(self.payload.get('name'), self.payload.get('firstname')))
|
||||
|
||||
@property
|
||||
def identity(self) -> str:
|
||||
return self.__id
|
||||
|
||||
|
||||
|
||||
class JWTAuthenticationBackend(AuthenticationBackend):
|
||||
def __init__(self, secret_key: str = SECRET,
|
||||
algorithm: str = 'HS256', prefix: str = 'JWT'):
|
||||
|
@ -95,11 +87,14 @@ class JWTAuthenticationBackend(AuthenticationBackend):
|
|||
self.algorithm = algorithm
|
||||
self.prefix = prefix
|
||||
|
||||
async def authenticate(self, request):
|
||||
if "Authorization" not in request.headers:
|
||||
async def authenticate(
|
||||
self, conn: HTTPConnection
|
||||
) -> typing.Optional[typing.Tuple["AuthCredentials", "BaseUser"]]:
|
||||
|
||||
if "Authorization" not in conn.headers:
|
||||
return None
|
||||
|
||||
token = request.headers["Authorization"]
|
||||
token = conn.headers["Authorization"]
|
||||
try:
|
||||
payload = jwt.decode(token,
|
||||
key=self.secret_key,
|
||||
|
@ -113,32 +108,36 @@ class JWTAuthenticationBackend(AuthenticationBackend):
|
|||
'Trying to connect using *DEBUG* token in *PRODUCTION* mode')
|
||||
|
||||
except jwt.InvalidTokenError as exc:
|
||||
raise AuthenticationError(str(exc))
|
||||
raise AuthenticationError(str(exc)) from exc
|
||||
except Exception as exc:
|
||||
logger.error('Authentication error : %s', exc)
|
||||
raise exc
|
||||
|
||||
|
||||
return AuthCredentials(["authenticated"]), JWTUser(
|
||||
id=payload['user_id'], token=token, payload=payload)
|
||||
user_id=payload['user_id'], token=token, payload=payload)
|
||||
|
||||
|
||||
|
||||
class JWTWebSocketAuthenticationBackend(AuthenticationBackend):
|
||||
|
||||
def __init__(self, secret_key: str, algorithm: str = 'HS256', query_param_name: str = 'jwt',
|
||||
id: UUID = None, audience = None):
|
||||
user_id: UUID = None, audience = None):
|
||||
self.secret_key = secret_key
|
||||
self.algorithm = algorithm
|
||||
self.query_param_name = query_param_name
|
||||
self.id = id
|
||||
self.__id = user_id
|
||||
self.audience = audience
|
||||
|
||||
|
||||
async def authenticate(self, request):
|
||||
if self.query_param_name not in request.query_params:
|
||||
async def authenticate(
|
||||
self, conn: HTTPConnection
|
||||
) -> typing.Optional[typing.Tuple["AuthCredentials", "BaseUser"]]:
|
||||
|
||||
if self.query_param_name not in conn.query_params:
|
||||
return AuthCredentials(), UnauthenticatedUser()
|
||||
|
||||
token = request.query_params[self.query_param_name]
|
||||
token = conn.query_params[self.query_param_name]
|
||||
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
|
@ -155,12 +154,12 @@ class JWTWebSocketAuthenticationBackend(AuthenticationBackend):
|
|||
'Trying to connect using *DEBUG* token in *PRODUCTION* mode')
|
||||
|
||||
except jwt.InvalidTokenError as exc:
|
||||
raise AuthenticationError(str(exc))
|
||||
raise AuthenticationError(str(exc)) from exc
|
||||
|
||||
return (
|
||||
AuthCredentials(["authenticated"]),
|
||||
JWTUser(
|
||||
id=payload['id'],
|
||||
user_id=payload['id'],
|
||||
token=token,
|
||||
payload=payload)
|
||||
)
|
||||
|
|
|
@ -1,12 +1,16 @@
|
|||
#!/usr/bin/env python3
|
||||
from starlette.exceptions import HTTPException
|
||||
|
||||
"""
|
||||
This is the *query* library that contains all the useful functions to treat our
|
||||
queries
|
||||
|
||||
Fonction:
|
||||
- parse_query
|
||||
"""
|
||||
|
||||
def parse_query(q: str = ""):
|
||||
from starlette.exceptions import HTTPException
|
||||
|
||||
|
||||
def parse_query(q_string: str = ""):
|
||||
"""
|
||||
Returns the fitting Response object according to query parameters.
|
||||
|
||||
|
@ -15,7 +19,7 @@ def parse_query(q: str = ""):
|
|||
It returns a callable function that returns the desired Response object.
|
||||
|
||||
Parameters:
|
||||
q (str): The query string "q" parameter, in the format
|
||||
q_string (str): The query string "q" parameter, in the format
|
||||
key0:value0|...|keyN:valueN
|
||||
|
||||
Returns:
|
||||
|
@ -61,16 +65,16 @@ def parse_query(q: str = ""):
|
|||
"""
|
||||
|
||||
params = {}
|
||||
if len(q) > 0:
|
||||
if len(q_string) > 0:
|
||||
try:
|
||||
split_ = lambda x : x.split(':')
|
||||
params = dict(map(split_, q.split('|')))
|
||||
except ValueError:
|
||||
raise HTTPException(400)
|
||||
params = dict(map(split_, q_string.split('|')))
|
||||
except ValueError as exc:
|
||||
raise HTTPException(400) from exc
|
||||
split_ = lambda x : x.split(':')
|
||||
params = dict(map(split_, q.split('|')))
|
||||
params = dict(map(split_, q_string.split('|')))
|
||||
|
||||
def select(obj, fields = []):
|
||||
def select(obj, fields):
|
||||
|
||||
if 'limit' in params and int(params['limit']) > 0:
|
||||
obj.limit(int(params['limit']))
|
||||
|
|
|
@ -1,7 +1,21 @@
|
|||
#!/usr/bin/env python3
|
||||
# builtins
|
||||
""" Response module
|
||||
|
||||
Contains some base response classes
|
||||
|
||||
Classes :
|
||||
- HJSONResponse
|
||||
- InternalServerErrorResponse
|
||||
- NotFoundResponse
|
||||
- NotImplementedResponse
|
||||
- ORJSONResponse
|
||||
- PlainTextResponse
|
||||
- UnauthorizedResponse
|
||||
|
||||
"""
|
||||
import decimal
|
||||
import typing as typ
|
||||
import typing
|
||||
import orjson
|
||||
|
||||
# asgi framework
|
||||
|
@ -47,17 +61,21 @@ class UnauthorizedResponse(Response):
|
|||
|
||||
|
||||
class ORJSONResponse(JSONResponse):
|
||||
""" The response that encodes data into JSON
|
||||
"""
|
||||
def __init__(self, content, default=None, **kwargs):
|
||||
self.default = default if default is not None else ORJSONResponse.default_cast
|
||||
super().__init__(content, **kwargs)
|
||||
|
||||
def render(self, content: typ.Any) -> bytes:
|
||||
def render(self, content: typing.Any) -> bytes:
|
||||
return orjson.dumps(content,
|
||||
option=orjson.OPT_NON_STR_KEYS,
|
||||
default=self.default)
|
||||
|
||||
@staticmethod
|
||||
def default_cast(x):
|
||||
def default_cast(typ):
|
||||
""" Cast the data in JSON-serializable type
|
||||
"""
|
||||
str_types = {
|
||||
decimal.Decimal
|
||||
}
|
||||
|
@ -65,14 +83,16 @@ class ORJSONResponse(JSONResponse):
|
|||
set
|
||||
}
|
||||
|
||||
if type(x) in str_types:
|
||||
return str(x)
|
||||
if type(x) in list_types:
|
||||
return list(x)
|
||||
if type(typ) in str_types:
|
||||
return str(typ)
|
||||
if type(typ) in list_types:
|
||||
return list(typ)
|
||||
|
||||
raise TypeError(f'Type {type(x)} is not handled by ORJSONResponse')
|
||||
raise TypeError(f'Type {type(typ)} is not handled by ORJSONResponse')
|
||||
|
||||
|
||||
class HJSONResponse(ORJSONResponse):
|
||||
def render(self, content: typ.Generator):
|
||||
""" The response that encodes generator data into JSON
|
||||
"""
|
||||
def render(self, content: typing.Generator):
|
||||
return super().render(list(content))
|
||||
|
|
|
@ -1,5 +1,18 @@
|
|||
#!/usr/bin/env python3
|
||||
import time
|
||||
"""
|
||||
Routes module
|
||||
|
||||
Fonctions :
|
||||
- route_acl_decorator
|
||||
- gen_starlette_routes
|
||||
- api_routes
|
||||
- api_acls
|
||||
- debug_routes
|
||||
|
||||
Exception :
|
||||
- DomainNotFoundError
|
||||
|
||||
"""
|
||||
from datetime import datetime
|
||||
from functools import wraps
|
||||
import logging
|
||||
|
@ -17,7 +30,8 @@ from halfapi.lib.domain import gen_domain_routes, VERBS
|
|||
logger = logging.getLogger('uvicorn.asgi')
|
||||
|
||||
class DomainNotFoundError(Exception):
|
||||
pass
|
||||
""" Exception when a domain is not importable
|
||||
"""
|
||||
|
||||
def route_acl_decorator(fct: Callable, params: List[Dict]):
|
||||
"""
|
||||
|
@ -46,10 +60,11 @@ def route_acl_decorator(fct: Callable, params: List[Dict]):
|
|||
|
||||
if not passed:
|
||||
logger.debug(
|
||||
f'ACL FAIL for current route ({fct} - {param.get("acl")})')
|
||||
'ACL FAIL for current route (%s - %s)', fct, param.get('acl'))
|
||||
continue
|
||||
|
||||
logger.debug(f'ACL OK for current route ({fct} - {param.get("acl")})')
|
||||
logger.debug(
|
||||
'ACL OK for current route (%s - %s)', fct, param.get('acl'))
|
||||
|
||||
req.scope['acl_pass'] = param['acl'].__name__
|
||||
if 'args' in param:
|
||||
|
@ -138,6 +153,8 @@ def api_routes(m_dom: ModuleType) -> Generator:
|
|||
|
||||
|
||||
def api_acls(request):
|
||||
""" Returns the list of possible ACLs
|
||||
"""
|
||||
res = {}
|
||||
doc = 'doc' in request.query_params
|
||||
for domain, d_domain_acl in request.scope['acl'].items():
|
||||
|
@ -152,6 +169,8 @@ def api_acls(request):
|
|||
|
||||
|
||||
def debug_routes():
|
||||
""" Halfapi debug routes definition
|
||||
"""
|
||||
async def debug_log(request: Request, *args, **kwargs):
|
||||
logger.debug('debuglog# %s', {datetime.now().isoformat()})
|
||||
logger.info('debuglog# %s', {datetime.now().isoformat()})
|
||||
|
|
|
@ -1,53 +1,38 @@
|
|||
from types import ModuleType
|
||||
""" Schemas module
|
||||
|
||||
Functions :
|
||||
- get_api_routes
|
||||
- schema_json
|
||||
- schema_dict_dom
|
||||
- get_acls
|
||||
|
||||
Constant :
|
||||
SCHEMAS (starlette.schemas.SchemaGenerator)
|
||||
"""
|
||||
|
||||
from typing import Dict
|
||||
from ..conf import DOMAINSDICT
|
||||
from .routes import gen_starlette_routes
|
||||
from .responses import *
|
||||
from .jwt_middleware import UnauthenticatedUser, JWTUser
|
||||
|
||||
from starlette.schemas import SchemaGenerator
|
||||
from starlette.routing import Router
|
||||
|
||||
from .routes import gen_starlette_routes, api_acls
|
||||
from .responses import ORJSONResponse
|
||||
|
||||
SCHEMAS = SchemaGenerator(
|
||||
{"openapi": "3.0.0", "info": {"title": "HalfAPI", "version": "1.0"}}
|
||||
)
|
||||
|
||||
"""
|
||||
example: > {
|
||||
"dummy_domain": {
|
||||
"/abc/alphabet/organigramme": {
|
||||
"fqtn": null,
|
||||
"params": [
|
||||
{}
|
||||
],
|
||||
"verb": "GET"
|
||||
},
|
||||
"/act/personne/": {
|
||||
"fqtn": "acteur.personne",
|
||||
"params": [
|
||||
{}
|
||||
|
||||
"verb": "GET"
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
async def get_api_routes(request, *args, **kwargs):
|
||||
"""
|
||||
responses:
|
||||
200:
|
||||
description: Returns the current API routes description (HalfAPI 0.2.1)
|
||||
as a JSON object
|
||||
description: Returns the current API routes description (HalfAPI 0.2.1)
|
||||
as a JSON object
|
||||
"""
|
||||
return ORJSONResponse(request.scope['api'])
|
||||
|
||||
|
||||
async def schema_json(request, *args, **kwargs):
|
||||
"""
|
||||
responses:
|
||||
200:
|
||||
description: Returns the current API routes description (OpenAPI v3)
|
||||
as a JSON object
|
||||
description: Returns the current API routes description (OpenAPI v3)
|
||||
as a JSON object
|
||||
"""
|
||||
return ORJSONResponse(
|
||||
SCHEMAS.get_schema(routes=request.app.routes))
|
||||
|
@ -59,8 +44,7 @@ def schema_dict_dom(d_domains) -> Dict:
|
|||
|
||||
Parameters:
|
||||
|
||||
m_domain (ModuleType): The module to scan for routes
|
||||
|
||||
d_domains (Dict[str, moduleType]): The module to scan for routes
|
||||
|
||||
Returns:
|
||||
|
||||
|
@ -73,12 +57,7 @@ def schema_dict_dom(d_domains) -> Dict:
|
|||
|
||||
async def get_acls(request, *args, **kwargs):
|
||||
"""
|
||||
responses:
|
||||
200:
|
||||
description: A dictionnary of the domains and their acls, with the
|
||||
result of the acls functions
|
||||
description: A dictionnary of the domains and their acls, with the
|
||||
result of the acls functions
|
||||
"""
|
||||
|
||||
from .routes import api_acls
|
||||
return ORJSONResponse(api_acls(request))
|
||||
|
||||
|
|
|
@ -1,3 +1,10 @@
|
|||
"""
|
||||
Timing module
|
||||
|
||||
Helpers to gathers stats on requests timing
|
||||
|
||||
class HTimingClient
|
||||
"""
|
||||
import logging
|
||||
|
||||
from timing_asgi import TimingClient
|
||||
|
@ -5,12 +12,11 @@ from timing_asgi import TimingClient
|
|||
logger = logging.getLogger('uvicorn.asgi')
|
||||
|
||||
class HTimingClient(TimingClient):
|
||||
""" Used to redefine TimingClient.timing
|
||||
"""
|
||||
def timing(self, metric_name, timing, tags):
|
||||
tags_d = {
|
||||
key: val
|
||||
for key, val in map(
|
||||
lambda elt: elt.split(':'), tags)
|
||||
}
|
||||
tags_d = dict(map(lambda elt: elt.split(':'), tags))
|
||||
|
||||
logger.debug('[TIME:%s][%s] %s %s - %sms',
|
||||
tags_d['time'], metric_name,
|
||||
tags_d['http_method'], tags_d['http_status'],
|
||||
|
|
|
@ -145,7 +145,7 @@ def test_JWTUser():
|
|||
token = '{}'
|
||||
payload = {}
|
||||
user = JWTUser(uid, token, payload)
|
||||
assert user.id == uid
|
||||
assert user.identity == uid
|
||||
assert user.token == token
|
||||
assert user.payload == payload
|
||||
assert user.is_authenticated == True
|
||||
|
|
Loading…
Reference in New Issue