diff --git a/halfapi/half_domain.py b/halfapi/half_domain.py index 688470a..ee97026 100644 --- a/halfapi/half_domain.py +++ b/halfapi/half_domain.py @@ -11,7 +11,7 @@ from types import ModuleType, FunctionType from schema import SchemaError from starlette.applications import Starlette -from starlette.routing import Router +from starlette.routing import Router, Route import yaml @@ -19,7 +19,8 @@ import yaml from . import __version__ from .lib.constants import API_SCHEMA_DICT, ROUTER_SCHEMA, VERBS from .half_route import HalfRoute -from .lib import acl +from .lib import acl as lib_acl +from .lib.responses import PlainTextResponse from .lib.routes import JSONRoute from .lib.domain import MissingAclError, PathError, UnknownPathParameterType, \ UndefinedRoute, UndefinedFunction, get_fct_name, route_decorator @@ -88,6 +89,13 @@ class HalfDomain(Starlette): ] ) + @staticmethod + def name(module): + """ Returns the name declared in the 'domain' dict at the root of the package + """ + return module.domain['name'] + + @staticmethod def m_acl(module, acl=None): """ Returns the imported acl module for the domain module @@ -104,9 +112,14 @@ class HalfDomain(Starlette): """ m_acl = HalfDomain.m_acl(module, acl) try: - return getattr(m_acl, 'ACLS') - except AttributeError: - raise Exception(f'Missing acl.ACLS constant in module {m_acl.__package__}') + return [ + lib_acl.ACL(*elt) + for elt in getattr(m_acl, 'ACLS') + ] + except AttributeError as exc: + logger.error(exc) + raise Exception( + f'Missing acl.ACLS constant in module {m_acl.__package__}') from exc @staticmethod def acls_route(domain, module_path=None, acl=None): @@ -118,7 +131,6 @@ class HalfDomain(Starlette): [acl_name]: { callable: fct_reference, docs: fct_docstring, - result: fct_result } } """ @@ -131,18 +143,73 @@ class HalfDomain(Starlette): m_acl = HalfDomain.m_acl(module, acl) - for acl_name, doc, order in HalfDomain.acls( - module, - acl=acl): - fct = getattr(m_acl, acl_name) - d_res[acl_name] = { + for elt in HalfDomain.acls(module, acl=acl): + + fct = getattr(m_acl, elt.name) + + d_res[elt.name] = { 'callable': fct, - 'docs': doc, - 'result': None + 'docs': elt.documentation } + return d_res - # def schema(self): + @staticmethod + def acls_router(domain, module_path=None, acl=None): + """ Router of the acls routes : + + / : Same result as HalfDomain.acls_route + /acl_name : Dummy route protected by the "acl_name" acl + """ + + async def dummy_endpoint(request, *args, **kwargs): + return PlainTextResponse('') + + routes = [] + d_res = {} + + module = importlib.import_module(domain) \ + if module_path is None \ + else importlib.import_module(module_path) + + + m_acl = HalfDomain.m_acl(module, acl) + + for elt in HalfDomain.acls(module, acl=acl): + + fct = getattr(m_acl, elt.name) + + d_res[elt.name] = { + 'callable': fct, + 'docs': elt.documentation, + 'public': elt.public + } + + if elt.public: + routes.append( + Route( + f'/{elt.name}', + HalfRoute.acl_decorator( + dummy_endpoint, + params=[{'acl': fct}] + ), + methods=['GET'] + ) + ) + + d_res_under_domain_name = {} + d_res_under_domain_name[HalfDomain.name(module)] = d_res + + routes.append( + Route( + '/', + JSONRoute(d_res_under_domain_name), + methods=['GET'] + ) + ) + + return Router(routes) + @staticmethod def gen_routes(m_router: ModuleType, @@ -188,7 +255,7 @@ class HalfDomain(Starlette): return route_decorator(fct), params # TODO: Remove when using only sync functions - return acl.args_check(fct), params + return lib_acl.args_check(fct), params @staticmethod @@ -318,7 +385,7 @@ class HalfDomain(Starlette): """ yield HalfRoute('/', JSONRoute([ self.schema() ]), - [{'acl': acl.public}], + [{'acl': lib_acl.public}], 'GET' ) diff --git a/halfapi/halfapi.py b/halfapi/halfapi.py index 1ff3547..736d386 100644 --- a/halfapi/halfapi.py +++ b/halfapi/halfapi.py @@ -19,7 +19,7 @@ from starlette.applications import Starlette from starlette.authentication import UnauthenticatedUser from starlette.exceptions import HTTPException from starlette.middleware import Middleware -from starlette.routing import Route, Mount +from starlette.routing import Router, Route, Mount from starlette.requests import Request from starlette.responses import Response, PlainTextResponse from starlette.middleware.authentication import AuthenticationMiddleware @@ -178,7 +178,7 @@ class HalfAPI(Starlette): yield Route('/whoami', get_user) yield Route('/schema', schema_json) - yield Route('/acls', self.acls_route()) + yield Mount('/acls', self.acls_router()) yield Route('/version', self.version_async) """ Halfapi debug routes definition """ @@ -220,35 +220,26 @@ class HalfAPI(Starlette): time.sleep(1) sys.exit(0) - def acls_route(self): - module = None - res = { - domain: HalfDomain.acls_route( - domain, - module_path=domain_conf.get('module'), - acl=domain_conf.get('acl')) - for domain, domain_conf in self.config.get('domain', {}).items() - if isinstance(domain_conf, dict) and domain_conf.get('enabled', False) - } + def acls_router(self): + mounts = {} - async def wrapped(req, *args, **kwargs): - for domain, domain_acls in res.items(): - for acl_name, d_acl in domain_acls.items(): - fct = d_acl['callable'] - if not callable(fct): - raise Exception( - 'No callable function in acl definition %s', - acl_name) + for domain, domain_conf in self.config.get('domain', {}).items(): + if isinstance(domain_conf, dict) and domain_conf.get('enabled', False): + mounts['domain'] = HalfDomain.acls_router( + domain, + module_path=domain_conf.get('module'), + acl=domain_conf.get('acl') + ) - fct_result = fct(req, *args, **kwargs) - if callable(fct_result): - fct_result = fct()(req, *args, **kwargs) - - d_acl['result'] = fct_result - - return ORJSONResponse(res) - - return wrapped + if len(mounts) > 1: + return Router([ + Mount(f'/{domain}', acls_router) + for domain, acls_router in mounts.items() + ]) + elif len(mounts) == 1: + return Mount('/', mounts.popitem()[1]) + else: + return Router() @property def domains(self): diff --git a/halfapi/lib/acl.py b/halfapi/lib/acl.py index 82ab900..2542851 100644 --- a/halfapi/lib/acl.py +++ b/halfapi/lib/acl.py @@ -2,6 +2,7 @@ """ Base ACL module that contains generic functions for domains ACL """ +from dataclasses import dataclass from functools import wraps from json import JSONDecodeError from starlette.authentication import UnauthenticatedUser @@ -118,7 +119,24 @@ def args_check(fct): # ACLS list for doc and priorities # Write your own constant in your domain or import this one +# Format : (acl_name: str, acl_documentation: str, priority: int, [public=False]) +# +# The 'priority' integer is greater than zero and the lower values means more +# priority. For a route, the order of declaration of the ACLs should respect +# their priority. +# +# When the 'public' boolean value is True, a route protected by this ACL is +# defined on the "/halfapi/acls/acl_name", that returns an empty response and +# the status code 200 or 401. + ACLS = ( - ('private', public.__doc__, 0), - ('public', public.__doc__, 999) + ('private', private.__doc__, 0, True), + ('public', public.__doc__, 999, True) ) + +@dataclass +class ACL(): + name: str + documentation: str + priority: int + public: bool = False diff --git a/halfapi/lib/constants.py b/halfapi/lib/constants.py index b0dd1f6..60daac4 100644 --- a/halfapi/lib/constants.py +++ b/halfapi/lib/constants.py @@ -44,7 +44,7 @@ DOMAIN_SCHEMA = Schema({ Optional('version'): str, Optional('patch_release'): str, Optional('acls'): [ - [str, str, int] + [str, str, int, Optional(bool)] ] }) diff --git a/halfapi/testing/test_domain.py b/halfapi/testing/test_domain.py index ccda4e4..6f702a8 100644 --- a/halfapi/testing/test_domain.py +++ b/halfapi/testing/test_domain.py @@ -132,16 +132,21 @@ class TestDomain(TestCase): assert 'domain' in schema r = self.client.request('get', '/halfapi/acls') - """ assert r.status_code == 200 d_r = r.json() assert isinstance(d_r, dict) - assert self.domain_name in d_r.keys() ACLS = HalfDomain.acls(self.module, self.acl_path) assert len(ACLS) == len(d_r[self.domain_name]) - for acl_name in ACLS: - assert acl_name[0] in d_r[self.domain_name] - """ + for acl_rule in ACLS: + assert len(acl_rule.name) > 0 + assert acl_rule.name in d_r[self.domain_name] + assert len(acl_rule.documentation) > 0 + assert isinstance(acl_rule.priority, int) + assert acl_rule.priority >= 0 + + if acl_rule.public is True: + r = self.client.request('get', f'/halfapi/acls/{acl_rule.name}') + assert r.status_code in [200, 401] diff --git a/tests/dummy_domain/__init__.py b/tests/dummy_domain/__init__.py index 1a60aa6..2ab3c71 100644 --- a/tests/dummy_domain/__init__.py +++ b/tests/dummy_domain/__init__.py @@ -8,5 +8,3 @@ domain = { ('halfapi', '=={}'.format(halfapi_version)), ) } - - diff --git a/tests/dummy_domain/acl.py b/tests/dummy_domain/acl.py index b8f9a8f..5734a4c 100644 --- a/tests/dummy_domain/acl.py +++ b/tests/dummy_domain/acl.py @@ -1,5 +1,5 @@ from halfapi.lib import acl -from halfapi.lib.acl import public, private +from halfapi.lib.acl import public, private, ACLS from random import randint def random(*args): @@ -8,7 +8,6 @@ def random(*args): return randint(0,1) == 1 ACLS = ( - ('public', public.__doc__, 999), - ('random', random.__doc__, 10), - ('private', private.__doc__, 0) + *ACLS, + ('random', random.__doc__, 10) ) diff --git a/tests/test_dummy_domain.py b/tests/test_dummy_domain.py index 8c9b757..8aca05c 100644 --- a/tests/test_dummy_domain.py +++ b/tests/test_dummy_domain.py @@ -1,4 +1,5 @@ import importlib +from halfapi.testing.test_domain import TestDomain def test_dummy_domain(): from . import dummy_domain