halfapi/halfapi/lib/domain_middleware.py

101 lines
3.1 KiB
Python

"""
DomainMiddleware
"""
import configparser
import logging
from starlette.datastructures import URL
from starlette.middleware.base import (BaseHTTPMiddleware,
RequestResponseEndpoint)
from starlette.requests import Request
from starlette.responses import Response
from starlette.types import Scope, Send, Receive
from .routes import api_routes
from .domain import d_domains
logger = logging.getLogger('uvicorn.asgi')
class DomainMiddleware(BaseHTTPMiddleware):
"""
DomainMiddleware adds the api routes and acls to the following scope keys :
- domains
- api
- acl
"""
def __init__(self, app, config):
super().__init__(app)
self.config = config
self.domains = {}
self.api = {}
self.acl = {}
self.request = None
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""
Scans routes and acls of the domain in the first part of the path
"""
domain = scope['path'].split('/')[1]
self.domains = self.config.get('domains', {})
if len(domain) == 0:
for domain in self.domains:
self.api[domain], self.acl[domain] = api_routes(self.domains[domain])
elif domain in self.domains:
self.api[domain], self.acl[domain] = api_routes(self.domains[domain])
else:
logger.error('domain not in self.domains %s / %s',
scope['path'],
self.domains)
scope_ = scope.copy()
scope_['domains'] = self.domains
scope_['api'] = self.api
scope_['acl'] = self.acl
cur_path = URL(scope=scope).path
if cur_path[0] == '/':
current_domain = cur_path[1:].split('/')[0]
else:
current_domain = cur_path.split('/')[0]
try:
scope_['config'] = self.config.copy()
except configparser.NoSectionError:
logger.debug(
'No specific configuration for domain **%s**', current_domain)
scope_['config'] = {}
self.request = Request(scope_, receive)
response = await self.dispatch(self.request, self.call_next)
await response(scope_, receive, send)
async def dispatch(self, request: Request,
call_next: RequestResponseEndpoint) -> Response:
"""
Call of the route fonction (decorated or not)
"""
response = await call_next(request)
if 'acl_pass' in self.request.scope:
# Set the http header "x-acl" if an acl was used on the route
response.headers['x-acl'] = self.request.scope['acl_pass']
if 'args' in self.request.scope:
# Set the http headers "x-args-required" and "x-args-optional"
if 'required' in self.request.scope['args']:
response.headers['x-args-required'] = \
','.join(self.request.scope['args']['required'])
if 'optional' in self.request.scope['args']:
response.headers['x-args-optional'] = \
','.join(self.request.scope['args']['optional'])
return response