okay fine

This commit is contained in:
pacnpal
2024-11-03 17:47:26 +00:00
parent 01c6004a79
commit 27eb239e97
10020 changed files with 1935769 additions and 2364 deletions

View File

@@ -0,0 +1,4 @@
__version__ = "4.0.0"
DEFAULT_CHANNEL_LAYER = "default"

View File

@@ -0,0 +1,7 @@
from django.apps import AppConfig
class ChannelsConfig(AppConfig):
name = "channels"
verbose_name = "Channels"

View File

@@ -0,0 +1,190 @@
from django.conf import settings
from django.contrib.auth import (
BACKEND_SESSION_KEY,
HASH_SESSION_KEY,
SESSION_KEY,
_get_backends,
get_user_model,
load_backend,
user_logged_in,
user_logged_out,
)
from django.utils.crypto import constant_time_compare
from django.utils.functional import LazyObject
from channels.db import database_sync_to_async
from channels.middleware import BaseMiddleware
from channels.sessions import CookieMiddleware, SessionMiddleware
@database_sync_to_async
def get_user(scope):
"""
Return the user model instance associated with the given scope.
If no user is retrieved, return an instance of `AnonymousUser`.
"""
# postpone model import to avoid ImproperlyConfigured error before Django
# setup is complete.
from django.contrib.auth.models import AnonymousUser
if "session" not in scope:
raise ValueError(
"Cannot find session in scope. You should wrap your consumer in "
"SessionMiddleware."
)
session = scope["session"]
user = None
try:
user_id = _get_user_session_key(session)
backend_path = session[BACKEND_SESSION_KEY]
except KeyError:
pass
else:
if backend_path in settings.AUTHENTICATION_BACKENDS:
backend = load_backend(backend_path)
user = backend.get_user(user_id)
# Verify the session
if hasattr(user, "get_session_auth_hash"):
session_hash = session.get(HASH_SESSION_KEY)
session_hash_verified = session_hash and constant_time_compare(
session_hash, user.get_session_auth_hash()
)
if not session_hash_verified:
session.flush()
user = None
return user or AnonymousUser()
@database_sync_to_async
def login(scope, user, backend=None):
"""
Persist a user id and a backend in the request.
This way a user doesn't have to re-authenticate on every request.
Note that data set during the anonymous session is retained when the user
logs in.
"""
if "session" not in scope:
raise ValueError(
"Cannot find session in scope. You should wrap your consumer in "
"SessionMiddleware."
)
session = scope["session"]
session_auth_hash = ""
if user is None:
user = scope.get("user", None)
if user is None:
raise ValueError(
"User must be passed as an argument or must be present in the scope."
)
if hasattr(user, "get_session_auth_hash"):
session_auth_hash = user.get_session_auth_hash()
if SESSION_KEY in session:
if _get_user_session_key(session) != user.pk or (
session_auth_hash
and not constant_time_compare(
session.get(HASH_SESSION_KEY, ""), session_auth_hash
)
):
# To avoid reusing another user's session, create a new, empty
# session if the existing session corresponds to a different
# authenticated user.
session.flush()
else:
session.cycle_key()
try:
backend = backend or user.backend
except AttributeError:
backends = _get_backends(return_tuples=True)
if len(backends) == 1:
_, backend = backends[0]
else:
raise ValueError(
"You have multiple authentication backends configured and "
"therefore must provide the `backend` "
"argument or set the `backend` attribute on the user."
)
session[SESSION_KEY] = user._meta.pk.value_to_string(user)
session[BACKEND_SESSION_KEY] = backend
session[HASH_SESSION_KEY] = session_auth_hash
scope["user"] = user
# note this does not reset the CSRF_COOKIE/Token
user_logged_in.send(sender=user.__class__, request=None, user=user)
@database_sync_to_async
def logout(scope):
"""
Remove the authenticated user's ID from the request and flush their session
data.
"""
# postpone model import to avoid ImproperlyConfigured error before Django
# setup is complete.
from django.contrib.auth.models import AnonymousUser
if "session" not in scope:
raise ValueError(
"Login cannot find session in scope. You should wrap your "
"consumer in SessionMiddleware."
)
session = scope["session"]
# Dispatch the signal before the user is logged out so the receivers have a
# chance to find out *who* logged out.
user = scope.get("user", None)
if hasattr(user, "is_authenticated") and not user.is_authenticated:
user = None
if user is not None:
user_logged_out.send(sender=user.__class__, request=None, user=user)
session.flush()
if "user" in scope:
scope["user"] = AnonymousUser()
def _get_user_session_key(session):
# This value in the session is always serialized to a string, so we need
# to convert it back to Python whenever we access it.
return get_user_model()._meta.pk.to_python(session[SESSION_KEY])
class UserLazyObject(LazyObject):
"""
Throw a more useful error message when scope['user'] is accessed before
it's resolved
"""
def _setup(self):
raise ValueError("Accessing scope user before it is ready.")
class AuthMiddleware(BaseMiddleware):
"""
Middleware which populates scope["user"] from a Django session.
Requires SessionMiddleware to function.
"""
def populate_scope(self, scope):
# Make sure we have a session
if "session" not in scope:
raise ValueError(
"AuthMiddleware cannot find session in scope. "
"SessionMiddleware must be above it."
)
# Add it to the scope if it's not there already
if "user" not in scope:
scope["user"] = UserLazyObject()
async def resolve_scope(self, scope):
scope["user"]._wrapped = await get_user(scope)
async def __call__(self, scope, receive, send):
scope = dict(scope)
# Scope injection/mutation per this middleware's needs.
self.populate_scope(scope)
# Grab the finalized/resolved scope
await self.resolve_scope(scope)
return await super().__call__(scope, receive, send)
# Handy shortcut for applying all three layers at once
def AuthMiddlewareStack(inner):
return CookieMiddleware(SessionMiddleware(AuthMiddleware(inner)))

View File

@@ -0,0 +1,133 @@
import functools
from asgiref.sync import async_to_sync
from . import DEFAULT_CHANNEL_LAYER
from .db import database_sync_to_async
from .exceptions import StopConsumer
from .layers import get_channel_layer
from .utils import await_many_dispatch
def get_handler_name(message):
"""
Looks at a message, checks it has a sensible type, and returns the
handler name for that type.
"""
# Check message looks OK
if "type" not in message:
raise ValueError("Incoming message has no 'type' attribute")
# Extract type and replace . with _
handler_name = message["type"].replace(".", "_")
if handler_name.startswith("_"):
raise ValueError("Malformed type in message (leading underscore)")
return handler_name
class AsyncConsumer:
"""
Base consumer class. Implements the ASGI application spec, and adds on
channel layer management and routing of events to named methods based
on their type.
"""
_sync = False
channel_layer_alias = DEFAULT_CHANNEL_LAYER
async def __call__(self, scope, receive, send):
"""
Dispatches incoming messages to type-based handlers asynchronously.
"""
self.scope = scope
# Initialize channel layer
self.channel_layer = get_channel_layer(self.channel_layer_alias)
if self.channel_layer is not None:
self.channel_name = await self.channel_layer.new_channel()
self.channel_receive = functools.partial(
self.channel_layer.receive, self.channel_name
)
# Store send function
if self._sync:
self.base_send = async_to_sync(send)
else:
self.base_send = send
# Pass messages in from channel layer or client to dispatch method
try:
if self.channel_layer is not None:
await await_many_dispatch(
[receive, self.channel_receive], self.dispatch
)
else:
await await_many_dispatch([receive], self.dispatch)
except StopConsumer:
# Exit cleanly
pass
async def dispatch(self, message):
"""
Works out what to do with a message.
"""
handler = getattr(self, get_handler_name(message), None)
if handler:
await handler(message)
else:
raise ValueError("No handler for message type %s" % message["type"])
async def send(self, message):
"""
Overrideable/callable-by-subclasses send method.
"""
await self.base_send(message)
@classmethod
def as_asgi(cls, **initkwargs):
"""
Return an ASGI v3 single callable that instantiates a consumer instance
per scope. Similar in purpose to Django's as_view().
initkwargs will be used to instantiate the consumer instance.
"""
async def app(scope, receive, send):
consumer = cls(**initkwargs)
return await consumer(scope, receive, send)
app.consumer_class = cls
app.consumer_initkwargs = initkwargs
# take name and docstring from class
functools.update_wrapper(app, cls, updated=())
return app
class SyncConsumer(AsyncConsumer):
"""
Synchronous version of the consumer, which is what we write most of the
generic consumers against (for now). Calls handlers in a threadpool and
uses CallBouncer to get the send method out to the main event loop.
It would have been possible to have "mixed" consumers and auto-detect
if a handler was awaitable or not, but that would have made the API
for user-called methods very confusing as there'd be two types of each.
"""
_sync = True
@database_sync_to_async
def dispatch(self, message):
"""
Dispatches incoming messages to type-based handlers asynchronously.
"""
# Get and execute the handler
handler = getattr(self, get_handler_name(message), None)
if handler:
handler(message)
else:
raise ValueError("No handler for message type %s" % message["type"])
def send(self, message):
"""
Overrideable/callable-by-subclasses send method.
"""
self.base_send(message)

View File

@@ -0,0 +1,19 @@
from asgiref.sync import SyncToAsync
from django.db import close_old_connections
class DatabaseSyncToAsync(SyncToAsync):
"""
SyncToAsync version that cleans up old database connections when it exits.
"""
def thread_handler(self, loop, *args, **kwargs):
close_old_connections()
try:
return super().thread_handler(loop, *args, **kwargs)
finally:
close_old_connections()
# The class is TitleCased, but we want to encourage use as a callable/decorator
database_sync_to_async = DatabaseSyncToAsync

View File

@@ -0,0 +1,65 @@
class RequestAborted(Exception):
"""
Raised when the incoming request tells us it's aborted partway through
reading the body.
"""
pass
class RequestTimeout(RequestAborted):
"""
Aborted specifically due to timeout.
"""
pass
class InvalidChannelLayerError(ValueError):
"""
Raised when a channel layer is configured incorrectly.
"""
pass
class AcceptConnection(Exception):
"""
Raised during a websocket.connect (or other supported connection) handler
to accept the connection.
"""
pass
class DenyConnection(Exception):
"""
Raised during a websocket.connect (or other supported connection) handler
to deny the connection.
"""
pass
class ChannelFull(Exception):
"""
Raised when a channel cannot be sent to as it is over capacity.
"""
pass
class MessageTooLarge(Exception):
"""
Raised when a message cannot be sent as it's too big.
"""
pass
class StopConsumer(Exception):
"""
Raised when a consumer wants to stop and close down its application instance.
"""
pass

View File

@@ -0,0 +1,91 @@
from channels.consumer import AsyncConsumer
from ..exceptions import StopConsumer
class AsyncHttpConsumer(AsyncConsumer):
"""
Async HTTP consumer. Provides basic primitives for building asynchronous
HTTP endpoints.
"""
def __init__(self, *args, **kwargs):
self.body = []
async def send_headers(self, *, status=200, headers=None):
"""
Sets the HTTP response status and headers. Headers may be provided as
a list of tuples or as a dictionary.
Note that the ASGI spec requires that the protocol server only starts
sending the response to the client after ``self.send_body`` has been
called the first time.
"""
if headers is None:
headers = []
elif isinstance(headers, dict):
headers = list(headers.items())
await self.send(
{"type": "http.response.start", "status": status, "headers": headers}
)
async def send_body(self, body, *, more_body=False):
"""
Sends a response body to the client. The method expects a bytestring.
Set ``more_body=True`` if you want to send more body content later.
The default behavior closes the response, and further messages on
the channel will be ignored.
"""
assert isinstance(body, bytes), "Body is not bytes"
await self.send(
{"type": "http.response.body", "body": body, "more_body": more_body}
)
async def send_response(self, status, body, **kwargs):
"""
Sends a response to the client. This is a thin wrapper over
``self.send_headers`` and ``self.send_body``, and everything said
above applies here as well. This method may only be called once.
"""
await self.send_headers(status=status, **kwargs)
await self.send_body(body)
async def handle(self, body):
"""
Receives the request body as a bytestring. Response may be composed
using the ``self.send*`` methods; the return value of this method is
thrown away.
"""
raise NotImplementedError(
"Subclasses of AsyncHttpConsumer must provide a handle() method."
)
async def disconnect(self):
"""
Overrideable place to run disconnect handling. Do not send anything
from here.
"""
pass
async def http_request(self, message):
"""
Async entrypoint - concatenates body fragments and hands off control
to ``self.handle`` when the body has been completely received.
"""
if "body" in message:
self.body.append(message["body"])
if not message.get("more_body"):
try:
await self.handle(b"".join(self.body))
finally:
await self.disconnect()
raise StopConsumer()
async def http_disconnect(self, message):
"""
Let the user do their cleanup and close the consumer.
"""
await self.disconnect()
raise StopConsumer()

View File

@@ -0,0 +1,279 @@
import json
from asgiref.sync import async_to_sync
from ..consumer import AsyncConsumer, SyncConsumer
from ..exceptions import (
AcceptConnection,
DenyConnection,
InvalidChannelLayerError,
StopConsumer,
)
class WebsocketConsumer(SyncConsumer):
"""
Base WebSocket consumer. Provides a general encapsulation for the
WebSocket handling model that other applications can build on.
"""
groups = None
def __init__(self, *args, **kwargs):
if self.groups is None:
self.groups = []
def websocket_connect(self, message):
"""
Called when a WebSocket connection is opened.
"""
try:
for group in self.groups:
async_to_sync(self.channel_layer.group_add)(group, self.channel_name)
except AttributeError:
raise InvalidChannelLayerError(
"BACKEND is unconfigured or doesn't support groups"
)
try:
self.connect()
except AcceptConnection:
self.accept()
except DenyConnection:
self.close()
def connect(self):
self.accept()
def accept(self, subprotocol=None):
"""
Accepts an incoming socket
"""
super().send({"type": "websocket.accept", "subprotocol": subprotocol})
def websocket_receive(self, message):
"""
Called when a WebSocket frame is received. Decodes it and passes it
to receive().
"""
if "text" in message:
self.receive(text_data=message["text"])
else:
self.receive(bytes_data=message["bytes"])
def receive(self, text_data=None, bytes_data=None):
"""
Called with a decoded WebSocket frame.
"""
pass
def send(self, text_data=None, bytes_data=None, close=False):
"""
Sends a reply back down the WebSocket
"""
if text_data is not None:
super().send({"type": "websocket.send", "text": text_data})
elif bytes_data is not None:
super().send({"type": "websocket.send", "bytes": bytes_data})
else:
raise ValueError("You must pass one of bytes_data or text_data")
if close:
self.close(close)
def close(self, code=None):
"""
Closes the WebSocket from the server end
"""
if code is not None and code is not True:
super().send({"type": "websocket.close", "code": code})
else:
super().send({"type": "websocket.close"})
def websocket_disconnect(self, message):
"""
Called when a WebSocket connection is closed. Base level so you don't
need to call super() all the time.
"""
try:
for group in self.groups:
async_to_sync(self.channel_layer.group_discard)(
group, self.channel_name
)
except AttributeError:
raise InvalidChannelLayerError(
"BACKEND is unconfigured or doesn't support groups"
)
self.disconnect(message["code"])
raise StopConsumer()
def disconnect(self, code):
"""
Called when a WebSocket connection is closed.
"""
pass
class JsonWebsocketConsumer(WebsocketConsumer):
"""
Variant of WebsocketConsumer that automatically JSON-encodes and decodes
messages as they come in and go out. Expects everything to be text; will
error on binary data.
"""
def receive(self, text_data=None, bytes_data=None, **kwargs):
if text_data:
self.receive_json(self.decode_json(text_data), **kwargs)
else:
raise ValueError("No text section for incoming WebSocket frame!")
def receive_json(self, content, **kwargs):
"""
Called with decoded JSON content.
"""
pass
def send_json(self, content, close=False):
"""
Encode the given content as JSON and send it to the client.
"""
super().send(text_data=self.encode_json(content), close=close)
@classmethod
def decode_json(cls, text_data):
return json.loads(text_data)
@classmethod
def encode_json(cls, content):
return json.dumps(content)
class AsyncWebsocketConsumer(AsyncConsumer):
"""
Base WebSocket consumer, async version. Provides a general encapsulation
for the WebSocket handling model that other applications can build on.
"""
groups = None
def __init__(self, *args, **kwargs):
if self.groups is None:
self.groups = []
async def websocket_connect(self, message):
"""
Called when a WebSocket connection is opened.
"""
try:
for group in self.groups:
await self.channel_layer.group_add(group, self.channel_name)
except AttributeError:
raise InvalidChannelLayerError(
"BACKEND is unconfigured or doesn't support groups"
)
try:
await self.connect()
except AcceptConnection:
await self.accept()
except DenyConnection:
await self.close()
async def connect(self):
await self.accept()
async def accept(self, subprotocol=None):
"""
Accepts an incoming socket
"""
await super().send({"type": "websocket.accept", "subprotocol": subprotocol})
async def websocket_receive(self, message):
"""
Called when a WebSocket frame is received. Decodes it and passes it
to receive().
"""
if "text" in message:
await self.receive(text_data=message["text"])
else:
await self.receive(bytes_data=message["bytes"])
async def receive(self, text_data=None, bytes_data=None):
"""
Called with a decoded WebSocket frame.
"""
pass
async def send(self, text_data=None, bytes_data=None, close=False):
"""
Sends a reply back down the WebSocket
"""
if text_data is not None:
await super().send({"type": "websocket.send", "text": text_data})
elif bytes_data is not None:
await super().send({"type": "websocket.send", "bytes": bytes_data})
else:
raise ValueError("You must pass one of bytes_data or text_data")
if close:
await self.close(close)
async def close(self, code=None):
"""
Closes the WebSocket from the server end
"""
if code is not None and code is not True:
await super().send({"type": "websocket.close", "code": code})
else:
await super().send({"type": "websocket.close"})
async def websocket_disconnect(self, message):
"""
Called when a WebSocket connection is closed. Base level so you don't
need to call super() all the time.
"""
try:
for group in self.groups:
await self.channel_layer.group_discard(group, self.channel_name)
except AttributeError:
raise InvalidChannelLayerError(
"BACKEND is unconfigured or doesn't support groups"
)
await self.disconnect(message["code"])
raise StopConsumer()
async def disconnect(self, code):
"""
Called when a WebSocket connection is closed.
"""
pass
class AsyncJsonWebsocketConsumer(AsyncWebsocketConsumer):
"""
Variant of AsyncWebsocketConsumer that automatically JSON-encodes and decodes
messages as they come in and go out. Expects everything to be text; will
error on binary data.
"""
async def receive(self, text_data=None, bytes_data=None, **kwargs):
if text_data:
await self.receive_json(await self.decode_json(text_data), **kwargs)
else:
raise ValueError("No text section for incoming WebSocket frame!")
async def receive_json(self, content, **kwargs):
"""
Called with decoded JSON content.
"""
pass
async def send_json(self, content, close=False):
"""
Encode the given content as JSON and send it to the client.
"""
await super().send(text_data=await self.encode_json(content), close=close)
@classmethod
async def decode_json(cls, text_data):
return json.loads(text_data)
@classmethod
async def encode_json(cls, content):
return json.dumps(content)

View File

@@ -0,0 +1,363 @@
import asyncio
import fnmatch
import random
import re
import string
import time
from copy import deepcopy
from django.conf import settings
from django.core.signals import setting_changed
from django.utils.module_loading import import_string
from channels import DEFAULT_CHANNEL_LAYER
from .exceptions import ChannelFull, InvalidChannelLayerError
class ChannelLayerManager:
"""
Takes a settings dictionary of backends and initialises them on request.
"""
def __init__(self):
self.backends = {}
setting_changed.connect(self._reset_backends)
def _reset_backends(self, setting, **kwargs):
"""
Removes cached channel layers when the CHANNEL_LAYERS setting changes.
"""
if setting == "CHANNEL_LAYERS":
self.backends = {}
@property
def configs(self):
# Lazy load settings so we can be imported
return getattr(settings, "CHANNEL_LAYERS", {})
def make_backend(self, name):
"""
Instantiate channel layer.
"""
config = self.configs[name].get("CONFIG", {})
return self._make_backend(name, config)
def make_test_backend(self, name):
"""
Instantiate channel layer using its test config.
"""
try:
config = self.configs[name]["TEST_CONFIG"]
except KeyError:
raise InvalidChannelLayerError("No TEST_CONFIG specified for %s" % name)
return self._make_backend(name, config)
def _make_backend(self, name, config):
# Check for old format config
if "ROUTING" in self.configs[name]:
raise InvalidChannelLayerError(
"ROUTING key found for %s - this is no longer needed in Channels 2."
% name
)
# Load the backend class
try:
backend_class = import_string(self.configs[name]["BACKEND"])
except KeyError:
raise InvalidChannelLayerError("No BACKEND specified for %s" % name)
except ImportError:
raise InvalidChannelLayerError(
"Cannot import BACKEND %r specified for %s"
% (self.configs[name]["BACKEND"], name)
)
# Initialise and pass config
return backend_class(**config)
def __getitem__(self, key):
if key not in self.backends:
self.backends[key] = self.make_backend(key)
return self.backends[key]
def __contains__(self, key):
return key in self.configs
def set(self, key, layer):
"""
Sets an alias to point to a new ChannelLayerWrapper instance, and
returns the old one that it replaced. Useful for swapping out the
backend during tests.
"""
old = self.backends.get(key, None)
self.backends[key] = layer
return old
class BaseChannelLayer:
"""
Base channel layer class that others can inherit from, with useful
common functionality.
"""
MAX_NAME_LENGTH = 100
def __init__(self, expiry=60, capacity=100, channel_capacity=None):
self.expiry = expiry
self.capacity = capacity
self.channel_capacity = channel_capacity or {}
def compile_capacities(self, channel_capacity):
"""
Takes an input channel_capacity dict and returns the compiled list
of regexes that get_capacity will look for as self.channel_capacity
"""
result = []
for pattern, value in channel_capacity.items():
# If they passed in a precompiled regex, leave it, else interpret
# it as a glob.
if hasattr(pattern, "match"):
result.append((pattern, value))
else:
result.append((re.compile(fnmatch.translate(pattern)), value))
return result
def get_capacity(self, channel):
"""
Gets the correct capacity for the given channel; either the default,
or a matching result from channel_capacity. Returns the first matching
result; if you want to control the order of matches, use an ordered dict
as input.
"""
for pattern, capacity in self.channel_capacity:
if pattern.match(channel):
return capacity
return self.capacity
def match_type_and_length(self, name):
if isinstance(name, str) and (len(name) < self.MAX_NAME_LENGTH):
return True
return False
# Name validation functions
channel_name_regex = re.compile(r"^[a-zA-Z\d\-_.]+(\![\d\w\-_.]*)?$")
group_name_regex = re.compile(r"^[a-zA-Z\d\-_.]+$")
invalid_name_error = (
"{} name must be a valid unicode string "
+ "with length < {} ".format(MAX_NAME_LENGTH)
+ "containing only ASCII alphanumerics, hyphens, underscores, or periods, "
+ "not {}"
)
def valid_channel_name(self, name, receive=False):
if self.match_type_and_length(name):
if bool(self.channel_name_regex.match(name)):
# Check cases for special channels
if "!" in name and not name.endswith("!") and receive:
raise TypeError(
"Specific channel names in receive() must end at the !"
)
return True
raise TypeError(self.invalid_name_error.format("Channel", name))
def valid_group_name(self, name):
if self.match_type_and_length(name):
if bool(self.group_name_regex.match(name)):
return True
raise TypeError(self.invalid_name_error.format("Group", name))
def valid_channel_names(self, names, receive=False):
_non_empty_list = True if names else False
_names_type = isinstance(names, list)
assert _non_empty_list and _names_type, "names must be a non-empty list"
assert all(
self.valid_channel_name(channel, receive=receive) for channel in names
)
return True
def non_local_name(self, name):
"""
Given a channel name, returns the "non-local" part. If the channel name
is a process-specific channel (contains !) this means the part up to
and including the !; if it is anything else, this means the full name.
"""
if "!" in name:
return name[: name.find("!") + 1]
else:
return name
class InMemoryChannelLayer(BaseChannelLayer):
"""
In-memory channel layer implementation
"""
def __init__(
self,
expiry=60,
group_expiry=86400,
capacity=100,
channel_capacity=None,
**kwargs
):
super().__init__(
expiry=expiry,
capacity=capacity,
channel_capacity=channel_capacity,
**kwargs
)
self.channels = {}
self.groups = {}
self.group_expiry = group_expiry
# Channel layer API
extensions = ["groups", "flush"]
async def send(self, channel, message):
"""
Send a message onto a (general or specific) channel.
"""
# Typecheck
assert isinstance(message, dict), "message is not a dict"
assert self.valid_channel_name(channel), "Channel name not valid"
# If it's a process-local channel, strip off local part and stick full
# name in message
assert "__asgi_channel__" not in message
queue = self.channels.setdefault(channel, asyncio.Queue())
# Are we full
if queue.qsize() >= self.capacity:
raise ChannelFull(channel)
# Add message
await queue.put((time.time() + self.expiry, deepcopy(message)))
async def receive(self, channel):
"""
Receive the first message that arrives on the channel.
If more than one coroutine waits on the same channel, a random one
of the waiting coroutines will get the result.
"""
assert self.valid_channel_name(channel)
self._clean_expired()
queue = self.channels.setdefault(channel, asyncio.Queue())
# Do a plain direct receive
try:
_, message = await queue.get()
finally:
if queue.empty():
del self.channels[channel]
return message
async def new_channel(self, prefix="specific."):
"""
Returns a new channel name that can be used by something in our
process as a specific channel.
"""
return "%s.inmemory!%s" % (
prefix,
"".join(random.choice(string.ascii_letters) for i in range(12)),
)
# Expire cleanup
def _clean_expired(self):
"""
Goes through all messages and groups and removes those that are expired.
Any channel with an expired message is removed from all groups.
"""
# Channel cleanup
for channel, queue in list(self.channels.items()):
# See if it's expired
while not queue.empty() and queue._queue[0][0] < time.time():
queue.get_nowait()
# Any removal prompts group discard
self._remove_from_groups(channel)
# Is the channel now empty and needs deleting?
if queue.empty():
del self.channels[channel]
# Group Expiration
timeout = int(time.time()) - self.group_expiry
for group in self.groups:
for channel in list(self.groups.get(group, set())):
# If join time is older than group_expiry end the group membership
if (
self.groups[group][channel]
and int(self.groups[group][channel]) < timeout
):
# Delete from group
del self.groups[group][channel]
# Flush extension
async def flush(self):
self.channels = {}
self.groups = {}
async def close(self):
# Nothing to go
pass
def _remove_from_groups(self, channel):
"""
Removes a channel from all groups. Used when a message on it expires.
"""
for channels in self.groups.values():
if channel in channels:
del channels[channel]
# Groups extension
async def group_add(self, group, channel):
"""
Adds the channel name to a group.
"""
# Check the inputs
assert self.valid_group_name(group), "Group name not valid"
assert self.valid_channel_name(channel), "Channel name not valid"
# Add to group dict
self.groups.setdefault(group, {})
self.groups[group][channel] = time.time()
async def group_discard(self, group, channel):
# Both should be text and valid
assert self.valid_channel_name(channel), "Invalid channel name"
assert self.valid_group_name(group), "Invalid group name"
# Remove from group set
if group in self.groups:
if channel in self.groups[group]:
del self.groups[group][channel]
if not self.groups[group]:
del self.groups[group]
async def group_send(self, group, message):
# Check types
assert isinstance(message, dict), "Message is not a dict"
assert self.valid_group_name(group), "Invalid group name"
# Run clean
self._clean_expired()
# Send to each channel
for channel in self.groups.get(group, set()):
try:
await self.send(channel, message)
except ChannelFull:
pass
def get_channel_layer(alias=DEFAULT_CHANNEL_LAYER):
"""
Returns a channel layer by alias, or None if it is not configured.
"""
try:
return channel_layers[alias]
except KeyError:
return None
# Default global instance of the channel layer manager
channel_layers = ChannelLayerManager()

View File

@@ -0,0 +1,46 @@
import logging
from django.core.management import BaseCommand, CommandError
from channels import DEFAULT_CHANNEL_LAYER
from channels.layers import get_channel_layer
from channels.routing import get_default_application
from channels.worker import Worker
logger = logging.getLogger("django.channels.worker")
class Command(BaseCommand):
leave_locale_alone = True
worker_class = Worker
def add_arguments(self, parser):
super(Command, self).add_arguments(parser)
parser.add_argument(
"--layer",
action="store",
dest="layer",
default=DEFAULT_CHANNEL_LAYER,
help="Channel layer alias to use, if not the default.",
)
parser.add_argument("channels", nargs="+", help="Channels to listen on.")
def handle(self, *args, **options):
# Get the backend to use
self.verbosity = options.get("verbosity", 1)
# Get the channel layer they asked for (or see if one isn't configured)
if "layer" in options:
self.channel_layer = get_channel_layer(options["layer"])
else:
self.channel_layer = get_channel_layer()
if self.channel_layer is None:
raise CommandError("You do not have any CHANNEL_LAYERS configured.")
# Run the worker
logger.info("Running worker for channels %s", options["channels"])
worker = self.worker_class(
application=get_default_application(),
channels=options["channels"],
channel_layer=self.channel_layer,
)
worker.run()

View File

@@ -0,0 +1,24 @@
class BaseMiddleware:
"""
Base class for implementing ASGI middleware.
Note that subclasses of this are not self-safe; don't store state on
the instance, as it serves multiple application instances. Instead, use
scope.
"""
def __init__(self, inner):
"""
Middleware constructor - just takes inner application.
"""
self.inner = inner
async def __call__(self, scope, receive, send):
"""
ASGI application; can insert things into the scope and run asynchronous
code.
"""
# Copy scope to stop changes going upstream
scope = dict(scope)
# Run the inner application along with the scope
return await self.inner(scope, receive, send)

View File

@@ -0,0 +1,158 @@
import importlib
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
from django.urls.exceptions import Resolver404
from django.urls.resolvers import URLResolver
"""
All Routing instances inside this file are also valid ASGI applications - with
new Channels routing, whatever you end up with as the top level object is just
served up as the "ASGI application".
"""
def get_default_application():
"""
Gets the default application, set in the ASGI_APPLICATION setting.
"""
try:
path, name = settings.ASGI_APPLICATION.rsplit(".", 1)
except (ValueError, AttributeError):
raise ImproperlyConfigured("Cannot find ASGI_APPLICATION setting.")
try:
module = importlib.import_module(path)
except ImportError:
raise ImproperlyConfigured("Cannot import ASGI_APPLICATION module %r" % path)
try:
value = getattr(module, name)
except AttributeError:
raise ImproperlyConfigured(
"Cannot find %r in ASGI_APPLICATION module %s" % (name, path)
)
return value
DEPRECATION_MSG = """
Using ProtocolTypeRouter without an explicit "http" key is deprecated.
Given that you have not passed the "http" you likely should use Django's
get_asgi_application():
from django.core.asgi import get_asgi_application
application = ProtocolTypeRouter(
"http": get_asgi_application()
# Other protocols here.
)
"""
class ProtocolTypeRouter:
"""
Takes a mapping of protocol type names to other Application instances,
and dispatches to the right one based on protocol name (or raises an error)
"""
def __init__(self, application_mapping):
self.application_mapping = application_mapping
async def __call__(self, scope, receive, send):
if scope["type"] in self.application_mapping:
application = self.application_mapping[scope["type"]]
return await application(scope, receive, send)
else:
raise ValueError(
"No application configured for scope type %r" % scope["type"]
)
class URLRouter:
"""
Routes to different applications/consumers based on the URL path.
Works with anything that has a ``path`` key, but intended for WebSocket
and HTTP. Uses Django's django.urls objects for resolution -
path() or re_path().
"""
#: This router wants to do routing based on scope[path] or
#: scope[path_remaining]. ``path()`` entries in URLRouter should not be
#: treated as endpoints (ended with ``$``), but similar to ``include()``.
_path_routing = True
def __init__(self, routes):
self.routes = routes
for route in self.routes:
# The inner ASGI app wants to do additional routing, route
# must not be an endpoint
if getattr(route.callback, "_path_routing", False) is True:
route.pattern._is_endpoint = False
if not route.callback and isinstance(route, URLResolver):
raise ImproperlyConfigured(
"%s: include() is not supported in URLRouter. Use nested"
" URLRouter instances instead." % (route,)
)
async def __call__(self, scope, receive, send):
# Get the path
path = scope.get("path_remaining", scope.get("path", None))
if path is None:
raise ValueError("No 'path' key in connection scope, cannot route URLs")
# Remove leading / to match Django's handling
path = path.lstrip("/")
# Run through the routes we have until one matches
for route in self.routes:
try:
match = route.pattern.match(path)
if match:
new_path, args, kwargs = match
# Add defaults to kwargs from the URL pattern.
kwargs.update(route.default_args)
# Add args or kwargs into the scope
outer = scope.get("url_route", {})
application = route.callback
return await application(
dict(
scope,
path_remaining=new_path,
url_route={
"args": outer.get("args", ()) + args,
"kwargs": {**outer.get("kwargs", {}), **kwargs},
},
),
receive,
send,
)
except Resolver404:
pass
else:
if "path_remaining" in scope:
raise Resolver404("No route found for path %r." % path)
# We are the outermost URLRouter
raise ValueError("No route found for path %r." % path)
class ChannelNameRouter:
"""
Maps to different applications based on a "channel" key in the scope
(intended for the Channels worker mode)
"""
def __init__(self, application_mapping):
self.application_mapping = application_mapping
async def __call__(self, scope, receive, send):
if "channel" not in scope:
raise ValueError(
"ChannelNameRouter got a scope without a 'channel' key. "
+ "Did you make sure it's only being used for 'channel' type messages?"
)
if scope["channel"] in self.application_mapping:
application = self.application_mapping[scope["channel"]]
return await application(scope, receive, send)
else:
raise ValueError(
"No application configured for channel name %r" % scope["channel"]
)

View File

@@ -0,0 +1,153 @@
from urllib.parse import urlparse
from django.conf import settings
from django.http.request import is_same_domain
from ..generic.websocket import AsyncWebsocketConsumer
class OriginValidator:
"""
Validates that the incoming connection has an Origin header that
is in an allowed list.
"""
def __init__(self, application, allowed_origins):
self.application = application
self.allowed_origins = allowed_origins
async def __call__(self, scope, receive, send):
# Make sure the scope is of type websocket
if scope["type"] != "websocket":
raise ValueError(
"You cannot use OriginValidator on a non-WebSocket connection"
)
# Extract the Origin header
parsed_origin = None
for header_name, header_value in scope.get("headers", []):
if header_name == b"origin":
try:
# Set ResultParse
parsed_origin = urlparse(header_value.decode("latin1"))
except UnicodeDecodeError:
pass
# Check to see if the origin header is valid
if self.valid_origin(parsed_origin):
# Pass control to the application
return await self.application(scope, receive, send)
else:
# Deny the connection
denier = WebsocketDenier()
return await denier(scope, receive, send)
def valid_origin(self, parsed_origin):
"""
Checks parsed origin is None.
Pass control to the validate_origin function.
Returns ``True`` if validation function was successful, ``False`` otherwise.
"""
# None is not allowed unless all hosts are allowed
if parsed_origin is None and "*" not in self.allowed_origins:
return False
return self.validate_origin(parsed_origin)
def validate_origin(self, parsed_origin):
"""
Validate the given origin for this site.
Check than the origin looks valid and matches the origin pattern in
specified list ``allowed_origins``. Any pattern begins with a scheme.
After the scheme there must be a domain. Any domain beginning with a
period corresponds to the domain and all its subdomains (for example,
``http://.example.com``). After the domain there must be a port,
but it can be omitted. ``*`` matches anything and anything
else must match exactly.
Note. This function assumes that the given origin has a schema, domain
and port, but port is optional.
Returns ``True`` for a valid host, ``False`` otherwise.
"""
return any(
pattern == "*" or self.match_allowed_origin(parsed_origin, pattern)
for pattern in self.allowed_origins
)
def match_allowed_origin(self, parsed_origin, pattern):
"""
Returns ``True`` if the origin is either an exact match or a match
to the wildcard pattern. Compares scheme, domain, port of origin and pattern.
Any pattern can be begins with a scheme. After the scheme must be a domain,
or just domain without scheme.
Any domain beginning with a period corresponds to the domain and all
its subdomains (for example, ``.example.com`` ``example.com``
and any subdomain). Also with scheme (for example, ``http://.example.com``
``http://exapmple.com``). After the domain there must be a port,
but it can be omitted.
Note. This function assumes that the given origin is either None, a
schema-domain-port string, or just a domain string
"""
if parsed_origin is None:
return False
# Get ResultParse object
parsed_pattern = urlparse(pattern.lower())
if parsed_origin.hostname is None:
return False
if not parsed_pattern.scheme:
pattern_hostname = urlparse("//" + pattern).hostname or pattern
return is_same_domain(parsed_origin.hostname, pattern_hostname)
# Get origin.port or default ports for origin or None
origin_port = self.get_origin_port(parsed_origin)
# Get pattern.port or default ports for pattern or None
pattern_port = self.get_origin_port(parsed_pattern)
# Compares hostname, scheme, ports of pattern and origin
if (
parsed_pattern.scheme == parsed_origin.scheme
and origin_port == pattern_port
and is_same_domain(parsed_origin.hostname, parsed_pattern.hostname)
):
return True
return False
def get_origin_port(self, origin):
"""
Returns the origin.port or port for this schema by default.
Otherwise, it returns None.
"""
if origin.port is not None:
# Return origin.port
return origin.port
# if origin.port doesn`t exists
if origin.scheme == "http" or origin.scheme == "ws":
# Default port return for http, ws
return 80
elif origin.scheme == "https" or origin.scheme == "wss":
# Default port return for https, wss
return 443
else:
return None
def AllowedHostsOriginValidator(application):
"""
Factory function which returns an OriginValidator configured to use
settings.ALLOWED_HOSTS.
"""
allowed_hosts = settings.ALLOWED_HOSTS
if settings.DEBUG and not allowed_hosts:
allowed_hosts = ["localhost", "127.0.0.1", "[::1]"]
return OriginValidator(application, allowed_hosts)
class WebsocketDenier(AsyncWebsocketConsumer):
"""
Simple application which denies all requests to it.
"""
async def connect(self):
await self.close()

View File

@@ -0,0 +1,268 @@
import datetime
import time
from importlib import import_module
from django.conf import settings
from django.contrib.sessions.backends.base import UpdateError
from django.core.exceptions import SuspiciousOperation
from django.http import parse_cookie
from django.http.cookie import SimpleCookie
from django.utils import timezone
from django.utils.encoding import force_str
from django.utils.functional import LazyObject
from channels.db import database_sync_to_async
try:
from django.utils.http import http_date
except ImportError:
from django.utils.http import cookie_date as http_date
class CookieMiddleware:
"""
Extracts cookies from HTTP or WebSocket-style scopes and adds them as a
scope["cookies"] entry with the same format as Django's request.COOKIES.
"""
def __init__(self, inner):
self.inner = inner
async def __call__(self, scope, receive, send):
# Check this actually has headers. They're a required scope key for HTTP and WS.
if "headers" not in scope:
raise ValueError(
"CookieMiddleware was passed a scope that did not have a headers key "
+ "(make sure it is only passed HTTP or WebSocket connections)"
)
# Go through headers to find the cookie one
for name, value in scope.get("headers", []):
if name == b"cookie":
cookies = parse_cookie(value.decode("latin1"))
break
else:
# No cookie header found - add an empty default.
cookies = {}
# Return inner application
return await self.inner(dict(scope, cookies=cookies), receive, send)
@classmethod
def set_cookie(
cls,
message,
key,
value="",
max_age=None,
expires=None,
path="/",
domain=None,
secure=False,
httponly=False,
samesite="lax",
):
"""
Sets a cookie in the passed HTTP response message.
``expires`` can be:
- a string in the correct format,
- a naive ``datetime.datetime`` object in UTC,
- an aware ``datetime.datetime`` object in any time zone.
If it is a ``datetime.datetime`` object then ``max_age`` will be calculated.
"""
value = force_str(value)
cookies = SimpleCookie()
cookies[key] = value
if expires is not None:
if isinstance(expires, datetime.datetime):
if timezone.is_aware(expires):
expires = timezone.make_naive(expires, timezone.utc)
delta = expires - expires.utcnow()
# Add one second so the date matches exactly (a fraction of
# time gets lost between converting to a timedelta and
# then the date string).
delta = delta + datetime.timedelta(seconds=1)
# Just set max_age - the max_age logic will set expires.
expires = None
max_age = max(0, delta.days * 86400 + delta.seconds)
else:
cookies[key]["expires"] = expires
else:
cookies[key]["expires"] = ""
if max_age is not None:
cookies[key]["max-age"] = max_age
# IE requires expires, so set it if hasn't been already.
if not expires:
cookies[key]["expires"] = http_date(time.time() + max_age)
if path is not None:
cookies[key]["path"] = path
if domain is not None:
cookies[key]["domain"] = domain
if secure:
cookies[key]["secure"] = True
if httponly:
cookies[key]["httponly"] = True
if samesite is not None:
assert samesite.lower() in [
"strict",
"lax",
"none",
], "samesite must be either 'strict', 'lax' or 'none'"
cookies[key]["samesite"] = samesite
# Write out the cookies to the response
for c in cookies.values():
message.setdefault("headers", []).append(
(b"Set-Cookie", bytes(c.output(header=""), encoding="utf-8"))
)
@classmethod
def delete_cookie(cls, message, key, path="/", domain=None):
"""
Deletes a cookie in a response.
"""
return cls.set_cookie(
message,
key,
max_age=0,
path=path,
domain=domain,
expires="Thu, 01-Jan-1970 00:00:00 GMT",
)
class InstanceSessionWrapper:
"""
Populates the session in application instance scope, and wraps send to save
the session.
"""
# Message types that trigger a session save if it's modified
save_message_types = ["http.response.start"]
# Message types that can carry session cookies back
cookie_response_message_types = ["http.response.start"]
def __init__(self, scope, send):
self.cookie_name = settings.SESSION_COOKIE_NAME
self.session_store = import_module(settings.SESSION_ENGINE).SessionStore
self.scope = dict(scope)
if "session" in self.scope:
# There's already session middleware of some kind above us, pass
# that through
self.activated = False
else:
# Make sure there are cookies in the scope
if "cookies" not in self.scope:
raise ValueError(
"No cookies in scope - SessionMiddleware needs to run "
"inside of CookieMiddleware."
)
# Parse the headers in the scope into cookies
self.scope["session"] = LazyObject()
self.activated = True
# Override send
self.real_send = send
async def resolve_session(self):
session_key = self.scope["cookies"].get(self.cookie_name)
self.scope["session"]._wrapped = await database_sync_to_async(
self.session_store
)(session_key)
async def send(self, message):
"""
Overridden send that also does session saves/cookies.
"""
# Only save session if we're the outermost session middleware
if self.activated:
modified = self.scope["session"].modified
empty = self.scope["session"].is_empty()
# If this is a message type that we want to save on, and there's
# changed data, save it. We also save if it's empty as we might
# not be able to send a cookie-delete along with this message.
if (
message["type"] in self.save_message_types
and message.get("status", 200) != 500
and (modified or settings.SESSION_SAVE_EVERY_REQUEST)
):
await database_sync_to_async(self.save_session)()
# If this is a message type that can transport cookies back to the
# client, then do so.
if message["type"] in self.cookie_response_message_types:
if empty:
# Delete cookie if it's set
if settings.SESSION_COOKIE_NAME in self.scope["cookies"]:
CookieMiddleware.delete_cookie(
message,
settings.SESSION_COOKIE_NAME,
path=settings.SESSION_COOKIE_PATH,
domain=settings.SESSION_COOKIE_DOMAIN,
)
else:
# Get the expiry data
if self.scope["session"].get_expire_at_browser_close():
max_age = None
expires = None
else:
max_age = self.scope["session"].get_expiry_age()
expires_time = time.time() + max_age
expires = http_date(expires_time)
# Set the cookie
CookieMiddleware.set_cookie(
message,
self.cookie_name,
self.scope["session"].session_key,
max_age=max_age,
expires=expires,
domain=settings.SESSION_COOKIE_DOMAIN,
path=settings.SESSION_COOKIE_PATH,
secure=settings.SESSION_COOKIE_SECURE or None,
httponly=settings.SESSION_COOKIE_HTTPONLY or None,
samesite=settings.SESSION_COOKIE_SAMESITE,
)
# Pass up the send
return await self.real_send(message)
def save_session(self):
"""
Saves the current session.
"""
try:
self.scope["session"].save()
except UpdateError:
raise SuspiciousOperation(
"The request's session was deleted before the "
"request completed. The user may have logged "
"out in a concurrent request, for example."
)
class SessionMiddleware:
"""
Class that adds Django sessions (from HTTP cookies) to the
scope. Works with HTTP or WebSocket protocol types (or anything that
provides a "headers" entry in the scope).
Requires the CookieMiddleware to be higher up in the stack.
"""
def __init__(self, inner):
self.inner = inner
async def __call__(self, scope, receive, send):
"""
Instantiate a session wrapper for this scope, resolve the session and
call the inner application.
"""
wrapper = InstanceSessionWrapper(scope, send)
await wrapper.resolve_session()
return await self.inner(wrapper.scope, receive, wrapper.send)
# Shortcut to include cookie middleware
def SessionMiddlewareStack(inner):
return CookieMiddleware(SessionMiddleware(inner))

View File

@@ -0,0 +1,12 @@
from asgiref.testing import ApplicationCommunicator # noqa
from .http import HttpCommunicator # noqa
from .live import ChannelsLiveServerTestCase # noqa
from .websocket import WebsocketCommunicator # noqa
__all__ = [
"ApplicationCommunicator",
"HttpCommunicator",
"ChannelsLiveServerTestCase",
"WebsocketCommunicator",
]

View File

@@ -0,0 +1,56 @@
from urllib.parse import unquote, urlparse
from asgiref.testing import ApplicationCommunicator
class HttpCommunicator(ApplicationCommunicator):
"""
ApplicationCommunicator subclass that has HTTP shortcut methods.
It will construct the scope for you, so you need to pass the application
(uninstantiated) along with HTTP parameters.
This does not support full chunking - for that, just use ApplicationCommunicator
directly.
"""
def __init__(self, application, method, path, body=b"", headers=None):
parsed = urlparse(path)
self.scope = {
"type": "http",
"http_version": "1.1",
"method": method.upper(),
"path": unquote(parsed.path),
"query_string": parsed.query.encode("utf-8"),
"headers": headers or [],
}
assert isinstance(body, bytes)
self.body = body
self.sent_request = False
super().__init__(application, self.scope)
async def get_response(self, timeout=1):
"""
Get the application's response. Returns a dict with keys of
"body", "headers" and "status".
"""
# If we've not sent the request yet, do so
if not self.sent_request:
self.sent_request = True
await self.send_input({"type": "http.request", "body": self.body})
# Get the response start
response_start = await self.receive_output(timeout)
assert response_start["type"] == "http.response.start"
# Get all body parts
response_start["body"] = b""
while True:
chunk = await self.receive_output(timeout)
assert chunk["type"] == "http.response.body"
assert isinstance(chunk["body"], bytes)
response_start["body"] += chunk["body"]
if not chunk.get("more_body", False):
break
# Return structured info
del response_start["type"]
response_start.setdefault("headers", [])
return response_start

View File

@@ -0,0 +1,76 @@
from functools import partial
from daphne.testing import DaphneProcess
from django.contrib.staticfiles.handlers import ASGIStaticFilesHandler
from django.core.exceptions import ImproperlyConfigured
from django.db import connections
from django.test.testcases import TransactionTestCase
from django.test.utils import modify_settings
from channels.routing import get_default_application
def make_application(*, static_wrapper):
# Module-level function for pickle-ability
application = get_default_application()
if static_wrapper is not None:
application = static_wrapper(application)
return application
class ChannelsLiveServerTestCase(TransactionTestCase):
"""
Does basically the same as TransactionTestCase but also launches a
live Daphne server in a separate process, so
that the tests may use another test framework, such as Selenium,
instead of the built-in dummy client.
"""
host = "localhost"
ProtocolServerProcess = DaphneProcess
static_wrapper = ASGIStaticFilesHandler
serve_static = True
@property
def live_server_url(self):
return "http://%s:%s" % (self.host, self._port)
@property
def live_server_ws_url(self):
return "ws://%s:%s" % (self.host, self._port)
def _pre_setup(self):
for connection in connections.all():
if self._is_in_memory_db(connection):
raise ImproperlyConfigured(
"ChannelLiveServerTestCase can not be used with in memory databases"
)
super(ChannelsLiveServerTestCase, self)._pre_setup()
self._live_server_modified_settings = modify_settings(
ALLOWED_HOSTS={"append": self.host}
)
self._live_server_modified_settings.enable()
get_application = partial(
make_application,
static_wrapper=self.static_wrapper if self.serve_static else None,
)
self._server_process = self.ProtocolServerProcess(self.host, get_application)
self._server_process.start()
self._server_process.ready.wait()
self._port = self._server_process.port.value
def _post_teardown(self):
self._server_process.terminate()
self._server_process.join()
self._live_server_modified_settings.disable()
super(ChannelsLiveServerTestCase, self)._post_teardown()
def _is_in_memory_db(self, connection):
"""
Check if DatabaseWrapper holds in memory database.
"""
if connection.vendor == "sqlite":
return connection.is_in_memory_db()

View File

@@ -0,0 +1,102 @@
import json
from urllib.parse import unquote, urlparse
from asgiref.testing import ApplicationCommunicator
class WebsocketCommunicator(ApplicationCommunicator):
"""
ApplicationCommunicator subclass that has WebSocket shortcut methods.
It will construct the scope for you, so you need to pass the application
(uninstantiated) along with the initial connection parameters.
"""
def __init__(self, application, path, headers=None, subprotocols=None):
if not isinstance(path, str):
raise TypeError("Expected str, got {}".format(type(path)))
parsed = urlparse(path)
self.scope = {
"type": "websocket",
"path": unquote(parsed.path),
"query_string": parsed.query.encode("utf-8"),
"headers": headers or [],
"subprotocols": subprotocols or [],
}
super().__init__(application, self.scope)
async def connect(self, timeout=1):
"""
Trigger the connection code.
On an accepted connection, returns (True, <chosen-subprotocol>)
On a rejected connection, returns (False, <close-code>)
"""
await self.send_input({"type": "websocket.connect"})
response = await self.receive_output(timeout)
if response["type"] == "websocket.close":
return (False, response.get("code", 1000))
else:
return (True, response.get("subprotocol", None))
async def send_to(self, text_data=None, bytes_data=None):
"""
Sends a WebSocket frame to the application.
"""
# Make sure we have exactly one of the arguments
assert bool(text_data) != bool(
bytes_data
), "You must supply exactly one of text_data or bytes_data"
# Send the right kind of event
if text_data:
assert isinstance(text_data, str), "The text_data argument must be a str"
await self.send_input({"type": "websocket.receive", "text": text_data})
else:
assert isinstance(
bytes_data, bytes
), "The bytes_data argument must be bytes"
await self.send_input({"type": "websocket.receive", "bytes": bytes_data})
async def send_json_to(self, data):
"""
Sends JSON data as a text frame
"""
await self.send_to(text_data=json.dumps(data))
async def receive_from(self, timeout=1):
"""
Receives a data frame from the view. Will fail if the connection
closes instead. Returns either a bytestring or a unicode string
depending on what sort of frame you got.
"""
response = await self.receive_output(timeout)
# Make sure this is a send message
assert response["type"] == "websocket.send"
# Make sure there's exactly one key in the response
assert ("text" in response) != (
"bytes" in response
), "The response needs exactly one of 'text' or 'bytes'"
# Pull out the right key and typecheck it for our users
if "text" in response:
assert isinstance(response["text"], str), "Text frame payload is not str"
return response["text"]
else:
assert isinstance(
response["bytes"], bytes
), "Binary frame payload is not bytes"
return response["bytes"]
async def receive_json_from(self, timeout=1):
"""
Receives a JSON text frame payload and decodes it
"""
payload = await self.receive_from(timeout)
assert isinstance(payload, str), "JSON data is not a text frame"
return json.loads(payload)
async def disconnect(self, code=1000, timeout=1):
"""
Closes the socket
"""
await self.send_input({"type": "websocket.disconnect", "code": code})
await self.wait(timeout)

View File

@@ -0,0 +1,59 @@
import asyncio
import types
def name_that_thing(thing):
"""
Returns either the function/class path or just the object's repr
"""
# Instance method
if hasattr(thing, "im_class"):
# Mocks will recurse im_class forever
if hasattr(thing, "mock_calls"):
return "<mock>"
return name_that_thing(thing.im_class) + "." + thing.im_func.func_name
# Other named thing
if hasattr(thing, "__name__"):
if hasattr(thing, "__class__") and not isinstance(
thing, (types.FunctionType, types.MethodType)
):
if thing.__class__ is not type and not issubclass(thing.__class__, type):
return name_that_thing(thing.__class__)
if hasattr(thing, "__self__"):
return "%s.%s" % (thing.__self__.__module__, thing.__self__.__name__)
if hasattr(thing, "__module__"):
return "%s.%s" % (thing.__module__, thing.__name__)
# Generic instance of a class
if hasattr(thing, "__class__"):
return name_that_thing(thing.__class__)
return repr(thing)
async def await_many_dispatch(consumer_callables, dispatch):
"""
Given a set of consumer callables, awaits on them all and passes results
from them to the dispatch awaitable as they come in.
"""
# Call all callables, and ensure all return types are Futures
tasks = [
asyncio.ensure_future(consumer_callable())
for consumer_callable in consumer_callables
]
try:
while True:
# Wait for any of them to complete
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
# Find the completed one(s), yield results, and replace them
for i, task in enumerate(tasks):
if task.done():
result = task.result()
await dispatch(result)
tasks[i] = asyncio.ensure_future(consumer_callables[i]())
finally:
# Make sure we clean up tasks on exit
for task in tasks:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass

View File

@@ -0,0 +1,44 @@
import asyncio
from asgiref.server import StatelessServer
class Worker(StatelessServer):
"""
ASGI protocol server that surfaces events sent to specific channels
on the channel layer into a single application instance.
"""
def __init__(self, application, channels, channel_layer, max_applications=1000):
super().__init__(application, max_applications)
self.channels = channels
self.channel_layer = channel_layer
if self.channel_layer is None:
raise ValueError("Channel layer is not valid")
async def handle(self):
"""
Listens on all the provided channels and handles the messages.
"""
# For each channel, launch its own listening coroutine
listeners = []
for channel in self.channels:
listeners.append(asyncio.ensure_future(self.listener(channel)))
# Wait for them all to exit
await asyncio.wait(listeners)
# See if any of the listeners had an error (e.g. channel layer error)
[listener.result() for listener in listeners]
async def listener(self, channel):
"""
Single-channel listener
"""
while True:
message = await self.channel_layer.receive(channel)
if not message.get("type", None):
raise ValueError("Worker received message with no type.")
# Make a scope and get an application instance for it
scope = {"type": "channel", "channel": channel}
instance_queue = self.get_or_create_application_instance(channel, scope)
# Run the message into the app
await instance_queue.put(message)