role_based_system/venv/Lib/site-packages/channels/layers.py

366 lines
12 KiB
Python
Raw Normal View History

import asyncio
import fnmatch
import random
import re
import string
import time
from copy import deepcopy
from django.conf import settings
from django.core.signals import setting_changed
from django.utils.module_loading import import_string
from channels import DEFAULT_CHANNEL_LAYER
from .exceptions import ChannelFull, InvalidChannelLayerError
class ChannelLayerManager:
"""
Takes a settings dictionary of backends and initialises them on request.
"""
def __init__(self):
self.backends = {}
setting_changed.connect(self._reset_backends)
def _reset_backends(self, setting, **kwargs):
"""
Removes cached channel layers when the CHANNEL_LAYERS setting changes.
"""
if setting == "CHANNEL_LAYERS":
self.backends = {}
@property
def configs(self):
# Lazy load settings so we can be imported
return getattr(settings, "CHANNEL_LAYERS", {})
def make_backend(self, name):
"""
Instantiate channel layer.
"""
config = self.configs[name].get("CONFIG", {})
return self._make_backend(name, config)
def make_test_backend(self, name):
"""
Instantiate channel layer using its test config.
"""
try:
config = self.configs[name]["TEST_CONFIG"]
except KeyError:
raise InvalidChannelLayerError("No TEST_CONFIG specified for %s" % name)
return self._make_backend(name, config)
def _make_backend(self, name, config):
# Check for old format config
if "ROUTING" in self.configs[name]:
raise InvalidChannelLayerError(
"ROUTING key found for %s - this is no longer needed in Channels 2."
% name
)
# Load the backend class
try:
backend_class = import_string(self.configs[name]["BACKEND"])
except KeyError:
raise InvalidChannelLayerError("No BACKEND specified for %s" % name)
except ImportError:
raise InvalidChannelLayerError(
"Cannot import BACKEND %r specified for %s"
% (self.configs[name]["BACKEND"], name)
)
# Initialise and pass config
return backend_class(**config)
def __getitem__(self, key):
if key not in self.backends:
self.backends[key] = self.make_backend(key)
return self.backends[key]
def __contains__(self, key):
return key in self.configs
def set(self, key, layer):
"""
Sets an alias to point to a new ChannelLayerWrapper instance, and
returns the old one that it replaced. Useful for swapping out the
backend during tests.
"""
old = self.backends.get(key, None)
self.backends[key] = layer
return old
class BaseChannelLayer:
"""
Base channel layer class that others can inherit from, with useful
common functionality.
"""
def __init__(self, expiry=60, capacity=100, channel_capacity=None):
self.expiry = expiry
self.capacity = capacity
self.channel_capacity = channel_capacity or {}
def compile_capacities(self, channel_capacity):
"""
Takes an input channel_capacity dict and returns the compiled list
of regexes that get_capacity will look for as self.channel_capacity
"""
result = []
for pattern, value in channel_capacity.items():
# If they passed in a precompiled regex, leave it, else interpret
# it as a glob.
if hasattr(pattern, "match"):
result.append((pattern, value))
else:
result.append((re.compile(fnmatch.translate(pattern)), value))
return result
def get_capacity(self, channel):
"""
Gets the correct capacity for the given channel; either the default,
or a matching result from channel_capacity. Returns the first matching
result; if you want to control the order of matches, use an ordered dict
as input.
"""
for pattern, capacity in self.channel_capacity:
if pattern.match(channel):
return capacity
return self.capacity
def match_type_and_length(self, name):
if isinstance(name, str) and (len(name) < 100):
return True
return False
# Name validation functions
channel_name_regex = re.compile(r"^[a-zA-Z\d\-_.]+(\![\d\w\-_.]*)?$")
group_name_regex = re.compile(r"^[a-zA-Z\d\-_.]+$")
invalid_name_error = (
"{} name must be a valid unicode string containing only ASCII "
+ "alphanumerics, hyphens, underscores, or periods."
)
def valid_channel_name(self, name, receive=False):
if self.match_type_and_length(name):
if bool(self.channel_name_regex.match(name)):
# Check cases for special channels
if "!" in name and not name.endswith("!") and receive:
raise TypeError(
"Specific channel names in receive() must end at the !"
)
return True
raise TypeError(
"Channel name must be a valid unicode string containing only ASCII "
+ "alphanumerics, hyphens, or periods, not '{}'.".format(name)
)
def valid_group_name(self, name):
if self.match_type_and_length(name):
if bool(self.group_name_regex.match(name)):
return True
raise TypeError(
"Group name must be a valid unicode string containing only ASCII "
+ "alphanumerics, hyphens, or periods."
)
def valid_channel_names(self, names, receive=False):
_non_empty_list = True if names else False
_names_type = isinstance(names, list)
assert _non_empty_list and _names_type, "names must be a non-empty list"
assert all(
self.valid_channel_name(channel, receive=receive) for channel in names
)
return True
def non_local_name(self, name):
"""
Given a channel name, returns the "non-local" part. If the channel name
is a process-specific channel (contains !) this means the part up to
and including the !; if it is anything else, this means the full name.
"""
if "!" in name:
return name[: name.find("!") + 1]
else:
return name
class InMemoryChannelLayer(BaseChannelLayer):
"""
In-memory channel layer implementation
"""
def __init__(
self,
expiry=60,
group_expiry=86400,
capacity=100,
channel_capacity=None,
**kwargs
):
super().__init__(
expiry=expiry,
capacity=capacity,
channel_capacity=channel_capacity,
**kwargs
)
self.channels = {}
self.groups = {}
self.group_expiry = group_expiry
# Channel layer API
extensions = ["groups", "flush"]
async def send(self, channel, message):
"""
Send a message onto a (general or specific) channel.
"""
# Typecheck
assert isinstance(message, dict), "message is not a dict"
assert self.valid_channel_name(channel), "Channel name not valid"
# If it's a process-local channel, strip off local part and stick full
# name in message
assert "__asgi_channel__" not in message
queue = self.channels.setdefault(channel, asyncio.Queue())
# Are we full
if queue.qsize() >= self.capacity:
raise ChannelFull(channel)
# Add message
await queue.put((time.time() + self.expiry, deepcopy(message)))
async def receive(self, channel):
"""
Receive the first message that arrives on the channel.
If more than one coroutine waits on the same channel, a random one
of the waiting coroutines will get the result.
"""
assert self.valid_channel_name(channel)
self._clean_expired()
queue = self.channels.setdefault(channel, asyncio.Queue())
# Do a plain direct receive
try:
_, message = await queue.get()
finally:
if queue.empty():
del self.channels[channel]
return message
async def new_channel(self, prefix="specific."):
"""
Returns a new channel name that can be used by something in our
process as a specific channel.
"""
return "%s.inmemory!%s" % (
prefix,
"".join(random.choice(string.ascii_letters) for i in range(12)),
)
# Expire cleanup
def _clean_expired(self):
"""
Goes through all messages and groups and removes those that are expired.
Any channel with an expired message is removed from all groups.
"""
# Channel cleanup
for channel, queue in list(self.channels.items()):
# See if it's expired
while not queue.empty() and queue._queue[0][0] < time.time():
queue.get_nowait()
# Any removal prompts group discard
self._remove_from_groups(channel)
# Is the channel now empty and needs deleting?
if queue.empty():
del self.channels[channel]
# Group Expiration
timeout = int(time.time()) - self.group_expiry
for group in self.groups:
for channel in list(self.groups.get(group, set())):
# If join time is older than group_expiry end the group membership
if (
self.groups[group][channel]
and int(self.groups[group][channel]) < timeout
):
# Delete from group
del self.groups[group][channel]
# Flush extension
async def flush(self):
self.channels = {}
self.groups = {}
async def close(self):
# Nothing to go
pass
def _remove_from_groups(self, channel):
"""
Removes a channel from all groups. Used when a message on it expires.
"""
for channels in self.groups.values():
if channel in channels:
del channels[channel]
# Groups extension
async def group_add(self, group, channel):
"""
Adds the channel name to a group.
"""
# Check the inputs
assert self.valid_group_name(group), "Group name not valid"
assert self.valid_channel_name(channel), "Channel name not valid"
# Add to group dict
self.groups.setdefault(group, {})
self.groups[group][channel] = time.time()
async def group_discard(self, group, channel):
# Both should be text and valid
assert self.valid_channel_name(channel), "Invalid channel name"
assert self.valid_group_name(group), "Invalid group name"
# Remove from group set
if group in self.groups:
if channel in self.groups[group]:
del self.groups[group][channel]
if not self.groups[group]:
del self.groups[group]
async def group_send(self, group, message):
# Check types
assert isinstance(message, dict), "Message is not a dict"
assert self.valid_group_name(group), "Invalid group name"
# Run clean
self._clean_expired()
# Send to each channel
for channel in self.groups.get(group, set()):
try:
await self.send(channel, message)
except ChannelFull:
pass
def get_channel_layer(alias=DEFAULT_CHANNEL_LAYER):
"""
Returns a channel layer by alias, or None if it is not configured.
"""
try:
return channel_layers[alias]
except KeyError:
return None
# Default global instance of the channel layer manager
channel_layers = ChannelLayerManager()