[lib.schemas] add schema_csv_dict and schema_to_csv function

This commit is contained in:
Maxime Alves LIRMM 2021-11-30 00:39:46 +01:00
parent c9639ddbc0
commit 1fda2ab15d
2 changed files with 156 additions and 7 deletions

View File

@ -11,12 +11,14 @@ Constant :
""" """
import os import os
from typing import Dict, Coroutine import importlib
from typing import Dict, Coroutine, List
from types import ModuleType from types import ModuleType
from starlette.schemas import SchemaGenerator from starlette.schemas import SchemaGenerator
from .. import __version__ from .. import __version__
from .domain import gen_router_routes
from ..logging import logger from ..logging import logger
from .routes import gen_starlette_routes, api_routes, api_acls from .routes import gen_starlette_routes, api_routes, api_acls
from .responses import ORJSONResponse from .responses import ORJSONResponse
@ -32,11 +34,14 @@ def get_api_routes(domains: Dict[str, ModuleType]) -> Dict:
example: { example: {
"dummy_domain": { "dummy_domain": {
"abc/alphabet": { "abc/alphabet": {
"GET": [ "GET": {
{ "docs": "",
"acl": "public" "acls": [
} {
] "acl": "public"
}
]
}
}, },
"abc/alphabet/{test:uuid}": { "abc/alphabet/{test:uuid}": {
"GET": [ "GET": [
@ -115,3 +120,131 @@ async def get_acls(request, *args, **kwargs):
result of the acls functions result of the acls functions
""" """
return ORJSONResponse(api_acls(request)) return ORJSONResponse(api_acls(request))
def schema_to_csv(module_name, header=True) -> str:
"""
Returns a string composed where each line is a set of path, verb, function,
acl, required arguments, optional arguments and output variables. Those
lines should be unique in the result string;
"""
# retrieve module
mod = importlib.import_module(module_name)
lines = []
if header:
lines.append([
'path',
'method',
'module:function',
'acl',
'args_required', 'args_optional',
'out'
])
for path, verb, m_router, fct, parameters in gen_router_routes(mod, []):
""" Call route generator (.lib.domain)
"""
for param in parameters:
""" Each parameters row represents rules for a specific ACL
"""
fields = (
f'/{path}',
verb,
f'{m_router.__name__}:{fct.__name__}',
param['acl'].__name__,
','.join((param.get('args', {}).get('required', set()))),
','.join((param.get('args', {}).get('optional', set()))),
','.join((param.get('out', set())))
)
if fields[0:4] in map(lambda elt: elt[0:4], lines):
raise Exception(
'Already defined acl for this route \
(path: {}, verb: {}, acl: {})'.format(
path,
verb,
param['acl'].__name__
)
)
lines.append(fields)
return '\n'.join(
[ ';'.join(fields) for fields in lines ]
)
def schema_csv_dict(csv: List[str]) -> Dict:
package = None
schema_d = {}
modules_d = {}
acl_modules_d = {}
for line in csv:
path, verb, router, acl_fct_name, args_req, args_opt, out = line.strip().split(';')
logger.info('schema_csv_dict %s %s %s', path, args_req, args_opt)
if path not in schema_d:
schema_d[path] = {}
if verb not in schema_d[path]:
mod_str = router.split(':')[0]
fct_str = router.split(':')[1]
if mod_str not in modules_d:
modules_d[mod_str] = importlib.import_module(mod_str)
if not hasattr(modules_d[mod_str], fct_str):
raise Exception(
'Missing function in module. module:{} function:{}'.format(
router, fct_str
)
)
fct = getattr(modules_d[mod_str], fct_str)
schema_d[path][verb] = {
'module': modules_d[mod_str],
'fct': fct,
'acls': []
}
if package and router.split('.')[0] != package:
raise Exception('Multi-domain is not allowed in that mode')
package = router.split('.')[0]
if not len(package):
raise Exception(
'Empty package name (router=%s)'.format(router))
acl_package = f'{package}.acl'
if acl_package not in acl_modules_d:
if acl_package not in modules_d:
modules_d[acl_package] = importlib.import_module(acl_package)
if not hasattr(modules_d[acl_package], acl_fct_name):
raise Exception(
'Missing acl function in module. module:{} acl:{}'.format(
acl_package, acl_fct_name
)
)
acl_modules_d[acl_package] = {}
acl_modules_d[acl_package][acl_fct_name] = getattr(modules_d[acl_package], acl_fct_name)
schema_d[path][verb]['acls'].append({
'acl': acl_modules_d[acl_package][acl_fct_name],
'args': {
'required': set(args_req.split(',')) if len(args_req) else set(),
'optional': set(args_opt.split(',')) if len(args_opt) else set()
}
})
return schema_d

View File

@ -5,7 +5,9 @@ from starlette.authentication import (
AuthenticationBackend, AuthenticationError, BaseUser, AuthCredentials, AuthenticationBackend, AuthenticationError, BaseUser, AuthCredentials,
UnauthenticatedUser) UnauthenticatedUser)
from halfapi.lib.schemas import schema_dict_dom from halfapi.lib.schemas import schema_dict_dom, schema_to_csv, schema_csv_dict
from halfapi.lib.constants import DOMAIN_SCHEMA
from halfapi import __version__ from halfapi import __version__
def test_schemas_dict_dom(): def test_schemas_dict_dom():
@ -17,11 +19,13 @@ def test_schemas_dict_dom():
def test_get_api_routes(project_runner, application_debug): def test_get_api_routes(project_runner, application_debug):
c = TestClient(application_debug) c = TestClient(application_debug)
r = c.get('/') r = c.get('/')
assert isinstance(c, TestClient)
d_r = r.json() d_r = r.json()
assert isinstance(d_r, dict) assert isinstance(d_r, dict)
def test_get_schema_route(project_runner, application_debug): def test_get_schema_route(project_runner, application_debug):
c = TestClient(application_debug) c = TestClient(application_debug)
assert isinstance(c, TestClient)
r = c.get('/halfapi/schema') r = c.get('/halfapi/schema')
d_r = r.json() d_r = r.json()
assert isinstance(d_r, dict) assert isinstance(d_r, dict)
@ -42,3 +46,15 @@ def test_get_api_dummy_domain_routes(application_domain, routers):
assert 'GET' in d_r['abc/alphabet'] assert 'GET' in d_r['abc/alphabet']
assert len(d_r['abc/alphabet']['GET']) > 0 assert len(d_r['abc/alphabet']['GET']) > 0
assert 'acls' in d_r['abc/alphabet']['GET'] assert 'acls' in d_r['abc/alphabet']['GET']
def test_schema_to_csv():
csv = schema_to_csv('dummy_domain.routers', False)
assert isinstance(csv, str)
assert len(csv.split('\n')) > 0
def test_schema_csv_dict():
csv = schema_to_csv('dummy_domain.routers', False)
assert isinstance(csv, str)
schema_d = schema_csv_dict(csv.split('\n'))
assert isinstance(schema_d, dict)