Fix #21 by simplifying DomainMiddleware

Tests are passing, but we loose the by-domain configuration (#19)

Squashed commit of the following:

commit d75fafcb9a043ac2540b2ac135704721b002d3c0
Author: Maxime Alves LIRMM <maxime.alves@lirmm.fr>
Date:   Thu Sep 2 14:40:05 2021 +0200

    fix #21

commit 38c59e4ea3b40bd230f2add2bb0e05772913c097
Author: Maxime Alves LIRMM <maxime.alves@lirmm.fr>
Date:   Thu Sep 2 01:13:51 2021 +0200

    [deps] starlette 0.15 (breaks tests)

    FAILED tests/test_debug_routes.py::test_current_user - AttributeError: 'DomainMiddleware' object has no attribute 'call_next'
    FAILED tests/test_debug_routes.py::test_log - AttributeError: 'DomainMiddleware' object has no attribute 'call_next'
    FAILED tests/test_debug_routes.py::test_error - AttributeError:
    'DomainMiddleware' object has no attribute 'call_next'
    FAILED tests/test_dummy_project_router.py::test_get_route - AttributeError: 'DomainMiddleware' object has no attribute 'call_next'
    FAILED tests/test_dummy_project_router.py::test_delete_route - AttributeError: 'DomainMiddleware' object has no attribute 'call_next'
    FAILED tests/test_lib_schemas.py::test_get_api_routes - AttributeError: 'DomainMiddleware' object has no attribute 'call_next'
    FAILED tests/test_lib_schemas.py::test_get_schema_route - AttributeError: 'DomainMiddleware' object has no attribute 'call_next'
    FAILED tests/test_lib_schemas.py::test_get_api_dummy_domain_routes - AttributeError: 'DomainMiddleware' object has no attribute 'call_next'
This commit is contained in:
Maxime Alves LIRMM 2021-09-02 14:45:06 +02:00
parent 865a4dffd1
commit bc556854ac
5 changed files with 35 additions and 68 deletions

View File

@ -12,7 +12,7 @@ build = "*"
[packages]
click = ">=7.1,<8"
starlette = ">=0.14,<0.15"
starlette = ">=0.15,<0.16"
uvicorn = ">=0.13,<1"
orjson = ">=3.4.7,<4"
pyjwt = ">=2.0.1,<3"

View File

@ -51,12 +51,12 @@ class HalfAPI:
from halfapi.conf import CONFIG, SECRET, PRODUCTION, DOMAINS
routes = [ Route('/', get_api_routes) ]
routes = [ Route('/', get_api_routes(DOMAINS)) ]
routes += [
Route('/halfapi/schema', schema_json),
Route('/halfapi/acls', get_acls)
Route('/halfapi/acls', get_acls),
]
routes += Route('/halfapi/current_user', lambda request, *args, **kwargs:
@ -74,11 +74,11 @@ class HalfAPI:
for route in gen_starlette_routes(DOMAINS):
routes.append(route)
for domain in DOMAINS:
for domain, m_domain in DOMAINS.items():
routes.append(
Route(
f'/{domain}',
get_api_domain_routes(domain)
get_api_domain_routes(m_domain)
)
)

View File

@ -29,53 +29,9 @@ class DomainMiddleware(BaseHTTPMiddleware):
super().__init__(app)
self.config = config
self.domains = {}
self.api = {}
self.acl = {}
self.request = None
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""
Scans routes and acls of the domain in the first part of the path
"""
domain = scope['path'].split('/')[1]
self.domains = self.config.get('domains', {})
if len(domain) == 0 or domain == 'halfapi':
for domain in self.domains:
self.api[domain], self.acl[domain] = api_routes(self.domains[domain])
elif domain in self.domains:
self.api[domain], self.acl[domain] = api_routes(self.domains[domain])
else:
logger.error('domain not in self.domains %s / %s',
scope['path'],
self.domains)
scope_ = scope.copy()
scope_['domains'] = self.domains
scope_['api'] = self.api
scope_['acl'] = self.acl
cur_path = URL(scope=scope).path
if cur_path[0] == '/':
current_domain = cur_path[1:].split('/')[0]
else:
current_domain = cur_path.split('/')[0]
try:
scope_['config'] = self.config.copy()
except configparser.NoSectionError:
logger.debug(
'No specific configuration for domain **%s**', current_domain)
scope_['config'] = {}
self.request = Request(scope_, receive)
response = await self.dispatch(self.request, self.call_next)
await response(scope_, receive, send)
async def dispatch(self, request: Request,
call_next: RequestResponseEndpoint) -> Response:
"""
@ -84,17 +40,19 @@ class DomainMiddleware(BaseHTTPMiddleware):
response = await call_next(request)
if 'acl_pass' in self.request.scope:
if 'acl_pass' in request.scope:
# Set the http header "x-acl" if an acl was used on the route
response.headers['x-acl'] = self.request.scope['acl_pass']
response.headers['x-acl'] = request.scope['acl_pass']
if 'args' in self.request.scope:
if 'args' in request.scope:
# Set the http headers "x-args-required" and "x-args-optional"
if 'required' in self.request.scope['args']:
if 'required' in request.scope['args']:
response.headers['x-args-required'] = \
','.join(self.request.scope['args']['required'])
if 'optional' in self.request.scope['args']:
','.join(request.scope['args']['required'])
if 'optional' in request.scope['args']:
response.headers['x-args-optional'] = \
','.join(self.request.scope['args']['optional'])
','.join(request.scope['args']['optional'])
return response

View File

@ -10,14 +10,15 @@ Constant :
SCHEMAS (starlette.schemas.SchemaGenerator)
"""
import os
import logging
from typing import Dict
from typing import Dict, Coroutine
from types import ModuleType
from starlette.schemas import SchemaGenerator
from starlette.exceptions import HTTPException
from .. import __version__
from .routes import gen_starlette_routes, api_acls
from .routes import gen_starlette_routes, api_routes, api_acls
from .responses import ORJSONResponse
logger = logging.getLogger('uvicorn.asgi')
@ -25,7 +26,7 @@ SCHEMAS = SchemaGenerator(
{"openapi": "3.0.0", "info": {"title": "HalfAPI", "version": __version__}}
)
async def get_api_routes(request, *args, **kwargs):
def get_api_routes(domains: Dict[str, ModuleType]) -> Coroutine:
"""
description: Returns the current API routes dictionary
as a JSON object
@ -63,18 +64,26 @@ async def get_api_routes(request, *args, **kwargs):
}
}
"""
return ORJSONResponse(request.scope['api'])
routes = {
domain: api_routes(m_domain)[0]
for domain, m_domain in domains.items()
}
async def wrapped(request, *args, **kwargs):
return ORJSONResponse(routes)
return wrapped
def get_api_domain_routes(m_domain: ModuleType) -> Coroutine:
routes, _ = api_routes(m_domain)
def get_api_domain_routes(domain):
async def wrapped(request, *args, **kwargs):
"""
description: Returns the current API routes dictionary for a specific
domain as a JSON object
"""
if domain in request.scope['api']:
return ORJSONResponse(request.scope['api'][domain])
else:
raise HTTPException(404)
return ORJSONResponse(routes)
return wrapped

View File

@ -44,7 +44,7 @@ setup(
python_requires=">=3.8",
install_requires=[
"PyJWT>=2.0.1",
"starlette>=0.14,<0.15",
"starlette>=0.15,<0.16",
"click>=7.1,<8",
"uvicorn>=0.13,<1",
"orjson>=3.4.7,<4",