[responses] use a wrapper function for exception handling (fix starlette 0.20)

This commit is contained in:
Maxime Alves LIRMM 2023-01-14 10:26:31 +01:00
parent d0876e45da
commit 717d3f8bd6
2 changed files with 14 additions and 6 deletions

View File

@ -35,7 +35,7 @@ from .lib.timing import HTimingClient
from .lib.jwt_middleware import JWTAuthenticationBackend from .lib.jwt_middleware import JWTAuthenticationBackend
from .lib.responses import (ORJSONResponse, UnauthorizedResponse, from .lib.responses import (ORJSONResponse, UnauthorizedResponse,
NotFoundResponse, InternalServerErrorResponse, NotImplementedResponse, NotFoundResponse, InternalServerErrorResponse, NotImplementedResponse,
ServiceUnavailableResponse) ServiceUnavailableResponse, gen_exception_route)
from .lib.domain import NoDomainsException from .lib.domain import NoDomainsException
from .lib.routes import gen_schema_routes, JSONRoute from .lib.routes import gen_schema_routes, JSONRoute
from .lib.schemas import schema_json from .lib.schemas import schema_json
@ -90,11 +90,11 @@ class HalfAPI(Starlette):
debug=not PRODUCTION, debug=not PRODUCTION,
routes=routes, routes=routes,
exception_handlers={ exception_handlers={
401: UnauthorizedResponse, 401: gen_exception_route(UnauthorizedResponse),
404: NotFoundResponse, 404: gen_exception_route(NotFoundResponse),
500: HalfAPI.exception, 500: gen_exception_route(HalfAPI.exception),
501: NotImplementedResponse, 501: gen_exception_route(NotImplementedResponse),
503: ServiceUnavailableResponse 503: gen_exception_route(ServiceUnavailableResponse)
}, },
on_startup=startup_fcts on_startup=startup_fcts
) )

View File

@ -24,6 +24,8 @@ import orjson
# asgi framework # asgi framework
from starlette.responses import PlainTextResponse, Response, JSONResponse, HTMLResponse from starlette.responses import PlainTextResponse, Response, JSONResponse, HTMLResponse
from starlette.requests import Request
from starlette.exceptions import HTTPException
from .user import JWTUser, Nobody from .user import JWTUser, Nobody
from ..logging import logger from ..logging import logger
@ -157,3 +159,9 @@ class ODSResponse(Response):
class XLSXResponse(ODSResponse): class XLSXResponse(ODSResponse):
file_type = 'xlsx' file_type = 'xlsx'
def gen_exception_route(response_cls):
async def exception_route(req: Request, exc: HTTPException):
return response_cls()
return exception_route