Fix #21 by simplifying DomainMiddleware
Tests are passing, but we loose the by-domain configuration (#19) Squashed commit of the following: commit d75fafcb9a043ac2540b2ac135704721b002d3c0 Author: Maxime Alves LIRMM <maxime.alves@lirmm.fr> Date: Thu Sep 2 14:40:05 2021 +0200 fix #21 commit 38c59e4ea3b40bd230f2add2bb0e05772913c097 Author: Maxime Alves LIRMM <maxime.alves@lirmm.fr> Date: Thu Sep 2 01:13:51 2021 +0200 [deps] starlette 0.15 (breaks tests) FAILED tests/test_debug_routes.py::test_current_user - AttributeError: 'DomainMiddleware' object has no attribute 'call_next' FAILED tests/test_debug_routes.py::test_log - AttributeError: 'DomainMiddleware' object has no attribute 'call_next' FAILED tests/test_debug_routes.py::test_error - AttributeError: 'DomainMiddleware' object has no attribute 'call_next' FAILED tests/test_dummy_project_router.py::test_get_route - AttributeError: 'DomainMiddleware' object has no attribute 'call_next' FAILED tests/test_dummy_project_router.py::test_delete_route - AttributeError: 'DomainMiddleware' object has no attribute 'call_next' FAILED tests/test_lib_schemas.py::test_get_api_routes - AttributeError: 'DomainMiddleware' object has no attribute 'call_next' FAILED tests/test_lib_schemas.py::test_get_schema_route - AttributeError: 'DomainMiddleware' object has no attribute 'call_next' FAILED tests/test_lib_schemas.py::test_get_api_dummy_domain_routes - AttributeError: 'DomainMiddleware' object has no attribute 'call_next'
This commit is contained in:
parent
865a4dffd1
commit
bc556854ac
2
Pipfile
2
Pipfile
|
@ -12,7 +12,7 @@ build = "*"
|
||||||
|
|
||||||
[packages]
|
[packages]
|
||||||
click = ">=7.1,<8"
|
click = ">=7.1,<8"
|
||||||
starlette = ">=0.14,<0.15"
|
starlette = ">=0.15,<0.16"
|
||||||
uvicorn = ">=0.13,<1"
|
uvicorn = ">=0.13,<1"
|
||||||
orjson = ">=3.4.7,<4"
|
orjson = ">=3.4.7,<4"
|
||||||
pyjwt = ">=2.0.1,<3"
|
pyjwt = ">=2.0.1,<3"
|
||||||
|
|
|
@ -51,12 +51,12 @@ class HalfAPI:
|
||||||
from halfapi.conf import CONFIG, SECRET, PRODUCTION, DOMAINS
|
from halfapi.conf import CONFIG, SECRET, PRODUCTION, DOMAINS
|
||||||
|
|
||||||
|
|
||||||
routes = [ Route('/', get_api_routes) ]
|
routes = [ Route('/', get_api_routes(DOMAINS)) ]
|
||||||
|
|
||||||
|
|
||||||
routes += [
|
routes += [
|
||||||
Route('/halfapi/schema', schema_json),
|
Route('/halfapi/schema', schema_json),
|
||||||
Route('/halfapi/acls', get_acls)
|
Route('/halfapi/acls', get_acls),
|
||||||
]
|
]
|
||||||
|
|
||||||
routes += Route('/halfapi/current_user', lambda request, *args, **kwargs:
|
routes += Route('/halfapi/current_user', lambda request, *args, **kwargs:
|
||||||
|
@ -74,11 +74,11 @@ class HalfAPI:
|
||||||
for route in gen_starlette_routes(DOMAINS):
|
for route in gen_starlette_routes(DOMAINS):
|
||||||
routes.append(route)
|
routes.append(route)
|
||||||
|
|
||||||
for domain in DOMAINS:
|
for domain, m_domain in DOMAINS.items():
|
||||||
routes.append(
|
routes.append(
|
||||||
Route(
|
Route(
|
||||||
f'/{domain}',
|
f'/{domain}',
|
||||||
get_api_domain_routes(domain)
|
get_api_domain_routes(m_domain)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -29,53 +29,9 @@ class DomainMiddleware(BaseHTTPMiddleware):
|
||||||
super().__init__(app)
|
super().__init__(app)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.domains = {}
|
self.domains = {}
|
||||||
self.api = {}
|
|
||||||
self.acl = {}
|
|
||||||
self.request = None
|
self.request = None
|
||||||
|
|
||||||
|
|
||||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
|
||||||
"""
|
|
||||||
Scans routes and acls of the domain in the first part of the path
|
|
||||||
"""
|
|
||||||
domain = scope['path'].split('/')[1]
|
|
||||||
|
|
||||||
self.domains = self.config.get('domains', {})
|
|
||||||
|
|
||||||
if len(domain) == 0 or domain == 'halfapi':
|
|
||||||
for domain in self.domains:
|
|
||||||
self.api[domain], self.acl[domain] = api_routes(self.domains[domain])
|
|
||||||
elif domain in self.domains:
|
|
||||||
self.api[domain], self.acl[domain] = api_routes(self.domains[domain])
|
|
||||||
else:
|
|
||||||
logger.error('domain not in self.domains %s / %s',
|
|
||||||
scope['path'],
|
|
||||||
self.domains)
|
|
||||||
|
|
||||||
scope_ = scope.copy()
|
|
||||||
scope_['domains'] = self.domains
|
|
||||||
scope_['api'] = self.api
|
|
||||||
scope_['acl'] = self.acl
|
|
||||||
|
|
||||||
cur_path = URL(scope=scope).path
|
|
||||||
if cur_path[0] == '/':
|
|
||||||
current_domain = cur_path[1:].split('/')[0]
|
|
||||||
else:
|
|
||||||
current_domain = cur_path.split('/')[0]
|
|
||||||
|
|
||||||
try:
|
|
||||||
scope_['config'] = self.config.copy()
|
|
||||||
except configparser.NoSectionError:
|
|
||||||
logger.debug(
|
|
||||||
'No specific configuration for domain **%s**', current_domain)
|
|
||||||
scope_['config'] = {}
|
|
||||||
|
|
||||||
|
|
||||||
self.request = Request(scope_, receive)
|
|
||||||
response = await self.dispatch(self.request, self.call_next)
|
|
||||||
await response(scope_, receive, send)
|
|
||||||
|
|
||||||
|
|
||||||
async def dispatch(self, request: Request,
|
async def dispatch(self, request: Request,
|
||||||
call_next: RequestResponseEndpoint) -> Response:
|
call_next: RequestResponseEndpoint) -> Response:
|
||||||
"""
|
"""
|
||||||
|
@ -84,17 +40,19 @@ class DomainMiddleware(BaseHTTPMiddleware):
|
||||||
|
|
||||||
response = await call_next(request)
|
response = await call_next(request)
|
||||||
|
|
||||||
if 'acl_pass' in self.request.scope:
|
if 'acl_pass' in request.scope:
|
||||||
# Set the http header "x-acl" if an acl was used on the route
|
# Set the http header "x-acl" if an acl was used on the route
|
||||||
response.headers['x-acl'] = self.request.scope['acl_pass']
|
response.headers['x-acl'] = request.scope['acl_pass']
|
||||||
|
|
||||||
if 'args' in self.request.scope:
|
if 'args' in request.scope:
|
||||||
# Set the http headers "x-args-required" and "x-args-optional"
|
# Set the http headers "x-args-required" and "x-args-optional"
|
||||||
|
|
||||||
if 'required' in self.request.scope['args']:
|
if 'required' in request.scope['args']:
|
||||||
response.headers['x-args-required'] = \
|
response.headers['x-args-required'] = \
|
||||||
','.join(self.request.scope['args']['required'])
|
','.join(request.scope['args']['required'])
|
||||||
if 'optional' in self.request.scope['args']:
|
if 'optional' in request.scope['args']:
|
||||||
response.headers['x-args-optional'] = \
|
response.headers['x-args-optional'] = \
|
||||||
','.join(self.request.scope['args']['optional'])
|
','.join(request.scope['args']['optional'])
|
||||||
|
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
|
@ -10,14 +10,15 @@ Constant :
|
||||||
SCHEMAS (starlette.schemas.SchemaGenerator)
|
SCHEMAS (starlette.schemas.SchemaGenerator)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict
|
from typing import Dict, Coroutine
|
||||||
|
from types import ModuleType
|
||||||
|
|
||||||
from starlette.schemas import SchemaGenerator
|
from starlette.schemas import SchemaGenerator
|
||||||
from starlette.exceptions import HTTPException
|
|
||||||
|
|
||||||
from .. import __version__
|
from .. import __version__
|
||||||
from .routes import gen_starlette_routes, api_acls
|
from .routes import gen_starlette_routes, api_routes, api_acls
|
||||||
from .responses import ORJSONResponse
|
from .responses import ORJSONResponse
|
||||||
|
|
||||||
logger = logging.getLogger('uvicorn.asgi')
|
logger = logging.getLogger('uvicorn.asgi')
|
||||||
|
@ -25,7 +26,7 @@ SCHEMAS = SchemaGenerator(
|
||||||
{"openapi": "3.0.0", "info": {"title": "HalfAPI", "version": __version__}}
|
{"openapi": "3.0.0", "info": {"title": "HalfAPI", "version": __version__}}
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_api_routes(request, *args, **kwargs):
|
def get_api_routes(domains: Dict[str, ModuleType]) -> Coroutine:
|
||||||
"""
|
"""
|
||||||
description: Returns the current API routes dictionary
|
description: Returns the current API routes dictionary
|
||||||
as a JSON object
|
as a JSON object
|
||||||
|
@ -63,18 +64,26 @@ async def get_api_routes(request, *args, **kwargs):
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
return ORJSONResponse(request.scope['api'])
|
routes = {
|
||||||
|
domain: api_routes(m_domain)[0]
|
||||||
|
for domain, m_domain in domains.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
async def wrapped(request, *args, **kwargs):
|
||||||
|
return ORJSONResponse(routes)
|
||||||
|
|
||||||
|
return wrapped
|
||||||
|
|
||||||
|
|
||||||
|
def get_api_domain_routes(m_domain: ModuleType) -> Coroutine:
|
||||||
|
routes, _ = api_routes(m_domain)
|
||||||
|
|
||||||
def get_api_domain_routes(domain):
|
|
||||||
async def wrapped(request, *args, **kwargs):
|
async def wrapped(request, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
description: Returns the current API routes dictionary for a specific
|
description: Returns the current API routes dictionary for a specific
|
||||||
domain as a JSON object
|
domain as a JSON object
|
||||||
"""
|
"""
|
||||||
if domain in request.scope['api']:
|
return ORJSONResponse(routes)
|
||||||
return ORJSONResponse(request.scope['api'][domain])
|
|
||||||
else:
|
|
||||||
raise HTTPException(404)
|
|
||||||
|
|
||||||
return wrapped
|
return wrapped
|
||||||
|
|
||||||
|
|
2
setup.py
2
setup.py
|
@ -44,7 +44,7 @@ setup(
|
||||||
python_requires=">=3.8",
|
python_requires=">=3.8",
|
||||||
install_requires=[
|
install_requires=[
|
||||||
"PyJWT>=2.0.1",
|
"PyJWT>=2.0.1",
|
||||||
"starlette>=0.14,<0.15",
|
"starlette>=0.15,<0.16",
|
||||||
"click>=7.1,<8",
|
"click>=7.1,<8",
|
||||||
"uvicorn>=0.13,<1",
|
"uvicorn>=0.13,<1",
|
||||||
"orjson>=3.4.7,<4",
|
"orjson>=3.4.7,<4",
|
||||||
|
|
Loading…
Reference in New Issue