diff --git a/halfapi/half_domain.py b/halfapi/half_domain.py index ee97026..19c09c1 100644 --- a/halfapi/half_domain.py +++ b/halfapi/half_domain.py @@ -13,6 +13,8 @@ from schema import SchemaError from starlette.applications import Starlette from starlette.routing import Router, Route +from .lib.acl import AclRoute + import yaml @@ -156,15 +158,12 @@ class HalfDomain(Starlette): @staticmethod def acls_router(domain, module_path=None, acl=None): - """ Router of the acls routes : + """ Returns a Router object with the following routes : - / : Same result as HalfDomain.acls_route - /acl_name : Dummy route protected by the "acl_name" acl + / : The "acls" field of the API metadatas + /{acl_name} : If the ACL is defined as public, a route that returns either status code 200 or 401 on HEAD/GET request """ - async def dummy_endpoint(request, *args, **kwargs): - return PlainTextResponse('') - routes = [] d_res = {} @@ -187,14 +186,7 @@ class HalfDomain(Starlette): if elt.public: routes.append( - Route( - f'/{elt.name}', - HalfRoute.acl_decorator( - dummy_endpoint, - params=[{'acl': fct}] - ), - methods=['GET'] - ) + AclRoute(f'/{elt.name}', fct, elt) ) d_res_under_domain_name = {} diff --git a/halfapi/lib/acl.py b/halfapi/lib/acl.py index 2542851..9f35fdc 100644 --- a/halfapi/lib/acl.py +++ b/halfapi/lib/acl.py @@ -5,8 +5,11 @@ Base ACL module that contains generic functions for domains ACL from dataclasses import dataclass from functools import wraps from json import JSONDecodeError +import yaml from starlette.authentication import UnauthenticatedUser from starlette.exceptions import HTTPException +from starlette.routing import Route +from starlette.responses import Response from ..logging import logger @@ -140,3 +143,36 @@ class ACL(): documentation: str priority: int public: bool = False + + +class AclRoute(Route): + def __init__(self, path, acl_fct, acl: ACL): + self.acl_fct = acl_fct + self.name = acl.name + self.description = acl.documentation + + self.docstring = yaml.dump({ + 'description': f'{self.name}: {self.description}', + 'responses': { + '200': { + 'description': 'ACL OK' + }, + '401': { + 'description': 'ACL FAIL' + } + } + }) + + async def endpoint(request, *args, **kwargs): + if request.method == 'GET': + logger.warning('Deprecated since 0.6.28, use HEAD method since now') + + if self.acl_fct(request, *args, **kwargs) is True: + return Response(status_code=200) + + return Response(status_code=401) + + endpoint.__doc__ = self.docstring + + return super().__init__(path, methods=['HEAD', 'GET'], endpoint=endpoint) +