okay fine

This commit is contained in:
pacnpal
2024-11-03 17:47:26 +00:00
parent 01c6004a79
commit 27eb239e97
10020 changed files with 1935769 additions and 2364 deletions

View File

@@ -0,0 +1,14 @@
import sys
__version__ = "4.1.0"
# Windows on Python 3.8+ uses ProactorEventLoop, which is not compatible with
# Twisted. Does not implement add_writer/add_reader.
# See https://bugs.python.org/issue37373
# and https://twistedmatrix.com/trac/ticket/9766
PY38_WIN = sys.version_info >= (3, 8) and sys.platform == "win32"
if PY38_WIN:
import asyncio
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())

View File

@@ -0,0 +1,3 @@
from daphne.cli import CommandLineInterface
CommandLineInterface.entrypoint()

View File

@@ -0,0 +1,70 @@
import datetime
class AccessLogGenerator:
"""
Object that implements the Daphne "action logger" internal interface in
order to provide an access log in something resembling NCSA format.
"""
def __init__(self, stream):
self.stream = stream
def __call__(self, protocol, action, details):
"""
Called when an action happens; use it to generate log entries.
"""
# HTTP requests
if protocol == "http" and action == "complete":
self.write_entry(
host=details["client"],
date=datetime.datetime.now(),
request="%(method)s %(path)s" % details,
status=details["status"],
length=details["size"],
)
# Websocket requests
elif protocol == "websocket" and action == "connecting":
self.write_entry(
host=details["client"],
date=datetime.datetime.now(),
request="WSCONNECTING %(path)s" % details,
)
elif protocol == "websocket" and action == "rejected":
self.write_entry(
host=details["client"],
date=datetime.datetime.now(),
request="WSREJECT %(path)s" % details,
)
elif protocol == "websocket" and action == "connected":
self.write_entry(
host=details["client"],
date=datetime.datetime.now(),
request="WSCONNECT %(path)s" % details,
)
elif protocol == "websocket" and action == "disconnected":
self.write_entry(
host=details["client"],
date=datetime.datetime.now(),
request="WSDISCONNECT %(path)s" % details,
)
def write_entry(
self, host, date, request, status=None, length=None, ident=None, user=None
):
"""
Writes an NCSA-style entry to the log file (some liberty is taken with
what the entries are for non-HTTP)
"""
self.stream.write(
'%s %s %s [%s] "%s" %s %s\n'
% (
host,
ident or "-",
user or "-",
date.strftime("%d/%b/%Y:%H:%M:%S"),
request,
status or "-",
length or "-",
)
)

View File

@@ -0,0 +1,16 @@
# Import the server here to ensure the reactor is installed very early on in case other
# packages import twisted.internet.reactor (e.g. raven does this).
from django.apps import AppConfig
from django.core import checks
import daphne.server # noqa: F401
from .checks import check_daphne_installed
class DaphneConfig(AppConfig):
name = "daphne"
verbose_name = "Daphne"
def ready(self):
checks.register(check_daphne_installed, checks.Tags.staticfiles)

View File

@@ -0,0 +1,21 @@
# Django system check to ensure daphne app is listed in INSTALLED_APPS before django.contrib.staticfiles.
from django.core.checks import Error, register
@register()
def check_daphne_installed(app_configs, **kwargs):
from django.apps import apps
from django.contrib.staticfiles.apps import StaticFilesConfig
from daphne.apps import DaphneConfig
for app in apps.get_app_configs():
if isinstance(app, DaphneConfig):
return []
if isinstance(app, StaticFilesConfig):
return [
Error(
"Daphne must be listed before django.contrib.staticfiles in INSTALLED_APPS.",
id="daphne.E001",
)
]

View File

@@ -0,0 +1,285 @@
import argparse
import logging
import sys
from argparse import ArgumentError, Namespace
from asgiref.compatibility import guarantee_single_callable
from .access import AccessLogGenerator
from .endpoints import build_endpoint_description_strings
from .server import Server
from .utils import import_by_path
logger = logging.getLogger(__name__)
DEFAULT_HOST = "127.0.0.1"
DEFAULT_PORT = 8000
class CommandLineInterface:
"""
Acts as the main CLI entry point for running the server.
"""
description = "Django HTTP/WebSocket server"
server_class = Server
def __init__(self):
self.parser = argparse.ArgumentParser(description=self.description)
self.parser.add_argument(
"-p", "--port", type=int, help="Port number to listen on", default=None
)
self.parser.add_argument(
"-b",
"--bind",
dest="host",
help="The host/address to bind to",
default=None,
)
self.parser.add_argument(
"--websocket_timeout",
type=int,
help="Maximum time to allow a websocket to be connected. -1 for infinite.",
default=86400,
)
self.parser.add_argument(
"--websocket_connect_timeout",
type=int,
help="Maximum time to allow a connection to handshake. -1 for infinite",
default=5,
)
self.parser.add_argument(
"-u",
"--unix-socket",
dest="unix_socket",
help="Bind to a UNIX socket rather than a TCP host/port",
default=None,
)
self.parser.add_argument(
"--fd",
type=int,
dest="file_descriptor",
help="Bind to a file descriptor rather than a TCP host/port or named unix socket",
default=None,
)
self.parser.add_argument(
"-e",
"--endpoint",
dest="socket_strings",
action="append",
help="Use raw server strings passed directly to twisted",
default=[],
)
self.parser.add_argument(
"-v",
"--verbosity",
type=int,
help="How verbose to make the output",
default=1,
)
self.parser.add_argument(
"-t",
"--http-timeout",
type=int,
help="How long to wait for worker before timing out HTTP connections",
default=None,
)
self.parser.add_argument(
"--access-log",
help="Where to write the access log (- for stdout, the default for verbosity=1)",
default=None,
)
self.parser.add_argument(
"--log-fmt",
help="Log format to use",
default="%(asctime)-15s %(levelname)-8s %(message)s",
)
self.parser.add_argument(
"--ping-interval",
type=int,
help="The number of seconds a WebSocket must be idle before a keepalive ping is sent",
default=20,
)
self.parser.add_argument(
"--ping-timeout",
type=int,
help="The number of seconds before a WebSocket is closed if no response to a keepalive ping",
default=30,
)
self.parser.add_argument(
"--application-close-timeout",
type=int,
help="The number of seconds an ASGI application has to exit after client disconnect before it is killed",
default=10,
)
self.parser.add_argument(
"--root-path",
dest="root_path",
help="The setting for the ASGI root_path variable",
default="",
)
self.parser.add_argument(
"--proxy-headers",
dest="proxy_headers",
help="Enable parsing and using of X-Forwarded-For and X-Forwarded-Port headers and using that as the "
"client address",
default=False,
action="store_true",
)
self.arg_proxy_host = self.parser.add_argument(
"--proxy-headers-host",
dest="proxy_headers_host",
help="Specify which header will be used for getting the host "
"part. Can be omitted, requires --proxy-headers to be specified "
'when passed. "X-Real-IP" (when passed by your webserver) is a '
"good candidate for this.",
default=False,
action="store",
)
self.arg_proxy_port = self.parser.add_argument(
"--proxy-headers-port",
dest="proxy_headers_port",
help="Specify which header will be used for getting the port "
"part. Can be omitted, requires --proxy-headers to be specified "
"when passed.",
default=False,
action="store",
)
self.parser.add_argument(
"application",
help="The application to dispatch to as path.to.module:instance.path",
)
self.parser.add_argument(
"-s",
"--server-name",
dest="server_name",
help="specify which value should be passed to response header Server attribute",
default="daphne",
)
self.parser.add_argument(
"--no-server-name", dest="server_name", action="store_const", const=""
)
self.server = None
@classmethod
def entrypoint(cls):
"""
Main entrypoint for external starts.
"""
cls().run(sys.argv[1:])
def _check_proxy_headers_passed(self, argument: str, args: Namespace):
"""Raise if the `--proxy-headers` weren't specified."""
if args.proxy_headers:
return
raise ArgumentError(
argument=argument,
message="--proxy-headers has to be passed for this parameter.",
)
def _get_forwarded_host(self, args: Namespace):
"""
Return the default host header from which the remote hostname/ip
will be extracted.
"""
if args.proxy_headers_host:
self._check_proxy_headers_passed(argument=self.arg_proxy_host, args=args)
return args.proxy_headers_host
if args.proxy_headers:
return "X-Forwarded-For"
def _get_forwarded_port(self, args: Namespace):
"""
Return the default host header from which the remote hostname/ip
will be extracted.
"""
if args.proxy_headers_port:
self._check_proxy_headers_passed(argument=self.arg_proxy_port, args=args)
return args.proxy_headers_port
if args.proxy_headers:
return "X-Forwarded-Port"
def run(self, args):
"""
Pass in raw argument list and it will decode them
and run the server.
"""
# Decode args
args = self.parser.parse_args(args)
# Set up logging
logging.basicConfig(
level={
0: logging.WARN,
1: logging.INFO,
2: logging.DEBUG,
3: logging.DEBUG, # Also turns on asyncio debug
}[args.verbosity],
format=args.log_fmt,
)
# If verbosity is 1 or greater, or they told us explicitly, set up access log
access_log_stream = None
if args.access_log:
if args.access_log == "-":
access_log_stream = sys.stdout
else:
access_log_stream = open(args.access_log, "a", 1)
elif args.verbosity >= 1:
access_log_stream = sys.stdout
# Import application
sys.path.insert(0, ".")
application = import_by_path(args.application)
application = guarantee_single_callable(application)
# Set up port/host bindings
if not any(
[
args.host,
args.port is not None,
args.unix_socket,
args.file_descriptor is not None,
args.socket_strings,
]
):
# no advanced binding options passed, patch in defaults
args.host = DEFAULT_HOST
args.port = DEFAULT_PORT
elif args.host and args.port is None:
args.port = DEFAULT_PORT
elif args.port is not None and not args.host:
args.host = DEFAULT_HOST
# Build endpoint description strings from (optional) cli arguments
endpoints = build_endpoint_description_strings(
host=args.host,
port=args.port,
unix_socket=args.unix_socket,
file_descriptor=args.file_descriptor,
)
endpoints = sorted(args.socket_strings + endpoints)
# Start the server
logger.info("Starting server at {}".format(", ".join(endpoints)))
self.server = self.server_class(
application=application,
endpoints=endpoints,
http_timeout=args.http_timeout,
ping_interval=args.ping_interval,
ping_timeout=args.ping_timeout,
websocket_timeout=args.websocket_timeout,
websocket_connect_timeout=args.websocket_connect_timeout,
websocket_handshake_timeout=args.websocket_connect_timeout,
application_close_timeout=args.application_close_timeout,
action_logger=AccessLogGenerator(access_log_stream)
if access_log_stream
else None,
root_path=args.root_path,
verbosity=args.verbosity,
proxy_forwarded_address_header=self._get_forwarded_host(args=args),
proxy_forwarded_port_header=self._get_forwarded_port(args=args),
proxy_forwarded_proto_header="X-Forwarded-Proto"
if args.proxy_headers
else None,
server_name=args.server_name,
)
self.server.run()

View File

@@ -0,0 +1,22 @@
def build_endpoint_description_strings(
host=None, port=None, unix_socket=None, file_descriptor=None
):
"""
Build a list of twisted endpoint description strings that the server will listen on.
This is to streamline the generation of twisted endpoint description strings from easier
to use command line args such as host, port, unix sockets etc.
"""
socket_descriptions = []
if host and port is not None:
host = host.strip("[]").replace(":", r"\:")
socket_descriptions.append("tcp:port=%d:interface=%s" % (int(port), host))
elif any([host, port]):
raise ValueError("TCP binding requires both port and host kwargs.")
if unix_socket:
socket_descriptions.append("unix:%s" % unix_socket)
if file_descriptor is not None:
socket_descriptions.append("fd:fileno=%d" % int(file_descriptor))
return socket_descriptions

View File

@@ -0,0 +1,414 @@
import logging
import time
import traceback
from urllib.parse import unquote
from twisted.internet.defer import inlineCallbacks, maybeDeferred
from twisted.internet.interfaces import IProtocolNegotiationFactory
from twisted.protocols.policies import ProtocolWrapper
from twisted.web import http
from zope.interface import implementer
from .utils import HEADER_NAME_RE, parse_x_forwarded_for
logger = logging.getLogger(__name__)
class WebRequest(http.Request):
"""
Request that either hands off information to channels, or offloads
to a WebSocket class.
Does some extra processing over the normal Twisted Web request to separate
GET and POST out.
"""
error_template = (
"""
<html>
<head>
<title>%(title)s</title>
<style>
body { font-family: sans-serif; margin: 0; padding: 0; }
h1 { padding: 0.6em 0 0.2em 20px; color: #896868; margin: 0; }
p { padding: 0 0 0.3em 20px; margin: 0; }
footer { padding: 1em 0 0.3em 20px; color: #999; font-size: 80%%; font-style: italic; }
</style>
</head>
<body>
<h1>%(title)s</h1>
<p>%(body)s</p>
<footer>Daphne</footer>
</body>
</html>
""".replace(
"\n", ""
)
.replace(" ", " ")
.replace(" ", " ")
.replace(" ", " ")
) # Shorten it a bit, bytes wise
def __init__(self, *args, **kwargs):
self.client_addr = None
self.server_addr = None
try:
http.Request.__init__(self, *args, **kwargs)
# Easy server link
self.server = self.channel.factory.server
self.application_queue = None
self._response_started = False
self.server.protocol_connected(self)
except Exception:
logger.error(traceback.format_exc())
raise
### Twisted progress callbacks
@inlineCallbacks
def process(self):
try:
self.request_start = time.time()
# Validate header names.
for name, _ in self.requestHeaders.getAllRawHeaders():
if not HEADER_NAME_RE.fullmatch(name):
self.basic_error(400, b"Bad Request", "Invalid header name")
return
# Get upgrade header
upgrade_header = None
if self.requestHeaders.hasHeader(b"Upgrade"):
upgrade_header = self.requestHeaders.getRawHeaders(b"Upgrade")[0]
# Get client address if possible
if hasattr(self.client, "host") and hasattr(self.client, "port"):
# client.host and host.host are byte strings in Python 2, but spec
# requires unicode string.
self.client_addr = [str(self.client.host), self.client.port]
self.server_addr = [str(self.host.host), self.host.port]
self.client_scheme = "https" if self.isSecure() else "http"
# See if we need to get the address from a proxy header instead
if self.server.proxy_forwarded_address_header:
self.client_addr, self.client_scheme = parse_x_forwarded_for(
self.requestHeaders,
self.server.proxy_forwarded_address_header,
self.server.proxy_forwarded_port_header,
self.server.proxy_forwarded_proto_header,
self.client_addr,
self.client_scheme,
)
# Check for unicodeish path (or it'll crash when trying to parse)
try:
self.path.decode("ascii")
except UnicodeDecodeError:
self.path = b"/"
self.basic_error(400, b"Bad Request", "Invalid characters in path")
return
# Calculate query string
self.query_string = b""
if b"?" in self.uri:
self.query_string = self.uri.split(b"?", 1)[1]
try:
self.query_string.decode("ascii")
except UnicodeDecodeError:
self.basic_error(400, b"Bad Request", "Invalid query string")
return
# Is it WebSocket? IS IT?!
if upgrade_header and upgrade_header.lower() == b"websocket":
# Make WebSocket protocol to hand off to
protocol = self.server.ws_factory.buildProtocol(
self.transport.getPeer()
)
if not protocol:
# If protocol creation fails, we signal "internal server error"
self.setResponseCode(500)
logger.warn("Could not make WebSocket protocol")
self.finish()
# Give it the raw query string
protocol._raw_query_string = self.query_string
# Port across transport
transport, self.transport = self.transport, None
if isinstance(transport, ProtocolWrapper):
# i.e. TLS is a wrapping protocol
transport.wrappedProtocol = protocol
else:
transport.protocol = protocol
protocol.makeConnection(transport)
# Re-inject request
data = self.method + b" " + self.uri + b" HTTP/1.1\x0d\x0a"
for h in self.requestHeaders.getAllRawHeaders():
data += h[0] + b": " + b",".join(h[1]) + b"\x0d\x0a"
data += b"\x0d\x0a"
data += self.content.read()
protocol.dataReceived(data)
# Remove our HTTP reply channel association
logger.debug("Upgraded connection %s to WebSocket", self.client_addr)
self.server.protocol_disconnected(self)
# Resume the producer so we keep getting data, if it's available as a method
self.channel._networkProducer.resumeProducing()
# Boring old HTTP.
else:
# Sanitize and decode headers, potentially extracting root path
self.clean_headers = []
self.root_path = self.server.root_path
for name, values in self.requestHeaders.getAllRawHeaders():
# Prevent CVE-2015-0219
if b"_" in name:
continue
for value in values:
if name.lower() == b"daphne-root-path":
self.root_path = unquote(value.decode("ascii"))
else:
self.clean_headers.append((name.lower(), value))
logger.debug("HTTP %s request for %s", self.method, self.client_addr)
self.content.seek(0, 0)
# Work out the application scope and create application
self.application_queue = yield maybeDeferred(
self.server.create_application,
self,
{
"type": "http",
# TODO: Correctly say if it's 1.1 or 1.0
"http_version": self.clientproto.split(b"/")[-1].decode(
"ascii"
),
"method": self.method.decode("ascii"),
"path": unquote(self.path.decode("ascii")),
"raw_path": self.path,
"root_path": self.root_path,
"scheme": self.client_scheme,
"query_string": self.query_string,
"headers": self.clean_headers,
"client": self.client_addr,
"server": self.server_addr,
},
)
# Check they didn't close an unfinished request
if self.application_queue is None or self.content.closed:
# Not much we can do, the request is prematurely abandoned.
return
# Run application against request
buffer_size = self.server.request_buffer_size
while True:
chunk = self.content.read(buffer_size)
more_body = not (len(chunk) < buffer_size)
payload = {
"type": "http.request",
"body": chunk,
"more_body": more_body,
}
self.application_queue.put_nowait(payload)
if not more_body:
break
except Exception:
logger.error(traceback.format_exc())
self.basic_error(
500, b"Internal Server Error", "Daphne HTTP processing error"
)
def connectionLost(self, reason):
"""
Cleans up reply channel on close.
"""
if self.application_queue:
self.send_disconnect()
logger.debug("HTTP disconnect for %s", self.client_addr)
http.Request.connectionLost(self, reason)
self.server.protocol_disconnected(self)
def finish(self):
"""
Cleans up reply channel on close.
"""
if self.application_queue:
self.send_disconnect()
logger.debug("HTTP close for %s", self.client_addr)
http.Request.finish(self)
self.server.protocol_disconnected(self)
### Server reply callbacks
def handle_reply(self, message):
"""
Handles a reply from the client
"""
# Handle connections that are already closed
if self.finished or self.channel is None:
return
# Check message validity
if "type" not in message:
raise ValueError("Message has no type defined")
# Handle message
if message["type"] == "http.response.start":
if self._response_started:
raise ValueError("HTTP response has already been started")
self._response_started = True
if "status" not in message:
raise ValueError(
"Specifying a status code is required for a Response message."
)
# Set HTTP status code
self.setResponseCode(message["status"])
# Write headers
for header, value in message.get("headers", {}):
self.responseHeaders.addRawHeader(header, value)
if self.server.server_name and not self.responseHeaders.hasHeader("server"):
self.setHeader(b"server", self.server.server_name.encode())
logger.debug(
"HTTP %s response started for %s", message["status"], self.client_addr
)
elif message["type"] == "http.response.body":
if not self._response_started:
raise ValueError(
"HTTP response has not yet been started but got %s"
% message["type"]
)
# Write out body
http.Request.write(self, message.get("body", b""))
# End if there's no more content
if not message.get("more_body", False):
self.finish()
logger.debug("HTTP response complete for %s", self.client_addr)
try:
uri = self.uri.decode("ascii")
except UnicodeDecodeError:
# The path is malformed somehow - do our best to log something
uri = repr(self.uri)
try:
self.server.log_action(
"http",
"complete",
{
"path": uri,
"status": self.code,
"method": self.method.decode("ascii", "replace"),
"client": "%s:%s" % tuple(self.client_addr)
if self.client_addr
else None,
"time_taken": self.duration(),
"size": self.sentLength,
},
)
except Exception:
logger.error(traceback.format_exc())
else:
logger.debug("HTTP response chunk for %s", self.client_addr)
else:
raise ValueError("Cannot handle message type %s!" % message["type"])
def handle_exception(self, exception):
"""
Called by the server when our application tracebacks
"""
self.basic_error(500, b"Internal Server Error", "Exception inside application.")
def check_timeouts(self):
"""
Called periodically to see if we should timeout something
"""
# Web timeout checking
if self.server.http_timeout and self.duration() > self.server.http_timeout:
if self._response_started:
logger.warning("Application timed out while sending response")
self.finish()
else:
self.basic_error(
503,
b"Service Unavailable",
"Application failed to respond within time limit.",
)
### Utility functions
def send_disconnect(self):
"""
Sends a http.disconnect message.
Useful only really for long-polling.
"""
# If we don't yet have a path, then don't send as we never opened.
if self.path:
self.application_queue.put_nowait({"type": "http.disconnect"})
def duration(self):
"""
Returns the time since the start of the request.
"""
if not hasattr(self, "request_start"):
return 0
return time.time() - self.request_start
def basic_error(self, status, status_text, body):
"""
Responds with a server-level error page (very basic)
"""
self.handle_reply(
{
"type": "http.response.start",
"status": status,
"headers": [(b"Content-Type", b"text/html; charset=utf-8")],
}
)
self.handle_reply(
{
"type": "http.response.body",
"body": (
self.error_template
% {
"title": str(status) + " " + status_text.decode("ascii"),
"body": body,
}
).encode("utf8"),
}
)
def __hash__(self):
return hash(id(self))
def __eq__(self, other):
return id(self) == id(other)
@implementer(IProtocolNegotiationFactory)
class HTTPFactory(http.HTTPFactory):
"""
Factory which takes care of tracking which protocol
instances or request instances are responsible for which
named response channels, so incoming messages can be
routed appropriately.
"""
def __init__(self, server):
http.HTTPFactory.__init__(self)
self.server = server
def buildProtocol(self, addr):
"""
Builds protocol instances. This override is used to ensure we use our
own Request object instead of the default.
"""
try:
protocol = http.HTTPFactory.buildProtocol(self, addr)
protocol.requestFactory = WebRequest
return protocol
except Exception:
logger.error("Cannot build protocol: %s" % traceback.format_exc())
raise
# IProtocolNegotiationFactory
def acceptableProtocols(self):
"""
Protocols this server can speak after ALPN negotiation. Currently that
is HTTP/1.1 and optionally HTTP/2. Websockets cannot be negotiated
using ALPN, so that doesn't go here: anyone wanting websockets will
negotiate HTTP/1.1 and then do the upgrade dance.
"""
baseProtocols = [b"http/1.1"]
if http.H2_ENABLED:
baseProtocols.insert(0, b"h2")
return baseProtocols

View File

@@ -0,0 +1,191 @@
import datetime
import importlib
import logging
import sys
from django.apps import apps
from django.conf import settings
from django.contrib.staticfiles.handlers import ASGIStaticFilesHandler
from django.core.exceptions import ImproperlyConfigured
from django.core.management import CommandError
from django.core.management.commands.runserver import Command as RunserverCommand
from daphne import __version__
from daphne.endpoints import build_endpoint_description_strings
from daphne.server import Server
logger = logging.getLogger("django.channels.server")
def get_default_application():
"""
Gets the default application, set in the ASGI_APPLICATION setting.
"""
try:
path, name = settings.ASGI_APPLICATION.rsplit(".", 1)
except (ValueError, AttributeError):
raise ImproperlyConfigured("Cannot find ASGI_APPLICATION setting.")
try:
module = importlib.import_module(path)
except ImportError:
raise ImproperlyConfigured("Cannot import ASGI_APPLICATION module %r" % path)
try:
value = getattr(module, name)
except AttributeError:
raise ImproperlyConfigured(
f"Cannot find {name!r} in ASGI_APPLICATION module {path}"
)
return value
class Command(RunserverCommand):
protocol = "http"
server_cls = Server
def add_arguments(self, parser):
super().add_arguments(parser)
parser.add_argument(
"--noasgi",
action="store_false",
dest="use_asgi",
default=True,
help="Run the old WSGI-based runserver rather than the ASGI-based one",
)
parser.add_argument(
"--http_timeout",
action="store",
dest="http_timeout",
type=int,
default=None,
help=(
"Specify the daphne http_timeout interval in seconds "
"(default: no timeout)"
),
)
parser.add_argument(
"--websocket_handshake_timeout",
action="store",
dest="websocket_handshake_timeout",
type=int,
default=5,
help=(
"Specify the daphne websocket_handshake_timeout interval in "
"seconds (default: 5)"
),
)
def handle(self, *args, **options):
self.http_timeout = options.get("http_timeout", None)
self.websocket_handshake_timeout = options.get("websocket_handshake_timeout", 5)
# Check Channels is installed right
if options["use_asgi"] and not hasattr(settings, "ASGI_APPLICATION"):
raise CommandError(
"You have not set ASGI_APPLICATION, which is needed to run the server."
)
# Dispatch upward
super().handle(*args, **options)
def inner_run(self, *args, **options):
# Maybe they want the wsgi one?
if not options.get("use_asgi", True):
if hasattr(RunserverCommand, "server_cls"):
self.server_cls = RunserverCommand.server_cls
return RunserverCommand.inner_run(self, *args, **options)
# Run checks
self.stdout.write("Performing system checks...\n\n")
self.check(display_num_errors=True)
self.check_migrations()
# Print helpful text
quit_command = "CTRL-BREAK" if sys.platform == "win32" else "CONTROL-C"
now = datetime.datetime.now().strftime("%B %d, %Y - %X")
self.stdout.write(now)
self.stdout.write(
(
"Django version %(version)s, using settings %(settings)r\n"
"Starting ASGI/Daphne version %(daphne_version)s development server"
" at %(protocol)s://%(addr)s:%(port)s/\n"
"Quit the server with %(quit_command)s.\n"
)
% {
"version": self.get_version(),
"daphne_version": __version__,
"settings": settings.SETTINGS_MODULE,
"protocol": self.protocol,
"addr": "[%s]" % self.addr if self._raw_ipv6 else self.addr,
"port": self.port,
"quit_command": quit_command,
}
)
# Launch server in 'main' thread. Signals are disabled as it's still
# actually a subthread under the autoreloader.
logger.debug("Daphne running, listening on %s:%s", self.addr, self.port)
# build the endpoint description string from host/port options
endpoints = build_endpoint_description_strings(host=self.addr, port=self.port)
try:
self.server_cls(
application=self.get_application(options),
endpoints=endpoints,
signal_handlers=not options["use_reloader"],
action_logger=self.log_action,
http_timeout=self.http_timeout,
root_path=getattr(settings, "FORCE_SCRIPT_NAME", "") or "",
websocket_handshake_timeout=self.websocket_handshake_timeout,
).run()
logger.debug("Daphne exited")
except KeyboardInterrupt:
shutdown_message = options.get("shutdown_message", "")
if shutdown_message:
self.stdout.write(shutdown_message)
return
def get_application(self, options):
"""
Returns the static files serving application wrapping the default application,
if static files should be served. Otherwise just returns the default
handler.
"""
staticfiles_installed = apps.is_installed("django.contrib.staticfiles")
use_static_handler = options.get("use_static_handler", staticfiles_installed)
insecure_serving = options.get("insecure_serving", False)
if use_static_handler and (settings.DEBUG or insecure_serving):
return ASGIStaticFilesHandler(get_default_application())
else:
return get_default_application()
def log_action(self, protocol, action, details):
"""
Logs various different kinds of requests to the console.
"""
# HTTP requests
if protocol == "http" and action == "complete":
msg = "HTTP %(method)s %(path)s %(status)s [%(time_taken).2f, %(client)s]"
# Utilize terminal colors, if available
if 200 <= details["status"] < 300:
# Put 2XX first, since it should be the common case
logger.info(self.style.HTTP_SUCCESS(msg), details)
elif 100 <= details["status"] < 200:
logger.info(self.style.HTTP_INFO(msg), details)
elif details["status"] == 304:
logger.info(self.style.HTTP_NOT_MODIFIED(msg), details)
elif 300 <= details["status"] < 400:
logger.info(self.style.HTTP_REDIRECT(msg), details)
elif details["status"] == 404:
logger.warning(self.style.HTTP_NOT_FOUND(msg), details)
elif 400 <= details["status"] < 500:
logger.warning(self.style.HTTP_BAD_REQUEST(msg), details)
else:
# Any 5XX, or any other response
logger.error(self.style.HTTP_SERVER_ERROR(msg), details)
# Websocket requests
elif protocol == "websocket" and action == "connected":
logger.info("WebSocket CONNECT %(path)s [%(client)s]", details)
elif protocol == "websocket" and action == "disconnected":
logger.info("WebSocket DISCONNECT %(path)s [%(client)s]", details)
elif protocol == "websocket" and action == "connecting":
logger.info("WebSocket HANDSHAKING %(path)s [%(client)s]", details)
elif protocol == "websocket" and action == "rejected":
logger.info("WebSocket REJECT %(path)s [%(client)s]", details)

View File

@@ -0,0 +1,342 @@
# This has to be done first as Twisted is import-order-sensitive with reactors
import asyncio # isort:skip
import os # isort:skip
import sys # isort:skip
import warnings # isort:skip
from concurrent.futures import ThreadPoolExecutor # isort:skip
from twisted.internet import asyncioreactor # isort:skip
twisted_loop = asyncio.new_event_loop()
if "ASGI_THREADS" in os***REMOVED***iron:
twisted_loop.set_default_executor(
ThreadPoolExecutor(max_workers=int(os***REMOVED***iron["ASGI_THREADS"]))
)
current_reactor = sys.modules.get("twisted.internet.reactor", None)
if current_reactor is not None:
if not isinstance(current_reactor, asyncioreactor.AsyncioSelectorReactor):
warnings.warn(
"Something has already installed a non-asyncio Twisted reactor. Attempting to uninstall it; "
+ "you can fix this warning by importing daphne.server early in your codebase or "
+ "finding the package that imports Twisted and importing it later on.",
UserWarning,
stacklevel=2,
)
del sys.modules["twisted.internet.reactor"]
asyncioreactor.install(twisted_loop)
else:
asyncioreactor.install(twisted_loop)
import logging
import time
from concurrent.futures import CancelledError
from functools import partial
from twisted.internet import defer, reactor
from twisted.internet.endpoints import serverFromString
from twisted.logger import STDLibLogObserver, globalLogBeginner
from twisted.web import http
from .http_protocol import HTTPFactory
from .ws_protocol import WebSocketFactory
logger = logging.getLogger(__name__)
class Server:
def __init__(
self,
application,
endpoints=None,
signal_handlers=True,
action_logger=None,
http_timeout=None,
request_buffer_size=8192,
websocket_timeout=86400,
websocket_connect_timeout=20,
ping_interval=20,
ping_timeout=30,
root_path="",
proxy_forwarded_address_header=None,
proxy_forwarded_port_header=None,
proxy_forwarded_proto_header=None,
verbosity=1,
websocket_handshake_timeout=5,
application_close_timeout=10,
ready_callable=None,
server_name="daphne",
):
self.application = application
self.endpoints = endpoints or []
self.listeners = []
self.listening_addresses = []
self.signal_handlers = signal_handlers
self.action_logger = action_logger
self.http_timeout = http_timeout
self.ping_interval = ping_interval
self.ping_timeout = ping_timeout
self.request_buffer_size = request_buffer_size
self.proxy_forwarded_address_header = proxy_forwarded_address_header
self.proxy_forwarded_port_header = proxy_forwarded_port_header
self.proxy_forwarded_proto_header = proxy_forwarded_proto_header
self.websocket_timeout = websocket_timeout
self.websocket_connect_timeout = websocket_connect_timeout
self.websocket_handshake_timeout = websocket_handshake_timeout
self.application_close_timeout = application_close_timeout
self.root_path = root_path
self.verbosity = verbosity
self.abort_start = False
self.ready_callable = ready_callable
self.server_name = server_name
# Check our construction is actually sensible
if not self.endpoints:
logger.error("No endpoints. This server will not listen on anything.")
sys.exit(1)
def run(self):
# A dict of protocol: {"application_instance":, "connected":, "disconnected":} dicts
self.connections = {}
# Make the factory
self.http_factory = HTTPFactory(self)
self.ws_factory = WebSocketFactory(self, server=self.server_name)
self.ws_factory.setProtocolOptions(
autoPingTimeout=self.ping_timeout,
allowNullOrigin=True,
openHandshakeTimeout=self.websocket_handshake_timeout,
)
if self.verbosity <= 1:
# Redirect the Twisted log to nowhere
globalLogBeginner.beginLoggingTo(
[lambda _: None], redirectStandardIO=False, discardBuffer=True
)
else:
globalLogBeginner.beginLoggingTo([STDLibLogObserver(__name__)])
# Detect what Twisted features are enabled
if http.H2_ENABLED:
logger.info("HTTP/2 support enabled")
else:
logger.info(
"HTTP/2 support not enabled (install the http2 and tls Twisted extras)"
)
# Kick off the timeout loop
reactor.callLater(1, self.application_checker)
reactor.callLater(2, self.timeout_checker)
for socket_description in self.endpoints:
logger.info("Configuring endpoint %s", socket_description)
ep = serverFromString(reactor, str(socket_description))
listener = ep.listen(self.http_factory)
listener.addCallback(self.listen_success)
listener.addErrback(self.listen_error)
self.listeners.append(listener)
# Set the asyncio reactor's event loop as global
# TODO: Should we instead pass the global one into the reactor?
asyncio.set_event_loop(reactor._asyncioEventloop)
# Verbosity 3 turns on asyncio debug to find those blocking yields
if self.verbosity >= 3:
asyncio.get_event_loop().set_debug(True)
reactor.addSystemEventTrigger("before", "shutdown", self.kill_all_applications)
if not self.abort_start:
# Trigger the ready flag if we had one
if self.ready_callable:
self.ready_callable()
# Run the reactor
reactor.run(installSignalHandlers=self.signal_handlers)
def listen_success(self, port):
"""
Called when a listen succeeds so we can store port details (if there are any)
"""
if hasattr(port, "getHost"):
host = port.getHost()
if hasattr(host, "host") and hasattr(host, "port"):
self.listening_addresses.append((host.host, host.port))
logger.info(
"Listening on TCP address %s:%s",
port.getHost().host,
port.getHost().port,
)
def listen_error(self, failure):
logger.critical("Listen failure: %s", failure.getErrorMessage())
self.stop()
def stop(self):
"""
Force-stops the server.
"""
if reactor.running:
reactor.stop()
else:
self.abort_start = True
### Protocol handling
def protocol_connected(self, protocol):
"""
Adds a protocol as a current connection.
"""
if protocol in self.connections:
raise RuntimeError("Protocol %r was added to main list twice!" % protocol)
self.connections[protocol] = {"connected": time.time()}
def protocol_disconnected(self, protocol):
# Set its disconnected time (the loops will come and clean it up)
# Do not set it if it is already set. Overwriting it might
# cause it to never be cleaned up.
# See https://github.com/django/channels/issues/1181
if "disconnected" not in self.connections[protocol]:
self.connections[protocol]["disconnected"] = time.time()
### Internal event/message handling
def create_application(self, protocol, scope):
"""
Creates a new application instance that fronts a Protocol instance
for one of our supported protocols. Pass it the protocol,
and it will work out the type, supply appropriate callables, and
return you the application's input queue
"""
# Make sure the protocol has not had another application made for it
assert "application_instance" not in self.connections[protocol]
# Make an instance of the application
input_queue = asyncio.Queue()
scope.setdefault("asgi", {"version": "3.0"})
application_instance = self.application(
scope=scope,
receive=input_queue.get,
send=partial(self.handle_reply, protocol),
)
# Run it, and stash the future for later checking
if protocol not in self.connections:
return None
self.connections[protocol]["application_instance"] = asyncio.ensure_future(
application_instance,
loop=asyncio.get_event_loop(),
)
return input_queue
async def handle_reply(self, protocol, message):
"""
Coroutine that jumps the reply message from asyncio to Twisted
"""
# Don't do anything if the connection is closed or does not exist
if protocol not in self.connections or self.connections[protocol].get(
"disconnected", None
):
return
try:
self.check_headers_type(message)
except ValueError:
# Ensure to send SOME reply.
protocol.basic_error(500, b"Server Error", "Server Error")
raise
# Let the protocol handle it
protocol.handle_reply(message)
@staticmethod
def check_headers_type(message):
if not message["type"] == "http.response.start":
return
for k, v in message.get("headers", []):
if not isinstance(k, bytes):
raise ValueError(
"Header name '{}' expected to be `bytes`, but got `{}`".format(
k, type(k)
)
)
if not isinstance(v, bytes):
raise ValueError(
"Header value '{}' expected to be `bytes`, but got `{}`".format(
v, type(v)
)
)
### Utility
def application_checker(self):
"""
Goes through the set of current application Futures and cleans up
any that are done/prints exceptions for any that errored.
"""
for protocol, details in list(self.connections.items()):
disconnected = details.get("disconnected", None)
application_instance = details.get("application_instance", None)
# First, see if the protocol disconnected and the app has taken
# too long to close up
if (
disconnected
and time.time() - disconnected > self.application_close_timeout
):
if application_instance and not application_instance.done():
logger.warning(
"Application instance %r for connection %s took too long to shut down and was killed.",
application_instance,
repr(protocol),
)
application_instance.cancel()
# Then see if the app is done and we should reap it
if application_instance and application_instance.done():
try:
exception = application_instance.exception()
except (CancelledError, asyncio.CancelledError):
# Future cancellation. We can ignore this.
pass
else:
if exception:
if isinstance(exception, KeyboardInterrupt):
# Protocol is asking the server to exit (likely during test)
self.stop()
else:
logger.error(
"Exception inside application: %s",
exception,
exc_info=exception,
)
if not disconnected:
protocol.handle_exception(exception)
del self.connections[protocol]["application_instance"]
application_instance = None
# Check to see if protocol is closed and app is closed so we can remove it
if not application_instance and disconnected:
del self.connections[protocol]
reactor.callLater(1, self.application_checker)
def kill_all_applications(self):
"""
Kills all application coroutines before reactor exit.
"""
# Send cancel to all coroutines
wait_for = []
for details in self.connections.values():
application_instance = details["application_instance"]
if not application_instance.done():
application_instance.cancel()
wait_for.append(application_instance)
logger.info("Killed %i pending application instances", len(wait_for))
# Make Twisted wait until they're all dead
wait_deferred = defer.Deferred.fromFuture(asyncio.gather(*wait_for))
wait_deferred.addErrback(lambda x: None)
return wait_deferred
def timeout_checker(self):
"""
Called periodically to enforce timeout rules on all connections.
Also checks pings at the same time.
"""
for protocol in list(self.connections.keys()):
protocol.check_timeouts()
reactor.callLater(2, self.timeout_checker)
def log_action(self, protocol, action, details):
"""
Dispatches to any registered action logger, if there is one.
"""
if self.action_logger:
self.action_logger(protocol, action, details)

View File

@@ -0,0 +1,309 @@
import logging
import multiprocessing
import os
import pickle
import tempfile
import traceback
from concurrent.futures import CancelledError
class BaseDaphneTestingInstance:
"""
Launches an instance of Daphne in a subprocess, with a host and port
attribute allowing you to call it.
Works as a context manager.
"""
startup_timeout = 2
def __init__(
self, xff=False, http_timeout=None, request_buffer_size=None, *, application
):
self.xff = xff
self.http_timeout = http_timeout
self.host = "127.0.0.1"
self.request_buffer_size = request_buffer_size
self.application = application
def get_application(self):
return self.application
def __enter__(self):
# Option Daphne features
kwargs = {}
if self.request_buffer_size:
kwargs["request_buffer_size"] = self.request_buffer_size
# Optionally enable X-Forwarded-For support.
if self.xff:
kwargs["proxy_forwarded_address_header"] = "X-Forwarded-For"
kwargs["proxy_forwarded_port_header"] = "X-Forwarded-Port"
kwargs["proxy_forwarded_proto_header"] = "X-Forwarded-Proto"
if self.http_timeout:
kwargs["http_timeout"] = self.http_timeout
# Start up process
self.process = DaphneProcess(
host=self.host,
get_application=self.get_application,
kwargs=kwargs,
setup=self.process_setup,
teardown=self.process_teardown,
)
self.process.start()
# Wait for the port
if self.process.ready.wait(self.startup_timeout):
self.port = self.process.port.value
return self
else:
if self.process.errors.empty():
raise RuntimeError("Daphne did not start up, no error caught")
else:
error, traceback = self.process.errors.get(False)
raise RuntimeError("Daphne did not start up:\n%s" % traceback)
def __exit__(self, exc_type, exc_value, traceback):
# Shut down the process
self.process.terminate()
del self.process
def process_setup(self):
"""
Called by the process just before it starts serving.
"""
pass
def process_teardown(self):
"""
Called by the process just after it stops serving
"""
pass
def get_received(self):
pass
class DaphneTestingInstance(BaseDaphneTestingInstance):
def __init__(self, *args, **kwargs):
self.lock = multiprocessing.Lock()
super().__init__(*args, **kwargs, application=TestApplication(lock=self.lock))
def __enter__(self):
# Clear result storage
TestApplication.delete_setup()
TestApplication.delete_result()
return super().__enter__()
def get_received(self):
"""
Returns the scope and messages the test application has received
so far. Note you'll get all messages since scope start, not just any
new ones since the last call.
Also checks for any exceptions in the application. If there are,
raises them.
"""
try:
with self.lock:
inner_result = TestApplication.load_result()
except FileNotFoundError:
raise ValueError("No results available yet.")
# Check for exception
if "exception" in inner_result:
raise inner_result["exception"]
return inner_result["scope"], inner_result["messages"]
def add_send_messages(self, messages):
"""
Adds messages for the application to send back.
The next time it receives an incoming message, it will reply with these.
"""
TestApplication.save_setup(response_messages=messages)
class DaphneProcess(multiprocessing.Process):
"""
Process subclass that launches and runs a Daphne instance, communicating the
port it ends up listening on back to the parent process.
"""
def __init__(self, host, get_application, kwargs=None, setup=None, teardown=None):
super().__init__()
self.host = host
self.get_application = get_application
self.kwargs = kwargs or {}
self.setup = setup
self.teardown = teardown
self.port = multiprocessing.Value("i")
self.ready = multiprocessing.Event()
self.errors = multiprocessing.Queue()
def run(self):
# OK, now we are in a forked child process, and want to use the reactor.
# However, FreeBSD systems like MacOS do not fork the underlying Kqueue,
# which asyncio (hence asyncioreactor) is built on.
# Therefore, we should uninstall the broken reactor and install a new one.
_reinstall_reactor()
from twisted.internet import reactor
from .endpoints import build_endpoint_description_strings
from .server import Server
application = self.get_application()
try:
# Create the server class
endpoints = build_endpoint_description_strings(host=self.host, port=0)
self.server = Server(
application=application,
endpoints=endpoints,
signal_handlers=False,
**self.kwargs
)
# Set up a poller to look for the port
reactor.callLater(0.1, self.resolve_port)
# Run with setup/teardown
if self.setup is not None:
self.setup()
try:
self.server.run()
finally:
if self.teardown is not None:
self.teardown()
except BaseException as e:
# Put the error on our queue so the parent gets it
self.errors.put((e, traceback.format_exc()))
def resolve_port(self):
from twisted.internet import reactor
if self.server.listening_addresses:
self.port.value = self.server.listening_addresses[0][1]
self.ready.set()
else:
reactor.callLater(0.1, self.resolve_port)
class TestApplication:
"""
An application that receives one or more messages, sends a response,
and then quits the server. For testing.
"""
setup_storage = os.path.join(tempfile.gettempdir(), "setup.testio")
result_storage = os.path.join(tempfile.gettempdir(), "result.testio")
def __init__(self, lock):
self.lock = lock
self.messages = []
async def __call__(self, scope, receive, send):
self.scope = scope
# Receive input and send output
logging.debug("test app coroutine alive")
try:
while True:
# Receive a message and save it into the result store
self.messages.append(await receive())
self.lock.acquire()
logging.debug("test app received %r", self.messages[-1])
self.save_result(self.scope, self.messages)
self.lock.release()
# See if there are any messages to send back
setup = self.load_setup()
self.delete_setup()
for message in setup["response_messages"]:
await send(message)
logging.debug("test app sent %r", message)
except Exception as e:
if isinstance(e, CancelledError):
# Don't catch task-cancelled errors!
raise
else:
self.save_exception(e)
@classmethod
def save_setup(cls, response_messages):
"""
Stores setup information.
"""
with open(cls.setup_storage, "wb") as fh:
pickle.dump({"response_messages": response_messages}, fh)
@classmethod
def load_setup(cls):
"""
Returns setup details.
"""
try:
with open(cls.setup_storage, "rb") as fh:
return pickle.load(fh)
except FileNotFoundError:
return {"response_messages": []}
@classmethod
def save_result(cls, scope, messages):
"""
Saves details of what happened to the result storage.
We could use pickle here, but that seems wrong, still, somehow.
"""
with open(cls.result_storage, "wb") as fh:
pickle.dump({"scope": scope, "messages": messages}, fh)
@classmethod
def save_exception(cls, exception):
"""
Saves details of what happened to the result storage.
We could use pickle here, but that seems wrong, still, somehow.
"""
with open(cls.result_storage, "wb") as fh:
pickle.dump({"exception": exception}, fh)
@classmethod
def load_result(cls):
"""
Returns result details.
"""
with open(cls.result_storage, "rb") as fh:
return pickle.load(fh)
@classmethod
def delete_setup(cls):
"""
Clears setup storage files.
"""
try:
os.unlink(cls.setup_storage)
except OSError:
pass
@classmethod
def delete_result(cls):
"""
Clears result storage files.
"""
try:
os.unlink(cls.result_storage)
except OSError:
pass
def _reinstall_reactor():
import asyncio
import sys
from twisted.internet import asyncioreactor
# Uninstall the reactor.
if "twisted.internet.reactor" in sys.modules:
del sys.modules["twisted.internet.reactor"]
# The daphne.server module may have already installed the reactor.
# If so, using this module will use uninstalled one, thus we should
# reimport this module too.
if "daphne.server" in sys.modules:
del sys.modules["daphne.server"]
event_loop = asyncio.new_event_loop()
asyncioreactor.install(event_loop)
asyncio.set_event_loop(event_loop)

View File

@@ -0,0 +1,24 @@
import socket
from twisted.internet import endpoints
from twisted.internet.interfaces import IStreamServerEndpointStringParser
from twisted.plugin import IPlugin
from zope.interface import implementer
@implementer(IPlugin, IStreamServerEndpointStringParser)
class _FDParser:
prefix = "fd"
def _parseServer(self, reactor, fileno, domain=socket.AF_INET):
fileno = int(fileno)
return endpoints.AdoptedStreamServerEndpoint(reactor, fileno, domain)
def parseStreamServer(self, reactor, *args, **kwargs):
# Delegate to another function with a sane signature. This function has
# an insane signature to trick zope.interface into believing the
# interface is correctly implemented.
return self._parseServer(reactor, *args, **kwargs)
parser = _FDParser()

View File

@@ -0,0 +1,89 @@
import importlib
import re
from twisted.web.http_headers import Headers
# Header name regex as per h11.
# https://github.com/python-[AWS-SECRET-REMOVED]9d98002e4a4ed27/h11/_abnf.py#L10-L21
HEADER_NAME_RE = re.compile(rb"[-!#$%&'*+.^_`|~0-9a-zA-Z]+")
def import_by_path(path):
"""
Given a dotted/colon path, like project.module:ClassName.callable,
returns the object at the end of the path.
"""
module_path, object_path = path.split(":", 1)
target = importlib.import_module(module_path)
for bit in object_path.split("."):
target = getattr(target, bit)
return target
def header_value(headers, header_name):
value = headers[header_name]
if isinstance(value, list):
value = value[0]
return value.decode("utf-8")
def parse_x_forwarded_for(
headers,
address_header_name="X-Forwarded-For",
port_header_name="X-Forwarded-Port",
proto_header_name="X-Forwarded-Proto",
original_addr=None,
original_scheme=None,
):
"""
Parses an X-Forwarded-For header and returns a host/port pair as a list.
@param headers: The twisted-style object containing a request's headers
@param address_header_name: The name of the expected host header
@param port_header_name: The name of the expected port header
@param proto_header_name: The name of the expected proto header
@param original_addr: A host/port pair that should be returned if the headers are not in the request
@param original_scheme: A scheme that should be returned if the headers are not in the request
@return: A list containing a host (string) as the first entry and a port (int) as the second.
"""
if not address_header_name:
return original_addr, original_scheme
# Convert twisted-style headers into dicts
if isinstance(headers, Headers):
headers = dict(headers.getAllRawHeaders())
# Lowercase all header names in the dict
headers = {name.lower(): values for name, values in headers.items()}
# Make sure header names are bytes (values are checked in header_value)
assert all(isinstance(name, bytes) for name in headers.keys())
address_header_name = address_header_name.lower().encode("utf-8")
result_addr = original_addr
result_scheme = original_scheme
if address_header_name in headers:
address_value = header_value(headers, address_header_name)
if "," in address_value:
address_value = address_value.split(",")[0].strip()
result_addr = [address_value, 0]
if port_header_name:
# We only want to parse the X-Forwarded-Port header if we also parsed the X-Forwarded-For
# header to avoid inconsistent results.
port_header_name = port_header_name.lower().encode("utf-8")
if port_header_name in headers:
port_value = header_value(headers, port_header_name)
try:
result_addr[1] = int(port_value)
except ValueError:
pass
if proto_header_name:
proto_header_name = proto_header_name.lower().encode("utf-8")
if proto_header_name in headers:
result_scheme = header_value(headers, proto_header_name)
return result_addr, result_scheme

View File

@@ -0,0 +1,331 @@
import logging
import time
import traceback
from urllib.parse import unquote
from autobahn.twisted.websocket import (
ConnectionDeny,
WebSocketServerFactory,
WebSocketServerProtocol,
)
from twisted.internet import defer
from .utils import parse_x_forwarded_for
logger = logging.getLogger(__name__)
class WebSocketProtocol(WebSocketServerProtocol):
"""
Protocol which supports WebSockets and forwards incoming messages to
the websocket channels.
"""
application_type = "websocket"
# If we should send no more messages (e.g. we error-closed the socket)
muted = False
def onConnect(self, request):
self.server = self.factory.server_class
self.server.protocol_connected(self)
self.request = request
self.protocol_to_accept = None
self.root_path = self.server.root_path
self.socket_opened = time.time()
self.last_ping = time.time()
try:
# Sanitize and decode headers, potentially extracting root path
self.clean_headers = []
for name, value in request.headers.items():
name = name.encode("ascii")
# Prevent CVE-2015-0219
if b"_" in name:
continue
if name.lower() == b"daphne-root-path":
self.root_path = unquote(value)
else:
self.clean_headers.append((name.lower(), value.encode("latin1")))
# Get client address if possible
peer = self.transport.getPeer()
host = self.transport.getHost()
if hasattr(peer, "host") and hasattr(peer, "port"):
self.client_addr = [str(peer.host), peer.port]
self.server_addr = [str(host.host), host.port]
else:
self.client_addr = None
self.server_addr = None
if self.server.proxy_forwarded_address_header:
self.client_addr, self.client_scheme = parse_x_forwarded_for(
dict(self.clean_headers),
self.server.proxy_forwarded_address_header,
self.server.proxy_forwarded_port_header,
self.server.proxy_forwarded_proto_header,
self.client_addr,
)
# Decode websocket subprotocol options
subprotocols = []
for header, value in self.clean_headers:
if header == b"sec-websocket-protocol":
subprotocols = [
x.strip() for x in unquote(value.decode("ascii")).split(",")
]
# Make new application instance with scope
self.path = request.path.encode("ascii")
self.application_deferred = defer.maybeDeferred(
self.server.create_application,
self,
{
"type": "websocket",
"path": unquote(self.path.decode("ascii")),
"raw_path": self.path,
"root_path": self.root_path,
"headers": self.clean_headers,
"query_string": self._raw_query_string, # Passed by HTTP protocol
"client": self.client_addr,
"server": self.server_addr,
"subprotocols": subprotocols,
},
)
if self.application_deferred is not None:
self.application_deferred.addCallback(self.applicationCreateWorked)
self.application_deferred.addErrback(self.applicationCreateFailed)
except Exception:
# Exceptions here are not displayed right, just 500.
# Turn them into an ERROR log.
logger.error(traceback.format_exc())
raise
# Make a deferred and return it - we'll either call it or err it later on
self.handshake_deferred = defer.Deferred()
return self.handshake_deferred
def applicationCreateWorked(self, application_queue):
"""
Called when the background thread has successfully made the application
instance.
"""
# Store the application's queue
self.application_queue = application_queue
# Send over the connect message
self.application_queue.put_nowait({"type": "websocket.connect"})
self.server.log_action(
"websocket",
"connecting",
{
"path": self.request.path,
"client": "%s:%s" % tuple(self.client_addr)
if self.client_addr
else None,
},
)
def applicationCreateFailed(self, failure):
"""
Called when application creation fails.
"""
logger.error(failure)
return failure
### Twisted event handling
def onOpen(self):
# Send news that this channel is open
logger.debug("WebSocket %s open and established", self.client_addr)
self.server.log_action(
"websocket",
"connected",
{
"path": self.request.path,
"client": "%s:%s" % tuple(self.client_addr)
if self.client_addr
else None,
},
)
def onMessage(self, payload, isBinary):
# If we're muted, do nothing.
if self.muted:
logger.debug("Muting incoming frame on %s", self.client_addr)
return
logger.debug("WebSocket incoming frame on %s", self.client_addr)
self.last_ping = time.time()
if isBinary:
self.application_queue.put_nowait(
{"type": "websocket.receive", "bytes": payload}
)
else:
self.application_queue.put_nowait(
{"type": "websocket.receive", "text": payload.decode("utf8")}
)
def onClose(self, wasClean, code, reason):
"""
Called when Twisted closes the socket.
"""
self.server.protocol_disconnected(self)
logger.debug("WebSocket closed for %s", self.client_addr)
if not self.muted and hasattr(self, "application_queue"):
self.application_queue.put_nowait(
{"type": "websocket.disconnect", "code": code}
)
self.server.log_action(
"websocket",
"disconnected",
{
"path": self.request.path,
"client": "%s:%s" % tuple(self.client_addr)
if self.client_addr
else None,
},
)
### Internal event handling
def handle_reply(self, message):
if "type" not in message:
raise ValueError("Message has no type defined")
if message["type"] == "websocket.accept":
self.serverAccept(message.get("subprotocol", None))
elif message["type"] == "websocket.close":
if self.state == self.STATE_CONNECTING:
self.serverReject()
else:
self.serverClose(code=message.get("code", None))
elif message["type"] == "websocket.send":
if self.state == self.STATE_CONNECTING:
raise ValueError("Socket has not been accepted, so cannot send over it")
if message.get("bytes", None) and message.get("text", None):
raise ValueError(
"Got invalid WebSocket reply message on %s - contains both bytes and text keys"
% (message,)
)
if message.get("bytes", None):
self.serverSend(message["bytes"], True)
if message.get("text", None):
self.serverSend(message["text"], False)
def handle_exception(self, exception):
"""
Called by the server when our application tracebacks
"""
if hasattr(self, "handshake_deferred"):
# If the handshake is still ongoing, we need to emit a HTTP error
# code rather than a WebSocket one.
self.handshake_deferred.errback(
ConnectionDeny(code=500, reason="Internal server error")
)
else:
self.sendCloseFrame(code=1011)
def serverAccept(self, subprotocol=None):
"""
Called when we get a message saying to accept the connection.
"""
self.handshake_deferred.callback(subprotocol)
del self.handshake_deferred
logger.debug("WebSocket %s accepted by application", self.client_addr)
def serverReject(self):
"""
Called when we get a message saying to reject the connection.
"""
self.handshake_deferred.errback(
ConnectionDeny(code=403, reason="Access denied")
)
del self.handshake_deferred
self.server.protocol_disconnected(self)
logger.debug("WebSocket %s rejected by application", self.client_addr)
self.server.log_action(
"websocket",
"rejected",
{
"path": self.request.path,
"client": "%s:%s" % tuple(self.client_addr)
if self.client_addr
else None,
},
)
def serverSend(self, content, binary=False):
"""
Server-side channel message to send a message.
"""
if self.state == self.STATE_CONNECTING:
self.serverAccept()
logger.debug("Sent WebSocket packet to client for %s", self.client_addr)
if binary:
self.sendMessage(content, binary)
else:
self.sendMessage(content.encode("utf8"), binary)
def serverClose(self, code=None):
"""
Server-side channel message to close the socket
"""
code = 1000 if code is None else code
self.sendClose(code=code)
### Utils
def duration(self):
"""
Returns the time since the socket was opened
"""
return time.time() - self.socket_opened
def check_timeouts(self):
"""
Called periodically to see if we should timeout something
"""
# Web timeout checking
if (
self.duration() > self.server.websocket_timeout
and self.server.websocket_timeout >= 0
):
self.serverClose()
# Ping check
# If we're still connecting, deny the connection
if self.state == self.STATE_CONNECTING:
if self.duration() > self.server.websocket_connect_timeout:
self.serverReject()
elif self.state == self.STATE_OPEN:
if (time.time() - self.last_ping) > self.server.ping_interval:
self._sendAutoPing()
self.last_ping = time.time()
def __hash__(self):
return hash(id(self))
def __eq__(self, other):
return id(self) == id(other)
def __repr__(self):
return f"<WebSocketProtocol client={self.client_addr!r} path={self.path!r}>"
class WebSocketFactory(WebSocketServerFactory):
"""
Factory subclass that remembers what the "main"
factory is, so WebSocket protocols can access it
to get reply ID info.
"""
protocol = WebSocketProtocol
def __init__(self, server_class, *args, **kwargs):
self.server_class = server_class
WebSocketServerFactory.__init__(self, *args, **kwargs)
def buildProtocol(self, addr):
"""
Builds protocol instances. We use this to inject the factory object into the protocol.
"""
try:
protocol = super().buildProtocol(addr)
protocol.factory = self
return protocol
except Exception:
logger.error("Cannot build protocol: %s" % traceback.format_exc())
raise