okay fine

This commit is contained in:
pacnpal
2024-11-03 17:47:26 +00:00
parent 01c6004a79
commit 27eb239e97
10020 changed files with 1935769 additions and 2364 deletions

View File

@@ -0,0 +1,6 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Twisted Protocols: A collection of internet protocol implementations.
"""

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,912 @@
# -*- test-case-name: twisted.protocols.test.test_basic -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Basic protocols, such as line-oriented, netstring, and int prefixed strings.
"""
import math
# System imports
import re
from io import BytesIO
from struct import calcsize, pack, unpack
from zope.interface import implementer
# Twisted imports
from twisted.internet import defer, interfaces, protocol
from twisted.python import log
# Unfortunately we cannot use regular string formatting on Python 3; see
# http://bugs.python.org/issue3982 for details.
def _formatNetstring(data):
return b"".join([str(len(data)).encode("ascii"), b":", data, b","])
_formatNetstring.__doc__ = """
Convert some C{bytes} into netstring format.
@param data: C{bytes} that will be reformatted.
"""
DEBUG = 0
class NetstringParseError(ValueError):
"""
The incoming data is not in valid Netstring format.
"""
class IncompleteNetstring(Exception):
"""
Not enough data to complete a netstring.
"""
class NetstringReceiver(protocol.Protocol):
"""
A protocol that sends and receives netstrings.
See U{http://cr.yp.to/proto/netstrings.txt} for the specification of
netstrings. Every netstring starts with digits that specify the length
of the data. This length specification is separated from the data by
a colon. The data is terminated with a comma.
Override L{stringReceived} to handle received netstrings. This
method is called with the netstring payload as a single argument
whenever a complete netstring is received.
Security features:
1. Messages are limited in size, useful if you don't want
someone sending you a 500MB netstring (change C{self.MAX_LENGTH}
to the maximum length you wish to accept).
2. The connection is lost if an illegal message is received.
@ivar MAX_LENGTH: Defines the maximum length of netstrings that can be
received.
@type MAX_LENGTH: C{int}
@ivar _LENGTH: A pattern describing all strings that contain a netstring
length specification. Examples for length specifications are C{b'0:'},
C{b'12:'}, and C{b'179:'}. C{b'007:'} is not a valid length
specification, since leading zeros are not allowed.
@type _LENGTH: C{re.Match}
@ivar _LENGTH_PREFIX: A pattern describing all strings that contain
the first part of a netstring length specification (without the
trailing comma). Examples are '0', '12', and '179'. '007' does not
start a netstring length specification, since leading zeros are
not allowed.
@type _LENGTH_PREFIX: C{re.Match}
@ivar _PARSING_LENGTH: Indicates that the C{NetstringReceiver} is in
the state of parsing the length portion of a netstring.
@type _PARSING_LENGTH: C{int}
@ivar _PARSING_PAYLOAD: Indicates that the C{NetstringReceiver} is in
the state of parsing the payload portion (data and trailing comma)
of a netstring.
@type _PARSING_PAYLOAD: C{int}
@ivar brokenPeer: Indicates if the connection is still functional
@type brokenPeer: C{int}
@ivar _state: Indicates if the protocol is consuming the length portion
(C{PARSING_LENGTH}) or the payload (C{PARSING_PAYLOAD}) of a netstring
@type _state: C{int}
@ivar _remainingData: Holds the chunk of data that has not yet been consumed
@type _remainingData: C{string}
@ivar _payload: Holds the payload portion of a netstring including the
trailing comma
@type _payload: C{BytesIO}
@ivar _expectedPayloadSize: Holds the payload size plus one for the trailing
comma.
@type _expectedPayloadSize: C{int}
"""
MAX_LENGTH = 99999
_LENGTH = re.compile(rb"(0|[1-9]\d*)(:)")
_LENGTH_PREFIX = re.compile(rb"(0|[1-9]\d*)$")
# Some error information for NetstringParseError instances.
_MISSING_LENGTH = (
"The received netstring does not start with a " "length specification."
)
_OVERFLOW = (
"The length specification of the received netstring "
"cannot be represented in Python - it causes an "
"OverflowError!"
)
_TOO_LONG = (
"The received netstring is longer than the maximum %s "
"specified by self.MAX_LENGTH"
)
_MISSING_COMMA = "The received netstring is not terminated by a comma."
# The following constants are used for determining if the NetstringReceiver
# is parsing the length portion of a netstring, or the payload.
_PARSING_LENGTH, _PARSING_PAYLOAD = range(2)
def makeConnection(self, transport):
"""
Initializes the protocol.
"""
protocol.Protocol.makeConnection(self, transport)
self._remainingData = b""
self._currentPayloadSize = 0
self._payload = BytesIO()
self._state = self._PARSING_LENGTH
self._expectedPayloadSize = 0
self.brokenPeer = 0
def sendString(self, string):
"""
Sends a netstring.
Wraps up C{string} by adding length information and a
trailing comma; writes the result to the transport.
@param string: The string to send. The necessary framing (length
prefix, etc) will be added.
@type string: C{bytes}
"""
self.transport.write(_formatNetstring(string))
def dataReceived(self, data):
"""
Receives some characters of a netstring.
Whenever a complete netstring is received, this method extracts
its payload and calls L{stringReceived} to process it.
@param data: A chunk of data representing a (possibly partial)
netstring
@type data: C{bytes}
"""
self._remainingData += data
while self._remainingData:
try:
self._consumeData()
except IncompleteNetstring:
break
except NetstringParseError:
self._handleParseError()
break
def stringReceived(self, string):
"""
Override this for notification when each complete string is received.
@param string: The complete string which was received with all
framing (length prefix, etc) removed.
@type string: C{bytes}
@raise NotImplementedError: because the method has to be implemented
by the child class.
"""
raise NotImplementedError()
def _maxLengthSize(self):
"""
Calculate and return the string size of C{self.MAX_LENGTH}.
@return: The size of the string representation for C{self.MAX_LENGTH}
@rtype: C{float}
"""
return math.ceil(math.log10(self.MAX_LENGTH)) + 1
def _consumeData(self):
"""
Consumes the content of C{self._remainingData}.
@raise IncompleteNetstring: if C{self._remainingData} does not
contain enough data to complete the current netstring.
@raise NetstringParseError: if the received data do not
form a valid netstring.
"""
if self._state == self._PARSING_LENGTH:
self._consumeLength()
self._prepareForPayloadConsumption()
if self._state == self._PARSING_PAYLOAD:
self._consumePayload()
def _consumeLength(self):
"""
Consumes the length portion of C{self._remainingData}.
@raise IncompleteNetstring: if C{self._remainingData} contains
a partial length specification (digits without trailing
comma).
@raise NetstringParseError: if the received data do not form a valid
netstring.
"""
lengthMatch = self._LENGTH.match(self._remainingData)
if not lengthMatch:
self._checkPartialLengthSpecification()
raise IncompleteNetstring()
self._processLength(lengthMatch)
def _checkPartialLengthSpecification(self):
"""
Makes sure that the received data represents a valid number.
Checks if C{self._remainingData} represents a number smaller or
equal to C{self.MAX_LENGTH}.
@raise NetstringParseError: if C{self._remainingData} is no
number or is too big (checked by L{_extractLength}).
"""
partialLengthMatch = self._LENGTH_PREFIX.match(self._remainingData)
if not partialLengthMatch:
raise NetstringParseError(self._MISSING_LENGTH)
lengthSpecification = partialLengthMatch.group(1)
self._extractLength(lengthSpecification)
def _processLength(self, lengthMatch):
"""
Processes the length definition of a netstring.
Extracts and stores in C{self._expectedPayloadSize} the number
representing the netstring size. Removes the prefix
representing the length specification from
C{self._remainingData}.
@raise NetstringParseError: if the received netstring does not
start with a number or the number is bigger than
C{self.MAX_LENGTH}.
@param lengthMatch: A regular expression match object matching
a netstring length specification
@type lengthMatch: C{re.Match}
"""
endOfNumber = lengthMatch.end(1)
startOfData = lengthMatch.end(2)
lengthString = self._remainingData[:endOfNumber]
# Expect payload plus trailing comma:
self._expectedPayloadSize = self._extractLength(lengthString) + 1
self._remainingData = self._remainingData[startOfData:]
def _extractLength(self, lengthAsString):
"""
Attempts to extract the length information of a netstring.
@raise NetstringParseError: if the number is bigger than
C{self.MAX_LENGTH}.
@param lengthAsString: A chunk of data starting with a length
specification
@type lengthAsString: C{bytes}
@return: The length of the netstring
@rtype: C{int}
"""
self._checkStringSize(lengthAsString)
length = int(lengthAsString)
if length > self.MAX_LENGTH:
raise NetstringParseError(self._TOO_LONG % (self.MAX_LENGTH,))
return length
def _checkStringSize(self, lengthAsString):
"""
Checks the sanity of lengthAsString.
Checks if the size of the length specification exceeds the
size of the string representing self.MAX_LENGTH. If this is
not the case, the number represented by lengthAsString is
certainly bigger than self.MAX_LENGTH, and a
NetstringParseError can be raised.
This method should make sure that netstrings with extremely
long length specifications are refused before even attempting
to convert them to an integer (which might trigger a
MemoryError).
"""
if len(lengthAsString) > self._maxLengthSize():
raise NetstringParseError(self._TOO_LONG % (self.MAX_LENGTH,))
def _prepareForPayloadConsumption(self):
"""
Sets up variables necessary for consuming the payload of a netstring.
"""
self._state = self._PARSING_PAYLOAD
self._currentPayloadSize = 0
self._payload.seek(0)
self._payload.truncate()
def _consumePayload(self):
"""
Consumes the payload portion of C{self._remainingData}.
If the payload is complete, checks for the trailing comma and
processes the payload. If not, raises an L{IncompleteNetstring}
exception.
@raise IncompleteNetstring: if the payload received so far
contains fewer characters than expected.
@raise NetstringParseError: if the payload does not end with a
comma.
"""
self._extractPayload()
if self._currentPayloadSize < self._expectedPayloadSize:
raise IncompleteNetstring()
self._checkForTrailingComma()
self._state = self._PARSING_LENGTH
self._processPayload()
def _extractPayload(self):
"""
Extracts payload information from C{self._remainingData}.
Splits C{self._remainingData} at the end of the netstring. The
first part becomes C{self._payload}, the second part is stored
in C{self._remainingData}.
If the netstring is not yet complete, the whole content of
C{self._remainingData} is moved to C{self._payload}.
"""
if self._payloadComplete():
remainingPayloadSize = self._expectedPayloadSize - self._currentPayloadSize
self._payload.write(self._remainingData[:remainingPayloadSize])
self._remainingData = self._remainingData[remainingPayloadSize:]
self._currentPayloadSize = self._expectedPayloadSize
else:
self._payload.write(self._remainingData)
self._currentPayloadSize += len(self._remainingData)
self._remainingData = b""
def _payloadComplete(self):
"""
Checks if enough data have been received to complete the netstring.
@return: C{True} iff the received data contain at least as many
characters as specified in the length section of the
netstring
@rtype: C{bool}
"""
return (
len(self._remainingData) + self._currentPayloadSize
>= self._expectedPayloadSize
)
def _processPayload(self):
"""
Processes the actual payload with L{stringReceived}.
Strips C{self._payload} of the trailing comma and calls
L{stringReceived} with the result.
"""
self.stringReceived(self._payload.getvalue()[:-1])
def _checkForTrailingComma(self):
"""
Checks if the netstring has a trailing comma at the expected position.
@raise NetstringParseError: if the last payload character is
anything but a comma.
"""
if self._payload.getvalue()[-1:] != b",":
raise NetstringParseError(self._MISSING_COMMA)
def _handleParseError(self):
"""
Terminates the connection and sets the flag C{self.brokenPeer}.
"""
self.transport.loseConnection()
self.brokenPeer = 1
class LineOnlyReceiver(protocol.Protocol):
"""
A protocol that receives only lines.
This is purely a speed optimisation over LineReceiver, for the
cases that raw mode is known to be unnecessary.
@cvar delimiter: The line-ending delimiter to use. By default this is
C{b'\\r\\n'}.
@cvar MAX_LENGTH: The maximum length of a line to allow (If a
sent line is longer than this, the connection is dropped).
Default is 16384.
"""
_buffer = b""
delimiter = b"\r\n"
MAX_LENGTH = 16384
def dataReceived(self, data):
"""
Translates bytes into lines, and calls lineReceived.
"""
lines = (self._buffer + data).split(self.delimiter)
self._buffer = lines.pop(-1)
for line in lines:
if self.transport.disconnecting:
# this is necessary because the transport may be told to lose
# the connection by a line within a larger packet, and it is
# important to disregard all the lines in that packet following
# the one that told it to close.
return
if len(line) > self.MAX_LENGTH:
return self.lineLengthExceeded(line)
else:
self.lineReceived(line)
if len(self._buffer) > self.MAX_LENGTH:
return self.lineLengthExceeded(self._buffer)
def lineReceived(self, line):
"""
Override this for when each line is received.
@param line: The line which was received with the delimiter removed.
@type line: C{bytes}
"""
raise NotImplementedError
def sendLine(self, line):
"""
Sends a line to the other end of the connection.
@param line: The line to send, not including the delimiter.
@type line: C{bytes}
"""
return self.transport.writeSequence((line, self.delimiter))
def lineLengthExceeded(self, line):
"""
Called when the maximum line length has been reached.
Override if it needs to be dealt with in some special way.
"""
return self.transport.loseConnection()
class _PauseableMixin:
paused = False
def pauseProducing(self):
self.paused = True
self.transport.pauseProducing()
def resumeProducing(self):
self.paused = False
self.transport.resumeProducing()
self.dataReceived(b"")
def stopProducing(self):
self.paused = True
self.transport.stopProducing()
class LineReceiver(protocol.Protocol, _PauseableMixin):
"""
A protocol that receives lines and/or raw data, depending on mode.
In line mode, each line that's received becomes a callback to
L{lineReceived}. In raw data mode, each chunk of raw data becomes a
callback to L{LineReceiver.rawDataReceived}.
The L{setLineMode} and L{setRawMode} methods switch between the two modes.
This is useful for line-oriented protocols such as IRC, HTTP, POP, etc.
@cvar delimiter: The line-ending delimiter to use. By default this is
C{b'\\r\\n'}.
@cvar MAX_LENGTH: The maximum length of a line to allow (If a
sent line is longer than this, the connection is dropped).
Default is 16384.
"""
line_mode = 1
_buffer = b""
_busyReceiving = False
delimiter = b"\r\n"
MAX_LENGTH = 16384
def clearLineBuffer(self):
"""
Clear buffered data.
@return: All of the cleared buffered data.
@rtype: C{bytes}
"""
b, self._buffer = self._buffer, b""
return b
def dataReceived(self, data):
"""
Protocol.dataReceived.
Translates bytes into lines, and calls lineReceived (or
rawDataReceived, depending on mode.)
"""
if self._busyReceiving:
self._buffer += data
return
try:
self._busyReceiving = True
self._buffer += data
while self._buffer and not self.paused:
if self.line_mode:
try:
line, self._buffer = self._buffer.split(self.delimiter, 1)
except ValueError:
if len(self._buffer) >= (self.MAX_LENGTH + len(self.delimiter)):
line, self._buffer = self._buffer, b""
return self.lineLengthExceeded(line)
return
else:
lineLength = len(line)
if lineLength > self.MAX_LENGTH:
exceeded = line + self.delimiter + self._buffer
self._buffer = b""
return self.lineLengthExceeded(exceeded)
why = self.lineReceived(line)
if why or self.transport and self.transport.disconnecting:
return why
else:
data = self._buffer
self._buffer = b""
why = self.rawDataReceived(data)
if why:
return why
finally:
self._busyReceiving = False
def setLineMode(self, extra=b""):
"""
Sets the line-mode of this receiver.
If you are calling this from a rawDataReceived callback,
you can pass in extra unhandled data, and that data will
be parsed for lines. Further data received will be sent
to lineReceived rather than rawDataReceived.
Do not pass extra data if calling this function from
within a lineReceived callback.
"""
self.line_mode = 1
if extra:
return self.dataReceived(extra)
def setRawMode(self):
"""
Sets the raw mode of this receiver.
Further data received will be sent to rawDataReceived rather
than lineReceived.
"""
self.line_mode = 0
def rawDataReceived(self, data):
"""
Override this for when raw data is received.
"""
raise NotImplementedError
def lineReceived(self, line):
"""
Override this for when each line is received.
@param line: The line which was received with the delimiter removed.
@type line: C{bytes}
"""
raise NotImplementedError
def sendLine(self, line):
"""
Sends a line to the other end of the connection.
@param line: The line to send, not including the delimiter.
@type line: C{bytes}
"""
return self.transport.write(line + self.delimiter)
def lineLengthExceeded(self, line):
"""
Called when the maximum line length has been reached.
Override if it needs to be dealt with in some special way.
The argument 'line' contains the remainder of the buffer, starting
with (at least some part) of the line which is too long. This may
be more than one line, or may be only the initial portion of the
line.
"""
return self.transport.loseConnection()
class StringTooLongError(AssertionError):
"""
Raised when trying to send a string too long for a length prefixed
protocol.
"""
class _RecvdCompatHack:
"""
Emulates the to-be-deprecated C{IntNStringReceiver.recvd} attribute.
The C{recvd} attribute was where the working buffer for buffering and
parsing netstrings was kept. It was updated each time new data arrived and
each time some of that data was parsed and delivered to application code.
The piecemeal updates to its string value were expensive and have been
removed from C{IntNStringReceiver} in the normal case. However, for
applications directly reading this attribute, this descriptor restores that
behavior. It only copies the working buffer when necessary (ie, when
accessed). This avoids the cost for applications not using the data.
This is a custom descriptor rather than a property, because we still need
the default __set__ behavior in both new-style and old-style subclasses.
"""
def __get__(self, oself, type=None):
return oself._unprocessed[oself._compatibilityOffset :]
class IntNStringReceiver(protocol.Protocol, _PauseableMixin):
"""
Generic class for length prefixed protocols.
@ivar _unprocessed: bytes received, but not yet broken up into messages /
sent to stringReceived. _compatibilityOffset must be updated when this
value is updated so that the C{recvd} attribute can be generated
correctly.
@type _unprocessed: C{bytes}
@ivar structFormat: format used for struct packing/unpacking. Define it in
subclass.
@type structFormat: C{str}
@ivar prefixLength: length of the prefix, in bytes. Define it in subclass,
using C{struct.calcsize(structFormat)}
@type prefixLength: C{int}
@ivar _compatibilityOffset: the offset within C{_unprocessed} to the next
message to be parsed. (used to generate the recvd attribute)
@type _compatibilityOffset: C{int}
"""
MAX_LENGTH = 99999
_unprocessed = b""
_compatibilityOffset = 0
# Backwards compatibility support for applications which directly touch the
# "internal" parse buffer.
recvd = _RecvdCompatHack()
def stringReceived(self, string):
"""
Override this for notification when each complete string is received.
@param string: The complete string which was received with all
framing (length prefix, etc) removed.
@type string: C{bytes}
"""
raise NotImplementedError
def lengthLimitExceeded(self, length):
"""
Callback invoked when a length prefix greater than C{MAX_LENGTH} is
received. The default implementation disconnects the transport.
Override this.
@param length: The length prefix which was received.
@type length: C{int}
"""
self.transport.loseConnection()
def dataReceived(self, data):
"""
Convert int prefixed strings into calls to stringReceived.
"""
# Try to minimize string copying (via slices) by keeping one buffer
# containing all the data we have so far and a separate offset into that
# buffer.
alldata = self._unprocessed + data
currentOffset = 0
prefixLength = self.prefixLength
fmt = self.structFormat
self._unprocessed = alldata
while len(alldata) >= (currentOffset + prefixLength) and not self.paused:
messageStart = currentOffset + prefixLength
(length,) = unpack(fmt, alldata[currentOffset:messageStart])
if length > self.MAX_LENGTH:
self._unprocessed = alldata
self._compatibilityOffset = currentOffset
self.lengthLimitExceeded(length)
return
messageEnd = messageStart + length
if len(alldata) < messageEnd:
break
# Here we have to slice the working buffer so we can send just the
# netstring into the stringReceived callback.
packet = alldata[messageStart:messageEnd]
currentOffset = messageEnd
self._compatibilityOffset = currentOffset
self.stringReceived(packet)
# Check to see if the backwards compat "recvd" attribute got written
# to by application code. If so, drop the current data buffer and
# switch to the new buffer given by that attribute's value.
if "recvd" in self.__dict__:
alldata = self.__dict__.pop("recvd")
self._unprocessed = alldata
self._compatibilityOffset = currentOffset = 0
if alldata:
continue
return
# Slice off all the data that has been processed, avoiding holding onto
# memory to store it, and update the compatibility attributes to reflect
# that change.
self._unprocessed = alldata[currentOffset:]
self._compatibilityOffset = 0
def sendString(self, string):
"""
Send a prefixed string to the other end of the connection.
@param string: The string to send. The necessary framing (length
prefix, etc) will be added.
@type string: C{bytes}
"""
if len(string) >= 2 ** (8 * self.prefixLength):
raise StringTooLongError(
"Try to send %s bytes whereas maximum is %s"
% (len(string), 2 ** (8 * self.prefixLength))
)
self.transport.write(pack(self.structFormat, len(string)) + string)
class Int32StringReceiver(IntNStringReceiver):
"""
A receiver for int32-prefixed strings.
An int32 string is a string prefixed by 4 bytes, the 32-bit length of
the string encoded in network byte order.
This class publishes the same interface as NetstringReceiver.
"""
structFormat = "!I"
prefixLength = calcsize(structFormat)
class Int16StringReceiver(IntNStringReceiver):
"""
A receiver for int16-prefixed strings.
An int16 string is a string prefixed by 2 bytes, the 16-bit length of
the string encoded in network byte order.
This class publishes the same interface as NetstringReceiver.
"""
structFormat = "!H"
prefixLength = calcsize(structFormat)
class Int8StringReceiver(IntNStringReceiver):
"""
A receiver for int8-prefixed strings.
An int8 string is a string prefixed by 1 byte, the 8-bit length of
the string.
This class publishes the same interface as NetstringReceiver.
"""
structFormat = "!B"
prefixLength = calcsize(structFormat)
class StatefulStringProtocol:
"""
A stateful string protocol.
This is a mixin for string protocols (L{Int32StringReceiver},
L{NetstringReceiver}) which translates L{stringReceived} into a callback
(prefixed with C{'proto_'}) depending on state.
The state C{'done'} is special; if a C{proto_*} method returns it, the
connection will be closed immediately.
@ivar state: Current state of the protocol. Defaults to C{'init'}.
@type state: C{str}
"""
state = "init"
def stringReceived(self, string):
"""
Choose a protocol phase function and call it.
Call back to the appropriate protocol phase; this begins with
the function C{proto_init} and moves on to C{proto_*} depending on
what each C{proto_*} function returns. (For example, if
C{self.proto_init} returns 'foo', then C{self.proto_foo} will be the
next function called when a protocol message is received.
"""
try:
pto = "proto_" + self.state
statehandler = getattr(self, pto)
except AttributeError:
log.msg("callback", self.state, "not found")
else:
self.state = statehandler(string)
if self.state == "done":
self.transport.loseConnection()
@implementer(interfaces.IProducer)
class FileSender:
"""
A producer that sends the contents of a file to a consumer.
This is a helper for protocols that, at some point, will take a
file-like object, read its contents, and write them out to the network,
optionally performing some transformation on the bytes in between.
"""
CHUNK_SIZE = 2**14
lastSent = ""
deferred = None
def beginFileTransfer(self, file, consumer, transform=None):
"""
Begin transferring a file
@type file: Any file-like object
@param file: The file object to read data from
@type consumer: Any implementor of IConsumer
@param consumer: The object to write data to
@param transform: A callable taking one string argument and returning
the same. All bytes read from the file are passed through this before
being written to the consumer.
@rtype: C{Deferred}
@return: A deferred whose callback will be invoked when the file has
been completely written to the consumer. The last byte written to the
consumer is passed to the callback.
"""
self.file = file
self.consumer = consumer
self.transform = transform
self.deferred = deferred = defer.Deferred()
self.consumer.registerProducer(self, False)
return deferred
def resumeProducing(self):
chunk = ""
if self.file:
chunk = self.file.read(self.CHUNK_SIZE)
if not chunk:
self.file = None
self.consumer.unregisterProducer()
if self.deferred:
self.deferred.callback(self.lastSent)
self.deferred = None
return
if self.transform:
chunk = self.transform(chunk)
self.consumer.write(chunk)
self.lastSent = chunk[-1:]
def pauseProducing(self):
pass
def stopProducing(self):
if self.deferred:
self.deferred.errback(Exception("Consumer asked us to stop producing"))
self.deferred = None

View File

@@ -0,0 +1,42 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""The Finger User Information Protocol (RFC 1288)"""
from twisted.protocols import basic
class Finger(basic.LineReceiver):
def lineReceived(self, line):
parts = line.split()
if not parts:
parts = [b""]
if len(parts) == 1:
slash_w = 0
else:
slash_w = 1
user = parts[-1]
if b"@" in user:
hostPlace = user.rfind(b"@")
user = user[:hostPlace]
host = user[hostPlace + 1 :]
return self.forwardQuery(slash_w, user, host)
if user:
return self.getUser(slash_w, user)
else:
return self.getDomain(slash_w)
def _refuseMessage(self, message):
self.transport.write(message + b"\n")
self.transport.loseConnection()
def forwardQuery(self, slash_w, user, host):
self._refuseMessage(b"Finger forwarding service denied")
def getDomain(self, slash_w):
self._refuseMessage(b"Finger online list denied")
def getUser(self, slash_w, user):
self.transport.write(b"Login: " + user + b"\n")
self._refuseMessage(b"No such user")

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,10 @@
# -*- test-case-name: twisted.protocols.haproxy.test -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
HAProxy PROXY protocol implementations.
"""
__all__ = ["proxyEndpoint"]
from ._wrapper import proxyEndpoint

View File

@@ -0,0 +1,49 @@
# -*- test-case-name: twisted.protocols.haproxy.test -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
HAProxy specific exceptions.
"""
import contextlib
from typing import Callable, Generator, Type
class InvalidProxyHeader(Exception):
"""
The provided PROXY protocol header is invalid.
"""
class InvalidNetworkProtocol(InvalidProxyHeader):
"""
The network protocol was not one of TCP4 TCP6 or UNKNOWN.
"""
class MissingAddressData(InvalidProxyHeader):
"""
The address data is missing or incomplete.
"""
@contextlib.contextmanager
def convertError(
sourceType: Type[BaseException], targetType: Callable[[], BaseException]
) -> Generator[None, None, None]:
"""
Convert an error into a different error type.
@param sourceType: The type of exception that should be caught and
converted.
@type sourceType: L{BaseException}
@param targetType: The type of exception to which the original should be
converted.
@type targetType: L{BaseException}
"""
try:
yield
except sourceType as e:
raise targetType().with_traceback(e.__traceback__)

View File

@@ -0,0 +1,34 @@
# -*- test-case-name: twisted.protocols.haproxy.test -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
IProxyInfo implementation.
"""
from typing import Optional
from zope.interface import implementer
import attr
from twisted.internet.interfaces import IAddress
from ._interfaces import IProxyInfo
@implementer(IProxyInfo)
@attr.s(frozen=True, slots=True, auto_attribs=True)
class ProxyInfo:
"""
A data container for parsed PROXY protocol information.
@ivar header: The raw header bytes extracted from the connection.
@type header: C{bytes}
@ivar source: The connection source address.
@type source: L{twisted.internet.interfaces.IAddress}
@ivar destination: The connection destination address.
@type destination: L{twisted.internet.interfaces.IAddress}
"""
header: bytes
source: Optional[IAddress]
destination: Optional[IAddress]

View File

@@ -0,0 +1,63 @@
# -*- test-case-name: twisted.protocols.haproxy.test -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Interfaces used by the PROXY protocol modules.
"""
from typing import Tuple, Union
import zope.interface
class IProxyInfo(zope.interface.Interface):
"""
Data container for PROXY protocol header data.
"""
header = zope.interface.Attribute(
"The raw byestring that represents the PROXY protocol header.",
)
source = zope.interface.Attribute(
"An L{twisted.internet.interfaces.IAddress} representing the "
"connection source."
)
destination = zope.interface.Attribute(
"An L{twisted.internet.interfaces.IAddress} representing the "
"connection destination."
)
class IProxyParser(zope.interface.Interface):
"""
Streaming parser that handles PROXY protocol headers.
"""
def feed(data: bytes) -> Union[Tuple[IProxyInfo, bytes], Tuple[None, None]]:
"""
Consume a chunk of data and attempt to parse it.
@param data: A bytestring.
@type data: bytes
@return: A two-tuple containing, in order, an L{IProxyInfo} and any
bytes fed to the parser that followed the end of the header. Both
of these values are None until a complete header is parsed.
@raises InvalidProxyHeader: If the bytes fed to the parser create an
invalid PROXY header.
"""
def parse(line: bytes) -> IProxyInfo:
"""
Parse a bytestring as a full PROXY protocol header line.
@param line: A bytestring that represents a valid HAProxy PROXY
protocol header line.
@type line: bytes
@return: An L{IProxyInfo} containing the parsed data.
@raises InvalidProxyHeader: If the bytestring does not represent a
valid PROXY header.
"""

View File

@@ -0,0 +1,75 @@
# -*- test-case-name: twisted.protocols.haproxy.test.test_parser -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Parser for 'haproxy:' string endpoint.
"""
from typing import Mapping, Tuple
from zope.interface import implementer
from twisted.internet import interfaces
from twisted.internet.endpoints import (
IStreamServerEndpointStringParser,
_WrapperServerEndpoint,
quoteStringArgument,
serverFromString,
)
from twisted.plugin import IPlugin
from . import proxyEndpoint
def unparseEndpoint(args: Tuple[object, ...], kwargs: Mapping[str, object]) -> str:
"""
Un-parse the already-parsed args and kwargs back into endpoint syntax.
@param args: C{:}-separated arguments
@param kwargs: C{:} and then C{=}-separated keyword arguments
@return: a string equivalent to the original format which this was parsed
as.
"""
description = ":".join(
[quoteStringArgument(str(arg)) for arg in args]
+ sorted(
"{}={}".format(
quoteStringArgument(str(key)), quoteStringArgument(str(value))
)
for key, value in kwargs.items()
)
)
return description
@implementer(IPlugin, IStreamServerEndpointStringParser)
class HAProxyServerParser:
"""
Stream server endpoint string parser for the HAProxyServerEndpoint type.
@ivar prefix: See L{IStreamServerEndpointStringParser.prefix}.
"""
prefix = "haproxy"
def parseStreamServer(
self, reactor: interfaces.IReactorCore, *args: object, **kwargs: object
) -> _WrapperServerEndpoint:
"""
Parse a stream server endpoint from a reactor and string-only arguments
and keyword arguments.
@param reactor: The reactor.
@param args: The parsed string arguments.
@param kwargs: The parsed keyword arguments.
@return: a stream server endpoint
@rtype: L{IStreamServerEndpoint}
"""
subdescription = unparseEndpoint(args, kwargs)
wrappedEndpoint = serverFromString(reactor, subdescription)
return proxyEndpoint(wrappedEndpoint)

View File

@@ -0,0 +1,142 @@
# -*- test-case-name: twisted.protocols.haproxy.test.test_v1parser -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
IProxyParser implementation for version one of the PROXY protocol.
"""
from typing import Tuple, Union
from zope.interface import implementer
from twisted.internet import address
from . import _info, _interfaces
from ._exceptions import (
InvalidNetworkProtocol,
InvalidProxyHeader,
MissingAddressData,
convertError,
)
@implementer(_interfaces.IProxyParser)
class V1Parser:
"""
PROXY protocol version one header parser.
Version one of the PROXY protocol is a human readable format represented
by a single, newline delimited binary string that contains all of the
relevant source and destination data.
"""
PROXYSTR = b"PROXY"
UNKNOWN_PROTO = b"UNKNOWN"
TCP4_PROTO = b"TCP4"
TCP6_PROTO = b"TCP6"
ALLOWED_NET_PROTOS = (
TCP4_PROTO,
TCP6_PROTO,
UNKNOWN_PROTO,
)
NEWLINE = b"\r\n"
def __init__(self) -> None:
self.buffer = b""
def feed(
self, data: bytes
) -> Union[Tuple[_info.ProxyInfo, bytes], Tuple[None, None]]:
"""
Consume a chunk of data and attempt to parse it.
@param data: A bytestring.
@type data: L{bytes}
@return: A two-tuple containing, in order, a
L{_interfaces.IProxyInfo} and any bytes fed to the
parser that followed the end of the header. Both of these values
are None until a complete header is parsed.
@raises InvalidProxyHeader: If the bytes fed to the parser create an
invalid PROXY header.
"""
self.buffer += data
if len(self.buffer) > 107 and self.NEWLINE not in self.buffer:
raise InvalidProxyHeader()
lines = (self.buffer).split(self.NEWLINE, 1)
if not len(lines) > 1:
return (None, None)
self.buffer = b""
remaining = lines.pop()
header = lines.pop()
info = self.parse(header)
return (info, remaining)
@classmethod
def parse(cls, line: bytes) -> _info.ProxyInfo:
"""
Parse a bytestring as a full PROXY protocol header line.
@param line: A bytestring that represents a valid HAProxy PROXY
protocol header line.
@type line: bytes
@return: A L{_interfaces.IProxyInfo} containing the parsed data.
@raises InvalidProxyHeader: If the bytestring does not represent a
valid PROXY header.
@raises InvalidNetworkProtocol: When no protocol can be parsed or is
not one of the allowed values.
@raises MissingAddressData: When the protocol is TCP* but the header
does not contain a complete set of addresses and ports.
"""
originalLine = line
proxyStr = None
networkProtocol = None
sourceAddr = None
sourcePort = None
destAddr = None
destPort = None
with convertError(ValueError, InvalidProxyHeader):
proxyStr, line = line.split(b" ", 1)
if proxyStr != cls.PROXYSTR:
raise InvalidProxyHeader()
with convertError(ValueError, InvalidNetworkProtocol):
networkProtocol, line = line.split(b" ", 1)
if networkProtocol not in cls.ALLOWED_NET_PROTOS:
raise InvalidNetworkProtocol()
if networkProtocol == cls.UNKNOWN_PROTO:
return _info.ProxyInfo(originalLine, None, None)
with convertError(ValueError, MissingAddressData):
sourceAddr, line = line.split(b" ", 1)
with convertError(ValueError, MissingAddressData):
destAddr, line = line.split(b" ", 1)
with convertError(ValueError, MissingAddressData):
sourcePort, line = line.split(b" ", 1)
with convertError(ValueError, MissingAddressData):
destPort = line.split(b" ")[0]
if networkProtocol == cls.TCP4_PROTO:
return _info.ProxyInfo(
originalLine,
address.IPv4Address("TCP", sourceAddr.decode(), int(sourcePort)),
address.IPv4Address("TCP", destAddr.decode(), int(destPort)),
)
return _info.ProxyInfo(
originalLine,
address.IPv6Address("TCP", sourceAddr.decode(), int(sourcePort)),
address.IPv6Address("TCP", destAddr.decode(), int(destPort)),
)

View File

@@ -0,0 +1,217 @@
# -*- test-case-name: twisted.protocols.haproxy.test.test_v2parser -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
IProxyParser implementation for version two of the PROXY protocol.
"""
import binascii
import struct
from typing import Callable, Tuple, Type, Union
from zope.interface import implementer
from constantly import ValueConstant, Values
from typing_extensions import Literal
from twisted.internet import address
from twisted.python import compat
from . import _info, _interfaces
from ._exceptions import (
InvalidNetworkProtocol,
InvalidProxyHeader,
MissingAddressData,
convertError,
)
class NetFamily(Values):
"""
Values for the 'family' field.
"""
UNSPEC = ValueConstant(0x00)
INET = ValueConstant(0x10)
INET6 = ValueConstant(0x20)
UNIX = ValueConstant(0x30)
class NetProtocol(Values):
"""
Values for 'protocol' field.
"""
UNSPEC = ValueConstant(0)
STREAM = ValueConstant(1)
DGRAM = ValueConstant(2)
_HIGH = 0b11110000
_LOW = 0b00001111
_LOCALCOMMAND = "LOCAL"
_PROXYCOMMAND = "PROXY"
@implementer(_interfaces.IProxyParser)
class V2Parser:
"""
PROXY protocol version two header parser.
Version two of the PROXY protocol is a binary format.
"""
PREFIX = b"\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A"
VERSIONS = [32]
COMMANDS = {0: _LOCALCOMMAND, 1: _PROXYCOMMAND}
ADDRESSFORMATS = {
# TCP4
17: "!4s4s2H",
18: "!4s4s2H",
# TCP6
33: "!16s16s2H",
34: "!16s16s2H",
# UNIX
49: "!108s108s",
50: "!108s108s",
}
def __init__(self) -> None:
self.buffer = b""
def feed(
self, data: bytes
) -> Union[Tuple[_info.ProxyInfo, bytes], Tuple[None, None]]:
"""
Consume a chunk of data and attempt to parse it.
@param data: A bytestring.
@type data: bytes
@return: A two-tuple containing, in order, a L{_interfaces.IProxyInfo}
and any bytes fed to the parser that followed the end of the
header. Both of these values are None until a complete header is
parsed.
@raises InvalidProxyHeader: If the bytes fed to the parser create an
invalid PROXY header.
"""
self.buffer += data
if len(self.buffer) < 16:
raise InvalidProxyHeader()
size = struct.unpack("!H", self.buffer[14:16])[0] + 16
if len(self.buffer) < size:
return (None, None)
header, remaining = self.buffer[:size], self.buffer[size:]
self.buffer = b""
info = self.parse(header)
return (info, remaining)
@staticmethod
def _bytesToIPv4(bytestring: bytes) -> bytes:
"""
Convert packed 32-bit IPv4 address bytes into a dotted-quad ASCII bytes
representation of that address.
@param bytestring: 4 octets representing an IPv4 address.
@type bytestring: L{bytes}
@return: a dotted-quad notation IPv4 address.
@rtype: L{bytes}
"""
return b".".join(
("%i" % (ord(b),)).encode("ascii") for b in compat.iterbytes(bytestring)
)
@staticmethod
def _bytesToIPv6(bytestring: bytes) -> bytes:
"""
Convert packed 128-bit IPv6 address bytes into a colon-separated ASCII
bytes representation of that address.
@param bytestring: 16 octets representing an IPv6 address.
@type bytestring: L{bytes}
@return: a dotted-quad notation IPv6 address.
@rtype: L{bytes}
"""
hexString = binascii.b2a_hex(bytestring)
return b":".join(
(f"{int(hexString[b : b + 4], 16):x}").encode("ascii")
for b in range(0, 32, 4)
)
@classmethod
def parse(cls, line: bytes) -> _info.ProxyInfo:
"""
Parse a bytestring as a full PROXY protocol header.
@param line: A bytestring that represents a valid HAProxy PROXY
protocol version 2 header.
@type line: bytes
@return: A L{_interfaces.IProxyInfo} containing the
parsed data.
@raises InvalidProxyHeader: If the bytestring does not represent a
valid PROXY header.
"""
prefix = line[:12]
addrInfo = None
with convertError(IndexError, InvalidProxyHeader):
# Use single value slices to ensure bytestring values are returned
# instead of int in PY3.
versionCommand = ord(line[12:13])
familyProto = ord(line[13:14])
if prefix != cls.PREFIX:
raise InvalidProxyHeader()
version, command = versionCommand & _HIGH, versionCommand & _LOW
if version not in cls.VERSIONS or command not in cls.COMMANDS:
raise InvalidProxyHeader()
if cls.COMMANDS[command] == _LOCALCOMMAND:
return _info.ProxyInfo(line, None, None)
family, netproto = familyProto & _HIGH, familyProto & _LOW
with convertError(ValueError, InvalidNetworkProtocol):
family = NetFamily.lookupByValue(family)
netproto = NetProtocol.lookupByValue(netproto)
if family is NetFamily.UNSPEC or netproto is NetProtocol.UNSPEC:
return _info.ProxyInfo(line, None, None)
addressFormat = cls.ADDRESSFORMATS[familyProto]
addrInfo = line[16 : 16 + struct.calcsize(addressFormat)]
if family is NetFamily.UNIX:
with convertError(struct.error, MissingAddressData):
source, dest = struct.unpack(addressFormat, addrInfo)
return _info.ProxyInfo(
line,
address.UNIXAddress(source.rstrip(b"\x00")),
address.UNIXAddress(dest.rstrip(b"\x00")),
)
addrType: Union[Literal["TCP"], Literal["UDP"]] = "TCP"
if netproto is NetProtocol.DGRAM:
addrType = "UDP"
addrCls: Union[
Type[address.IPv4Address], Type[address.IPv6Address]
] = address.IPv4Address
addrParser: Callable[[bytes], bytes] = cls._bytesToIPv4
if family is NetFamily.INET6:
addrCls = address.IPv6Address
addrParser = cls._bytesToIPv6
with convertError(struct.error, MissingAddressData):
info = struct.unpack(addressFormat, addrInfo)
source, dest, sPort, dPort = info
return _info.ProxyInfo(
line,
addrCls(addrType, addrParser(source).decode(), sPort),
addrCls(addrType, addrParser(dest).decode(), dPort),
)

View File

@@ -0,0 +1,109 @@
# -*- test-case-name: twisted.protocols.haproxy.test.test_wrapper -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Protocol wrapper that provides HAProxy PROXY protocol support.
"""
from typing import Optional, Union
from twisted.internet import interfaces
from twisted.internet.endpoints import _WrapperServerEndpoint
from twisted.protocols import policies
from . import _info
from ._exceptions import InvalidProxyHeader
from ._v1parser import V1Parser
from ._v2parser import V2Parser
class HAProxyProtocolWrapper(policies.ProtocolWrapper):
"""
A Protocol wrapper that provides HAProxy support.
This protocol reads the PROXY stream header, v1 or v2, parses the provided
connection data, and modifies the behavior of getPeer and getHost to return
the data provided by the PROXY header.
"""
def __init__(
self, factory: policies.WrappingFactory, wrappedProtocol: interfaces.IProtocol
):
super().__init__(factory, wrappedProtocol)
self._proxyInfo: Optional[_info.ProxyInfo] = None
self._parser: Union[V2Parser, V1Parser, None] = None
def dataReceived(self, data: bytes) -> None:
if self._proxyInfo is not None:
return self.wrappedProtocol.dataReceived(data)
parser = self._parser
if parser is None:
if (
len(data) >= 16
and data[:12] == V2Parser.PREFIX
and ord(data[12:13]) & 0b11110000 == 0x20
):
self._parser = parser = V2Parser()
elif len(data) >= 8 and data[:5] == V1Parser.PROXYSTR:
self._parser = parser = V1Parser()
else:
self.loseConnection()
return None
try:
self._proxyInfo, remaining = parser.feed(data)
if remaining:
self.wrappedProtocol.dataReceived(remaining)
except InvalidProxyHeader:
self.loseConnection()
def getPeer(self) -> interfaces.IAddress:
if self._proxyInfo and self._proxyInfo.source:
return self._proxyInfo.source
assert self.transport
return self.transport.getPeer()
def getHost(self) -> interfaces.IAddress:
if self._proxyInfo and self._proxyInfo.destination:
return self._proxyInfo.destination
assert self.transport
return self.transport.getHost()
class HAProxyWrappingFactory(policies.WrappingFactory):
"""
A Factory wrapper that adds PROXY protocol support to connections.
"""
protocol = HAProxyProtocolWrapper
def logPrefix(self) -> str:
"""
Annotate the wrapped factory's log prefix with some text indicating
the PROXY protocol is in use.
@rtype: C{str}
"""
if interfaces.ILoggingContext.providedBy(self.wrappedFactory):
logPrefix = self.wrappedFactory.logPrefix()
else:
logPrefix = self.wrappedFactory.__class__.__name__
return f"{logPrefix} (PROXY)"
def proxyEndpoint(
wrappedEndpoint: interfaces.IStreamServerEndpoint,
) -> _WrapperServerEndpoint:
"""
Wrap an endpoint with PROXY protocol support, so that the transport's
C{getHost} and C{getPeer} methods reflect the attributes of the proxied
connection rather than the underlying connection.
@param wrappedEndpoint: The underlying listening endpoint.
@type wrappedEndpoint: L{IStreamServerEndpoint}
@return: a new listening endpoint that speaks the PROXY protocol.
@rtype: L{IStreamServerEndpoint}
"""
return _WrapperServerEndpoint(wrappedEndpoint, HAProxyWrappingFactory)

View File

@@ -0,0 +1,7 @@
# -*- test-case-name: twisted.protocols.haproxy.test -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Unit tests for L{twisted.protocols.haproxy}.
"""

View File

@@ -0,0 +1,133 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.protocols.haproxy._parser}.
"""
from typing import Type, Union
from twisted.internet.endpoints import (
TCP4ServerEndpoint,
TCP6ServerEndpoint,
UNIXServerEndpoint,
_parse as parseEndpoint,
_WrapperServerEndpoint,
serverFromString,
)
from twisted.internet.testing import MemoryReactor
from twisted.trial.unittest import SynchronousTestCase as TestCase
from .._parser import unparseEndpoint
from .._wrapper import HAProxyWrappingFactory
class UnparseEndpointTests(TestCase):
"""
Tests to ensure that un-parsing an endpoint string round trips through
escaping properly.
"""
def check(self, input: str) -> None:
"""
Check that the input unparses into the output, raising an assertion
error if it doesn't.
@param input: an input in endpoint-string-description format. (To
ensure determinism, keyword arguments should be in alphabetical
order.)
@type input: native L{str}
"""
self.assertEqual(unparseEndpoint(*parseEndpoint(input)), input)
def test_basicUnparse(self) -> None:
"""
An individual word.
"""
self.check("word")
def test_multipleArguments(self) -> None:
"""
Multiple arguments.
"""
self.check("one:two")
def test_keywords(self) -> None:
"""
Keyword arguments.
"""
self.check("aleph=one:bet=two")
def test_colonInArgument(self) -> None:
"""
Escaped ":".
"""
self.check("hello\\:colon\\:world")
def test_colonInKeywordValue(self) -> None:
"""
Escaped ":" in keyword value.
"""
self.check("hello=\\:")
def test_colonInKeywordName(self) -> None:
"""
Escaped ":" in keyword name.
"""
self.check("\\:=hello")
class HAProxyServerParserTests(TestCase):
"""
Tests that the parser generates the correct endpoints.
"""
def onePrefix(
self,
description: str,
expectedClass: Union[
Type[TCP4ServerEndpoint],
Type[TCP6ServerEndpoint],
Type[UNIXServerEndpoint],
],
) -> _WrapperServerEndpoint:
"""
Test the C{haproxy} enpdoint prefix against one sub-endpoint type.
@param description: A string endpoint description beginning with
C{haproxy}.
@type description: native L{str}
@param expectedClass: the expected sub-endpoint class given the
description.
@type expectedClass: L{type}
@return: the parsed endpoint
@rtype: L{IStreamServerEndpoint}
@raise twisted.trial.unittest.Failtest: if the parsed endpoint doesn't
match expectations.
"""
reactor = MemoryReactor()
endpoint = serverFromString(reactor, description)
self.assertIsInstance(endpoint, _WrapperServerEndpoint)
assert isinstance(endpoint, _WrapperServerEndpoint)
self.assertIsInstance(endpoint._wrappedEndpoint, expectedClass)
self.assertIs(endpoint._wrapperFactory, HAProxyWrappingFactory)
return endpoint
def test_tcp4(self) -> None:
"""
Test if the parser generates a wrapped TCP4 endpoint.
"""
self.onePrefix("haproxy:tcp:8080", TCP4ServerEndpoint)
def test_tcp6(self) -> None:
"""
Test if the parser generates a wrapped TCP6 endpoint.
"""
self.onePrefix("haproxy:tcp6:8080", TCP6ServerEndpoint)
def test_unix(self) -> None:
"""
Test if the parser generates a wrapped UNIX endpoint.
"""
self.onePrefix("haproxy:unix:address=/tmp/socket", UNIXServerEndpoint)

View File

@@ -0,0 +1,149 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Test cases for L{twisted.protocols.haproxy.V1Parser}.
"""
from twisted.internet import address
from twisted.trial import unittest
from .. import _v1parser
from .._exceptions import InvalidNetworkProtocol, InvalidProxyHeader, MissingAddressData
class V1ParserTests(unittest.TestCase):
"""
Test L{twisted.protocols.haproxy.V1Parser} behaviour.
"""
def test_missingPROXYHeaderValue(self) -> None:
"""
Test that an exception is raised when the PROXY header is missing.
"""
self.assertRaises(
InvalidProxyHeader,
_v1parser.V1Parser.parse,
b"NOTPROXY ",
)
def test_invalidNetworkProtocol(self) -> None:
"""
Test that an exception is raised when the proto is not TCP or UNKNOWN.
"""
self.assertRaises(
InvalidNetworkProtocol,
_v1parser.V1Parser.parse,
b"PROXY WUTPROTO ",
)
def test_missingSourceData(self) -> None:
"""
Test that an exception is raised when the proto has no source data.
"""
self.assertRaises(
MissingAddressData,
_v1parser.V1Parser.parse,
b"PROXY TCP4 ",
)
def test_missingDestData(self) -> None:
"""
Test that an exception is raised when the proto has no destination.
"""
self.assertRaises(
MissingAddressData,
_v1parser.V1Parser.parse,
b"PROXY TCP4 127.0.0.1 8080 8888",
)
def test_fullParsingSuccess(self) -> None:
"""
Test that parsing is successful for a PROXY header.
"""
info = _v1parser.V1Parser.parse(
b"PROXY TCP4 127.0.0.1 127.0.0.1 8080 8888",
)
self.assertIsInstance(info.source, address.IPv4Address)
assert isinstance(info.source, address.IPv4Address)
assert isinstance(info.destination, address.IPv4Address) # type: ignore[unreachable]
self.assertEqual(info.source.host, "127.0.0.1")
self.assertEqual(info.source.port, 8080)
self.assertEqual(info.destination.host, "127.0.0.1")
self.assertEqual(info.destination.port, 8888)
def test_fullParsingSuccess_IPv6(self) -> None:
"""
Test that parsing is successful for an IPv6 PROXY header.
"""
info = _v1parser.V1Parser.parse(
b"PROXY TCP6 ::1 ::1 8080 8888",
)
self.assertIsInstance(info.source, address.IPv6Address)
assert isinstance(info.source, address.IPv6Address)
assert isinstance(info.destination, address.IPv6Address) # type: ignore[unreachable]
self.assertEqual(info.source.host, "::1")
self.assertEqual(info.source.port, 8080)
self.assertEqual(info.destination.host, "::1")
self.assertEqual(info.destination.port, 8888)
def test_fullParsingSuccess_UNKNOWN(self) -> None:
"""
Test that parsing is successful for a UNKNOWN PROXY header.
"""
info = _v1parser.V1Parser.parse(
b"PROXY UNKNOWN anything could go here",
)
self.assertIsNone(info.source)
self.assertIsNone(info.destination)
def test_feedParsing(self) -> None:
"""
Test that parsing happens when fed a complete line.
"""
parser = _v1parser.V1Parser()
info, remaining = parser.feed(b"PROXY TCP4 127.0.0.1 127.0.0.1 ")
self.assertFalse(info)
self.assertFalse(remaining)
info, remaining = parser.feed(b"8080 8888")
self.assertFalse(info)
self.assertFalse(remaining)
info, remaining = parser.feed(b"\r\n")
self.assertFalse(remaining)
assert remaining is not None
assert info is not None
self.assertIsInstance(info.source, address.IPv4Address)
assert isinstance(info.source, address.IPv4Address)
assert isinstance(info.destination, address.IPv4Address) # type: ignore[unreachable]
self.assertEqual(info.source.host, "127.0.0.1")
self.assertEqual(info.source.port, 8080)
self.assertEqual(info.destination.host, "127.0.0.1")
self.assertEqual(info.destination.port, 8888)
def test_feedParsingTooLong(self) -> None:
"""
Test that parsing fails if no newline is found in 108 bytes.
"""
parser = _v1parser.V1Parser()
info, remaining = parser.feed(b"PROXY TCP4 127.0.0.1 127.0.0.1 ")
self.assertFalse(info)
self.assertFalse(remaining)
info, remaining = parser.feed(b"8080 8888")
self.assertFalse(info)
self.assertFalse(remaining)
self.assertRaises(
InvalidProxyHeader,
parser.feed,
b" " * 100,
)
def test_feedParsingOverflow(self) -> None:
"""
Test that parsing leaves overflow bytes in the buffer.
"""
parser = _v1parser.V1Parser()
info, remaining = parser.feed(
b"PROXY TCP4 127.0.0.1 127.0.0.1 8080 8888\r\nHTTP/1.1 GET /\r\n",
)
self.assertTrue(info)
self.assertEqual(remaining, b"HTTP/1.1 GET /\r\n")
self.assertFalse(parser.buffer)

View File

@@ -0,0 +1,368 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Test cases for L{twisted.protocols.haproxy.V2Parser}.
"""
from twisted.internet import address
from twisted.trial import unittest
from .. import _v2parser
from .._exceptions import InvalidProxyHeader
V2_SIGNATURE = b"\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A"
def _makeHeaderIPv6(
sig: bytes = V2_SIGNATURE,
verCom: bytes = b"\x21",
famProto: bytes = b"\x21",
addrLength: bytes = b"\x00\x24",
addrs: bytes = ((b"\x00" * 15) + b"\x01") * 2,
ports: bytes = b"\x1F\x90\x22\xB8",
) -> bytes:
"""
Construct a version 2 IPv6 header with custom bytes.
@param sig: The protocol signature; defaults to valid L{V2_SIGNATURE}.
@type sig: L{bytes}
@param verCom: Protocol version and command. Defaults to V2 PROXY.
@type verCom: L{bytes}
@param famProto: Address family and protocol. Defaults to AF_INET6/STREAM.
@type famProto: L{bytes}
@param addrLength: Network-endian byte length of payload. Defaults to
description of default addrs/ports.
@type addrLength: L{bytes}
@param addrs: Address payload. Defaults to C{::1} for source and
destination.
@type addrs: L{bytes}
@param ports: Source and destination ports. Defaults to 8080 for source
8888 for destination.
@type ports: L{bytes}
@return: A packet with header, addresses, and ports.
@rtype: L{bytes}
"""
return sig + verCom + famProto + addrLength + addrs + ports
def _makeHeaderIPv4(
sig: bytes = V2_SIGNATURE,
verCom: bytes = b"\x21",
famProto: bytes = b"\x11",
addrLength: bytes = b"\x00\x0C",
addrs: bytes = b"\x7F\x00\x00\x01\x7F\x00\x00\x01",
ports: bytes = b"\x1F\x90\x22\xB8",
) -> bytes:
"""
Construct a version 2 IPv4 header with custom bytes.
@param sig: The protocol signature; defaults to valid L{V2_SIGNATURE}.
@type sig: L{bytes}
@param verCom: Protocol version and command. Defaults to V2 PROXY.
@type verCom: L{bytes}
@param famProto: Address family and protocol. Defaults to AF_INET/STREAM.
@type famProto: L{bytes}
@param addrLength: Network-endian byte length of payload. Defaults to
description of default addrs/ports.
@type addrLength: L{bytes}
@param addrs: Address payload. Defaults to 127.0.0.1 for source and
destination.
@type addrs: L{bytes}
@param ports: Source and destination ports. Defaults to 8080 for source
8888 for destination.
@type ports: L{bytes}
@return: A packet with header, addresses, and ports.
@rtype: L{bytes}
"""
return sig + verCom + famProto + addrLength + addrs + ports
def _makeHeaderUnix(
sig: bytes = V2_SIGNATURE,
verCom: bytes = b"\x21",
famProto: bytes = b"\x31",
addrLength: bytes = b"\x00\xD8",
addrs: bytes = (
b"\x2F\x68\x6F\x6D\x65\x2F\x74\x65\x73\x74\x73\x2F"
b"\x6D\x79\x73\x6F\x63\x6B\x65\x74\x73\x2F\x73\x6F"
b"\x63\x6B" + (b"\x00" * 82)
)
* 2,
) -> bytes:
"""
Construct a version 2 IPv4 header with custom bytes.
@param sig: The protocol signature; defaults to valid L{V2_SIGNATURE}.
@type sig: L{bytes}
@param verCom: Protocol version and command. Defaults to V2 PROXY.
@type verCom: L{bytes}
@param famProto: Address family and protocol. Defaults to AF_UNIX/STREAM.
@type famProto: L{bytes}
@param addrLength: Network-endian byte length of payload. Defaults to 108
bytes for 2 null terminated paths.
@type addrLength: L{bytes}
@param addrs: Address payload. Defaults to C{/home/tests/mysockets/sock}
for source and destination paths.
@type addrs: L{bytes}
@return: A packet with header, addresses, and8 ports.
@rtype: L{bytes}
"""
return sig + verCom + famProto + addrLength + addrs
class V2ParserTests(unittest.TestCase):
"""
Test L{twisted.protocols.haproxy.V2Parser} behaviour.
"""
def test_happyPathIPv4(self) -> None:
"""
Test if a well formed IPv4 header is parsed without error.
"""
header = _makeHeaderIPv4()
self.assertTrue(_v2parser.V2Parser.parse(header))
def test_happyPathIPv6(self) -> None:
"""
Test if a well formed IPv6 header is parsed without error.
"""
header = _makeHeaderIPv6()
self.assertTrue(_v2parser.V2Parser.parse(header))
def test_happyPathUnix(self) -> None:
"""
Test if a well formed UNIX header is parsed without error.
"""
header = _makeHeaderUnix()
self.assertTrue(_v2parser.V2Parser.parse(header))
def test_invalidSignature(self) -> None:
"""
Test if an invalid signature block raises InvalidProxyError.
"""
header = _makeHeaderIPv4(sig=b"\x00" * 12)
self.assertRaises(
InvalidProxyHeader,
_v2parser.V2Parser.parse,
header,
)
def test_invalidVersion(self) -> None:
"""
Test if an invalid version raises InvalidProxyError.
"""
header = _makeHeaderIPv4(verCom=b"\x11")
self.assertRaises(
InvalidProxyHeader,
_v2parser.V2Parser.parse,
header,
)
def test_invalidCommand(self) -> None:
"""
Test if an invalid command raises InvalidProxyError.
"""
header = _makeHeaderIPv4(verCom=b"\x23")
self.assertRaises(
InvalidProxyHeader,
_v2parser.V2Parser.parse,
header,
)
def test_invalidFamily(self) -> None:
"""
Test if an invalid family raises InvalidProxyError.
"""
header = _makeHeaderIPv4(famProto=b"\x40")
self.assertRaises(
InvalidProxyHeader,
_v2parser.V2Parser.parse,
header,
)
def test_invalidProto(self) -> None:
"""
Test if an invalid protocol raises InvalidProxyError.
"""
header = _makeHeaderIPv4(famProto=b"\x24")
self.assertRaises(
InvalidProxyHeader,
_v2parser.V2Parser.parse,
header,
)
def test_localCommandIpv4(self) -> None:
"""
Test that local does not return endpoint data for IPv4 connections.
"""
header = _makeHeaderIPv4(verCom=b"\x20")
info = _v2parser.V2Parser.parse(header)
self.assertFalse(info.source)
self.assertFalse(info.destination)
def test_localCommandIpv6(self) -> None:
"""
Test that local does not return endpoint data for IPv6 connections.
"""
header = _makeHeaderIPv6(verCom=b"\x20")
info = _v2parser.V2Parser.parse(header)
self.assertFalse(info.source)
self.assertFalse(info.destination)
def test_localCommandUnix(self) -> None:
"""
Test that local does not return endpoint data for UNIX connections.
"""
header = _makeHeaderUnix(verCom=b"\x20")
info = _v2parser.V2Parser.parse(header)
self.assertFalse(info.source)
self.assertFalse(info.destination)
def test_proxyCommandIpv4(self) -> None:
"""
Test that proxy returns endpoint data for IPv4 connections.
"""
header = _makeHeaderIPv4(verCom=b"\x21")
info = _v2parser.V2Parser.parse(header)
self.assertTrue(info.source)
self.assertIsInstance(info.source, address.IPv4Address)
self.assertTrue(info.destination)
self.assertIsInstance(info.destination, address.IPv4Address)
def test_proxyCommandIpv6(self) -> None:
"""
Test that proxy returns endpoint data for IPv6 connections.
"""
header = _makeHeaderIPv6(verCom=b"\x21")
info = _v2parser.V2Parser.parse(header)
self.assertTrue(info.source)
self.assertIsInstance(info.source, address.IPv6Address)
self.assertTrue(info.destination)
self.assertIsInstance(info.destination, address.IPv6Address)
def test_proxyCommandUnix(self) -> None:
"""
Test that proxy returns endpoint data for UNIX connections.
"""
header = _makeHeaderUnix(verCom=b"\x21")
info = _v2parser.V2Parser.parse(header)
self.assertTrue(info.source)
self.assertIsInstance(info.source, address.UNIXAddress)
self.assertTrue(info.destination)
self.assertIsInstance(info.destination, address.UNIXAddress)
def test_unspecFamilyIpv4(self) -> None:
"""
Test that UNSPEC does not return endpoint data for IPv4 connections.
"""
header = _makeHeaderIPv4(famProto=b"\x01")
info = _v2parser.V2Parser.parse(header)
self.assertFalse(info.source)
self.assertFalse(info.destination)
def test_unspecFamilyIpv6(self) -> None:
"""
Test that UNSPEC does not return endpoint data for IPv6 connections.
"""
header = _makeHeaderIPv6(famProto=b"\x01")
info = _v2parser.V2Parser.parse(header)
self.assertFalse(info.source)
self.assertFalse(info.destination)
def test_unspecFamilyUnix(self) -> None:
"""
Test that UNSPEC does not return endpoint data for UNIX connections.
"""
header = _makeHeaderUnix(famProto=b"\x01")
info = _v2parser.V2Parser.parse(header)
self.assertFalse(info.source)
self.assertFalse(info.destination)
def test_unspecProtoIpv4(self) -> None:
"""
Test that UNSPEC does not return endpoint data for IPv4 connections.
"""
header = _makeHeaderIPv4(famProto=b"\x10")
info = _v2parser.V2Parser.parse(header)
self.assertFalse(info.source)
self.assertFalse(info.destination)
def test_unspecProtoIpv6(self) -> None:
"""
Test that UNSPEC does not return endpoint data for IPv6 connections.
"""
header = _makeHeaderIPv6(famProto=b"\x20")
info = _v2parser.V2Parser.parse(header)
self.assertFalse(info.source)
self.assertFalse(info.destination)
def test_unspecProtoUnix(self) -> None:
"""
Test that UNSPEC does not return endpoint data for UNIX connections.
"""
header = _makeHeaderUnix(famProto=b"\x30")
info = _v2parser.V2Parser.parse(header)
self.assertFalse(info.source)
self.assertFalse(info.destination)
def test_overflowIpv4(self) -> None:
"""
Test that overflow bits are preserved during feed parsing for IPv4.
"""
testValue = b"TEST DATA\r\n\r\nTEST DATA"
header = _makeHeaderIPv4() + testValue
parser = _v2parser.V2Parser()
info, overflow = parser.feed(header)
self.assertTrue(info)
self.assertEqual(overflow, testValue)
def test_overflowIpv6(self) -> None:
"""
Test that overflow bits are preserved during feed parsing for IPv6.
"""
testValue = b"TEST DATA\r\n\r\nTEST DATA"
header = _makeHeaderIPv6() + testValue
parser = _v2parser.V2Parser()
info, overflow = parser.feed(header)
self.assertTrue(info)
self.assertEqual(overflow, testValue)
def test_overflowUnix(self) -> None:
"""
Test that overflow bits are preserved during feed parsing for Unix.
"""
testValue = b"TEST DATA\r\n\r\nTEST DATA"
header = _makeHeaderUnix() + testValue
parser = _v2parser.V2Parser()
info, overflow = parser.feed(header)
self.assertTrue(info)
self.assertEqual(overflow, testValue)
def test_segmentTooSmall(self) -> None:
"""
Test that an initial payload of less than 16 bytes fails.
"""
testValue = b"NEEDMOREDATA"
parser = _v2parser.V2Parser()
self.assertRaises(
InvalidProxyHeader,
parser.feed,
testValue,
)

View File

@@ -0,0 +1,375 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Test cases for L{twisted.protocols.haproxy.HAProxyProtocol}.
"""
from typing import Optional
from unittest import mock
from twisted.internet import address
from twisted.internet.protocol import Factory, Protocol
from twisted.internet.testing import StringTransportWithDisconnection
from twisted.trial import unittest
from .._wrapper import HAProxyWrappingFactory
class StaticProtocol(Protocol):
"""
Protocol stand-in that maintains test state.
"""
def __init__(self) -> None:
self.source: Optional[address.IAddress] = None
self.destination: Optional[address.IAddress] = None
self.data = b""
self.disconnected = False
def dataReceived(self, data: bytes) -> None:
assert self.transport
self.source = self.transport.getPeer()
self.destination = self.transport.getHost()
self.data += data
class HAProxyWrappingFactoryV1Tests(unittest.TestCase):
"""
Test L{twisted.protocols.haproxy.HAProxyWrappingFactory} with v1 PROXY
headers.
"""
def test_invalidHeaderDisconnects(self) -> None:
"""
Test if invalid headers result in connectionLost events.
"""
factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol))
proto = factory.buildProtocol(
address.IPv4Address("TCP", "127.1.1.1", 8080),
)
transport = StringTransportWithDisconnection()
transport.protocol = proto
proto.makeConnection(transport)
proto.dataReceived(b"NOTPROXY anything can go here\r\n")
self.assertFalse(transport.connected)
def test_invalidPartialHeaderDisconnects(self) -> None:
"""
Test if invalid headers result in connectionLost events.
"""
factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol))
proto = factory.buildProtocol(
address.IPv4Address("TCP", "127.1.1.1", 8080),
)
transport = StringTransportWithDisconnection()
transport.protocol = proto
proto.makeConnection(transport)
proto.dataReceived(b"PROXY TCP4 1.1.1.1\r\n")
proto.dataReceived(b"2.2.2.2 8080\r\n")
self.assertFalse(transport.connected)
def test_preDataReceived_getPeerHost(self) -> None:
"""
Before any data is received the HAProxy protocol will return the same peer
and host as the IP connection.
"""
factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol))
proto = factory.buildProtocol(
address.IPv4Address("TCP", "127.0.0.1", 8080),
)
transport = StringTransportWithDisconnection(
hostAddress=mock.sentinel.host_address,
peerAddress=mock.sentinel.peer_address,
)
proto.makeConnection(transport)
self.assertEqual(proto.getHost(), mock.sentinel.host_address)
self.assertEqual(proto.getPeer(), mock.sentinel.peer_address)
def test_validIPv4HeaderResolves_getPeerHost(self) -> None:
"""
Test if IPv4 headers result in the correct host and peer data.
"""
factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol))
proto = factory.buildProtocol(
address.IPv4Address("TCP", "127.0.0.1", 8080),
)
transport = StringTransportWithDisconnection()
proto.makeConnection(transport)
proto.dataReceived(b"PROXY TCP4 1.1.1.1 2.2.2.2 8080 8888\r\n")
self.assertEqual(proto.getPeer().host, "1.1.1.1")
self.assertEqual(proto.getPeer().port, 8080)
self.assertEqual(
proto.wrappedProtocol.transport.getPeer().host,
"1.1.1.1",
)
self.assertEqual(
proto.wrappedProtocol.transport.getPeer().port,
8080,
)
self.assertEqual(proto.getHost().host, "2.2.2.2")
self.assertEqual(proto.getHost().port, 8888)
self.assertEqual(
proto.wrappedProtocol.transport.getHost().host,
"2.2.2.2",
)
self.assertEqual(
proto.wrappedProtocol.transport.getHost().port,
8888,
)
def test_validIPv6HeaderResolves_getPeerHost(self) -> None:
"""
Test if IPv6 headers result in the correct host and peer data.
"""
factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol))
proto = factory.buildProtocol(
address.IPv6Address("TCP", "::1", 8080),
)
transport = StringTransportWithDisconnection()
proto.makeConnection(transport)
proto.dataReceived(b"PROXY TCP6 ::1 ::2 8080 8888\r\n")
self.assertEqual(proto.getPeer().host, "::1")
self.assertEqual(proto.getPeer().port, 8080)
self.assertEqual(
proto.wrappedProtocol.transport.getPeer().host,
"::1",
)
self.assertEqual(
proto.wrappedProtocol.transport.getPeer().port,
8080,
)
self.assertEqual(proto.getHost().host, "::2")
self.assertEqual(proto.getHost().port, 8888)
self.assertEqual(
proto.wrappedProtocol.transport.getHost().host,
"::2",
)
self.assertEqual(
proto.wrappedProtocol.transport.getHost().port,
8888,
)
def test_overflowBytesSentToWrappedProtocol(self) -> None:
"""
Test if non-header bytes are passed to the wrapped protocol.
"""
factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol))
proto = factory.buildProtocol(
address.IPv6Address("TCP", "::1", 8080),
)
transport = StringTransportWithDisconnection()
proto.makeConnection(transport)
proto.dataReceived(b"PROXY TCP6 ::1 ::2 8080 8888\r\nHTTP/1.1 / GET")
self.assertEqual(proto.wrappedProtocol.data, b"HTTP/1.1 / GET")
def test_[AWS-SECRET-REMOVED](self) -> None:
"""
Test if header streaming passes extra data appropriately.
"""
factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol))
proto = factory.buildProtocol(
address.IPv6Address("TCP", "::1", 8080),
)
transport = StringTransportWithDisconnection()
proto.makeConnection(transport)
proto.dataReceived(b"PROXY TCP6 ::1 ::2 ")
proto.dataReceived(b"8080 8888\r\nHTTP/1.1 / GET")
self.assertEqual(proto.wrappedProtocol.data, b"HTTP/1.1 / GET")
def test_overflowBytesSentToWrappedProtocolAfter(self) -> None:
"""
Test if wrapper writes all data to wrapped protocol after parsing.
"""
factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol))
proto = factory.buildProtocol(
address.IPv6Address("TCP", "::1", 8080),
)
transport = StringTransportWithDisconnection()
proto.makeConnection(transport)
proto.dataReceived(b"PROXY TCP6 ::1 ::2 ")
proto.dataReceived(b"8080 8888\r\nHTTP/1.1 / GET")
proto.dataReceived(b"\r\n\r\n")
self.assertEqual(proto.wrappedProtocol.data, b"HTTP/1.1 / GET\r\n\r\n")
class HAProxyWrappingFactoryV2Tests(unittest.TestCase):
"""
Test L{twisted.protocols.haproxy.HAProxyWrappingFactory} with v2 PROXY
headers.
"""
IPV4HEADER = (
# V2 Signature
b"\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A"
# V2 PROXY command
b"\x21"
# AF_INET/STREAM
b"\x11"
# 12 bytes for 2 IPv4 addresses and two ports
b"\x00\x0C"
# 127.0.0.1 for source and destination
b"\x7F\x00\x00\x01\x7F\x00\x00\x01"
# 8080 for source 8888 for destination
b"\x1F\x90\x22\xB8"
)
IPV6HEADER = (
# V2 Signature
b"\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A"
# V2 PROXY command
b"\x21"
# AF_INET6/STREAM
b"\x21"
# 16 bytes for 2 IPv6 addresses and two ports
b"\x00\x24"
# ::1 for source and destination
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
# 8080 for source 8888 for destination
b"\x1F\x90\x22\xB8"
)
_SOCK_PATH = (
b"\x2F\x68\x6F\x6D\x65\x2F\x74\x65\x73\x74\x73\x2F\x6D\x79\x73\x6F"
b"\x63\x6B\x65\x74\x73\x2F\x73\x6F\x63\x6B" + (b"\x00" * 82)
)
UNIXHEADER = (
(
# V2 Signature
b"\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A"
# V2 PROXY command
b"\x21"
# AF_UNIX/STREAM
b"\x31"
# 108 bytes for 2 null terminated paths
b"\x00\xD8"
# /home/tests/mysockets/sock for source and destination paths
)
+ _SOCK_PATH
+ _SOCK_PATH
)
def test_invalidHeaderDisconnects(self) -> None:
"""
Test if invalid headers result in connectionLost events.
"""
factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol))
proto = factory.buildProtocol(
address.IPv6Address("TCP", "::1", 8080),
)
transport = StringTransportWithDisconnection()
transport.protocol = proto
proto.makeConnection(transport)
proto.dataReceived(b"\x00" + self.IPV4HEADER[1:])
self.assertFalse(transport.connected)
def test_validIPv4HeaderResolves_getPeerHost(self) -> None:
"""
Test if IPv4 headers result in the correct host and peer data.
"""
factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol))
proto = factory.buildProtocol(
address.IPv4Address("TCP", "127.0.0.1", 8080),
)
transport = StringTransportWithDisconnection()
proto.makeConnection(transport)
proto.dataReceived(self.IPV4HEADER)
self.assertEqual(proto.getPeer().host, "127.0.0.1")
self.assertEqual(proto.getPeer().port, 8080)
self.assertEqual(
proto.wrappedProtocol.transport.getPeer().host,
"127.0.0.1",
)
self.assertEqual(
proto.wrappedProtocol.transport.getPeer().port,
8080,
)
self.assertEqual(proto.getHost().host, "127.0.0.1")
self.assertEqual(proto.getHost().port, 8888)
self.assertEqual(
proto.wrappedProtocol.transport.getHost().host,
"127.0.0.1",
)
self.assertEqual(
proto.wrappedProtocol.transport.getHost().port,
8888,
)
def test_validIPv6HeaderResolves_getPeerHost(self) -> None:
"""
Test if IPv6 headers result in the correct host and peer data.
"""
factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol))
proto = factory.buildProtocol(
address.IPv4Address("TCP", "::1", 8080),
)
transport = StringTransportWithDisconnection()
proto.makeConnection(transport)
proto.dataReceived(self.IPV6HEADER)
self.assertEqual(proto.getPeer().host, "0:0:0:0:0:0:0:1")
self.assertEqual(proto.getPeer().port, 8080)
self.assertEqual(
proto.wrappedProtocol.transport.getPeer().host,
"0:0:0:0:0:0:0:1",
)
self.assertEqual(
proto.wrappedProtocol.transport.getPeer().port,
8080,
)
self.assertEqual(proto.getHost().host, "0:0:0:0:0:0:0:1")
self.assertEqual(proto.getHost().port, 8888)
self.assertEqual(
proto.wrappedProtocol.transport.getHost().host,
"0:0:0:0:0:0:0:1",
)
self.assertEqual(
proto.wrappedProtocol.transport.getHost().port,
8888,
)
def test_validUNIXHeaderResolves_getPeerHost(self) -> None:
"""
Test if UNIX headers result in the correct host and peer data.
"""
factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol))
proto = factory.buildProtocol(
address.UNIXAddress(b"/home/test/sockets/server.sock"),
)
transport = StringTransportWithDisconnection()
proto.makeConnection(transport)
proto.dataReceived(self.UNIXHEADER)
self.assertEqual(proto.getPeer().name, b"/home/tests/mysockets/sock")
self.assertEqual(
proto.wrappedProtocol.transport.getPeer().name,
b"/home/tests/mysockets/sock",
)
self.assertEqual(proto.getHost().name, b"/home/tests/mysockets/sock")
self.assertEqual(
proto.wrappedProtocol.transport.getHost().name,
b"/home/tests/mysockets/sock",
)
def test_overflowBytesSentToWrappedProtocol(self) -> None:
"""
Test if non-header bytes are passed to the wrapped protocol.
"""
factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol))
proto = factory.buildProtocol(
address.IPv6Address("TCP", "::1", 8080),
)
transport = StringTransportWithDisconnection()
proto.makeConnection(transport)
proto.dataReceived(self.IPV6HEADER + b"HTTP/1.1 / GET")
self.assertEqual(proto.wrappedProtocol.data, b"HTTP/1.1 / GET")
def test_[AWS-SECRET-REMOVED](self) -> None:
"""
Test if header streaming passes extra data appropriately.
"""
factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol))
proto = factory.buildProtocol(
address.IPv6Address("TCP", "::1", 8080),
)
transport = StringTransportWithDisconnection()
proto.makeConnection(transport)
proto.dataReceived(self.IPV6HEADER[:18])
proto.dataReceived(self.IPV6HEADER[18:] + b"HTTP/1.1 / GET")
self.assertEqual(proto.wrappedProtocol.data, b"HTTP/1.1 / GET")

View File

@@ -0,0 +1,306 @@
# -*- test-case-name: twisted.test.test_htb -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Hierarchical Token Bucket traffic shaping.
Patterned after U{Martin Devera's Hierarchical Token Bucket traffic
shaper for the Linux kernel<http://luxik.cdi.cz/~devik/qos/htb/>}.
@seealso: U{HTB Linux queuing discipline manual - user guide
<http://luxik.cdi.cz/~devik/qos/htb/manual/userg.htm>}
@seealso: U{Token Bucket Filter in Linux Advanced Routing & Traffic Control
HOWTO<http://lartc.org/howto/lartc.qdisc.classless.html#AEN682>}
"""
# TODO: Investigate whether we should be using os.times()[-1] instead of
# time.time. time.time, it has been pointed out, can go backwards. Is
# the same true of os.times?
from time import time
from typing import Optional
from zope.interface import Interface, implementer
from twisted.protocols import pcp
class Bucket:
"""
Implementation of a Token bucket.
A bucket can hold a certain number of tokens and it drains over time.
@cvar maxburst: The maximum number of tokens that the bucket can
hold at any given time. If this is L{None}, the bucket has
an infinite size.
@type maxburst: C{int}
@cvar rate: The rate at which the bucket drains, in number
of tokens per second. If the rate is L{None}, the bucket
drains instantaneously.
@type rate: C{int}
"""
maxburst: Optional[int] = None
rate: Optional[int] = None
_refcount = 0
def __init__(self, parentBucket=None):
"""
Create a L{Bucket} that may have a parent L{Bucket}.
@param parentBucket: If a parent Bucket is specified,
all L{add} and L{drip} operations on this L{Bucket}
will be applied on the parent L{Bucket} as well.
@type parentBucket: L{Bucket}
"""
self.content = 0
self.parentBucket = parentBucket
self.lastDrip = time()
def add(self, amount):
"""
Adds tokens to the L{Bucket} and its C{parentBucket}.
This will add as many of the C{amount} tokens as will fit into both
this L{Bucket} and its C{parentBucket}.
@param amount: The number of tokens to try to add.
@type amount: C{int}
@returns: The number of tokens that actually fit.
@returntype: C{int}
"""
self.drip()
if self.maxburst is None:
allowable = amount
else:
allowable = min(amount, self.maxburst - self.content)
if self.parentBucket is not None:
allowable = self.parentBucket.add(allowable)
self.content += allowable
return allowable
def drip(self):
"""
Let some of the bucket drain.
The L{Bucket} drains at the rate specified by the class
variable C{rate}.
@returns: C{True} if the bucket is empty after this drip.
@returntype: C{bool}
"""
if self.parentBucket is not None:
self.parentBucket.drip()
if self.rate is None:
self.content = 0
else:
now = time()
deltaTime = now - self.lastDrip
deltaTokens = deltaTime * self.rate
self.content = max(0, self.content - deltaTokens)
self.lastDrip = now
return self.content == 0
class IBucketFilter(Interface):
def getBucketFor(*somethings, **some_kw):
"""
Return a L{Bucket} corresponding to the provided parameters.
@returntype: L{Bucket}
"""
@implementer(IBucketFilter)
class HierarchicalBucketFilter:
"""
Filter things into buckets that can be nested.
@cvar bucketFactory: Class of buckets to make.
@type bucketFactory: L{Bucket}
@cvar sweepInterval: Seconds between sweeping out the bucket cache.
@type sweepInterval: C{int}
"""
bucketFactory = Bucket
sweepInterval: Optional[int] = None
def __init__(self, parentFilter=None):
self.buckets = {}
self.parentFilter = parentFilter
self.lastSweep = time()
def getBucketFor(self, *a, **kw):
"""
Find or create a L{Bucket} corresponding to the provided parameters.
Any parameters are passed on to L{getBucketKey}, from them it
decides which bucket you get.
@returntype: L{Bucket}
"""
if (self.sweepInterval is not None) and (
(time() - self.lastSweep) > self.sweepInterval
):
self.sweep()
if self.parentFilter:
parentBucket = self.parentFilter.getBucketFor(self, *a, **kw)
else:
parentBucket = None
key = self.getBucketKey(*a, **kw)
bucket = self.buckets.get(key)
if bucket is None:
bucket = self.bucketFactory(parentBucket)
self.buckets[key] = bucket
return bucket
def getBucketKey(self, *a, **kw):
"""
Construct a key based on the input parameters to choose a L{Bucket}.
The default implementation returns the same key for all
arguments. Override this method to provide L{Bucket} selection.
@returns: Something to be used as a key in the bucket cache.
"""
return None
def sweep(self):
"""
Remove empty buckets.
"""
for key, bucket in self.buckets.items():
bucket_is_empty = bucket.drip()
if (bucket._refcount == 0) and bucket_is_empty:
del self.buckets[key]
self.lastSweep = time()
class FilterByHost(HierarchicalBucketFilter):
"""
A Hierarchical Bucket filter with a L{Bucket} for each host.
"""
sweepInterval = 60 * 20
def getBucketKey(self, transport):
return transport.getPeer()[1]
class FilterByServer(HierarchicalBucketFilter):
"""
A Hierarchical Bucket filter with a L{Bucket} for each service.
"""
sweepInterval = None
def getBucketKey(self, transport):
return transport.getHost()[2]
class ShapedConsumer(pcp.ProducerConsumerProxy):
"""
Wraps a C{Consumer} and shapes the rate at which it receives data.
"""
# Providing a Pull interface means I don't have to try to schedule
# traffic with callLaters.
iAmStreaming = False
def __init__(self, consumer, bucket):
pcp.ProducerConsumerProxy.__init__(self, consumer)
self.bucket = bucket
self.bucket._refcount += 1
def _writeSomeData(self, data):
# In practice, this actually results in obscene amounts of
# overhead, as a result of generating lots and lots of packets
# with twelve-byte payloads. We may need to do a version of
# this with scheduled writes after all.
amount = self.bucket.add(len(data))
return pcp.ProducerConsumerProxy._writeSomeData(self, data[:amount])
def stopProducing(self):
pcp.ProducerConsumerProxy.stopProducing(self)
self.bucket._refcount -= 1
class ShapedTransport(ShapedConsumer):
"""
Wraps a C{Transport} and shapes the rate at which it receives data.
This is a L{ShapedConsumer} with a little bit of magic to provide for
the case where the consumer it wraps is also a C{Transport} and people
will be attempting to access attributes this does not proxy as a
C{Consumer} (e.g. C{loseConnection}).
"""
# Ugh. We only wanted to filter IConsumer, not ITransport.
iAmStreaming = False
def __getattr__(self, name):
# Because people will be doing things like .getPeer and
# .loseConnection on me.
return getattr(self.consumer, name)
class ShapedProtocolFactory:
"""
Dispense C{Protocols} with traffic shaping on their transports.
Usage::
myserver = SomeFactory()
myserver.protocol = ShapedProtocolFactory(myserver.protocol,
bucketFilter)
Where C{SomeServerFactory} is a L{twisted.internet.protocol.Factory}, and
C{bucketFilter} is an instance of L{HierarchicalBucketFilter}.
"""
def __init__(self, protoClass, bucketFilter):
"""
Tell me what to wrap and where to get buckets.
@param protoClass: The class of C{Protocol} this will generate
wrapped instances of.
@type protoClass: L{Protocol<twisted.internet.interfaces.IProtocol>}
class
@param bucketFilter: The filter which will determine how
traffic is shaped.
@type bucketFilter: L{HierarchicalBucketFilter}.
"""
# More precisely, protoClass can be any callable that will return
# instances of something that implements IProtocol.
self.protocol = protoClass
self.bucketFilter = bucketFilter
def __call__(self, *a, **kw):
"""
Make a C{Protocol} instance with a shaped transport.
Any parameters will be passed on to the protocol's initializer.
@returns: A C{Protocol} instance with a L{ShapedTransport}.
"""
proto = self.protocol(*a, **kw)
origMakeConnection = proto.makeConnection
def makeConnection(transport):
bucket = self.bucketFilter.getBucketFor(transport)
shapedTransport = ShapedTransport(transport, bucket)
return origMakeConnection(shapedTransport)
proto.makeConnection = makeConnection
return proto

View File

@@ -0,0 +1,253 @@
# -*- test-case-name: twisted.test.test_ident -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Ident protocol implementation.
"""
import struct
from twisted.internet import defer
from twisted.protocols import basic
from twisted.python import failure, log
_MIN_PORT = 1
_MAX_PORT = 2**16 - 1
class IdentError(Exception):
"""
Can't determine connection owner; reason unknown.
"""
identDescription = "UNKNOWN-ERROR"
def __str__(self) -> str:
return self.identDescription
class NoUser(IdentError):
"""
The connection specified by the port pair is not currently in use or
currently not owned by an identifiable entity.
"""
identDescription = "NO-USER"
class InvalidPort(IdentError):
"""
Either the local or foreign port was improperly specified. This should
be returned if either or both of the port ids were out of range (TCP
port numbers are from 1-65535), negative integers, reals or in any
fashion not recognized as a non-negative integer.
"""
identDescription = "INVALID-PORT"
class HiddenUser(IdentError):
"""
The server was able to identify the user of this port, but the
information was not returned at the request of the user.
"""
identDescription = "HIDDEN-USER"
class IdentServer(basic.LineOnlyReceiver):
"""
The Identification Protocol (a.k.a., "ident", a.k.a., "the Ident
Protocol") provides a means to determine the identity of a user of a
particular TCP connection. Given a TCP port number pair, it returns a
character string which identifies the owner of that connection on the
server's system.
Server authors should subclass this class and override the lookup method.
The default implementation returns an UNKNOWN-ERROR response for every
query.
"""
def lineReceived(self, line):
parts = line.split(",")
if len(parts) != 2:
self.invalidQuery()
else:
try:
portOnServer, portOnClient = map(int, parts)
except ValueError:
self.invalidQuery()
else:
if (
_MIN_PORT <= portOnServer <= _MAX_PORT
and _MIN_PORT <= portOnClient <= _MAX_PORT
):
self.validQuery(portOnServer, portOnClient)
else:
self._ebLookup(
failure.Failure(InvalidPort()), portOnServer, portOnClient
)
def invalidQuery(self):
self.transport.loseConnection()
def validQuery(self, portOnServer, portOnClient):
"""
Called when a valid query is received to look up and deliver the
response.
@param portOnServer: The server port from the query.
@param portOnClient: The client port from the query.
"""
serverAddr = self.transport.getHost().host, portOnServer
clientAddr = self.transport.getPeer().host, portOnClient
defer.maybeDeferred(self.lookup, serverAddr, clientAddr).addCallback(
self._cbLookup, portOnServer, portOnClient
).addErrback(self._ebLookup, portOnServer, portOnClient)
def _cbLookup(self, result, sport, cport):
(sysName, userId) = result
self.sendLine("%d, %d : USERID : %s : %s" % (sport, cport, sysName, userId))
def _ebLookup(self, failure, sport, cport):
if failure.check(IdentError):
self.sendLine("%d, %d : ERROR : %s" % (sport, cport, failure.value))
else:
log.err(failure)
self.sendLine(
"%d, %d : ERROR : %s" % (sport, cport, IdentError(failure.value))
)
def lookup(self, serverAddress, clientAddress):
"""
Lookup user information about the specified address pair.
Return value should be a two-tuple of system name and username.
Acceptable values for the system name may be found online at::
U{http://www.iana.org/assignments/operating-system-names}
This method may also raise any IdentError subclass (or IdentError
itself) to indicate user information will not be provided for the
given query.
A Deferred may also be returned.
@param serverAddress: A two-tuple representing the server endpoint
of the address being queried. The first element is a string holding
a dotted-quad IP address. The second element is an integer
representing the port.
@param clientAddress: Like I{serverAddress}, but represents the
client endpoint of the address being queried.
"""
raise IdentError()
class ProcServerMixin:
"""Implements lookup() to grab entries for responses from /proc/net/tcp"""
SYSTEM_NAME = "LINUX"
try:
from pwd import getpwuid # type:ignore[misc]
def getUsername(self, uid, getpwuid=getpwuid):
return getpwuid(uid)[0]
del getpwuid
except ImportError:
def getUsername(self, uid, getpwuid=None):
raise IdentError()
def entries(self):
with open("/proc/net/tcp") as f:
f.readline()
for L in f:
yield L.strip()
def dottedQuadFromHexString(self, hexstr):
return ".".join(
map(str, struct.unpack("4B", struct.pack("=L", int(hexstr, 16))))
)
def unpackAddress(self, packed):
addr, port = packed.split(":")
addr = self.dottedQuadFromHexString(addr)
port = int(port, 16)
return addr, port
def parseLine(self, line):
parts = line.strip().split()
localAddr, localPort = self.unpackAddress(parts[1])
remoteAddr, remotePort = self.unpackAddress(parts[2])
uid = int(parts[7])
return (localAddr, localPort), (remoteAddr, remotePort), uid
def lookup(self, serverAddress, clientAddress):
for ent in self.entries():
localAddr, remoteAddr, uid = self.parseLine(ent)
if remoteAddr == clientAddress and localAddr[1] == serverAddress[1]:
return (self.SYSTEM_NAME, self.getUsername(uid))
raise NoUser()
class IdentClient(basic.LineOnlyReceiver):
errorTypes = (IdentError, NoUser, InvalidPort, HiddenUser)
def __init__(self):
self.queries = []
def lookup(self, portOnServer, portOnClient):
"""
Lookup user information about the specified address pair.
"""
self.queries.append((defer.Deferred(), portOnServer, portOnClient))
if len(self.queries) > 1:
return self.queries[-1][0]
self.sendLine("%d, %d" % (portOnServer, portOnClient))
return self.queries[-1][0]
def lineReceived(self, line):
if not self.queries:
log.msg(f"Unexpected server response: {line!r}")
else:
d, _, _ = self.queries.pop(0)
self.parseResponse(d, line)
if self.queries:
self.sendLine("%d, %d" % (self.queries[0][1], self.queries[0][2]))
def connectionLost(self, reason):
for q in self.queries:
q[0].errback(IdentError(reason))
self.queries = []
def parseResponse(self, deferred, line):
parts = line.split(":", 2)
if len(parts) != 3:
deferred.errback(IdentError(line))
else:
ports, type, addInfo = map(str.strip, parts)
if type == "ERROR":
for et in self.errorTypes:
if et.identDescription == addInfo:
deferred.errback(et(line))
return
deferred.errback(IdentError(line))
else:
deferred.callback((type, addInfo))
__all__ = [
"IdentError",
"NoUser",
"InvalidPort",
"HiddenUser",
"IdentServer",
"IdentClient",
"ProcServerMixin",
]

View File

@@ -0,0 +1,387 @@
# -*- test-case-name: twisted.test.test_loopback -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Testing support for protocols -- loopback between client and server.
"""
# system imports
import tempfile
from zope.interface import implementer
from twisted.internet import defer, interfaces, main, protocol
from twisted.internet.interfaces import IAddress
from twisted.internet.task import deferLater
# Twisted Imports
from twisted.protocols import policies
from twisted.python import failure
class _LoopbackQueue:
"""
Trivial wrapper around a list to give it an interface like a queue, which
the addition of also sending notifications by way of a Deferred whenever
the list has something added to it.
"""
_notificationDeferred = None
disconnect = False
def __init__(self):
self._queue = []
def put(self, v):
self._queue.append(v)
if self._notificationDeferred is not None:
d, self._notificationDeferred = self._notificationDeferred, None
d.callback(None)
def __nonzero__(self):
return bool(self._queue)
__bool__ = __nonzero__
def get(self):
return self._queue.pop(0)
@implementer(IAddress)
class _LoopbackAddress:
pass
@implementer(interfaces.ITransport, interfaces.IConsumer)
class _LoopbackTransport:
disconnecting = False
producer = None
# ITransport
def __init__(self, q):
self.q = q
def write(self, data):
if not isinstance(data, bytes):
raise TypeError("Can only write bytes to ITransport")
self.q.put(data)
def writeSequence(self, iovec):
self.q.put(b"".join(iovec))
def loseConnection(self):
self.q.disconnect = True
self.q.put(None)
def abortConnection(self):
"""
Abort the connection. Same as L{loseConnection}.
"""
self.loseConnection()
def getPeer(self):
return _LoopbackAddress()
def getHost(self):
return _LoopbackAddress()
# IConsumer
def registerProducer(self, producer, streaming):
assert self.producer is None
self.producer = producer
self.streamingProducer = streaming
self._pollProducer()
def unregisterProducer(self):
assert self.producer is not None
self.producer = None
def _pollProducer(self):
if self.producer is not None and not self.streamingProducer:
self.producer.resumeProducing()
def identityPumpPolicy(queue, target):
"""
L{identityPumpPolicy} is a policy which delivers each chunk of data written
to the given queue as-is to the target.
This isn't a particularly realistic policy.
@see: L{loopbackAsync}
"""
while queue:
bytes = queue.get()
if bytes is None:
break
target.dataReceived(bytes)
def collapsingPumpPolicy(queue, target):
"""
L{collapsingPumpPolicy} is a policy which collapses all outstanding chunks
into a single string and delivers it to the target.
@see: L{loopbackAsync}
"""
bytes = []
while queue:
chunk = queue.get()
if chunk is None:
break
bytes.append(chunk)
if bytes:
target.dataReceived(b"".join(bytes))
def loopbackAsync(server, client, pumpPolicy=identityPumpPolicy):
"""
Establish a connection between C{server} and C{client} then transfer data
between them until the connection is closed. This is often useful for
testing a protocol.
@param server: The protocol instance representing the server-side of this
connection.
@param client: The protocol instance representing the client-side of this
connection.
@param pumpPolicy: When either C{server} or C{client} writes to its
transport, the string passed in is added to a queue of data for the
other protocol. Eventually, C{pumpPolicy} will be called with one such
queue and the corresponding protocol object. The pump policy callable
is responsible for emptying the queue and passing the strings it
contains to the given protocol's C{dataReceived} method. The signature
of C{pumpPolicy} is C{(queue, protocol)}. C{queue} is an object with a
C{get} method which will return the next string written to the
transport, or L{None} if the transport has been disconnected, and which
evaluates to C{True} if and only if there are more items to be
retrieved via C{get}.
@return: A L{Deferred} which fires when the connection has been closed and
both sides have received notification of this.
"""
serverToClient = _LoopbackQueue()
clientToServer = _LoopbackQueue()
server.makeConnection(_LoopbackTransport(serverToClient))
client.makeConnection(_LoopbackTransport(clientToServer))
return _loopbackAsyncBody(
server, serverToClient, client, clientToServer, pumpPolicy
)
def _loopbackAsyncBody(server, serverToClient, client, clientToServer, pumpPolicy):
"""
Transfer bytes from the output queue of each protocol to the input of the other.
@param server: The protocol instance representing the server-side of this
connection.
@param serverToClient: The L{_LoopbackQueue} holding the server's output.
@param client: The protocol instance representing the client-side of this
connection.
@param clientToServer: The L{_LoopbackQueue} holding the client's output.
@param pumpPolicy: See L{loopbackAsync}.
@return: A L{Deferred} which fires when the connection has been closed and
both sides have received notification of this.
"""
def pump(source, q, target):
sent = False
if q:
pumpPolicy(q, target)
sent = True
if sent and not q:
# A write buffer has now been emptied. Give any producer on that
# side an opportunity to produce more data.
source.transport._pollProducer()
return sent
while 1:
disconnect = clientSent = serverSent = False
# Deliver the data which has been written.
serverSent = pump(server, serverToClient, client)
clientSent = pump(client, clientToServer, server)
if not clientSent and not serverSent:
# Neither side wrote any data. Wait for some new data to be added
# before trying to do anything further.
d = defer.Deferred()
clientToServer._notificationDeferred = d
serverToClient._notificationDeferred = d
d.addCallback(
_loopbackAsyncContinue,
server,
serverToClient,
client,
clientToServer,
pumpPolicy,
)
return d
if serverToClient.disconnect:
# The server wants to drop the connection. Flush any remaining
# data it has.
disconnect = True
pump(server, serverToClient, client)
elif clientToServer.disconnect:
# The client wants to drop the connection. Flush any remaining
# data it has.
disconnect = True
pump(client, clientToServer, server)
if disconnect:
# Someone wanted to disconnect, so okay, the connection is gone.
server.connectionLost(failure.Failure(main.CONNECTION_DONE))
client.connectionLost(failure.Failure(main.CONNECTION_DONE))
return defer.succeed(None)
def _loopbackAsyncContinue(
ignored, server, serverToClient, client, clientToServer, pumpPolicy
):
# Clear the Deferred from each message queue, since it has already fired
# and cannot be used again.
clientToServer._notificationDeferred = None
serverToClient._notificationDeferred = None
# Schedule some more byte-pushing to happen. This isn't done
# synchronously because no actual transport can re-enter dataReceived as
# a result of calling write, and doing this synchronously could result
# in that.
from twisted.internet import reactor
return deferLater(
reactor,
0,
_loopbackAsyncBody,
server,
serverToClient,
client,
clientToServer,
pumpPolicy,
)
@implementer(interfaces.ITransport, interfaces.IConsumer)
class LoopbackRelay:
buffer = b""
shouldLose = 0
disconnecting = 0
producer = None
def __init__(self, target, logFile=None):
self.target = target
self.logFile = logFile
def write(self, data):
self.buffer = self.buffer + data
if self.logFile:
self.logFile.write("loopback writing %s\n" % repr(data))
def writeSequence(self, iovec):
self.write(b"".join(iovec))
def clearBuffer(self):
if self.shouldLose == -1:
return
if self.producer:
self.producer.resumeProducing()
if self.buffer:
if self.logFile:
self.logFile.write("loopback receiving %s\n" % repr(self.buffer))
buffer = self.buffer
self.buffer = b""
self.target.dataReceived(buffer)
if self.shouldLose == 1:
self.shouldLose = -1
self.target.connectionLost(failure.Failure(main.CONNECTION_DONE))
def loseConnection(self):
if self.shouldLose != -1:
self.shouldLose = 1
def getHost(self):
return "loopback"
def getPeer(self):
return "loopback"
def registerProducer(self, producer, streaming):
self.producer = producer
def unregisterProducer(self):
self.producer = None
def logPrefix(self):
return f"Loopback({self.target.__class__.__name__!r})"
class LoopbackClientFactory(protocol.ClientFactory):
def __init__(self, protocol):
self.disconnected = 0
self.deferred = defer.Deferred()
self.protocol = protocol
def buildProtocol(self, addr):
return self.protocol
def clientConnectionLost(self, connector, reason):
self.disconnected = 1
self.deferred.callback(None)
class _FireOnClose(policies.ProtocolWrapper):
def __init__(self, protocol, factory):
policies.ProtocolWrapper.__init__(self, protocol, factory)
self.deferred = defer.Deferred()
def connectionLost(self, reason):
policies.ProtocolWrapper.connectionLost(self, reason)
self.deferred.callback(None)
def loopbackTCP(server, client, port=0, noisy=True):
"""Run session between server and client protocol instances over TCP."""
from twisted.internet import reactor
f = policies.WrappingFactory(protocol.Factory())
serverWrapper = _FireOnClose(f, server)
f.noisy = noisy
f.buildProtocol = lambda addr: serverWrapper
serverPort = reactor.listenTCP(port, f, interface="127.0.0.1")
clientF = LoopbackClientFactory(client)
clientF.noisy = noisy
reactor.connectTCP("127.0.0.1", serverPort.getHost().port, clientF)
d = clientF.deferred
d.addCallback(lambda x: serverWrapper.deferred)
d.addCallback(lambda x: serverPort.stopListening())
return d
def loopbackUNIX(server, client, noisy=True):
"""Run session between server and client protocol instances over UNIX socket."""
path = tempfile.mktemp()
from twisted.internet import reactor
f = policies.WrappingFactory(protocol.Factory())
serverWrapper = _FireOnClose(f, server)
f.noisy = noisy
f.buildProtocol = lambda addr: serverWrapper
serverPort = reactor.listenUNIX(path, f)
clientF = LoopbackClientFactory(client)
clientF.noisy = noisy
reactor.connectUNIX(path, clientF)
d = clientF.deferred
d.addCallback(lambda x: serverWrapper.deferred)
d.addCallback(lambda x: serverPort.stopListening())
return d

View File

@@ -0,0 +1,733 @@
# -*- test-case-name: twisted.test.test_memcache -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Memcache client protocol. Memcached is a caching server, storing data in the
form of pairs key/value, and memcache is the protocol to talk with it.
To connect to a server, create a factory for L{MemCacheProtocol}::
from twisted.internet import reactor, protocol
from twisted.protocols.memcache import MemCacheProtocol, DEFAULT_PORT
d = protocol.ClientCreator(reactor, MemCacheProtocol
).connectTCP("localhost", DEFAULT_PORT)
def doSomething(proto):
# Here you call the memcache operations
return proto.set("mykey", "a lot of data")
d.addCallback(doSomething)
reactor.run()
All the operations of the memcache protocol are present, but
L{MemCacheProtocol.set} and L{MemCacheProtocol.get} are the more important.
See U{http://code.sixapart.[AWS-SECRET-REMOVED]col.txt} for
more information about the protocol.
"""
from collections import deque
from twisted.internet.defer import Deferred, TimeoutError, fail
from twisted.protocols.basic import LineReceiver
from twisted.protocols.policies import TimeoutMixin
from twisted.python import log
from twisted.python.compat import nativeString, networkString
DEFAULT_PORT = 11211
class NoSuchCommand(Exception):
"""
Exception raised when a non existent command is called.
"""
class ClientError(Exception):
"""
Error caused by an invalid client call.
"""
class ServerError(Exception):
"""
Problem happening on the server.
"""
class Command:
"""
Wrap a client action into an object, that holds the values used in the
protocol.
@ivar _deferred: the L{Deferred} object that will be fired when the result
arrives.
@type _deferred: L{Deferred}
@ivar command: name of the command sent to the server.
@type command: L{bytes}
"""
def __init__(self, command, **kwargs):
"""
Create a command.
@param command: the name of the command.
@type command: L{bytes}
@param kwargs: this values will be stored as attributes of the object
for future use
"""
self.command = command
self._deferred = Deferred()
for k, v in kwargs.items():
setattr(self, k, v)
def success(self, value):
"""
Shortcut method to fire the underlying deferred.
"""
self._deferred.callback(value)
def fail(self, error):
"""
Make the underlying deferred fails.
"""
self._deferred.errback(error)
class MemCacheProtocol(LineReceiver, TimeoutMixin):
"""
MemCache protocol: connect to a memcached server to store/retrieve values.
@ivar persistentTimeOut: the timeout period used to wait for a response.
@type persistentTimeOut: L{int}
@ivar _current: current list of requests waiting for an answer from the
server.
@type _current: L{deque} of L{Command}
@ivar _lenExpected: amount of data expected in raw mode, when reading for
a value.
@type _lenExpected: L{int}
@ivar _getBuffer: current buffer of data, used to store temporary data
when reading in raw mode.
@type _getBuffer: L{list}
@ivar _bufferLength: the total amount of bytes in C{_getBuffer}.
@type _bufferLength: L{int}
@ivar _disconnected: indicate if the connectionLost has been called or not.
@type _disconnected: L{bool}
"""
MAX_KEY_LENGTH = 250
_disconnected = False
def __init__(self, timeOut=60):
"""
Create the protocol.
@param timeOut: the timeout to wait before detecting that the
connection is dead and close it. It's expressed in seconds.
@type timeOut: L{int}
"""
self._current = deque()
self._lenExpected = None
self._getBuffer = None
self._bufferLength = None
self.persistentTimeOut = self.timeOut = timeOut
def _cancelCommands(self, reason):
"""
Cancel all the outstanding commands, making them fail with C{reason}.
"""
while self._current:
cmd = self._current.popleft()
cmd.fail(reason)
def timeoutConnection(self):
"""
Close the connection in case of timeout.
"""
self._cancelCommands(TimeoutError("Connection timeout"))
self.transport.loseConnection()
def connectionLost(self, reason):
"""
Cause any outstanding commands to fail.
"""
self._disconnected = True
self._cancelCommands(reason)
LineReceiver.connectionLost(self, reason)
def sendLine(self, line):
"""
Override sendLine to add a timeout to response.
"""
if not self._current:
self.setTimeout(self.persistentTimeOut)
LineReceiver.sendLine(self, line)
def rawDataReceived(self, data):
"""
Collect data for a get.
"""
self.resetTimeout()
self._getBuffer.append(data)
self._bufferLength += len(data)
if self._bufferLength >= self._lenExpected + 2:
data = b"".join(self._getBuffer)
buf = data[: self._lenExpected]
rem = data[self._lenExpected + 2 :]
val = buf
self._lenExpected = None
self._getBuffer = None
self._bufferLength = None
cmd = self._current[0]
if cmd.multiple:
flags, cas = cmd.values[cmd.currentKey]
cmd.values[cmd.currentKey] = (flags, cas, val)
else:
cmd.value = val
self.setLineMode(rem)
def cmd_STORED(self):
"""
Manage a success response to a set operation.
"""
self._current.popleft().success(True)
def cmd_NOT_STORED(self):
"""
Manage a specific 'not stored' response to a set operation: this is not
an error, but some condition wasn't met.
"""
self._current.popleft().success(False)
def cmd_END(self):
"""
This the end token to a get or a stat operation.
"""
cmd = self._current.popleft()
if cmd.command == b"get":
if cmd.multiple:
values = {key: val[::2] for key, val in cmd.values.items()}
cmd.success(values)
else:
cmd.success((cmd.flags, cmd.value))
elif cmd.command == b"gets":
if cmd.multiple:
cmd.success(cmd.values)
else:
cmd.success((cmd.flags, cmd.cas, cmd.value))
elif cmd.command == b"stats":
cmd.success(cmd.values)
else:
raise RuntimeError(
"Unexpected END response to {} command".format(
nativeString(cmd.command)
)
)
def cmd_NOT_FOUND(self):
"""
Manage error response for incr/decr/delete.
"""
self._current.popleft().success(False)
def cmd_VALUE(self, line):
"""
Prepare the reading a value after a get.
"""
cmd = self._current[0]
if cmd.command == b"get":
key, flags, length = line.split()
cas = b""
else:
key, flags, length, cas = line.split()
self._lenExpected = int(length)
self._getBuffer = []
self._bufferLength = 0
if cmd.multiple:
if key not in cmd.keys:
raise RuntimeError("Unexpected commands answer.")
cmd.currentKey = key
cmd.values[key] = [int(flags), cas]
else:
if cmd.key != key:
raise RuntimeError("Unexpected commands answer.")
cmd.flags = int(flags)
cmd.cas = cas
self.setRawMode()
def cmd_STAT(self, line):
"""
Reception of one stat line.
"""
cmd = self._current[0]
key, val = line.split(b" ", 1)
cmd.values[key] = val
def cmd_VERSION(self, versionData):
"""
Read version token.
"""
self._current.popleft().success(versionData)
def cmd_ERROR(self):
"""
A non-existent command has been sent.
"""
log.err("Non-existent command sent.")
cmd = self._current.popleft()
cmd.fail(NoSuchCommand())
def cmd_CLIENT_ERROR(self, errText):
"""
An invalid input as been sent.
"""
errText = repr(errText)
log.err("Invalid input: " + errText)
cmd = self._current.popleft()
cmd.fail(ClientError(errText))
def cmd_SERVER_ERROR(self, errText):
"""
An error has happened server-side.
"""
errText = repr(errText)
log.err("Server error: " + errText)
cmd = self._current.popleft()
cmd.fail(ServerError(errText))
def cmd_DELETED(self):
"""
A delete command has completed successfully.
"""
self._current.popleft().success(True)
def cmd_OK(self):
"""
The last command has been completed.
"""
self._current.popleft().success(True)
def cmd_EXISTS(self):
"""
A C{checkAndSet} update has failed.
"""
self._current.popleft().success(False)
def lineReceived(self, line):
"""
Receive line commands from the server.
"""
self.resetTimeout()
token = line.split(b" ", 1)[0]
# First manage standard commands without space
cmd = getattr(self, "cmd_" + nativeString(token), None)
if cmd is not None:
args = line.split(b" ", 1)[1:]
if args:
cmd(args[0])
else:
cmd()
else:
# Then manage commands with space in it
line = line.replace(b" ", b"_")
cmd = getattr(self, "cmd_" + nativeString(line), None)
if cmd is not None:
cmd()
else:
# Increment/Decrement response
cmd = self._current.popleft()
val = int(line)
cmd.success(val)
if not self._current:
# No pending request, remove timeout
self.setTimeout(None)
def increment(self, key, val=1):
"""
Increment the value of C{key} by given value (default to 1).
C{key} must be consistent with an int. Return the new value.
@param key: the key to modify.
@type key: L{bytes}
@param val: the value to increment.
@type val: L{int}
@return: a deferred with will be called back with the new value
associated with the key (after the increment).
@rtype: L{Deferred}
"""
return self._incrdecr(b"incr", key, val)
def decrement(self, key, val=1):
"""
Decrement the value of C{key} by given value (default to 1).
C{key} must be consistent with an int. Return the new value, coerced to
0 if negative.
@param key: the key to modify.
@type key: L{bytes}
@param val: the value to decrement.
@type val: L{int}
@return: a deferred with will be called back with the new value
associated with the key (after the decrement).
@rtype: L{Deferred}
"""
return self._incrdecr(b"decr", key, val)
def _incrdecr(self, cmd, key, val):
"""
Internal wrapper for incr/decr.
"""
if self._disconnected:
return fail(RuntimeError("not connected"))
if not isinstance(key, bytes):
return fail(
ClientError(f"Invalid type for key: {type(key)}, expecting bytes")
)
if len(key) > self.MAX_KEY_LENGTH:
return fail(ClientError("Key too long"))
fullcmd = b" ".join([cmd, key, b"%d" % (int(val),)])
self.sendLine(fullcmd)
cmdObj = Command(cmd, key=key)
self._current.append(cmdObj)
return cmdObj._deferred
def replace(self, key, val, flags=0, expireTime=0):
"""
Replace the given C{key}. It must already exist in the server.
@param key: the key to replace.
@type key: L{bytes}
@param val: the new value associated with the key.
@type val: L{bytes}
@param flags: the flags to store with the key.
@type flags: L{int}
@param expireTime: if different from 0, the relative time in seconds
when the key will be deleted from the store.
@type expireTime: L{int}
@return: a deferred that will fire with C{True} if the operation has
succeeded, and C{False} with the key didn't previously exist.
@rtype: L{Deferred}
"""
return self._set(b"replace", key, val, flags, expireTime, b"")
def add(self, key, val, flags=0, expireTime=0):
"""
Add the given C{key}. It must not exist in the server.
@param key: the key to add.
@type key: L{bytes}
@param val: the value associated with the key.
@type val: L{bytes}
@param flags: the flags to store with the key.
@type flags: L{int}
@param expireTime: if different from 0, the relative time in seconds
when the key will be deleted from the store.
@type expireTime: L{int}
@return: a deferred that will fire with C{True} if the operation has
succeeded, and C{False} with the key already exists.
@rtype: L{Deferred}
"""
return self._set(b"add", key, val, flags, expireTime, b"")
def set(self, key, val, flags=0, expireTime=0):
"""
Set the given C{key}.
@param key: the key to set.
@type key: L{bytes}
@param val: the value associated with the key.
@type val: L{bytes}
@param flags: the flags to store with the key.
@type flags: L{int}
@param expireTime: if different from 0, the relative time in seconds
when the key will be deleted from the store.
@type expireTime: L{int}
@return: a deferred that will fire with C{True} if the operation has
succeeded.
@rtype: L{Deferred}
"""
return self._set(b"set", key, val, flags, expireTime, b"")
def checkAndSet(self, key, val, cas, flags=0, expireTime=0):
"""
Change the content of C{key} only if the C{cas} value matches the
current one associated with the key. Use this to store a value which
hasn't been modified since last time you fetched it.
@param key: The key to set.
@type key: L{bytes}
@param val: The value associated with the key.
@type val: L{bytes}
@param cas: Unique 64-bit value returned by previous call of C{get}.
@type cas: L{bytes}
@param flags: The flags to store with the key.
@type flags: L{int}
@param expireTime: If different from 0, the relative time in seconds
when the key will be deleted from the store.
@type expireTime: L{int}
@return: A deferred that will fire with C{True} if the operation has
succeeded, C{False} otherwise.
@rtype: L{Deferred}
"""
return self._set(b"cas", key, val, flags, expireTime, cas)
def _set(self, cmd, key, val, flags, expireTime, cas):
"""
Internal wrapper for setting values.
"""
if self._disconnected:
return fail(RuntimeError("not connected"))
if not isinstance(key, bytes):
return fail(
ClientError(f"Invalid type for key: {type(key)}, expecting bytes")
)
if len(key) > self.MAX_KEY_LENGTH:
return fail(ClientError("Key too long"))
if not isinstance(val, bytes):
return fail(
ClientError(f"Invalid type for value: {type(val)}, expecting bytes")
)
if cas:
cas = b" " + cas
length = len(val)
fullcmd = (
b" ".join(
[cmd, key, networkString("%d %d %d" % (flags, expireTime, length))]
)
+ cas
)
self.sendLine(fullcmd)
self.sendLine(val)
cmdObj = Command(cmd, key=key, flags=flags, length=length)
self._current.append(cmdObj)
return cmdObj._deferred
def append(self, key, val):
"""
Append given data to the value of an existing key.
@param key: The key to modify.
@type key: L{bytes}
@param val: The value to append to the current value associated with
the key.
@type val: L{bytes}
@return: A deferred that will fire with C{True} if the operation has
succeeded, C{False} otherwise.
@rtype: L{Deferred}
"""
# Even if flags and expTime values are ignored, we have to pass them
return self._set(b"append", key, val, 0, 0, b"")
def prepend(self, key, val):
"""
Prepend given data to the value of an existing key.
@param key: The key to modify.
@type key: L{bytes}
@param val: The value to prepend to the current value associated with
the key.
@type val: L{bytes}
@return: A deferred that will fire with C{True} if the operation has
succeeded, C{False} otherwise.
@rtype: L{Deferred}
"""
# Even if flags and expTime values are ignored, we have to pass them
return self._set(b"prepend", key, val, 0, 0, b"")
def get(self, key, withIdentifier=False):
"""
Get the given C{key}. It doesn't support multiple keys. If
C{withIdentifier} is set to C{True}, the command issued is a C{gets},
that will return the current identifier associated with the value. This
identifier has to be used when issuing C{checkAndSet} update later,
using the corresponding method.
@param key: The key to retrieve.
@type key: L{bytes}
@param withIdentifier: If set to C{True}, retrieve the current
identifier along with the value and the flags.
@type withIdentifier: L{bool}
@return: A deferred that will fire with the tuple (flags, value) if
C{withIdentifier} is C{False}, or (flags, cas identifier, value)
if C{True}. If the server indicates there is no value
associated with C{key}, the returned value will be L{None} and
the returned flags will be C{0}.
@rtype: L{Deferred}
"""
return self._get([key], withIdentifier, False)
def getMultiple(self, keys, withIdentifier=False):
"""
Get the given list of C{keys}. If C{withIdentifier} is set to C{True},
the command issued is a C{gets}, that will return the identifiers
associated with each values. This identifier has to be used when
issuing C{checkAndSet} update later, using the corresponding method.
@param keys: The keys to retrieve.
@type keys: L{list} of L{bytes}
@param withIdentifier: If set to C{True}, retrieve the identifiers
along with the values and the flags.
@type withIdentifier: L{bool}
@return: A deferred that will fire with a dictionary with the elements
of C{keys} as keys and the tuples (flags, value) as values if
C{withIdentifier} is C{False}, or (flags, cas identifier, value) if
C{True}. If the server indicates there is no value associated with
C{key}, the returned values will be L{None} and the returned flags
will be C{0}.
@rtype: L{Deferred}
@since: 9.0
"""
return self._get(keys, withIdentifier, True)
def _get(self, keys, withIdentifier, multiple):
"""
Helper method for C{get} and C{getMultiple}.
"""
keys = list(keys)
if self._disconnected:
return fail(RuntimeError("not connected"))
for key in keys:
if not isinstance(key, bytes):
return fail(
ClientError(f"Invalid type for key: {type(key)}, expecting bytes")
)
if len(key) > self.MAX_KEY_LENGTH:
return fail(ClientError("Key too long"))
if withIdentifier:
cmd = b"gets"
else:
cmd = b"get"
fullcmd = b" ".join([cmd] + keys)
self.sendLine(fullcmd)
if multiple:
values = {key: (0, b"", None) for key in keys}
cmdObj = Command(cmd, keys=keys, values=values, multiple=True)
else:
cmdObj = Command(
cmd, key=keys[0], value=None, flags=0, cas=b"", multiple=False
)
self._current.append(cmdObj)
return cmdObj._deferred
def stats(self, arg=None):
"""
Get some stats from the server. It will be available as a dict.
@param arg: An optional additional string which will be sent along
with the I{stats} command. The interpretation of this value by
the server is left undefined by the memcache protocol
specification.
@type arg: L{None} or L{bytes}
@return: a deferred that will fire with a L{dict} of the available
statistics.
@rtype: L{Deferred}
"""
if arg:
cmd = b"stats " + arg
else:
cmd = b"stats"
if self._disconnected:
return fail(RuntimeError("not connected"))
self.sendLine(cmd)
cmdObj = Command(b"stats", values={})
self._current.append(cmdObj)
return cmdObj._deferred
def version(self):
"""
Get the version of the server.
@return: a deferred that will fire with the string value of the
version.
@rtype: L{Deferred}
"""
if self._disconnected:
return fail(RuntimeError("not connected"))
self.sendLine(b"version")
cmdObj = Command(b"version")
self._current.append(cmdObj)
return cmdObj._deferred
def delete(self, key):
"""
Delete an existing C{key}.
@param key: the key to delete.
@type key: L{bytes}
@return: a deferred that will be called back with C{True} if the key
was successfully deleted, or C{False} if not.
@rtype: L{Deferred}
"""
if self._disconnected:
return fail(RuntimeError("not connected"))
if not isinstance(key, bytes):
return fail(
ClientError(f"Invalid type for key: {type(key)}, expecting bytes")
)
self.sendLine(b"delete " + key)
cmdObj = Command(b"delete", key=key)
self._current.append(cmdObj)
return cmdObj._deferred
def flushAll(self):
"""
Flush all cached values.
@return: a deferred that will be called back with C{True} when the
operation has succeeded.
@rtype: L{Deferred}
"""
if self._disconnected:
return fail(RuntimeError("not connected"))
self.sendLine(b"flush_all")
cmdObj = Command(b"flush_all")
self._current.append(cmdObj)
return cmdObj._deferred
__all__ = [
"MemCacheProtocol",
"DEFAULT_PORT",
"NoSuchCommand",
"ClientError",
"ServerError",
]

View File

@@ -0,0 +1,211 @@
# -*- test-case-name: twisted.test.test_pcp -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Producer-Consumer Proxy.
"""
from zope.interface import implementer
from twisted.internet import interfaces
@implementer(interfaces.IProducer, interfaces.IConsumer)
class BasicProducerConsumerProxy:
"""
I can act as a man in the middle between any Producer and Consumer.
@ivar producer: the Producer I subscribe to.
@type producer: L{IProducer<interfaces.IProducer>}
@ivar consumer: the Consumer I publish to.
@type consumer: L{IConsumer<interfaces.IConsumer>}
@ivar paused: As a Producer, am I paused?
@type paused: bool
"""
consumer = None
producer = None
producerIsStreaming = None
iAmStreaming = True
outstandingPull = False
paused = False
stopped = False
def __init__(self, consumer):
self._buffer = []
if consumer is not None:
self.consumer = consumer
consumer.registerProducer(self, self.iAmStreaming)
# Producer methods:
def pauseProducing(self):
self.paused = True
if self.producer:
self.producer.pauseProducing()
def resumeProducing(self):
self.paused = False
if self._buffer:
# TODO: Check to see if consumer supports writeSeq.
self.consumer.write("".join(self._buffer))
self._buffer[:] = []
else:
if not self.iAmStreaming:
self.outstandingPull = True
if self.producer is not None:
self.producer.resumeProducing()
def stopProducing(self):
if self.producer is not None:
self.producer.stopProducing()
if self.consumer is not None:
del self.consumer
# Consumer methods:
def write(self, data):
if self.paused or (not self.iAmStreaming and not self.outstandingPull):
# We could use that fifo queue here.
self._buffer.append(data)
elif self.consumer is not None:
self.consumer.write(data)
self.outstandingPull = False
def finish(self):
if self.consumer is not None:
self.consumer.finish()
self.unregisterProducer()
def registerProducer(self, producer, streaming):
self.producer = producer
self.producerIsStreaming = streaming
def unregisterProducer(self):
if self.producer is not None:
del self.producer
del self.producerIsStreaming
if self.consumer:
self.consumer.unregisterProducer()
def __repr__(self) -> str:
return f"<{self.__class__}@{id(self):x} around {self.consumer}>"
class ProducerConsumerProxy(BasicProducerConsumerProxy):
"""ProducerConsumerProxy with a finite buffer.
When my buffer fills up, I have my parent Producer pause until my buffer
has room in it again.
"""
# Copies much from abstract.FileDescriptor
bufferSize = 2**2**2**2
producerPaused = False
unregistered = False
def pauseProducing(self):
# Does *not* call up to ProducerConsumerProxy to relay the pause
# message through to my parent Producer.
self.paused = True
def resumeProducing(self):
self.paused = False
if self._buffer:
data = "".join(self._buffer)
bytesSent = self._writeSomeData(data)
if bytesSent < len(data):
unsent = data[bytesSent:]
assert (
not self.iAmStreaming
), "Streaming producer did not write all its data."
self._buffer[:] = [unsent]
else:
self._buffer[:] = []
else:
bytesSent = 0
if (
self.unregistered
and bytesSent
and not self._buffer
and self.consumer is not None
):
self.consumer.unregisterProducer()
if not self.iAmStreaming:
self.outstandingPull = not bytesSent
if self.producer is not None:
bytesBuffered = sum(len(s) for s in self._buffer)
# TODO: You can see here the potential for high and low
# watermarks, where bufferSize would be the high mark when we
# ask the upstream producer to pause, and we wouldn't have
# it resume again until it hit the low mark. Or if producer
# is Pull, maybe we'd like to pull from it as much as necessary
# to keep our buffer full to the low mark, so we're never caught
# without something to send.
if self.producerPaused and (bytesBuffered < self.bufferSize):
# Now that our buffer is empty,
self.producerPaused = False
self.producer.resumeProducing()
elif self.outstandingPull:
# I did not have any data to write in response to a pull,
# so I'd better pull some myself.
self.producer.resumeProducing()
def write(self, data):
if self.paused or (not self.iAmStreaming and not self.outstandingPull):
# We could use that fifo queue here.
self._buffer.append(data)
elif self.consumer is not None:
assert (
not self._buffer
), "Writing fresh data to consumer before my buffer is empty!"
# I'm going to use _writeSomeData here so that there is only one
# path to self.consumer.write. But it doesn't actually make sense,
# if I am streaming, for some data to not be all data. But maybe I
# am not streaming, but I am writing here anyway, because there was
# an earlier request for data which was not answered.
bytesSent = self._writeSomeData(data)
self.outstandingPull = False
if not bytesSent == len(data):
assert (
not self.iAmStreaming
), "Streaming producer did not write all its data."
self._buffer.append(data[bytesSent:])
if (self.producer is not None) and self.producerIsStreaming:
bytesBuffered = sum(len(s) for s in self._buffer)
if bytesBuffered >= self.bufferSize:
self.producer.pauseProducing()
self.producerPaused = True
def registerProducer(self, producer, streaming):
self.unregistered = False
BasicProducerConsumerProxy.registerProducer(self, producer, streaming)
if not streaming:
producer.resumeProducing()
def unregisterProducer(self):
if self.producer is not None:
del self.producer
del self.producerIsStreaming
self.unregistered = True
if self.consumer and not self._buffer:
self.consumer.unregisterProducer()
def _writeSomeData(self, data):
"""Write as much of this data as possible.
@returns: The number of bytes written.
"""
if self.consumer is None:
return 0
self.consumer.write(data)
return len(data)

View File

@@ -0,0 +1,696 @@
# -*- test-case-name: twisted.test.test_policies -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Resource limiting policies.
@seealso: See also L{twisted.protocols.htb} for rate limiting.
"""
# system imports
import sys
from typing import Optional, Type
from zope.interface import directlyProvides, providedBy
from twisted.internet import error, interfaces
from twisted.internet.interfaces import ILoggingContext
# twisted imports
from twisted.internet.protocol import ClientFactory, Protocol, ServerFactory
from twisted.python import log
def _wrappedLogPrefix(wrapper, wrapped):
"""
Compute a log prefix for a wrapper and the object it wraps.
@rtype: C{str}
"""
if ILoggingContext.providedBy(wrapped):
logPrefix = wrapped.logPrefix()
else:
logPrefix = wrapped.__class__.__name__
return f"{logPrefix} ({wrapper.__class__.__name__})"
class ProtocolWrapper(Protocol):
"""
Wraps protocol instances and acts as their transport as well.
@ivar wrappedProtocol: An L{IProtocol<twisted.internet.interfaces.IProtocol>}
provider to which L{IProtocol<twisted.internet.interfaces.IProtocol>}
method calls onto this L{ProtocolWrapper} will be proxied.
@ivar factory: The L{WrappingFactory} which created this
L{ProtocolWrapper}.
"""
disconnecting = 0
def __init__(
self, factory: "WrappingFactory", wrappedProtocol: interfaces.IProtocol
):
self.wrappedProtocol = wrappedProtocol
self.factory = factory
def logPrefix(self):
"""
Use a customized log prefix mentioning both the wrapped protocol and
the current one.
"""
return _wrappedLogPrefix(self, self.wrappedProtocol)
def makeConnection(self, transport):
"""
When a connection is made, register this wrapper with its factory,
save the real transport, and connect the wrapped protocol to this
L{ProtocolWrapper} to intercept any transport calls it makes.
"""
directlyProvides(self, providedBy(transport))
Protocol.makeConnection(self, transport)
self.factory.registerProtocol(self)
self.wrappedProtocol.makeConnection(self)
# Transport relaying
def write(self, data):
self.transport.write(data)
def writeSequence(self, data):
self.transport.writeSequence(data)
def loseConnection(self):
self.disconnecting = 1
self.transport.loseConnection()
def getPeer(self):
return self.transport.getPeer()
def getHost(self):
return self.transport.getHost()
def registerProducer(self, producer, streaming):
self.transport.registerProducer(producer, streaming)
def unregisterProducer(self):
self.transport.unregisterProducer()
def stopConsuming(self):
self.transport.stopConsuming()
def __getattr__(self, name):
return getattr(self.transport, name)
# Protocol relaying
def dataReceived(self, data):
self.wrappedProtocol.dataReceived(data)
def connectionLost(self, reason):
self.factory.unregisterProtocol(self)
self.wrappedProtocol.connectionLost(reason)
# Breaking reference cycle between self and wrappedProtocol.
self.wrappedProtocol = None
class WrappingFactory(ClientFactory):
"""
Wraps a factory and its protocols, and keeps track of them.
"""
protocol: Type[Protocol] = ProtocolWrapper
def __init__(self, wrappedFactory):
self.wrappedFactory = wrappedFactory
self.protocols = {}
def logPrefix(self):
"""
Generate a log prefix mentioning both the wrapped factory and this one.
"""
return _wrappedLogPrefix(self, self.wrappedFactory)
def doStart(self):
self.wrappedFactory.doStart()
ClientFactory.doStart(self)
def doStop(self):
self.wrappedFactory.doStop()
ClientFactory.doStop(self)
def startedConnecting(self, connector):
self.wrappedFactory.startedConnecting(connector)
def clientConnectionFailed(self, connector, reason):
self.wrappedFactory.clientConnectionFailed(connector, reason)
def clientConnectionLost(self, connector, reason):
self.wrappedFactory.clientConnectionLost(connector, reason)
def buildProtocol(self, addr):
return self.protocol(self, self.wrappedFactory.buildProtocol(addr))
def registerProtocol(self, p):
"""
Called by protocol to register itself.
"""
self.protocols[p] = 1
def unregisterProtocol(self, p):
"""
Called by protocols when they go away.
"""
del self.protocols[p]
class ThrottlingProtocol(ProtocolWrapper):
"""
Protocol for L{ThrottlingFactory}.
"""
# wrap API for tracking bandwidth
def write(self, data):
self.factory.registerWritten(len(data))
ProtocolWrapper.write(self, data)
def writeSequence(self, seq):
self.factory.registerWritten(sum(map(len, seq)))
ProtocolWrapper.writeSequence(self, seq)
def dataReceived(self, data):
self.factory.registerRead(len(data))
ProtocolWrapper.dataReceived(self, data)
def registerProducer(self, producer, streaming):
self.producer = producer
ProtocolWrapper.registerProducer(self, producer, streaming)
def unregisterProducer(self):
del self.producer
ProtocolWrapper.unregisterProducer(self)
def throttleReads(self):
self.transport.pauseProducing()
def unthrottleReads(self):
self.transport.resumeProducing()
def throttleWrites(self):
if hasattr(self, "producer"):
self.producer.pauseProducing()
def unthrottleWrites(self):
if hasattr(self, "producer"):
self.producer.resumeProducing()
class ThrottlingFactory(WrappingFactory):
"""
Throttles bandwidth and number of connections.
Write bandwidth will only be throttled if there is a producer
registered.
"""
protocol = ThrottlingProtocol
def __init__(
self,
wrappedFactory,
maxConnectionCount=sys.maxsize,
readLimit=None,
writeLimit=None,
):
WrappingFactory.__init__(self, wrappedFactory)
self.connectionCount = 0
self.maxConnectionCount = maxConnectionCount
self.readLimit = readLimit # max bytes we should read per second
self.writeLimit = writeLimit # max bytes we should write per second
self.readThisSecond = 0
self.writtenThisSecond = 0
self.unthrottleReadsID = None
self.checkReadBandwidthID = None
self.unthrottleWritesID = None
self.checkWriteBandwidthID = None
def callLater(self, period, func):
"""
Wrapper around
L{reactor.callLater<twisted.internet.interfaces.IReactorTime.callLater>}
for test purpose.
"""
from twisted.internet import reactor
return reactor.callLater(period, func)
def registerWritten(self, length):
"""
Called by protocol to tell us more bytes were written.
"""
self.writtenThisSecond += length
def registerRead(self, length):
"""
Called by protocol to tell us more bytes were read.
"""
self.readThisSecond += length
def checkReadBandwidth(self):
"""
Checks if we've passed bandwidth limits.
"""
if self.readThisSecond > self.readLimit:
self.throttleReads()
throttleTime = (float(self.readThisSecond) / self.readLimit) - 1.0
self.unthrottleReadsID = self.callLater(throttleTime, self.unthrottleReads)
self.readThisSecond = 0
self.checkReadBandwidthID = self.callLater(1, self.checkReadBandwidth)
def checkWriteBandwidth(self):
if self.writtenThisSecond > self.writeLimit:
self.throttleWrites()
throttleTime = (float(self.writtenThisSecond) / self.writeLimit) - 1.0
self.unthrottleWritesID = self.callLater(
throttleTime, self.unthrottleWrites
)
# reset for next round
self.writtenThisSecond = 0
self.checkWriteBandwidthID = self.callLater(1, self.checkWriteBandwidth)
def throttleReads(self):
"""
Throttle reads on all protocols.
"""
log.msg("Throttling reads on %s" % self)
for p in self.protocols.keys():
p.throttleReads()
def unthrottleReads(self):
"""
Stop throttling reads on all protocols.
"""
self.unthrottleReadsID = None
log.msg("Stopped throttling reads on %s" % self)
for p in self.protocols.keys():
p.unthrottleReads()
def throttleWrites(self):
"""
Throttle writes on all protocols.
"""
log.msg("Throttling writes on %s" % self)
for p in self.protocols.keys():
p.throttleWrites()
def unthrottleWrites(self):
"""
Stop throttling writes on all protocols.
"""
self.unthrottleWritesID = None
log.msg("Stopped throttling writes on %s" % self)
for p in self.protocols.keys():
p.unthrottleWrites()
def buildProtocol(self, addr):
if self.connectionCount == 0:
if self.readLimit is not None:
self.checkReadBandwidth()
if self.writeLimit is not None:
self.checkWriteBandwidth()
if self.connectionCount < self.maxConnectionCount:
self.connectionCount += 1
return WrappingFactory.buildProtocol(self, addr)
else:
log.msg("Max connection count reached!")
return None
def unregisterProtocol(self, p):
WrappingFactory.unregisterProtocol(self, p)
self.connectionCount -= 1
if self.connectionCount == 0:
if self.unthrottleReadsID is not None:
self.unthrottleReadsID.cancel()
if self.checkReadBandwidthID is not None:
self.checkReadBandwidthID.cancel()
if self.unthrottleWritesID is not None:
self.unthrottleWritesID.cancel()
if self.checkWriteBandwidthID is not None:
self.checkWriteBandwidthID.cancel()
class SpewingProtocol(ProtocolWrapper):
def dataReceived(self, data):
log.msg("Received: %r" % data)
ProtocolWrapper.dataReceived(self, data)
def write(self, data):
log.msg("Sending: %r" % data)
ProtocolWrapper.write(self, data)
class SpewingFactory(WrappingFactory):
protocol = SpewingProtocol
class LimitConnectionsByPeer(WrappingFactory):
maxConnectionsPerPeer = 5
def startFactory(self):
self.peerConnections = {}
def buildProtocol(self, addr):
peerHost = addr[0]
connectionCount = self.peerConnections.get(peerHost, 0)
if connectionCount >= self.maxConnectionsPerPeer:
return None
self.peerConnections[peerHost] = connectionCount + 1
return WrappingFactory.buildProtocol(self, addr)
def unregisterProtocol(self, p):
peerHost = p.getPeer()[1]
self.peerConnections[peerHost] -= 1
if self.peerConnections[peerHost] == 0:
del self.peerConnections[peerHost]
class LimitTotalConnectionsFactory(ServerFactory):
"""
Factory that limits the number of simultaneous connections.
@type connectionCount: C{int}
@ivar connectionCount: number of current connections.
@type connectionLimit: C{int} or L{None}
@cvar connectionLimit: maximum number of connections.
@type overflowProtocol: L{Protocol} or L{None}
@cvar overflowProtocol: Protocol to use for new connections when
connectionLimit is exceeded. If L{None} (the default value), excess
connections will be closed immediately.
"""
connectionCount = 0
connectionLimit = None
overflowProtocol: Optional[Type[Protocol]] = None
def buildProtocol(self, addr):
if self.connectionLimit is None or self.connectionCount < self.connectionLimit:
# Build the normal protocol
wrappedProtocol = self.protocol()
elif self.overflowProtocol is None:
# Just drop the connection
return None
else:
# Too many connections, so build the overflow protocol
wrappedProtocol = self.overflowProtocol()
wrappedProtocol.factory = self
protocol = ProtocolWrapper(self, wrappedProtocol)
self.connectionCount += 1
return protocol
def registerProtocol(self, p):
pass
def unregisterProtocol(self, p):
self.connectionCount -= 1
class TimeoutProtocol(ProtocolWrapper):
"""
Protocol that automatically disconnects when the connection is idle.
"""
def __init__(self, factory, wrappedProtocol, timeoutPeriod):
"""
Constructor.
@param factory: An L{TimeoutFactory}.
@param wrappedProtocol: A L{Protocol} to wrapp.
@param timeoutPeriod: Number of seconds to wait for activity before
timing out.
"""
ProtocolWrapper.__init__(self, factory, wrappedProtocol)
self.timeoutCall = None
self.timeoutPeriod = None
self.setTimeout(timeoutPeriod)
def setTimeout(self, timeoutPeriod=None):
"""
Set a timeout.
This will cancel any existing timeouts.
@param timeoutPeriod: If not L{None}, change the timeout period.
Otherwise, use the existing value.
"""
self.cancelTimeout()
self.timeoutPeriod = timeoutPeriod
if timeoutPeriod is not None:
self.timeoutCall = self.factory.callLater(
self.timeoutPeriod, self.timeoutFunc
)
def cancelTimeout(self):
"""
Cancel the timeout.
If the timeout was already cancelled, this does nothing.
"""
self.timeoutPeriod = None
if self.timeoutCall:
try:
self.timeoutCall.cancel()
except (error.AlreadyCalled, error.AlreadyCancelled):
pass
self.timeoutCall = None
def resetTimeout(self):
"""
Reset the timeout, usually because some activity just happened.
"""
if self.timeoutCall:
self.timeoutCall.reset(self.timeoutPeriod)
def write(self, data):
self.resetTimeout()
ProtocolWrapper.write(self, data)
def writeSequence(self, seq):
self.resetTimeout()
ProtocolWrapper.writeSequence(self, seq)
def dataReceived(self, data):
self.resetTimeout()
ProtocolWrapper.dataReceived(self, data)
def connectionLost(self, reason):
self.cancelTimeout()
ProtocolWrapper.connectionLost(self, reason)
def timeoutFunc(self):
"""
This method is called when the timeout is triggered.
By default it calls I{loseConnection}. Override this if you want
something else to happen.
"""
self.loseConnection()
class TimeoutFactory(WrappingFactory):
"""
Factory for TimeoutWrapper.
"""
protocol = TimeoutProtocol
def __init__(self, wrappedFactory, timeoutPeriod=30 * 60):
self.timeoutPeriod = timeoutPeriod
WrappingFactory.__init__(self, wrappedFactory)
def buildProtocol(self, addr):
return self.protocol(
self,
self.wrappedFactory.buildProtocol(addr),
timeoutPeriod=self.timeoutPeriod,
)
def callLater(self, period, func):
"""
Wrapper around
L{reactor.callLater<twisted.internet.interfaces.IReactorTime.callLater>}
for test purpose.
"""
from twisted.internet import reactor
return reactor.callLater(period, func)
class TrafficLoggingProtocol(ProtocolWrapper):
def __init__(self, factory, wrappedProtocol, logfile, lengthLimit=None, number=0):
"""
@param factory: factory which created this protocol.
@type factory: L{protocol.Factory}.
@param wrappedProtocol: the underlying protocol.
@type wrappedProtocol: C{protocol.Protocol}.
@param logfile: file opened for writing used to write log messages.
@type logfile: C{file}
@param lengthLimit: maximum size of the datareceived logged.
@type lengthLimit: C{int}
@param number: identifier of the connection.
@type number: C{int}.
"""
ProtocolWrapper.__init__(self, factory, wrappedProtocol)
self.logfile = logfile
self.lengthLimit = lengthLimit
self._number = number
def _log(self, line):
self.logfile.write(line + "\n")
self.logfile.flush()
def _mungeData(self, data):
if self.lengthLimit and len(data) > self.lengthLimit:
data = data[: self.lengthLimit - 12] + "<... elided>"
return data
# IProtocol
def connectionMade(self):
self._log("*")
return ProtocolWrapper.connectionMade(self)
def dataReceived(self, data):
self._log("C %d: %r" % (self._number, self._mungeData(data)))
return ProtocolWrapper.dataReceived(self, data)
def connectionLost(self, reason):
self._log("C %d: %r" % (self._number, reason))
return ProtocolWrapper.connectionLost(self, reason)
# ITransport
def write(self, data):
self._log("S %d: %r" % (self._number, self._mungeData(data)))
return ProtocolWrapper.write(self, data)
def writeSequence(self, iovec):
self._log("SV %d: %r" % (self._number, [self._mungeData(d) for d in iovec]))
return ProtocolWrapper.writeSequence(self, iovec)
def loseConnection(self):
self._log("S %d: *" % (self._number,))
return ProtocolWrapper.loseConnection(self)
class TrafficLoggingFactory(WrappingFactory):
protocol = TrafficLoggingProtocol
_counter = 0
def __init__(self, wrappedFactory, logfilePrefix, lengthLimit=None):
self.logfilePrefix = logfilePrefix
self.lengthLimit = lengthLimit
WrappingFactory.__init__(self, wrappedFactory)
def open(self, name):
return open(name, "w")
def buildProtocol(self, addr):
self._counter += 1
logfile = self.open(self.logfilePrefix + "-" + str(self._counter))
return self.protocol(
self,
self.wrappedFactory.buildProtocol(addr),
logfile,
self.lengthLimit,
self._counter,
)
def resetCounter(self):
"""
Reset the value of the counter used to identify connections.
"""
self._counter = 0
class TimeoutMixin:
"""
Mixin for protocols which wish to timeout connections.
Protocols that mix this in have a single timeout, set using L{setTimeout}.
When the timeout is hit, L{timeoutConnection} is called, which, by
default, closes the connection.
@cvar timeOut: The number of seconds after which to timeout the connection.
"""
timeOut: Optional[int] = None
__timeoutCall = None
def callLater(self, period, func):
"""
Wrapper around
L{reactor.callLater<twisted.internet.interfaces.IReactorTime.callLater>}
for test purpose.
"""
from twisted.internet import reactor
return reactor.callLater(period, func)
def resetTimeout(self):
"""
Reset the timeout count down.
If the connection has already timed out, then do nothing. If the
timeout has been cancelled (probably using C{setTimeout(None)}), also
do nothing.
It's often a good idea to call this when the protocol has received
some meaningful input from the other end of the connection. "I've got
some data, they're still there, reset the timeout".
"""
if self.__timeoutCall is not None and self.timeOut is not None:
self.__timeoutCall.reset(self.timeOut)
def setTimeout(self, period):
"""
Change the timeout period
@type period: C{int} or L{None}
@param period: The period, in seconds, to change the timeout to, or
L{None} to disable the timeout.
"""
prev = self.timeOut
self.timeOut = period
if self.__timeoutCall is not None:
if period is None:
try:
self.__timeoutCall.cancel()
except (error.AlreadyCancelled, error.AlreadyCalled):
# Do nothing if the call was already consumed.
pass
self.__timeoutCall = None
else:
self.__timeoutCall.reset(period)
elif period is not None:
self.__timeoutCall = self.callLater(period, self.__timedOut)
return prev
def __timedOut(self):
self.__timeoutCall = None
self.timeoutConnection()
def timeoutConnection(self):
"""
Called when the connection times out.
Override to define behavior other than dropping the connection.
"""
self.transport.loseConnection()

View File

@@ -0,0 +1,90 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
A simple port forwarder.
"""
# Twisted imports
from twisted.internet import protocol
from twisted.python import log
class Proxy(protocol.Protocol):
noisy = True
peer = None
def setPeer(self, peer):
self.peer = peer
def connectionLost(self, reason):
if self.peer is not None:
self.peer.transport.loseConnection()
self.peer = None
elif self.noisy:
log.msg(f"Unable to connect to peer: {reason}")
def dataReceived(self, data):
self.peer.transport.write(data)
class ProxyClient(Proxy):
def connectionMade(self):
self.peer.setPeer(self)
# Wire this and the peer transport together to enable
# flow control (this stops connections from filling
# this proxy memory when one side produces data at a
# higher rate than the other can consume).
self.transport.registerProducer(self.peer.transport, True)
self.peer.transport.registerProducer(self.transport, True)
# We're connected, everybody can read to their hearts content.
self.peer.transport.resumeProducing()
class ProxyClientFactory(protocol.ClientFactory):
protocol = ProxyClient
def setServer(self, server):
self.server = server
def buildProtocol(self, *args, **kw):
prot = protocol.ClientFactory.buildProtocol(self, *args, **kw)
prot.setPeer(self.server)
return prot
def clientConnectionFailed(self, connector, reason):
self.server.transport.loseConnection()
class ProxyServer(Proxy):
clientProtocolFactory = ProxyClientFactory
reactor = None
def connectionMade(self):
# Don't read anything from the connecting client until we have
# somewhere to send it to.
self.transport.pauseProducing()
client = self.clientProtocolFactory()
client.setServer(self)
if self.reactor is None:
from twisted.internet import reactor
self.reactor = reactor
self.reactor.connectTCP(self.factory.host, self.factory.port, client)
class ProxyFactory(protocol.Factory):
"""
Factory for port forwarder.
"""
protocol = ProxyServer
def __init__(self, host, port):
self.host = host
self.port = port

View File

@@ -0,0 +1,137 @@
# -*- test-case-name: twisted.test.test_postfix -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Postfix mail transport agent related protocols.
"""
from __future__ import annotations
import sys
from collections import UserDict
from typing import TYPE_CHECKING, Union
from urllib.parse import quote as _quote, unquote as _unquote
from twisted.internet import defer, protocol
from twisted.protocols import basic, policies
from twisted.python import log
# urllib's quote functions just happen to match
# the postfix semantics.
def quote(s):
quoted = _quote(s)
if isinstance(quoted, str):
quoted = quoted.encode("ascii")
return quoted
def unquote(s):
if isinstance(s, bytes):
s = s.decode("ascii")
quoted = _unquote(s)
return quoted.encode("ascii")
class PostfixTCPMapServer(basic.LineReceiver, policies.TimeoutMixin):
"""
Postfix mail transport agent TCP map protocol implementation.
Receive requests for data matching given key via lineReceived,
asks it's factory for the data with self.factory.get(key), and
returns the data to the requester. None means no entry found.
You can use postfix's postmap to test the map service::
/usr/sbin/postmap -q KEY tcp:localhost:4242
"""
timeout = 600
delimiter = b"\n"
def connectionMade(self):
self.setTimeout(self.timeout)
def sendCode(self, code, message=b""):
"""
Send an SMTP-like code with a message.
"""
self.sendLine(str(code).encode("ascii") + b" " + message)
def lineReceived(self, line):
self.resetTimeout()
try:
request, params = line.split(None, 1)
except ValueError:
request = line
params = None
try:
f = getattr(self, "do_" + request.decode("ascii"))
except AttributeError:
self.sendCode(400, b"unknown command")
else:
try:
f(params)
except BaseException:
excInfo = str(sys.exc_info()[1]).encode("ascii")
self.sendCode(400, b"Command " + request + b" failed: " + excInfo)
def do_get(self, key):
if key is None:
self.sendCode(400, b"Command 'get' takes 1 parameters.")
else:
d = defer.maybeDeferred(self.factory.get, key)
d.addCallbacks(self._cbGot, self._cbNot)
d.addErrback(log.err)
def _cbNot(self, fail):
msg = fail.getErrorMessage().encode("ascii")
self.sendCode(400, msg)
def _cbGot(self, value):
if value is None:
self.sendCode(500)
else:
self.sendCode(200, quote(value))
def do_put(self, keyAndValue):
if keyAndValue is None:
self.sendCode(400, b"Command 'put' takes 2 parameters.")
else:
try:
key, value = keyAndValue.split(None, 1)
except ValueError:
self.sendCode(400, b"Command 'put' takes 2 parameters.")
else:
self.sendCode(500, b"put is not implemented yet.")
if TYPE_CHECKING or sys.version_info >= (3, 9):
_PostfixTCPMapDict = UserDict[bytes, Union[str, bytes]]
else:
_PostfixTCPMapDict = UserDict
class PostfixTCPMapDictServerFactory(_PostfixTCPMapDict, protocol.ServerFactory):
"""
An in-memory dictionary factory for PostfixTCPMapServer.
"""
protocol = PostfixTCPMapServer
class PostfixTCPMapDeferringDictServerFactory(protocol.ServerFactory):
"""
An in-memory dictionary factory for PostfixTCPMapServer.
"""
protocol = PostfixTCPMapServer
def __init__(self, data=None):
self.data = {}
if data is not None:
self.data.update(data)
def get(self, key):
return defer.succeed(self.data.get(key))

View File

@@ -0,0 +1,111 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Chop up shoutcast stream into MP3s and metadata, if available.
"""
from twisted import copyright
from twisted.web import http
class ShoutcastClient(http.HTTPClient):
"""
Shoutcast HTTP stream.
Modes can be 'length', 'meta' and 'mp3'.
See U{http://www.smackfu.com/stuff/programming/shoutcast.html}
for details on the protocol.
"""
userAgent = "Twisted Shoutcast client " + copyright.version
def __init__(self, path="/"):
self.path = path
self.got_metadata = False
self.metaint = None
self.metamode = "mp3"
self.databuffer = ""
def connectionMade(self):
self.sendCommand("GET", self.path)
self.sendHeader("User-Agent", self.userAgent)
self.sendHeader("Icy-MetaData", "1")
self.endHeaders()
def lineReceived(self, line):
# fix shoutcast crappiness
if not self.firstLine and line:
if len(line.split(": ", 1)) == 1:
line = line.replace(":", ": ", 1)
http.HTTPClient.lineReceived(self, line)
def handleHeader(self, key, value):
if key.lower() == "icy-metaint":
self.metaint = int(value)
self.got_metadata = True
def handleEndHeaders(self):
# Lets check if we got metadata, and set the
# appropriate handleResponsePart method.
if self.got_metadata:
# if we have metadata, then it has to be parsed out of the data stream
self.handleResponsePart = self.handleResponsePart_with_metadata
else:
# otherwise, all the data is MP3 data
self.handleResponsePart = self.gotMP3Data
def handleResponsePart_with_metadata(self, data):
self.databuffer += data
while self.databuffer:
stop = getattr(self, "handle_%s" % self.metamode)()
if stop:
return
def handle_length(self):
self.remaining = ord(self.databuffer[0]) * 16
self.databuffer = self.databuffer[1:]
self.metamode = "meta"
def handle_mp3(self):
if len(self.databuffer) > self.metaint:
self.gotMP3Data(self.databuffer[: self.metaint])
self.databuffer = self.databuffer[self.metaint :]
self.metamode = "length"
else:
return 1
def handle_meta(self):
if len(self.databuffer) >= self.remaining:
if self.remaining:
data = self.databuffer[: self.remaining]
self.gotMetaData(self.parseMetadata(data))
self.databuffer = self.databuffer[self.remaining :]
self.metamode = "mp3"
else:
return 1
def parseMetadata(self, data):
meta = []
for chunk in data.split(";"):
chunk = chunk.strip().replace("\x00", "")
if not chunk:
continue
key, value = chunk.split("=", 1)
if value.startswith("'") and value.endswith("'"):
value = value[1:-1]
meta.append((key, value))
return meta
def gotMetaData(self, metadata):
"""Called with a list of (key, value) pairs of metadata,
if metadata is available on the server.
Will only be called on non-empty metadata.
"""
raise NotImplementedError("implement in subclass")
def gotMP3Data(self, data):
"""Called with chunk of MP3 data."""
raise NotImplementedError("implement in subclass")

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,249 @@
# -*- test-case-name: twisted.test.test_socks -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Implementation of the SOCKSv4 protocol.
"""
import socket
import string
# python imports
import struct
import time
# twisted imports
from twisted.internet import defer, protocol, reactor
from twisted.python import log
class SOCKSv4Outgoing(protocol.Protocol):
def __init__(self, socks):
self.socks = socks
def connectionMade(self):
peer = self.transport.getPeer()
self.socks.makeReply(90, 0, port=peer.port, ip=peer.host)
self.socks.otherConn = self
def connectionLost(self, reason):
self.socks.transport.loseConnection()
def dataReceived(self, data):
self.socks.write(data)
def write(self, data):
self.socks.log(self, data)
self.transport.write(data)
class SOCKSv4Incoming(protocol.Protocol):
def __init__(self, socks):
self.socks = socks
self.socks.otherConn = self
def connectionLost(self, reason):
self.socks.transport.loseConnection()
def dataReceived(self, data):
self.socks.write(data)
def write(self, data):
self.socks.log(self, data)
self.transport.write(data)
class SOCKSv4(protocol.Protocol):
"""
An implementation of the SOCKSv4 protocol.
@type logging: L{str} or L{None}
@ivar logging: If not L{None}, the name of the logfile to which connection
information will be written.
@type reactor: object providing L{twisted.internet.interfaces.IReactorTCP}
@ivar reactor: The reactor used to create connections.
@type buf: L{str}
@ivar buf: Part of a SOCKSv4 connection request.
@type otherConn: C{SOCKSv4Incoming}, C{SOCKSv4Outgoing} or L{None}
@ivar otherConn: Until the connection has been established, C{otherConn} is
L{None}. After that, it is the proxy-to-destination protocol instance
along which the client's connection is being forwarded.
"""
def __init__(self, logging=None, reactor=reactor):
self.logging = logging
self.reactor = reactor
def connectionMade(self):
self.buf = b""
self.otherConn = None
def dataReceived(self, data):
"""
Called whenever data is received.
@type data: L{bytes}
@param data: Part or all of a SOCKSv4 packet.
"""
if self.otherConn:
self.otherConn.write(data)
return
self.buf = self.buf + data
completeBuffer = self.buf
if b"\000" in self.buf[8:]:
head, self.buf = self.buf[:8], self.buf[8:]
version, code, port = struct.unpack("!BBH", head[:4])
user, self.buf = self.buf.split(b"\000", 1)
if head[4:7] == b"\000\000\000" and head[7:8] != b"\000":
# An IP address of the form 0.0.0.X, where X is non-zero,
# signifies that this is a SOCKSv4a packet.
# If the complete packet hasn't been received, restore the
# buffer and wait for it.
if b"\000" not in self.buf:
self.buf = completeBuffer
return
server, self.buf = self.buf.split(b"\000", 1)
d = self.reactor.resolve(server)
d.addCallback(self._dataReceived2, user, version, code, port)
d.addErrback(lambda result, self=self: self.makeReply(91))
return
else:
server = socket.inet_ntoa(head[4:8])
self._dataReceived2(server, user, version, code, port)
def _dataReceived2(self, server, user, version, code, port):
"""
The second half of the SOCKS connection setup. For a SOCKSv4 packet this
is after the server address has been extracted from the header. For a
SOCKSv4a packet this is after the host name has been resolved.
@type server: L{str}
@param server: The IP address of the destination, represented as a
dotted quad.
@type user: L{str}
@param user: The username associated with the connection.
@type version: L{int}
@param version: The SOCKS protocol version number.
@type code: L{int}
@param code: The command code. 1 means establish a TCP/IP stream
connection, and 2 means establish a TCP/IP port binding.
@type port: L{int}
@param port: The port number associated with the connection.
"""
assert version == 4, "Bad version code: %s" % version
if not self.authorize(code, server, port, user):
self.makeReply(91)
return
if code == 1: # CONNECT
d = self.connectClass(server, port, SOCKSv4Outgoing, self)
d.addErrback(lambda result, self=self: self.makeReply(91))
elif code == 2: # BIND
d = self.listenClass(0, SOCKSv4IncomingFactory, self, server)
d.addCallback(lambda x, self=self: self.makeReply(90, 0, x[1], x[0]))
else:
raise RuntimeError(f"Bad Connect Code: {code}")
assert self.buf == b"", "hmm, still stuff in buffer... %s" % repr(self.buf)
def connectionLost(self, reason):
if self.otherConn:
self.otherConn.transport.loseConnection()
def authorize(self, code, server, port, user):
log.msg(
"code %s connection to %s:%s (user %s) authorized"
% (code, server, port, user)
)
return 1
def connectClass(self, host, port, klass, *args):
return protocol.ClientCreator(reactor, klass, *args).connectTCP(host, port)
def listenClass(self, port, klass, *args):
serv = reactor.listenTCP(port, klass(*args))
return defer.succeed(serv.getHost()[1:])
def makeReply(self, reply, version=0, port=0, ip="0.0.0.0"):
self.transport.write(
struct.pack("!BBH", version, reply, port) + socket.inet_aton(ip)
)
if reply != 90:
self.transport.loseConnection()
def write(self, data):
self.log(self, data)
self.transport.write(data)
def log(self, proto, data):
if not self.logging:
return
peer = self.transport.getPeer()
their_peer = self.otherConn.transport.getPeer()
f = open(self.logging, "a")
f.write(
"%s\t%s:%d %s %s:%d\n"
% (
time.ctime(),
peer.host,
peer.port,
((proto == self and "<") or ">"),
their_peer.host,
their_peer.port,
)
)
while data:
p, data = data[:16], data[16:]
f.write(string.join(map(lambda x: "%02X" % ord(x), p), " ") + " ")
f.write((16 - len(p)) * 3 * " ")
for c in p:
if len(repr(c)) > 3:
f.write(".")
else:
f.write(c)
f.write("\n")
f.write("\n")
f.close()
class SOCKSv4Factory(protocol.Factory):
"""
A factory for a SOCKSv4 proxy.
Constructor accepts one argument, a log file name.
"""
def __init__(self, log):
self.logging = log
def buildProtocol(self, addr):
return SOCKSv4(self.logging, reactor)
class SOCKSv4IncomingFactory(protocol.Factory):
"""
A utility class for building protocols for incoming connections.
"""
def __init__(self, socks, ip):
self.socks = socks
self.ip = ip
def buildProtocol(self, addr):
if addr[0] == self.ip:
self.ip = ""
self.socks.makeReply(90, 0)
return SOCKSv4Incoming(self.socks)
elif self.ip == "":
return None
else:
self.socks.makeReply(91, 0)
self.ip = ""
return None

View File

@@ -0,0 +1,52 @@
# -*- test-case-name: twisted.test.test_stateful -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
from io import BytesIO
from twisted.internet import protocol
class StatefulProtocol(protocol.Protocol):
"""A Protocol that stores state for you.
state is a pair (function, num_bytes). When num_bytes bytes of data arrives
from the network, function is called. It is expected to return the next
state or None to keep same state. Initial state is returned by
getInitialState (override it).
"""
_sful_data = None, None, 0
def makeConnection(self, transport):
protocol.Protocol.makeConnection(self, transport)
self._sful_data = self.getInitialState(), BytesIO(), 0
def getInitialState(self):
raise NotImplementedError
def dataReceived(self, data):
state, buffer, offset = self._sful_data
buffer.seek(0, 2)
buffer.write(data)
blen = buffer.tell() # how many bytes total is in the buffer
buffer.seek(offset)
while blen - offset >= state[1]:
d = buffer.read(state[1])
offset += state[1]
next = state[0](d)
if (
self.transport.disconnecting
): # XXX: argh stupid hack borrowed right from LineReceiver
return # dataReceived won't be called again, so who cares about consistent state
if next:
state = next
if offset != 0:
b = buffer.read()
buffer.seek(0)
buffer.truncate()
buffer.write(b)
offset = 0
self._sful_data = state, buffer, offset

View File

@@ -0,0 +1,6 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Unit tests for L{twisted.protocols}.
"""

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,936 @@
# -*- test-case-name: twisted.protocols.test.test_tls,twisted.internet.test.test_tls,twisted.test.test_sslverify -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Implementation of a TLS transport (L{ISSLTransport}) as an
L{IProtocol<twisted.internet.interfaces.IProtocol>} layered on top of any
L{ITransport<twisted.internet.interfaces.ITransport>} implementation, based on
U{OpenSSL<http://www.openssl.org>}'s memory BIO features.
L{TLSMemoryBIOFactory} is a L{WrappingFactory} which wraps protocols created by
the factory it wraps with L{TLSMemoryBIOProtocol}. L{TLSMemoryBIOProtocol}
intercedes between the underlying transport and the wrapped protocol to
implement SSL and TLS. Typical usage of this module looks like this::
from twisted.protocols.tls import TLSMemoryBIOFactory
from twisted.internet.protocol import ServerFactory
from twisted.internet.ssl import PrivateCertificate
from twisted.internet import reactor
from someapplication import ApplicationProtocol
serverFactory = ServerFactory()
serverFactory.protocol = ApplicationProtocol
certificate = PrivateCertificate.loadPEM(certPEMData)
contextFactory = certificate.options()
tlsFactory = TLSMemoryBIOFactory(contextFactory, False, serverFactory)
reactor.listenTCP(12345, tlsFactory)
reactor.run()
This API offers somewhat more flexibility than
L{twisted.internet.interfaces.IReactorSSL}; for example, a
L{TLSMemoryBIOProtocol} instance can use another instance of
L{TLSMemoryBIOProtocol} as its transport, yielding TLS over TLS - useful to
implement onion routing. It can also be used to run TLS over unusual
transports, such as UNIX sockets and stdio.
"""
from __future__ import annotations
from typing import Callable, Iterable, Optional, cast
from zope.interface import directlyProvides, implementer, providedBy
from OpenSSL.SSL import Connection, Error, SysCallError, WantReadError, ZeroReturnError
from twisted.internet._producer_helpers import _PullToPush
from twisted.internet._sslverify import _setAcceptableProtocols
from twisted.internet.interfaces import (
IDelayedCall,
IHandshakeListener,
ILoggingContext,
INegotiated,
IOpenSSLClientConnectionCreator,
IOpenSSLServerConnectionCreator,
IProtocol,
IProtocolNegotiationFactory,
IPushProducer,
IReactorTime,
ISystemHandle,
ITransport,
)
from twisted.internet.main import CONNECTION_LOST
from twisted.internet.protocol import Protocol
from twisted.protocols.policies import ProtocolWrapper, WrappingFactory
from twisted.python.failure import Failure
@implementer(IPushProducer)
class _ProducerMembrane:
"""
Stand-in for producer registered with a L{TLSMemoryBIOProtocol} transport.
Ensures that producer pause/resume events from the undelying transport are
coordinated with pause/resume events from the TLS layer.
@ivar _producer: The application-layer producer.
"""
_producerPaused = False
def __init__(self, producer):
self._producer = producer
def pauseProducing(self):
"""
C{pauseProducing} the underlying producer, if it's not paused.
"""
if self._producerPaused:
return
self._producerPaused = True
self._producer.pauseProducing()
def resumeProducing(self):
"""
C{resumeProducing} the underlying producer, if it's paused.
"""
if not self._producerPaused:
return
self._producerPaused = False
self._producer.resumeProducing()
def stopProducing(self):
"""
C{stopProducing} the underlying producer.
There is only a single source for this event, so it's simply passed
on.
"""
self._producer.stopProducing()
def _representsEOF(exceptionObject: Error) -> bool:
"""
Does the given OpenSSL.SSL.Error represent an end-of-file?
"""
reasonString: str
if isinstance(exceptionObject, SysCallError):
_, reasonString = exceptionObject.args
else:
errorQueue = exceptionObject.args[0]
_, _, reasonString = errorQueue[-1]
return reasonString.casefold().startswith("unexpected eof")
@implementer(ISystemHandle, INegotiated, ITransport)
class TLSMemoryBIOProtocol(ProtocolWrapper):
"""
L{TLSMemoryBIOProtocol} is a protocol wrapper which uses OpenSSL via a
memory BIO to encrypt bytes written to it before sending them on to the
underlying transport and decrypts bytes received from the underlying
transport before delivering them to the wrapped protocol.
In addition to producer events from the underlying transport, the need to
wait for reads before a write can proceed means the L{TLSMemoryBIOProtocol}
may also want to pause a producer. Pause/resume events are therefore
merged using the L{_ProducerMembrane} wrapper. Non-streaming (pull)
producers are supported by wrapping them with L{_PullToPush}.
Because TLS may need to wait for reads before writing, some writes may be
buffered until a read occurs.
@ivar _tlsConnection: The L{OpenSSL.SSL.Connection} instance which is
encrypted and decrypting this connection.
@ivar _lostTLSConnection: A flag indicating whether connection loss has
already been dealt with (C{True}) or not (C{False}). TLS disconnection
is distinct from the underlying connection being lost.
@ivar _appSendBuffer: application-level (cleartext) data that is waiting to
be transferred to the TLS buffer, but can't be because the TLS
connection is handshaking.
@type _appSendBuffer: L{list} of L{bytes}
@ivar _connectWrapped: A flag indicating whether or not to call
C{makeConnection} on the wrapped protocol. This is for the reactor's
L{twisted.internet.interfaces.ITLSTransport.startTLS} implementation,
since it has a protocol which it has already called C{makeConnection}
on, and which has no interest in a new transport. See #3821.
@ivar _handshakeDone: A flag indicating whether or not the handshake is
known to have completed successfully (C{True}) or not (C{False}). This
is used to control error reporting behavior. If the handshake has not
completed, the underlying L{OpenSSL.SSL.Error} will be passed to the
application's C{connectionLost} method. If it has completed, any
unexpected L{OpenSSL.SSL.Error} will be turned into a
L{ConnectionLost}. This is weird; however, it is simply an attempt at
a faithful re-implementation of the behavior provided by
L{twisted.internet.ssl}.
@ivar _reason: If an unexpected L{OpenSSL.SSL.Error} occurs which causes
the connection to be lost, it is saved here. If appropriate, this may
be used as the reason passed to the application protocol's
C{connectionLost} method.
@ivar _producer: The current producer registered via C{registerProducer},
or L{None} if no producer has been registered or a previous one was
unregistered.
@ivar _aborted: C{abortConnection} has been called. No further data will
be received to the wrapped protocol's C{dataReceived}.
@type _aborted: L{bool}
"""
_reason = None
_handshakeDone = False
_lostTLSConnection = False
_producer = None
_aborted = False
def __init__(self, factory, wrappedProtocol, _connectWrapped=True):
ProtocolWrapper.__init__(self, factory, wrappedProtocol)
self._connectWrapped = _connectWrapped
def getHandle(self):
"""
Return the L{OpenSSL.SSL.Connection} object being used to encrypt and
decrypt this connection.
This is done for the benefit of L{twisted.internet.ssl.Certificate}'s
C{peerFromTransport} and C{hostFromTransport} methods only. A
different system handle may be returned by future versions of this
method.
"""
return self._tlsConnection
def makeConnection(self, transport):
"""
Connect this wrapper to the given transport and initialize the
necessary L{OpenSSL.SSL.Connection} with a memory BIO.
"""
self._tlsConnection = self.factory._createConnection(self)
self._appSendBuffer = []
# Add interfaces provided by the transport we are wrapping:
for interface in providedBy(transport):
directlyProvides(self, interface)
# Intentionally skip ProtocolWrapper.makeConnection - it might call
# wrappedProtocol.makeConnection, which we want to make conditional.
Protocol.makeConnection(self, transport)
self.factory.registerProtocol(self)
if self._connectWrapped:
# Now that the TLS layer is initialized, notify the application of
# the connection.
ProtocolWrapper.makeConnection(self, transport)
# Now that we ourselves have a transport (initialized by the
# ProtocolWrapper.makeConnection call above), kick off the TLS
# handshake.
self._checkHandshakeStatus()
def _checkHandshakeStatus(self):
"""
Ask OpenSSL to proceed with a handshake in progress.
Initially, this just sends the ClientHello; after some bytes have been
stuffed in to the C{Connection} object by C{dataReceived}, it will then
respond to any C{Certificate} or C{KeyExchange} messages.
"""
# The connection might already be aborted (eg. by a callback during
# connection setup), so don't even bother trying to handshake in that
# case.
if self._aborted:
return
try:
self._tlsConnection.do_handshake()
except WantReadError:
self._flushSendBIO()
except Error:
self._tlsShutdownFinished(Failure())
else:
self._handshakeDone = True
if IHandshakeListener.providedBy(self.wrappedProtocol):
self.wrappedProtocol.handshakeCompleted()
def _flushSendBIO(self):
"""
Read any bytes out of the send BIO and write them to the underlying
transport.
"""
try:
bytes = self._tlsConnection.bio_read(2**15)
except WantReadError:
# There may be nothing in the send BIO right now.
pass
else:
self.transport.write(bytes)
def _flushReceiveBIO(self):
"""
Try to receive any application-level bytes which are now available
because of a previous write into the receive BIO. This will take
care of delivering any application-level bytes which are received to
the protocol, as well as handling of the various exceptions which
can come from trying to get such bytes.
"""
# Keep trying this until an error indicates we should stop or we
# close the connection. Looping is necessary to make sure we
# process all of the data which was put into the receive BIO, as
# there is no guarantee that a single recv call will do it all.
while not self._lostTLSConnection:
try:
bytes = self._tlsConnection.recv(2**15)
except WantReadError:
# The newly received bytes might not have been enough to produce
# any application data.
break
except ZeroReturnError:
# TLS has shut down and no more TLS data will be received over
# this connection.
self._shutdownTLS()
# Passing in None means the user protocol's connnectionLost
# will get called with reason from underlying transport:
self._tlsShutdownFinished(None)
except Error:
# Something went pretty wrong. For example, this might be a
# handshake failure during renegotiation (because there were no
# shared ciphers, because a certificate failed to verify, etc).
# TLS can no longer proceed.
failure = Failure()
self._tlsShutdownFinished(failure)
else:
if not self._aborted:
ProtocolWrapper.dataReceived(self, bytes)
# The received bytes might have generated a response which needs to be
# sent now. For example, the handshake involves several round-trip
# exchanges without ever producing application-bytes.
self._flushSendBIO()
def dataReceived(self, bytes):
"""
Deliver any received bytes to the receive BIO and then read and deliver
to the application any application-level data which becomes available
as a result of this.
"""
# Let OpenSSL know some bytes were just received.
self._tlsConnection.bio_write(bytes)
# If we are still waiting for the handshake to complete, try to
# complete the handshake with the bytes we just received.
if not self._handshakeDone:
self._checkHandshakeStatus()
# If the handshake still isn't finished, then we've nothing left to
# do.
if not self._handshakeDone:
return
# If we've any pending writes, this read may have un-blocked them, so
# attempt to unbuffer them into the OpenSSL layer.
if self._appSendBuffer:
self._unbufferPendingWrites()
# Since the handshake is complete, the wire-level bytes we just
# processed might turn into some application-level bytes; try to pull
# those out.
self._flushReceiveBIO()
def _shutdownTLS(self):
"""
Initiate, or reply to, the shutdown handshake of the TLS layer.
"""
try:
shutdownSuccess = self._tlsConnection.shutdown()
except Error:
# Mid-handshake, a call to shutdown() can result in a
# WantWantReadError, or rather an SSL_ERR_WANT_READ; but pyOpenSSL
# doesn't allow us to get at the error. See:
# https://github.com/pyca/pyopenssl/issues/91
shutdownSuccess = False
self._flushSendBIO()
if shutdownSuccess:
# Both sides have shutdown, so we can start closing lower-level
# transport. This will also happen if we haven't started
# negotiation at all yet, in which case shutdown succeeds
# immediately.
self.transport.loseConnection()
def _tlsShutdownFinished(self, reason):
"""
Called when TLS connection has gone away; tell underlying transport to
disconnect.
@param reason: a L{Failure} whose value is an L{Exception} if we want to
report that failure through to the wrapped protocol's
C{connectionLost}, or L{None} if the C{reason} that
C{connectionLost} should receive should be coming from the
underlying transport.
@type reason: L{Failure} or L{None}
"""
if reason is not None:
# Squash an EOF in violation of the TLS protocol into
# ConnectionLost, so that applications which might run over
# multiple protocols can recognize its type.
if _representsEOF(reason.value):
reason = Failure(CONNECTION_LOST)
if self._reason is None:
self._reason = reason
self._lostTLSConnection = True
# We may need to send a TLS alert regarding the nature of the shutdown
# here (for example, why a handshake failed), so always flush our send
# buffer before telling our lower-level transport to go away.
self._flushSendBIO()
# Using loseConnection causes the application protocol's
# connectionLost method to be invoked non-reentrantly, which is always
# a nice feature. However, for error cases (reason != None) we might
# want to use abortConnection when it becomes available. The
# loseConnection call is basically tested by test_handshakeFailure.
# At least one side will need to do it or the test never finishes.
self.transport.loseConnection()
def connectionLost(self, reason):
"""
Handle the possible repetition of calls to this method (due to either
the underlying transport going away or due to an error at the TLS
layer) and make sure the base implementation only gets invoked once.
"""
if not self._lostTLSConnection:
# Tell the TLS connection that it's not going to get any more data
# and give it a chance to finish reading.
self._tlsConnection.bio_shutdown()
self._flushReceiveBIO()
self._lostTLSConnection = True
reason = self._reason or reason
self._reason = None
self.connected = False
ProtocolWrapper.connectionLost(self, reason)
# Breaking reference cycle between self._tlsConnection and self.
self._tlsConnection = None
def loseConnection(self):
"""
Send a TLS close alert and close the underlying connection.
"""
if self.disconnecting or not self.connected:
return
# If connection setup has not finished, OpenSSL 1.0.2f+ will not shut
# down the connection until we write some data to the connection which
# allows the handshake to complete. However, since no data should be
# written after loseConnection, this means we'll be stuck forever
# waiting for shutdown to complete. Instead, we simply abort the
# connection without trying to shut down cleanly:
if not self._handshakeDone and not self._appSendBuffer:
self.abortConnection()
self.disconnecting = True
if not self._appSendBuffer and self._producer is None:
self._shutdownTLS()
def abortConnection(self):
"""
Tear down TLS state so that if the connection is aborted mid-handshake
we don't deliver any further data from the application.
"""
self._aborted = True
self.disconnecting = True
self._shutdownTLS()
self.transport.abortConnection()
def failVerification(self, reason):
"""
Abort the connection during connection setup, giving a reason that
certificate verification failed.
@param reason: The reason that the verification failed; reported to the
application protocol's C{connectionLost} method.
@type reason: L{Failure}
"""
self._reason = reason
self.abortConnection()
def write(self, bytes):
"""
Process the given application bytes and send any resulting TLS traffic
which arrives in the send BIO.
If C{loseConnection} was called, subsequent calls to C{write} will
drop the bytes on the floor.
"""
# Writes after loseConnection are not supported, unless a producer has
# been registered, in which case writes can happen until the producer
# is unregistered:
if self.disconnecting and self._producer is None:
return
self._write(bytes)
def _bufferedWrite(self, octets):
"""
Put the given octets into L{TLSMemoryBIOProtocol._appSendBuffer}, and
tell any listening producer that it should pause because we are now
buffering.
"""
self._appSendBuffer.append(octets)
if self._producer is not None:
self._producer.pauseProducing()
def _unbufferPendingWrites(self):
"""
Un-buffer all waiting writes in L{TLSMemoryBIOProtocol._appSendBuffer}.
"""
pendingWrites, self._appSendBuffer = self._appSendBuffer, []
for eachWrite in pendingWrites:
self._write(eachWrite)
if self._appSendBuffer:
# If OpenSSL ran out of buffer space in the Connection on our way
# through the loop earlier and re-buffered any of our outgoing
# writes, then we're done; don't consider any future work.
return
if self._producer is not None:
# If we have a registered producer, let it know that we have some
# more buffer space.
self._producer.resumeProducing()
return
if self.disconnecting:
# Finally, if we have no further buffered data, no producer wants
# to send us more data in the future, and the application told us
# to end the stream, initiate a TLS shutdown.
self._shutdownTLS()
def _write(self, bytes):
"""
Process the given application bytes and send any resulting TLS traffic
which arrives in the send BIO.
This may be called by C{dataReceived} with bytes that were buffered
before C{loseConnection} was called, which is why this function
doesn't check for disconnection but accepts the bytes regardless.
"""
if self._lostTLSConnection:
return
# A TLS payload is 16kB max
bufferSize = 2**14
# How far into the input we've gotten so far
alreadySent = 0
while alreadySent < len(bytes):
toSend = bytes[alreadySent : alreadySent + bufferSize]
try:
sent = self._tlsConnection.send(toSend)
except WantReadError:
self._bufferedWrite(bytes[alreadySent:])
break
except Error:
# Pretend TLS connection disconnected, which will trigger
# disconnect of underlying transport. The error will be passed
# to the application protocol's connectionLost method. The
# other SSL implementation doesn't, but losing helpful
# debugging information is a bad idea.
self._tlsShutdownFinished(Failure())
break
else:
# We've successfully handed off the bytes to the OpenSSL
# Connection object.
alreadySent += sent
# See if OpenSSL wants to hand any bytes off to the underlying
# transport as a result.
self._flushSendBIO()
def writeSequence(self, iovec):
"""
Write a sequence of application bytes by joining them into one string
and passing them to L{write}.
"""
self.write(b"".join(iovec))
def getPeerCertificate(self):
return self._tlsConnection.get_peer_certificate()
@property
def negotiatedProtocol(self):
"""
@see: L{INegotiated.negotiatedProtocol}
"""
protocolName = None
try:
# If ALPN is not implemented that's ok, NPN might be.
protocolName = self._tlsConnection.get_alpn_proto_negotiated()
except (NotImplementedError, AttributeError):
pass
if protocolName not in (b"", None):
# A protocol was selected using ALPN.
return protocolName
try:
protocolName = self._tlsConnection.get_next_proto_negotiated()
except (NotImplementedError, AttributeError):
pass
if protocolName != b"":
return protocolName
return None
def registerProducer(self, producer, streaming):
# If we've already disconnected, nothing to do here:
if self._lostTLSConnection:
producer.stopProducing()
return
# If we received a non-streaming producer, wrap it so it becomes a
# streaming producer:
if not streaming:
producer = streamingProducer = _PullToPush(producer, self)
producer = _ProducerMembrane(producer)
# This will raise an exception if a producer is already registered:
self.transport.registerProducer(producer, True)
self._producer = producer
# If we received a non-streaming producer, we need to start the
# streaming wrapper:
if not streaming:
streamingProducer.startStreaming()
def unregisterProducer(self):
# If we have no producer, we don't need to do anything here.
if self._producer is None:
return
# If we received a non-streaming producer, we need to stop the
# streaming wrapper:
if isinstance(self._producer._producer, _PullToPush):
self._producer._producer.stopStreaming()
self._producer = None
self._producerPaused = False
self.transport.unregisterProducer()
if self.disconnecting and not self._appSendBuffer:
self._shutdownTLS()
@implementer(IOpenSSLClientConnectionCreator, IOpenSSLServerConnectionCreator)
class _ContextFactoryToConnectionFactory:
"""
Adapter wrapping a L{twisted.internet.interfaces.IOpenSSLContextFactory}
into a L{IOpenSSLClientConnectionCreator} or
L{IOpenSSLServerConnectionCreator}.
See U{https://twistedmatrix.com/trac/ticket/7215} for work that should make
this unnecessary.
"""
def __init__(self, oldStyleContextFactory):
"""
Construct a L{_ContextFactoryToConnectionFactory} with a
L{twisted.internet.interfaces.IOpenSSLContextFactory}.
Immediately call C{getContext} on C{oldStyleContextFactory} in order to
force advance parameter checking, since old-style context factories
don't actually check that their arguments to L{OpenSSL} are correct.
@param oldStyleContextFactory: A factory that can produce contexts.
@type oldStyleContextFactory:
L{twisted.internet.interfaces.IOpenSSLContextFactory}
"""
oldStyleContextFactory.getContext()
self._oldStyleContextFactory = oldStyleContextFactory
def _connectionForTLS(self, protocol):
"""
Create an L{OpenSSL.SSL.Connection} object.
@param protocol: The protocol initiating a TLS connection.
@type protocol: L{TLSMemoryBIOProtocol}
@return: a connection
@rtype: L{OpenSSL.SSL.Connection}
"""
context = self._oldStyleContextFactory.getContext()
return Connection(context, None)
def serverConnectionForTLS(self, protocol):
"""
Construct an OpenSSL server connection from the wrapped old-style
context factory.
@note: Since old-style context factories don't distinguish between
clients and servers, this is exactly the same as
L{_ContextFactoryToConnectionFactory.clientConnectionForTLS}.
@param protocol: The protocol initiating a TLS connection.
@type protocol: L{TLSMemoryBIOProtocol}
@return: a connection
@rtype: L{OpenSSL.SSL.Connection}
"""
return self._connectionForTLS(protocol)
def clientConnectionForTLS(self, protocol):
"""
Construct an OpenSSL server connection from the wrapped old-style
context factory.
@note: Since old-style context factories don't distinguish between
clients and servers, this is exactly the same as
L{_ContextFactoryToConnectionFactory.serverConnectionForTLS}.
@param protocol: The protocol initiating a TLS connection.
@type protocol: L{TLSMemoryBIOProtocol}
@return: a connection
@rtype: L{OpenSSL.SSL.Connection}
"""
return self._connectionForTLS(protocol)
class _AggregateSmallWrites:
"""
Aggregate small writes so they get written in large batches.
If this is used as part of a transport, the transport needs to call
``flush()`` immediately when ``loseConnection()`` is called, otherwise any
buffered writes will never get written.
@cvar MAX_BUFFER_SIZE: The maximum amount of bytes to buffer before writing
them out.
"""
MAX_BUFFER_SIZE = 64_000
def __init__(self, write: Callable[[bytes], object], clock: IReactorTime):
self._write = write
self._clock = clock
self._buffer: list[bytes] = []
self._bufferLeft = self.MAX_BUFFER_SIZE
self._scheduled: Optional[IDelayedCall] = None
def write(self, data: bytes) -> None:
"""
Buffer the data, or write it immediately if we've accumulated enough to
make it worth it.
Accumulating too much data can result in higher memory usage.
"""
self._buffer.append(data)
self._bufferLeft -= len(data)
if self._bufferLeft < 0:
# We've accumulated enough we should just write it out. No need to
# schedule a flush, since we just flushed everything.
self.flush()
return
if self._scheduled:
# We already have a scheduled send, so with the data in the buffer,
# there is nothing more to do here.
return
# Schedule the write of the accumulated buffer for the next reactor
# iteration.
self._scheduled = self._clock.callLater(0, self._scheduledFlush)
def _scheduledFlush(self) -> None:
"""Called in next reactor iteration."""
self._scheduled = None
self.flush()
def flush(self) -> None:
"""Flush any buffered writes."""
if self._buffer:
self._bufferLeft = self.MAX_BUFFER_SIZE
self._write(b"".join(self._buffer))
del self._buffer[:]
def _get_default_clock() -> IReactorTime:
"""
Return the default reactor.
This is a function so it can be monkey-patched in tests, specifically
L{twisted.web.test.test_agent}.
"""
from twisted.internet import reactor
return cast(IReactorTime, reactor)
class BufferingTLSTransport(TLSMemoryBIOProtocol):
"""
A TLS transport implemented by wrapping buffering around a
L{TLSMemoryBIOProtocol}.
Doing many small writes directly to a L{OpenSSL.SSL.Connection}, as
implemented in L{TLSMemoryBIOProtocol}, can add significant CPU and
bandwidth overhead. Thus, even when writing is possible, small writes will
get aggregated and written as a single write at the next reactor iteration.
"""
# Implementation Note: An implementation based on composition would be
# nicer, but there's close integration between L{ProtocolWrapper}
# subclasses like L{TLSMemoryBIOProtocol} and the corresponding factory. An
# attempt to implement this with broke things like
# L{TLSMemoryBIOFactory.protocols} having the correct instances, whereas
# subclassing makes that work.
def __init__(
self,
factory: TLSMemoryBIOFactory,
wrappedProtocol: IProtocol,
_connectWrapped: bool = True,
):
super().__init__(factory, wrappedProtocol, _connectWrapped)
actual_write = super().write
self._aggregator = _AggregateSmallWrites(actual_write, factory._clock)
# This is kinda ugly, but speeds things up a lot in a hot path with
# lots of small TLS writes. May become unnecessary in Python 3.13 or
# later if JIT and/or inlining becomes a thing.
self.write = self._aggregator.write # type: ignore[method-assign]
def writeSequence(self, sequence: Iterable[bytes]) -> None:
self._aggregator.write(b"".join(sequence))
def loseConnection(self) -> None:
self._aggregator.flush()
super().loseConnection()
class TLSMemoryBIOFactory(WrappingFactory):
"""
L{TLSMemoryBIOFactory} adds TLS to connections.
@ivar _creatorInterface: the interface which L{_connectionCreator} is
expected to implement.
@type _creatorInterface: L{zope.interface.interfaces.IInterface}
@ivar _connectionCreator: a callable which creates an OpenSSL Connection
object.
@type _connectionCreator: 1-argument callable taking
L{TLSMemoryBIOProtocol} and returning L{OpenSSL.SSL.Connection}.
"""
protocol = BufferingTLSTransport
noisy = False # disable unnecessary logging.
def __init__(
self,
contextFactory,
isClient,
wrappedFactory,
clock=None,
):
"""
Create a L{TLSMemoryBIOFactory}.
@param contextFactory: Configuration parameters used to create an
OpenSSL connection. In order of preference, what you should pass
here should be:
1. L{twisted.internet.ssl.CertificateOptions} (if you're
writing a server) or the result of
L{twisted.internet.ssl.optionsForClientTLS} (if you're
writing a client). If you want security you should really
use one of these.
2. If you really want to implement something yourself, supply a
provider of L{IOpenSSLClientConnectionCreator} or
L{IOpenSSLServerConnectionCreator}.
3. If you really have to, supply a
L{twisted.internet.ssl.ContextFactory}. This will likely be
deprecated at some point so please upgrade to the new
interfaces.
@type contextFactory: L{IOpenSSLClientConnectionCreator} or
L{IOpenSSLServerConnectionCreator}, or, for compatibility with
older code, anything implementing
L{twisted.internet.interfaces.IOpenSSLContextFactory}. See
U{https://twistedmatrix.com/trac/ticket/7215} for information on
the upcoming deprecation of passing a
L{twisted.internet.ssl.ContextFactory} here.
@param isClient: Is this a factory for TLS client connections; in other
words, those that will send a C{ClientHello} greeting? L{True} if
so, L{False} otherwise. This flag determines what interface is
expected of C{contextFactory}. If L{True}, C{contextFactory}
should provide L{IOpenSSLClientConnectionCreator}; otherwise it
should provide L{IOpenSSLServerConnectionCreator}.
@type isClient: L{bool}
@param wrappedFactory: A factory which will create the
application-level protocol.
@type wrappedFactory: L{twisted.internet.interfaces.IProtocolFactory}
"""
WrappingFactory.__init__(self, wrappedFactory)
if isClient:
creatorInterface = IOpenSSLClientConnectionCreator
else:
creatorInterface = IOpenSSLServerConnectionCreator
self._creatorInterface = creatorInterface
if not creatorInterface.providedBy(contextFactory):
contextFactory = _ContextFactoryToConnectionFactory(contextFactory)
self._connectionCreator = contextFactory
if clock is None:
clock = _get_default_clock()
self._clock = clock
def logPrefix(self):
"""
Annotate the wrapped factory's log prefix with some text indicating TLS
is in use.
@rtype: C{str}
"""
if ILoggingContext.providedBy(self.wrappedFactory):
logPrefix = self.wrappedFactory.logPrefix()
else:
logPrefix = self.wrappedFactory.__class__.__name__
return f"{logPrefix} (TLS)"
def _applyProtocolNegotiation(self, connection):
"""
Applies ALPN/NPN protocol neogitation to the connection, if the factory
supports it.
@param connection: The OpenSSL connection object to have ALPN/NPN added
to it.
@type connection: L{OpenSSL.SSL.Connection}
@return: Nothing
@rtype: L{None}
"""
if IProtocolNegotiationFactory.providedBy(self.wrappedFactory):
protocols = self.wrappedFactory.acceptableProtocols()
context = connection.get_context()
_setAcceptableProtocols(context, protocols)
return
def _createConnection(self, tlsProtocol):
"""
Create an OpenSSL connection and set it up good.
@param tlsProtocol: The protocol which is establishing the connection.
@type tlsProtocol: L{TLSMemoryBIOProtocol}
@return: an OpenSSL connection object for C{tlsProtocol} to use
@rtype: L{OpenSSL.SSL.Connection}
"""
connectionCreator = self._connectionCreator
if self._creatorInterface is IOpenSSLClientConnectionCreator:
connection = connectionCreator.clientConnectionForTLS(tlsProtocol)
self._applyProtocolNegotiation(connection)
connection.set_connect_state()
else:
connection = connectionCreator.serverConnectionForTLS(tlsProtocol)
self._applyProtocolNegotiation(connection)
connection.set_accept_state()
return connection

View File

@@ -0,0 +1,112 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""Implement standard (and unused) TCP protocols.
These protocols are either provided by inetd, or are not provided at all.
"""
import struct
import time
from zope.interface import implementer
from twisted.internet import interfaces, protocol
class Echo(protocol.Protocol):
"""
As soon as any data is received, write it back (RFC 862).
"""
def dataReceived(self, data):
self.transport.write(data)
class Discard(protocol.Protocol):
"""
Discard any received data (RFC 863).
"""
def dataReceived(self, data):
# I'm ignoring you, nyah-nyah
pass
@implementer(interfaces.IProducer)
class Chargen(protocol.Protocol):
"""
Generate repeating noise (RFC 864).
"""
noise = rb'@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\]^_`abcdefghijklmnopqrstuvwxyz{|}~ !"#$%&?'
def connectionMade(self):
self.transport.registerProducer(self, 0)
def resumeProducing(self):
self.transport.write(self.noise)
def pauseProducing(self):
pass
def stopProducing(self):
pass
class QOTD(protocol.Protocol):
"""
Return a quote of the day (RFC 865).
"""
def connectionMade(self):
self.transport.write(self.getQuote())
self.transport.loseConnection()
def getQuote(self):
"""
Return a quote. May be overrriden in subclasses.
"""
return b"An apple a day keeps the doctor away.\r\n"
class Who(protocol.Protocol):
"""
Return list of active users (RFC 866)
"""
def connectionMade(self):
self.transport.write(self.getUsers())
self.transport.loseConnection()
def getUsers(self):
"""
Return active users. Override in subclasses.
"""
return b"root\r\n"
class Daytime(protocol.Protocol):
"""
Send back the daytime in ASCII form (RFC 867).
"""
def connectionMade(self):
self.transport.write(time.asctime(time.gmtime(time.time())) + b"\r\n")
self.transport.loseConnection()
class Time(protocol.Protocol):
"""
Send back the time in machine readable form (RFC 868).
"""
def connectionMade(self):
# is this correct only for 32-bit machines?
result = struct.pack("!i", int(time.time()))
self.transport.write(result)
self.transport.loseConnection()
__all__ = ["Echo", "Discard", "Chargen", "QOTD", "Who", "Daytime", "Time"]