add request.scope['config'] when there is a section with the domains name in the project's configuration

This commit is contained in:
Maxime Alves LIRMM@home 2020-11-04 05:01:26 +01:00
parent 61aec6871a
commit 4782764059
3 changed files with 19 additions and 3 deletions

View File

@ -16,12 +16,9 @@ def connected(fct=public):
""" """
@wraps(fct) @wraps(fct)
def caller(req, *args, **kwargs): def caller(req, *args, **kwargs):
print(fct)
print(req.user)
if (not hasattr(req, 'user') if (not hasattr(req, 'user')
or isinstance(req.user, UnauthenticatedUser) or isinstance(req.user, UnauthenticatedUser)
or not hasattr(req.user, 'is_authenticated')): or not hasattr(req.user, 'is_authenticated')):
print('Connected is false')
return False return False
return fct(req, **{**kwargs, **req.path_params}) return fct(req, **{**kwargs, **req.path_params})

View File

@ -158,11 +158,13 @@ def gen_domain_routes(domain: str, m_dom: ModuleType) -> Generator:
If not, it is considered as empty If not, it is considered as empty
""" """
m_router = None
try: try:
m_router = importlib.import_module('.routers', domain) m_router = importlib.import_module('.routers', domain)
except ImportError: except ImportError:
logger.warning('Domain **%s** has no **routers** module', domain) logger.warning('Domain **%s** has no **routers** module', domain)
logger.debug('%s', m_dom) logger.debug('%s', m_dom)
m_router = importlib.import_module('.routers', f'.{domain}')
if m_router: if m_router:
yield from gen_router_routes(m_router, [domain]) yield from gen_router_routes(m_router, [domain])

View File

@ -1,7 +1,9 @@
""" """
DomainMiddleware DomainMiddleware
""" """
import logging
from starlette.datastructures import URL
from starlette.middleware.base import (BaseHTTPMiddleware, from starlette.middleware.base import (BaseHTTPMiddleware,
RequestResponseEndpoint) RequestResponseEndpoint)
from starlette.requests import Request from starlette.requests import Request
@ -10,6 +12,9 @@ from starlette.types import Scope, Send, Receive
from .routes import api_routes from .routes import api_routes
from .domain import d_domains from .domain import d_domains
from ..conf import config_dict
logger = logging.getLogger('uvicorn.asgi')
class DomainMiddleware(BaseHTTPMiddleware): class DomainMiddleware(BaseHTTPMiddleware):
""" """
@ -41,6 +46,18 @@ class DomainMiddleware(BaseHTTPMiddleware):
scope_['domains'] = self.domains scope_['domains'] = self.domains
scope_['api'] = self.api scope_['api'] = self.api
scope_['acl'] = self.acl scope_['acl'] = self.acl
cur_path = URL(scope=scope).path
if cur_path[0] == '/':
current_domain = cur_path[1:].split('/')[0]
else:
current_domain = cur_path.split('/')[0]
if len(current_domain):
config_section = self.config.items(current_domain)
scope_['config'] = dict(config_section)
request = Request(scope_, receive) request = Request(scope_, receive)
response = await self.dispatch(request, self.call_next) response = await self.dispatch(request, self.call_next)
await response(scope_, receive, send) await response(scope_, receive, send)