from collections.abc import Awaitable, Callable, Collection, Iterable from aiohttp import web from aiohttp.typedefs import Middleware from aiohttp.web_middlewares import middleware from aiohttp.web_request import Request from aiohttp.web_response import StreamResponse Handler = Callable[[Request], Awaitable[StreamResponse]] def cors( *, allow_headers: Iterable[str], allow_origins: Collection[str], expose_headers: Iterable[str], ) -> Middleware: @middleware async def impl(request: Request, handler: Handler) -> StreamResponse: origin = request.headers.get("Origin") if not origin: return await handler(request) if origin not in allow_origins: return web.Response(status=403, text="CORS origin is not allowed") is_options = request.method == "OPTIONS" is_preflight = is_options and "Access-Control-Request-Method" in request.headers if is_preflight: resp = StreamResponse() else: resp = await handler(request) resp.headers["Access-Control-Allow-Origin"] = origin if expose_headers: resp.headers["Access-Control-Expose-Headers"] = ", ".join(expose_headers) if is_options: resp.headers["Access-Control-Allow-Headers"] = ", ".join(allow_headers) resp.headers["Access-Control-Allow-Methods"] = ", ".join( ("OPTIONS", "POST") ) return resp return impl