mirror of
https://github.com/pacnpal/thrillwiki_django_no_react.git
synced 2025-12-22 21:51:10 -05:00
okay fine
This commit is contained in:
363
.venv/lib/python3.12/site-packages/channels/layers.py
Normal file
363
.venv/lib/python3.12/site-packages/channels/layers.py
Normal file
@@ -0,0 +1,363 @@
|
||||
import asyncio
|
||||
import fnmatch
|
||||
import random
|
||||
import re
|
||||
import string
|
||||
import time
|
||||
from copy import deepcopy
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.signals import setting_changed
|
||||
from django.utils.module_loading import import_string
|
||||
|
||||
from channels import DEFAULT_CHANNEL_LAYER
|
||||
|
||||
from .exceptions import ChannelFull, InvalidChannelLayerError
|
||||
|
||||
|
||||
class ChannelLayerManager:
|
||||
"""
|
||||
Takes a settings dictionary of backends and initialises them on request.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.backends = {}
|
||||
setting_changed.connect(self._reset_backends)
|
||||
|
||||
def _reset_backends(self, setting, **kwargs):
|
||||
"""
|
||||
Removes cached channel layers when the CHANNEL_LAYERS setting changes.
|
||||
"""
|
||||
if setting == "CHANNEL_LAYERS":
|
||||
self.backends = {}
|
||||
|
||||
@property
|
||||
def configs(self):
|
||||
# Lazy load settings so we can be imported
|
||||
return getattr(settings, "CHANNEL_LAYERS", {})
|
||||
|
||||
def make_backend(self, name):
|
||||
"""
|
||||
Instantiate channel layer.
|
||||
"""
|
||||
config = self.configs[name].get("CONFIG", {})
|
||||
return self._make_backend(name, config)
|
||||
|
||||
def make_test_backend(self, name):
|
||||
"""
|
||||
Instantiate channel layer using its test config.
|
||||
"""
|
||||
try:
|
||||
config = self.configs[name]["TEST_CONFIG"]
|
||||
except KeyError:
|
||||
raise InvalidChannelLayerError("No TEST_CONFIG specified for %s" % name)
|
||||
return self._make_backend(name, config)
|
||||
|
||||
def _make_backend(self, name, config):
|
||||
# Check for old format config
|
||||
if "ROUTING" in self.configs[name]:
|
||||
raise InvalidChannelLayerError(
|
||||
"ROUTING key found for %s - this is no longer needed in Channels 2."
|
||||
% name
|
||||
)
|
||||
# Load the backend class
|
||||
try:
|
||||
backend_class = import_string(self.configs[name]["BACKEND"])
|
||||
except KeyError:
|
||||
raise InvalidChannelLayerError("No BACKEND specified for %s" % name)
|
||||
except ImportError:
|
||||
raise InvalidChannelLayerError(
|
||||
"Cannot import BACKEND %r specified for %s"
|
||||
% (self.configs[name]["BACKEND"], name)
|
||||
)
|
||||
# Initialise and pass config
|
||||
return backend_class(**config)
|
||||
|
||||
def __getitem__(self, key):
|
||||
if key not in self.backends:
|
||||
self.backends[key] = self.make_backend(key)
|
||||
return self.backends[key]
|
||||
|
||||
def __contains__(self, key):
|
||||
return key in self.configs
|
||||
|
||||
def set(self, key, layer):
|
||||
"""
|
||||
Sets an alias to point to a new ChannelLayerWrapper instance, and
|
||||
returns the old one that it replaced. Useful for swapping out the
|
||||
backend during tests.
|
||||
"""
|
||||
old = self.backends.get(key, None)
|
||||
self.backends[key] = layer
|
||||
return old
|
||||
|
||||
|
||||
class BaseChannelLayer:
|
||||
"""
|
||||
Base channel layer class that others can inherit from, with useful
|
||||
common functionality.
|
||||
"""
|
||||
|
||||
MAX_NAME_LENGTH = 100
|
||||
|
||||
def __init__(self, expiry=60, capacity=100, channel_capacity=None):
|
||||
self.expiry = expiry
|
||||
self.capacity = capacity
|
||||
self.channel_capacity = channel_capacity or {}
|
||||
|
||||
def compile_capacities(self, channel_capacity):
|
||||
"""
|
||||
Takes an input channel_capacity dict and returns the compiled list
|
||||
of regexes that get_capacity will look for as self.channel_capacity
|
||||
"""
|
||||
result = []
|
||||
for pattern, value in channel_capacity.items():
|
||||
# If they passed in a precompiled regex, leave it, else interpret
|
||||
# it as a glob.
|
||||
if hasattr(pattern, "match"):
|
||||
result.append((pattern, value))
|
||||
else:
|
||||
result.append((re.compile(fnmatch.translate(pattern)), value))
|
||||
return result
|
||||
|
||||
def get_capacity(self, channel):
|
||||
"""
|
||||
Gets the correct capacity for the given channel; either the default,
|
||||
or a matching result from channel_capacity. Returns the first matching
|
||||
result; if you want to control the order of matches, use an ordered dict
|
||||
as input.
|
||||
"""
|
||||
for pattern, capacity in self.channel_capacity:
|
||||
if pattern.match(channel):
|
||||
return capacity
|
||||
return self.capacity
|
||||
|
||||
def match_type_and_length(self, name):
|
||||
if isinstance(name, str) and (len(name) < self.MAX_NAME_LENGTH):
|
||||
return True
|
||||
return False
|
||||
|
||||
# Name validation functions
|
||||
|
||||
channel_name_regex = re.compile(r"^[a-zA-Z\d\-_.]+(\![\d\w\-_.]*)?$")
|
||||
group_name_regex = re.compile(r"^[a-zA-Z\d\-_.]+$")
|
||||
invalid_name_error = (
|
||||
"{} name must be a valid unicode string "
|
||||
+ "with length < {} ".format(MAX_NAME_LENGTH)
|
||||
+ "containing only ASCII alphanumerics, hyphens, underscores, or periods, "
|
||||
+ "not {}"
|
||||
)
|
||||
|
||||
def valid_channel_name(self, name, receive=False):
|
||||
if self.match_type_and_length(name):
|
||||
if bool(self.channel_name_regex.match(name)):
|
||||
# Check cases for special channels
|
||||
if "!" in name and not name.endswith("!") and receive:
|
||||
raise TypeError(
|
||||
"Specific channel names in receive() must end at the !"
|
||||
)
|
||||
return True
|
||||
raise TypeError(self.invalid_name_error.format("Channel", name))
|
||||
|
||||
def valid_group_name(self, name):
|
||||
if self.match_type_and_length(name):
|
||||
if bool(self.group_name_regex.match(name)):
|
||||
return True
|
||||
raise TypeError(self.invalid_name_error.format("Group", name))
|
||||
|
||||
def valid_channel_names(self, names, receive=False):
|
||||
_non_empty_list = True if names else False
|
||||
_names_type = isinstance(names, list)
|
||||
assert _non_empty_list and _names_type, "names must be a non-empty list"
|
||||
|
||||
assert all(
|
||||
self.valid_channel_name(channel, receive=receive) for channel in names
|
||||
)
|
||||
return True
|
||||
|
||||
def non_local_name(self, name):
|
||||
"""
|
||||
Given a channel name, returns the "non-local" part. If the channel name
|
||||
is a process-specific channel (contains !) this means the part up to
|
||||
and including the !; if it is anything else, this means the full name.
|
||||
"""
|
||||
if "!" in name:
|
||||
return name[: name.find("!") + 1]
|
||||
else:
|
||||
return name
|
||||
|
||||
|
||||
class InMemoryChannelLayer(BaseChannelLayer):
|
||||
"""
|
||||
In-memory channel layer implementation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
expiry=60,
|
||||
group_expiry=86400,
|
||||
capacity=100,
|
||||
channel_capacity=None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
expiry=expiry,
|
||||
capacity=capacity,
|
||||
channel_capacity=channel_capacity,
|
||||
**kwargs
|
||||
)
|
||||
self.channels = {}
|
||||
self.groups = {}
|
||||
self.group_expiry = group_expiry
|
||||
|
||||
# Channel layer API
|
||||
|
||||
extensions = ["groups", "flush"]
|
||||
|
||||
async def send(self, channel, message):
|
||||
"""
|
||||
Send a message onto a (general or specific) channel.
|
||||
"""
|
||||
# Typecheck
|
||||
assert isinstance(message, dict), "message is not a dict"
|
||||
assert self.valid_channel_name(channel), "Channel name not valid"
|
||||
# If it's a process-local channel, strip off local part and stick full
|
||||
# name in message
|
||||
assert "__asgi_channel__" not in message
|
||||
|
||||
queue = self.channels.setdefault(channel, asyncio.Queue())
|
||||
# Are we full
|
||||
if queue.qsize() >= self.capacity:
|
||||
raise ChannelFull(channel)
|
||||
|
||||
# Add message
|
||||
await queue.put((time.time() + self.expiry, deepcopy(message)))
|
||||
|
||||
async def receive(self, channel):
|
||||
"""
|
||||
Receive the first message that arrives on the channel.
|
||||
If more than one coroutine waits on the same channel, a random one
|
||||
of the waiting coroutines will get the result.
|
||||
"""
|
||||
assert self.valid_channel_name(channel)
|
||||
self._clean_expired()
|
||||
|
||||
queue = self.channels.setdefault(channel, asyncio.Queue())
|
||||
|
||||
# Do a plain direct receive
|
||||
try:
|
||||
_, message = await queue.get()
|
||||
finally:
|
||||
if queue.empty():
|
||||
del self.channels[channel]
|
||||
|
||||
return message
|
||||
|
||||
async def new_channel(self, prefix="specific."):
|
||||
"""
|
||||
Returns a new channel name that can be used by something in our
|
||||
process as a specific channel.
|
||||
"""
|
||||
return "%s.inmemory!%s" % (
|
||||
prefix,
|
||||
"".join(random.choice(string.ascii_letters) for i in range(12)),
|
||||
)
|
||||
|
||||
# Expire cleanup
|
||||
|
||||
def _clean_expired(self):
|
||||
"""
|
||||
Goes through all messages and groups and removes those that are expired.
|
||||
Any channel with an expired message is removed from all groups.
|
||||
"""
|
||||
# Channel cleanup
|
||||
for channel, queue in list(self.channels.items()):
|
||||
# See if it's expired
|
||||
while not queue.empty() and queue._queue[0][0] < time.time():
|
||||
queue.get_nowait()
|
||||
# Any removal prompts group discard
|
||||
self._remove_from_groups(channel)
|
||||
# Is the channel now empty and needs deleting?
|
||||
if queue.empty():
|
||||
del self.channels[channel]
|
||||
|
||||
# Group Expiration
|
||||
timeout = int(time.time()) - self.group_expiry
|
||||
for group in self.groups:
|
||||
for channel in list(self.groups.get(group, set())):
|
||||
# If join time is older than group_expiry end the group membership
|
||||
if (
|
||||
self.groups[group][channel]
|
||||
and int(self.groups[group][channel]) < timeout
|
||||
):
|
||||
# Delete from group
|
||||
del self.groups[group][channel]
|
||||
|
||||
# Flush extension
|
||||
|
||||
async def flush(self):
|
||||
self.channels = {}
|
||||
self.groups = {}
|
||||
|
||||
async def close(self):
|
||||
# Nothing to go
|
||||
pass
|
||||
|
||||
def _remove_from_groups(self, channel):
|
||||
"""
|
||||
Removes a channel from all groups. Used when a message on it expires.
|
||||
"""
|
||||
for channels in self.groups.values():
|
||||
if channel in channels:
|
||||
del channels[channel]
|
||||
|
||||
# Groups extension
|
||||
|
||||
async def group_add(self, group, channel):
|
||||
"""
|
||||
Adds the channel name to a group.
|
||||
"""
|
||||
# Check the inputs
|
||||
assert self.valid_group_name(group), "Group name not valid"
|
||||
assert self.valid_channel_name(channel), "Channel name not valid"
|
||||
# Add to group dict
|
||||
self.groups.setdefault(group, {})
|
||||
self.groups[group][channel] = time.time()
|
||||
|
||||
async def group_discard(self, group, channel):
|
||||
# Both should be text and valid
|
||||
assert self.valid_channel_name(channel), "Invalid channel name"
|
||||
assert self.valid_group_name(group), "Invalid group name"
|
||||
# Remove from group set
|
||||
if group in self.groups:
|
||||
if channel in self.groups[group]:
|
||||
del self.groups[group][channel]
|
||||
if not self.groups[group]:
|
||||
del self.groups[group]
|
||||
|
||||
async def group_send(self, group, message):
|
||||
# Check types
|
||||
assert isinstance(message, dict), "Message is not a dict"
|
||||
assert self.valid_group_name(group), "Invalid group name"
|
||||
# Run clean
|
||||
self._clean_expired()
|
||||
# Send to each channel
|
||||
for channel in self.groups.get(group, set()):
|
||||
try:
|
||||
await self.send(channel, message)
|
||||
except ChannelFull:
|
||||
pass
|
||||
|
||||
|
||||
def get_channel_layer(alias=DEFAULT_CHANNEL_LAYER):
|
||||
"""
|
||||
Returns a channel layer by alias, or None if it is not configured.
|
||||
"""
|
||||
try:
|
||||
return channel_layers[alias]
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
|
||||
# Default global instance of the channel layer manager
|
||||
channel_layers = ChannelLayerManager()
|
||||
Reference in New Issue
Block a user