170 lines
6.0 KiB
Python
170 lines
6.0 KiB
Python
from __future__ import annotations
|
|
|
|
import re
|
|
from collections.abc import Awaitable
|
|
from typing import Callable
|
|
from urllib.parse import SplitResult
|
|
from urllib.parse import urlsplit
|
|
|
|
from asgiref.sync import iscoroutinefunction
|
|
from asgiref.sync import markcoroutinefunction
|
|
from django.http import HttpRequest
|
|
from django.http import HttpResponse
|
|
from django.http.response import HttpResponseBase
|
|
from django.utils.cache import patch_vary_headers
|
|
|
|
from corsheaders.conf import conf
|
|
from corsheaders.signals import check_request_enabled
|
|
|
|
ACCESS_CONTROL_ALLOW_ORIGIN = "access-control-allow-origin"
|
|
ACCESS_CONTROL_EXPOSE_HEADERS = "access-control-expose-headers"
|
|
ACCESS_CONTROL_ALLOW_CREDENTIALS = "access-control-allow-credentials"
|
|
ACCESS_CONTROL_ALLOW_HEADERS = "access-control-allow-headers"
|
|
ACCESS_CONTROL_ALLOW_METHODS = "access-control-allow-methods"
|
|
ACCESS_CONTROL_MAX_AGE = "access-control-max-age"
|
|
ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK = "access-control-request-private-network"
|
|
ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK = "access-control-allow-private-network"
|
|
|
|
|
|
class CorsMiddleware:
|
|
sync_capable = True
|
|
async_capable = True
|
|
|
|
def __init__(
|
|
self,
|
|
get_response: (
|
|
Callable[[HttpRequest], HttpResponseBase]
|
|
| Callable[[HttpRequest], Awaitable[HttpResponseBase]]
|
|
),
|
|
) -> None:
|
|
self.get_response = get_response
|
|
self.async_mode = iscoroutinefunction(self.get_response)
|
|
|
|
if self.async_mode:
|
|
# Mark the class as async-capable, but do the actual switch
|
|
|
|
# inside __call__ to avoid swapping out dunder methods
|
|
markcoroutinefunction(self)
|
|
|
|
def __call__(
|
|
self, request: HttpRequest
|
|
) -> HttpResponseBase | Awaitable[HttpResponseBase]:
|
|
if self.async_mode:
|
|
return self.__acall__(request)
|
|
response: HttpResponseBase | None = self.check_preflight(request)
|
|
if response is None:
|
|
result = self.get_response(request)
|
|
assert isinstance(result, HttpResponseBase)
|
|
response = result
|
|
self.add_response_headers(request, response)
|
|
return response
|
|
|
|
async def __acall__(self, request: HttpRequest) -> HttpResponseBase:
|
|
response = self.check_preflight(request)
|
|
if response is None:
|
|
result = self.get_response(request)
|
|
assert not isinstance(result, HttpResponseBase)
|
|
response = await result
|
|
self.add_response_headers(request, response)
|
|
return response
|
|
|
|
def check_preflight(self, request: HttpRequest) -> HttpResponseBase | None:
|
|
"""
|
|
Generate a response for CORS preflight requests.
|
|
"""
|
|
request._cors_enabled = self.is_enabled(request) # type: ignore [attr-defined]
|
|
if (
|
|
request._cors_enabled # type: ignore [attr-defined]
|
|
and request.method == "OPTIONS"
|
|
and "access-control-request-method" in request.headers
|
|
):
|
|
return HttpResponse(headers={"content-length": "0"})
|
|
return None
|
|
|
|
def add_response_headers(
|
|
self, request: HttpRequest, response: HttpResponseBase
|
|
) -> HttpResponseBase:
|
|
"""
|
|
Add the respective CORS headers
|
|
"""
|
|
enabled = getattr(request, "_cors_enabled", None)
|
|
if enabled is None:
|
|
enabled = self.is_enabled(request)
|
|
|
|
if not enabled:
|
|
return response
|
|
|
|
patch_vary_headers(response, ("origin",))
|
|
|
|
origin = request.headers.get("origin")
|
|
if not origin:
|
|
return response
|
|
|
|
try:
|
|
url = urlsplit(origin)
|
|
except ValueError:
|
|
return response
|
|
|
|
if (
|
|
not conf.CORS_ALLOW_ALL_ORIGINS
|
|
and not self.origin_found_in_white_lists(origin, url)
|
|
and not self.check_signal(request)
|
|
):
|
|
return response
|
|
|
|
if conf.CORS_ALLOW_ALL_ORIGINS and not conf.CORS_ALLOW_CREDENTIALS:
|
|
response[ACCESS_CONTROL_ALLOW_ORIGIN] = "*"
|
|
else:
|
|
response[ACCESS_CONTROL_ALLOW_ORIGIN] = origin
|
|
|
|
if conf.CORS_ALLOW_CREDENTIALS:
|
|
response[ACCESS_CONTROL_ALLOW_CREDENTIALS] = "true"
|
|
|
|
if len(conf.CORS_EXPOSE_HEADERS):
|
|
response[ACCESS_CONTROL_EXPOSE_HEADERS] = ", ".join(
|
|
conf.CORS_EXPOSE_HEADERS
|
|
)
|
|
|
|
if request.method == "OPTIONS":
|
|
response[ACCESS_CONTROL_ALLOW_HEADERS] = ", ".join(conf.CORS_ALLOW_HEADERS)
|
|
response[ACCESS_CONTROL_ALLOW_METHODS] = ", ".join(conf.CORS_ALLOW_METHODS)
|
|
if conf.CORS_PREFLIGHT_MAX_AGE:
|
|
response[ACCESS_CONTROL_MAX_AGE] = str(conf.CORS_PREFLIGHT_MAX_AGE)
|
|
|
|
if (
|
|
conf.CORS_ALLOW_PRIVATE_NETWORK
|
|
and request.headers.get(ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK) == "true"
|
|
):
|
|
response[ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK] = "true"
|
|
|
|
return response
|
|
|
|
def origin_found_in_white_lists(self, origin: str, url: SplitResult) -> bool:
|
|
return (
|
|
(origin == "null" and origin in conf.CORS_ALLOWED_ORIGINS)
|
|
or self._url_in_whitelist(url)
|
|
or self.regex_domain_match(origin)
|
|
)
|
|
|
|
def regex_domain_match(self, origin: str) -> bool:
|
|
return any(
|
|
re.match(domain_pattern, origin)
|
|
for domain_pattern in conf.CORS_ALLOWED_ORIGIN_REGEXES
|
|
)
|
|
|
|
def is_enabled(self, request: HttpRequest) -> bool:
|
|
return bool(
|
|
re.match(conf.CORS_URLS_REGEX, request.path_info)
|
|
) or self.check_signal(request)
|
|
|
|
def check_signal(self, request: HttpRequest) -> bool:
|
|
signal_responses = check_request_enabled.send(sender=None, request=request)
|
|
return any(return_value for function, return_value in signal_responses)
|
|
|
|
def _url_in_whitelist(self, url: SplitResult) -> bool:
|
|
origins = [urlsplit(o) for o in conf.CORS_ALLOWED_ORIGINS]
|
|
return any(
|
|
origin.scheme == url.scheme and origin.netloc == url.netloc
|
|
for origin in origins
|
|
)
|