[doc-schema] In module-based routers, if there is a path parameter, you can specify an OpenAPI documentation for it, or a default will be used

This commit is contained in:
maxime 2023-08-01 19:46:01 +02:00
parent 7949b3206c
commit 20563081f5
9 changed files with 113 additions and 8 deletions

View File

@ -24,6 +24,7 @@ from .half_route import HalfRoute
from .lib import acl as lib_acl from .lib import acl as lib_acl
from .lib.responses import PlainTextResponse from .lib.responses import PlainTextResponse
from .lib.routes import JSONRoute from .lib.routes import JSONRoute
from .lib.schemas import param_docstring_default
from .lib.domain import MissingAclError, PathError, UnknownPathParameterType, \ from .lib.domain import MissingAclError, PathError, UnknownPathParameterType, \
UndefinedRoute, UndefinedFunction, get_fct_name, route_decorator UndefinedRoute, UndefinedFunction, get_fct_name, route_decorator
from .lib.domain_middleware import DomainMiddleware from .lib.domain_middleware import DomainMiddleware
@ -207,7 +208,8 @@ class HalfDomain(Starlette):
def gen_routes(m_router: ModuleType, def gen_routes(m_router: ModuleType,
verb: str, verb: str,
path: List[str], path: List[str],
params: List[Dict]) -> Tuple[FunctionType, Dict]: params: List[Dict],
path_param_docstrings: Dict[str, str] = {}) -> Tuple[FunctionType, Dict]:
""" """
Returns a tuple of the function associatied to the verb and path arguments, Returns a tuple of the function associatied to the verb and path arguments,
and the dictionary of it's acls and the dictionary of it's acls
@ -239,6 +241,13 @@ class HalfDomain(Starlette):
fct_name = get_fct_name(verb, path[-1]) fct_name = get_fct_name(verb, path[-1])
if hasattr(m_router, fct_name): if hasattr(m_router, fct_name):
fct = getattr(m_router, fct_name) fct = getattr(m_router, fct_name)
fct_docstring_obj = yaml.safe_load(fct.__doc__)
if 'parameters' not in fct_docstring_obj and path_param_docstrings:
fct_docstring_obj['parameters'] = list(map(
yaml.safe_load,
path_param_docstrings.values()))
fct.__doc__ = yaml.dump(fct_docstring_obj)
else: else:
raise UndefinedFunction('{}.{}'.format(m_router.__name__, fct_name or '')) raise UndefinedFunction('{}.{}'.format(m_router.__name__, fct_name or ''))
@ -251,7 +260,7 @@ class HalfDomain(Starlette):
@staticmethod @staticmethod
def gen_router_routes(m_router, path: List[str]) -> \ def gen_router_routes(m_router, path: List[str], PATH_PARAMS={}) -> \
Iterator[Tuple[str, str, ModuleType, Coroutine, List]]: Iterator[Tuple[str, str, ModuleType, Coroutine, List]]:
""" """
Recursive generator that parses a router (or a subrouter) Recursive generator that parses a router (or a subrouter)
@ -279,17 +288,32 @@ class HalfDomain(Starlette):
yield ('/'.join(filter(lambda x: len(x) > 0, path)), yield ('/'.join(filter(lambda x: len(x) > 0, path)),
verb, verb,
m_router, m_router,
*HalfDomain.gen_routes(m_router, verb, path, params[verb]) *HalfDomain.gen_routes(m_router, verb, path, params[verb], PATH_PARAMS)
) )
for subroute in params.get('SUBROUTES', []): for subroute in params.get('SUBROUTES', []):
#logger.debug('Processing subroute **%s** - %s', subroute, m_router.__name__) subroute_module = importlib.import_module(f'.{subroute}', m_router.__name__)
param_match = re.fullmatch('^([A-Z_]+)_([a-z]+)$', subroute) param_match = re.fullmatch('^([A-Z_]+)_([a-z]+)$', subroute)
parameter_name = None
if param_match is not None: if param_match is not None:
try: try:
parameter_name = param_match.groups()[0].lower()
if parameter_name in PATH_PARAMS:
raise Exception(f'Duplicate parameter name in same path! {subroute} : {parameter_name}')
parameter_type = param_match.groups()[1]
path.append('{{{}:{}}}'.format( path.append('{{{}:{}}}'.format(
param_match.groups()[0].lower(), parameter_name,
param_match.groups()[1])) parameter_type,
)
)
try:
PATH_PARAMS[parameter_name] = subroute_module.param_docstring
except AttributeError as exc:
PATH_PARAMS[parameter_name] = param_docstring_default(parameter_name, parameter_type)
except AssertionError as exc: except AssertionError as exc:
raise UnknownPathParameterType(subroute) from exc raise UnknownPathParameterType(subroute) from exc
else: else:
@ -297,14 +321,19 @@ class HalfDomain(Starlette):
try: try:
yield from HalfDomain.gen_router_routes( yield from HalfDomain.gen_router_routes(
importlib.import_module(f'.{subroute}', m_router.__name__), subroute_module,
path) path,
PATH_PARAMS
)
except ImportError as exc: except ImportError as exc:
logger.error('Failed to import subroute **{%s}**', subroute) logger.error('Failed to import subroute **{%s}**', subroute)
raise exc raise exc
path.pop() path.pop()
if parameter_name:
PATH_PARAMS.pop(parameter_name)
path.pop() path.pop()

View File

@ -13,6 +13,7 @@ import os
import importlib import importlib
from typing import Dict, Coroutine, List from typing import Dict, Coroutine, List
from types import ModuleType from types import ModuleType
import yaml
from starlette.schemas import SchemaGenerator from starlette.schemas import SchemaGenerator
@ -114,3 +115,23 @@ def schema_csv_dict(csv: List[str], prefix='/') -> Dict:
}) })
return schema_d return schema_d
def param_docstring_default(name, type):
""" Returns a default docstring in OpenAPI format for a path parameter
"""
type_map = {
'str': 'string',
'uuid': 'string',
'path': 'string',
'int': 'number',
'float': 'number'
}
return yaml.dump({
'name': name,
'in': 'path',
'description': f'default description for path parameter {name}',
'required': True,
'schema': {
'type': type_map[type]
}
})

View File

@ -0,0 +1,8 @@
param_docstring = """
name: second
in: path
description: second parameter description test
required: true
schema:
type: string
"""

View File

@ -0,0 +1,20 @@
from uuid import UUID
from halfapi.lib import acl
ACLS = {
'GET': [{'acl': acl.public}]
}
def get(first, second, third):
"""
description: a Test route for path parameters
responses:
200:
description: The test passed!
500:
description: The test did not pass :(
"""
assert isintance(first, str)
assert isintance(second, UUID)
assert isintance(third, int)
return ''

View File

@ -1,6 +1,8 @@
import pytest import pytest
from halfapi.testing.test_domain import TestDomain from halfapi.testing.test_domain import TestDomain
from pprint import pprint from pprint import pprint
import logging
logger = logging.getLogger()
class TestDummyDomain(TestDomain): class TestDummyDomain(TestDomain):
from .dummy_domain import domain from .dummy_domain import domain
@ -77,3 +79,28 @@ class TestDummyDomain(TestDomain):
res = self.client.request('post', '/arguments', json={ **arg_dict, 'z': True}) res = self.client.request('post', '/arguments', json={ **arg_dict, 'z': True})
assert res.json() == {**arg_dict, 'z': True} assert res.json() == {**arg_dict, 'z': True}
def test_schema_path_params(self):
res = self.client.request('get', '/halfapi/schema')
schema = res.json()
logger.debug(schema)
assert len(schema['paths']) > 0
route = schema['paths']['/path_params/{first}/one/{second}/two/{third}']
assert 'parameters' in route['get']
parameters = route['get']['parameters']
assert len(parameters) == 3
param_map = {
elt['name']: elt
for elt in parameters
}
assert param_map['second']['description'] == 'second parameter description test'