Fix #21 by simplifying DomainMiddleware
Tests are passing, but we loose the by-domain configuration (#19) Squashed commit of the following: commit d75fafcb9a043ac2540b2ac135704721b002d3c0 Author: Maxime Alves LIRMM <maxime.alves@lirmm.fr> Date: Thu Sep 2 14:40:05 2021 +0200 fix #21 commit 38c59e4ea3b40bd230f2add2bb0e05772913c097 Author: Maxime Alves LIRMM <maxime.alves@lirmm.fr> Date: Thu Sep 2 01:13:51 2021 +0200 [deps] starlette 0.15 (breaks tests) FAILED tests/test_debug_routes.py::test_current_user - AttributeError: 'DomainMiddleware' object has no attribute 'call_next' FAILED tests/test_debug_routes.py::test_log - AttributeError: 'DomainMiddleware' object has no attribute 'call_next' FAILED tests/test_debug_routes.py::test_error - AttributeError: 'DomainMiddleware' object has no attribute 'call_next' FAILED tests/test_dummy_project_router.py::test_get_route - AttributeError: 'DomainMiddleware' object has no attribute 'call_next' FAILED tests/test_dummy_project_router.py::test_delete_route - AttributeError: 'DomainMiddleware' object has no attribute 'call_next' FAILED tests/test_lib_schemas.py::test_get_api_routes - AttributeError: 'DomainMiddleware' object has no attribute 'call_next' FAILED tests/test_lib_schemas.py::test_get_schema_route - AttributeError: 'DomainMiddleware' object has no attribute 'call_next' FAILED tests/test_lib_schemas.py::test_get_api_dummy_domain_routes - AttributeError: 'DomainMiddleware' object has no attribute 'call_next'
This commit is contained in:
parent
865a4dffd1
commit
bc556854ac
2
Pipfile
2
Pipfile
|
@ -12,7 +12,7 @@ build = "*"
|
|||
|
||||
[packages]
|
||||
click = ">=7.1,<8"
|
||||
starlette = ">=0.14,<0.15"
|
||||
starlette = ">=0.15,<0.16"
|
||||
uvicorn = ">=0.13,<1"
|
||||
orjson = ">=3.4.7,<4"
|
||||
pyjwt = ">=2.0.1,<3"
|
||||
|
|
|
@ -51,12 +51,12 @@ class HalfAPI:
|
|||
from halfapi.conf import CONFIG, SECRET, PRODUCTION, DOMAINS
|
||||
|
||||
|
||||
routes = [ Route('/', get_api_routes) ]
|
||||
routes = [ Route('/', get_api_routes(DOMAINS)) ]
|
||||
|
||||
|
||||
routes += [
|
||||
Route('/halfapi/schema', schema_json),
|
||||
Route('/halfapi/acls', get_acls)
|
||||
Route('/halfapi/acls', get_acls),
|
||||
]
|
||||
|
||||
routes += Route('/halfapi/current_user', lambda request, *args, **kwargs:
|
||||
|
@ -74,11 +74,11 @@ class HalfAPI:
|
|||
for route in gen_starlette_routes(DOMAINS):
|
||||
routes.append(route)
|
||||
|
||||
for domain in DOMAINS:
|
||||
for domain, m_domain in DOMAINS.items():
|
||||
routes.append(
|
||||
Route(
|
||||
f'/{domain}',
|
||||
get_api_domain_routes(domain)
|
||||
get_api_domain_routes(m_domain)
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
@ -29,53 +29,9 @@ class DomainMiddleware(BaseHTTPMiddleware):
|
|||
super().__init__(app)
|
||||
self.config = config
|
||||
self.domains = {}
|
||||
self.api = {}
|
||||
self.acl = {}
|
||||
self.request = None
|
||||
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
"""
|
||||
Scans routes and acls of the domain in the first part of the path
|
||||
"""
|
||||
domain = scope['path'].split('/')[1]
|
||||
|
||||
self.domains = self.config.get('domains', {})
|
||||
|
||||
if len(domain) == 0 or domain == 'halfapi':
|
||||
for domain in self.domains:
|
||||
self.api[domain], self.acl[domain] = api_routes(self.domains[domain])
|
||||
elif domain in self.domains:
|
||||
self.api[domain], self.acl[domain] = api_routes(self.domains[domain])
|
||||
else:
|
||||
logger.error('domain not in self.domains %s / %s',
|
||||
scope['path'],
|
||||
self.domains)
|
||||
|
||||
scope_ = scope.copy()
|
||||
scope_['domains'] = self.domains
|
||||
scope_['api'] = self.api
|
||||
scope_['acl'] = self.acl
|
||||
|
||||
cur_path = URL(scope=scope).path
|
||||
if cur_path[0] == '/':
|
||||
current_domain = cur_path[1:].split('/')[0]
|
||||
else:
|
||||
current_domain = cur_path.split('/')[0]
|
||||
|
||||
try:
|
||||
scope_['config'] = self.config.copy()
|
||||
except configparser.NoSectionError:
|
||||
logger.debug(
|
||||
'No specific configuration for domain **%s**', current_domain)
|
||||
scope_['config'] = {}
|
||||
|
||||
|
||||
self.request = Request(scope_, receive)
|
||||
response = await self.dispatch(self.request, self.call_next)
|
||||
await response(scope_, receive, send)
|
||||
|
||||
|
||||
async def dispatch(self, request: Request,
|
||||
call_next: RequestResponseEndpoint) -> Response:
|
||||
"""
|
||||
|
@ -84,17 +40,19 @@ class DomainMiddleware(BaseHTTPMiddleware):
|
|||
|
||||
response = await call_next(request)
|
||||
|
||||
if 'acl_pass' in self.request.scope:
|
||||
if 'acl_pass' in request.scope:
|
||||
# Set the http header "x-acl" if an acl was used on the route
|
||||
response.headers['x-acl'] = self.request.scope['acl_pass']
|
||||
response.headers['x-acl'] = request.scope['acl_pass']
|
||||
|
||||
if 'args' in self.request.scope:
|
||||
if 'args' in request.scope:
|
||||
# Set the http headers "x-args-required" and "x-args-optional"
|
||||
|
||||
if 'required' in self.request.scope['args']:
|
||||
if 'required' in request.scope['args']:
|
||||
response.headers['x-args-required'] = \
|
||||
','.join(self.request.scope['args']['required'])
|
||||
if 'optional' in self.request.scope['args']:
|
||||
','.join(request.scope['args']['required'])
|
||||
if 'optional' in request.scope['args']:
|
||||
response.headers['x-args-optional'] = \
|
||||
','.join(self.request.scope['args']['optional'])
|
||||
','.join(request.scope['args']['optional'])
|
||||
|
||||
|
||||
return response
|
||||
|
|
|
@ -10,14 +10,15 @@ Constant :
|
|||
SCHEMAS (starlette.schemas.SchemaGenerator)
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict
|
||||
from typing import Dict, Coroutine
|
||||
from types import ModuleType
|
||||
|
||||
from starlette.schemas import SchemaGenerator
|
||||
from starlette.exceptions import HTTPException
|
||||
|
||||
from .. import __version__
|
||||
from .routes import gen_starlette_routes, api_acls
|
||||
from .routes import gen_starlette_routes, api_routes, api_acls
|
||||
from .responses import ORJSONResponse
|
||||
|
||||
logger = logging.getLogger('uvicorn.asgi')
|
||||
|
@ -25,7 +26,7 @@ SCHEMAS = SchemaGenerator(
|
|||
{"openapi": "3.0.0", "info": {"title": "HalfAPI", "version": __version__}}
|
||||
)
|
||||
|
||||
async def get_api_routes(request, *args, **kwargs):
|
||||
def get_api_routes(domains: Dict[str, ModuleType]) -> Coroutine:
|
||||
"""
|
||||
description: Returns the current API routes dictionary
|
||||
as a JSON object
|
||||
|
@ -63,18 +64,26 @@ async def get_api_routes(request, *args, **kwargs):
|
|||
}
|
||||
}
|
||||
"""
|
||||
return ORJSONResponse(request.scope['api'])
|
||||
routes = {
|
||||
domain: api_routes(m_domain)[0]
|
||||
for domain, m_domain in domains.items()
|
||||
}
|
||||
|
||||
async def wrapped(request, *args, **kwargs):
|
||||
return ORJSONResponse(routes)
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
def get_api_domain_routes(m_domain: ModuleType) -> Coroutine:
|
||||
routes, _ = api_routes(m_domain)
|
||||
|
||||
def get_api_domain_routes(domain):
|
||||
async def wrapped(request, *args, **kwargs):
|
||||
"""
|
||||
description: Returns the current API routes dictionary for a specific
|
||||
description: Returns the current API routes dictionary for a specific
|
||||
domain as a JSON object
|
||||
"""
|
||||
if domain in request.scope['api']:
|
||||
return ORJSONResponse(request.scope['api'][domain])
|
||||
else:
|
||||
raise HTTPException(404)
|
||||
return ORJSONResponse(routes)
|
||||
|
||||
return wrapped
|
||||
|
||||
|
|
Loading…
Reference in New Issue