diff --git a/halfapi/half_route.py b/halfapi/half_route.py new file mode 100644 index 0000000..df93670 --- /dev/null +++ b/halfapi/half_route.py @@ -0,0 +1,95 @@ +""" HalfRoute + +Child class of starlette.routing.Route +""" +from functools import partial, wraps + +from typing import Callable, Coroutine, List, Dict +from types import FunctionType + +from starlette.requests import Request +from starlette.responses import PlainTextResponse +from starlette.routing import Route +from starlette.exceptions import HTTPException + +from .logging import logger + +class HalfRoute(Route): + """ HalfRoute + """ + def __init__(self, path, fct, params, method): + logger.info('HalfRoute creation: %s', params) + super().__init__( + path, + HalfRoute.acl_decorator( + fct, + params + ), + methods=[method]) + + @staticmethod + def acl_decorator(fct: Callable = None, params: List[Dict] = None) -> Coroutine: + """ + Decorator for async functions that calls pre-conditions functions + and appends kwargs to the target function + + + Parameters: + fct (Callable): + The function to decorate + + params List[Dict]: + A list of dicts that have an "acl" key that points to a function + + Returns: + async function + """ + + if not params: + params = [] + + if not fct: + return partial(HalfRoute.acl_decorator, params=params) + + + @wraps(fct) + async def caller(req: Request, *args, **kwargs): + for param in params: + if param.get('acl'): + passed = param['acl'](req, *args, **kwargs) + if isinstance(passed, FunctionType): + passed = param['acl']()(req, *args, **kwargs) + + if not passed: + logger.debug( + 'ACL FAIL for current route (%s - %s)', fct, param.get('acl')) + continue + + logger.debug( + 'ACL OK for current route (%s - %s)', fct, param.get('acl')) + + req.scope['acl_pass'] = param['acl'].__name__ + + if 'args' in param: + req.scope['args'] = param['args'] + logger.debug( + 'Args for current route (%s)', param.get('args')) + + + + if 'check' in req.query_params: + return PlainTextResponse(param['acl'].__name__) + + logger.debug('acl_decorator %s', param) + return await fct( + req, *args, + **{ + **kwargs, + }) + + if 'check' in req.query_params: + return PlainTextResponse('') + + raise HTTPException(401) + + return caller diff --git a/halfapi/lib/routes.py b/halfapi/lib/routes.py index 6b650bd..c7b9045 100644 --- a/halfapi/lib/routes.py +++ b/halfapi/lib/routes.py @@ -3,7 +3,6 @@ Routes module Fonctions : - - route_acl_decorator - gen_domain_routes - gen_starlette_routes - api_routes @@ -13,18 +12,16 @@ Exception : - DomainNotFoundError """ -from datetime import datetime -from functools import partial, wraps -from typing import Callable, Coroutine, List, Dict, Generator, Tuple, Any +import inspect + +from typing import Coroutine, Dict, Generator, Tuple, Any from types import ModuleType, FunctionType -from starlette.exceptions import HTTPException -from starlette.routing import Route -from starlette.requests import Request -from starlette.responses import Response, PlainTextResponse +import yaml -from halfapi.lib.domain import gen_router_routes, domain_acls -from halfapi.lib.responses import ORJSONResponse +from .domain import gen_router_routes, domain_acls, route_decorator +from .responses import ORJSONResponse +from ..half_route import HalfRoute from ..conf import DOMAINSDICT from ..logging import logger @@ -50,66 +47,6 @@ def JSONRoute(data: Any) -> Coroutine: return wrapped -def route_acl_decorator(fct: Callable = None, params: List[Dict] = None) -> Coroutine: - """ - Decorator for async functions that calls pre-conditions functions - and appends kwargs to the target function - - - Parameters: - fct (Callable): - The function to decorate - - params List[Dict]: - A list of dicts that have an "acl" key that points to a function - - Returns: - async function - """ - - if not params: - params = [] - - if not fct: - return partial(route_acl_decorator, params=params) - - - @wraps(fct) - async def caller(req: Request, *args, **kwargs): - for param in params: - if param.get('acl'): - passed = param['acl'](req, *args, **kwargs) - if isinstance(passed, FunctionType): - passed = param['acl']()(req, *args, **kwargs) - - if not passed: - logger.debug( - 'ACL FAIL for current route (%s - %s)', fct, param.get('acl')) - continue - - logger.debug( - 'ACL OK for current route (%s - %s)', fct, param.get('acl')) - - req.scope['acl_pass'] = param['acl'].__name__ - if 'args' in param: - req.scope['args'] = param['args'] - - if 'check' in req.query_params: - return PlainTextResponse(param['acl'].__name__) - - return await fct( - req, *args, - **{ - **kwargs, - **param - }) - - if 'check' in req.query_params: - return PlainTextResponse('') - - raise HTTPException(401) - - return caller def gen_domain_routes(m_domain: ModuleType): @@ -120,17 +57,26 @@ def gen_domain_routes(m_domain: ModuleType): m_domains: ModuleType Returns: - Generator(Route) + Generator(HalfRoute) """ - for path, verb, m_router, fct, params in gen_router_routes(m_domain, []): - yield ( - Route(f'/{path}', - route_acl_decorator( - fct, - params - ), - methods=[verb]) - ) + for path, method, m_router, fct, params in gen_router_routes(m_domain, []): + yield HalfRoute(f'/{path}', fct, params, method) + + +def gen_schema_routes(schema: Dict): + """ + Yields the Route objects according to a given schema + """ + for path, methods in schema.items(): + for verb, definition in methods.items(): + fct = definition.pop('fct') + acls = definition.pop('acls') + # TODO: Check what to do with gen_routes, it is almost the same function + if not inspect.iscoroutinefunction(fct): + yield HalfRoute(path, route_decorator(fct), acls, verb) + else: + yield HalfRoute(path, fct, acls, verb) + def gen_starlette_routes(d_domains: Dict[str, ModuleType]) -> Generator: """ diff --git a/tests/test_acl.py b/tests/test_acl.py index c16878e..52a6272 100644 --- a/tests/test_acl.py +++ b/tests/test_acl.py @@ -1,7 +1,7 @@ import pytest from starlette.responses import PlainTextResponse from starlette.testclient import TestClient -from halfapi.lib.routes import route_acl_decorator +from halfapi.half_route import HalfRoute from halfapi.lib import acl def test_acl_Check(dummy_app, token_debug_false_builder): @@ -9,7 +9,7 @@ def test_acl_Check(dummy_app, token_debug_false_builder): A request with ?check should always return a 200 status code """ - @route_acl_decorator(params=[{'acl':acl.public}]) + @HalfRoute.acl_decorator(params=[{'acl':acl.public}]) async def test_route_public(request, **kwargs): raise Exception('Should not raise') return PlainTextResponse('ok') @@ -20,7 +20,7 @@ def test_acl_Check(dummy_app, token_debug_false_builder): resp = test_client.get('/test_public?check') assert resp.status_code == 200 - @route_acl_decorator(params=[{'acl':acl.private}]) + @HalfRoute.acl_decorator(params=[{'acl':acl.private}]) async def test_route_private(request, **kwargs): raise Exception('Should not raise') return PlainTextResponse('ok')