mirror of
https://github.com/pacnpal/thrillwiki_django_no_react.git
synced 2025-12-22 16:31:09 -05:00
okay fine
This commit is contained in:
@@ -0,0 +1,6 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Twisted Protocols: A collection of internet protocol implementations.
|
||||
"""
|
||||
2856
.venv/lib/python3.12/site-packages/twisted/protocols/amp.py
Normal file
2856
.venv/lib/python3.12/site-packages/twisted/protocols/amp.py
Normal file
File diff suppressed because it is too large
Load Diff
912
.venv/lib/python3.12/site-packages/twisted/protocols/basic.py
Normal file
912
.venv/lib/python3.12/site-packages/twisted/protocols/basic.py
Normal file
@@ -0,0 +1,912 @@
|
||||
# -*- test-case-name: twisted.protocols.test.test_basic -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
|
||||
"""
|
||||
Basic protocols, such as line-oriented, netstring, and int prefixed strings.
|
||||
"""
|
||||
|
||||
|
||||
import math
|
||||
|
||||
# System imports
|
||||
import re
|
||||
from io import BytesIO
|
||||
from struct import calcsize, pack, unpack
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
# Twisted imports
|
||||
from twisted.internet import defer, interfaces, protocol
|
||||
from twisted.python import log
|
||||
|
||||
|
||||
# Unfortunately we cannot use regular string formatting on Python 3; see
|
||||
# http://bugs.python.org/issue3982 for details.
|
||||
def _formatNetstring(data):
|
||||
return b"".join([str(len(data)).encode("ascii"), b":", data, b","])
|
||||
|
||||
|
||||
_formatNetstring.__doc__ = """
|
||||
Convert some C{bytes} into netstring format.
|
||||
|
||||
@param data: C{bytes} that will be reformatted.
|
||||
"""
|
||||
|
||||
|
||||
DEBUG = 0
|
||||
|
||||
|
||||
class NetstringParseError(ValueError):
|
||||
"""
|
||||
The incoming data is not in valid Netstring format.
|
||||
"""
|
||||
|
||||
|
||||
class IncompleteNetstring(Exception):
|
||||
"""
|
||||
Not enough data to complete a netstring.
|
||||
"""
|
||||
|
||||
|
||||
class NetstringReceiver(protocol.Protocol):
|
||||
"""
|
||||
A protocol that sends and receives netstrings.
|
||||
|
||||
See U{http://cr.yp.to/proto/netstrings.txt} for the specification of
|
||||
netstrings. Every netstring starts with digits that specify the length
|
||||
of the data. This length specification is separated from the data by
|
||||
a colon. The data is terminated with a comma.
|
||||
|
||||
Override L{stringReceived} to handle received netstrings. This
|
||||
method is called with the netstring payload as a single argument
|
||||
whenever a complete netstring is received.
|
||||
|
||||
Security features:
|
||||
1. Messages are limited in size, useful if you don't want
|
||||
someone sending you a 500MB netstring (change C{self.MAX_LENGTH}
|
||||
to the maximum length you wish to accept).
|
||||
2. The connection is lost if an illegal message is received.
|
||||
|
||||
@ivar MAX_LENGTH: Defines the maximum length of netstrings that can be
|
||||
received.
|
||||
@type MAX_LENGTH: C{int}
|
||||
|
||||
@ivar _LENGTH: A pattern describing all strings that contain a netstring
|
||||
length specification. Examples for length specifications are C{b'0:'},
|
||||
C{b'12:'}, and C{b'179:'}. C{b'007:'} is not a valid length
|
||||
specification, since leading zeros are not allowed.
|
||||
@type _LENGTH: C{re.Match}
|
||||
|
||||
@ivar _LENGTH_PREFIX: A pattern describing all strings that contain
|
||||
the first part of a netstring length specification (without the
|
||||
trailing comma). Examples are '0', '12', and '179'. '007' does not
|
||||
start a netstring length specification, since leading zeros are
|
||||
not allowed.
|
||||
@type _LENGTH_PREFIX: C{re.Match}
|
||||
|
||||
@ivar _PARSING_LENGTH: Indicates that the C{NetstringReceiver} is in
|
||||
the state of parsing the length portion of a netstring.
|
||||
@type _PARSING_LENGTH: C{int}
|
||||
|
||||
@ivar _PARSING_PAYLOAD: Indicates that the C{NetstringReceiver} is in
|
||||
the state of parsing the payload portion (data and trailing comma)
|
||||
of a netstring.
|
||||
@type _PARSING_PAYLOAD: C{int}
|
||||
|
||||
@ivar brokenPeer: Indicates if the connection is still functional
|
||||
@type brokenPeer: C{int}
|
||||
|
||||
@ivar _state: Indicates if the protocol is consuming the length portion
|
||||
(C{PARSING_LENGTH}) or the payload (C{PARSING_PAYLOAD}) of a netstring
|
||||
@type _state: C{int}
|
||||
|
||||
@ivar _remainingData: Holds the chunk of data that has not yet been consumed
|
||||
@type _remainingData: C{string}
|
||||
|
||||
@ivar _payload: Holds the payload portion of a netstring including the
|
||||
trailing comma
|
||||
@type _payload: C{BytesIO}
|
||||
|
||||
@ivar _expectedPayloadSize: Holds the payload size plus one for the trailing
|
||||
comma.
|
||||
@type _expectedPayloadSize: C{int}
|
||||
"""
|
||||
|
||||
MAX_LENGTH = 99999
|
||||
_LENGTH = re.compile(rb"(0|[1-9]\d*)(:)")
|
||||
|
||||
_LENGTH_PREFIX = re.compile(rb"(0|[1-9]\d*)$")
|
||||
|
||||
# Some error information for NetstringParseError instances.
|
||||
_MISSING_LENGTH = (
|
||||
"The received netstring does not start with a " "length specification."
|
||||
)
|
||||
_OVERFLOW = (
|
||||
"The length specification of the received netstring "
|
||||
"cannot be represented in Python - it causes an "
|
||||
"OverflowError!"
|
||||
)
|
||||
_TOO_LONG = (
|
||||
"The received netstring is longer than the maximum %s "
|
||||
"specified by self.MAX_LENGTH"
|
||||
)
|
||||
_MISSING_COMMA = "The received netstring is not terminated by a comma."
|
||||
|
||||
# The following constants are used for determining if the NetstringReceiver
|
||||
# is parsing the length portion of a netstring, or the payload.
|
||||
_PARSING_LENGTH, _PARSING_PAYLOAD = range(2)
|
||||
|
||||
def makeConnection(self, transport):
|
||||
"""
|
||||
Initializes the protocol.
|
||||
"""
|
||||
protocol.Protocol.makeConnection(self, transport)
|
||||
self._remainingData = b""
|
||||
self._currentPayloadSize = 0
|
||||
self._payload = BytesIO()
|
||||
self._state = self._PARSING_LENGTH
|
||||
self._expectedPayloadSize = 0
|
||||
self.brokenPeer = 0
|
||||
|
||||
def sendString(self, string):
|
||||
"""
|
||||
Sends a netstring.
|
||||
|
||||
Wraps up C{string} by adding length information and a
|
||||
trailing comma; writes the result to the transport.
|
||||
|
||||
@param string: The string to send. The necessary framing (length
|
||||
prefix, etc) will be added.
|
||||
@type string: C{bytes}
|
||||
"""
|
||||
self.transport.write(_formatNetstring(string))
|
||||
|
||||
def dataReceived(self, data):
|
||||
"""
|
||||
Receives some characters of a netstring.
|
||||
|
||||
Whenever a complete netstring is received, this method extracts
|
||||
its payload and calls L{stringReceived} to process it.
|
||||
|
||||
@param data: A chunk of data representing a (possibly partial)
|
||||
netstring
|
||||
@type data: C{bytes}
|
||||
"""
|
||||
self._remainingData += data
|
||||
while self._remainingData:
|
||||
try:
|
||||
self._consumeData()
|
||||
except IncompleteNetstring:
|
||||
break
|
||||
except NetstringParseError:
|
||||
self._handleParseError()
|
||||
break
|
||||
|
||||
def stringReceived(self, string):
|
||||
"""
|
||||
Override this for notification when each complete string is received.
|
||||
|
||||
@param string: The complete string which was received with all
|
||||
framing (length prefix, etc) removed.
|
||||
@type string: C{bytes}
|
||||
|
||||
@raise NotImplementedError: because the method has to be implemented
|
||||
by the child class.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _maxLengthSize(self):
|
||||
"""
|
||||
Calculate and return the string size of C{self.MAX_LENGTH}.
|
||||
|
||||
@return: The size of the string representation for C{self.MAX_LENGTH}
|
||||
@rtype: C{float}
|
||||
"""
|
||||
return math.ceil(math.log10(self.MAX_LENGTH)) + 1
|
||||
|
||||
def _consumeData(self):
|
||||
"""
|
||||
Consumes the content of C{self._remainingData}.
|
||||
|
||||
@raise IncompleteNetstring: if C{self._remainingData} does not
|
||||
contain enough data to complete the current netstring.
|
||||
@raise NetstringParseError: if the received data do not
|
||||
form a valid netstring.
|
||||
"""
|
||||
if self._state == self._PARSING_LENGTH:
|
||||
self._consumeLength()
|
||||
self._prepareForPayloadConsumption()
|
||||
if self._state == self._PARSING_PAYLOAD:
|
||||
self._consumePayload()
|
||||
|
||||
def _consumeLength(self):
|
||||
"""
|
||||
Consumes the length portion of C{self._remainingData}.
|
||||
|
||||
@raise IncompleteNetstring: if C{self._remainingData} contains
|
||||
a partial length specification (digits without trailing
|
||||
comma).
|
||||
@raise NetstringParseError: if the received data do not form a valid
|
||||
netstring.
|
||||
"""
|
||||
lengthMatch = self._LENGTH.match(self._remainingData)
|
||||
if not lengthMatch:
|
||||
self._checkPartialLengthSpecification()
|
||||
raise IncompleteNetstring()
|
||||
self._processLength(lengthMatch)
|
||||
|
||||
def _checkPartialLengthSpecification(self):
|
||||
"""
|
||||
Makes sure that the received data represents a valid number.
|
||||
|
||||
Checks if C{self._remainingData} represents a number smaller or
|
||||
equal to C{self.MAX_LENGTH}.
|
||||
|
||||
@raise NetstringParseError: if C{self._remainingData} is no
|
||||
number or is too big (checked by L{_extractLength}).
|
||||
"""
|
||||
partialLengthMatch = self._LENGTH_PREFIX.match(self._remainingData)
|
||||
if not partialLengthMatch:
|
||||
raise NetstringParseError(self._MISSING_LENGTH)
|
||||
lengthSpecification = partialLengthMatch.group(1)
|
||||
self._extractLength(lengthSpecification)
|
||||
|
||||
def _processLength(self, lengthMatch):
|
||||
"""
|
||||
Processes the length definition of a netstring.
|
||||
|
||||
Extracts and stores in C{self._expectedPayloadSize} the number
|
||||
representing the netstring size. Removes the prefix
|
||||
representing the length specification from
|
||||
C{self._remainingData}.
|
||||
|
||||
@raise NetstringParseError: if the received netstring does not
|
||||
start with a number or the number is bigger than
|
||||
C{self.MAX_LENGTH}.
|
||||
@param lengthMatch: A regular expression match object matching
|
||||
a netstring length specification
|
||||
@type lengthMatch: C{re.Match}
|
||||
"""
|
||||
endOfNumber = lengthMatch.end(1)
|
||||
startOfData = lengthMatch.end(2)
|
||||
lengthString = self._remainingData[:endOfNumber]
|
||||
# Expect payload plus trailing comma:
|
||||
self._expectedPayloadSize = self._extractLength(lengthString) + 1
|
||||
self._remainingData = self._remainingData[startOfData:]
|
||||
|
||||
def _extractLength(self, lengthAsString):
|
||||
"""
|
||||
Attempts to extract the length information of a netstring.
|
||||
|
||||
@raise NetstringParseError: if the number is bigger than
|
||||
C{self.MAX_LENGTH}.
|
||||
@param lengthAsString: A chunk of data starting with a length
|
||||
specification
|
||||
@type lengthAsString: C{bytes}
|
||||
@return: The length of the netstring
|
||||
@rtype: C{int}
|
||||
"""
|
||||
self._checkStringSize(lengthAsString)
|
||||
length = int(lengthAsString)
|
||||
if length > self.MAX_LENGTH:
|
||||
raise NetstringParseError(self._TOO_LONG % (self.MAX_LENGTH,))
|
||||
return length
|
||||
|
||||
def _checkStringSize(self, lengthAsString):
|
||||
"""
|
||||
Checks the sanity of lengthAsString.
|
||||
|
||||
Checks if the size of the length specification exceeds the
|
||||
size of the string representing self.MAX_LENGTH. If this is
|
||||
not the case, the number represented by lengthAsString is
|
||||
certainly bigger than self.MAX_LENGTH, and a
|
||||
NetstringParseError can be raised.
|
||||
|
||||
This method should make sure that netstrings with extremely
|
||||
long length specifications are refused before even attempting
|
||||
to convert them to an integer (which might trigger a
|
||||
MemoryError).
|
||||
"""
|
||||
if len(lengthAsString) > self._maxLengthSize():
|
||||
raise NetstringParseError(self._TOO_LONG % (self.MAX_LENGTH,))
|
||||
|
||||
def _prepareForPayloadConsumption(self):
|
||||
"""
|
||||
Sets up variables necessary for consuming the payload of a netstring.
|
||||
"""
|
||||
self._state = self._PARSING_PAYLOAD
|
||||
self._currentPayloadSize = 0
|
||||
self._payload.seek(0)
|
||||
self._payload.truncate()
|
||||
|
||||
def _consumePayload(self):
|
||||
"""
|
||||
Consumes the payload portion of C{self._remainingData}.
|
||||
|
||||
If the payload is complete, checks for the trailing comma and
|
||||
processes the payload. If not, raises an L{IncompleteNetstring}
|
||||
exception.
|
||||
|
||||
@raise IncompleteNetstring: if the payload received so far
|
||||
contains fewer characters than expected.
|
||||
@raise NetstringParseError: if the payload does not end with a
|
||||
comma.
|
||||
"""
|
||||
self._extractPayload()
|
||||
if self._currentPayloadSize < self._expectedPayloadSize:
|
||||
raise IncompleteNetstring()
|
||||
self._checkForTrailingComma()
|
||||
self._state = self._PARSING_LENGTH
|
||||
self._processPayload()
|
||||
|
||||
def _extractPayload(self):
|
||||
"""
|
||||
Extracts payload information from C{self._remainingData}.
|
||||
|
||||
Splits C{self._remainingData} at the end of the netstring. The
|
||||
first part becomes C{self._payload}, the second part is stored
|
||||
in C{self._remainingData}.
|
||||
|
||||
If the netstring is not yet complete, the whole content of
|
||||
C{self._remainingData} is moved to C{self._payload}.
|
||||
"""
|
||||
if self._payloadComplete():
|
||||
remainingPayloadSize = self._expectedPayloadSize - self._currentPayloadSize
|
||||
self._payload.write(self._remainingData[:remainingPayloadSize])
|
||||
self._remainingData = self._remainingData[remainingPayloadSize:]
|
||||
self._currentPayloadSize = self._expectedPayloadSize
|
||||
else:
|
||||
self._payload.write(self._remainingData)
|
||||
self._currentPayloadSize += len(self._remainingData)
|
||||
self._remainingData = b""
|
||||
|
||||
def _payloadComplete(self):
|
||||
"""
|
||||
Checks if enough data have been received to complete the netstring.
|
||||
|
||||
@return: C{True} iff the received data contain at least as many
|
||||
characters as specified in the length section of the
|
||||
netstring
|
||||
@rtype: C{bool}
|
||||
"""
|
||||
return (
|
||||
len(self._remainingData) + self._currentPayloadSize
|
||||
>= self._expectedPayloadSize
|
||||
)
|
||||
|
||||
def _processPayload(self):
|
||||
"""
|
||||
Processes the actual payload with L{stringReceived}.
|
||||
|
||||
Strips C{self._payload} of the trailing comma and calls
|
||||
L{stringReceived} with the result.
|
||||
"""
|
||||
self.stringReceived(self._payload.getvalue()[:-1])
|
||||
|
||||
def _checkForTrailingComma(self):
|
||||
"""
|
||||
Checks if the netstring has a trailing comma at the expected position.
|
||||
|
||||
@raise NetstringParseError: if the last payload character is
|
||||
anything but a comma.
|
||||
"""
|
||||
if self._payload.getvalue()[-1:] != b",":
|
||||
raise NetstringParseError(self._MISSING_COMMA)
|
||||
|
||||
def _handleParseError(self):
|
||||
"""
|
||||
Terminates the connection and sets the flag C{self.brokenPeer}.
|
||||
"""
|
||||
self.transport.loseConnection()
|
||||
self.brokenPeer = 1
|
||||
|
||||
|
||||
class LineOnlyReceiver(protocol.Protocol):
|
||||
"""
|
||||
A protocol that receives only lines.
|
||||
|
||||
This is purely a speed optimisation over LineReceiver, for the
|
||||
cases that raw mode is known to be unnecessary.
|
||||
|
||||
@cvar delimiter: The line-ending delimiter to use. By default this is
|
||||
C{b'\\r\\n'}.
|
||||
@cvar MAX_LENGTH: The maximum length of a line to allow (If a
|
||||
sent line is longer than this, the connection is dropped).
|
||||
Default is 16384.
|
||||
"""
|
||||
|
||||
_buffer = b""
|
||||
delimiter = b"\r\n"
|
||||
MAX_LENGTH = 16384
|
||||
|
||||
def dataReceived(self, data):
|
||||
"""
|
||||
Translates bytes into lines, and calls lineReceived.
|
||||
"""
|
||||
lines = (self._buffer + data).split(self.delimiter)
|
||||
self._buffer = lines.pop(-1)
|
||||
for line in lines:
|
||||
if self.transport.disconnecting:
|
||||
# this is necessary because the transport may be told to lose
|
||||
# the connection by a line within a larger packet, and it is
|
||||
# important to disregard all the lines in that packet following
|
||||
# the one that told it to close.
|
||||
return
|
||||
if len(line) > self.MAX_LENGTH:
|
||||
return self.lineLengthExceeded(line)
|
||||
else:
|
||||
self.lineReceived(line)
|
||||
if len(self._buffer) > self.MAX_LENGTH:
|
||||
return self.lineLengthExceeded(self._buffer)
|
||||
|
||||
def lineReceived(self, line):
|
||||
"""
|
||||
Override this for when each line is received.
|
||||
|
||||
@param line: The line which was received with the delimiter removed.
|
||||
@type line: C{bytes}
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def sendLine(self, line):
|
||||
"""
|
||||
Sends a line to the other end of the connection.
|
||||
|
||||
@param line: The line to send, not including the delimiter.
|
||||
@type line: C{bytes}
|
||||
"""
|
||||
return self.transport.writeSequence((line, self.delimiter))
|
||||
|
||||
def lineLengthExceeded(self, line):
|
||||
"""
|
||||
Called when the maximum line length has been reached.
|
||||
Override if it needs to be dealt with in some special way.
|
||||
"""
|
||||
return self.transport.loseConnection()
|
||||
|
||||
|
||||
class _PauseableMixin:
|
||||
paused = False
|
||||
|
||||
def pauseProducing(self):
|
||||
self.paused = True
|
||||
self.transport.pauseProducing()
|
||||
|
||||
def resumeProducing(self):
|
||||
self.paused = False
|
||||
self.transport.resumeProducing()
|
||||
self.dataReceived(b"")
|
||||
|
||||
def stopProducing(self):
|
||||
self.paused = True
|
||||
self.transport.stopProducing()
|
||||
|
||||
|
||||
class LineReceiver(protocol.Protocol, _PauseableMixin):
|
||||
"""
|
||||
A protocol that receives lines and/or raw data, depending on mode.
|
||||
|
||||
In line mode, each line that's received becomes a callback to
|
||||
L{lineReceived}. In raw data mode, each chunk of raw data becomes a
|
||||
callback to L{LineReceiver.rawDataReceived}.
|
||||
The L{setLineMode} and L{setRawMode} methods switch between the two modes.
|
||||
|
||||
This is useful for line-oriented protocols such as IRC, HTTP, POP, etc.
|
||||
|
||||
@cvar delimiter: The line-ending delimiter to use. By default this is
|
||||
C{b'\\r\\n'}.
|
||||
@cvar MAX_LENGTH: The maximum length of a line to allow (If a
|
||||
sent line is longer than this, the connection is dropped).
|
||||
Default is 16384.
|
||||
"""
|
||||
|
||||
line_mode = 1
|
||||
_buffer = b""
|
||||
_busyReceiving = False
|
||||
delimiter = b"\r\n"
|
||||
MAX_LENGTH = 16384
|
||||
|
||||
def clearLineBuffer(self):
|
||||
"""
|
||||
Clear buffered data.
|
||||
|
||||
@return: All of the cleared buffered data.
|
||||
@rtype: C{bytes}
|
||||
"""
|
||||
b, self._buffer = self._buffer, b""
|
||||
return b
|
||||
|
||||
def dataReceived(self, data):
|
||||
"""
|
||||
Protocol.dataReceived.
|
||||
Translates bytes into lines, and calls lineReceived (or
|
||||
rawDataReceived, depending on mode.)
|
||||
"""
|
||||
if self._busyReceiving:
|
||||
self._buffer += data
|
||||
return
|
||||
|
||||
try:
|
||||
self._busyReceiving = True
|
||||
self._buffer += data
|
||||
while self._buffer and not self.paused:
|
||||
if self.line_mode:
|
||||
try:
|
||||
line, self._buffer = self._buffer.split(self.delimiter, 1)
|
||||
except ValueError:
|
||||
if len(self._buffer) >= (self.MAX_LENGTH + len(self.delimiter)):
|
||||
line, self._buffer = self._buffer, b""
|
||||
return self.lineLengthExceeded(line)
|
||||
return
|
||||
else:
|
||||
lineLength = len(line)
|
||||
if lineLength > self.MAX_LENGTH:
|
||||
exceeded = line + self.delimiter + self._buffer
|
||||
self._buffer = b""
|
||||
return self.lineLengthExceeded(exceeded)
|
||||
why = self.lineReceived(line)
|
||||
if why or self.transport and self.transport.disconnecting:
|
||||
return why
|
||||
else:
|
||||
data = self._buffer
|
||||
self._buffer = b""
|
||||
why = self.rawDataReceived(data)
|
||||
if why:
|
||||
return why
|
||||
finally:
|
||||
self._busyReceiving = False
|
||||
|
||||
def setLineMode(self, extra=b""):
|
||||
"""
|
||||
Sets the line-mode of this receiver.
|
||||
|
||||
If you are calling this from a rawDataReceived callback,
|
||||
you can pass in extra unhandled data, and that data will
|
||||
be parsed for lines. Further data received will be sent
|
||||
to lineReceived rather than rawDataReceived.
|
||||
|
||||
Do not pass extra data if calling this function from
|
||||
within a lineReceived callback.
|
||||
"""
|
||||
self.line_mode = 1
|
||||
if extra:
|
||||
return self.dataReceived(extra)
|
||||
|
||||
def setRawMode(self):
|
||||
"""
|
||||
Sets the raw mode of this receiver.
|
||||
Further data received will be sent to rawDataReceived rather
|
||||
than lineReceived.
|
||||
"""
|
||||
self.line_mode = 0
|
||||
|
||||
def rawDataReceived(self, data):
|
||||
"""
|
||||
Override this for when raw data is received.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def lineReceived(self, line):
|
||||
"""
|
||||
Override this for when each line is received.
|
||||
|
||||
@param line: The line which was received with the delimiter removed.
|
||||
@type line: C{bytes}
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def sendLine(self, line):
|
||||
"""
|
||||
Sends a line to the other end of the connection.
|
||||
|
||||
@param line: The line to send, not including the delimiter.
|
||||
@type line: C{bytes}
|
||||
"""
|
||||
return self.transport.write(line + self.delimiter)
|
||||
|
||||
def lineLengthExceeded(self, line):
|
||||
"""
|
||||
Called when the maximum line length has been reached.
|
||||
Override if it needs to be dealt with in some special way.
|
||||
|
||||
The argument 'line' contains the remainder of the buffer, starting
|
||||
with (at least some part) of the line which is too long. This may
|
||||
be more than one line, or may be only the initial portion of the
|
||||
line.
|
||||
"""
|
||||
return self.transport.loseConnection()
|
||||
|
||||
|
||||
class StringTooLongError(AssertionError):
|
||||
"""
|
||||
Raised when trying to send a string too long for a length prefixed
|
||||
protocol.
|
||||
"""
|
||||
|
||||
|
||||
class _RecvdCompatHack:
|
||||
"""
|
||||
Emulates the to-be-deprecated C{IntNStringReceiver.recvd} attribute.
|
||||
|
||||
The C{recvd} attribute was where the working buffer for buffering and
|
||||
parsing netstrings was kept. It was updated each time new data arrived and
|
||||
each time some of that data was parsed and delivered to application code.
|
||||
The piecemeal updates to its string value were expensive and have been
|
||||
removed from C{IntNStringReceiver} in the normal case. However, for
|
||||
applications directly reading this attribute, this descriptor restores that
|
||||
behavior. It only copies the working buffer when necessary (ie, when
|
||||
accessed). This avoids the cost for applications not using the data.
|
||||
|
||||
This is a custom descriptor rather than a property, because we still need
|
||||
the default __set__ behavior in both new-style and old-style subclasses.
|
||||
"""
|
||||
|
||||
def __get__(self, oself, type=None):
|
||||
return oself._unprocessed[oself._compatibilityOffset :]
|
||||
|
||||
|
||||
class IntNStringReceiver(protocol.Protocol, _PauseableMixin):
|
||||
"""
|
||||
Generic class for length prefixed protocols.
|
||||
|
||||
@ivar _unprocessed: bytes received, but not yet broken up into messages /
|
||||
sent to stringReceived. _compatibilityOffset must be updated when this
|
||||
value is updated so that the C{recvd} attribute can be generated
|
||||
correctly.
|
||||
@type _unprocessed: C{bytes}
|
||||
|
||||
@ivar structFormat: format used for struct packing/unpacking. Define it in
|
||||
subclass.
|
||||
@type structFormat: C{str}
|
||||
|
||||
@ivar prefixLength: length of the prefix, in bytes. Define it in subclass,
|
||||
using C{struct.calcsize(structFormat)}
|
||||
@type prefixLength: C{int}
|
||||
|
||||
@ivar _compatibilityOffset: the offset within C{_unprocessed} to the next
|
||||
message to be parsed. (used to generate the recvd attribute)
|
||||
@type _compatibilityOffset: C{int}
|
||||
"""
|
||||
|
||||
MAX_LENGTH = 99999
|
||||
_unprocessed = b""
|
||||
_compatibilityOffset = 0
|
||||
|
||||
# Backwards compatibility support for applications which directly touch the
|
||||
# "internal" parse buffer.
|
||||
recvd = _RecvdCompatHack()
|
||||
|
||||
def stringReceived(self, string):
|
||||
"""
|
||||
Override this for notification when each complete string is received.
|
||||
|
||||
@param string: The complete string which was received with all
|
||||
framing (length prefix, etc) removed.
|
||||
@type string: C{bytes}
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def lengthLimitExceeded(self, length):
|
||||
"""
|
||||
Callback invoked when a length prefix greater than C{MAX_LENGTH} is
|
||||
received. The default implementation disconnects the transport.
|
||||
Override this.
|
||||
|
||||
@param length: The length prefix which was received.
|
||||
@type length: C{int}
|
||||
"""
|
||||
self.transport.loseConnection()
|
||||
|
||||
def dataReceived(self, data):
|
||||
"""
|
||||
Convert int prefixed strings into calls to stringReceived.
|
||||
"""
|
||||
# Try to minimize string copying (via slices) by keeping one buffer
|
||||
# containing all the data we have so far and a separate offset into that
|
||||
# buffer.
|
||||
alldata = self._unprocessed + data
|
||||
currentOffset = 0
|
||||
prefixLength = self.prefixLength
|
||||
fmt = self.structFormat
|
||||
self._unprocessed = alldata
|
||||
|
||||
while len(alldata) >= (currentOffset + prefixLength) and not self.paused:
|
||||
messageStart = currentOffset + prefixLength
|
||||
(length,) = unpack(fmt, alldata[currentOffset:messageStart])
|
||||
if length > self.MAX_LENGTH:
|
||||
self._unprocessed = alldata
|
||||
self._compatibilityOffset = currentOffset
|
||||
self.lengthLimitExceeded(length)
|
||||
return
|
||||
messageEnd = messageStart + length
|
||||
if len(alldata) < messageEnd:
|
||||
break
|
||||
|
||||
# Here we have to slice the working buffer so we can send just the
|
||||
# netstring into the stringReceived callback.
|
||||
packet = alldata[messageStart:messageEnd]
|
||||
currentOffset = messageEnd
|
||||
self._compatibilityOffset = currentOffset
|
||||
self.stringReceived(packet)
|
||||
|
||||
# Check to see if the backwards compat "recvd" attribute got written
|
||||
# to by application code. If so, drop the current data buffer and
|
||||
# switch to the new buffer given by that attribute's value.
|
||||
if "recvd" in self.__dict__:
|
||||
alldata = self.__dict__.pop("recvd")
|
||||
self._unprocessed = alldata
|
||||
self._compatibilityOffset = currentOffset = 0
|
||||
if alldata:
|
||||
continue
|
||||
return
|
||||
|
||||
# Slice off all the data that has been processed, avoiding holding onto
|
||||
# memory to store it, and update the compatibility attributes to reflect
|
||||
# that change.
|
||||
self._unprocessed = alldata[currentOffset:]
|
||||
self._compatibilityOffset = 0
|
||||
|
||||
def sendString(self, string):
|
||||
"""
|
||||
Send a prefixed string to the other end of the connection.
|
||||
|
||||
@param string: The string to send. The necessary framing (length
|
||||
prefix, etc) will be added.
|
||||
@type string: C{bytes}
|
||||
"""
|
||||
if len(string) >= 2 ** (8 * self.prefixLength):
|
||||
raise StringTooLongError(
|
||||
"Try to send %s bytes whereas maximum is %s"
|
||||
% (len(string), 2 ** (8 * self.prefixLength))
|
||||
)
|
||||
self.transport.write(pack(self.structFormat, len(string)) + string)
|
||||
|
||||
|
||||
class Int32StringReceiver(IntNStringReceiver):
|
||||
"""
|
||||
A receiver for int32-prefixed strings.
|
||||
|
||||
An int32 string is a string prefixed by 4 bytes, the 32-bit length of
|
||||
the string encoded in network byte order.
|
||||
|
||||
This class publishes the same interface as NetstringReceiver.
|
||||
"""
|
||||
|
||||
structFormat = "!I"
|
||||
prefixLength = calcsize(structFormat)
|
||||
|
||||
|
||||
class Int16StringReceiver(IntNStringReceiver):
|
||||
"""
|
||||
A receiver for int16-prefixed strings.
|
||||
|
||||
An int16 string is a string prefixed by 2 bytes, the 16-bit length of
|
||||
the string encoded in network byte order.
|
||||
|
||||
This class publishes the same interface as NetstringReceiver.
|
||||
"""
|
||||
|
||||
structFormat = "!H"
|
||||
prefixLength = calcsize(structFormat)
|
||||
|
||||
|
||||
class Int8StringReceiver(IntNStringReceiver):
|
||||
"""
|
||||
A receiver for int8-prefixed strings.
|
||||
|
||||
An int8 string is a string prefixed by 1 byte, the 8-bit length of
|
||||
the string.
|
||||
|
||||
This class publishes the same interface as NetstringReceiver.
|
||||
"""
|
||||
|
||||
structFormat = "!B"
|
||||
prefixLength = calcsize(structFormat)
|
||||
|
||||
|
||||
class StatefulStringProtocol:
|
||||
"""
|
||||
A stateful string protocol.
|
||||
|
||||
This is a mixin for string protocols (L{Int32StringReceiver},
|
||||
L{NetstringReceiver}) which translates L{stringReceived} into a callback
|
||||
(prefixed with C{'proto_'}) depending on state.
|
||||
|
||||
The state C{'done'} is special; if a C{proto_*} method returns it, the
|
||||
connection will be closed immediately.
|
||||
|
||||
@ivar state: Current state of the protocol. Defaults to C{'init'}.
|
||||
@type state: C{str}
|
||||
"""
|
||||
|
||||
state = "init"
|
||||
|
||||
def stringReceived(self, string):
|
||||
"""
|
||||
Choose a protocol phase function and call it.
|
||||
|
||||
Call back to the appropriate protocol phase; this begins with
|
||||
the function C{proto_init} and moves on to C{proto_*} depending on
|
||||
what each C{proto_*} function returns. (For example, if
|
||||
C{self.proto_init} returns 'foo', then C{self.proto_foo} will be the
|
||||
next function called when a protocol message is received.
|
||||
"""
|
||||
try:
|
||||
pto = "proto_" + self.state
|
||||
statehandler = getattr(self, pto)
|
||||
except AttributeError:
|
||||
log.msg("callback", self.state, "not found")
|
||||
else:
|
||||
self.state = statehandler(string)
|
||||
if self.state == "done":
|
||||
self.transport.loseConnection()
|
||||
|
||||
|
||||
@implementer(interfaces.IProducer)
|
||||
class FileSender:
|
||||
"""
|
||||
A producer that sends the contents of a file to a consumer.
|
||||
|
||||
This is a helper for protocols that, at some point, will take a
|
||||
file-like object, read its contents, and write them out to the network,
|
||||
optionally performing some transformation on the bytes in between.
|
||||
"""
|
||||
|
||||
CHUNK_SIZE = 2**14
|
||||
|
||||
lastSent = ""
|
||||
deferred = None
|
||||
|
||||
def beginFileTransfer(self, file, consumer, transform=None):
|
||||
"""
|
||||
Begin transferring a file
|
||||
|
||||
@type file: Any file-like object
|
||||
@param file: The file object to read data from
|
||||
|
||||
@type consumer: Any implementor of IConsumer
|
||||
@param consumer: The object to write data to
|
||||
|
||||
@param transform: A callable taking one string argument and returning
|
||||
the same. All bytes read from the file are passed through this before
|
||||
being written to the consumer.
|
||||
|
||||
@rtype: C{Deferred}
|
||||
@return: A deferred whose callback will be invoked when the file has
|
||||
been completely written to the consumer. The last byte written to the
|
||||
consumer is passed to the callback.
|
||||
"""
|
||||
self.file = file
|
||||
self.consumer = consumer
|
||||
self.transform = transform
|
||||
|
||||
self.deferred = deferred = defer.Deferred()
|
||||
self.consumer.registerProducer(self, False)
|
||||
return deferred
|
||||
|
||||
def resumeProducing(self):
|
||||
chunk = ""
|
||||
if self.file:
|
||||
chunk = self.file.read(self.CHUNK_SIZE)
|
||||
if not chunk:
|
||||
self.file = None
|
||||
self.consumer.unregisterProducer()
|
||||
if self.deferred:
|
||||
self.deferred.callback(self.lastSent)
|
||||
self.deferred = None
|
||||
return
|
||||
|
||||
if self.transform:
|
||||
chunk = self.transform(chunk)
|
||||
self.consumer.write(chunk)
|
||||
self.lastSent = chunk[-1:]
|
||||
|
||||
def pauseProducing(self):
|
||||
pass
|
||||
|
||||
def stopProducing(self):
|
||||
if self.deferred:
|
||||
self.deferred.errback(Exception("Consumer asked us to stop producing"))
|
||||
self.deferred = None
|
||||
@@ -0,0 +1,42 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
|
||||
"""The Finger User Information Protocol (RFC 1288)"""
|
||||
|
||||
from twisted.protocols import basic
|
||||
|
||||
|
||||
class Finger(basic.LineReceiver):
|
||||
def lineReceived(self, line):
|
||||
parts = line.split()
|
||||
if not parts:
|
||||
parts = [b""]
|
||||
if len(parts) == 1:
|
||||
slash_w = 0
|
||||
else:
|
||||
slash_w = 1
|
||||
user = parts[-1]
|
||||
if b"@" in user:
|
||||
hostPlace = user.rfind(b"@")
|
||||
user = user[:hostPlace]
|
||||
host = user[hostPlace + 1 :]
|
||||
return self.forwardQuery(slash_w, user, host)
|
||||
if user:
|
||||
return self.getUser(slash_w, user)
|
||||
else:
|
||||
return self.getDomain(slash_w)
|
||||
|
||||
def _refuseMessage(self, message):
|
||||
self.transport.write(message + b"\n")
|
||||
self.transport.loseConnection()
|
||||
|
||||
def forwardQuery(self, slash_w, user, host):
|
||||
self._refuseMessage(b"Finger forwarding service denied")
|
||||
|
||||
def getDomain(self, slash_w):
|
||||
self._refuseMessage(b"Finger online list denied")
|
||||
|
||||
def getUser(self, slash_w, user):
|
||||
self.transport.write(b"Login: " + user + b"\n")
|
||||
self._refuseMessage(b"No such user")
|
||||
3443
.venv/lib/python3.12/site-packages/twisted/protocols/ftp.py
Normal file
3443
.venv/lib/python3.12/site-packages/twisted/protocols/ftp.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,10 @@
|
||||
# -*- test-case-name: twisted.protocols.haproxy.test -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
HAProxy PROXY protocol implementations.
|
||||
"""
|
||||
__all__ = ["proxyEndpoint"]
|
||||
|
||||
from ._wrapper import proxyEndpoint
|
||||
@@ -0,0 +1,49 @@
|
||||
# -*- test-case-name: twisted.protocols.haproxy.test -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
HAProxy specific exceptions.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
from typing import Callable, Generator, Type
|
||||
|
||||
|
||||
class InvalidProxyHeader(Exception):
|
||||
"""
|
||||
The provided PROXY protocol header is invalid.
|
||||
"""
|
||||
|
||||
|
||||
class InvalidNetworkProtocol(InvalidProxyHeader):
|
||||
"""
|
||||
The network protocol was not one of TCP4 TCP6 or UNKNOWN.
|
||||
"""
|
||||
|
||||
|
||||
class MissingAddressData(InvalidProxyHeader):
|
||||
"""
|
||||
The address data is missing or incomplete.
|
||||
"""
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def convertError(
|
||||
sourceType: Type[BaseException], targetType: Callable[[], BaseException]
|
||||
) -> Generator[None, None, None]:
|
||||
"""
|
||||
Convert an error into a different error type.
|
||||
|
||||
@param sourceType: The type of exception that should be caught and
|
||||
converted.
|
||||
@type sourceType: L{BaseException}
|
||||
|
||||
@param targetType: The type of exception to which the original should be
|
||||
converted.
|
||||
@type targetType: L{BaseException}
|
||||
"""
|
||||
try:
|
||||
yield
|
||||
except sourceType as e:
|
||||
raise targetType().with_traceback(e.__traceback__)
|
||||
@@ -0,0 +1,34 @@
|
||||
# -*- test-case-name: twisted.protocols.haproxy.test -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
IProxyInfo implementation.
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
import attr
|
||||
|
||||
from twisted.internet.interfaces import IAddress
|
||||
from ._interfaces import IProxyInfo
|
||||
|
||||
|
||||
@implementer(IProxyInfo)
|
||||
@attr.s(frozen=True, slots=True, auto_attribs=True)
|
||||
class ProxyInfo:
|
||||
"""
|
||||
A data container for parsed PROXY protocol information.
|
||||
|
||||
@ivar header: The raw header bytes extracted from the connection.
|
||||
@type header: C{bytes}
|
||||
@ivar source: The connection source address.
|
||||
@type source: L{twisted.internet.interfaces.IAddress}
|
||||
@ivar destination: The connection destination address.
|
||||
@type destination: L{twisted.internet.interfaces.IAddress}
|
||||
"""
|
||||
|
||||
header: bytes
|
||||
source: Optional[IAddress]
|
||||
destination: Optional[IAddress]
|
||||
@@ -0,0 +1,63 @@
|
||||
# -*- test-case-name: twisted.protocols.haproxy.test -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Interfaces used by the PROXY protocol modules.
|
||||
"""
|
||||
from typing import Tuple, Union
|
||||
|
||||
import zope.interface
|
||||
|
||||
|
||||
class IProxyInfo(zope.interface.Interface):
|
||||
"""
|
||||
Data container for PROXY protocol header data.
|
||||
"""
|
||||
|
||||
header = zope.interface.Attribute(
|
||||
"The raw byestring that represents the PROXY protocol header.",
|
||||
)
|
||||
source = zope.interface.Attribute(
|
||||
"An L{twisted.internet.interfaces.IAddress} representing the "
|
||||
"connection source."
|
||||
)
|
||||
destination = zope.interface.Attribute(
|
||||
"An L{twisted.internet.interfaces.IAddress} representing the "
|
||||
"connection destination."
|
||||
)
|
||||
|
||||
|
||||
class IProxyParser(zope.interface.Interface):
|
||||
"""
|
||||
Streaming parser that handles PROXY protocol headers.
|
||||
"""
|
||||
|
||||
def feed(data: bytes) -> Union[Tuple[IProxyInfo, bytes], Tuple[None, None]]:
|
||||
"""
|
||||
Consume a chunk of data and attempt to parse it.
|
||||
|
||||
@param data: A bytestring.
|
||||
@type data: bytes
|
||||
|
||||
@return: A two-tuple containing, in order, an L{IProxyInfo} and any
|
||||
bytes fed to the parser that followed the end of the header. Both
|
||||
of these values are None until a complete header is parsed.
|
||||
|
||||
@raises InvalidProxyHeader: If the bytes fed to the parser create an
|
||||
invalid PROXY header.
|
||||
"""
|
||||
|
||||
def parse(line: bytes) -> IProxyInfo:
|
||||
"""
|
||||
Parse a bytestring as a full PROXY protocol header line.
|
||||
|
||||
@param line: A bytestring that represents a valid HAProxy PROXY
|
||||
protocol header line.
|
||||
@type line: bytes
|
||||
|
||||
@return: An L{IProxyInfo} containing the parsed data.
|
||||
|
||||
@raises InvalidProxyHeader: If the bytestring does not represent a
|
||||
valid PROXY header.
|
||||
"""
|
||||
@@ -0,0 +1,75 @@
|
||||
# -*- test-case-name: twisted.protocols.haproxy.test.test_parser -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Parser for 'haproxy:' string endpoint.
|
||||
"""
|
||||
from typing import Mapping, Tuple
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.internet import interfaces
|
||||
from twisted.internet.endpoints import (
|
||||
IStreamServerEndpointStringParser,
|
||||
_WrapperServerEndpoint,
|
||||
quoteStringArgument,
|
||||
serverFromString,
|
||||
)
|
||||
from twisted.plugin import IPlugin
|
||||
from . import proxyEndpoint
|
||||
|
||||
|
||||
def unparseEndpoint(args: Tuple[object, ...], kwargs: Mapping[str, object]) -> str:
|
||||
"""
|
||||
Un-parse the already-parsed args and kwargs back into endpoint syntax.
|
||||
|
||||
@param args: C{:}-separated arguments
|
||||
|
||||
@param kwargs: C{:} and then C{=}-separated keyword arguments
|
||||
|
||||
@return: a string equivalent to the original format which this was parsed
|
||||
as.
|
||||
"""
|
||||
|
||||
description = ":".join(
|
||||
[quoteStringArgument(str(arg)) for arg in args]
|
||||
+ sorted(
|
||||
"{}={}".format(
|
||||
quoteStringArgument(str(key)), quoteStringArgument(str(value))
|
||||
)
|
||||
for key, value in kwargs.items()
|
||||
)
|
||||
)
|
||||
return description
|
||||
|
||||
|
||||
@implementer(IPlugin, IStreamServerEndpointStringParser)
|
||||
class HAProxyServerParser:
|
||||
"""
|
||||
Stream server endpoint string parser for the HAProxyServerEndpoint type.
|
||||
|
||||
@ivar prefix: See L{IStreamServerEndpointStringParser.prefix}.
|
||||
"""
|
||||
|
||||
prefix = "haproxy"
|
||||
|
||||
def parseStreamServer(
|
||||
self, reactor: interfaces.IReactorCore, *args: object, **kwargs: object
|
||||
) -> _WrapperServerEndpoint:
|
||||
"""
|
||||
Parse a stream server endpoint from a reactor and string-only arguments
|
||||
and keyword arguments.
|
||||
|
||||
@param reactor: The reactor.
|
||||
|
||||
@param args: The parsed string arguments.
|
||||
|
||||
@param kwargs: The parsed keyword arguments.
|
||||
|
||||
@return: a stream server endpoint
|
||||
@rtype: L{IStreamServerEndpoint}
|
||||
"""
|
||||
subdescription = unparseEndpoint(args, kwargs)
|
||||
wrappedEndpoint = serverFromString(reactor, subdescription)
|
||||
return proxyEndpoint(wrappedEndpoint)
|
||||
@@ -0,0 +1,142 @@
|
||||
# -*- test-case-name: twisted.protocols.haproxy.test.test_v1parser -*-
|
||||
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
IProxyParser implementation for version one of the PROXY protocol.
|
||||
"""
|
||||
from typing import Tuple, Union
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.internet import address
|
||||
from . import _info, _interfaces
|
||||
from ._exceptions import (
|
||||
InvalidNetworkProtocol,
|
||||
InvalidProxyHeader,
|
||||
MissingAddressData,
|
||||
convertError,
|
||||
)
|
||||
|
||||
|
||||
@implementer(_interfaces.IProxyParser)
|
||||
class V1Parser:
|
||||
"""
|
||||
PROXY protocol version one header parser.
|
||||
|
||||
Version one of the PROXY protocol is a human readable format represented
|
||||
by a single, newline delimited binary string that contains all of the
|
||||
relevant source and destination data.
|
||||
"""
|
||||
|
||||
PROXYSTR = b"PROXY"
|
||||
UNKNOWN_PROTO = b"UNKNOWN"
|
||||
TCP4_PROTO = b"TCP4"
|
||||
TCP6_PROTO = b"TCP6"
|
||||
ALLOWED_NET_PROTOS = (
|
||||
TCP4_PROTO,
|
||||
TCP6_PROTO,
|
||||
UNKNOWN_PROTO,
|
||||
)
|
||||
NEWLINE = b"\r\n"
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.buffer = b""
|
||||
|
||||
def feed(
|
||||
self, data: bytes
|
||||
) -> Union[Tuple[_info.ProxyInfo, bytes], Tuple[None, None]]:
|
||||
"""
|
||||
Consume a chunk of data and attempt to parse it.
|
||||
|
||||
@param data: A bytestring.
|
||||
@type data: L{bytes}
|
||||
|
||||
@return: A two-tuple containing, in order, a
|
||||
L{_interfaces.IProxyInfo} and any bytes fed to the
|
||||
parser that followed the end of the header. Both of these values
|
||||
are None until a complete header is parsed.
|
||||
|
||||
@raises InvalidProxyHeader: If the bytes fed to the parser create an
|
||||
invalid PROXY header.
|
||||
"""
|
||||
self.buffer += data
|
||||
if len(self.buffer) > 107 and self.NEWLINE not in self.buffer:
|
||||
raise InvalidProxyHeader()
|
||||
lines = (self.buffer).split(self.NEWLINE, 1)
|
||||
if not len(lines) > 1:
|
||||
return (None, None)
|
||||
self.buffer = b""
|
||||
remaining = lines.pop()
|
||||
header = lines.pop()
|
||||
info = self.parse(header)
|
||||
return (info, remaining)
|
||||
|
||||
@classmethod
|
||||
def parse(cls, line: bytes) -> _info.ProxyInfo:
|
||||
"""
|
||||
Parse a bytestring as a full PROXY protocol header line.
|
||||
|
||||
@param line: A bytestring that represents a valid HAProxy PROXY
|
||||
protocol header line.
|
||||
@type line: bytes
|
||||
|
||||
@return: A L{_interfaces.IProxyInfo} containing the parsed data.
|
||||
|
||||
@raises InvalidProxyHeader: If the bytestring does not represent a
|
||||
valid PROXY header.
|
||||
|
||||
@raises InvalidNetworkProtocol: When no protocol can be parsed or is
|
||||
not one of the allowed values.
|
||||
|
||||
@raises MissingAddressData: When the protocol is TCP* but the header
|
||||
does not contain a complete set of addresses and ports.
|
||||
"""
|
||||
originalLine = line
|
||||
proxyStr = None
|
||||
networkProtocol = None
|
||||
sourceAddr = None
|
||||
sourcePort = None
|
||||
destAddr = None
|
||||
destPort = None
|
||||
|
||||
with convertError(ValueError, InvalidProxyHeader):
|
||||
proxyStr, line = line.split(b" ", 1)
|
||||
|
||||
if proxyStr != cls.PROXYSTR:
|
||||
raise InvalidProxyHeader()
|
||||
|
||||
with convertError(ValueError, InvalidNetworkProtocol):
|
||||
networkProtocol, line = line.split(b" ", 1)
|
||||
|
||||
if networkProtocol not in cls.ALLOWED_NET_PROTOS:
|
||||
raise InvalidNetworkProtocol()
|
||||
|
||||
if networkProtocol == cls.UNKNOWN_PROTO:
|
||||
return _info.ProxyInfo(originalLine, None, None)
|
||||
|
||||
with convertError(ValueError, MissingAddressData):
|
||||
sourceAddr, line = line.split(b" ", 1)
|
||||
|
||||
with convertError(ValueError, MissingAddressData):
|
||||
destAddr, line = line.split(b" ", 1)
|
||||
|
||||
with convertError(ValueError, MissingAddressData):
|
||||
sourcePort, line = line.split(b" ", 1)
|
||||
|
||||
with convertError(ValueError, MissingAddressData):
|
||||
destPort = line.split(b" ")[0]
|
||||
|
||||
if networkProtocol == cls.TCP4_PROTO:
|
||||
return _info.ProxyInfo(
|
||||
originalLine,
|
||||
address.IPv4Address("TCP", sourceAddr.decode(), int(sourcePort)),
|
||||
address.IPv4Address("TCP", destAddr.decode(), int(destPort)),
|
||||
)
|
||||
|
||||
return _info.ProxyInfo(
|
||||
originalLine,
|
||||
address.IPv6Address("TCP", sourceAddr.decode(), int(sourcePort)),
|
||||
address.IPv6Address("TCP", destAddr.decode(), int(destPort)),
|
||||
)
|
||||
@@ -0,0 +1,217 @@
|
||||
# -*- test-case-name: twisted.protocols.haproxy.test.test_v2parser -*-
|
||||
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
IProxyParser implementation for version two of the PROXY protocol.
|
||||
"""
|
||||
|
||||
import binascii
|
||||
import struct
|
||||
from typing import Callable, Tuple, Type, Union
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
from constantly import ValueConstant, Values
|
||||
from typing_extensions import Literal
|
||||
|
||||
from twisted.internet import address
|
||||
from twisted.python import compat
|
||||
from . import _info, _interfaces
|
||||
from ._exceptions import (
|
||||
InvalidNetworkProtocol,
|
||||
InvalidProxyHeader,
|
||||
MissingAddressData,
|
||||
convertError,
|
||||
)
|
||||
|
||||
|
||||
class NetFamily(Values):
|
||||
"""
|
||||
Values for the 'family' field.
|
||||
"""
|
||||
|
||||
UNSPEC = ValueConstant(0x00)
|
||||
INET = ValueConstant(0x10)
|
||||
INET6 = ValueConstant(0x20)
|
||||
UNIX = ValueConstant(0x30)
|
||||
|
||||
|
||||
class NetProtocol(Values):
|
||||
"""
|
||||
Values for 'protocol' field.
|
||||
"""
|
||||
|
||||
UNSPEC = ValueConstant(0)
|
||||
STREAM = ValueConstant(1)
|
||||
DGRAM = ValueConstant(2)
|
||||
|
||||
|
||||
_HIGH = 0b11110000
|
||||
_LOW = 0b00001111
|
||||
_LOCALCOMMAND = "LOCAL"
|
||||
_PROXYCOMMAND = "PROXY"
|
||||
|
||||
|
||||
@implementer(_interfaces.IProxyParser)
|
||||
class V2Parser:
|
||||
"""
|
||||
PROXY protocol version two header parser.
|
||||
|
||||
Version two of the PROXY protocol is a binary format.
|
||||
"""
|
||||
|
||||
PREFIX = b"\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A"
|
||||
VERSIONS = [32]
|
||||
COMMANDS = {0: _LOCALCOMMAND, 1: _PROXYCOMMAND}
|
||||
ADDRESSFORMATS = {
|
||||
# TCP4
|
||||
17: "!4s4s2H",
|
||||
18: "!4s4s2H",
|
||||
# TCP6
|
||||
33: "!16s16s2H",
|
||||
34: "!16s16s2H",
|
||||
# UNIX
|
||||
49: "!108s108s",
|
||||
50: "!108s108s",
|
||||
}
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.buffer = b""
|
||||
|
||||
def feed(
|
||||
self, data: bytes
|
||||
) -> Union[Tuple[_info.ProxyInfo, bytes], Tuple[None, None]]:
|
||||
"""
|
||||
Consume a chunk of data and attempt to parse it.
|
||||
|
||||
@param data: A bytestring.
|
||||
@type data: bytes
|
||||
|
||||
@return: A two-tuple containing, in order, a L{_interfaces.IProxyInfo}
|
||||
and any bytes fed to the parser that followed the end of the
|
||||
header. Both of these values are None until a complete header is
|
||||
parsed.
|
||||
|
||||
@raises InvalidProxyHeader: If the bytes fed to the parser create an
|
||||
invalid PROXY header.
|
||||
"""
|
||||
self.buffer += data
|
||||
if len(self.buffer) < 16:
|
||||
raise InvalidProxyHeader()
|
||||
|
||||
size = struct.unpack("!H", self.buffer[14:16])[0] + 16
|
||||
if len(self.buffer) < size:
|
||||
return (None, None)
|
||||
|
||||
header, remaining = self.buffer[:size], self.buffer[size:]
|
||||
self.buffer = b""
|
||||
info = self.parse(header)
|
||||
return (info, remaining)
|
||||
|
||||
@staticmethod
|
||||
def _bytesToIPv4(bytestring: bytes) -> bytes:
|
||||
"""
|
||||
Convert packed 32-bit IPv4 address bytes into a dotted-quad ASCII bytes
|
||||
representation of that address.
|
||||
|
||||
@param bytestring: 4 octets representing an IPv4 address.
|
||||
@type bytestring: L{bytes}
|
||||
|
||||
@return: a dotted-quad notation IPv4 address.
|
||||
@rtype: L{bytes}
|
||||
"""
|
||||
return b".".join(
|
||||
("%i" % (ord(b),)).encode("ascii") for b in compat.iterbytes(bytestring)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _bytesToIPv6(bytestring: bytes) -> bytes:
|
||||
"""
|
||||
Convert packed 128-bit IPv6 address bytes into a colon-separated ASCII
|
||||
bytes representation of that address.
|
||||
|
||||
@param bytestring: 16 octets representing an IPv6 address.
|
||||
@type bytestring: L{bytes}
|
||||
|
||||
@return: a dotted-quad notation IPv6 address.
|
||||
@rtype: L{bytes}
|
||||
"""
|
||||
hexString = binascii.b2a_hex(bytestring)
|
||||
return b":".join(
|
||||
(f"{int(hexString[b : b + 4], 16):x}").encode("ascii")
|
||||
for b in range(0, 32, 4)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def parse(cls, line: bytes) -> _info.ProxyInfo:
|
||||
"""
|
||||
Parse a bytestring as a full PROXY protocol header.
|
||||
|
||||
@param line: A bytestring that represents a valid HAProxy PROXY
|
||||
protocol version 2 header.
|
||||
@type line: bytes
|
||||
|
||||
@return: A L{_interfaces.IProxyInfo} containing the
|
||||
parsed data.
|
||||
|
||||
@raises InvalidProxyHeader: If the bytestring does not represent a
|
||||
valid PROXY header.
|
||||
"""
|
||||
prefix = line[:12]
|
||||
addrInfo = None
|
||||
with convertError(IndexError, InvalidProxyHeader):
|
||||
# Use single value slices to ensure bytestring values are returned
|
||||
# instead of int in PY3.
|
||||
versionCommand = ord(line[12:13])
|
||||
familyProto = ord(line[13:14])
|
||||
|
||||
if prefix != cls.PREFIX:
|
||||
raise InvalidProxyHeader()
|
||||
|
||||
version, command = versionCommand & _HIGH, versionCommand & _LOW
|
||||
if version not in cls.VERSIONS or command not in cls.COMMANDS:
|
||||
raise InvalidProxyHeader()
|
||||
|
||||
if cls.COMMANDS[command] == _LOCALCOMMAND:
|
||||
return _info.ProxyInfo(line, None, None)
|
||||
|
||||
family, netproto = familyProto & _HIGH, familyProto & _LOW
|
||||
with convertError(ValueError, InvalidNetworkProtocol):
|
||||
family = NetFamily.lookupByValue(family)
|
||||
netproto = NetProtocol.lookupByValue(netproto)
|
||||
if family is NetFamily.UNSPEC or netproto is NetProtocol.UNSPEC:
|
||||
return _info.ProxyInfo(line, None, None)
|
||||
|
||||
addressFormat = cls.ADDRESSFORMATS[familyProto]
|
||||
addrInfo = line[16 : 16 + struct.calcsize(addressFormat)]
|
||||
if family is NetFamily.UNIX:
|
||||
with convertError(struct.error, MissingAddressData):
|
||||
source, dest = struct.unpack(addressFormat, addrInfo)
|
||||
return _info.ProxyInfo(
|
||||
line,
|
||||
address.UNIXAddress(source.rstrip(b"\x00")),
|
||||
address.UNIXAddress(dest.rstrip(b"\x00")),
|
||||
)
|
||||
|
||||
addrType: Union[Literal["TCP"], Literal["UDP"]] = "TCP"
|
||||
if netproto is NetProtocol.DGRAM:
|
||||
addrType = "UDP"
|
||||
addrCls: Union[
|
||||
Type[address.IPv4Address], Type[address.IPv6Address]
|
||||
] = address.IPv4Address
|
||||
addrParser: Callable[[bytes], bytes] = cls._bytesToIPv4
|
||||
if family is NetFamily.INET6:
|
||||
addrCls = address.IPv6Address
|
||||
addrParser = cls._bytesToIPv6
|
||||
|
||||
with convertError(struct.error, MissingAddressData):
|
||||
info = struct.unpack(addressFormat, addrInfo)
|
||||
source, dest, sPort, dPort = info
|
||||
|
||||
return _info.ProxyInfo(
|
||||
line,
|
||||
addrCls(addrType, addrParser(source).decode(), sPort),
|
||||
addrCls(addrType, addrParser(dest).decode(), dPort),
|
||||
)
|
||||
@@ -0,0 +1,109 @@
|
||||
# -*- test-case-name: twisted.protocols.haproxy.test.test_wrapper -*-
|
||||
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Protocol wrapper that provides HAProxy PROXY protocol support.
|
||||
"""
|
||||
from typing import Optional, Union
|
||||
|
||||
from twisted.internet import interfaces
|
||||
from twisted.internet.endpoints import _WrapperServerEndpoint
|
||||
from twisted.protocols import policies
|
||||
from . import _info
|
||||
from ._exceptions import InvalidProxyHeader
|
||||
from ._v1parser import V1Parser
|
||||
from ._v2parser import V2Parser
|
||||
|
||||
|
||||
class HAProxyProtocolWrapper(policies.ProtocolWrapper):
|
||||
"""
|
||||
A Protocol wrapper that provides HAProxy support.
|
||||
|
||||
This protocol reads the PROXY stream header, v1 or v2, parses the provided
|
||||
connection data, and modifies the behavior of getPeer and getHost to return
|
||||
the data provided by the PROXY header.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, factory: policies.WrappingFactory, wrappedProtocol: interfaces.IProtocol
|
||||
):
|
||||
super().__init__(factory, wrappedProtocol)
|
||||
self._proxyInfo: Optional[_info.ProxyInfo] = None
|
||||
self._parser: Union[V2Parser, V1Parser, None] = None
|
||||
|
||||
def dataReceived(self, data: bytes) -> None:
|
||||
if self._proxyInfo is not None:
|
||||
return self.wrappedProtocol.dataReceived(data)
|
||||
|
||||
parser = self._parser
|
||||
if parser is None:
|
||||
if (
|
||||
len(data) >= 16
|
||||
and data[:12] == V2Parser.PREFIX
|
||||
and ord(data[12:13]) & 0b11110000 == 0x20
|
||||
):
|
||||
self._parser = parser = V2Parser()
|
||||
elif len(data) >= 8 and data[:5] == V1Parser.PROXYSTR:
|
||||
self._parser = parser = V1Parser()
|
||||
else:
|
||||
self.loseConnection()
|
||||
return None
|
||||
|
||||
try:
|
||||
self._proxyInfo, remaining = parser.feed(data)
|
||||
if remaining:
|
||||
self.wrappedProtocol.dataReceived(remaining)
|
||||
except InvalidProxyHeader:
|
||||
self.loseConnection()
|
||||
|
||||
def getPeer(self) -> interfaces.IAddress:
|
||||
if self._proxyInfo and self._proxyInfo.source:
|
||||
return self._proxyInfo.source
|
||||
assert self.transport
|
||||
return self.transport.getPeer()
|
||||
|
||||
def getHost(self) -> interfaces.IAddress:
|
||||
if self._proxyInfo and self._proxyInfo.destination:
|
||||
return self._proxyInfo.destination
|
||||
assert self.transport
|
||||
return self.transport.getHost()
|
||||
|
||||
|
||||
class HAProxyWrappingFactory(policies.WrappingFactory):
|
||||
"""
|
||||
A Factory wrapper that adds PROXY protocol support to connections.
|
||||
"""
|
||||
|
||||
protocol = HAProxyProtocolWrapper
|
||||
|
||||
def logPrefix(self) -> str:
|
||||
"""
|
||||
Annotate the wrapped factory's log prefix with some text indicating
|
||||
the PROXY protocol is in use.
|
||||
|
||||
@rtype: C{str}
|
||||
"""
|
||||
if interfaces.ILoggingContext.providedBy(self.wrappedFactory):
|
||||
logPrefix = self.wrappedFactory.logPrefix()
|
||||
else:
|
||||
logPrefix = self.wrappedFactory.__class__.__name__
|
||||
return f"{logPrefix} (PROXY)"
|
||||
|
||||
|
||||
def proxyEndpoint(
|
||||
wrappedEndpoint: interfaces.IStreamServerEndpoint,
|
||||
) -> _WrapperServerEndpoint:
|
||||
"""
|
||||
Wrap an endpoint with PROXY protocol support, so that the transport's
|
||||
C{getHost} and C{getPeer} methods reflect the attributes of the proxied
|
||||
connection rather than the underlying connection.
|
||||
|
||||
@param wrappedEndpoint: The underlying listening endpoint.
|
||||
@type wrappedEndpoint: L{IStreamServerEndpoint}
|
||||
|
||||
@return: a new listening endpoint that speaks the PROXY protocol.
|
||||
@rtype: L{IStreamServerEndpoint}
|
||||
"""
|
||||
return _WrapperServerEndpoint(wrappedEndpoint, HAProxyWrappingFactory)
|
||||
@@ -0,0 +1,7 @@
|
||||
# -*- test-case-name: twisted.protocols.haproxy.test -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Unit tests for L{twisted.protocols.haproxy}.
|
||||
"""
|
||||
@@ -0,0 +1,133 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.protocols.haproxy._parser}.
|
||||
"""
|
||||
from typing import Type, Union
|
||||
|
||||
from twisted.internet.endpoints import (
|
||||
TCP4ServerEndpoint,
|
||||
TCP6ServerEndpoint,
|
||||
UNIXServerEndpoint,
|
||||
_parse as parseEndpoint,
|
||||
_WrapperServerEndpoint,
|
||||
serverFromString,
|
||||
)
|
||||
from twisted.internet.testing import MemoryReactor
|
||||
from twisted.trial.unittest import SynchronousTestCase as TestCase
|
||||
from .._parser import unparseEndpoint
|
||||
from .._wrapper import HAProxyWrappingFactory
|
||||
|
||||
|
||||
class UnparseEndpointTests(TestCase):
|
||||
"""
|
||||
Tests to ensure that un-parsing an endpoint string round trips through
|
||||
escaping properly.
|
||||
"""
|
||||
|
||||
def check(self, input: str) -> None:
|
||||
"""
|
||||
Check that the input unparses into the output, raising an assertion
|
||||
error if it doesn't.
|
||||
|
||||
@param input: an input in endpoint-string-description format. (To
|
||||
ensure determinism, keyword arguments should be in alphabetical
|
||||
order.)
|
||||
@type input: native L{str}
|
||||
"""
|
||||
self.assertEqual(unparseEndpoint(*parseEndpoint(input)), input)
|
||||
|
||||
def test_basicUnparse(self) -> None:
|
||||
"""
|
||||
An individual word.
|
||||
"""
|
||||
self.check("word")
|
||||
|
||||
def test_multipleArguments(self) -> None:
|
||||
"""
|
||||
Multiple arguments.
|
||||
"""
|
||||
self.check("one:two")
|
||||
|
||||
def test_keywords(self) -> None:
|
||||
"""
|
||||
Keyword arguments.
|
||||
"""
|
||||
self.check("aleph=one:bet=two")
|
||||
|
||||
def test_colonInArgument(self) -> None:
|
||||
"""
|
||||
Escaped ":".
|
||||
"""
|
||||
self.check("hello\\:colon\\:world")
|
||||
|
||||
def test_colonInKeywordValue(self) -> None:
|
||||
"""
|
||||
Escaped ":" in keyword value.
|
||||
"""
|
||||
self.check("hello=\\:")
|
||||
|
||||
def test_colonInKeywordName(self) -> None:
|
||||
"""
|
||||
Escaped ":" in keyword name.
|
||||
"""
|
||||
self.check("\\:=hello")
|
||||
|
||||
|
||||
class HAProxyServerParserTests(TestCase):
|
||||
"""
|
||||
Tests that the parser generates the correct endpoints.
|
||||
"""
|
||||
|
||||
def onePrefix(
|
||||
self,
|
||||
description: str,
|
||||
expectedClass: Union[
|
||||
Type[TCP4ServerEndpoint],
|
||||
Type[TCP6ServerEndpoint],
|
||||
Type[UNIXServerEndpoint],
|
||||
],
|
||||
) -> _WrapperServerEndpoint:
|
||||
"""
|
||||
Test the C{haproxy} enpdoint prefix against one sub-endpoint type.
|
||||
|
||||
@param description: A string endpoint description beginning with
|
||||
C{haproxy}.
|
||||
@type description: native L{str}
|
||||
|
||||
@param expectedClass: the expected sub-endpoint class given the
|
||||
description.
|
||||
@type expectedClass: L{type}
|
||||
|
||||
@return: the parsed endpoint
|
||||
@rtype: L{IStreamServerEndpoint}
|
||||
|
||||
@raise twisted.trial.unittest.Failtest: if the parsed endpoint doesn't
|
||||
match expectations.
|
||||
"""
|
||||
reactor = MemoryReactor()
|
||||
endpoint = serverFromString(reactor, description)
|
||||
self.assertIsInstance(endpoint, _WrapperServerEndpoint)
|
||||
assert isinstance(endpoint, _WrapperServerEndpoint)
|
||||
self.assertIsInstance(endpoint._wrappedEndpoint, expectedClass)
|
||||
self.assertIs(endpoint._wrapperFactory, HAProxyWrappingFactory)
|
||||
return endpoint
|
||||
|
||||
def test_tcp4(self) -> None:
|
||||
"""
|
||||
Test if the parser generates a wrapped TCP4 endpoint.
|
||||
"""
|
||||
self.onePrefix("haproxy:tcp:8080", TCP4ServerEndpoint)
|
||||
|
||||
def test_tcp6(self) -> None:
|
||||
"""
|
||||
Test if the parser generates a wrapped TCP6 endpoint.
|
||||
"""
|
||||
self.onePrefix("haproxy:tcp6:8080", TCP6ServerEndpoint)
|
||||
|
||||
def test_unix(self) -> None:
|
||||
"""
|
||||
Test if the parser generates a wrapped UNIX endpoint.
|
||||
"""
|
||||
self.onePrefix("haproxy:unix:address=/tmp/socket", UNIXServerEndpoint)
|
||||
@@ -0,0 +1,149 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Test cases for L{twisted.protocols.haproxy.V1Parser}.
|
||||
"""
|
||||
|
||||
from twisted.internet import address
|
||||
from twisted.trial import unittest
|
||||
from .. import _v1parser
|
||||
from .._exceptions import InvalidNetworkProtocol, InvalidProxyHeader, MissingAddressData
|
||||
|
||||
|
||||
class V1ParserTests(unittest.TestCase):
|
||||
"""
|
||||
Test L{twisted.protocols.haproxy.V1Parser} behaviour.
|
||||
"""
|
||||
|
||||
def test_missingPROXYHeaderValue(self) -> None:
|
||||
"""
|
||||
Test that an exception is raised when the PROXY header is missing.
|
||||
"""
|
||||
self.assertRaises(
|
||||
InvalidProxyHeader,
|
||||
_v1parser.V1Parser.parse,
|
||||
b"NOTPROXY ",
|
||||
)
|
||||
|
||||
def test_invalidNetworkProtocol(self) -> None:
|
||||
"""
|
||||
Test that an exception is raised when the proto is not TCP or UNKNOWN.
|
||||
"""
|
||||
self.assertRaises(
|
||||
InvalidNetworkProtocol,
|
||||
_v1parser.V1Parser.parse,
|
||||
b"PROXY WUTPROTO ",
|
||||
)
|
||||
|
||||
def test_missingSourceData(self) -> None:
|
||||
"""
|
||||
Test that an exception is raised when the proto has no source data.
|
||||
"""
|
||||
self.assertRaises(
|
||||
MissingAddressData,
|
||||
_v1parser.V1Parser.parse,
|
||||
b"PROXY TCP4 ",
|
||||
)
|
||||
|
||||
def test_missingDestData(self) -> None:
|
||||
"""
|
||||
Test that an exception is raised when the proto has no destination.
|
||||
"""
|
||||
self.assertRaises(
|
||||
MissingAddressData,
|
||||
_v1parser.V1Parser.parse,
|
||||
b"PROXY TCP4 127.0.0.1 8080 8888",
|
||||
)
|
||||
|
||||
def test_fullParsingSuccess(self) -> None:
|
||||
"""
|
||||
Test that parsing is successful for a PROXY header.
|
||||
"""
|
||||
info = _v1parser.V1Parser.parse(
|
||||
b"PROXY TCP4 127.0.0.1 127.0.0.1 8080 8888",
|
||||
)
|
||||
self.assertIsInstance(info.source, address.IPv4Address)
|
||||
assert isinstance(info.source, address.IPv4Address)
|
||||
assert isinstance(info.destination, address.IPv4Address) # type: ignore[unreachable]
|
||||
self.assertEqual(info.source.host, "127.0.0.1")
|
||||
self.assertEqual(info.source.port, 8080)
|
||||
self.assertEqual(info.destination.host, "127.0.0.1")
|
||||
self.assertEqual(info.destination.port, 8888)
|
||||
|
||||
def test_fullParsingSuccess_IPv6(self) -> None:
|
||||
"""
|
||||
Test that parsing is successful for an IPv6 PROXY header.
|
||||
"""
|
||||
info = _v1parser.V1Parser.parse(
|
||||
b"PROXY TCP6 ::1 ::1 8080 8888",
|
||||
)
|
||||
self.assertIsInstance(info.source, address.IPv6Address)
|
||||
assert isinstance(info.source, address.IPv6Address)
|
||||
assert isinstance(info.destination, address.IPv6Address) # type: ignore[unreachable]
|
||||
self.assertEqual(info.source.host, "::1")
|
||||
self.assertEqual(info.source.port, 8080)
|
||||
self.assertEqual(info.destination.host, "::1")
|
||||
self.assertEqual(info.destination.port, 8888)
|
||||
|
||||
def test_fullParsingSuccess_UNKNOWN(self) -> None:
|
||||
"""
|
||||
Test that parsing is successful for a UNKNOWN PROXY header.
|
||||
"""
|
||||
info = _v1parser.V1Parser.parse(
|
||||
b"PROXY UNKNOWN anything could go here",
|
||||
)
|
||||
self.assertIsNone(info.source)
|
||||
self.assertIsNone(info.destination)
|
||||
|
||||
def test_feedParsing(self) -> None:
|
||||
"""
|
||||
Test that parsing happens when fed a complete line.
|
||||
"""
|
||||
parser = _v1parser.V1Parser()
|
||||
info, remaining = parser.feed(b"PROXY TCP4 127.0.0.1 127.0.0.1 ")
|
||||
self.assertFalse(info)
|
||||
self.assertFalse(remaining)
|
||||
info, remaining = parser.feed(b"8080 8888")
|
||||
self.assertFalse(info)
|
||||
self.assertFalse(remaining)
|
||||
info, remaining = parser.feed(b"\r\n")
|
||||
self.assertFalse(remaining)
|
||||
assert remaining is not None
|
||||
assert info is not None
|
||||
self.assertIsInstance(info.source, address.IPv4Address)
|
||||
assert isinstance(info.source, address.IPv4Address)
|
||||
assert isinstance(info.destination, address.IPv4Address) # type: ignore[unreachable]
|
||||
self.assertEqual(info.source.host, "127.0.0.1")
|
||||
self.assertEqual(info.source.port, 8080)
|
||||
self.assertEqual(info.destination.host, "127.0.0.1")
|
||||
self.assertEqual(info.destination.port, 8888)
|
||||
|
||||
def test_feedParsingTooLong(self) -> None:
|
||||
"""
|
||||
Test that parsing fails if no newline is found in 108 bytes.
|
||||
"""
|
||||
parser = _v1parser.V1Parser()
|
||||
info, remaining = parser.feed(b"PROXY TCP4 127.0.0.1 127.0.0.1 ")
|
||||
self.assertFalse(info)
|
||||
self.assertFalse(remaining)
|
||||
info, remaining = parser.feed(b"8080 8888")
|
||||
self.assertFalse(info)
|
||||
self.assertFalse(remaining)
|
||||
self.assertRaises(
|
||||
InvalidProxyHeader,
|
||||
parser.feed,
|
||||
b" " * 100,
|
||||
)
|
||||
|
||||
def test_feedParsingOverflow(self) -> None:
|
||||
"""
|
||||
Test that parsing leaves overflow bytes in the buffer.
|
||||
"""
|
||||
parser = _v1parser.V1Parser()
|
||||
info, remaining = parser.feed(
|
||||
b"PROXY TCP4 127.0.0.1 127.0.0.1 8080 8888\r\nHTTP/1.1 GET /\r\n",
|
||||
)
|
||||
self.assertTrue(info)
|
||||
self.assertEqual(remaining, b"HTTP/1.1 GET /\r\n")
|
||||
self.assertFalse(parser.buffer)
|
||||
@@ -0,0 +1,368 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Test cases for L{twisted.protocols.haproxy.V2Parser}.
|
||||
"""
|
||||
|
||||
from twisted.internet import address
|
||||
from twisted.trial import unittest
|
||||
from .. import _v2parser
|
||||
from .._exceptions import InvalidProxyHeader
|
||||
|
||||
V2_SIGNATURE = b"\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A"
|
||||
|
||||
|
||||
def _makeHeaderIPv6(
|
||||
sig: bytes = V2_SIGNATURE,
|
||||
verCom: bytes = b"\x21",
|
||||
famProto: bytes = b"\x21",
|
||||
addrLength: bytes = b"\x00\x24",
|
||||
addrs: bytes = ((b"\x00" * 15) + b"\x01") * 2,
|
||||
ports: bytes = b"\x1F\x90\x22\xB8",
|
||||
) -> bytes:
|
||||
"""
|
||||
Construct a version 2 IPv6 header with custom bytes.
|
||||
|
||||
@param sig: The protocol signature; defaults to valid L{V2_SIGNATURE}.
|
||||
@type sig: L{bytes}
|
||||
|
||||
@param verCom: Protocol version and command. Defaults to V2 PROXY.
|
||||
@type verCom: L{bytes}
|
||||
|
||||
@param famProto: Address family and protocol. Defaults to AF_INET6/STREAM.
|
||||
@type famProto: L{bytes}
|
||||
|
||||
@param addrLength: Network-endian byte length of payload. Defaults to
|
||||
description of default addrs/ports.
|
||||
@type addrLength: L{bytes}
|
||||
|
||||
@param addrs: Address payload. Defaults to C{::1} for source and
|
||||
destination.
|
||||
@type addrs: L{bytes}
|
||||
|
||||
@param ports: Source and destination ports. Defaults to 8080 for source
|
||||
8888 for destination.
|
||||
@type ports: L{bytes}
|
||||
|
||||
@return: A packet with header, addresses, and ports.
|
||||
@rtype: L{bytes}
|
||||
"""
|
||||
return sig + verCom + famProto + addrLength + addrs + ports
|
||||
|
||||
|
||||
def _makeHeaderIPv4(
|
||||
sig: bytes = V2_SIGNATURE,
|
||||
verCom: bytes = b"\x21",
|
||||
famProto: bytes = b"\x11",
|
||||
addrLength: bytes = b"\x00\x0C",
|
||||
addrs: bytes = b"\x7F\x00\x00\x01\x7F\x00\x00\x01",
|
||||
ports: bytes = b"\x1F\x90\x22\xB8",
|
||||
) -> bytes:
|
||||
"""
|
||||
Construct a version 2 IPv4 header with custom bytes.
|
||||
|
||||
@param sig: The protocol signature; defaults to valid L{V2_SIGNATURE}.
|
||||
@type sig: L{bytes}
|
||||
|
||||
@param verCom: Protocol version and command. Defaults to V2 PROXY.
|
||||
@type verCom: L{bytes}
|
||||
|
||||
@param famProto: Address family and protocol. Defaults to AF_INET/STREAM.
|
||||
@type famProto: L{bytes}
|
||||
|
||||
@param addrLength: Network-endian byte length of payload. Defaults to
|
||||
description of default addrs/ports.
|
||||
@type addrLength: L{bytes}
|
||||
|
||||
@param addrs: Address payload. Defaults to 127.0.0.1 for source and
|
||||
destination.
|
||||
@type addrs: L{bytes}
|
||||
|
||||
@param ports: Source and destination ports. Defaults to 8080 for source
|
||||
8888 for destination.
|
||||
@type ports: L{bytes}
|
||||
|
||||
@return: A packet with header, addresses, and ports.
|
||||
@rtype: L{bytes}
|
||||
"""
|
||||
return sig + verCom + famProto + addrLength + addrs + ports
|
||||
|
||||
|
||||
def _makeHeaderUnix(
|
||||
sig: bytes = V2_SIGNATURE,
|
||||
verCom: bytes = b"\x21",
|
||||
famProto: bytes = b"\x31",
|
||||
addrLength: bytes = b"\x00\xD8",
|
||||
addrs: bytes = (
|
||||
b"\x2F\x68\x6F\x6D\x65\x2F\x74\x65\x73\x74\x73\x2F"
|
||||
b"\x6D\x79\x73\x6F\x63\x6B\x65\x74\x73\x2F\x73\x6F"
|
||||
b"\x63\x6B" + (b"\x00" * 82)
|
||||
)
|
||||
* 2,
|
||||
) -> bytes:
|
||||
"""
|
||||
Construct a version 2 IPv4 header with custom bytes.
|
||||
|
||||
@param sig: The protocol signature; defaults to valid L{V2_SIGNATURE}.
|
||||
@type sig: L{bytes}
|
||||
|
||||
@param verCom: Protocol version and command. Defaults to V2 PROXY.
|
||||
@type verCom: L{bytes}
|
||||
|
||||
@param famProto: Address family and protocol. Defaults to AF_UNIX/STREAM.
|
||||
@type famProto: L{bytes}
|
||||
|
||||
@param addrLength: Network-endian byte length of payload. Defaults to 108
|
||||
bytes for 2 null terminated paths.
|
||||
@type addrLength: L{bytes}
|
||||
|
||||
@param addrs: Address payload. Defaults to C{/home/tests/mysockets/sock}
|
||||
for source and destination paths.
|
||||
@type addrs: L{bytes}
|
||||
|
||||
@return: A packet with header, addresses, and8 ports.
|
||||
@rtype: L{bytes}
|
||||
"""
|
||||
return sig + verCom + famProto + addrLength + addrs
|
||||
|
||||
|
||||
class V2ParserTests(unittest.TestCase):
|
||||
"""
|
||||
Test L{twisted.protocols.haproxy.V2Parser} behaviour.
|
||||
"""
|
||||
|
||||
def test_happyPathIPv4(self) -> None:
|
||||
"""
|
||||
Test if a well formed IPv4 header is parsed without error.
|
||||
"""
|
||||
header = _makeHeaderIPv4()
|
||||
self.assertTrue(_v2parser.V2Parser.parse(header))
|
||||
|
||||
def test_happyPathIPv6(self) -> None:
|
||||
"""
|
||||
Test if a well formed IPv6 header is parsed without error.
|
||||
"""
|
||||
header = _makeHeaderIPv6()
|
||||
self.assertTrue(_v2parser.V2Parser.parse(header))
|
||||
|
||||
def test_happyPathUnix(self) -> None:
|
||||
"""
|
||||
Test if a well formed UNIX header is parsed without error.
|
||||
"""
|
||||
header = _makeHeaderUnix()
|
||||
self.assertTrue(_v2parser.V2Parser.parse(header))
|
||||
|
||||
def test_invalidSignature(self) -> None:
|
||||
"""
|
||||
Test if an invalid signature block raises InvalidProxyError.
|
||||
"""
|
||||
header = _makeHeaderIPv4(sig=b"\x00" * 12)
|
||||
self.assertRaises(
|
||||
InvalidProxyHeader,
|
||||
_v2parser.V2Parser.parse,
|
||||
header,
|
||||
)
|
||||
|
||||
def test_invalidVersion(self) -> None:
|
||||
"""
|
||||
Test if an invalid version raises InvalidProxyError.
|
||||
"""
|
||||
header = _makeHeaderIPv4(verCom=b"\x11")
|
||||
self.assertRaises(
|
||||
InvalidProxyHeader,
|
||||
_v2parser.V2Parser.parse,
|
||||
header,
|
||||
)
|
||||
|
||||
def test_invalidCommand(self) -> None:
|
||||
"""
|
||||
Test if an invalid command raises InvalidProxyError.
|
||||
"""
|
||||
header = _makeHeaderIPv4(verCom=b"\x23")
|
||||
self.assertRaises(
|
||||
InvalidProxyHeader,
|
||||
_v2parser.V2Parser.parse,
|
||||
header,
|
||||
)
|
||||
|
||||
def test_invalidFamily(self) -> None:
|
||||
"""
|
||||
Test if an invalid family raises InvalidProxyError.
|
||||
"""
|
||||
header = _makeHeaderIPv4(famProto=b"\x40")
|
||||
self.assertRaises(
|
||||
InvalidProxyHeader,
|
||||
_v2parser.V2Parser.parse,
|
||||
header,
|
||||
)
|
||||
|
||||
def test_invalidProto(self) -> None:
|
||||
"""
|
||||
Test if an invalid protocol raises InvalidProxyError.
|
||||
"""
|
||||
header = _makeHeaderIPv4(famProto=b"\x24")
|
||||
self.assertRaises(
|
||||
InvalidProxyHeader,
|
||||
_v2parser.V2Parser.parse,
|
||||
header,
|
||||
)
|
||||
|
||||
def test_localCommandIpv4(self) -> None:
|
||||
"""
|
||||
Test that local does not return endpoint data for IPv4 connections.
|
||||
"""
|
||||
header = _makeHeaderIPv4(verCom=b"\x20")
|
||||
info = _v2parser.V2Parser.parse(header)
|
||||
self.assertFalse(info.source)
|
||||
self.assertFalse(info.destination)
|
||||
|
||||
def test_localCommandIpv6(self) -> None:
|
||||
"""
|
||||
Test that local does not return endpoint data for IPv6 connections.
|
||||
"""
|
||||
header = _makeHeaderIPv6(verCom=b"\x20")
|
||||
info = _v2parser.V2Parser.parse(header)
|
||||
self.assertFalse(info.source)
|
||||
self.assertFalse(info.destination)
|
||||
|
||||
def test_localCommandUnix(self) -> None:
|
||||
"""
|
||||
Test that local does not return endpoint data for UNIX connections.
|
||||
"""
|
||||
header = _makeHeaderUnix(verCom=b"\x20")
|
||||
info = _v2parser.V2Parser.parse(header)
|
||||
self.assertFalse(info.source)
|
||||
self.assertFalse(info.destination)
|
||||
|
||||
def test_proxyCommandIpv4(self) -> None:
|
||||
"""
|
||||
Test that proxy returns endpoint data for IPv4 connections.
|
||||
"""
|
||||
header = _makeHeaderIPv4(verCom=b"\x21")
|
||||
info = _v2parser.V2Parser.parse(header)
|
||||
self.assertTrue(info.source)
|
||||
self.assertIsInstance(info.source, address.IPv4Address)
|
||||
self.assertTrue(info.destination)
|
||||
self.assertIsInstance(info.destination, address.IPv4Address)
|
||||
|
||||
def test_proxyCommandIpv6(self) -> None:
|
||||
"""
|
||||
Test that proxy returns endpoint data for IPv6 connections.
|
||||
"""
|
||||
header = _makeHeaderIPv6(verCom=b"\x21")
|
||||
info = _v2parser.V2Parser.parse(header)
|
||||
self.assertTrue(info.source)
|
||||
self.assertIsInstance(info.source, address.IPv6Address)
|
||||
self.assertTrue(info.destination)
|
||||
self.assertIsInstance(info.destination, address.IPv6Address)
|
||||
|
||||
def test_proxyCommandUnix(self) -> None:
|
||||
"""
|
||||
Test that proxy returns endpoint data for UNIX connections.
|
||||
"""
|
||||
header = _makeHeaderUnix(verCom=b"\x21")
|
||||
info = _v2parser.V2Parser.parse(header)
|
||||
self.assertTrue(info.source)
|
||||
self.assertIsInstance(info.source, address.UNIXAddress)
|
||||
self.assertTrue(info.destination)
|
||||
self.assertIsInstance(info.destination, address.UNIXAddress)
|
||||
|
||||
def test_unspecFamilyIpv4(self) -> None:
|
||||
"""
|
||||
Test that UNSPEC does not return endpoint data for IPv4 connections.
|
||||
"""
|
||||
header = _makeHeaderIPv4(famProto=b"\x01")
|
||||
info = _v2parser.V2Parser.parse(header)
|
||||
self.assertFalse(info.source)
|
||||
self.assertFalse(info.destination)
|
||||
|
||||
def test_unspecFamilyIpv6(self) -> None:
|
||||
"""
|
||||
Test that UNSPEC does not return endpoint data for IPv6 connections.
|
||||
"""
|
||||
header = _makeHeaderIPv6(famProto=b"\x01")
|
||||
info = _v2parser.V2Parser.parse(header)
|
||||
self.assertFalse(info.source)
|
||||
self.assertFalse(info.destination)
|
||||
|
||||
def test_unspecFamilyUnix(self) -> None:
|
||||
"""
|
||||
Test that UNSPEC does not return endpoint data for UNIX connections.
|
||||
"""
|
||||
header = _makeHeaderUnix(famProto=b"\x01")
|
||||
info = _v2parser.V2Parser.parse(header)
|
||||
self.assertFalse(info.source)
|
||||
self.assertFalse(info.destination)
|
||||
|
||||
def test_unspecProtoIpv4(self) -> None:
|
||||
"""
|
||||
Test that UNSPEC does not return endpoint data for IPv4 connections.
|
||||
"""
|
||||
header = _makeHeaderIPv4(famProto=b"\x10")
|
||||
info = _v2parser.V2Parser.parse(header)
|
||||
self.assertFalse(info.source)
|
||||
self.assertFalse(info.destination)
|
||||
|
||||
def test_unspecProtoIpv6(self) -> None:
|
||||
"""
|
||||
Test that UNSPEC does not return endpoint data for IPv6 connections.
|
||||
"""
|
||||
header = _makeHeaderIPv6(famProto=b"\x20")
|
||||
info = _v2parser.V2Parser.parse(header)
|
||||
self.assertFalse(info.source)
|
||||
self.assertFalse(info.destination)
|
||||
|
||||
def test_unspecProtoUnix(self) -> None:
|
||||
"""
|
||||
Test that UNSPEC does not return endpoint data for UNIX connections.
|
||||
"""
|
||||
header = _makeHeaderUnix(famProto=b"\x30")
|
||||
info = _v2parser.V2Parser.parse(header)
|
||||
self.assertFalse(info.source)
|
||||
self.assertFalse(info.destination)
|
||||
|
||||
def test_overflowIpv4(self) -> None:
|
||||
"""
|
||||
Test that overflow bits are preserved during feed parsing for IPv4.
|
||||
"""
|
||||
testValue = b"TEST DATA\r\n\r\nTEST DATA"
|
||||
header = _makeHeaderIPv4() + testValue
|
||||
parser = _v2parser.V2Parser()
|
||||
info, overflow = parser.feed(header)
|
||||
self.assertTrue(info)
|
||||
self.assertEqual(overflow, testValue)
|
||||
|
||||
def test_overflowIpv6(self) -> None:
|
||||
"""
|
||||
Test that overflow bits are preserved during feed parsing for IPv6.
|
||||
"""
|
||||
testValue = b"TEST DATA\r\n\r\nTEST DATA"
|
||||
header = _makeHeaderIPv6() + testValue
|
||||
parser = _v2parser.V2Parser()
|
||||
info, overflow = parser.feed(header)
|
||||
self.assertTrue(info)
|
||||
self.assertEqual(overflow, testValue)
|
||||
|
||||
def test_overflowUnix(self) -> None:
|
||||
"""
|
||||
Test that overflow bits are preserved during feed parsing for Unix.
|
||||
"""
|
||||
testValue = b"TEST DATA\r\n\r\nTEST DATA"
|
||||
header = _makeHeaderUnix() + testValue
|
||||
parser = _v2parser.V2Parser()
|
||||
info, overflow = parser.feed(header)
|
||||
self.assertTrue(info)
|
||||
self.assertEqual(overflow, testValue)
|
||||
|
||||
def test_segmentTooSmall(self) -> None:
|
||||
"""
|
||||
Test that an initial payload of less than 16 bytes fails.
|
||||
"""
|
||||
testValue = b"NEEDMOREDATA"
|
||||
parser = _v2parser.V2Parser()
|
||||
self.assertRaises(
|
||||
InvalidProxyHeader,
|
||||
parser.feed,
|
||||
testValue,
|
||||
)
|
||||
@@ -0,0 +1,375 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Test cases for L{twisted.protocols.haproxy.HAProxyProtocol}.
|
||||
"""
|
||||
from typing import Optional
|
||||
from unittest import mock
|
||||
|
||||
from twisted.internet import address
|
||||
from twisted.internet.protocol import Factory, Protocol
|
||||
from twisted.internet.testing import StringTransportWithDisconnection
|
||||
from twisted.trial import unittest
|
||||
from .._wrapper import HAProxyWrappingFactory
|
||||
|
||||
|
||||
class StaticProtocol(Protocol):
|
||||
"""
|
||||
Protocol stand-in that maintains test state.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.source: Optional[address.IAddress] = None
|
||||
self.destination: Optional[address.IAddress] = None
|
||||
self.data = b""
|
||||
self.disconnected = False
|
||||
|
||||
def dataReceived(self, data: bytes) -> None:
|
||||
assert self.transport
|
||||
self.source = self.transport.getPeer()
|
||||
self.destination = self.transport.getHost()
|
||||
self.data += data
|
||||
|
||||
|
||||
class HAProxyWrappingFactoryV1Tests(unittest.TestCase):
|
||||
"""
|
||||
Test L{twisted.protocols.haproxy.HAProxyWrappingFactory} with v1 PROXY
|
||||
headers.
|
||||
"""
|
||||
|
||||
def test_invalidHeaderDisconnects(self) -> None:
|
||||
"""
|
||||
Test if invalid headers result in connectionLost events.
|
||||
"""
|
||||
factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol))
|
||||
proto = factory.buildProtocol(
|
||||
address.IPv4Address("TCP", "127.1.1.1", 8080),
|
||||
)
|
||||
transport = StringTransportWithDisconnection()
|
||||
transport.protocol = proto
|
||||
proto.makeConnection(transport)
|
||||
proto.dataReceived(b"NOTPROXY anything can go here\r\n")
|
||||
self.assertFalse(transport.connected)
|
||||
|
||||
def test_invalidPartialHeaderDisconnects(self) -> None:
|
||||
"""
|
||||
Test if invalid headers result in connectionLost events.
|
||||
"""
|
||||
factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol))
|
||||
proto = factory.buildProtocol(
|
||||
address.IPv4Address("TCP", "127.1.1.1", 8080),
|
||||
)
|
||||
transport = StringTransportWithDisconnection()
|
||||
transport.protocol = proto
|
||||
proto.makeConnection(transport)
|
||||
proto.dataReceived(b"PROXY TCP4 1.1.1.1\r\n")
|
||||
proto.dataReceived(b"2.2.2.2 8080\r\n")
|
||||
self.assertFalse(transport.connected)
|
||||
|
||||
def test_preDataReceived_getPeerHost(self) -> None:
|
||||
"""
|
||||
Before any data is received the HAProxy protocol will return the same peer
|
||||
and host as the IP connection.
|
||||
"""
|
||||
factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol))
|
||||
proto = factory.buildProtocol(
|
||||
address.IPv4Address("TCP", "127.0.0.1", 8080),
|
||||
)
|
||||
transport = StringTransportWithDisconnection(
|
||||
hostAddress=mock.sentinel.host_address,
|
||||
peerAddress=mock.sentinel.peer_address,
|
||||
)
|
||||
proto.makeConnection(transport)
|
||||
self.assertEqual(proto.getHost(), mock.sentinel.host_address)
|
||||
self.assertEqual(proto.getPeer(), mock.sentinel.peer_address)
|
||||
|
||||
def test_validIPv4HeaderResolves_getPeerHost(self) -> None:
|
||||
"""
|
||||
Test if IPv4 headers result in the correct host and peer data.
|
||||
"""
|
||||
factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol))
|
||||
proto = factory.buildProtocol(
|
||||
address.IPv4Address("TCP", "127.0.0.1", 8080),
|
||||
)
|
||||
transport = StringTransportWithDisconnection()
|
||||
proto.makeConnection(transport)
|
||||
proto.dataReceived(b"PROXY TCP4 1.1.1.1 2.2.2.2 8080 8888\r\n")
|
||||
self.assertEqual(proto.getPeer().host, "1.1.1.1")
|
||||
self.assertEqual(proto.getPeer().port, 8080)
|
||||
self.assertEqual(
|
||||
proto.wrappedProtocol.transport.getPeer().host,
|
||||
"1.1.1.1",
|
||||
)
|
||||
self.assertEqual(
|
||||
proto.wrappedProtocol.transport.getPeer().port,
|
||||
8080,
|
||||
)
|
||||
self.assertEqual(proto.getHost().host, "2.2.2.2")
|
||||
self.assertEqual(proto.getHost().port, 8888)
|
||||
self.assertEqual(
|
||||
proto.wrappedProtocol.transport.getHost().host,
|
||||
"2.2.2.2",
|
||||
)
|
||||
self.assertEqual(
|
||||
proto.wrappedProtocol.transport.getHost().port,
|
||||
8888,
|
||||
)
|
||||
|
||||
def test_validIPv6HeaderResolves_getPeerHost(self) -> None:
|
||||
"""
|
||||
Test if IPv6 headers result in the correct host and peer data.
|
||||
"""
|
||||
factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol))
|
||||
proto = factory.buildProtocol(
|
||||
address.IPv6Address("TCP", "::1", 8080),
|
||||
)
|
||||
transport = StringTransportWithDisconnection()
|
||||
proto.makeConnection(transport)
|
||||
proto.dataReceived(b"PROXY TCP6 ::1 ::2 8080 8888\r\n")
|
||||
self.assertEqual(proto.getPeer().host, "::1")
|
||||
self.assertEqual(proto.getPeer().port, 8080)
|
||||
self.assertEqual(
|
||||
proto.wrappedProtocol.transport.getPeer().host,
|
||||
"::1",
|
||||
)
|
||||
self.assertEqual(
|
||||
proto.wrappedProtocol.transport.getPeer().port,
|
||||
8080,
|
||||
)
|
||||
self.assertEqual(proto.getHost().host, "::2")
|
||||
self.assertEqual(proto.getHost().port, 8888)
|
||||
self.assertEqual(
|
||||
proto.wrappedProtocol.transport.getHost().host,
|
||||
"::2",
|
||||
)
|
||||
self.assertEqual(
|
||||
proto.wrappedProtocol.transport.getHost().port,
|
||||
8888,
|
||||
)
|
||||
|
||||
def test_overflowBytesSentToWrappedProtocol(self) -> None:
|
||||
"""
|
||||
Test if non-header bytes are passed to the wrapped protocol.
|
||||
"""
|
||||
factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol))
|
||||
proto = factory.buildProtocol(
|
||||
address.IPv6Address("TCP", "::1", 8080),
|
||||
)
|
||||
transport = StringTransportWithDisconnection()
|
||||
proto.makeConnection(transport)
|
||||
proto.dataReceived(b"PROXY TCP6 ::1 ::2 8080 8888\r\nHTTP/1.1 / GET")
|
||||
self.assertEqual(proto.wrappedProtocol.data, b"HTTP/1.1 / GET")
|
||||
|
||||
def test_[AWS-SECRET-REMOVED](self) -> None:
|
||||
"""
|
||||
Test if header streaming passes extra data appropriately.
|
||||
"""
|
||||
factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol))
|
||||
proto = factory.buildProtocol(
|
||||
address.IPv6Address("TCP", "::1", 8080),
|
||||
)
|
||||
transport = StringTransportWithDisconnection()
|
||||
proto.makeConnection(transport)
|
||||
proto.dataReceived(b"PROXY TCP6 ::1 ::2 ")
|
||||
proto.dataReceived(b"8080 8888\r\nHTTP/1.1 / GET")
|
||||
self.assertEqual(proto.wrappedProtocol.data, b"HTTP/1.1 / GET")
|
||||
|
||||
def test_overflowBytesSentToWrappedProtocolAfter(self) -> None:
|
||||
"""
|
||||
Test if wrapper writes all data to wrapped protocol after parsing.
|
||||
"""
|
||||
factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol))
|
||||
proto = factory.buildProtocol(
|
||||
address.IPv6Address("TCP", "::1", 8080),
|
||||
)
|
||||
transport = StringTransportWithDisconnection()
|
||||
proto.makeConnection(transport)
|
||||
proto.dataReceived(b"PROXY TCP6 ::1 ::2 ")
|
||||
proto.dataReceived(b"8080 8888\r\nHTTP/1.1 / GET")
|
||||
proto.dataReceived(b"\r\n\r\n")
|
||||
self.assertEqual(proto.wrappedProtocol.data, b"HTTP/1.1 / GET\r\n\r\n")
|
||||
|
||||
|
||||
class HAProxyWrappingFactoryV2Tests(unittest.TestCase):
|
||||
"""
|
||||
Test L{twisted.protocols.haproxy.HAProxyWrappingFactory} with v2 PROXY
|
||||
headers.
|
||||
"""
|
||||
|
||||
IPV4HEADER = (
|
||||
# V2 Signature
|
||||
b"\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A"
|
||||
# V2 PROXY command
|
||||
b"\x21"
|
||||
# AF_INET/STREAM
|
||||
b"\x11"
|
||||
# 12 bytes for 2 IPv4 addresses and two ports
|
||||
b"\x00\x0C"
|
||||
# 127.0.0.1 for source and destination
|
||||
b"\x7F\x00\x00\x01\x7F\x00\x00\x01"
|
||||
# 8080 for source 8888 for destination
|
||||
b"\x1F\x90\x22\xB8"
|
||||
)
|
||||
IPV6HEADER = (
|
||||
# V2 Signature
|
||||
b"\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A"
|
||||
# V2 PROXY command
|
||||
b"\x21"
|
||||
# AF_INET6/STREAM
|
||||
b"\x21"
|
||||
# 16 bytes for 2 IPv6 addresses and two ports
|
||||
b"\x00\x24"
|
||||
# ::1 for source and destination
|
||||
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
|
||||
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
|
||||
# 8080 for source 8888 for destination
|
||||
b"\x1F\x90\x22\xB8"
|
||||
)
|
||||
|
||||
_SOCK_PATH = (
|
||||
b"\x2F\x68\x6F\x6D\x65\x2F\x74\x65\x73\x74\x73\x2F\x6D\x79\x73\x6F"
|
||||
b"\x63\x6B\x65\x74\x73\x2F\x73\x6F\x63\x6B" + (b"\x00" * 82)
|
||||
)
|
||||
UNIXHEADER = (
|
||||
(
|
||||
# V2 Signature
|
||||
b"\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A"
|
||||
# V2 PROXY command
|
||||
b"\x21"
|
||||
# AF_UNIX/STREAM
|
||||
b"\x31"
|
||||
# 108 bytes for 2 null terminated paths
|
||||
b"\x00\xD8"
|
||||
# /home/tests/mysockets/sock for source and destination paths
|
||||
)
|
||||
+ _SOCK_PATH
|
||||
+ _SOCK_PATH
|
||||
)
|
||||
|
||||
def test_invalidHeaderDisconnects(self) -> None:
|
||||
"""
|
||||
Test if invalid headers result in connectionLost events.
|
||||
"""
|
||||
factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol))
|
||||
proto = factory.buildProtocol(
|
||||
address.IPv6Address("TCP", "::1", 8080),
|
||||
)
|
||||
transport = StringTransportWithDisconnection()
|
||||
transport.protocol = proto
|
||||
proto.makeConnection(transport)
|
||||
proto.dataReceived(b"\x00" + self.IPV4HEADER[1:])
|
||||
self.assertFalse(transport.connected)
|
||||
|
||||
def test_validIPv4HeaderResolves_getPeerHost(self) -> None:
|
||||
"""
|
||||
Test if IPv4 headers result in the correct host and peer data.
|
||||
"""
|
||||
factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol))
|
||||
proto = factory.buildProtocol(
|
||||
address.IPv4Address("TCP", "127.0.0.1", 8080),
|
||||
)
|
||||
transport = StringTransportWithDisconnection()
|
||||
proto.makeConnection(transport)
|
||||
proto.dataReceived(self.IPV4HEADER)
|
||||
self.assertEqual(proto.getPeer().host, "127.0.0.1")
|
||||
self.assertEqual(proto.getPeer().port, 8080)
|
||||
self.assertEqual(
|
||||
proto.wrappedProtocol.transport.getPeer().host,
|
||||
"127.0.0.1",
|
||||
)
|
||||
self.assertEqual(
|
||||
proto.wrappedProtocol.transport.getPeer().port,
|
||||
8080,
|
||||
)
|
||||
self.assertEqual(proto.getHost().host, "127.0.0.1")
|
||||
self.assertEqual(proto.getHost().port, 8888)
|
||||
self.assertEqual(
|
||||
proto.wrappedProtocol.transport.getHost().host,
|
||||
"127.0.0.1",
|
||||
)
|
||||
self.assertEqual(
|
||||
proto.wrappedProtocol.transport.getHost().port,
|
||||
8888,
|
||||
)
|
||||
|
||||
def test_validIPv6HeaderResolves_getPeerHost(self) -> None:
|
||||
"""
|
||||
Test if IPv6 headers result in the correct host and peer data.
|
||||
"""
|
||||
factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol))
|
||||
proto = factory.buildProtocol(
|
||||
address.IPv4Address("TCP", "::1", 8080),
|
||||
)
|
||||
transport = StringTransportWithDisconnection()
|
||||
proto.makeConnection(transport)
|
||||
proto.dataReceived(self.IPV6HEADER)
|
||||
self.assertEqual(proto.getPeer().host, "0:0:0:0:0:0:0:1")
|
||||
self.assertEqual(proto.getPeer().port, 8080)
|
||||
self.assertEqual(
|
||||
proto.wrappedProtocol.transport.getPeer().host,
|
||||
"0:0:0:0:0:0:0:1",
|
||||
)
|
||||
self.assertEqual(
|
||||
proto.wrappedProtocol.transport.getPeer().port,
|
||||
8080,
|
||||
)
|
||||
self.assertEqual(proto.getHost().host, "0:0:0:0:0:0:0:1")
|
||||
self.assertEqual(proto.getHost().port, 8888)
|
||||
self.assertEqual(
|
||||
proto.wrappedProtocol.transport.getHost().host,
|
||||
"0:0:0:0:0:0:0:1",
|
||||
)
|
||||
self.assertEqual(
|
||||
proto.wrappedProtocol.transport.getHost().port,
|
||||
8888,
|
||||
)
|
||||
|
||||
def test_validUNIXHeaderResolves_getPeerHost(self) -> None:
|
||||
"""
|
||||
Test if UNIX headers result in the correct host and peer data.
|
||||
"""
|
||||
factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol))
|
||||
proto = factory.buildProtocol(
|
||||
address.UNIXAddress(b"/home/test/sockets/server.sock"),
|
||||
)
|
||||
transport = StringTransportWithDisconnection()
|
||||
proto.makeConnection(transport)
|
||||
proto.dataReceived(self.UNIXHEADER)
|
||||
self.assertEqual(proto.getPeer().name, b"/home/tests/mysockets/sock")
|
||||
self.assertEqual(
|
||||
proto.wrappedProtocol.transport.getPeer().name,
|
||||
b"/home/tests/mysockets/sock",
|
||||
)
|
||||
self.assertEqual(proto.getHost().name, b"/home/tests/mysockets/sock")
|
||||
self.assertEqual(
|
||||
proto.wrappedProtocol.transport.getHost().name,
|
||||
b"/home/tests/mysockets/sock",
|
||||
)
|
||||
|
||||
def test_overflowBytesSentToWrappedProtocol(self) -> None:
|
||||
"""
|
||||
Test if non-header bytes are passed to the wrapped protocol.
|
||||
"""
|
||||
factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol))
|
||||
proto = factory.buildProtocol(
|
||||
address.IPv6Address("TCP", "::1", 8080),
|
||||
)
|
||||
transport = StringTransportWithDisconnection()
|
||||
proto.makeConnection(transport)
|
||||
proto.dataReceived(self.IPV6HEADER + b"HTTP/1.1 / GET")
|
||||
self.assertEqual(proto.wrappedProtocol.data, b"HTTP/1.1 / GET")
|
||||
|
||||
def test_[AWS-SECRET-REMOVED](self) -> None:
|
||||
"""
|
||||
Test if header streaming passes extra data appropriately.
|
||||
"""
|
||||
factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol))
|
||||
proto = factory.buildProtocol(
|
||||
address.IPv6Address("TCP", "::1", 8080),
|
||||
)
|
||||
transport = StringTransportWithDisconnection()
|
||||
proto.makeConnection(transport)
|
||||
proto.dataReceived(self.IPV6HEADER[:18])
|
||||
proto.dataReceived(self.IPV6HEADER[18:] + b"HTTP/1.1 / GET")
|
||||
self.assertEqual(proto.wrappedProtocol.data, b"HTTP/1.1 / GET")
|
||||
306
.venv/lib/python3.12/site-packages/twisted/protocols/htb.py
Normal file
306
.venv/lib/python3.12/site-packages/twisted/protocols/htb.py
Normal file
@@ -0,0 +1,306 @@
|
||||
# -*- test-case-name: twisted.test.test_htb -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
|
||||
"""
|
||||
Hierarchical Token Bucket traffic shaping.
|
||||
|
||||
Patterned after U{Martin Devera's Hierarchical Token Bucket traffic
|
||||
shaper for the Linux kernel<http://luxik.cdi.cz/~devik/qos/htb/>}.
|
||||
|
||||
@seealso: U{HTB Linux queuing discipline manual - user guide
|
||||
<http://luxik.cdi.cz/~devik/qos/htb/manual/userg.htm>}
|
||||
@seealso: U{Token Bucket Filter in Linux Advanced Routing & Traffic Control
|
||||
HOWTO<http://lartc.org/howto/lartc.qdisc.classless.html#AEN682>}
|
||||
"""
|
||||
|
||||
|
||||
# TODO: Investigate whether we should be using os.times()[-1] instead of
|
||||
# time.time. time.time, it has been pointed out, can go backwards. Is
|
||||
# the same true of os.times?
|
||||
from time import time
|
||||
from typing import Optional
|
||||
|
||||
from zope.interface import Interface, implementer
|
||||
|
||||
from twisted.protocols import pcp
|
||||
|
||||
|
||||
class Bucket:
|
||||
"""
|
||||
Implementation of a Token bucket.
|
||||
|
||||
A bucket can hold a certain number of tokens and it drains over time.
|
||||
|
||||
@cvar maxburst: The maximum number of tokens that the bucket can
|
||||
hold at any given time. If this is L{None}, the bucket has
|
||||
an infinite size.
|
||||
@type maxburst: C{int}
|
||||
@cvar rate: The rate at which the bucket drains, in number
|
||||
of tokens per second. If the rate is L{None}, the bucket
|
||||
drains instantaneously.
|
||||
@type rate: C{int}
|
||||
"""
|
||||
|
||||
maxburst: Optional[int] = None
|
||||
rate: Optional[int] = None
|
||||
|
||||
_refcount = 0
|
||||
|
||||
def __init__(self, parentBucket=None):
|
||||
"""
|
||||
Create a L{Bucket} that may have a parent L{Bucket}.
|
||||
|
||||
@param parentBucket: If a parent Bucket is specified,
|
||||
all L{add} and L{drip} operations on this L{Bucket}
|
||||
will be applied on the parent L{Bucket} as well.
|
||||
@type parentBucket: L{Bucket}
|
||||
"""
|
||||
self.content = 0
|
||||
self.parentBucket = parentBucket
|
||||
self.lastDrip = time()
|
||||
|
||||
def add(self, amount):
|
||||
"""
|
||||
Adds tokens to the L{Bucket} and its C{parentBucket}.
|
||||
|
||||
This will add as many of the C{amount} tokens as will fit into both
|
||||
this L{Bucket} and its C{parentBucket}.
|
||||
|
||||
@param amount: The number of tokens to try to add.
|
||||
@type amount: C{int}
|
||||
|
||||
@returns: The number of tokens that actually fit.
|
||||
@returntype: C{int}
|
||||
"""
|
||||
self.drip()
|
||||
if self.maxburst is None:
|
||||
allowable = amount
|
||||
else:
|
||||
allowable = min(amount, self.maxburst - self.content)
|
||||
|
||||
if self.parentBucket is not None:
|
||||
allowable = self.parentBucket.add(allowable)
|
||||
self.content += allowable
|
||||
return allowable
|
||||
|
||||
def drip(self):
|
||||
"""
|
||||
Let some of the bucket drain.
|
||||
|
||||
The L{Bucket} drains at the rate specified by the class
|
||||
variable C{rate}.
|
||||
|
||||
@returns: C{True} if the bucket is empty after this drip.
|
||||
@returntype: C{bool}
|
||||
"""
|
||||
if self.parentBucket is not None:
|
||||
self.parentBucket.drip()
|
||||
|
||||
if self.rate is None:
|
||||
self.content = 0
|
||||
else:
|
||||
now = time()
|
||||
deltaTime = now - self.lastDrip
|
||||
deltaTokens = deltaTime * self.rate
|
||||
self.content = max(0, self.content - deltaTokens)
|
||||
self.lastDrip = now
|
||||
return self.content == 0
|
||||
|
||||
|
||||
class IBucketFilter(Interface):
|
||||
def getBucketFor(*somethings, **some_kw):
|
||||
"""
|
||||
Return a L{Bucket} corresponding to the provided parameters.
|
||||
|
||||
@returntype: L{Bucket}
|
||||
"""
|
||||
|
||||
|
||||
@implementer(IBucketFilter)
|
||||
class HierarchicalBucketFilter:
|
||||
"""
|
||||
Filter things into buckets that can be nested.
|
||||
|
||||
@cvar bucketFactory: Class of buckets to make.
|
||||
@type bucketFactory: L{Bucket}
|
||||
@cvar sweepInterval: Seconds between sweeping out the bucket cache.
|
||||
@type sweepInterval: C{int}
|
||||
"""
|
||||
|
||||
bucketFactory = Bucket
|
||||
sweepInterval: Optional[int] = None
|
||||
|
||||
def __init__(self, parentFilter=None):
|
||||
self.buckets = {}
|
||||
self.parentFilter = parentFilter
|
||||
self.lastSweep = time()
|
||||
|
||||
def getBucketFor(self, *a, **kw):
|
||||
"""
|
||||
Find or create a L{Bucket} corresponding to the provided parameters.
|
||||
|
||||
Any parameters are passed on to L{getBucketKey}, from them it
|
||||
decides which bucket you get.
|
||||
|
||||
@returntype: L{Bucket}
|
||||
"""
|
||||
if (self.sweepInterval is not None) and (
|
||||
(time() - self.lastSweep) > self.sweepInterval
|
||||
):
|
||||
self.sweep()
|
||||
|
||||
if self.parentFilter:
|
||||
parentBucket = self.parentFilter.getBucketFor(self, *a, **kw)
|
||||
else:
|
||||
parentBucket = None
|
||||
|
||||
key = self.getBucketKey(*a, **kw)
|
||||
bucket = self.buckets.get(key)
|
||||
if bucket is None:
|
||||
bucket = self.bucketFactory(parentBucket)
|
||||
self.buckets[key] = bucket
|
||||
return bucket
|
||||
|
||||
def getBucketKey(self, *a, **kw):
|
||||
"""
|
||||
Construct a key based on the input parameters to choose a L{Bucket}.
|
||||
|
||||
The default implementation returns the same key for all
|
||||
arguments. Override this method to provide L{Bucket} selection.
|
||||
|
||||
@returns: Something to be used as a key in the bucket cache.
|
||||
"""
|
||||
return None
|
||||
|
||||
def sweep(self):
|
||||
"""
|
||||
Remove empty buckets.
|
||||
"""
|
||||
for key, bucket in self.buckets.items():
|
||||
bucket_is_empty = bucket.drip()
|
||||
if (bucket._refcount == 0) and bucket_is_empty:
|
||||
del self.buckets[key]
|
||||
|
||||
self.lastSweep = time()
|
||||
|
||||
|
||||
class FilterByHost(HierarchicalBucketFilter):
|
||||
"""
|
||||
A Hierarchical Bucket filter with a L{Bucket} for each host.
|
||||
"""
|
||||
|
||||
sweepInterval = 60 * 20
|
||||
|
||||
def getBucketKey(self, transport):
|
||||
return transport.getPeer()[1]
|
||||
|
||||
|
||||
class FilterByServer(HierarchicalBucketFilter):
|
||||
"""
|
||||
A Hierarchical Bucket filter with a L{Bucket} for each service.
|
||||
"""
|
||||
|
||||
sweepInterval = None
|
||||
|
||||
def getBucketKey(self, transport):
|
||||
return transport.getHost()[2]
|
||||
|
||||
|
||||
class ShapedConsumer(pcp.ProducerConsumerProxy):
|
||||
"""
|
||||
Wraps a C{Consumer} and shapes the rate at which it receives data.
|
||||
"""
|
||||
|
||||
# Providing a Pull interface means I don't have to try to schedule
|
||||
# traffic with callLaters.
|
||||
iAmStreaming = False
|
||||
|
||||
def __init__(self, consumer, bucket):
|
||||
pcp.ProducerConsumerProxy.__init__(self, consumer)
|
||||
self.bucket = bucket
|
||||
self.bucket._refcount += 1
|
||||
|
||||
def _writeSomeData(self, data):
|
||||
# In practice, this actually results in obscene amounts of
|
||||
# overhead, as a result of generating lots and lots of packets
|
||||
# with twelve-byte payloads. We may need to do a version of
|
||||
# this with scheduled writes after all.
|
||||
amount = self.bucket.add(len(data))
|
||||
return pcp.ProducerConsumerProxy._writeSomeData(self, data[:amount])
|
||||
|
||||
def stopProducing(self):
|
||||
pcp.ProducerConsumerProxy.stopProducing(self)
|
||||
self.bucket._refcount -= 1
|
||||
|
||||
|
||||
class ShapedTransport(ShapedConsumer):
|
||||
"""
|
||||
Wraps a C{Transport} and shapes the rate at which it receives data.
|
||||
|
||||
This is a L{ShapedConsumer} with a little bit of magic to provide for
|
||||
the case where the consumer it wraps is also a C{Transport} and people
|
||||
will be attempting to access attributes this does not proxy as a
|
||||
C{Consumer} (e.g. C{loseConnection}).
|
||||
"""
|
||||
|
||||
# Ugh. We only wanted to filter IConsumer, not ITransport.
|
||||
|
||||
iAmStreaming = False
|
||||
|
||||
def __getattr__(self, name):
|
||||
# Because people will be doing things like .getPeer and
|
||||
# .loseConnection on me.
|
||||
return getattr(self.consumer, name)
|
||||
|
||||
|
||||
class ShapedProtocolFactory:
|
||||
"""
|
||||
Dispense C{Protocols} with traffic shaping on their transports.
|
||||
|
||||
Usage::
|
||||
|
||||
myserver = SomeFactory()
|
||||
myserver.protocol = ShapedProtocolFactory(myserver.protocol,
|
||||
bucketFilter)
|
||||
|
||||
Where C{SomeServerFactory} is a L{twisted.internet.protocol.Factory}, and
|
||||
C{bucketFilter} is an instance of L{HierarchicalBucketFilter}.
|
||||
"""
|
||||
|
||||
def __init__(self, protoClass, bucketFilter):
|
||||
"""
|
||||
Tell me what to wrap and where to get buckets.
|
||||
|
||||
@param protoClass: The class of C{Protocol} this will generate
|
||||
wrapped instances of.
|
||||
@type protoClass: L{Protocol<twisted.internet.interfaces.IProtocol>}
|
||||
class
|
||||
@param bucketFilter: The filter which will determine how
|
||||
traffic is shaped.
|
||||
@type bucketFilter: L{HierarchicalBucketFilter}.
|
||||
"""
|
||||
# More precisely, protoClass can be any callable that will return
|
||||
# instances of something that implements IProtocol.
|
||||
self.protocol = protoClass
|
||||
self.bucketFilter = bucketFilter
|
||||
|
||||
def __call__(self, *a, **kw):
|
||||
"""
|
||||
Make a C{Protocol} instance with a shaped transport.
|
||||
|
||||
Any parameters will be passed on to the protocol's initializer.
|
||||
|
||||
@returns: A C{Protocol} instance with a L{ShapedTransport}.
|
||||
"""
|
||||
proto = self.protocol(*a, **kw)
|
||||
origMakeConnection = proto.makeConnection
|
||||
|
||||
def makeConnection(transport):
|
||||
bucket = self.bucketFilter.getBucketFor(transport)
|
||||
shapedTransport = ShapedTransport(transport, bucket)
|
||||
return origMakeConnection(shapedTransport)
|
||||
|
||||
proto.makeConnection = makeConnection
|
||||
return proto
|
||||
253
.venv/lib/python3.12/site-packages/twisted/protocols/ident.py
Normal file
253
.venv/lib/python3.12/site-packages/twisted/protocols/ident.py
Normal file
@@ -0,0 +1,253 @@
|
||||
# -*- test-case-name: twisted.test.test_ident -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Ident protocol implementation.
|
||||
"""
|
||||
|
||||
import struct
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.protocols import basic
|
||||
from twisted.python import failure, log
|
||||
|
||||
_MIN_PORT = 1
|
||||
_MAX_PORT = 2**16 - 1
|
||||
|
||||
|
||||
class IdentError(Exception):
|
||||
"""
|
||||
Can't determine connection owner; reason unknown.
|
||||
"""
|
||||
|
||||
identDescription = "UNKNOWN-ERROR"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.identDescription
|
||||
|
||||
|
||||
class NoUser(IdentError):
|
||||
"""
|
||||
The connection specified by the port pair is not currently in use or
|
||||
currently not owned by an identifiable entity.
|
||||
"""
|
||||
|
||||
identDescription = "NO-USER"
|
||||
|
||||
|
||||
class InvalidPort(IdentError):
|
||||
"""
|
||||
Either the local or foreign port was improperly specified. This should
|
||||
be returned if either or both of the port ids were out of range (TCP
|
||||
port numbers are from 1-65535), negative integers, reals or in any
|
||||
fashion not recognized as a non-negative integer.
|
||||
"""
|
||||
|
||||
identDescription = "INVALID-PORT"
|
||||
|
||||
|
||||
class HiddenUser(IdentError):
|
||||
"""
|
||||
The server was able to identify the user of this port, but the
|
||||
information was not returned at the request of the user.
|
||||
"""
|
||||
|
||||
identDescription = "HIDDEN-USER"
|
||||
|
||||
|
||||
class IdentServer(basic.LineOnlyReceiver):
|
||||
"""
|
||||
The Identification Protocol (a.k.a., "ident", a.k.a., "the Ident
|
||||
Protocol") provides a means to determine the identity of a user of a
|
||||
particular TCP connection. Given a TCP port number pair, it returns a
|
||||
character string which identifies the owner of that connection on the
|
||||
server's system.
|
||||
|
||||
Server authors should subclass this class and override the lookup method.
|
||||
The default implementation returns an UNKNOWN-ERROR response for every
|
||||
query.
|
||||
"""
|
||||
|
||||
def lineReceived(self, line):
|
||||
parts = line.split(",")
|
||||
if len(parts) != 2:
|
||||
self.invalidQuery()
|
||||
else:
|
||||
try:
|
||||
portOnServer, portOnClient = map(int, parts)
|
||||
except ValueError:
|
||||
self.invalidQuery()
|
||||
else:
|
||||
if (
|
||||
_MIN_PORT <= portOnServer <= _MAX_PORT
|
||||
and _MIN_PORT <= portOnClient <= _MAX_PORT
|
||||
):
|
||||
self.validQuery(portOnServer, portOnClient)
|
||||
else:
|
||||
self._ebLookup(
|
||||
failure.Failure(InvalidPort()), portOnServer, portOnClient
|
||||
)
|
||||
|
||||
def invalidQuery(self):
|
||||
self.transport.loseConnection()
|
||||
|
||||
def validQuery(self, portOnServer, portOnClient):
|
||||
"""
|
||||
Called when a valid query is received to look up and deliver the
|
||||
response.
|
||||
|
||||
@param portOnServer: The server port from the query.
|
||||
@param portOnClient: The client port from the query.
|
||||
"""
|
||||
serverAddr = self.transport.getHost().host, portOnServer
|
||||
clientAddr = self.transport.getPeer().host, portOnClient
|
||||
defer.maybeDeferred(self.lookup, serverAddr, clientAddr).addCallback(
|
||||
self._cbLookup, portOnServer, portOnClient
|
||||
).addErrback(self._ebLookup, portOnServer, portOnClient)
|
||||
|
||||
def _cbLookup(self, result, sport, cport):
|
||||
(sysName, userId) = result
|
||||
self.sendLine("%d, %d : USERID : %s : %s" % (sport, cport, sysName, userId))
|
||||
|
||||
def _ebLookup(self, failure, sport, cport):
|
||||
if failure.check(IdentError):
|
||||
self.sendLine("%d, %d : ERROR : %s" % (sport, cport, failure.value))
|
||||
else:
|
||||
log.err(failure)
|
||||
self.sendLine(
|
||||
"%d, %d : ERROR : %s" % (sport, cport, IdentError(failure.value))
|
||||
)
|
||||
|
||||
def lookup(self, serverAddress, clientAddress):
|
||||
"""
|
||||
Lookup user information about the specified address pair.
|
||||
|
||||
Return value should be a two-tuple of system name and username.
|
||||
Acceptable values for the system name may be found online at::
|
||||
|
||||
U{http://www.iana.org/assignments/operating-system-names}
|
||||
|
||||
This method may also raise any IdentError subclass (or IdentError
|
||||
itself) to indicate user information will not be provided for the
|
||||
given query.
|
||||
|
||||
A Deferred may also be returned.
|
||||
|
||||
@param serverAddress: A two-tuple representing the server endpoint
|
||||
of the address being queried. The first element is a string holding
|
||||
a dotted-quad IP address. The second element is an integer
|
||||
representing the port.
|
||||
|
||||
@param clientAddress: Like I{serverAddress}, but represents the
|
||||
client endpoint of the address being queried.
|
||||
"""
|
||||
raise IdentError()
|
||||
|
||||
|
||||
class ProcServerMixin:
|
||||
"""Implements lookup() to grab entries for responses from /proc/net/tcp"""
|
||||
|
||||
SYSTEM_NAME = "LINUX"
|
||||
|
||||
try:
|
||||
from pwd import getpwuid # type:ignore[misc]
|
||||
|
||||
def getUsername(self, uid, getpwuid=getpwuid):
|
||||
return getpwuid(uid)[0]
|
||||
|
||||
del getpwuid
|
||||
except ImportError:
|
||||
|
||||
def getUsername(self, uid, getpwuid=None):
|
||||
raise IdentError()
|
||||
|
||||
def entries(self):
|
||||
with open("/proc/net/tcp") as f:
|
||||
f.readline()
|
||||
for L in f:
|
||||
yield L.strip()
|
||||
|
||||
def dottedQuadFromHexString(self, hexstr):
|
||||
return ".".join(
|
||||
map(str, struct.unpack("4B", struct.pack("=L", int(hexstr, 16))))
|
||||
)
|
||||
|
||||
def unpackAddress(self, packed):
|
||||
addr, port = packed.split(":")
|
||||
addr = self.dottedQuadFromHexString(addr)
|
||||
port = int(port, 16)
|
||||
return addr, port
|
||||
|
||||
def parseLine(self, line):
|
||||
parts = line.strip().split()
|
||||
localAddr, localPort = self.unpackAddress(parts[1])
|
||||
remoteAddr, remotePort = self.unpackAddress(parts[2])
|
||||
uid = int(parts[7])
|
||||
return (localAddr, localPort), (remoteAddr, remotePort), uid
|
||||
|
||||
def lookup(self, serverAddress, clientAddress):
|
||||
for ent in self.entries():
|
||||
localAddr, remoteAddr, uid = self.parseLine(ent)
|
||||
if remoteAddr == clientAddress and localAddr[1] == serverAddress[1]:
|
||||
return (self.SYSTEM_NAME, self.getUsername(uid))
|
||||
|
||||
raise NoUser()
|
||||
|
||||
|
||||
class IdentClient(basic.LineOnlyReceiver):
|
||||
errorTypes = (IdentError, NoUser, InvalidPort, HiddenUser)
|
||||
|
||||
def __init__(self):
|
||||
self.queries = []
|
||||
|
||||
def lookup(self, portOnServer, portOnClient):
|
||||
"""
|
||||
Lookup user information about the specified address pair.
|
||||
"""
|
||||
self.queries.append((defer.Deferred(), portOnServer, portOnClient))
|
||||
if len(self.queries) > 1:
|
||||
return self.queries[-1][0]
|
||||
|
||||
self.sendLine("%d, %d" % (portOnServer, portOnClient))
|
||||
return self.queries[-1][0]
|
||||
|
||||
def lineReceived(self, line):
|
||||
if not self.queries:
|
||||
log.msg(f"Unexpected server response: {line!r}")
|
||||
else:
|
||||
d, _, _ = self.queries.pop(0)
|
||||
self.parseResponse(d, line)
|
||||
if self.queries:
|
||||
self.sendLine("%d, %d" % (self.queries[0][1], self.queries[0][2]))
|
||||
|
||||
def connectionLost(self, reason):
|
||||
for q in self.queries:
|
||||
q[0].errback(IdentError(reason))
|
||||
self.queries = []
|
||||
|
||||
def parseResponse(self, deferred, line):
|
||||
parts = line.split(":", 2)
|
||||
if len(parts) != 3:
|
||||
deferred.errback(IdentError(line))
|
||||
else:
|
||||
ports, type, addInfo = map(str.strip, parts)
|
||||
if type == "ERROR":
|
||||
for et in self.errorTypes:
|
||||
if et.identDescription == addInfo:
|
||||
deferred.errback(et(line))
|
||||
return
|
||||
deferred.errback(IdentError(line))
|
||||
else:
|
||||
deferred.callback((type, addInfo))
|
||||
|
||||
|
||||
__all__ = [
|
||||
"IdentError",
|
||||
"NoUser",
|
||||
"InvalidPort",
|
||||
"HiddenUser",
|
||||
"IdentServer",
|
||||
"IdentClient",
|
||||
"ProcServerMixin",
|
||||
]
|
||||
387
.venv/lib/python3.12/site-packages/twisted/protocols/loopback.py
Normal file
387
.venv/lib/python3.12/site-packages/twisted/protocols/loopback.py
Normal file
@@ -0,0 +1,387 @@
|
||||
# -*- test-case-name: twisted.test.test_loopback -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Testing support for protocols -- loopback between client and server.
|
||||
"""
|
||||
|
||||
|
||||
# system imports
|
||||
import tempfile
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.internet import defer, interfaces, main, protocol
|
||||
from twisted.internet.interfaces import IAddress
|
||||
from twisted.internet.task import deferLater
|
||||
|
||||
# Twisted Imports
|
||||
from twisted.protocols import policies
|
||||
from twisted.python import failure
|
||||
|
||||
|
||||
class _LoopbackQueue:
|
||||
"""
|
||||
Trivial wrapper around a list to give it an interface like a queue, which
|
||||
the addition of also sending notifications by way of a Deferred whenever
|
||||
the list has something added to it.
|
||||
"""
|
||||
|
||||
_notificationDeferred = None
|
||||
disconnect = False
|
||||
|
||||
def __init__(self):
|
||||
self._queue = []
|
||||
|
||||
def put(self, v):
|
||||
self._queue.append(v)
|
||||
if self._notificationDeferred is not None:
|
||||
d, self._notificationDeferred = self._notificationDeferred, None
|
||||
d.callback(None)
|
||||
|
||||
def __nonzero__(self):
|
||||
return bool(self._queue)
|
||||
|
||||
__bool__ = __nonzero__
|
||||
|
||||
def get(self):
|
||||
return self._queue.pop(0)
|
||||
|
||||
|
||||
@implementer(IAddress)
|
||||
class _LoopbackAddress:
|
||||
pass
|
||||
|
||||
|
||||
@implementer(interfaces.ITransport, interfaces.IConsumer)
|
||||
class _LoopbackTransport:
|
||||
disconnecting = False
|
||||
producer = None
|
||||
|
||||
# ITransport
|
||||
def __init__(self, q):
|
||||
self.q = q
|
||||
|
||||
def write(self, data):
|
||||
if not isinstance(data, bytes):
|
||||
raise TypeError("Can only write bytes to ITransport")
|
||||
self.q.put(data)
|
||||
|
||||
def writeSequence(self, iovec):
|
||||
self.q.put(b"".join(iovec))
|
||||
|
||||
def loseConnection(self):
|
||||
self.q.disconnect = True
|
||||
self.q.put(None)
|
||||
|
||||
def abortConnection(self):
|
||||
"""
|
||||
Abort the connection. Same as L{loseConnection}.
|
||||
"""
|
||||
self.loseConnection()
|
||||
|
||||
def getPeer(self):
|
||||
return _LoopbackAddress()
|
||||
|
||||
def getHost(self):
|
||||
return _LoopbackAddress()
|
||||
|
||||
# IConsumer
|
||||
def registerProducer(self, producer, streaming):
|
||||
assert self.producer is None
|
||||
self.producer = producer
|
||||
self.streamingProducer = streaming
|
||||
self._pollProducer()
|
||||
|
||||
def unregisterProducer(self):
|
||||
assert self.producer is not None
|
||||
self.producer = None
|
||||
|
||||
def _pollProducer(self):
|
||||
if self.producer is not None and not self.streamingProducer:
|
||||
self.producer.resumeProducing()
|
||||
|
||||
|
||||
def identityPumpPolicy(queue, target):
|
||||
"""
|
||||
L{identityPumpPolicy} is a policy which delivers each chunk of data written
|
||||
to the given queue as-is to the target.
|
||||
|
||||
This isn't a particularly realistic policy.
|
||||
|
||||
@see: L{loopbackAsync}
|
||||
"""
|
||||
while queue:
|
||||
bytes = queue.get()
|
||||
if bytes is None:
|
||||
break
|
||||
target.dataReceived(bytes)
|
||||
|
||||
|
||||
def collapsingPumpPolicy(queue, target):
|
||||
"""
|
||||
L{collapsingPumpPolicy} is a policy which collapses all outstanding chunks
|
||||
into a single string and delivers it to the target.
|
||||
|
||||
@see: L{loopbackAsync}
|
||||
"""
|
||||
bytes = []
|
||||
while queue:
|
||||
chunk = queue.get()
|
||||
if chunk is None:
|
||||
break
|
||||
bytes.append(chunk)
|
||||
if bytes:
|
||||
target.dataReceived(b"".join(bytes))
|
||||
|
||||
|
||||
def loopbackAsync(server, client, pumpPolicy=identityPumpPolicy):
|
||||
"""
|
||||
Establish a connection between C{server} and C{client} then transfer data
|
||||
between them until the connection is closed. This is often useful for
|
||||
testing a protocol.
|
||||
|
||||
@param server: The protocol instance representing the server-side of this
|
||||
connection.
|
||||
|
||||
@param client: The protocol instance representing the client-side of this
|
||||
connection.
|
||||
|
||||
@param pumpPolicy: When either C{server} or C{client} writes to its
|
||||
transport, the string passed in is added to a queue of data for the
|
||||
other protocol. Eventually, C{pumpPolicy} will be called with one such
|
||||
queue and the corresponding protocol object. The pump policy callable
|
||||
is responsible for emptying the queue and passing the strings it
|
||||
contains to the given protocol's C{dataReceived} method. The signature
|
||||
of C{pumpPolicy} is C{(queue, protocol)}. C{queue} is an object with a
|
||||
C{get} method which will return the next string written to the
|
||||
transport, or L{None} if the transport has been disconnected, and which
|
||||
evaluates to C{True} if and only if there are more items to be
|
||||
retrieved via C{get}.
|
||||
|
||||
@return: A L{Deferred} which fires when the connection has been closed and
|
||||
both sides have received notification of this.
|
||||
"""
|
||||
serverToClient = _LoopbackQueue()
|
||||
clientToServer = _LoopbackQueue()
|
||||
|
||||
server.makeConnection(_LoopbackTransport(serverToClient))
|
||||
client.makeConnection(_LoopbackTransport(clientToServer))
|
||||
|
||||
return _loopbackAsyncBody(
|
||||
server, serverToClient, client, clientToServer, pumpPolicy
|
||||
)
|
||||
|
||||
|
||||
def _loopbackAsyncBody(server, serverToClient, client, clientToServer, pumpPolicy):
|
||||
"""
|
||||
Transfer bytes from the output queue of each protocol to the input of the other.
|
||||
|
||||
@param server: The protocol instance representing the server-side of this
|
||||
connection.
|
||||
|
||||
@param serverToClient: The L{_LoopbackQueue} holding the server's output.
|
||||
|
||||
@param client: The protocol instance representing the client-side of this
|
||||
connection.
|
||||
|
||||
@param clientToServer: The L{_LoopbackQueue} holding the client's output.
|
||||
|
||||
@param pumpPolicy: See L{loopbackAsync}.
|
||||
|
||||
@return: A L{Deferred} which fires when the connection has been closed and
|
||||
both sides have received notification of this.
|
||||
"""
|
||||
|
||||
def pump(source, q, target):
|
||||
sent = False
|
||||
if q:
|
||||
pumpPolicy(q, target)
|
||||
sent = True
|
||||
if sent and not q:
|
||||
# A write buffer has now been emptied. Give any producer on that
|
||||
# side an opportunity to produce more data.
|
||||
source.transport._pollProducer()
|
||||
|
||||
return sent
|
||||
|
||||
while 1:
|
||||
disconnect = clientSent = serverSent = False
|
||||
|
||||
# Deliver the data which has been written.
|
||||
serverSent = pump(server, serverToClient, client)
|
||||
clientSent = pump(client, clientToServer, server)
|
||||
|
||||
if not clientSent and not serverSent:
|
||||
# Neither side wrote any data. Wait for some new data to be added
|
||||
# before trying to do anything further.
|
||||
d = defer.Deferred()
|
||||
clientToServer._notificationDeferred = d
|
||||
serverToClient._notificationDeferred = d
|
||||
d.addCallback(
|
||||
_loopbackAsyncContinue,
|
||||
server,
|
||||
serverToClient,
|
||||
client,
|
||||
clientToServer,
|
||||
pumpPolicy,
|
||||
)
|
||||
return d
|
||||
if serverToClient.disconnect:
|
||||
# The server wants to drop the connection. Flush any remaining
|
||||
# data it has.
|
||||
disconnect = True
|
||||
pump(server, serverToClient, client)
|
||||
elif clientToServer.disconnect:
|
||||
# The client wants to drop the connection. Flush any remaining
|
||||
# data it has.
|
||||
disconnect = True
|
||||
pump(client, clientToServer, server)
|
||||
if disconnect:
|
||||
# Someone wanted to disconnect, so okay, the connection is gone.
|
||||
server.connectionLost(failure.Failure(main.CONNECTION_DONE))
|
||||
client.connectionLost(failure.Failure(main.CONNECTION_DONE))
|
||||
return defer.succeed(None)
|
||||
|
||||
|
||||
def _loopbackAsyncContinue(
|
||||
ignored, server, serverToClient, client, clientToServer, pumpPolicy
|
||||
):
|
||||
# Clear the Deferred from each message queue, since it has already fired
|
||||
# and cannot be used again.
|
||||
clientToServer._notificationDeferred = None
|
||||
serverToClient._notificationDeferred = None
|
||||
|
||||
# Schedule some more byte-pushing to happen. This isn't done
|
||||
# synchronously because no actual transport can re-enter dataReceived as
|
||||
# a result of calling write, and doing this synchronously could result
|
||||
# in that.
|
||||
from twisted.internet import reactor
|
||||
|
||||
return deferLater(
|
||||
reactor,
|
||||
0,
|
||||
_loopbackAsyncBody,
|
||||
server,
|
||||
serverToClient,
|
||||
client,
|
||||
clientToServer,
|
||||
pumpPolicy,
|
||||
)
|
||||
|
||||
|
||||
@implementer(interfaces.ITransport, interfaces.IConsumer)
|
||||
class LoopbackRelay:
|
||||
buffer = b""
|
||||
shouldLose = 0
|
||||
disconnecting = 0
|
||||
producer = None
|
||||
|
||||
def __init__(self, target, logFile=None):
|
||||
self.target = target
|
||||
self.logFile = logFile
|
||||
|
||||
def write(self, data):
|
||||
self.buffer = self.buffer + data
|
||||
if self.logFile:
|
||||
self.logFile.write("loopback writing %s\n" % repr(data))
|
||||
|
||||
def writeSequence(self, iovec):
|
||||
self.write(b"".join(iovec))
|
||||
|
||||
def clearBuffer(self):
|
||||
if self.shouldLose == -1:
|
||||
return
|
||||
|
||||
if self.producer:
|
||||
self.producer.resumeProducing()
|
||||
if self.buffer:
|
||||
if self.logFile:
|
||||
self.logFile.write("loopback receiving %s\n" % repr(self.buffer))
|
||||
buffer = self.buffer
|
||||
self.buffer = b""
|
||||
self.target.dataReceived(buffer)
|
||||
if self.shouldLose == 1:
|
||||
self.shouldLose = -1
|
||||
self.target.connectionLost(failure.Failure(main.CONNECTION_DONE))
|
||||
|
||||
def loseConnection(self):
|
||||
if self.shouldLose != -1:
|
||||
self.shouldLose = 1
|
||||
|
||||
def getHost(self):
|
||||
return "loopback"
|
||||
|
||||
def getPeer(self):
|
||||
return "loopback"
|
||||
|
||||
def registerProducer(self, producer, streaming):
|
||||
self.producer = producer
|
||||
|
||||
def unregisterProducer(self):
|
||||
self.producer = None
|
||||
|
||||
def logPrefix(self):
|
||||
return f"Loopback({self.target.__class__.__name__!r})"
|
||||
|
||||
|
||||
class LoopbackClientFactory(protocol.ClientFactory):
|
||||
def __init__(self, protocol):
|
||||
self.disconnected = 0
|
||||
self.deferred = defer.Deferred()
|
||||
self.protocol = protocol
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
return self.protocol
|
||||
|
||||
def clientConnectionLost(self, connector, reason):
|
||||
self.disconnected = 1
|
||||
self.deferred.callback(None)
|
||||
|
||||
|
||||
class _FireOnClose(policies.ProtocolWrapper):
|
||||
def __init__(self, protocol, factory):
|
||||
policies.ProtocolWrapper.__init__(self, protocol, factory)
|
||||
self.deferred = defer.Deferred()
|
||||
|
||||
def connectionLost(self, reason):
|
||||
policies.ProtocolWrapper.connectionLost(self, reason)
|
||||
self.deferred.callback(None)
|
||||
|
||||
|
||||
def loopbackTCP(server, client, port=0, noisy=True):
|
||||
"""Run session between server and client protocol instances over TCP."""
|
||||
from twisted.internet import reactor
|
||||
|
||||
f = policies.WrappingFactory(protocol.Factory())
|
||||
serverWrapper = _FireOnClose(f, server)
|
||||
f.noisy = noisy
|
||||
f.buildProtocol = lambda addr: serverWrapper
|
||||
serverPort = reactor.listenTCP(port, f, interface="127.0.0.1")
|
||||
clientF = LoopbackClientFactory(client)
|
||||
clientF.noisy = noisy
|
||||
reactor.connectTCP("127.0.0.1", serverPort.getHost().port, clientF)
|
||||
d = clientF.deferred
|
||||
d.addCallback(lambda x: serverWrapper.deferred)
|
||||
d.addCallback(lambda x: serverPort.stopListening())
|
||||
return d
|
||||
|
||||
|
||||
def loopbackUNIX(server, client, noisy=True):
|
||||
"""Run session between server and client protocol instances over UNIX socket."""
|
||||
path = tempfile.mktemp()
|
||||
from twisted.internet import reactor
|
||||
|
||||
f = policies.WrappingFactory(protocol.Factory())
|
||||
serverWrapper = _FireOnClose(f, server)
|
||||
f.noisy = noisy
|
||||
f.buildProtocol = lambda addr: serverWrapper
|
||||
serverPort = reactor.listenUNIX(path, f)
|
||||
clientF = LoopbackClientFactory(client)
|
||||
clientF.noisy = noisy
|
||||
reactor.connectUNIX(path, clientF)
|
||||
d = clientF.deferred
|
||||
d.addCallback(lambda x: serverWrapper.deferred)
|
||||
d.addCallback(lambda x: serverPort.stopListening())
|
||||
return d
|
||||
733
.venv/lib/python3.12/site-packages/twisted/protocols/memcache.py
Normal file
733
.venv/lib/python3.12/site-packages/twisted/protocols/memcache.py
Normal file
@@ -0,0 +1,733 @@
|
||||
# -*- test-case-name: twisted.test.test_memcache -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Memcache client protocol. Memcached is a caching server, storing data in the
|
||||
form of pairs key/value, and memcache is the protocol to talk with it.
|
||||
|
||||
To connect to a server, create a factory for L{MemCacheProtocol}::
|
||||
|
||||
from twisted.internet import reactor, protocol
|
||||
from twisted.protocols.memcache import MemCacheProtocol, DEFAULT_PORT
|
||||
d = protocol.ClientCreator(reactor, MemCacheProtocol
|
||||
).connectTCP("localhost", DEFAULT_PORT)
|
||||
def doSomething(proto):
|
||||
# Here you call the memcache operations
|
||||
return proto.set("mykey", "a lot of data")
|
||||
d.addCallback(doSomething)
|
||||
reactor.run()
|
||||
|
||||
All the operations of the memcache protocol are present, but
|
||||
L{MemCacheProtocol.set} and L{MemCacheProtocol.get} are the more important.
|
||||
|
||||
See U{http://code.sixapart.[AWS-SECRET-REMOVED]col.txt} for
|
||||
more information about the protocol.
|
||||
"""
|
||||
|
||||
|
||||
from collections import deque
|
||||
|
||||
from twisted.internet.defer import Deferred, TimeoutError, fail
|
||||
from twisted.protocols.basic import LineReceiver
|
||||
from twisted.protocols.policies import TimeoutMixin
|
||||
from twisted.python import log
|
||||
from twisted.python.compat import nativeString, networkString
|
||||
|
||||
DEFAULT_PORT = 11211
|
||||
|
||||
|
||||
class NoSuchCommand(Exception):
|
||||
"""
|
||||
Exception raised when a non existent command is called.
|
||||
"""
|
||||
|
||||
|
||||
class ClientError(Exception):
|
||||
"""
|
||||
Error caused by an invalid client call.
|
||||
"""
|
||||
|
||||
|
||||
class ServerError(Exception):
|
||||
"""
|
||||
Problem happening on the server.
|
||||
"""
|
||||
|
||||
|
||||
class Command:
|
||||
"""
|
||||
Wrap a client action into an object, that holds the values used in the
|
||||
protocol.
|
||||
|
||||
@ivar _deferred: the L{Deferred} object that will be fired when the result
|
||||
arrives.
|
||||
@type _deferred: L{Deferred}
|
||||
|
||||
@ivar command: name of the command sent to the server.
|
||||
@type command: L{bytes}
|
||||
"""
|
||||
|
||||
def __init__(self, command, **kwargs):
|
||||
"""
|
||||
Create a command.
|
||||
|
||||
@param command: the name of the command.
|
||||
@type command: L{bytes}
|
||||
|
||||
@param kwargs: this values will be stored as attributes of the object
|
||||
for future use
|
||||
"""
|
||||
self.command = command
|
||||
self._deferred = Deferred()
|
||||
for k, v in kwargs.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
def success(self, value):
|
||||
"""
|
||||
Shortcut method to fire the underlying deferred.
|
||||
"""
|
||||
self._deferred.callback(value)
|
||||
|
||||
def fail(self, error):
|
||||
"""
|
||||
Make the underlying deferred fails.
|
||||
"""
|
||||
self._deferred.errback(error)
|
||||
|
||||
|
||||
class MemCacheProtocol(LineReceiver, TimeoutMixin):
|
||||
"""
|
||||
MemCache protocol: connect to a memcached server to store/retrieve values.
|
||||
|
||||
@ivar persistentTimeOut: the timeout period used to wait for a response.
|
||||
@type persistentTimeOut: L{int}
|
||||
|
||||
@ivar _current: current list of requests waiting for an answer from the
|
||||
server.
|
||||
@type _current: L{deque} of L{Command}
|
||||
|
||||
@ivar _lenExpected: amount of data expected in raw mode, when reading for
|
||||
a value.
|
||||
@type _lenExpected: L{int}
|
||||
|
||||
@ivar _getBuffer: current buffer of data, used to store temporary data
|
||||
when reading in raw mode.
|
||||
@type _getBuffer: L{list}
|
||||
|
||||
@ivar _bufferLength: the total amount of bytes in C{_getBuffer}.
|
||||
@type _bufferLength: L{int}
|
||||
|
||||
@ivar _disconnected: indicate if the connectionLost has been called or not.
|
||||
@type _disconnected: L{bool}
|
||||
"""
|
||||
|
||||
MAX_KEY_LENGTH = 250
|
||||
_disconnected = False
|
||||
|
||||
def __init__(self, timeOut=60):
|
||||
"""
|
||||
Create the protocol.
|
||||
|
||||
@param timeOut: the timeout to wait before detecting that the
|
||||
connection is dead and close it. It's expressed in seconds.
|
||||
@type timeOut: L{int}
|
||||
"""
|
||||
self._current = deque()
|
||||
self._lenExpected = None
|
||||
self._getBuffer = None
|
||||
self._bufferLength = None
|
||||
self.persistentTimeOut = self.timeOut = timeOut
|
||||
|
||||
def _cancelCommands(self, reason):
|
||||
"""
|
||||
Cancel all the outstanding commands, making them fail with C{reason}.
|
||||
"""
|
||||
while self._current:
|
||||
cmd = self._current.popleft()
|
||||
cmd.fail(reason)
|
||||
|
||||
def timeoutConnection(self):
|
||||
"""
|
||||
Close the connection in case of timeout.
|
||||
"""
|
||||
self._cancelCommands(TimeoutError("Connection timeout"))
|
||||
self.transport.loseConnection()
|
||||
|
||||
def connectionLost(self, reason):
|
||||
"""
|
||||
Cause any outstanding commands to fail.
|
||||
"""
|
||||
self._disconnected = True
|
||||
self._cancelCommands(reason)
|
||||
LineReceiver.connectionLost(self, reason)
|
||||
|
||||
def sendLine(self, line):
|
||||
"""
|
||||
Override sendLine to add a timeout to response.
|
||||
"""
|
||||
if not self._current:
|
||||
self.setTimeout(self.persistentTimeOut)
|
||||
LineReceiver.sendLine(self, line)
|
||||
|
||||
def rawDataReceived(self, data):
|
||||
"""
|
||||
Collect data for a get.
|
||||
"""
|
||||
self.resetTimeout()
|
||||
self._getBuffer.append(data)
|
||||
self._bufferLength += len(data)
|
||||
if self._bufferLength >= self._lenExpected + 2:
|
||||
data = b"".join(self._getBuffer)
|
||||
buf = data[: self._lenExpected]
|
||||
rem = data[self._lenExpected + 2 :]
|
||||
val = buf
|
||||
self._lenExpected = None
|
||||
self._getBuffer = None
|
||||
self._bufferLength = None
|
||||
cmd = self._current[0]
|
||||
if cmd.multiple:
|
||||
flags, cas = cmd.values[cmd.currentKey]
|
||||
cmd.values[cmd.currentKey] = (flags, cas, val)
|
||||
else:
|
||||
cmd.value = val
|
||||
self.setLineMode(rem)
|
||||
|
||||
def cmd_STORED(self):
|
||||
"""
|
||||
Manage a success response to a set operation.
|
||||
"""
|
||||
self._current.popleft().success(True)
|
||||
|
||||
def cmd_NOT_STORED(self):
|
||||
"""
|
||||
Manage a specific 'not stored' response to a set operation: this is not
|
||||
an error, but some condition wasn't met.
|
||||
"""
|
||||
self._current.popleft().success(False)
|
||||
|
||||
def cmd_END(self):
|
||||
"""
|
||||
This the end token to a get or a stat operation.
|
||||
"""
|
||||
cmd = self._current.popleft()
|
||||
if cmd.command == b"get":
|
||||
if cmd.multiple:
|
||||
values = {key: val[::2] for key, val in cmd.values.items()}
|
||||
cmd.success(values)
|
||||
else:
|
||||
cmd.success((cmd.flags, cmd.value))
|
||||
elif cmd.command == b"gets":
|
||||
if cmd.multiple:
|
||||
cmd.success(cmd.values)
|
||||
else:
|
||||
cmd.success((cmd.flags, cmd.cas, cmd.value))
|
||||
elif cmd.command == b"stats":
|
||||
cmd.success(cmd.values)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Unexpected END response to {} command".format(
|
||||
nativeString(cmd.command)
|
||||
)
|
||||
)
|
||||
|
||||
def cmd_NOT_FOUND(self):
|
||||
"""
|
||||
Manage error response for incr/decr/delete.
|
||||
"""
|
||||
self._current.popleft().success(False)
|
||||
|
||||
def cmd_VALUE(self, line):
|
||||
"""
|
||||
Prepare the reading a value after a get.
|
||||
"""
|
||||
cmd = self._current[0]
|
||||
if cmd.command == b"get":
|
||||
key, flags, length = line.split()
|
||||
cas = b""
|
||||
else:
|
||||
key, flags, length, cas = line.split()
|
||||
self._lenExpected = int(length)
|
||||
self._getBuffer = []
|
||||
self._bufferLength = 0
|
||||
if cmd.multiple:
|
||||
if key not in cmd.keys:
|
||||
raise RuntimeError("Unexpected commands answer.")
|
||||
cmd.currentKey = key
|
||||
cmd.values[key] = [int(flags), cas]
|
||||
else:
|
||||
if cmd.key != key:
|
||||
raise RuntimeError("Unexpected commands answer.")
|
||||
cmd.flags = int(flags)
|
||||
cmd.cas = cas
|
||||
self.setRawMode()
|
||||
|
||||
def cmd_STAT(self, line):
|
||||
"""
|
||||
Reception of one stat line.
|
||||
"""
|
||||
cmd = self._current[0]
|
||||
key, val = line.split(b" ", 1)
|
||||
cmd.values[key] = val
|
||||
|
||||
def cmd_VERSION(self, versionData):
|
||||
"""
|
||||
Read version token.
|
||||
"""
|
||||
self._current.popleft().success(versionData)
|
||||
|
||||
def cmd_ERROR(self):
|
||||
"""
|
||||
A non-existent command has been sent.
|
||||
"""
|
||||
log.err("Non-existent command sent.")
|
||||
cmd = self._current.popleft()
|
||||
cmd.fail(NoSuchCommand())
|
||||
|
||||
def cmd_CLIENT_ERROR(self, errText):
|
||||
"""
|
||||
An invalid input as been sent.
|
||||
"""
|
||||
errText = repr(errText)
|
||||
log.err("Invalid input: " + errText)
|
||||
cmd = self._current.popleft()
|
||||
cmd.fail(ClientError(errText))
|
||||
|
||||
def cmd_SERVER_ERROR(self, errText):
|
||||
"""
|
||||
An error has happened server-side.
|
||||
"""
|
||||
errText = repr(errText)
|
||||
log.err("Server error: " + errText)
|
||||
cmd = self._current.popleft()
|
||||
cmd.fail(ServerError(errText))
|
||||
|
||||
def cmd_DELETED(self):
|
||||
"""
|
||||
A delete command has completed successfully.
|
||||
"""
|
||||
self._current.popleft().success(True)
|
||||
|
||||
def cmd_OK(self):
|
||||
"""
|
||||
The last command has been completed.
|
||||
"""
|
||||
self._current.popleft().success(True)
|
||||
|
||||
def cmd_EXISTS(self):
|
||||
"""
|
||||
A C{checkAndSet} update has failed.
|
||||
"""
|
||||
self._current.popleft().success(False)
|
||||
|
||||
def lineReceived(self, line):
|
||||
"""
|
||||
Receive line commands from the server.
|
||||
"""
|
||||
self.resetTimeout()
|
||||
token = line.split(b" ", 1)[0]
|
||||
# First manage standard commands without space
|
||||
cmd = getattr(self, "cmd_" + nativeString(token), None)
|
||||
if cmd is not None:
|
||||
args = line.split(b" ", 1)[1:]
|
||||
if args:
|
||||
cmd(args[0])
|
||||
else:
|
||||
cmd()
|
||||
else:
|
||||
# Then manage commands with space in it
|
||||
line = line.replace(b" ", b"_")
|
||||
cmd = getattr(self, "cmd_" + nativeString(line), None)
|
||||
if cmd is not None:
|
||||
cmd()
|
||||
else:
|
||||
# Increment/Decrement response
|
||||
cmd = self._current.popleft()
|
||||
val = int(line)
|
||||
cmd.success(val)
|
||||
if not self._current:
|
||||
# No pending request, remove timeout
|
||||
self.setTimeout(None)
|
||||
|
||||
def increment(self, key, val=1):
|
||||
"""
|
||||
Increment the value of C{key} by given value (default to 1).
|
||||
C{key} must be consistent with an int. Return the new value.
|
||||
|
||||
@param key: the key to modify.
|
||||
@type key: L{bytes}
|
||||
|
||||
@param val: the value to increment.
|
||||
@type val: L{int}
|
||||
|
||||
@return: a deferred with will be called back with the new value
|
||||
associated with the key (after the increment).
|
||||
@rtype: L{Deferred}
|
||||
"""
|
||||
return self._incrdecr(b"incr", key, val)
|
||||
|
||||
def decrement(self, key, val=1):
|
||||
"""
|
||||
Decrement the value of C{key} by given value (default to 1).
|
||||
C{key} must be consistent with an int. Return the new value, coerced to
|
||||
0 if negative.
|
||||
|
||||
@param key: the key to modify.
|
||||
@type key: L{bytes}
|
||||
|
||||
@param val: the value to decrement.
|
||||
@type val: L{int}
|
||||
|
||||
@return: a deferred with will be called back with the new value
|
||||
associated with the key (after the decrement).
|
||||
@rtype: L{Deferred}
|
||||
"""
|
||||
return self._incrdecr(b"decr", key, val)
|
||||
|
||||
def _incrdecr(self, cmd, key, val):
|
||||
"""
|
||||
Internal wrapper for incr/decr.
|
||||
"""
|
||||
if self._disconnected:
|
||||
return fail(RuntimeError("not connected"))
|
||||
if not isinstance(key, bytes):
|
||||
return fail(
|
||||
ClientError(f"Invalid type for key: {type(key)}, expecting bytes")
|
||||
)
|
||||
if len(key) > self.MAX_KEY_LENGTH:
|
||||
return fail(ClientError("Key too long"))
|
||||
fullcmd = b" ".join([cmd, key, b"%d" % (int(val),)])
|
||||
self.sendLine(fullcmd)
|
||||
cmdObj = Command(cmd, key=key)
|
||||
self._current.append(cmdObj)
|
||||
return cmdObj._deferred
|
||||
|
||||
def replace(self, key, val, flags=0, expireTime=0):
|
||||
"""
|
||||
Replace the given C{key}. It must already exist in the server.
|
||||
|
||||
@param key: the key to replace.
|
||||
@type key: L{bytes}
|
||||
|
||||
@param val: the new value associated with the key.
|
||||
@type val: L{bytes}
|
||||
|
||||
@param flags: the flags to store with the key.
|
||||
@type flags: L{int}
|
||||
|
||||
@param expireTime: if different from 0, the relative time in seconds
|
||||
when the key will be deleted from the store.
|
||||
@type expireTime: L{int}
|
||||
|
||||
@return: a deferred that will fire with C{True} if the operation has
|
||||
succeeded, and C{False} with the key didn't previously exist.
|
||||
@rtype: L{Deferred}
|
||||
"""
|
||||
return self._set(b"replace", key, val, flags, expireTime, b"")
|
||||
|
||||
def add(self, key, val, flags=0, expireTime=0):
|
||||
"""
|
||||
Add the given C{key}. It must not exist in the server.
|
||||
|
||||
@param key: the key to add.
|
||||
@type key: L{bytes}
|
||||
|
||||
@param val: the value associated with the key.
|
||||
@type val: L{bytes}
|
||||
|
||||
@param flags: the flags to store with the key.
|
||||
@type flags: L{int}
|
||||
|
||||
@param expireTime: if different from 0, the relative time in seconds
|
||||
when the key will be deleted from the store.
|
||||
@type expireTime: L{int}
|
||||
|
||||
@return: a deferred that will fire with C{True} if the operation has
|
||||
succeeded, and C{False} with the key already exists.
|
||||
@rtype: L{Deferred}
|
||||
"""
|
||||
return self._set(b"add", key, val, flags, expireTime, b"")
|
||||
|
||||
def set(self, key, val, flags=0, expireTime=0):
|
||||
"""
|
||||
Set the given C{key}.
|
||||
|
||||
@param key: the key to set.
|
||||
@type key: L{bytes}
|
||||
|
||||
@param val: the value associated with the key.
|
||||
@type val: L{bytes}
|
||||
|
||||
@param flags: the flags to store with the key.
|
||||
@type flags: L{int}
|
||||
|
||||
@param expireTime: if different from 0, the relative time in seconds
|
||||
when the key will be deleted from the store.
|
||||
@type expireTime: L{int}
|
||||
|
||||
@return: a deferred that will fire with C{True} if the operation has
|
||||
succeeded.
|
||||
@rtype: L{Deferred}
|
||||
"""
|
||||
return self._set(b"set", key, val, flags, expireTime, b"")
|
||||
|
||||
def checkAndSet(self, key, val, cas, flags=0, expireTime=0):
|
||||
"""
|
||||
Change the content of C{key} only if the C{cas} value matches the
|
||||
current one associated with the key. Use this to store a value which
|
||||
hasn't been modified since last time you fetched it.
|
||||
|
||||
@param key: The key to set.
|
||||
@type key: L{bytes}
|
||||
|
||||
@param val: The value associated with the key.
|
||||
@type val: L{bytes}
|
||||
|
||||
@param cas: Unique 64-bit value returned by previous call of C{get}.
|
||||
@type cas: L{bytes}
|
||||
|
||||
@param flags: The flags to store with the key.
|
||||
@type flags: L{int}
|
||||
|
||||
@param expireTime: If different from 0, the relative time in seconds
|
||||
when the key will be deleted from the store.
|
||||
@type expireTime: L{int}
|
||||
|
||||
@return: A deferred that will fire with C{True} if the operation has
|
||||
succeeded, C{False} otherwise.
|
||||
@rtype: L{Deferred}
|
||||
"""
|
||||
return self._set(b"cas", key, val, flags, expireTime, cas)
|
||||
|
||||
def _set(self, cmd, key, val, flags, expireTime, cas):
|
||||
"""
|
||||
Internal wrapper for setting values.
|
||||
"""
|
||||
if self._disconnected:
|
||||
return fail(RuntimeError("not connected"))
|
||||
if not isinstance(key, bytes):
|
||||
return fail(
|
||||
ClientError(f"Invalid type for key: {type(key)}, expecting bytes")
|
||||
)
|
||||
if len(key) > self.MAX_KEY_LENGTH:
|
||||
return fail(ClientError("Key too long"))
|
||||
if not isinstance(val, bytes):
|
||||
return fail(
|
||||
ClientError(f"Invalid type for value: {type(val)}, expecting bytes")
|
||||
)
|
||||
if cas:
|
||||
cas = b" " + cas
|
||||
length = len(val)
|
||||
fullcmd = (
|
||||
b" ".join(
|
||||
[cmd, key, networkString("%d %d %d" % (flags, expireTime, length))]
|
||||
)
|
||||
+ cas
|
||||
)
|
||||
self.sendLine(fullcmd)
|
||||
self.sendLine(val)
|
||||
cmdObj = Command(cmd, key=key, flags=flags, length=length)
|
||||
self._current.append(cmdObj)
|
||||
return cmdObj._deferred
|
||||
|
||||
def append(self, key, val):
|
||||
"""
|
||||
Append given data to the value of an existing key.
|
||||
|
||||
@param key: The key to modify.
|
||||
@type key: L{bytes}
|
||||
|
||||
@param val: The value to append to the current value associated with
|
||||
the key.
|
||||
@type val: L{bytes}
|
||||
|
||||
@return: A deferred that will fire with C{True} if the operation has
|
||||
succeeded, C{False} otherwise.
|
||||
@rtype: L{Deferred}
|
||||
"""
|
||||
# Even if flags and expTime values are ignored, we have to pass them
|
||||
return self._set(b"append", key, val, 0, 0, b"")
|
||||
|
||||
def prepend(self, key, val):
|
||||
"""
|
||||
Prepend given data to the value of an existing key.
|
||||
|
||||
@param key: The key to modify.
|
||||
@type key: L{bytes}
|
||||
|
||||
@param val: The value to prepend to the current value associated with
|
||||
the key.
|
||||
@type val: L{bytes}
|
||||
|
||||
@return: A deferred that will fire with C{True} if the operation has
|
||||
succeeded, C{False} otherwise.
|
||||
@rtype: L{Deferred}
|
||||
"""
|
||||
# Even if flags and expTime values are ignored, we have to pass them
|
||||
return self._set(b"prepend", key, val, 0, 0, b"")
|
||||
|
||||
def get(self, key, withIdentifier=False):
|
||||
"""
|
||||
Get the given C{key}. It doesn't support multiple keys. If
|
||||
C{withIdentifier} is set to C{True}, the command issued is a C{gets},
|
||||
that will return the current identifier associated with the value. This
|
||||
identifier has to be used when issuing C{checkAndSet} update later,
|
||||
using the corresponding method.
|
||||
|
||||
@param key: The key to retrieve.
|
||||
@type key: L{bytes}
|
||||
|
||||
@param withIdentifier: If set to C{True}, retrieve the current
|
||||
identifier along with the value and the flags.
|
||||
@type withIdentifier: L{bool}
|
||||
|
||||
@return: A deferred that will fire with the tuple (flags, value) if
|
||||
C{withIdentifier} is C{False}, or (flags, cas identifier, value)
|
||||
if C{True}. If the server indicates there is no value
|
||||
associated with C{key}, the returned value will be L{None} and
|
||||
the returned flags will be C{0}.
|
||||
@rtype: L{Deferred}
|
||||
"""
|
||||
return self._get([key], withIdentifier, False)
|
||||
|
||||
def getMultiple(self, keys, withIdentifier=False):
|
||||
"""
|
||||
Get the given list of C{keys}. If C{withIdentifier} is set to C{True},
|
||||
the command issued is a C{gets}, that will return the identifiers
|
||||
associated with each values. This identifier has to be used when
|
||||
issuing C{checkAndSet} update later, using the corresponding method.
|
||||
|
||||
@param keys: The keys to retrieve.
|
||||
@type keys: L{list} of L{bytes}
|
||||
|
||||
@param withIdentifier: If set to C{True}, retrieve the identifiers
|
||||
along with the values and the flags.
|
||||
@type withIdentifier: L{bool}
|
||||
|
||||
@return: A deferred that will fire with a dictionary with the elements
|
||||
of C{keys} as keys and the tuples (flags, value) as values if
|
||||
C{withIdentifier} is C{False}, or (flags, cas identifier, value) if
|
||||
C{True}. If the server indicates there is no value associated with
|
||||
C{key}, the returned values will be L{None} and the returned flags
|
||||
will be C{0}.
|
||||
@rtype: L{Deferred}
|
||||
|
||||
@since: 9.0
|
||||
"""
|
||||
return self._get(keys, withIdentifier, True)
|
||||
|
||||
def _get(self, keys, withIdentifier, multiple):
|
||||
"""
|
||||
Helper method for C{get} and C{getMultiple}.
|
||||
"""
|
||||
keys = list(keys)
|
||||
if self._disconnected:
|
||||
return fail(RuntimeError("not connected"))
|
||||
for key in keys:
|
||||
if not isinstance(key, bytes):
|
||||
return fail(
|
||||
ClientError(f"Invalid type for key: {type(key)}, expecting bytes")
|
||||
)
|
||||
if len(key) > self.MAX_KEY_LENGTH:
|
||||
return fail(ClientError("Key too long"))
|
||||
if withIdentifier:
|
||||
cmd = b"gets"
|
||||
else:
|
||||
cmd = b"get"
|
||||
fullcmd = b" ".join([cmd] + keys)
|
||||
self.sendLine(fullcmd)
|
||||
if multiple:
|
||||
values = {key: (0, b"", None) for key in keys}
|
||||
cmdObj = Command(cmd, keys=keys, values=values, multiple=True)
|
||||
else:
|
||||
cmdObj = Command(
|
||||
cmd, key=keys[0], value=None, flags=0, cas=b"", multiple=False
|
||||
)
|
||||
self._current.append(cmdObj)
|
||||
return cmdObj._deferred
|
||||
|
||||
def stats(self, arg=None):
|
||||
"""
|
||||
Get some stats from the server. It will be available as a dict.
|
||||
|
||||
@param arg: An optional additional string which will be sent along
|
||||
with the I{stats} command. The interpretation of this value by
|
||||
the server is left undefined by the memcache protocol
|
||||
specification.
|
||||
@type arg: L{None} or L{bytes}
|
||||
|
||||
@return: a deferred that will fire with a L{dict} of the available
|
||||
statistics.
|
||||
@rtype: L{Deferred}
|
||||
"""
|
||||
if arg:
|
||||
cmd = b"stats " + arg
|
||||
else:
|
||||
cmd = b"stats"
|
||||
if self._disconnected:
|
||||
return fail(RuntimeError("not connected"))
|
||||
self.sendLine(cmd)
|
||||
cmdObj = Command(b"stats", values={})
|
||||
self._current.append(cmdObj)
|
||||
return cmdObj._deferred
|
||||
|
||||
def version(self):
|
||||
"""
|
||||
Get the version of the server.
|
||||
|
||||
@return: a deferred that will fire with the string value of the
|
||||
version.
|
||||
@rtype: L{Deferred}
|
||||
"""
|
||||
if self._disconnected:
|
||||
return fail(RuntimeError("not connected"))
|
||||
self.sendLine(b"version")
|
||||
cmdObj = Command(b"version")
|
||||
self._current.append(cmdObj)
|
||||
return cmdObj._deferred
|
||||
|
||||
def delete(self, key):
|
||||
"""
|
||||
Delete an existing C{key}.
|
||||
|
||||
@param key: the key to delete.
|
||||
@type key: L{bytes}
|
||||
|
||||
@return: a deferred that will be called back with C{True} if the key
|
||||
was successfully deleted, or C{False} if not.
|
||||
@rtype: L{Deferred}
|
||||
"""
|
||||
if self._disconnected:
|
||||
return fail(RuntimeError("not connected"))
|
||||
if not isinstance(key, bytes):
|
||||
return fail(
|
||||
ClientError(f"Invalid type for key: {type(key)}, expecting bytes")
|
||||
)
|
||||
self.sendLine(b"delete " + key)
|
||||
cmdObj = Command(b"delete", key=key)
|
||||
self._current.append(cmdObj)
|
||||
return cmdObj._deferred
|
||||
|
||||
def flushAll(self):
|
||||
"""
|
||||
Flush all cached values.
|
||||
|
||||
@return: a deferred that will be called back with C{True} when the
|
||||
operation has succeeded.
|
||||
@rtype: L{Deferred}
|
||||
"""
|
||||
if self._disconnected:
|
||||
return fail(RuntimeError("not connected"))
|
||||
self.sendLine(b"flush_all")
|
||||
cmdObj = Command(b"flush_all")
|
||||
self._current.append(cmdObj)
|
||||
return cmdObj._deferred
|
||||
|
||||
|
||||
__all__ = [
|
||||
"MemCacheProtocol",
|
||||
"DEFAULT_PORT",
|
||||
"NoSuchCommand",
|
||||
"ClientError",
|
||||
"ServerError",
|
||||
]
|
||||
211
.venv/lib/python3.12/site-packages/twisted/protocols/pcp.py
Normal file
211
.venv/lib/python3.12/site-packages/twisted/protocols/pcp.py
Normal file
@@ -0,0 +1,211 @@
|
||||
# -*- test-case-name: twisted.test.test_pcp -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Producer-Consumer Proxy.
|
||||
"""
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.internet import interfaces
|
||||
|
||||
|
||||
@implementer(interfaces.IProducer, interfaces.IConsumer)
|
||||
class BasicProducerConsumerProxy:
|
||||
"""
|
||||
I can act as a man in the middle between any Producer and Consumer.
|
||||
|
||||
@ivar producer: the Producer I subscribe to.
|
||||
@type producer: L{IProducer<interfaces.IProducer>}
|
||||
@ivar consumer: the Consumer I publish to.
|
||||
@type consumer: L{IConsumer<interfaces.IConsumer>}
|
||||
@ivar paused: As a Producer, am I paused?
|
||||
@type paused: bool
|
||||
"""
|
||||
|
||||
consumer = None
|
||||
producer = None
|
||||
producerIsStreaming = None
|
||||
iAmStreaming = True
|
||||
outstandingPull = False
|
||||
paused = False
|
||||
stopped = False
|
||||
|
||||
def __init__(self, consumer):
|
||||
self._buffer = []
|
||||
if consumer is not None:
|
||||
self.consumer = consumer
|
||||
consumer.registerProducer(self, self.iAmStreaming)
|
||||
|
||||
# Producer methods:
|
||||
|
||||
def pauseProducing(self):
|
||||
self.paused = True
|
||||
if self.producer:
|
||||
self.producer.pauseProducing()
|
||||
|
||||
def resumeProducing(self):
|
||||
self.paused = False
|
||||
if self._buffer:
|
||||
# TODO: Check to see if consumer supports writeSeq.
|
||||
self.consumer.write("".join(self._buffer))
|
||||
self._buffer[:] = []
|
||||
else:
|
||||
if not self.iAmStreaming:
|
||||
self.outstandingPull = True
|
||||
|
||||
if self.producer is not None:
|
||||
self.producer.resumeProducing()
|
||||
|
||||
def stopProducing(self):
|
||||
if self.producer is not None:
|
||||
self.producer.stopProducing()
|
||||
if self.consumer is not None:
|
||||
del self.consumer
|
||||
|
||||
# Consumer methods:
|
||||
|
||||
def write(self, data):
|
||||
if self.paused or (not self.iAmStreaming and not self.outstandingPull):
|
||||
# We could use that fifo queue here.
|
||||
self._buffer.append(data)
|
||||
|
||||
elif self.consumer is not None:
|
||||
self.consumer.write(data)
|
||||
self.outstandingPull = False
|
||||
|
||||
def finish(self):
|
||||
if self.consumer is not None:
|
||||
self.consumer.finish()
|
||||
self.unregisterProducer()
|
||||
|
||||
def registerProducer(self, producer, streaming):
|
||||
self.producer = producer
|
||||
self.producerIsStreaming = streaming
|
||||
|
||||
def unregisterProducer(self):
|
||||
if self.producer is not None:
|
||||
del self.producer
|
||||
del self.producerIsStreaming
|
||||
if self.consumer:
|
||||
self.consumer.unregisterProducer()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__}@{id(self):x} around {self.consumer}>"
|
||||
|
||||
|
||||
class ProducerConsumerProxy(BasicProducerConsumerProxy):
|
||||
"""ProducerConsumerProxy with a finite buffer.
|
||||
|
||||
When my buffer fills up, I have my parent Producer pause until my buffer
|
||||
has room in it again.
|
||||
"""
|
||||
|
||||
# Copies much from abstract.FileDescriptor
|
||||
bufferSize = 2**2**2**2
|
||||
|
||||
producerPaused = False
|
||||
unregistered = False
|
||||
|
||||
def pauseProducing(self):
|
||||
# Does *not* call up to ProducerConsumerProxy to relay the pause
|
||||
# message through to my parent Producer.
|
||||
self.paused = True
|
||||
|
||||
def resumeProducing(self):
|
||||
self.paused = False
|
||||
if self._buffer:
|
||||
data = "".join(self._buffer)
|
||||
bytesSent = self._writeSomeData(data)
|
||||
if bytesSent < len(data):
|
||||
unsent = data[bytesSent:]
|
||||
assert (
|
||||
not self.iAmStreaming
|
||||
), "Streaming producer did not write all its data."
|
||||
self._buffer[:] = [unsent]
|
||||
else:
|
||||
self._buffer[:] = []
|
||||
else:
|
||||
bytesSent = 0
|
||||
|
||||
if (
|
||||
self.unregistered
|
||||
and bytesSent
|
||||
and not self._buffer
|
||||
and self.consumer is not None
|
||||
):
|
||||
self.consumer.unregisterProducer()
|
||||
|
||||
if not self.iAmStreaming:
|
||||
self.outstandingPull = not bytesSent
|
||||
|
||||
if self.producer is not None:
|
||||
bytesBuffered = sum(len(s) for s in self._buffer)
|
||||
# TODO: You can see here the potential for high and low
|
||||
# watermarks, where bufferSize would be the high mark when we
|
||||
# ask the upstream producer to pause, and we wouldn't have
|
||||
# it resume again until it hit the low mark. Or if producer
|
||||
# is Pull, maybe we'd like to pull from it as much as necessary
|
||||
# to keep our buffer full to the low mark, so we're never caught
|
||||
# without something to send.
|
||||
if self.producerPaused and (bytesBuffered < self.bufferSize):
|
||||
# Now that our buffer is empty,
|
||||
self.producerPaused = False
|
||||
self.producer.resumeProducing()
|
||||
elif self.outstandingPull:
|
||||
# I did not have any data to write in response to a pull,
|
||||
# so I'd better pull some myself.
|
||||
self.producer.resumeProducing()
|
||||
|
||||
def write(self, data):
|
||||
if self.paused or (not self.iAmStreaming and not self.outstandingPull):
|
||||
# We could use that fifo queue here.
|
||||
self._buffer.append(data)
|
||||
|
||||
elif self.consumer is not None:
|
||||
assert (
|
||||
not self._buffer
|
||||
), "Writing fresh data to consumer before my buffer is empty!"
|
||||
# I'm going to use _writeSomeData here so that there is only one
|
||||
# path to self.consumer.write. But it doesn't actually make sense,
|
||||
# if I am streaming, for some data to not be all data. But maybe I
|
||||
# am not streaming, but I am writing here anyway, because there was
|
||||
# an earlier request for data which was not answered.
|
||||
bytesSent = self._writeSomeData(data)
|
||||
self.outstandingPull = False
|
||||
if not bytesSent == len(data):
|
||||
assert (
|
||||
not self.iAmStreaming
|
||||
), "Streaming producer did not write all its data."
|
||||
self._buffer.append(data[bytesSent:])
|
||||
|
||||
if (self.producer is not None) and self.producerIsStreaming:
|
||||
bytesBuffered = sum(len(s) for s in self._buffer)
|
||||
if bytesBuffered >= self.bufferSize:
|
||||
self.producer.pauseProducing()
|
||||
self.producerPaused = True
|
||||
|
||||
def registerProducer(self, producer, streaming):
|
||||
self.unregistered = False
|
||||
BasicProducerConsumerProxy.registerProducer(self, producer, streaming)
|
||||
if not streaming:
|
||||
producer.resumeProducing()
|
||||
|
||||
def unregisterProducer(self):
|
||||
if self.producer is not None:
|
||||
del self.producer
|
||||
del self.producerIsStreaming
|
||||
self.unregistered = True
|
||||
if self.consumer and not self._buffer:
|
||||
self.consumer.unregisterProducer()
|
||||
|
||||
def _writeSomeData(self, data):
|
||||
"""Write as much of this data as possible.
|
||||
|
||||
@returns: The number of bytes written.
|
||||
"""
|
||||
if self.consumer is None:
|
||||
return 0
|
||||
self.consumer.write(data)
|
||||
return len(data)
|
||||
696
.venv/lib/python3.12/site-packages/twisted/protocols/policies.py
Normal file
696
.venv/lib/python3.12/site-packages/twisted/protocols/policies.py
Normal file
@@ -0,0 +1,696 @@
|
||||
# -*- test-case-name: twisted.test.test_policies -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Resource limiting policies.
|
||||
|
||||
@seealso: See also L{twisted.protocols.htb} for rate limiting.
|
||||
"""
|
||||
|
||||
|
||||
# system imports
|
||||
import sys
|
||||
from typing import Optional, Type
|
||||
|
||||
from zope.interface import directlyProvides, providedBy
|
||||
|
||||
from twisted.internet import error, interfaces
|
||||
from twisted.internet.interfaces import ILoggingContext
|
||||
|
||||
# twisted imports
|
||||
from twisted.internet.protocol import ClientFactory, Protocol, ServerFactory
|
||||
from twisted.python import log
|
||||
|
||||
|
||||
def _wrappedLogPrefix(wrapper, wrapped):
|
||||
"""
|
||||
Compute a log prefix for a wrapper and the object it wraps.
|
||||
|
||||
@rtype: C{str}
|
||||
"""
|
||||
if ILoggingContext.providedBy(wrapped):
|
||||
logPrefix = wrapped.logPrefix()
|
||||
else:
|
||||
logPrefix = wrapped.__class__.__name__
|
||||
return f"{logPrefix} ({wrapper.__class__.__name__})"
|
||||
|
||||
|
||||
class ProtocolWrapper(Protocol):
|
||||
"""
|
||||
Wraps protocol instances and acts as their transport as well.
|
||||
|
||||
@ivar wrappedProtocol: An L{IProtocol<twisted.internet.interfaces.IProtocol>}
|
||||
provider to which L{IProtocol<twisted.internet.interfaces.IProtocol>}
|
||||
method calls onto this L{ProtocolWrapper} will be proxied.
|
||||
|
||||
@ivar factory: The L{WrappingFactory} which created this
|
||||
L{ProtocolWrapper}.
|
||||
"""
|
||||
|
||||
disconnecting = 0
|
||||
|
||||
def __init__(
|
||||
self, factory: "WrappingFactory", wrappedProtocol: interfaces.IProtocol
|
||||
):
|
||||
self.wrappedProtocol = wrappedProtocol
|
||||
self.factory = factory
|
||||
|
||||
def logPrefix(self):
|
||||
"""
|
||||
Use a customized log prefix mentioning both the wrapped protocol and
|
||||
the current one.
|
||||
"""
|
||||
return _wrappedLogPrefix(self, self.wrappedProtocol)
|
||||
|
||||
def makeConnection(self, transport):
|
||||
"""
|
||||
When a connection is made, register this wrapper with its factory,
|
||||
save the real transport, and connect the wrapped protocol to this
|
||||
L{ProtocolWrapper} to intercept any transport calls it makes.
|
||||
"""
|
||||
directlyProvides(self, providedBy(transport))
|
||||
Protocol.makeConnection(self, transport)
|
||||
self.factory.registerProtocol(self)
|
||||
self.wrappedProtocol.makeConnection(self)
|
||||
|
||||
# Transport relaying
|
||||
|
||||
def write(self, data):
|
||||
self.transport.write(data)
|
||||
|
||||
def writeSequence(self, data):
|
||||
self.transport.writeSequence(data)
|
||||
|
||||
def loseConnection(self):
|
||||
self.disconnecting = 1
|
||||
self.transport.loseConnection()
|
||||
|
||||
def getPeer(self):
|
||||
return self.transport.getPeer()
|
||||
|
||||
def getHost(self):
|
||||
return self.transport.getHost()
|
||||
|
||||
def registerProducer(self, producer, streaming):
|
||||
self.transport.registerProducer(producer, streaming)
|
||||
|
||||
def unregisterProducer(self):
|
||||
self.transport.unregisterProducer()
|
||||
|
||||
def stopConsuming(self):
|
||||
self.transport.stopConsuming()
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self.transport, name)
|
||||
|
||||
# Protocol relaying
|
||||
|
||||
def dataReceived(self, data):
|
||||
self.wrappedProtocol.dataReceived(data)
|
||||
|
||||
def connectionLost(self, reason):
|
||||
self.factory.unregisterProtocol(self)
|
||||
self.wrappedProtocol.connectionLost(reason)
|
||||
|
||||
# Breaking reference cycle between self and wrappedProtocol.
|
||||
self.wrappedProtocol = None
|
||||
|
||||
|
||||
class WrappingFactory(ClientFactory):
|
||||
"""
|
||||
Wraps a factory and its protocols, and keeps track of them.
|
||||
"""
|
||||
|
||||
protocol: Type[Protocol] = ProtocolWrapper
|
||||
|
||||
def __init__(self, wrappedFactory):
|
||||
self.wrappedFactory = wrappedFactory
|
||||
self.protocols = {}
|
||||
|
||||
def logPrefix(self):
|
||||
"""
|
||||
Generate a log prefix mentioning both the wrapped factory and this one.
|
||||
"""
|
||||
return _wrappedLogPrefix(self, self.wrappedFactory)
|
||||
|
||||
def doStart(self):
|
||||
self.wrappedFactory.doStart()
|
||||
ClientFactory.doStart(self)
|
||||
|
||||
def doStop(self):
|
||||
self.wrappedFactory.doStop()
|
||||
ClientFactory.doStop(self)
|
||||
|
||||
def startedConnecting(self, connector):
|
||||
self.wrappedFactory.startedConnecting(connector)
|
||||
|
||||
def clientConnectionFailed(self, connector, reason):
|
||||
self.wrappedFactory.clientConnectionFailed(connector, reason)
|
||||
|
||||
def clientConnectionLost(self, connector, reason):
|
||||
self.wrappedFactory.clientConnectionLost(connector, reason)
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
return self.protocol(self, self.wrappedFactory.buildProtocol(addr))
|
||||
|
||||
def registerProtocol(self, p):
|
||||
"""
|
||||
Called by protocol to register itself.
|
||||
"""
|
||||
self.protocols[p] = 1
|
||||
|
||||
def unregisterProtocol(self, p):
|
||||
"""
|
||||
Called by protocols when they go away.
|
||||
"""
|
||||
del self.protocols[p]
|
||||
|
||||
|
||||
class ThrottlingProtocol(ProtocolWrapper):
|
||||
"""
|
||||
Protocol for L{ThrottlingFactory}.
|
||||
"""
|
||||
|
||||
# wrap API for tracking bandwidth
|
||||
|
||||
def write(self, data):
|
||||
self.factory.registerWritten(len(data))
|
||||
ProtocolWrapper.write(self, data)
|
||||
|
||||
def writeSequence(self, seq):
|
||||
self.factory.registerWritten(sum(map(len, seq)))
|
||||
ProtocolWrapper.writeSequence(self, seq)
|
||||
|
||||
def dataReceived(self, data):
|
||||
self.factory.registerRead(len(data))
|
||||
ProtocolWrapper.dataReceived(self, data)
|
||||
|
||||
def registerProducer(self, producer, streaming):
|
||||
self.producer = producer
|
||||
ProtocolWrapper.registerProducer(self, producer, streaming)
|
||||
|
||||
def unregisterProducer(self):
|
||||
del self.producer
|
||||
ProtocolWrapper.unregisterProducer(self)
|
||||
|
||||
def throttleReads(self):
|
||||
self.transport.pauseProducing()
|
||||
|
||||
def unthrottleReads(self):
|
||||
self.transport.resumeProducing()
|
||||
|
||||
def throttleWrites(self):
|
||||
if hasattr(self, "producer"):
|
||||
self.producer.pauseProducing()
|
||||
|
||||
def unthrottleWrites(self):
|
||||
if hasattr(self, "producer"):
|
||||
self.producer.resumeProducing()
|
||||
|
||||
|
||||
class ThrottlingFactory(WrappingFactory):
|
||||
"""
|
||||
Throttles bandwidth and number of connections.
|
||||
|
||||
Write bandwidth will only be throttled if there is a producer
|
||||
registered.
|
||||
"""
|
||||
|
||||
protocol = ThrottlingProtocol
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
wrappedFactory,
|
||||
maxConnectionCount=sys.maxsize,
|
||||
readLimit=None,
|
||||
writeLimit=None,
|
||||
):
|
||||
WrappingFactory.__init__(self, wrappedFactory)
|
||||
self.connectionCount = 0
|
||||
self.maxConnectionCount = maxConnectionCount
|
||||
self.readLimit = readLimit # max bytes we should read per second
|
||||
self.writeLimit = writeLimit # max bytes we should write per second
|
||||
self.readThisSecond = 0
|
||||
self.writtenThisSecond = 0
|
||||
self.unthrottleReadsID = None
|
||||
self.checkReadBandwidthID = None
|
||||
self.unthrottleWritesID = None
|
||||
self.checkWriteBandwidthID = None
|
||||
|
||||
def callLater(self, period, func):
|
||||
"""
|
||||
Wrapper around
|
||||
L{reactor.callLater<twisted.internet.interfaces.IReactorTime.callLater>}
|
||||
for test purpose.
|
||||
"""
|
||||
from twisted.internet import reactor
|
||||
|
||||
return reactor.callLater(period, func)
|
||||
|
||||
def registerWritten(self, length):
|
||||
"""
|
||||
Called by protocol to tell us more bytes were written.
|
||||
"""
|
||||
self.writtenThisSecond += length
|
||||
|
||||
def registerRead(self, length):
|
||||
"""
|
||||
Called by protocol to tell us more bytes were read.
|
||||
"""
|
||||
self.readThisSecond += length
|
||||
|
||||
def checkReadBandwidth(self):
|
||||
"""
|
||||
Checks if we've passed bandwidth limits.
|
||||
"""
|
||||
if self.readThisSecond > self.readLimit:
|
||||
self.throttleReads()
|
||||
throttleTime = (float(self.readThisSecond) / self.readLimit) - 1.0
|
||||
self.unthrottleReadsID = self.callLater(throttleTime, self.unthrottleReads)
|
||||
self.readThisSecond = 0
|
||||
self.checkReadBandwidthID = self.callLater(1, self.checkReadBandwidth)
|
||||
|
||||
def checkWriteBandwidth(self):
|
||||
if self.writtenThisSecond > self.writeLimit:
|
||||
self.throttleWrites()
|
||||
throttleTime = (float(self.writtenThisSecond) / self.writeLimit) - 1.0
|
||||
self.unthrottleWritesID = self.callLater(
|
||||
throttleTime, self.unthrottleWrites
|
||||
)
|
||||
# reset for next round
|
||||
self.writtenThisSecond = 0
|
||||
self.checkWriteBandwidthID = self.callLater(1, self.checkWriteBandwidth)
|
||||
|
||||
def throttleReads(self):
|
||||
"""
|
||||
Throttle reads on all protocols.
|
||||
"""
|
||||
log.msg("Throttling reads on %s" % self)
|
||||
for p in self.protocols.keys():
|
||||
p.throttleReads()
|
||||
|
||||
def unthrottleReads(self):
|
||||
"""
|
||||
Stop throttling reads on all protocols.
|
||||
"""
|
||||
self.unthrottleReadsID = None
|
||||
log.msg("Stopped throttling reads on %s" % self)
|
||||
for p in self.protocols.keys():
|
||||
p.unthrottleReads()
|
||||
|
||||
def throttleWrites(self):
|
||||
"""
|
||||
Throttle writes on all protocols.
|
||||
"""
|
||||
log.msg("Throttling writes on %s" % self)
|
||||
for p in self.protocols.keys():
|
||||
p.throttleWrites()
|
||||
|
||||
def unthrottleWrites(self):
|
||||
"""
|
||||
Stop throttling writes on all protocols.
|
||||
"""
|
||||
self.unthrottleWritesID = None
|
||||
log.msg("Stopped throttling writes on %s" % self)
|
||||
for p in self.protocols.keys():
|
||||
p.unthrottleWrites()
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
if self.connectionCount == 0:
|
||||
if self.readLimit is not None:
|
||||
self.checkReadBandwidth()
|
||||
if self.writeLimit is not None:
|
||||
self.checkWriteBandwidth()
|
||||
|
||||
if self.connectionCount < self.maxConnectionCount:
|
||||
self.connectionCount += 1
|
||||
return WrappingFactory.buildProtocol(self, addr)
|
||||
else:
|
||||
log.msg("Max connection count reached!")
|
||||
return None
|
||||
|
||||
def unregisterProtocol(self, p):
|
||||
WrappingFactory.unregisterProtocol(self, p)
|
||||
self.connectionCount -= 1
|
||||
if self.connectionCount == 0:
|
||||
if self.unthrottleReadsID is not None:
|
||||
self.unthrottleReadsID.cancel()
|
||||
if self.checkReadBandwidthID is not None:
|
||||
self.checkReadBandwidthID.cancel()
|
||||
if self.unthrottleWritesID is not None:
|
||||
self.unthrottleWritesID.cancel()
|
||||
if self.checkWriteBandwidthID is not None:
|
||||
self.checkWriteBandwidthID.cancel()
|
||||
|
||||
|
||||
class SpewingProtocol(ProtocolWrapper):
|
||||
def dataReceived(self, data):
|
||||
log.msg("Received: %r" % data)
|
||||
ProtocolWrapper.dataReceived(self, data)
|
||||
|
||||
def write(self, data):
|
||||
log.msg("Sending: %r" % data)
|
||||
ProtocolWrapper.write(self, data)
|
||||
|
||||
|
||||
class SpewingFactory(WrappingFactory):
|
||||
protocol = SpewingProtocol
|
||||
|
||||
|
||||
class LimitConnectionsByPeer(WrappingFactory):
|
||||
maxConnectionsPerPeer = 5
|
||||
|
||||
def startFactory(self):
|
||||
self.peerConnections = {}
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
peerHost = addr[0]
|
||||
connectionCount = self.peerConnections.get(peerHost, 0)
|
||||
if connectionCount >= self.maxConnectionsPerPeer:
|
||||
return None
|
||||
self.peerConnections[peerHost] = connectionCount + 1
|
||||
return WrappingFactory.buildProtocol(self, addr)
|
||||
|
||||
def unregisterProtocol(self, p):
|
||||
peerHost = p.getPeer()[1]
|
||||
self.peerConnections[peerHost] -= 1
|
||||
if self.peerConnections[peerHost] == 0:
|
||||
del self.peerConnections[peerHost]
|
||||
|
||||
|
||||
class LimitTotalConnectionsFactory(ServerFactory):
|
||||
"""
|
||||
Factory that limits the number of simultaneous connections.
|
||||
|
||||
@type connectionCount: C{int}
|
||||
@ivar connectionCount: number of current connections.
|
||||
@type connectionLimit: C{int} or L{None}
|
||||
@cvar connectionLimit: maximum number of connections.
|
||||
@type overflowProtocol: L{Protocol} or L{None}
|
||||
@cvar overflowProtocol: Protocol to use for new connections when
|
||||
connectionLimit is exceeded. If L{None} (the default value), excess
|
||||
connections will be closed immediately.
|
||||
"""
|
||||
|
||||
connectionCount = 0
|
||||
connectionLimit = None
|
||||
overflowProtocol: Optional[Type[Protocol]] = None
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
if self.connectionLimit is None or self.connectionCount < self.connectionLimit:
|
||||
# Build the normal protocol
|
||||
wrappedProtocol = self.protocol()
|
||||
elif self.overflowProtocol is None:
|
||||
# Just drop the connection
|
||||
return None
|
||||
else:
|
||||
# Too many connections, so build the overflow protocol
|
||||
wrappedProtocol = self.overflowProtocol()
|
||||
|
||||
wrappedProtocol.factory = self
|
||||
protocol = ProtocolWrapper(self, wrappedProtocol)
|
||||
self.connectionCount += 1
|
||||
return protocol
|
||||
|
||||
def registerProtocol(self, p):
|
||||
pass
|
||||
|
||||
def unregisterProtocol(self, p):
|
||||
self.connectionCount -= 1
|
||||
|
||||
|
||||
class TimeoutProtocol(ProtocolWrapper):
|
||||
"""
|
||||
Protocol that automatically disconnects when the connection is idle.
|
||||
"""
|
||||
|
||||
def __init__(self, factory, wrappedProtocol, timeoutPeriod):
|
||||
"""
|
||||
Constructor.
|
||||
|
||||
@param factory: An L{TimeoutFactory}.
|
||||
@param wrappedProtocol: A L{Protocol} to wrapp.
|
||||
@param timeoutPeriod: Number of seconds to wait for activity before
|
||||
timing out.
|
||||
"""
|
||||
ProtocolWrapper.__init__(self, factory, wrappedProtocol)
|
||||
self.timeoutCall = None
|
||||
self.timeoutPeriod = None
|
||||
self.setTimeout(timeoutPeriod)
|
||||
|
||||
def setTimeout(self, timeoutPeriod=None):
|
||||
"""
|
||||
Set a timeout.
|
||||
|
||||
This will cancel any existing timeouts.
|
||||
|
||||
@param timeoutPeriod: If not L{None}, change the timeout period.
|
||||
Otherwise, use the existing value.
|
||||
"""
|
||||
self.cancelTimeout()
|
||||
self.timeoutPeriod = timeoutPeriod
|
||||
if timeoutPeriod is not None:
|
||||
self.timeoutCall = self.factory.callLater(
|
||||
self.timeoutPeriod, self.timeoutFunc
|
||||
)
|
||||
|
||||
def cancelTimeout(self):
|
||||
"""
|
||||
Cancel the timeout.
|
||||
|
||||
If the timeout was already cancelled, this does nothing.
|
||||
"""
|
||||
self.timeoutPeriod = None
|
||||
if self.timeoutCall:
|
||||
try:
|
||||
self.timeoutCall.cancel()
|
||||
except (error.AlreadyCalled, error.AlreadyCancelled):
|
||||
pass
|
||||
self.timeoutCall = None
|
||||
|
||||
def resetTimeout(self):
|
||||
"""
|
||||
Reset the timeout, usually because some activity just happened.
|
||||
"""
|
||||
if self.timeoutCall:
|
||||
self.timeoutCall.reset(self.timeoutPeriod)
|
||||
|
||||
def write(self, data):
|
||||
self.resetTimeout()
|
||||
ProtocolWrapper.write(self, data)
|
||||
|
||||
def writeSequence(self, seq):
|
||||
self.resetTimeout()
|
||||
ProtocolWrapper.writeSequence(self, seq)
|
||||
|
||||
def dataReceived(self, data):
|
||||
self.resetTimeout()
|
||||
ProtocolWrapper.dataReceived(self, data)
|
||||
|
||||
def connectionLost(self, reason):
|
||||
self.cancelTimeout()
|
||||
ProtocolWrapper.connectionLost(self, reason)
|
||||
|
||||
def timeoutFunc(self):
|
||||
"""
|
||||
This method is called when the timeout is triggered.
|
||||
|
||||
By default it calls I{loseConnection}. Override this if you want
|
||||
something else to happen.
|
||||
"""
|
||||
self.loseConnection()
|
||||
|
||||
|
||||
class TimeoutFactory(WrappingFactory):
|
||||
"""
|
||||
Factory for TimeoutWrapper.
|
||||
"""
|
||||
|
||||
protocol = TimeoutProtocol
|
||||
|
||||
def __init__(self, wrappedFactory, timeoutPeriod=30 * 60):
|
||||
self.timeoutPeriod = timeoutPeriod
|
||||
WrappingFactory.__init__(self, wrappedFactory)
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
return self.protocol(
|
||||
self,
|
||||
self.wrappedFactory.buildProtocol(addr),
|
||||
timeoutPeriod=self.timeoutPeriod,
|
||||
)
|
||||
|
||||
def callLater(self, period, func):
|
||||
"""
|
||||
Wrapper around
|
||||
L{reactor.callLater<twisted.internet.interfaces.IReactorTime.callLater>}
|
||||
for test purpose.
|
||||
"""
|
||||
from twisted.internet import reactor
|
||||
|
||||
return reactor.callLater(period, func)
|
||||
|
||||
|
||||
class TrafficLoggingProtocol(ProtocolWrapper):
|
||||
def __init__(self, factory, wrappedProtocol, logfile, lengthLimit=None, number=0):
|
||||
"""
|
||||
@param factory: factory which created this protocol.
|
||||
@type factory: L{protocol.Factory}.
|
||||
@param wrappedProtocol: the underlying protocol.
|
||||
@type wrappedProtocol: C{protocol.Protocol}.
|
||||
@param logfile: file opened for writing used to write log messages.
|
||||
@type logfile: C{file}
|
||||
@param lengthLimit: maximum size of the datareceived logged.
|
||||
@type lengthLimit: C{int}
|
||||
@param number: identifier of the connection.
|
||||
@type number: C{int}.
|
||||
"""
|
||||
ProtocolWrapper.__init__(self, factory, wrappedProtocol)
|
||||
self.logfile = logfile
|
||||
self.lengthLimit = lengthLimit
|
||||
self._number = number
|
||||
|
||||
def _log(self, line):
|
||||
self.logfile.write(line + "\n")
|
||||
self.logfile.flush()
|
||||
|
||||
def _mungeData(self, data):
|
||||
if self.lengthLimit and len(data) > self.lengthLimit:
|
||||
data = data[: self.lengthLimit - 12] + "<... elided>"
|
||||
return data
|
||||
|
||||
# IProtocol
|
||||
def connectionMade(self):
|
||||
self._log("*")
|
||||
return ProtocolWrapper.connectionMade(self)
|
||||
|
||||
def dataReceived(self, data):
|
||||
self._log("C %d: %r" % (self._number, self._mungeData(data)))
|
||||
return ProtocolWrapper.dataReceived(self, data)
|
||||
|
||||
def connectionLost(self, reason):
|
||||
self._log("C %d: %r" % (self._number, reason))
|
||||
return ProtocolWrapper.connectionLost(self, reason)
|
||||
|
||||
# ITransport
|
||||
def write(self, data):
|
||||
self._log("S %d: %r" % (self._number, self._mungeData(data)))
|
||||
return ProtocolWrapper.write(self, data)
|
||||
|
||||
def writeSequence(self, iovec):
|
||||
self._log("SV %d: %r" % (self._number, [self._mungeData(d) for d in iovec]))
|
||||
return ProtocolWrapper.writeSequence(self, iovec)
|
||||
|
||||
def loseConnection(self):
|
||||
self._log("S %d: *" % (self._number,))
|
||||
return ProtocolWrapper.loseConnection(self)
|
||||
|
||||
|
||||
class TrafficLoggingFactory(WrappingFactory):
|
||||
protocol = TrafficLoggingProtocol
|
||||
|
||||
_counter = 0
|
||||
|
||||
def __init__(self, wrappedFactory, logfilePrefix, lengthLimit=None):
|
||||
self.logfilePrefix = logfilePrefix
|
||||
self.lengthLimit = lengthLimit
|
||||
WrappingFactory.__init__(self, wrappedFactory)
|
||||
|
||||
def open(self, name):
|
||||
return open(name, "w")
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
self._counter += 1
|
||||
logfile = self.open(self.logfilePrefix + "-" + str(self._counter))
|
||||
return self.protocol(
|
||||
self,
|
||||
self.wrappedFactory.buildProtocol(addr),
|
||||
logfile,
|
||||
self.lengthLimit,
|
||||
self._counter,
|
||||
)
|
||||
|
||||
def resetCounter(self):
|
||||
"""
|
||||
Reset the value of the counter used to identify connections.
|
||||
"""
|
||||
self._counter = 0
|
||||
|
||||
|
||||
class TimeoutMixin:
|
||||
"""
|
||||
Mixin for protocols which wish to timeout connections.
|
||||
|
||||
Protocols that mix this in have a single timeout, set using L{setTimeout}.
|
||||
When the timeout is hit, L{timeoutConnection} is called, which, by
|
||||
default, closes the connection.
|
||||
|
||||
@cvar timeOut: The number of seconds after which to timeout the connection.
|
||||
"""
|
||||
|
||||
timeOut: Optional[int] = None
|
||||
|
||||
__timeoutCall = None
|
||||
|
||||
def callLater(self, period, func):
|
||||
"""
|
||||
Wrapper around
|
||||
L{reactor.callLater<twisted.internet.interfaces.IReactorTime.callLater>}
|
||||
for test purpose.
|
||||
"""
|
||||
from twisted.internet import reactor
|
||||
|
||||
return reactor.callLater(period, func)
|
||||
|
||||
def resetTimeout(self):
|
||||
"""
|
||||
Reset the timeout count down.
|
||||
|
||||
If the connection has already timed out, then do nothing. If the
|
||||
timeout has been cancelled (probably using C{setTimeout(None)}), also
|
||||
do nothing.
|
||||
|
||||
It's often a good idea to call this when the protocol has received
|
||||
some meaningful input from the other end of the connection. "I've got
|
||||
some data, they're still there, reset the timeout".
|
||||
"""
|
||||
if self.__timeoutCall is not None and self.timeOut is not None:
|
||||
self.__timeoutCall.reset(self.timeOut)
|
||||
|
||||
def setTimeout(self, period):
|
||||
"""
|
||||
Change the timeout period
|
||||
|
||||
@type period: C{int} or L{None}
|
||||
@param period: The period, in seconds, to change the timeout to, or
|
||||
L{None} to disable the timeout.
|
||||
"""
|
||||
prev = self.timeOut
|
||||
self.timeOut = period
|
||||
|
||||
if self.__timeoutCall is not None:
|
||||
if period is None:
|
||||
try:
|
||||
self.__timeoutCall.cancel()
|
||||
except (error.AlreadyCancelled, error.AlreadyCalled):
|
||||
# Do nothing if the call was already consumed.
|
||||
pass
|
||||
self.__timeoutCall = None
|
||||
else:
|
||||
self.__timeoutCall.reset(period)
|
||||
elif period is not None:
|
||||
self.__timeoutCall = self.callLater(period, self.__timedOut)
|
||||
|
||||
return prev
|
||||
|
||||
def __timedOut(self):
|
||||
self.__timeoutCall = None
|
||||
self.timeoutConnection()
|
||||
|
||||
def timeoutConnection(self):
|
||||
"""
|
||||
Called when the connection times out.
|
||||
|
||||
Override to define behavior other than dropping the connection.
|
||||
"""
|
||||
self.transport.loseConnection()
|
||||
@@ -0,0 +1,90 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
A simple port forwarder.
|
||||
"""
|
||||
|
||||
# Twisted imports
|
||||
from twisted.internet import protocol
|
||||
from twisted.python import log
|
||||
|
||||
|
||||
class Proxy(protocol.Protocol):
|
||||
noisy = True
|
||||
|
||||
peer = None
|
||||
|
||||
def setPeer(self, peer):
|
||||
self.peer = peer
|
||||
|
||||
def connectionLost(self, reason):
|
||||
if self.peer is not None:
|
||||
self.peer.transport.loseConnection()
|
||||
self.peer = None
|
||||
elif self.noisy:
|
||||
log.msg(f"Unable to connect to peer: {reason}")
|
||||
|
||||
def dataReceived(self, data):
|
||||
self.peer.transport.write(data)
|
||||
|
||||
|
||||
class ProxyClient(Proxy):
|
||||
def connectionMade(self):
|
||||
self.peer.setPeer(self)
|
||||
|
||||
# Wire this and the peer transport together to enable
|
||||
# flow control (this stops connections from filling
|
||||
# this proxy memory when one side produces data at a
|
||||
# higher rate than the other can consume).
|
||||
self.transport.registerProducer(self.peer.transport, True)
|
||||
self.peer.transport.registerProducer(self.transport, True)
|
||||
|
||||
# We're connected, everybody can read to their hearts content.
|
||||
self.peer.transport.resumeProducing()
|
||||
|
||||
|
||||
class ProxyClientFactory(protocol.ClientFactory):
|
||||
protocol = ProxyClient
|
||||
|
||||
def setServer(self, server):
|
||||
self.server = server
|
||||
|
||||
def buildProtocol(self, *args, **kw):
|
||||
prot = protocol.ClientFactory.buildProtocol(self, *args, **kw)
|
||||
prot.setPeer(self.server)
|
||||
return prot
|
||||
|
||||
def clientConnectionFailed(self, connector, reason):
|
||||
self.server.transport.loseConnection()
|
||||
|
||||
|
||||
class ProxyServer(Proxy):
|
||||
clientProtocolFactory = ProxyClientFactory
|
||||
reactor = None
|
||||
|
||||
def connectionMade(self):
|
||||
# Don't read anything from the connecting client until we have
|
||||
# somewhere to send it to.
|
||||
self.transport.pauseProducing()
|
||||
|
||||
client = self.clientProtocolFactory()
|
||||
client.setServer(self)
|
||||
|
||||
if self.reactor is None:
|
||||
from twisted.internet import reactor
|
||||
|
||||
self.reactor = reactor
|
||||
self.reactor.connectTCP(self.factory.host, self.factory.port, client)
|
||||
|
||||
|
||||
class ProxyFactory(protocol.Factory):
|
||||
"""
|
||||
Factory for port forwarder.
|
||||
"""
|
||||
|
||||
protocol = ProxyServer
|
||||
|
||||
def __init__(self, host, port):
|
||||
self.host = host
|
||||
self.port = port
|
||||
137
.venv/lib/python3.12/site-packages/twisted/protocols/postfix.py
Normal file
137
.venv/lib/python3.12/site-packages/twisted/protocols/postfix.py
Normal file
@@ -0,0 +1,137 @@
|
||||
# -*- test-case-name: twisted.test.test_postfix -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Postfix mail transport agent related protocols.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from collections import UserDict
|
||||
from typing import TYPE_CHECKING, Union
|
||||
from urllib.parse import quote as _quote, unquote as _unquote
|
||||
|
||||
from twisted.internet import defer, protocol
|
||||
from twisted.protocols import basic, policies
|
||||
from twisted.python import log
|
||||
|
||||
|
||||
# urllib's quote functions just happen to match
|
||||
# the postfix semantics.
|
||||
def quote(s):
|
||||
quoted = _quote(s)
|
||||
if isinstance(quoted, str):
|
||||
quoted = quoted.encode("ascii")
|
||||
return quoted
|
||||
|
||||
|
||||
def unquote(s):
|
||||
if isinstance(s, bytes):
|
||||
s = s.decode("ascii")
|
||||
quoted = _unquote(s)
|
||||
return quoted.encode("ascii")
|
||||
|
||||
|
||||
class PostfixTCPMapServer(basic.LineReceiver, policies.TimeoutMixin):
|
||||
"""
|
||||
Postfix mail transport agent TCP map protocol implementation.
|
||||
|
||||
Receive requests for data matching given key via lineReceived,
|
||||
asks it's factory for the data with self.factory.get(key), and
|
||||
returns the data to the requester. None means no entry found.
|
||||
|
||||
You can use postfix's postmap to test the map service::
|
||||
|
||||
/usr/sbin/postmap -q KEY tcp:localhost:4242
|
||||
|
||||
"""
|
||||
|
||||
timeout = 600
|
||||
delimiter = b"\n"
|
||||
|
||||
def connectionMade(self):
|
||||
self.setTimeout(self.timeout)
|
||||
|
||||
def sendCode(self, code, message=b""):
|
||||
"""
|
||||
Send an SMTP-like code with a message.
|
||||
"""
|
||||
self.sendLine(str(code).encode("ascii") + b" " + message)
|
||||
|
||||
def lineReceived(self, line):
|
||||
self.resetTimeout()
|
||||
try:
|
||||
request, params = line.split(None, 1)
|
||||
except ValueError:
|
||||
request = line
|
||||
params = None
|
||||
try:
|
||||
f = getattr(self, "do_" + request.decode("ascii"))
|
||||
except AttributeError:
|
||||
self.sendCode(400, b"unknown command")
|
||||
else:
|
||||
try:
|
||||
f(params)
|
||||
except BaseException:
|
||||
excInfo = str(sys.exc_info()[1]).encode("ascii")
|
||||
self.sendCode(400, b"Command " + request + b" failed: " + excInfo)
|
||||
|
||||
def do_get(self, key):
|
||||
if key is None:
|
||||
self.sendCode(400, b"Command 'get' takes 1 parameters.")
|
||||
else:
|
||||
d = defer.maybeDeferred(self.factory.get, key)
|
||||
d.addCallbacks(self._cbGot, self._cbNot)
|
||||
d.addErrback(log.err)
|
||||
|
||||
def _cbNot(self, fail):
|
||||
msg = fail.getErrorMessage().encode("ascii")
|
||||
self.sendCode(400, msg)
|
||||
|
||||
def _cbGot(self, value):
|
||||
if value is None:
|
||||
self.sendCode(500)
|
||||
else:
|
||||
self.sendCode(200, quote(value))
|
||||
|
||||
def do_put(self, keyAndValue):
|
||||
if keyAndValue is None:
|
||||
self.sendCode(400, b"Command 'put' takes 2 parameters.")
|
||||
else:
|
||||
try:
|
||||
key, value = keyAndValue.split(None, 1)
|
||||
except ValueError:
|
||||
self.sendCode(400, b"Command 'put' takes 2 parameters.")
|
||||
else:
|
||||
self.sendCode(500, b"put is not implemented yet.")
|
||||
|
||||
|
||||
if TYPE_CHECKING or sys.version_info >= (3, 9):
|
||||
_PostfixTCPMapDict = UserDict[bytes, Union[str, bytes]]
|
||||
else:
|
||||
_PostfixTCPMapDict = UserDict
|
||||
|
||||
|
||||
class PostfixTCPMapDictServerFactory(_PostfixTCPMapDict, protocol.ServerFactory):
|
||||
"""
|
||||
An in-memory dictionary factory for PostfixTCPMapServer.
|
||||
"""
|
||||
|
||||
protocol = PostfixTCPMapServer
|
||||
|
||||
|
||||
class PostfixTCPMapDeferringDictServerFactory(protocol.ServerFactory):
|
||||
"""
|
||||
An in-memory dictionary factory for PostfixTCPMapServer.
|
||||
"""
|
||||
|
||||
protocol = PostfixTCPMapServer
|
||||
|
||||
def __init__(self, data=None):
|
||||
self.data = {}
|
||||
if data is not None:
|
||||
self.data.update(data)
|
||||
|
||||
def get(self, key):
|
||||
return defer.succeed(self.data.get(key))
|
||||
@@ -0,0 +1,111 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Chop up shoutcast stream into MP3s and metadata, if available.
|
||||
"""
|
||||
|
||||
from twisted import copyright
|
||||
from twisted.web import http
|
||||
|
||||
|
||||
class ShoutcastClient(http.HTTPClient):
|
||||
"""
|
||||
Shoutcast HTTP stream.
|
||||
|
||||
Modes can be 'length', 'meta' and 'mp3'.
|
||||
|
||||
See U{http://www.smackfu.com/stuff/programming/shoutcast.html}
|
||||
for details on the protocol.
|
||||
"""
|
||||
|
||||
userAgent = "Twisted Shoutcast client " + copyright.version
|
||||
|
||||
def __init__(self, path="/"):
|
||||
self.path = path
|
||||
self.got_metadata = False
|
||||
self.metaint = None
|
||||
self.metamode = "mp3"
|
||||
self.databuffer = ""
|
||||
|
||||
def connectionMade(self):
|
||||
self.sendCommand("GET", self.path)
|
||||
self.sendHeader("User-Agent", self.userAgent)
|
||||
self.sendHeader("Icy-MetaData", "1")
|
||||
self.endHeaders()
|
||||
|
||||
def lineReceived(self, line):
|
||||
# fix shoutcast crappiness
|
||||
if not self.firstLine and line:
|
||||
if len(line.split(": ", 1)) == 1:
|
||||
line = line.replace(":", ": ", 1)
|
||||
http.HTTPClient.lineReceived(self, line)
|
||||
|
||||
def handleHeader(self, key, value):
|
||||
if key.lower() == "icy-metaint":
|
||||
self.metaint = int(value)
|
||||
self.got_metadata = True
|
||||
|
||||
def handleEndHeaders(self):
|
||||
# Lets check if we got metadata, and set the
|
||||
# appropriate handleResponsePart method.
|
||||
if self.got_metadata:
|
||||
# if we have metadata, then it has to be parsed out of the data stream
|
||||
self.handleResponsePart = self.handleResponsePart_with_metadata
|
||||
else:
|
||||
# otherwise, all the data is MP3 data
|
||||
self.handleResponsePart = self.gotMP3Data
|
||||
|
||||
def handleResponsePart_with_metadata(self, data):
|
||||
self.databuffer += data
|
||||
while self.databuffer:
|
||||
stop = getattr(self, "handle_%s" % self.metamode)()
|
||||
if stop:
|
||||
return
|
||||
|
||||
def handle_length(self):
|
||||
self.remaining = ord(self.databuffer[0]) * 16
|
||||
self.databuffer = self.databuffer[1:]
|
||||
self.metamode = "meta"
|
||||
|
||||
def handle_mp3(self):
|
||||
if len(self.databuffer) > self.metaint:
|
||||
self.gotMP3Data(self.databuffer[: self.metaint])
|
||||
self.databuffer = self.databuffer[self.metaint :]
|
||||
self.metamode = "length"
|
||||
else:
|
||||
return 1
|
||||
|
||||
def handle_meta(self):
|
||||
if len(self.databuffer) >= self.remaining:
|
||||
if self.remaining:
|
||||
data = self.databuffer[: self.remaining]
|
||||
self.gotMetaData(self.parseMetadata(data))
|
||||
self.databuffer = self.databuffer[self.remaining :]
|
||||
self.metamode = "mp3"
|
||||
else:
|
||||
return 1
|
||||
|
||||
def parseMetadata(self, data):
|
||||
meta = []
|
||||
for chunk in data.split(";"):
|
||||
chunk = chunk.strip().replace("\x00", "")
|
||||
if not chunk:
|
||||
continue
|
||||
key, value = chunk.split("=", 1)
|
||||
if value.startswith("'") and value.endswith("'"):
|
||||
value = value[1:-1]
|
||||
meta.append((key, value))
|
||||
return meta
|
||||
|
||||
def gotMetaData(self, metadata):
|
||||
"""Called with a list of (key, value) pairs of metadata,
|
||||
if metadata is available on the server.
|
||||
|
||||
Will only be called on non-empty metadata.
|
||||
"""
|
||||
raise NotImplementedError("implement in subclass")
|
||||
|
||||
def gotMP3Data(self, data):
|
||||
"""Called with chunk of MP3 data."""
|
||||
raise NotImplementedError("implement in subclass")
|
||||
1251
.venv/lib/python3.12/site-packages/twisted/protocols/sip.py
Normal file
1251
.venv/lib/python3.12/site-packages/twisted/protocols/sip.py
Normal file
File diff suppressed because it is too large
Load Diff
249
.venv/lib/python3.12/site-packages/twisted/protocols/socks.py
Normal file
249
.venv/lib/python3.12/site-packages/twisted/protocols/socks.py
Normal file
@@ -0,0 +1,249 @@
|
||||
# -*- test-case-name: twisted.test.test_socks -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Implementation of the SOCKSv4 protocol.
|
||||
"""
|
||||
|
||||
import socket
|
||||
import string
|
||||
|
||||
# python imports
|
||||
import struct
|
||||
import time
|
||||
|
||||
# twisted imports
|
||||
from twisted.internet import defer, protocol, reactor
|
||||
from twisted.python import log
|
||||
|
||||
|
||||
class SOCKSv4Outgoing(protocol.Protocol):
|
||||
def __init__(self, socks):
|
||||
self.socks = socks
|
||||
|
||||
def connectionMade(self):
|
||||
peer = self.transport.getPeer()
|
||||
self.socks.makeReply(90, 0, port=peer.port, ip=peer.host)
|
||||
self.socks.otherConn = self
|
||||
|
||||
def connectionLost(self, reason):
|
||||
self.socks.transport.loseConnection()
|
||||
|
||||
def dataReceived(self, data):
|
||||
self.socks.write(data)
|
||||
|
||||
def write(self, data):
|
||||
self.socks.log(self, data)
|
||||
self.transport.write(data)
|
||||
|
||||
|
||||
class SOCKSv4Incoming(protocol.Protocol):
|
||||
def __init__(self, socks):
|
||||
self.socks = socks
|
||||
self.socks.otherConn = self
|
||||
|
||||
def connectionLost(self, reason):
|
||||
self.socks.transport.loseConnection()
|
||||
|
||||
def dataReceived(self, data):
|
||||
self.socks.write(data)
|
||||
|
||||
def write(self, data):
|
||||
self.socks.log(self, data)
|
||||
self.transport.write(data)
|
||||
|
||||
|
||||
class SOCKSv4(protocol.Protocol):
|
||||
"""
|
||||
An implementation of the SOCKSv4 protocol.
|
||||
|
||||
@type logging: L{str} or L{None}
|
||||
@ivar logging: If not L{None}, the name of the logfile to which connection
|
||||
information will be written.
|
||||
|
||||
@type reactor: object providing L{twisted.internet.interfaces.IReactorTCP}
|
||||
@ivar reactor: The reactor used to create connections.
|
||||
|
||||
@type buf: L{str}
|
||||
@ivar buf: Part of a SOCKSv4 connection request.
|
||||
|
||||
@type otherConn: C{SOCKSv4Incoming}, C{SOCKSv4Outgoing} or L{None}
|
||||
@ivar otherConn: Until the connection has been established, C{otherConn} is
|
||||
L{None}. After that, it is the proxy-to-destination protocol instance
|
||||
along which the client's connection is being forwarded.
|
||||
"""
|
||||
|
||||
def __init__(self, logging=None, reactor=reactor):
|
||||
self.logging = logging
|
||||
self.reactor = reactor
|
||||
|
||||
def connectionMade(self):
|
||||
self.buf = b""
|
||||
self.otherConn = None
|
||||
|
||||
def dataReceived(self, data):
|
||||
"""
|
||||
Called whenever data is received.
|
||||
|
||||
@type data: L{bytes}
|
||||
@param data: Part or all of a SOCKSv4 packet.
|
||||
"""
|
||||
if self.otherConn:
|
||||
self.otherConn.write(data)
|
||||
return
|
||||
self.buf = self.buf + data
|
||||
completeBuffer = self.buf
|
||||
if b"\000" in self.buf[8:]:
|
||||
head, self.buf = self.buf[:8], self.buf[8:]
|
||||
version, code, port = struct.unpack("!BBH", head[:4])
|
||||
user, self.buf = self.buf.split(b"\000", 1)
|
||||
if head[4:7] == b"\000\000\000" and head[7:8] != b"\000":
|
||||
# An IP address of the form 0.0.0.X, where X is non-zero,
|
||||
# signifies that this is a SOCKSv4a packet.
|
||||
# If the complete packet hasn't been received, restore the
|
||||
# buffer and wait for it.
|
||||
if b"\000" not in self.buf:
|
||||
self.buf = completeBuffer
|
||||
return
|
||||
server, self.buf = self.buf.split(b"\000", 1)
|
||||
d = self.reactor.resolve(server)
|
||||
d.addCallback(self._dataReceived2, user, version, code, port)
|
||||
d.addErrback(lambda result, self=self: self.makeReply(91))
|
||||
return
|
||||
else:
|
||||
server = socket.inet_ntoa(head[4:8])
|
||||
|
||||
self._dataReceived2(server, user, version, code, port)
|
||||
|
||||
def _dataReceived2(self, server, user, version, code, port):
|
||||
"""
|
||||
The second half of the SOCKS connection setup. For a SOCKSv4 packet this
|
||||
is after the server address has been extracted from the header. For a
|
||||
SOCKSv4a packet this is after the host name has been resolved.
|
||||
|
||||
@type server: L{str}
|
||||
@param server: The IP address of the destination, represented as a
|
||||
dotted quad.
|
||||
|
||||
@type user: L{str}
|
||||
@param user: The username associated with the connection.
|
||||
|
||||
@type version: L{int}
|
||||
@param version: The SOCKS protocol version number.
|
||||
|
||||
@type code: L{int}
|
||||
@param code: The command code. 1 means establish a TCP/IP stream
|
||||
connection, and 2 means establish a TCP/IP port binding.
|
||||
|
||||
@type port: L{int}
|
||||
@param port: The port number associated with the connection.
|
||||
"""
|
||||
assert version == 4, "Bad version code: %s" % version
|
||||
if not self.authorize(code, server, port, user):
|
||||
self.makeReply(91)
|
||||
return
|
||||
if code == 1: # CONNECT
|
||||
d = self.connectClass(server, port, SOCKSv4Outgoing, self)
|
||||
d.addErrback(lambda result, self=self: self.makeReply(91))
|
||||
elif code == 2: # BIND
|
||||
d = self.listenClass(0, SOCKSv4IncomingFactory, self, server)
|
||||
d.addCallback(lambda x, self=self: self.makeReply(90, 0, x[1], x[0]))
|
||||
else:
|
||||
raise RuntimeError(f"Bad Connect Code: {code}")
|
||||
assert self.buf == b"", "hmm, still stuff in buffer... %s" % repr(self.buf)
|
||||
|
||||
def connectionLost(self, reason):
|
||||
if self.otherConn:
|
||||
self.otherConn.transport.loseConnection()
|
||||
|
||||
def authorize(self, code, server, port, user):
|
||||
log.msg(
|
||||
"code %s connection to %s:%s (user %s) authorized"
|
||||
% (code, server, port, user)
|
||||
)
|
||||
return 1
|
||||
|
||||
def connectClass(self, host, port, klass, *args):
|
||||
return protocol.ClientCreator(reactor, klass, *args).connectTCP(host, port)
|
||||
|
||||
def listenClass(self, port, klass, *args):
|
||||
serv = reactor.listenTCP(port, klass(*args))
|
||||
return defer.succeed(serv.getHost()[1:])
|
||||
|
||||
def makeReply(self, reply, version=0, port=0, ip="0.0.0.0"):
|
||||
self.transport.write(
|
||||
struct.pack("!BBH", version, reply, port) + socket.inet_aton(ip)
|
||||
)
|
||||
if reply != 90:
|
||||
self.transport.loseConnection()
|
||||
|
||||
def write(self, data):
|
||||
self.log(self, data)
|
||||
self.transport.write(data)
|
||||
|
||||
def log(self, proto, data):
|
||||
if not self.logging:
|
||||
return
|
||||
peer = self.transport.getPeer()
|
||||
their_peer = self.otherConn.transport.getPeer()
|
||||
f = open(self.logging, "a")
|
||||
f.write(
|
||||
"%s\t%s:%d %s %s:%d\n"
|
||||
% (
|
||||
time.ctime(),
|
||||
peer.host,
|
||||
peer.port,
|
||||
((proto == self and "<") or ">"),
|
||||
their_peer.host,
|
||||
their_peer.port,
|
||||
)
|
||||
)
|
||||
while data:
|
||||
p, data = data[:16], data[16:]
|
||||
f.write(string.join(map(lambda x: "%02X" % ord(x), p), " ") + " ")
|
||||
f.write((16 - len(p)) * 3 * " ")
|
||||
for c in p:
|
||||
if len(repr(c)) > 3:
|
||||
f.write(".")
|
||||
else:
|
||||
f.write(c)
|
||||
f.write("\n")
|
||||
f.write("\n")
|
||||
f.close()
|
||||
|
||||
|
||||
class SOCKSv4Factory(protocol.Factory):
|
||||
"""
|
||||
A factory for a SOCKSv4 proxy.
|
||||
|
||||
Constructor accepts one argument, a log file name.
|
||||
"""
|
||||
|
||||
def __init__(self, log):
|
||||
self.logging = log
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
return SOCKSv4(self.logging, reactor)
|
||||
|
||||
|
||||
class SOCKSv4IncomingFactory(protocol.Factory):
|
||||
"""
|
||||
A utility class for building protocols for incoming connections.
|
||||
"""
|
||||
|
||||
def __init__(self, socks, ip):
|
||||
self.socks = socks
|
||||
self.ip = ip
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
if addr[0] == self.ip:
|
||||
self.ip = ""
|
||||
self.socks.makeReply(90, 0)
|
||||
return SOCKSv4Incoming(self.socks)
|
||||
elif self.ip == "":
|
||||
return None
|
||||
else:
|
||||
self.socks.makeReply(91, 0)
|
||||
self.ip = ""
|
||||
return None
|
||||
@@ -0,0 +1,52 @@
|
||||
# -*- test-case-name: twisted.test.test_stateful -*-
|
||||
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
|
||||
from io import BytesIO
|
||||
|
||||
from twisted.internet import protocol
|
||||
|
||||
|
||||
class StatefulProtocol(protocol.Protocol):
|
||||
"""A Protocol that stores state for you.
|
||||
|
||||
state is a pair (function, num_bytes). When num_bytes bytes of data arrives
|
||||
from the network, function is called. It is expected to return the next
|
||||
state or None to keep same state. Initial state is returned by
|
||||
getInitialState (override it).
|
||||
"""
|
||||
|
||||
_sful_data = None, None, 0
|
||||
|
||||
def makeConnection(self, transport):
|
||||
protocol.Protocol.makeConnection(self, transport)
|
||||
self._sful_data = self.getInitialState(), BytesIO(), 0
|
||||
|
||||
def getInitialState(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def dataReceived(self, data):
|
||||
state, buffer, offset = self._sful_data
|
||||
buffer.seek(0, 2)
|
||||
buffer.write(data)
|
||||
blen = buffer.tell() # how many bytes total is in the buffer
|
||||
buffer.seek(offset)
|
||||
while blen - offset >= state[1]:
|
||||
d = buffer.read(state[1])
|
||||
offset += state[1]
|
||||
next = state[0](d)
|
||||
if (
|
||||
self.transport.disconnecting
|
||||
): # XXX: argh stupid hack borrowed right from LineReceiver
|
||||
return # dataReceived won't be called again, so who cares about consistent state
|
||||
if next:
|
||||
state = next
|
||||
if offset != 0:
|
||||
b = buffer.read()
|
||||
buffer.seek(0)
|
||||
buffer.truncate()
|
||||
buffer.write(b)
|
||||
offset = 0
|
||||
self._sful_data = state, buffer, offset
|
||||
@@ -0,0 +1,6 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Unit tests for L{twisted.protocols}.
|
||||
"""
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
936
.venv/lib/python3.12/site-packages/twisted/protocols/tls.py
Normal file
936
.venv/lib/python3.12/site-packages/twisted/protocols/tls.py
Normal file
@@ -0,0 +1,936 @@
|
||||
# -*- test-case-name: twisted.protocols.test.test_tls,twisted.internet.test.test_tls,twisted.test.test_sslverify -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Implementation of a TLS transport (L{ISSLTransport}) as an
|
||||
L{IProtocol<twisted.internet.interfaces.IProtocol>} layered on top of any
|
||||
L{ITransport<twisted.internet.interfaces.ITransport>} implementation, based on
|
||||
U{OpenSSL<http://www.openssl.org>}'s memory BIO features.
|
||||
|
||||
L{TLSMemoryBIOFactory} is a L{WrappingFactory} which wraps protocols created by
|
||||
the factory it wraps with L{TLSMemoryBIOProtocol}. L{TLSMemoryBIOProtocol}
|
||||
intercedes between the underlying transport and the wrapped protocol to
|
||||
implement SSL and TLS. Typical usage of this module looks like this::
|
||||
|
||||
from twisted.protocols.tls import TLSMemoryBIOFactory
|
||||
from twisted.internet.protocol import ServerFactory
|
||||
from twisted.internet.ssl import PrivateCertificate
|
||||
from twisted.internet import reactor
|
||||
|
||||
from someapplication import ApplicationProtocol
|
||||
|
||||
serverFactory = ServerFactory()
|
||||
serverFactory.protocol = ApplicationProtocol
|
||||
certificate = PrivateCertificate.loadPEM(certPEMData)
|
||||
contextFactory = certificate.options()
|
||||
tlsFactory = TLSMemoryBIOFactory(contextFactory, False, serverFactory)
|
||||
reactor.listenTCP(12345, tlsFactory)
|
||||
reactor.run()
|
||||
|
||||
This API offers somewhat more flexibility than
|
||||
L{twisted.internet.interfaces.IReactorSSL}; for example, a
|
||||
L{TLSMemoryBIOProtocol} instance can use another instance of
|
||||
L{TLSMemoryBIOProtocol} as its transport, yielding TLS over TLS - useful to
|
||||
implement onion routing. It can also be used to run TLS over unusual
|
||||
transports, such as UNIX sockets and stdio.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Callable, Iterable, Optional, cast
|
||||
|
||||
from zope.interface import directlyProvides, implementer, providedBy
|
||||
|
||||
from OpenSSL.SSL import Connection, Error, SysCallError, WantReadError, ZeroReturnError
|
||||
|
||||
from twisted.internet._producer_helpers import _PullToPush
|
||||
from twisted.internet._sslverify import _setAcceptableProtocols
|
||||
from twisted.internet.interfaces import (
|
||||
IDelayedCall,
|
||||
IHandshakeListener,
|
||||
ILoggingContext,
|
||||
INegotiated,
|
||||
IOpenSSLClientConnectionCreator,
|
||||
IOpenSSLServerConnectionCreator,
|
||||
IProtocol,
|
||||
IProtocolNegotiationFactory,
|
||||
IPushProducer,
|
||||
IReactorTime,
|
||||
ISystemHandle,
|
||||
ITransport,
|
||||
)
|
||||
from twisted.internet.main import CONNECTION_LOST
|
||||
from twisted.internet.protocol import Protocol
|
||||
from twisted.protocols.policies import ProtocolWrapper, WrappingFactory
|
||||
from twisted.python.failure import Failure
|
||||
|
||||
|
||||
@implementer(IPushProducer)
|
||||
class _ProducerMembrane:
|
||||
"""
|
||||
Stand-in for producer registered with a L{TLSMemoryBIOProtocol} transport.
|
||||
|
||||
Ensures that producer pause/resume events from the undelying transport are
|
||||
coordinated with pause/resume events from the TLS layer.
|
||||
|
||||
@ivar _producer: The application-layer producer.
|
||||
"""
|
||||
|
||||
_producerPaused = False
|
||||
|
||||
def __init__(self, producer):
|
||||
self._producer = producer
|
||||
|
||||
def pauseProducing(self):
|
||||
"""
|
||||
C{pauseProducing} the underlying producer, if it's not paused.
|
||||
"""
|
||||
if self._producerPaused:
|
||||
return
|
||||
self._producerPaused = True
|
||||
self._producer.pauseProducing()
|
||||
|
||||
def resumeProducing(self):
|
||||
"""
|
||||
C{resumeProducing} the underlying producer, if it's paused.
|
||||
"""
|
||||
if not self._producerPaused:
|
||||
return
|
||||
self._producerPaused = False
|
||||
self._producer.resumeProducing()
|
||||
|
||||
def stopProducing(self):
|
||||
"""
|
||||
C{stopProducing} the underlying producer.
|
||||
|
||||
There is only a single source for this event, so it's simply passed
|
||||
on.
|
||||
"""
|
||||
self._producer.stopProducing()
|
||||
|
||||
|
||||
def _representsEOF(exceptionObject: Error) -> bool:
|
||||
"""
|
||||
Does the given OpenSSL.SSL.Error represent an end-of-file?
|
||||
"""
|
||||
reasonString: str
|
||||
if isinstance(exceptionObject, SysCallError):
|
||||
_, reasonString = exceptionObject.args
|
||||
else:
|
||||
errorQueue = exceptionObject.args[0]
|
||||
_, _, reasonString = errorQueue[-1]
|
||||
return reasonString.casefold().startswith("unexpected eof")
|
||||
|
||||
|
||||
@implementer(ISystemHandle, INegotiated, ITransport)
|
||||
class TLSMemoryBIOProtocol(ProtocolWrapper):
|
||||
"""
|
||||
L{TLSMemoryBIOProtocol} is a protocol wrapper which uses OpenSSL via a
|
||||
memory BIO to encrypt bytes written to it before sending them on to the
|
||||
underlying transport and decrypts bytes received from the underlying
|
||||
transport before delivering them to the wrapped protocol.
|
||||
|
||||
In addition to producer events from the underlying transport, the need to
|
||||
wait for reads before a write can proceed means the L{TLSMemoryBIOProtocol}
|
||||
may also want to pause a producer. Pause/resume events are therefore
|
||||
merged using the L{_ProducerMembrane} wrapper. Non-streaming (pull)
|
||||
producers are supported by wrapping them with L{_PullToPush}.
|
||||
|
||||
Because TLS may need to wait for reads before writing, some writes may be
|
||||
buffered until a read occurs.
|
||||
|
||||
@ivar _tlsConnection: The L{OpenSSL.SSL.Connection} instance which is
|
||||
encrypted and decrypting this connection.
|
||||
|
||||
@ivar _lostTLSConnection: A flag indicating whether connection loss has
|
||||
already been dealt with (C{True}) or not (C{False}). TLS disconnection
|
||||
is distinct from the underlying connection being lost.
|
||||
|
||||
@ivar _appSendBuffer: application-level (cleartext) data that is waiting to
|
||||
be transferred to the TLS buffer, but can't be because the TLS
|
||||
connection is handshaking.
|
||||
@type _appSendBuffer: L{list} of L{bytes}
|
||||
|
||||
@ivar _connectWrapped: A flag indicating whether or not to call
|
||||
C{makeConnection} on the wrapped protocol. This is for the reactor's
|
||||
L{twisted.internet.interfaces.ITLSTransport.startTLS} implementation,
|
||||
since it has a protocol which it has already called C{makeConnection}
|
||||
on, and which has no interest in a new transport. See #3821.
|
||||
|
||||
@ivar _handshakeDone: A flag indicating whether or not the handshake is
|
||||
known to have completed successfully (C{True}) or not (C{False}). This
|
||||
is used to control error reporting behavior. If the handshake has not
|
||||
completed, the underlying L{OpenSSL.SSL.Error} will be passed to the
|
||||
application's C{connectionLost} method. If it has completed, any
|
||||
unexpected L{OpenSSL.SSL.Error} will be turned into a
|
||||
L{ConnectionLost}. This is weird; however, it is simply an attempt at
|
||||
a faithful re-implementation of the behavior provided by
|
||||
L{twisted.internet.ssl}.
|
||||
|
||||
@ivar _reason: If an unexpected L{OpenSSL.SSL.Error} occurs which causes
|
||||
the connection to be lost, it is saved here. If appropriate, this may
|
||||
be used as the reason passed to the application protocol's
|
||||
C{connectionLost} method.
|
||||
|
||||
@ivar _producer: The current producer registered via C{registerProducer},
|
||||
or L{None} if no producer has been registered or a previous one was
|
||||
unregistered.
|
||||
|
||||
@ivar _aborted: C{abortConnection} has been called. No further data will
|
||||
be received to the wrapped protocol's C{dataReceived}.
|
||||
@type _aborted: L{bool}
|
||||
"""
|
||||
|
||||
_reason = None
|
||||
_handshakeDone = False
|
||||
_lostTLSConnection = False
|
||||
_producer = None
|
||||
_aborted = False
|
||||
|
||||
def __init__(self, factory, wrappedProtocol, _connectWrapped=True):
|
||||
ProtocolWrapper.__init__(self, factory, wrappedProtocol)
|
||||
self._connectWrapped = _connectWrapped
|
||||
|
||||
def getHandle(self):
|
||||
"""
|
||||
Return the L{OpenSSL.SSL.Connection} object being used to encrypt and
|
||||
decrypt this connection.
|
||||
|
||||
This is done for the benefit of L{twisted.internet.ssl.Certificate}'s
|
||||
C{peerFromTransport} and C{hostFromTransport} methods only. A
|
||||
different system handle may be returned by future versions of this
|
||||
method.
|
||||
"""
|
||||
return self._tlsConnection
|
||||
|
||||
def makeConnection(self, transport):
|
||||
"""
|
||||
Connect this wrapper to the given transport and initialize the
|
||||
necessary L{OpenSSL.SSL.Connection} with a memory BIO.
|
||||
"""
|
||||
self._tlsConnection = self.factory._createConnection(self)
|
||||
self._appSendBuffer = []
|
||||
|
||||
# Add interfaces provided by the transport we are wrapping:
|
||||
for interface in providedBy(transport):
|
||||
directlyProvides(self, interface)
|
||||
|
||||
# Intentionally skip ProtocolWrapper.makeConnection - it might call
|
||||
# wrappedProtocol.makeConnection, which we want to make conditional.
|
||||
Protocol.makeConnection(self, transport)
|
||||
self.factory.registerProtocol(self)
|
||||
if self._connectWrapped:
|
||||
# Now that the TLS layer is initialized, notify the application of
|
||||
# the connection.
|
||||
ProtocolWrapper.makeConnection(self, transport)
|
||||
|
||||
# Now that we ourselves have a transport (initialized by the
|
||||
# ProtocolWrapper.makeConnection call above), kick off the TLS
|
||||
# handshake.
|
||||
self._checkHandshakeStatus()
|
||||
|
||||
def _checkHandshakeStatus(self):
|
||||
"""
|
||||
Ask OpenSSL to proceed with a handshake in progress.
|
||||
|
||||
Initially, this just sends the ClientHello; after some bytes have been
|
||||
stuffed in to the C{Connection} object by C{dataReceived}, it will then
|
||||
respond to any C{Certificate} or C{KeyExchange} messages.
|
||||
"""
|
||||
# The connection might already be aborted (eg. by a callback during
|
||||
# connection setup), so don't even bother trying to handshake in that
|
||||
# case.
|
||||
if self._aborted:
|
||||
return
|
||||
try:
|
||||
self._tlsConnection.do_handshake()
|
||||
except WantReadError:
|
||||
self._flushSendBIO()
|
||||
except Error:
|
||||
self._tlsShutdownFinished(Failure())
|
||||
else:
|
||||
self._handshakeDone = True
|
||||
if IHandshakeListener.providedBy(self.wrappedProtocol):
|
||||
self.wrappedProtocol.handshakeCompleted()
|
||||
|
||||
def _flushSendBIO(self):
|
||||
"""
|
||||
Read any bytes out of the send BIO and write them to the underlying
|
||||
transport.
|
||||
"""
|
||||
try:
|
||||
bytes = self._tlsConnection.bio_read(2**15)
|
||||
except WantReadError:
|
||||
# There may be nothing in the send BIO right now.
|
||||
pass
|
||||
else:
|
||||
self.transport.write(bytes)
|
||||
|
||||
def _flushReceiveBIO(self):
|
||||
"""
|
||||
Try to receive any application-level bytes which are now available
|
||||
because of a previous write into the receive BIO. This will take
|
||||
care of delivering any application-level bytes which are received to
|
||||
the protocol, as well as handling of the various exceptions which
|
||||
can come from trying to get such bytes.
|
||||
"""
|
||||
# Keep trying this until an error indicates we should stop or we
|
||||
# close the connection. Looping is necessary to make sure we
|
||||
# process all of the data which was put into the receive BIO, as
|
||||
# there is no guarantee that a single recv call will do it all.
|
||||
while not self._lostTLSConnection:
|
||||
try:
|
||||
bytes = self._tlsConnection.recv(2**15)
|
||||
except WantReadError:
|
||||
# The newly received bytes might not have been enough to produce
|
||||
# any application data.
|
||||
break
|
||||
except ZeroReturnError:
|
||||
# TLS has shut down and no more TLS data will be received over
|
||||
# this connection.
|
||||
self._shutdownTLS()
|
||||
# Passing in None means the user protocol's connnectionLost
|
||||
# will get called with reason from underlying transport:
|
||||
self._tlsShutdownFinished(None)
|
||||
except Error:
|
||||
# Something went pretty wrong. For example, this might be a
|
||||
# handshake failure during renegotiation (because there were no
|
||||
# shared ciphers, because a certificate failed to verify, etc).
|
||||
# TLS can no longer proceed.
|
||||
failure = Failure()
|
||||
self._tlsShutdownFinished(failure)
|
||||
else:
|
||||
if not self._aborted:
|
||||
ProtocolWrapper.dataReceived(self, bytes)
|
||||
|
||||
# The received bytes might have generated a response which needs to be
|
||||
# sent now. For example, the handshake involves several round-trip
|
||||
# exchanges without ever producing application-bytes.
|
||||
self._flushSendBIO()
|
||||
|
||||
def dataReceived(self, bytes):
|
||||
"""
|
||||
Deliver any received bytes to the receive BIO and then read and deliver
|
||||
to the application any application-level data which becomes available
|
||||
as a result of this.
|
||||
"""
|
||||
# Let OpenSSL know some bytes were just received.
|
||||
self._tlsConnection.bio_write(bytes)
|
||||
|
||||
# If we are still waiting for the handshake to complete, try to
|
||||
# complete the handshake with the bytes we just received.
|
||||
if not self._handshakeDone:
|
||||
self._checkHandshakeStatus()
|
||||
|
||||
# If the handshake still isn't finished, then we've nothing left to
|
||||
# do.
|
||||
if not self._handshakeDone:
|
||||
return
|
||||
|
||||
# If we've any pending writes, this read may have un-blocked them, so
|
||||
# attempt to unbuffer them into the OpenSSL layer.
|
||||
if self._appSendBuffer:
|
||||
self._unbufferPendingWrites()
|
||||
|
||||
# Since the handshake is complete, the wire-level bytes we just
|
||||
# processed might turn into some application-level bytes; try to pull
|
||||
# those out.
|
||||
self._flushReceiveBIO()
|
||||
|
||||
def _shutdownTLS(self):
|
||||
"""
|
||||
Initiate, or reply to, the shutdown handshake of the TLS layer.
|
||||
"""
|
||||
try:
|
||||
shutdownSuccess = self._tlsConnection.shutdown()
|
||||
except Error:
|
||||
# Mid-handshake, a call to shutdown() can result in a
|
||||
# WantWantReadError, or rather an SSL_ERR_WANT_READ; but pyOpenSSL
|
||||
# doesn't allow us to get at the error. See:
|
||||
# https://github.com/pyca/pyopenssl/issues/91
|
||||
shutdownSuccess = False
|
||||
self._flushSendBIO()
|
||||
if shutdownSuccess:
|
||||
# Both sides have shutdown, so we can start closing lower-level
|
||||
# transport. This will also happen if we haven't started
|
||||
# negotiation at all yet, in which case shutdown succeeds
|
||||
# immediately.
|
||||
self.transport.loseConnection()
|
||||
|
||||
def _tlsShutdownFinished(self, reason):
|
||||
"""
|
||||
Called when TLS connection has gone away; tell underlying transport to
|
||||
disconnect.
|
||||
|
||||
@param reason: a L{Failure} whose value is an L{Exception} if we want to
|
||||
report that failure through to the wrapped protocol's
|
||||
C{connectionLost}, or L{None} if the C{reason} that
|
||||
C{connectionLost} should receive should be coming from the
|
||||
underlying transport.
|
||||
@type reason: L{Failure} or L{None}
|
||||
"""
|
||||
if reason is not None:
|
||||
# Squash an EOF in violation of the TLS protocol into
|
||||
# ConnectionLost, so that applications which might run over
|
||||
# multiple protocols can recognize its type.
|
||||
if _representsEOF(reason.value):
|
||||
reason = Failure(CONNECTION_LOST)
|
||||
if self._reason is None:
|
||||
self._reason = reason
|
||||
self._lostTLSConnection = True
|
||||
# We may need to send a TLS alert regarding the nature of the shutdown
|
||||
# here (for example, why a handshake failed), so always flush our send
|
||||
# buffer before telling our lower-level transport to go away.
|
||||
self._flushSendBIO()
|
||||
# Using loseConnection causes the application protocol's
|
||||
# connectionLost method to be invoked non-reentrantly, which is always
|
||||
# a nice feature. However, for error cases (reason != None) we might
|
||||
# want to use abortConnection when it becomes available. The
|
||||
# loseConnection call is basically tested by test_handshakeFailure.
|
||||
# At least one side will need to do it or the test never finishes.
|
||||
self.transport.loseConnection()
|
||||
|
||||
def connectionLost(self, reason):
|
||||
"""
|
||||
Handle the possible repetition of calls to this method (due to either
|
||||
the underlying transport going away or due to an error at the TLS
|
||||
layer) and make sure the base implementation only gets invoked once.
|
||||
"""
|
||||
if not self._lostTLSConnection:
|
||||
# Tell the TLS connection that it's not going to get any more data
|
||||
# and give it a chance to finish reading.
|
||||
self._tlsConnection.bio_shutdown()
|
||||
self._flushReceiveBIO()
|
||||
self._lostTLSConnection = True
|
||||
reason = self._reason or reason
|
||||
self._reason = None
|
||||
self.connected = False
|
||||
ProtocolWrapper.connectionLost(self, reason)
|
||||
|
||||
# Breaking reference cycle between self._tlsConnection and self.
|
||||
self._tlsConnection = None
|
||||
|
||||
def loseConnection(self):
|
||||
"""
|
||||
Send a TLS close alert and close the underlying connection.
|
||||
"""
|
||||
if self.disconnecting or not self.connected:
|
||||
return
|
||||
# If connection setup has not finished, OpenSSL 1.0.2f+ will not shut
|
||||
# down the connection until we write some data to the connection which
|
||||
# allows the handshake to complete. However, since no data should be
|
||||
# written after loseConnection, this means we'll be stuck forever
|
||||
# waiting for shutdown to complete. Instead, we simply abort the
|
||||
# connection without trying to shut down cleanly:
|
||||
if not self._handshakeDone and not self._appSendBuffer:
|
||||
self.abortConnection()
|
||||
self.disconnecting = True
|
||||
if not self._appSendBuffer and self._producer is None:
|
||||
self._shutdownTLS()
|
||||
|
||||
def abortConnection(self):
|
||||
"""
|
||||
Tear down TLS state so that if the connection is aborted mid-handshake
|
||||
we don't deliver any further data from the application.
|
||||
"""
|
||||
self._aborted = True
|
||||
self.disconnecting = True
|
||||
self._shutdownTLS()
|
||||
self.transport.abortConnection()
|
||||
|
||||
def failVerification(self, reason):
|
||||
"""
|
||||
Abort the connection during connection setup, giving a reason that
|
||||
certificate verification failed.
|
||||
|
||||
@param reason: The reason that the verification failed; reported to the
|
||||
application protocol's C{connectionLost} method.
|
||||
@type reason: L{Failure}
|
||||
"""
|
||||
self._reason = reason
|
||||
self.abortConnection()
|
||||
|
||||
def write(self, bytes):
|
||||
"""
|
||||
Process the given application bytes and send any resulting TLS traffic
|
||||
which arrives in the send BIO.
|
||||
|
||||
If C{loseConnection} was called, subsequent calls to C{write} will
|
||||
drop the bytes on the floor.
|
||||
"""
|
||||
# Writes after loseConnection are not supported, unless a producer has
|
||||
# been registered, in which case writes can happen until the producer
|
||||
# is unregistered:
|
||||
if self.disconnecting and self._producer is None:
|
||||
return
|
||||
self._write(bytes)
|
||||
|
||||
def _bufferedWrite(self, octets):
|
||||
"""
|
||||
Put the given octets into L{TLSMemoryBIOProtocol._appSendBuffer}, and
|
||||
tell any listening producer that it should pause because we are now
|
||||
buffering.
|
||||
"""
|
||||
self._appSendBuffer.append(octets)
|
||||
if self._producer is not None:
|
||||
self._producer.pauseProducing()
|
||||
|
||||
def _unbufferPendingWrites(self):
|
||||
"""
|
||||
Un-buffer all waiting writes in L{TLSMemoryBIOProtocol._appSendBuffer}.
|
||||
"""
|
||||
pendingWrites, self._appSendBuffer = self._appSendBuffer, []
|
||||
for eachWrite in pendingWrites:
|
||||
self._write(eachWrite)
|
||||
|
||||
if self._appSendBuffer:
|
||||
# If OpenSSL ran out of buffer space in the Connection on our way
|
||||
# through the loop earlier and re-buffered any of our outgoing
|
||||
# writes, then we're done; don't consider any future work.
|
||||
return
|
||||
|
||||
if self._producer is not None:
|
||||
# If we have a registered producer, let it know that we have some
|
||||
# more buffer space.
|
||||
self._producer.resumeProducing()
|
||||
return
|
||||
|
||||
if self.disconnecting:
|
||||
# Finally, if we have no further buffered data, no producer wants
|
||||
# to send us more data in the future, and the application told us
|
||||
# to end the stream, initiate a TLS shutdown.
|
||||
self._shutdownTLS()
|
||||
|
||||
def _write(self, bytes):
|
||||
"""
|
||||
Process the given application bytes and send any resulting TLS traffic
|
||||
which arrives in the send BIO.
|
||||
|
||||
This may be called by C{dataReceived} with bytes that were buffered
|
||||
before C{loseConnection} was called, which is why this function
|
||||
doesn't check for disconnection but accepts the bytes regardless.
|
||||
"""
|
||||
if self._lostTLSConnection:
|
||||
return
|
||||
|
||||
# A TLS payload is 16kB max
|
||||
bufferSize = 2**14
|
||||
|
||||
# How far into the input we've gotten so far
|
||||
alreadySent = 0
|
||||
|
||||
while alreadySent < len(bytes):
|
||||
toSend = bytes[alreadySent : alreadySent + bufferSize]
|
||||
try:
|
||||
sent = self._tlsConnection.send(toSend)
|
||||
except WantReadError:
|
||||
self._bufferedWrite(bytes[alreadySent:])
|
||||
break
|
||||
except Error:
|
||||
# Pretend TLS connection disconnected, which will trigger
|
||||
# disconnect of underlying transport. The error will be passed
|
||||
# to the application protocol's connectionLost method. The
|
||||
# other SSL implementation doesn't, but losing helpful
|
||||
# debugging information is a bad idea.
|
||||
self._tlsShutdownFinished(Failure())
|
||||
break
|
||||
else:
|
||||
# We've successfully handed off the bytes to the OpenSSL
|
||||
# Connection object.
|
||||
alreadySent += sent
|
||||
# See if OpenSSL wants to hand any bytes off to the underlying
|
||||
# transport as a result.
|
||||
self._flushSendBIO()
|
||||
|
||||
def writeSequence(self, iovec):
|
||||
"""
|
||||
Write a sequence of application bytes by joining them into one string
|
||||
and passing them to L{write}.
|
||||
"""
|
||||
self.write(b"".join(iovec))
|
||||
|
||||
def getPeerCertificate(self):
|
||||
return self._tlsConnection.get_peer_certificate()
|
||||
|
||||
@property
|
||||
def negotiatedProtocol(self):
|
||||
"""
|
||||
@see: L{INegotiated.negotiatedProtocol}
|
||||
"""
|
||||
protocolName = None
|
||||
|
||||
try:
|
||||
# If ALPN is not implemented that's ok, NPN might be.
|
||||
protocolName = self._tlsConnection.get_alpn_proto_negotiated()
|
||||
except (NotImplementedError, AttributeError):
|
||||
pass
|
||||
|
||||
if protocolName not in (b"", None):
|
||||
# A protocol was selected using ALPN.
|
||||
return protocolName
|
||||
|
||||
try:
|
||||
protocolName = self._tlsConnection.get_next_proto_negotiated()
|
||||
except (NotImplementedError, AttributeError):
|
||||
pass
|
||||
|
||||
if protocolName != b"":
|
||||
return protocolName
|
||||
|
||||
return None
|
||||
|
||||
def registerProducer(self, producer, streaming):
|
||||
# If we've already disconnected, nothing to do here:
|
||||
if self._lostTLSConnection:
|
||||
producer.stopProducing()
|
||||
return
|
||||
|
||||
# If we received a non-streaming producer, wrap it so it becomes a
|
||||
# streaming producer:
|
||||
if not streaming:
|
||||
producer = streamingProducer = _PullToPush(producer, self)
|
||||
producer = _ProducerMembrane(producer)
|
||||
# This will raise an exception if a producer is already registered:
|
||||
self.transport.registerProducer(producer, True)
|
||||
self._producer = producer
|
||||
# If we received a non-streaming producer, we need to start the
|
||||
# streaming wrapper:
|
||||
if not streaming:
|
||||
streamingProducer.startStreaming()
|
||||
|
||||
def unregisterProducer(self):
|
||||
# If we have no producer, we don't need to do anything here.
|
||||
if self._producer is None:
|
||||
return
|
||||
|
||||
# If we received a non-streaming producer, we need to stop the
|
||||
# streaming wrapper:
|
||||
if isinstance(self._producer._producer, _PullToPush):
|
||||
self._producer._producer.stopStreaming()
|
||||
self._producer = None
|
||||
self._producerPaused = False
|
||||
self.transport.unregisterProducer()
|
||||
if self.disconnecting and not self._appSendBuffer:
|
||||
self._shutdownTLS()
|
||||
|
||||
|
||||
@implementer(IOpenSSLClientConnectionCreator, IOpenSSLServerConnectionCreator)
|
||||
class _ContextFactoryToConnectionFactory:
|
||||
"""
|
||||
Adapter wrapping a L{twisted.internet.interfaces.IOpenSSLContextFactory}
|
||||
into a L{IOpenSSLClientConnectionCreator} or
|
||||
L{IOpenSSLServerConnectionCreator}.
|
||||
|
||||
See U{https://twistedmatrix.com/trac/ticket/7215} for work that should make
|
||||
this unnecessary.
|
||||
"""
|
||||
|
||||
def __init__(self, oldStyleContextFactory):
|
||||
"""
|
||||
Construct a L{_ContextFactoryToConnectionFactory} with a
|
||||
L{twisted.internet.interfaces.IOpenSSLContextFactory}.
|
||||
|
||||
Immediately call C{getContext} on C{oldStyleContextFactory} in order to
|
||||
force advance parameter checking, since old-style context factories
|
||||
don't actually check that their arguments to L{OpenSSL} are correct.
|
||||
|
||||
@param oldStyleContextFactory: A factory that can produce contexts.
|
||||
@type oldStyleContextFactory:
|
||||
L{twisted.internet.interfaces.IOpenSSLContextFactory}
|
||||
"""
|
||||
oldStyleContextFactory.getContext()
|
||||
self._oldStyleContextFactory = oldStyleContextFactory
|
||||
|
||||
def _connectionForTLS(self, protocol):
|
||||
"""
|
||||
Create an L{OpenSSL.SSL.Connection} object.
|
||||
|
||||
@param protocol: The protocol initiating a TLS connection.
|
||||
@type protocol: L{TLSMemoryBIOProtocol}
|
||||
|
||||
@return: a connection
|
||||
@rtype: L{OpenSSL.SSL.Connection}
|
||||
"""
|
||||
context = self._oldStyleContextFactory.getContext()
|
||||
return Connection(context, None)
|
||||
|
||||
def serverConnectionForTLS(self, protocol):
|
||||
"""
|
||||
Construct an OpenSSL server connection from the wrapped old-style
|
||||
context factory.
|
||||
|
||||
@note: Since old-style context factories don't distinguish between
|
||||
clients and servers, this is exactly the same as
|
||||
L{_ContextFactoryToConnectionFactory.clientConnectionForTLS}.
|
||||
|
||||
@param protocol: The protocol initiating a TLS connection.
|
||||
@type protocol: L{TLSMemoryBIOProtocol}
|
||||
|
||||
@return: a connection
|
||||
@rtype: L{OpenSSL.SSL.Connection}
|
||||
"""
|
||||
return self._connectionForTLS(protocol)
|
||||
|
||||
def clientConnectionForTLS(self, protocol):
|
||||
"""
|
||||
Construct an OpenSSL server connection from the wrapped old-style
|
||||
context factory.
|
||||
|
||||
@note: Since old-style context factories don't distinguish between
|
||||
clients and servers, this is exactly the same as
|
||||
L{_ContextFactoryToConnectionFactory.serverConnectionForTLS}.
|
||||
|
||||
@param protocol: The protocol initiating a TLS connection.
|
||||
@type protocol: L{TLSMemoryBIOProtocol}
|
||||
|
||||
@return: a connection
|
||||
@rtype: L{OpenSSL.SSL.Connection}
|
||||
"""
|
||||
return self._connectionForTLS(protocol)
|
||||
|
||||
|
||||
class _AggregateSmallWrites:
|
||||
"""
|
||||
Aggregate small writes so they get written in large batches.
|
||||
|
||||
If this is used as part of a transport, the transport needs to call
|
||||
``flush()`` immediately when ``loseConnection()`` is called, otherwise any
|
||||
buffered writes will never get written.
|
||||
|
||||
@cvar MAX_BUFFER_SIZE: The maximum amount of bytes to buffer before writing
|
||||
them out.
|
||||
"""
|
||||
|
||||
MAX_BUFFER_SIZE = 64_000
|
||||
|
||||
def __init__(self, write: Callable[[bytes], object], clock: IReactorTime):
|
||||
self._write = write
|
||||
self._clock = clock
|
||||
self._buffer: list[bytes] = []
|
||||
self._bufferLeft = self.MAX_BUFFER_SIZE
|
||||
self._scheduled: Optional[IDelayedCall] = None
|
||||
|
||||
def write(self, data: bytes) -> None:
|
||||
"""
|
||||
Buffer the data, or write it immediately if we've accumulated enough to
|
||||
make it worth it.
|
||||
|
||||
Accumulating too much data can result in higher memory usage.
|
||||
"""
|
||||
self._buffer.append(data)
|
||||
self._bufferLeft -= len(data)
|
||||
|
||||
if self._bufferLeft < 0:
|
||||
# We've accumulated enough we should just write it out. No need to
|
||||
# schedule a flush, since we just flushed everything.
|
||||
self.flush()
|
||||
return
|
||||
|
||||
if self._scheduled:
|
||||
# We already have a scheduled send, so with the data in the buffer,
|
||||
# there is nothing more to do here.
|
||||
return
|
||||
|
||||
# Schedule the write of the accumulated buffer for the next reactor
|
||||
# iteration.
|
||||
self._scheduled = self._clock.callLater(0, self._scheduledFlush)
|
||||
|
||||
def _scheduledFlush(self) -> None:
|
||||
"""Called in next reactor iteration."""
|
||||
self._scheduled = None
|
||||
self.flush()
|
||||
|
||||
def flush(self) -> None:
|
||||
"""Flush any buffered writes."""
|
||||
if self._buffer:
|
||||
self._bufferLeft = self.MAX_BUFFER_SIZE
|
||||
self._write(b"".join(self._buffer))
|
||||
del self._buffer[:]
|
||||
|
||||
|
||||
def _get_default_clock() -> IReactorTime:
|
||||
"""
|
||||
Return the default reactor.
|
||||
|
||||
This is a function so it can be monkey-patched in tests, specifically
|
||||
L{twisted.web.test.test_agent}.
|
||||
"""
|
||||
from twisted.internet import reactor
|
||||
|
||||
return cast(IReactorTime, reactor)
|
||||
|
||||
|
||||
class BufferingTLSTransport(TLSMemoryBIOProtocol):
|
||||
"""
|
||||
A TLS transport implemented by wrapping buffering around a
|
||||
L{TLSMemoryBIOProtocol}.
|
||||
|
||||
Doing many small writes directly to a L{OpenSSL.SSL.Connection}, as
|
||||
implemented in L{TLSMemoryBIOProtocol}, can add significant CPU and
|
||||
bandwidth overhead. Thus, even when writing is possible, small writes will
|
||||
get aggregated and written as a single write at the next reactor iteration.
|
||||
"""
|
||||
|
||||
# Implementation Note: An implementation based on composition would be
|
||||
# nicer, but there's close integration between L{ProtocolWrapper}
|
||||
# subclasses like L{TLSMemoryBIOProtocol} and the corresponding factory. An
|
||||
# attempt to implement this with broke things like
|
||||
# L{TLSMemoryBIOFactory.protocols} having the correct instances, whereas
|
||||
# subclassing makes that work.
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
factory: TLSMemoryBIOFactory,
|
||||
wrappedProtocol: IProtocol,
|
||||
_connectWrapped: bool = True,
|
||||
):
|
||||
super().__init__(factory, wrappedProtocol, _connectWrapped)
|
||||
actual_write = super().write
|
||||
self._aggregator = _AggregateSmallWrites(actual_write, factory._clock)
|
||||
# This is kinda ugly, but speeds things up a lot in a hot path with
|
||||
# lots of small TLS writes. May become unnecessary in Python 3.13 or
|
||||
# later if JIT and/or inlining becomes a thing.
|
||||
self.write = self._aggregator.write # type: ignore[method-assign]
|
||||
|
||||
def writeSequence(self, sequence: Iterable[bytes]) -> None:
|
||||
self._aggregator.write(b"".join(sequence))
|
||||
|
||||
def loseConnection(self) -> None:
|
||||
self._aggregator.flush()
|
||||
super().loseConnection()
|
||||
|
||||
|
||||
class TLSMemoryBIOFactory(WrappingFactory):
|
||||
"""
|
||||
L{TLSMemoryBIOFactory} adds TLS to connections.
|
||||
|
||||
@ivar _creatorInterface: the interface which L{_connectionCreator} is
|
||||
expected to implement.
|
||||
@type _creatorInterface: L{zope.interface.interfaces.IInterface}
|
||||
|
||||
@ivar _connectionCreator: a callable which creates an OpenSSL Connection
|
||||
object.
|
||||
@type _connectionCreator: 1-argument callable taking
|
||||
L{TLSMemoryBIOProtocol} and returning L{OpenSSL.SSL.Connection}.
|
||||
"""
|
||||
|
||||
protocol = BufferingTLSTransport
|
||||
|
||||
noisy = False # disable unnecessary logging.
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
contextFactory,
|
||||
isClient,
|
||||
wrappedFactory,
|
||||
clock=None,
|
||||
):
|
||||
"""
|
||||
Create a L{TLSMemoryBIOFactory}.
|
||||
|
||||
@param contextFactory: Configuration parameters used to create an
|
||||
OpenSSL connection. In order of preference, what you should pass
|
||||
here should be:
|
||||
|
||||
1. L{twisted.internet.ssl.CertificateOptions} (if you're
|
||||
writing a server) or the result of
|
||||
L{twisted.internet.ssl.optionsForClientTLS} (if you're
|
||||
writing a client). If you want security you should really
|
||||
use one of these.
|
||||
|
||||
2. If you really want to implement something yourself, supply a
|
||||
provider of L{IOpenSSLClientConnectionCreator} or
|
||||
L{IOpenSSLServerConnectionCreator}.
|
||||
|
||||
3. If you really have to, supply a
|
||||
L{twisted.internet.ssl.ContextFactory}. This will likely be
|
||||
deprecated at some point so please upgrade to the new
|
||||
interfaces.
|
||||
|
||||
@type contextFactory: L{IOpenSSLClientConnectionCreator} or
|
||||
L{IOpenSSLServerConnectionCreator}, or, for compatibility with
|
||||
older code, anything implementing
|
||||
L{twisted.internet.interfaces.IOpenSSLContextFactory}. See
|
||||
U{https://twistedmatrix.com/trac/ticket/7215} for information on
|
||||
the upcoming deprecation of passing a
|
||||
L{twisted.internet.ssl.ContextFactory} here.
|
||||
|
||||
@param isClient: Is this a factory for TLS client connections; in other
|
||||
words, those that will send a C{ClientHello} greeting? L{True} if
|
||||
so, L{False} otherwise. This flag determines what interface is
|
||||
expected of C{contextFactory}. If L{True}, C{contextFactory}
|
||||
should provide L{IOpenSSLClientConnectionCreator}; otherwise it
|
||||
should provide L{IOpenSSLServerConnectionCreator}.
|
||||
@type isClient: L{bool}
|
||||
|
||||
@param wrappedFactory: A factory which will create the
|
||||
application-level protocol.
|
||||
@type wrappedFactory: L{twisted.internet.interfaces.IProtocolFactory}
|
||||
"""
|
||||
WrappingFactory.__init__(self, wrappedFactory)
|
||||
if isClient:
|
||||
creatorInterface = IOpenSSLClientConnectionCreator
|
||||
else:
|
||||
creatorInterface = IOpenSSLServerConnectionCreator
|
||||
self._creatorInterface = creatorInterface
|
||||
if not creatorInterface.providedBy(contextFactory):
|
||||
contextFactory = _ContextFactoryToConnectionFactory(contextFactory)
|
||||
self._connectionCreator = contextFactory
|
||||
|
||||
if clock is None:
|
||||
clock = _get_default_clock()
|
||||
self._clock = clock
|
||||
|
||||
def logPrefix(self):
|
||||
"""
|
||||
Annotate the wrapped factory's log prefix with some text indicating TLS
|
||||
is in use.
|
||||
|
||||
@rtype: C{str}
|
||||
"""
|
||||
if ILoggingContext.providedBy(self.wrappedFactory):
|
||||
logPrefix = self.wrappedFactory.logPrefix()
|
||||
else:
|
||||
logPrefix = self.wrappedFactory.__class__.__name__
|
||||
return f"{logPrefix} (TLS)"
|
||||
|
||||
def _applyProtocolNegotiation(self, connection):
|
||||
"""
|
||||
Applies ALPN/NPN protocol neogitation to the connection, if the factory
|
||||
supports it.
|
||||
|
||||
@param connection: The OpenSSL connection object to have ALPN/NPN added
|
||||
to it.
|
||||
@type connection: L{OpenSSL.SSL.Connection}
|
||||
|
||||
@return: Nothing
|
||||
@rtype: L{None}
|
||||
"""
|
||||
if IProtocolNegotiationFactory.providedBy(self.wrappedFactory):
|
||||
protocols = self.wrappedFactory.acceptableProtocols()
|
||||
context = connection.get_context()
|
||||
_setAcceptableProtocols(context, protocols)
|
||||
|
||||
return
|
||||
|
||||
def _createConnection(self, tlsProtocol):
|
||||
"""
|
||||
Create an OpenSSL connection and set it up good.
|
||||
|
||||
@param tlsProtocol: The protocol which is establishing the connection.
|
||||
@type tlsProtocol: L{TLSMemoryBIOProtocol}
|
||||
|
||||
@return: an OpenSSL connection object for C{tlsProtocol} to use
|
||||
@rtype: L{OpenSSL.SSL.Connection}
|
||||
"""
|
||||
connectionCreator = self._connectionCreator
|
||||
if self._creatorInterface is IOpenSSLClientConnectionCreator:
|
||||
connection = connectionCreator.clientConnectionForTLS(tlsProtocol)
|
||||
self._applyProtocolNegotiation(connection)
|
||||
connection.set_connect_state()
|
||||
else:
|
||||
connection = connectionCreator.serverConnectionForTLS(tlsProtocol)
|
||||
self._applyProtocolNegotiation(connection)
|
||||
connection.set_accept_state()
|
||||
return connection
|
||||
112
.venv/lib/python3.12/site-packages/twisted/protocols/wire.py
Normal file
112
.venv/lib/python3.12/site-packages/twisted/protocols/wire.py
Normal file
@@ -0,0 +1,112 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""Implement standard (and unused) TCP protocols.
|
||||
|
||||
These protocols are either provided by inetd, or are not provided at all.
|
||||
"""
|
||||
|
||||
|
||||
import struct
|
||||
import time
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.internet import interfaces, protocol
|
||||
|
||||
|
||||
class Echo(protocol.Protocol):
|
||||
"""
|
||||
As soon as any data is received, write it back (RFC 862).
|
||||
"""
|
||||
|
||||
def dataReceived(self, data):
|
||||
self.transport.write(data)
|
||||
|
||||
|
||||
class Discard(protocol.Protocol):
|
||||
"""
|
||||
Discard any received data (RFC 863).
|
||||
"""
|
||||
|
||||
def dataReceived(self, data):
|
||||
# I'm ignoring you, nyah-nyah
|
||||
pass
|
||||
|
||||
|
||||
@implementer(interfaces.IProducer)
|
||||
class Chargen(protocol.Protocol):
|
||||
"""
|
||||
Generate repeating noise (RFC 864).
|
||||
"""
|
||||
|
||||
noise = rb'@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\]^_`abcdefghijklmnopqrstuvwxyz{|}~ !"#$%&?'
|
||||
|
||||
def connectionMade(self):
|
||||
self.transport.registerProducer(self, 0)
|
||||
|
||||
def resumeProducing(self):
|
||||
self.transport.write(self.noise)
|
||||
|
||||
def pauseProducing(self):
|
||||
pass
|
||||
|
||||
def stopProducing(self):
|
||||
pass
|
||||
|
||||
|
||||
class QOTD(protocol.Protocol):
|
||||
"""
|
||||
Return a quote of the day (RFC 865).
|
||||
"""
|
||||
|
||||
def connectionMade(self):
|
||||
self.transport.write(self.getQuote())
|
||||
self.transport.loseConnection()
|
||||
|
||||
def getQuote(self):
|
||||
"""
|
||||
Return a quote. May be overrriden in subclasses.
|
||||
"""
|
||||
return b"An apple a day keeps the doctor away.\r\n"
|
||||
|
||||
|
||||
class Who(protocol.Protocol):
|
||||
"""
|
||||
Return list of active users (RFC 866)
|
||||
"""
|
||||
|
||||
def connectionMade(self):
|
||||
self.transport.write(self.getUsers())
|
||||
self.transport.loseConnection()
|
||||
|
||||
def getUsers(self):
|
||||
"""
|
||||
Return active users. Override in subclasses.
|
||||
"""
|
||||
return b"root\r\n"
|
||||
|
||||
|
||||
class Daytime(protocol.Protocol):
|
||||
"""
|
||||
Send back the daytime in ASCII form (RFC 867).
|
||||
"""
|
||||
|
||||
def connectionMade(self):
|
||||
self.transport.write(time.asctime(time.gmtime(time.time())) + b"\r\n")
|
||||
self.transport.loseConnection()
|
||||
|
||||
|
||||
class Time(protocol.Protocol):
|
||||
"""
|
||||
Send back the time in machine readable form (RFC 868).
|
||||
"""
|
||||
|
||||
def connectionMade(self):
|
||||
# is this correct only for 32-bit machines?
|
||||
result = struct.pack("!i", int(time.time()))
|
||||
self.transport.write(result)
|
||||
self.transport.loseConnection()
|
||||
|
||||
|
||||
__all__ = ["Echo", "Discard", "Chargen", "QOTD", "Who", "Daytime", "Time"]
|
||||
Reference in New Issue
Block a user