[feature] changes in the ACLs result availability

This commit is contained in:
Maxime Alves LIRMM@home 2023-02-23 02:51:28 +01:00 committed by Maxime Alves LIRMM
parent 0a385661b9
commit 8d254bafa0
8 changed files with 138 additions and 59 deletions

View File

@ -11,7 +11,7 @@ from types import ModuleType, FunctionType
from schema import SchemaError from schema import SchemaError
from starlette.applications import Starlette from starlette.applications import Starlette
from starlette.routing import Router from starlette.routing import Router, Route
import yaml import yaml
@ -19,7 +19,8 @@ import yaml
from . import __version__ from . import __version__
from .lib.constants import API_SCHEMA_DICT, ROUTER_SCHEMA, VERBS from .lib.constants import API_SCHEMA_DICT, ROUTER_SCHEMA, VERBS
from .half_route import HalfRoute from .half_route import HalfRoute
from .lib import acl from .lib import acl as lib_acl
from .lib.responses import PlainTextResponse
from .lib.routes import JSONRoute from .lib.routes import JSONRoute
from .lib.domain import MissingAclError, PathError, UnknownPathParameterType, \ from .lib.domain import MissingAclError, PathError, UnknownPathParameterType, \
UndefinedRoute, UndefinedFunction, get_fct_name, route_decorator UndefinedRoute, UndefinedFunction, get_fct_name, route_decorator
@ -88,6 +89,13 @@ class HalfDomain(Starlette):
] ]
) )
@staticmethod
def name(module):
""" Returns the name declared in the 'domain' dict at the root of the package
"""
return module.domain['name']
@staticmethod @staticmethod
def m_acl(module, acl=None): def m_acl(module, acl=None):
""" Returns the imported acl module for the domain module """ Returns the imported acl module for the domain module
@ -104,9 +112,14 @@ class HalfDomain(Starlette):
""" """
m_acl = HalfDomain.m_acl(module, acl) m_acl = HalfDomain.m_acl(module, acl)
try: try:
return getattr(m_acl, 'ACLS') return [
except AttributeError: lib_acl.ACL(*elt)
raise Exception(f'Missing acl.ACLS constant in module {m_acl.__package__}') for elt in getattr(m_acl, 'ACLS')
]
except AttributeError as exc:
logger.error(exc)
raise Exception(
f'Missing acl.ACLS constant in module {m_acl.__package__}') from exc
@staticmethod @staticmethod
def acls_route(domain, module_path=None, acl=None): def acls_route(domain, module_path=None, acl=None):
@ -118,7 +131,6 @@ class HalfDomain(Starlette):
[acl_name]: { [acl_name]: {
callable: fct_reference, callable: fct_reference,
docs: fct_docstring, docs: fct_docstring,
result: fct_result
} }
} }
""" """
@ -131,18 +143,73 @@ class HalfDomain(Starlette):
m_acl = HalfDomain.m_acl(module, acl) m_acl = HalfDomain.m_acl(module, acl)
for acl_name, doc, order in HalfDomain.acls( for elt in HalfDomain.acls(module, acl=acl):
module,
acl=acl): fct = getattr(m_acl, elt.name)
fct = getattr(m_acl, acl_name)
d_res[acl_name] = { d_res[elt.name] = {
'callable': fct, 'callable': fct,
'docs': doc, 'docs': elt.documentation
'result': None
} }
return d_res return d_res
# def schema(self): @staticmethod
def acls_router(domain, module_path=None, acl=None):
""" Router of the acls routes :
/ : Same result as HalfDomain.acls_route
/acl_name : Dummy route protected by the "acl_name" acl
"""
async def dummy_endpoint(request, *args, **kwargs):
return PlainTextResponse('')
routes = []
d_res = {}
module = importlib.import_module(domain) \
if module_path is None \
else importlib.import_module(module_path)
m_acl = HalfDomain.m_acl(module, acl)
for elt in HalfDomain.acls(module, acl=acl):
fct = getattr(m_acl, elt.name)
d_res[elt.name] = {
'callable': fct,
'docs': elt.documentation,
'public': elt.public
}
if elt.public:
routes.append(
Route(
f'/{elt.name}',
HalfRoute.acl_decorator(
dummy_endpoint,
params=[{'acl': fct}]
),
methods=['GET']
)
)
d_res_under_domain_name = {}
d_res_under_domain_name[HalfDomain.name(module)] = d_res
routes.append(
Route(
'/',
JSONRoute(d_res_under_domain_name),
methods=['GET']
)
)
return Router(routes)
@staticmethod @staticmethod
def gen_routes(m_router: ModuleType, def gen_routes(m_router: ModuleType,
@ -188,7 +255,7 @@ class HalfDomain(Starlette):
return route_decorator(fct), params return route_decorator(fct), params
# TODO: Remove when using only sync functions # TODO: Remove when using only sync functions
return acl.args_check(fct), params return lib_acl.args_check(fct), params
@staticmethod @staticmethod
@ -318,7 +385,7 @@ class HalfDomain(Starlette):
""" """
yield HalfRoute('/', yield HalfRoute('/',
JSONRoute([ self.schema() ]), JSONRoute([ self.schema() ]),
[{'acl': acl.public}], [{'acl': lib_acl.public}],
'GET' 'GET'
) )

View File

@ -19,7 +19,7 @@ from starlette.applications import Starlette
from starlette.authentication import UnauthenticatedUser from starlette.authentication import UnauthenticatedUser
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from starlette.middleware import Middleware from starlette.middleware import Middleware
from starlette.routing import Route, Mount from starlette.routing import Router, Route, Mount
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import Response, PlainTextResponse from starlette.responses import Response, PlainTextResponse
from starlette.middleware.authentication import AuthenticationMiddleware from starlette.middleware.authentication import AuthenticationMiddleware
@ -178,7 +178,7 @@ class HalfAPI(Starlette):
yield Route('/whoami', get_user) yield Route('/whoami', get_user)
yield Route('/schema', schema_json) yield Route('/schema', schema_json)
yield Route('/acls', self.acls_route()) yield Mount('/acls', self.acls_router())
yield Route('/version', self.version_async) yield Route('/version', self.version_async)
""" Halfapi debug routes definition """ Halfapi debug routes definition
""" """
@ -220,35 +220,26 @@ class HalfAPI(Starlette):
time.sleep(1) time.sleep(1)
sys.exit(0) sys.exit(0)
def acls_route(self): def acls_router(self):
module = None mounts = {}
res = {
domain: HalfDomain.acls_route(
domain,
module_path=domain_conf.get('module'),
acl=domain_conf.get('acl'))
for domain, domain_conf in self.config.get('domain', {}).items()
if isinstance(domain_conf, dict) and domain_conf.get('enabled', False)
}
async def wrapped(req, *args, **kwargs): for domain, domain_conf in self.config.get('domain', {}).items():
for domain, domain_acls in res.items(): if isinstance(domain_conf, dict) and domain_conf.get('enabled', False):
for acl_name, d_acl in domain_acls.items(): mounts['domain'] = HalfDomain.acls_router(
fct = d_acl['callable'] domain,
if not callable(fct): module_path=domain_conf.get('module'),
raise Exception( acl=domain_conf.get('acl')
'No callable function in acl definition %s', )
acl_name)
fct_result = fct(req, *args, **kwargs) if len(mounts) > 1:
if callable(fct_result): return Router([
fct_result = fct()(req, *args, **kwargs) Mount(f'/{domain}', acls_router)
for domain, acls_router in mounts.items()
d_acl['result'] = fct_result ])
elif len(mounts) == 1:
return ORJSONResponse(res) return Mount('/', mounts.popitem()[1])
else:
return wrapped return Router()
@property @property
def domains(self): def domains(self):

View File

@ -2,6 +2,7 @@
""" """
Base ACL module that contains generic functions for domains ACL Base ACL module that contains generic functions for domains ACL
""" """
from dataclasses import dataclass
from functools import wraps from functools import wraps
from json import JSONDecodeError from json import JSONDecodeError
from starlette.authentication import UnauthenticatedUser from starlette.authentication import UnauthenticatedUser
@ -118,7 +119,24 @@ def args_check(fct):
# ACLS list for doc and priorities # ACLS list for doc and priorities
# Write your own constant in your domain or import this one # Write your own constant in your domain or import this one
# Format : (acl_name: str, acl_documentation: str, priority: int, [public=False])
#
# The 'priority' integer is greater than zero and the lower values means more
# priority. For a route, the order of declaration of the ACLs should respect
# their priority.
#
# When the 'public' boolean value is True, a route protected by this ACL is
# defined on the "/halfapi/acls/acl_name", that returns an empty response and
# the status code 200 or 401.
ACLS = ( ACLS = (
('private', public.__doc__, 0), ('private', private.__doc__, 0, True),
('public', public.__doc__, 999) ('public', public.__doc__, 999, True)
) )
@dataclass
class ACL():
name: str
documentation: str
priority: int
public: bool = False

View File

@ -44,7 +44,7 @@ DOMAIN_SCHEMA = Schema({
Optional('version'): str, Optional('version'): str,
Optional('patch_release'): str, Optional('patch_release'): str,
Optional('acls'): [ Optional('acls'): [
[str, str, int] [str, str, int, Optional(bool)]
] ]
}) })

View File

@ -132,16 +132,21 @@ class TestDomain(TestCase):
assert 'domain' in schema assert 'domain' in schema
r = self.client.request('get', '/halfapi/acls') r = self.client.request('get', '/halfapi/acls')
"""
assert r.status_code == 200 assert r.status_code == 200
d_r = r.json() d_r = r.json()
assert isinstance(d_r, dict) assert isinstance(d_r, dict)
assert self.domain_name in d_r.keys() assert self.domain_name in d_r.keys()
ACLS = HalfDomain.acls(self.module, self.acl_path) ACLS = HalfDomain.acls(self.module, self.acl_path)
assert len(ACLS) == len(d_r[self.domain_name]) assert len(ACLS) == len(d_r[self.domain_name])
for acl_name in ACLS: for acl_rule in ACLS:
assert acl_name[0] in d_r[self.domain_name] assert len(acl_rule.name) > 0
""" assert acl_rule.name in d_r[self.domain_name]
assert len(acl_rule.documentation) > 0
assert isinstance(acl_rule.priority, int)
assert acl_rule.priority >= 0
if acl_rule.public is True:
r = self.client.request('get', f'/halfapi/acls/{acl_rule.name}')
assert r.status_code in [200, 401]

View File

@ -8,5 +8,3 @@ domain = {
('halfapi', '=={}'.format(halfapi_version)), ('halfapi', '=={}'.format(halfapi_version)),
) )
} }

View File

@ -1,5 +1,5 @@
from halfapi.lib import acl from halfapi.lib import acl
from halfapi.lib.acl import public, private from halfapi.lib.acl import public, private, ACLS
from random import randint from random import randint
def random(*args): def random(*args):
@ -8,7 +8,6 @@ def random(*args):
return randint(0,1) == 1 return randint(0,1) == 1
ACLS = ( ACLS = (
('public', public.__doc__, 999), *ACLS,
('random', random.__doc__, 10), ('random', random.__doc__, 10)
('private', private.__doc__, 0)
) )

View File

@ -1,4 +1,5 @@
import importlib import importlib
from halfapi.testing.test_domain import TestDomain
def test_dummy_domain(): def test_dummy_domain():
from . import dummy_domain from . import dummy_domain