[lib] route_acl_decorator becomes HalfRoute.acl_decorator, creation of HalfRoute that wraps starlette.route

This commit is contained in:
Maxime Alves LIRMM@home 2021-11-29 05:42:26 +01:00
parent ad6877a7e9
commit 47d81c048f
3 changed files with 124 additions and 83 deletions

95
halfapi/half_route.py Normal file
View File

@ -0,0 +1,95 @@
""" HalfRoute
Child class of starlette.routing.Route
"""
from functools import partial, wraps
from typing import Callable, Coroutine, List, Dict
from types import FunctionType
from starlette.requests import Request
from starlette.responses import PlainTextResponse
from starlette.routing import Route
from starlette.exceptions import HTTPException
from .logging import logger
class HalfRoute(Route):
""" HalfRoute
"""
def __init__(self, path, fct, params, method):
logger.info('HalfRoute creation: %s', params)
super().__init__(
path,
HalfRoute.acl_decorator(
fct,
params
),
methods=[method])
@staticmethod
def acl_decorator(fct: Callable = None, params: List[Dict] = None) -> Coroutine:
"""
Decorator for async functions that calls pre-conditions functions
and appends kwargs to the target function
Parameters:
fct (Callable):
The function to decorate
params List[Dict]:
A list of dicts that have an "acl" key that points to a function
Returns:
async function
"""
if not params:
params = []
if not fct:
return partial(HalfRoute.acl_decorator, params=params)
@wraps(fct)
async def caller(req: Request, *args, **kwargs):
for param in params:
if param.get('acl'):
passed = param['acl'](req, *args, **kwargs)
if isinstance(passed, FunctionType):
passed = param['acl']()(req, *args, **kwargs)
if not passed:
logger.debug(
'ACL FAIL for current route (%s - %s)', fct, param.get('acl'))
continue
logger.debug(
'ACL OK for current route (%s - %s)', fct, param.get('acl'))
req.scope['acl_pass'] = param['acl'].__name__
if 'args' in param:
req.scope['args'] = param['args']
logger.debug(
'Args for current route (%s)', param.get('args'))
if 'check' in req.query_params:
return PlainTextResponse(param['acl'].__name__)
logger.debug('acl_decorator %s', param)
return await fct(
req, *args,
**{
**kwargs,
})
if 'check' in req.query_params:
return PlainTextResponse('')
raise HTTPException(401)
return caller

View File

@ -3,7 +3,6 @@
Routes module
Fonctions :
- route_acl_decorator
- gen_domain_routes
- gen_starlette_routes
- api_routes
@ -13,18 +12,16 @@ Exception :
- DomainNotFoundError
"""
from datetime import datetime
from functools import partial, wraps
from typing import Callable, Coroutine, List, Dict, Generator, Tuple, Any
import inspect
from typing import Coroutine, Dict, Generator, Tuple, Any
from types import ModuleType, FunctionType
from starlette.exceptions import HTTPException
from starlette.routing import Route
from starlette.requests import Request
from starlette.responses import Response, PlainTextResponse
import yaml
from halfapi.lib.domain import gen_router_routes, domain_acls
from halfapi.lib.responses import ORJSONResponse
from .domain import gen_router_routes, domain_acls, route_decorator
from .responses import ORJSONResponse
from ..half_route import HalfRoute
from ..conf import DOMAINSDICT
from ..logging import logger
@ -50,66 +47,6 @@ def JSONRoute(data: Any) -> Coroutine:
return wrapped
def route_acl_decorator(fct: Callable = None, params: List[Dict] = None) -> Coroutine:
"""
Decorator for async functions that calls pre-conditions functions
and appends kwargs to the target function
Parameters:
fct (Callable):
The function to decorate
params List[Dict]:
A list of dicts that have an "acl" key that points to a function
Returns:
async function
"""
if not params:
params = []
if not fct:
return partial(route_acl_decorator, params=params)
@wraps(fct)
async def caller(req: Request, *args, **kwargs):
for param in params:
if param.get('acl'):
passed = param['acl'](req, *args, **kwargs)
if isinstance(passed, FunctionType):
passed = param['acl']()(req, *args, **kwargs)
if not passed:
logger.debug(
'ACL FAIL for current route (%s - %s)', fct, param.get('acl'))
continue
logger.debug(
'ACL OK for current route (%s - %s)', fct, param.get('acl'))
req.scope['acl_pass'] = param['acl'].__name__
if 'args' in param:
req.scope['args'] = param['args']
if 'check' in req.query_params:
return PlainTextResponse(param['acl'].__name__)
return await fct(
req, *args,
**{
**kwargs,
**param
})
if 'check' in req.query_params:
return PlainTextResponse('')
raise HTTPException(401)
return caller
def gen_domain_routes(m_domain: ModuleType):
@ -120,17 +57,26 @@ def gen_domain_routes(m_domain: ModuleType):
m_domains: ModuleType
Returns:
Generator(Route)
Generator(HalfRoute)
"""
for path, verb, m_router, fct, params in gen_router_routes(m_domain, []):
yield (
Route(f'/{path}',
route_acl_decorator(
fct,
params
),
methods=[verb])
)
for path, method, m_router, fct, params in gen_router_routes(m_domain, []):
yield HalfRoute(f'/{path}', fct, params, method)
def gen_schema_routes(schema: Dict):
"""
Yields the Route objects according to a given schema
"""
for path, methods in schema.items():
for verb, definition in methods.items():
fct = definition.pop('fct')
acls = definition.pop('acls')
# TODO: Check what to do with gen_routes, it is almost the same function
if not inspect.iscoroutinefunction(fct):
yield HalfRoute(path, route_decorator(fct), acls, verb)
else:
yield HalfRoute(path, fct, acls, verb)
def gen_starlette_routes(d_domains: Dict[str, ModuleType]) -> Generator:
"""

View File

@ -1,7 +1,7 @@
import pytest
from starlette.responses import PlainTextResponse
from starlette.testclient import TestClient
from halfapi.lib.routes import route_acl_decorator
from halfapi.half_route import HalfRoute
from halfapi.lib import acl
def test_acl_Check(dummy_app, token_debug_false_builder):
@ -9,7 +9,7 @@ def test_acl_Check(dummy_app, token_debug_false_builder):
A request with ?check should always return a 200 status code
"""
@route_acl_decorator(params=[{'acl':acl.public}])
@HalfRoute.acl_decorator(params=[{'acl':acl.public}])
async def test_route_public(request, **kwargs):
raise Exception('Should not raise')
return PlainTextResponse('ok')
@ -20,7 +20,7 @@ def test_acl_Check(dummy_app, token_debug_false_builder):
resp = test_client.get('/test_public?check')
assert resp.status_code == 200
@route_acl_decorator(params=[{'acl':acl.private}])
@HalfRoute.acl_decorator(params=[{'acl':acl.private}])
async def test_route_private(request, **kwargs):
raise Exception('Should not raise')
return PlainTextResponse('ok')