diff --git a/halfapi/lib/acl.py b/halfapi/lib/acl.py index 0e6e91d..21a5fe9 100644 --- a/halfapi/lib/acl.py +++ b/halfapi/lib/acl.py @@ -2,10 +2,14 @@ """ Base ACL module that contains generic functions for domains ACL """ +import logging from functools import wraps from starlette.authentication import UnauthenticatedUser +from json import JSONDecodeError + +logger = logging.getLogger('uvicorn.asgi') def public(*args, **kwargs) -> bool: "Unlimited access" @@ -24,3 +28,51 @@ def connected(fct=public): return fct(req, **{**kwargs, **req.path_params}) return caller + +def args_check(fct): + @wraps(fct) + async def caller(req, *args, **kwargs): + if 'check' in req.query_params: + return await fct(req, *args, **kwargs) + + if req.method == 'GET': + data_ = req.query_params + + if req.method == 'POST': + try: + data_ = await req.json() + except JSONDecodeError as exc: + data_ = {} + + def plural(array: list) -> str: + return len(array) > 1 and 's' or '' + def comma_list(array: list) -> str: + return ', '.join(array) + + + args_d = kwargs.get('args', {}) + required = args_d.get('required', set()) + + missing = [] + data = {} + + for key in required: + data[key] = data_.pop(key, None) + if data[key] is None: + missing.append(key) + + if missing: + raise HTTPException( + 400, + f"Missing value{plural(missing)} for: {comma_list(missing)}!") + + optional = args_d.get('optional', set()) + for key in optional: + if key in data_: + data[key] = data_[key] + + kwargs['data'] = data + + return await fct(req, *args, **kwargs) + + return caller