diff --git a/Pipfile b/Pipfile index c68c670..f3d7049 100644 --- a/Pipfile +++ b/Pipfile @@ -12,7 +12,7 @@ build = "*" [packages] click = ">=7.1,<8" -starlette = ">=0.14,<0.15" +starlette = ">=0.15,<0.16" uvicorn = ">=0.13,<1" orjson = ">=3.4.7,<4" pyjwt = ">=2.0.1,<3" diff --git a/halfapi/app.py b/halfapi/app.py index bd7025f..90c963a 100644 --- a/halfapi/app.py +++ b/halfapi/app.py @@ -51,12 +51,12 @@ class HalfAPI: from halfapi.conf import CONFIG, SECRET, PRODUCTION, DOMAINS - routes = [ Route('/', get_api_routes) ] + routes = [ Route('/', get_api_routes(DOMAINS)) ] routes += [ Route('/halfapi/schema', schema_json), - Route('/halfapi/acls', get_acls) + Route('/halfapi/acls', get_acls), ] routes += Route('/halfapi/current_user', lambda request, *args, **kwargs: @@ -74,11 +74,11 @@ class HalfAPI: for route in gen_starlette_routes(DOMAINS): routes.append(route) - for domain in DOMAINS: + for domain, m_domain in DOMAINS.items(): routes.append( Route( f'/{domain}', - get_api_domain_routes(domain) + get_api_domain_routes(m_domain) ) ) diff --git a/halfapi/lib/domain_middleware.py b/halfapi/lib/domain_middleware.py index 9179ee8..260cfe7 100644 --- a/halfapi/lib/domain_middleware.py +++ b/halfapi/lib/domain_middleware.py @@ -29,53 +29,9 @@ class DomainMiddleware(BaseHTTPMiddleware): super().__init__(app) self.config = config self.domains = {} - self.api = {} - self.acl = {} self.request = None - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - """ - Scans routes and acls of the domain in the first part of the path - """ - domain = scope['path'].split('/')[1] - - self.domains = self.config.get('domains', {}) - - if len(domain) == 0 or domain == 'halfapi': - for domain in self.domains: - self.api[domain], self.acl[domain] = api_routes(self.domains[domain]) - elif domain in self.domains: - self.api[domain], self.acl[domain] = api_routes(self.domains[domain]) - else: - logger.error('domain not in self.domains %s / %s', - scope['path'], - self.domains) - - scope_ = scope.copy() - scope_['domains'] = self.domains - scope_['api'] = self.api - scope_['acl'] = self.acl - - cur_path = URL(scope=scope).path - if cur_path[0] == '/': - current_domain = cur_path[1:].split('/')[0] - else: - current_domain = cur_path.split('/')[0] - - try: - scope_['config'] = self.config.copy() - except configparser.NoSectionError: - logger.debug( - 'No specific configuration for domain **%s**', current_domain) - scope_['config'] = {} - - - self.request = Request(scope_, receive) - response = await self.dispatch(self.request, self.call_next) - await response(scope_, receive, send) - - async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: """ @@ -84,17 +40,19 @@ class DomainMiddleware(BaseHTTPMiddleware): response = await call_next(request) - if 'acl_pass' in self.request.scope: + if 'acl_pass' in request.scope: # Set the http header "x-acl" if an acl was used on the route - response.headers['x-acl'] = self.request.scope['acl_pass'] + response.headers['x-acl'] = request.scope['acl_pass'] - if 'args' in self.request.scope: + if 'args' in request.scope: # Set the http headers "x-args-required" and "x-args-optional" - if 'required' in self.request.scope['args']: + if 'required' in request.scope['args']: response.headers['x-args-required'] = \ - ','.join(self.request.scope['args']['required']) - if 'optional' in self.request.scope['args']: + ','.join(request.scope['args']['required']) + if 'optional' in request.scope['args']: response.headers['x-args-optional'] = \ - ','.join(self.request.scope['args']['optional']) + ','.join(request.scope['args']['optional']) + + return response diff --git a/halfapi/lib/schemas.py b/halfapi/lib/schemas.py index 82abc29..7cbf405 100644 --- a/halfapi/lib/schemas.py +++ b/halfapi/lib/schemas.py @@ -10,14 +10,15 @@ Constant : SCHEMAS (starlette.schemas.SchemaGenerator) """ +import os import logging -from typing import Dict +from typing import Dict, Coroutine +from types import ModuleType from starlette.schemas import SchemaGenerator -from starlette.exceptions import HTTPException from .. import __version__ -from .routes import gen_starlette_routes, api_acls +from .routes import gen_starlette_routes, api_routes, api_acls from .responses import ORJSONResponse logger = logging.getLogger('uvicorn.asgi') @@ -25,7 +26,7 @@ SCHEMAS = SchemaGenerator( {"openapi": "3.0.0", "info": {"title": "HalfAPI", "version": __version__}} ) -async def get_api_routes(request, *args, **kwargs): +def get_api_routes(domains: Dict[str, ModuleType]) -> Coroutine: """ description: Returns the current API routes dictionary as a JSON object @@ -63,18 +64,26 @@ async def get_api_routes(request, *args, **kwargs): } } """ - return ORJSONResponse(request.scope['api']) + routes = { + domain: api_routes(m_domain)[0] + for domain, m_domain in domains.items() + } + + async def wrapped(request, *args, **kwargs): + return ORJSONResponse(routes) + + return wrapped + + +def get_api_domain_routes(m_domain: ModuleType) -> Coroutine: + routes, _ = api_routes(m_domain) -def get_api_domain_routes(domain): async def wrapped(request, *args, **kwargs): """ - description: Returns the current API routes dictionary for a specific + description: Returns the current API routes dictionary for a specific domain as a JSON object """ - if domain in request.scope['api']: - return ORJSONResponse(request.scope['api'][domain]) - else: - raise HTTPException(404) + return ORJSONResponse(routes) return wrapped diff --git a/setup.py b/setup.py index 146ec16..9ea8de4 100755 --- a/setup.py +++ b/setup.py @@ -44,7 +44,7 @@ setup( python_requires=">=3.8", install_requires=[ "PyJWT>=2.0.1", - "starlette>=0.14,<0.15", + "starlette>=0.15,<0.16", "click>=7.1,<8", "uvicorn>=0.13,<1", "orjson>=3.4.7,<4",