mirror of
https://github.com/pacnpal/thrillwiki_django_no_react.git
synced 2025-12-22 11:11:10 -05:00
okay fine
This commit is contained in:
4
.venv/lib/python3.12/site-packages/channels/__init__.py
Normal file
4
.venv/lib/python3.12/site-packages/channels/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
__version__ = "4.0.0"
|
||||
|
||||
|
||||
DEFAULT_CHANNEL_LAYER = "default"
|
||||
7
.venv/lib/python3.12/site-packages/channels/apps.py
Normal file
7
.venv/lib/python3.12/site-packages/channels/apps.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from django.apps import AppConfig
|
||||
|
||||
|
||||
class ChannelsConfig(AppConfig):
|
||||
|
||||
name = "channels"
|
||||
verbose_name = "Channels"
|
||||
190
.venv/lib/python3.12/site-packages/channels/auth.py
Normal file
190
.venv/lib/python3.12/site-packages/channels/auth.py
Normal file
@@ -0,0 +1,190 @@
|
||||
from django.conf import settings
|
||||
from django.contrib.auth import (
|
||||
BACKEND_SESSION_KEY,
|
||||
HASH_SESSION_KEY,
|
||||
SESSION_KEY,
|
||||
_get_backends,
|
||||
get_user_model,
|
||||
load_backend,
|
||||
user_logged_in,
|
||||
user_logged_out,
|
||||
)
|
||||
from django.utils.crypto import constant_time_compare
|
||||
from django.utils.functional import LazyObject
|
||||
|
||||
from channels.db import database_sync_to_async
|
||||
from channels.middleware import BaseMiddleware
|
||||
from channels.sessions import CookieMiddleware, SessionMiddleware
|
||||
|
||||
|
||||
@database_sync_to_async
|
||||
def get_user(scope):
|
||||
"""
|
||||
Return the user model instance associated with the given scope.
|
||||
If no user is retrieved, return an instance of `AnonymousUser`.
|
||||
"""
|
||||
# postpone model import to avoid ImproperlyConfigured error before Django
|
||||
# setup is complete.
|
||||
from django.contrib.auth.models import AnonymousUser
|
||||
|
||||
if "session" not in scope:
|
||||
raise ValueError(
|
||||
"Cannot find session in scope. You should wrap your consumer in "
|
||||
"SessionMiddleware."
|
||||
)
|
||||
session = scope["session"]
|
||||
user = None
|
||||
try:
|
||||
user_id = _get_user_session_key(session)
|
||||
backend_path = session[BACKEND_SESSION_KEY]
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
if backend_path in settings.AUTHENTICATION_BACKENDS:
|
||||
backend = load_backend(backend_path)
|
||||
user = backend.get_user(user_id)
|
||||
# Verify the session
|
||||
if hasattr(user, "get_session_auth_hash"):
|
||||
session_hash = session.get(HASH_SESSION_KEY)
|
||||
session_hash_verified = session_hash and constant_time_compare(
|
||||
session_hash, user.get_session_auth_hash()
|
||||
)
|
||||
if not session_hash_verified:
|
||||
session.flush()
|
||||
user = None
|
||||
return user or AnonymousUser()
|
||||
|
||||
|
||||
@database_sync_to_async
|
||||
def login(scope, user, backend=None):
|
||||
"""
|
||||
Persist a user id and a backend in the request.
|
||||
This way a user doesn't have to re-authenticate on every request.
|
||||
Note that data set during the anonymous session is retained when the user
|
||||
logs in.
|
||||
"""
|
||||
if "session" not in scope:
|
||||
raise ValueError(
|
||||
"Cannot find session in scope. You should wrap your consumer in "
|
||||
"SessionMiddleware."
|
||||
)
|
||||
session = scope["session"]
|
||||
session_auth_hash = ""
|
||||
if user is None:
|
||||
user = scope.get("user", None)
|
||||
if user is None:
|
||||
raise ValueError(
|
||||
"User must be passed as an argument or must be present in the scope."
|
||||
)
|
||||
if hasattr(user, "get_session_auth_hash"):
|
||||
session_auth_hash = user.get_session_auth_hash()
|
||||
if SESSION_KEY in session:
|
||||
if _get_user_session_key(session) != user.pk or (
|
||||
session_auth_hash
|
||||
and not constant_time_compare(
|
||||
session.get(HASH_SESSION_KEY, ""), session_auth_hash
|
||||
)
|
||||
):
|
||||
# To avoid reusing another user's session, create a new, empty
|
||||
# session if the existing session corresponds to a different
|
||||
# authenticated user.
|
||||
session.flush()
|
||||
else:
|
||||
session.cycle_key()
|
||||
try:
|
||||
backend = backend or user.backend
|
||||
except AttributeError:
|
||||
backends = _get_backends(return_tuples=True)
|
||||
if len(backends) == 1:
|
||||
_, backend = backends[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
"You have multiple authentication backends configured and "
|
||||
"therefore must provide the `backend` "
|
||||
"argument or set the `backend` attribute on the user."
|
||||
)
|
||||
session[SESSION_KEY] = user._meta.pk.value_to_string(user)
|
||||
session[BACKEND_SESSION_KEY] = backend
|
||||
session[HASH_SESSION_KEY] = session_auth_hash
|
||||
scope["user"] = user
|
||||
# note this does not reset the CSRF_COOKIE/Token
|
||||
user_logged_in.send(sender=user.__class__, request=None, user=user)
|
||||
|
||||
|
||||
@database_sync_to_async
|
||||
def logout(scope):
|
||||
"""
|
||||
Remove the authenticated user's ID from the request and flush their session
|
||||
data.
|
||||
"""
|
||||
# postpone model import to avoid ImproperlyConfigured error before Django
|
||||
# setup is complete.
|
||||
from django.contrib.auth.models import AnonymousUser
|
||||
|
||||
if "session" not in scope:
|
||||
raise ValueError(
|
||||
"Login cannot find session in scope. You should wrap your "
|
||||
"consumer in SessionMiddleware."
|
||||
)
|
||||
session = scope["session"]
|
||||
# Dispatch the signal before the user is logged out so the receivers have a
|
||||
# chance to find out *who* logged out.
|
||||
user = scope.get("user", None)
|
||||
if hasattr(user, "is_authenticated") and not user.is_authenticated:
|
||||
user = None
|
||||
if user is not None:
|
||||
user_logged_out.send(sender=user.__class__, request=None, user=user)
|
||||
session.flush()
|
||||
if "user" in scope:
|
||||
scope["user"] = AnonymousUser()
|
||||
|
||||
|
||||
def _get_user_session_key(session):
|
||||
# This value in the session is always serialized to a string, so we need
|
||||
# to convert it back to Python whenever we access it.
|
||||
return get_user_model()._meta.pk.to_python(session[SESSION_KEY])
|
||||
|
||||
|
||||
class UserLazyObject(LazyObject):
|
||||
"""
|
||||
Throw a more useful error message when scope['user'] is accessed before
|
||||
it's resolved
|
||||
"""
|
||||
|
||||
def _setup(self):
|
||||
raise ValueError("Accessing scope user before it is ready.")
|
||||
|
||||
|
||||
class AuthMiddleware(BaseMiddleware):
|
||||
"""
|
||||
Middleware which populates scope["user"] from a Django session.
|
||||
Requires SessionMiddleware to function.
|
||||
"""
|
||||
|
||||
def populate_scope(self, scope):
|
||||
# Make sure we have a session
|
||||
if "session" not in scope:
|
||||
raise ValueError(
|
||||
"AuthMiddleware cannot find session in scope. "
|
||||
"SessionMiddleware must be above it."
|
||||
)
|
||||
# Add it to the scope if it's not there already
|
||||
if "user" not in scope:
|
||||
scope["user"] = UserLazyObject()
|
||||
|
||||
async def resolve_scope(self, scope):
|
||||
scope["user"]._wrapped = await get_user(scope)
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
scope = dict(scope)
|
||||
# Scope injection/mutation per this middleware's needs.
|
||||
self.populate_scope(scope)
|
||||
# Grab the finalized/resolved scope
|
||||
await self.resolve_scope(scope)
|
||||
|
||||
return await super().__call__(scope, receive, send)
|
||||
|
||||
|
||||
# Handy shortcut for applying all three layers at once
|
||||
def AuthMiddlewareStack(inner):
|
||||
return CookieMiddleware(SessionMiddleware(AuthMiddleware(inner)))
|
||||
133
.venv/lib/python3.12/site-packages/channels/consumer.py
Normal file
133
.venv/lib/python3.12/site-packages/channels/consumer.py
Normal file
@@ -0,0 +1,133 @@
|
||||
import functools
|
||||
|
||||
from asgiref.sync import async_to_sync
|
||||
|
||||
from . import DEFAULT_CHANNEL_LAYER
|
||||
from .db import database_sync_to_async
|
||||
from .exceptions import StopConsumer
|
||||
from .layers import get_channel_layer
|
||||
from .utils import await_many_dispatch
|
||||
|
||||
|
||||
def get_handler_name(message):
|
||||
"""
|
||||
Looks at a message, checks it has a sensible type, and returns the
|
||||
handler name for that type.
|
||||
"""
|
||||
# Check message looks OK
|
||||
if "type" not in message:
|
||||
raise ValueError("Incoming message has no 'type' attribute")
|
||||
# Extract type and replace . with _
|
||||
handler_name = message["type"].replace(".", "_")
|
||||
if handler_name.startswith("_"):
|
||||
raise ValueError("Malformed type in message (leading underscore)")
|
||||
return handler_name
|
||||
|
||||
|
||||
class AsyncConsumer:
|
||||
"""
|
||||
Base consumer class. Implements the ASGI application spec, and adds on
|
||||
channel layer management and routing of events to named methods based
|
||||
on their type.
|
||||
"""
|
||||
|
||||
_sync = False
|
||||
channel_layer_alias = DEFAULT_CHANNEL_LAYER
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
"""
|
||||
Dispatches incoming messages to type-based handlers asynchronously.
|
||||
"""
|
||||
self.scope = scope
|
||||
|
||||
# Initialize channel layer
|
||||
self.channel_layer = get_channel_layer(self.channel_layer_alias)
|
||||
if self.channel_layer is not None:
|
||||
self.channel_name = await self.channel_layer.new_channel()
|
||||
self.channel_receive = functools.partial(
|
||||
self.channel_layer.receive, self.channel_name
|
||||
)
|
||||
# Store send function
|
||||
if self._sync:
|
||||
self.base_send = async_to_sync(send)
|
||||
else:
|
||||
self.base_send = send
|
||||
# Pass messages in from channel layer or client to dispatch method
|
||||
try:
|
||||
if self.channel_layer is not None:
|
||||
await await_many_dispatch(
|
||||
[receive, self.channel_receive], self.dispatch
|
||||
)
|
||||
else:
|
||||
await await_many_dispatch([receive], self.dispatch)
|
||||
except StopConsumer:
|
||||
# Exit cleanly
|
||||
pass
|
||||
|
||||
async def dispatch(self, message):
|
||||
"""
|
||||
Works out what to do with a message.
|
||||
"""
|
||||
handler = getattr(self, get_handler_name(message), None)
|
||||
if handler:
|
||||
await handler(message)
|
||||
else:
|
||||
raise ValueError("No handler for message type %s" % message["type"])
|
||||
|
||||
async def send(self, message):
|
||||
"""
|
||||
Overrideable/callable-by-subclasses send method.
|
||||
"""
|
||||
await self.base_send(message)
|
||||
|
||||
@classmethod
|
||||
def as_asgi(cls, **initkwargs):
|
||||
"""
|
||||
Return an ASGI v3 single callable that instantiates a consumer instance
|
||||
per scope. Similar in purpose to Django's as_view().
|
||||
|
||||
initkwargs will be used to instantiate the consumer instance.
|
||||
"""
|
||||
|
||||
async def app(scope, receive, send):
|
||||
consumer = cls(**initkwargs)
|
||||
return await consumer(scope, receive, send)
|
||||
|
||||
app.consumer_class = cls
|
||||
app.consumer_initkwargs = initkwargs
|
||||
|
||||
# take name and docstring from class
|
||||
functools.update_wrapper(app, cls, updated=())
|
||||
return app
|
||||
|
||||
|
||||
class SyncConsumer(AsyncConsumer):
|
||||
"""
|
||||
Synchronous version of the consumer, which is what we write most of the
|
||||
generic consumers against (for now). Calls handlers in a threadpool and
|
||||
uses CallBouncer to get the send method out to the main event loop.
|
||||
|
||||
It would have been possible to have "mixed" consumers and auto-detect
|
||||
if a handler was awaitable or not, but that would have made the API
|
||||
for user-called methods very confusing as there'd be two types of each.
|
||||
"""
|
||||
|
||||
_sync = True
|
||||
|
||||
@database_sync_to_async
|
||||
def dispatch(self, message):
|
||||
"""
|
||||
Dispatches incoming messages to type-based handlers asynchronously.
|
||||
"""
|
||||
# Get and execute the handler
|
||||
handler = getattr(self, get_handler_name(message), None)
|
||||
if handler:
|
||||
handler(message)
|
||||
else:
|
||||
raise ValueError("No handler for message type %s" % message["type"])
|
||||
|
||||
def send(self, message):
|
||||
"""
|
||||
Overrideable/callable-by-subclasses send method.
|
||||
"""
|
||||
self.base_send(message)
|
||||
19
.venv/lib/python3.12/site-packages/channels/db.py
Normal file
19
.venv/lib/python3.12/site-packages/channels/db.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from asgiref.sync import SyncToAsync
|
||||
from django.db import close_old_connections
|
||||
|
||||
|
||||
class DatabaseSyncToAsync(SyncToAsync):
|
||||
"""
|
||||
SyncToAsync version that cleans up old database connections when it exits.
|
||||
"""
|
||||
|
||||
def thread_handler(self, loop, *args, **kwargs):
|
||||
close_old_connections()
|
||||
try:
|
||||
return super().thread_handler(loop, *args, **kwargs)
|
||||
finally:
|
||||
close_old_connections()
|
||||
|
||||
|
||||
# The class is TitleCased, but we want to encourage use as a callable/decorator
|
||||
database_sync_to_async = DatabaseSyncToAsync
|
||||
65
.venv/lib/python3.12/site-packages/channels/exceptions.py
Normal file
65
.venv/lib/python3.12/site-packages/channels/exceptions.py
Normal file
@@ -0,0 +1,65 @@
|
||||
class RequestAborted(Exception):
|
||||
"""
|
||||
Raised when the incoming request tells us it's aborted partway through
|
||||
reading the body.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class RequestTimeout(RequestAborted):
|
||||
"""
|
||||
Aborted specifically due to timeout.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class InvalidChannelLayerError(ValueError):
|
||||
"""
|
||||
Raised when a channel layer is configured incorrectly.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AcceptConnection(Exception):
|
||||
"""
|
||||
Raised during a websocket.connect (or other supported connection) handler
|
||||
to accept the connection.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class DenyConnection(Exception):
|
||||
"""
|
||||
Raised during a websocket.connect (or other supported connection) handler
|
||||
to deny the connection.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ChannelFull(Exception):
|
||||
"""
|
||||
Raised when a channel cannot be sent to as it is over capacity.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class MessageTooLarge(Exception):
|
||||
"""
|
||||
Raised when a message cannot be sent as it's too big.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class StopConsumer(Exception):
|
||||
"""
|
||||
Raised when a consumer wants to stop and close down its application instance.
|
||||
"""
|
||||
|
||||
pass
|
||||
91
.venv/lib/python3.12/site-packages/channels/generic/http.py
Normal file
91
.venv/lib/python3.12/site-packages/channels/generic/http.py
Normal file
@@ -0,0 +1,91 @@
|
||||
from channels.consumer import AsyncConsumer
|
||||
|
||||
from ..exceptions import StopConsumer
|
||||
|
||||
|
||||
class AsyncHttpConsumer(AsyncConsumer):
|
||||
"""
|
||||
Async HTTP consumer. Provides basic primitives for building asynchronous
|
||||
HTTP endpoints.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.body = []
|
||||
|
||||
async def send_headers(self, *, status=200, headers=None):
|
||||
"""
|
||||
Sets the HTTP response status and headers. Headers may be provided as
|
||||
a list of tuples or as a dictionary.
|
||||
|
||||
Note that the ASGI spec requires that the protocol server only starts
|
||||
sending the response to the client after ``self.send_body`` has been
|
||||
called the first time.
|
||||
"""
|
||||
if headers is None:
|
||||
headers = []
|
||||
elif isinstance(headers, dict):
|
||||
headers = list(headers.items())
|
||||
|
||||
await self.send(
|
||||
{"type": "http.response.start", "status": status, "headers": headers}
|
||||
)
|
||||
|
||||
async def send_body(self, body, *, more_body=False):
|
||||
"""
|
||||
Sends a response body to the client. The method expects a bytestring.
|
||||
|
||||
Set ``more_body=True`` if you want to send more body content later.
|
||||
The default behavior closes the response, and further messages on
|
||||
the channel will be ignored.
|
||||
"""
|
||||
assert isinstance(body, bytes), "Body is not bytes"
|
||||
await self.send(
|
||||
{"type": "http.response.body", "body": body, "more_body": more_body}
|
||||
)
|
||||
|
||||
async def send_response(self, status, body, **kwargs):
|
||||
"""
|
||||
Sends a response to the client. This is a thin wrapper over
|
||||
``self.send_headers`` and ``self.send_body``, and everything said
|
||||
above applies here as well. This method may only be called once.
|
||||
"""
|
||||
await self.send_headers(status=status, **kwargs)
|
||||
await self.send_body(body)
|
||||
|
||||
async def handle(self, body):
|
||||
"""
|
||||
Receives the request body as a bytestring. Response may be composed
|
||||
using the ``self.send*`` methods; the return value of this method is
|
||||
thrown away.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"Subclasses of AsyncHttpConsumer must provide a handle() method."
|
||||
)
|
||||
|
||||
async def disconnect(self):
|
||||
"""
|
||||
Overrideable place to run disconnect handling. Do not send anything
|
||||
from here.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def http_request(self, message):
|
||||
"""
|
||||
Async entrypoint - concatenates body fragments and hands off control
|
||||
to ``self.handle`` when the body has been completely received.
|
||||
"""
|
||||
if "body" in message:
|
||||
self.body.append(message["body"])
|
||||
if not message.get("more_body"):
|
||||
try:
|
||||
await self.handle(b"".join(self.body))
|
||||
finally:
|
||||
await self.disconnect()
|
||||
raise StopConsumer()
|
||||
|
||||
async def http_disconnect(self, message):
|
||||
"""
|
||||
Let the user do their cleanup and close the consumer.
|
||||
"""
|
||||
await self.disconnect()
|
||||
raise StopConsumer()
|
||||
279
.venv/lib/python3.12/site-packages/channels/generic/websocket.py
Normal file
279
.venv/lib/python3.12/site-packages/channels/generic/websocket.py
Normal file
@@ -0,0 +1,279 @@
|
||||
import json
|
||||
|
||||
from asgiref.sync import async_to_sync
|
||||
|
||||
from ..consumer import AsyncConsumer, SyncConsumer
|
||||
from ..exceptions import (
|
||||
AcceptConnection,
|
||||
DenyConnection,
|
||||
InvalidChannelLayerError,
|
||||
StopConsumer,
|
||||
)
|
||||
|
||||
|
||||
class WebsocketConsumer(SyncConsumer):
|
||||
"""
|
||||
Base WebSocket consumer. Provides a general encapsulation for the
|
||||
WebSocket handling model that other applications can build on.
|
||||
"""
|
||||
|
||||
groups = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if self.groups is None:
|
||||
self.groups = []
|
||||
|
||||
def websocket_connect(self, message):
|
||||
"""
|
||||
Called when a WebSocket connection is opened.
|
||||
"""
|
||||
try:
|
||||
for group in self.groups:
|
||||
async_to_sync(self.channel_layer.group_add)(group, self.channel_name)
|
||||
except AttributeError:
|
||||
raise InvalidChannelLayerError(
|
||||
"BACKEND is unconfigured or doesn't support groups"
|
||||
)
|
||||
try:
|
||||
self.connect()
|
||||
except AcceptConnection:
|
||||
self.accept()
|
||||
except DenyConnection:
|
||||
self.close()
|
||||
|
||||
def connect(self):
|
||||
self.accept()
|
||||
|
||||
def accept(self, subprotocol=None):
|
||||
"""
|
||||
Accepts an incoming socket
|
||||
"""
|
||||
super().send({"type": "websocket.accept", "subprotocol": subprotocol})
|
||||
|
||||
def websocket_receive(self, message):
|
||||
"""
|
||||
Called when a WebSocket frame is received. Decodes it and passes it
|
||||
to receive().
|
||||
"""
|
||||
if "text" in message:
|
||||
self.receive(text_data=message["text"])
|
||||
else:
|
||||
self.receive(bytes_data=message["bytes"])
|
||||
|
||||
def receive(self, text_data=None, bytes_data=None):
|
||||
"""
|
||||
Called with a decoded WebSocket frame.
|
||||
"""
|
||||
pass
|
||||
|
||||
def send(self, text_data=None, bytes_data=None, close=False):
|
||||
"""
|
||||
Sends a reply back down the WebSocket
|
||||
"""
|
||||
if text_data is not None:
|
||||
super().send({"type": "websocket.send", "text": text_data})
|
||||
elif bytes_data is not None:
|
||||
super().send({"type": "websocket.send", "bytes": bytes_data})
|
||||
else:
|
||||
raise ValueError("You must pass one of bytes_data or text_data")
|
||||
if close:
|
||||
self.close(close)
|
||||
|
||||
def close(self, code=None):
|
||||
"""
|
||||
Closes the WebSocket from the server end
|
||||
"""
|
||||
if code is not None and code is not True:
|
||||
super().send({"type": "websocket.close", "code": code})
|
||||
else:
|
||||
super().send({"type": "websocket.close"})
|
||||
|
||||
def websocket_disconnect(self, message):
|
||||
"""
|
||||
Called when a WebSocket connection is closed. Base level so you don't
|
||||
need to call super() all the time.
|
||||
"""
|
||||
try:
|
||||
for group in self.groups:
|
||||
async_to_sync(self.channel_layer.group_discard)(
|
||||
group, self.channel_name
|
||||
)
|
||||
except AttributeError:
|
||||
raise InvalidChannelLayerError(
|
||||
"BACKEND is unconfigured or doesn't support groups"
|
||||
)
|
||||
self.disconnect(message["code"])
|
||||
raise StopConsumer()
|
||||
|
||||
def disconnect(self, code):
|
||||
"""
|
||||
Called when a WebSocket connection is closed.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class JsonWebsocketConsumer(WebsocketConsumer):
|
||||
"""
|
||||
Variant of WebsocketConsumer that automatically JSON-encodes and decodes
|
||||
messages as they come in and go out. Expects everything to be text; will
|
||||
error on binary data.
|
||||
"""
|
||||
|
||||
def receive(self, text_data=None, bytes_data=None, **kwargs):
|
||||
if text_data:
|
||||
self.receive_json(self.decode_json(text_data), **kwargs)
|
||||
else:
|
||||
raise ValueError("No text section for incoming WebSocket frame!")
|
||||
|
||||
def receive_json(self, content, **kwargs):
|
||||
"""
|
||||
Called with decoded JSON content.
|
||||
"""
|
||||
pass
|
||||
|
||||
def send_json(self, content, close=False):
|
||||
"""
|
||||
Encode the given content as JSON and send it to the client.
|
||||
"""
|
||||
super().send(text_data=self.encode_json(content), close=close)
|
||||
|
||||
@classmethod
|
||||
def decode_json(cls, text_data):
|
||||
return json.loads(text_data)
|
||||
|
||||
@classmethod
|
||||
def encode_json(cls, content):
|
||||
return json.dumps(content)
|
||||
|
||||
|
||||
class AsyncWebsocketConsumer(AsyncConsumer):
|
||||
"""
|
||||
Base WebSocket consumer, async version. Provides a general encapsulation
|
||||
for the WebSocket handling model that other applications can build on.
|
||||
"""
|
||||
|
||||
groups = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if self.groups is None:
|
||||
self.groups = []
|
||||
|
||||
async def websocket_connect(self, message):
|
||||
"""
|
||||
Called when a WebSocket connection is opened.
|
||||
"""
|
||||
try:
|
||||
for group in self.groups:
|
||||
await self.channel_layer.group_add(group, self.channel_name)
|
||||
except AttributeError:
|
||||
raise InvalidChannelLayerError(
|
||||
"BACKEND is unconfigured or doesn't support groups"
|
||||
)
|
||||
try:
|
||||
await self.connect()
|
||||
except AcceptConnection:
|
||||
await self.accept()
|
||||
except DenyConnection:
|
||||
await self.close()
|
||||
|
||||
async def connect(self):
|
||||
await self.accept()
|
||||
|
||||
async def accept(self, subprotocol=None):
|
||||
"""
|
||||
Accepts an incoming socket
|
||||
"""
|
||||
await super().send({"type": "websocket.accept", "subprotocol": subprotocol})
|
||||
|
||||
async def websocket_receive(self, message):
|
||||
"""
|
||||
Called when a WebSocket frame is received. Decodes it and passes it
|
||||
to receive().
|
||||
"""
|
||||
if "text" in message:
|
||||
await self.receive(text_data=message["text"])
|
||||
else:
|
||||
await self.receive(bytes_data=message["bytes"])
|
||||
|
||||
async def receive(self, text_data=None, bytes_data=None):
|
||||
"""
|
||||
Called with a decoded WebSocket frame.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def send(self, text_data=None, bytes_data=None, close=False):
|
||||
"""
|
||||
Sends a reply back down the WebSocket
|
||||
"""
|
||||
if text_data is not None:
|
||||
await super().send({"type": "websocket.send", "text": text_data})
|
||||
elif bytes_data is not None:
|
||||
await super().send({"type": "websocket.send", "bytes": bytes_data})
|
||||
else:
|
||||
raise ValueError("You must pass one of bytes_data or text_data")
|
||||
if close:
|
||||
await self.close(close)
|
||||
|
||||
async def close(self, code=None):
|
||||
"""
|
||||
Closes the WebSocket from the server end
|
||||
"""
|
||||
if code is not None and code is not True:
|
||||
await super().send({"type": "websocket.close", "code": code})
|
||||
else:
|
||||
await super().send({"type": "websocket.close"})
|
||||
|
||||
async def websocket_disconnect(self, message):
|
||||
"""
|
||||
Called when a WebSocket connection is closed. Base level so you don't
|
||||
need to call super() all the time.
|
||||
"""
|
||||
try:
|
||||
for group in self.groups:
|
||||
await self.channel_layer.group_discard(group, self.channel_name)
|
||||
except AttributeError:
|
||||
raise InvalidChannelLayerError(
|
||||
"BACKEND is unconfigured or doesn't support groups"
|
||||
)
|
||||
await self.disconnect(message["code"])
|
||||
raise StopConsumer()
|
||||
|
||||
async def disconnect(self, code):
|
||||
"""
|
||||
Called when a WebSocket connection is closed.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class AsyncJsonWebsocketConsumer(AsyncWebsocketConsumer):
|
||||
"""
|
||||
Variant of AsyncWebsocketConsumer that automatically JSON-encodes and decodes
|
||||
messages as they come in and go out. Expects everything to be text; will
|
||||
error on binary data.
|
||||
"""
|
||||
|
||||
async def receive(self, text_data=None, bytes_data=None, **kwargs):
|
||||
if text_data:
|
||||
await self.receive_json(await self.decode_json(text_data), **kwargs)
|
||||
else:
|
||||
raise ValueError("No text section for incoming WebSocket frame!")
|
||||
|
||||
async def receive_json(self, content, **kwargs):
|
||||
"""
|
||||
Called with decoded JSON content.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def send_json(self, content, close=False):
|
||||
"""
|
||||
Encode the given content as JSON and send it to the client.
|
||||
"""
|
||||
await super().send(text_data=await self.encode_json(content), close=close)
|
||||
|
||||
@classmethod
|
||||
async def decode_json(cls, text_data):
|
||||
return json.loads(text_data)
|
||||
|
||||
@classmethod
|
||||
async def encode_json(cls, content):
|
||||
return json.dumps(content)
|
||||
363
.venv/lib/python3.12/site-packages/channels/layers.py
Normal file
363
.venv/lib/python3.12/site-packages/channels/layers.py
Normal file
@@ -0,0 +1,363 @@
|
||||
import asyncio
|
||||
import fnmatch
|
||||
import random
|
||||
import re
|
||||
import string
|
||||
import time
|
||||
from copy import deepcopy
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.signals import setting_changed
|
||||
from django.utils.module_loading import import_string
|
||||
|
||||
from channels import DEFAULT_CHANNEL_LAYER
|
||||
|
||||
from .exceptions import ChannelFull, InvalidChannelLayerError
|
||||
|
||||
|
||||
class ChannelLayerManager:
|
||||
"""
|
||||
Takes a settings dictionary of backends and initialises them on request.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.backends = {}
|
||||
setting_changed.connect(self._reset_backends)
|
||||
|
||||
def _reset_backends(self, setting, **kwargs):
|
||||
"""
|
||||
Removes cached channel layers when the CHANNEL_LAYERS setting changes.
|
||||
"""
|
||||
if setting == "CHANNEL_LAYERS":
|
||||
self.backends = {}
|
||||
|
||||
@property
|
||||
def configs(self):
|
||||
# Lazy load settings so we can be imported
|
||||
return getattr(settings, "CHANNEL_LAYERS", {})
|
||||
|
||||
def make_backend(self, name):
|
||||
"""
|
||||
Instantiate channel layer.
|
||||
"""
|
||||
config = self.configs[name].get("CONFIG", {})
|
||||
return self._make_backend(name, config)
|
||||
|
||||
def make_test_backend(self, name):
|
||||
"""
|
||||
Instantiate channel layer using its test config.
|
||||
"""
|
||||
try:
|
||||
config = self.configs[name]["TEST_CONFIG"]
|
||||
except KeyError:
|
||||
raise InvalidChannelLayerError("No TEST_CONFIG specified for %s" % name)
|
||||
return self._make_backend(name, config)
|
||||
|
||||
def _make_backend(self, name, config):
|
||||
# Check for old format config
|
||||
if "ROUTING" in self.configs[name]:
|
||||
raise InvalidChannelLayerError(
|
||||
"ROUTING key found for %s - this is no longer needed in Channels 2."
|
||||
% name
|
||||
)
|
||||
# Load the backend class
|
||||
try:
|
||||
backend_class = import_string(self.configs[name]["BACKEND"])
|
||||
except KeyError:
|
||||
raise InvalidChannelLayerError("No BACKEND specified for %s" % name)
|
||||
except ImportError:
|
||||
raise InvalidChannelLayerError(
|
||||
"Cannot import BACKEND %r specified for %s"
|
||||
% (self.configs[name]["BACKEND"], name)
|
||||
)
|
||||
# Initialise and pass config
|
||||
return backend_class(**config)
|
||||
|
||||
def __getitem__(self, key):
|
||||
if key not in self.backends:
|
||||
self.backends[key] = self.make_backend(key)
|
||||
return self.backends[key]
|
||||
|
||||
def __contains__(self, key):
|
||||
return key in self.configs
|
||||
|
||||
def set(self, key, layer):
|
||||
"""
|
||||
Sets an alias to point to a new ChannelLayerWrapper instance, and
|
||||
returns the old one that it replaced. Useful for swapping out the
|
||||
backend during tests.
|
||||
"""
|
||||
old = self.backends.get(key, None)
|
||||
self.backends[key] = layer
|
||||
return old
|
||||
|
||||
|
||||
class BaseChannelLayer:
|
||||
"""
|
||||
Base channel layer class that others can inherit from, with useful
|
||||
common functionality.
|
||||
"""
|
||||
|
||||
MAX_NAME_LENGTH = 100
|
||||
|
||||
def __init__(self, expiry=60, capacity=100, channel_capacity=None):
|
||||
self.expiry = expiry
|
||||
self.capacity = capacity
|
||||
self.channel_capacity = channel_capacity or {}
|
||||
|
||||
def compile_capacities(self, channel_capacity):
|
||||
"""
|
||||
Takes an input channel_capacity dict and returns the compiled list
|
||||
of regexes that get_capacity will look for as self.channel_capacity
|
||||
"""
|
||||
result = []
|
||||
for pattern, value in channel_capacity.items():
|
||||
# If they passed in a precompiled regex, leave it, else interpret
|
||||
# it as a glob.
|
||||
if hasattr(pattern, "match"):
|
||||
result.append((pattern, value))
|
||||
else:
|
||||
result.append((re.compile(fnmatch.translate(pattern)), value))
|
||||
return result
|
||||
|
||||
def get_capacity(self, channel):
|
||||
"""
|
||||
Gets the correct capacity for the given channel; either the default,
|
||||
or a matching result from channel_capacity. Returns the first matching
|
||||
result; if you want to control the order of matches, use an ordered dict
|
||||
as input.
|
||||
"""
|
||||
for pattern, capacity in self.channel_capacity:
|
||||
if pattern.match(channel):
|
||||
return capacity
|
||||
return self.capacity
|
||||
|
||||
def match_type_and_length(self, name):
|
||||
if isinstance(name, str) and (len(name) < self.MAX_NAME_LENGTH):
|
||||
return True
|
||||
return False
|
||||
|
||||
# Name validation functions
|
||||
|
||||
channel_name_regex = re.compile(r"^[a-zA-Z\d\-_.]+(\![\d\w\-_.]*)?$")
|
||||
group_name_regex = re.compile(r"^[a-zA-Z\d\-_.]+$")
|
||||
invalid_name_error = (
|
||||
"{} name must be a valid unicode string "
|
||||
+ "with length < {} ".format(MAX_NAME_LENGTH)
|
||||
+ "containing only ASCII alphanumerics, hyphens, underscores, or periods, "
|
||||
+ "not {}"
|
||||
)
|
||||
|
||||
def valid_channel_name(self, name, receive=False):
|
||||
if self.match_type_and_length(name):
|
||||
if bool(self.channel_name_regex.match(name)):
|
||||
# Check cases for special channels
|
||||
if "!" in name and not name.endswith("!") and receive:
|
||||
raise TypeError(
|
||||
"Specific channel names in receive() must end at the !"
|
||||
)
|
||||
return True
|
||||
raise TypeError(self.invalid_name_error.format("Channel", name))
|
||||
|
||||
def valid_group_name(self, name):
|
||||
if self.match_type_and_length(name):
|
||||
if bool(self.group_name_regex.match(name)):
|
||||
return True
|
||||
raise TypeError(self.invalid_name_error.format("Group", name))
|
||||
|
||||
def valid_channel_names(self, names, receive=False):
|
||||
_non_empty_list = True if names else False
|
||||
_names_type = isinstance(names, list)
|
||||
assert _non_empty_list and _names_type, "names must be a non-empty list"
|
||||
|
||||
assert all(
|
||||
self.valid_channel_name(channel, receive=receive) for channel in names
|
||||
)
|
||||
return True
|
||||
|
||||
def non_local_name(self, name):
|
||||
"""
|
||||
Given a channel name, returns the "non-local" part. If the channel name
|
||||
is a process-specific channel (contains !) this means the part up to
|
||||
and including the !; if it is anything else, this means the full name.
|
||||
"""
|
||||
if "!" in name:
|
||||
return name[: name.find("!") + 1]
|
||||
else:
|
||||
return name
|
||||
|
||||
|
||||
class InMemoryChannelLayer(BaseChannelLayer):
|
||||
"""
|
||||
In-memory channel layer implementation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
expiry=60,
|
||||
group_expiry=86400,
|
||||
capacity=100,
|
||||
channel_capacity=None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
expiry=expiry,
|
||||
capacity=capacity,
|
||||
channel_capacity=channel_capacity,
|
||||
**kwargs
|
||||
)
|
||||
self.channels = {}
|
||||
self.groups = {}
|
||||
self.group_expiry = group_expiry
|
||||
|
||||
# Channel layer API
|
||||
|
||||
extensions = ["groups", "flush"]
|
||||
|
||||
async def send(self, channel, message):
|
||||
"""
|
||||
Send a message onto a (general or specific) channel.
|
||||
"""
|
||||
# Typecheck
|
||||
assert isinstance(message, dict), "message is not a dict"
|
||||
assert self.valid_channel_name(channel), "Channel name not valid"
|
||||
# If it's a process-local channel, strip off local part and stick full
|
||||
# name in message
|
||||
assert "__asgi_channel__" not in message
|
||||
|
||||
queue = self.channels.setdefault(channel, asyncio.Queue())
|
||||
# Are we full
|
||||
if queue.qsize() >= self.capacity:
|
||||
raise ChannelFull(channel)
|
||||
|
||||
# Add message
|
||||
await queue.put((time.time() + self.expiry, deepcopy(message)))
|
||||
|
||||
async def receive(self, channel):
|
||||
"""
|
||||
Receive the first message that arrives on the channel.
|
||||
If more than one coroutine waits on the same channel, a random one
|
||||
of the waiting coroutines will get the result.
|
||||
"""
|
||||
assert self.valid_channel_name(channel)
|
||||
self._clean_expired()
|
||||
|
||||
queue = self.channels.setdefault(channel, asyncio.Queue())
|
||||
|
||||
# Do a plain direct receive
|
||||
try:
|
||||
_, message = await queue.get()
|
||||
finally:
|
||||
if queue.empty():
|
||||
del self.channels[channel]
|
||||
|
||||
return message
|
||||
|
||||
async def new_channel(self, prefix="specific."):
|
||||
"""
|
||||
Returns a new channel name that can be used by something in our
|
||||
process as a specific channel.
|
||||
"""
|
||||
return "%s.inmemory!%s" % (
|
||||
prefix,
|
||||
"".join(random.choice(string.ascii_letters) for i in range(12)),
|
||||
)
|
||||
|
||||
# Expire cleanup
|
||||
|
||||
def _clean_expired(self):
|
||||
"""
|
||||
Goes through all messages and groups and removes those that are expired.
|
||||
Any channel with an expired message is removed from all groups.
|
||||
"""
|
||||
# Channel cleanup
|
||||
for channel, queue in list(self.channels.items()):
|
||||
# See if it's expired
|
||||
while not queue.empty() and queue._queue[0][0] < time.time():
|
||||
queue.get_nowait()
|
||||
# Any removal prompts group discard
|
||||
self._remove_from_groups(channel)
|
||||
# Is the channel now empty and needs deleting?
|
||||
if queue.empty():
|
||||
del self.channels[channel]
|
||||
|
||||
# Group Expiration
|
||||
timeout = int(time.time()) - self.group_expiry
|
||||
for group in self.groups:
|
||||
for channel in list(self.groups.get(group, set())):
|
||||
# If join time is older than group_expiry end the group membership
|
||||
if (
|
||||
self.groups[group][channel]
|
||||
and int(self.groups[group][channel]) < timeout
|
||||
):
|
||||
# Delete from group
|
||||
del self.groups[group][channel]
|
||||
|
||||
# Flush extension
|
||||
|
||||
async def flush(self):
|
||||
self.channels = {}
|
||||
self.groups = {}
|
||||
|
||||
async def close(self):
|
||||
# Nothing to go
|
||||
pass
|
||||
|
||||
def _remove_from_groups(self, channel):
|
||||
"""
|
||||
Removes a channel from all groups. Used when a message on it expires.
|
||||
"""
|
||||
for channels in self.groups.values():
|
||||
if channel in channels:
|
||||
del channels[channel]
|
||||
|
||||
# Groups extension
|
||||
|
||||
async def group_add(self, group, channel):
|
||||
"""
|
||||
Adds the channel name to a group.
|
||||
"""
|
||||
# Check the inputs
|
||||
assert self.valid_group_name(group), "Group name not valid"
|
||||
assert self.valid_channel_name(channel), "Channel name not valid"
|
||||
# Add to group dict
|
||||
self.groups.setdefault(group, {})
|
||||
self.groups[group][channel] = time.time()
|
||||
|
||||
async def group_discard(self, group, channel):
|
||||
# Both should be text and valid
|
||||
assert self.valid_channel_name(channel), "Invalid channel name"
|
||||
assert self.valid_group_name(group), "Invalid group name"
|
||||
# Remove from group set
|
||||
if group in self.groups:
|
||||
if channel in self.groups[group]:
|
||||
del self.groups[group][channel]
|
||||
if not self.groups[group]:
|
||||
del self.groups[group]
|
||||
|
||||
async def group_send(self, group, message):
|
||||
# Check types
|
||||
assert isinstance(message, dict), "Message is not a dict"
|
||||
assert self.valid_group_name(group), "Invalid group name"
|
||||
# Run clean
|
||||
self._clean_expired()
|
||||
# Send to each channel
|
||||
for channel in self.groups.get(group, set()):
|
||||
try:
|
||||
await self.send(channel, message)
|
||||
except ChannelFull:
|
||||
pass
|
||||
|
||||
|
||||
def get_channel_layer(alias=DEFAULT_CHANNEL_LAYER):
|
||||
"""
|
||||
Returns a channel layer by alias, or None if it is not configured.
|
||||
"""
|
||||
try:
|
||||
return channel_layers[alias]
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
|
||||
# Default global instance of the channel layer manager
|
||||
channel_layers = ChannelLayerManager()
|
||||
@@ -0,0 +1,46 @@
|
||||
import logging
|
||||
|
||||
from django.core.management import BaseCommand, CommandError
|
||||
|
||||
from channels import DEFAULT_CHANNEL_LAYER
|
||||
from channels.layers import get_channel_layer
|
||||
from channels.routing import get_default_application
|
||||
from channels.worker import Worker
|
||||
|
||||
logger = logging.getLogger("django.channels.worker")
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
|
||||
leave_locale_alone = True
|
||||
worker_class = Worker
|
||||
|
||||
def add_arguments(self, parser):
|
||||
super(Command, self).add_arguments(parser)
|
||||
parser.add_argument(
|
||||
"--layer",
|
||||
action="store",
|
||||
dest="layer",
|
||||
default=DEFAULT_CHANNEL_LAYER,
|
||||
help="Channel layer alias to use, if not the default.",
|
||||
)
|
||||
parser.add_argument("channels", nargs="+", help="Channels to listen on.")
|
||||
|
||||
def handle(self, *args, **options):
|
||||
# Get the backend to use
|
||||
self.verbosity = options.get("verbosity", 1)
|
||||
# Get the channel layer they asked for (or see if one isn't configured)
|
||||
if "layer" in options:
|
||||
self.channel_layer = get_channel_layer(options["layer"])
|
||||
else:
|
||||
self.channel_layer = get_channel_layer()
|
||||
if self.channel_layer is None:
|
||||
raise CommandError("You do not have any CHANNEL_LAYERS configured.")
|
||||
# Run the worker
|
||||
logger.info("Running worker for channels %s", options["channels"])
|
||||
worker = self.worker_class(
|
||||
application=get_default_application(),
|
||||
channels=options["channels"],
|
||||
channel_layer=self.channel_layer,
|
||||
)
|
||||
worker.run()
|
||||
24
.venv/lib/python3.12/site-packages/channels/middleware.py
Normal file
24
.venv/lib/python3.12/site-packages/channels/middleware.py
Normal file
@@ -0,0 +1,24 @@
|
||||
class BaseMiddleware:
|
||||
"""
|
||||
Base class for implementing ASGI middleware.
|
||||
|
||||
Note that subclasses of this are not self-safe; don't store state on
|
||||
the instance, as it serves multiple application instances. Instead, use
|
||||
scope.
|
||||
"""
|
||||
|
||||
def __init__(self, inner):
|
||||
"""
|
||||
Middleware constructor - just takes inner application.
|
||||
"""
|
||||
self.inner = inner
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
"""
|
||||
ASGI application; can insert things into the scope and run asynchronous
|
||||
code.
|
||||
"""
|
||||
# Copy scope to stop changes going upstream
|
||||
scope = dict(scope)
|
||||
# Run the inner application along with the scope
|
||||
return await self.inner(scope, receive, send)
|
||||
158
.venv/lib/python3.12/site-packages/channels/routing.py
Normal file
158
.venv/lib/python3.12/site-packages/channels/routing.py
Normal file
@@ -0,0 +1,158 @@
|
||||
import importlib
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.urls.exceptions import Resolver404
|
||||
from django.urls.resolvers import URLResolver
|
||||
|
||||
"""
|
||||
All Routing instances inside this file are also valid ASGI applications - with
|
||||
new Channels routing, whatever you end up with as the top level object is just
|
||||
served up as the "ASGI application".
|
||||
"""
|
||||
|
||||
|
||||
def get_default_application():
|
||||
"""
|
||||
Gets the default application, set in the ASGI_APPLICATION setting.
|
||||
"""
|
||||
try:
|
||||
path, name = settings.ASGI_APPLICATION.rsplit(".", 1)
|
||||
except (ValueError, AttributeError):
|
||||
raise ImproperlyConfigured("Cannot find ASGI_APPLICATION setting.")
|
||||
try:
|
||||
module = importlib.import_module(path)
|
||||
except ImportError:
|
||||
raise ImproperlyConfigured("Cannot import ASGI_APPLICATION module %r" % path)
|
||||
try:
|
||||
value = getattr(module, name)
|
||||
except AttributeError:
|
||||
raise ImproperlyConfigured(
|
||||
"Cannot find %r in ASGI_APPLICATION module %s" % (name, path)
|
||||
)
|
||||
return value
|
||||
|
||||
|
||||
DEPRECATION_MSG = """
|
||||
Using ProtocolTypeRouter without an explicit "http" key is deprecated.
|
||||
Given that you have not passed the "http" you likely should use Django's
|
||||
get_asgi_application():
|
||||
|
||||
from django.core.asgi import get_asgi_application
|
||||
|
||||
application = ProtocolTypeRouter(
|
||||
"http": get_asgi_application()
|
||||
# Other protocols here.
|
||||
)
|
||||
"""
|
||||
|
||||
|
||||
class ProtocolTypeRouter:
|
||||
"""
|
||||
Takes a mapping of protocol type names to other Application instances,
|
||||
and dispatches to the right one based on protocol name (or raises an error)
|
||||
"""
|
||||
|
||||
def __init__(self, application_mapping):
|
||||
self.application_mapping = application_mapping
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
if scope["type"] in self.application_mapping:
|
||||
application = self.application_mapping[scope["type"]]
|
||||
return await application(scope, receive, send)
|
||||
else:
|
||||
raise ValueError(
|
||||
"No application configured for scope type %r" % scope["type"]
|
||||
)
|
||||
|
||||
|
||||
class URLRouter:
|
||||
"""
|
||||
Routes to different applications/consumers based on the URL path.
|
||||
|
||||
Works with anything that has a ``path`` key, but intended for WebSocket
|
||||
and HTTP. Uses Django's django.urls objects for resolution -
|
||||
path() or re_path().
|
||||
"""
|
||||
|
||||
#: This router wants to do routing based on scope[path] or
|
||||
#: scope[path_remaining]. ``path()`` entries in URLRouter should not be
|
||||
#: treated as endpoints (ended with ``$``), but similar to ``include()``.
|
||||
_path_routing = True
|
||||
|
||||
def __init__(self, routes):
|
||||
self.routes = routes
|
||||
|
||||
for route in self.routes:
|
||||
# The inner ASGI app wants to do additional routing, route
|
||||
# must not be an endpoint
|
||||
if getattr(route.callback, "_path_routing", False) is True:
|
||||
route.pattern._is_endpoint = False
|
||||
|
||||
if not route.callback and isinstance(route, URLResolver):
|
||||
raise ImproperlyConfigured(
|
||||
"%s: include() is not supported in URLRouter. Use nested"
|
||||
" URLRouter instances instead." % (route,)
|
||||
)
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
# Get the path
|
||||
path = scope.get("path_remaining", scope.get("path", None))
|
||||
if path is None:
|
||||
raise ValueError("No 'path' key in connection scope, cannot route URLs")
|
||||
# Remove leading / to match Django's handling
|
||||
path = path.lstrip("/")
|
||||
# Run through the routes we have until one matches
|
||||
for route in self.routes:
|
||||
try:
|
||||
match = route.pattern.match(path)
|
||||
if match:
|
||||
new_path, args, kwargs = match
|
||||
# Add defaults to kwargs from the URL pattern.
|
||||
kwargs.update(route.default_args)
|
||||
# Add args or kwargs into the scope
|
||||
outer = scope.get("url_route", {})
|
||||
application = route.callback
|
||||
return await application(
|
||||
dict(
|
||||
scope,
|
||||
path_remaining=new_path,
|
||||
url_route={
|
||||
"args": outer.get("args", ()) + args,
|
||||
"kwargs": {**outer.get("kwargs", {}), **kwargs},
|
||||
},
|
||||
),
|
||||
receive,
|
||||
send,
|
||||
)
|
||||
except Resolver404:
|
||||
pass
|
||||
else:
|
||||
if "path_remaining" in scope:
|
||||
raise Resolver404("No route found for path %r." % path)
|
||||
# We are the outermost URLRouter
|
||||
raise ValueError("No route found for path %r." % path)
|
||||
|
||||
|
||||
class ChannelNameRouter:
|
||||
"""
|
||||
Maps to different applications based on a "channel" key in the scope
|
||||
(intended for the Channels worker mode)
|
||||
"""
|
||||
|
||||
def __init__(self, application_mapping):
|
||||
self.application_mapping = application_mapping
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
if "channel" not in scope:
|
||||
raise ValueError(
|
||||
"ChannelNameRouter got a scope without a 'channel' key. "
|
||||
+ "Did you make sure it's only being used for 'channel' type messages?"
|
||||
)
|
||||
if scope["channel"] in self.application_mapping:
|
||||
application = self.application_mapping[scope["channel"]]
|
||||
return await application(scope, receive, send)
|
||||
else:
|
||||
raise ValueError(
|
||||
"No application configured for channel name %r" % scope["channel"]
|
||||
)
|
||||
@@ -0,0 +1,153 @@
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from django.conf import settings
|
||||
from django.http.request import is_same_domain
|
||||
|
||||
from ..generic.websocket import AsyncWebsocketConsumer
|
||||
|
||||
|
||||
class OriginValidator:
|
||||
"""
|
||||
Validates that the incoming connection has an Origin header that
|
||||
is in an allowed list.
|
||||
"""
|
||||
|
||||
def __init__(self, application, allowed_origins):
|
||||
self.application = application
|
||||
self.allowed_origins = allowed_origins
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
# Make sure the scope is of type websocket
|
||||
if scope["type"] != "websocket":
|
||||
raise ValueError(
|
||||
"You cannot use OriginValidator on a non-WebSocket connection"
|
||||
)
|
||||
# Extract the Origin header
|
||||
parsed_origin = None
|
||||
for header_name, header_value in scope.get("headers", []):
|
||||
if header_name == b"origin":
|
||||
try:
|
||||
# Set ResultParse
|
||||
parsed_origin = urlparse(header_value.decode("latin1"))
|
||||
except UnicodeDecodeError:
|
||||
pass
|
||||
# Check to see if the origin header is valid
|
||||
if self.valid_origin(parsed_origin):
|
||||
# Pass control to the application
|
||||
return await self.application(scope, receive, send)
|
||||
else:
|
||||
# Deny the connection
|
||||
denier = WebsocketDenier()
|
||||
return await denier(scope, receive, send)
|
||||
|
||||
def valid_origin(self, parsed_origin):
|
||||
"""
|
||||
Checks parsed origin is None.
|
||||
|
||||
Pass control to the validate_origin function.
|
||||
|
||||
Returns ``True`` if validation function was successful, ``False`` otherwise.
|
||||
"""
|
||||
# None is not allowed unless all hosts are allowed
|
||||
if parsed_origin is None and "*" not in self.allowed_origins:
|
||||
return False
|
||||
return self.validate_origin(parsed_origin)
|
||||
|
||||
def validate_origin(self, parsed_origin):
|
||||
"""
|
||||
Validate the given origin for this site.
|
||||
|
||||
Check than the origin looks valid and matches the origin pattern in
|
||||
specified list ``allowed_origins``. Any pattern begins with a scheme.
|
||||
After the scheme there must be a domain. Any domain beginning with a
|
||||
period corresponds to the domain and all its subdomains (for example,
|
||||
``http://.example.com``). After the domain there must be a port,
|
||||
but it can be omitted. ``*`` matches anything and anything
|
||||
else must match exactly.
|
||||
|
||||
Note. This function assumes that the given origin has a schema, domain
|
||||
and port, but port is optional.
|
||||
|
||||
Returns ``True`` for a valid host, ``False`` otherwise.
|
||||
"""
|
||||
return any(
|
||||
pattern == "*" or self.match_allowed_origin(parsed_origin, pattern)
|
||||
for pattern in self.allowed_origins
|
||||
)
|
||||
|
||||
def match_allowed_origin(self, parsed_origin, pattern):
|
||||
"""
|
||||
Returns ``True`` if the origin is either an exact match or a match
|
||||
to the wildcard pattern. Compares scheme, domain, port of origin and pattern.
|
||||
|
||||
Any pattern can be begins with a scheme. After the scheme must be a domain,
|
||||
or just domain without scheme.
|
||||
Any domain beginning with a period corresponds to the domain and all
|
||||
its subdomains (for example, ``.example.com`` ``example.com``
|
||||
and any subdomain). Also with scheme (for example, ``http://.example.com``
|
||||
``http://exapmple.com``). After the domain there must be a port,
|
||||
but it can be omitted.
|
||||
|
||||
Note. This function assumes that the given origin is either None, a
|
||||
schema-domain-port string, or just a domain string
|
||||
"""
|
||||
if parsed_origin is None:
|
||||
return False
|
||||
|
||||
# Get ResultParse object
|
||||
parsed_pattern = urlparse(pattern.lower())
|
||||
if parsed_origin.hostname is None:
|
||||
return False
|
||||
if not parsed_pattern.scheme:
|
||||
pattern_hostname = urlparse("//" + pattern).hostname or pattern
|
||||
return is_same_domain(parsed_origin.hostname, pattern_hostname)
|
||||
# Get origin.port or default ports for origin or None
|
||||
origin_port = self.get_origin_port(parsed_origin)
|
||||
# Get pattern.port or default ports for pattern or None
|
||||
pattern_port = self.get_origin_port(parsed_pattern)
|
||||
# Compares hostname, scheme, ports of pattern and origin
|
||||
if (
|
||||
parsed_pattern.scheme == parsed_origin.scheme
|
||||
and origin_port == pattern_port
|
||||
and is_same_domain(parsed_origin.hostname, parsed_pattern.hostname)
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_origin_port(self, origin):
|
||||
"""
|
||||
Returns the origin.port or port for this schema by default.
|
||||
Otherwise, it returns None.
|
||||
"""
|
||||
if origin.port is not None:
|
||||
# Return origin.port
|
||||
return origin.port
|
||||
# if origin.port doesn`t exists
|
||||
if origin.scheme == "http" or origin.scheme == "ws":
|
||||
# Default port return for http, ws
|
||||
return 80
|
||||
elif origin.scheme == "https" or origin.scheme == "wss":
|
||||
# Default port return for https, wss
|
||||
return 443
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def AllowedHostsOriginValidator(application):
|
||||
"""
|
||||
Factory function which returns an OriginValidator configured to use
|
||||
settings.ALLOWED_HOSTS.
|
||||
"""
|
||||
allowed_hosts = settings.ALLOWED_HOSTS
|
||||
if settings.DEBUG and not allowed_hosts:
|
||||
allowed_hosts = ["localhost", "127.0.0.1", "[::1]"]
|
||||
return OriginValidator(application, allowed_hosts)
|
||||
|
||||
|
||||
class WebsocketDenier(AsyncWebsocketConsumer):
|
||||
"""
|
||||
Simple application which denies all requests to it.
|
||||
"""
|
||||
|
||||
async def connect(self):
|
||||
await self.close()
|
||||
268
.venv/lib/python3.12/site-packages/channels/sessions.py
Normal file
268
.venv/lib/python3.12/site-packages/channels/sessions.py
Normal file
@@ -0,0 +1,268 @@
|
||||
import datetime
|
||||
import time
|
||||
from importlib import import_module
|
||||
|
||||
from django.conf import settings
|
||||
from django.contrib.sessions.backends.base import UpdateError
|
||||
from django.core.exceptions import SuspiciousOperation
|
||||
from django.http import parse_cookie
|
||||
from django.http.cookie import SimpleCookie
|
||||
from django.utils import timezone
|
||||
from django.utils.encoding import force_str
|
||||
from django.utils.functional import LazyObject
|
||||
|
||||
from channels.db import database_sync_to_async
|
||||
|
||||
try:
|
||||
from django.utils.http import http_date
|
||||
except ImportError:
|
||||
from django.utils.http import cookie_date as http_date
|
||||
|
||||
|
||||
class CookieMiddleware:
|
||||
"""
|
||||
Extracts cookies from HTTP or WebSocket-style scopes and adds them as a
|
||||
scope["cookies"] entry with the same format as Django's request.COOKIES.
|
||||
"""
|
||||
|
||||
def __init__(self, inner):
|
||||
self.inner = inner
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
# Check this actually has headers. They're a required scope key for HTTP and WS.
|
||||
if "headers" not in scope:
|
||||
raise ValueError(
|
||||
"CookieMiddleware was passed a scope that did not have a headers key "
|
||||
+ "(make sure it is only passed HTTP or WebSocket connections)"
|
||||
)
|
||||
# Go through headers to find the cookie one
|
||||
for name, value in scope.get("headers", []):
|
||||
if name == b"cookie":
|
||||
cookies = parse_cookie(value.decode("latin1"))
|
||||
break
|
||||
else:
|
||||
# No cookie header found - add an empty default.
|
||||
cookies = {}
|
||||
# Return inner application
|
||||
return await self.inner(dict(scope, cookies=cookies), receive, send)
|
||||
|
||||
@classmethod
|
||||
def set_cookie(
|
||||
cls,
|
||||
message,
|
||||
key,
|
||||
value="",
|
||||
max_age=None,
|
||||
expires=None,
|
||||
path="/",
|
||||
domain=None,
|
||||
secure=False,
|
||||
httponly=False,
|
||||
samesite="lax",
|
||||
):
|
||||
"""
|
||||
Sets a cookie in the passed HTTP response message.
|
||||
|
||||
``expires`` can be:
|
||||
- a string in the correct format,
|
||||
- a naive ``datetime.datetime`` object in UTC,
|
||||
- an aware ``datetime.datetime`` object in any time zone.
|
||||
If it is a ``datetime.datetime`` object then ``max_age`` will be calculated.
|
||||
"""
|
||||
value = force_str(value)
|
||||
cookies = SimpleCookie()
|
||||
cookies[key] = value
|
||||
if expires is not None:
|
||||
if isinstance(expires, datetime.datetime):
|
||||
if timezone.is_aware(expires):
|
||||
expires = timezone.make_naive(expires, timezone.utc)
|
||||
delta = expires - expires.utcnow()
|
||||
# Add one second so the date matches exactly (a fraction of
|
||||
# time gets lost between converting to a timedelta and
|
||||
# then the date string).
|
||||
delta = delta + datetime.timedelta(seconds=1)
|
||||
# Just set max_age - the max_age logic will set expires.
|
||||
expires = None
|
||||
max_age = max(0, delta.days * 86400 + delta.seconds)
|
||||
else:
|
||||
cookies[key]["expires"] = expires
|
||||
else:
|
||||
cookies[key]["expires"] = ""
|
||||
if max_age is not None:
|
||||
cookies[key]["max-age"] = max_age
|
||||
# IE requires expires, so set it if hasn't been already.
|
||||
if not expires:
|
||||
cookies[key]["expires"] = http_date(time.time() + max_age)
|
||||
if path is not None:
|
||||
cookies[key]["path"] = path
|
||||
if domain is not None:
|
||||
cookies[key]["domain"] = domain
|
||||
if secure:
|
||||
cookies[key]["secure"] = True
|
||||
if httponly:
|
||||
cookies[key]["httponly"] = True
|
||||
if samesite is not None:
|
||||
assert samesite.lower() in [
|
||||
"strict",
|
||||
"lax",
|
||||
"none",
|
||||
], "samesite must be either 'strict', 'lax' or 'none'"
|
||||
cookies[key]["samesite"] = samesite
|
||||
# Write out the cookies to the response
|
||||
for c in cookies.values():
|
||||
message.setdefault("headers", []).append(
|
||||
(b"Set-Cookie", bytes(c.output(header=""), encoding="utf-8"))
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def delete_cookie(cls, message, key, path="/", domain=None):
|
||||
"""
|
||||
Deletes a cookie in a response.
|
||||
"""
|
||||
return cls.set_cookie(
|
||||
message,
|
||||
key,
|
||||
max_age=0,
|
||||
path=path,
|
||||
domain=domain,
|
||||
expires="Thu, 01-Jan-1970 00:00:00 GMT",
|
||||
)
|
||||
|
||||
|
||||
class InstanceSessionWrapper:
|
||||
"""
|
||||
Populates the session in application instance scope, and wraps send to save
|
||||
the session.
|
||||
"""
|
||||
|
||||
# Message types that trigger a session save if it's modified
|
||||
save_message_types = ["http.response.start"]
|
||||
|
||||
# Message types that can carry session cookies back
|
||||
cookie_response_message_types = ["http.response.start"]
|
||||
|
||||
def __init__(self, scope, send):
|
||||
self.cookie_name = settings.SESSION_COOKIE_NAME
|
||||
self.session_store = import_module(settings.SESSION_ENGINE).SessionStore
|
||||
|
||||
self.scope = dict(scope)
|
||||
|
||||
if "session" in self.scope:
|
||||
# There's already session middleware of some kind above us, pass
|
||||
# that through
|
||||
self.activated = False
|
||||
else:
|
||||
# Make sure there are cookies in the scope
|
||||
if "cookies" not in self.scope:
|
||||
raise ValueError(
|
||||
"No cookies in scope - SessionMiddleware needs to run "
|
||||
"inside of CookieMiddleware."
|
||||
)
|
||||
# Parse the headers in the scope into cookies
|
||||
self.scope["session"] = LazyObject()
|
||||
self.activated = True
|
||||
|
||||
# Override send
|
||||
self.real_send = send
|
||||
|
||||
async def resolve_session(self):
|
||||
session_key = self.scope["cookies"].get(self.cookie_name)
|
||||
self.scope["session"]._wrapped = await database_sync_to_async(
|
||||
self.session_store
|
||||
)(session_key)
|
||||
|
||||
async def send(self, message):
|
||||
"""
|
||||
Overridden send that also does session saves/cookies.
|
||||
"""
|
||||
# Only save session if we're the outermost session middleware
|
||||
if self.activated:
|
||||
modified = self.scope["session"].modified
|
||||
empty = self.scope["session"].is_empty()
|
||||
# If this is a message type that we want to save on, and there's
|
||||
# changed data, save it. We also save if it's empty as we might
|
||||
# not be able to send a cookie-delete along with this message.
|
||||
if (
|
||||
message["type"] in self.save_message_types
|
||||
and message.get("status", 200) != 500
|
||||
and (modified or settings.SESSION_SAVE_EVERY_REQUEST)
|
||||
):
|
||||
await database_sync_to_async(self.save_session)()
|
||||
# If this is a message type that can transport cookies back to the
|
||||
# client, then do so.
|
||||
if message["type"] in self.cookie_response_message_types:
|
||||
if empty:
|
||||
# Delete cookie if it's set
|
||||
if settings.SESSION_COOKIE_NAME in self.scope["cookies"]:
|
||||
CookieMiddleware.delete_cookie(
|
||||
message,
|
||||
settings.SESSION_COOKIE_NAME,
|
||||
path=settings.SESSION_COOKIE_PATH,
|
||||
domain=settings.SESSION_COOKIE_DOMAIN,
|
||||
)
|
||||
else:
|
||||
# Get the expiry data
|
||||
if self.scope["session"].get_expire_at_browser_close():
|
||||
max_age = None
|
||||
expires = None
|
||||
else:
|
||||
max_age = self.scope["session"].get_expiry_age()
|
||||
expires_time = time.time() + max_age
|
||||
expires = http_date(expires_time)
|
||||
# Set the cookie
|
||||
CookieMiddleware.set_cookie(
|
||||
message,
|
||||
self.cookie_name,
|
||||
self.scope["session"].session_key,
|
||||
max_age=max_age,
|
||||
expires=expires,
|
||||
domain=settings.SESSION_COOKIE_DOMAIN,
|
||||
path=settings.SESSION_COOKIE_PATH,
|
||||
secure=settings.SESSION_COOKIE_SECURE or None,
|
||||
httponly=settings.SESSION_COOKIE_HTTPONLY or None,
|
||||
samesite=settings.SESSION_COOKIE_SAMESITE,
|
||||
)
|
||||
# Pass up the send
|
||||
return await self.real_send(message)
|
||||
|
||||
def save_session(self):
|
||||
"""
|
||||
Saves the current session.
|
||||
"""
|
||||
try:
|
||||
self.scope["session"].save()
|
||||
except UpdateError:
|
||||
raise SuspiciousOperation(
|
||||
"The request's session was deleted before the "
|
||||
"request completed. The user may have logged "
|
||||
"out in a concurrent request, for example."
|
||||
)
|
||||
|
||||
|
||||
class SessionMiddleware:
|
||||
"""
|
||||
Class that adds Django sessions (from HTTP cookies) to the
|
||||
scope. Works with HTTP or WebSocket protocol types (or anything that
|
||||
provides a "headers" entry in the scope).
|
||||
|
||||
Requires the CookieMiddleware to be higher up in the stack.
|
||||
"""
|
||||
|
||||
def __init__(self, inner):
|
||||
self.inner = inner
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
"""
|
||||
Instantiate a session wrapper for this scope, resolve the session and
|
||||
call the inner application.
|
||||
"""
|
||||
wrapper = InstanceSessionWrapper(scope, send)
|
||||
|
||||
await wrapper.resolve_session()
|
||||
|
||||
return await self.inner(wrapper.scope, receive, wrapper.send)
|
||||
|
||||
|
||||
# Shortcut to include cookie middleware
|
||||
def SessionMiddlewareStack(inner):
|
||||
return CookieMiddleware(SessionMiddleware(inner))
|
||||
@@ -0,0 +1,12 @@
|
||||
from asgiref.testing import ApplicationCommunicator # noqa
|
||||
|
||||
from .http import HttpCommunicator # noqa
|
||||
from .live import ChannelsLiveServerTestCase # noqa
|
||||
from .websocket import WebsocketCommunicator # noqa
|
||||
|
||||
__all__ = [
|
||||
"ApplicationCommunicator",
|
||||
"HttpCommunicator",
|
||||
"ChannelsLiveServerTestCase",
|
||||
"WebsocketCommunicator",
|
||||
]
|
||||
56
.venv/lib/python3.12/site-packages/channels/testing/http.py
Normal file
56
.venv/lib/python3.12/site-packages/channels/testing/http.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from urllib.parse import unquote, urlparse
|
||||
|
||||
from asgiref.testing import ApplicationCommunicator
|
||||
|
||||
|
||||
class HttpCommunicator(ApplicationCommunicator):
|
||||
"""
|
||||
ApplicationCommunicator subclass that has HTTP shortcut methods.
|
||||
|
||||
It will construct the scope for you, so you need to pass the application
|
||||
(uninstantiated) along with HTTP parameters.
|
||||
|
||||
This does not support full chunking - for that, just use ApplicationCommunicator
|
||||
directly.
|
||||
"""
|
||||
|
||||
def __init__(self, application, method, path, body=b"", headers=None):
|
||||
parsed = urlparse(path)
|
||||
self.scope = {
|
||||
"type": "http",
|
||||
"http_version": "1.1",
|
||||
"method": method.upper(),
|
||||
"path": unquote(parsed.path),
|
||||
"query_string": parsed.query.encode("utf-8"),
|
||||
"headers": headers or [],
|
||||
}
|
||||
assert isinstance(body, bytes)
|
||||
self.body = body
|
||||
self.sent_request = False
|
||||
super().__init__(application, self.scope)
|
||||
|
||||
async def get_response(self, timeout=1):
|
||||
"""
|
||||
Get the application's response. Returns a dict with keys of
|
||||
"body", "headers" and "status".
|
||||
"""
|
||||
# If we've not sent the request yet, do so
|
||||
if not self.sent_request:
|
||||
self.sent_request = True
|
||||
await self.send_input({"type": "http.request", "body": self.body})
|
||||
# Get the response start
|
||||
response_start = await self.receive_output(timeout)
|
||||
assert response_start["type"] == "http.response.start"
|
||||
# Get all body parts
|
||||
response_start["body"] = b""
|
||||
while True:
|
||||
chunk = await self.receive_output(timeout)
|
||||
assert chunk["type"] == "http.response.body"
|
||||
assert isinstance(chunk["body"], bytes)
|
||||
response_start["body"] += chunk["body"]
|
||||
if not chunk.get("more_body", False):
|
||||
break
|
||||
# Return structured info
|
||||
del response_start["type"]
|
||||
response_start.setdefault("headers", [])
|
||||
return response_start
|
||||
76
.venv/lib/python3.12/site-packages/channels/testing/live.py
Normal file
76
.venv/lib/python3.12/site-packages/channels/testing/live.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from functools import partial
|
||||
|
||||
from daphne.testing import DaphneProcess
|
||||
from django.contrib.staticfiles.handlers import ASGIStaticFilesHandler
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.db import connections
|
||||
from django.test.testcases import TransactionTestCase
|
||||
from django.test.utils import modify_settings
|
||||
|
||||
from channels.routing import get_default_application
|
||||
|
||||
|
||||
def make_application(*, static_wrapper):
|
||||
# Module-level function for pickle-ability
|
||||
application = get_default_application()
|
||||
if static_wrapper is not None:
|
||||
application = static_wrapper(application)
|
||||
return application
|
||||
|
||||
|
||||
class ChannelsLiveServerTestCase(TransactionTestCase):
|
||||
"""
|
||||
Does basically the same as TransactionTestCase but also launches a
|
||||
live Daphne server in a separate process, so
|
||||
that the tests may use another test framework, such as Selenium,
|
||||
instead of the built-in dummy client.
|
||||
"""
|
||||
|
||||
host = "localhost"
|
||||
ProtocolServerProcess = DaphneProcess
|
||||
static_wrapper = ASGIStaticFilesHandler
|
||||
serve_static = True
|
||||
|
||||
@property
|
||||
def live_server_url(self):
|
||||
return "http://%s:%s" % (self.host, self._port)
|
||||
|
||||
@property
|
||||
def live_server_ws_url(self):
|
||||
return "ws://%s:%s" % (self.host, self._port)
|
||||
|
||||
def _pre_setup(self):
|
||||
for connection in connections.all():
|
||||
if self._is_in_memory_db(connection):
|
||||
raise ImproperlyConfigured(
|
||||
"ChannelLiveServerTestCase can not be used with in memory databases"
|
||||
)
|
||||
|
||||
super(ChannelsLiveServerTestCase, self)._pre_setup()
|
||||
|
||||
self._live_server_modified_settings = modify_settings(
|
||||
ALLOWED_HOSTS={"append": self.host}
|
||||
)
|
||||
self._live_server_modified_settings.enable()
|
||||
|
||||
get_application = partial(
|
||||
make_application,
|
||||
static_wrapper=self.static_wrapper if self.serve_static else None,
|
||||
)
|
||||
self._server_process = self.ProtocolServerProcess(self.host, get_application)
|
||||
self._server_process.start()
|
||||
self._server_process.ready.wait()
|
||||
self._port = self._server_process.port.value
|
||||
|
||||
def _post_teardown(self):
|
||||
self._server_process.terminate()
|
||||
self._server_process.join()
|
||||
self._live_server_modified_settings.disable()
|
||||
super(ChannelsLiveServerTestCase, self)._post_teardown()
|
||||
|
||||
def _is_in_memory_db(self, connection):
|
||||
"""
|
||||
Check if DatabaseWrapper holds in memory database.
|
||||
"""
|
||||
if connection.vendor == "sqlite":
|
||||
return connection.is_in_memory_db()
|
||||
102
.venv/lib/python3.12/site-packages/channels/testing/websocket.py
Normal file
102
.venv/lib/python3.12/site-packages/channels/testing/websocket.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import json
|
||||
from urllib.parse import unquote, urlparse
|
||||
|
||||
from asgiref.testing import ApplicationCommunicator
|
||||
|
||||
|
||||
class WebsocketCommunicator(ApplicationCommunicator):
|
||||
"""
|
||||
ApplicationCommunicator subclass that has WebSocket shortcut methods.
|
||||
|
||||
It will construct the scope for you, so you need to pass the application
|
||||
(uninstantiated) along with the initial connection parameters.
|
||||
"""
|
||||
|
||||
def __init__(self, application, path, headers=None, subprotocols=None):
|
||||
if not isinstance(path, str):
|
||||
raise TypeError("Expected str, got {}".format(type(path)))
|
||||
parsed = urlparse(path)
|
||||
self.scope = {
|
||||
"type": "websocket",
|
||||
"path": unquote(parsed.path),
|
||||
"query_string": parsed.query.encode("utf-8"),
|
||||
"headers": headers or [],
|
||||
"subprotocols": subprotocols or [],
|
||||
}
|
||||
super().__init__(application, self.scope)
|
||||
|
||||
async def connect(self, timeout=1):
|
||||
"""
|
||||
Trigger the connection code.
|
||||
|
||||
On an accepted connection, returns (True, <chosen-subprotocol>)
|
||||
On a rejected connection, returns (False, <close-code>)
|
||||
"""
|
||||
await self.send_input({"type": "websocket.connect"})
|
||||
response = await self.receive_output(timeout)
|
||||
if response["type"] == "websocket.close":
|
||||
return (False, response.get("code", 1000))
|
||||
else:
|
||||
return (True, response.get("subprotocol", None))
|
||||
|
||||
async def send_to(self, text_data=None, bytes_data=None):
|
||||
"""
|
||||
Sends a WebSocket frame to the application.
|
||||
"""
|
||||
# Make sure we have exactly one of the arguments
|
||||
assert bool(text_data) != bool(
|
||||
bytes_data
|
||||
), "You must supply exactly one of text_data or bytes_data"
|
||||
# Send the right kind of event
|
||||
if text_data:
|
||||
assert isinstance(text_data, str), "The text_data argument must be a str"
|
||||
await self.send_input({"type": "websocket.receive", "text": text_data})
|
||||
else:
|
||||
assert isinstance(
|
||||
bytes_data, bytes
|
||||
), "The bytes_data argument must be bytes"
|
||||
await self.send_input({"type": "websocket.receive", "bytes": bytes_data})
|
||||
|
||||
async def send_json_to(self, data):
|
||||
"""
|
||||
Sends JSON data as a text frame
|
||||
"""
|
||||
await self.send_to(text_data=json.dumps(data))
|
||||
|
||||
async def receive_from(self, timeout=1):
|
||||
"""
|
||||
Receives a data frame from the view. Will fail if the connection
|
||||
closes instead. Returns either a bytestring or a unicode string
|
||||
depending on what sort of frame you got.
|
||||
"""
|
||||
response = await self.receive_output(timeout)
|
||||
# Make sure this is a send message
|
||||
assert response["type"] == "websocket.send"
|
||||
# Make sure there's exactly one key in the response
|
||||
assert ("text" in response) != (
|
||||
"bytes" in response
|
||||
), "The response needs exactly one of 'text' or 'bytes'"
|
||||
# Pull out the right key and typecheck it for our users
|
||||
if "text" in response:
|
||||
assert isinstance(response["text"], str), "Text frame payload is not str"
|
||||
return response["text"]
|
||||
else:
|
||||
assert isinstance(
|
||||
response["bytes"], bytes
|
||||
), "Binary frame payload is not bytes"
|
||||
return response["bytes"]
|
||||
|
||||
async def receive_json_from(self, timeout=1):
|
||||
"""
|
||||
Receives a JSON text frame payload and decodes it
|
||||
"""
|
||||
payload = await self.receive_from(timeout)
|
||||
assert isinstance(payload, str), "JSON data is not a text frame"
|
||||
return json.loads(payload)
|
||||
|
||||
async def disconnect(self, code=1000, timeout=1):
|
||||
"""
|
||||
Closes the socket
|
||||
"""
|
||||
await self.send_input({"type": "websocket.disconnect", "code": code})
|
||||
await self.wait(timeout)
|
||||
59
.venv/lib/python3.12/site-packages/channels/utils.py
Normal file
59
.venv/lib/python3.12/site-packages/channels/utils.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import asyncio
|
||||
import types
|
||||
|
||||
|
||||
def name_that_thing(thing):
|
||||
"""
|
||||
Returns either the function/class path or just the object's repr
|
||||
"""
|
||||
# Instance method
|
||||
if hasattr(thing, "im_class"):
|
||||
# Mocks will recurse im_class forever
|
||||
if hasattr(thing, "mock_calls"):
|
||||
return "<mock>"
|
||||
return name_that_thing(thing.im_class) + "." + thing.im_func.func_name
|
||||
# Other named thing
|
||||
if hasattr(thing, "__name__"):
|
||||
if hasattr(thing, "__class__") and not isinstance(
|
||||
thing, (types.FunctionType, types.MethodType)
|
||||
):
|
||||
if thing.__class__ is not type and not issubclass(thing.__class__, type):
|
||||
return name_that_thing(thing.__class__)
|
||||
if hasattr(thing, "__self__"):
|
||||
return "%s.%s" % (thing.__self__.__module__, thing.__self__.__name__)
|
||||
if hasattr(thing, "__module__"):
|
||||
return "%s.%s" % (thing.__module__, thing.__name__)
|
||||
# Generic instance of a class
|
||||
if hasattr(thing, "__class__"):
|
||||
return name_that_thing(thing.__class__)
|
||||
return repr(thing)
|
||||
|
||||
|
||||
async def await_many_dispatch(consumer_callables, dispatch):
|
||||
"""
|
||||
Given a set of consumer callables, awaits on them all and passes results
|
||||
from them to the dispatch awaitable as they come in.
|
||||
"""
|
||||
# Call all callables, and ensure all return types are Futures
|
||||
tasks = [
|
||||
asyncio.ensure_future(consumer_callable())
|
||||
for consumer_callable in consumer_callables
|
||||
]
|
||||
try:
|
||||
while True:
|
||||
# Wait for any of them to complete
|
||||
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
||||
# Find the completed one(s), yield results, and replace them
|
||||
for i, task in enumerate(tasks):
|
||||
if task.done():
|
||||
result = task.result()
|
||||
await dispatch(result)
|
||||
tasks[i] = asyncio.ensure_future(consumer_callables[i]())
|
||||
finally:
|
||||
# Make sure we clean up tasks on exit
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
44
.venv/lib/python3.12/site-packages/channels/worker.py
Normal file
44
.venv/lib/python3.12/site-packages/channels/worker.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import asyncio
|
||||
|
||||
from asgiref.server import StatelessServer
|
||||
|
||||
|
||||
class Worker(StatelessServer):
|
||||
"""
|
||||
ASGI protocol server that surfaces events sent to specific channels
|
||||
on the channel layer into a single application instance.
|
||||
"""
|
||||
|
||||
def __init__(self, application, channels, channel_layer, max_applications=1000):
|
||||
super().__init__(application, max_applications)
|
||||
self.channels = channels
|
||||
self.channel_layer = channel_layer
|
||||
if self.channel_layer is None:
|
||||
raise ValueError("Channel layer is not valid")
|
||||
|
||||
async def handle(self):
|
||||
"""
|
||||
Listens on all the provided channels and handles the messages.
|
||||
"""
|
||||
# For each channel, launch its own listening coroutine
|
||||
listeners = []
|
||||
for channel in self.channels:
|
||||
listeners.append(asyncio.ensure_future(self.listener(channel)))
|
||||
# Wait for them all to exit
|
||||
await asyncio.wait(listeners)
|
||||
# See if any of the listeners had an error (e.g. channel layer error)
|
||||
[listener.result() for listener in listeners]
|
||||
|
||||
async def listener(self, channel):
|
||||
"""
|
||||
Single-channel listener
|
||||
"""
|
||||
while True:
|
||||
message = await self.channel_layer.receive(channel)
|
||||
if not message.get("type", None):
|
||||
raise ValueError("Worker received message with no type.")
|
||||
# Make a scope and get an application instance for it
|
||||
scope = {"type": "channel", "channel": channel}
|
||||
instance_queue = self.get_or_create_application_instance(channel, scope)
|
||||
# Run the message into the app
|
||||
await instance_queue.put(message)
|
||||
Reference in New Issue
Block a user