daren_project/venv/Lib/site-packages/channels/security/websocket.py

154 lines
5.7 KiB
Python
Raw Normal View History

from urllib.parse import urlparse
from django.conf import settings
from django.http.request import is_same_domain
from ..generic.websocket import AsyncWebsocketConsumer
class OriginValidator:
"""
Validates that the incoming connection has an Origin header that
is in an allowed list.
"""
def __init__(self, application, allowed_origins):
self.application = application
self.allowed_origins = allowed_origins
async def __call__(self, scope, receive, send):
# Make sure the scope is of type websocket
if scope["type"] != "websocket":
raise ValueError(
"You cannot use OriginValidator on a non-WebSocket connection"
)
# Extract the Origin header
parsed_origin = None
for header_name, header_value in scope.get("headers", []):
if header_name == b"origin":
try:
# Set ResultParse
parsed_origin = urlparse(header_value.decode("latin1"))
except UnicodeDecodeError:
pass
# Check to see if the origin header is valid
if self.valid_origin(parsed_origin):
# Pass control to the application
return await self.application(scope, receive, send)
else:
# Deny the connection
denier = WebsocketDenier()
return await denier(scope, receive, send)
def valid_origin(self, parsed_origin):
"""
Checks parsed origin is None.
Pass control to the validate_origin function.
Returns ``True`` if validation function was successful, ``False`` otherwise.
"""
# None is not allowed unless all hosts are allowed
if parsed_origin is None and "*" not in self.allowed_origins:
return False
return self.validate_origin(parsed_origin)
def validate_origin(self, parsed_origin):
"""
Validate the given origin for this site.
Check than the origin looks valid and matches the origin pattern in
specified list ``allowed_origins``. Any pattern begins with a scheme.
After the scheme there must be a domain. Any domain beginning with a
period corresponds to the domain and all its subdomains (for example,
``http://.example.com``). After the domain there must be a port,
but it can be omitted. ``*`` matches anything and anything
else must match exactly.
Note. This function assumes that the given origin has a schema, domain
and port, but port is optional.
Returns ``True`` for a valid host, ``False`` otherwise.
"""
return any(
pattern == "*" or self.match_allowed_origin(parsed_origin, pattern)
for pattern in self.allowed_origins
)
def match_allowed_origin(self, parsed_origin, pattern):
"""
Returns ``True`` if the origin is either an exact match or a match
to the wildcard pattern. Compares scheme, domain, port of origin and pattern.
Any pattern can be begins with a scheme. After the scheme must be a domain,
or just domain without scheme.
Any domain beginning with a period corresponds to the domain and all
its subdomains (for example, ``.example.com`` ``example.com``
and any subdomain). Also with scheme (for example, ``http://.example.com``
``http://exapmple.com``). After the domain there must be a port,
but it can be omitted.
Note. This function assumes that the given origin is either None, a
schema-domain-port string, or just a domain string
"""
if parsed_origin is None:
return False
# Get ResultParse object
parsed_pattern = urlparse(pattern.lower())
if parsed_origin.hostname is None:
return False
if not parsed_pattern.scheme:
pattern_hostname = urlparse("//" + pattern).hostname or pattern
return is_same_domain(parsed_origin.hostname, pattern_hostname)
# Get origin.port or default ports for origin or None
origin_port = self.get_origin_port(parsed_origin)
# Get pattern.port or default ports for pattern or None
pattern_port = self.get_origin_port(parsed_pattern)
# Compares hostname, scheme, ports of pattern and origin
if (
parsed_pattern.scheme == parsed_origin.scheme
and origin_port == pattern_port
and is_same_domain(parsed_origin.hostname, parsed_pattern.hostname)
):
return True
return False
def get_origin_port(self, origin):
"""
Returns the origin.port or port for this schema by default.
Otherwise, it returns None.
"""
if origin.port is not None:
# Return origin.port
return origin.port
# if origin.port doesn`t exists
if origin.scheme == "http" or origin.scheme == "ws":
# Default port return for http, ws
return 80
elif origin.scheme == "https" or origin.scheme == "wss":
# Default port return for https, wss
return 443
else:
return None
def AllowedHostsOriginValidator(application):
"""
Factory function which returns an OriginValidator configured to use
settings.ALLOWED_HOSTS.
"""
allowed_hosts = settings.ALLOWED_HOSTS
if settings.DEBUG and not allowed_hosts:
allowed_hosts = ["localhost", "127.0.0.1", "[::1]"]
return OriginValidator(application, allowed_hosts)
class WebsocketDenier(AsyncWebsocketConsumer):
"""
Simple application which denies all requests to it.
"""
async def connect(self):
await self.close()