[lib] route_acl_decorator becomes HalfRoute.acl_decorator, creation of HalfRoute that wraps starlette.route
This commit is contained in:
parent
ad6877a7e9
commit
47d81c048f
|
@ -0,0 +1,95 @@
|
|||
""" HalfRoute
|
||||
|
||||
Child class of starlette.routing.Route
|
||||
"""
|
||||
from functools import partial, wraps
|
||||
|
||||
from typing import Callable, Coroutine, List, Dict
|
||||
from types import FunctionType
|
||||
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import PlainTextResponse
|
||||
from starlette.routing import Route
|
||||
from starlette.exceptions import HTTPException
|
||||
|
||||
from .logging import logger
|
||||
|
||||
class HalfRoute(Route):
|
||||
""" HalfRoute
|
||||
"""
|
||||
def __init__(self, path, fct, params, method):
|
||||
logger.info('HalfRoute creation: %s', params)
|
||||
super().__init__(
|
||||
path,
|
||||
HalfRoute.acl_decorator(
|
||||
fct,
|
||||
params
|
||||
),
|
||||
methods=[method])
|
||||
|
||||
@staticmethod
|
||||
def acl_decorator(fct: Callable = None, params: List[Dict] = None) -> Coroutine:
|
||||
"""
|
||||
Decorator for async functions that calls pre-conditions functions
|
||||
and appends kwargs to the target function
|
||||
|
||||
|
||||
Parameters:
|
||||
fct (Callable):
|
||||
The function to decorate
|
||||
|
||||
params List[Dict]:
|
||||
A list of dicts that have an "acl" key that points to a function
|
||||
|
||||
Returns:
|
||||
async function
|
||||
"""
|
||||
|
||||
if not params:
|
||||
params = []
|
||||
|
||||
if not fct:
|
||||
return partial(HalfRoute.acl_decorator, params=params)
|
||||
|
||||
|
||||
@wraps(fct)
|
||||
async def caller(req: Request, *args, **kwargs):
|
||||
for param in params:
|
||||
if param.get('acl'):
|
||||
passed = param['acl'](req, *args, **kwargs)
|
||||
if isinstance(passed, FunctionType):
|
||||
passed = param['acl']()(req, *args, **kwargs)
|
||||
|
||||
if not passed:
|
||||
logger.debug(
|
||||
'ACL FAIL for current route (%s - %s)', fct, param.get('acl'))
|
||||
continue
|
||||
|
||||
logger.debug(
|
||||
'ACL OK for current route (%s - %s)', fct, param.get('acl'))
|
||||
|
||||
req.scope['acl_pass'] = param['acl'].__name__
|
||||
|
||||
if 'args' in param:
|
||||
req.scope['args'] = param['args']
|
||||
logger.debug(
|
||||
'Args for current route (%s)', param.get('args'))
|
||||
|
||||
|
||||
|
||||
if 'check' in req.query_params:
|
||||
return PlainTextResponse(param['acl'].__name__)
|
||||
|
||||
logger.debug('acl_decorator %s', param)
|
||||
return await fct(
|
||||
req, *args,
|
||||
**{
|
||||
**kwargs,
|
||||
})
|
||||
|
||||
if 'check' in req.query_params:
|
||||
return PlainTextResponse('')
|
||||
|
||||
raise HTTPException(401)
|
||||
|
||||
return caller
|
|
@ -3,7 +3,6 @@
|
|||
Routes module
|
||||
|
||||
Fonctions :
|
||||
- route_acl_decorator
|
||||
- gen_domain_routes
|
||||
- gen_starlette_routes
|
||||
- api_routes
|
||||
|
@ -13,18 +12,16 @@ Exception :
|
|||
- DomainNotFoundError
|
||||
|
||||
"""
|
||||
from datetime import datetime
|
||||
from functools import partial, wraps
|
||||
from typing import Callable, Coroutine, List, Dict, Generator, Tuple, Any
|
||||
import inspect
|
||||
|
||||
from typing import Coroutine, Dict, Generator, Tuple, Any
|
||||
from types import ModuleType, FunctionType
|
||||
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.routing import Route
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response, PlainTextResponse
|
||||
import yaml
|
||||
|
||||
from halfapi.lib.domain import gen_router_routes, domain_acls
|
||||
from halfapi.lib.responses import ORJSONResponse
|
||||
from .domain import gen_router_routes, domain_acls, route_decorator
|
||||
from .responses import ORJSONResponse
|
||||
from ..half_route import HalfRoute
|
||||
from ..conf import DOMAINSDICT
|
||||
|
||||
from ..logging import logger
|
||||
|
@ -50,66 +47,6 @@ def JSONRoute(data: Any) -> Coroutine:
|
|||
return wrapped
|
||||
|
||||
|
||||
def route_acl_decorator(fct: Callable = None, params: List[Dict] = None) -> Coroutine:
|
||||
"""
|
||||
Decorator for async functions that calls pre-conditions functions
|
||||
and appends kwargs to the target function
|
||||
|
||||
|
||||
Parameters:
|
||||
fct (Callable):
|
||||
The function to decorate
|
||||
|
||||
params List[Dict]:
|
||||
A list of dicts that have an "acl" key that points to a function
|
||||
|
||||
Returns:
|
||||
async function
|
||||
"""
|
||||
|
||||
if not params:
|
||||
params = []
|
||||
|
||||
if not fct:
|
||||
return partial(route_acl_decorator, params=params)
|
||||
|
||||
|
||||
@wraps(fct)
|
||||
async def caller(req: Request, *args, **kwargs):
|
||||
for param in params:
|
||||
if param.get('acl'):
|
||||
passed = param['acl'](req, *args, **kwargs)
|
||||
if isinstance(passed, FunctionType):
|
||||
passed = param['acl']()(req, *args, **kwargs)
|
||||
|
||||
if not passed:
|
||||
logger.debug(
|
||||
'ACL FAIL for current route (%s - %s)', fct, param.get('acl'))
|
||||
continue
|
||||
|
||||
logger.debug(
|
||||
'ACL OK for current route (%s - %s)', fct, param.get('acl'))
|
||||
|
||||
req.scope['acl_pass'] = param['acl'].__name__
|
||||
if 'args' in param:
|
||||
req.scope['args'] = param['args']
|
||||
|
||||
if 'check' in req.query_params:
|
||||
return PlainTextResponse(param['acl'].__name__)
|
||||
|
||||
return await fct(
|
||||
req, *args,
|
||||
**{
|
||||
**kwargs,
|
||||
**param
|
||||
})
|
||||
|
||||
if 'check' in req.query_params:
|
||||
return PlainTextResponse('')
|
||||
|
||||
raise HTTPException(401)
|
||||
|
||||
return caller
|
||||
|
||||
|
||||
def gen_domain_routes(m_domain: ModuleType):
|
||||
|
@ -120,17 +57,26 @@ def gen_domain_routes(m_domain: ModuleType):
|
|||
m_domains: ModuleType
|
||||
|
||||
Returns:
|
||||
Generator(Route)
|
||||
Generator(HalfRoute)
|
||||
"""
|
||||
for path, verb, m_router, fct, params in gen_router_routes(m_domain, []):
|
||||
yield (
|
||||
Route(f'/{path}',
|
||||
route_acl_decorator(
|
||||
fct,
|
||||
params
|
||||
),
|
||||
methods=[verb])
|
||||
)
|
||||
for path, method, m_router, fct, params in gen_router_routes(m_domain, []):
|
||||
yield HalfRoute(f'/{path}', fct, params, method)
|
||||
|
||||
|
||||
def gen_schema_routes(schema: Dict):
|
||||
"""
|
||||
Yields the Route objects according to a given schema
|
||||
"""
|
||||
for path, methods in schema.items():
|
||||
for verb, definition in methods.items():
|
||||
fct = definition.pop('fct')
|
||||
acls = definition.pop('acls')
|
||||
# TODO: Check what to do with gen_routes, it is almost the same function
|
||||
if not inspect.iscoroutinefunction(fct):
|
||||
yield HalfRoute(path, route_decorator(fct), acls, verb)
|
||||
else:
|
||||
yield HalfRoute(path, fct, acls, verb)
|
||||
|
||||
|
||||
def gen_starlette_routes(d_domains: Dict[str, ModuleType]) -> Generator:
|
||||
"""
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import pytest
|
||||
from starlette.responses import PlainTextResponse
|
||||
from starlette.testclient import TestClient
|
||||
from halfapi.lib.routes import route_acl_decorator
|
||||
from halfapi.half_route import HalfRoute
|
||||
from halfapi.lib import acl
|
||||
|
||||
def test_acl_Check(dummy_app, token_debug_false_builder):
|
||||
|
@ -9,7 +9,7 @@ def test_acl_Check(dummy_app, token_debug_false_builder):
|
|||
A request with ?check should always return a 200 status code
|
||||
"""
|
||||
|
||||
@route_acl_decorator(params=[{'acl':acl.public}])
|
||||
@HalfRoute.acl_decorator(params=[{'acl':acl.public}])
|
||||
async def test_route_public(request, **kwargs):
|
||||
raise Exception('Should not raise')
|
||||
return PlainTextResponse('ok')
|
||||
|
@ -20,7 +20,7 @@ def test_acl_Check(dummy_app, token_debug_false_builder):
|
|||
resp = test_client.get('/test_public?check')
|
||||
assert resp.status_code == 200
|
||||
|
||||
@route_acl_decorator(params=[{'acl':acl.private}])
|
||||
@HalfRoute.acl_decorator(params=[{'acl':acl.private}])
|
||||
async def test_route_private(request, **kwargs):
|
||||
raise Exception('Should not raise')
|
||||
return PlainTextResponse('ok')
|
||||
|
|
Loading…
Reference in New Issue