mirror of
https://github.com/pacnpal/thrillwiki_django_no_react.git
synced 2025-12-23 02:11:08 -05:00
okay fine
This commit is contained in:
19
.venv/lib/python3.12/site-packages/twisted/test/__init__.py
Normal file
19
.venv/lib/python3.12/site-packages/twisted/test/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Twisted's unit tests.
|
||||
"""
|
||||
|
||||
|
||||
from twisted.python.deprecate import deprecatedModuleAttribute
|
||||
from twisted.python.versions import Version
|
||||
from twisted.test import proto_helpers
|
||||
|
||||
for obj in proto_helpers.__all__:
|
||||
deprecatedModuleAttribute(
|
||||
Version("Twisted", 19, 7, 0),
|
||||
f"Please use twisted.internet.testing.{obj} instead.",
|
||||
"twisted.test.proto_helpers",
|
||||
obj,
|
||||
)
|
||||
@@ -0,0 +1,25 @@
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIEJDCCAwygAwIBAgIUKaSXgzt5gDMt9GbUzLz/A9HEyFEwDQYJKoZIhvcNAQEL
|
||||
BQAwgb0xGDAWBgNVBAMMD0EgSG9zdCwgTG9jYWxseTELMAkGA1UEBhMCVFIxDzAN
|
||||
BgNVBAgMBsOHb3J1bTEUMBIGA1UEBwwLQmHFn21ha8OnxLExHDAaBgNVBAoME1R3
|
||||
aXN0ZWQgTWF0cml4IExhYnMxJDAiBgNVBAsMG0F1dG9tYXRlZCBUZXN0aW5nIEF1
|
||||
dGhvcml0eTEpMCcGCSqGSIb3DQEJARYac2VjdXJpdHlAdHdpc3RlZG1hdHJpeC5j
|
||||
b20wIBcNMjMwNjE0MTM0MDI4WhgPMjEyMzA1MjExMzQwMjhaMIG9MRgwFgYDVQQD
|
||||
DA9BIEhvc3QsIExvY2FsbHkxCzAJBgNVBAYTAlRSMQ8wDQYDVQQIDAbDh29ydW0x
|
||||
FDASBgNVBAcMC0JhxZ9tYWvDp8SxMRwwGgYDVQQKDBNUd2lzdGVkIE1hdHJpeCBM
|
||||
YWJzMSQwIgYDVQQLDBtBdXRvbWF0ZWQgVGVzdGluZyBBdXRob3JpdHkxKTAnBgkq
|
||||
hkiG9w0BCQEWGnNlY3VyaXR5QHR3aXN0ZWRtYXRyaXguY29tMIIBIjANBgkqhkiG
|
||||
9w0BAQEFAAOCAQ8AMIIBCgKCAQEA0rT5+hF+1BjE7qXms9PZWHskXZGXLPiYVmiY
|
||||
jsVeJAOtHAYq8igzA49KgR1xR9M4jQ6U46nwPsnGCh4liyxdWkBLw9maxMoE+r6d
|
||||
W1zZ8Tllunbdb/Da6L8P55SKb7QGet4CB1fZ2SqZD4GvTby6xpoR09AqrfjuEIYR
|
||||
8V/y+8dG3mR5W0HqaJ58IWihAwIQSakuc8jTadJY55t7UW6Ebj2X2WTO6Zh7gJ1d
|
||||
yHPMVkUHJF9Jsuj/4F4lx6hWGQzWO8Nf8Q7t364pagE3evUv/BECJLONNYLaFjLt
|
||||
WnsCEJDV9owCjaxu785KuA7OM/f3h3xVIfTBTo2AlHiQnXdyrwIDAQABoxgwFjAU
|
||||
BgNVHREEDTALgglsb2NhbGhvc3QwDQYJKoZIhvcNAQELBQADggEBAEHAErq/Fs8h
|
||||
M+kwGCt5Ochqyu/IzPbwgQ27n5IJehl7kmpoXBxGa/u+ajoxrZaOheg8E2MYVwQi
|
||||
FTKE9wJgaN3uGo4bzCbCYxDm7tflQORo6QOZlumfiQIzXON2RvgJpwFfkLNtq0t9
|
||||
e453kJ7+e11Wah46bc3RAvBZpwswh6hDv2FvFUZ+IUcO0tU8O4kWrLIFPpJbcHQq
|
||||
wezjky773X4CNEtoeuTb8/ws/eED/TGZ2AZO+BWT93OZJgwE2x3iUd3k8HbwxfoY
|
||||
bZ+NHgtM7iKRcL59asB0OMi3Ays0+IOfZ1+3aB82zYlxFBoDyalR7NJjJGdTwNFt
|
||||
3CPGCQ28cDk=
|
||||
-----END CERTIFICATE-----
|
||||
@@ -0,0 +1,40 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from zope.interface import Interface, implementer
|
||||
|
||||
from twisted.python import components
|
||||
|
||||
|
||||
def foo() -> Literal[2]:
|
||||
return 2
|
||||
|
||||
|
||||
class X:
|
||||
def __init__(self, x: str) -> None:
|
||||
self.x = x
|
||||
|
||||
def do(self) -> None:
|
||||
# print 'X',self.x,'doing!'
|
||||
pass
|
||||
|
||||
|
||||
class XComponent(components.Componentized):
|
||||
pass
|
||||
|
||||
|
||||
class IX(Interface):
|
||||
pass
|
||||
|
||||
|
||||
@implementer(IX)
|
||||
class XA(components.Adapter):
|
||||
def method(self) -> None:
|
||||
# Kick start :(
|
||||
pass
|
||||
|
||||
|
||||
components.registerAdapter(XA, X, IX)
|
||||
577
.venv/lib/python3.12/site-packages/twisted/test/iosim.py
Normal file
577
.venv/lib/python3.12/site-packages/twisted/test/iosim.py
Normal file
@@ -0,0 +1,577 @@
|
||||
# -*- test-case-name: twisted.test.test_amp,twisted.test.test_iosim -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Utilities and helpers for simulating a network
|
||||
"""
|
||||
|
||||
import itertools
|
||||
|
||||
try:
|
||||
from OpenSSL.SSL import Error as NativeOpenSSLError
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
from zope.interface import directlyProvides, implementer
|
||||
|
||||
from twisted.internet import error, interfaces
|
||||
from twisted.internet.endpoints import TCP4ClientEndpoint, TCP4ServerEndpoint
|
||||
from twisted.internet.error import ConnectionRefusedError
|
||||
from twisted.internet.protocol import Factory, Protocol
|
||||
from twisted.internet.testing import MemoryReactorClock
|
||||
from twisted.python.failure import Failure
|
||||
|
||||
|
||||
class TLSNegotiation:
|
||||
def __init__(self, obj, connectState):
|
||||
self.obj = obj
|
||||
self.connectState = connectState
|
||||
self.sent = False
|
||||
self.readyToSend = connectState
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"TLSNegotiation({self.obj!r})"
|
||||
|
||||
def pretendToVerify(self, other, tpt):
|
||||
# Set the transport problems list here? disconnections?
|
||||
# hmmmmm... need some negative path tests.
|
||||
|
||||
if not self.obj.iosimVerify(other.obj):
|
||||
tpt.disconnectReason = NativeOpenSSLError()
|
||||
tpt.loseConnection()
|
||||
|
||||
|
||||
@implementer(interfaces.IAddress)
|
||||
class FakeAddress:
|
||||
"""
|
||||
The default address type for the host and peer of L{FakeTransport}
|
||||
connections.
|
||||
"""
|
||||
|
||||
|
||||
@implementer(interfaces.ITransport, interfaces.ITLSTransport)
|
||||
class FakeTransport:
|
||||
"""
|
||||
A wrapper around a file-like object to make it behave as a Transport.
|
||||
|
||||
This doesn't actually stream the file to the attached protocol,
|
||||
and is thus useful mainly as a utility for debugging protocols.
|
||||
"""
|
||||
|
||||
_nextserial = staticmethod(lambda counter=itertools.count(): int(next(counter)))
|
||||
closed = 0
|
||||
disconnecting = 0
|
||||
disconnected = 0
|
||||
disconnectReason = error.ConnectionDone("Connection done")
|
||||
producer = None
|
||||
streamingProducer = 0
|
||||
tls = None
|
||||
|
||||
def __init__(self, protocol, isServer, hostAddress=None, peerAddress=None):
|
||||
"""
|
||||
@param protocol: This transport will deliver bytes to this protocol.
|
||||
@type protocol: L{IProtocol} provider
|
||||
|
||||
@param isServer: C{True} if this is the accepting side of the
|
||||
connection, C{False} if it is the connecting side.
|
||||
@type isServer: L{bool}
|
||||
|
||||
@param hostAddress: The value to return from C{getHost}. L{None}
|
||||
results in a new L{FakeAddress} being created to use as the value.
|
||||
@type hostAddress: L{IAddress} provider or L{None}
|
||||
|
||||
@param peerAddress: The value to return from C{getPeer}. L{None}
|
||||
results in a new L{FakeAddress} being created to use as the value.
|
||||
@type peerAddress: L{IAddress} provider or L{None}
|
||||
"""
|
||||
self.protocol = protocol
|
||||
self.isServer = isServer
|
||||
self.stream = []
|
||||
self.serial = self._nextserial()
|
||||
if hostAddress is None:
|
||||
hostAddress = FakeAddress()
|
||||
self.hostAddress = hostAddress
|
||||
if peerAddress is None:
|
||||
peerAddress = FakeAddress()
|
||||
self.peerAddress = peerAddress
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "FakeTransport<{},{},{}>".format(
|
||||
self.isServer and "S" or "C",
|
||||
self.serial,
|
||||
self.protocol.__class__.__name__,
|
||||
)
|
||||
|
||||
def write(self, data):
|
||||
# If transport is closed, we should accept writes but drop the data.
|
||||
if self.disconnecting:
|
||||
return
|
||||
|
||||
if self.tls is not None:
|
||||
self.tlsbuf.append(data)
|
||||
else:
|
||||
self.stream.append(data)
|
||||
|
||||
def _checkProducer(self):
|
||||
# Cheating; this is called at "idle" times to allow producers to be
|
||||
# found and dealt with
|
||||
if self.producer and not self.streamingProducer:
|
||||
self.producer.resumeProducing()
|
||||
|
||||
def registerProducer(self, producer, streaming):
|
||||
"""
|
||||
From abstract.FileDescriptor
|
||||
"""
|
||||
self.producer = producer
|
||||
self.streamingProducer = streaming
|
||||
if not streaming:
|
||||
producer.resumeProducing()
|
||||
|
||||
def unregisterProducer(self):
|
||||
self.producer = None
|
||||
|
||||
def stopConsuming(self):
|
||||
self.unregisterProducer()
|
||||
self.loseConnection()
|
||||
|
||||
def writeSequence(self, iovec):
|
||||
self.write(b"".join(iovec))
|
||||
|
||||
def loseConnection(self):
|
||||
self.disconnecting = True
|
||||
|
||||
def abortConnection(self):
|
||||
"""
|
||||
For the time being, this is the same as loseConnection; no buffered
|
||||
data will be lost.
|
||||
"""
|
||||
self.disconnecting = True
|
||||
|
||||
def reportDisconnect(self):
|
||||
if self.tls is not None:
|
||||
# We were in the middle of negotiating! Must have been a TLS
|
||||
# problem.
|
||||
err = NativeOpenSSLError()
|
||||
else:
|
||||
err = self.disconnectReason
|
||||
self.protocol.connectionLost(Failure(err))
|
||||
|
||||
def logPrefix(self):
|
||||
"""
|
||||
Identify this transport/event source to the logging system.
|
||||
"""
|
||||
return "iosim"
|
||||
|
||||
def getPeer(self):
|
||||
return self.peerAddress
|
||||
|
||||
def getHost(self):
|
||||
return self.hostAddress
|
||||
|
||||
def resumeProducing(self):
|
||||
# Never sends data anyways
|
||||
pass
|
||||
|
||||
def pauseProducing(self):
|
||||
# Never sends data anyways
|
||||
pass
|
||||
|
||||
def stopProducing(self):
|
||||
self.loseConnection()
|
||||
|
||||
def startTLS(self, contextFactory, beNormal=True):
|
||||
# Nothing's using this feature yet, but startTLS has an undocumented
|
||||
# second argument which defaults to true; if set to False, servers will
|
||||
# behave like clients and clients will behave like servers.
|
||||
connectState = self.isServer ^ beNormal
|
||||
self.tls = TLSNegotiation(contextFactory, connectState)
|
||||
self.tlsbuf = []
|
||||
|
||||
def getOutBuffer(self):
|
||||
"""
|
||||
Get the pending writes from this transport, clearing them from the
|
||||
pending buffer.
|
||||
|
||||
@return: the bytes written with C{transport.write}
|
||||
@rtype: L{bytes}
|
||||
"""
|
||||
S = self.stream
|
||||
if S:
|
||||
self.stream = []
|
||||
return b"".join(S)
|
||||
elif self.tls is not None:
|
||||
if self.tls.readyToSend:
|
||||
# Only _send_ the TLS negotiation "packet" if I'm ready to.
|
||||
self.tls.sent = True
|
||||
return self.tls
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
|
||||
def bufferReceived(self, buf):
|
||||
if isinstance(buf, TLSNegotiation):
|
||||
assert self.tls is not None # By the time you're receiving a
|
||||
# negotiation, you have to have called
|
||||
# startTLS already.
|
||||
if self.tls.sent:
|
||||
self.tls.pretendToVerify(buf, self)
|
||||
self.tls = None # We're done with the handshake if we've gotten
|
||||
# this far... although maybe it failed...?
|
||||
# TLS started! Unbuffer...
|
||||
b, self.tlsbuf = self.tlsbuf, None
|
||||
self.writeSequence(b)
|
||||
directlyProvides(self, interfaces.ISSLTransport)
|
||||
else:
|
||||
# We haven't sent our own TLS negotiation: time to do that!
|
||||
self.tls.readyToSend = True
|
||||
else:
|
||||
self.protocol.dataReceived(buf)
|
||||
|
||||
def getTcpKeepAlive(self):
|
||||
# ITCPTransport.getTcpKeepAlive
|
||||
pass
|
||||
|
||||
def getTcpNoDelay(self):
|
||||
# ITCPTransport.getTcpNoDelay
|
||||
pass
|
||||
|
||||
def loseWriteConnection(self):
|
||||
# ITCPTransport.loseWriteConnection
|
||||
pass
|
||||
|
||||
def setTcpKeepAlive(self, enabled):
|
||||
# ITCPTransport.setTcpKeepAlive
|
||||
pass
|
||||
|
||||
def setTcpNoDelay(self, enabled):
|
||||
# ITCPTransport.setTcpNoDelay
|
||||
pass
|
||||
|
||||
|
||||
def makeFakeClient(clientProtocol):
|
||||
"""
|
||||
Create and return a new in-memory transport hooked up to the given protocol.
|
||||
|
||||
@param clientProtocol: The client protocol to use.
|
||||
@type clientProtocol: L{IProtocol} provider
|
||||
|
||||
@return: The transport.
|
||||
@rtype: L{FakeTransport}
|
||||
"""
|
||||
return FakeTransport(clientProtocol, isServer=False)
|
||||
|
||||
|
||||
def makeFakeServer(serverProtocol):
|
||||
"""
|
||||
Create and return a new in-memory transport hooked up to the given protocol.
|
||||
|
||||
@param serverProtocol: The server protocol to use.
|
||||
@type serverProtocol: L{IProtocol} provider
|
||||
|
||||
@return: The transport.
|
||||
@rtype: L{FakeTransport}
|
||||
"""
|
||||
return FakeTransport(serverProtocol, isServer=True)
|
||||
|
||||
|
||||
class IOPump:
|
||||
"""
|
||||
Utility to pump data between clients and servers for protocol testing.
|
||||
|
||||
Perhaps this is a utility worthy of being in protocol.py?
|
||||
"""
|
||||
|
||||
def __init__(self, client, server, clientIO, serverIO, debug, clock=None):
|
||||
self.client = client
|
||||
self.server = server
|
||||
self.clientIO = clientIO
|
||||
self.serverIO = serverIO
|
||||
self.debug = debug
|
||||
if clock is None:
|
||||
clock = MemoryReactorClock()
|
||||
self.clock = clock
|
||||
|
||||
def flush(self, debug=False, advanceClock=True):
|
||||
"""
|
||||
Pump until there is no more input or output.
|
||||
|
||||
Returns whether any data was moved.
|
||||
"""
|
||||
result = False
|
||||
for _ in range(1000):
|
||||
if self.pump(debug, advanceClock):
|
||||
result = True
|
||||
else:
|
||||
break
|
||||
else:
|
||||
assert 0, "Too long"
|
||||
return result
|
||||
|
||||
def pump(self, debug=False, advanceClock=True):
|
||||
"""
|
||||
Move data back and forth, while also triggering any currently pending
|
||||
scheduled calls (i.e. C{callLater(0, f)}).
|
||||
|
||||
Returns whether any data was moved.
|
||||
"""
|
||||
if advanceClock:
|
||||
self.clock.advance(0)
|
||||
if self.debug or debug:
|
||||
print("-- GLUG --")
|
||||
sData = self.serverIO.getOutBuffer()
|
||||
cData = self.clientIO.getOutBuffer()
|
||||
self.clientIO._checkProducer()
|
||||
self.serverIO._checkProducer()
|
||||
if self.debug or debug:
|
||||
print(".")
|
||||
# XXX slightly buggy in the face of incremental output
|
||||
if cData:
|
||||
print("C: " + repr(cData))
|
||||
if sData:
|
||||
print("S: " + repr(sData))
|
||||
if cData:
|
||||
self.serverIO.bufferReceived(cData)
|
||||
if sData:
|
||||
self.clientIO.bufferReceived(sData)
|
||||
if cData or sData:
|
||||
return True
|
||||
if self.serverIO.disconnecting and not self.serverIO.disconnected:
|
||||
if self.debug or debug:
|
||||
print("* C")
|
||||
self.serverIO.disconnected = True
|
||||
self.clientIO.disconnecting = True
|
||||
self.clientIO.reportDisconnect()
|
||||
return True
|
||||
if self.clientIO.disconnecting and not self.clientIO.disconnected:
|
||||
if self.debug or debug:
|
||||
print("* S")
|
||||
self.clientIO.disconnected = True
|
||||
self.serverIO.disconnecting = True
|
||||
self.serverIO.reportDisconnect()
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def connect(
|
||||
serverProtocol,
|
||||
serverTransport,
|
||||
clientProtocol,
|
||||
clientTransport,
|
||||
debug=False,
|
||||
greet=True,
|
||||
clock=None,
|
||||
):
|
||||
"""
|
||||
Create a new L{IOPump} connecting two protocols.
|
||||
|
||||
@param serverProtocol: The protocol to use on the accepting side of the
|
||||
connection.
|
||||
@type serverProtocol: L{IProtocol} provider
|
||||
|
||||
@param serverTransport: The transport to associate with C{serverProtocol}.
|
||||
@type serverTransport: L{FakeTransport}
|
||||
|
||||
@param clientProtocol: The protocol to use on the initiating side of the
|
||||
connection.
|
||||
@type clientProtocol: L{IProtocol} provider
|
||||
|
||||
@param clientTransport: The transport to associate with C{clientProtocol}.
|
||||
@type clientTransport: L{FakeTransport}
|
||||
|
||||
@param debug: A flag indicating whether to log information about what the
|
||||
L{IOPump} is doing.
|
||||
@type debug: L{bool}
|
||||
|
||||
@param greet: Should the L{IOPump} be L{flushed <IOPump.flush>} once before
|
||||
returning to put the protocols into their post-handshake or
|
||||
post-server-greeting state?
|
||||
@type greet: L{bool}
|
||||
|
||||
@param clock: An optional L{Clock}. Pumping the resulting L{IOPump} will
|
||||
also increase clock time by a small increment.
|
||||
|
||||
@return: An L{IOPump} which connects C{serverProtocol} and
|
||||
C{clientProtocol} and delivers bytes between them when it is pumped.
|
||||
@rtype: L{IOPump}
|
||||
"""
|
||||
serverProtocol.makeConnection(serverTransport)
|
||||
clientProtocol.makeConnection(clientTransport)
|
||||
pump = IOPump(
|
||||
clientProtocol,
|
||||
serverProtocol,
|
||||
clientTransport,
|
||||
serverTransport,
|
||||
debug,
|
||||
clock=clock,
|
||||
)
|
||||
if greet:
|
||||
# Kick off server greeting, etc
|
||||
pump.flush()
|
||||
return pump
|
||||
|
||||
|
||||
def connectedServerAndClient(
|
||||
ServerClass,
|
||||
ClientClass,
|
||||
clientTransportFactory=makeFakeClient,
|
||||
serverTransportFactory=makeFakeServer,
|
||||
debug=False,
|
||||
greet=True,
|
||||
clock=None,
|
||||
):
|
||||
"""
|
||||
Connect a given server and client class to each other.
|
||||
|
||||
@param ServerClass: a callable that produces the server-side protocol.
|
||||
@type ServerClass: 0-argument callable returning L{IProtocol} provider.
|
||||
|
||||
@param ClientClass: like C{ServerClass} but for the other side of the
|
||||
connection.
|
||||
@type ClientClass: 0-argument callable returning L{IProtocol} provider.
|
||||
|
||||
@param clientTransportFactory: a callable that produces the transport which
|
||||
will be attached to the protocol returned from C{ClientClass}.
|
||||
@type clientTransportFactory: callable taking (L{IProtocol}) and returning
|
||||
L{FakeTransport}
|
||||
|
||||
@param serverTransportFactory: a callable that produces the transport which
|
||||
will be attached to the protocol returned from C{ServerClass}.
|
||||
@type serverTransportFactory: callable taking (L{IProtocol}) and returning
|
||||
L{FakeTransport}
|
||||
|
||||
@param debug: Should this dump an escaped version of all traffic on this
|
||||
connection to stdout for inspection?
|
||||
@type debug: L{bool}
|
||||
|
||||
@param greet: Should the L{IOPump} be L{flushed <IOPump.flush>} once before
|
||||
returning to put the protocols into their post-handshake or
|
||||
post-server-greeting state?
|
||||
@type greet: L{bool}
|
||||
|
||||
@param clock: An optional L{Clock}. Pumping the resulting L{IOPump} will
|
||||
also increase clock time by a small increment.
|
||||
|
||||
@return: the client protocol, the server protocol, and an L{IOPump} which,
|
||||
when its C{pump} and C{flush} methods are called, will move data
|
||||
between the created client and server protocol instances.
|
||||
@rtype: 3-L{tuple} of L{IProtocol}, L{IProtocol}, L{IOPump}
|
||||
"""
|
||||
c = ClientClass()
|
||||
s = ServerClass()
|
||||
cio = clientTransportFactory(c)
|
||||
sio = serverTransportFactory(s)
|
||||
return c, s, connect(s, sio, c, cio, debug, greet, clock=clock)
|
||||
|
||||
|
||||
def _factoriesShouldConnect(clientInfo, serverInfo):
|
||||
"""
|
||||
Should the client and server described by the arguments be connected to
|
||||
each other, i.e. do their port numbers match?
|
||||
|
||||
@param clientInfo: the args for connectTCP
|
||||
@type clientInfo: L{tuple}
|
||||
|
||||
@param serverInfo: the args for listenTCP
|
||||
@type serverInfo: L{tuple}
|
||||
|
||||
@return: If they do match, return factories for the client and server that
|
||||
should connect; otherwise return L{None}, indicating they shouldn't be
|
||||
connected.
|
||||
@rtype: L{None} or 2-L{tuple} of (L{ClientFactory},
|
||||
L{IProtocolFactory})
|
||||
"""
|
||||
(
|
||||
clientHost,
|
||||
clientPort,
|
||||
clientFactory,
|
||||
clientTimeout,
|
||||
clientBindAddress,
|
||||
) = clientInfo
|
||||
(serverPort, serverFactory, serverBacklog, serverInterface) = serverInfo
|
||||
if serverPort == clientPort:
|
||||
return clientFactory, serverFactory
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class ConnectionCompleter:
|
||||
"""
|
||||
A L{ConnectionCompleter} can cause synthetic TCP connections established by
|
||||
L{MemoryReactor.connectTCP} and L{MemoryReactor.listenTCP} to succeed or
|
||||
fail.
|
||||
"""
|
||||
|
||||
def __init__(self, memoryReactor):
|
||||
"""
|
||||
Create a L{ConnectionCompleter} from a L{MemoryReactor}.
|
||||
|
||||
@param memoryReactor: The reactor to attach to.
|
||||
@type memoryReactor: L{MemoryReactor}
|
||||
"""
|
||||
self._reactor = memoryReactor
|
||||
|
||||
def succeedOnce(self, debug=False):
|
||||
"""
|
||||
Complete a single TCP connection established on this
|
||||
L{ConnectionCompleter}'s L{MemoryReactor}.
|
||||
|
||||
@param debug: A flag; whether to dump output from the established
|
||||
connection to stdout.
|
||||
@type debug: L{bool}
|
||||
|
||||
@return: a pump for the connection, or L{None} if no connection could
|
||||
be established.
|
||||
@rtype: L{IOPump} or L{None}
|
||||
"""
|
||||
memoryReactor = self._reactor
|
||||
for clientIdx, clientInfo in enumerate(memoryReactor.tcpClients):
|
||||
for serverInfo in memoryReactor.tcpServers:
|
||||
factories = _factoriesShouldConnect(clientInfo, serverInfo)
|
||||
if factories:
|
||||
memoryReactor.tcpClients.remove(clientInfo)
|
||||
memoryReactor.connectors.pop(clientIdx)
|
||||
clientFactory, serverFactory = factories
|
||||
clientProtocol = clientFactory.buildProtocol(None)
|
||||
serverProtocol = serverFactory.buildProtocol(None)
|
||||
serverTransport = makeFakeServer(serverProtocol)
|
||||
clientTransport = makeFakeClient(clientProtocol)
|
||||
return connect(
|
||||
serverProtocol,
|
||||
serverTransport,
|
||||
clientProtocol,
|
||||
clientTransport,
|
||||
debug,
|
||||
)
|
||||
|
||||
def failOnce(self, reason=Failure(ConnectionRefusedError())):
|
||||
"""
|
||||
Fail a single TCP connection established on this
|
||||
L{ConnectionCompleter}'s L{MemoryReactor}.
|
||||
|
||||
@param reason: the reason to provide that the connection failed.
|
||||
@type reason: L{Failure}
|
||||
"""
|
||||
self._reactor.tcpClients.pop(0)[2].clientConnectionFailed(
|
||||
self._reactor.connectors.pop(0), reason
|
||||
)
|
||||
|
||||
|
||||
def connectableEndpoint(debug=False):
|
||||
"""
|
||||
Create an endpoint that can be fired on demand.
|
||||
|
||||
@param debug: A flag; whether to dump output from the established
|
||||
connection to stdout.
|
||||
@type debug: L{bool}
|
||||
|
||||
@return: A client endpoint, and an object that will cause one of the
|
||||
L{Deferred}s returned by that client endpoint.
|
||||
@rtype: 2-L{tuple} of (L{IStreamClientEndpoint}, L{ConnectionCompleter})
|
||||
"""
|
||||
reactor = MemoryReactorClock()
|
||||
clientEndpoint = TCP4ClientEndpoint(reactor, "0.0.0.0", 4321)
|
||||
serverEndpoint = TCP4ServerEndpoint(reactor, 4321)
|
||||
serverEndpoint.listen(Factory.forProtocol(Protocol))
|
||||
return clientEndpoint, ConnectionCompleter(reactor)
|
||||
@@ -0,0 +1,27 @@
|
||||
-----BEGIN RSA PRIVATE KEY-----
|
||||
MIIEogIBAAKCAQEA0rT5+hF+1BjE7qXms9PZWHskXZGXLPiYVmiYjsVeJAOtHAYq
|
||||
8igzA49KgR1xR9M4jQ6U46nwPsnGCh4liyxdWkBLw9maxMoE+r6dW1zZ8Tllunbd
|
||||
b/Da6L8P55SKb7QGet4CB1fZ2SqZD4GvTby6xpoR09AqrfjuEIYR8V/y+8dG3mR5
|
||||
W0HqaJ58IWihAwIQSakuc8jTadJY55t7UW6Ebj2X2WTO6Zh7gJ1dyHPMVkUHJF9J
|
||||
suj/4F4lx6hWGQzWO8Nf8Q7t364pagE3evUv/BECJLONNYLaFjLtWnsCEJDV9owC
|
||||
jaxu785KuA7OM/f3h3xVIfTBTo2AlHiQnXdyrwIDAQABAoH/Ib7aSjKDHXTaFV58
|
||||
lFBZftI6AMJQc+Ncgno99J+ndB0inFpghmfpw6gvRn5wphAt/mlXbx7IW0X1cali
|
||||
WefBC7NAbx1qrBmusnnUuc0lGn0WzcY7sLHiXWQ8J9qiUUGDyCnGKWbofN9VpCYg
|
||||
7VJMl4IVWNb9/t7fQcY3GXFEeQ4mzLo7p+gPxyeUcCLVrhVrHzw1HFTIlA51LjfI
|
||||
xQM+QVeaEWQQ4UsDdPe5iGthDd7ze2F5ciDzMkShrf7URSudS+Us6vr6gDVpKAky
|
||||
eCVyFPJXCfH4qJoa6mB6L6SFzMnN3OPp3RlYQWQ7sK/ELQfhPoyHyRvL1woUIO5C
|
||||
tK0pAoGBAPS6ZSZ26M0guZ2K/2fKMiGq0jZQLcxP3N0jWm8R8ENOnuIjhCl5aKsB
|
||||
DoV0BvPv1C2vWm+VgNArgTece9l8o5f8pcfjbT5r/k8zoqgcj9CmmDofBka4XxZb
|
||||
wxsut+8rBSIoVKIre4Pyqfa9u1IrEnoOzMqvF16xUME2t2EaryUzAoGBANxpb4Jz
|
||||
FjH7nfPc3iejd+cXovX6x2VTJzWaknA6hGsoc+UZ01KTaKyYpq+9q9VxXhWxYsh3
|
||||
TL1JWuIBy6ao5tdt4nPBu07J7tfu5bfr3Imd8waNQxDEfKeFedskxORs+FIUzqBb
|
||||
3nIkQH8sx0Syv620coIdtEn1raVXc9QfRgSVAoGAWNFhLoGPYgsTcnrk0N1QLmnZ
|
||||
mv6kcHc3mEZhZtgi07qv7TCooYi/lPhwNbzzXQrYfbAbaU3gDy0K24z+YeNbWCjI
|
||||
XfBLUJFPHZ2G1e5vv3EG5GkoFPiLAglRmQbumG2LkmcCuEyBqlSinLslRd/997Bx
|
||||
YMoE+EfwH/9ktGhD0oMCgYEAxaSqAFDQ00ssjTM95k94Qjn4wBf7WwmgfDm6HHbs
|
||||
rOZeXk61JzPVxgcwWSB8iG4bDtq8mMQZhRbVLxqrEiwcq4r2aBSNsI305Z5sUWtn
|
||||
m+ONvA9J1yxKFzHiXjbvc2GfnoLX8gXPR4zoZOGzYg/jP5EyqSiXtUZfSodL7yeH
|
||||
8q0CgYEA2OzA59AITJe8jhC5JsVbLs7Rj4kFTjD+iZ8P86FnWBf1iDeuywEZJqvG
|
||||
n6SNK4KczDJ//DBV06w4L6iwe5iOCdf06+V7Hnkbvrjk0ONnXX7VXNgJ3/e7aJTx
|
||||
gE42Ug0qu6lXtEfYqlhQoF2lAtnYq0fty/XWMVfpjVuh1lyd4C4=
|
||||
-----END RSA PRIVATE KEY-----
|
||||
@@ -0,0 +1,52 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
This is a mock win32process module.
|
||||
|
||||
The purpose of this module is mock process creation for the PID test.
|
||||
|
||||
CreateProcess(...) will spawn a process, and always return a PID of 42.
|
||||
"""
|
||||
|
||||
import win32process
|
||||
|
||||
GetExitCodeProcess = win32process.GetExitCodeProcess
|
||||
STARTUPINFO = win32process.STARTUPINFO
|
||||
|
||||
STARTF_USESTDHANDLES = win32process.STARTF_USESTDHANDLES
|
||||
|
||||
|
||||
def CreateProcess(
|
||||
appName,
|
||||
cmdline,
|
||||
procSecurity,
|
||||
threadSecurity,
|
||||
inheritHandles,
|
||||
newEnvironment,
|
||||
env,
|
||||
workingDir,
|
||||
startupInfo,
|
||||
):
|
||||
"""
|
||||
This function mocks the generated pid aspect of the win32.CreateProcess
|
||||
function.
|
||||
- the true win32process.CreateProcess is called
|
||||
- return values are harvested in a tuple.
|
||||
- all return values from createProcess are passed back to the calling
|
||||
function except for the pid, the returned pid is hardcoded to 42
|
||||
"""
|
||||
|
||||
hProcess, hThread, dwPid, dwTid = win32process.CreateProcess(
|
||||
appName,
|
||||
cmdline,
|
||||
procSecurity,
|
||||
threadSecurity,
|
||||
inheritHandles,
|
||||
newEnvironment,
|
||||
env,
|
||||
workingDir,
|
||||
startupInfo,
|
||||
)
|
||||
dwPid = 42
|
||||
return (hProcess, hThread, dwPid, dwTid)
|
||||
@@ -0,0 +1,13 @@
|
||||
class A:
|
||||
def a(self) -> str:
|
||||
return "a"
|
||||
|
||||
|
||||
class B(A):
|
||||
def b(self) -> str:
|
||||
return "b"
|
||||
|
||||
|
||||
class Inherit(A):
|
||||
def a(self) -> str:
|
||||
return "c"
|
||||
@@ -0,0 +1,13 @@
|
||||
class A:
|
||||
def a(self) -> str:
|
||||
return "b"
|
||||
|
||||
|
||||
class B(A):
|
||||
def b(self) -> str:
|
||||
return "c"
|
||||
|
||||
|
||||
class Inherit(A):
|
||||
def a(self) -> str:
|
||||
return "d"
|
||||
@@ -0,0 +1,47 @@
|
||||
# Copyright (c) 2005 Divmod, Inc.
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
# Don't change the docstring, it's part of the tests
|
||||
"""
|
||||
I'm a test drop-in. The plugin system's unit tests use me. No one
|
||||
else should.
|
||||
"""
|
||||
|
||||
from zope.interface import provider
|
||||
|
||||
from twisted.plugin import IPlugin
|
||||
from twisted.test.test_plugin import ITestPlugin, ITestPlugin2
|
||||
|
||||
|
||||
@provider(ITestPlugin, IPlugin)
|
||||
class TestPlugin:
|
||||
"""
|
||||
A plugin used solely for testing purposes.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def test1() -> None:
|
||||
pass
|
||||
|
||||
|
||||
@provider(ITestPlugin2, IPlugin)
|
||||
class AnotherTestPlugin:
|
||||
"""
|
||||
Another plugin used solely for testing purposes.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def test() -> None:
|
||||
pass
|
||||
|
||||
|
||||
@provider(ITestPlugin2, IPlugin)
|
||||
class ThirdTestPlugin:
|
||||
"""
|
||||
Another plugin used solely for testing purposes.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def test() -> None:
|
||||
pass
|
||||
@@ -0,0 +1,19 @@
|
||||
# Copyright (c) 2005 Divmod, Inc.
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Test plugin used in L{twisted.test.test_plugin}.
|
||||
"""
|
||||
|
||||
from zope.interface import provider
|
||||
|
||||
from twisted.plugin import IPlugin
|
||||
from twisted.test.test_plugin import ITestPlugin
|
||||
|
||||
|
||||
@provider(ITestPlugin, IPlugin)
|
||||
class FourthTestPlugin:
|
||||
@staticmethod
|
||||
def test1() -> None:
|
||||
pass
|
||||
@@ -0,0 +1,30 @@
|
||||
# Copyright (c) 2005 Divmod, Inc.
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Test plugin used in L{twisted.test.test_plugin}.
|
||||
"""
|
||||
|
||||
from zope.interface import provider
|
||||
|
||||
from twisted.plugin import IPlugin
|
||||
from twisted.test.test_plugin import ITestPlugin
|
||||
|
||||
|
||||
@provider(ITestPlugin, IPlugin)
|
||||
class FourthTestPlugin:
|
||||
@staticmethod
|
||||
def test1() -> None:
|
||||
pass
|
||||
|
||||
|
||||
@provider(ITestPlugin, IPlugin)
|
||||
class FifthTestPlugin:
|
||||
"""
|
||||
More documentation: I hate you.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def test1() -> None:
|
||||
pass
|
||||
@@ -0,0 +1,9 @@
|
||||
"""
|
||||
Write to stdout the command line args it received, one per line.
|
||||
"""
|
||||
|
||||
|
||||
import sys
|
||||
|
||||
for x in sys.argv[1:]:
|
||||
print(x)
|
||||
@@ -0,0 +1,11 @@
|
||||
"""Write back all data it receives."""
|
||||
|
||||
import sys
|
||||
|
||||
data = sys.stdin.read(1)
|
||||
while data:
|
||||
sys.stdout.write(data)
|
||||
sys.stdout.flush()
|
||||
data = sys.stdin.read(1)
|
||||
sys.stderr.write("byebye")
|
||||
sys.stderr.flush()
|
||||
@@ -0,0 +1,50 @@
|
||||
"""Write to a handful of file descriptors, to test the childFDs= argument of
|
||||
reactor.spawnProcess()
|
||||
"""
|
||||
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
if __name__ == "__main__":
|
||||
debug = 0
|
||||
|
||||
if debug:
|
||||
stderr = os.fdopen(2, "w")
|
||||
|
||||
if debug:
|
||||
print("this is stderr", file=stderr)
|
||||
|
||||
abcd = os.read(0, 4)
|
||||
if debug:
|
||||
print("read(0):", abcd, file=stderr)
|
||||
if abcd != b"abcd":
|
||||
sys.exit(1)
|
||||
|
||||
if debug:
|
||||
print("os.write(1, righto)", file=stderr)
|
||||
os.write(1, b"righto")
|
||||
|
||||
efgh = os.read(3, 4)
|
||||
if debug:
|
||||
print("read(3):", file=stderr)
|
||||
if efgh != b"efgh":
|
||||
sys.exit(2)
|
||||
|
||||
if debug:
|
||||
print("os.close(4)", file=stderr)
|
||||
os.close(4)
|
||||
|
||||
eof = os.read(5, 4)
|
||||
if debug:
|
||||
print("read(5):", eof, file=stderr)
|
||||
if eof != b"":
|
||||
sys.exit(3)
|
||||
|
||||
if debug:
|
||||
print("os.write(1, closed)", file=stderr)
|
||||
os.write(1, b"closed")
|
||||
|
||||
if debug:
|
||||
print("sys.exit(0)", file=stderr)
|
||||
sys.exit(0)
|
||||
@@ -0,0 +1,13 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Used by L{twisted.test.test_process}.
|
||||
"""
|
||||
|
||||
|
||||
from sys import argv, stdout
|
||||
|
||||
if __name__ == "__main__":
|
||||
stdout.write(chr(0).join(argv))
|
||||
stdout.flush()
|
||||
@@ -0,0 +1,13 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Used by L{twisted.test.test_process}.
|
||||
"""
|
||||
|
||||
from os import environ
|
||||
from sys import stdout
|
||||
|
||||
items = environ.items()
|
||||
stdout.write(chr(0).join([k + chr(0) + v for k, v in items]))
|
||||
stdout.flush()
|
||||
@@ -0,0 +1,17 @@
|
||||
"""Write to a file descriptor and then close it, waiting a few seconds before
|
||||
quitting. This serves to make sure SIGCHLD is actually being noticed.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
print("here is some text")
|
||||
time.sleep(1)
|
||||
print("goodbye")
|
||||
os.close(1)
|
||||
os.close(2)
|
||||
|
||||
time.sleep(2)
|
||||
|
||||
sys.exit(0)
|
||||
@@ -0,0 +1,10 @@
|
||||
"""Script used by test_process.TestTwoProcesses"""
|
||||
|
||||
# run until stdin is closed, then quit
|
||||
|
||||
import sys
|
||||
|
||||
while 1:
|
||||
d = sys.stdin.read()
|
||||
if len(d) == 0:
|
||||
sys.exit(0)
|
||||
@@ -0,0 +1,9 @@
|
||||
import signal
|
||||
import sys
|
||||
|
||||
signal.signal(signal.SIGINT, signal.SIG_DFL)
|
||||
if getattr(signal, "SIGHUP", None) is not None:
|
||||
signal.signal(signal.SIGHUP, signal.SIG_DFL)
|
||||
print("ok, signal us")
|
||||
sys.stdin.read()
|
||||
sys.exit(1)
|
||||
@@ -0,0 +1,34 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Script used by twisted.test.test_process on win32.
|
||||
"""
|
||||
|
||||
|
||||
import msvcrt
|
||||
import os
|
||||
import sys
|
||||
|
||||
msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY) # type:ignore[attr-defined]
|
||||
msvcrt.setmode(sys.stderr.fileno(), os.O_BINARY) # type:ignore[attr-defined]
|
||||
|
||||
# We want to write bytes directly to the output, not text, because otherwise
|
||||
# newlines get mangled. Get the underlying buffer.
|
||||
stdout = sys.stdout.buffer
|
||||
stderr = sys.stderr.buffer
|
||||
stdin = sys.stdin.buffer
|
||||
|
||||
stdout.write(b"out\n")
|
||||
stdout.flush()
|
||||
stderr.write(b"err\n")
|
||||
stderr.flush()
|
||||
|
||||
data = stdin.read()
|
||||
|
||||
stdout.write(data)
|
||||
stdout.write(b"\nout\n")
|
||||
stderr.write(b"err\n")
|
||||
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
@@ -0,0 +1,44 @@
|
||||
"""Test program for processes."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
test_file_match = "process_test.log.*"
|
||||
test_file = "process_test.log.%d" % os.getpid()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
f = open(test_file, "wb")
|
||||
|
||||
stdin = sys.stdin.buffer
|
||||
stderr = sys.stderr.buffer
|
||||
stdout = sys.stdout.buffer
|
||||
|
||||
# stage 1
|
||||
b = stdin.read(4)
|
||||
f.write(b"one: " + b + b"\n")
|
||||
|
||||
# stage 2
|
||||
stdout.write(b)
|
||||
stdout.flush()
|
||||
os.close(sys.stdout.fileno())
|
||||
|
||||
# and a one, and a two, and a...
|
||||
b = stdin.read(4)
|
||||
f.write(b"two: " + b + b"\n")
|
||||
|
||||
# stage 3
|
||||
stderr.write(b)
|
||||
stderr.flush()
|
||||
os.close(stderr.fileno())
|
||||
|
||||
# stage 4
|
||||
b = stdin.read(4)
|
||||
f.write(b"three: " + b + b"\n")
|
||||
|
||||
# exit with status code 23
|
||||
sys.exit(23)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,5 @@
|
||||
"""Test to make sure we can open /dev/tty"""
|
||||
|
||||
with open("/dev/tty", "rb+", buffering=0) as f:
|
||||
a = f.readline()
|
||||
f.write(a)
|
||||
@@ -0,0 +1,49 @@
|
||||
"""A process that reads from stdin and out using Twisted."""
|
||||
|
||||
|
||||
# Twisted Preamble
|
||||
# This makes sure that users don't have to set up their environment
|
||||
# specially in order to run these programs from bin/.
|
||||
import os
|
||||
import sys
|
||||
|
||||
pos = os.path.abspath(sys.argv[0]).find(os.sep + "Twisted")
|
||||
if pos != -1:
|
||||
sys.path.insert(0, os.path.abspath(sys.argv[0])[: pos + 8])
|
||||
sys.path.insert(0, os.curdir)
|
||||
# end of preamble
|
||||
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.internet import interfaces
|
||||
from twisted.python import log
|
||||
|
||||
log.startLogging(sys.stderr)
|
||||
|
||||
|
||||
from twisted.internet import protocol, reactor, stdio
|
||||
|
||||
|
||||
@implementer(interfaces.IHalfCloseableProtocol)
|
||||
class Echo(protocol.Protocol):
|
||||
def connectionMade(self):
|
||||
print("connection made")
|
||||
|
||||
def dataReceived(self, data):
|
||||
self.transport.write(data)
|
||||
|
||||
def readConnectionLost(self):
|
||||
print("readConnectionLost")
|
||||
self.transport.loseConnection()
|
||||
|
||||
def writeConnectionLost(self):
|
||||
print("writeConnectionLost")
|
||||
|
||||
def connectionLost(self, reason):
|
||||
print("connectionLost", reason)
|
||||
reactor.stop()
|
||||
|
||||
|
||||
stdio.StandardIO(Echo())
|
||||
reactor.run() # type: ignore[attr-defined]
|
||||
@@ -0,0 +1,43 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Assorted functionality which is commonly useful when writing unit tests.
|
||||
|
||||
This module has been deprecated, please use twisted.internet.testing
|
||||
instead.
|
||||
"""
|
||||
from twisted.internet import testing
|
||||
|
||||
__all__ = [
|
||||
"AccumulatingProtocol",
|
||||
"LineSendingProtocol",
|
||||
"FakeDatagramTransport",
|
||||
"StringTransport",
|
||||
"StringTransportWithDisconnection",
|
||||
"StringIOWithoutClosing",
|
||||
"_FakeConnector",
|
||||
"_FakePort",
|
||||
"MemoryReactor",
|
||||
"MemoryReactorClock",
|
||||
"RaisingMemoryReactor",
|
||||
"NonStreamingProducer",
|
||||
"waitUntilAllDisconnected",
|
||||
"EventLoggingObserver",
|
||||
]
|
||||
|
||||
|
||||
AccumulatingProtocol = testing.AccumulatingProtocol
|
||||
LineSendingProtocol = testing.LineSendingProtocol
|
||||
FakeDatagramTransport = testing.FakeDatagramTransport
|
||||
StringTransport = testing.StringTransport
|
||||
StringTransportWithDisconnection = testing.StringTransportWithDisconnection
|
||||
StringIOWithoutClosing = testing.StringIOWithoutClosing
|
||||
_FakeConnector = testing._FakeConnector
|
||||
_FakePort = testing._FakePort
|
||||
MemoryReactor = testing.MemoryReactor
|
||||
MemoryReactorClock = testing.MemoryReactorClock
|
||||
RaisingMemoryReactor = testing.RaisingMemoryReactor
|
||||
NonStreamingProducer = testing.NonStreamingProducer
|
||||
waitUntilAllDisconnected = testing.waitUntilAllDisconnected
|
||||
EventLoggingObserver = testing.EventLoggingObserver
|
||||
@@ -0,0 +1,3 @@
|
||||
# Helper for a test_reflect test
|
||||
|
||||
__import__("idonotexist")
|
||||
@@ -0,0 +1,3 @@
|
||||
# Helper for a test_reflect test
|
||||
|
||||
raise ValueError("Stuff is broken and things")
|
||||
@@ -0,0 +1,3 @@
|
||||
# Helper module for a test_reflect test
|
||||
|
||||
1 // 0
|
||||
123
.venv/lib/python3.12/site-packages/twisted/test/server.pem
Normal file
123
.venv/lib/python3.12/site-packages/twisted/test/server.pem
Normal file
@@ -0,0 +1,123 @@
|
||||
# coding: utf-8
|
||||
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from inspect import getsource
|
||||
|
||||
from cryptography.hazmat.primitives.asymmetric.rsa import generate_private_key
|
||||
from cryptography.hazmat.primitives.hashes import SHA256
|
||||
from cryptography.hazmat.primitives.serialization import (
|
||||
Encoding,
|
||||
NoEncryption,
|
||||
PrivateFormat,
|
||||
)
|
||||
from cryptography.x509 import (
|
||||
CertificateBuilder,
|
||||
Name,
|
||||
NameAttribute,
|
||||
NameOID,
|
||||
SubjectAlternativeName,
|
||||
DNSName,
|
||||
random_serial_number,
|
||||
)
|
||||
|
||||
pk = generate_private_key(key_size=2048, public_exponent=65537)
|
||||
|
||||
me = Name(
|
||||
[
|
||||
NameAttribute(NameOID.COMMON_NAME, "A Host, Locally"),
|
||||
NameAttribute(NameOID.COUNTRY_NAME, "TR"),
|
||||
NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "Çorum"),
|
||||
NameAttribute(NameOID.LOCALITY_NAME, "Başmakçı"),
|
||||
NameAttribute(NameOID.ORGANIZATION_NAME, "Twisted Matrix Labs"),
|
||||
NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, "Automated Testing Authority"),
|
||||
NameAttribute(NameOID.EMAIL_ADDRESS, "security@twistedmatrix.com"),
|
||||
]
|
||||
)
|
||||
|
||||
certificate_bytes = (
|
||||
CertificateBuilder()
|
||||
.serial_number(random_serial_number())
|
||||
.not_valid_before(datetime.now())
|
||||
.not_valid_after(datetime.now() + timedelta(seconds=60 * 60 * 24 * 365 * 100))
|
||||
.subject_name(me)
|
||||
.add_extension(SubjectAlternativeName([DNSName("localhost")]), critical=False)
|
||||
.issuer_name(me)
|
||||
.public_key(pk.public_key())
|
||||
.sign(pk, algorithm=SHA256())
|
||||
).public_bytes(Encoding.PEM)
|
||||
|
||||
privkey_bytes = pk.private_bytes(
|
||||
Encoding.PEM, PrivateFormat.TraditionalOpenSSL, NoEncryption()
|
||||
)
|
||||
|
||||
import __main__
|
||||
|
||||
source = getsource(__main__)
|
||||
source = source.split("\n" + "-" * 5)[0].rsplit("\n", 1)[0]
|
||||
with open("server.pem", "w") as fObj:
|
||||
fObj.write(source)
|
||||
fObj.write("\n")
|
||||
fObj.write('"""\n')
|
||||
fObj.write(privkey_bytes.decode("ascii"))
|
||||
fObj.write(certificate_bytes.decode("ascii"))
|
||||
fObj.write('"""\n')
|
||||
with open(b"key.pem.no_trailing_newline", "w") as fObj:
|
||||
fObj.write(privkey_bytes.decode("ascii").rstrip("\n"))
|
||||
with open(b"cert.pem.no_trailing_newline", "w") as fObj:
|
||||
fObj.write(certificate_bytes.decode("ascii").rstrip("\n"))
|
||||
|
||||
"""
|
||||
-----BEGIN RSA PRIVATE KEY-----
|
||||
MIIEogIBAAKCAQEA0rT5+hF+1BjE7qXms9PZWHskXZGXLPiYVmiYjsVeJAOtHAYq
|
||||
8igzA49KgR1xR9M4jQ6U46nwPsnGCh4liyxdWkBLw9maxMoE+r6dW1zZ8Tllunbd
|
||||
b/Da6L8P55SKb7QGet4CB1fZ2SqZD4GvTby6xpoR09AqrfjuEIYR8V/y+8dG3mR5
|
||||
W0HqaJ58IWihAwIQSakuc8jTadJY55t7UW6Ebj2X2WTO6Zh7gJ1dyHPMVkUHJF9J
|
||||
suj/4F4lx6hWGQzWO8Nf8Q7t364pagE3evUv/BECJLONNYLaFjLtWnsCEJDV9owC
|
||||
jaxu785KuA7OM/f3h3xVIfTBTo2AlHiQnXdyrwIDAQABAoH/Ib7aSjKDHXTaFV58
|
||||
lFBZftI6AMJQc+Ncgno99J+ndB0inFpghmfpw6gvRn5wphAt/mlXbx7IW0X1cali
|
||||
WefBC7NAbx1qrBmusnnUuc0lGn0WzcY7sLHiXWQ8J9qiUUGDyCnGKWbofN9VpCYg
|
||||
7VJMl4IVWNb9/t7fQcY3GXFEeQ4mzLo7p+gPxyeUcCLVrhVrHzw1HFTIlA51LjfI
|
||||
xQM+QVeaEWQQ4UsDdPe5iGthDd7ze2F5ciDzMkShrf7URSudS+Us6vr6gDVpKAky
|
||||
eCVyFPJXCfH4qJoa6mB6L6SFzMnN3OPp3RlYQWQ7sK/ELQfhPoyHyRvL1woUIO5C
|
||||
tK0pAoGBAPS6ZSZ26M0guZ2K/2fKMiGq0jZQLcxP3N0jWm8R8ENOnuIjhCl5aKsB
|
||||
DoV0BvPv1C2vWm+VgNArgTece9l8o5f8pcfjbT5r/k8zoqgcj9CmmDofBka4XxZb
|
||||
wxsut+8rBSIoVKIre4Pyqfa9u1IrEnoOzMqvF16xUME2t2EaryUzAoGBANxpb4Jz
|
||||
FjH7nfPc3iejd+cXovX6x2VTJzWaknA6hGsoc+UZ01KTaKyYpq+9q9VxXhWxYsh3
|
||||
TL1JWuIBy6ao5tdt4nPBu07J7tfu5bfr3Imd8waNQxDEfKeFedskxORs+FIUzqBb
|
||||
3nIkQH8sx0Syv620coIdtEn1raVXc9QfRgSVAoGAWNFhLoGPYgsTcnrk0N1QLmnZ
|
||||
mv6kcHc3mEZhZtgi07qv7TCooYi/lPhwNbzzXQrYfbAbaU3gDy0K24z+YeNbWCjI
|
||||
XfBLUJFPHZ2G1e5vv3EG5GkoFPiLAglRmQbumG2LkmcCuEyBqlSinLslRd/997Bx
|
||||
YMoE+EfwH/9ktGhD0oMCgYEAxaSqAFDQ00ssjTM95k94Qjn4wBf7WwmgfDm6HHbs
|
||||
rOZeXk61JzPVxgcwWSB8iG4bDtq8mMQZhRbVLxqrEiwcq4r2aBSNsI305Z5sUWtn
|
||||
m+ONvA9J1yxKFzHiXjbvc2GfnoLX8gXPR4zoZOGzYg/jP5EyqSiXtUZfSodL7yeH
|
||||
8q0CgYEA2OzA59AITJe8jhC5JsVbLs7Rj4kFTjD+iZ8P86FnWBf1iDeuywEZJqvG
|
||||
n6SNK4KczDJ//DBV06w4L6iwe5iOCdf06+V7Hnkbvrjk0ONnXX7VXNgJ3/e7aJTx
|
||||
gE42Ug0qu6lXtEfYqlhQoF2lAtnYq0fty/XWMVfpjVuh1lyd4C4=
|
||||
-----END RSA PRIVATE KEY-----
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIEJDCCAwygAwIBAgIUKaSXgzt5gDMt9GbUzLz/A9HEyFEwDQYJKoZIhvcNAQEL
|
||||
BQAwgb0xGDAWBgNVBAMMD0EgSG9zdCwgTG9jYWxseTELMAkGA1UEBhMCVFIxDzAN
|
||||
BgNVBAgMBsOHb3J1bTEUMBIGA1UEBwwLQmHFn21ha8OnxLExHDAaBgNVBAoME1R3
|
||||
aXN0ZWQgTWF0cml4IExhYnMxJDAiBgNVBAsMG0F1dG9tYXRlZCBUZXN0aW5nIEF1
|
||||
dGhvcml0eTEpMCcGCSqGSIb3DQEJARYac2VjdXJpdHlAdHdpc3RlZG1hdHJpeC5j
|
||||
b20wIBcNMjMwNjE0MTM0MDI4WhgPMjEyMzA1MjExMzQwMjhaMIG9MRgwFgYDVQQD
|
||||
DA9BIEhvc3QsIExvY2FsbHkxCzAJBgNVBAYTAlRSMQ8wDQYDVQQIDAbDh29ydW0x
|
||||
FDASBgNVBAcMC0JhxZ9tYWvDp8SxMRwwGgYDVQQKDBNUd2lzdGVkIE1hdHJpeCBM
|
||||
YWJzMSQwIgYDVQQLDBtBdXRvbWF0ZWQgVGVzdGluZyBBdXRob3JpdHkxKTAnBgkq
|
||||
hkiG9w0BCQEWGnNlY3VyaXR5QHR3aXN0ZWRtYXRyaXguY29tMIIBIjANBgkqhkiG
|
||||
9w0BAQEFAAOCAQ8AMIIBCgKCAQEA0rT5+hF+1BjE7qXms9PZWHskXZGXLPiYVmiY
|
||||
jsVeJAOtHAYq8igzA49KgR1xR9M4jQ6U46nwPsnGCh4liyxdWkBLw9maxMoE+r6d
|
||||
W1zZ8Tllunbdb/Da6L8P55SKb7QGet4CB1fZ2SqZD4GvTby6xpoR09AqrfjuEIYR
|
||||
8V/y+8dG3mR5W0HqaJ58IWihAwIQSakuc8jTadJY55t7UW6Ebj2X2WTO6Zh7gJ1d
|
||||
yHPMVkUHJF9Jsuj/4F4lx6hWGQzWO8Nf8Q7t364pagE3evUv/BECJLONNYLaFjLt
|
||||
WnsCEJDV9owCjaxu785KuA7OM/f3h3xVIfTBTo2AlHiQnXdyrwIDAQABoxgwFjAU
|
||||
BgNVHREEDTALgglsb2NhbGhvc3QwDQYJKoZIhvcNAQELBQADggEBAEHAErq/Fs8h
|
||||
M+kwGCt5Ochqyu/IzPbwgQ27n5IJehl7kmpoXBxGa/u+ajoxrZaOheg8E2MYVwQi
|
||||
FTKE9wJgaN3uGo4bzCbCYxDm7tflQORo6QOZlumfiQIzXON2RvgJpwFfkLNtq0t9
|
||||
e453kJ7+e11Wah46bc3RAvBZpwswh6hDv2FvFUZ+IUcO0tU8O4kWrLIFPpJbcHQq
|
||||
wezjky773X4CNEtoeuTb8/ws/eED/TGZ2AZO+BWT93OZJgwE2x3iUd3k8HbwxfoY
|
||||
bZ+NHgtM7iKRcL59asB0OMi3Ays0+IOfZ1+3aB82zYlxFBoDyalR7NJjJGdTwNFt
|
||||
3CPGCQ28cDk=
|
||||
-----END CERTIFICATE-----
|
||||
"""
|
||||
@@ -0,0 +1,66 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Helper classes for twisted.test.test_ssl.
|
||||
|
||||
They are in a separate module so they will not prevent test_ssl importing if
|
||||
pyOpenSSL is unavailable.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from OpenSSL import SSL
|
||||
|
||||
from twisted.internet import ssl
|
||||
from twisted.python.compat import nativeString
|
||||
from twisted.python.filepath import FilePath
|
||||
|
||||
certPath = nativeString(FilePath(__file__.encode("utf-8")).sibling(b"server.pem").path)
|
||||
|
||||
|
||||
class ClientTLSContext(ssl.ClientContextFactory):
|
||||
"""
|
||||
SSL Context Factory for client-side connections.
|
||||
"""
|
||||
|
||||
isClient = 1
|
||||
|
||||
def getContext(self) -> SSL.Context:
|
||||
"""
|
||||
Return an L{SSL.Context} to be use for client-side connections.
|
||||
|
||||
Will not return a cached context.
|
||||
This is done to improve the test coverage as most implementation
|
||||
are caching the context.
|
||||
"""
|
||||
return SSL.Context(SSL.SSLv23_METHOD)
|
||||
|
||||
|
||||
class ServerTLSContext:
|
||||
"""
|
||||
SSL Context Factory for server-side connections.
|
||||
"""
|
||||
|
||||
isClient = 0
|
||||
|
||||
def __init__(
|
||||
self, filename: str | bytes = certPath, method: int | None = None
|
||||
) -> None:
|
||||
self.filename = filename
|
||||
if method is None:
|
||||
method = SSL.SSLv23_METHOD
|
||||
|
||||
self._method = method
|
||||
|
||||
def getContext(self) -> SSL.Context:
|
||||
"""
|
||||
Return an L{SSL.Context} to be use for server-side connections.
|
||||
|
||||
Will not return a cached context.
|
||||
This is done to improve the test coverage as most implementation
|
||||
are caching the context.
|
||||
"""
|
||||
ctx = SSL.Context(self._method)
|
||||
ctx.use_certificate_file(self.filename)
|
||||
ctx.use_privatekey_file(self.filename)
|
||||
return ctx
|
||||
@@ -0,0 +1,44 @@
|
||||
# -*- test-case-name: twisted.test.test_stdio.StandardInputOutputTests.test_consumer -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Main program for the child process run by
|
||||
L{twisted.test.test_stdio.StandardInputOutputTests.test_consumer} to test
|
||||
that process transports implement IConsumer properly.
|
||||
"""
|
||||
|
||||
|
||||
import sys
|
||||
|
||||
from twisted.internet import protocol, stdio
|
||||
from twisted.protocols import basic
|
||||
from twisted.python import log, reflect
|
||||
|
||||
|
||||
def failed(err):
|
||||
log.startLogging(sys.stderr)
|
||||
log.err(err)
|
||||
|
||||
|
||||
class ConsumerChild(protocol.Protocol):
|
||||
def __init__(self, junkPath):
|
||||
self.junkPath = junkPath
|
||||
|
||||
def connectionMade(self):
|
||||
d = basic.FileSender().beginFileTransfer(
|
||||
open(self.junkPath, "rb"), self.transport
|
||||
)
|
||||
d.addErrback(failed)
|
||||
d.addCallback(lambda ign: self.transport.loseConnection())
|
||||
|
||||
def connectionLost(self, reason):
|
||||
reactor.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
reflect.namedAny(sys.argv[1]).install()
|
||||
from twisted.internet import reactor
|
||||
|
||||
stdio.StandardIO(ConsumerChild(sys.argv[2]))
|
||||
reactor.run() # type: ignore[attr-defined]
|
||||
@@ -0,0 +1,68 @@
|
||||
# -*- test-case-name: twisted.test.test_stdio.StandardInputOutputTests.test_readConnectionLost -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Main program for the child process run by
|
||||
L{twisted.test.test_stdio.StandardInputOutputTests.test_readConnectionLost} to
|
||||
test that IHalfCloseableProtocol.readConnectionLost works for stdio transports.
|
||||
"""
|
||||
|
||||
|
||||
import sys
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.internet import protocol, stdio
|
||||
from twisted.internet.interfaces import IHalfCloseableProtocol
|
||||
from twisted.python import log, reflect
|
||||
|
||||
|
||||
@implementer(IHalfCloseableProtocol)
|
||||
class HalfCloseProtocol(protocol.Protocol):
|
||||
"""
|
||||
A protocol to hook up to stdio and observe its transport being
|
||||
half-closed. If all goes as expected, C{exitCode} will be set to C{0};
|
||||
otherwise it will be set to C{1} to indicate failure.
|
||||
"""
|
||||
|
||||
exitCode = None
|
||||
|
||||
def connectionMade(self):
|
||||
"""
|
||||
Signal the parent process that we're ready.
|
||||
"""
|
||||
self.transport.write(b"x")
|
||||
|
||||
def readConnectionLost(self):
|
||||
"""
|
||||
This is the desired event. Once it has happened, stop the reactor so
|
||||
the process will exit.
|
||||
"""
|
||||
self.exitCode = 0
|
||||
reactor.stop()
|
||||
|
||||
def connectionLost(self, reason):
|
||||
"""
|
||||
This may only be invoked after C{readConnectionLost}. If it happens
|
||||
otherwise, mark it as an error and shut down.
|
||||
"""
|
||||
if self.exitCode is None:
|
||||
self.exitCode = 1
|
||||
log.err(reason, "Unexpected call to connectionLost")
|
||||
reactor.stop()
|
||||
|
||||
def writeConnectionLost(self):
|
||||
# IHalfCloseableProtocol.writeConnectionLost
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
reflect.namedAny(sys.argv[1]).install()
|
||||
log.startLogging(open(sys.argv[2], "wb"))
|
||||
from twisted.internet import reactor
|
||||
|
||||
halfCloseProtocol = HalfCloseProtocol()
|
||||
stdio.StandardIO(halfCloseProtocol)
|
||||
reactor.run() # type: ignore[attr-defined]
|
||||
sys.exit(halfCloseProtocol.exitCode)
|
||||
@@ -0,0 +1,67 @@
|
||||
# -*- test-case-name: twisted.test.test_stdio.StandardInputOutputTests.test_buggyReadConnectionLost -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Main program for the child process run by
|
||||
L{twisted.test.test_stdio.StandardInputOutputTests.test_readConnectionLost} to
|
||||
test that IHalfCloseableProtocol.readConnectionLost works for stdio transports.
|
||||
"""
|
||||
|
||||
|
||||
import sys
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.internet import protocol, stdio
|
||||
from twisted.internet.interfaces import IHalfCloseableProtocol, IReactorCore, ITransport
|
||||
from twisted.python import log, reflect
|
||||
|
||||
|
||||
@implementer(IHalfCloseableProtocol)
|
||||
class HalfCloseProtocol(protocol.Protocol):
|
||||
"""
|
||||
A protocol to hook up to stdio and observe its transport being
|
||||
half-closed. If all goes as expected, C{exitCode} will be set to C{0};
|
||||
otherwise it will be set to C{1} to indicate failure.
|
||||
"""
|
||||
|
||||
exitCode = None
|
||||
transport: ITransport
|
||||
|
||||
def connectionMade(self) -> None:
|
||||
"""
|
||||
Signal the parent process that we're ready.
|
||||
"""
|
||||
self.transport.write(b"x")
|
||||
|
||||
def readConnectionLost(self) -> None:
|
||||
"""
|
||||
This is the desired event. Once it has happened, stop the reactor so
|
||||
the process will exit.
|
||||
"""
|
||||
raise ValueError("something went wrong")
|
||||
|
||||
def connectionLost(self, reason: object = None) -> None:
|
||||
"""
|
||||
This may only be invoked after C{readConnectionLost}. If it happens
|
||||
otherwise, mark it as an error and shut down.
|
||||
"""
|
||||
self.exitCode = 0
|
||||
reactor.stop()
|
||||
|
||||
def writeConnectionLost(self) -> None:
|
||||
# IHalfCloseableProtocol.writeConnectionLost
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__": # pragma: no branch
|
||||
reflect.namedAny(sys.argv[1]).install()
|
||||
log.startLogging(open(sys.argv[2], "w"))
|
||||
reactor: IReactorCore
|
||||
from twisted.internet import reactor # type:ignore[assignment]
|
||||
|
||||
halfCloseProtocol = HalfCloseProtocol()
|
||||
stdio.StandardIO(halfCloseProtocol)
|
||||
reactor.run()
|
||||
sys.exit(halfCloseProtocol.exitCode)
|
||||
@@ -0,0 +1,69 @@
|
||||
# -*- test-case-name: twisted.test.test_stdio.StandardInputOutputTests.test_buggyWriteConnectionLost -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Main program for the child process run by
|
||||
L{twisted.test.test_stdio.StandardInputOutputTests.test_buggyWriteConnectionLost}
|
||||
to test that IHalfCloseableProtocol.writeConnectionLost works for stdio
|
||||
transports.
|
||||
"""
|
||||
|
||||
|
||||
import sys
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.internet import protocol, stdio
|
||||
from twisted.internet.interfaces import IHalfCloseableProtocol, IReactorCore, ITransport
|
||||
from twisted.python import log, reflect
|
||||
|
||||
|
||||
@implementer(IHalfCloseableProtocol)
|
||||
class HalfCloseProtocol(protocol.Protocol):
|
||||
"""
|
||||
A protocol to hook up to stdio and observe its transport being
|
||||
half-closed. If all goes as expected, C{exitCode} will be set to C{0};
|
||||
otherwise it will be set to C{1} to indicate failure.
|
||||
"""
|
||||
|
||||
exitCode = 9
|
||||
wasWriteConnectionLost = False
|
||||
transport: ITransport
|
||||
|
||||
def connectionMade(self) -> None:
|
||||
"""
|
||||
Signal the parent process that we're ready.
|
||||
"""
|
||||
self.transport.write(b"x")
|
||||
|
||||
def readConnectionLost(self) -> None:
|
||||
"""
|
||||
This is the desired event. Once it has happened, stop the reactor so
|
||||
the process will exit.
|
||||
"""
|
||||
|
||||
def connectionLost(self, reason: object = None) -> None:
|
||||
"""
|
||||
This may only be invoked after C{readConnectionLost}. If it happens
|
||||
otherwise, mark it as an error and shut down.
|
||||
"""
|
||||
if self.wasWriteConnectionLost: # pragma: no branch
|
||||
self.exitCode = 0
|
||||
reactor.stop()
|
||||
|
||||
def writeConnectionLost(self) -> None:
|
||||
self.wasWriteConnectionLost = True
|
||||
raise ValueError("something went wrong")
|
||||
|
||||
|
||||
if __name__ == "__main__": # pragma: no branch
|
||||
reflect.namedAny(sys.argv[1]).install()
|
||||
log.startLogging(open(sys.argv[2], "w"))
|
||||
reactor: IReactorCore
|
||||
from twisted.internet import reactor # type:ignore[assignment]
|
||||
|
||||
halfCloseProtocol = HalfCloseProtocol()
|
||||
stdio.StandardIO(halfCloseProtocol)
|
||||
reactor.run()
|
||||
sys.exit(halfCloseProtocol.exitCode)
|
||||
@@ -0,0 +1,39 @@
|
||||
# -*- test-case-name: twisted.test.test_stdio.StandardInputOutputTests.test_hostAndPeer -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Main program for the child process run by
|
||||
L{twisted.test.test_stdio.StandardInputOutputTests.test_hostAndPeer} to test
|
||||
that ITransport.getHost() and ITransport.getPeer() work for process transports.
|
||||
"""
|
||||
|
||||
|
||||
import sys
|
||||
|
||||
from twisted.internet import protocol, stdio
|
||||
from twisted.python import reflect
|
||||
|
||||
|
||||
class HostPeerChild(protocol.Protocol):
|
||||
def connectionMade(self):
|
||||
self.transport.write(
|
||||
b"\n".join(
|
||||
[
|
||||
str(self.transport.getHost()).encode("ascii"),
|
||||
str(self.transport.getPeer()).encode("ascii"),
|
||||
]
|
||||
)
|
||||
)
|
||||
self.transport.loseConnection()
|
||||
|
||||
def connectionLost(self, reason):
|
||||
reactor.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
reflect.namedAny(sys.argv[1]).install()
|
||||
from twisted.internet import reactor
|
||||
|
||||
stdio.StandardIO(HostPeerChild())
|
||||
reactor.run() # type: ignore[attr-defined]
|
||||
@@ -0,0 +1,43 @@
|
||||
# -*- test-case-name: twisted.test.test_stdio.StandardInputOutputTests.test_lastWriteReceived -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Main program for the child process run by
|
||||
L{twisted.test.test_stdio.StandardInputOutputTests.test_lastWriteReceived}
|
||||
to test that L{os.write} can be reliably used after
|
||||
L{twisted.internet.stdio.StandardIO} has finished.
|
||||
"""
|
||||
|
||||
|
||||
import sys
|
||||
|
||||
from twisted.internet.protocol import Protocol
|
||||
from twisted.internet.stdio import StandardIO
|
||||
from twisted.python.reflect import namedAny
|
||||
|
||||
|
||||
class LastWriteChild(Protocol):
|
||||
def __init__(self, reactor, magicString):
|
||||
self.reactor = reactor
|
||||
self.magicString = magicString
|
||||
|
||||
def connectionMade(self):
|
||||
self.transport.write(self.magicString)
|
||||
self.transport.loseConnection()
|
||||
|
||||
def connectionLost(self, reason):
|
||||
self.reactor.stop()
|
||||
|
||||
|
||||
def main(reactor, magicString):
|
||||
p = LastWriteChild(reactor, magicString.encode("ascii"))
|
||||
StandardIO(p)
|
||||
reactor.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
namedAny(sys.argv[1]).install()
|
||||
from twisted.internet import reactor
|
||||
|
||||
main(reactor, sys.argv[2])
|
||||
@@ -0,0 +1,50 @@
|
||||
# -*- test-case-name: twisted.test.test_stdio.StandardInputOutputTests.test_loseConnection -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Main program for the child process run by
|
||||
L{twisted.test.test_stdio.StandardInputOutputTests.test_loseConnection} to
|
||||
test that ITransport.loseConnection() works for process transports.
|
||||
"""
|
||||
|
||||
|
||||
import sys
|
||||
|
||||
from twisted.internet import protocol, stdio
|
||||
from twisted.internet.error import ConnectionDone
|
||||
from twisted.python import log, reflect
|
||||
|
||||
|
||||
class LoseConnChild(protocol.Protocol):
|
||||
exitCode = 0
|
||||
|
||||
def connectionMade(self):
|
||||
self.transport.loseConnection()
|
||||
|
||||
def connectionLost(self, reason):
|
||||
"""
|
||||
Check that C{reason} is a L{Failure} wrapping a L{ConnectionDone}
|
||||
instance and stop the reactor. If C{reason} is wrong for some reason,
|
||||
log something about that in C{self.errorLogFile} and make sure the
|
||||
process exits with a non-zero status.
|
||||
"""
|
||||
try:
|
||||
try:
|
||||
reason.trap(ConnectionDone)
|
||||
except BaseException:
|
||||
log.err(None, "Problem with reason passed to connectionLost")
|
||||
self.exitCode = 1
|
||||
finally:
|
||||
reactor.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
reflect.namedAny(sys.argv[1]).install()
|
||||
log.startLogging(open(sys.argv[2], "wb"))
|
||||
from twisted.internet import reactor
|
||||
|
||||
protocolLoseConnChild = LoseConnChild()
|
||||
stdio.StandardIO(protocolLoseConnChild)
|
||||
reactor.run() # type: ignore[attr-defined]
|
||||
sys.exit(protocolLoseConnChild.exitCode)
|
||||
@@ -0,0 +1,54 @@
|
||||
# -*- test-case-name: twisted.test.test_stdio.StandardInputOutputTests.test_producer -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Main program for the child process run by
|
||||
L{twisted.test.test_stdio.StandardInputOutputTests.test_producer} to test
|
||||
that process transports implement IProducer properly.
|
||||
"""
|
||||
|
||||
|
||||
import sys
|
||||
|
||||
from twisted.internet import protocol, stdio
|
||||
from twisted.python import log, reflect
|
||||
|
||||
|
||||
class ProducerChild(protocol.Protocol):
|
||||
_paused = False
|
||||
buf = b""
|
||||
|
||||
def connectionLost(self, reason):
|
||||
log.msg("*****OVER*****")
|
||||
reactor.callLater(1, reactor.stop)
|
||||
|
||||
def dataReceived(self, data):
|
||||
self.buf += data
|
||||
if self._paused:
|
||||
log.startLogging(sys.stderr)
|
||||
log.msg("dataReceived while transport paused!")
|
||||
self.transport.loseConnection()
|
||||
else:
|
||||
self.transport.write(data)
|
||||
if self.buf.endswith(b"\n0\n"):
|
||||
self.transport.loseConnection()
|
||||
else:
|
||||
self.pause()
|
||||
|
||||
def pause(self):
|
||||
self._paused = True
|
||||
self.transport.pauseProducing()
|
||||
reactor.callLater(0.01, self.unpause)
|
||||
|
||||
def unpause(self):
|
||||
self._paused = False
|
||||
self.transport.resumeProducing()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
reflect.namedAny(sys.argv[1]).install()
|
||||
from twisted.internet import reactor
|
||||
|
||||
stdio.StandardIO(ProducerChild())
|
||||
reactor.run() # type: ignore[attr-defined]
|
||||
@@ -0,0 +1,34 @@
|
||||
# -*- test-case-name: twisted.test.test_stdio.StandardInputOutputTests.test_write -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Main program for the child process run by
|
||||
L{twisted.test.test_stdio.StandardInputOutputTests.test_write} to test that
|
||||
ITransport.write() works for process transports.
|
||||
"""
|
||||
|
||||
|
||||
import sys
|
||||
|
||||
from twisted.internet import protocol, stdio
|
||||
from twisted.python import reflect
|
||||
|
||||
|
||||
class WriteChild(protocol.Protocol):
|
||||
def connectionMade(self):
|
||||
self.transport.write(b"o")
|
||||
self.transport.write(b"k")
|
||||
self.transport.write(b"!")
|
||||
self.transport.loseConnection()
|
||||
|
||||
def connectionLost(self, reason):
|
||||
reactor.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
reflect.namedAny(sys.argv[1]).install()
|
||||
from twisted.internet import reactor
|
||||
|
||||
stdio.StandardIO(WriteChild())
|
||||
reactor.run() # type: ignore[attr-defined]
|
||||
@@ -0,0 +1,32 @@
|
||||
# -*- test-case-name: twisted.test.test_stdio.StandardInputOutputTests.test_writeSequence -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Main program for the child process run by
|
||||
L{twisted.test.test_stdio.StandardInputOutputTests.test_writeSequence} to test
|
||||
that ITransport.writeSequence() works for process transports.
|
||||
"""
|
||||
|
||||
|
||||
import sys
|
||||
|
||||
from twisted.internet import protocol, stdio
|
||||
from twisted.python import reflect
|
||||
|
||||
|
||||
class WriteSequenceChild(protocol.Protocol):
|
||||
def connectionMade(self):
|
||||
self.transport.writeSequence([b"o", b"k", b"!"])
|
||||
self.transport.loseConnection()
|
||||
|
||||
def connectionLost(self, reason):
|
||||
reactor.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
reflect.namedAny(sys.argv[1]).install()
|
||||
from twisted.internet import reactor
|
||||
|
||||
stdio.StandardIO(WriteSequenceChild())
|
||||
reactor.run() # type: ignore[attr-defined]
|
||||
106
.venv/lib/python3.12/site-packages/twisted/test/test_abstract.py
Normal file
106
.venv/lib/python3.12/site-packages/twisted/test/test_abstract.py
Normal file
@@ -0,0 +1,106 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for generic file descriptor based reactor support code.
|
||||
"""
|
||||
|
||||
|
||||
from socket import AF_IPX
|
||||
|
||||
from twisted.internet.abstract import isIPAddress
|
||||
from twisted.trial.unittest import TestCase
|
||||
|
||||
|
||||
class AddressTests(TestCase):
|
||||
"""
|
||||
Tests for address-related functionality.
|
||||
"""
|
||||
|
||||
def test_decimalDotted(self) -> None:
|
||||
"""
|
||||
L{isIPAddress} should return C{True} for any decimal dotted
|
||||
representation of an IPv4 address.
|
||||
"""
|
||||
self.assertTrue(isIPAddress("0.1.2.3"))
|
||||
self.assertTrue(isIPAddress("252.253.254.255"))
|
||||
|
||||
def test_shortDecimalDotted(self) -> None:
|
||||
"""
|
||||
L{isIPAddress} should return C{False} for a dotted decimal
|
||||
representation with fewer or more than four octets.
|
||||
"""
|
||||
self.assertFalse(isIPAddress("0"))
|
||||
self.assertFalse(isIPAddress("0.1"))
|
||||
self.assertFalse(isIPAddress("0.1.2"))
|
||||
self.assertFalse(isIPAddress("0.1.2.3.4"))
|
||||
|
||||
def test_invalidLetters(self) -> None:
|
||||
"""
|
||||
L{isIPAddress} should return C{False} for any non-decimal dotted
|
||||
representation including letters.
|
||||
"""
|
||||
self.assertFalse(isIPAddress("a.2.3.4"))
|
||||
self.assertFalse(isIPAddress("1.b.3.4"))
|
||||
|
||||
def test_invalidPunctuation(self) -> None:
|
||||
"""
|
||||
L{isIPAddress} should return C{False} for a string containing
|
||||
strange punctuation.
|
||||
"""
|
||||
self.assertFalse(isIPAddress(","))
|
||||
self.assertFalse(isIPAddress("1,2"))
|
||||
self.assertFalse(isIPAddress("1,2,3"))
|
||||
self.assertFalse(isIPAddress("1.,.3,4"))
|
||||
|
||||
def test_emptyString(self) -> None:
|
||||
"""
|
||||
L{isIPAddress} should return C{False} for the empty string.
|
||||
"""
|
||||
self.assertFalse(isIPAddress(""))
|
||||
|
||||
def test_invalidNegative(self) -> None:
|
||||
"""
|
||||
L{isIPAddress} should return C{False} for negative decimal values.
|
||||
"""
|
||||
self.assertFalse(isIPAddress("-1"))
|
||||
self.assertFalse(isIPAddress("1.-2"))
|
||||
self.assertFalse(isIPAddress("1.2.-3"))
|
||||
self.assertFalse(isIPAddress("1.2.-3.4"))
|
||||
|
||||
def test_invalidPositive(self) -> None:
|
||||
"""
|
||||
L{isIPAddress} should return C{False} for a string containing
|
||||
positive decimal values greater than 255.
|
||||
"""
|
||||
self.assertFalse(isIPAddress("256.0.0.0"))
|
||||
self.assertFalse(isIPAddress("0.256.0.0"))
|
||||
self.assertFalse(isIPAddress("0.0.256.0"))
|
||||
self.assertFalse(isIPAddress("0.0.0.256"))
|
||||
self.assertFalse(isIPAddress("256.256.256.256"))
|
||||
|
||||
def test_unicodeAndBytes(self) -> None:
|
||||
"""
|
||||
L{isIPAddress} evaluates ASCII-encoded bytes as well as text.
|
||||
"""
|
||||
# we test passing bytes but don't support bytes in the type annotation
|
||||
self.assertFalse(isIPAddress(b"256.0.0.0")) # type: ignore[arg-type]
|
||||
self.assertFalse(isIPAddress("256.0.0.0"))
|
||||
self.assertTrue(isIPAddress(b"252.253.254.255")) # type: ignore[arg-type]
|
||||
self.assertTrue(isIPAddress("252.253.254.255"))
|
||||
|
||||
def test_nonIPAddressFamily(self) -> None:
|
||||
"""
|
||||
All address families other than C{AF_INET} and C{AF_INET6} result in a
|
||||
L{ValueError} being raised.
|
||||
"""
|
||||
self.assertRaises(ValueError, isIPAddress, b"anything", AF_IPX)
|
||||
|
||||
def test_nonASCII(self) -> None:
|
||||
"""
|
||||
All IP addresses must be encodable as ASCII; non-ASCII should result in
|
||||
a L{False} result.
|
||||
"""
|
||||
# we test passing bytes but don't support bytes in the type annotation
|
||||
self.assertFalse(isIPAddress(b"\xff.notascii")) # type: ignore[arg-type]
|
||||
self.assertFalse(isIPAddress("\u4321.notascii"))
|
||||
869
.venv/lib/python3.12/site-packages/twisted/test/test_adbapi.py
Normal file
869
.venv/lib/python3.12/site-packages/twisted/test/test_adbapi.py
Normal file
@@ -0,0 +1,869 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for twisted.enterprise.adbapi.
|
||||
"""
|
||||
|
||||
import os
|
||||
import stat
|
||||
from typing import Dict, Optional
|
||||
|
||||
from twisted.enterprise.adbapi import (
|
||||
Connection,
|
||||
ConnectionLost,
|
||||
ConnectionPool,
|
||||
Transaction,
|
||||
)
|
||||
from twisted.internet import defer, interfaces, reactor
|
||||
from twisted.python.failure import Failure
|
||||
from twisted.python.reflect import requireModule
|
||||
from twisted.trial import unittest
|
||||
|
||||
simple_table_schema = """
|
||||
CREATE TABLE simple (
|
||||
x integer
|
||||
)
|
||||
"""
|
||||
|
||||
|
||||
class ADBAPITestBase:
|
||||
"""
|
||||
Test the asynchronous DB-API code.
|
||||
"""
|
||||
|
||||
openfun_called: Dict[object, bool] = {}
|
||||
|
||||
if interfaces.IReactorThreads(reactor, None) is None:
|
||||
skip = "ADB-API requires threads, no way to test without them"
|
||||
|
||||
def extraSetUp(self):
|
||||
"""
|
||||
Set up the database and create a connection pool pointing at it.
|
||||
"""
|
||||
self.startDB()
|
||||
self.dbpool = self.makePool(cp_openfun=self.openfun)
|
||||
self.dbpool.start()
|
||||
|
||||
def tearDown(self):
|
||||
d = self.dbpool.runOperation("DROP TABLE simple")
|
||||
d.addCallback(lambda res: self.dbpool.close())
|
||||
d.addCallback(lambda res: self.stopDB())
|
||||
return d
|
||||
|
||||
def openfun(self, conn):
|
||||
self.openfun_called[conn] = True
|
||||
|
||||
def checkOpenfunCalled(self, conn=None):
|
||||
if not conn:
|
||||
self.assertTrue(self.openfun_called)
|
||||
else:
|
||||
self.assertIn(conn, self.openfun_called)
|
||||
|
||||
def test_pool(self):
|
||||
d = self.dbpool.runOperation(simple_table_schema)
|
||||
if self.test_failures:
|
||||
d.addCallback(self._testPool_1_1)
|
||||
d.addCallback(self._testPool_1_2)
|
||||
d.addCallback(self._testPool_1_3)
|
||||
d.addCallback(self._testPool_1_4)
|
||||
d.addCallback(lambda res: self.flushLoggedErrors())
|
||||
d.addCallback(self._testPool_2)
|
||||
d.addCallback(self._testPool_3)
|
||||
d.addCallback(self._testPool_4)
|
||||
d.addCallback(self._testPool_5)
|
||||
d.addCallback(self._testPool_6)
|
||||
d.addCallback(self._testPool_7)
|
||||
d.addCallback(self._testPool_8)
|
||||
d.addCallback(self._testPool_9)
|
||||
return d
|
||||
|
||||
def _testPool_1_1(self, res):
|
||||
d = defer.maybeDeferred(self.dbpool.runQuery, "select * from NOTABLE")
|
||||
d.addCallbacks(lambda res: self.fail("no exception"), lambda f: None)
|
||||
return d
|
||||
|
||||
def _testPool_1_2(self, res):
|
||||
d = defer.maybeDeferred(self.dbpool.runOperation, "deletexxx from NOTABLE")
|
||||
d.addCallbacks(lambda res: self.fail("no exception"), lambda f: None)
|
||||
return d
|
||||
|
||||
def _testPool_1_3(self, res):
|
||||
d = defer.maybeDeferred(self.dbpool.runInteraction, self.bad_interaction)
|
||||
d.addCallbacks(lambda res: self.fail("no exception"), lambda f: None)
|
||||
return d
|
||||
|
||||
def _testPool_1_4(self, res):
|
||||
d = defer.maybeDeferred(self.dbpool.runWithConnection, self.bad_withConnection)
|
||||
d.addCallbacks(lambda res: self.fail("no exception"), lambda f: None)
|
||||
return d
|
||||
|
||||
def _testPool_2(self, res):
|
||||
# verify simple table is empty
|
||||
sql = "select count(1) from simple"
|
||||
d = self.dbpool.runQuery(sql)
|
||||
|
||||
def _check(row):
|
||||
self.assertTrue(int(row[0][0]) == 0, "Interaction not rolled back")
|
||||
self.checkOpenfunCalled()
|
||||
|
||||
d.addCallback(_check)
|
||||
return d
|
||||
|
||||
def _testPool_3(self, res):
|
||||
sql = "select count(1) from simple"
|
||||
inserts = []
|
||||
# add some rows to simple table (runOperation)
|
||||
for i in range(self.num_iterations):
|
||||
sql = "insert into simple(x) values(%d)" % i
|
||||
inserts.append(self.dbpool.runOperation(sql))
|
||||
d = defer.gatherResults(inserts)
|
||||
|
||||
def _select(res):
|
||||
# make sure they were added (runQuery)
|
||||
sql = "select x from simple order by x"
|
||||
d = self.dbpool.runQuery(sql)
|
||||
return d
|
||||
|
||||
d.addCallback(_select)
|
||||
|
||||
def _check(rows):
|
||||
self.assertTrue(len(rows) == self.num_iterations, "Wrong number of rows")
|
||||
for i in range(self.num_iterations):
|
||||
self.assertTrue(len(rows[i]) == 1, "Wrong size row")
|
||||
self.assertTrue(rows[i][0] == i, "Values not returned.")
|
||||
|
||||
d.addCallback(_check)
|
||||
|
||||
return d
|
||||
|
||||
def _testPool_4(self, res):
|
||||
# runInteraction
|
||||
d = self.dbpool.runInteraction(self.interaction)
|
||||
d.addCallback(lambda res: self.assertEqual(res, "done"))
|
||||
return d
|
||||
|
||||
def _testPool_5(self, res):
|
||||
# withConnection
|
||||
d = self.dbpool.runWithConnection(self.withConnection)
|
||||
d.addCallback(lambda res: self.assertEqual(res, "done"))
|
||||
return d
|
||||
|
||||
def _testPool_6(self, res):
|
||||
# Test a withConnection cannot be closed
|
||||
d = self.dbpool.runWithConnection(self.close_withConnection)
|
||||
return d
|
||||
|
||||
def _testPool_7(self, res):
|
||||
# give the pool a workout
|
||||
ds = []
|
||||
for i in range(self.num_iterations):
|
||||
sql = "select x from simple where x = %d" % i
|
||||
ds.append(self.dbpool.runQuery(sql))
|
||||
dlist = defer.DeferredList(ds, fireOnOneErrback=True)
|
||||
|
||||
def _check(result):
|
||||
for i in range(self.num_iterations):
|
||||
self.assertTrue(result[i][1][0][0] == i, "Value not returned")
|
||||
|
||||
dlist.addCallback(_check)
|
||||
return dlist
|
||||
|
||||
def _testPool_8(self, res):
|
||||
# now delete everything
|
||||
ds = []
|
||||
for i in range(self.num_iterations):
|
||||
sql = "delete from simple where x = %d" % i
|
||||
ds.append(self.dbpool.runOperation(sql))
|
||||
dlist = defer.DeferredList(ds, fireOnOneErrback=True)
|
||||
return dlist
|
||||
|
||||
def _testPool_9(self, res):
|
||||
# verify simple table is empty
|
||||
sql = "select count(1) from simple"
|
||||
d = self.dbpool.runQuery(sql)
|
||||
|
||||
def _check(row):
|
||||
self.assertTrue(
|
||||
int(row[0][0]) == 0, "Didn't successfully delete table contents"
|
||||
)
|
||||
self.checkConnect()
|
||||
|
||||
d.addCallback(_check)
|
||||
return d
|
||||
|
||||
def checkConnect(self):
|
||||
"""Check the connect/disconnect synchronous calls."""
|
||||
conn = self.dbpool.connect()
|
||||
self.checkOpenfunCalled(conn)
|
||||
curs = conn.cursor()
|
||||
curs.execute("insert into simple(x) values(1)")
|
||||
curs.execute("select x from simple")
|
||||
res = curs.fetchall()
|
||||
self.assertEqual(len(res), 1)
|
||||
self.assertEqual(len(res[0]), 1)
|
||||
self.assertEqual(res[0][0], 1)
|
||||
curs.execute("delete from simple")
|
||||
curs.execute("select x from simple")
|
||||
self.assertEqual(len(curs.fetchall()), 0)
|
||||
curs.close()
|
||||
self.dbpool.disconnect(conn)
|
||||
|
||||
def interaction(self, transaction):
|
||||
transaction.execute("select x from simple order by x")
|
||||
for i in range(self.num_iterations):
|
||||
row = transaction.fetchone()
|
||||
self.assertTrue(len(row) == 1, "Wrong size row")
|
||||
self.assertTrue(row[0] == i, "Value not returned.")
|
||||
self.assertIsNone(transaction.fetchone(), "Too many rows")
|
||||
return "done"
|
||||
|
||||
def bad_interaction(self, transaction):
|
||||
if self.can_rollback:
|
||||
transaction.execute("insert into simple(x) values(0)")
|
||||
|
||||
transaction.execute("select * from NOTABLE")
|
||||
|
||||
def withConnection(self, conn):
|
||||
curs = conn.cursor()
|
||||
try:
|
||||
curs.execute("select x from simple order by x")
|
||||
for i in range(self.num_iterations):
|
||||
row = curs.fetchone()
|
||||
self.assertTrue(len(row) == 1, "Wrong size row")
|
||||
self.assertTrue(row[0] == i, "Value not returned.")
|
||||
finally:
|
||||
curs.close()
|
||||
return "done"
|
||||
|
||||
def close_withConnection(self, conn):
|
||||
conn.close()
|
||||
|
||||
def bad_withConnection(self, conn):
|
||||
curs = conn.cursor()
|
||||
try:
|
||||
curs.execute("select * from NOTABLE")
|
||||
finally:
|
||||
curs.close()
|
||||
|
||||
|
||||
class ReconnectTestBase:
|
||||
"""
|
||||
Test the asynchronous DB-API code with reconnect.
|
||||
"""
|
||||
|
||||
if interfaces.IReactorThreads(reactor, None) is None:
|
||||
skip = "ADB-API requires threads, no way to test without them"
|
||||
|
||||
def extraSetUp(self):
|
||||
"""
|
||||
Skip the test if C{good_sql} is unavailable. Otherwise, set up the
|
||||
database, create a connection pool pointed at it, and set up a simple
|
||||
schema in it.
|
||||
"""
|
||||
if self.good_sql is None:
|
||||
raise unittest.SkipTest("no good sql for reconnect test")
|
||||
self.startDB()
|
||||
self.dbpool = self.makePool(
|
||||
cp_max=1, cp_reconnect=True, cp_good_sql=self.good_sql
|
||||
)
|
||||
self.dbpool.start()
|
||||
return self.dbpool.runOperation(simple_table_schema)
|
||||
|
||||
def tearDown(self):
|
||||
d = self.dbpool.runOperation("DROP TABLE simple")
|
||||
d.addCallback(lambda res: self.dbpool.close())
|
||||
d.addCallback(lambda res: self.stopDB())
|
||||
return d
|
||||
|
||||
def test_pool(self):
|
||||
d = defer.succeed(None)
|
||||
d.addCallback(self._testPool_1)
|
||||
d.addCallback(self._testPool_2)
|
||||
if not self.early_reconnect:
|
||||
d.addCallback(self._testPool_3)
|
||||
d.addCallback(self._testPool_4)
|
||||
d.addCallback(self._testPool_5)
|
||||
return d
|
||||
|
||||
def _testPool_1(self, res):
|
||||
sql = "select count(1) from simple"
|
||||
d = self.dbpool.runQuery(sql)
|
||||
|
||||
def _check(row):
|
||||
self.assertTrue(int(row[0][0]) == 0, "Table not empty")
|
||||
|
||||
d.addCallback(_check)
|
||||
return d
|
||||
|
||||
def _testPool_2(self, res):
|
||||
# reach in and close the connection manually
|
||||
list(self.dbpool.connections.values())[0].close()
|
||||
|
||||
def _testPool_3(self, res):
|
||||
sql = "select count(1) from simple"
|
||||
d = defer.maybeDeferred(self.dbpool.runQuery, sql)
|
||||
d.addCallbacks(lambda res: self.fail("no exception"), lambda f: None)
|
||||
return d
|
||||
|
||||
def _testPool_4(self, res):
|
||||
sql = "select count(1) from simple"
|
||||
d = self.dbpool.runQuery(sql)
|
||||
|
||||
def _check(row):
|
||||
self.assertTrue(int(row[0][0]) == 0, "Table not empty")
|
||||
|
||||
d.addCallback(_check)
|
||||
return d
|
||||
|
||||
def _testPool_5(self, res):
|
||||
self.flushLoggedErrors()
|
||||
sql = "select * from NOTABLE" # bad sql
|
||||
d = defer.maybeDeferred(self.dbpool.runQuery, sql)
|
||||
d.addCallbacks(
|
||||
lambda res: self.fail("no exception"),
|
||||
lambda f: self.assertFalse(f.check(ConnectionLost)),
|
||||
)
|
||||
return d
|
||||
|
||||
|
||||
class DBTestConnector:
|
||||
"""
|
||||
A class which knows how to test for the presence of
|
||||
and establish a connection to a relational database.
|
||||
|
||||
To enable test cases which use a central, system database,
|
||||
you must create a database named DB_NAME with a user DB_USER
|
||||
and password DB_PASS with full access rights to database DB_NAME.
|
||||
"""
|
||||
|
||||
# used for creating new test cases
|
||||
TEST_PREFIX: Optional[str] = None
|
||||
|
||||
DB_NAME = "twisted_test"
|
||||
DB_USER = "twisted_test"
|
||||
DB_PASS = "twisted_test"
|
||||
|
||||
DB_DIR = None # directory for database storage
|
||||
|
||||
nulls_ok = True # nulls supported
|
||||
trailing_spaces_ok = True # trailing spaces in strings preserved
|
||||
can_rollback = True # rollback supported
|
||||
test_failures = True # test bad sql?
|
||||
escape_slashes = True # escape \ in sql?
|
||||
good_sql: Optional[str] = ConnectionPool.good_sql
|
||||
early_reconnect = True # cursor() will fail on closed connection
|
||||
can_clear = True # can try to clear out tables when starting
|
||||
|
||||
# number of iterations for test loop (lower this for slow db's)
|
||||
num_iterations = 50
|
||||
|
||||
def setUp(self):
|
||||
self.DB_DIR = self.mktemp()
|
||||
os.mkdir(self.DB_DIR)
|
||||
if not self.can_connect():
|
||||
raise unittest.SkipTest("%s: Cannot access db" % self.TEST_PREFIX)
|
||||
return self.extraSetUp()
|
||||
|
||||
def can_connect(self):
|
||||
"""Return true if this database is present on the system
|
||||
and can be used in a test."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def startDB(self):
|
||||
"""Take any steps needed to bring database up."""
|
||||
pass
|
||||
|
||||
def stopDB(self):
|
||||
"""Bring database down, if needed."""
|
||||
pass
|
||||
|
||||
def makePool(self, **newkw):
|
||||
"""Create a connection pool with additional keyword arguments."""
|
||||
args, kw = self.getPoolArgs()
|
||||
kw = kw.copy()
|
||||
kw.update(newkw)
|
||||
return ConnectionPool(*args, **kw)
|
||||
|
||||
def getPoolArgs(self):
|
||||
"""Return a tuple (args, kw) of list and keyword arguments
|
||||
that need to be passed to ConnectionPool to create a connection
|
||||
to this database."""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class SQLite3Connector(DBTestConnector):
|
||||
"""
|
||||
Connector that uses the stdlib SQLite3 database support.
|
||||
"""
|
||||
|
||||
TEST_PREFIX = "SQLite3"
|
||||
escape_slashes = False
|
||||
num_iterations = 1 # slow
|
||||
|
||||
def can_connect(self):
|
||||
if requireModule("sqlite3") is None:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def startDB(self):
|
||||
self.database = os.path.join(self.DB_DIR, self.DB_NAME)
|
||||
if os.path.exists(self.database):
|
||||
os.unlink(self.database)
|
||||
|
||||
def getPoolArgs(self):
|
||||
args = ("sqlite3",)
|
||||
kw = {"database": self.database, "cp_max": 1, "check_same_thread": False}
|
||||
return args, kw
|
||||
|
||||
|
||||
class PySQLite2Connector(DBTestConnector):
|
||||
"""
|
||||
Connector that uses pysqlite's SQLite database support.
|
||||
"""
|
||||
|
||||
TEST_PREFIX = "pysqlite2"
|
||||
escape_slashes = False
|
||||
num_iterations = 1 # slow
|
||||
|
||||
def can_connect(self):
|
||||
if requireModule("pysqlite2.dbapi2") is None:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def startDB(self):
|
||||
self.database = os.path.join(self.DB_DIR, self.DB_NAME)
|
||||
if os.path.exists(self.database):
|
||||
os.unlink(self.database)
|
||||
|
||||
def getPoolArgs(self):
|
||||
args = ("pysqlite2.dbapi2",)
|
||||
kw = {"database": self.database, "cp_max": 1, "check_same_thread": False}
|
||||
return args, kw
|
||||
|
||||
|
||||
class PyPgSQLConnector(DBTestConnector):
|
||||
TEST_PREFIX = "PyPgSQL"
|
||||
|
||||
def can_connect(self):
|
||||
try:
|
||||
from pyPgSQL import PgSQL
|
||||
except BaseException:
|
||||
return False
|
||||
try:
|
||||
conn = PgSQL.connect(
|
||||
database=self.DB_NAME, user=self.DB_USER, password=self.DB_PASS
|
||||
)
|
||||
conn.close()
|
||||
return True
|
||||
except BaseException:
|
||||
return False
|
||||
|
||||
def getPoolArgs(self):
|
||||
args = ("pyPgSQL.PgSQL",)
|
||||
kw = {
|
||||
"database": self.DB_NAME,
|
||||
"user": self.DB_USER,
|
||||
"password": self.DB_PASS,
|
||||
"cp_min": 0,
|
||||
}
|
||||
return args, kw
|
||||
|
||||
|
||||
class PsycopgConnector(DBTestConnector):
|
||||
TEST_PREFIX = "Psycopg"
|
||||
|
||||
def can_connect(self):
|
||||
try:
|
||||
import psycopg
|
||||
except BaseException:
|
||||
return False
|
||||
try:
|
||||
conn = psycopg.connect(
|
||||
database=self.DB_NAME, user=self.DB_USER, password=self.DB_PASS
|
||||
)
|
||||
conn.close()
|
||||
return True
|
||||
except BaseException:
|
||||
return False
|
||||
|
||||
def getPoolArgs(self):
|
||||
args = ("psycopg",)
|
||||
kw = {
|
||||
"database": self.DB_NAME,
|
||||
"user": self.DB_USER,
|
||||
"password": self.DB_PASS,
|
||||
"cp_min": 0,
|
||||
}
|
||||
return args, kw
|
||||
|
||||
|
||||
class MySQLConnector(DBTestConnector):
|
||||
TEST_PREFIX = "MySQL"
|
||||
|
||||
trailing_spaces_ok = False
|
||||
can_rollback = False
|
||||
early_reconnect = False
|
||||
|
||||
def can_connect(self):
|
||||
try:
|
||||
import MySQLdb
|
||||
except BaseException:
|
||||
return False
|
||||
try:
|
||||
conn = MySQLdb.connect(
|
||||
db=self.DB_NAME, user=self.DB_USER, passwd=self.DB_PASS
|
||||
)
|
||||
conn.close()
|
||||
return True
|
||||
except BaseException:
|
||||
return False
|
||||
|
||||
def getPoolArgs(self):
|
||||
args = ("MySQLdb",)
|
||||
kw = {"db": self.DB_NAME, "user": self.DB_USER, "passwd": self.DB_PASS}
|
||||
return args, kw
|
||||
|
||||
|
||||
class FirebirdConnector(DBTestConnector):
|
||||
TEST_PREFIX = "Firebird"
|
||||
|
||||
test_failures = False # failure testing causes problems
|
||||
escape_slashes = False
|
||||
good_sql = None # firebird doesn't handle failed sql well
|
||||
can_clear = False # firebird is not so good
|
||||
|
||||
num_iterations = 5 # slow
|
||||
|
||||
def can_connect(self):
|
||||
if requireModule("kinterbasdb") is None:
|
||||
return False
|
||||
try:
|
||||
self.startDB()
|
||||
self.stopDB()
|
||||
return True
|
||||
except BaseException:
|
||||
return False
|
||||
|
||||
def startDB(self):
|
||||
import kinterbasdb
|
||||
|
||||
self.DB_NAME = os.path.join(self.DB_DIR, DBTestConnector.DB_NAME)
|
||||
os.chmod(self.DB_DIR, stat.S_IRWXU + stat.S_IRWXG + stat.S_IRWXO)
|
||||
sql = 'create database "%s" user "%s" password "%s"'
|
||||
sql %= (self.DB_NAME, self.DB_USER, self.DB_PASS)
|
||||
conn = kinterbasdb.create_database(sql)
|
||||
conn.close()
|
||||
|
||||
def getPoolArgs(self):
|
||||
args = ("kinterbasdb",)
|
||||
kw = {
|
||||
"database": self.DB_NAME,
|
||||
"host": "127.0.0.1",
|
||||
"user": self.DB_USER,
|
||||
"password": self.DB_PASS,
|
||||
}
|
||||
return args, kw
|
||||
|
||||
def stopDB(self):
|
||||
import kinterbasdb
|
||||
|
||||
conn = kinterbasdb.connect(
|
||||
database=self.DB_NAME,
|
||||
host="127.0.0.1",
|
||||
user=self.DB_USER,
|
||||
password=self.DB_PASS,
|
||||
)
|
||||
conn.drop_database()
|
||||
|
||||
|
||||
def makeSQLTests(base, suffix, globals):
|
||||
"""
|
||||
Make a test case for every db connector which can connect.
|
||||
|
||||
@param base: Base class for test case. Additional base classes
|
||||
will be a DBConnector subclass and unittest.TestCase
|
||||
@param suffix: A suffix used to create test case names. Prefixes
|
||||
are defined in the DBConnector subclasses.
|
||||
"""
|
||||
connectors = [
|
||||
PySQLite2Connector,
|
||||
SQLite3Connector,
|
||||
PyPgSQLConnector,
|
||||
PsycopgConnector,
|
||||
MySQLConnector,
|
||||
FirebirdConnector,
|
||||
]
|
||||
tests = {}
|
||||
for connclass in connectors:
|
||||
name = connclass.TEST_PREFIX + suffix
|
||||
|
||||
class testcase(connclass, base, unittest.TestCase):
|
||||
__module__ = connclass.__module__
|
||||
|
||||
testcase.__name__ = name
|
||||
if hasattr(connclass, "__qualname__"):
|
||||
testcase.__qualname__ = ".".join(
|
||||
connclass.__qualname__.split()[0:-1] + [name]
|
||||
)
|
||||
tests[name] = testcase
|
||||
|
||||
globals.update(tests)
|
||||
|
||||
|
||||
# PySQLite2Connector SQLite3ADBAPITests PyPgSQLADBAPITests
|
||||
# PsycopgADBAPITests MySQLADBAPITests FirebirdADBAPITests
|
||||
makeSQLTests(ADBAPITestBase, "ADBAPITests", globals())
|
||||
|
||||
# PySQLite2Connector SQLite3ReconnectTests PyPgSQLReconnectTests
|
||||
# PsycopgReconnectTests MySQLReconnectTests FirebirdReconnectTests
|
||||
makeSQLTests(ReconnectTestBase, "ReconnectTests", globals())
|
||||
|
||||
|
||||
class FakePool:
|
||||
"""
|
||||
A fake L{ConnectionPool} for tests.
|
||||
|
||||
@ivar connectionFactory: factory for making connections returned by the
|
||||
C{connect} method.
|
||||
@type connectionFactory: any callable
|
||||
"""
|
||||
|
||||
reconnect = True
|
||||
noisy = True
|
||||
|
||||
def __init__(self, connectionFactory):
|
||||
self.connectionFactory = connectionFactory
|
||||
|
||||
def connect(self):
|
||||
"""
|
||||
Return an instance of C{self.connectionFactory}.
|
||||
"""
|
||||
return self.connectionFactory()
|
||||
|
||||
def disconnect(self, connection):
|
||||
"""
|
||||
Do nothing.
|
||||
"""
|
||||
|
||||
|
||||
class ConnectionTests(unittest.TestCase):
|
||||
"""
|
||||
Tests for the L{Connection} class.
|
||||
"""
|
||||
|
||||
def test_rollbackErrorLogged(self):
|
||||
"""
|
||||
If an error happens during rollback, L{ConnectionLost} is raised but
|
||||
the original error is logged.
|
||||
"""
|
||||
|
||||
class ConnectionRollbackRaise:
|
||||
def rollback(self):
|
||||
raise RuntimeError("problem!")
|
||||
|
||||
pool = FakePool(ConnectionRollbackRaise)
|
||||
connection = Connection(pool)
|
||||
self.assertRaises(ConnectionLost, connection.rollback)
|
||||
errors = self.flushLoggedErrors(RuntimeError)
|
||||
self.assertEqual(len(errors), 1)
|
||||
self.assertEqual(errors[0].value.args[0], "problem!")
|
||||
|
||||
|
||||
class TransactionTests(unittest.TestCase):
|
||||
"""
|
||||
Tests for the L{Transaction} class.
|
||||
"""
|
||||
|
||||
def test_reopenLogErrorIfReconnect(self):
|
||||
"""
|
||||
If the cursor creation raises an error in L{Transaction.reopen}, it
|
||||
reconnects but log the error occurred.
|
||||
"""
|
||||
|
||||
class ConnectionCursorRaise:
|
||||
count = 0
|
||||
|
||||
def reconnect(self):
|
||||
pass
|
||||
|
||||
def cursor(self):
|
||||
if self.count == 0:
|
||||
self.count += 1
|
||||
raise RuntimeError("problem!")
|
||||
|
||||
pool = FakePool(None)
|
||||
transaction = Transaction(pool, ConnectionCursorRaise())
|
||||
transaction.reopen()
|
||||
errors = self.flushLoggedErrors(RuntimeError)
|
||||
self.assertEqual(len(errors), 1)
|
||||
self.assertEqual(errors[0].value.args[0], "problem!")
|
||||
|
||||
|
||||
class NonThreadPool:
|
||||
def callInThreadWithCallback(self, onResult, f, *a, **kw):
|
||||
success = True
|
||||
try:
|
||||
result = f(*a, **kw)
|
||||
except Exception:
|
||||
success = False
|
||||
result = Failure()
|
||||
onResult(success, result)
|
||||
|
||||
|
||||
class DummyConnectionPool(ConnectionPool):
|
||||
"""
|
||||
A testable L{ConnectionPool};
|
||||
"""
|
||||
|
||||
threadpool = NonThreadPool()
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Don't forward init call.
|
||||
"""
|
||||
self._reactor = reactor
|
||||
|
||||
|
||||
class EventReactor:
|
||||
"""
|
||||
Partial L{IReactorCore} implementation with simple event-related
|
||||
methods.
|
||||
|
||||
@ivar _running: A C{bool} indicating whether the reactor is pretending
|
||||
to have been started already or not.
|
||||
|
||||
@ivar triggers: A C{list} of pending system event triggers.
|
||||
"""
|
||||
|
||||
def __init__(self, running):
|
||||
self._running = running
|
||||
self.triggers = []
|
||||
|
||||
def callWhenRunning(self, function):
|
||||
if self._running:
|
||||
function()
|
||||
else:
|
||||
return self.addSystemEventTrigger("after", "startup", function)
|
||||
|
||||
def addSystemEventTrigger(self, phase, event, trigger):
|
||||
handle = (phase, event, trigger)
|
||||
self.triggers.append(handle)
|
||||
return handle
|
||||
|
||||
def removeSystemEventTrigger(self, handle):
|
||||
self.triggers.remove(handle)
|
||||
|
||||
|
||||
class ConnectionPoolTests(unittest.TestCase):
|
||||
"""
|
||||
Unit tests for L{ConnectionPool}.
|
||||
"""
|
||||
|
||||
def test_runWithConnectionRaiseOriginalError(self):
|
||||
"""
|
||||
If rollback fails, L{ConnectionPool.runWithConnection} raises the
|
||||
original exception and log the error of the rollback.
|
||||
"""
|
||||
|
||||
class ConnectionRollbackRaise:
|
||||
def __init__(self, pool):
|
||||
pass
|
||||
|
||||
def rollback(self):
|
||||
raise RuntimeError("problem!")
|
||||
|
||||
def raisingFunction(connection):
|
||||
raise ValueError("foo")
|
||||
|
||||
pool = DummyConnectionPool()
|
||||
pool.connectionFactory = ConnectionRollbackRaise
|
||||
d = pool.runWithConnection(raisingFunction)
|
||||
d = self.assertFailure(d, ValueError)
|
||||
|
||||
def cbFailed(ignored):
|
||||
errors = self.flushLoggedErrors(RuntimeError)
|
||||
self.assertEqual(len(errors), 1)
|
||||
self.assertEqual(errors[0].value.args[0], "problem!")
|
||||
|
||||
d.addCallback(cbFailed)
|
||||
return d
|
||||
|
||||
def test_closeLogError(self):
|
||||
"""
|
||||
L{ConnectionPool._close} logs exceptions.
|
||||
"""
|
||||
|
||||
class ConnectionCloseRaise:
|
||||
def close(self):
|
||||
raise RuntimeError("problem!")
|
||||
|
||||
pool = DummyConnectionPool()
|
||||
pool._close(ConnectionCloseRaise())
|
||||
|
||||
errors = self.flushLoggedErrors(RuntimeError)
|
||||
self.assertEqual(len(errors), 1)
|
||||
self.assertEqual(errors[0].value.args[0], "problem!")
|
||||
|
||||
def test_runWithInteractionRaiseOriginalError(self):
|
||||
"""
|
||||
If rollback fails, L{ConnectionPool.runInteraction} raises the
|
||||
original exception and log the error of the rollback.
|
||||
"""
|
||||
|
||||
class ConnectionRollbackRaise:
|
||||
def __init__(self, pool):
|
||||
pass
|
||||
|
||||
def rollback(self):
|
||||
raise RuntimeError("problem!")
|
||||
|
||||
class DummyTransaction:
|
||||
def __init__(self, pool, connection):
|
||||
pass
|
||||
|
||||
def raisingFunction(transaction):
|
||||
raise ValueError("foo")
|
||||
|
||||
pool = DummyConnectionPool()
|
||||
pool.connectionFactory = ConnectionRollbackRaise
|
||||
pool.transactionFactory = DummyTransaction
|
||||
|
||||
d = pool.runInteraction(raisingFunction)
|
||||
d = self.assertFailure(d, ValueError)
|
||||
|
||||
def cbFailed(ignored):
|
||||
errors = self.flushLoggedErrors(RuntimeError)
|
||||
self.assertEqual(len(errors), 1)
|
||||
self.assertEqual(errors[0].value.args[0], "problem!")
|
||||
|
||||
d.addCallback(cbFailed)
|
||||
return d
|
||||
|
||||
def test_unstartedClose(self):
|
||||
"""
|
||||
If L{ConnectionPool.close} is called without L{ConnectionPool.start}
|
||||
having been called, the pool's startup event is cancelled.
|
||||
"""
|
||||
reactor = EventReactor(False)
|
||||
pool = ConnectionPool("twisted.test.test_adbapi", cp_reactor=reactor)
|
||||
# There should be a startup trigger waiting.
|
||||
self.assertEqual(reactor.triggers, [("after", "startup", pool._start)])
|
||||
pool.close()
|
||||
# But not anymore.
|
||||
self.assertFalse(reactor.triggers)
|
||||
|
||||
def test_startedClose(self):
|
||||
"""
|
||||
If L{ConnectionPool.close} is called after it has been started, but
|
||||
not by its shutdown trigger, the shutdown trigger is cancelled.
|
||||
"""
|
||||
reactor = EventReactor(True)
|
||||
pool = ConnectionPool("twisted.test.test_adbapi", cp_reactor=reactor)
|
||||
# There should be a shutdown trigger waiting.
|
||||
self.assertEqual(reactor.triggers, [("during", "shutdown", pool.finalClose)])
|
||||
pool.close()
|
||||
# But not anymore.
|
||||
self.assertFalse(reactor.triggers)
|
||||
3390
.venv/lib/python3.12/site-packages/twisted/test/test_amp.py
Normal file
3390
.venv/lib/python3.12/site-packages/twisted/test/test_amp.py
Normal file
File diff suppressed because it is too large
Load Diff
1018
.venv/lib/python3.12/site-packages/twisted/test/test_application.py
Normal file
1018
.venv/lib/python3.12/site-packages/twisted/test/test_application.py
Normal file
File diff suppressed because it is too large
Load Diff
548
.venv/lib/python3.12/site-packages/twisted/test/test_compat.py
Normal file
548
.venv/lib/python3.12/site-packages/twisted/test/test_compat.py
Normal file
@@ -0,0 +1,548 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
|
||||
"""
|
||||
Tests for L{twisted.python.compat}.
|
||||
"""
|
||||
|
||||
|
||||
import codecs
|
||||
import io
|
||||
import sys
|
||||
import traceback
|
||||
from unittest import skipIf
|
||||
|
||||
from twisted.python.compat import (
|
||||
_PYPY,
|
||||
bytesEnviron,
|
||||
cmp,
|
||||
comparable,
|
||||
execfile,
|
||||
intToBytes,
|
||||
ioType,
|
||||
iterbytes,
|
||||
lazyByteSlice,
|
||||
nativeString,
|
||||
networkString,
|
||||
reraise,
|
||||
)
|
||||
from twisted.python.filepath import FilePath
|
||||
from twisted.python.runtime import platform
|
||||
from twisted.trial.unittest import SynchronousTestCase, TestCase
|
||||
|
||||
|
||||
class IOTypeTests(SynchronousTestCase):
|
||||
"""
|
||||
Test cases for determining a file-like object's type.
|
||||
"""
|
||||
|
||||
def test_3StringIO(self):
|
||||
"""
|
||||
An L{io.StringIO} accepts and returns text.
|
||||
"""
|
||||
self.assertEqual(ioType(io.StringIO()), str)
|
||||
|
||||
def test_3BytesIO(self):
|
||||
"""
|
||||
An L{io.BytesIO} accepts and returns bytes.
|
||||
"""
|
||||
self.assertEqual(ioType(io.BytesIO()), bytes)
|
||||
|
||||
def test_3openTextMode(self):
|
||||
"""
|
||||
A file opened via 'io.open' in text mode accepts and returns text.
|
||||
"""
|
||||
with open(self.mktemp(), "w") as f:
|
||||
self.assertEqual(ioType(f), str)
|
||||
|
||||
def test_3openBinaryMode(self):
|
||||
"""
|
||||
A file opened via 'io.open' in binary mode accepts and returns bytes.
|
||||
"""
|
||||
with open(self.mktemp(), "wb") as f:
|
||||
self.assertEqual(ioType(f), bytes)
|
||||
|
||||
def test_codecsOpenBytes(self):
|
||||
"""
|
||||
The L{codecs} module, oddly, returns a file-like object which returns
|
||||
bytes when not passed an 'encoding' argument.
|
||||
"""
|
||||
with codecs.open(self.mktemp(), "wb") as f:
|
||||
self.assertEqual(ioType(f), bytes)
|
||||
|
||||
def test_codecsOpenText(self):
|
||||
"""
|
||||
When passed an encoding, however, the L{codecs} module returns unicode.
|
||||
"""
|
||||
with codecs.open(self.mktemp(), "wb", encoding="utf-8") as f:
|
||||
self.assertEqual(ioType(f), str)
|
||||
|
||||
def test_defaultToText(self):
|
||||
"""
|
||||
When passed an object about which no sensible decision can be made, err
|
||||
on the side of unicode.
|
||||
"""
|
||||
self.assertEqual(ioType(object()), str)
|
||||
|
||||
|
||||
class CompatTests(SynchronousTestCase):
|
||||
"""
|
||||
Various utility functions in C{twisted.python.compat} provide same
|
||||
functionality as modern Python variants.
|
||||
"""
|
||||
|
||||
def test_set(self):
|
||||
"""
|
||||
L{set} should behave like the expected set interface.
|
||||
"""
|
||||
a = set()
|
||||
a.add("b")
|
||||
a.add("c")
|
||||
a.add("a")
|
||||
b = list(a)
|
||||
b.sort()
|
||||
self.assertEqual(b, ["a", "b", "c"])
|
||||
a.remove("b")
|
||||
b = list(a)
|
||||
b.sort()
|
||||
self.assertEqual(b, ["a", "c"])
|
||||
|
||||
a.discard("d")
|
||||
|
||||
b = {"r", "s"}
|
||||
d = a.union(b)
|
||||
b = list(d)
|
||||
b.sort()
|
||||
self.assertEqual(b, ["a", "c", "r", "s"])
|
||||
|
||||
def test_frozenset(self):
|
||||
"""
|
||||
L{frozenset} should behave like the expected frozenset interface.
|
||||
"""
|
||||
a = frozenset(["a", "b"])
|
||||
self.assertRaises(AttributeError, getattr, a, "add")
|
||||
self.assertEqual(sorted(a), ["a", "b"])
|
||||
|
||||
b = frozenset(["r", "s"])
|
||||
d = a.union(b)
|
||||
b = list(d)
|
||||
b.sort()
|
||||
self.assertEqual(b, ["a", "b", "r", "s"])
|
||||
|
||||
|
||||
class ExecfileCompatTests(SynchronousTestCase):
|
||||
"""
|
||||
Tests for the Python 3-friendly L{execfile} implementation.
|
||||
"""
|
||||
|
||||
def writeScript(self, content):
|
||||
"""
|
||||
Write L{content} to a new temporary file, returning the L{FilePath}
|
||||
for the new file.
|
||||
"""
|
||||
path = self.mktemp()
|
||||
with open(path, "wb") as f:
|
||||
f.write(content.encode("ascii"))
|
||||
return FilePath(path.encode("utf-8"))
|
||||
|
||||
def test_execfileGlobals(self):
|
||||
"""
|
||||
L{execfile} executes the specified file in the given global namespace.
|
||||
"""
|
||||
script = self.writeScript("foo += 1\n")
|
||||
globalNamespace = {"foo": 1}
|
||||
execfile(script.path, globalNamespace)
|
||||
self.assertEqual(2, globalNamespace["foo"])
|
||||
|
||||
def test_execfileGlobalsAndLocals(self):
|
||||
"""
|
||||
L{execfile} executes the specified file in the given global and local
|
||||
namespaces.
|
||||
"""
|
||||
script = self.writeScript("foo += 1\n")
|
||||
globalNamespace = {"foo": 10}
|
||||
localNamespace = {"foo": 20}
|
||||
execfile(script.path, globalNamespace, localNamespace)
|
||||
self.assertEqual(10, globalNamespace["foo"])
|
||||
self.assertEqual(21, localNamespace["foo"])
|
||||
|
||||
def test_execfileUniversalNewlines(self):
|
||||
"""
|
||||
L{execfile} reads in the specified file using universal newlines so
|
||||
that scripts written on one platform will work on another.
|
||||
"""
|
||||
for lineEnding in "\n", "\r", "\r\n":
|
||||
script = self.writeScript("foo = 'okay'" + lineEnding)
|
||||
globalNamespace = {"foo": None}
|
||||
execfile(script.path, globalNamespace)
|
||||
self.assertEqual("okay", globalNamespace["foo"])
|
||||
|
||||
|
||||
class PYPYTest(SynchronousTestCase):
|
||||
"""
|
||||
Identification of PyPy.
|
||||
"""
|
||||
|
||||
def test_PYPY(self):
|
||||
"""
|
||||
On PyPy, L{_PYPY} is True.
|
||||
"""
|
||||
if "PyPy" in sys.version:
|
||||
self.assertTrue(_PYPY)
|
||||
else:
|
||||
self.assertFalse(_PYPY)
|
||||
|
||||
|
||||
@comparable
|
||||
class Comparable:
|
||||
"""
|
||||
Objects that can be compared to each other, but not others.
|
||||
"""
|
||||
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
def __cmp__(self, other):
|
||||
if not isinstance(other, Comparable):
|
||||
return NotImplemented
|
||||
return cmp(self.value, other.value)
|
||||
|
||||
|
||||
class ComparableTests(SynchronousTestCase):
|
||||
"""
|
||||
L{comparable} decorated classes emulate Python 2's C{__cmp__} semantics.
|
||||
"""
|
||||
|
||||
def test_equality(self):
|
||||
"""
|
||||
Instances of a class that is decorated by C{comparable} support
|
||||
equality comparisons.
|
||||
"""
|
||||
# Make explicitly sure we're using ==:
|
||||
self.assertTrue(Comparable(1) == Comparable(1))
|
||||
self.assertFalse(Comparable(2) == Comparable(1))
|
||||
|
||||
def test_nonEquality(self):
|
||||
"""
|
||||
Instances of a class that is decorated by C{comparable} support
|
||||
inequality comparisons.
|
||||
"""
|
||||
# Make explicitly sure we're using !=:
|
||||
self.assertFalse(Comparable(1) != Comparable(1))
|
||||
self.assertTrue(Comparable(2) != Comparable(1))
|
||||
|
||||
def test_greaterThan(self):
|
||||
"""
|
||||
Instances of a class that is decorated by C{comparable} support
|
||||
greater-than comparisons.
|
||||
"""
|
||||
self.assertTrue(Comparable(2) > Comparable(1))
|
||||
self.assertFalse(Comparable(0) > Comparable(3))
|
||||
|
||||
def test_greaterThanOrEqual(self):
|
||||
"""
|
||||
Instances of a class that is decorated by C{comparable} support
|
||||
greater-than-or-equal comparisons.
|
||||
"""
|
||||
self.assertTrue(Comparable(1) >= Comparable(1))
|
||||
self.assertTrue(Comparable(2) >= Comparable(1))
|
||||
self.assertFalse(Comparable(0) >= Comparable(3))
|
||||
|
||||
def test_lessThan(self):
|
||||
"""
|
||||
Instances of a class that is decorated by C{comparable} support
|
||||
less-than comparisons.
|
||||
"""
|
||||
self.assertTrue(Comparable(0) < Comparable(3))
|
||||
self.assertFalse(Comparable(2) < Comparable(0))
|
||||
|
||||
def test_lessThanOrEqual(self):
|
||||
"""
|
||||
Instances of a class that is decorated by C{comparable} support
|
||||
less-than-or-equal comparisons.
|
||||
"""
|
||||
self.assertTrue(Comparable(3) <= Comparable(3))
|
||||
self.assertTrue(Comparable(0) <= Comparable(3))
|
||||
self.assertFalse(Comparable(2) <= Comparable(0))
|
||||
|
||||
|
||||
class Python3ComparableTests(SynchronousTestCase):
|
||||
"""
|
||||
Python 3-specific functionality of C{comparable}.
|
||||
"""
|
||||
|
||||
def test_notImplementedEquals(self):
|
||||
"""
|
||||
Instances of a class that is decorated by C{comparable} support
|
||||
returning C{NotImplemented} from C{__eq__} if it is returned by the
|
||||
underlying C{__cmp__} call.
|
||||
"""
|
||||
self.assertEqual(Comparable(1).__eq__(object()), NotImplemented)
|
||||
|
||||
def test_notImplementedNotEquals(self):
|
||||
"""
|
||||
Instances of a class that is decorated by C{comparable} support
|
||||
returning C{NotImplemented} from C{__ne__} if it is returned by the
|
||||
underlying C{__cmp__} call.
|
||||
"""
|
||||
self.assertEqual(Comparable(1).__ne__(object()), NotImplemented)
|
||||
|
||||
def test_notImplementedGreaterThan(self):
|
||||
"""
|
||||
Instances of a class that is decorated by C{comparable} support
|
||||
returning C{NotImplemented} from C{__gt__} if it is returned by the
|
||||
underlying C{__cmp__} call.
|
||||
"""
|
||||
self.assertEqual(Comparable(1).__gt__(object()), NotImplemented)
|
||||
|
||||
def test_notImplementedLessThan(self):
|
||||
"""
|
||||
Instances of a class that is decorated by C{comparable} support
|
||||
returning C{NotImplemented} from C{__lt__} if it is returned by the
|
||||
underlying C{__cmp__} call.
|
||||
"""
|
||||
self.assertEqual(Comparable(1).__lt__(object()), NotImplemented)
|
||||
|
||||
def test_notImplementedGreaterThanEquals(self):
|
||||
"""
|
||||
Instances of a class that is decorated by C{comparable} support
|
||||
returning C{NotImplemented} from C{__ge__} if it is returned by the
|
||||
underlying C{__cmp__} call.
|
||||
"""
|
||||
self.assertEqual(Comparable(1).__ge__(object()), NotImplemented)
|
||||
|
||||
def test_notImplementedLessThanEquals(self):
|
||||
"""
|
||||
Instances of a class that is decorated by C{comparable} support
|
||||
returning C{NotImplemented} from C{__le__} if it is returned by the
|
||||
underlying C{__cmp__} call.
|
||||
"""
|
||||
self.assertEqual(Comparable(1).__le__(object()), NotImplemented)
|
||||
|
||||
|
||||
class CmpTests(SynchronousTestCase):
|
||||
"""
|
||||
L{cmp} should behave like the built-in Python 2 C{cmp}.
|
||||
"""
|
||||
|
||||
def test_equals(self):
|
||||
"""
|
||||
L{cmp} returns 0 for equal objects.
|
||||
"""
|
||||
self.assertEqual(cmp("a", "a"), 0)
|
||||
self.assertEqual(cmp(1, 1), 0)
|
||||
self.assertEqual(cmp([1], [1]), 0)
|
||||
|
||||
def test_greaterThan(self):
|
||||
"""
|
||||
L{cmp} returns 1 if its first argument is bigger than its second.
|
||||
"""
|
||||
self.assertEqual(cmp(4, 0), 1)
|
||||
self.assertEqual(cmp(b"z", b"a"), 1)
|
||||
|
||||
def test_lessThan(self):
|
||||
"""
|
||||
L{cmp} returns -1 if its first argument is smaller than its second.
|
||||
"""
|
||||
self.assertEqual(cmp(0.1, 2.3), -1)
|
||||
self.assertEqual(cmp(b"a", b"d"), -1)
|
||||
|
||||
|
||||
class StringTests(SynchronousTestCase):
|
||||
"""
|
||||
Compatibility functions and types for strings.
|
||||
"""
|
||||
|
||||
def assertNativeString(self, original, expected):
|
||||
"""
|
||||
Raise an exception indicating a failed test if the output of
|
||||
C{nativeString(original)} is unequal to the expected string, or is not
|
||||
a native string.
|
||||
"""
|
||||
self.assertEqual(nativeString(original), expected)
|
||||
self.assertIsInstance(nativeString(original), str)
|
||||
|
||||
def test_nonASCIIBytesToString(self):
|
||||
"""
|
||||
C{nativeString} raises a C{UnicodeError} if input bytes are not ASCII
|
||||
decodable.
|
||||
"""
|
||||
self.assertRaises(UnicodeError, nativeString, b"\xFF")
|
||||
|
||||
def test_nonASCIIUnicodeToString(self):
|
||||
"""
|
||||
C{nativeString} raises a C{UnicodeError} if input Unicode is not ASCII
|
||||
encodable.
|
||||
"""
|
||||
self.assertRaises(UnicodeError, nativeString, "\u1234")
|
||||
|
||||
def test_bytesToString(self):
|
||||
"""
|
||||
C{nativeString} converts bytes to the native string format, assuming
|
||||
an ASCII encoding if applicable.
|
||||
"""
|
||||
self.assertNativeString(b"hello", "hello")
|
||||
|
||||
def test_unicodeToString(self):
|
||||
"""
|
||||
C{nativeString} converts unicode to the native string format, assuming
|
||||
an ASCII encoding if applicable.
|
||||
"""
|
||||
self.assertNativeString("Good day", "Good day")
|
||||
|
||||
def test_stringToString(self):
|
||||
"""
|
||||
C{nativeString} leaves native strings as native strings.
|
||||
"""
|
||||
self.assertNativeString("Hello!", "Hello!")
|
||||
|
||||
def test_unexpectedType(self):
|
||||
"""
|
||||
C{nativeString} raises a C{TypeError} if given an object that is not a
|
||||
string of some sort.
|
||||
"""
|
||||
self.assertRaises(TypeError, nativeString, 1)
|
||||
|
||||
|
||||
class NetworkStringTests(SynchronousTestCase):
|
||||
"""
|
||||
Tests for L{networkString}.
|
||||
"""
|
||||
|
||||
def test_str(self):
|
||||
"""
|
||||
L{networkString} returns a C{unicode} object passed to it encoded into
|
||||
a C{bytes} instance.
|
||||
"""
|
||||
self.assertEqual(b"foo", networkString("foo"))
|
||||
|
||||
def test_unicodeOutOfRange(self):
|
||||
"""
|
||||
L{networkString} raises L{UnicodeError} if passed a C{unicode} instance
|
||||
containing characters not encodable in ASCII.
|
||||
"""
|
||||
self.assertRaises(UnicodeError, networkString, "\N{SNOWMAN}")
|
||||
|
||||
def test_nonString(self):
|
||||
"""
|
||||
L{networkString} raises L{TypeError} if passed a non-string object or
|
||||
the wrong type of string object.
|
||||
"""
|
||||
self.assertRaises(TypeError, networkString, object())
|
||||
self.assertRaises(TypeError, networkString, b"bytes")
|
||||
|
||||
|
||||
class ReraiseTests(SynchronousTestCase):
|
||||
"""
|
||||
L{reraise} re-raises exceptions on both Python 2 and Python 3.
|
||||
"""
|
||||
|
||||
def test_reraiseWithNone(self):
|
||||
"""
|
||||
Calling L{reraise} with an exception instance and a traceback of
|
||||
L{None} re-raises it with a new traceback.
|
||||
"""
|
||||
try:
|
||||
1 / 0
|
||||
except BaseException:
|
||||
typ, value, tb = sys.exc_info()
|
||||
try:
|
||||
reraise(value, None)
|
||||
except BaseException:
|
||||
typ2, value2, tb2 = sys.exc_info()
|
||||
self.assertEqual(typ2, ZeroDivisionError)
|
||||
self.assertIs(value, value2)
|
||||
self.assertNotEqual(
|
||||
traceback.format_tb(tb)[-1], traceback.format_tb(tb2)[-1]
|
||||
)
|
||||
else:
|
||||
self.fail("The exception was not raised.")
|
||||
|
||||
def test_reraiseWithTraceback(self):
|
||||
"""
|
||||
Calling L{reraise} with an exception instance and a traceback
|
||||
re-raises the exception with the given traceback.
|
||||
"""
|
||||
try:
|
||||
1 / 0
|
||||
except BaseException:
|
||||
typ, value, tb = sys.exc_info()
|
||||
try:
|
||||
reraise(value, tb)
|
||||
except BaseException:
|
||||
typ2, value2, tb2 = sys.exc_info()
|
||||
self.assertEqual(typ2, ZeroDivisionError)
|
||||
self.assertIs(value, value2)
|
||||
self.assertEqual(traceback.format_tb(tb)[-1], traceback.format_tb(tb2)[-1])
|
||||
else:
|
||||
self.fail("The exception was not raised.")
|
||||
|
||||
|
||||
class Python3BytesTests(SynchronousTestCase):
|
||||
"""
|
||||
Tests for L{iterbytes}, L{intToBytes}, L{lazyByteSlice}.
|
||||
"""
|
||||
|
||||
def test_iteration(self):
|
||||
"""
|
||||
When L{iterbytes} is called with a bytestring, the returned object
|
||||
can be iterated over, resulting in the individual bytes of the
|
||||
bytestring.
|
||||
"""
|
||||
input = b"abcd"
|
||||
result = list(iterbytes(input))
|
||||
self.assertEqual(result, [b"a", b"b", b"c", b"d"])
|
||||
|
||||
def test_intToBytes(self):
|
||||
"""
|
||||
When L{intToBytes} is called with an integer, the result is an
|
||||
ASCII-encoded string representation of the number.
|
||||
"""
|
||||
self.assertEqual(intToBytes(213), b"213")
|
||||
|
||||
def test_lazyByteSliceNoOffset(self):
|
||||
"""
|
||||
L{lazyByteSlice} called with some bytes returns a semantically equal
|
||||
version of these bytes.
|
||||
"""
|
||||
data = b"123XYZ"
|
||||
self.assertEqual(bytes(lazyByteSlice(data)), data)
|
||||
|
||||
def test_lazyByteSliceOffset(self):
|
||||
"""
|
||||
L{lazyByteSlice} called with some bytes and an offset returns a
|
||||
semantically equal version of these bytes starting at the given offset.
|
||||
"""
|
||||
data = b"123XYZ"
|
||||
self.assertEqual(bytes(lazyByteSlice(data, 2)), data[2:])
|
||||
|
||||
def test_lazyByteSliceOffsetAndLength(self):
|
||||
"""
|
||||
L{lazyByteSlice} called with some bytes, an offset and a length returns
|
||||
a semantically equal version of these bytes starting at the given
|
||||
offset, up to the given length.
|
||||
"""
|
||||
data = b"123XYZ"
|
||||
self.assertEqual(bytes(lazyByteSlice(data, 2, 3)), data[2:5])
|
||||
|
||||
|
||||
class BytesEnvironTests(TestCase):
|
||||
"""
|
||||
Tests for L{BytesEnviron}.
|
||||
"""
|
||||
|
||||
@skipIf(platform.isWindows(), "Environment vars are always str on Windows.")
|
||||
def test_alwaysBytes(self):
|
||||
"""
|
||||
The output of L{BytesEnviron} should always be a L{dict} with L{bytes}
|
||||
values and L{bytes} keys.
|
||||
"""
|
||||
result = bytesEnviron()
|
||||
types = set()
|
||||
|
||||
for key, val in result.items():
|
||||
types.add(type(key))
|
||||
types.add(type(val))
|
||||
|
||||
self.assertEqual(list(types), [bytes])
|
||||
@@ -0,0 +1,48 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.python.context}.
|
||||
"""
|
||||
|
||||
|
||||
from twisted.python import context
|
||||
from twisted.trial.unittest import SynchronousTestCase
|
||||
|
||||
|
||||
class ContextTests(SynchronousTestCase):
|
||||
"""
|
||||
Tests for the module-scope APIs for L{twisted.python.context}.
|
||||
"""
|
||||
|
||||
def test_notPresentIfNotSet(self):
|
||||
"""
|
||||
Arbitrary keys which have not been set in the context have an associated
|
||||
value of L{None}.
|
||||
"""
|
||||
self.assertIsNone(context.get("x"))
|
||||
|
||||
def test_setByCall(self):
|
||||
"""
|
||||
Values may be associated with keys by passing them in a dictionary as
|
||||
the first argument to L{twisted.python.context.call}.
|
||||
"""
|
||||
self.assertEqual(context.call({"x": "y"}, context.get, "x"), "y")
|
||||
|
||||
def test_unsetAfterCall(self):
|
||||
"""
|
||||
After a L{twisted.python.context.call} completes, keys specified in the
|
||||
call are no longer associated with the values from that call.
|
||||
"""
|
||||
context.call({"x": "y"}, lambda: None)
|
||||
self.assertIsNone(context.get("x"))
|
||||
|
||||
def test_setDefault(self):
|
||||
"""
|
||||
A default value may be set for a key in the context using
|
||||
L{twisted.python.context.setDefault}.
|
||||
"""
|
||||
key = object()
|
||||
self.addCleanup(context.defaultContextDict.pop, key, None)
|
||||
context.setDefault(key, "y")
|
||||
self.assertEqual("y", context.get(key))
|
||||
@@ -0,0 +1,690 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
This module contains tests for L{twisted.internet.task.Cooperator} and
|
||||
related functionality.
|
||||
"""
|
||||
|
||||
|
||||
from twisted.internet import defer, reactor, task
|
||||
from twisted.trial import unittest
|
||||
|
||||
|
||||
class FakeDelayedCall:
|
||||
"""
|
||||
Fake delayed call which lets us simulate the scheduler.
|
||||
"""
|
||||
|
||||
def __init__(self, func):
|
||||
"""
|
||||
A function to run, later.
|
||||
"""
|
||||
self.func = func
|
||||
self.cancelled = False
|
||||
|
||||
def cancel(self):
|
||||
"""
|
||||
Don't run my function later.
|
||||
"""
|
||||
self.cancelled = True
|
||||
|
||||
|
||||
class FakeScheduler:
|
||||
"""
|
||||
A fake scheduler for testing against.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Create a fake scheduler with a list of work to do.
|
||||
"""
|
||||
self.work = []
|
||||
|
||||
def __call__(self, thunk):
|
||||
"""
|
||||
Schedule a unit of work to be done later.
|
||||
"""
|
||||
unit = FakeDelayedCall(thunk)
|
||||
self.work.append(unit)
|
||||
return unit
|
||||
|
||||
def pump(self):
|
||||
"""
|
||||
Do all of the work that is currently available to be done.
|
||||
"""
|
||||
work, self.work = self.work, []
|
||||
for unit in work:
|
||||
if not unit.cancelled:
|
||||
unit.func()
|
||||
|
||||
|
||||
class CooperatorTests(unittest.TestCase):
|
||||
RESULT = "done"
|
||||
|
||||
def ebIter(self, err):
|
||||
err.trap(task.SchedulerStopped)
|
||||
return self.RESULT
|
||||
|
||||
def cbIter(self, ign):
|
||||
self.fail()
|
||||
|
||||
def testStoppedRejectsNewTasks(self):
|
||||
"""
|
||||
Test that Cooperators refuse new tasks when they have been stopped.
|
||||
"""
|
||||
|
||||
def testwith(stuff):
|
||||
c = task.Cooperator()
|
||||
c.stop()
|
||||
d = c.coiterate(iter(()), stuff)
|
||||
d.addCallback(self.cbIter)
|
||||
d.addErrback(self.ebIter)
|
||||
return d.addCallback(lambda result: self.assertEqual(result, self.RESULT))
|
||||
|
||||
return testwith(None).addCallback(lambda ign: testwith(defer.Deferred()))
|
||||
|
||||
def testStopRunning(self):
|
||||
"""
|
||||
Test that a running iterator will not run to completion when the
|
||||
cooperator is stopped.
|
||||
"""
|
||||
c = task.Cooperator()
|
||||
|
||||
def myiter():
|
||||
yield from range(3)
|
||||
|
||||
myiter.value = -1
|
||||
d = c.coiterate(myiter())
|
||||
d.addCallback(self.cbIter)
|
||||
d.addErrback(self.ebIter)
|
||||
c.stop()
|
||||
|
||||
def doasserts(result):
|
||||
self.assertEqual(result, self.RESULT)
|
||||
self.assertEqual(myiter.value, -1)
|
||||
|
||||
d.addCallback(doasserts)
|
||||
return d
|
||||
|
||||
def testStopOutstanding(self):
|
||||
"""
|
||||
An iterator run with L{Cooperator.coiterate} paused on a L{Deferred}
|
||||
yielded by that iterator will fire its own L{Deferred} (the one
|
||||
returned by C{coiterate}) when L{Cooperator.stop} is called.
|
||||
"""
|
||||
testControlD = defer.Deferred()
|
||||
outstandingD = defer.Deferred()
|
||||
|
||||
def myiter():
|
||||
reactor.callLater(0, testControlD.callback, None)
|
||||
yield outstandingD
|
||||
self.fail()
|
||||
|
||||
c = task.Cooperator()
|
||||
d = c.coiterate(myiter())
|
||||
|
||||
def stopAndGo(ign):
|
||||
c.stop()
|
||||
outstandingD.callback("arglebargle")
|
||||
|
||||
testControlD.addCallback(stopAndGo)
|
||||
d.addCallback(self.cbIter)
|
||||
d.addErrback(self.ebIter)
|
||||
|
||||
return d.addCallback(lambda result: self.assertEqual(result, self.RESULT))
|
||||
|
||||
def testUnexpectedError(self):
|
||||
c = task.Cooperator()
|
||||
|
||||
def myiter():
|
||||
if False:
|
||||
yield None
|
||||
else:
|
||||
raise RuntimeError()
|
||||
|
||||
d = c.coiterate(myiter())
|
||||
return self.assertFailure(d, RuntimeError)
|
||||
|
||||
def testUnexpectedErrorActuallyLater(self):
|
||||
def myiter():
|
||||
D = defer.Deferred()
|
||||
reactor.callLater(0, D.errback, RuntimeError())
|
||||
yield D
|
||||
|
||||
c = task.Cooperator()
|
||||
d = c.coiterate(myiter())
|
||||
return self.assertFailure(d, RuntimeError)
|
||||
|
||||
def testUnexpectedErrorNotActuallyLater(self):
|
||||
def myiter():
|
||||
yield defer.fail(RuntimeError())
|
||||
|
||||
c = task.Cooperator()
|
||||
d = c.coiterate(myiter())
|
||||
return self.assertFailure(d, RuntimeError)
|
||||
|
||||
def testCooperation(self):
|
||||
L = []
|
||||
|
||||
def myiter(things):
|
||||
for th in things:
|
||||
L.append(th)
|
||||
yield None
|
||||
|
||||
groupsOfThings = ["abc", (1, 2, 3), "def", (4, 5, 6)]
|
||||
|
||||
c = task.Cooperator()
|
||||
tasks = []
|
||||
for stuff in groupsOfThings:
|
||||
tasks.append(c.coiterate(myiter(stuff)))
|
||||
|
||||
return defer.DeferredList(tasks).addCallback(
|
||||
lambda ign: self.assertEqual(tuple(L), sum(zip(*groupsOfThings), ()))
|
||||
)
|
||||
|
||||
def testResourceExhaustion(self):
|
||||
output = []
|
||||
|
||||
def myiter():
|
||||
for i in range(100):
|
||||
output.append(i)
|
||||
if i == 9:
|
||||
_TPF.stopped = True
|
||||
yield i
|
||||
|
||||
class _TPF:
|
||||
stopped = False
|
||||
|
||||
def __call__(self):
|
||||
return self.stopped
|
||||
|
||||
c = task.Cooperator(terminationPredicateFactory=_TPF)
|
||||
c.coiterate(myiter()).addErrback(self.ebIter)
|
||||
c._delayedCall.cancel()
|
||||
# testing a private method because only the test case will ever care
|
||||
# about this, so we have to carefully clean up after ourselves.
|
||||
c._tick()
|
||||
c.stop()
|
||||
self.assertTrue(_TPF.stopped)
|
||||
self.assertEqual(output, list(range(10)))
|
||||
|
||||
def testCallbackReCoiterate(self):
|
||||
"""
|
||||
If a callback to a deferred returned by coiterate calls coiterate on
|
||||
the same Cooperator, we should make sure to only do the minimal amount
|
||||
of scheduling work. (This test was added to demonstrate a specific bug
|
||||
that was found while writing the scheduler.)
|
||||
"""
|
||||
calls = []
|
||||
|
||||
class FakeCall:
|
||||
def __init__(self, func):
|
||||
self.func = func
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<FakeCall {self.func!r}>"
|
||||
|
||||
def sched(f):
|
||||
self.assertFalse(calls, repr(calls))
|
||||
calls.append(FakeCall(f))
|
||||
return calls[-1]
|
||||
|
||||
c = task.Cooperator(
|
||||
scheduler=sched, terminationPredicateFactory=lambda: lambda: True
|
||||
)
|
||||
d = c.coiterate(iter(()))
|
||||
|
||||
done = []
|
||||
|
||||
def anotherTask(ign):
|
||||
c.coiterate(iter(())).addBoth(done.append)
|
||||
|
||||
d.addCallback(anotherTask)
|
||||
|
||||
work = 0
|
||||
while not done:
|
||||
work += 1
|
||||
while calls:
|
||||
calls.pop(0).func()
|
||||
work += 1
|
||||
if work > 50:
|
||||
self.fail("Cooperator took too long")
|
||||
|
||||
def test_removingLastTaskStopsScheduledCall(self):
|
||||
"""
|
||||
If the last task in a Cooperator is removed, the scheduled call for
|
||||
the next tick is cancelled, since it is no longer necessary.
|
||||
|
||||
This behavior is useful for tests that want to assert they have left
|
||||
no reactor state behind when they're done.
|
||||
"""
|
||||
calls = [None]
|
||||
|
||||
def sched(f):
|
||||
calls[0] = FakeDelayedCall(f)
|
||||
return calls[0]
|
||||
|
||||
coop = task.Cooperator(scheduler=sched)
|
||||
|
||||
# Add two task; this should schedule the tick:
|
||||
task1 = coop.cooperate(iter([1, 2]))
|
||||
task2 = coop.cooperate(iter([1, 2]))
|
||||
self.assertEqual(calls[0].func, coop._tick)
|
||||
|
||||
# Remove first task; scheduled call should still be going:
|
||||
task1.stop()
|
||||
self.assertFalse(calls[0].cancelled)
|
||||
self.assertEqual(coop._delayedCall, calls[0])
|
||||
|
||||
# Remove second task; scheduled call should be cancelled:
|
||||
task2.stop()
|
||||
self.assertTrue(calls[0].cancelled)
|
||||
self.assertIsNone(coop._delayedCall)
|
||||
|
||||
# Add another task; scheduled call will be recreated:
|
||||
coop.cooperate(iter([1, 2]))
|
||||
self.assertFalse(calls[0].cancelled)
|
||||
self.assertEqual(coop._delayedCall, calls[0])
|
||||
|
||||
def test_runningWhenStarted(self):
|
||||
"""
|
||||
L{Cooperator.running} reports C{True} if the L{Cooperator}
|
||||
was started on creation.
|
||||
"""
|
||||
c = task.Cooperator()
|
||||
self.assertTrue(c.running)
|
||||
|
||||
def test_runningWhenNotStarted(self):
|
||||
"""
|
||||
L{Cooperator.running} reports C{False} if the L{Cooperator}
|
||||
has not been started.
|
||||
"""
|
||||
c = task.Cooperator(started=False)
|
||||
self.assertFalse(c.running)
|
||||
|
||||
def test_runningWhenRunning(self):
|
||||
"""
|
||||
L{Cooperator.running} reports C{True} when the L{Cooperator}
|
||||
is running.
|
||||
"""
|
||||
c = task.Cooperator(started=False)
|
||||
c.start()
|
||||
self.addCleanup(c.stop)
|
||||
self.assertTrue(c.running)
|
||||
|
||||
def test_runningWhenStopped(self):
|
||||
"""
|
||||
L{Cooperator.running} reports C{False} after the L{Cooperator}
|
||||
has been stopped.
|
||||
"""
|
||||
c = task.Cooperator(started=False)
|
||||
c.start()
|
||||
c.stop()
|
||||
self.assertFalse(c.running)
|
||||
|
||||
|
||||
class UnhandledException(Exception):
|
||||
"""
|
||||
An exception that should go unhandled.
|
||||
"""
|
||||
|
||||
|
||||
class AliasTests(unittest.TestCase):
|
||||
"""
|
||||
Integration test to verify that the global singleton aliases do what
|
||||
they're supposed to.
|
||||
"""
|
||||
|
||||
def test_cooperate(self):
|
||||
"""
|
||||
L{twisted.internet.task.cooperate} ought to run the generator that it is
|
||||
"""
|
||||
d = defer.Deferred()
|
||||
|
||||
def doit():
|
||||
yield 1
|
||||
yield 2
|
||||
yield 3
|
||||
d.callback("yay")
|
||||
|
||||
it = doit()
|
||||
theTask = task.cooperate(it)
|
||||
self.assertIn(theTask, task._theCooperator._tasks)
|
||||
return d
|
||||
|
||||
|
||||
class RunStateTests(unittest.TestCase):
|
||||
"""
|
||||
Tests to verify the behavior of L{CooperativeTask.pause},
|
||||
L{CooperativeTask.resume}, L{CooperativeTask.stop}, exhausting the
|
||||
underlying iterator, and their interactions with each other.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""
|
||||
Create a cooperator with a fake scheduler and a termination predicate
|
||||
that ensures only one unit of work will take place per tick.
|
||||
"""
|
||||
self._doDeferNext = False
|
||||
self._doStopNext = False
|
||||
self._doDieNext = False
|
||||
self.work = []
|
||||
self.scheduler = FakeScheduler()
|
||||
self.cooperator = task.Cooperator(
|
||||
scheduler=self.scheduler,
|
||||
# Always stop after one iteration of work (return a function which
|
||||
# returns a function which always returns True)
|
||||
terminationPredicateFactory=lambda: lambda: True,
|
||||
)
|
||||
self.task = self.cooperator.cooperate(self.worker())
|
||||
self.cooperator.start()
|
||||
|
||||
def worker(self):
|
||||
"""
|
||||
This is a sample generator which yields Deferreds when we are testing
|
||||
deferral and an ascending integer count otherwise.
|
||||
"""
|
||||
i = 0
|
||||
while True:
|
||||
i += 1
|
||||
if self._doDeferNext:
|
||||
self._doDeferNext = False
|
||||
d = defer.Deferred()
|
||||
self.work.append(d)
|
||||
yield d
|
||||
elif self._doStopNext:
|
||||
return
|
||||
elif self._doDieNext:
|
||||
raise UnhandledException()
|
||||
else:
|
||||
self.work.append(i)
|
||||
yield i
|
||||
|
||||
def tearDown(self):
|
||||
"""
|
||||
Drop references to interesting parts of the fixture to allow Deferred
|
||||
errors to be noticed when things start failing.
|
||||
"""
|
||||
del self.task
|
||||
del self.scheduler
|
||||
|
||||
def deferNext(self):
|
||||
"""
|
||||
Defer the next result from my worker iterator.
|
||||
"""
|
||||
self._doDeferNext = True
|
||||
|
||||
def stopNext(self):
|
||||
"""
|
||||
Make the next result from my worker iterator be completion (raising
|
||||
StopIteration).
|
||||
"""
|
||||
self._doStopNext = True
|
||||
|
||||
def dieNext(self):
|
||||
"""
|
||||
Make the next result from my worker iterator be raising an
|
||||
L{UnhandledException}.
|
||||
"""
|
||||
|
||||
def ignoreUnhandled(failure):
|
||||
failure.trap(UnhandledException)
|
||||
return None
|
||||
|
||||
self._doDieNext = True
|
||||
|
||||
def test_pauseResume(self):
|
||||
"""
|
||||
Cooperators should stop running their tasks when they're paused, and
|
||||
start again when they're resumed.
|
||||
"""
|
||||
# first, sanity check
|
||||
self.scheduler.pump()
|
||||
self.assertEqual(self.work, [1])
|
||||
self.scheduler.pump()
|
||||
self.assertEqual(self.work, [1, 2])
|
||||
|
||||
# OK, now for real
|
||||
self.task.pause()
|
||||
self.scheduler.pump()
|
||||
self.assertEqual(self.work, [1, 2])
|
||||
self.task.resume()
|
||||
# Resuming itself shoult not do any work
|
||||
self.assertEqual(self.work, [1, 2])
|
||||
self.scheduler.pump()
|
||||
# But when the scheduler rolls around again...
|
||||
self.assertEqual(self.work, [1, 2, 3])
|
||||
|
||||
def test_resumeNotPaused(self):
|
||||
"""
|
||||
L{CooperativeTask.resume} should raise a L{TaskNotPaused} exception if
|
||||
it was not paused; e.g. if L{CooperativeTask.pause} was not invoked
|
||||
more times than L{CooperativeTask.resume} on that object.
|
||||
"""
|
||||
self.assertRaises(task.NotPaused, self.task.resume)
|
||||
self.task.pause()
|
||||
self.task.resume()
|
||||
self.assertRaises(task.NotPaused, self.task.resume)
|
||||
|
||||
def test_pauseTwice(self):
|
||||
"""
|
||||
Pauses on tasks should behave like a stack. If a task is paused twice,
|
||||
it needs to be resumed twice.
|
||||
"""
|
||||
# pause once
|
||||
self.task.pause()
|
||||
self.scheduler.pump()
|
||||
self.assertEqual(self.work, [])
|
||||
# pause twice
|
||||
self.task.pause()
|
||||
self.scheduler.pump()
|
||||
self.assertEqual(self.work, [])
|
||||
# resume once (it shouldn't)
|
||||
self.task.resume()
|
||||
self.scheduler.pump()
|
||||
self.assertEqual(self.work, [])
|
||||
# resume twice (now it should go)
|
||||
self.task.resume()
|
||||
self.scheduler.pump()
|
||||
self.assertEqual(self.work, [1])
|
||||
|
||||
def test_pauseWhileDeferred(self):
|
||||
"""
|
||||
C{pause()}ing a task while it is waiting on an outstanding
|
||||
L{defer.Deferred} should put the task into a state where the
|
||||
outstanding L{defer.Deferred} must be called back I{and} the task is
|
||||
C{resume}d before it will continue processing.
|
||||
"""
|
||||
self.deferNext()
|
||||
self.scheduler.pump()
|
||||
self.assertEqual(len(self.work), 1)
|
||||
self.assertIsInstance(self.work[0], defer.Deferred)
|
||||
self.scheduler.pump()
|
||||
self.assertEqual(len(self.work), 1)
|
||||
self.task.pause()
|
||||
self.scheduler.pump()
|
||||
self.assertEqual(len(self.work), 1)
|
||||
self.task.resume()
|
||||
self.scheduler.pump()
|
||||
self.assertEqual(len(self.work), 1)
|
||||
self.work[0].callback("STUFF!")
|
||||
self.scheduler.pump()
|
||||
self.assertEqual(len(self.work), 2)
|
||||
self.assertEqual(self.work[1], 2)
|
||||
|
||||
def test_whenDone(self):
|
||||
"""
|
||||
L{CooperativeTask.whenDone} returns a Deferred which fires when the
|
||||
Cooperator's iterator is exhausted. It returns a new Deferred each
|
||||
time it is called; callbacks added to other invocations will not modify
|
||||
the value that subsequent invocations will fire with.
|
||||
"""
|
||||
|
||||
deferred1 = self.task.whenDone()
|
||||
deferred2 = self.task.whenDone()
|
||||
results1 = []
|
||||
results2 = []
|
||||
final1 = []
|
||||
final2 = []
|
||||
|
||||
def callbackOne(result):
|
||||
results1.append(result)
|
||||
return 1
|
||||
|
||||
def callbackTwo(result):
|
||||
results2.append(result)
|
||||
return 2
|
||||
|
||||
deferred1.addCallback(callbackOne)
|
||||
deferred2.addCallback(callbackTwo)
|
||||
|
||||
deferred1.addCallback(final1.append)
|
||||
deferred2.addCallback(final2.append)
|
||||
|
||||
# exhaust the task iterator
|
||||
# callbacks fire
|
||||
self.stopNext()
|
||||
self.scheduler.pump()
|
||||
|
||||
self.assertEqual(len(results1), 1)
|
||||
self.assertEqual(len(results2), 1)
|
||||
|
||||
self.assertIs(results1[0], self.task._iterator)
|
||||
self.assertIs(results2[0], self.task._iterator)
|
||||
|
||||
self.assertEqual(final1, [1])
|
||||
self.assertEqual(final2, [2])
|
||||
|
||||
def test_whenDoneError(self):
|
||||
"""
|
||||
L{CooperativeTask.whenDone} returns a L{defer.Deferred} that will fail
|
||||
when the iterable's C{next} method raises an exception, with that
|
||||
exception.
|
||||
"""
|
||||
deferred1 = self.task.whenDone()
|
||||
results = []
|
||||
deferred1.addErrback(results.append)
|
||||
self.dieNext()
|
||||
self.scheduler.pump()
|
||||
self.assertEqual(len(results), 1)
|
||||
self.assertEqual(results[0].check(UnhandledException), UnhandledException)
|
||||
|
||||
def test_whenDoneStop(self):
|
||||
"""
|
||||
L{CooperativeTask.whenDone} returns a L{defer.Deferred} that fails with
|
||||
L{TaskStopped} when the C{stop} method is called on that
|
||||
L{CooperativeTask}.
|
||||
"""
|
||||
deferred1 = self.task.whenDone()
|
||||
errors = []
|
||||
deferred1.addErrback(errors.append)
|
||||
self.task.stop()
|
||||
self.assertEqual(len(errors), 1)
|
||||
self.assertEqual(errors[0].check(task.TaskStopped), task.TaskStopped)
|
||||
|
||||
def test_whenDoneAlreadyDone(self):
|
||||
"""
|
||||
L{CooperativeTask.whenDone} will return a L{defer.Deferred} that will
|
||||
succeed immediately if its iterator has already completed.
|
||||
"""
|
||||
self.stopNext()
|
||||
self.scheduler.pump()
|
||||
results = []
|
||||
self.task.whenDone().addCallback(results.append)
|
||||
self.assertEqual(results, [self.task._iterator])
|
||||
|
||||
def test_stopStops(self):
|
||||
"""
|
||||
C{stop()}ping a task should cause it to be removed from the run just as
|
||||
C{pause()}ing, with the distinction that C{resume()} will raise a
|
||||
L{TaskStopped} exception.
|
||||
"""
|
||||
self.task.stop()
|
||||
self.scheduler.pump()
|
||||
self.assertEqual(len(self.work), 0)
|
||||
self.assertRaises(task.TaskStopped, self.task.stop)
|
||||
self.assertRaises(task.TaskStopped, self.task.pause)
|
||||
# Sanity check - it's still not scheduled, is it?
|
||||
self.scheduler.pump()
|
||||
self.assertEqual(self.work, [])
|
||||
|
||||
def test_pauseStopResume(self):
|
||||
"""
|
||||
C{resume()}ing a paused, stopped task should be a no-op; it should not
|
||||
raise an exception, because it's paused, but neither should it actually
|
||||
do more work from the task.
|
||||
"""
|
||||
self.task.pause()
|
||||
self.task.stop()
|
||||
self.task.resume()
|
||||
self.scheduler.pump()
|
||||
self.assertEqual(self.work, [])
|
||||
|
||||
def test_stopDeferred(self):
|
||||
"""
|
||||
As a corrolary of the interaction of C{pause()} and C{unpause()},
|
||||
C{stop()}ping a task which is waiting on a L{Deferred} should cause the
|
||||
task to gracefully shut down, meaning that it should not be unpaused
|
||||
when the deferred fires.
|
||||
"""
|
||||
self.deferNext()
|
||||
self.scheduler.pump()
|
||||
d = self.work.pop()
|
||||
self.assertEqual(self.task._pauseCount, 1)
|
||||
results = []
|
||||
d.addBoth(results.append)
|
||||
self.scheduler.pump()
|
||||
self.task.stop()
|
||||
self.scheduler.pump()
|
||||
d.callback(7)
|
||||
self.scheduler.pump()
|
||||
# Let's make sure that Deferred doesn't come out fried with an
|
||||
# unhandled error that will be logged. The value is None, rather than
|
||||
# our test value, 7, because this Deferred is returned to and consumed
|
||||
# by the cooperator code. Its callback therefore has no contract.
|
||||
self.assertEqual(results, [None])
|
||||
# But more importantly, no further work should have happened.
|
||||
self.assertEqual(self.work, [])
|
||||
|
||||
def test_stopExhausted(self):
|
||||
"""
|
||||
C{stop()}ping a L{CooperativeTask} whose iterator has been exhausted
|
||||
should raise L{TaskDone}.
|
||||
"""
|
||||
self.stopNext()
|
||||
self.scheduler.pump()
|
||||
self.assertRaises(task.TaskDone, self.task.stop)
|
||||
|
||||
def test_stopErrored(self):
|
||||
"""
|
||||
C{stop()}ping a L{CooperativeTask} whose iterator has encountered an
|
||||
error should raise L{TaskFailed}.
|
||||
"""
|
||||
self.dieNext()
|
||||
self.scheduler.pump()
|
||||
self.assertRaises(task.TaskFailed, self.task.stop)
|
||||
|
||||
def test_stopCooperatorReentrancy(self):
|
||||
"""
|
||||
If a callback of a L{Deferred} from L{CooperativeTask.whenDone} calls
|
||||
C{Cooperator.stop} on its L{CooperativeTask._cooperator}, the
|
||||
L{Cooperator} will stop, but the L{CooperativeTask} whose callback is
|
||||
calling C{stop} should already be considered 'stopped' by the time the
|
||||
callback is running, and therefore removed from the
|
||||
L{CoooperativeTask}.
|
||||
"""
|
||||
callbackPhases = []
|
||||
|
||||
def stopit(result):
|
||||
callbackPhases.append(result)
|
||||
self.cooperator.stop()
|
||||
# "done" here is a sanity check to make sure that we get all the
|
||||
# way through the callback; i.e. stop() shouldn't be raising an
|
||||
# exception due to the stopped-ness of our main task.
|
||||
callbackPhases.append("done")
|
||||
|
||||
self.task.whenDone().addCallback(stopit)
|
||||
self.stopNext()
|
||||
self.scheduler.pump()
|
||||
self.assertEqual(callbackPhases, [self.task._iterator, "done"])
|
||||
4039
.venv/lib/python3.12/site-packages/twisted/test/test_defer.py
Normal file
4039
.venv/lib/python3.12/site-packages/twisted/test/test_defer.py
Normal file
File diff suppressed because it is too large
Load Diff
235
.venv/lib/python3.12/site-packages/twisted/test/test_defgen.py
Normal file
235
.venv/lib/python3.12/site-packages/twisted/test/test_defgen.py
Normal file
@@ -0,0 +1,235 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.internet.defer.deferredGenerator} and related APIs.
|
||||
"""
|
||||
|
||||
from twisted.internet import defer, reactor
|
||||
from twisted.internet.defer import Deferred, deferredGenerator, waitForDeferred
|
||||
from twisted.python.util import runWithWarningsSuppressed
|
||||
from twisted.trial import unittest
|
||||
from twisted.trial.util import suppress as SUPPRESS
|
||||
|
||||
|
||||
def getThing():
|
||||
d = Deferred()
|
||||
reactor.callLater(0, d.callback, "hi")
|
||||
return d
|
||||
|
||||
|
||||
def getOwie():
|
||||
d = Deferred()
|
||||
|
||||
def CRAP():
|
||||
d.errback(ZeroDivisionError("OMG"))
|
||||
|
||||
reactor.callLater(0, CRAP)
|
||||
return d
|
||||
|
||||
|
||||
class TerminalException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def deprecatedDeferredGenerator(f):
|
||||
"""
|
||||
Calls L{deferredGenerator} while suppressing the deprecation warning.
|
||||
|
||||
@param f: Function to call
|
||||
@return: Return value of function.
|
||||
"""
|
||||
return runWithWarningsSuppressed(
|
||||
[
|
||||
SUPPRESS(
|
||||
message="twisted.internet.defer.deferredGenerator was " "deprecated"
|
||||
)
|
||||
],
|
||||
deferredGenerator,
|
||||
f,
|
||||
)
|
||||
|
||||
|
||||
class DeferredGeneratorTests(unittest.TestCase):
|
||||
def testBasics(self):
|
||||
"""
|
||||
Test that a normal deferredGenerator works. Tests yielding a
|
||||
deferred which callbacks, as well as a deferred errbacks. Also
|
||||
ensures returning a final value works.
|
||||
"""
|
||||
|
||||
@deprecatedDeferredGenerator
|
||||
def _genBasics():
|
||||
x = waitForDeferred(getThing())
|
||||
yield x
|
||||
x = x.getResult()
|
||||
|
||||
self.assertEqual(x, "hi")
|
||||
|
||||
ow = waitForDeferred(getOwie())
|
||||
yield ow
|
||||
try:
|
||||
ow.getResult()
|
||||
except ZeroDivisionError as e:
|
||||
self.assertEqual(str(e), "OMG")
|
||||
yield "WOOSH"
|
||||
return
|
||||
|
||||
return _genBasics().addCallback(self.assertEqual, "WOOSH")
|
||||
|
||||
def testProducesException(self):
|
||||
"""
|
||||
Ensure that a generator that produces an exception signals
|
||||
a Failure condition on result deferred by converting the exception to
|
||||
a L{Failure}.
|
||||
"""
|
||||
|
||||
@deprecatedDeferredGenerator
|
||||
def _genProduceException():
|
||||
yield waitForDeferred(getThing())
|
||||
1 // 0
|
||||
|
||||
return self.assertFailure(_genProduceException(), ZeroDivisionError)
|
||||
|
||||
def testNothing(self):
|
||||
"""Test that a generator which never yields results in None."""
|
||||
|
||||
@deprecatedDeferredGenerator
|
||||
def _genNothing():
|
||||
if False:
|
||||
yield 1 # pragma: no cover
|
||||
|
||||
return _genNothing().addCallback(self.assertEqual, None)
|
||||
|
||||
def testHandledTerminalFailure(self):
|
||||
"""
|
||||
Create a Deferred Generator which yields a Deferred which fails and
|
||||
handles the exception which results. Assert that the Deferred
|
||||
Generator does not errback its Deferred.
|
||||
"""
|
||||
|
||||
@deprecatedDeferredGenerator
|
||||
def _genHandledTerminalFailure():
|
||||
x = waitForDeferred(
|
||||
defer.fail(TerminalException("Handled Terminal Failure"))
|
||||
)
|
||||
yield x
|
||||
try:
|
||||
x.getResult()
|
||||
except TerminalException:
|
||||
pass
|
||||
|
||||
return _genHandledTerminalFailure().addCallback(self.assertEqual, None)
|
||||
|
||||
def testHandledTerminalAsyncFailure(self):
|
||||
"""
|
||||
Just like testHandledTerminalFailure, only with a Deferred which fires
|
||||
asynchronously with an error.
|
||||
"""
|
||||
|
||||
@deprecatedDeferredGenerator
|
||||
def _genHandledTerminalAsyncFailure(d):
|
||||
x = waitForDeferred(d)
|
||||
yield x
|
||||
try:
|
||||
x.getResult()
|
||||
except TerminalException:
|
||||
pass
|
||||
|
||||
d = defer.Deferred()
|
||||
deferredGeneratorResultDeferred = _genHandledTerminalAsyncFailure(d)
|
||||
d.errback(TerminalException("Handled Terminal Failure"))
|
||||
return deferredGeneratorResultDeferred.addCallback(self.assertEqual, None)
|
||||
|
||||
def testStackUsage(self):
|
||||
"""
|
||||
Make sure we don't blow the stack when yielding immediately
|
||||
available deferreds.
|
||||
"""
|
||||
|
||||
@deprecatedDeferredGenerator
|
||||
def _genStackUsage():
|
||||
for x in range(5000):
|
||||
# Test with yielding a deferred
|
||||
x = waitForDeferred(defer.succeed(1))
|
||||
yield x
|
||||
x = x.getResult()
|
||||
yield 0
|
||||
|
||||
return _genStackUsage().addCallback(self.assertEqual, 0)
|
||||
|
||||
def testStackUsage2(self):
|
||||
"""
|
||||
Make sure we don't blow the stack when yielding immediately
|
||||
available values.
|
||||
"""
|
||||
|
||||
@deprecatedDeferredGenerator
|
||||
def _genStackUsage2():
|
||||
for x in range(5000):
|
||||
# Test with yielding a random value
|
||||
yield 1
|
||||
yield 0
|
||||
|
||||
return _genStackUsage2().addCallback(self.assertEqual, 0)
|
||||
|
||||
def testDeferredYielding(self):
|
||||
"""
|
||||
Ensure that yielding a Deferred directly is trapped as an
|
||||
error.
|
||||
"""
|
||||
|
||||
# See the comment _deferGenerator about d.callback(Deferred).
|
||||
def _genDeferred():
|
||||
yield getThing()
|
||||
|
||||
_genDeferred = deprecatedDeferredGenerator(_genDeferred)
|
||||
|
||||
return self.assertFailure(_genDeferred(), TypeError)
|
||||
|
||||
suppress = [
|
||||
SUPPRESS(message="twisted.internet.defer.waitForDeferred was " "deprecated")
|
||||
]
|
||||
|
||||
|
||||
class DeprecateDeferredGeneratorTests(unittest.SynchronousTestCase):
|
||||
"""
|
||||
Tests that L{DeferredGeneratorTests} and L{waitForDeferred} are
|
||||
deprecated.
|
||||
"""
|
||||
|
||||
def test_deferredGeneratorDeprecated(self):
|
||||
"""
|
||||
L{deferredGenerator} is deprecated.
|
||||
"""
|
||||
|
||||
@deferredGenerator
|
||||
def decoratedFunction():
|
||||
yield None
|
||||
|
||||
warnings = self.flushWarnings([self.test_deferredGeneratorDeprecated])
|
||||
self.assertEqual(len(warnings), 1)
|
||||
self.assertEqual(warnings[0]["category"], DeprecationWarning)
|
||||
self.assertEqual(
|
||||
warnings[0]["message"],
|
||||
"twisted.internet.defer.deferredGenerator was deprecated in "
|
||||
"Twisted 15.0.0; please use "
|
||||
"twisted.internet.defer.inlineCallbacks instead",
|
||||
)
|
||||
|
||||
def test_waitForDeferredDeprecated(self):
|
||||
"""
|
||||
L{waitForDeferred} is deprecated.
|
||||
"""
|
||||
d = Deferred()
|
||||
waitForDeferred(d)
|
||||
|
||||
warnings = self.flushWarnings([self.test_waitForDeferredDeprecated])
|
||||
self.assertEqual(len(warnings), 1)
|
||||
self.assertEqual(warnings[0]["category"], DeprecationWarning)
|
||||
self.assertEqual(
|
||||
warnings[0]["message"],
|
||||
"twisted.internet.defer.waitForDeferred was deprecated in "
|
||||
"Twisted 15.0.0; please use "
|
||||
"twisted.internet.defer.inlineCallbacks instead",
|
||||
)
|
||||
218
.venv/lib/python3.12/site-packages/twisted/test/test_dirdbm.py
Normal file
218
.venv/lib/python3.12/site-packages/twisted/test/test_dirdbm.py
Normal file
@@ -0,0 +1,218 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Test cases for dirdbm module.
|
||||
"""
|
||||
|
||||
import shutil
|
||||
from base64 import b64decode
|
||||
|
||||
from twisted.persisted import dirdbm
|
||||
from twisted.python import rebuild
|
||||
from twisted.python.filepath import FilePath
|
||||
from twisted.trial import unittest
|
||||
|
||||
|
||||
class DirDbmTests(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.path = FilePath(self.mktemp())
|
||||
self.dbm = dirdbm.open(self.path.path)
|
||||
self.items: tuple[tuple[bytes, bytes | int | float | tuple[None, int]], ...] = (
|
||||
(b"abc", b"foo"),
|
||||
(b"/lalal", b"\000\001"),
|
||||
(b"\000\012", b"baz"),
|
||||
)
|
||||
|
||||
def test_all(self) -> None:
|
||||
k = b64decode("//==")
|
||||
self.dbm[k] = b"a"
|
||||
self.dbm[k] = b"a"
|
||||
self.assertEqual(self.dbm[k], b"a")
|
||||
|
||||
def test_rebuildInteraction(self) -> None:
|
||||
s = dirdbm.Shelf("dirdbm.rebuild.test")
|
||||
s[b"key"] = b"value"
|
||||
rebuild.rebuild(dirdbm)
|
||||
|
||||
def test_dbm(self) -> None:
|
||||
d = self.dbm
|
||||
|
||||
# Insert keys
|
||||
keys = []
|
||||
values = set()
|
||||
for k, v in self.items:
|
||||
d[k] = v
|
||||
keys.append(k)
|
||||
values.add(v)
|
||||
keys.sort()
|
||||
|
||||
# Check they exist
|
||||
for k, v in self.items:
|
||||
self.assertIn(k, d)
|
||||
self.assertEqual(d[k], v)
|
||||
|
||||
# Check non existent key
|
||||
try:
|
||||
d[b"XXX"]
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
assert 0, "didn't raise KeyError on non-existent key"
|
||||
|
||||
# Check keys(), values() and items()
|
||||
dbkeys = d.keys()
|
||||
dbvalues = set(d.values())
|
||||
dbitems = set(d.items())
|
||||
dbkeys.sort()
|
||||
items = set(self.items)
|
||||
self.assertEqual(
|
||||
keys,
|
||||
dbkeys,
|
||||
f".keys() output didn't match: {repr(keys)} != {repr(dbkeys)}",
|
||||
)
|
||||
self.assertEqual(
|
||||
values,
|
||||
dbvalues,
|
||||
".values() output didn't match: {} != {}".format(
|
||||
repr(values), repr(dbvalues)
|
||||
),
|
||||
)
|
||||
self.assertEqual(
|
||||
items,
|
||||
dbitems,
|
||||
f"items() didn't match: {repr(items)} != {repr(dbitems)}",
|
||||
)
|
||||
|
||||
copyPath = self.mktemp()
|
||||
d2 = d.copyTo(copyPath)
|
||||
|
||||
copykeys = d.keys()
|
||||
copyvalues = set(d.values())
|
||||
copyitems = set(d.items())
|
||||
copykeys.sort()
|
||||
|
||||
self.assertEqual(
|
||||
dbkeys,
|
||||
copykeys,
|
||||
".copyTo().keys() didn't match: {} != {}".format(
|
||||
repr(dbkeys), repr(copykeys)
|
||||
),
|
||||
)
|
||||
self.assertEqual(
|
||||
dbvalues,
|
||||
copyvalues,
|
||||
".copyTo().values() didn't match: %s != %s"
|
||||
% (repr(dbvalues), repr(copyvalues)),
|
||||
)
|
||||
self.assertEqual(
|
||||
dbitems,
|
||||
copyitems,
|
||||
".copyTo().items() didn't match: %s != %s"
|
||||
% (repr(dbkeys), repr(copyitems)),
|
||||
)
|
||||
|
||||
d2.clear()
|
||||
self.assertTrue(
|
||||
len(d2.keys()) == len(d2.values()) == len(d2.items()) == len(d2) == 0,
|
||||
".clear() failed",
|
||||
)
|
||||
self.assertNotEqual(len(d), len(d2))
|
||||
shutil.rmtree(copyPath)
|
||||
|
||||
# Delete items
|
||||
for k, v in self.items:
|
||||
del d[k]
|
||||
self.assertNotIn(
|
||||
k, d, "key is still in database, even though we deleted it"
|
||||
)
|
||||
self.assertEqual(len(d.keys()), 0, "database has keys")
|
||||
self.assertEqual(len(d.values()), 0, "database has values")
|
||||
self.assertEqual(len(d.items()), 0, "database has items")
|
||||
self.assertEqual(len(d), 0, "database has items")
|
||||
|
||||
def test_modificationTime(self) -> None:
|
||||
import time
|
||||
|
||||
# The mtime value for files comes from a different place than the
|
||||
# gettimeofday() system call. On linux, gettimeofday() can be
|
||||
# slightly ahead (due to clock drift which gettimeofday() takes into
|
||||
# account but which open()/write()/close() do not), and if we are
|
||||
# close to the edge of the next second, time.time() can give a value
|
||||
# which is larger than the mtime which results from a subsequent
|
||||
# write(). I consider this a kernel bug, but it is beyond the scope
|
||||
# of this test. Thus we keep the range of acceptability to 3 seconds time.
|
||||
# -warner
|
||||
self.dbm[b"k"] = b"v"
|
||||
self.assertTrue(abs(time.time() - self.dbm.getModificationTime(b"k")) <= 3)
|
||||
self.assertRaises(KeyError, self.dbm.getModificationTime, b"nokey")
|
||||
|
||||
def test_recovery(self) -> None:
|
||||
"""
|
||||
DirDBM: test recovery from directory after a faked crash
|
||||
"""
|
||||
k = self.dbm._encode(b"key1")
|
||||
with self.path.child(k + b".rpl").open(mode="w") as f:
|
||||
f.write(b"value")
|
||||
|
||||
k2 = self.dbm._encode(b"key2")
|
||||
with self.path.child(k2).open(mode="w") as f:
|
||||
f.write(b"correct")
|
||||
with self.path.child(k2 + b".rpl").open(mode="w") as f:
|
||||
f.write(b"wrong")
|
||||
|
||||
with self.path.child("aa.new").open(mode="w") as f:
|
||||
f.write(b"deleted")
|
||||
|
||||
dbm = dirdbm.DirDBM(self.path.path)
|
||||
self.assertEqual(dbm[b"key1"], b"value")
|
||||
self.assertEqual(dbm[b"key2"], b"correct")
|
||||
self.assertFalse(self.path.globChildren("*.new"))
|
||||
self.assertFalse(self.path.globChildren("*.rpl"))
|
||||
|
||||
def test_nonStringKeys(self) -> None:
|
||||
"""
|
||||
L{dirdbm.DirDBM} operations only support string keys: other types
|
||||
should raise a L{TypeError}.
|
||||
"""
|
||||
self.assertRaises(TypeError, self.dbm.__setitem__, 2, "3")
|
||||
try:
|
||||
self.assertRaises(TypeError, self.dbm.__setitem__, "2", 3)
|
||||
except unittest.FailTest:
|
||||
# dirdbm.Shelf.__setitem__ supports non-string values
|
||||
self.assertIsInstance(self.dbm, dirdbm.Shelf)
|
||||
self.assertRaises(TypeError, self.dbm.__getitem__, 2)
|
||||
self.assertRaises(TypeError, self.dbm.__delitem__, 2)
|
||||
self.assertRaises(TypeError, self.dbm.has_key, 2)
|
||||
self.assertRaises(TypeError, self.dbm.__contains__, 2)
|
||||
self.assertRaises(TypeError, self.dbm.getModificationTime, 2)
|
||||
|
||||
def test_failSet(self) -> None:
|
||||
"""
|
||||
Failure path when setting an item.
|
||||
"""
|
||||
|
||||
def _writeFail(path: FilePath[str], data: bytes) -> None:
|
||||
path.setContent(data)
|
||||
raise OSError("fail to write")
|
||||
|
||||
self.dbm[b"failkey"] = b"test"
|
||||
self.patch(self.dbm, "_writeFile", _writeFail)
|
||||
self.assertRaises(IOError, self.dbm.__setitem__, b"failkey", b"test2")
|
||||
|
||||
|
||||
class ShelfTests(DirDbmTests):
|
||||
def setUp(self) -> None:
|
||||
self.path = FilePath(self.mktemp())
|
||||
self.dbm = dirdbm.Shelf(self.path.path)
|
||||
self.items = (
|
||||
(b"abc", b"foo"),
|
||||
(b"/lalal", b"\000\001"),
|
||||
(b"\000\012", b"baz"),
|
||||
(b"int", 12),
|
||||
(b"float", 12.0),
|
||||
(b"tuple", (None, 12)),
|
||||
)
|
||||
|
||||
|
||||
testCases = [DirDbmTests, ShelfTests]
|
||||
290
.venv/lib/python3.12/site-packages/twisted/test/test_error.py
Normal file
290
.venv/lib/python3.12/site-packages/twisted/test/test_error.py
Normal file
@@ -0,0 +1,290 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
from __future__ import annotations
|
||||
|
||||
import errno
|
||||
import socket
|
||||
import sys
|
||||
from typing import Sequence
|
||||
|
||||
from twisted.internet import error
|
||||
from twisted.trial import unittest
|
||||
|
||||
|
||||
class StringificationTests(unittest.SynchronousTestCase):
|
||||
"""Test that the exceptions have useful stringifications."""
|
||||
|
||||
listOfTests: list[
|
||||
tuple[
|
||||
str,
|
||||
type[Exception],
|
||||
Sequence[str | int | Exception | None],
|
||||
dict[str, str | int],
|
||||
]
|
||||
] = [
|
||||
# (output, exception[, args[, kwargs]]),
|
||||
("An error occurred binding to an interface.", error.BindError, [], {}),
|
||||
(
|
||||
"An error occurred binding to an interface: foo.",
|
||||
error.BindError,
|
||||
["foo"],
|
||||
{},
|
||||
),
|
||||
(
|
||||
"An error occurred binding to an interface: foo bar.",
|
||||
error.BindError,
|
||||
["foo", "bar"],
|
||||
{},
|
||||
),
|
||||
(
|
||||
"Couldn't listen on eth0:4242: Foo.",
|
||||
error.CannotListenError,
|
||||
("eth0", 4242, socket.error("Foo")),
|
||||
{},
|
||||
),
|
||||
("Message is too long to send.", error.MessageLengthError, [], {}),
|
||||
(
|
||||
"Message is too long to send: foo bar.",
|
||||
error.MessageLengthError,
|
||||
["foo", "bar"],
|
||||
{},
|
||||
),
|
||||
("DNS lookup failed.", error.DNSLookupError, [], {}),
|
||||
("DNS lookup failed: foo bar.", error.DNSLookupError, ["foo", "bar"], {}),
|
||||
("An error occurred while connecting.", error.ConnectError, [], {}),
|
||||
(
|
||||
"An error occurred while connecting: someOsError.",
|
||||
error.ConnectError,
|
||||
["someOsError"],
|
||||
{},
|
||||
),
|
||||
(
|
||||
"An error occurred while connecting: foo.",
|
||||
error.ConnectError,
|
||||
[],
|
||||
{"string": "foo"},
|
||||
),
|
||||
(
|
||||
"An error occurred while connecting: someOsError: foo.",
|
||||
error.ConnectError,
|
||||
["someOsError", "foo"],
|
||||
{},
|
||||
),
|
||||
("Couldn't bind.", error.ConnectBindError, [], {}),
|
||||
("Couldn't bind: someOsError.", error.ConnectBindError, ["someOsError"], {}),
|
||||
(
|
||||
"Couldn't bind: someOsError: foo.",
|
||||
error.ConnectBindError,
|
||||
["someOsError", "foo"],
|
||||
{},
|
||||
),
|
||||
("Hostname couldn't be looked up.", error.UnknownHostError, [], {}),
|
||||
("No route to host.", error.NoRouteError, [], {}),
|
||||
("Connection was refused by other side.", error.ConnectionRefusedError, [], {}),
|
||||
("TCP connection timed out.", error.TCPTimedOutError, [], {}),
|
||||
("File used for UNIX socket is no good.", error.BadFileError, [], {}),
|
||||
(
|
||||
"Service name given as port is unknown.",
|
||||
error.ServiceNameUnknownError,
|
||||
[],
|
||||
{},
|
||||
),
|
||||
("User aborted connection.", error.UserError, [], {}),
|
||||
("User timeout caused connection failure.", error.TimeoutError, [], {}),
|
||||
("An SSL error occurred.", error.SSLError, [], {}),
|
||||
(
|
||||
"Connection to the other side was lost in a non-clean fashion.",
|
||||
error.ConnectionLost,
|
||||
[],
|
||||
{},
|
||||
),
|
||||
(
|
||||
"Connection to the other side was lost in a non-clean fashion: foo bar.",
|
||||
error.ConnectionLost,
|
||||
["foo", "bar"],
|
||||
{},
|
||||
),
|
||||
("Connection was closed cleanly.", error.ConnectionDone, [], {}),
|
||||
(
|
||||
"Connection was closed cleanly: foo bar.",
|
||||
error.ConnectionDone,
|
||||
["foo", "bar"],
|
||||
{},
|
||||
),
|
||||
(
|
||||
"Uh.", # TODO nice docstring, you've got there.
|
||||
error.ConnectionFdescWentAway,
|
||||
[],
|
||||
{},
|
||||
),
|
||||
("Tried to cancel an already-called event.", error.AlreadyCalled, [], {}),
|
||||
(
|
||||
"Tried to cancel an already-called event: foo bar.",
|
||||
error.AlreadyCalled,
|
||||
["foo", "bar"],
|
||||
{},
|
||||
),
|
||||
("Tried to cancel an already-cancelled event.", error.AlreadyCancelled, [], {}),
|
||||
(
|
||||
"Tried to cancel an already-cancelled event: x 2.",
|
||||
error.AlreadyCancelled,
|
||||
["x", "2"],
|
||||
{},
|
||||
),
|
||||
(
|
||||
"A process has ended without apparent errors: process finished with exit code 0.",
|
||||
error.ProcessDone,
|
||||
[None],
|
||||
{},
|
||||
),
|
||||
(
|
||||
"A process has ended with a probable error condition: process ended.",
|
||||
error.ProcessTerminated,
|
||||
[],
|
||||
{},
|
||||
),
|
||||
(
|
||||
"A process has ended with a probable error condition: process ended with exit code 42.",
|
||||
error.ProcessTerminated,
|
||||
[],
|
||||
{"exitCode": 42},
|
||||
),
|
||||
(
|
||||
"A process has ended with a probable error condition: process ended by signal SIGBUS.",
|
||||
error.ProcessTerminated,
|
||||
[],
|
||||
{"signal": "SIGBUS"},
|
||||
),
|
||||
(
|
||||
"The Connector was not connecting when it was asked to stop connecting.",
|
||||
error.NotConnectingError,
|
||||
[],
|
||||
{},
|
||||
),
|
||||
(
|
||||
"The Connector was not connecting when it was asked to stop connecting: x 13.",
|
||||
error.NotConnectingError,
|
||||
["x", "13"],
|
||||
{},
|
||||
),
|
||||
(
|
||||
"The Port was not listening when it was asked to stop listening.",
|
||||
error.NotListeningError,
|
||||
[],
|
||||
{},
|
||||
),
|
||||
(
|
||||
"The Port was not listening when it was asked to stop listening: a 12.",
|
||||
error.NotListeningError,
|
||||
["a", "12"],
|
||||
{},
|
||||
),
|
||||
]
|
||||
|
||||
def testThemAll(self) -> None:
|
||||
for entry in self.listOfTests:
|
||||
output = entry[0]
|
||||
exception = entry[1]
|
||||
args = entry[2]
|
||||
kwargs = entry[3]
|
||||
|
||||
self.assertEqual(str(exception(*args, **kwargs)), output)
|
||||
|
||||
def test_connectingCancelledError(self) -> None:
|
||||
"""
|
||||
L{error.ConnectingCancelledError} has an C{address} attribute.
|
||||
"""
|
||||
address = object()
|
||||
e = error.ConnectingCancelledError(address)
|
||||
self.assertIs(e.address, address)
|
||||
|
||||
|
||||
class SubclassingTests(unittest.SynchronousTestCase):
|
||||
"""
|
||||
Some exceptions are subclasses of other exceptions.
|
||||
"""
|
||||
|
||||
def test_connectionLostSubclassOfConnectionClosed(self) -> None:
|
||||
"""
|
||||
L{error.ConnectionClosed} is a superclass of L{error.ConnectionLost}.
|
||||
"""
|
||||
self.assertTrue(issubclass(error.ConnectionLost, error.ConnectionClosed))
|
||||
|
||||
def test_connectionDoneSubclassOfConnectionClosed(self) -> None:
|
||||
"""
|
||||
L{error.ConnectionClosed} is a superclass of L{error.ConnectionDone}.
|
||||
"""
|
||||
self.assertTrue(issubclass(error.ConnectionDone, error.ConnectionClosed))
|
||||
|
||||
def test_invalidAddressErrorSubclassOfValueError(self) -> None:
|
||||
"""
|
||||
L{ValueError} is a superclass of L{error.InvalidAddressError}.
|
||||
"""
|
||||
self.assertTrue(issubclass(error.InvalidAddressError, ValueError))
|
||||
|
||||
|
||||
class GetConnectErrorTests(unittest.SynchronousTestCase):
|
||||
"""
|
||||
Given an exception instance thrown by C{socket.connect},
|
||||
L{error.getConnectError} returns the appropriate high-level Twisted
|
||||
exception instance.
|
||||
"""
|
||||
|
||||
def assertErrnoException(
|
||||
self, errno: int, expectedClass: type[error.ConnectError]
|
||||
) -> None:
|
||||
"""
|
||||
When called with a tuple with the given errno,
|
||||
L{error.getConnectError} returns an exception which is an instance of
|
||||
the expected class.
|
||||
"""
|
||||
e = (errno, "lalala")
|
||||
result = error.getConnectError(e)
|
||||
self.assertCorrectException(errno, "lalala", result, expectedClass)
|
||||
|
||||
def assertCorrectException(
|
||||
self,
|
||||
errno: int | None,
|
||||
message: object,
|
||||
result: error.ConnectError,
|
||||
expectedClass: type[error.ConnectError],
|
||||
) -> None:
|
||||
"""
|
||||
The given result of L{error.getConnectError} has the given attributes
|
||||
(C{osError} and C{args}), and is an instance of the given class.
|
||||
"""
|
||||
|
||||
# Want exact class match, not inherited classes, so no isinstance():
|
||||
self.assertEqual(result.__class__, expectedClass)
|
||||
self.assertEqual(result.osError, errno)
|
||||
self.assertEqual(result.args, (message,))
|
||||
|
||||
def test_errno(self) -> None:
|
||||
"""
|
||||
L{error.getConnectError} converts based on errno for C{socket.error}.
|
||||
"""
|
||||
self.assertErrnoException(errno.ENETUNREACH, error.NoRouteError)
|
||||
self.assertErrnoException(errno.ECONNREFUSED, error.ConnectionRefusedError)
|
||||
self.assertErrnoException(errno.ETIMEDOUT, error.TCPTimedOutError)
|
||||
if sys.platform == "win32":
|
||||
self.assertErrnoException(
|
||||
errno.WSAECONNREFUSED, error.ConnectionRefusedError
|
||||
)
|
||||
self.assertErrnoException(errno.WSAENETUNREACH, error.NoRouteError)
|
||||
|
||||
def test_gaierror(self) -> None:
|
||||
"""
|
||||
L{error.getConnectError} converts to a L{error.UnknownHostError} given
|
||||
a C{socket.gaierror} instance.
|
||||
"""
|
||||
result = error.getConnectError(socket.gaierror(12, "hello"))
|
||||
self.assertCorrectException(12, "hello", result, error.UnknownHostError)
|
||||
|
||||
def test_nonTuple(self) -> None:
|
||||
"""
|
||||
L{error.getConnectError} converts to a L{error.ConnectError} given
|
||||
an argument that cannot be unpacked.
|
||||
"""
|
||||
e = Exception()
|
||||
result = error.getConnectError(e)
|
||||
self.assertCorrectException(None, e, result, error.ConnectError)
|
||||
@@ -0,0 +1,139 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Test code for basic Factory classes.
|
||||
"""
|
||||
|
||||
|
||||
import pickle
|
||||
|
||||
from twisted.internet.protocol import Protocol, ReconnectingClientFactory
|
||||
from twisted.internet.task import Clock
|
||||
from twisted.trial.unittest import TestCase
|
||||
|
||||
|
||||
class FakeConnector:
|
||||
"""
|
||||
A fake connector class, to be used to mock connections failed or lost.
|
||||
"""
|
||||
|
||||
def stopConnecting(self):
|
||||
pass
|
||||
|
||||
def connect(self):
|
||||
pass
|
||||
|
||||
|
||||
class ReconnectingFactoryTests(TestCase):
|
||||
"""
|
||||
Tests for L{ReconnectingClientFactory}.
|
||||
"""
|
||||
|
||||
def test_stopTryingWhenConnected(self):
|
||||
"""
|
||||
If a L{ReconnectingClientFactory} has C{stopTrying} called while it is
|
||||
connected, it does not subsequently attempt to reconnect if the
|
||||
connection is later lost.
|
||||
"""
|
||||
|
||||
class NoConnectConnector:
|
||||
def stopConnecting(self):
|
||||
raise RuntimeError("Shouldn't be called, we're connected.")
|
||||
|
||||
def connect(self):
|
||||
raise RuntimeError("Shouldn't be reconnecting.")
|
||||
|
||||
c = ReconnectingClientFactory()
|
||||
c.protocol = Protocol
|
||||
# Let's pretend we've connected:
|
||||
c.buildProtocol(None)
|
||||
# Now we stop trying, then disconnect:
|
||||
c.stopTrying()
|
||||
c.clientConnectionLost(NoConnectConnector(), None)
|
||||
self.assertFalse(c.continueTrying)
|
||||
|
||||
def test_stopTryingDoesNotReconnect(self):
|
||||
"""
|
||||
Calling stopTrying on a L{ReconnectingClientFactory} doesn't attempt a
|
||||
retry on any active connector.
|
||||
"""
|
||||
|
||||
class FactoryAwareFakeConnector(FakeConnector):
|
||||
attemptedRetry = False
|
||||
|
||||
def stopConnecting(self):
|
||||
"""
|
||||
Behave as though an ongoing connection attempt has now
|
||||
failed, and notify the factory of this.
|
||||
"""
|
||||
f.clientConnectionFailed(self, None)
|
||||
|
||||
def connect(self):
|
||||
"""
|
||||
Record an attempt to reconnect, since this is what we
|
||||
are trying to avoid.
|
||||
"""
|
||||
self.attemptedRetry = True
|
||||
|
||||
f = ReconnectingClientFactory()
|
||||
f.clock = Clock()
|
||||
|
||||
# simulate an active connection - stopConnecting on this connector should
|
||||
# be triggered when we call stopTrying
|
||||
f.connector = FactoryAwareFakeConnector()
|
||||
f.stopTrying()
|
||||
|
||||
# make sure we never attempted to retry
|
||||
self.assertFalse(f.connector.attemptedRetry)
|
||||
self.assertFalse(f.clock.getDelayedCalls())
|
||||
|
||||
def test_serializeUnused(self):
|
||||
"""
|
||||
A L{ReconnectingClientFactory} which hasn't been used for anything
|
||||
can be pickled and unpickled and end up with the same state.
|
||||
"""
|
||||
original = ReconnectingClientFactory()
|
||||
reconstituted = pickle.loads(pickle.dumps(original))
|
||||
self.assertEqual(original.__dict__, reconstituted.__dict__)
|
||||
|
||||
def test_serializeWithClock(self):
|
||||
"""
|
||||
The clock attribute of L{ReconnectingClientFactory} is not serialized,
|
||||
and the restored value sets it to the default value, the reactor.
|
||||
"""
|
||||
clock = Clock()
|
||||
original = ReconnectingClientFactory()
|
||||
original.clock = clock
|
||||
reconstituted = pickle.loads(pickle.dumps(original))
|
||||
self.assertIsNone(reconstituted.clock)
|
||||
|
||||
def test_deserializationResetsParameters(self):
|
||||
"""
|
||||
A L{ReconnectingClientFactory} which is unpickled does not have an
|
||||
L{IConnector} and has its reconnecting timing parameters reset to their
|
||||
initial values.
|
||||
"""
|
||||
factory = ReconnectingClientFactory()
|
||||
factory.clientConnectionFailed(FakeConnector(), None)
|
||||
self.addCleanup(factory.stopTrying)
|
||||
|
||||
serialized = pickle.dumps(factory)
|
||||
unserialized = pickle.loads(serialized)
|
||||
self.assertIsNone(unserialized.connector)
|
||||
self.assertIsNone(unserialized._callID)
|
||||
self.assertEqual(unserialized.retries, 0)
|
||||
self.assertEqual(unserialized.delay, factory.initialDelay)
|
||||
self.assertTrue(unserialized.continueTrying)
|
||||
|
||||
def test_parametrizedClock(self):
|
||||
"""
|
||||
The clock used by L{ReconnectingClientFactory} can be parametrized, so
|
||||
that one can cleanly test reconnections.
|
||||
"""
|
||||
clock = Clock()
|
||||
factory = ReconnectingClientFactory()
|
||||
factory.clock = clock
|
||||
|
||||
factory.clientConnectionLost(FakeConnector(), None)
|
||||
self.assertEqual(len(clock.calls), 1)
|
||||
1035
.venv/lib/python3.12/site-packages/twisted/test/test_failure.py
Normal file
1035
.venv/lib/python3.12/site-packages/twisted/test/test_failure.py
Normal file
File diff suppressed because it is too large
Load Diff
258
.venv/lib/python3.12/site-packages/twisted/test/test_fdesc.py
Normal file
258
.venv/lib/python3.12/site-packages/twisted/test/test_fdesc.py
Normal file
@@ -0,0 +1,258 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.internet.fdesc}.
|
||||
"""
|
||||
|
||||
import errno
|
||||
import os
|
||||
import sys
|
||||
|
||||
try:
|
||||
import fcntl
|
||||
except ImportError:
|
||||
skip = "not supported on this platform"
|
||||
else:
|
||||
from twisted.internet import fdesc
|
||||
|
||||
from twisted.python.util import untilConcludes
|
||||
from twisted.trial import unittest
|
||||
|
||||
|
||||
class NonBlockingTests(unittest.SynchronousTestCase):
|
||||
"""
|
||||
Tests for L{fdesc.setNonBlocking} and L{fdesc.setBlocking}.
|
||||
"""
|
||||
|
||||
def test_setNonBlocking(self):
|
||||
"""
|
||||
L{fdesc.setNonBlocking} sets a file description to non-blocking.
|
||||
"""
|
||||
r, w = os.pipe()
|
||||
self.addCleanup(os.close, r)
|
||||
self.addCleanup(os.close, w)
|
||||
self.assertFalse(fcntl.fcntl(r, fcntl.F_GETFL) & os.O_NONBLOCK)
|
||||
fdesc.setNonBlocking(r)
|
||||
self.assertTrue(fcntl.fcntl(r, fcntl.F_GETFL) & os.O_NONBLOCK)
|
||||
|
||||
def test_setBlocking(self):
|
||||
"""
|
||||
L{fdesc.setBlocking} sets a file description to blocking.
|
||||
"""
|
||||
r, w = os.pipe()
|
||||
self.addCleanup(os.close, r)
|
||||
self.addCleanup(os.close, w)
|
||||
fdesc.setNonBlocking(r)
|
||||
fdesc.setBlocking(r)
|
||||
self.assertFalse(fcntl.fcntl(r, fcntl.F_GETFL) & os.O_NONBLOCK)
|
||||
|
||||
|
||||
class ReadWriteTests(unittest.SynchronousTestCase):
|
||||
"""
|
||||
Tests for L{fdesc.readFromFD}, L{fdesc.writeToFD}.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""
|
||||
Create a non-blocking pipe that can be used in tests.
|
||||
"""
|
||||
self.r, self.w = os.pipe()
|
||||
fdesc.setNonBlocking(self.r)
|
||||
fdesc.setNonBlocking(self.w)
|
||||
|
||||
def tearDown(self):
|
||||
"""
|
||||
Close pipes.
|
||||
"""
|
||||
try:
|
||||
os.close(self.w)
|
||||
except OSError:
|
||||
pass
|
||||
try:
|
||||
os.close(self.r)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def write(self, d):
|
||||
"""
|
||||
Write data to the pipe.
|
||||
"""
|
||||
return fdesc.writeToFD(self.w, d)
|
||||
|
||||
def read(self):
|
||||
"""
|
||||
Read data from the pipe.
|
||||
"""
|
||||
l = []
|
||||
res = fdesc.readFromFD(self.r, l.append)
|
||||
if res is None:
|
||||
if l:
|
||||
return l[0]
|
||||
else:
|
||||
return b""
|
||||
else:
|
||||
return res
|
||||
|
||||
def test_writeAndRead(self):
|
||||
"""
|
||||
Test that the number of bytes L{fdesc.writeToFD} reports as written
|
||||
with its return value are seen by L{fdesc.readFromFD}.
|
||||
"""
|
||||
n = self.write(b"hello")
|
||||
self.assertTrue(n > 0)
|
||||
s = self.read()
|
||||
self.assertEqual(len(s), n)
|
||||
self.assertEqual(b"hello"[:n], s)
|
||||
|
||||
def test_writeAndReadLarge(self):
|
||||
"""
|
||||
Similar to L{test_writeAndRead}, but use a much larger string to verify
|
||||
the behavior for that case.
|
||||
"""
|
||||
orig = b"0123456879" * 10000
|
||||
written = self.write(orig)
|
||||
self.assertTrue(written > 0)
|
||||
result = []
|
||||
resultlength = 0
|
||||
i = 0
|
||||
while resultlength < written or i < 50:
|
||||
result.append(self.read())
|
||||
resultlength += len(result[-1])
|
||||
# Increment a counter to be sure we'll exit at some point
|
||||
i += 1
|
||||
result = b"".join(result)
|
||||
self.assertEqual(len(result), written)
|
||||
self.assertEqual(orig[:written], result)
|
||||
|
||||
def test_readFromEmpty(self):
|
||||
"""
|
||||
Verify that reading from a file descriptor with no data does not raise
|
||||
an exception and does not result in the callback function being called.
|
||||
"""
|
||||
l = []
|
||||
result = fdesc.readFromFD(self.r, l.append)
|
||||
self.assertEqual(l, [])
|
||||
self.assertIsNone(result)
|
||||
|
||||
def test_readFromCleanClose(self):
|
||||
"""
|
||||
Test that using L{fdesc.readFromFD} on a cleanly closed file descriptor
|
||||
returns a connection done indicator.
|
||||
"""
|
||||
os.close(self.w)
|
||||
self.assertEqual(self.read(), fdesc.CONNECTION_DONE)
|
||||
|
||||
def test_writeToClosed(self):
|
||||
"""
|
||||
Verify that writing with L{fdesc.writeToFD} when the read end is closed
|
||||
results in a connection lost indicator.
|
||||
"""
|
||||
os.close(self.r)
|
||||
self.assertEqual(self.write(b"s"), fdesc.CONNECTION_LOST)
|
||||
|
||||
def test_readFromInvalid(self):
|
||||
"""
|
||||
Verify that reading with L{fdesc.readFromFD} when the read end is
|
||||
closed results in a connection lost indicator.
|
||||
"""
|
||||
os.close(self.r)
|
||||
self.assertEqual(self.read(), fdesc.CONNECTION_LOST)
|
||||
|
||||
def test_writeToInvalid(self):
|
||||
"""
|
||||
Verify that writing with L{fdesc.writeToFD} when the write end is
|
||||
closed results in a connection lost indicator.
|
||||
"""
|
||||
os.close(self.w)
|
||||
self.assertEqual(self.write(b"s"), fdesc.CONNECTION_LOST)
|
||||
|
||||
def test_writeErrors(self):
|
||||
"""
|
||||
Test error path for L{fdesc.writeTod}.
|
||||
"""
|
||||
oldOsWrite = os.write
|
||||
|
||||
def eagainWrite(fd, data):
|
||||
err = OSError()
|
||||
err.errno = errno.EAGAIN
|
||||
raise err
|
||||
|
||||
os.write = eagainWrite
|
||||
try:
|
||||
self.assertEqual(self.write(b"s"), 0)
|
||||
finally:
|
||||
os.write = oldOsWrite
|
||||
|
||||
def eintrWrite(fd, data):
|
||||
err = OSError()
|
||||
err.errno = errno.EINTR
|
||||
raise err
|
||||
|
||||
os.write = eintrWrite
|
||||
try:
|
||||
self.assertEqual(self.write(b"s"), 0)
|
||||
finally:
|
||||
os.write = oldOsWrite
|
||||
|
||||
|
||||
class CloseOnExecTests(unittest.SynchronousTestCase):
|
||||
"""
|
||||
Tests for L{fdesc._setCloseOnExec} and L{fdesc._unsetCloseOnExec}.
|
||||
"""
|
||||
|
||||
program = """
|
||||
import os, errno
|
||||
try:
|
||||
os.write(%d, b'lul')
|
||||
except OSError as e:
|
||||
if e.errno == errno.EBADF:
|
||||
os._exit(0)
|
||||
os._exit(5)
|
||||
except BaseException:
|
||||
os._exit(10)
|
||||
else:
|
||||
os._exit(20)
|
||||
"""
|
||||
|
||||
def _execWithFileDescriptor(self, fObj):
|
||||
pid = os.fork()
|
||||
if pid == 0:
|
||||
try:
|
||||
os.execv(
|
||||
sys.executable,
|
||||
[sys.executable, "-c", self.program % (fObj.fileno(),)],
|
||||
)
|
||||
except BaseException:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
os._exit(30)
|
||||
else:
|
||||
# On Linux wait(2) doesn't seem ever able to fail with EINTR but
|
||||
# POSIX seems to allow it and on macOS it happens quite a lot.
|
||||
return untilConcludes(os.waitpid, pid, 0)[1]
|
||||
|
||||
def test_setCloseOnExec(self):
|
||||
"""
|
||||
A file descriptor passed to L{fdesc._setCloseOnExec} is not inherited
|
||||
by a new process image created with one of the exec family of
|
||||
functions.
|
||||
"""
|
||||
with open(self.mktemp(), "wb") as fObj:
|
||||
fdesc._setCloseOnExec(fObj.fileno())
|
||||
status = self._execWithFileDescriptor(fObj)
|
||||
self.assertTrue(os.WIFEXITED(status))
|
||||
self.assertEqual(os.WEXITSTATUS(status), 0)
|
||||
|
||||
def test_unsetCloseOnExec(self):
|
||||
"""
|
||||
A file descriptor passed to L{fdesc._unsetCloseOnExec} is inherited by
|
||||
a new process image created with one of the exec family of functions.
|
||||
"""
|
||||
with open(self.mktemp(), "wb") as fObj:
|
||||
fdesc._setCloseOnExec(fObj.fileno())
|
||||
fdesc._unsetCloseOnExec(fObj.fileno())
|
||||
status = self._execWithFileDescriptor(fObj)
|
||||
self.assertTrue(os.WIFEXITED(status))
|
||||
self.assertEqual(os.WEXITSTATUS(status), 20)
|
||||
@@ -0,0 +1,56 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.protocols.finger}.
|
||||
"""
|
||||
|
||||
from twisted.internet.testing import StringTransport
|
||||
from twisted.protocols import finger
|
||||
from twisted.trial import unittest
|
||||
|
||||
|
||||
class FingerTests(unittest.TestCase):
|
||||
"""
|
||||
Tests for L{finger.Finger}.
|
||||
"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
"""
|
||||
Create and connect a L{finger.Finger} instance.
|
||||
"""
|
||||
self.transport = StringTransport()
|
||||
self.protocol = finger.Finger()
|
||||
self.protocol.makeConnection(self.transport)
|
||||
|
||||
def test_simple(self) -> None:
|
||||
"""
|
||||
When L{finger.Finger} receives a CR LF terminated line, it responds
|
||||
with the default user status message - that no such user exists.
|
||||
"""
|
||||
self.protocol.dataReceived(b"moshez\r\n")
|
||||
self.assertEqual(self.transport.value(), b"Login: moshez\nNo such user\n")
|
||||
|
||||
def test_simpleW(self) -> None:
|
||||
"""
|
||||
The behavior for a query which begins with C{"/w"} is the same as the
|
||||
behavior for one which does not. The user is reported as not existing.
|
||||
"""
|
||||
self.protocol.dataReceived(b"/w moshez\r\n")
|
||||
self.assertEqual(self.transport.value(), b"Login: moshez\nNo such user\n")
|
||||
|
||||
def test_forwarding(self) -> None:
|
||||
"""
|
||||
When L{finger.Finger} receives a request for a remote user, it responds
|
||||
with a message rejecting the request.
|
||||
"""
|
||||
self.protocol.dataReceived(b"moshez@example.com\r\n")
|
||||
self.assertEqual(self.transport.value(), b"Finger forwarding service denied\n")
|
||||
|
||||
def test_list(self) -> None:
|
||||
"""
|
||||
When L{finger.Finger} receives a blank line, it responds with a message
|
||||
rejecting the request for all online users.
|
||||
"""
|
||||
self.protocol.dataReceived(b"\r\n")
|
||||
self.assertEqual(self.transport.value(), b"Finger online list denied\n")
|
||||
@@ -0,0 +1,139 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
from __future__ import annotations
|
||||
|
||||
"""
|
||||
Test cases for formmethod module.
|
||||
"""
|
||||
|
||||
from typing import Callable, Iterable
|
||||
|
||||
from typing_extensions import Concatenate, ParamSpec
|
||||
|
||||
from twisted.python import formmethod
|
||||
from twisted.trial import unittest
|
||||
|
||||
_P = ParamSpec("_P")
|
||||
|
||||
|
||||
class ArgumentTests(unittest.TestCase):
|
||||
def argTest(
|
||||
self,
|
||||
argKlass: Callable[Concatenate[str, _P], formmethod.Argument],
|
||||
testPairs: Iterable[tuple[object, object]],
|
||||
badValues: Iterable[object],
|
||||
*args: _P.args,
|
||||
**kwargs: _P.kwargs,
|
||||
) -> None:
|
||||
arg = argKlass("name", *args, **kwargs)
|
||||
for val, result in testPairs:
|
||||
self.assertEqual(arg.coerce(val), result)
|
||||
for val in badValues:
|
||||
self.assertRaises(formmethod.InputError, arg.coerce, val)
|
||||
|
||||
def test_argument(self) -> None:
|
||||
"""
|
||||
Test that corce correctly raises NotImplementedError.
|
||||
"""
|
||||
arg = formmethod.Argument("name")
|
||||
self.assertRaises(NotImplementedError, arg.coerce, "")
|
||||
|
||||
def testString(self) -> None:
|
||||
self.argTest(formmethod.String, [("a", "a"), (1, "1"), ("", "")], ())
|
||||
self.argTest(
|
||||
formmethod.String, [("ab", "ab"), ("abc", "abc")], ("2", ""), min=2
|
||||
)
|
||||
self.argTest(
|
||||
formmethod.String, [("ab", "ab"), ("a", "a")], ("223213", "345x"), max=3
|
||||
)
|
||||
self.argTest(
|
||||
formmethod.String,
|
||||
[("ab", "ab"), ("add", "add")],
|
||||
("223213", "x"),
|
||||
min=2,
|
||||
max=3,
|
||||
)
|
||||
|
||||
def testInt(self) -> None:
|
||||
self.argTest(
|
||||
formmethod.Integer, [("3", 3), ("-2", -2), ("", None)], ("q", "2.3")
|
||||
)
|
||||
self.argTest(
|
||||
formmethod.Integer, [("3", 3), ("-2", -2)], ("q", "2.3", ""), allowNone=0
|
||||
)
|
||||
|
||||
def testFloat(self) -> None:
|
||||
self.argTest(
|
||||
formmethod.Float, [("3", 3.0), ("-2.3", -2.3), ("", None)], ("q", "2.3z")
|
||||
)
|
||||
self.argTest(
|
||||
formmethod.Float,
|
||||
[("3", 3.0), ("-2.3", -2.3)],
|
||||
("q", "2.3z", ""),
|
||||
allowNone=0,
|
||||
)
|
||||
|
||||
def testChoice(self) -> None:
|
||||
choices = [("a", "apple", "an apple"), ("b", "banana", "ook")]
|
||||
self.argTest(
|
||||
formmethod.Choice,
|
||||
[("a", "apple"), ("b", "banana")],
|
||||
("c", 1),
|
||||
choices=choices,
|
||||
)
|
||||
|
||||
def testFlags(self) -> None:
|
||||
flags = [("a", "apple", "an apple"), ("b", "banana", "ook")]
|
||||
self.argTest(
|
||||
formmethod.Flags,
|
||||
[(["a"], ["apple"]), (["b", "a"], ["banana", "apple"])],
|
||||
(["a", "c"], ["fdfs"]),
|
||||
flags=flags,
|
||||
)
|
||||
|
||||
def testBoolean(self) -> None:
|
||||
tests = [("yes", 1), ("", 0), ("False", 0), ("no", 0)]
|
||||
self.argTest(formmethod.Boolean, tests, ())
|
||||
|
||||
def test_file(self) -> None:
|
||||
"""
|
||||
Test the correctness of the coerce function.
|
||||
"""
|
||||
arg = formmethod.File("name", allowNone=0)
|
||||
self.assertEqual(arg.coerce("something"), "something")
|
||||
self.assertRaises(formmethod.InputError, arg.coerce, None)
|
||||
arg2 = formmethod.File("name")
|
||||
self.assertIsNone(arg2.coerce(None))
|
||||
|
||||
def testDate(self) -> None:
|
||||
goodTests = {
|
||||
("2002", "12", "21"): (2002, 12, 21),
|
||||
("1996", "2", "29"): (1996, 2, 29),
|
||||
("", "", ""): None,
|
||||
}.items()
|
||||
badTests = [
|
||||
("2002", "2", "29"),
|
||||
("xx", "2", "3"),
|
||||
("2002", "13", "1"),
|
||||
("1999", "12", "32"),
|
||||
("2002", "1"),
|
||||
("2002", "2", "3", "4"),
|
||||
]
|
||||
self.argTest(formmethod.Date, goodTests, badTests)
|
||||
|
||||
def testRangedInteger(self) -> None:
|
||||
goodTests = {"0": 0, "12": 12, "3": 3}.items()
|
||||
badTests = ["-1", "x", "13", "-2000", "3.4"]
|
||||
self.argTest(formmethod.IntegerRange, goodTests, badTests, 0, 12)
|
||||
|
||||
def testVerifiedPassword(self) -> None:
|
||||
goodTests = {("foo", "foo"): "foo", ("ab", "ab"): "ab"}.items()
|
||||
badTests = [
|
||||
("ab", "a"),
|
||||
("12345", "12345"),
|
||||
("", ""),
|
||||
("a", "a"),
|
||||
("a",),
|
||||
("a", "a", "a"),
|
||||
]
|
||||
self.argTest(formmethod.VerifiedPassword, goodTests, badTests, min=2, max=4)
|
||||
4183
.venv/lib/python3.12/site-packages/twisted/test/test_ftp.py
Normal file
4183
.venv/lib/python3.12/site-packages/twisted/test/test_ftp.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,76 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.tap.ftp}.
|
||||
"""
|
||||
|
||||
from twisted.cred import credentials, error
|
||||
from twisted.python import versions
|
||||
from twisted.python.filepath import FilePath
|
||||
from twisted.tap.ftp import Options
|
||||
from twisted.trial.unittest import TestCase
|
||||
|
||||
|
||||
class FTPOptionsTests(TestCase):
|
||||
"""
|
||||
Tests for the command line option parser used for C{twistd ftp}.
|
||||
"""
|
||||
|
||||
usernamePassword = (b"iamuser", b"thisispassword")
|
||||
|
||||
def setUp(self) -> None:
|
||||
"""
|
||||
Create a file with two users.
|
||||
"""
|
||||
self.filename = self.mktemp()
|
||||
f = FilePath(self.filename)
|
||||
f.setContent(b":".join(self.usernamePassword))
|
||||
self.options = Options()
|
||||
|
||||
def test_passwordfileDeprecation(self) -> None:
|
||||
"""
|
||||
The C{--password-file} option will emit a warning stating that
|
||||
said option is deprecated.
|
||||
"""
|
||||
self.callDeprecated(
|
||||
versions.Version("Twisted", 11, 1, 0),
|
||||
self.options.opt_password_file,
|
||||
self.filename,
|
||||
)
|
||||
|
||||
def test_authAdded(self) -> None:
|
||||
"""
|
||||
The C{--auth} command-line option will add a checker to the list of
|
||||
checkers
|
||||
"""
|
||||
numCheckers = len(self.options["credCheckers"])
|
||||
self.options.parseOptions(["--auth", "file:" + self.filename])
|
||||
self.assertEqual(len(self.options["credCheckers"]), numCheckers + 1)
|
||||
|
||||
def test_authFailure(self):
|
||||
"""
|
||||
The checker created by the C{--auth} command-line option returns a
|
||||
L{Deferred} that fails with L{UnauthorizedLogin} when
|
||||
presented with credentials that are unknown to that checker.
|
||||
"""
|
||||
self.options.parseOptions(["--auth", "file:" + self.filename])
|
||||
checker = self.options["credCheckers"][-1]
|
||||
invalid = credentials.UsernamePassword(self.usernamePassword[0], "fake")
|
||||
return checker.requestAvatarId(invalid).addCallbacks(
|
||||
lambda ignore: self.fail("Wrong password should raise error"),
|
||||
lambda err: err.trap(error.UnauthorizedLogin),
|
||||
)
|
||||
|
||||
def test_authSuccess(self):
|
||||
"""
|
||||
The checker created by the C{--auth} command-line option returns a
|
||||
L{Deferred} that returns the avatar id when presented with credentials
|
||||
that are known to that checker.
|
||||
"""
|
||||
self.options.parseOptions(["--auth", "file:" + self.filename])
|
||||
checker = self.options["credCheckers"][-1]
|
||||
correct = credentials.UsernamePassword(*self.usernamePassword)
|
||||
return checker.requestAvatarId(correct).addCallback(
|
||||
lambda username: self.assertEqual(username, correct.username)
|
||||
)
|
||||
118
.venv/lib/python3.12/site-packages/twisted/test/test_htb.py
Normal file
118
.venv/lib/python3.12/site-packages/twisted/test/test_htb.py
Normal file
@@ -0,0 +1,118 @@
|
||||
# -*- Python -*-
|
||||
|
||||
__version__ = "$Revision: 1.3 $"[11:-2]
|
||||
|
||||
from twisted.protocols import htb
|
||||
from twisted.trial import unittest
|
||||
from .test_pcp import DummyConsumer
|
||||
|
||||
|
||||
class DummyClock:
|
||||
time = 0
|
||||
|
||||
def set(self, when: int) -> None:
|
||||
self.time = when
|
||||
|
||||
def __call__(self) -> int:
|
||||
return self.time
|
||||
|
||||
|
||||
class SomeBucket(htb.Bucket):
|
||||
maxburst = 100
|
||||
rate = 2
|
||||
|
||||
|
||||
class TestBucketBase(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self._realTimeFunc = htb.time
|
||||
self.clock = DummyClock()
|
||||
htb.time = self.clock
|
||||
|
||||
def tearDown(self) -> None:
|
||||
htb.time = self._realTimeFunc
|
||||
|
||||
|
||||
class BucketTests(TestBucketBase):
|
||||
def testBucketSize(self) -> None:
|
||||
"""
|
||||
Testing the size of the bucket.
|
||||
"""
|
||||
b = SomeBucket()
|
||||
fit = b.add(1000)
|
||||
self.assertEqual(100, fit)
|
||||
|
||||
def testBucketDrain(self) -> None:
|
||||
"""
|
||||
Testing the bucket's drain rate.
|
||||
"""
|
||||
b = SomeBucket()
|
||||
fit = b.add(1000)
|
||||
self.clock.set(10)
|
||||
fit = b.add(1000)
|
||||
self.assertEqual(20, fit)
|
||||
|
||||
def test_bucketEmpty(self) -> None:
|
||||
"""
|
||||
L{htb.Bucket.drip} returns C{True} if the bucket is empty after that drip.
|
||||
"""
|
||||
b = SomeBucket()
|
||||
b.add(20)
|
||||
self.clock.set(9)
|
||||
empty = b.drip()
|
||||
self.assertFalse(empty)
|
||||
self.clock.set(10)
|
||||
empty = b.drip()
|
||||
self.assertTrue(empty)
|
||||
|
||||
|
||||
class BucketNestingTests(TestBucketBase):
|
||||
def setUp(self) -> None:
|
||||
TestBucketBase.setUp(self)
|
||||
self.parent = SomeBucket()
|
||||
self.child1 = SomeBucket(self.parent)
|
||||
self.child2 = SomeBucket(self.parent)
|
||||
|
||||
def testBucketParentSize(self) -> None:
|
||||
# Use up most of the parent bucket.
|
||||
self.child1.add(90)
|
||||
fit = self.child2.add(90)
|
||||
self.assertEqual(10, fit)
|
||||
|
||||
def testBucketParentRate(self) -> None:
|
||||
# Make the parent bucket drain slower.
|
||||
self.parent.rate = 1
|
||||
# Fill both child1 and parent.
|
||||
self.child1.add(100)
|
||||
self.clock.set(10)
|
||||
fit = self.child1.add(100)
|
||||
# How much room was there? The child bucket would have had 20,
|
||||
# but the parent bucket only ten (so no, it wouldn't make too much
|
||||
# sense to have a child bucket draining faster than its parent in a real
|
||||
# application.)
|
||||
self.assertEqual(10, fit)
|
||||
|
||||
|
||||
# TODO: Test the Transport stuff?
|
||||
|
||||
|
||||
class ConsumerShaperTests(TestBucketBase):
|
||||
def setUp(self) -> None:
|
||||
TestBucketBase.setUp(self)
|
||||
self.underlying = DummyConsumer()
|
||||
self.bucket = SomeBucket()
|
||||
self.shaped = htb.ShapedConsumer(self.underlying, self.bucket)
|
||||
|
||||
def testRate(self) -> None:
|
||||
# Start off with a full bucket, so the burst-size doesn't factor in
|
||||
# to the calculations.
|
||||
delta_t = 10
|
||||
self.bucket.add(100)
|
||||
self.shaped.write("x" * 100)
|
||||
self.clock.set(delta_t)
|
||||
self.shaped.resumeProducing()
|
||||
self.assertEqual(len(self.underlying.getvalue()), delta_t * self.bucket.rate)
|
||||
|
||||
def testBucketRefs(self) -> None:
|
||||
self.assertEqual(self.bucket._refcount, 1)
|
||||
self.shaped.stopProducing()
|
||||
self.assertEqual(self.bucket._refcount, 0)
|
||||
211
.venv/lib/python3.12/site-packages/twisted/test/test_ident.py
Normal file
211
.venv/lib/python3.12/site-packages/twisted/test/test_ident.py
Normal file
@@ -0,0 +1,211 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
|
||||
"""
|
||||
Test cases for twisted.protocols.ident module.
|
||||
"""
|
||||
|
||||
import builtins
|
||||
import struct
|
||||
from io import StringIO
|
||||
|
||||
from twisted.internet import defer, error
|
||||
from twisted.internet.testing import StringTransport
|
||||
from twisted.protocols import ident
|
||||
from twisted.python import failure
|
||||
from twisted.trial import unittest
|
||||
|
||||
|
||||
class ClassParserTests(unittest.TestCase):
|
||||
"""
|
||||
Test parsing of ident responses.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""
|
||||
Create an ident client used in tests.
|
||||
"""
|
||||
self.client = ident.IdentClient()
|
||||
|
||||
def test_indentError(self):
|
||||
"""
|
||||
'UNKNOWN-ERROR' error should map to the L{ident.IdentError} exception.
|
||||
"""
|
||||
d = defer.Deferred()
|
||||
self.client.queries.append((d, 123, 456))
|
||||
self.client.lineReceived("123, 456 : ERROR : UNKNOWN-ERROR")
|
||||
return self.assertFailure(d, ident.IdentError)
|
||||
|
||||
def test_noUSerError(self):
|
||||
"""
|
||||
'NO-USER' error should map to the L{ident.NoUser} exception.
|
||||
"""
|
||||
d = defer.Deferred()
|
||||
self.client.queries.append((d, 234, 456))
|
||||
self.client.lineReceived("234, 456 : ERROR : NO-USER")
|
||||
return self.assertFailure(d, ident.NoUser)
|
||||
|
||||
def test_invalidPortError(self):
|
||||
"""
|
||||
'INVALID-PORT' error should map to the L{ident.InvalidPort} exception.
|
||||
"""
|
||||
d = defer.Deferred()
|
||||
self.client.queries.append((d, 345, 567))
|
||||
self.client.lineReceived("345, 567 : ERROR : INVALID-PORT")
|
||||
return self.assertFailure(d, ident.InvalidPort)
|
||||
|
||||
def test_hiddenUserError(self):
|
||||
"""
|
||||
'HIDDEN-USER' error should map to the L{ident.HiddenUser} exception.
|
||||
"""
|
||||
d = defer.Deferred()
|
||||
self.client.queries.append((d, 567, 789))
|
||||
self.client.lineReceived("567, 789 : ERROR : HIDDEN-USER")
|
||||
return self.assertFailure(d, ident.HiddenUser)
|
||||
|
||||
def test_lostConnection(self):
|
||||
"""
|
||||
A pending query which failed because of a ConnectionLost should
|
||||
receive an L{ident.IdentError}.
|
||||
"""
|
||||
d = defer.Deferred()
|
||||
self.client.queries.append((d, 765, 432))
|
||||
self.client.connectionLost(failure.Failure(error.ConnectionLost()))
|
||||
return self.assertFailure(d, ident.IdentError)
|
||||
|
||||
|
||||
class TestIdentServer(ident.IdentServer):
|
||||
def lookup(self, serverAddress, clientAddress):
|
||||
return self.resultValue
|
||||
|
||||
|
||||
class TestErrorIdentServer(ident.IdentServer):
|
||||
def lookup(self, serverAddress, clientAddress):
|
||||
raise self.exceptionType()
|
||||
|
||||
|
||||
class NewException(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
class ServerParserTests(unittest.TestCase):
|
||||
def testErrors(self):
|
||||
p = TestErrorIdentServer()
|
||||
p.makeConnection(StringTransport())
|
||||
L = []
|
||||
p.sendLine = L.append
|
||||
|
||||
p.exceptionType = ident.IdentError
|
||||
p.lineReceived("123, 345")
|
||||
self.assertEqual(L[0], "123, 345 : ERROR : UNKNOWN-ERROR")
|
||||
|
||||
p.exceptionType = ident.NoUser
|
||||
p.lineReceived("432, 210")
|
||||
self.assertEqual(L[1], "432, 210 : ERROR : NO-USER")
|
||||
|
||||
p.exceptionType = ident.InvalidPort
|
||||
p.lineReceived("987, 654")
|
||||
self.assertEqual(L[2], "987, 654 : ERROR : INVALID-PORT")
|
||||
|
||||
p.exceptionType = ident.HiddenUser
|
||||
p.lineReceived("756, 827")
|
||||
self.assertEqual(L[3], "756, 827 : ERROR : HIDDEN-USER")
|
||||
|
||||
p.exceptionType = NewException
|
||||
p.lineReceived("987, 789")
|
||||
self.assertEqual(L[4], "987, 789 : ERROR : UNKNOWN-ERROR")
|
||||
errs = self.flushLoggedErrors(NewException)
|
||||
self.assertEqual(len(errs), 1)
|
||||
|
||||
for port in -1, 0, 65536, 65537:
|
||||
del L[:]
|
||||
p.lineReceived("%d, 5" % (port,))
|
||||
p.lineReceived("5, %d" % (port,))
|
||||
self.assertEqual(
|
||||
L,
|
||||
[
|
||||
"%d, 5 : ERROR : INVALID-PORT" % (port,),
|
||||
"5, %d : ERROR : INVALID-PORT" % (port,),
|
||||
],
|
||||
)
|
||||
|
||||
def testSuccess(self):
|
||||
p = TestIdentServer()
|
||||
p.makeConnection(StringTransport())
|
||||
L = []
|
||||
p.sendLine = L.append
|
||||
|
||||
p.resultValue = ("SYS", "USER")
|
||||
p.lineReceived("123, 456")
|
||||
self.assertEqual(L[0], "123, 456 : USERID : SYS : USER")
|
||||
|
||||
|
||||
if struct.pack("=L", 1)[0:1] == b"\x01":
|
||||
_addr1 = "0100007F"
|
||||
_addr2 = "04030201"
|
||||
else:
|
||||
_addr1 = "7F000001"
|
||||
_addr2 = "01020304"
|
||||
|
||||
|
||||
class ProcMixinTests(unittest.TestCase):
|
||||
line = (
|
||||
"4: %s:0019 %s:02FA 0A 00000000:00000000 "
|
||||
"00:00000000 00000000 0 0 10927 1 f72a5b80 "
|
||||
"3000 0 0 2 -1"
|
||||
) % (_addr1, _addr2)
|
||||
sampleFile = (
|
||||
" sl local_address rem_address st tx_queue rx_queue tr "
|
||||
"tm->when retrnsmt uid timeout inode\n " + line
|
||||
)
|
||||
|
||||
def testDottedQuadFromHexString(self):
|
||||
p = ident.ProcServerMixin()
|
||||
self.assertEqual(p.dottedQuadFromHexString(_addr1), "127.0.0.1")
|
||||
|
||||
def testUnpackAddress(self):
|
||||
p = ident.ProcServerMixin()
|
||||
self.assertEqual(p.unpackAddress(_addr1 + ":0277"), ("127.0.0.1", 631))
|
||||
|
||||
def testLineParser(self):
|
||||
p = ident.ProcServerMixin()
|
||||
self.assertEqual(
|
||||
p.parseLine(self.line), (("127.0.0.1", 25), ("1.2.3.4", 762), 0)
|
||||
)
|
||||
|
||||
def testExistingAddress(self):
|
||||
username = []
|
||||
p = ident.ProcServerMixin()
|
||||
p.entries = lambda: iter([self.line])
|
||||
p.getUsername = lambda uid: (username.append(uid), "root")[1]
|
||||
self.assertEqual(
|
||||
p.lookup(("127.0.0.1", 25), ("1.2.3.4", 762)), (p.SYSTEM_NAME, "root")
|
||||
)
|
||||
self.assertEqual(username, [0])
|
||||
|
||||
def testNonExistingAddress(self):
|
||||
p = ident.ProcServerMixin()
|
||||
p.entries = lambda: iter([self.line])
|
||||
self.assertRaises(ident.NoUser, p.lookup, ("127.0.0.1", 26), ("1.2.3.4", 762))
|
||||
self.assertRaises(ident.NoUser, p.lookup, ("127.0.0.1", 25), ("1.2.3.5", 762))
|
||||
self.assertRaises(ident.NoUser, p.lookup, ("127.0.0.1", 25), ("1.2.3.4", 763))
|
||||
|
||||
def testLookupProcNetTcp(self):
|
||||
"""
|
||||
L{ident.ProcServerMixin.lookup} uses the Linux TCP process table.
|
||||
"""
|
||||
open_calls = []
|
||||
|
||||
def mocked_open(*args, **kwargs):
|
||||
"""
|
||||
Mock for the open call to prevent actually opening /proc/net/tcp.
|
||||
"""
|
||||
open_calls.append((args, kwargs))
|
||||
return StringIO(self.sampleFile)
|
||||
|
||||
self.patch(builtins, "open", mocked_open)
|
||||
|
||||
p = ident.ProcServerMixin()
|
||||
self.assertRaises(ident.NoUser, p.lookup, ("127.0.0.1", 26), ("1.2.3.4", 762))
|
||||
self.assertEqual([(("/proc/net/tcp",), {})], open_calls)
|
||||
1400
.venv/lib/python3.12/site-packages/twisted/test/test_internet.py
Normal file
1400
.venv/lib/python3.12/site-packages/twisted/test/test_internet.py
Normal file
File diff suppressed because it is too large
Load Diff
314
.venv/lib/python3.12/site-packages/twisted/test/test_iosim.py
Normal file
314
.venv/lib/python3.12/site-packages/twisted/test/test_iosim.py
Normal file
@@ -0,0 +1,314 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.test.iosim}.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.internet.interfaces import IPushProducer
|
||||
from twisted.internet.protocol import Protocol
|
||||
from twisted.internet.task import Clock
|
||||
from twisted.test.iosim import FakeTransport, connect, connectedServerAndClient
|
||||
from twisted.trial.unittest import TestCase
|
||||
|
||||
|
||||
class FakeTransportTests(TestCase):
|
||||
"""
|
||||
Tests for L{FakeTransport}.
|
||||
"""
|
||||
|
||||
def test_connectionSerial(self) -> None:
|
||||
"""
|
||||
Each L{FakeTransport} receives a serial number that uniquely identifies
|
||||
it.
|
||||
"""
|
||||
a = FakeTransport(object(), True)
|
||||
b = FakeTransport(object(), False)
|
||||
self.assertIsInstance(a.serial, int)
|
||||
self.assertIsInstance(b.serial, int)
|
||||
self.assertNotEqual(a.serial, b.serial)
|
||||
|
||||
def test_writeSequence(self) -> None:
|
||||
"""
|
||||
L{FakeTransport.writeSequence} will write a sequence of L{bytes} to the
|
||||
transport.
|
||||
"""
|
||||
a = FakeTransport(object(), False)
|
||||
|
||||
a.write(b"a")
|
||||
a.writeSequence([b"b", b"c", b"d"])
|
||||
|
||||
self.assertEqual(b"".join(a.stream), b"abcd")
|
||||
|
||||
def test_writeAfterClose(self) -> None:
|
||||
"""
|
||||
L{FakeTransport.write} will accept writes after transport was closed,
|
||||
but the data will be silently discarded.
|
||||
"""
|
||||
a = FakeTransport(object(), False)
|
||||
a.write(b"before")
|
||||
a.loseConnection()
|
||||
a.write(b"after")
|
||||
|
||||
self.assertEqual(b"".join(a.stream), b"before")
|
||||
|
||||
|
||||
@implementer(IPushProducer)
|
||||
class StrictPushProducer:
|
||||
"""
|
||||
An L{IPushProducer} implementation which produces nothing but enforces
|
||||
preconditions on its state transition methods.
|
||||
"""
|
||||
|
||||
_state = "running"
|
||||
|
||||
def stopProducing(self) -> None:
|
||||
if self._state == "stopped":
|
||||
raise ValueError("Cannot stop already-stopped IPushProducer")
|
||||
self._state = "stopped"
|
||||
|
||||
def pauseProducing(self) -> None:
|
||||
if self._state != "running":
|
||||
raise ValueError(f"Cannot pause {self._state} IPushProducer")
|
||||
self._state = "paused"
|
||||
|
||||
def resumeProducing(self) -> None:
|
||||
if self._state != "paused":
|
||||
raise ValueError(f"Cannot resume {self._state} IPushProducer")
|
||||
self._state = "running"
|
||||
|
||||
|
||||
class StrictPushProducerTests(TestCase):
|
||||
"""
|
||||
Tests for L{StrictPushProducer}.
|
||||
"""
|
||||
|
||||
def _initial(self) -> StrictPushProducer:
|
||||
"""
|
||||
@return: A new L{StrictPushProducer} which has not been through any state
|
||||
changes.
|
||||
"""
|
||||
return StrictPushProducer()
|
||||
|
||||
def _stopped(self) -> StrictPushProducer:
|
||||
"""
|
||||
@return: A new, stopped L{StrictPushProducer}.
|
||||
"""
|
||||
producer = StrictPushProducer()
|
||||
producer.stopProducing()
|
||||
return producer
|
||||
|
||||
def _paused(self) -> StrictPushProducer:
|
||||
"""
|
||||
@return: A new, paused L{StrictPushProducer}.
|
||||
"""
|
||||
producer = StrictPushProducer()
|
||||
producer.pauseProducing()
|
||||
return producer
|
||||
|
||||
def _resumed(self) -> StrictPushProducer:
|
||||
"""
|
||||
@return: A new L{StrictPushProducer} which has been paused and resumed.
|
||||
"""
|
||||
producer = StrictPushProducer()
|
||||
producer.pauseProducing()
|
||||
producer.resumeProducing()
|
||||
return producer
|
||||
|
||||
def assertStopped(self, producer: StrictPushProducer) -> None:
|
||||
"""
|
||||
Assert that the given producer is in the stopped state.
|
||||
|
||||
@param producer: The producer to verify.
|
||||
@type producer: L{StrictPushProducer}
|
||||
"""
|
||||
self.assertEqual(producer._state, "stopped")
|
||||
|
||||
def assertPaused(self, producer: StrictPushProducer) -> None:
|
||||
"""
|
||||
Assert that the given producer is in the paused state.
|
||||
|
||||
@param producer: The producer to verify.
|
||||
@type producer: L{StrictPushProducer}
|
||||
"""
|
||||
self.assertEqual(producer._state, "paused")
|
||||
|
||||
def assertRunning(self, producer: StrictPushProducer) -> None:
|
||||
"""
|
||||
Assert that the given producer is in the running state.
|
||||
|
||||
@param producer: The producer to verify.
|
||||
@type producer: L{StrictPushProducer}
|
||||
"""
|
||||
self.assertEqual(producer._state, "running")
|
||||
|
||||
def test_stopThenStop(self) -> None:
|
||||
"""
|
||||
L{StrictPushProducer.stopProducing} raises L{ValueError} if called when
|
||||
the producer is stopped.
|
||||
"""
|
||||
self.assertRaises(ValueError, self._stopped().stopProducing)
|
||||
|
||||
def test_stopThenPause(self) -> None:
|
||||
"""
|
||||
L{StrictPushProducer.pauseProducing} raises L{ValueError} if called when
|
||||
the producer is stopped.
|
||||
"""
|
||||
self.assertRaises(ValueError, self._stopped().pauseProducing)
|
||||
|
||||
def test_stopThenResume(self) -> None:
|
||||
"""
|
||||
L{StrictPushProducer.resumeProducing} raises L{ValueError} if called when
|
||||
the producer is stopped.
|
||||
"""
|
||||
self.assertRaises(ValueError, self._stopped().resumeProducing)
|
||||
|
||||
def test_pauseThenStop(self) -> None:
|
||||
"""
|
||||
L{StrictPushProducer} is stopped if C{stopProducing} is called on a paused
|
||||
producer.
|
||||
"""
|
||||
producer = self._paused()
|
||||
producer.stopProducing()
|
||||
self.assertStopped(producer)
|
||||
|
||||
def test_pauseThenPause(self) -> None:
|
||||
"""
|
||||
L{StrictPushProducer.pauseProducing} raises L{ValueError} if called on a
|
||||
paused producer.
|
||||
"""
|
||||
producer = self._paused()
|
||||
self.assertRaises(ValueError, producer.pauseProducing)
|
||||
|
||||
def test_pauseThenResume(self) -> None:
|
||||
"""
|
||||
L{StrictPushProducer} is resumed if C{resumeProducing} is called on a
|
||||
paused producer.
|
||||
"""
|
||||
producer = self._paused()
|
||||
producer.resumeProducing()
|
||||
self.assertRunning(producer)
|
||||
|
||||
def test_resumeThenStop(self) -> None:
|
||||
"""
|
||||
L{StrictPushProducer} is stopped if C{stopProducing} is called on a
|
||||
resumed producer.
|
||||
"""
|
||||
producer = self._resumed()
|
||||
producer.stopProducing()
|
||||
self.assertStopped(producer)
|
||||
|
||||
def test_resumeThenPause(self) -> None:
|
||||
"""
|
||||
L{StrictPushProducer} is paused if C{pauseProducing} is called on a
|
||||
resumed producer.
|
||||
"""
|
||||
producer = self._resumed()
|
||||
producer.pauseProducing()
|
||||
self.assertPaused(producer)
|
||||
|
||||
def test_resumeThenResume(self) -> None:
|
||||
"""
|
||||
L{StrictPushProducer.resumeProducing} raises L{ValueError} if called on a
|
||||
resumed producer.
|
||||
"""
|
||||
producer = self._resumed()
|
||||
self.assertRaises(ValueError, producer.resumeProducing)
|
||||
|
||||
def test_stop(self) -> None:
|
||||
"""
|
||||
L{StrictPushProducer} is stopped if C{stopProducing} is called in the
|
||||
initial state.
|
||||
"""
|
||||
producer = self._initial()
|
||||
producer.stopProducing()
|
||||
self.assertStopped(producer)
|
||||
|
||||
def test_pause(self) -> None:
|
||||
"""
|
||||
L{StrictPushProducer} is paused if C{pauseProducing} is called in the
|
||||
initial state.
|
||||
"""
|
||||
producer = self._initial()
|
||||
producer.pauseProducing()
|
||||
self.assertPaused(producer)
|
||||
|
||||
def test_resume(self) -> None:
|
||||
"""
|
||||
L{StrictPushProducer} raises L{ValueError} if C{resumeProducing} is called
|
||||
in the initial state.
|
||||
"""
|
||||
producer = self._initial()
|
||||
self.assertRaises(ValueError, producer.resumeProducing)
|
||||
|
||||
|
||||
class IOPumpTests(TestCase):
|
||||
"""
|
||||
Tests for L{IOPump}.
|
||||
"""
|
||||
|
||||
def _testStreamingProducer(self, mode: Literal["server", "client"]) -> None:
|
||||
"""
|
||||
Connect a couple protocol/transport pairs to an L{IOPump} and then pump
|
||||
it. Verify that a streaming producer registered with one of the
|
||||
transports does not receive invalid L{IPushProducer} method calls and
|
||||
ends in the right state.
|
||||
|
||||
@param mode: C{u"server"} to test a producer registered with the
|
||||
server transport. C{u"client"} to test a producer registered with
|
||||
the client transport.
|
||||
"""
|
||||
serverProto = Protocol()
|
||||
serverTransport = FakeTransport(serverProto, isServer=True)
|
||||
|
||||
clientProto = Protocol()
|
||||
clientTransport = FakeTransport(clientProto, isServer=False)
|
||||
|
||||
pump = connect(
|
||||
serverProto,
|
||||
serverTransport,
|
||||
clientProto,
|
||||
clientTransport,
|
||||
greet=False,
|
||||
)
|
||||
|
||||
producer = StrictPushProducer()
|
||||
victim = {
|
||||
"server": serverTransport,
|
||||
"client": clientTransport,
|
||||
}[mode]
|
||||
victim.registerProducer(producer, streaming=True)
|
||||
|
||||
pump.pump()
|
||||
self.assertEqual("running", producer._state)
|
||||
|
||||
def test_serverStreamingProducer(self) -> None:
|
||||
"""
|
||||
L{IOPump.pump} does not call C{resumeProducing} on a L{IPushProducer}
|
||||
(stream producer) registered with the server transport.
|
||||
"""
|
||||
self._testStreamingProducer(mode="server")
|
||||
|
||||
def test_clientStreamingProducer(self) -> None:
|
||||
"""
|
||||
L{IOPump.pump} does not call C{resumeProducing} on a L{IPushProducer}
|
||||
(stream producer) registered with the client transport.
|
||||
"""
|
||||
self._testStreamingProducer(mode="client")
|
||||
|
||||
def test_timeAdvances(self) -> None:
|
||||
"""
|
||||
L{IOPump.pump} advances time in the given L{Clock}.
|
||||
"""
|
||||
time_passed = []
|
||||
clock = Clock()
|
||||
_, _, pump = connectedServerAndClient(Protocol, Protocol, clock=clock)
|
||||
clock.callLater(0, lambda: time_passed.append(True))
|
||||
self.assertFalse(time_passed)
|
||||
pump.pump()
|
||||
self.assertTrue(time_passed)
|
||||
383
.venv/lib/python3.12/site-packages/twisted/test/test_iutils.py
Normal file
383
.venv/lib/python3.12/site-packages/twisted/test/test_iutils.py
Normal file
@@ -0,0 +1,383 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Test running processes with the APIs in L{twisted.internet.utils}.
|
||||
"""
|
||||
|
||||
|
||||
import os
|
||||
import signal
|
||||
import stat
|
||||
import sys
|
||||
import warnings
|
||||
from unittest import skipIf
|
||||
|
||||
from twisted.internet import error, interfaces, reactor, utils
|
||||
from twisted.internet.defer import Deferred
|
||||
from twisted.python.runtime import platform
|
||||
from twisted.python.test.test_util import SuppressedWarningsTests
|
||||
from twisted.trial.unittest import SynchronousTestCase, TestCase
|
||||
|
||||
|
||||
class ProcessUtilsTests(TestCase):
|
||||
"""
|
||||
Test running a process using L{getProcessOutput}, L{getProcessValue}, and
|
||||
L{getProcessOutputAndValue}.
|
||||
"""
|
||||
|
||||
if interfaces.IReactorProcess(reactor, None) is None:
|
||||
skip = "reactor doesn't implement IReactorProcess"
|
||||
|
||||
output = None
|
||||
value = None
|
||||
exe = sys.executable
|
||||
|
||||
def makeSourceFile(self, sourceLines):
|
||||
"""
|
||||
Write the given list of lines to a text file and return the absolute
|
||||
path to it.
|
||||
"""
|
||||
script = self.mktemp()
|
||||
with open(script, "wt") as scriptFile:
|
||||
scriptFile.write(os.linesep.join(sourceLines) + os.linesep)
|
||||
return os.path.abspath(script)
|
||||
|
||||
def test_output(self):
|
||||
"""
|
||||
L{getProcessOutput} returns a L{Deferred} which fires with the complete
|
||||
output of the process it runs after that process exits.
|
||||
"""
|
||||
scriptFile = self.makeSourceFile(
|
||||
[
|
||||
"import sys",
|
||||
"for s in b'hello world\\n':",
|
||||
" s = bytes([s])",
|
||||
" sys.stdout.buffer.write(s)",
|
||||
" sys.stdout.flush()",
|
||||
]
|
||||
)
|
||||
d = utils.getProcessOutput(self.exe, ["-u", scriptFile])
|
||||
return d.addCallback(self.assertEqual, b"hello world\n")
|
||||
|
||||
def test_outputWithErrorIgnored(self):
|
||||
"""
|
||||
The L{Deferred} returned by L{getProcessOutput} is fired with an
|
||||
L{IOError} L{Failure} if the child process writes to stderr.
|
||||
"""
|
||||
# make sure stderr raises an error normally
|
||||
scriptFile = self.makeSourceFile(
|
||||
["import sys", 'sys.stderr.write("hello world\\n")']
|
||||
)
|
||||
|
||||
d = utils.getProcessOutput(self.exe, ["-u", scriptFile])
|
||||
d = self.assertFailure(d, IOError)
|
||||
|
||||
def cbFailed(err):
|
||||
return self.assertFailure(err.processEnded, error.ProcessDone)
|
||||
|
||||
d.addCallback(cbFailed)
|
||||
return d
|
||||
|
||||
def test_outputWithErrorCollected(self):
|
||||
"""
|
||||
If a C{True} value is supplied for the C{errortoo} parameter to
|
||||
L{getProcessOutput}, the returned L{Deferred} fires with the child's
|
||||
stderr output as well as its stdout output.
|
||||
"""
|
||||
scriptFile = self.makeSourceFile(
|
||||
[
|
||||
"import sys",
|
||||
# Write the same value to both because ordering isn't guaranteed so
|
||||
# this simplifies the test.
|
||||
'sys.stdout.write("foo")',
|
||||
"sys.stdout.flush()",
|
||||
'sys.stderr.write("foo")',
|
||||
"sys.stderr.flush()",
|
||||
]
|
||||
)
|
||||
|
||||
d = utils.getProcessOutput(self.exe, ["-u", scriptFile], errortoo=True)
|
||||
return d.addCallback(self.assertEqual, b"foofoo")
|
||||
|
||||
def test_value(self):
|
||||
"""
|
||||
The L{Deferred} returned by L{getProcessValue} is fired with the exit
|
||||
status of the child process.
|
||||
"""
|
||||
scriptFile = self.makeSourceFile(["raise SystemExit(1)"])
|
||||
|
||||
d = utils.getProcessValue(self.exe, ["-u", scriptFile])
|
||||
return d.addCallback(self.assertEqual, 1)
|
||||
|
||||
def test_outputAndValue(self):
|
||||
"""
|
||||
The L{Deferred} returned by L{getProcessOutputAndValue} fires with a
|
||||
three-tuple, the elements of which give the data written to the child's
|
||||
stdout, the data written to the child's stderr, and the exit status of
|
||||
the child.
|
||||
"""
|
||||
scriptFile = self.makeSourceFile(
|
||||
[
|
||||
"import sys",
|
||||
"sys.stdout.buffer.write(b'hello world!\\n')",
|
||||
"sys.stderr.buffer.write(b'goodbye world!\\n')",
|
||||
"sys.exit(1)",
|
||||
]
|
||||
)
|
||||
|
||||
def gotOutputAndValue(out_err_code):
|
||||
out, err, code = out_err_code
|
||||
self.assertEqual(out, b"hello world!\n")
|
||||
self.assertEqual(err, b"goodbye world!\n")
|
||||
self.assertEqual(code, 1)
|
||||
|
||||
d = utils.getProcessOutputAndValue(self.exe, ["-u", scriptFile])
|
||||
return d.addCallback(gotOutputAndValue)
|
||||
|
||||
@skipIf(platform.isWindows(), "Windows doesn't have real signals.")
|
||||
def test_outputSignal(self):
|
||||
"""
|
||||
If the child process exits because of a signal, the L{Deferred}
|
||||
returned by L{getProcessOutputAndValue} fires a L{Failure} of a tuple
|
||||
containing the child's stdout, stderr, and the signal which caused
|
||||
it to exit.
|
||||
"""
|
||||
# Use SIGKILL here because it's guaranteed to be delivered. Using
|
||||
# SIGHUP might not work in, e.g., a buildbot slave run under the
|
||||
# 'nohup' command.
|
||||
scriptFile = self.makeSourceFile(
|
||||
[
|
||||
"import sys, os, signal",
|
||||
"sys.stdout.write('stdout bytes\\n')",
|
||||
"sys.stderr.write('stderr bytes\\n')",
|
||||
"sys.stdout.flush()",
|
||||
"sys.stderr.flush()",
|
||||
"os.kill(os.getpid(), signal.SIGKILL)",
|
||||
]
|
||||
)
|
||||
|
||||
def gotOutputAndValue(out_err_sig):
|
||||
out, err, sig = out_err_sig
|
||||
self.assertEqual(out, b"stdout bytes\n")
|
||||
self.assertEqual(err, b"stderr bytes\n")
|
||||
self.assertEqual(sig, signal.SIGKILL)
|
||||
|
||||
d = utils.getProcessOutputAndValue(self.exe, ["-u", scriptFile])
|
||||
d = self.assertFailure(d, tuple)
|
||||
return d.addCallback(gotOutputAndValue)
|
||||
|
||||
def _pathTest(self, utilFunc, check):
|
||||
dir = os.path.abspath(self.mktemp())
|
||||
os.makedirs(dir)
|
||||
scriptFile = self.makeSourceFile(
|
||||
["import os, sys", "sys.stdout.write(os.getcwd())"]
|
||||
)
|
||||
d = utilFunc(self.exe, ["-u", scriptFile], path=dir)
|
||||
d.addCallback(check, dir.encode(sys.getfilesystemencoding()))
|
||||
return d
|
||||
|
||||
def test_getProcessOutputPath(self):
|
||||
"""
|
||||
L{getProcessOutput} runs the given command with the working directory
|
||||
given by the C{path} parameter.
|
||||
"""
|
||||
return self._pathTest(utils.getProcessOutput, self.assertEqual)
|
||||
|
||||
def test_getProcessValuePath(self):
|
||||
"""
|
||||
L{getProcessValue} runs the given command with the working directory
|
||||
given by the C{path} parameter.
|
||||
"""
|
||||
|
||||
def check(result, ignored):
|
||||
self.assertEqual(result, 0)
|
||||
|
||||
return self._pathTest(utils.getProcessValue, check)
|
||||
|
||||
def test_getProcessOutputAndValuePath(self):
|
||||
"""
|
||||
L{getProcessOutputAndValue} runs the given command with the working
|
||||
directory given by the C{path} parameter.
|
||||
"""
|
||||
|
||||
def check(out_err_status, dir):
|
||||
out, err, status = out_err_status
|
||||
self.assertEqual(out, dir)
|
||||
self.assertEqual(status, 0)
|
||||
|
||||
return self._pathTest(utils.getProcessOutputAndValue, check)
|
||||
|
||||
def _defaultPathTest(self, utilFunc, check):
|
||||
# Make another directory to mess around with.
|
||||
dir = os.path.abspath(self.mktemp())
|
||||
os.makedirs(dir)
|
||||
|
||||
scriptFile = self.makeSourceFile(
|
||||
["import os, sys", "cdir = os.getcwd()", "sys.stdout.write(cdir)"]
|
||||
)
|
||||
|
||||
# Switch to it, but make sure we switch back
|
||||
self.addCleanup(os.chdir, os.getcwd())
|
||||
os.chdir(dir)
|
||||
|
||||
# Remember its default permissions.
|
||||
originalMode = stat.S_IMODE(os.stat(".").st_mode)
|
||||
|
||||
# On macOS Catalina (and maybe elsewhere), os.getcwd() sometimes fails
|
||||
# with EACCES if u+rx is missing from the working directory, so don't
|
||||
# reduce it further than this.
|
||||
os.chmod(dir, stat.S_IXUSR | stat.S_IRUSR)
|
||||
|
||||
# Restore the permissions to their original state later (probably
|
||||
# adding at least u+w), because otherwise it might be hard to delete
|
||||
# the trial temporary directory.
|
||||
self.addCleanup(os.chmod, dir, originalMode)
|
||||
|
||||
d = utilFunc(self.exe, ["-u", scriptFile])
|
||||
d.addCallback(check, dir.encode(sys.getfilesystemencoding()))
|
||||
return d
|
||||
|
||||
def test_getProcessOutputDefaultPath(self):
|
||||
"""
|
||||
If no value is supplied for the C{path} parameter, L{getProcessOutput}
|
||||
runs the given command in the same working directory as the parent
|
||||
process and succeeds even if the current working directory is not
|
||||
accessible.
|
||||
"""
|
||||
return self._defaultPathTest(utils.getProcessOutput, self.assertEqual)
|
||||
|
||||
def test_getProcessValueDefaultPath(self):
|
||||
"""
|
||||
If no value is supplied for the C{path} parameter, L{getProcessValue}
|
||||
runs the given command in the same working directory as the parent
|
||||
process and succeeds even if the current working directory is not
|
||||
accessible.
|
||||
"""
|
||||
|
||||
def check(result, ignored):
|
||||
self.assertEqual(result, 0)
|
||||
|
||||
return self._defaultPathTest(utils.getProcessValue, check)
|
||||
|
||||
def test_getProcessOutputAndValueDefaultPath(self):
|
||||
"""
|
||||
If no value is supplied for the C{path} parameter,
|
||||
L{getProcessOutputAndValue} runs the given command in the same working
|
||||
directory as the parent process and succeeds even if the current
|
||||
working directory is not accessible.
|
||||
"""
|
||||
|
||||
def check(out_err_status, dir):
|
||||
out, err, status = out_err_status
|
||||
self.assertEqual(out, dir)
|
||||
self.assertEqual(status, 0)
|
||||
|
||||
return self._defaultPathTest(utils.getProcessOutputAndValue, check)
|
||||
|
||||
def test_get_processOutputAndValueStdin(self):
|
||||
"""
|
||||
Standard input can be made available to the child process by passing
|
||||
bytes for the `stdinBytes` parameter.
|
||||
"""
|
||||
scriptFile = self.makeSourceFile(
|
||||
[
|
||||
"import sys",
|
||||
"sys.stdout.write(sys.stdin.read())",
|
||||
]
|
||||
)
|
||||
stdinBytes = b"These are the bytes to see."
|
||||
d = utils.getProcessOutputAndValue(
|
||||
self.exe,
|
||||
["-u", scriptFile],
|
||||
stdinBytes=stdinBytes,
|
||||
)
|
||||
|
||||
def gotOutputAndValue(out_err_code):
|
||||
out, err, code = out_err_code
|
||||
# Avoid making an exact equality comparison in case there is extra
|
||||
# random output on stdout (warnings, stray print statements,
|
||||
# logging, who knows).
|
||||
self.assertIn(stdinBytes, out)
|
||||
self.assertEqual(0, code)
|
||||
|
||||
d.addCallback(gotOutputAndValue)
|
||||
return d
|
||||
|
||||
|
||||
class SuppressWarningsTests(SynchronousTestCase):
|
||||
"""
|
||||
Tests for L{utils.suppressWarnings}.
|
||||
"""
|
||||
|
||||
def test_suppressWarnings(self):
|
||||
"""
|
||||
L{utils.suppressWarnings} decorates a function so that the given
|
||||
warnings are suppressed.
|
||||
"""
|
||||
result = []
|
||||
|
||||
def showwarning(self, *a, **kw):
|
||||
result.append((a, kw))
|
||||
|
||||
self.patch(warnings, "showwarning", showwarning)
|
||||
|
||||
def f(msg):
|
||||
warnings.warn(msg)
|
||||
|
||||
g = utils.suppressWarnings(f, (("ignore",), dict(message="This is message")))
|
||||
|
||||
# Start off with a sanity check - calling the original function
|
||||
# should emit the warning.
|
||||
f("Sanity check message")
|
||||
self.assertEqual(len(result), 1)
|
||||
|
||||
# Now that that's out of the way, call the wrapped function, and
|
||||
# make sure no new warnings show up.
|
||||
g("This is message")
|
||||
self.assertEqual(len(result), 1)
|
||||
|
||||
# Finally, emit another warning which should not be ignored, and
|
||||
# make sure it is not.
|
||||
g("Unignored message")
|
||||
self.assertEqual(len(result), 2)
|
||||
|
||||
|
||||
class DeferredSuppressedWarningsTests(SuppressedWarningsTests):
|
||||
"""
|
||||
Tests for L{utils.runWithWarningsSuppressed}, the version that supports
|
||||
Deferreds.
|
||||
"""
|
||||
|
||||
# Override the non-Deferred-supporting function from the base class with
|
||||
# the function we are testing in this class:
|
||||
runWithWarningsSuppressed = staticmethod(utils.runWithWarningsSuppressed)
|
||||
|
||||
def test_deferredCallback(self):
|
||||
"""
|
||||
If the function called by L{utils.runWithWarningsSuppressed} returns a
|
||||
C{Deferred}, the warning filters aren't removed until the Deferred
|
||||
fires.
|
||||
"""
|
||||
filters = [(("ignore", ".*foo.*"), {}), (("ignore", ".*bar.*"), {})]
|
||||
result = Deferred()
|
||||
self.runWithWarningsSuppressed(filters, lambda: result)
|
||||
warnings.warn("ignore foo")
|
||||
result.callback(3)
|
||||
warnings.warn("ignore foo 2")
|
||||
self.assertEqual(["ignore foo 2"], [w["message"] for w in self.flushWarnings()])
|
||||
|
||||
def test_deferredErrback(self):
|
||||
"""
|
||||
If the function called by L{utils.runWithWarningsSuppressed} returns a
|
||||
C{Deferred}, the warning filters aren't removed until the Deferred
|
||||
fires with an errback.
|
||||
"""
|
||||
filters = [(("ignore", ".*foo.*"), {}), (("ignore", ".*bar.*"), {})]
|
||||
result = Deferred()
|
||||
d = self.runWithWarningsSuppressed(filters, lambda: result)
|
||||
warnings.warn("ignore foo")
|
||||
result.errback(ZeroDivisionError())
|
||||
d.addErrback(lambda f: f.trap(ZeroDivisionError))
|
||||
warnings.warn("ignore foo 2")
|
||||
self.assertEqual(["ignore foo 2"], [w["message"] for w in self.flushWarnings()])
|
||||
457
.venv/lib/python3.12/site-packages/twisted/test/test_lockfile.py
Normal file
457
.venv/lib/python3.12/site-packages/twisted/test/test_lockfile.py
Normal file
@@ -0,0 +1,457 @@
|
||||
# Copyright (c) 2005 Divmod, Inc.
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.python.lockfile}.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import errno
|
||||
import os
|
||||
from unittest import skipIf, skipUnless
|
||||
|
||||
from typing_extensions import NoReturn
|
||||
|
||||
from twisted.python import lockfile
|
||||
from twisted.python.reflect import requireModule
|
||||
from twisted.python.runtime import platform
|
||||
from twisted.trial.unittest import TestCase
|
||||
|
||||
skipKill = False
|
||||
skipKillReason = ""
|
||||
if platform.isWindows():
|
||||
if (
|
||||
requireModule("win32api.OpenProcess") is None
|
||||
and requireModule("pywintypes") is None
|
||||
):
|
||||
skipKill = True
|
||||
skipKillReason = (
|
||||
"On windows, lockfile.kill is not implemented "
|
||||
"in the absence of win32api and/or pywintypes."
|
||||
)
|
||||
|
||||
|
||||
class UtilTests(TestCase):
|
||||
"""
|
||||
Tests for the helper functions used to implement L{FilesystemLock}.
|
||||
"""
|
||||
|
||||
def test_symlinkEEXIST(self) -> None:
|
||||
"""
|
||||
L{lockfile.symlink} raises L{OSError} with C{errno} set to L{EEXIST}
|
||||
when an attempt is made to create a symlink which already exists.
|
||||
"""
|
||||
name = self.mktemp()
|
||||
lockfile.symlink("foo", name)
|
||||
exc = self.assertRaises(OSError, lockfile.symlink, "foo", name)
|
||||
self.assertEqual(exc.errno, errno.EEXIST)
|
||||
|
||||
@skipUnless(
|
||||
platform.isWindows(),
|
||||
"special rename EIO handling only necessary and correct on " "Windows.",
|
||||
)
|
||||
def test_symlinkEIOWindows(self) -> None:
|
||||
"""
|
||||
L{lockfile.symlink} raises L{OSError} with C{errno} set to L{EIO} when
|
||||
the underlying L{rename} call fails with L{EIO}.
|
||||
|
||||
Renaming a file on Windows may fail if the target of the rename is in
|
||||
the process of being deleted (directory deletion appears not to be
|
||||
atomic).
|
||||
"""
|
||||
name = self.mktemp()
|
||||
|
||||
def fakeRename(src: str, dst: str) -> NoReturn:
|
||||
raise OSError(errno.EIO, None)
|
||||
|
||||
self.patch(lockfile, "rename", fakeRename)
|
||||
exc = self.assertRaises(IOError, lockfile.symlink, name, "foo")
|
||||
self.assertEqual(exc.errno, errno.EIO)
|
||||
|
||||
def test_readlinkENOENT(self) -> None:
|
||||
"""
|
||||
L{lockfile.readlink} raises L{OSError} with C{errno} set to L{ENOENT}
|
||||
when an attempt is made to read a symlink which does not exist.
|
||||
"""
|
||||
name = self.mktemp()
|
||||
exc = self.assertRaises(OSError, lockfile.readlink, name)
|
||||
self.assertEqual(exc.errno, errno.ENOENT)
|
||||
|
||||
@skipUnless(
|
||||
platform.isWindows(),
|
||||
"special readlink EACCES handling only necessary and " "correct on Windows.",
|
||||
)
|
||||
def test_readlinkEACCESWindows(self) -> None:
|
||||
"""
|
||||
L{lockfile.readlink} raises L{OSError} with C{errno} set to L{EACCES}
|
||||
on Windows when the underlying file open attempt fails with C{EACCES}.
|
||||
|
||||
Opening a file on Windows may fail if the path is inside a directory
|
||||
which is in the process of being deleted (directory deletion appears
|
||||
not to be atomic).
|
||||
"""
|
||||
name = self.mktemp()
|
||||
|
||||
def fakeOpen(path: str, mode: str) -> NoReturn:
|
||||
raise OSError(errno.EACCES, None)
|
||||
|
||||
self.patch(lockfile, "_open", fakeOpen)
|
||||
exc = self.assertRaises(IOError, lockfile.readlink, name)
|
||||
self.assertEqual(exc.errno, errno.EACCES)
|
||||
|
||||
@skipIf(skipKill, skipKillReason)
|
||||
def test_kill(self) -> None:
|
||||
"""
|
||||
L{lockfile.kill} returns without error if passed the PID of a
|
||||
process which exists and signal C{0}.
|
||||
"""
|
||||
lockfile.kill(os.getpid(), 0)
|
||||
|
||||
@skipIf(skipKill, skipKillReason)
|
||||
def test_killESRCH(self) -> None:
|
||||
"""
|
||||
L{lockfile.kill} raises L{OSError} with errno of L{ESRCH} if
|
||||
passed a PID which does not correspond to any process.
|
||||
"""
|
||||
# Hopefully there is no process with PID 2 ** 31 - 1
|
||||
exc = self.assertRaises(OSError, lockfile.kill, 2**31 - 1, 0)
|
||||
self.assertEqual(exc.errno, errno.ESRCH)
|
||||
|
||||
def test_noKillCall(self) -> None:
|
||||
"""
|
||||
Verify that when L{lockfile.kill} does end up as None (e.g. on Windows
|
||||
without pywin32), it doesn't end up being called and raising a
|
||||
L{TypeError}.
|
||||
"""
|
||||
self.patch(lockfile, "kill", None)
|
||||
fl = lockfile.FilesystemLock(self.mktemp())
|
||||
fl.lock()
|
||||
self.assertFalse(fl.lock())
|
||||
|
||||
|
||||
class LockingTests(TestCase):
|
||||
def _symlinkErrorTest(self, errno: int) -> None:
|
||||
def fakeSymlink(source: str, dest: str) -> NoReturn:
|
||||
raise OSError(errno, None)
|
||||
|
||||
self.patch(lockfile, "symlink", fakeSymlink)
|
||||
|
||||
lockf = self.mktemp()
|
||||
lock = lockfile.FilesystemLock(lockf)
|
||||
exc = self.assertRaises(OSError, lock.lock)
|
||||
self.assertEqual(exc.errno, errno)
|
||||
|
||||
def test_symlinkError(self) -> None:
|
||||
"""
|
||||
An exception raised by C{symlink} other than C{EEXIST} is passed up to
|
||||
the caller of L{FilesystemLock.lock}.
|
||||
"""
|
||||
self._symlinkErrorTest(errno.ENOSYS)
|
||||
|
||||
@skipIf(
|
||||
platform.isWindows(),
|
||||
"POSIX-specific error propagation not expected on Windows.",
|
||||
)
|
||||
def test_symlinkErrorPOSIX(self) -> None:
|
||||
"""
|
||||
An L{OSError} raised by C{symlink} on a POSIX platform with an errno of
|
||||
C{EACCES} or C{EIO} is passed to the caller of L{FilesystemLock.lock}.
|
||||
|
||||
On POSIX, unlike on Windows, these are unexpected errors which cannot
|
||||
be handled by L{FilesystemLock}.
|
||||
"""
|
||||
self._symlinkErrorTest(errno.EACCES)
|
||||
self._symlinkErrorTest(errno.EIO)
|
||||
|
||||
def test_cleanlyAcquire(self) -> None:
|
||||
"""
|
||||
If the lock has never been held, it can be acquired and the C{clean}
|
||||
and C{locked} attributes are set to C{True}.
|
||||
"""
|
||||
lockf = self.mktemp()
|
||||
lock = lockfile.FilesystemLock(lockf)
|
||||
self.assertTrue(lock.lock())
|
||||
self.assertTrue(lock.clean)
|
||||
self.assertTrue(lock.locked)
|
||||
|
||||
def test_cleanlyRelease(self) -> None:
|
||||
"""
|
||||
If a lock is released cleanly, it can be re-acquired and the C{clean}
|
||||
and C{locked} attributes are set to C{True}.
|
||||
"""
|
||||
lockf = self.mktemp()
|
||||
lock = lockfile.FilesystemLock(lockf)
|
||||
self.assertTrue(lock.lock())
|
||||
lock.unlock()
|
||||
self.assertFalse(lock.locked)
|
||||
|
||||
lock = lockfile.FilesystemLock(lockf)
|
||||
self.assertTrue(lock.lock())
|
||||
self.assertTrue(lock.clean)
|
||||
self.assertTrue(lock.locked)
|
||||
|
||||
def test_cannotLockLocked(self) -> None:
|
||||
"""
|
||||
If a lock is currently locked, it cannot be locked again.
|
||||
"""
|
||||
lockf = self.mktemp()
|
||||
firstLock = lockfile.FilesystemLock(lockf)
|
||||
self.assertTrue(firstLock.lock())
|
||||
|
||||
secondLock = lockfile.FilesystemLock(lockf)
|
||||
self.assertFalse(secondLock.lock())
|
||||
self.assertFalse(secondLock.locked)
|
||||
|
||||
def test_uncleanlyAcquire(self) -> None:
|
||||
"""
|
||||
If a lock was held by a process which no longer exists, it can be
|
||||
acquired, the C{clean} attribute is set to C{False}, and the
|
||||
C{locked} attribute is set to C{True}.
|
||||
"""
|
||||
owner = 12345
|
||||
|
||||
def fakeKill(pid: int, signal: int) -> None:
|
||||
if signal != 0:
|
||||
raise OSError(errno.EPERM, None)
|
||||
if pid == owner:
|
||||
raise OSError(errno.ESRCH, None)
|
||||
|
||||
lockf = self.mktemp()
|
||||
self.patch(lockfile, "kill", fakeKill)
|
||||
lockfile.symlink(str(owner), lockf)
|
||||
|
||||
lock = lockfile.FilesystemLock(lockf)
|
||||
self.assertTrue(lock.lock())
|
||||
self.assertFalse(lock.clean)
|
||||
self.assertTrue(lock.locked)
|
||||
|
||||
self.assertEqual(lockfile.readlink(lockf), str(os.getpid()))
|
||||
|
||||
def test_lockReleasedBeforeCheck(self) -> None:
|
||||
"""
|
||||
If the lock is initially held but then released before it can be
|
||||
examined to determine if the process which held it still exists, it is
|
||||
acquired and the C{clean} and C{locked} attributes are set to C{True}.
|
||||
"""
|
||||
|
||||
def fakeReadlink(name: str) -> str:
|
||||
# Pretend to be another process releasing the lock.
|
||||
lockfile.rmlink(lockf)
|
||||
# Fall back to the real implementation of readlink.
|
||||
readlinkPatch.restore()
|
||||
return lockfile.readlink(name)
|
||||
|
||||
readlinkPatch = self.patch(lockfile, "readlink", fakeReadlink)
|
||||
|
||||
def fakeKill(pid: int, signal: int) -> None:
|
||||
if signal != 0:
|
||||
raise OSError(errno.EPERM, None)
|
||||
if pid == 43125:
|
||||
raise OSError(errno.ESRCH, None)
|
||||
|
||||
self.patch(lockfile, "kill", fakeKill)
|
||||
|
||||
lockf = self.mktemp()
|
||||
lock = lockfile.FilesystemLock(lockf)
|
||||
lockfile.symlink(str(43125), lockf)
|
||||
self.assertTrue(lock.lock())
|
||||
self.assertTrue(lock.clean)
|
||||
self.assertTrue(lock.locked)
|
||||
|
||||
@skipUnless(
|
||||
platform.isWindows(),
|
||||
"special rename EIO handling only necessary and correct on " "Windows.",
|
||||
)
|
||||
def test_lockReleasedDuringAcquireSymlink(self) -> None:
|
||||
"""
|
||||
If the lock is released while an attempt is made to acquire
|
||||
it, the lock attempt fails and C{FilesystemLock.lock} returns
|
||||
C{False}. This can happen on Windows when L{lockfile.symlink}
|
||||
fails with L{IOError} of C{EIO} because another process is in
|
||||
the middle of a call to L{os.rmdir} (implemented in terms of
|
||||
RemoveDirectory) which is not atomic.
|
||||
"""
|
||||
|
||||
def fakeSymlink(src: str, dst: str) -> NoReturn:
|
||||
# While another process id doing os.rmdir which the Windows
|
||||
# implementation of rmlink does, a rename call will fail with EIO.
|
||||
raise OSError(errno.EIO, None)
|
||||
|
||||
self.patch(lockfile, "symlink", fakeSymlink)
|
||||
|
||||
lockf = self.mktemp()
|
||||
lock = lockfile.FilesystemLock(lockf)
|
||||
self.assertFalse(lock.lock())
|
||||
self.assertFalse(lock.locked)
|
||||
|
||||
@skipUnless(
|
||||
platform.isWindows(),
|
||||
"special readlink EACCES handling only necessary and " "correct on Windows.",
|
||||
)
|
||||
def test_lockReleasedDuringAcquireReadlink(self) -> None:
|
||||
"""
|
||||
If the lock is initially held but is released while an attempt
|
||||
is made to acquire it, the lock attempt fails and
|
||||
L{FilesystemLock.lock} returns C{False}.
|
||||
"""
|
||||
|
||||
def fakeReadlink(name: str) -> NoReturn:
|
||||
# While another process is doing os.rmdir which the
|
||||
# Windows implementation of rmlink does, a readlink call
|
||||
# will fail with EACCES.
|
||||
raise OSError(errno.EACCES, None)
|
||||
|
||||
self.patch(lockfile, "readlink", fakeReadlink)
|
||||
|
||||
lockf = self.mktemp()
|
||||
lock = lockfile.FilesystemLock(lockf)
|
||||
lockfile.symlink(str(43125), lockf)
|
||||
self.assertFalse(lock.lock())
|
||||
self.assertFalse(lock.locked)
|
||||
|
||||
def _readlinkErrorTest(
|
||||
self, exceptionType: type[OSError] | type[IOError], errno: int
|
||||
) -> None:
|
||||
def fakeReadlink(name: str) -> NoReturn:
|
||||
raise exceptionType(errno, None)
|
||||
|
||||
self.patch(lockfile, "readlink", fakeReadlink)
|
||||
|
||||
lockf = self.mktemp()
|
||||
|
||||
# Make it appear locked so it has to use readlink
|
||||
lockfile.symlink(str(43125), lockf)
|
||||
|
||||
lock = lockfile.FilesystemLock(lockf)
|
||||
exc = self.assertRaises(exceptionType, lock.lock)
|
||||
self.assertEqual(exc.errno, errno)
|
||||
self.assertFalse(lock.locked)
|
||||
|
||||
def test_readlinkError(self) -> None:
|
||||
"""
|
||||
An exception raised by C{readlink} other than C{ENOENT} is passed up to
|
||||
the caller of L{FilesystemLock.lock}.
|
||||
"""
|
||||
self._readlinkErrorTest(OSError, errno.ENOSYS)
|
||||
self._readlinkErrorTest(IOError, errno.ENOSYS)
|
||||
|
||||
@skipIf(
|
||||
platform.isWindows(),
|
||||
"POSIX-specific error propagation not expected on Windows.",
|
||||
)
|
||||
def test_readlinkErrorPOSIX(self) -> None:
|
||||
"""
|
||||
Any L{IOError} raised by C{readlink} on a POSIX platform passed to the
|
||||
caller of L{FilesystemLock.lock}.
|
||||
|
||||
On POSIX, unlike on Windows, these are unexpected errors which cannot
|
||||
be handled by L{FilesystemLock}.
|
||||
"""
|
||||
self._readlinkErrorTest(IOError, errno.ENOSYS)
|
||||
self._readlinkErrorTest(IOError, errno.EACCES)
|
||||
|
||||
def test_lockCleanedUpConcurrently(self) -> None:
|
||||
"""
|
||||
If a second process cleans up the lock after a first one checks the
|
||||
lock and finds that no process is holding it, the first process does
|
||||
not fail when it tries to clean up the lock.
|
||||
"""
|
||||
|
||||
def fakeRmlink(name: str) -> None:
|
||||
rmlinkPatch.restore()
|
||||
# Pretend to be another process cleaning up the lock.
|
||||
lockfile.rmlink(lockf)
|
||||
# Fall back to the real implementation of rmlink.
|
||||
return lockfile.rmlink(name)
|
||||
|
||||
rmlinkPatch = self.patch(lockfile, "rmlink", fakeRmlink)
|
||||
|
||||
def fakeKill(pid: int, signal: int) -> None:
|
||||
if signal != 0:
|
||||
raise OSError(errno.EPERM, None)
|
||||
if pid == 43125:
|
||||
raise OSError(errno.ESRCH, None)
|
||||
|
||||
self.patch(lockfile, "kill", fakeKill)
|
||||
|
||||
lockf = self.mktemp()
|
||||
lock = lockfile.FilesystemLock(lockf)
|
||||
lockfile.symlink(str(43125), lockf)
|
||||
self.assertTrue(lock.lock())
|
||||
self.assertTrue(lock.clean)
|
||||
self.assertTrue(lock.locked)
|
||||
|
||||
def test_rmlinkError(self) -> None:
|
||||
"""
|
||||
An exception raised by L{rmlink} other than C{ENOENT} is passed up
|
||||
to the caller of L{FilesystemLock.lock}.
|
||||
"""
|
||||
|
||||
def fakeRmlink(name: str) -> NoReturn:
|
||||
raise OSError(errno.ENOSYS, None)
|
||||
|
||||
self.patch(lockfile, "rmlink", fakeRmlink)
|
||||
|
||||
def fakeKill(pid: int, signal: int) -> None:
|
||||
if signal != 0:
|
||||
raise OSError(errno.EPERM, None)
|
||||
if pid == 43125:
|
||||
raise OSError(errno.ESRCH, None)
|
||||
|
||||
self.patch(lockfile, "kill", fakeKill)
|
||||
|
||||
lockf = self.mktemp()
|
||||
|
||||
# Make it appear locked so it has to use readlink
|
||||
lockfile.symlink(str(43125), lockf)
|
||||
|
||||
lock = lockfile.FilesystemLock(lockf)
|
||||
exc = self.assertRaises(OSError, lock.lock)
|
||||
self.assertEqual(exc.errno, errno.ENOSYS)
|
||||
self.assertFalse(lock.locked)
|
||||
|
||||
def test_killError(self) -> None:
|
||||
"""
|
||||
If L{kill} raises an exception other than L{OSError} with errno set to
|
||||
C{ESRCH}, the exception is passed up to the caller of
|
||||
L{FilesystemLock.lock}.
|
||||
"""
|
||||
|
||||
def fakeKill(pid: int, signal: int) -> NoReturn:
|
||||
raise OSError(errno.EPERM, None)
|
||||
|
||||
self.patch(lockfile, "kill", fakeKill)
|
||||
|
||||
lockf = self.mktemp()
|
||||
|
||||
# Make it appear locked so it has to use readlink
|
||||
lockfile.symlink(str(43125), lockf)
|
||||
|
||||
lock = lockfile.FilesystemLock(lockf)
|
||||
exc = self.assertRaises(OSError, lock.lock)
|
||||
self.assertEqual(exc.errno, errno.EPERM)
|
||||
self.assertFalse(lock.locked)
|
||||
|
||||
def test_unlockOther(self) -> None:
|
||||
"""
|
||||
L{FilesystemLock.unlock} raises L{ValueError} if called for a lock
|
||||
which is held by a different process.
|
||||
"""
|
||||
lockf = self.mktemp()
|
||||
lockfile.symlink(str(os.getpid() + 1), lockf)
|
||||
lock = lockfile.FilesystemLock(lockf)
|
||||
self.assertRaises(ValueError, lock.unlock)
|
||||
|
||||
def test_isLocked(self) -> None:
|
||||
"""
|
||||
L{isLocked} returns C{True} if the named lock is currently locked,
|
||||
C{False} otherwise.
|
||||
"""
|
||||
lockf = self.mktemp()
|
||||
self.assertFalse(lockfile.isLocked(lockf))
|
||||
lock = lockfile.FilesystemLock(lockf)
|
||||
self.assertTrue(lock.lock())
|
||||
self.assertTrue(lockfile.isLocked(lockf))
|
||||
lock.unlock()
|
||||
self.assertFalse(lockfile.isLocked(lockf))
|
||||
1039
.venv/lib/python3.12/site-packages/twisted/test/test_log.py
Normal file
1039
.venv/lib/python3.12/site-packages/twisted/test/test_log.py
Normal file
File diff suppressed because it is too large
Load Diff
534
.venv/lib/python3.12/site-packages/twisted/test/test_logfile.py
Normal file
534
.venv/lib/python3.12/site-packages/twisted/test/test_logfile.py
Normal file
@@ -0,0 +1,534 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import errno
|
||||
import os
|
||||
import stat
|
||||
import time
|
||||
from unittest import skipIf
|
||||
|
||||
from twisted.python import logfile, runtime
|
||||
from twisted.trial.unittest import TestCase
|
||||
|
||||
|
||||
class LogFileTests(TestCase):
|
||||
"""
|
||||
Test the rotating log file.
|
||||
"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.dir = self.mktemp()
|
||||
os.makedirs(self.dir)
|
||||
self.name = "test.log"
|
||||
self.path = os.path.join(self.dir, self.name)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
"""
|
||||
Restore back write rights on created paths: if tests modified the
|
||||
rights, that will allow the paths to be removed easily afterwards.
|
||||
"""
|
||||
os.chmod(self.dir, 0o777)
|
||||
if os.path.exists(self.path):
|
||||
os.chmod(self.path, 0o777)
|
||||
|
||||
def test_abstractShouldRotate(self) -> None:
|
||||
"""
|
||||
L{BaseLogFile.shouldRotate} is abstract and must be implemented by
|
||||
subclass.
|
||||
"""
|
||||
log = logfile.BaseLogFile(self.name, self.dir)
|
||||
self.addCleanup(log.close)
|
||||
self.assertRaises(NotImplementedError, log.shouldRotate)
|
||||
|
||||
def test_writing(self) -> None:
|
||||
"""
|
||||
Log files can be written to, flushed and closed. Closing a log file
|
||||
also flushes it.
|
||||
"""
|
||||
with contextlib.closing(logfile.LogFile(self.name, self.dir)) as log:
|
||||
log.write("123")
|
||||
log.write("456")
|
||||
log.flush()
|
||||
log.write("7890")
|
||||
|
||||
with open(self.path) as f:
|
||||
self.assertEqual(f.read(), "1234567890")
|
||||
|
||||
def test_rotation(self) -> None:
|
||||
"""
|
||||
Rotating log files autorotate after a period of time, and can also be
|
||||
manually rotated.
|
||||
"""
|
||||
# this logfile should rotate every 10 bytes
|
||||
with contextlib.closing(
|
||||
logfile.LogFile(self.name, self.dir, rotateLength=10)
|
||||
) as log:
|
||||
# test automatic rotation
|
||||
log.write("123")
|
||||
log.write("4567890")
|
||||
log.write("1" * 11)
|
||||
self.assertTrue(os.path.exists(f"{self.path}.1"))
|
||||
self.assertFalse(os.path.exists(f"{self.path}.2"))
|
||||
log.write("")
|
||||
self.assertTrue(os.path.exists(f"{self.path}.1"))
|
||||
self.assertTrue(os.path.exists(f"{self.path}.2"))
|
||||
self.assertFalse(os.path.exists(f"{self.path}.3"))
|
||||
log.write("3")
|
||||
self.assertFalse(os.path.exists(f"{self.path}.3"))
|
||||
|
||||
# test manual rotation
|
||||
log.rotate()
|
||||
self.assertTrue(os.path.exists(f"{self.path}.3"))
|
||||
self.assertFalse(os.path.exists(f"{self.path}.4"))
|
||||
|
||||
self.assertEqual(log.listLogs(), [1, 2, 3])
|
||||
|
||||
def test_append(self) -> None:
|
||||
"""
|
||||
Log files can be written to, closed. Their size is the number of
|
||||
bytes written to them. Everything that was written to them can
|
||||
be read, even if the writing happened on separate occasions,
|
||||
and even if the log file was closed in between.
|
||||
"""
|
||||
with contextlib.closing(logfile.LogFile(self.name, self.dir)) as log:
|
||||
log.write("0123456789")
|
||||
|
||||
log = logfile.LogFile(self.name, self.dir)
|
||||
self.addCleanup(log.close)
|
||||
self.assertEqual(log.size, 10)
|
||||
self.assertEqual(log._file.tell(), log.size)
|
||||
log.write("abc")
|
||||
log.write(b"def\xff")
|
||||
expectResult = b"0123456789abcdef\xff"
|
||||
self.assertEqual(log.size, len(expectResult))
|
||||
self.assertEqual(log._file.tell(), log.size)
|
||||
f = log._file
|
||||
f.seek(0, 0)
|
||||
self.assertEqual(f.read(), expectResult)
|
||||
|
||||
def test_logReader(self) -> None:
|
||||
"""
|
||||
Various tests for log readers.
|
||||
|
||||
First of all, log readers can get logs by number and read what
|
||||
was written to those log files. Getting nonexistent log files
|
||||
raises C{ValueError}. Using anything other than an integer
|
||||
index raises C{TypeError}. As logs get older, their log
|
||||
numbers increase.
|
||||
"""
|
||||
log = logfile.LogFile(self.name, self.dir)
|
||||
self.addCleanup(log.close)
|
||||
log.write("abc\n")
|
||||
log.write("def\n")
|
||||
log.rotate()
|
||||
log.write("ghi\n")
|
||||
log.flush()
|
||||
|
||||
# check reading logs
|
||||
self.assertEqual(log.listLogs(), [1])
|
||||
with contextlib.closing(log.getCurrentLog()) as reader:
|
||||
reader._file.seek(0)
|
||||
self.assertEqual(reader.readLines(), ["ghi\n"])
|
||||
self.assertEqual(reader.readLines(), [])
|
||||
with contextlib.closing(log.getLog(1)) as reader:
|
||||
self.assertEqual(reader.readLines(), ["abc\n", "def\n"])
|
||||
self.assertEqual(reader.readLines(), [])
|
||||
|
||||
# check getting illegal log readers
|
||||
self.assertRaises(ValueError, log.getLog, 2)
|
||||
self.assertRaises(TypeError, log.getLog, "1")
|
||||
|
||||
# check that log numbers are higher for older logs
|
||||
log.rotate()
|
||||
self.assertEqual(log.listLogs(), [1, 2])
|
||||
with contextlib.closing(log.getLog(1)) as reader:
|
||||
reader._file.seek(0)
|
||||
self.assertEqual(reader.readLines(), ["ghi\n"])
|
||||
self.assertEqual(reader.readLines(), [])
|
||||
with contextlib.closing(log.getLog(2)) as reader:
|
||||
self.assertEqual(reader.readLines(), ["abc\n", "def\n"])
|
||||
self.assertEqual(reader.readLines(), [])
|
||||
|
||||
def test_LogReaderReadsZeroLine(self) -> None:
|
||||
"""
|
||||
L{LogReader.readLines} supports reading no line.
|
||||
"""
|
||||
# We don't need any content, just a file path that can be opened.
|
||||
with open(self.path, "w"):
|
||||
pass
|
||||
|
||||
reader = logfile.LogReader(self.path)
|
||||
self.addCleanup(reader.close)
|
||||
self.assertEqual([], reader.readLines(0))
|
||||
|
||||
def test_modePreservation(self) -> None:
|
||||
"""
|
||||
Check rotated files have same permissions as original.
|
||||
"""
|
||||
open(self.path, "w").close()
|
||||
os.chmod(self.path, 0o707)
|
||||
mode = os.stat(self.path)[stat.ST_MODE]
|
||||
log = logfile.LogFile(self.name, self.dir)
|
||||
self.addCleanup(log.close)
|
||||
log.write("abc")
|
||||
log.rotate()
|
||||
self.assertEqual(mode, os.stat(self.path)[stat.ST_MODE])
|
||||
|
||||
def test_noPermission(self) -> None:
|
||||
"""
|
||||
Check it keeps working when permission on dir changes.
|
||||
"""
|
||||
log = logfile.LogFile(self.name, self.dir)
|
||||
self.addCleanup(log.close)
|
||||
log.write("abc")
|
||||
|
||||
# change permissions so rotation would fail
|
||||
os.chmod(self.dir, 0o555)
|
||||
|
||||
# if this succeeds, chmod doesn't restrict us, so we can't
|
||||
# do the test
|
||||
try:
|
||||
f = open(os.path.join(self.dir, "xxx"), "w")
|
||||
except OSError:
|
||||
pass
|
||||
else:
|
||||
f.close()
|
||||
return
|
||||
|
||||
log.rotate() # this should not fail
|
||||
|
||||
log.write("def")
|
||||
log.flush()
|
||||
|
||||
f = log._file
|
||||
self.assertEqual(f.tell(), 6)
|
||||
f.seek(0, 0)
|
||||
self.assertEqual(f.read(), b"abcdef")
|
||||
|
||||
def test_maxNumberOfLog(self) -> None:
|
||||
"""
|
||||
Test it respect the limit on the number of files when maxRotatedFiles
|
||||
is not None.
|
||||
"""
|
||||
log = logfile.LogFile(self.name, self.dir, rotateLength=10, maxRotatedFiles=3)
|
||||
self.addCleanup(log.close)
|
||||
log.write("1" * 11)
|
||||
log.write("2" * 11)
|
||||
self.assertTrue(os.path.exists(f"{self.path}.1"))
|
||||
|
||||
log.write("3" * 11)
|
||||
self.assertTrue(os.path.exists(f"{self.path}.2"))
|
||||
|
||||
log.write("4" * 11)
|
||||
self.assertTrue(os.path.exists(f"{self.path}.3"))
|
||||
with open(f"{self.path}.3") as fp:
|
||||
self.assertEqual(fp.read(), "1" * 11)
|
||||
|
||||
log.write("5" * 11)
|
||||
with open(f"{self.path}.3") as fp:
|
||||
self.assertEqual(fp.read(), "2" * 11)
|
||||
self.assertFalse(os.path.exists(f"{self.path}.4"))
|
||||
|
||||
def test_fromFullPath(self) -> None:
|
||||
"""
|
||||
Test the fromFullPath method.
|
||||
"""
|
||||
log1 = logfile.LogFile(self.name, self.dir, 10, defaultMode=0o777)
|
||||
self.addCleanup(log1.close)
|
||||
log2 = logfile.LogFile.fromFullPath(self.path, 10, defaultMode=0o777)
|
||||
self.addCleanup(log2.close)
|
||||
self.assertEqual(log1.name, log2.name)
|
||||
self.assertEqual(os.path.abspath(log1.path), log2.path)
|
||||
self.assertEqual(log1.rotateLength, log2.rotateLength)
|
||||
self.assertEqual(log1.defaultMode, log2.defaultMode)
|
||||
|
||||
def test_defaultPermissions(self) -> None:
|
||||
"""
|
||||
Test the default permission of the log file: if the file exist, it
|
||||
should keep the permission.
|
||||
"""
|
||||
with open(self.path, "wb"):
|
||||
os.chmod(self.path, 0o707)
|
||||
currentMode = stat.S_IMODE(os.stat(self.path)[stat.ST_MODE])
|
||||
log1 = logfile.LogFile(self.name, self.dir)
|
||||
self.assertEqual(stat.S_IMODE(os.stat(self.path)[stat.ST_MODE]), currentMode)
|
||||
self.addCleanup(log1.close)
|
||||
|
||||
def test_specifiedPermissions(self) -> None:
|
||||
"""
|
||||
Test specifying the permissions used on the log file.
|
||||
"""
|
||||
log1 = logfile.LogFile(self.name, self.dir, defaultMode=0o066)
|
||||
self.addCleanup(log1.close)
|
||||
mode = stat.S_IMODE(os.stat(self.path)[stat.ST_MODE])
|
||||
if runtime.platform.isWindows():
|
||||
# The only thing we can get here is global read-only
|
||||
self.assertEqual(mode, 0o444)
|
||||
else:
|
||||
self.assertEqual(mode, 0o066)
|
||||
|
||||
@skipIf(runtime.platform.isWindows(), "Can't test reopen on Windows")
|
||||
def test_reopen(self) -> None:
|
||||
"""
|
||||
L{logfile.LogFile.reopen} allows to rename the currently used file and
|
||||
make L{logfile.LogFile} create a new file.
|
||||
"""
|
||||
with contextlib.closing(logfile.LogFile(self.name, self.dir)) as log1:
|
||||
log1.write("hello1")
|
||||
savePath = os.path.join(self.dir, "save.log")
|
||||
os.rename(self.path, savePath)
|
||||
log1.reopen()
|
||||
log1.write("hello2")
|
||||
|
||||
with open(self.path) as f:
|
||||
self.assertEqual(f.read(), "hello2")
|
||||
with open(savePath) as f:
|
||||
self.assertEqual(f.read(), "hello1")
|
||||
|
||||
def test_nonExistentDir(self) -> None:
|
||||
"""
|
||||
Specifying an invalid directory to L{LogFile} raises C{IOError}.
|
||||
"""
|
||||
e = self.assertRaises(
|
||||
IOError, logfile.LogFile, self.name, "this_dir_does_not_exist"
|
||||
)
|
||||
self.assertEqual(e.errno, errno.ENOENT)
|
||||
|
||||
def test_cantChangeFileMode(self) -> None:
|
||||
"""
|
||||
Opening a L{LogFile} which can be read and write but whose mode can't
|
||||
be changed doesn't trigger an error.
|
||||
"""
|
||||
if runtime.platform.isWindows():
|
||||
name, directory = "NUL", ""
|
||||
expectedPath = "NUL"
|
||||
else:
|
||||
name, directory = "null", "/dev"
|
||||
expectedPath = "/dev/null"
|
||||
|
||||
log = logfile.LogFile(name, directory, defaultMode=0o555)
|
||||
self.addCleanup(log.close)
|
||||
|
||||
self.assertEqual(log.path, expectedPath)
|
||||
self.assertEqual(log.defaultMode, 0o555)
|
||||
|
||||
def test_listLogsWithBadlyNamedFiles(self) -> None:
|
||||
"""
|
||||
L{LogFile.listLogs} doesn't choke if it encounters a file with an
|
||||
unexpected name.
|
||||
"""
|
||||
log = logfile.LogFile(self.name, self.dir)
|
||||
self.addCleanup(log.close)
|
||||
|
||||
with open(f"{log.path}.1", "w") as fp:
|
||||
fp.write("123")
|
||||
with open(f"{log.path}.bad-file", "w") as fp:
|
||||
fp.write("123")
|
||||
|
||||
self.assertEqual([1], log.listLogs())
|
||||
|
||||
def test_listLogsIgnoresZeroSuffixedFiles(self) -> None:
|
||||
"""
|
||||
L{LogFile.listLogs} ignores log files which rotated suffix is 0.
|
||||
"""
|
||||
log = logfile.LogFile(self.name, self.dir)
|
||||
self.addCleanup(log.close)
|
||||
|
||||
for i in range(0, 3):
|
||||
with open(f"{log.path}.{i}", "w") as fp:
|
||||
fp.write("123")
|
||||
|
||||
self.assertEqual([1, 2], log.listLogs())
|
||||
|
||||
|
||||
class RiggedDailyLogFile(logfile.DailyLogFile):
|
||||
_clock = 0.0
|
||||
|
||||
def _openFile(self) -> None:
|
||||
logfile.DailyLogFile._openFile(self)
|
||||
# rig the date to match _clock, not mtime
|
||||
self.lastDate = self.toDate()
|
||||
|
||||
def toDate(self, *args: float) -> tuple[int, int, int]:
|
||||
if args:
|
||||
return time.gmtime(*args)[:3]
|
||||
return time.gmtime(self._clock)[:3]
|
||||
|
||||
|
||||
class DailyLogFileTests(TestCase):
|
||||
"""
|
||||
Test rotating log file.
|
||||
"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.dir = self.mktemp()
|
||||
os.makedirs(self.dir)
|
||||
self.name = "testdaily.log"
|
||||
self.path = os.path.join(self.dir, self.name)
|
||||
|
||||
def test_writing(self) -> None:
|
||||
"""
|
||||
A daily log file can be written to like an ordinary log file.
|
||||
"""
|
||||
with contextlib.closing(RiggedDailyLogFile(self.name, self.dir)) as log:
|
||||
log.write("123")
|
||||
log.write("456")
|
||||
log.flush()
|
||||
log.write("7890")
|
||||
|
||||
with open(self.path) as f:
|
||||
self.assertEqual(f.read(), "1234567890")
|
||||
|
||||
def test_rotation(self) -> None:
|
||||
"""
|
||||
Daily log files rotate daily.
|
||||
"""
|
||||
log = RiggedDailyLogFile(self.name, self.dir)
|
||||
self.addCleanup(log.close)
|
||||
days = [(self.path + "." + log.suffix(day * 86400)) for day in range(3)]
|
||||
|
||||
# test automatic rotation
|
||||
log._clock = 0.0 # 1970/01/01 00:00.00
|
||||
log.write("123")
|
||||
log._clock = 43200 # 1970/01/01 12:00.00
|
||||
log.write("4567890")
|
||||
log._clock = 86400 # 1970/01/02 00:00.00
|
||||
log.write("1" * 11)
|
||||
self.assertTrue(os.path.exists(days[0]))
|
||||
self.assertFalse(os.path.exists(days[1]))
|
||||
log._clock = 172800 # 1970/01/03 00:00.00
|
||||
log.write("")
|
||||
self.assertTrue(os.path.exists(days[0]))
|
||||
self.assertTrue(os.path.exists(days[1]))
|
||||
self.assertFalse(os.path.exists(days[2]))
|
||||
log._clock = 259199 # 1970/01/03 23:59.59
|
||||
log.write("3")
|
||||
self.assertFalse(os.path.exists(days[2]))
|
||||
|
||||
def test_getLog(self) -> None:
|
||||
"""
|
||||
Test retrieving log files with L{DailyLogFile.getLog}.
|
||||
"""
|
||||
data = ["1\n", "2\n", "3\n"]
|
||||
log = RiggedDailyLogFile(self.name, self.dir)
|
||||
self.addCleanup(log.close)
|
||||
for d in data:
|
||||
log.write(d)
|
||||
log.flush()
|
||||
|
||||
# This returns the current log file.
|
||||
r = log.getLog(0.0)
|
||||
self.addCleanup(r.close)
|
||||
|
||||
self.assertEqual(data, r.readLines())
|
||||
|
||||
# We can't get this log, it doesn't exist yet.
|
||||
self.assertRaises(ValueError, log.getLog, 86400)
|
||||
|
||||
log._clock = 86401 # New day
|
||||
r.close()
|
||||
log.rotate()
|
||||
r = log.getLog(0) # We get the previous log
|
||||
self.addCleanup(r.close)
|
||||
self.assertEqual(data, r.readLines())
|
||||
|
||||
def test_rotateAlreadyExists(self) -> None:
|
||||
"""
|
||||
L{DailyLogFile.rotate} doesn't do anything if they new log file already
|
||||
exists on the disk.
|
||||
"""
|
||||
log = RiggedDailyLogFile(self.name, self.dir)
|
||||
self.addCleanup(log.close)
|
||||
|
||||
# Build a new file with the same name as the file which would be created
|
||||
# if the log file is to be rotated.
|
||||
newFilePath = f"{log.path}.{log.suffix(log.lastDate)}"
|
||||
with open(newFilePath, "w") as fp:
|
||||
fp.write("123")
|
||||
previousFile = log._file
|
||||
log.rotate()
|
||||
self.assertEqual(previousFile, log._file)
|
||||
|
||||
@skipIf(
|
||||
runtime.platform.isWindows(),
|
||||
"Making read-only directories on Windows is too complex for this "
|
||||
"test to reasonably do.",
|
||||
)
|
||||
def test_rotatePermissionDirectoryNotOk(self) -> None:
|
||||
"""
|
||||
L{DailyLogFile.rotate} doesn't do anything if the directory containing
|
||||
the log files can't be written to.
|
||||
"""
|
||||
log = logfile.DailyLogFile(self.name, self.dir)
|
||||
self.addCleanup(log.close)
|
||||
|
||||
os.chmod(log.directory, 0o444)
|
||||
# Restore permissions so tests can be cleaned up.
|
||||
self.addCleanup(os.chmod, log.directory, 0o755)
|
||||
previousFile = log._file
|
||||
log.rotate()
|
||||
self.assertEqual(previousFile, log._file)
|
||||
|
||||
def test_rotatePermissionFileNotOk(self) -> None:
|
||||
"""
|
||||
L{DailyLogFile.rotate} doesn't do anything if the log file can't be
|
||||
written to.
|
||||
"""
|
||||
log = logfile.DailyLogFile(self.name, self.dir)
|
||||
self.addCleanup(log.close)
|
||||
|
||||
os.chmod(log.path, 0o444)
|
||||
previousFile = log._file
|
||||
log.rotate()
|
||||
self.assertEqual(previousFile, log._file)
|
||||
|
||||
def test_toDate(self) -> None:
|
||||
"""
|
||||
Test that L{DailyLogFile.toDate} converts its timestamp argument to a
|
||||
time tuple (year, month, day).
|
||||
"""
|
||||
log = logfile.DailyLogFile(self.name, self.dir)
|
||||
self.addCleanup(log.close)
|
||||
|
||||
timestamp = time.mktime((2000, 1, 1, 0, 0, 0, 0, 0, 0))
|
||||
self.assertEqual((2000, 1, 1), log.toDate(timestamp))
|
||||
|
||||
def test_toDateDefaultToday(self) -> None:
|
||||
"""
|
||||
Test that L{DailyLogFile.toDate} returns today's date by default.
|
||||
|
||||
By mocking L{time.localtime}, we ensure that L{DailyLogFile.toDate}
|
||||
returns the first 3 values of L{time.localtime} which is the current
|
||||
date.
|
||||
|
||||
Note that we don't compare the *real* result of L{DailyLogFile.toDate}
|
||||
to the *real* current date, as there's a slight possibility that the
|
||||
date changes between the 2 function calls.
|
||||
"""
|
||||
|
||||
def mock_localtime(*args: object) -> list[int]:
|
||||
self.assertEqual((), args)
|
||||
return list(range(0, 9))
|
||||
|
||||
log = logfile.DailyLogFile(self.name, self.dir)
|
||||
self.addCleanup(log.close)
|
||||
|
||||
self.patch(time, "localtime", mock_localtime)
|
||||
logDate = log.toDate()
|
||||
self.assertEqual([0, 1, 2], logDate)
|
||||
|
||||
def test_toDateUsesArgumentsToMakeADate(self) -> None:
|
||||
"""
|
||||
Test that L{DailyLogFile.toDate} uses its arguments to create a new
|
||||
date.
|
||||
"""
|
||||
log = logfile.DailyLogFile(self.name, self.dir)
|
||||
self.addCleanup(log.close)
|
||||
|
||||
date = (2014, 10, 22)
|
||||
seconds = time.mktime(date + (0,) * 6)
|
||||
|
||||
logDate = log.toDate(seconds)
|
||||
self.assertEqual(date, logDate)
|
||||
464
.venv/lib/python3.12/site-packages/twisted/test/test_loopback.py
Normal file
464
.venv/lib/python3.12/site-packages/twisted/test/test_loopback.py
Normal file
@@ -0,0 +1,464 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Test case for L{twisted.protocols.loopback}.
|
||||
"""
|
||||
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.internet import defer, interfaces, reactor
|
||||
from twisted.internet.defer import Deferred
|
||||
from twisted.internet.interfaces import IAddress, IPullProducer, IPushProducer
|
||||
from twisted.internet.protocol import Protocol
|
||||
from twisted.protocols import basic, loopback
|
||||
from twisted.trial import unittest
|
||||
|
||||
|
||||
class SimpleProtocol(basic.LineReceiver):
|
||||
def __init__(self):
|
||||
self.conn = defer.Deferred()
|
||||
self.lines = []
|
||||
self.connLost = []
|
||||
|
||||
def connectionMade(self):
|
||||
self.conn.callback(None)
|
||||
|
||||
def lineReceived(self, line):
|
||||
self.lines.append(line)
|
||||
|
||||
def connectionLost(self, reason):
|
||||
self.connLost.append(reason)
|
||||
|
||||
|
||||
class DoomProtocol(SimpleProtocol):
|
||||
i = 0
|
||||
|
||||
def lineReceived(self, line):
|
||||
self.i += 1
|
||||
if self.i < 4:
|
||||
# by this point we should have connection closed,
|
||||
# but just in case we didn't we won't ever send 'Hello 4'
|
||||
self.sendLine(b"Hello %d" % (self.i,))
|
||||
SimpleProtocol.lineReceived(self, line)
|
||||
if self.lines[-1] == b"Hello 3":
|
||||
self.transport.loseConnection()
|
||||
|
||||
|
||||
class LoopbackTestCaseMixin:
|
||||
def testRegularFunction(self):
|
||||
s = SimpleProtocol()
|
||||
c = SimpleProtocol()
|
||||
|
||||
def sendALine(result):
|
||||
s.sendLine(b"THIS IS LINE ONE!")
|
||||
s.transport.loseConnection()
|
||||
|
||||
s.conn.addCallback(sendALine)
|
||||
|
||||
def check(ignored):
|
||||
self.assertEqual(c.lines, [b"THIS IS LINE ONE!"])
|
||||
self.assertEqual(len(s.connLost), 1)
|
||||
self.assertEqual(len(c.connLost), 1)
|
||||
|
||||
d = defer.maybeDeferred(self.loopbackFunc, s, c)
|
||||
d.addCallback(check)
|
||||
return d
|
||||
|
||||
def testSneakyHiddenDoom(self):
|
||||
s = DoomProtocol()
|
||||
c = DoomProtocol()
|
||||
|
||||
def sendALine(result):
|
||||
s.sendLine(b"DOOM LINE")
|
||||
|
||||
s.conn.addCallback(sendALine)
|
||||
|
||||
def check(ignored):
|
||||
self.assertEqual(s.lines, [b"Hello 1", b"Hello 2", b"Hello 3"])
|
||||
self.assertEqual(
|
||||
c.lines, [b"DOOM LINE", b"Hello 1", b"Hello 2", b"Hello 3"]
|
||||
)
|
||||
self.assertEqual(len(s.connLost), 1)
|
||||
self.assertEqual(len(c.connLost), 1)
|
||||
|
||||
d = defer.maybeDeferred(self.loopbackFunc, s, c)
|
||||
d.addCallback(check)
|
||||
return d
|
||||
|
||||
|
||||
class LoopbackAsyncTests(LoopbackTestCaseMixin, unittest.TestCase):
|
||||
loopbackFunc = staticmethod(loopback.loopbackAsync)
|
||||
|
||||
def test_makeConnection(self):
|
||||
"""
|
||||
Test that the client and server protocol both have makeConnection
|
||||
invoked on them by loopbackAsync.
|
||||
"""
|
||||
|
||||
class TestProtocol(Protocol):
|
||||
transport = None
|
||||
|
||||
def makeConnection(self, transport):
|
||||
self.transport = transport
|
||||
|
||||
server = TestProtocol()
|
||||
client = TestProtocol()
|
||||
loopback.loopbackAsync(server, client)
|
||||
self.assertIsNotNone(client.transport)
|
||||
self.assertIsNotNone(server.transport)
|
||||
|
||||
def _hostpeertest(self, get, testServer):
|
||||
"""
|
||||
Test one of the permutations of client/server host/peer.
|
||||
"""
|
||||
|
||||
class TestProtocol(Protocol):
|
||||
def makeConnection(self, transport):
|
||||
Protocol.makeConnection(self, transport)
|
||||
self.onConnection.callback(transport)
|
||||
|
||||
if testServer:
|
||||
server = TestProtocol()
|
||||
d = server.onConnection = Deferred()
|
||||
client = Protocol()
|
||||
else:
|
||||
server = Protocol()
|
||||
client = TestProtocol()
|
||||
d = client.onConnection = Deferred()
|
||||
|
||||
loopback.loopbackAsync(server, client)
|
||||
|
||||
def connected(transport):
|
||||
host = getattr(transport, get)()
|
||||
self.assertTrue(IAddress.providedBy(host))
|
||||
|
||||
return d.addCallback(connected)
|
||||
|
||||
def test_serverHost(self):
|
||||
"""
|
||||
Test that the server gets a transport with a properly functioning
|
||||
implementation of L{ITransport.getHost}.
|
||||
"""
|
||||
return self._hostpeertest("getHost", True)
|
||||
|
||||
def test_serverPeer(self):
|
||||
"""
|
||||
Like C{test_serverHost} but for L{ITransport.getPeer}
|
||||
"""
|
||||
return self._hostpeertest("getPeer", True)
|
||||
|
||||
def test_clientHost(self, get="getHost"):
|
||||
"""
|
||||
Test that the client gets a transport with a properly functioning
|
||||
implementation of L{ITransport.getHost}.
|
||||
"""
|
||||
return self._hostpeertest("getHost", False)
|
||||
|
||||
def test_clientPeer(self):
|
||||
"""
|
||||
Like C{test_clientHost} but for L{ITransport.getPeer}.
|
||||
"""
|
||||
return self._hostpeertest("getPeer", False)
|
||||
|
||||
def _greetingtest(self, write, testServer):
|
||||
"""
|
||||
Test one of the permutations of write/writeSequence client/server.
|
||||
|
||||
@param write: The name of the method to test, C{"write"} or
|
||||
C{"writeSequence"}.
|
||||
"""
|
||||
|
||||
class GreeteeProtocol(Protocol):
|
||||
bytes = b""
|
||||
|
||||
def dataReceived(self, bytes):
|
||||
self.bytes += bytes
|
||||
if self.bytes == b"bytes":
|
||||
self.received.callback(None)
|
||||
|
||||
class GreeterProtocol(Protocol):
|
||||
def connectionMade(self):
|
||||
if write == "write":
|
||||
self.transport.write(b"bytes")
|
||||
else:
|
||||
self.transport.writeSequence([b"byt", b"es"])
|
||||
|
||||
if testServer:
|
||||
server = GreeterProtocol()
|
||||
client = GreeteeProtocol()
|
||||
d = client.received = Deferred()
|
||||
else:
|
||||
server = GreeteeProtocol()
|
||||
d = server.received = Deferred()
|
||||
client = GreeterProtocol()
|
||||
|
||||
loopback.loopbackAsync(server, client)
|
||||
return d
|
||||
|
||||
def test_clientGreeting(self):
|
||||
"""
|
||||
Test that on a connection where the client speaks first, the server
|
||||
receives the bytes sent by the client.
|
||||
"""
|
||||
return self._greetingtest("write", False)
|
||||
|
||||
def test_clientGreetingSequence(self):
|
||||
"""
|
||||
Like C{test_clientGreeting}, but use C{writeSequence} instead of
|
||||
C{write} to issue the greeting.
|
||||
"""
|
||||
return self._greetingtest("writeSequence", False)
|
||||
|
||||
def test_serverGreeting(self, write="write"):
|
||||
"""
|
||||
Test that on a connection where the server speaks first, the client
|
||||
receives the bytes sent by the server.
|
||||
"""
|
||||
return self._greetingtest("write", True)
|
||||
|
||||
def test_serverGreetingSequence(self):
|
||||
"""
|
||||
Like C{test_serverGreeting}, but use C{writeSequence} instead of
|
||||
C{write} to issue the greeting.
|
||||
"""
|
||||
return self._greetingtest("writeSequence", True)
|
||||
|
||||
def _producertest(self, producerClass):
|
||||
toProduce = [b"%d" % (i,) for i in range(0, 10)]
|
||||
|
||||
class ProducingProtocol(Protocol):
|
||||
def connectionMade(self):
|
||||
self.producer = producerClass(list(toProduce))
|
||||
self.producer.start(self.transport)
|
||||
|
||||
class ReceivingProtocol(Protocol):
|
||||
bytes = b""
|
||||
|
||||
def dataReceived(self, data):
|
||||
self.bytes += data
|
||||
if self.bytes == b"".join(toProduce):
|
||||
self.received.callback((client, server))
|
||||
|
||||
server = ProducingProtocol()
|
||||
client = ReceivingProtocol()
|
||||
client.received = Deferred()
|
||||
|
||||
loopback.loopbackAsync(server, client)
|
||||
return client.received
|
||||
|
||||
def test_pushProducer(self):
|
||||
"""
|
||||
Test a push producer registered against a loopback transport.
|
||||
"""
|
||||
|
||||
@implementer(IPushProducer)
|
||||
class PushProducer:
|
||||
resumed = False
|
||||
|
||||
def __init__(self, toProduce):
|
||||
self.toProduce = toProduce
|
||||
|
||||
def resumeProducing(self):
|
||||
self.resumed = True
|
||||
|
||||
def start(self, consumer):
|
||||
self.consumer = consumer
|
||||
consumer.registerProducer(self, True)
|
||||
self._produceAndSchedule()
|
||||
|
||||
def _produceAndSchedule(self):
|
||||
if self.toProduce:
|
||||
self.consumer.write(self.toProduce.pop(0))
|
||||
reactor.callLater(0, self._produceAndSchedule)
|
||||
else:
|
||||
self.consumer.unregisterProducer()
|
||||
|
||||
d = self._producertest(PushProducer)
|
||||
|
||||
def finished(results):
|
||||
(client, server) = results
|
||||
self.assertFalse(
|
||||
server.producer.resumed,
|
||||
"Streaming producer should not have been resumed.",
|
||||
)
|
||||
|
||||
d.addCallback(finished)
|
||||
return d
|
||||
|
||||
def test_pullProducer(self):
|
||||
"""
|
||||
Test a pull producer registered against a loopback transport.
|
||||
"""
|
||||
|
||||
@implementer(IPullProducer)
|
||||
class PullProducer:
|
||||
def __init__(self, toProduce):
|
||||
self.toProduce = toProduce
|
||||
|
||||
def start(self, consumer):
|
||||
self.consumer = consumer
|
||||
self.consumer.registerProducer(self, False)
|
||||
|
||||
def resumeProducing(self):
|
||||
self.consumer.write(self.toProduce.pop(0))
|
||||
if not self.toProduce:
|
||||
self.consumer.unregisterProducer()
|
||||
|
||||
return self._producertest(PullProducer)
|
||||
|
||||
def test_writeNotReentrant(self):
|
||||
"""
|
||||
L{loopback.loopbackAsync} does not call a protocol's C{dataReceived}
|
||||
method while that protocol's transport's C{write} method is higher up
|
||||
on the stack.
|
||||
"""
|
||||
|
||||
class Server(Protocol):
|
||||
def dataReceived(self, bytes):
|
||||
self.transport.write(b"bytes")
|
||||
|
||||
class Client(Protocol):
|
||||
ready = False
|
||||
|
||||
def connectionMade(self):
|
||||
reactor.callLater(0, self.go)
|
||||
|
||||
def go(self):
|
||||
self.transport.write(b"foo")
|
||||
self.ready = True
|
||||
|
||||
def dataReceived(self, bytes):
|
||||
self.wasReady = self.ready
|
||||
self.transport.loseConnection()
|
||||
|
||||
server = Server()
|
||||
client = Client()
|
||||
d = loopback.loopbackAsync(client, server)
|
||||
|
||||
def cbFinished(ignored):
|
||||
self.assertTrue(client.wasReady)
|
||||
|
||||
d.addCallback(cbFinished)
|
||||
return d
|
||||
|
||||
def test_pumpPolicy(self):
|
||||
"""
|
||||
The callable passed as the value for the C{pumpPolicy} parameter to
|
||||
L{loopbackAsync} is called with a L{_LoopbackQueue} of pending bytes
|
||||
and a protocol to which they should be delivered.
|
||||
"""
|
||||
pumpCalls = []
|
||||
|
||||
def dummyPolicy(queue, target):
|
||||
bytes = []
|
||||
while queue:
|
||||
bytes.append(queue.get())
|
||||
pumpCalls.append((target, bytes))
|
||||
|
||||
client = Protocol()
|
||||
server = Protocol()
|
||||
|
||||
finished = loopback.loopbackAsync(server, client, dummyPolicy)
|
||||
self.assertEqual(pumpCalls, [])
|
||||
|
||||
client.transport.write(b"foo")
|
||||
client.transport.write(b"bar")
|
||||
server.transport.write(b"baz")
|
||||
server.transport.write(b"quux")
|
||||
server.transport.loseConnection()
|
||||
|
||||
def cbComplete(ignored):
|
||||
self.assertEqual(
|
||||
pumpCalls,
|
||||
# The order here is somewhat arbitrary. The implementation
|
||||
# happens to always deliver data to the client first.
|
||||
[(client, [b"baz", b"quux", None]), (server, [b"foo", b"bar"])],
|
||||
)
|
||||
|
||||
finished.addCallback(cbComplete)
|
||||
return finished
|
||||
|
||||
def test_identityPumpPolicy(self):
|
||||
"""
|
||||
L{identityPumpPolicy} is a pump policy which calls the target's
|
||||
C{dataReceived} method one for each string in the queue passed to it.
|
||||
"""
|
||||
bytes = []
|
||||
client = Protocol()
|
||||
client.dataReceived = bytes.append
|
||||
queue = loopback._LoopbackQueue()
|
||||
queue.put(b"foo")
|
||||
queue.put(b"bar")
|
||||
queue.put(None)
|
||||
|
||||
loopback.identityPumpPolicy(queue, client)
|
||||
|
||||
self.assertEqual(bytes, [b"foo", b"bar"])
|
||||
|
||||
def test_collapsingPumpPolicy(self):
|
||||
"""
|
||||
L{collapsingPumpPolicy} is a pump policy which calls the target's
|
||||
C{dataReceived} only once with all of the strings in the queue passed
|
||||
to it joined together.
|
||||
"""
|
||||
bytes = []
|
||||
client = Protocol()
|
||||
client.dataReceived = bytes.append
|
||||
queue = loopback._LoopbackQueue()
|
||||
queue.put(b"foo")
|
||||
queue.put(b"bar")
|
||||
queue.put(None)
|
||||
|
||||
loopback.collapsingPumpPolicy(queue, client)
|
||||
|
||||
self.assertEqual(bytes, [b"foobar"])
|
||||
|
||||
|
||||
class LoopbackTCPTests(LoopbackTestCaseMixin, unittest.TestCase):
|
||||
loopbackFunc = staticmethod(loopback.loopbackTCP)
|
||||
|
||||
|
||||
class LoopbackUNIXTests(LoopbackTestCaseMixin, unittest.TestCase):
|
||||
loopbackFunc = staticmethod(loopback.loopbackUNIX)
|
||||
|
||||
if interfaces.IReactorUNIX(reactor, None) is None:
|
||||
skip = "Current reactor does not support UNIX sockets"
|
||||
|
||||
|
||||
class LoopbackRelayTest(unittest.TestCase):
|
||||
"""
|
||||
Test for L{twisted.protocols.loopback.LoopbackRelay}
|
||||
"""
|
||||
|
||||
class Receiver(Protocol):
|
||||
"""
|
||||
Simple Receiver class used for testing LoopbackRelay
|
||||
"""
|
||||
|
||||
data = b""
|
||||
|
||||
def dataReceived(self, data):
|
||||
"Accumulate received data for verification"
|
||||
self.data += data
|
||||
|
||||
def test_write(self):
|
||||
"Test to verify that the write function works as expected"
|
||||
receiver = self.Receiver()
|
||||
relay = loopback.LoopbackRelay(receiver)
|
||||
relay.write(b"abc")
|
||||
relay.write(b"def")
|
||||
self.assertEqual(receiver.data, b"")
|
||||
relay.clearBuffer()
|
||||
self.assertEqual(receiver.data, b"abcdef")
|
||||
|
||||
def test_writeSequence(self):
|
||||
"Test to verify that the writeSequence function works as expected"
|
||||
receiver = self.Receiver()
|
||||
relay = loopback.LoopbackRelay(receiver)
|
||||
relay.writeSequence([b"The ", b"quick ", b"brown ", b"fox "])
|
||||
relay.writeSequence([b"jumps ", b"over ", b"the lazy dog"])
|
||||
self.assertEqual(receiver.data, b"")
|
||||
relay.clearBuffer()
|
||||
self.assertEqual(receiver.data, b"The quick brown fox jumps over the lazy dog")
|
||||
74
.venv/lib/python3.12/site-packages/twisted/test/test_main.py
Normal file
74
.venv/lib/python3.12/site-packages/twisted/test/test_main.py
Normal file
@@ -0,0 +1,74 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Test that twisted scripts can be invoked as modules.
|
||||
"""
|
||||
|
||||
|
||||
import sys
|
||||
from io import StringIO
|
||||
|
||||
from twisted.internet import defer, reactor
|
||||
from twisted.test.test_process import Accumulator
|
||||
from twisted.trial.unittest import TestCase
|
||||
|
||||
|
||||
class MainTests(TestCase):
|
||||
"""Test that twisted scripts can be invoked as modules."""
|
||||
|
||||
def test_twisted(self):
|
||||
"""Invoking python -m twisted should execute twist."""
|
||||
cmd = sys.executable
|
||||
p = Accumulator()
|
||||
d = p.endedDeferred = defer.Deferred()
|
||||
reactor.spawnProcess(p, cmd, [cmd, "-m", "twisted", "--help"], env=None)
|
||||
p.transport.closeStdin()
|
||||
|
||||
# Fix up our sys args to match the command we issued
|
||||
from twisted import __main__
|
||||
|
||||
self.patch(sys, "argv", [__main__.__file__, "--help"])
|
||||
|
||||
def processEnded(ign):
|
||||
f = p.outF
|
||||
output = f.getvalue()
|
||||
|
||||
self.assertTrue(
|
||||
b"-m twisted [options] plugin [plugin_options]" in output, output
|
||||
)
|
||||
|
||||
return d.addCallback(processEnded)
|
||||
|
||||
def test_trial(self):
|
||||
"""Invoking python -m twisted.trial should execute trial."""
|
||||
cmd = sys.executable
|
||||
p = Accumulator()
|
||||
d = p.endedDeferred = defer.Deferred()
|
||||
reactor.spawnProcess(p, cmd, [cmd, "-m", "twisted.trial", "--help"], env=None)
|
||||
p.transport.closeStdin()
|
||||
|
||||
# Fix up our sys args to match the command we issued
|
||||
from twisted.trial import __main__
|
||||
|
||||
self.patch(sys, "argv", [__main__.__file__, "--help"])
|
||||
|
||||
def processEnded(ign):
|
||||
f = p.outF
|
||||
output = f.getvalue()
|
||||
|
||||
self.assertTrue(b"-j, --jobs= " in output, output)
|
||||
|
||||
return d.addCallback(processEnded)
|
||||
|
||||
def test_twisted_import(self):
|
||||
"""Importing twisted.__main__ does not execute twist."""
|
||||
output = StringIO()
|
||||
monkey = self.patch(sys, "stdout", output)
|
||||
|
||||
import twisted.__main__
|
||||
|
||||
self.assertTrue(twisted.__main__) # Appease pyflakes
|
||||
|
||||
monkey.restore()
|
||||
self.assertEqual(output.getvalue(), "")
|
||||
714
.venv/lib/python3.12/site-packages/twisted/test/test_memcache.py
Normal file
714
.venv/lib/python3.12/site-packages/twisted/test/test_memcache.py
Normal file
@@ -0,0 +1,714 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Test the memcache client protocol.
|
||||
"""
|
||||
|
||||
|
||||
from twisted.internet.defer import Deferred, DeferredList, TimeoutError, gatherResults
|
||||
from twisted.internet.error import ConnectionDone
|
||||
from twisted.internet.task import Clock
|
||||
from twisted.internet.testing import StringTransportWithDisconnection
|
||||
from twisted.protocols.memcache import (
|
||||
ClientError,
|
||||
MemCacheProtocol,
|
||||
NoSuchCommand,
|
||||
ServerError,
|
||||
)
|
||||
from twisted.trial.unittest import TestCase
|
||||
|
||||
|
||||
class CommandMixin:
|
||||
"""
|
||||
Setup and tests for basic invocation of L{MemCacheProtocol} commands.
|
||||
"""
|
||||
|
||||
def _test(self, d, send, recv, result):
|
||||
"""
|
||||
Helper test method to test the resulting C{Deferred} of a
|
||||
L{MemCacheProtocol} command.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def test_get(self):
|
||||
"""
|
||||
L{MemCacheProtocol.get} returns a L{Deferred} which is called back with
|
||||
the value and the flag associated with the given key if the server
|
||||
returns a successful result.
|
||||
"""
|
||||
return self._test(
|
||||
self.proto.get(b"foo"),
|
||||
b"get foo\r\n",
|
||||
b"VALUE foo 0 3\r\nbar\r\nEND\r\n",
|
||||
(0, b"bar"),
|
||||
)
|
||||
|
||||
def test_emptyGet(self):
|
||||
"""
|
||||
Test getting a non-available key: it succeeds but return L{None} as
|
||||
value and C{0} as flag.
|
||||
"""
|
||||
return self._test(self.proto.get(b"foo"), b"get foo\r\n", b"END\r\n", (0, None))
|
||||
|
||||
def test_getMultiple(self):
|
||||
"""
|
||||
L{MemCacheProtocol.getMultiple} returns a L{Deferred} which is called
|
||||
back with a dictionary of flag, value for each given key.
|
||||
"""
|
||||
return self._test(
|
||||
self.proto.getMultiple([b"foo", b"cow"]),
|
||||
b"get foo cow\r\n",
|
||||
b"VALUE foo 0 3\r\nbar\r\nVALUE cow 0 7\r\nchicken\r\nEND\r\n",
|
||||
{b"cow": (0, b"chicken"), b"foo": (0, b"bar")},
|
||||
)
|
||||
|
||||
def test_getMultipleWithEmpty(self):
|
||||
"""
|
||||
When L{MemCacheProtocol.getMultiple} is called with non-available keys,
|
||||
the corresponding tuples are (0, None).
|
||||
"""
|
||||
return self._test(
|
||||
self.proto.getMultiple([b"foo", b"cow"]),
|
||||
b"get foo cow\r\n",
|
||||
b"VALUE cow 1 3\r\nbar\r\nEND\r\n",
|
||||
{b"cow": (1, b"bar"), b"foo": (0, None)},
|
||||
)
|
||||
|
||||
def test_set(self):
|
||||
"""
|
||||
L{MemCacheProtocol.set} returns a L{Deferred} which is called back with
|
||||
C{True} when the operation succeeds.
|
||||
"""
|
||||
return self._test(
|
||||
self.proto.set(b"foo", b"bar"),
|
||||
b"set foo 0 0 3\r\nbar\r\n",
|
||||
b"STORED\r\n",
|
||||
True,
|
||||
)
|
||||
|
||||
def test_add(self):
|
||||
"""
|
||||
L{MemCacheProtocol.add} returns a L{Deferred} which is called back with
|
||||
C{True} when the operation succeeds.
|
||||
"""
|
||||
return self._test(
|
||||
self.proto.add(b"foo", b"bar"),
|
||||
b"add foo 0 0 3\r\nbar\r\n",
|
||||
b"STORED\r\n",
|
||||
True,
|
||||
)
|
||||
|
||||
def test_replace(self):
|
||||
"""
|
||||
L{MemCacheProtocol.replace} returns a L{Deferred} which is called back
|
||||
with C{True} when the operation succeeds.
|
||||
"""
|
||||
return self._test(
|
||||
self.proto.replace(b"foo", b"bar"),
|
||||
b"replace foo 0 0 3\r\nbar\r\n",
|
||||
b"STORED\r\n",
|
||||
True,
|
||||
)
|
||||
|
||||
def test_errorAdd(self):
|
||||
"""
|
||||
Test an erroneous add: if a L{MemCacheProtocol.add} is called but the
|
||||
key already exists on the server, it returns a B{NOT STORED} answer,
|
||||
which calls back the resulting L{Deferred} with C{False}.
|
||||
"""
|
||||
return self._test(
|
||||
self.proto.add(b"foo", b"bar"),
|
||||
b"add foo 0 0 3\r\nbar\r\n",
|
||||
b"NOT STORED\r\n",
|
||||
False,
|
||||
)
|
||||
|
||||
def test_errorReplace(self):
|
||||
"""
|
||||
Test an erroneous replace: if a L{MemCacheProtocol.replace} is called
|
||||
but the key doesn't exist on the server, it returns a B{NOT STORED}
|
||||
answer, which calls back the resulting L{Deferred} with C{False}.
|
||||
"""
|
||||
return self._test(
|
||||
self.proto.replace(b"foo", b"bar"),
|
||||
b"replace foo 0 0 3\r\nbar\r\n",
|
||||
b"NOT STORED\r\n",
|
||||
False,
|
||||
)
|
||||
|
||||
def test_delete(self):
|
||||
"""
|
||||
L{MemCacheProtocol.delete} returns a L{Deferred} which is called back
|
||||
with C{True} when the server notifies a success.
|
||||
"""
|
||||
return self._test(
|
||||
self.proto.delete(b"bar"), b"delete bar\r\n", b"DELETED\r\n", True
|
||||
)
|
||||
|
||||
def test_errorDelete(self):
|
||||
"""
|
||||
Test an error during a delete: if key doesn't exist on the server, it
|
||||
returns a B{NOT FOUND} answer which calls back the resulting
|
||||
L{Deferred} with C{False}.
|
||||
"""
|
||||
return self._test(
|
||||
self.proto.delete(b"bar"), b"delete bar\r\n", b"NOT FOUND\r\n", False
|
||||
)
|
||||
|
||||
def test_increment(self):
|
||||
"""
|
||||
Test incrementing a variable: L{MemCacheProtocol.increment} returns a
|
||||
L{Deferred} which is called back with the incremented value of the
|
||||
given key.
|
||||
"""
|
||||
return self._test(self.proto.increment(b"foo"), b"incr foo 1\r\n", b"4\r\n", 4)
|
||||
|
||||
def test_decrement(self):
|
||||
"""
|
||||
Test decrementing a variable: L{MemCacheProtocol.decrement} returns a
|
||||
L{Deferred} which is called back with the decremented value of the
|
||||
given key.
|
||||
"""
|
||||
return self._test(self.proto.decrement(b"foo"), b"decr foo 1\r\n", b"5\r\n", 5)
|
||||
|
||||
def test_incrementVal(self):
|
||||
"""
|
||||
L{MemCacheProtocol.increment} takes an optional argument C{value} which
|
||||
replaces the default value of 1 when specified.
|
||||
"""
|
||||
return self._test(
|
||||
self.proto.increment(b"foo", 8), b"incr foo 8\r\n", b"4\r\n", 4
|
||||
)
|
||||
|
||||
def test_decrementVal(self):
|
||||
"""
|
||||
L{MemCacheProtocol.decrement} takes an optional argument C{value} which
|
||||
replaces the default value of 1 when specified.
|
||||
"""
|
||||
return self._test(
|
||||
self.proto.decrement(b"foo", 3), b"decr foo 3\r\n", b"5\r\n", 5
|
||||
)
|
||||
|
||||
def test_stats(self):
|
||||
"""
|
||||
Test retrieving server statistics via the L{MemCacheProtocol.stats}
|
||||
command: it parses the data sent by the server and calls back the
|
||||
resulting L{Deferred} with a dictionary of the received statistics.
|
||||
"""
|
||||
return self._test(
|
||||
self.proto.stats(),
|
||||
b"stats\r\n",
|
||||
b"STAT foo bar\r\nSTAT egg spam\r\nEND\r\n",
|
||||
{b"foo": b"bar", b"egg": b"spam"},
|
||||
)
|
||||
|
||||
def test_statsWithArgument(self):
|
||||
"""
|
||||
L{MemCacheProtocol.stats} takes an optional C{bytes} argument which,
|
||||
if specified, is sent along with the I{STAT} command. The I{STAT}
|
||||
responses from the server are parsed as key/value pairs and returned
|
||||
as a C{dict} (as in the case where the argument is not specified).
|
||||
"""
|
||||
return self._test(
|
||||
self.proto.stats(b"blah"),
|
||||
b"stats blah\r\n",
|
||||
b"STAT foo bar\r\nSTAT egg spam\r\nEND\r\n",
|
||||
{b"foo": b"bar", b"egg": b"spam"},
|
||||
)
|
||||
|
||||
def test_version(self):
|
||||
"""
|
||||
Test version retrieval via the L{MemCacheProtocol.version} command: it
|
||||
returns a L{Deferred} which is called back with the version sent by the
|
||||
server.
|
||||
"""
|
||||
return self._test(
|
||||
self.proto.version(), b"version\r\n", b"VERSION 1.1\r\n", b"1.1"
|
||||
)
|
||||
|
||||
def test_flushAll(self):
|
||||
"""
|
||||
L{MemCacheProtocol.flushAll} returns a L{Deferred} which is called back
|
||||
with C{True} if the server acknowledges success.
|
||||
"""
|
||||
return self._test(self.proto.flushAll(), b"flush_all\r\n", b"OK\r\n", True)
|
||||
|
||||
|
||||
class MemCacheTests(CommandMixin, TestCase):
|
||||
"""
|
||||
Test client protocol class L{MemCacheProtocol}.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""
|
||||
Create a memcache client, connect it to a string protocol, and make it
|
||||
use a deterministic clock.
|
||||
"""
|
||||
self.proto = MemCacheProtocol()
|
||||
self.clock = Clock()
|
||||
self.proto.callLater = self.clock.callLater
|
||||
self.transport = StringTransportWithDisconnection()
|
||||
self.transport.protocol = self.proto
|
||||
self.proto.makeConnection(self.transport)
|
||||
|
||||
def _test(self, d, send, recv, result):
|
||||
"""
|
||||
Implementation of C{_test} which checks that the command sends C{send}
|
||||
data, and that upon reception of C{recv} the result is C{result}.
|
||||
|
||||
@param d: the resulting deferred from the memcache command.
|
||||
@type d: C{Deferred}
|
||||
|
||||
@param send: the expected data to be sent.
|
||||
@type send: C{bytes}
|
||||
|
||||
@param recv: the data to simulate as reception.
|
||||
@type recv: C{bytes}
|
||||
|
||||
@param result: the expected result.
|
||||
@type result: C{any}
|
||||
"""
|
||||
|
||||
def cb(res):
|
||||
self.assertEqual(res, result)
|
||||
|
||||
self.assertEqual(self.transport.value(), send)
|
||||
d.addCallback(cb)
|
||||
self.proto.dataReceived(recv)
|
||||
return d
|
||||
|
||||
def test_invalidGetResponse(self):
|
||||
"""
|
||||
If the value returned doesn't match the expected key of the current
|
||||
C{get} command, an error is raised in L{MemCacheProtocol.dataReceived}.
|
||||
"""
|
||||
self.proto.get(b"foo")
|
||||
self.assertRaises(
|
||||
RuntimeError,
|
||||
self.proto.dataReceived,
|
||||
b"VALUE bar 0 7\r\nspamegg\r\nEND\r\n",
|
||||
)
|
||||
|
||||
def test_invalidMultipleGetResponse(self):
|
||||
"""
|
||||
If the value returned doesn't match one the expected keys of the
|
||||
current multiple C{get} command, an error is raised error in
|
||||
L{MemCacheProtocol.dataReceived}.
|
||||
"""
|
||||
self.proto.getMultiple([b"foo", b"bar"])
|
||||
self.assertRaises(
|
||||
RuntimeError,
|
||||
self.proto.dataReceived,
|
||||
b"VALUE egg 0 7\r\nspamegg\r\nEND\r\n",
|
||||
)
|
||||
|
||||
def test_invalidEndResponse(self):
|
||||
"""
|
||||
If an END is received in response to an operation that isn't C{get},
|
||||
C{gets}, or C{stats}, an error is raised in
|
||||
L{MemCacheProtocol.dataReceived}.
|
||||
"""
|
||||
self.proto.set(b"key", b"value")
|
||||
self.assertRaises(RuntimeError, self.proto.dataReceived, b"END\r\n")
|
||||
|
||||
def test_timeOut(self):
|
||||
"""
|
||||
Test the timeout on outgoing requests: when timeout is detected, all
|
||||
current commands fail with a L{TimeoutError}, and the connection is
|
||||
closed.
|
||||
"""
|
||||
d1 = self.proto.get(b"foo")
|
||||
d2 = self.proto.get(b"bar")
|
||||
d3 = Deferred()
|
||||
self.proto.connectionLost = d3.callback
|
||||
|
||||
self.clock.advance(self.proto.persistentTimeOut)
|
||||
self.assertFailure(d1, TimeoutError)
|
||||
self.assertFailure(d2, TimeoutError)
|
||||
|
||||
def checkMessage(error):
|
||||
self.assertEqual(str(error), "Connection timeout")
|
||||
|
||||
d1.addCallback(checkMessage)
|
||||
self.assertFailure(d3, ConnectionDone)
|
||||
return gatherResults([d1, d2, d3])
|
||||
|
||||
def test_timeoutRemoved(self):
|
||||
"""
|
||||
When a request gets a response, no pending timeout call remains around.
|
||||
"""
|
||||
d = self.proto.get(b"foo")
|
||||
|
||||
self.clock.advance(self.proto.persistentTimeOut - 1)
|
||||
self.proto.dataReceived(b"VALUE foo 0 3\r\nbar\r\nEND\r\n")
|
||||
|
||||
def check(result):
|
||||
self.assertEqual(result, (0, b"bar"))
|
||||
self.assertEqual(len(self.clock.calls), 0)
|
||||
|
||||
d.addCallback(check)
|
||||
return d
|
||||
|
||||
def test_timeOutRaw(self):
|
||||
"""
|
||||
Test the timeout when raw mode was started: the timeout is not reset
|
||||
until all the data has been received, so we can have a L{TimeoutError}
|
||||
when waiting for raw data.
|
||||
"""
|
||||
d1 = self.proto.get(b"foo")
|
||||
d2 = Deferred()
|
||||
self.proto.connectionLost = d2.callback
|
||||
|
||||
self.proto.dataReceived(b"VALUE foo 0 10\r\n12345")
|
||||
self.clock.advance(self.proto.persistentTimeOut)
|
||||
self.assertFailure(d1, TimeoutError)
|
||||
self.assertFailure(d2, ConnectionDone)
|
||||
return gatherResults([d1, d2])
|
||||
|
||||
def test_timeOutStat(self):
|
||||
"""
|
||||
Test the timeout when stat command has started: the timeout is not
|
||||
reset until the final B{END} is received.
|
||||
"""
|
||||
d1 = self.proto.stats()
|
||||
d2 = Deferred()
|
||||
self.proto.connectionLost = d2.callback
|
||||
|
||||
self.proto.dataReceived(b"STAT foo bar\r\n")
|
||||
self.clock.advance(self.proto.persistentTimeOut)
|
||||
self.assertFailure(d1, TimeoutError)
|
||||
self.assertFailure(d2, ConnectionDone)
|
||||
return gatherResults([d1, d2])
|
||||
|
||||
def test_timeoutPipelining(self):
|
||||
"""
|
||||
When two requests are sent, a timeout call remains around for the
|
||||
second request, and its timeout time is correct.
|
||||
"""
|
||||
d1 = self.proto.get(b"foo")
|
||||
d2 = self.proto.get(b"bar")
|
||||
d3 = Deferred()
|
||||
self.proto.connectionLost = d3.callback
|
||||
|
||||
self.clock.advance(self.proto.persistentTimeOut - 1)
|
||||
self.proto.dataReceived(b"VALUE foo 0 3\r\nbar\r\nEND\r\n")
|
||||
|
||||
def check(result):
|
||||
self.assertEqual(result, (0, b"bar"))
|
||||
self.assertEqual(len(self.clock.calls), 1)
|
||||
for i in range(self.proto.persistentTimeOut):
|
||||
self.clock.advance(1)
|
||||
return self.assertFailure(d2, TimeoutError).addCallback(checkTime)
|
||||
|
||||
def checkTime(ignored):
|
||||
# Check that the timeout happened C{self.proto.persistentTimeOut}
|
||||
# after the last response
|
||||
self.assertEqual(self.clock.seconds(), 2 * self.proto.persistentTimeOut - 1)
|
||||
|
||||
d1.addCallback(check)
|
||||
self.assertFailure(d3, ConnectionDone)
|
||||
return d1
|
||||
|
||||
def test_timeoutNotReset(self):
|
||||
"""
|
||||
Check that timeout is not resetted for every command, but keep the
|
||||
timeout from the first command without response.
|
||||
"""
|
||||
d1 = self.proto.get(b"foo")
|
||||
d3 = Deferred()
|
||||
self.proto.connectionLost = d3.callback
|
||||
|
||||
self.clock.advance(self.proto.persistentTimeOut - 1)
|
||||
d2 = self.proto.get(b"bar")
|
||||
self.clock.advance(1)
|
||||
self.assertFailure(d1, TimeoutError)
|
||||
self.assertFailure(d2, TimeoutError)
|
||||
self.assertFailure(d3, ConnectionDone)
|
||||
return gatherResults([d1, d2, d3])
|
||||
|
||||
def test_timeoutCleanDeferreds(self):
|
||||
"""
|
||||
C{timeoutConnection} cleans the list of commands that it fires with
|
||||
C{TimeoutError}: C{connectionLost} doesn't try to fire them again, but
|
||||
sets the disconnected state so that future commands fail with a
|
||||
C{RuntimeError}.
|
||||
"""
|
||||
d1 = self.proto.get(b"foo")
|
||||
self.clock.advance(self.proto.persistentTimeOut)
|
||||
self.assertFailure(d1, TimeoutError)
|
||||
d2 = self.proto.get(b"bar")
|
||||
self.assertFailure(d2, RuntimeError)
|
||||
return gatherResults([d1, d2])
|
||||
|
||||
def test_connectionLost(self):
|
||||
"""
|
||||
When disconnection occurs while commands are still outstanding, the
|
||||
commands fail.
|
||||
"""
|
||||
d1 = self.proto.get(b"foo")
|
||||
d2 = self.proto.get(b"bar")
|
||||
self.transport.loseConnection()
|
||||
done = DeferredList([d1, d2], consumeErrors=True)
|
||||
|
||||
def checkFailures(results):
|
||||
for success, result in results:
|
||||
self.assertFalse(success)
|
||||
result.trap(ConnectionDone)
|
||||
|
||||
return done.addCallback(checkFailures)
|
||||
|
||||
def test_tooLongKey(self):
|
||||
"""
|
||||
An error is raised when trying to use a too long key: the called
|
||||
command returns a L{Deferred} which fails with a L{ClientError}.
|
||||
"""
|
||||
d1 = self.assertFailure(self.proto.set(b"a" * 500, b"bar"), ClientError)
|
||||
d2 = self.assertFailure(self.proto.increment(b"a" * 500), ClientError)
|
||||
d3 = self.assertFailure(self.proto.get(b"a" * 500), ClientError)
|
||||
d4 = self.assertFailure(self.proto.append(b"a" * 500, b"bar"), ClientError)
|
||||
d5 = self.assertFailure(self.proto.prepend(b"a" * 500, b"bar"), ClientError)
|
||||
d6 = self.assertFailure(
|
||||
self.proto.getMultiple([b"foo", b"a" * 500]), ClientError
|
||||
)
|
||||
return gatherResults([d1, d2, d3, d4, d5, d6])
|
||||
|
||||
def test_invalidCommand(self):
|
||||
"""
|
||||
When an unknown command is sent directly (not through public API), the
|
||||
server answers with an B{ERROR} token, and the command fails with
|
||||
L{NoSuchCommand}.
|
||||
"""
|
||||
d = self.proto._set(b"egg", b"foo", b"bar", 0, 0, b"")
|
||||
self.assertEqual(self.transport.value(), b"egg foo 0 0 3\r\nbar\r\n")
|
||||
self.assertFailure(d, NoSuchCommand)
|
||||
self.proto.dataReceived(b"ERROR\r\n")
|
||||
return d
|
||||
|
||||
def test_clientError(self):
|
||||
"""
|
||||
Test the L{ClientError} error: when the server sends a B{CLIENT_ERROR}
|
||||
token, the originating command fails with L{ClientError}, and the error
|
||||
contains the text sent by the server.
|
||||
"""
|
||||
a = b"eggspamm"
|
||||
d = self.proto.set(b"foo", a)
|
||||
self.assertEqual(self.transport.value(), b"set foo 0 0 8\r\neggspamm\r\n")
|
||||
self.assertFailure(d, ClientError)
|
||||
|
||||
def check(err):
|
||||
self.assertEqual(str(err), repr(b"We don't like egg and spam"))
|
||||
|
||||
d.addCallback(check)
|
||||
self.proto.dataReceived(b"CLIENT_ERROR We don't like egg and spam\r\n")
|
||||
return d
|
||||
|
||||
def test_serverError(self):
|
||||
"""
|
||||
Test the L{ServerError} error: when the server sends a B{SERVER_ERROR}
|
||||
token, the originating command fails with L{ServerError}, and the error
|
||||
contains the text sent by the server.
|
||||
"""
|
||||
a = b"eggspamm"
|
||||
d = self.proto.set(b"foo", a)
|
||||
self.assertEqual(self.transport.value(), b"set foo 0 0 8\r\neggspamm\r\n")
|
||||
self.assertFailure(d, ServerError)
|
||||
|
||||
def check(err):
|
||||
self.assertEqual(str(err), repr(b"zomg"))
|
||||
|
||||
d.addCallback(check)
|
||||
self.proto.dataReceived(b"SERVER_ERROR zomg\r\n")
|
||||
return d
|
||||
|
||||
def test_unicodeKey(self):
|
||||
"""
|
||||
Using a non-string key as argument to commands raises an error.
|
||||
"""
|
||||
d1 = self.assertFailure(self.proto.set("foo", b"bar"), ClientError)
|
||||
d2 = self.assertFailure(self.proto.increment("egg"), ClientError)
|
||||
d3 = self.assertFailure(self.proto.get(1), ClientError)
|
||||
d4 = self.assertFailure(self.proto.delete("bar"), ClientError)
|
||||
d5 = self.assertFailure(self.proto.append("foo", b"bar"), ClientError)
|
||||
d6 = self.assertFailure(self.proto.prepend("foo", b"bar"), ClientError)
|
||||
d7 = self.assertFailure(self.proto.getMultiple([b"egg", 1]), ClientError)
|
||||
return gatherResults([d1, d2, d3, d4, d5, d6, d7])
|
||||
|
||||
def test_unicodeValue(self):
|
||||
"""
|
||||
Using a non-string value raises an error.
|
||||
"""
|
||||
return self.assertFailure(self.proto.set(b"foo", "bar"), ClientError)
|
||||
|
||||
def test_pipelining(self):
|
||||
"""
|
||||
Multiple requests can be sent subsequently to the server, and the
|
||||
protocol orders the responses correctly and dispatch to the
|
||||
corresponding client command.
|
||||
"""
|
||||
d1 = self.proto.get(b"foo")
|
||||
d1.addCallback(self.assertEqual, (0, b"bar"))
|
||||
d2 = self.proto.set(b"bar", b"spamspamspam")
|
||||
d2.addCallback(self.assertEqual, True)
|
||||
d3 = self.proto.get(b"egg")
|
||||
d3.addCallback(self.assertEqual, (0, b"spam"))
|
||||
self.assertEqual(
|
||||
self.transport.value(),
|
||||
b"get foo\r\nset bar 0 0 12\r\nspamspamspam\r\nget egg\r\n",
|
||||
)
|
||||
self.proto.dataReceived(
|
||||
b"VALUE foo 0 3\r\nbar\r\nEND\r\n"
|
||||
b"STORED\r\n"
|
||||
b"VALUE egg 0 4\r\nspam\r\nEND\r\n"
|
||||
)
|
||||
return gatherResults([d1, d2, d3])
|
||||
|
||||
def test_getInChunks(self):
|
||||
"""
|
||||
If the value retrieved by a C{get} arrive in chunks, the protocol
|
||||
is able to reconstruct it and to produce the good value.
|
||||
"""
|
||||
d = self.proto.get(b"foo")
|
||||
d.addCallback(self.assertEqual, (0, b"0123456789"))
|
||||
self.assertEqual(self.transport.value(), b"get foo\r\n")
|
||||
self.proto.dataReceived(b"VALUE foo 0 10\r\n0123456")
|
||||
self.proto.dataReceived(b"789")
|
||||
self.proto.dataReceived(b"\r\nEND")
|
||||
self.proto.dataReceived(b"\r\n")
|
||||
return d
|
||||
|
||||
def test_append(self):
|
||||
"""
|
||||
L{MemCacheProtocol.append} behaves like a L{MemCacheProtocol.set}
|
||||
method: it returns a L{Deferred} which is called back with C{True} when
|
||||
the operation succeeds.
|
||||
"""
|
||||
return self._test(
|
||||
self.proto.append(b"foo", b"bar"),
|
||||
b"append foo 0 0 3\r\nbar\r\n",
|
||||
b"STORED\r\n",
|
||||
True,
|
||||
)
|
||||
|
||||
def test_prepend(self):
|
||||
"""
|
||||
L{MemCacheProtocol.prepend} behaves like a L{MemCacheProtocol.set}
|
||||
method: it returns a L{Deferred} which is called back with C{True} when
|
||||
the operation succeeds.
|
||||
"""
|
||||
return self._test(
|
||||
self.proto.prepend(b"foo", b"bar"),
|
||||
b"prepend foo 0 0 3\r\nbar\r\n",
|
||||
b"STORED\r\n",
|
||||
True,
|
||||
)
|
||||
|
||||
def test_gets(self):
|
||||
"""
|
||||
L{MemCacheProtocol.get} handles an additional cas result when
|
||||
C{withIdentifier} is C{True} and forward it in the resulting
|
||||
L{Deferred}.
|
||||
"""
|
||||
return self._test(
|
||||
self.proto.get(b"foo", True),
|
||||
b"gets foo\r\n",
|
||||
b"VALUE foo 0 3 1234\r\nbar\r\nEND\r\n",
|
||||
(0, b"1234", b"bar"),
|
||||
)
|
||||
|
||||
def test_emptyGets(self):
|
||||
"""
|
||||
Test getting a non-available key with gets: it succeeds but return
|
||||
L{None} as value, C{0} as flag and an empty cas value.
|
||||
"""
|
||||
return self._test(
|
||||
self.proto.get(b"foo", True), b"gets foo\r\n", b"END\r\n", (0, b"", None)
|
||||
)
|
||||
|
||||
def test_getsMultiple(self):
|
||||
"""
|
||||
L{MemCacheProtocol.getMultiple} handles an additional cas field in the
|
||||
returned tuples if C{withIdentifier} is C{True}.
|
||||
"""
|
||||
return self._test(
|
||||
self.proto.getMultiple([b"foo", b"bar"], True),
|
||||
b"gets foo bar\r\n",
|
||||
b"VALUE foo 0 3 1234\r\negg\r\n" b"VALUE bar 0 4 2345\r\nspam\r\nEND\r\n",
|
||||
{b"bar": (0, b"2345", b"spam"), b"foo": (0, b"1234", b"egg")},
|
||||
)
|
||||
|
||||
def test_getsMultipleIterableKeys(self):
|
||||
"""
|
||||
L{MemCacheProtocol.getMultiple} accepts any iterable of keys.
|
||||
"""
|
||||
return self._test(
|
||||
self.proto.getMultiple(iter([b"foo", b"bar"]), True),
|
||||
b"gets foo bar\r\n",
|
||||
b"VALUE foo 0 3 1234\r\negg\r\n" b"VALUE bar 0 4 2345\r\nspam\r\nEND\r\n",
|
||||
{b"bar": (0, b"2345", b"spam"), b"foo": (0, b"1234", b"egg")},
|
||||
)
|
||||
|
||||
def test_getsMultipleWithEmpty(self):
|
||||
"""
|
||||
When getting a non-available key with L{MemCacheProtocol.getMultiple}
|
||||
when C{withIdentifier} is C{True}, the other keys are retrieved
|
||||
correctly, and the non-available key gets a tuple of C{0} as flag,
|
||||
L{None} as value, and an empty cas value.
|
||||
"""
|
||||
return self._test(
|
||||
self.proto.getMultiple([b"foo", b"bar"], True),
|
||||
b"gets foo bar\r\n",
|
||||
b"VALUE foo 0 3 1234\r\negg\r\nEND\r\n",
|
||||
{b"bar": (0, b"", None), b"foo": (0, b"1234", b"egg")},
|
||||
)
|
||||
|
||||
def test_checkAndSet(self):
|
||||
"""
|
||||
L{MemCacheProtocol.checkAndSet} passes an additional cas identifier
|
||||
that the server handles to check if the data has to be updated.
|
||||
"""
|
||||
return self._test(
|
||||
self.proto.checkAndSet(b"foo", b"bar", cas=b"1234"),
|
||||
b"cas foo 0 0 3 1234\r\nbar\r\n",
|
||||
b"STORED\r\n",
|
||||
True,
|
||||
)
|
||||
|
||||
def test_casUnknowKey(self):
|
||||
"""
|
||||
When L{MemCacheProtocol.checkAndSet} response is C{EXISTS}, the
|
||||
resulting L{Deferred} fires with C{False}.
|
||||
"""
|
||||
return self._test(
|
||||
self.proto.checkAndSet(b"foo", b"bar", cas=b"1234"),
|
||||
b"cas foo 0 0 3 1234\r\nbar\r\n",
|
||||
b"EXISTS\r\n",
|
||||
False,
|
||||
)
|
||||
|
||||
|
||||
class CommandFailureTests(CommandMixin, TestCase):
|
||||
"""
|
||||
Tests for correct failure of commands on a disconnected
|
||||
L{MemCacheProtocol}.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""
|
||||
Create a disconnected memcache client, using a deterministic clock.
|
||||
"""
|
||||
self.proto = MemCacheProtocol()
|
||||
self.clock = Clock()
|
||||
self.proto.callLater = self.clock.callLater
|
||||
self.transport = StringTransportWithDisconnection()
|
||||
self.transport.protocol = self.proto
|
||||
self.proto.makeConnection(self.transport)
|
||||
self.transport.loseConnection()
|
||||
|
||||
def _test(self, d, send, recv, result):
|
||||
"""
|
||||
Implementation of C{_test} which checks that the command fails with
|
||||
C{RuntimeError} because the transport is disconnected. All the
|
||||
parameters except C{d} are ignored.
|
||||
"""
|
||||
return self.assertFailure(d, RuntimeError)
|
||||
503
.venv/lib/python3.12/site-packages/twisted/test/test_modules.py
Normal file
503
.venv/lib/python3.12/site-packages/twisted/test/test_modules.py
Normal file
@@ -0,0 +1,503 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for twisted.python.modules, abstract access to imported or importable
|
||||
objects.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import compileall
|
||||
import itertools
|
||||
import sys
|
||||
import zipfile
|
||||
from importlib.abc import PathEntryFinder
|
||||
from types import ModuleType
|
||||
from typing import Any, Generator
|
||||
|
||||
from typing_extensions import Protocol
|
||||
|
||||
import twisted
|
||||
from twisted.python import modules
|
||||
from twisted.python.compat import networkString
|
||||
from twisted.python.filepath import FilePath
|
||||
from twisted.python.reflect import namedAny
|
||||
from twisted.python.test.modules_helpers import TwistedModulesMixin
|
||||
from twisted.python.test.test_zippath import zipit
|
||||
from twisted.trial.unittest import TestCase
|
||||
|
||||
|
||||
class _SupportsWalkModules(Protocol):
|
||||
def walkModules(
|
||||
self, importPackages: bool
|
||||
) -> Generator[modules.PythonModule, None, None]:
|
||||
...
|
||||
|
||||
|
||||
class TwistedModulesTestCase(TwistedModulesMixin, TestCase):
|
||||
"""
|
||||
Base class for L{modules} test cases.
|
||||
"""
|
||||
|
||||
def findByIteration(
|
||||
self,
|
||||
modname: str,
|
||||
where: _SupportsWalkModules = modules,
|
||||
importPackages: bool = False,
|
||||
) -> modules.PythonModule:
|
||||
"""
|
||||
You don't ever actually want to do this, so it's not in the public
|
||||
API, but sometimes we want to compare the result of an iterative call
|
||||
with a lookup call and make sure they're the same for test purposes.
|
||||
"""
|
||||
for modinfo in where.walkModules(importPackages=importPackages):
|
||||
if modinfo.name == modname:
|
||||
return modinfo
|
||||
self.fail(f"Unable to find module {modname!r} through iteration.")
|
||||
|
||||
|
||||
class BasicTests(TwistedModulesTestCase):
|
||||
def test_namespacedPackages(self) -> None:
|
||||
"""
|
||||
Duplicate packages are not yielded when iterating over namespace
|
||||
packages.
|
||||
"""
|
||||
# Force pkgutil to be loaded already, since the probe package being
|
||||
# created depends on it, and the replaceSysPath call below will make
|
||||
# pretty much everything unimportable.
|
||||
__import__("pkgutil")
|
||||
|
||||
namespaceBoilerplate = (
|
||||
b"import pkgutil; " b"__path__ = pkgutil.extend_path(__path__, __name__)"
|
||||
)
|
||||
|
||||
# Create two temporary directories with packages:
|
||||
#
|
||||
# entry:
|
||||
# test_package/
|
||||
# __init__.py
|
||||
# nested_package/
|
||||
# __init__.py
|
||||
# module.py
|
||||
#
|
||||
# anotherEntry:
|
||||
# test_package/
|
||||
# __init__.py
|
||||
# nested_package/
|
||||
# __init__.py
|
||||
# module2.py
|
||||
#
|
||||
# test_package and test_package.nested_package are namespace packages,
|
||||
# and when both of these are in sys.path, test_package.nested_package
|
||||
# should become a virtual package containing both "module" and
|
||||
# "module2"
|
||||
|
||||
entry = self.pathEntryWithOnePackage()
|
||||
testPackagePath = entry.child("test_package")
|
||||
testPackagePath.child("__init__.py").setContent(namespaceBoilerplate)
|
||||
|
||||
nestedEntry = testPackagePath.child("nested_package")
|
||||
nestedEntry.makedirs()
|
||||
nestedEntry.child("__init__.py").setContent(namespaceBoilerplate)
|
||||
nestedEntry.child("module.py").setContent(b"")
|
||||
|
||||
anotherEntry = self.pathEntryWithOnePackage()
|
||||
anotherPackagePath = anotherEntry.child("test_package")
|
||||
anotherPackagePath.child("__init__.py").setContent(namespaceBoilerplate)
|
||||
|
||||
anotherNestedEntry = anotherPackagePath.child("nested_package")
|
||||
anotherNestedEntry.makedirs()
|
||||
anotherNestedEntry.child("__init__.py").setContent(namespaceBoilerplate)
|
||||
anotherNestedEntry.child("module2.py").setContent(b"")
|
||||
|
||||
self.replaceSysPath([entry.path, anotherEntry.path])
|
||||
|
||||
module = modules.getModule("test_package")
|
||||
|
||||
# We have to use importPackages=True in order to resolve the namespace
|
||||
# packages, so we remove the imported packages from sys.modules after
|
||||
# walking
|
||||
try:
|
||||
walkedNames = [mod.name for mod in module.walkModules(importPackages=True)]
|
||||
finally:
|
||||
for module in list(sys.modules.keys()):
|
||||
if module.startswith("test_package"):
|
||||
del sys.modules[module]
|
||||
|
||||
expected = [
|
||||
"test_package",
|
||||
"test_package.nested_package",
|
||||
"test_package.nested_package.module",
|
||||
"test_package.nested_package.module2",
|
||||
]
|
||||
|
||||
self.assertEqual(walkedNames, expected)
|
||||
|
||||
def test_unimportablePackageGetItem(self) -> None:
|
||||
"""
|
||||
If a package has been explicitly forbidden from importing by setting a
|
||||
L{None} key in sys.modules under its name,
|
||||
L{modules.PythonPath.__getitem__} should still be able to retrieve an
|
||||
unloaded L{modules.PythonModule} for that package.
|
||||
"""
|
||||
shouldNotLoad: list[str] = []
|
||||
path = modules.PythonPath(
|
||||
sysPath=[self.pathEntryWithOnePackage().path],
|
||||
moduleLoader=shouldNotLoad.append,
|
||||
importerCache={},
|
||||
sysPathHooks={},
|
||||
moduleDict={"test_package": None},
|
||||
)
|
||||
self.assertEqual(shouldNotLoad, [])
|
||||
self.assertFalse(path["test_package"].isLoaded())
|
||||
|
||||
def test_unimportablePackageWalkModules(self) -> None:
|
||||
"""
|
||||
If a package has been explicitly forbidden from importing by setting a
|
||||
L{None} key in sys.modules under its name, L{modules.walkModules} should
|
||||
still be able to retrieve an unloaded L{modules.PythonModule} for that
|
||||
package.
|
||||
"""
|
||||
existentPath = self.pathEntryWithOnePackage()
|
||||
self.replaceSysPath([existentPath.path])
|
||||
self.replaceSysModules({"test_package": None}) # type: ignore[dict-item]
|
||||
|
||||
walked = list(modules.walkModules())
|
||||
self.assertEqual([m.name for m in walked], ["test_package"])
|
||||
self.assertFalse(walked[0].isLoaded())
|
||||
|
||||
def test_nonexistentPaths(self) -> None:
|
||||
"""
|
||||
Verify that L{modules.walkModules} ignores entries in sys.path which
|
||||
do not exist in the filesystem.
|
||||
"""
|
||||
existentPath = self.pathEntryWithOnePackage()
|
||||
|
||||
nonexistentPath = FilePath(self.mktemp())
|
||||
self.assertFalse(nonexistentPath.exists())
|
||||
|
||||
self.replaceSysPath([existentPath.path])
|
||||
|
||||
expected = [modules.getModule("test_package")]
|
||||
|
||||
beforeModules = list(modules.walkModules())
|
||||
sys.path.append(nonexistentPath.path)
|
||||
afterModules = list(modules.walkModules())
|
||||
|
||||
self.assertEqual(beforeModules, expected)
|
||||
self.assertEqual(afterModules, expected)
|
||||
|
||||
def test_nonDirectoryPaths(self) -> None:
|
||||
"""
|
||||
Verify that L{modules.walkModules} ignores entries in sys.path which
|
||||
refer to regular files in the filesystem.
|
||||
"""
|
||||
existentPath = self.pathEntryWithOnePackage()
|
||||
|
||||
nonDirectoryPath = FilePath(self.mktemp())
|
||||
self.assertFalse(nonDirectoryPath.exists())
|
||||
nonDirectoryPath.setContent(b"zip file or whatever\n")
|
||||
|
||||
self.replaceSysPath([existentPath.path])
|
||||
|
||||
beforeModules = list(modules.walkModules())
|
||||
sys.path.append(nonDirectoryPath.path)
|
||||
afterModules = list(modules.walkModules())
|
||||
|
||||
self.assertEqual(beforeModules, afterModules)
|
||||
|
||||
def test_twistedShowsUp(self) -> None:
|
||||
"""
|
||||
Scrounge around in the top-level module namespace and make sure that
|
||||
Twisted shows up, and that the module thusly obtained is the same as
|
||||
the module that we find when we look for it explicitly by name.
|
||||
"""
|
||||
self.assertEqual(modules.getModule("twisted"), self.findByIteration("twisted"))
|
||||
|
||||
def test_dottedNames(self) -> None:
|
||||
"""
|
||||
Verify that the walkModules APIs will give us back subpackages, not just
|
||||
subpackages.
|
||||
"""
|
||||
self.assertEqual(
|
||||
modules.getModule("twisted.python"),
|
||||
self.findByIteration("twisted.python", where=modules.getModule("twisted")),
|
||||
)
|
||||
|
||||
def test_onlyTopModules(self) -> None:
|
||||
"""
|
||||
Verify that the iterModules API will only return top-level modules and
|
||||
packages, not submodules or subpackages.
|
||||
"""
|
||||
for module in modules.iterModules():
|
||||
self.assertFalse(
|
||||
"." in module.name,
|
||||
"no nested modules should be returned from iterModules: %r"
|
||||
% (module.filePath),
|
||||
)
|
||||
|
||||
def test_loadPackagesAndModules(self) -> None:
|
||||
"""
|
||||
Verify that we can locate and load packages, modules, submodules, and
|
||||
subpackages.
|
||||
"""
|
||||
for n in ["os", "twisted", "twisted.python", "twisted.python.reflect"]:
|
||||
m = namedAny(n)
|
||||
self.failUnlessIdentical(modules.getModule(n).load(), m)
|
||||
self.failUnlessIdentical(self.findByIteration(n).load(), m)
|
||||
|
||||
def test_pathEntriesOnPath(self) -> None:
|
||||
"""
|
||||
Verify that path entries discovered via module loading are, in fact, on
|
||||
sys.path somewhere.
|
||||
"""
|
||||
for n in ["os", "twisted", "twisted.python", "twisted.python.reflect"]:
|
||||
self.failUnlessIn(modules.getModule(n).pathEntry.filePath.path, sys.path)
|
||||
|
||||
def test_alwaysPreferPy(self) -> None:
|
||||
"""
|
||||
Verify that .py files will always be preferred to .pyc files, regardless of
|
||||
directory listing order.
|
||||
"""
|
||||
mypath = FilePath(self.mktemp())
|
||||
mypath.createDirectory()
|
||||
pp = modules.PythonPath(sysPath=[mypath.path])
|
||||
originalSmartPath = pp._smartPath
|
||||
|
||||
def _evilSmartPath(pathName: str) -> Any:
|
||||
o = originalSmartPath(pathName)
|
||||
originalChildren = o.children
|
||||
|
||||
def evilChildren() -> Any:
|
||||
# normally this order is random; let's make sure it always
|
||||
# comes up .pyc-first.
|
||||
x = list(originalChildren())
|
||||
x.sort()
|
||||
x.reverse()
|
||||
return x
|
||||
|
||||
o.children = evilChildren
|
||||
return o
|
||||
|
||||
mypath.child("abcd.py").setContent(b"\n")
|
||||
compileall.compile_dir(mypath.path, quiet=True)
|
||||
# sanity check
|
||||
self.assertEqual(len(list(mypath.children())), 2)
|
||||
pp._smartPath = _evilSmartPath # type: ignore[method-assign]
|
||||
self.assertEqual(pp["abcd"].filePath, mypath.child("abcd.py"))
|
||||
|
||||
def test_packageMissingPath(self) -> None:
|
||||
"""
|
||||
A package can delete its __path__ for some reasons,
|
||||
C{modules.PythonPath} should be able to deal with it.
|
||||
"""
|
||||
mypath = FilePath(self.mktemp())
|
||||
mypath.createDirectory()
|
||||
pp = modules.PythonPath(sysPath=[mypath.path])
|
||||
subpath = mypath.child("abcd")
|
||||
subpath.createDirectory()
|
||||
subpath.child("__init__.py").setContent(b"del __path__\n")
|
||||
sys.path.append(mypath.path)
|
||||
__import__("abcd")
|
||||
try:
|
||||
l = list(pp.walkModules())
|
||||
self.assertEqual(len(l), 1)
|
||||
self.assertEqual(l[0].name, "abcd")
|
||||
finally:
|
||||
del sys.modules["abcd"]
|
||||
sys.path.remove(mypath.path)
|
||||
|
||||
|
||||
class PathModificationTests(TwistedModulesTestCase):
|
||||
"""
|
||||
These tests share setup/cleanup behavior of creating a dummy package and
|
||||
stuffing some code in it.
|
||||
"""
|
||||
|
||||
_serialnum = itertools.count() # used to generate serial numbers for
|
||||
# package names.
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.pathExtensionName = self.mktemp()
|
||||
self.pathExtension = FilePath(self.pathExtensionName)
|
||||
self.pathExtension.createDirectory()
|
||||
self.packageName = "pyspacetests%d" % (next(self._serialnum),)
|
||||
self.packagePath = self.pathExtension.child(self.packageName)
|
||||
self.packagePath.createDirectory()
|
||||
self.packagePath.child("__init__.py").setContent(b"")
|
||||
self.packagePath.child("a.py").setContent(b"")
|
||||
self.packagePath.child("b.py").setContent(b"")
|
||||
self.packagePath.child("c__init__.py").setContent(b"")
|
||||
self.pathSetUp = False
|
||||
|
||||
def _setupSysPath(self) -> None:
|
||||
assert not self.pathSetUp
|
||||
self.pathSetUp = True
|
||||
sys.path.append(self.pathExtensionName)
|
||||
|
||||
def _underUnderPathTest(self, doImport: bool = True) -> None:
|
||||
moddir2 = self.mktemp()
|
||||
fpmd = FilePath(moddir2)
|
||||
fpmd.createDirectory()
|
||||
fpmd.child("foozle.py").setContent(b"x = 123\n")
|
||||
self.packagePath.child("__init__.py").setContent(
|
||||
networkString(f"__path__.append({repr(moddir2)})\n")
|
||||
)
|
||||
# Cut here
|
||||
self._setupSysPath()
|
||||
modinfo = modules.getModule(self.packageName)
|
||||
self.assertEqual(
|
||||
self.findByIteration(
|
||||
self.packageName + ".foozle", modinfo, importPackages=doImport
|
||||
),
|
||||
modinfo["foozle"],
|
||||
)
|
||||
self.assertEqual(modinfo["foozle"].load().x, 123)
|
||||
|
||||
def test_underUnderPathAlreadyImported(self) -> None:
|
||||
"""
|
||||
Verify that iterModules will honor the __path__ of already-loaded packages.
|
||||
"""
|
||||
self._underUnderPathTest()
|
||||
|
||||
def _listModules(self) -> None:
|
||||
pkginfo = modules.getModule(self.packageName)
|
||||
nfni = [modinfo.name.split(".")[-1] for modinfo in pkginfo.iterModules()]
|
||||
nfni.sort()
|
||||
self.assertEqual(nfni, ["a", "b", "c__init__"])
|
||||
|
||||
def test_listingModules(self) -> None:
|
||||
"""
|
||||
Make sure the module list comes back as we expect from iterModules on a
|
||||
package, whether zipped or not.
|
||||
"""
|
||||
self._setupSysPath()
|
||||
self._listModules()
|
||||
|
||||
def test_listingModulesAlreadyImported(self) -> None:
|
||||
"""
|
||||
Make sure the module list comes back as we expect from iterModules on a
|
||||
package, whether zipped or not, even if the package has already been
|
||||
imported.
|
||||
"""
|
||||
self._setupSysPath()
|
||||
namedAny(self.packageName)
|
||||
self._listModules()
|
||||
|
||||
def tearDown(self) -> None:
|
||||
# Intentionally using 'assert' here, this is not a test assertion, this
|
||||
# is just an "oh fuck what is going ON" assertion. -glyph
|
||||
if self.pathSetUp:
|
||||
HORK = "path cleanup failed: don't be surprised if other tests break"
|
||||
assert sys.path.pop() is self.pathExtensionName, HORK + ", 1"
|
||||
assert self.pathExtensionName not in sys.path, HORK + ", 2"
|
||||
|
||||
|
||||
class RebindingTests(PathModificationTests):
|
||||
"""
|
||||
These tests verify that the default path interrogation API works properly
|
||||
even when sys.path has been rebound to a different object.
|
||||
"""
|
||||
|
||||
def _setupSysPath(self) -> None:
|
||||
assert not self.pathSetUp
|
||||
self.pathSetUp = True
|
||||
self.savedSysPath = sys.path
|
||||
sys.path = sys.path[:]
|
||||
sys.path.append(self.pathExtensionName)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
"""
|
||||
Clean up sys.path by re-binding our original object.
|
||||
"""
|
||||
if self.pathSetUp:
|
||||
sys.path = self.savedSysPath
|
||||
|
||||
|
||||
class ZipPathModificationTests(PathModificationTests):
|
||||
def _setupSysPath(self) -> None:
|
||||
assert not self.pathSetUp
|
||||
zipit(self.pathExtensionName, self.pathExtensionName + ".zip")
|
||||
self.pathExtensionName += ".zip"
|
||||
assert zipfile.is_zipfile(self.pathExtensionName)
|
||||
PathModificationTests._setupSysPath(self)
|
||||
|
||||
|
||||
class PythonPathTests(TestCase):
|
||||
"""
|
||||
Tests for the class which provides the implementation for all of the
|
||||
public API of L{twisted.python.modules}, L{PythonPath}.
|
||||
"""
|
||||
|
||||
def test_unhandledImporter(self) -> None:
|
||||
"""
|
||||
Make sure that the behavior when encountering an unknown importer
|
||||
type is not catastrophic failure.
|
||||
"""
|
||||
|
||||
class SecretImporter:
|
||||
pass
|
||||
|
||||
def hook(name: object) -> SecretImporter:
|
||||
return SecretImporter()
|
||||
|
||||
syspath = ["example/path"]
|
||||
sysmodules: dict[str, ModuleType] = {}
|
||||
syshooks = [hook]
|
||||
syscache: dict[str, PathEntryFinder | None] = {}
|
||||
|
||||
def sysloader(name: object) -> None:
|
||||
return None
|
||||
|
||||
space = modules.PythonPath(syspath, sysmodules, syshooks, syscache, sysloader)
|
||||
entries = list(space.iterEntries())
|
||||
self.assertEqual(len(entries), 1)
|
||||
self.assertRaises(KeyError, lambda: entries[0]["module"])
|
||||
|
||||
def test_inconsistentImporterCache(self) -> None:
|
||||
"""
|
||||
If the path a module loaded with L{PythonPath.__getitem__} is not
|
||||
present in the path importer cache, a warning is emitted, but the
|
||||
L{PythonModule} is returned as usual.
|
||||
"""
|
||||
space = modules.PythonPath([], sys.modules, [], {})
|
||||
thisModule = space[__name__]
|
||||
warnings = self.flushWarnings([self.test_inconsistentImporterCache])
|
||||
self.assertEqual(warnings[0]["category"], UserWarning)
|
||||
self.assertEqual(
|
||||
warnings[0]["message"],
|
||||
FilePath(twisted.__file__).parent().dirname()
|
||||
+ " (for module "
|
||||
+ __name__
|
||||
+ ") not in path importer cache "
|
||||
"(PEP 302 violation - check your local configuration).",
|
||||
)
|
||||
self.assertEqual(len(warnings), 1)
|
||||
self.assertEqual(thisModule.name, __name__)
|
||||
|
||||
def test_containsModule(self) -> None:
|
||||
"""
|
||||
L{PythonPath} implements the C{in} operator so that when it is the
|
||||
right-hand argument and the name of a module which exists on that
|
||||
L{PythonPath} is the left-hand argument, the result is C{True}.
|
||||
"""
|
||||
thePath = modules.PythonPath()
|
||||
self.assertIn("os", thePath)
|
||||
|
||||
def test_doesntContainModule(self) -> None:
|
||||
"""
|
||||
L{PythonPath} implements the C{in} operator so that when it is the
|
||||
right-hand argument and the name of a module which does not exist on
|
||||
that L{PythonPath} is the left-hand argument, the result is C{False}.
|
||||
"""
|
||||
thePath = modules.PythonPath()
|
||||
self.assertNotIn("bogusModule", thePath)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BasicTests",
|
||||
"PathModificationTests",
|
||||
"RebindingTests",
|
||||
"ZipPathModificationTests",
|
||||
"PythonPathTests",
|
||||
]
|
||||
175
.venv/lib/python3.12/site-packages/twisted/test/test_monkey.py
Normal file
175
.venv/lib/python3.12/site-packages/twisted/test/test_monkey.py
Normal file
@@ -0,0 +1,175 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.python.monkey}.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing_extensions import NoReturn
|
||||
|
||||
from twisted.python.monkey import MonkeyPatcher
|
||||
from twisted.trial import unittest
|
||||
|
||||
|
||||
class TestObj:
|
||||
def __init__(self) -> None:
|
||||
self.foo = "foo value"
|
||||
self.bar = "bar value"
|
||||
self.baz = "baz value"
|
||||
|
||||
|
||||
class MonkeyPatcherTests(unittest.SynchronousTestCase):
|
||||
"""
|
||||
Tests for L{MonkeyPatcher} monkey-patching class.
|
||||
"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.testObject = TestObj()
|
||||
self.originalObject = TestObj()
|
||||
self.monkeyPatcher = MonkeyPatcher()
|
||||
|
||||
def test_empty(self) -> None:
|
||||
"""
|
||||
A monkey patcher without patches shouldn't change a thing.
|
||||
"""
|
||||
self.monkeyPatcher.patch()
|
||||
|
||||
# We can't assert that all state is unchanged, but at least we can
|
||||
# check our test object.
|
||||
self.assertEqual(self.originalObject.foo, self.testObject.foo)
|
||||
self.assertEqual(self.originalObject.bar, self.testObject.bar)
|
||||
self.assertEqual(self.originalObject.baz, self.testObject.baz)
|
||||
|
||||
def test_constructWithPatches(self) -> None:
|
||||
"""
|
||||
Constructing a L{MonkeyPatcher} with patches should add all of the
|
||||
given patches to the patch list.
|
||||
"""
|
||||
patcher = MonkeyPatcher(
|
||||
(self.testObject, "foo", "haha"), (self.testObject, "bar", "hehe")
|
||||
)
|
||||
patcher.patch()
|
||||
self.assertEqual("haha", self.testObject.foo)
|
||||
self.assertEqual("hehe", self.testObject.bar)
|
||||
self.assertEqual(self.originalObject.baz, self.testObject.baz)
|
||||
|
||||
def test_patchExisting(self) -> None:
|
||||
"""
|
||||
Patching an attribute that exists sets it to the value defined in the
|
||||
patch.
|
||||
"""
|
||||
self.monkeyPatcher.addPatch(self.testObject, "foo", "haha")
|
||||
self.monkeyPatcher.patch()
|
||||
self.assertEqual(self.testObject.foo, "haha")
|
||||
|
||||
def test_patchNonExisting(self) -> None:
|
||||
"""
|
||||
Patching a non-existing attribute fails with an C{AttributeError}.
|
||||
"""
|
||||
self.monkeyPatcher.addPatch(self.testObject, "nowhere", "blow up please")
|
||||
self.assertRaises(AttributeError, self.monkeyPatcher.patch)
|
||||
|
||||
def test_patchAlreadyPatched(self) -> None:
|
||||
"""
|
||||
Adding a patch for an object and attribute that already have a patch
|
||||
overrides the existing patch.
|
||||
"""
|
||||
self.monkeyPatcher.addPatch(self.testObject, "foo", "blah")
|
||||
self.monkeyPatcher.addPatch(self.testObject, "foo", "BLAH")
|
||||
self.monkeyPatcher.patch()
|
||||
self.assertEqual(self.testObject.foo, "BLAH")
|
||||
self.monkeyPatcher.restore()
|
||||
self.assertEqual(self.testObject.foo, self.originalObject.foo)
|
||||
|
||||
def test_restoreTwiceIsANoOp(self) -> None:
|
||||
"""
|
||||
Restoring an already-restored monkey patch is a no-op.
|
||||
"""
|
||||
self.monkeyPatcher.addPatch(self.testObject, "foo", "blah")
|
||||
self.monkeyPatcher.patch()
|
||||
self.monkeyPatcher.restore()
|
||||
self.assertEqual(self.testObject.foo, self.originalObject.foo)
|
||||
self.monkeyPatcher.restore()
|
||||
self.assertEqual(self.testObject.foo, self.originalObject.foo)
|
||||
|
||||
def test_runWithPatchesDecoration(self) -> None:
|
||||
"""
|
||||
runWithPatches should run the given callable, passing in all arguments
|
||||
and keyword arguments, and return the return value of the callable.
|
||||
"""
|
||||
log: list[tuple[int, int, int | None]] = []
|
||||
|
||||
def f(a: int, b: int, c: int | None = None) -> str:
|
||||
log.append((a, b, c))
|
||||
return "foo"
|
||||
|
||||
result = self.monkeyPatcher.runWithPatches(f, 1, 2, c=10)
|
||||
self.assertEqual("foo", result)
|
||||
self.assertEqual([(1, 2, 10)], log)
|
||||
|
||||
def test_repeatedRunWithPatches(self) -> None:
|
||||
"""
|
||||
We should be able to call the same function with runWithPatches more
|
||||
than once. All patches should apply for each call.
|
||||
"""
|
||||
|
||||
def f() -> tuple[str, str, str]:
|
||||
return (self.testObject.foo, self.testObject.bar, self.testObject.baz)
|
||||
|
||||
self.monkeyPatcher.addPatch(self.testObject, "foo", "haha")
|
||||
result = self.monkeyPatcher.runWithPatches(f)
|
||||
self.assertEqual(
|
||||
("haha", self.originalObject.bar, self.originalObject.baz), result
|
||||
)
|
||||
result = self.monkeyPatcher.runWithPatches(f)
|
||||
self.assertEqual(
|
||||
("haha", self.originalObject.bar, self.originalObject.baz), result
|
||||
)
|
||||
|
||||
def test_runWithPatchesRestores(self) -> None:
|
||||
"""
|
||||
C{runWithPatches} should restore the original values after the function
|
||||
has executed.
|
||||
"""
|
||||
self.monkeyPatcher.addPatch(self.testObject, "foo", "haha")
|
||||
self.assertEqual(self.originalObject.foo, self.testObject.foo)
|
||||
self.monkeyPatcher.runWithPatches(lambda: None)
|
||||
self.assertEqual(self.originalObject.foo, self.testObject.foo)
|
||||
|
||||
def test_runWithPatchesRestoresOnException(self) -> None:
|
||||
"""
|
||||
Test runWithPatches restores the original values even when the function
|
||||
raises an exception.
|
||||
"""
|
||||
|
||||
def _() -> NoReturn:
|
||||
self.assertEqual(self.testObject.foo, "haha")
|
||||
self.assertEqual(self.testObject.bar, "blahblah")
|
||||
raise RuntimeError("Something went wrong!")
|
||||
|
||||
self.monkeyPatcher.addPatch(self.testObject, "foo", "haha")
|
||||
self.monkeyPatcher.addPatch(self.testObject, "bar", "blahblah")
|
||||
|
||||
self.assertRaises(RuntimeError, self.monkeyPatcher.runWithPatches, _)
|
||||
self.assertEqual(self.testObject.foo, self.originalObject.foo)
|
||||
self.assertEqual(self.testObject.bar, self.originalObject.bar)
|
||||
|
||||
def test_contextManager(self) -> None:
|
||||
"""
|
||||
L{MonkeyPatcher} is a context manager that applies its patches on
|
||||
entry and restore original values on exit.
|
||||
"""
|
||||
self.monkeyPatcher.addPatch(self.testObject, "foo", "patched value")
|
||||
with self.monkeyPatcher:
|
||||
self.assertEqual(self.testObject.foo, "patched value")
|
||||
self.assertEqual(self.testObject.foo, self.originalObject.foo)
|
||||
|
||||
def test_contextManagerPropagatesExceptions(self) -> None:
|
||||
"""
|
||||
Exceptions propagate through the L{MonkeyPatcher} context-manager
|
||||
exit method.
|
||||
"""
|
||||
with self.assertRaises(RuntimeError):
|
||||
with self.monkeyPatcher:
|
||||
raise RuntimeError("something")
|
||||
2008
.venv/lib/python3.12/site-packages/twisted/test/test_paths.py
Normal file
2008
.venv/lib/python3.12/site-packages/twisted/test/test_paths.py
Normal file
File diff suppressed because it is too large
Load Diff
382
.venv/lib/python3.12/site-packages/twisted/test/test_pcp.py
Normal file
382
.venv/lib/python3.12/site-packages/twisted/test/test_pcp.py
Normal file
@@ -0,0 +1,382 @@
|
||||
# -*- Python -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
|
||||
__version__ = "$Revision: 1.5 $"[11:-2]
|
||||
|
||||
from twisted.protocols import pcp
|
||||
from twisted.trial import unittest
|
||||
|
||||
# Goal:
|
||||
|
||||
# Take a Protocol instance. Own all outgoing data - anything that
|
||||
# would go to p.transport.write. Own all incoming data - anything
|
||||
# that comes to p.dataReceived.
|
||||
|
||||
# I need:
|
||||
# Something with the AbstractFileDescriptor interface.
|
||||
# That is:
|
||||
# - acts as a Transport
|
||||
# - has a method write()
|
||||
# - which buffers
|
||||
# - acts as a Consumer
|
||||
# - has a registerProducer, unRegisterProducer
|
||||
# - tells the Producer to back off (pauseProducing) when its buffer is full.
|
||||
# - tells the Producer to resumeProducing when its buffer is not so full.
|
||||
# - acts as a Producer
|
||||
# - calls registerProducer
|
||||
# - calls write() on consumers
|
||||
# - honors requests to pause/resume producing
|
||||
# - honors stopProducing, and passes it along to upstream Producers
|
||||
|
||||
|
||||
class DummyTransport:
|
||||
"""A dumb transport to wrap around."""
|
||||
|
||||
def __init__(self):
|
||||
self._writes = []
|
||||
|
||||
def write(self, data):
|
||||
self._writes.append(data)
|
||||
|
||||
def getvalue(self):
|
||||
return "".join(self._writes)
|
||||
|
||||
|
||||
class DummyProducer:
|
||||
resumed = False
|
||||
stopped = False
|
||||
paused = False
|
||||
|
||||
def __init__(self, consumer):
|
||||
self.consumer = consumer
|
||||
|
||||
def resumeProducing(self):
|
||||
self.resumed = True
|
||||
self.paused = False
|
||||
|
||||
def pauseProducing(self):
|
||||
self.paused = True
|
||||
|
||||
def stopProducing(self):
|
||||
self.stopped = True
|
||||
|
||||
|
||||
class DummyConsumer(DummyTransport):
|
||||
producer = None
|
||||
finished = False
|
||||
unregistered = True
|
||||
|
||||
def registerProducer(self, producer, streaming):
|
||||
self.producer = (producer, streaming)
|
||||
|
||||
def unregisterProducer(self):
|
||||
self.unregistered = True
|
||||
|
||||
def finish(self):
|
||||
self.finished = True
|
||||
|
||||
|
||||
class TransportInterfaceTests(unittest.TestCase):
|
||||
proxyClass = pcp.BasicProducerConsumerProxy
|
||||
|
||||
def setUp(self):
|
||||
self.underlying = DummyConsumer()
|
||||
self.transport = self.proxyClass(self.underlying)
|
||||
|
||||
def testWrite(self):
|
||||
self.transport.write("some bytes")
|
||||
|
||||
|
||||
class ConsumerInterfaceTest:
|
||||
"""Test ProducerConsumerProxy as a Consumer.
|
||||
|
||||
Normally we have ProducingServer -> ConsumingTransport.
|
||||
|
||||
If I am to go between (Server -> Shaper -> Transport), I have to
|
||||
play the role of Consumer convincingly for the ProducingServer.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
self.underlying = DummyConsumer()
|
||||
self.consumer = self.proxyClass(self.underlying)
|
||||
self.producer = DummyProducer(self.consumer)
|
||||
|
||||
def testRegisterPush(self):
|
||||
self.consumer.registerProducer(self.producer, True)
|
||||
## Consumer should NOT have called PushProducer.resumeProducing
|
||||
self.assertFalse(self.producer.resumed)
|
||||
|
||||
## I'm I'm just a proxy, should I only do resumeProducing when
|
||||
## I get poked myself?
|
||||
# def testRegisterPull(self):
|
||||
# self.consumer.registerProducer(self.producer, False)
|
||||
# ## Consumer SHOULD have called PushProducer.resumeProducing
|
||||
# self.assertTrue(self.producer.resumed)
|
||||
|
||||
def testUnregister(self):
|
||||
self.consumer.registerProducer(self.producer, False)
|
||||
self.consumer.unregisterProducer()
|
||||
# Now when the consumer would ordinarily want more data, it
|
||||
# shouldn't ask producer for it.
|
||||
# The most succinct way to trigger "want more data" is to proxy for
|
||||
# a PullProducer and have someone ask me for data.
|
||||
self.producer.resumed = False
|
||||
self.consumer.resumeProducing()
|
||||
self.assertFalse(self.producer.resumed)
|
||||
|
||||
def testFinish(self):
|
||||
self.consumer.registerProducer(self.producer, False)
|
||||
self.consumer.finish()
|
||||
# I guess finish should behave like unregister?
|
||||
self.producer.resumed = False
|
||||
self.consumer.resumeProducing()
|
||||
self.assertFalse(self.producer.resumed)
|
||||
|
||||
|
||||
class ProducerInterfaceTest:
|
||||
"""Test ProducerConsumerProxy as a Producer.
|
||||
|
||||
Normally we have ProducingServer -> ConsumingTransport.
|
||||
|
||||
If I am to go between (Server -> Shaper -> Transport), I have to
|
||||
play the role of Producer convincingly for the ConsumingTransport.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
self.consumer = DummyConsumer()
|
||||
self.producer = self.proxyClass(self.consumer)
|
||||
|
||||
def testRegistersProducer(self):
|
||||
self.assertEqual(self.consumer.producer[0], self.producer)
|
||||
|
||||
def testPause(self):
|
||||
self.producer.pauseProducing()
|
||||
self.producer.write("yakkity yak")
|
||||
self.assertFalse(
|
||||
self.consumer.getvalue(), "Paused producer should not have sent data."
|
||||
)
|
||||
|
||||
def testResume(self):
|
||||
self.producer.pauseProducing()
|
||||
self.producer.resumeProducing()
|
||||
self.producer.write("yakkity yak")
|
||||
self.assertEqual(self.consumer.getvalue(), "yakkity yak")
|
||||
|
||||
def testResumeNoEmptyWrite(self):
|
||||
self.producer.pauseProducing()
|
||||
self.producer.resumeProducing()
|
||||
self.assertEqual(
|
||||
len(self.consumer._writes), 0, "Resume triggered an empty write."
|
||||
)
|
||||
|
||||
def testResumeBuffer(self):
|
||||
self.producer.pauseProducing()
|
||||
self.producer.write("buffer this")
|
||||
self.producer.resumeProducing()
|
||||
self.assertEqual(self.consumer.getvalue(), "buffer this")
|
||||
|
||||
def testStop(self):
|
||||
self.producer.stopProducing()
|
||||
self.producer.write("yakkity yak")
|
||||
self.assertFalse(
|
||||
self.consumer.getvalue(), "Stopped producer should not have sent data."
|
||||
)
|
||||
|
||||
|
||||
class PCP_ConsumerInterfaceTests(ConsumerInterfaceTest, unittest.TestCase):
|
||||
proxyClass = pcp.BasicProducerConsumerProxy
|
||||
|
||||
|
||||
class PCPII_ConsumerInterfaceTests(ConsumerInterfaceTest, unittest.TestCase):
|
||||
proxyClass = pcp.ProducerConsumerProxy
|
||||
|
||||
|
||||
class PCP_ProducerInterfaceTests(ProducerInterfaceTest, unittest.TestCase):
|
||||
proxyClass = pcp.BasicProducerConsumerProxy
|
||||
|
||||
|
||||
class PCPII_ProducerInterfaceTests(ProducerInterfaceTest, unittest.TestCase):
|
||||
proxyClass = pcp.ProducerConsumerProxy
|
||||
|
||||
|
||||
class ProducerProxyTests(unittest.TestCase):
|
||||
"""Producer methods on me should be relayed to the Producer I proxy."""
|
||||
|
||||
proxyClass = pcp.BasicProducerConsumerProxy
|
||||
|
||||
def setUp(self):
|
||||
self.proxy = self.proxyClass(None)
|
||||
self.parentProducer = DummyProducer(self.proxy)
|
||||
self.proxy.registerProducer(self.parentProducer, True)
|
||||
|
||||
def testStop(self):
|
||||
self.proxy.stopProducing()
|
||||
self.assertTrue(self.parentProducer.stopped)
|
||||
|
||||
|
||||
class ConsumerProxyTests(unittest.TestCase):
|
||||
"""Consumer methods on me should be relayed to the Consumer I proxy."""
|
||||
|
||||
proxyClass = pcp.BasicProducerConsumerProxy
|
||||
|
||||
def setUp(self):
|
||||
self.underlying = DummyConsumer()
|
||||
self.consumer = self.proxyClass(self.underlying)
|
||||
|
||||
def testWrite(self):
|
||||
# NOTE: This test only valid for streaming (Push) systems.
|
||||
self.consumer.write("some bytes")
|
||||
self.assertEqual(self.underlying.getvalue(), "some bytes")
|
||||
|
||||
def testFinish(self):
|
||||
self.consumer.finish()
|
||||
self.assertTrue(self.underlying.finished)
|
||||
|
||||
def testUnregister(self):
|
||||
self.consumer.unregisterProducer()
|
||||
self.assertTrue(self.underlying.unregistered)
|
||||
|
||||
|
||||
class PullProducerTest:
|
||||
def setUp(self):
|
||||
self.underlying = DummyConsumer()
|
||||
self.proxy = self.proxyClass(self.underlying)
|
||||
self.parentProducer = DummyProducer(self.proxy)
|
||||
self.proxy.registerProducer(self.parentProducer, True)
|
||||
|
||||
def testHoldWrites(self):
|
||||
self.proxy.write("hello")
|
||||
# Consumer should get no data before it says resumeProducing.
|
||||
self.assertFalse(
|
||||
self.underlying.getvalue(), "Pulling Consumer got data before it pulled."
|
||||
)
|
||||
|
||||
def testPull(self):
|
||||
self.proxy.write("hello")
|
||||
self.proxy.resumeProducing()
|
||||
self.assertEqual(self.underlying.getvalue(), "hello")
|
||||
|
||||
def testMergeWrites(self):
|
||||
self.proxy.write("hello ")
|
||||
self.proxy.write("sunshine")
|
||||
self.proxy.resumeProducing()
|
||||
nwrites = len(self.underlying._writes)
|
||||
self.assertEqual(
|
||||
nwrites, 1, "Pull resulted in %d writes instead " "of 1." % (nwrites,)
|
||||
)
|
||||
self.assertEqual(self.underlying.getvalue(), "hello sunshine")
|
||||
|
||||
def testLateWrite(self):
|
||||
# consumer sends its initial pull before we have data
|
||||
self.proxy.resumeProducing()
|
||||
self.proxy.write("data")
|
||||
# This data should answer that pull request.
|
||||
self.assertEqual(self.underlying.getvalue(), "data")
|
||||
|
||||
|
||||
class PCP_PullProducerTests(PullProducerTest, unittest.TestCase):
|
||||
class proxyClass(pcp.BasicProducerConsumerProxy):
|
||||
iAmStreaming = False
|
||||
|
||||
|
||||
class PCPII_PullProducerTests(PullProducerTest, unittest.TestCase):
|
||||
class proxyClass(pcp.ProducerConsumerProxy):
|
||||
iAmStreaming = False
|
||||
|
||||
|
||||
# Buffering!
|
||||
|
||||
|
||||
class BufferedConsumerTests(unittest.TestCase):
|
||||
"""As a consumer, ask the producer to pause after too much data."""
|
||||
|
||||
proxyClass = pcp.ProducerConsumerProxy
|
||||
|
||||
def setUp(self):
|
||||
self.underlying = DummyConsumer()
|
||||
self.proxy = self.proxyClass(self.underlying)
|
||||
self.proxy.bufferSize = 100
|
||||
|
||||
self.parentProducer = DummyProducer(self.proxy)
|
||||
self.proxy.registerProducer(self.parentProducer, True)
|
||||
|
||||
def testRegisterPull(self):
|
||||
self.proxy.registerProducer(self.parentProducer, False)
|
||||
## Consumer SHOULD have called PushProducer.resumeProducing
|
||||
self.assertTrue(self.parentProducer.resumed)
|
||||
|
||||
def testPauseIntercept(self):
|
||||
self.proxy.pauseProducing()
|
||||
self.assertFalse(self.parentProducer.paused)
|
||||
|
||||
def testResumeIntercept(self):
|
||||
self.proxy.pauseProducing()
|
||||
self.proxy.resumeProducing()
|
||||
# With a streaming producer, just because the proxy was resumed is
|
||||
# not necessarily a reason to resume the parent producer. The state
|
||||
# of the buffer should decide that.
|
||||
self.assertFalse(self.parentProducer.resumed)
|
||||
|
||||
def testTriggerPause(self):
|
||||
"""Make sure I say \"when.\" """
|
||||
|
||||
# Pause the proxy so data sent to it builds up in its buffer.
|
||||
self.proxy.pauseProducing()
|
||||
self.assertFalse(self.parentProducer.paused, "don't pause yet")
|
||||
self.proxy.write("x" * 51)
|
||||
self.assertFalse(self.parentProducer.paused, "don't pause yet")
|
||||
self.proxy.write("x" * 51)
|
||||
self.assertTrue(self.parentProducer.paused)
|
||||
|
||||
def testTriggerResume(self):
|
||||
"""Make sure I resumeProducing when my buffer empties."""
|
||||
self.proxy.pauseProducing()
|
||||
self.proxy.write("x" * 102)
|
||||
self.assertTrue(self.parentProducer.paused, "should be paused")
|
||||
self.proxy.resumeProducing()
|
||||
# Resuming should have emptied my buffer, so I should tell my
|
||||
# parent to resume too.
|
||||
self.assertFalse(self.parentProducer.paused, "Producer should have resumed.")
|
||||
self.assertFalse(self.proxy.producerPaused)
|
||||
|
||||
|
||||
class BufferedPullTests(unittest.TestCase):
|
||||
class proxyClass(pcp.ProducerConsumerProxy):
|
||||
iAmStreaming = False
|
||||
|
||||
def _writeSomeData(self, data):
|
||||
pcp.ProducerConsumerProxy._writeSomeData(self, data[:100])
|
||||
return min(len(data), 100)
|
||||
|
||||
def setUp(self):
|
||||
self.underlying = DummyConsumer()
|
||||
self.proxy = self.proxyClass(self.underlying)
|
||||
self.proxy.bufferSize = 100
|
||||
|
||||
self.parentProducer = DummyProducer(self.proxy)
|
||||
self.proxy.registerProducer(self.parentProducer, False)
|
||||
|
||||
def testResumePull(self):
|
||||
# If proxy has no data to send on resumeProducing, it had better pull
|
||||
# some from its PullProducer.
|
||||
self.parentProducer.resumed = False
|
||||
self.proxy.resumeProducing()
|
||||
self.assertTrue(self.parentProducer.resumed)
|
||||
|
||||
def testLateWriteBuffering(self):
|
||||
# consumer sends its initial pull before we have data
|
||||
self.proxy.resumeProducing()
|
||||
self.proxy.write("datum" * 21)
|
||||
# This data should answer that pull request.
|
||||
self.assertEqual(self.underlying.getvalue(), "datum" * 20)
|
||||
# but there should be some left over
|
||||
self.assertEqual(self.proxy._buffer, ["datum"])
|
||||
|
||||
|
||||
# TODO:
|
||||
# test that web request finishing bug (when we weren't proxying
|
||||
# unregisterProducer but were proxying finish, web file transfers
|
||||
# would hang on the last block.)
|
||||
# test what happens if writeSomeBytes decided to write zero bytes.
|
||||
@@ -0,0 +1,536 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
from __future__ import annotations
|
||||
|
||||
# System Imports
|
||||
import copyreg
|
||||
import io
|
||||
import pickle
|
||||
import sys
|
||||
import textwrap
|
||||
from typing import Any, Callable, List, Tuple
|
||||
|
||||
from typing_extensions import NoReturn
|
||||
|
||||
# Twisted Imports
|
||||
from twisted.persisted import aot, crefutil, styles
|
||||
from twisted.trial import unittest
|
||||
from twisted.trial.unittest import TestCase
|
||||
|
||||
|
||||
class VersionTests(TestCase):
|
||||
def test_nullVersionUpgrade(self) -> None:
|
||||
global NullVersioned
|
||||
|
||||
class NullVersioned:
|
||||
def __init__(self) -> None:
|
||||
self.ok = 0
|
||||
|
||||
pkcl = pickle.dumps(NullVersioned())
|
||||
|
||||
class NullVersioned(styles.Versioned): # type: ignore[no-redef]
|
||||
persistenceVersion = 1
|
||||
|
||||
def upgradeToVersion1(self) -> None:
|
||||
self.ok = 1
|
||||
|
||||
mnv = pickle.loads(pkcl)
|
||||
styles.doUpgrade()
|
||||
assert mnv.ok, "initial upgrade not run!"
|
||||
|
||||
def test_versionUpgrade(self) -> None:
|
||||
global MyVersioned
|
||||
|
||||
class MyVersioned(styles.Versioned):
|
||||
persistenceVersion = 2
|
||||
# persistenceForgets should be a tuple
|
||||
persistenceForgets = ["garbagedata"] # type: ignore[assignment]
|
||||
v3 = 0
|
||||
v4 = 0
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.somedata = "xxx"
|
||||
self.garbagedata = lambda q: "cant persist"
|
||||
|
||||
def upgradeToVersion3(self) -> None:
|
||||
self.v3 += 1
|
||||
|
||||
def upgradeToVersion4(self) -> None:
|
||||
self.v4 += 1
|
||||
|
||||
mv = MyVersioned()
|
||||
assert not (mv.v3 or mv.v4), "hasn't been upgraded yet"
|
||||
pickl = pickle.dumps(mv)
|
||||
MyVersioned.persistenceVersion = 4
|
||||
obj = pickle.loads(pickl)
|
||||
styles.doUpgrade()
|
||||
assert obj.v3, "didn't do version 3 upgrade"
|
||||
assert obj.v4, "didn't do version 4 upgrade"
|
||||
pickl = pickle.dumps(obj)
|
||||
obj = pickle.loads(pickl)
|
||||
styles.doUpgrade()
|
||||
assert obj.v3 == 1, "upgraded unnecessarily"
|
||||
assert obj.v4 == 1, "upgraded unnecessarily"
|
||||
|
||||
def test_nonIdentityHash(self) -> None:
|
||||
global ClassWithCustomHash
|
||||
|
||||
class ClassWithCustomHash(styles.Versioned):
|
||||
def __init__(self, unique: str, hash: int) -> None:
|
||||
self.unique = unique
|
||||
self.hash = hash
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return self.hash
|
||||
|
||||
v1 = ClassWithCustomHash("v1", 0)
|
||||
v2 = ClassWithCustomHash("v2", 0)
|
||||
|
||||
pkl = pickle.dumps((v1, v2))
|
||||
del v1, v2
|
||||
ClassWithCustomHash.persistenceVersion = 1
|
||||
ClassWithCustomHash.upgradeToVersion1 = lambda self: setattr( # type: ignore[attr-defined]
|
||||
self, "upgraded", True
|
||||
)
|
||||
v1, v2 = pickle.loads(pkl)
|
||||
styles.doUpgrade()
|
||||
self.assertEqual(v1.unique, "v1")
|
||||
self.assertEqual(v2.unique, "v2")
|
||||
self.assertTrue(v1.upgraded) # type: ignore[attr-defined]
|
||||
self.assertTrue(v2.upgraded) # type: ignore[attr-defined]
|
||||
|
||||
def test_upgradeDeserializesObjectsRequiringUpgrade(self) -> None:
|
||||
global ToyClassA, ToyClassB
|
||||
|
||||
class ToyClassA(styles.Versioned):
|
||||
pass
|
||||
|
||||
class ToyClassB(styles.Versioned):
|
||||
pass
|
||||
|
||||
x = ToyClassA()
|
||||
y = ToyClassB()
|
||||
pklA, pklB = pickle.dumps(x), pickle.dumps(y)
|
||||
del x, y
|
||||
ToyClassA.persistenceVersion = 1
|
||||
|
||||
def upgradeToVersion1(self: Any) -> None:
|
||||
self.y = pickle.loads(pklB)
|
||||
styles.doUpgrade()
|
||||
|
||||
ToyClassA.upgradeToVersion1 = upgradeToVersion1 # type: ignore[attr-defined]
|
||||
ToyClassB.persistenceVersion = 1
|
||||
|
||||
def setUpgraded(self: object) -> None:
|
||||
setattr(self, "upgraded", True)
|
||||
|
||||
ToyClassB.upgradeToVersion1 = setUpgraded # type: ignore[attr-defined]
|
||||
|
||||
x = pickle.loads(pklA)
|
||||
styles.doUpgrade()
|
||||
self.assertTrue(x.y.upgraded) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
class VersionedSubClass(styles.Versioned):
|
||||
pass
|
||||
|
||||
|
||||
class SecondVersionedSubClass(styles.Versioned):
|
||||
pass
|
||||
|
||||
|
||||
class VersionedSubSubClass(VersionedSubClass):
|
||||
pass
|
||||
|
||||
|
||||
class VersionedDiamondSubClass(VersionedSubSubClass, SecondVersionedSubClass):
|
||||
pass
|
||||
|
||||
|
||||
class AybabtuTests(TestCase):
|
||||
"""
|
||||
L{styles._aybabtu} gets all of classes in the inheritance hierarchy of its
|
||||
argument that are strictly between L{Versioned} and the class itself.
|
||||
"""
|
||||
|
||||
def test_aybabtuStrictEmpty(self) -> None:
|
||||
"""
|
||||
L{styles._aybabtu} of L{Versioned} itself is an empty list.
|
||||
"""
|
||||
self.assertEqual(styles._aybabtu(styles.Versioned), [])
|
||||
|
||||
def test_aybabtuStrictSubclass(self) -> None:
|
||||
"""
|
||||
There are no classes I{between} L{VersionedSubClass} and L{Versioned},
|
||||
so L{styles._aybabtu} returns an empty list.
|
||||
"""
|
||||
self.assertEqual(styles._aybabtu(VersionedSubClass), [])
|
||||
|
||||
def test_aybabtuSubsubclass(self) -> None:
|
||||
"""
|
||||
With a sub-sub-class of L{Versioned}, L{styles._aybabtu} returns a list
|
||||
containing the intervening subclass.
|
||||
"""
|
||||
self.assertEqual(styles._aybabtu(VersionedSubSubClass), [VersionedSubClass])
|
||||
|
||||
def test_aybabtuStrict(self) -> None:
|
||||
"""
|
||||
For a diamond-shaped inheritance graph, L{styles._aybabtu} returns a
|
||||
list containing I{both} intermediate subclasses.
|
||||
"""
|
||||
self.assertEqual(
|
||||
styles._aybabtu(VersionedDiamondSubClass),
|
||||
[VersionedSubSubClass, VersionedSubClass, SecondVersionedSubClass],
|
||||
)
|
||||
|
||||
|
||||
class MyEphemeral(styles.Ephemeral):
|
||||
def __init__(self, x: int) -> None:
|
||||
self.x = x
|
||||
|
||||
|
||||
class EphemeralTests(TestCase):
|
||||
def test_ephemeral(self) -> None:
|
||||
o = MyEphemeral(3)
|
||||
self.assertEqual(o.__class__, MyEphemeral)
|
||||
self.assertEqual(o.x, 3)
|
||||
|
||||
pickl = pickle.dumps(o)
|
||||
o = pickle.loads(pickl)
|
||||
|
||||
self.assertEqual(o.__class__, styles.Ephemeral)
|
||||
self.assertFalse(hasattr(o, "x"))
|
||||
|
||||
|
||||
class Pickleable:
|
||||
def __init__(self, x: int) -> None:
|
||||
self.x = x
|
||||
|
||||
def getX(self) -> int:
|
||||
return self.x
|
||||
|
||||
|
||||
class NotPickleable:
|
||||
"""
|
||||
A class that is not pickleable.
|
||||
"""
|
||||
|
||||
def __reduce__(self) -> NoReturn:
|
||||
"""
|
||||
Raise an exception instead of pickling.
|
||||
"""
|
||||
raise TypeError("Not serializable.")
|
||||
|
||||
|
||||
class CopyRegistered:
|
||||
"""
|
||||
A class that is pickleable only because it is registered with the
|
||||
C{copyreg} module.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""
|
||||
Ensure that this object is normally not pickleable.
|
||||
"""
|
||||
self.notPickleable = NotPickleable()
|
||||
|
||||
|
||||
class CopyRegisteredLoaded:
|
||||
"""
|
||||
L{CopyRegistered} after unserialization.
|
||||
"""
|
||||
|
||||
|
||||
def reduceCopyRegistered(cr: object) -> tuple[type[CopyRegisteredLoaded], tuple[()]]:
|
||||
"""
|
||||
Externally implement C{__reduce__} for L{CopyRegistered}.
|
||||
|
||||
@param cr: The L{CopyRegistered} instance.
|
||||
|
||||
@return: a 2-tuple of callable and argument list, in this case
|
||||
L{CopyRegisteredLoaded} and no arguments.
|
||||
"""
|
||||
return CopyRegisteredLoaded, ()
|
||||
|
||||
|
||||
copyreg.pickle(CopyRegistered, reduceCopyRegistered) # type: ignore[arg-type]
|
||||
|
||||
|
||||
class A:
|
||||
"""
|
||||
dummy class
|
||||
"""
|
||||
|
||||
bmethod: Callable[[], None]
|
||||
|
||||
def amethod(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class B:
|
||||
"""
|
||||
dummy class
|
||||
"""
|
||||
|
||||
a: A
|
||||
|
||||
def bmethod(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def funktion() -> None:
|
||||
pass
|
||||
|
||||
|
||||
class PicklingTests(TestCase):
|
||||
"""Test pickling of extra object types."""
|
||||
|
||||
def test_module(self) -> None:
|
||||
pickl = pickle.dumps(styles)
|
||||
o = pickle.loads(pickl)
|
||||
self.assertEqual(o, styles)
|
||||
|
||||
def test_instanceMethod(self) -> None:
|
||||
obj = Pickleable(4)
|
||||
pickl = pickle.dumps(obj.getX)
|
||||
o = pickle.loads(pickl)
|
||||
self.assertEqual(o(), 4)
|
||||
self.assertEqual(type(o), type(obj.getX))
|
||||
|
||||
|
||||
class StringIOTransitionTests(TestCase):
|
||||
"""
|
||||
When pickling a cStringIO in Python 2, it should unpickle as a BytesIO or a
|
||||
StringIO in Python 3, depending on the type of its contents.
|
||||
"""
|
||||
|
||||
def test_unpickleBytesIO(self) -> None:
|
||||
"""
|
||||
A cStringIO pickled with bytes in it will yield an L{io.BytesIO} on
|
||||
python 3.
|
||||
"""
|
||||
pickledStringIWithText = (
|
||||
b"ctwisted.persisted.styles\nunpickleStringI\np0\n"
|
||||
b"(S'test'\np1\nI0\ntp2\nRp3\n."
|
||||
)
|
||||
loaded = pickle.loads(pickledStringIWithText)
|
||||
self.assertIsInstance(loaded, io.StringIO)
|
||||
self.assertEqual(loaded.getvalue(), "test")
|
||||
|
||||
|
||||
class EvilSourceror:
|
||||
a: EvilSourceror
|
||||
b: EvilSourceror
|
||||
c: object
|
||||
|
||||
def __init__(self, x: object) -> None:
|
||||
self.a = self
|
||||
self.a.b = self
|
||||
self.a.b.c = x
|
||||
|
||||
|
||||
class NonDictState:
|
||||
state: str
|
||||
|
||||
def __getstate__(self) -> str:
|
||||
return self.state
|
||||
|
||||
def __setstate__(self, state: str) -> None:
|
||||
self.state = state
|
||||
|
||||
|
||||
_CircularTupleType = List[Tuple["_CircularTupleType", int]]
|
||||
|
||||
|
||||
class AOTTests(TestCase):
|
||||
def test_simpleTypes(self) -> None:
|
||||
obj = (
|
||||
1,
|
||||
2.0,
|
||||
3j,
|
||||
True,
|
||||
slice(1, 2, 3),
|
||||
"hello",
|
||||
"world",
|
||||
sys.maxsize + 1,
|
||||
None,
|
||||
Ellipsis,
|
||||
)
|
||||
rtObj = aot.unjellyFromSource(aot.jellyToSource(obj))
|
||||
self.assertEqual(obj, rtObj)
|
||||
|
||||
def test_methodSelfIdentity(self) -> None:
|
||||
a = A()
|
||||
b = B()
|
||||
a.bmethod = b.bmethod
|
||||
b.a = a
|
||||
im_ = aot.unjellyFromSource(aot.jellyToSource(b)).a.bmethod
|
||||
self.assertEqual(aot._selfOfMethod(im_).__class__, aot._classOfMethod(im_))
|
||||
|
||||
def test_methodNotSelfIdentity(self) -> None:
|
||||
"""
|
||||
If a class change after an instance has been created,
|
||||
L{aot.unjellyFromSource} shoud raise a C{TypeError} when trying to
|
||||
unjelly the instance.
|
||||
"""
|
||||
a = A()
|
||||
b = B()
|
||||
a.bmethod = b.bmethod
|
||||
b.a = a
|
||||
savedbmethod = B.bmethod
|
||||
del B.bmethod
|
||||
try:
|
||||
self.assertRaises(TypeError, aot.unjellyFromSource, aot.jellyToSource(b))
|
||||
finally:
|
||||
B.bmethod = savedbmethod # type: ignore[method-assign]
|
||||
|
||||
def test_unsupportedType(self) -> None:
|
||||
"""
|
||||
L{aot.jellyToSource} should raise a C{TypeError} when trying to jelly
|
||||
an unknown type without a C{__dict__} property or C{__getstate__}
|
||||
method.
|
||||
"""
|
||||
|
||||
class UnknownType:
|
||||
@property
|
||||
def __dict__(self) -> NoReturn: # type: ignore[override]
|
||||
raise AttributeError()
|
||||
|
||||
@property
|
||||
def __getstate__(self) -> NoReturn:
|
||||
raise AttributeError()
|
||||
|
||||
self.assertRaises(TypeError, aot.jellyToSource, UnknownType())
|
||||
|
||||
def test_basicIdentity(self) -> None:
|
||||
# Anyone wanting to make this datastructure more complex, and thus this
|
||||
# test more comprehensive, is welcome to do so.
|
||||
aj = aot.AOTJellier().jellyToAO
|
||||
d = {"hello": "world", "method": aj}
|
||||
l = [
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
'he\tllo\n\n"x world!',
|
||||
"goodbye \n\t\u1010 world!",
|
||||
1,
|
||||
1.0,
|
||||
100**100,
|
||||
unittest,
|
||||
aot.AOTJellier,
|
||||
d,
|
||||
funktion,
|
||||
]
|
||||
t = tuple(l)
|
||||
l.append(l)
|
||||
l.append(t)
|
||||
l.append(t)
|
||||
uj = aot.unjellyFromSource(aot.jellyToSource([l, l]))
|
||||
assert uj[0] is uj[1]
|
||||
assert uj[1][0:5] == l[0:5]
|
||||
|
||||
def test_nonDictState(self) -> None:
|
||||
a = NonDictState()
|
||||
a.state = "meringue!"
|
||||
assert aot.unjellyFromSource(aot.jellyToSource(a)).state == a.state
|
||||
|
||||
def test_copyReg(self) -> None:
|
||||
"""
|
||||
L{aot.jellyToSource} and L{aot.unjellyFromSource} honor functions
|
||||
registered in the pickle copy registry.
|
||||
"""
|
||||
uj = aot.unjellyFromSource(aot.jellyToSource(CopyRegistered()))
|
||||
self.assertIsInstance(uj, CopyRegisteredLoaded)
|
||||
|
||||
def test_funkyReferences(self) -> None:
|
||||
o = EvilSourceror(EvilSourceror([]))
|
||||
j1 = aot.jellyToAOT(o)
|
||||
oj = aot.unjellyFromAOT(j1)
|
||||
|
||||
assert oj.a is oj
|
||||
assert oj.a.b is oj.b
|
||||
assert oj.c is not oj.c.c
|
||||
|
||||
def test_circularTuple(self) -> None:
|
||||
"""
|
||||
L{aot.jellyToAOT} can persist circular references through tuples.
|
||||
"""
|
||||
l: _CircularTupleType = []
|
||||
t = (l, 4321)
|
||||
l.append(t)
|
||||
j1 = aot.jellyToAOT(l)
|
||||
oj = aot.unjellyFromAOT(j1)
|
||||
self.assertIsInstance(oj[0], tuple)
|
||||
self.assertIs(oj[0][0], oj)
|
||||
self.assertEqual(oj[0][1], 4321)
|
||||
|
||||
def testIndentify(self) -> None:
|
||||
"""
|
||||
The generated serialization is indented.
|
||||
"""
|
||||
self.assertEqual(
|
||||
aot.jellyToSource({"hello": {"world": []}}),
|
||||
textwrap.dedent(
|
||||
"""\
|
||||
app={
|
||||
'hello':{
|
||||
'world':[],
|
||||
},
|
||||
}""",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class CrefUtilTests(TestCase):
|
||||
"""
|
||||
Tests for L{crefutil}.
|
||||
"""
|
||||
|
||||
def test_dictUnknownKey(self) -> None:
|
||||
"""
|
||||
L{crefutil._DictKeyAndValue} only support keys C{0} and C{1}.
|
||||
"""
|
||||
d = crefutil._DictKeyAndValue({})
|
||||
self.assertRaises(RuntimeError, d.__setitem__, 2, 3)
|
||||
|
||||
def test_deferSetMultipleTimes(self) -> None:
|
||||
"""
|
||||
L{crefutil._Defer} can be assigned a key only one time.
|
||||
"""
|
||||
d = crefutil._Defer()
|
||||
d[0] = 1
|
||||
self.assertRaises(RuntimeError, d.__setitem__, 0, 1)
|
||||
|
||||
def test_containerWhereAllElementsAreKnown(self) -> None:
|
||||
"""
|
||||
A L{crefutil._Container} where all of its elements are known at
|
||||
construction time is nonsensical and will result in errors in any call
|
||||
to addDependant.
|
||||
"""
|
||||
container = crefutil._Container([1, 2, 3], list)
|
||||
self.assertRaises(AssertionError, container.addDependant, {}, "ignore-me")
|
||||
|
||||
def test_dontPutCircularReferencesInDictionaryKeys(self) -> None:
|
||||
"""
|
||||
If a dictionary key contains a circular reference (which is probably a
|
||||
bad practice anyway) it will be resolved by a
|
||||
L{crefutil._DictKeyAndValue}, not by placing a L{crefutil.NotKnown}
|
||||
into a dictionary key.
|
||||
"""
|
||||
self.assertRaises(
|
||||
AssertionError, dict().__setitem__, crefutil.NotKnown(), "value"
|
||||
)
|
||||
|
||||
def test_dontCallInstanceMethodsThatArentReady(self) -> None:
|
||||
"""
|
||||
L{crefutil._InstanceMethod} raises L{AssertionError} to indicate it
|
||||
should not be called. This should not be possible with any of its API
|
||||
clients, but is provided for helping to debug.
|
||||
"""
|
||||
self.assertRaises(
|
||||
AssertionError,
|
||||
crefutil._InstanceMethod("no_name", crefutil.NotKnown(), type),
|
||||
)
|
||||
|
||||
|
||||
testCases = [VersionTests, EphemeralTests, PicklingTests]
|
||||
744
.venv/lib/python3.12/site-packages/twisted/test/test_plugin.py
Normal file
744
.venv/lib/python3.12/site-packages/twisted/test/test_plugin.py
Normal file
@@ -0,0 +1,744 @@
|
||||
# Copyright (c) 2005 Divmod, Inc.
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for Twisted plugin system.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import compileall
|
||||
import errno
|
||||
import functools
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from importlib import invalidate_caches as invalidateImportCaches
|
||||
from types import ModuleType
|
||||
from typing import Callable, TypedDict, TypeVar
|
||||
|
||||
from zope.interface import Interface
|
||||
|
||||
from twisted import plugin
|
||||
from twisted.python.filepath import FilePath
|
||||
from twisted.python.log import EventDict, addObserver, removeObserver, textFromEventDict
|
||||
from twisted.trial import unittest
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
class ITestPlugin(Interface):
|
||||
"""
|
||||
A plugin for use by the plugin system's unit tests.
|
||||
|
||||
Do not use this.
|
||||
"""
|
||||
|
||||
|
||||
class ITestPlugin2(Interface):
|
||||
"""
|
||||
See L{ITestPlugin}.
|
||||
"""
|
||||
|
||||
|
||||
class PluginTests(unittest.TestCase):
|
||||
"""
|
||||
Tests which verify the behavior of the current, active Twisted plugins
|
||||
directory.
|
||||
"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
"""
|
||||
Save C{sys.path} and C{sys.modules}, and create a package for tests.
|
||||
"""
|
||||
self.originalPath = sys.path[:]
|
||||
self.savedModules = sys.modules.copy()
|
||||
|
||||
self.root = FilePath(self.mktemp())
|
||||
self.root.createDirectory()
|
||||
self.package = self.root.child("mypackage")
|
||||
self.package.createDirectory()
|
||||
self.package.child("__init__.py").setContent(b"")
|
||||
|
||||
FilePath(__file__).sibling("plugin_basic.py").copyTo(
|
||||
self.package.child("testplugin.py")
|
||||
)
|
||||
|
||||
self.originalPlugin = "testplugin"
|
||||
|
||||
sys.path.insert(0, self.root.path)
|
||||
import mypackage # type: ignore[import-not-found]
|
||||
|
||||
self.module = mypackage
|
||||
|
||||
def tearDown(self) -> None:
|
||||
"""
|
||||
Restore C{sys.path} and C{sys.modules} to their original values.
|
||||
"""
|
||||
sys.path[:] = self.originalPath
|
||||
sys.modules.clear()
|
||||
sys.modules.update(self.savedModules)
|
||||
|
||||
def _unimportPythonModule(
|
||||
self, module: ModuleType, deleteSource: bool = False
|
||||
) -> None:
|
||||
assert module.__file__ is not None
|
||||
modulePath = module.__name__.split(".")
|
||||
packageName = ".".join(modulePath[:-1])
|
||||
moduleName = modulePath[-1]
|
||||
|
||||
delattr(sys.modules[packageName], moduleName)
|
||||
del sys.modules[module.__name__]
|
||||
for ext in ["c", "o"] + (deleteSource and [""] or []):
|
||||
try:
|
||||
os.remove(module.__file__ + ext)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
def _clearCache(self) -> None:
|
||||
"""
|
||||
Remove the plugins B{droping.cache} file.
|
||||
"""
|
||||
self.package.child("dropin.cache").remove()
|
||||
|
||||
def _withCacheness(
|
||||
meth: Callable[[PluginTests], object]
|
||||
) -> Callable[[PluginTests], None]:
|
||||
"""
|
||||
This is a paranoid test wrapper, that calls C{meth} 2 times, clear the
|
||||
cache, and calls it 2 other times. It's supposed to ensure that the
|
||||
plugin system behaves correctly no matter what the state of the cache
|
||||
is.
|
||||
"""
|
||||
|
||||
@functools.wraps(meth)
|
||||
def wrapped(self: PluginTests) -> None:
|
||||
meth(self)
|
||||
meth(self)
|
||||
self._clearCache()
|
||||
meth(self)
|
||||
meth(self)
|
||||
|
||||
return wrapped
|
||||
|
||||
@_withCacheness
|
||||
def test_cache(self) -> None:
|
||||
"""
|
||||
Check that the cache returned by L{plugin.getCache} hold the plugin
|
||||
B{testplugin}, and that this plugin has the properties we expect:
|
||||
provide L{TestPlugin}, has the good name and description, and can be
|
||||
loaded successfully.
|
||||
"""
|
||||
cache = plugin.getCache(self.module)
|
||||
|
||||
dropin = cache[self.originalPlugin]
|
||||
self.assertEqual(dropin.moduleName, f"mypackage.{self.originalPlugin}")
|
||||
self.assertIn("I'm a test drop-in.", dropin.description)
|
||||
|
||||
# Note, not the preferred way to get a plugin by its interface.
|
||||
p1 = [p for p in dropin.plugins if ITestPlugin in p.provided][0]
|
||||
self.assertIs(p1.dropin, dropin)
|
||||
self.assertEqual(p1.name, "TestPlugin")
|
||||
|
||||
# Check the content of the description comes from the plugin module
|
||||
# docstring
|
||||
self.assertEqual(
|
||||
p1.description.strip(), "A plugin used solely for testing purposes."
|
||||
)
|
||||
self.assertEqual(p1.provided, [ITestPlugin, plugin.IPlugin])
|
||||
realPlugin = p1.load()
|
||||
# The plugin should match the class present in sys.modules
|
||||
self.assertIs(
|
||||
realPlugin,
|
||||
sys.modules[f"mypackage.{self.originalPlugin}"].TestPlugin,
|
||||
)
|
||||
|
||||
# And it should also match if we import it classicly
|
||||
import mypackage.testplugin as tp # type: ignore[import-not-found]
|
||||
|
||||
self.assertIs(realPlugin, tp.TestPlugin)
|
||||
|
||||
def test_cacheRepr(self) -> None:
|
||||
"""
|
||||
L{CachedPlugin} has a helpful C{repr} which contains relevant
|
||||
information about it.
|
||||
"""
|
||||
cachedDropin = plugin.getCache(self.module)[self.originalPlugin]
|
||||
cachedPlugin = list(p for p in cachedDropin.plugins if p.name == "TestPlugin")[
|
||||
0
|
||||
]
|
||||
self.assertEqual(
|
||||
repr(cachedPlugin),
|
||||
"<CachedPlugin 'TestPlugin'/'mypackage.testplugin' "
|
||||
"(provides 'ITestPlugin, IPlugin')>",
|
||||
)
|
||||
|
||||
@_withCacheness
|
||||
def test_plugins(self) -> None:
|
||||
"""
|
||||
L{plugin.getPlugins} should return the list of plugins matching the
|
||||
specified interface (here, L{ITestPlugin2}), and these plugins
|
||||
should be instances of classes with a C{test} method, to be sure
|
||||
L{plugin.getPlugins} load classes correctly.
|
||||
"""
|
||||
plugins = list(plugin.getPlugins(ITestPlugin2, self.module))
|
||||
|
||||
self.assertEqual(len(plugins), 2)
|
||||
|
||||
names = ["AnotherTestPlugin", "ThirdTestPlugin"]
|
||||
for p in plugins:
|
||||
names.remove(p.__name__) # type: ignore[attr-defined]
|
||||
p.test() # type: ignore[attr-defined]
|
||||
|
||||
@_withCacheness
|
||||
def test_detectNewFiles(self) -> None:
|
||||
"""
|
||||
Check that L{plugin.getPlugins} is able to detect plugins added at
|
||||
runtime.
|
||||
"""
|
||||
FilePath(__file__).sibling("plugin_extra1.py").copyTo(
|
||||
self.package.child("pluginextra.py")
|
||||
)
|
||||
try:
|
||||
# Check that the current situation is clean
|
||||
self.failIfIn("mypackage.pluginextra", sys.modules)
|
||||
self.assertFalse(
|
||||
hasattr(sys.modules["mypackage"], "pluginextra"),
|
||||
"mypackage still has pluginextra module",
|
||||
)
|
||||
|
||||
plgs = list(plugin.getPlugins(ITestPlugin, self.module))
|
||||
|
||||
# We should find 2 plugins: the one in testplugin, and the one in
|
||||
# pluginextra
|
||||
self.assertEqual(len(plgs), 2)
|
||||
|
||||
names = ["TestPlugin", "FourthTestPlugin"]
|
||||
for p in plgs:
|
||||
names.remove(p.__name__) # type: ignore[attr-defined]
|
||||
p.test1() # type: ignore[attr-defined]
|
||||
finally:
|
||||
self._unimportPythonModule(sys.modules["mypackage.pluginextra"], True)
|
||||
|
||||
@_withCacheness
|
||||
def test_detectFilesChanged(self) -> None:
|
||||
"""
|
||||
Check that if the content of a plugin change, L{plugin.getPlugins} is
|
||||
able to detect the new plugins added.
|
||||
"""
|
||||
FilePath(__file__).sibling("plugin_extra1.py").copyTo(
|
||||
self.package.child("pluginextra.py")
|
||||
)
|
||||
try:
|
||||
plgs = list(plugin.getPlugins(ITestPlugin, self.module))
|
||||
# Sanity check
|
||||
self.assertEqual(len(plgs), 2)
|
||||
|
||||
FilePath(__file__).sibling("plugin_extra2.py").copyTo(
|
||||
self.package.child("pluginextra.py")
|
||||
)
|
||||
|
||||
# Fake out Python.
|
||||
self._unimportPythonModule(sys.modules["mypackage.pluginextra"])
|
||||
|
||||
# Make sure additions are noticed
|
||||
plgs = list(plugin.getPlugins(ITestPlugin, self.module))
|
||||
|
||||
self.assertEqual(len(plgs), 3)
|
||||
|
||||
names = ["TestPlugin", "FourthTestPlugin", "FifthTestPlugin"]
|
||||
for p in plgs:
|
||||
names.remove(p.__name__) # type: ignore[attr-defined]
|
||||
p.test1() # type: ignore[attr-defined]
|
||||
finally:
|
||||
self._unimportPythonModule(sys.modules["mypackage.pluginextra"], True)
|
||||
|
||||
@_withCacheness
|
||||
def test_detectFilesRemoved(self) -> None:
|
||||
"""
|
||||
Check that when a dropin file is removed, L{plugin.getPlugins} doesn't
|
||||
return it anymore.
|
||||
"""
|
||||
FilePath(__file__).sibling("plugin_extra1.py").copyTo(
|
||||
self.package.child("pluginextra.py")
|
||||
)
|
||||
try:
|
||||
# Generate a cache with pluginextra in it.
|
||||
list(plugin.getPlugins(ITestPlugin, self.module))
|
||||
|
||||
finally:
|
||||
self._unimportPythonModule(sys.modules["mypackage.pluginextra"], True)
|
||||
plgs = list(plugin.getPlugins(ITestPlugin, self.module))
|
||||
self.assertEqual(1, len(plgs))
|
||||
|
||||
@_withCacheness
|
||||
def test_nonexistentPathEntry(self) -> None:
|
||||
"""
|
||||
Test that getCache skips over any entries in a plugin package's
|
||||
C{__path__} which do not exist.
|
||||
"""
|
||||
path = self.mktemp()
|
||||
self.assertFalse(os.path.exists(path))
|
||||
# Add the test directory to the plugins path
|
||||
self.module.__path__.append(path)
|
||||
try:
|
||||
plgs = list(plugin.getPlugins(ITestPlugin, self.module))
|
||||
self.assertEqual(len(plgs), 1)
|
||||
finally:
|
||||
self.module.__path__.remove(path)
|
||||
|
||||
@_withCacheness
|
||||
def test_nonDirectoryChildEntry(self) -> None:
|
||||
"""
|
||||
Test that getCache skips over any entries in a plugin package's
|
||||
C{__path__} which refer to children of paths which are not directories.
|
||||
"""
|
||||
path = FilePath(self.mktemp())
|
||||
self.assertFalse(path.exists())
|
||||
path.touch()
|
||||
child = path.child("test_package").path
|
||||
self.module.__path__.append(child)
|
||||
try:
|
||||
plgs = list(plugin.getPlugins(ITestPlugin, self.module))
|
||||
self.assertEqual(len(plgs), 1)
|
||||
finally:
|
||||
self.module.__path__.remove(child)
|
||||
|
||||
def test_deployedMode(self) -> None:
|
||||
"""
|
||||
The C{dropin.cache} file may not be writable: the cache should still be
|
||||
attainable, but an error should be logged to show that the cache
|
||||
couldn't be updated.
|
||||
"""
|
||||
# Generate the cache
|
||||
plugin.getCache(self.module)
|
||||
|
||||
cachepath = self.package.child("dropin.cache")
|
||||
|
||||
# Add a new plugin
|
||||
FilePath(__file__).sibling("plugin_extra1.py").copyTo(
|
||||
self.package.child("pluginextra.py")
|
||||
)
|
||||
invalidateImportCaches()
|
||||
|
||||
os.chmod(self.package.path, 0o500)
|
||||
# Change the right of dropin.cache too for windows
|
||||
os.chmod(cachepath.path, 0o400)
|
||||
self.addCleanup(os.chmod, self.package.path, 0o700)
|
||||
self.addCleanup(os.chmod, cachepath.path, 0o700)
|
||||
|
||||
# Start observing log events to see the warning
|
||||
events: list[EventDict] = []
|
||||
addObserver(events.append)
|
||||
self.addCleanup(removeObserver, events.append)
|
||||
|
||||
cache = plugin.getCache(self.module)
|
||||
# The new plugin should be reported
|
||||
self.assertIn("pluginextra", cache)
|
||||
self.assertIn(self.originalPlugin, cache)
|
||||
|
||||
# Make sure something was logged about the cache.
|
||||
expected = "Unable to write to plugin cache %s: error number %d" % (
|
||||
cachepath.path,
|
||||
errno.EPERM,
|
||||
)
|
||||
for event in events:
|
||||
maybeText = textFromEventDict(event)
|
||||
assert maybeText is not None
|
||||
if expected in maybeText: # pragma: no branch
|
||||
break
|
||||
else: # pragma: no cover
|
||||
self.fail(
|
||||
"Did not observe unwriteable cache warning in log "
|
||||
"events: %r" % (events,)
|
||||
)
|
||||
|
||||
|
||||
# This is something like the Twisted plugins file.
|
||||
pluginInitFile = b"""
|
||||
from twisted.plugin import pluginPackagePaths
|
||||
__path__.extend(pluginPackagePaths(__name__))
|
||||
__all__ = []
|
||||
"""
|
||||
|
||||
|
||||
def pluginFileContents(name: str) -> bytes:
|
||||
return (
|
||||
(
|
||||
"from zope.interface import provider\n"
|
||||
"from twisted.plugin import IPlugin\n"
|
||||
"from twisted.test.test_plugin import ITestPlugin\n"
|
||||
"\n"
|
||||
"@provider(IPlugin, ITestPlugin)\n"
|
||||
"class {}:\n"
|
||||
" pass\n"
|
||||
)
|
||||
.format(name)
|
||||
.encode("ascii")
|
||||
)
|
||||
|
||||
|
||||
def _createPluginDummy(
|
||||
entrypath: FilePath[str], pluginContent: bytes, real: bool, pluginModule: str
|
||||
) -> FilePath[str]:
|
||||
"""
|
||||
Create a plugindummy package.
|
||||
"""
|
||||
entrypath.createDirectory()
|
||||
pkg = entrypath.child("plugindummy")
|
||||
pkg.createDirectory()
|
||||
if real:
|
||||
pkg.child("__init__.py").setContent(b"")
|
||||
plugs = pkg.child("plugins")
|
||||
plugs.createDirectory()
|
||||
if real:
|
||||
plugs.child("__init__.py").setContent(pluginInitFile)
|
||||
plugs.child(pluginModule + ".py").setContent(pluginContent)
|
||||
return plugs
|
||||
|
||||
|
||||
class _HasBoolLegacyKey(TypedDict):
|
||||
legacy: bool
|
||||
|
||||
|
||||
class DeveloperSetupTests(unittest.TestCase):
|
||||
"""
|
||||
These tests verify things about the plugin system without actually
|
||||
interacting with the deployed 'twisted.plugins' package, instead creating a
|
||||
temporary package.
|
||||
"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
"""
|
||||
Create a complex environment with multiple entries on sys.path, akin to
|
||||
a developer's environment who has a development (trunk) checkout of
|
||||
Twisted, a system installed version of Twisted (for their operating
|
||||
system's tools) and a project which provides Twisted plugins.
|
||||
"""
|
||||
self.savedPath = sys.path[:]
|
||||
self.savedModules = sys.modules.copy()
|
||||
self.fakeRoot = FilePath(self.mktemp())
|
||||
self.fakeRoot.createDirectory()
|
||||
self.systemPath = self.fakeRoot.child("system_path")
|
||||
self.devPath = self.fakeRoot.child("development_path")
|
||||
self.appPath = self.fakeRoot.child("application_path")
|
||||
self.systemPackage = _createPluginDummy(
|
||||
self.systemPath, pluginFileContents("system"), True, "plugindummy_builtin"
|
||||
)
|
||||
self.devPackage = _createPluginDummy(
|
||||
self.devPath, pluginFileContents("dev"), True, "plugindummy_builtin"
|
||||
)
|
||||
self.appPackage = _createPluginDummy(
|
||||
self.appPath, pluginFileContents("app"), False, "plugindummy_app"
|
||||
)
|
||||
|
||||
# Now we're going to do the system installation.
|
||||
sys.path.extend([x.path for x in [self.systemPath, self.appPath]])
|
||||
# Run all the way through the plugins list to cause the
|
||||
# L{plugin.getPlugins} generator to write cache files for the system
|
||||
# installation.
|
||||
self.getAllPlugins()
|
||||
self.sysplug = self.systemPath.child("plugindummy").child("plugins")
|
||||
self.syscache = self.sysplug.child("dropin.cache")
|
||||
# Make sure there's a nice big difference in modification times so that
|
||||
# we won't re-build the system cache.
|
||||
now = time.time()
|
||||
os.utime(self.sysplug.child("plugindummy_builtin.py").path, (now - 5000,) * 2)
|
||||
os.utime(self.syscache.path, (now - 2000,) * 2)
|
||||
# For extra realism, let's make sure that the system path is no longer
|
||||
# writable.
|
||||
self.lockSystem()
|
||||
self.resetEnvironment()
|
||||
|
||||
def lockSystem(self) -> None:
|
||||
"""
|
||||
Lock the system directories, as if they were unwritable by this user.
|
||||
"""
|
||||
os.chmod(self.sysplug.path, 0o555)
|
||||
os.chmod(self.syscache.path, 0o555)
|
||||
|
||||
def unlockSystem(self) -> None:
|
||||
"""
|
||||
Unlock the system directories, as if they were writable by this user.
|
||||
"""
|
||||
os.chmod(self.sysplug.path, 0o777)
|
||||
os.chmod(self.syscache.path, 0o777)
|
||||
|
||||
def getAllPlugins(self) -> list[str]:
|
||||
"""
|
||||
Get all the plugins loadable from our dummy package, and return their
|
||||
short names.
|
||||
"""
|
||||
# Import the module we just added to our path. (Local scope because
|
||||
# this package doesn't exist outside of this test.)
|
||||
import plugindummy.plugins # type: ignore[import-not-found]
|
||||
|
||||
x = list(plugin.getPlugins(ITestPlugin, plugindummy.plugins))
|
||||
return [plug.__name__ for plug in x] # type: ignore[attr-defined]
|
||||
|
||||
def resetEnvironment(self) -> None:
|
||||
"""
|
||||
Change the environment to what it should be just as the test is
|
||||
starting.
|
||||
"""
|
||||
self.unsetEnvironment()
|
||||
sys.path.extend([x.path for x in [self.devPath, self.systemPath, self.appPath]])
|
||||
|
||||
def unsetEnvironment(self) -> None:
|
||||
"""
|
||||
Change the Python environment back to what it was before the test was
|
||||
started.
|
||||
"""
|
||||
invalidateImportCaches()
|
||||
sys.modules.clear()
|
||||
sys.modules.update(self.savedModules)
|
||||
sys.path[:] = self.savedPath
|
||||
|
||||
def tearDown(self) -> None:
|
||||
"""
|
||||
Reset the Python environment to what it was before this test ran, and
|
||||
restore permissions on files which were marked read-only so that the
|
||||
directory may be cleanly cleaned up.
|
||||
"""
|
||||
self.unsetEnvironment()
|
||||
# Normally we wouldn't "clean up" the filesystem like this (leaving
|
||||
# things for post-test inspection), but if we left the permissions the
|
||||
# way they were, we'd be leaving files around that the buildbots
|
||||
# couldn't delete, and that would be bad.
|
||||
self.unlockSystem()
|
||||
|
||||
def test_developmentPluginAvailability(self) -> None:
|
||||
"""
|
||||
Plugins added in the development path should be loadable, even when
|
||||
the (now non-importable) system path contains its own idea of the
|
||||
list of plugins for a package. Inversely, plugins added in the
|
||||
system path should not be available.
|
||||
"""
|
||||
# Run 3 times: uncached, cached, and then cached again to make sure we
|
||||
# didn't overwrite / corrupt the cache on the cached try.
|
||||
for x in range(3):
|
||||
names = self.getAllPlugins()
|
||||
names.sort()
|
||||
self.assertEqual(names, ["app", "dev"])
|
||||
|
||||
def test_freshPyReplacesStalePyc(self) -> None:
|
||||
"""
|
||||
Verify that if a stale .pyc file on the PYTHONPATH is replaced by a
|
||||
fresh .py file, the plugins in the new .py are picked up rather than
|
||||
the stale .pyc, even if the .pyc is still around.
|
||||
"""
|
||||
mypath = self.appPackage.child("stale.py")
|
||||
mypath.setContent(pluginFileContents("one"))
|
||||
# Make it super stale
|
||||
x = time.time() - 1000
|
||||
os.utime(mypath.path, (x, x))
|
||||
pyc = mypath.sibling("stale.pyc")
|
||||
# compile it
|
||||
# On python 3, don't use the __pycache__ directory; the intention
|
||||
# of scanning for .pyc files is for configurations where you want
|
||||
# to intentionally include them, which means we _don't_ scan for
|
||||
# them inside cache directories.
|
||||
extra = _HasBoolLegacyKey(legacy=True)
|
||||
compileall.compile_dir(self.appPackage.path, quiet=1, **extra)
|
||||
os.utime(pyc.path, (x, x))
|
||||
# Eliminate the other option.
|
||||
mypath.remove()
|
||||
# Make sure it's the .pyc path getting cached.
|
||||
self.resetEnvironment()
|
||||
# Sanity check.
|
||||
self.assertIn("one", self.getAllPlugins())
|
||||
self.failIfIn("two", self.getAllPlugins())
|
||||
self.resetEnvironment()
|
||||
mypath.setContent(pluginFileContents("two"))
|
||||
self.failIfIn("one", self.getAllPlugins())
|
||||
self.assertIn("two", self.getAllPlugins())
|
||||
|
||||
def test_newPluginsOnReadOnlyPath(self) -> None:
|
||||
"""
|
||||
Verify that a failure to write the dropin.cache file on a read-only
|
||||
path will not affect the list of plugins returned.
|
||||
|
||||
Note: this test should pass on both Linux and Windows, but may not
|
||||
provide useful coverage on Windows due to the different meaning of
|
||||
"read-only directory".
|
||||
"""
|
||||
self.unlockSystem()
|
||||
self.sysplug.child("newstuff.py").setContent(pluginFileContents("one"))
|
||||
self.lockSystem()
|
||||
|
||||
# Take the developer path out, so that the system plugins are actually
|
||||
# examined.
|
||||
sys.path.remove(self.devPath.path)
|
||||
|
||||
# Start observing log events to see the warning
|
||||
events: list[EventDict] = []
|
||||
addObserver(events.append)
|
||||
self.addCleanup(removeObserver, events.append)
|
||||
|
||||
self.assertIn("one", self.getAllPlugins())
|
||||
|
||||
# Make sure something was logged about the cache.
|
||||
expected = "Unable to write to plugin cache %s: error number %d" % (
|
||||
self.syscache.path,
|
||||
errno.EPERM,
|
||||
)
|
||||
for event in events:
|
||||
maybeText = textFromEventDict(event)
|
||||
assert maybeText is not None
|
||||
if expected in maybeText: # pragma: no branch
|
||||
break
|
||||
else: # pragma: no cover
|
||||
self.fail(
|
||||
"Did not observe unwriteable cache warning in log "
|
||||
"events: %r" % (events,)
|
||||
)
|
||||
|
||||
|
||||
class AdjacentPackageTests(unittest.TestCase):
|
||||
"""
|
||||
Tests for the behavior of the plugin system when there are multiple
|
||||
installed copies of the package containing the plugins being loaded.
|
||||
"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
"""
|
||||
Save the elements of C{sys.path} and the items of C{sys.modules}.
|
||||
"""
|
||||
self.originalPath = sys.path[:]
|
||||
self.savedModules = sys.modules.copy()
|
||||
|
||||
def tearDown(self) -> None:
|
||||
"""
|
||||
Restore C{sys.path} and C{sys.modules} to their original values.
|
||||
"""
|
||||
sys.path[:] = self.originalPath
|
||||
sys.modules.clear()
|
||||
sys.modules.update(self.savedModules)
|
||||
|
||||
def createDummyPackage(
|
||||
self, root: FilePath[str], name: str, pluginName: str
|
||||
) -> FilePath[str]:
|
||||
"""
|
||||
Create a directory containing a Python package named I{dummy} with a
|
||||
I{plugins} subpackage.
|
||||
|
||||
@type root: L{FilePath}
|
||||
@param root: The directory in which to create the hierarchy.
|
||||
|
||||
@type name: C{str}
|
||||
@param name: The name of the directory to create which will contain
|
||||
the package.
|
||||
|
||||
@type pluginName: C{str}
|
||||
@param pluginName: The name of a module to create in the
|
||||
I{dummy.plugins} package.
|
||||
|
||||
@rtype: L{FilePath}
|
||||
@return: The directory which was created to contain the I{dummy}
|
||||
package.
|
||||
"""
|
||||
directory = root.child(name)
|
||||
package = directory.child("dummy")
|
||||
package.makedirs()
|
||||
package.child("__init__.py").setContent(b"")
|
||||
plugins = package.child("plugins")
|
||||
plugins.makedirs()
|
||||
plugins.child("__init__.py").setContent(pluginInitFile)
|
||||
pluginModule = plugins.child(pluginName + ".py")
|
||||
pluginModule.setContent(pluginFileContents(name))
|
||||
return directory
|
||||
|
||||
def test_hiddenPackageSamePluginModuleNameObscured(self) -> None:
|
||||
"""
|
||||
Only plugins from the first package in sys.path should be returned by
|
||||
getPlugins in the case where there are two Python packages by the same
|
||||
name installed, each with a plugin module by a single name.
|
||||
"""
|
||||
root = FilePath(self.mktemp())
|
||||
root.makedirs()
|
||||
|
||||
firstDirectory = self.createDummyPackage(root, "first", "someplugin")
|
||||
secondDirectory = self.createDummyPackage(root, "second", "someplugin")
|
||||
|
||||
sys.path.append(firstDirectory.path)
|
||||
sys.path.append(secondDirectory.path)
|
||||
|
||||
import dummy.plugins # type: ignore[import-not-found]
|
||||
|
||||
plugins = list(plugin.getPlugins(ITestPlugin, dummy.plugins))
|
||||
self.assertEqual(["first"], [p.__name__ for p in plugins]) # type: ignore[attr-defined]
|
||||
|
||||
def test_hiddenPackageDifferentPluginModuleNameObscured(self) -> None:
|
||||
"""
|
||||
Plugins from the first package in sys.path should be returned by
|
||||
getPlugins in the case where there are two Python packages by the same
|
||||
name installed, each with a plugin module by a different name.
|
||||
"""
|
||||
root = FilePath(self.mktemp())
|
||||
root.makedirs()
|
||||
|
||||
firstDirectory = self.createDummyPackage(root, "first", "thisplugin")
|
||||
secondDirectory = self.createDummyPackage(root, "second", "thatplugin")
|
||||
|
||||
sys.path.append(firstDirectory.path)
|
||||
sys.path.append(secondDirectory.path)
|
||||
|
||||
import dummy.plugins
|
||||
|
||||
plugins = list(plugin.getPlugins(ITestPlugin, dummy.plugins))
|
||||
self.assertEqual(["first"], [p.__name__ for p in plugins]) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
class PackagePathTests(unittest.TestCase):
|
||||
"""
|
||||
Tests for L{plugin.pluginPackagePaths} which constructs search paths for
|
||||
plugin packages.
|
||||
"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
"""
|
||||
Save the elements of C{sys.path}.
|
||||
"""
|
||||
self.originalPath = sys.path[:]
|
||||
|
||||
def tearDown(self) -> None:
|
||||
"""
|
||||
Restore C{sys.path} to its original value.
|
||||
"""
|
||||
sys.path[:] = self.originalPath
|
||||
|
||||
def test_pluginDirectories(self) -> None:
|
||||
"""
|
||||
L{plugin.pluginPackagePaths} should return a list containing each
|
||||
directory in C{sys.path} with a suffix based on the supplied package
|
||||
name.
|
||||
"""
|
||||
foo = FilePath("foo")
|
||||
bar = FilePath("bar")
|
||||
sys.path = [foo.path, bar.path]
|
||||
self.assertEqual(
|
||||
plugin.pluginPackagePaths("dummy.plugins"),
|
||||
[
|
||||
foo.child("dummy").child("plugins").path,
|
||||
bar.child("dummy").child("plugins").path,
|
||||
],
|
||||
)
|
||||
|
||||
def test_pluginPackagesExcluded(self) -> None:
|
||||
"""
|
||||
L{plugin.pluginPackagePaths} should exclude directories which are
|
||||
Python packages. The only allowed plugin package (the only one
|
||||
associated with a I{dummy} package which Python will allow to be
|
||||
imported) will already be known to the caller of
|
||||
L{plugin.pluginPackagePaths} and will most commonly already be in
|
||||
the C{__path__} they are about to mutate.
|
||||
"""
|
||||
root = FilePath(self.mktemp())
|
||||
foo = root.child("foo").child("dummy").child("plugins")
|
||||
foo.makedirs()
|
||||
foo.child("__init__.py").setContent(b"")
|
||||
sys.path = [root.child("foo").path, root.child("bar").path]
|
||||
self.assertEqual(
|
||||
plugin.pluginPackagePaths("dummy.plugins"),
|
||||
[root.child("bar").child("dummy").child("plugins").path],
|
||||
)
|
||||
1001
.venv/lib/python3.12/site-packages/twisted/test/test_policies.py
Normal file
1001
.venv/lib/python3.12/site-packages/twisted/test/test_policies.py
Normal file
File diff suppressed because it is too large
Load Diff
140
.venv/lib/python3.12/site-packages/twisted/test/test_postfix.py
Normal file
140
.venv/lib/python3.12/site-packages/twisted/test/test_postfix.py
Normal file
@@ -0,0 +1,140 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Test cases for twisted.protocols.postfix module.
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from twisted.internet.testing import StringTransport
|
||||
from twisted.protocols import postfix
|
||||
from twisted.trial import unittest
|
||||
|
||||
|
||||
class PostfixTCPMapQuoteTests(unittest.TestCase):
|
||||
data = [
|
||||
# (raw, quoted, [aliasQuotedForms]),
|
||||
(b"foo", b"foo"),
|
||||
(b"foo bar", b"foo%20bar"),
|
||||
(b"foo\tbar", b"foo%09bar"),
|
||||
(b"foo\nbar", b"foo%0Abar", b"foo%0abar"),
|
||||
(
|
||||
b"foo\r\nbar",
|
||||
b"foo%0D%0Abar",
|
||||
b"foo%0D%0abar",
|
||||
b"foo%0d%0Abar",
|
||||
b"foo%0d%0abar",
|
||||
),
|
||||
(b"foo ", b"foo%20"),
|
||||
(b" foo", b"%20foo"),
|
||||
]
|
||||
|
||||
def testData(self):
|
||||
for entry in self.data:
|
||||
raw = entry[0]
|
||||
quoted = entry[1:]
|
||||
|
||||
self.assertEqual(postfix.quote(raw), quoted[0])
|
||||
for q in quoted:
|
||||
self.assertEqual(postfix.unquote(q), raw)
|
||||
|
||||
|
||||
class PostfixTCPMapServerTestCase:
|
||||
data: Dict[bytes, bytes] = {
|
||||
# 'key': 'value',
|
||||
}
|
||||
|
||||
chat: List[Tuple[bytes, bytes]] = [
|
||||
# (input, expected_output),
|
||||
]
|
||||
|
||||
def test_chat(self):
|
||||
"""
|
||||
Test that I{get} and I{put} commands are responded to correctly by
|
||||
L{postfix.PostfixTCPMapServer} when its factory is an instance of
|
||||
L{postifx.PostfixTCPMapDictServerFactory}.
|
||||
"""
|
||||
factory = postfix.PostfixTCPMapDictServerFactory(self.data)
|
||||
transport = StringTransport()
|
||||
|
||||
protocol = postfix.PostfixTCPMapServer()
|
||||
protocol.service = factory
|
||||
protocol.factory = factory
|
||||
protocol.makeConnection(transport)
|
||||
|
||||
for input, expected_output in self.chat:
|
||||
protocol.lineReceived(input)
|
||||
self.assertEqual(
|
||||
transport.value(),
|
||||
expected_output,
|
||||
"For %r, expected %r but got %r"
|
||||
% (input, expected_output, transport.value()),
|
||||
)
|
||||
transport.clear()
|
||||
protocol.setTimeout(None)
|
||||
|
||||
def test_deferredChat(self):
|
||||
"""
|
||||
Test that I{get} and I{put} commands are responded to correctly by
|
||||
L{postfix.PostfixTCPMapServer} when its factory is an instance of
|
||||
L{postifx.PostfixTCPMapDeferringDictServerFactory}.
|
||||
"""
|
||||
factory = postfix.PostfixTCPMapDeferringDictServerFactory(self.data)
|
||||
transport = StringTransport()
|
||||
|
||||
protocol = postfix.PostfixTCPMapServer()
|
||||
protocol.service = factory
|
||||
protocol.factory = factory
|
||||
protocol.makeConnection(transport)
|
||||
|
||||
for input, expected_output in self.chat:
|
||||
protocol.lineReceived(input)
|
||||
self.assertEqual(
|
||||
transport.value(),
|
||||
expected_output,
|
||||
"For {!r}, expected {!r} but got {!r}".format(
|
||||
input, expected_output, transport.value()
|
||||
),
|
||||
)
|
||||
transport.clear()
|
||||
protocol.setTimeout(None)
|
||||
|
||||
def test_getException(self):
|
||||
"""
|
||||
If the factory throws an exception,
|
||||
error code 400 must be returned.
|
||||
"""
|
||||
|
||||
class ErrorFactory:
|
||||
"""
|
||||
Factory that raises an error on key lookup.
|
||||
"""
|
||||
|
||||
def get(self, key):
|
||||
raise Exception("This is a test error")
|
||||
|
||||
server = postfix.PostfixTCPMapServer()
|
||||
server.factory = ErrorFactory()
|
||||
server.transport = StringTransport()
|
||||
server.lineReceived(b"get example")
|
||||
self.assertEqual(server.transport.value(), b"400 This is a test error\n")
|
||||
|
||||
|
||||
class ValidTests(PostfixTCPMapServerTestCase, unittest.TestCase):
|
||||
data = {
|
||||
b"foo": b"ThisIs Foo",
|
||||
b"bar": b" bar really is found\r\n",
|
||||
}
|
||||
chat = [
|
||||
(b"get", b"400 Command 'get' takes 1 parameters.\n"),
|
||||
(b"get foo bar", b"500 \n"),
|
||||
(b"put", b"400 Command 'put' takes 2 parameters.\n"),
|
||||
(b"put foo", b"400 Command 'put' takes 2 parameters.\n"),
|
||||
(b"put foo bar baz", b"500 put is not implemented yet.\n"),
|
||||
(b"put foo bar", b"500 put is not implemented yet.\n"),
|
||||
(b"get foo", b"200 ThisIs%20Foo\n"),
|
||||
(b"get bar", b"200 %20bar%20really%20is%20found%0D%0A\n"),
|
||||
(b"get baz", b"500 \n"),
|
||||
(b"foo", b"400 unknown command\n"),
|
||||
]
|
||||
2789
.venv/lib/python3.12/site-packages/twisted/test/test_process.py
Normal file
2789
.venv/lib/python3.12/site-packages/twisted/test/test_process.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,227 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Test cases for twisted.protocols package.
|
||||
"""
|
||||
|
||||
from twisted.internet import address, defer, protocol, reactor
|
||||
from twisted.protocols import portforward, wire
|
||||
from twisted.python.compat import iterbytes
|
||||
from twisted.test import proto_helpers
|
||||
from twisted.trial import unittest
|
||||
|
||||
|
||||
class WireTests(unittest.TestCase):
|
||||
"""
|
||||
Test wire protocols.
|
||||
"""
|
||||
|
||||
def test_echo(self):
|
||||
"""
|
||||
Test wire.Echo protocol: send some data and check it send it back.
|
||||
"""
|
||||
t = proto_helpers.StringTransport()
|
||||
a = wire.Echo()
|
||||
a.makeConnection(t)
|
||||
a.dataReceived(b"hello")
|
||||
a.dataReceived(b"world")
|
||||
a.dataReceived(b"how")
|
||||
a.dataReceived(b"are")
|
||||
a.dataReceived(b"you")
|
||||
self.assertEqual(t.value(), b"helloworldhowareyou")
|
||||
|
||||
def test_who(self):
|
||||
"""
|
||||
Test wire.Who protocol.
|
||||
"""
|
||||
t = proto_helpers.StringTransport()
|
||||
a = wire.Who()
|
||||
a.makeConnection(t)
|
||||
self.assertEqual(t.value(), b"root\r\n")
|
||||
|
||||
def test_QOTD(self):
|
||||
"""
|
||||
Test wire.QOTD protocol.
|
||||
"""
|
||||
t = proto_helpers.StringTransport()
|
||||
a = wire.QOTD()
|
||||
a.makeConnection(t)
|
||||
self.assertEqual(t.value(), b"An apple a day keeps the doctor away.\r\n")
|
||||
|
||||
def test_discard(self):
|
||||
"""
|
||||
Test wire.Discard protocol.
|
||||
"""
|
||||
t = proto_helpers.StringTransport()
|
||||
a = wire.Discard()
|
||||
a.makeConnection(t)
|
||||
a.dataReceived(b"hello")
|
||||
a.dataReceived(b"world")
|
||||
a.dataReceived(b"how")
|
||||
a.dataReceived(b"are")
|
||||
a.dataReceived(b"you")
|
||||
self.assertEqual(t.value(), b"")
|
||||
|
||||
|
||||
class TestableProxyClientFactory(portforward.ProxyClientFactory):
|
||||
"""
|
||||
Test proxy client factory that keeps the last created protocol instance.
|
||||
|
||||
@ivar protoInstance: the last instance of the protocol.
|
||||
@type protoInstance: L{portforward.ProxyClient}
|
||||
"""
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
"""
|
||||
Create the protocol instance and keeps track of it.
|
||||
"""
|
||||
proto = portforward.ProxyClientFactory.buildProtocol(self, addr)
|
||||
self.protoInstance = proto
|
||||
return proto
|
||||
|
||||
|
||||
class TestableProxyFactory(portforward.ProxyFactory):
|
||||
"""
|
||||
Test proxy factory that keeps the last created protocol instance.
|
||||
|
||||
@ivar protoInstance: the last instance of the protocol.
|
||||
@type protoInstance: L{portforward.ProxyServer}
|
||||
|
||||
@ivar clientFactoryInstance: client factory used by C{protoInstance} to
|
||||
create forward connections.
|
||||
@type clientFactoryInstance: L{TestableProxyClientFactory}
|
||||
"""
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
"""
|
||||
Create the protocol instance, keeps track of it, and makes it use
|
||||
C{clientFactoryInstance} as client factory.
|
||||
"""
|
||||
proto = portforward.ProxyFactory.buildProtocol(self, addr)
|
||||
self.clientFactoryInstance = TestableProxyClientFactory()
|
||||
# Force the use of this specific instance
|
||||
proto.clientProtocolFactory = lambda: self.clientFactoryInstance
|
||||
self.protoInstance = proto
|
||||
return proto
|
||||
|
||||
|
||||
class PortforwardingTests(unittest.TestCase):
|
||||
"""
|
||||
Test port forwarding.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
self.serverProtocol = wire.Echo()
|
||||
self.clientProtocol = protocol.Protocol()
|
||||
self.openPorts = []
|
||||
|
||||
def tearDown(self):
|
||||
try:
|
||||
self.proxyServerFactory.protoInstance.transport.loseConnection()
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
pi = self.proxyServerFactory.clientFactoryInstance.protoInstance
|
||||
pi.transport.loseConnection()
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
self.clientProtocol.transport.loseConnection()
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
self.serverProtocol.transport.loseConnection()
|
||||
except AttributeError:
|
||||
pass
|
||||
return defer.gatherResults(
|
||||
[defer.maybeDeferred(p.stopListening) for p in self.openPorts]
|
||||
)
|
||||
|
||||
def test_portforward(self):
|
||||
"""
|
||||
Test port forwarding through Echo protocol.
|
||||
"""
|
||||
realServerFactory = protocol.ServerFactory()
|
||||
realServerFactory.protocol = lambda: self.serverProtocol
|
||||
realServerPort = reactor.listenTCP(0, realServerFactory, interface="127.0.0.1")
|
||||
self.openPorts.append(realServerPort)
|
||||
self.proxyServerFactory = TestableProxyFactory(
|
||||
"127.0.0.1", realServerPort.getHost().port
|
||||
)
|
||||
proxyServerPort = reactor.listenTCP(
|
||||
0, self.proxyServerFactory, interface="127.0.0.1"
|
||||
)
|
||||
self.openPorts.append(proxyServerPort)
|
||||
|
||||
nBytes = 1000
|
||||
received = []
|
||||
d = defer.Deferred()
|
||||
|
||||
def testDataReceived(data):
|
||||
received.extend(iterbytes(data))
|
||||
if len(received) >= nBytes:
|
||||
self.assertEqual(b"".join(received), b"x" * nBytes)
|
||||
d.callback(None)
|
||||
|
||||
self.clientProtocol.dataReceived = testDataReceived
|
||||
|
||||
def testConnectionMade():
|
||||
self.clientProtocol.transport.write(b"x" * nBytes)
|
||||
|
||||
self.clientProtocol.connectionMade = testConnectionMade
|
||||
|
||||
clientFactory = protocol.ClientFactory()
|
||||
clientFactory.protocol = lambda: self.clientProtocol
|
||||
|
||||
reactor.connectTCP("127.0.0.1", proxyServerPort.getHost().port, clientFactory)
|
||||
|
||||
return d
|
||||
|
||||
def test_registerProducers(self):
|
||||
"""
|
||||
The proxy client registers itself as a producer of the proxy server and
|
||||
vice versa.
|
||||
"""
|
||||
# create a ProxyServer instance
|
||||
addr = address.IPv4Address("TCP", "127.0.0.1", 0)
|
||||
server = portforward.ProxyFactory("127.0.0.1", 0).buildProtocol(addr)
|
||||
|
||||
# set the reactor for this test
|
||||
reactor = proto_helpers.MemoryReactor()
|
||||
server.reactor = reactor
|
||||
|
||||
# make the connection
|
||||
serverTransport = proto_helpers.StringTransport()
|
||||
server.makeConnection(serverTransport)
|
||||
|
||||
# check that the ProxyClientFactory is connecting to the backend
|
||||
self.assertEqual(len(reactor.tcpClients), 1)
|
||||
# get the factory instance and check it's the one we expect
|
||||
host, port, clientFactory, timeout, _ = reactor.tcpClients[0]
|
||||
self.assertIsInstance(clientFactory, portforward.ProxyClientFactory)
|
||||
|
||||
# Connect it
|
||||
client = clientFactory.buildProtocol(addr)
|
||||
clientTransport = proto_helpers.StringTransport()
|
||||
client.makeConnection(clientTransport)
|
||||
|
||||
# check that the producers are registered
|
||||
self.assertIs(clientTransport.producer, serverTransport)
|
||||
self.assertIs(serverTransport.producer, clientTransport)
|
||||
# check the streaming attribute in both transports
|
||||
self.assertTrue(clientTransport.streaming)
|
||||
self.assertTrue(serverTransport.streaming)
|
||||
|
||||
|
||||
class StringTransportTests(unittest.TestCase):
|
||||
"""
|
||||
Test L{proto_helpers.StringTransport} helper behaviour.
|
||||
"""
|
||||
|
||||
def test_noUnicode(self):
|
||||
"""
|
||||
Test that L{proto_helpers.StringTransport} doesn't accept unicode data.
|
||||
"""
|
||||
s = proto_helpers.StringTransport()
|
||||
self.assertRaises(TypeError, s.write, "foo")
|
||||
@@ -0,0 +1,126 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Test cases for L{twisted.python.randbytes}.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Callable
|
||||
|
||||
from typing_extensions import NoReturn, Protocol
|
||||
|
||||
from twisted.python import randbytes
|
||||
from twisted.trial import unittest
|
||||
|
||||
|
||||
class _SupportsAssertions(Protocol):
|
||||
def assertEqual(self, a: object, b: object) -> object:
|
||||
...
|
||||
|
||||
def assertNotEqual(self, a: object, b: object) -> object:
|
||||
...
|
||||
|
||||
|
||||
class SecureRandomTestCaseBase:
|
||||
"""
|
||||
Base class for secureRandom test cases.
|
||||
"""
|
||||
|
||||
def _check(self: _SupportsAssertions, source: Callable[[int], bytes]) -> None:
|
||||
"""
|
||||
The given random bytes source should return the number of bytes
|
||||
requested each time it is called and should probably not return the
|
||||
same bytes on two consecutive calls (although this is a perfectly
|
||||
legitimate occurrence and rejecting it may generate a spurious failure
|
||||
-- maybe we'll get lucky and the heat death with come first).
|
||||
"""
|
||||
for nbytes in range(17, 25):
|
||||
s = source(nbytes)
|
||||
self.assertEqual(len(s), nbytes)
|
||||
s2 = source(nbytes)
|
||||
self.assertEqual(len(s2), nbytes)
|
||||
# This is crude but hey
|
||||
self.assertNotEqual(s2, s)
|
||||
|
||||
|
||||
class SecureRandomTests(SecureRandomTestCaseBase, unittest.TestCase):
|
||||
"""
|
||||
Test secureRandom under normal conditions.
|
||||
"""
|
||||
|
||||
def test_normal(self) -> None:
|
||||
"""
|
||||
L{randbytes.secureRandom} should return a string of the requested
|
||||
length and make some effort to make its result otherwise unpredictable.
|
||||
"""
|
||||
self._check(randbytes.secureRandom)
|
||||
|
||||
|
||||
class ConditionalSecureRandomTests(
|
||||
SecureRandomTestCaseBase, unittest.SynchronousTestCase
|
||||
):
|
||||
"""
|
||||
Test random sources one by one, then remove it to.
|
||||
"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
"""
|
||||
Create a L{randbytes.RandomFactory} to use in the tests.
|
||||
"""
|
||||
self.factory = randbytes.RandomFactory()
|
||||
|
||||
def errorFactory(self, nbytes: object) -> NoReturn:
|
||||
"""
|
||||
A factory raising an error when a source is not available.
|
||||
"""
|
||||
raise randbytes.SourceNotAvailable()
|
||||
|
||||
def test_osUrandom(self) -> None:
|
||||
"""
|
||||
L{RandomFactory._osUrandom} should work as a random source whenever
|
||||
L{os.urandom} is available.
|
||||
"""
|
||||
self._check(self.factory._osUrandom)
|
||||
|
||||
def test_withoutAnything(self) -> None:
|
||||
"""
|
||||
Remove all secure sources and assert it raises a failure. Then try the
|
||||
fallback parameter.
|
||||
"""
|
||||
self.factory._osUrandom = self.errorFactory # type: ignore[method-assign]
|
||||
self.assertRaises(
|
||||
randbytes.SecureRandomNotAvailable, self.factory.secureRandom, 18
|
||||
)
|
||||
|
||||
def wrapper() -> bytes:
|
||||
return self.factory.secureRandom(18, fallback=True)
|
||||
|
||||
s = self.assertWarns(
|
||||
RuntimeWarning,
|
||||
"urandom unavailable - "
|
||||
"proceeding with non-cryptographically secure random source",
|
||||
__file__,
|
||||
wrapper,
|
||||
)
|
||||
self.assertEqual(len(s), 18)
|
||||
|
||||
|
||||
class RandomBaseTests(SecureRandomTestCaseBase, unittest.SynchronousTestCase):
|
||||
"""
|
||||
'Normal' random test cases.
|
||||
"""
|
||||
|
||||
def test_normal(self) -> None:
|
||||
"""
|
||||
Test basic case.
|
||||
"""
|
||||
self._check(randbytes.insecureRandom)
|
||||
|
||||
def test_withoutGetrandbits(self) -> None:
|
||||
"""
|
||||
Test C{insecureRandom} without C{random.getrandbits}.
|
||||
"""
|
||||
factory = randbytes.RandomFactory()
|
||||
factory.getrandbits = None
|
||||
self._check(factory.insecureRandom)
|
||||
266
.venv/lib/python3.12/site-packages/twisted/test/test_rebuild.py
Normal file
266
.venv/lib/python3.12/site-packages/twisted/test/test_rebuild.py
Normal file
@@ -0,0 +1,266 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
import types
|
||||
|
||||
from typing_extensions import NoReturn
|
||||
|
||||
from twisted.python import rebuild
|
||||
from twisted.trial.unittest import TestCase
|
||||
from . import crash_test_dummy
|
||||
|
||||
f = crash_test_dummy.foo
|
||||
|
||||
|
||||
class Foo:
|
||||
pass
|
||||
|
||||
|
||||
class Bar(Foo):
|
||||
pass
|
||||
|
||||
|
||||
class Baz:
|
||||
pass
|
||||
|
||||
|
||||
class Buz(Bar, Baz):
|
||||
pass
|
||||
|
||||
|
||||
class HashRaisesRuntimeError:
|
||||
"""
|
||||
Things that don't hash (raise an Exception) should be ignored by the
|
||||
rebuilder.
|
||||
|
||||
@ivar hashCalled: C{bool} set to True when __hash__ is called.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.hashCalled = False
|
||||
|
||||
def __hash__(self) -> NoReturn:
|
||||
self.hashCalled = True
|
||||
raise RuntimeError("not a TypeError!")
|
||||
|
||||
|
||||
# Set in test_hashException
|
||||
unhashableObject = None
|
||||
|
||||
|
||||
class RebuildTests(TestCase):
|
||||
"""
|
||||
Simple testcase for rebuilding, to at least exercise the code.
|
||||
"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.libPath = self.mktemp()
|
||||
os.mkdir(self.libPath)
|
||||
self.fakelibPath = os.path.join(self.libPath, "twisted_rebuild_fakelib")
|
||||
os.mkdir(self.fakelibPath)
|
||||
open(os.path.join(self.fakelibPath, "__init__.py"), "w").close()
|
||||
sys.path.insert(0, self.libPath)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
sys.path.remove(self.libPath)
|
||||
|
||||
def test_FileRebuild(self) -> None:
|
||||
import shutil
|
||||
import time
|
||||
|
||||
from twisted.python.util import sibpath
|
||||
|
||||
shutil.copyfile(
|
||||
sibpath(__file__, "myrebuilder1.py"),
|
||||
os.path.join(self.fakelibPath, "myrebuilder.py"),
|
||||
)
|
||||
from twisted_rebuild_fakelib import ( # type: ignore[import-not-found]
|
||||
myrebuilder,
|
||||
)
|
||||
|
||||
a = myrebuilder.A()
|
||||
b = myrebuilder.B()
|
||||
i = myrebuilder.Inherit()
|
||||
self.assertEqual(a.a(), "a")
|
||||
# Necessary because the file has not "changed" if a second has not gone
|
||||
# by in unix. This sucks, but it's not often that you'll be doing more
|
||||
# than one reload per second.
|
||||
time.sleep(1.1)
|
||||
shutil.copyfile(
|
||||
sibpath(__file__, "myrebuilder2.py"),
|
||||
os.path.join(self.fakelibPath, "myrebuilder.py"),
|
||||
)
|
||||
rebuild.rebuild(myrebuilder)
|
||||
b2 = myrebuilder.B()
|
||||
self.assertEqual(b2.b(), "c")
|
||||
self.assertEqual(b.b(), "c")
|
||||
self.assertEqual(i.a(), "d")
|
||||
self.assertEqual(a.a(), "b")
|
||||
|
||||
def test_Rebuild(self) -> None:
|
||||
"""
|
||||
Rebuilding an unchanged module.
|
||||
"""
|
||||
# This test would actually pass if rebuild was a no-op, but it
|
||||
# ensures rebuild doesn't break stuff while being a less
|
||||
# complex test than testFileRebuild.
|
||||
|
||||
x = crash_test_dummy.X("a")
|
||||
|
||||
rebuild.rebuild(crash_test_dummy, doLog=False)
|
||||
# Instance rebuilding is triggered by attribute access.
|
||||
x.do()
|
||||
self.assertEqual(x.__class__, crash_test_dummy.X)
|
||||
|
||||
self.assertEqual(f, crash_test_dummy.foo)
|
||||
|
||||
def test_ComponentInteraction(self) -> None:
|
||||
x = crash_test_dummy.XComponent()
|
||||
x.setAdapter(crash_test_dummy.IX, crash_test_dummy.XA)
|
||||
x.getComponent(crash_test_dummy.IX)
|
||||
rebuild.rebuild(crash_test_dummy, 0)
|
||||
newComponent = x.getComponent(crash_test_dummy.IX)
|
||||
|
||||
newComponent.method()
|
||||
|
||||
self.assertEqual(newComponent.__class__, crash_test_dummy.XA)
|
||||
|
||||
# Test that a duplicate registerAdapter is not allowed
|
||||
from twisted.python import components
|
||||
|
||||
self.assertRaises(
|
||||
ValueError,
|
||||
components.registerAdapter,
|
||||
crash_test_dummy.XA,
|
||||
crash_test_dummy.X,
|
||||
crash_test_dummy.IX,
|
||||
)
|
||||
|
||||
def test_UpdateInstance(self) -> None:
|
||||
global Foo, Buz
|
||||
|
||||
b = Buz()
|
||||
|
||||
class Foo:
|
||||
def foo(self) -> None:
|
||||
"""
|
||||
Dummy method
|
||||
"""
|
||||
|
||||
class Buz(Bar, Baz):
|
||||
x = 10
|
||||
|
||||
rebuild.updateInstance(b)
|
||||
assert hasattr(b, "foo"), "Missing method on rebuilt instance"
|
||||
assert hasattr(b, "x"), "Missing class attribute on rebuilt instance"
|
||||
|
||||
def test_BananaInteraction(self) -> None:
|
||||
from twisted.python import rebuild
|
||||
from twisted.spread import banana
|
||||
|
||||
rebuild.latestClass(banana.Banana)
|
||||
|
||||
def test_hashException(self) -> None:
|
||||
"""
|
||||
Rebuilding something that has a __hash__ that raises a non-TypeError
|
||||
shouldn't cause rebuild to die.
|
||||
"""
|
||||
global unhashableObject
|
||||
unhashableObject = HashRaisesRuntimeError()
|
||||
|
||||
def _cleanup() -> None:
|
||||
global unhashableObject
|
||||
unhashableObject = None
|
||||
|
||||
self.addCleanup(_cleanup)
|
||||
rebuild.rebuild(rebuild)
|
||||
self.assertTrue(unhashableObject.hashCalled)
|
||||
|
||||
def test_Sensitive(self) -> None:
|
||||
"""
|
||||
L{twisted.python.rebuild.Sensitive}
|
||||
"""
|
||||
from twisted.python import rebuild
|
||||
from twisted.python.rebuild import Sensitive
|
||||
|
||||
class TestSensitive(Sensitive):
|
||||
def test_method(self) -> None:
|
||||
"""
|
||||
Dummy method
|
||||
"""
|
||||
|
||||
testSensitive = TestSensitive()
|
||||
testSensitive.rebuildUpToDate()
|
||||
self.assertFalse(testSensitive.needRebuildUpdate())
|
||||
|
||||
# Test rebuilding a builtin class
|
||||
newException = rebuild.latestClass(Exception)
|
||||
self.assertEqual(repr(Exception), repr(newException))
|
||||
self.assertEqual(newException, testSensitive.latestVersionOf(newException))
|
||||
|
||||
# Test types.MethodType on method in class
|
||||
self.assertEqual(
|
||||
TestSensitive.test_method,
|
||||
testSensitive.latestVersionOf(TestSensitive.test_method),
|
||||
)
|
||||
# Test types.MethodType on method in instance of class
|
||||
self.assertEqual(
|
||||
testSensitive.test_method,
|
||||
testSensitive.latestVersionOf(testSensitive.test_method),
|
||||
)
|
||||
# Test a class
|
||||
self.assertEqual(TestSensitive, testSensitive.latestVersionOf(TestSensitive))
|
||||
|
||||
def myFunction() -> None:
|
||||
"""
|
||||
Dummy method
|
||||
"""
|
||||
|
||||
# Test types.FunctionType
|
||||
self.assertEqual(myFunction, testSensitive.latestVersionOf(myFunction))
|
||||
|
||||
|
||||
class NewStyleTests(TestCase):
|
||||
"""
|
||||
Tests for rebuilding new-style classes of various sorts.
|
||||
"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.m = types.ModuleType("whipping")
|
||||
sys.modules["whipping"] = self.m
|
||||
|
||||
def tearDown(self) -> None:
|
||||
del sys.modules["whipping"]
|
||||
del self.m
|
||||
|
||||
def test_slots(self) -> None:
|
||||
"""
|
||||
Try to rebuild a new style class with slots defined.
|
||||
"""
|
||||
classDefinition = "class SlottedClass:\n" " __slots__ = ['a']\n"
|
||||
|
||||
exec(classDefinition, self.m.__dict__)
|
||||
inst = self.m.SlottedClass()
|
||||
inst.a = 7
|
||||
exec(classDefinition, self.m.__dict__)
|
||||
rebuild.updateInstance(inst)
|
||||
self.assertEqual(inst.a, 7)
|
||||
self.assertIs(type(inst), self.m.SlottedClass)
|
||||
|
||||
def test_typeSubclass(self) -> None:
|
||||
"""
|
||||
Try to rebuild a base type subclass.
|
||||
"""
|
||||
classDefinition = "class ListSubclass(list):\n" " pass\n"
|
||||
|
||||
exec(classDefinition, self.m.__dict__)
|
||||
inst = self.m.ListSubclass()
|
||||
inst.append(2)
|
||||
exec(classDefinition, self.m.__dict__)
|
||||
rebuild.updateInstance(inst)
|
||||
self.assertEqual(inst[0], 2)
|
||||
self.assertIs(type(inst), self.m.ListSubclass)
|
||||
825
.venv/lib/python3.12/site-packages/twisted/test/test_reflect.py
Normal file
825
.venv/lib/python3.12/site-packages/twisted/test/test_reflect.py
Normal file
@@ -0,0 +1,825 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Test cases for the L{twisted.python.reflect} module.
|
||||
"""
|
||||
|
||||
|
||||
import os
|
||||
import weakref
|
||||
from collections import deque
|
||||
|
||||
from twisted.python import reflect
|
||||
from twisted.python.reflect import (
|
||||
accumulateMethods,
|
||||
addMethodNamesToDict,
|
||||
fullyQualifiedName,
|
||||
prefixedMethodNames,
|
||||
prefixedMethods,
|
||||
)
|
||||
from twisted.trial.unittest import SynchronousTestCase as TestCase
|
||||
|
||||
|
||||
class Base:
|
||||
"""
|
||||
A no-op class which can be used to verify the behavior of
|
||||
method-discovering APIs.
|
||||
"""
|
||||
|
||||
def method(self):
|
||||
"""
|
||||
A no-op method which can be discovered.
|
||||
"""
|
||||
|
||||
|
||||
class Sub(Base):
|
||||
"""
|
||||
A subclass of a class with a method which can be discovered.
|
||||
"""
|
||||
|
||||
|
||||
class Separate:
|
||||
"""
|
||||
A no-op class with methods with differing prefixes.
|
||||
"""
|
||||
|
||||
def good_method(self):
|
||||
"""
|
||||
A no-op method which a matching prefix to be discovered.
|
||||
"""
|
||||
|
||||
def bad_method(self):
|
||||
"""
|
||||
A no-op method with a mismatched prefix to not be discovered.
|
||||
"""
|
||||
|
||||
|
||||
class AccumulateMethodsTests(TestCase):
|
||||
"""
|
||||
Tests for L{accumulateMethods} which finds methods on a class hierarchy and
|
||||
adds them to a dictionary.
|
||||
"""
|
||||
|
||||
def test_ownClass(self):
|
||||
"""
|
||||
If x is and instance of Base and Base defines a method named method,
|
||||
L{accumulateMethods} adds an item to the given dictionary with
|
||||
C{"method"} as the key and a bound method object for Base.method value.
|
||||
"""
|
||||
x = Base()
|
||||
output = {}
|
||||
accumulateMethods(x, output)
|
||||
self.assertEqual({"method": x.method}, output)
|
||||
|
||||
def test_baseClass(self):
|
||||
"""
|
||||
If x is an instance of Sub and Sub is a subclass of Base and Base
|
||||
defines a method named method, L{accumulateMethods} adds an item to the
|
||||
given dictionary with C{"method"} as the key and a bound method object
|
||||
for Base.method as the value.
|
||||
"""
|
||||
x = Sub()
|
||||
output = {}
|
||||
accumulateMethods(x, output)
|
||||
self.assertEqual({"method": x.method}, output)
|
||||
|
||||
def test_prefix(self):
|
||||
"""
|
||||
If a prefix is given, L{accumulateMethods} limits its results to
|
||||
methods beginning with that prefix. Keys in the resulting dictionary
|
||||
also have the prefix removed from them.
|
||||
"""
|
||||
x = Separate()
|
||||
output = {}
|
||||
accumulateMethods(x, output, "good_")
|
||||
self.assertEqual({"method": x.good_method}, output)
|
||||
|
||||
|
||||
class PrefixedMethodsTests(TestCase):
|
||||
"""
|
||||
Tests for L{prefixedMethods} which finds methods on a class hierarchy and
|
||||
adds them to a dictionary.
|
||||
"""
|
||||
|
||||
def test_onlyObject(self):
|
||||
"""
|
||||
L{prefixedMethods} returns a list of the methods discovered on an
|
||||
object.
|
||||
"""
|
||||
x = Base()
|
||||
output = prefixedMethods(x)
|
||||
self.assertEqual([x.method], output)
|
||||
|
||||
def test_prefix(self):
|
||||
"""
|
||||
If a prefix is given, L{prefixedMethods} returns only methods named
|
||||
with that prefix.
|
||||
"""
|
||||
x = Separate()
|
||||
output = prefixedMethods(x, "good_")
|
||||
self.assertEqual([x.good_method], output)
|
||||
|
||||
|
||||
class PrefixedMethodNamesTests(TestCase):
|
||||
"""
|
||||
Tests for L{prefixedMethodNames}.
|
||||
"""
|
||||
|
||||
def test_method(self):
|
||||
"""
|
||||
L{prefixedMethodNames} returns a list including methods with the given
|
||||
prefix defined on the class passed to it.
|
||||
"""
|
||||
self.assertEqual(["method"], prefixedMethodNames(Separate, "good_"))
|
||||
|
||||
def test_inheritedMethod(self):
|
||||
"""
|
||||
L{prefixedMethodNames} returns a list included methods with the given
|
||||
prefix defined on base classes of the class passed to it.
|
||||
"""
|
||||
|
||||
class Child(Separate):
|
||||
pass
|
||||
|
||||
self.assertEqual(["method"], prefixedMethodNames(Child, "good_"))
|
||||
|
||||
|
||||
class AddMethodNamesToDictTests(TestCase):
|
||||
"""
|
||||
Tests for L{addMethodNamesToDict}.
|
||||
"""
|
||||
|
||||
def test_baseClass(self):
|
||||
"""
|
||||
If C{baseClass} is passed to L{addMethodNamesToDict}, only methods which
|
||||
are a subclass of C{baseClass} are added to the result dictionary.
|
||||
"""
|
||||
|
||||
class Alternate:
|
||||
pass
|
||||
|
||||
class Child(Separate, Alternate):
|
||||
def good_alternate(self):
|
||||
pass
|
||||
|
||||
result = {}
|
||||
addMethodNamesToDict(Child, result, "good_", Alternate)
|
||||
self.assertEqual({"alternate": 1}, result)
|
||||
|
||||
|
||||
class Summer:
|
||||
"""
|
||||
A class we look up as part of the LookupsTests.
|
||||
"""
|
||||
|
||||
def reallySet(self):
|
||||
"""
|
||||
Do something.
|
||||
"""
|
||||
|
||||
|
||||
class LookupsTests(TestCase):
|
||||
"""
|
||||
Tests for L{namedClass}, L{namedModule}, and L{namedAny}.
|
||||
"""
|
||||
|
||||
def test_namedClassLookup(self):
|
||||
"""
|
||||
L{namedClass} should return the class object for the name it is passed.
|
||||
"""
|
||||
self.assertIs(reflect.namedClass("twisted.test.test_reflect.Summer"), Summer)
|
||||
|
||||
def test_namedModuleLookup(self):
|
||||
"""
|
||||
L{namedModule} should return the module object for the name it is
|
||||
passed.
|
||||
"""
|
||||
from twisted.python import monkey
|
||||
|
||||
self.assertIs(reflect.namedModule("twisted.python.monkey"), monkey)
|
||||
|
||||
def test_namedAnyPackageLookup(self):
|
||||
"""
|
||||
L{namedAny} should return the package object for the name it is passed.
|
||||
"""
|
||||
import twisted.python
|
||||
|
||||
self.assertIs(reflect.namedAny("twisted.python"), twisted.python)
|
||||
|
||||
def test_namedAnyModuleLookup(self):
|
||||
"""
|
||||
L{namedAny} should return the module object for the name it is passed.
|
||||
"""
|
||||
from twisted.python import monkey
|
||||
|
||||
self.assertIs(reflect.namedAny("twisted.python.monkey"), monkey)
|
||||
|
||||
def test_namedAnyClassLookup(self):
|
||||
"""
|
||||
L{namedAny} should return the class object for the name it is passed.
|
||||
"""
|
||||
self.assertIs(reflect.namedAny("twisted.test.test_reflect.Summer"), Summer)
|
||||
|
||||
def test_namedAnyAttributeLookup(self):
|
||||
"""
|
||||
L{namedAny} should return the object an attribute of a non-module,
|
||||
non-package object is bound to for the name it is passed.
|
||||
"""
|
||||
# Note - not assertIs because unbound method lookup creates a new
|
||||
# object every time. This is a foolishness of Python's object
|
||||
# implementation, not a bug in Twisted.
|
||||
self.assertEqual(
|
||||
reflect.namedAny("twisted.test.test_reflect.Summer.reallySet"),
|
||||
Summer.reallySet,
|
||||
)
|
||||
|
||||
def test_namedAnySecondAttributeLookup(self):
|
||||
"""
|
||||
L{namedAny} should return the object an attribute of an object which
|
||||
itself was an attribute of a non-module, non-package object is bound to
|
||||
for the name it is passed.
|
||||
"""
|
||||
self.assertIs(
|
||||
reflect.namedAny("twisted.test.test_reflect." "Summer.reallySet.__doc__"),
|
||||
Summer.reallySet.__doc__,
|
||||
)
|
||||
|
||||
def test_importExceptions(self):
|
||||
"""
|
||||
Exceptions raised by modules which L{namedAny} causes to be imported
|
||||
should pass through L{namedAny} to the caller.
|
||||
"""
|
||||
self.assertRaises(
|
||||
ZeroDivisionError, reflect.namedAny, "twisted.test.reflect_helper_ZDE"
|
||||
)
|
||||
# Make sure that there is post-failed-import cleanup
|
||||
self.assertRaises(
|
||||
ZeroDivisionError, reflect.namedAny, "twisted.test.reflect_helper_ZDE"
|
||||
)
|
||||
self.assertRaises(
|
||||
ValueError, reflect.namedAny, "twisted.test.reflect_helper_VE"
|
||||
)
|
||||
# Modules which themselves raise ImportError when imported should
|
||||
# result in an ImportError
|
||||
self.assertRaises(
|
||||
ImportError, reflect.namedAny, "twisted.test.reflect_helper_IE"
|
||||
)
|
||||
|
||||
def test_attributeExceptions(self):
|
||||
"""
|
||||
If segments on the end of a fully-qualified Python name represents
|
||||
attributes which aren't actually present on the object represented by
|
||||
the earlier segments, L{namedAny} should raise an L{AttributeError}.
|
||||
"""
|
||||
self.assertRaises(
|
||||
AttributeError, reflect.namedAny, "twisted.nosuchmoduleintheworld"
|
||||
)
|
||||
# ImportError behaves somewhat differently between "import
|
||||
# extant.nonextant" and "import extant.nonextant.nonextant", so test
|
||||
# the latter as well.
|
||||
self.assertRaises(
|
||||
AttributeError, reflect.namedAny, "twisted.nosuch.modulein.theworld"
|
||||
)
|
||||
self.assertRaises(
|
||||
AttributeError,
|
||||
reflect.namedAny,
|
||||
"twisted.test.test_reflect.Summer.nosuchattribute",
|
||||
)
|
||||
|
||||
def test_invalidNames(self):
|
||||
"""
|
||||
Passing a name which isn't a fully-qualified Python name to L{namedAny}
|
||||
should result in one of the following exceptions:
|
||||
- L{InvalidName}: the name is not a dot-separated list of Python
|
||||
objects
|
||||
- L{ObjectNotFound}: the object doesn't exist
|
||||
- L{ModuleNotFound}: the object doesn't exist and there is only one
|
||||
component in the name
|
||||
"""
|
||||
err = self.assertRaises(
|
||||
reflect.ModuleNotFound, reflect.namedAny, "nosuchmoduleintheworld"
|
||||
)
|
||||
self.assertEqual(str(err), "No module named 'nosuchmoduleintheworld'")
|
||||
|
||||
# This is a dot-separated list, but it isn't valid!
|
||||
err = self.assertRaises(
|
||||
reflect.ObjectNotFound, reflect.namedAny, "@#$@(#.!@(#!@#"
|
||||
)
|
||||
self.assertEqual(str(err), "'@#$@(#.!@(#!@#' does not name an object")
|
||||
|
||||
err = self.assertRaises(
|
||||
reflect.ObjectNotFound, reflect.namedAny, "tcelfer.nohtyp.detsiwt"
|
||||
)
|
||||
self.assertEqual(str(err), "'tcelfer.nohtyp.detsiwt' does not name an object")
|
||||
|
||||
err = self.assertRaises(reflect.InvalidName, reflect.namedAny, "")
|
||||
self.assertEqual(str(err), "Empty module name")
|
||||
|
||||
for invalidName in [".twisted", "twisted.", "twisted..python"]:
|
||||
err = self.assertRaises(reflect.InvalidName, reflect.namedAny, invalidName)
|
||||
self.assertEqual(
|
||||
str(err),
|
||||
"name must be a string giving a '.'-separated list of Python "
|
||||
"identifiers, not %r" % (invalidName,),
|
||||
)
|
||||
|
||||
def test_requireModuleImportError(self):
|
||||
"""
|
||||
When module import fails with ImportError it returns the specified
|
||||
default value.
|
||||
"""
|
||||
for name in ["nosuchmtopodule", "no.such.module"]:
|
||||
default = object()
|
||||
|
||||
result = reflect.requireModule(name, default=default)
|
||||
|
||||
self.assertIs(result, default)
|
||||
|
||||
def test_requireModuleDefaultNone(self):
|
||||
"""
|
||||
When module import fails it returns L{None} by default.
|
||||
"""
|
||||
result = reflect.requireModule("no.such.module")
|
||||
|
||||
self.assertIsNone(result)
|
||||
|
||||
def test_requireModuleRequestedImport(self):
|
||||
"""
|
||||
When module import succeed it returns the module and not the default
|
||||
value.
|
||||
"""
|
||||
from twisted.python import monkey
|
||||
|
||||
default = object()
|
||||
|
||||
self.assertIs(
|
||||
reflect.requireModule("twisted.python.monkey", default=default),
|
||||
monkey,
|
||||
)
|
||||
|
||||
|
||||
class Breakable:
|
||||
breakRepr = False
|
||||
breakStr = False
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self.breakStr:
|
||||
raise RuntimeError("str!")
|
||||
else:
|
||||
return "<Breakable>"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
if self.breakRepr:
|
||||
raise RuntimeError("repr!")
|
||||
else:
|
||||
return "Breakable()"
|
||||
|
||||
|
||||
class BrokenType(Breakable, type):
|
||||
breakName = False
|
||||
|
||||
@property
|
||||
def __name__(self):
|
||||
if self.breakName:
|
||||
raise RuntimeError("no name")
|
||||
return "BrokenType"
|
||||
|
||||
|
||||
BTBase = BrokenType("BTBase", (Breakable,), {"breakRepr": True, "breakStr": True})
|
||||
|
||||
|
||||
class NoClassAttr(Breakable):
|
||||
__class__ = property(lambda x: x.not_class) # type: ignore[assignment]
|
||||
|
||||
|
||||
class SafeReprTests(TestCase):
|
||||
"""
|
||||
Tests for L{reflect.safe_repr} function.
|
||||
"""
|
||||
|
||||
def test_workingRepr(self):
|
||||
"""
|
||||
L{reflect.safe_repr} produces the same output as C{repr} on a working
|
||||
object.
|
||||
"""
|
||||
xs = ([1, 2, 3], b"a")
|
||||
self.assertEqual(list(map(reflect.safe_repr, xs)), list(map(repr, xs)))
|
||||
|
||||
def test_brokenRepr(self):
|
||||
"""
|
||||
L{reflect.safe_repr} returns a string with class name, address, and
|
||||
traceback when the repr call failed.
|
||||
"""
|
||||
b = Breakable()
|
||||
b.breakRepr = True
|
||||
bRepr = reflect.safe_repr(b)
|
||||
self.assertIn("Breakable instance at 0x", bRepr)
|
||||
# Check that the file is in the repr, but without the extension as it
|
||||
# can be .py/.pyc
|
||||
self.assertIn(os.path.splitext(__file__)[0], bRepr)
|
||||
self.assertIn("RuntimeError: repr!", bRepr)
|
||||
|
||||
def test_brokenStr(self):
|
||||
"""
|
||||
L{reflect.safe_repr} isn't affected by a broken C{__str__} method.
|
||||
"""
|
||||
b = Breakable()
|
||||
b.breakStr = True
|
||||
self.assertEqual(reflect.safe_repr(b), repr(b))
|
||||
|
||||
def test_brokenClassRepr(self):
|
||||
class X(BTBase):
|
||||
breakRepr = True
|
||||
|
||||
reflect.safe_repr(X)
|
||||
reflect.safe_repr(X())
|
||||
|
||||
def test_brokenReprIncludesID(self):
|
||||
"""
|
||||
C{id} is used to print the ID of the object in case of an error.
|
||||
|
||||
L{safe_repr} includes a traceback after a newline, so we only check
|
||||
against the first line of the repr.
|
||||
"""
|
||||
|
||||
class X(BTBase):
|
||||
breakRepr = True
|
||||
|
||||
xRepr = reflect.safe_repr(X)
|
||||
xReprExpected = f"<BrokenType instance at 0x{id(X):x} with repr error:"
|
||||
self.assertEqual(xReprExpected, xRepr.split("\n")[0])
|
||||
|
||||
def test_brokenClassStr(self):
|
||||
class X(BTBase):
|
||||
breakStr = True
|
||||
|
||||
reflect.safe_repr(X)
|
||||
reflect.safe_repr(X())
|
||||
|
||||
def test_brokenClassAttribute(self):
|
||||
"""
|
||||
If an object raises an exception when accessing its C{__class__}
|
||||
attribute, L{reflect.safe_repr} uses C{type} to retrieve the class
|
||||
object.
|
||||
"""
|
||||
b = NoClassAttr()
|
||||
b.breakRepr = True
|
||||
bRepr = reflect.safe_repr(b)
|
||||
self.assertIn("NoClassAttr instance at 0x", bRepr)
|
||||
self.assertIn(os.path.splitext(__file__)[0], bRepr)
|
||||
self.assertIn("RuntimeError: repr!", bRepr)
|
||||
|
||||
def test_brokenClassNameAttribute(self):
|
||||
"""
|
||||
If a class raises an exception when accessing its C{__name__} attribute
|
||||
B{and} when calling its C{__str__} implementation, L{reflect.safe_repr}
|
||||
returns 'BROKEN CLASS' instead of the class name.
|
||||
"""
|
||||
|
||||
class X(BTBase):
|
||||
breakName = True
|
||||
|
||||
xRepr = reflect.safe_repr(X())
|
||||
self.assertIn("<BROKEN CLASS AT 0x", xRepr)
|
||||
self.assertIn(os.path.splitext(__file__)[0], xRepr)
|
||||
self.assertIn("RuntimeError: repr!", xRepr)
|
||||
|
||||
|
||||
class SafeStrTests(TestCase):
|
||||
"""
|
||||
Tests for L{reflect.safe_str} function.
|
||||
"""
|
||||
|
||||
def test_workingStr(self):
|
||||
x = [1, 2, 3]
|
||||
self.assertEqual(reflect.safe_str(x), str(x))
|
||||
|
||||
def test_brokenStr(self):
|
||||
b = Breakable()
|
||||
b.breakStr = True
|
||||
reflect.safe_str(b)
|
||||
|
||||
def test_workingAscii(self):
|
||||
"""
|
||||
L{safe_str} for C{str} with ascii-only data should return the
|
||||
value unchanged.
|
||||
"""
|
||||
x = "a"
|
||||
self.assertEqual(reflect.safe_str(x), "a")
|
||||
|
||||
def test_workingUtf8_3(self):
|
||||
"""
|
||||
L{safe_str} for C{bytes} with utf-8 encoded data should return
|
||||
the value decoded into C{str}.
|
||||
"""
|
||||
x = b"t\xc3\xbcst"
|
||||
self.assertEqual(reflect.safe_str(x), x.decode("utf-8"))
|
||||
|
||||
def test_brokenUtf8(self):
|
||||
"""
|
||||
Use str() for non-utf8 bytes: "b'non-utf8'"
|
||||
"""
|
||||
x = b"\xff"
|
||||
xStr = reflect.safe_str(x)
|
||||
self.assertEqual(xStr, str(x))
|
||||
|
||||
def test_brokenRepr(self):
|
||||
b = Breakable()
|
||||
b.breakRepr = True
|
||||
reflect.safe_str(b)
|
||||
|
||||
def test_brokenClassStr(self):
|
||||
class X(BTBase):
|
||||
breakStr = True
|
||||
|
||||
reflect.safe_str(X)
|
||||
reflect.safe_str(X())
|
||||
|
||||
def test_brokenClassRepr(self):
|
||||
class X(BTBase):
|
||||
breakRepr = True
|
||||
|
||||
reflect.safe_str(X)
|
||||
reflect.safe_str(X())
|
||||
|
||||
def test_brokenClassAttribute(self):
|
||||
"""
|
||||
If an object raises an exception when accessing its C{__class__}
|
||||
attribute, L{reflect.safe_str} uses C{type} to retrieve the class
|
||||
object.
|
||||
"""
|
||||
b = NoClassAttr()
|
||||
b.breakStr = True
|
||||
bStr = reflect.safe_str(b)
|
||||
self.assertIn("NoClassAttr instance at 0x", bStr)
|
||||
self.assertIn(os.path.splitext(__file__)[0], bStr)
|
||||
self.assertIn("RuntimeError: str!", bStr)
|
||||
|
||||
def test_brokenClassNameAttribute(self):
|
||||
"""
|
||||
If a class raises an exception when accessing its C{__name__} attribute
|
||||
B{and} when calling its C{__str__} implementation, L{reflect.safe_str}
|
||||
returns 'BROKEN CLASS' instead of the class name.
|
||||
"""
|
||||
|
||||
class X(BTBase):
|
||||
breakName = True
|
||||
|
||||
xStr = reflect.safe_str(X())
|
||||
self.assertIn("<BROKEN CLASS AT 0x", xStr)
|
||||
self.assertIn(os.path.splitext(__file__)[0], xStr)
|
||||
self.assertIn("RuntimeError: str!", xStr)
|
||||
|
||||
|
||||
class FilenameToModuleTests(TestCase):
|
||||
"""
|
||||
Test L{filenameToModuleName} detection.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
self.path = os.path.join(self.mktemp(), "fakepackage", "test")
|
||||
os.makedirs(self.path)
|
||||
with open(os.path.join(self.path, "__init__.py"), "w") as f:
|
||||
f.write("")
|
||||
with open(os.path.join(os.path.dirname(self.path), "__init__.py"), "w") as f:
|
||||
f.write("")
|
||||
|
||||
def test_directory(self):
|
||||
"""
|
||||
L{filenameToModuleName} returns the correct module (a package) given a
|
||||
directory.
|
||||
"""
|
||||
module = reflect.filenameToModuleName(self.path)
|
||||
self.assertEqual(module, "fakepackage.test")
|
||||
module = reflect.filenameToModuleName(self.path + os.path.sep)
|
||||
self.assertEqual(module, "fakepackage.test")
|
||||
|
||||
def test_file(self):
|
||||
"""
|
||||
L{filenameToModuleName} returns the correct module given the path to
|
||||
its file.
|
||||
"""
|
||||
module = reflect.filenameToModuleName(
|
||||
os.path.join(self.path, "test_reflect.py")
|
||||
)
|
||||
self.assertEqual(module, "fakepackage.test.test_reflect")
|
||||
|
||||
def test_bytes(self):
|
||||
"""
|
||||
L{filenameToModuleName} returns the correct module given a C{bytes}
|
||||
path to its file.
|
||||
"""
|
||||
module = reflect.filenameToModuleName(
|
||||
os.path.join(self.path.encode("utf-8"), b"test_reflect.py")
|
||||
)
|
||||
# Module names are always native string:
|
||||
self.assertEqual(module, "fakepackage.test.test_reflect")
|
||||
|
||||
|
||||
class FullyQualifiedNameTests(TestCase):
|
||||
"""
|
||||
Test for L{fullyQualifiedName}.
|
||||
"""
|
||||
|
||||
def _checkFullyQualifiedName(self, obj, expected):
|
||||
"""
|
||||
Helper to check that fully qualified name of C{obj} results to
|
||||
C{expected}.
|
||||
"""
|
||||
self.assertEqual(fullyQualifiedName(obj), expected)
|
||||
|
||||
def test_package(self):
|
||||
"""
|
||||
L{fullyQualifiedName} returns the full name of a package and a
|
||||
subpackage.
|
||||
"""
|
||||
import twisted
|
||||
|
||||
self._checkFullyQualifiedName(twisted, "twisted")
|
||||
import twisted.python
|
||||
|
||||
self._checkFullyQualifiedName(twisted.python, "twisted.python")
|
||||
|
||||
def test_module(self):
|
||||
"""
|
||||
L{fullyQualifiedName} returns the name of a module inside a package.
|
||||
"""
|
||||
import twisted.python.compat
|
||||
|
||||
self._checkFullyQualifiedName(twisted.python.compat, "twisted.python.compat")
|
||||
|
||||
def test_class(self):
|
||||
"""
|
||||
L{fullyQualifiedName} returns the name of a class and its module.
|
||||
"""
|
||||
self._checkFullyQualifiedName(
|
||||
FullyQualifiedNameTests, f"{__name__}.FullyQualifiedNameTests"
|
||||
)
|
||||
|
||||
def test_function(self):
|
||||
"""
|
||||
L{fullyQualifiedName} returns the name of a function inside its module.
|
||||
"""
|
||||
self._checkFullyQualifiedName(
|
||||
fullyQualifiedName, "twisted.python.reflect.fullyQualifiedName"
|
||||
)
|
||||
|
||||
def test_boundMethod(self):
|
||||
"""
|
||||
L{fullyQualifiedName} returns the name of a bound method inside its
|
||||
class and its module.
|
||||
"""
|
||||
self._checkFullyQualifiedName(
|
||||
self.test_boundMethod,
|
||||
f"{__name__}.{self.__class__.__name__}.test_boundMethod",
|
||||
)
|
||||
|
||||
def test_unboundMethod(self):
|
||||
"""
|
||||
L{fullyQualifiedName} returns the name of an unbound method inside its
|
||||
class and its module.
|
||||
"""
|
||||
self._checkFullyQualifiedName(
|
||||
self.__class__.test_unboundMethod,
|
||||
f"{__name__}.{self.__class__.__name__}.test_unboundMethod",
|
||||
)
|
||||
|
||||
|
||||
class ObjectGrepTests(TestCase):
|
||||
def test_dictionary(self):
|
||||
"""
|
||||
Test references search through a dictionary, as a key or as a value.
|
||||
"""
|
||||
o = object()
|
||||
d1 = {None: o}
|
||||
d2 = {o: None}
|
||||
|
||||
self.assertIn("[None]", reflect.objgrep(d1, o, reflect.isSame))
|
||||
self.assertIn("{None}", reflect.objgrep(d2, o, reflect.isSame))
|
||||
|
||||
def test_list(self):
|
||||
"""
|
||||
Test references search through a list.
|
||||
"""
|
||||
o = object()
|
||||
L = [None, o]
|
||||
|
||||
self.assertIn("[1]", reflect.objgrep(L, o, reflect.isSame))
|
||||
|
||||
def test_tuple(self):
|
||||
"""
|
||||
Test references search through a tuple.
|
||||
"""
|
||||
o = object()
|
||||
T = (o, None)
|
||||
|
||||
self.assertIn("[0]", reflect.objgrep(T, o, reflect.isSame))
|
||||
|
||||
def test_instance(self):
|
||||
"""
|
||||
Test references search through an object attribute.
|
||||
"""
|
||||
|
||||
class Dummy:
|
||||
pass
|
||||
|
||||
o = object()
|
||||
d = Dummy()
|
||||
d.o = o
|
||||
|
||||
self.assertIn(".o", reflect.objgrep(d, o, reflect.isSame))
|
||||
|
||||
def test_weakref(self):
|
||||
"""
|
||||
Test references search through a weakref object.
|
||||
"""
|
||||
|
||||
class Dummy:
|
||||
pass
|
||||
|
||||
o = Dummy()
|
||||
w1 = weakref.ref(o)
|
||||
|
||||
self.assertIn("()", reflect.objgrep(w1, o, reflect.isSame))
|
||||
|
||||
def test_boundMethod(self):
|
||||
"""
|
||||
Test references search through method special attributes.
|
||||
"""
|
||||
|
||||
class Dummy:
|
||||
def dummy(self):
|
||||
pass
|
||||
|
||||
o = Dummy()
|
||||
m = o.dummy
|
||||
|
||||
self.assertIn(".__self__", reflect.objgrep(m, m.__self__, reflect.isSame))
|
||||
self.assertIn(
|
||||
".__self__.__class__",
|
||||
reflect.objgrep(m, m.__self__.__class__, reflect.isSame),
|
||||
)
|
||||
self.assertIn(".__func__", reflect.objgrep(m, m.__func__, reflect.isSame))
|
||||
|
||||
def test_everything(self):
|
||||
"""
|
||||
Test references search using complex set of objects.
|
||||
"""
|
||||
|
||||
class Dummy:
|
||||
def method(self):
|
||||
pass
|
||||
|
||||
o = Dummy()
|
||||
D1 = {(): "baz", None: "Quux", o: "Foosh"}
|
||||
L = [None, (), D1, 3]
|
||||
T = (L, {}, Dummy())
|
||||
D2 = {0: "foo", 1: "bar", 2: T}
|
||||
i = Dummy()
|
||||
i.attr = D2
|
||||
m = i.method
|
||||
w = weakref.ref(m)
|
||||
|
||||
self.assertIn(
|
||||
"().__self__.attr[2][0][2]{'Foosh'}", reflect.objgrep(w, o, reflect.isSame)
|
||||
)
|
||||
|
||||
def test_depthLimit(self):
|
||||
"""
|
||||
Test the depth of references search.
|
||||
"""
|
||||
a = []
|
||||
b = [a]
|
||||
c = [a, b]
|
||||
d = [a, c]
|
||||
|
||||
self.assertEqual(["[0]"], reflect.objgrep(d, a, reflect.isSame, maxDepth=1))
|
||||
self.assertEqual(
|
||||
["[0]", "[1][0]"], reflect.objgrep(d, a, reflect.isSame, maxDepth=2)
|
||||
)
|
||||
self.assertEqual(
|
||||
["[0]", "[1][0]", "[1][1][0]"],
|
||||
reflect.objgrep(d, a, reflect.isSame, maxDepth=3),
|
||||
)
|
||||
|
||||
def test_deque(self):
|
||||
"""
|
||||
Test references search through a deque object.
|
||||
"""
|
||||
o = object()
|
||||
D = deque()
|
||||
D.append(None)
|
||||
D.append(o)
|
||||
|
||||
self.assertIn("[1]", reflect.objgrep(D, o, reflect.isSame))
|
||||
|
||||
|
||||
class GetClassTests(TestCase):
|
||||
def test_new(self):
|
||||
class NewClass:
|
||||
pass
|
||||
|
||||
new = NewClass()
|
||||
self.assertEqual(reflect.getClass(NewClass).__name__, "type")
|
||||
self.assertEqual(reflect.getClass(new).__name__, "NewClass")
|
||||
@@ -0,0 +1,58 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
|
||||
from twisted.python import roots
|
||||
from twisted.trial import unittest
|
||||
|
||||
|
||||
class RootsTests(unittest.TestCase):
|
||||
def testExceptions(self) -> None:
|
||||
request = roots.Request()
|
||||
try:
|
||||
request.write(b"blah")
|
||||
except NotImplementedError:
|
||||
pass
|
||||
else:
|
||||
self.fail()
|
||||
try:
|
||||
request.finish()
|
||||
except NotImplementedError:
|
||||
pass
|
||||
else:
|
||||
self.fail()
|
||||
|
||||
def testCollection(self) -> None:
|
||||
collection = roots.Collection()
|
||||
collection.putEntity("x", "test")
|
||||
self.assertEqual(collection.getStaticEntity("x"), "test")
|
||||
collection.delEntity("x")
|
||||
self.assertEqual(collection.getStaticEntity("x"), None)
|
||||
try:
|
||||
collection.storeEntity("x", None)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
else:
|
||||
self.fail()
|
||||
try:
|
||||
collection.removeEntity("x", None)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
else:
|
||||
self.fail()
|
||||
|
||||
def testConstrained(self) -> None:
|
||||
class const(roots.Constrained):
|
||||
def nameConstraint(self, name: str) -> bool:
|
||||
return name == "x"
|
||||
|
||||
c = const()
|
||||
self.assertIsNone(c.putEntity("x", "test"))
|
||||
self.assertRaises(roots.ConstraintViolation, c.putEntity, "y", "test")
|
||||
|
||||
def testHomogenous(self) -> None:
|
||||
h = roots.Homogenous()
|
||||
h.entityType = int
|
||||
h.putEntity("a", 1)
|
||||
self.assertEqual(h.getStaticEntity("a"), 1)
|
||||
self.assertRaises(roots.ConstraintViolation, h.putEntity, "x", "y")
|
||||
@@ -0,0 +1,64 @@
|
||||
"""
|
||||
Test win32 shortcut script
|
||||
"""
|
||||
|
||||
import os.path
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
from twisted.trial import unittest
|
||||
|
||||
skipReason = None
|
||||
try:
|
||||
from win32com.shell import shell
|
||||
|
||||
from twisted.python import shortcut
|
||||
except ImportError:
|
||||
skipReason = "Only runs on Windows with win32com"
|
||||
|
||||
if sys.version_info[0:2] >= (3, 7):
|
||||
skipReason = "Broken on Python 3.7+."
|
||||
|
||||
|
||||
class ShortcutTests(unittest.TestCase):
|
||||
skip = skipReason
|
||||
|
||||
def test_create(self) -> None:
|
||||
"""
|
||||
Create a simple shortcut.
|
||||
"""
|
||||
testFilename = __file__
|
||||
baseFileName = os.path.basename(testFilename)
|
||||
s1 = shortcut.Shortcut(testFilename)
|
||||
tempname = self.mktemp() + ".lnk"
|
||||
s1.save(tempname)
|
||||
self.assertTrue(os.path.exists(tempname))
|
||||
sc = shortcut.open(tempname)
|
||||
scPath = sc.GetPath(shell.SLGP_RAWPATH)[0]
|
||||
self.assertEqual(scPath[-len(baseFileName) :].lower(), baseFileName.lower())
|
||||
|
||||
def test_createPythonShortcut(self) -> None:
|
||||
"""
|
||||
Create a shortcut to the Python executable,
|
||||
and set some values.
|
||||
"""
|
||||
testFilename = sys.executable
|
||||
baseFileName = os.path.basename(testFilename)
|
||||
tempDir = tempfile.gettempdir()
|
||||
s1 = shortcut.Shortcut(
|
||||
path=testFilename,
|
||||
arguments="-V",
|
||||
description="The Python executable",
|
||||
workingdir=tempDir,
|
||||
iconpath=tempDir,
|
||||
iconidx=1,
|
||||
)
|
||||
tempname = self.mktemp() + ".lnk"
|
||||
s1.save(tempname)
|
||||
self.assertTrue(os.path.exists(tempname))
|
||||
sc = shortcut.open(tempname)
|
||||
scPath = sc.GetPath(shell.SLGP_RAWPATH)[0]
|
||||
self.assertEqual(scPath[-len(baseFileName) :].lower(), baseFileName.lower())
|
||||
self.assertEqual(sc.GetDescription(), "The Python executable")
|
||||
self.assertEqual(sc.GetWorkingDirectory(), tempDir)
|
||||
self.assertEqual(sc.GetIconLocation(), (tempDir, 1))
|
||||
780
.venv/lib/python3.12/site-packages/twisted/test/test_sip.py
Normal file
780
.venv/lib/python3.12/site-packages/twisted/test/test_sip.py
Normal file
@@ -0,0 +1,780 @@
|
||||
# -*- test-case-name: twisted.test.test_sip -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Session Initialization Protocol tests.
|
||||
"""
|
||||
|
||||
from twisted.cred import checkers, portal
|
||||
from twisted.internet import defer, reactor
|
||||
from twisted.protocols import sip
|
||||
from twisted.trial import unittest
|
||||
|
||||
try:
|
||||
from twisted.internet.asyncioreactor import AsyncioSelectorReactor
|
||||
except BaseException:
|
||||
AsyncioSelectorReactor = None # type: ignore[assignment,misc]
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
# request, prefixed by random CRLFs
|
||||
request1 = (
|
||||
"\n\r\n\n\r"
|
||||
+ """\
|
||||
INVITE sip:foo SIP/2.0
|
||||
From: mo
|
||||
To: joe
|
||||
Content-Length: 4
|
||||
|
||||
abcd""".replace(
|
||||
"\n", "\r\n"
|
||||
)
|
||||
)
|
||||
|
||||
# request, no content-length
|
||||
request2 = """INVITE sip:foo SIP/2.0
|
||||
From: mo
|
||||
To: joe
|
||||
|
||||
1234""".replace(
|
||||
"\n", "\r\n"
|
||||
)
|
||||
|
||||
# request, with garbage after
|
||||
request3 = """INVITE sip:foo SIP/2.0
|
||||
From: mo
|
||||
To: joe
|
||||
Content-Length: 4
|
||||
|
||||
1234
|
||||
|
||||
lalalal""".replace(
|
||||
"\n", "\r\n"
|
||||
)
|
||||
|
||||
# three requests
|
||||
request4 = """INVITE sip:foo SIP/2.0
|
||||
From: mo
|
||||
To: joe
|
||||
Content-Length: 0
|
||||
|
||||
INVITE sip:loop SIP/2.0
|
||||
From: foo
|
||||
To: bar
|
||||
Content-Length: 4
|
||||
|
||||
abcdINVITE sip:loop SIP/2.0
|
||||
From: foo
|
||||
To: bar
|
||||
Content-Length: 4
|
||||
|
||||
1234""".replace(
|
||||
"\n", "\r\n"
|
||||
)
|
||||
|
||||
# response, no content
|
||||
response1 = """SIP/2.0 200 OK
|
||||
From: foo
|
||||
To:bar
|
||||
Content-Length: 0
|
||||
|
||||
""".replace(
|
||||
"\n", "\r\n"
|
||||
)
|
||||
|
||||
# short header version
|
||||
request_short = """\
|
||||
INVITE sip:foo SIP/2.0
|
||||
f: mo
|
||||
t: joe
|
||||
l: 4
|
||||
|
||||
abcd""".replace(
|
||||
"\n", "\r\n"
|
||||
)
|
||||
|
||||
request_natted = """\
|
||||
INVITE sip:foo SIP/2.0
|
||||
Via: SIP/2.0/UDP 10.0.0.1:5060;rport
|
||||
|
||||
""".replace(
|
||||
"\n", "\r\n"
|
||||
)
|
||||
|
||||
# multiline headers (example from RFC 3621).
|
||||
response_multiline = """\
|
||||
SIP/2.0 200 OK
|
||||
Via: SIP/2.0/UDP server10.biloxi.com
|
||||
;branch=z9hG4bKnashds8;received=192.0.2.3
|
||||
Via: SIP/2.0/UDP bigbox3.site3.atlanta.com
|
||||
;branch=z9hG4bK77ef4c2312983.1;received=192.0.2.2
|
||||
Via: SIP/2.0/UDP pc33.atlanta.com
|
||||
;branch=z9hG4bK776asdhds ;received=192.0.2.1
|
||||
To: Bob <sip:bob@biloxi.com>;tag=a6c85cf
|
||||
From: Alice <sip:alice@atlanta.com>;tag=1928301774
|
||||
Call-ID: a84b4c76e66710@pc33.atlanta.com
|
||||
CSeq: 314159 INVITE
|
||||
Contact: <sip:bob@192.0.2.4>
|
||||
Content-Type: application/sdp
|
||||
Content-Length: 0
|
||||
\n""".replace(
|
||||
"\n", "\r\n"
|
||||
)
|
||||
|
||||
|
||||
class TestRealm:
|
||||
def requestAvatar(self, avatarId, mind, *interfaces):
|
||||
return sip.IContact, None, lambda: None
|
||||
|
||||
|
||||
class MessageParsingTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.l = []
|
||||
self.parser = sip.MessagesParser(self.l.append)
|
||||
|
||||
def feedMessage(self, message):
|
||||
self.parser.dataReceived(message)
|
||||
self.parser.dataDone()
|
||||
|
||||
def validateMessage(self, m, method, uri, headers, body):
|
||||
"""
|
||||
Validate Requests.
|
||||
"""
|
||||
self.assertEqual(m.method, method)
|
||||
self.assertEqual(m.uri.toString(), uri)
|
||||
self.assertEqual(m.headers, headers)
|
||||
self.assertEqual(m.body, body)
|
||||
self.assertEqual(m.finished, 1)
|
||||
|
||||
def testSimple(self):
|
||||
l = self.l
|
||||
self.feedMessage(request1)
|
||||
self.assertEqual(len(l), 1)
|
||||
self.validateMessage(
|
||||
l[0],
|
||||
"INVITE",
|
||||
"sip:foo",
|
||||
{"from": ["mo"], "to": ["joe"], "content-length": ["4"]},
|
||||
"abcd",
|
||||
)
|
||||
|
||||
def testTwoMessages(self):
|
||||
l = self.l
|
||||
self.feedMessage(request1)
|
||||
self.feedMessage(request2)
|
||||
self.assertEqual(len(l), 2)
|
||||
self.validateMessage(
|
||||
l[0],
|
||||
"INVITE",
|
||||
"sip:foo",
|
||||
{"from": ["mo"], "to": ["joe"], "content-length": ["4"]},
|
||||
"abcd",
|
||||
)
|
||||
self.validateMessage(
|
||||
l[1], "INVITE", "sip:foo", {"from": ["mo"], "to": ["joe"]}, "1234"
|
||||
)
|
||||
|
||||
def testGarbage(self):
|
||||
l = self.l
|
||||
self.feedMessage(request3)
|
||||
self.assertEqual(len(l), 1)
|
||||
self.validateMessage(
|
||||
l[0],
|
||||
"INVITE",
|
||||
"sip:foo",
|
||||
{"from": ["mo"], "to": ["joe"], "content-length": ["4"]},
|
||||
"1234",
|
||||
)
|
||||
|
||||
def testThreeInOne(self):
|
||||
l = self.l
|
||||
self.feedMessage(request4)
|
||||
self.assertEqual(len(l), 3)
|
||||
self.validateMessage(
|
||||
l[0],
|
||||
"INVITE",
|
||||
"sip:foo",
|
||||
{"from": ["mo"], "to": ["joe"], "content-length": ["0"]},
|
||||
"",
|
||||
)
|
||||
self.validateMessage(
|
||||
l[1],
|
||||
"INVITE",
|
||||
"sip:loop",
|
||||
{"from": ["foo"], "to": ["bar"], "content-length": ["4"]},
|
||||
"abcd",
|
||||
)
|
||||
self.validateMessage(
|
||||
l[2],
|
||||
"INVITE",
|
||||
"sip:loop",
|
||||
{"from": ["foo"], "to": ["bar"], "content-length": ["4"]},
|
||||
"1234",
|
||||
)
|
||||
|
||||
def testShort(self):
|
||||
l = self.l
|
||||
self.feedMessage(request_short)
|
||||
self.assertEqual(len(l), 1)
|
||||
self.validateMessage(
|
||||
l[0],
|
||||
"INVITE",
|
||||
"sip:foo",
|
||||
{"from": ["mo"], "to": ["joe"], "content-length": ["4"]},
|
||||
"abcd",
|
||||
)
|
||||
|
||||
def testSimpleResponse(self):
|
||||
l = self.l
|
||||
self.feedMessage(response1)
|
||||
self.assertEqual(len(l), 1)
|
||||
m = l[0]
|
||||
self.assertEqual(m.code, 200)
|
||||
self.assertEqual(m.phrase, "OK")
|
||||
self.assertEqual(
|
||||
m.headers, {"from": ["foo"], "to": ["bar"], "content-length": ["0"]}
|
||||
)
|
||||
self.assertEqual(m.body, "")
|
||||
self.assertEqual(m.finished, 1)
|
||||
|
||||
def test_multiLine(self):
|
||||
"""
|
||||
A header may be split across multiple lines. Subsequent lines begin
|
||||
with C{" "} or C{"\\t"}.
|
||||
"""
|
||||
l = self.l
|
||||
self.feedMessage(response_multiline)
|
||||
self.assertEqual(len(l), 1)
|
||||
m = l[0]
|
||||
self.assertEqual(
|
||||
m.headers["via"][0],
|
||||
"SIP/2.0/UDP server10.biloxi.com;"
|
||||
"branch=z9hG4bKnashds8;received=192.0.2.3",
|
||||
)
|
||||
self.assertEqual(
|
||||
m.headers["via"][1],
|
||||
"SIP/2.0/UDP bigbox3.site3.atlanta.com;"
|
||||
"branch=z9hG4bK77ef4c2312983.1;received=192.0.2.2",
|
||||
)
|
||||
self.assertEqual(
|
||||
m.headers["via"][2],
|
||||
"SIP/2.0/UDP pc33.atlanta.com;"
|
||||
"branch=z9hG4bK776asdhds ;received=192.0.2.1",
|
||||
)
|
||||
|
||||
|
||||
class MessageParsingFeedDataCharByCharTests(MessageParsingTests):
|
||||
"""
|
||||
Same as base class, but feed data char by char.
|
||||
"""
|
||||
|
||||
def feedMessage(self, message):
|
||||
for c in message:
|
||||
self.parser.dataReceived(c)
|
||||
self.parser.dataDone()
|
||||
|
||||
|
||||
class MakeMessageTests(unittest.TestCase):
|
||||
def testRequest(self):
|
||||
r = sip.Request("INVITE", "sip:foo")
|
||||
r.addHeader("foo", "bar")
|
||||
self.assertEqual(r.toString(), "INVITE sip:foo SIP/2.0\r\nFoo: bar\r\n\r\n")
|
||||
|
||||
def testResponse(self):
|
||||
r = sip.Response(200, "OK")
|
||||
r.addHeader("foo", "bar")
|
||||
r.addHeader("Content-Length", "4")
|
||||
r.bodyDataReceived("1234")
|
||||
self.assertEqual(
|
||||
r.toString(), "SIP/2.0 200 OK\r\nFoo: bar\r\nContent-Length: 4\r\n\r\n1234"
|
||||
)
|
||||
|
||||
def testStatusCode(self):
|
||||
r = sip.Response(200)
|
||||
self.assertEqual(r.toString(), "SIP/2.0 200 OK\r\n\r\n")
|
||||
|
||||
|
||||
class ViaTests(unittest.TestCase):
|
||||
def checkRoundtrip(self, v):
|
||||
s = v.toString()
|
||||
self.assertEqual(s, sip.parseViaHeader(s).toString())
|
||||
|
||||
def testExtraWhitespace(self):
|
||||
v1 = sip.parseViaHeader("SIP/2.0/UDP 192.168.1.1:5060")
|
||||
v2 = sip.parseViaHeader("SIP/2.0/UDP 192.168.1.1:5060")
|
||||
self.assertEqual(v1.transport, v2.transport)
|
||||
self.assertEqual(v1.host, v2.host)
|
||||
self.assertEqual(v1.port, v2.port)
|
||||
|
||||
def test_complex(self):
|
||||
"""
|
||||
Test parsing a Via header with one of everything.
|
||||
"""
|
||||
s = (
|
||||
"SIP/2.0/UDP first.example.com:4000;ttl=16;maddr=224.2.0.1"
|
||||
" ;branch=a7c6a8dlze (Example)"
|
||||
)
|
||||
v = sip.parseViaHeader(s)
|
||||
self.assertEqual(v.transport, "UDP")
|
||||
self.assertEqual(v.host, "first.example.com")
|
||||
self.assertEqual(v.port, 4000)
|
||||
self.assertIsNone(v.rport)
|
||||
self.assertIsNone(v.rportValue)
|
||||
self.assertFalse(v.rportRequested)
|
||||
self.assertEqual(v.ttl, 16)
|
||||
self.assertEqual(v.maddr, "224.2.0.1")
|
||||
self.assertEqual(v.branch, "a7c6a8dlze")
|
||||
self.assertEqual(v.hidden, 0)
|
||||
self.assertEqual(
|
||||
v.toString(),
|
||||
"SIP/2.0/UDP first.example.com:4000"
|
||||
";ttl=16;branch=a7c6a8dlze;maddr=224.2.0.1",
|
||||
)
|
||||
self.checkRoundtrip(v)
|
||||
|
||||
def test_simple(self):
|
||||
"""
|
||||
Test parsing a simple Via header.
|
||||
"""
|
||||
s = "SIP/2.0/UDP example.com;hidden"
|
||||
v = sip.parseViaHeader(s)
|
||||
self.assertEqual(v.transport, "UDP")
|
||||
self.assertEqual(v.host, "example.com")
|
||||
self.assertEqual(v.port, 5060)
|
||||
self.assertIsNone(v.rport)
|
||||
self.assertIsNone(v.rportValue)
|
||||
self.assertFalse(v.rportRequested)
|
||||
self.assertIsNone(v.ttl)
|
||||
self.assertIsNone(v.maddr)
|
||||
self.assertIsNone(v.branch)
|
||||
self.assertTrue(v.hidden)
|
||||
self.assertEqual(v.toString(), "SIP/2.0/UDP example.com:5060;hidden")
|
||||
self.checkRoundtrip(v)
|
||||
|
||||
def testSimpler(self):
|
||||
v = sip.Via("example.com")
|
||||
self.checkRoundtrip(v)
|
||||
|
||||
def test_deprecatedRPort(self):
|
||||
"""
|
||||
Setting rport to True is deprecated, but still produces a Via header
|
||||
with the expected properties.
|
||||
"""
|
||||
v = sip.Via("foo.bar", rport=True)
|
||||
|
||||
warnings = self.flushWarnings(offendingFunctions=[self.test_deprecatedRPort])
|
||||
self.assertEqual(len(warnings), 1)
|
||||
self.assertEqual(
|
||||
warnings[0]["message"], "rport=True is deprecated since Twisted 9.0."
|
||||
)
|
||||
self.assertEqual(warnings[0]["category"], DeprecationWarning)
|
||||
|
||||
self.assertEqual(v.toString(), "SIP/2.0/UDP foo.bar:5060;rport")
|
||||
self.assertTrue(v.rport)
|
||||
self.assertTrue(v.rportRequested)
|
||||
self.assertIsNone(v.rportValue)
|
||||
|
||||
def test_rport(self):
|
||||
"""
|
||||
An rport setting of None should insert the parameter with no value.
|
||||
"""
|
||||
v = sip.Via("foo.bar", rport=None)
|
||||
self.assertEqual(v.toString(), "SIP/2.0/UDP foo.bar:5060;rport")
|
||||
self.assertTrue(v.rportRequested)
|
||||
self.assertIsNone(v.rportValue)
|
||||
|
||||
def test_rportValue(self):
|
||||
"""
|
||||
An rport numeric setting should insert the parameter with the number
|
||||
value given.
|
||||
"""
|
||||
v = sip.Via("foo.bar", rport=1)
|
||||
self.assertEqual(v.toString(), "SIP/2.0/UDP foo.bar:5060;rport=1")
|
||||
self.assertFalse(v.rportRequested)
|
||||
self.assertEqual(v.rportValue, 1)
|
||||
self.assertEqual(v.rport, 1)
|
||||
|
||||
def testNAT(self):
|
||||
s = "SIP/2.0/UDP 10.0.0.1:5060;received=22.13.1.5;rport=12345"
|
||||
v = sip.parseViaHeader(s)
|
||||
self.assertEqual(v.transport, "UDP")
|
||||
self.assertEqual(v.host, "10.0.0.1")
|
||||
self.assertEqual(v.port, 5060)
|
||||
self.assertEqual(v.received, "22.13.1.5")
|
||||
self.assertEqual(v.rport, 12345)
|
||||
|
||||
self.assertNotEqual(v.toString().find("rport=12345"), -1)
|
||||
|
||||
def test_unknownParams(self):
|
||||
"""
|
||||
Parsing and serializing Via headers with unknown parameters should work.
|
||||
"""
|
||||
s = "SIP/2.0/UDP example.com:5060;branch=a12345b;bogus;pie=delicious"
|
||||
v = sip.parseViaHeader(s)
|
||||
self.assertEqual(v.toString(), s)
|
||||
|
||||
|
||||
class URLTests(unittest.TestCase):
|
||||
def testRoundtrip(self):
|
||||
for url in [
|
||||
"sip:j.doe@big.com",
|
||||
"sip:j.doe:secret@big.com;transport=tcp",
|
||||
"sip:j.doe@big.com?subject=project",
|
||||
"sip:example.com",
|
||||
]:
|
||||
self.assertEqual(sip.parseURL(url).toString(), url)
|
||||
|
||||
def testComplex(self):
|
||||
s = (
|
||||
"sip:user:pass@hosta:123;transport=udp;user=phone;method=foo;"
|
||||
"ttl=12;maddr=1.2.3.4;blah;goo=bar?a=b&c=d"
|
||||
)
|
||||
url = sip.parseURL(s)
|
||||
for k, v in [
|
||||
("username", "user"),
|
||||
("password", "pass"),
|
||||
("host", "hosta"),
|
||||
("port", 123),
|
||||
("transport", "udp"),
|
||||
("usertype", "phone"),
|
||||
("method", "foo"),
|
||||
("ttl", 12),
|
||||
("maddr", "1.2.3.4"),
|
||||
("other", ["blah", "goo=bar"]),
|
||||
("headers", {"a": "b", "c": "d"}),
|
||||
]:
|
||||
self.assertEqual(getattr(url, k), v)
|
||||
|
||||
|
||||
class ParseTests(unittest.TestCase):
|
||||
def testParseAddress(self):
|
||||
for address, name, urls, params in [
|
||||
(
|
||||
'"A. G. Bell" <sip:foo@example.com>',
|
||||
"A. G. Bell",
|
||||
"sip:foo@example.com",
|
||||
{},
|
||||
),
|
||||
("Anon <sip:foo@example.com>", "Anon", "sip:foo@example.com", {}),
|
||||
("sip:foo@example.com", "", "sip:foo@example.com", {}),
|
||||
("<sip:foo@example.com>", "", "sip:foo@example.com", {}),
|
||||
(
|
||||
"foo <sip:foo@example.com>;tag=bar;foo=baz",
|
||||
"foo",
|
||||
"sip:foo@example.com",
|
||||
{"tag": "bar", "foo": "baz"},
|
||||
),
|
||||
]:
|
||||
gname, gurl, gparams = sip.parseAddress(address)
|
||||
self.assertEqual(name, gname)
|
||||
self.assertEqual(gurl.toString(), urls)
|
||||
self.assertEqual(gparams, params)
|
||||
|
||||
|
||||
@implementer(sip.ILocator)
|
||||
class DummyLocator:
|
||||
def getAddress(self, logicalURL):
|
||||
return defer.succeed(sip.URL("server.com", port=5060))
|
||||
|
||||
|
||||
@implementer(sip.ILocator)
|
||||
class FailingLocator:
|
||||
def getAddress(self, logicalURL):
|
||||
return defer.fail(LookupError())
|
||||
|
||||
|
||||
class ProxyTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.proxy = sip.Proxy("127.0.0.1")
|
||||
self.proxy.locator = DummyLocator()
|
||||
self.sent = []
|
||||
self.proxy.sendMessage = lambda dest, msg: self.sent.append((dest, msg))
|
||||
|
||||
def testRequestForward(self):
|
||||
r = sip.Request("INVITE", "sip:foo")
|
||||
r.addHeader("via", sip.Via("1.2.3.4").toString())
|
||||
r.addHeader("via", sip.Via("1.2.3.5").toString())
|
||||
r.addHeader("foo", "bar")
|
||||
r.addHeader("to", "<sip:joe@server.com>")
|
||||
r.addHeader("contact", "<sip:joe@1.2.3.5>")
|
||||
self.proxy.datagramReceived(r.toString(), ("1.2.3.4", 5060))
|
||||
self.assertEqual(len(self.sent), 1)
|
||||
dest, m = self.sent[0]
|
||||
self.assertEqual(dest.port, 5060)
|
||||
self.assertEqual(dest.host, "server.com")
|
||||
self.assertEqual(m.uri.toString(), "sip:foo")
|
||||
self.assertEqual(m.method, "INVITE")
|
||||
self.assertEqual(
|
||||
m.headers["via"],
|
||||
[
|
||||
"SIP/2.0/UDP 127.0.0.1:5060",
|
||||
"SIP/2.0/UDP 1.2.3.4:5060",
|
||||
"SIP/2.0/UDP 1.2.3.5:5060",
|
||||
],
|
||||
)
|
||||
|
||||
def testReceivedRequestForward(self):
|
||||
r = sip.Request("INVITE", "sip:foo")
|
||||
r.addHeader("via", sip.Via("1.2.3.4").toString())
|
||||
r.addHeader("foo", "bar")
|
||||
r.addHeader("to", "<sip:joe@server.com>")
|
||||
r.addHeader("contact", "<sip:joe@1.2.3.4>")
|
||||
self.proxy.datagramReceived(r.toString(), ("1.1.1.1", 5060))
|
||||
dest, m = self.sent[0]
|
||||
self.assertEqual(
|
||||
m.headers["via"],
|
||||
["SIP/2.0/UDP 127.0.0.1:5060", "SIP/2.0/UDP 1.2.3.4:5060;received=1.1.1.1"],
|
||||
)
|
||||
|
||||
def testResponseWrongVia(self):
|
||||
# first via must match proxy's address
|
||||
r = sip.Response(200)
|
||||
r.addHeader("via", sip.Via("foo.com").toString())
|
||||
self.proxy.datagramReceived(r.toString(), ("1.1.1.1", 5060))
|
||||
self.assertEqual(len(self.sent), 0)
|
||||
|
||||
def testResponseForward(self):
|
||||
r = sip.Response(200)
|
||||
r.addHeader("via", sip.Via("127.0.0.1").toString())
|
||||
r.addHeader("via", sip.Via("client.com", port=1234).toString())
|
||||
self.proxy.datagramReceived(r.toString(), ("1.1.1.1", 5060))
|
||||
self.assertEqual(len(self.sent), 1)
|
||||
dest, m = self.sent[0]
|
||||
self.assertEqual((dest.host, dest.port), ("client.com", 1234))
|
||||
self.assertEqual(m.code, 200)
|
||||
self.assertEqual(m.headers["via"], ["SIP/2.0/UDP client.com:1234"])
|
||||
|
||||
def testReceivedResponseForward(self):
|
||||
r = sip.Response(200)
|
||||
r.addHeader("via", sip.Via("127.0.0.1").toString())
|
||||
r.addHeader("via", sip.Via("10.0.0.1", received="client.com").toString())
|
||||
self.proxy.datagramReceived(r.toString(), ("1.1.1.1", 5060))
|
||||
self.assertEqual(len(self.sent), 1)
|
||||
dest, m = self.sent[0]
|
||||
self.assertEqual((dest.host, dest.port), ("client.com", 5060))
|
||||
|
||||
def testResponseToUs(self):
|
||||
r = sip.Response(200)
|
||||
r.addHeader("via", sip.Via("127.0.0.1").toString())
|
||||
l = []
|
||||
self.proxy.gotResponse = lambda *a: l.append(a)
|
||||
self.proxy.datagramReceived(r.toString(), ("1.1.1.1", 5060))
|
||||
self.assertEqual(len(l), 1)
|
||||
m, addr = l[0]
|
||||
self.assertEqual(len(m.headers.get("via", [])), 0)
|
||||
self.assertEqual(m.code, 200)
|
||||
|
||||
def testLoop(self):
|
||||
r = sip.Request("INVITE", "sip:foo")
|
||||
r.addHeader("via", sip.Via("1.2.3.4").toString())
|
||||
r.addHeader("via", sip.Via("127.0.0.1").toString())
|
||||
self.proxy.datagramReceived(r.toString(), ("client.com", 5060))
|
||||
self.assertEqual(self.sent, [])
|
||||
|
||||
def testCantForwardRequest(self):
|
||||
r = sip.Request("INVITE", "sip:foo")
|
||||
r.addHeader("via", sip.Via("1.2.3.4").toString())
|
||||
r.addHeader("to", "<sip:joe@server.com>")
|
||||
self.proxy.locator = FailingLocator()
|
||||
self.proxy.datagramReceived(r.toString(), ("1.2.3.4", 5060))
|
||||
self.assertEqual(len(self.sent), 1)
|
||||
dest, m = self.sent[0]
|
||||
self.assertEqual((dest.host, dest.port), ("1.2.3.4", 5060))
|
||||
self.assertEqual(m.code, 404)
|
||||
self.assertEqual(m.headers["via"], ["SIP/2.0/UDP 1.2.3.4:5060"])
|
||||
|
||||
|
||||
class RegistrationTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.proxy = sip.RegisterProxy(host="127.0.0.1")
|
||||
self.registry = sip.InMemoryRegistry("bell.example.com")
|
||||
self.proxy.registry = self.proxy.locator = self.registry
|
||||
self.sent = []
|
||||
self.proxy.sendMessage = lambda dest, msg: self.sent.append((dest, msg))
|
||||
|
||||
def tearDown(self):
|
||||
for d, uri in self.registry.users.values():
|
||||
d.cancel()
|
||||
del self.proxy
|
||||
|
||||
def register(self):
|
||||
r = sip.Request("REGISTER", "sip:bell.example.com")
|
||||
r.addHeader("to", "sip:joe@bell.example.com")
|
||||
r.addHeader("contact", "sip:joe@client.com:1234")
|
||||
r.addHeader("via", sip.Via("client.com").toString())
|
||||
self.proxy.datagramReceived(r.toString(), ("client.com", 5060))
|
||||
|
||||
def unregister(self):
|
||||
r = sip.Request("REGISTER", "sip:bell.example.com")
|
||||
r.addHeader("to", "sip:joe@bell.example.com")
|
||||
r.addHeader("contact", "*")
|
||||
r.addHeader("via", sip.Via("client.com").toString())
|
||||
r.addHeader("expires", "0")
|
||||
self.proxy.datagramReceived(r.toString(), ("client.com", 5060))
|
||||
|
||||
def testRegister(self):
|
||||
self.register()
|
||||
dest, m = self.sent[0]
|
||||
self.assertEqual((dest.host, dest.port), ("client.com", 5060))
|
||||
self.assertEqual(m.code, 200)
|
||||
self.assertEqual(m.headers["via"], ["SIP/2.0/UDP client.com:5060"])
|
||||
self.assertEqual(m.headers["to"], ["sip:joe@bell.example.com"])
|
||||
self.assertEqual(m.headers["contact"], ["sip:joe@client.com:5060"])
|
||||
#
|
||||
# XX: See http://tm.tl/8886
|
||||
#
|
||||
if type(reactor) != AsyncioSelectorReactor:
|
||||
self.assertTrue(int(m.headers["expires"][0]) in (3600, 3601, 3599, 3598))
|
||||
self.assertEqual(len(self.registry.users), 1)
|
||||
dc, uri = self.registry.users["joe"]
|
||||
self.assertEqual(uri.toString(), "sip:joe@client.com:5060")
|
||||
d = self.proxy.locator.getAddress(
|
||||
sip.URL(username="joe", host="bell.example.com")
|
||||
)
|
||||
d.addCallback(lambda desturl: (desturl.host, desturl.port))
|
||||
d.addCallback(self.assertEqual, ("client.com", 5060))
|
||||
return d
|
||||
|
||||
def testUnregister(self):
|
||||
self.register()
|
||||
self.unregister()
|
||||
dest, m = self.sent[1]
|
||||
self.assertEqual((dest.host, dest.port), ("client.com", 5060))
|
||||
self.assertEqual(m.code, 200)
|
||||
self.assertEqual(m.headers["via"], ["SIP/2.0/UDP client.com:5060"])
|
||||
self.assertEqual(m.headers["to"], ["sip:joe@bell.example.com"])
|
||||
self.assertEqual(m.headers["contact"], ["sip:joe@client.com:5060"])
|
||||
self.assertEqual(m.headers["expires"], ["0"])
|
||||
self.assertEqual(self.registry.users, {})
|
||||
|
||||
def addPortal(self):
|
||||
r = TestRealm()
|
||||
p = portal.Portal(r)
|
||||
c = checkers.InMemoryUsernamePasswordDatabaseDontUse()
|
||||
c.addUser("userXname@127.0.0.1", "passXword")
|
||||
p.registerChecker(c)
|
||||
self.proxy.portal = p
|
||||
|
||||
def testFailedAuthentication(self):
|
||||
self.addPortal()
|
||||
self.register()
|
||||
|
||||
self.assertEqual(len(self.registry.users), 0)
|
||||
self.assertEqual(len(self.sent), 1)
|
||||
dest, m = self.sent[0]
|
||||
self.assertEqual(m.code, 401)
|
||||
|
||||
def testWrongDomainRegister(self):
|
||||
r = sip.Request("REGISTER", "sip:wrong.com")
|
||||
r.addHeader("to", "sip:joe@bell.example.com")
|
||||
r.addHeader("contact", "sip:joe@client.com:1234")
|
||||
r.addHeader("via", sip.Via("client.com").toString())
|
||||
self.proxy.datagramReceived(r.toString(), ("client.com", 5060))
|
||||
self.assertEqual(len(self.sent), 0)
|
||||
|
||||
def testWrongToDomainRegister(self):
|
||||
r = sip.Request("REGISTER", "sip:bell.example.com")
|
||||
r.addHeader("to", "sip:joe@foo.com")
|
||||
r.addHeader("contact", "sip:joe@client.com:1234")
|
||||
r.addHeader("via", sip.Via("client.com").toString())
|
||||
self.proxy.datagramReceived(r.toString(), ("client.com", 5060))
|
||||
self.assertEqual(len(self.sent), 0)
|
||||
|
||||
def testWrongDomainLookup(self):
|
||||
self.register()
|
||||
url = sip.URL(username="joe", host="foo.com")
|
||||
d = self.proxy.locator.getAddress(url)
|
||||
self.assertFailure(d, LookupError)
|
||||
return d
|
||||
|
||||
def testNoContactLookup(self):
|
||||
self.register()
|
||||
url = sip.URL(username="jane", host="bell.example.com")
|
||||
d = self.proxy.locator.getAddress(url)
|
||||
self.assertFailure(d, LookupError)
|
||||
return d
|
||||
|
||||
|
||||
class Client(sip.Base):
|
||||
def __init__(self):
|
||||
sip.Base.__init__(self)
|
||||
self.received = []
|
||||
self.deferred = defer.Deferred()
|
||||
|
||||
def handle_response(self, response, addr):
|
||||
self.received.append(response)
|
||||
self.deferred.callback(self.received)
|
||||
|
||||
|
||||
class LiveTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.proxy = sip.RegisterProxy(host="127.0.0.1")
|
||||
self.registry = sip.InMemoryRegistry("bell.example.com")
|
||||
self.proxy.registry = self.proxy.locator = self.registry
|
||||
self.serverPort = reactor.listenUDP(0, self.proxy, interface="127.0.0.1")
|
||||
self.client = Client()
|
||||
self.clientPort = reactor.listenUDP(0, self.client, interface="127.0.0.1")
|
||||
self.serverAddress = (
|
||||
self.serverPort.getHost().host,
|
||||
self.serverPort.getHost().port,
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
for d, uri in self.registry.users.values():
|
||||
d.cancel()
|
||||
d1 = defer.maybeDeferred(self.clientPort.stopListening)
|
||||
d2 = defer.maybeDeferred(self.serverPort.stopListening)
|
||||
return defer.gatherResults([d1, d2])
|
||||
|
||||
def testRegister(self):
|
||||
p = self.clientPort.getHost().port
|
||||
r = sip.Request("REGISTER", "sip:bell.example.com")
|
||||
r.addHeader("to", "sip:joe@bell.example.com")
|
||||
r.addHeader("contact", "sip:joe@127.0.0.1:%d" % p)
|
||||
r.addHeader("via", sip.Via("127.0.0.1", port=p).toString())
|
||||
self.client.sendMessage(
|
||||
sip.URL(host="127.0.0.1", port=self.serverAddress[1]), r
|
||||
)
|
||||
d = self.client.deferred
|
||||
|
||||
def check(received):
|
||||
self.assertEqual(len(received), 1)
|
||||
r = received[0]
|
||||
self.assertEqual(r.code, 200)
|
||||
|
||||
d.addCallback(check)
|
||||
return d
|
||||
|
||||
def test_amoralRPort(self):
|
||||
"""
|
||||
rport is allowed without a value, apparently because server
|
||||
implementors might be too stupid to check the received port
|
||||
against 5060 and see if they're equal, and because client
|
||||
implementors might be too stupid to bind to port 5060, or set a
|
||||
value on the rport parameter they send if they bind to another
|
||||
port.
|
||||
"""
|
||||
p = self.clientPort.getHost().port
|
||||
r = sip.Request("REGISTER", "sip:bell.example.com")
|
||||
r.addHeader("to", "sip:joe@bell.example.com")
|
||||
r.addHeader("contact", "sip:joe@127.0.0.1:%d" % p)
|
||||
r.addHeader("via", sip.Via("127.0.0.1", port=p, rport=True).toString())
|
||||
warnings = self.flushWarnings(offendingFunctions=[self.test_amoralRPort])
|
||||
self.assertEqual(len(warnings), 1)
|
||||
self.assertEqual(
|
||||
warnings[0]["message"], "rport=True is deprecated since Twisted 9.0."
|
||||
)
|
||||
self.assertEqual(warnings[0]["category"], DeprecationWarning)
|
||||
self.client.sendMessage(
|
||||
sip.URL(host="127.0.0.1", port=self.serverAddress[1]), r
|
||||
)
|
||||
d = self.client.deferred
|
||||
|
||||
def check(received):
|
||||
self.assertEqual(len(received), 1)
|
||||
r = received[0]
|
||||
self.assertEqual(r.code, 200)
|
||||
|
||||
d.addCallback(check)
|
||||
return d
|
||||
176
.venv/lib/python3.12/site-packages/twisted/test/test_sob.py
Normal file
176
.venv/lib/python3.12/site-packages/twisted/test/test_sob.py
Normal file
@@ -0,0 +1,176 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
|
||||
import os
|
||||
import sys
|
||||
from textwrap import dedent
|
||||
|
||||
from twisted.persisted import sob
|
||||
from twisted.persisted.styles import Ephemeral
|
||||
from twisted.python import components
|
||||
from twisted.trial import unittest
|
||||
|
||||
|
||||
class Dummy(components.Componentized):
|
||||
pass
|
||||
|
||||
|
||||
objects = [
|
||||
1,
|
||||
"hello",
|
||||
(1, "hello"),
|
||||
[1, "hello"],
|
||||
{1: "hello"},
|
||||
]
|
||||
|
||||
|
||||
class FakeModule:
|
||||
pass
|
||||
|
||||
|
||||
class PersistTests(unittest.TestCase):
|
||||
def testStyles(self):
|
||||
for o in objects:
|
||||
p = sob.Persistent(o, "")
|
||||
for style in "source pickle".split():
|
||||
p.setStyle(style)
|
||||
p.save(filename="persisttest." + style)
|
||||
o1 = sob.load("persisttest." + style, style)
|
||||
self.assertEqual(o, o1)
|
||||
|
||||
def testStylesBeingSet(self):
|
||||
o = Dummy()
|
||||
o.foo = 5
|
||||
o.setComponent(sob.IPersistable, sob.Persistent(o, "lala"))
|
||||
for style in "source pickle".split():
|
||||
sob.IPersistable(o).setStyle(style)
|
||||
sob.IPersistable(o).save(filename="lala." + style)
|
||||
o1 = sob.load("lala." + style, style)
|
||||
self.assertEqual(o.foo, o1.foo)
|
||||
self.assertEqual(sob.IPersistable(o1).style, style)
|
||||
|
||||
def testPassphraseError(self):
|
||||
"""
|
||||
Calling save() with a passphrase is an error.
|
||||
"""
|
||||
p = sob.Persistant(None, "object")
|
||||
self.assertRaises(TypeError, p.save, "filename.pickle", passphrase="abc")
|
||||
|
||||
def testNames(self):
|
||||
o = [1, 2, 3]
|
||||
p = sob.Persistent(o, "object")
|
||||
for style in "source pickle".split():
|
||||
p.setStyle(style)
|
||||
p.save()
|
||||
o1 = sob.load("object.ta" + style[0], style)
|
||||
self.assertEqual(o, o1)
|
||||
for tag in "lala lolo".split():
|
||||
p.save(tag)
|
||||
o1 = sob.load("object-" + tag + ".ta" + style[0], style)
|
||||
self.assertEqual(o, o1)
|
||||
|
||||
def testPython(self):
|
||||
with open("persisttest.python", "w") as f:
|
||||
f.write("foo=[1,2,3] ")
|
||||
o = sob.loadValueFromFile("persisttest.python", "foo")
|
||||
self.assertEqual(o, [1, 2, 3])
|
||||
|
||||
def testTypeGuesser(self):
|
||||
self.assertRaises(KeyError, sob.guessType, "file.blah")
|
||||
self.assertEqual("python", sob.guessType("file.py"))
|
||||
self.assertEqual("python", sob.guessType("file.tac"))
|
||||
self.assertEqual("python", sob.guessType("file.etac"))
|
||||
self.assertEqual("pickle", sob.guessType("file.tap"))
|
||||
self.assertEqual("pickle", sob.guessType("file.etap"))
|
||||
self.assertEqual("source", sob.guessType("file.tas"))
|
||||
self.assertEqual("source", sob.guessType("file.etas"))
|
||||
|
||||
def testEverythingEphemeralGetattr(self):
|
||||
"""
|
||||
L{_EverythingEphermal.__getattr__} will proxy the __main__ module as an
|
||||
L{Ephemeral} object, and during load will be transparent, but after
|
||||
load will return L{Ephemeral} objects from any accessed attributes.
|
||||
"""
|
||||
self.fakeMain.testMainModGetattr = 1
|
||||
|
||||
dirname = self.mktemp()
|
||||
os.mkdir(dirname)
|
||||
|
||||
filename = os.path.join(dirname, "persisttest.ee_getattr")
|
||||
|
||||
global mainWhileLoading
|
||||
mainWhileLoading = None
|
||||
with open(filename, "w") as f:
|
||||
f.write(
|
||||
dedent(
|
||||
"""
|
||||
app = []
|
||||
import __main__
|
||||
app.append(__main__.testMainModGetattr == 1)
|
||||
try:
|
||||
__main__.somethingElse
|
||||
except AttributeError:
|
||||
app.append(True)
|
||||
else:
|
||||
app.append(False)
|
||||
from twisted.test import test_sob
|
||||
test_sob.mainWhileLoading = __main__
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
loaded = sob.load(filename, "source")
|
||||
self.assertIsInstance(loaded, list)
|
||||
self.assertTrue(loaded[0], "Expected attribute not set.")
|
||||
self.assertTrue(loaded[1], "Unexpected attribute set.")
|
||||
self.assertIsInstance(mainWhileLoading, Ephemeral)
|
||||
self.assertIsInstance(mainWhileLoading.somethingElse, Ephemeral)
|
||||
del mainWhileLoading
|
||||
|
||||
def testEverythingEphemeralSetattr(self):
|
||||
"""
|
||||
Verify that _EverythingEphemeral.__setattr__ won't affect __main__.
|
||||
"""
|
||||
self.fakeMain.testMainModSetattr = 1
|
||||
|
||||
dirname = self.mktemp()
|
||||
os.mkdir(dirname)
|
||||
|
||||
filename = os.path.join(dirname, "persisttest.ee_setattr")
|
||||
with open(filename, "w") as f:
|
||||
f.write("import __main__\n")
|
||||
f.write("__main__.testMainModSetattr = 2\n")
|
||||
f.write("app = None\n")
|
||||
|
||||
sob.load(filename, "source")
|
||||
|
||||
self.assertEqual(self.fakeMain.testMainModSetattr, 1)
|
||||
|
||||
def testEverythingEphemeralException(self):
|
||||
"""
|
||||
Test that an exception during load() won't cause _EE to mask __main__
|
||||
"""
|
||||
dirname = self.mktemp()
|
||||
os.mkdir(dirname)
|
||||
filename = os.path.join(dirname, "persisttest.ee_exception")
|
||||
|
||||
with open(filename, "w") as f:
|
||||
f.write("raise ValueError\n")
|
||||
|
||||
self.assertRaises(ValueError, sob.load, filename, "source")
|
||||
self.assertEqual(type(sys.modules["__main__"]), FakeModule)
|
||||
|
||||
def setUp(self):
|
||||
"""
|
||||
Replace the __main__ module with a fake one, so that it can be mutated
|
||||
in tests
|
||||
"""
|
||||
self.realMain = sys.modules["__main__"]
|
||||
self.fakeMain = sys.modules["__main__"] = FakeModule()
|
||||
|
||||
def tearDown(self):
|
||||
"""
|
||||
Restore __main__ to its original value
|
||||
"""
|
||||
sys.modules["__main__"] = self.realMain
|
||||
498
.venv/lib/python3.12/site-packages/twisted/test/test_socks.py
Normal file
498
.venv/lib/python3.12/site-packages/twisted/test/test_socks.py
Normal file
@@ -0,0 +1,498 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.protocol.socks}, an implementation of the SOCKSv4 and
|
||||
SOCKSv4a protocols.
|
||||
"""
|
||||
|
||||
import socket
|
||||
import struct
|
||||
|
||||
from twisted.internet import address, defer
|
||||
from twisted.internet.error import DNSLookupError
|
||||
from twisted.protocols import socks
|
||||
from twisted.python.compat import iterbytes
|
||||
from twisted.test import proto_helpers
|
||||
from twisted.trial import unittest
|
||||
|
||||
|
||||
class StringTCPTransport(proto_helpers.StringTransport):
|
||||
stringTCPTransport_closing = False
|
||||
peer = None
|
||||
|
||||
def getPeer(self):
|
||||
return self.peer
|
||||
|
||||
def getHost(self):
|
||||
return address.IPv4Address("TCP", "2.3.4.5", 42)
|
||||
|
||||
def loseConnection(self):
|
||||
self.stringTCPTransport_closing = True
|
||||
|
||||
|
||||
class FakeResolverReactor:
|
||||
"""
|
||||
Bare-bones reactor with deterministic behavior for the resolve method.
|
||||
"""
|
||||
|
||||
def __init__(self, names):
|
||||
"""
|
||||
@type names: L{dict} containing L{str} keys and L{str} values.
|
||||
@param names: A hostname to IP address mapping. The IP addresses are
|
||||
stringified dotted quads.
|
||||
"""
|
||||
self.names = names
|
||||
|
||||
def resolve(self, hostname):
|
||||
"""
|
||||
Resolve a hostname by looking it up in the C{names} dictionary.
|
||||
"""
|
||||
try:
|
||||
return defer.succeed(self.names[hostname])
|
||||
except KeyError:
|
||||
return defer.fail(
|
||||
DNSLookupError(
|
||||
"FakeResolverReactor couldn't find " + hostname.decode("utf-8")
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class SOCKSv4Driver(socks.SOCKSv4):
|
||||
# last SOCKSv4Outgoing instantiated
|
||||
driver_outgoing = None
|
||||
|
||||
# last SOCKSv4IncomingFactory instantiated
|
||||
driver_listen = None
|
||||
|
||||
def connectClass(self, host, port, klass, *args):
|
||||
# fake it
|
||||
proto = klass(*args)
|
||||
proto.transport = StringTCPTransport()
|
||||
proto.transport.peer = address.IPv4Address("TCP", host, port)
|
||||
proto.connectionMade()
|
||||
self.driver_outgoing = proto
|
||||
return defer.succeed(proto)
|
||||
|
||||
def listenClass(self, port, klass, *args):
|
||||
# fake it
|
||||
factory = klass(*args)
|
||||
self.driver_listen = factory
|
||||
if port == 0:
|
||||
port = 1234
|
||||
return defer.succeed(("6.7.8.9", port))
|
||||
|
||||
|
||||
class ConnectTests(unittest.TestCase):
|
||||
"""
|
||||
Tests for SOCKS and SOCKSv4a connect requests using the L{SOCKSv4} protocol.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
self.sock = SOCKSv4Driver()
|
||||
self.sock.transport = StringTCPTransport()
|
||||
self.sock.connectionMade()
|
||||
self.sock.reactor = FakeResolverReactor({b"localhost": "127.0.0.1"})
|
||||
|
||||
def tearDown(self):
|
||||
outgoing = self.sock.driver_outgoing
|
||||
if outgoing is not None:
|
||||
self.assertTrue(
|
||||
outgoing.transport.stringTCPTransport_closing,
|
||||
"Outgoing SOCKS connections need to be closed.",
|
||||
)
|
||||
|
||||
def test_simple(self):
|
||||
self.sock.dataReceived(
|
||||
struct.pack("!BBH", 4, 1, 34)
|
||||
+ socket.inet_aton("1.2.3.4")
|
||||
+ b"fooBAR"
|
||||
+ b"\0"
|
||||
)
|
||||
sent = self.sock.transport.value()
|
||||
self.sock.transport.clear()
|
||||
self.assertEqual(
|
||||
sent, struct.pack("!BBH", 0, 90, 34) + socket.inet_aton("1.2.3.4")
|
||||
)
|
||||
self.assertFalse(self.sock.transport.stringTCPTransport_closing)
|
||||
self.assertIsNotNone(self.sock.driver_outgoing)
|
||||
|
||||
# pass some data through
|
||||
self.sock.dataReceived(b"hello, world")
|
||||
self.assertEqual(self.sock.driver_outgoing.transport.value(), b"hello, world")
|
||||
|
||||
# the other way around
|
||||
self.sock.driver_outgoing.dataReceived(b"hi there")
|
||||
self.assertEqual(self.sock.transport.value(), b"hi there")
|
||||
|
||||
self.sock.connectionLost("fake reason")
|
||||
|
||||
def test_socks4aSuccessfulResolution(self):
|
||||
"""
|
||||
If the destination IP address has zeros for the first three octets and
|
||||
non-zero for the fourth octet, the client is attempting a v4a
|
||||
connection. A hostname is specified after the user ID string and the
|
||||
server connects to the address that hostname resolves to.
|
||||
|
||||
@see: U{http://en.wikipedia.org/wiki/SOCKS#SOCKS_4a_protocol}
|
||||
"""
|
||||
# send the domain name "localhost" to be resolved
|
||||
clientRequest = (
|
||||
struct.pack("!BBH", 4, 1, 34)
|
||||
+ socket.inet_aton("0.0.0.1")
|
||||
+ b"fooBAZ\0"
|
||||
+ b"localhost\0"
|
||||
)
|
||||
|
||||
# Deliver the bytes one by one to exercise the protocol's buffering
|
||||
# logic. FakeResolverReactor's resolve method is invoked to "resolve"
|
||||
# the hostname.
|
||||
for byte in iterbytes(clientRequest):
|
||||
self.sock.dataReceived(byte)
|
||||
|
||||
sent = self.sock.transport.value()
|
||||
self.sock.transport.clear()
|
||||
|
||||
# Verify that the server responded with the address which will be
|
||||
# connected to.
|
||||
self.assertEqual(
|
||||
sent, struct.pack("!BBH", 0, 90, 34) + socket.inet_aton("127.0.0.1")
|
||||
)
|
||||
self.assertFalse(self.sock.transport.stringTCPTransport_closing)
|
||||
self.assertIsNotNone(self.sock.driver_outgoing)
|
||||
|
||||
# Pass some data through and verify it is forwarded to the outgoing
|
||||
# connection.
|
||||
self.sock.dataReceived(b"hello, world")
|
||||
self.assertEqual(self.sock.driver_outgoing.transport.value(), b"hello, world")
|
||||
|
||||
# Deliver some data from the output connection and verify it is
|
||||
# passed along to the incoming side.
|
||||
self.sock.driver_outgoing.dataReceived(b"hi there")
|
||||
self.assertEqual(self.sock.transport.value(), b"hi there")
|
||||
|
||||
self.sock.connectionLost("fake reason")
|
||||
|
||||
def test_socks4aFailedResolution(self):
|
||||
"""
|
||||
Failed hostname resolution on a SOCKSv4a packet results in a 91 error
|
||||
response and the connection getting closed.
|
||||
"""
|
||||
# send the domain name "failinghost" to be resolved
|
||||
clientRequest = (
|
||||
struct.pack("!BBH", 4, 1, 34)
|
||||
+ socket.inet_aton("0.0.0.1")
|
||||
+ b"fooBAZ\0"
|
||||
+ b"failinghost\0"
|
||||
)
|
||||
|
||||
# Deliver the bytes one by one to exercise the protocol's buffering
|
||||
# logic. FakeResolverReactor's resolve method is invoked to "resolve"
|
||||
# the hostname.
|
||||
for byte in iterbytes(clientRequest):
|
||||
self.sock.dataReceived(byte)
|
||||
|
||||
# Verify that the server responds with a 91 error.
|
||||
sent = self.sock.transport.value()
|
||||
self.assertEqual(
|
||||
sent, struct.pack("!BBH", 0, 91, 0) + socket.inet_aton("0.0.0.0")
|
||||
)
|
||||
|
||||
# A failed resolution causes the transport to drop the connection.
|
||||
self.assertTrue(self.sock.transport.stringTCPTransport_closing)
|
||||
self.assertIsNone(self.sock.driver_outgoing)
|
||||
|
||||
def test_accessDenied(self):
|
||||
self.sock.authorize = lambda code, server, port, user: 0
|
||||
self.sock.dataReceived(
|
||||
struct.pack("!BBH", 4, 1, 4242)
|
||||
+ socket.inet_aton("10.2.3.4")
|
||||
+ b"fooBAR"
|
||||
+ b"\0"
|
||||
)
|
||||
self.assertEqual(
|
||||
self.sock.transport.value(),
|
||||
struct.pack("!BBH", 0, 91, 0) + socket.inet_aton("0.0.0.0"),
|
||||
)
|
||||
self.assertTrue(self.sock.transport.stringTCPTransport_closing)
|
||||
self.assertIsNone(self.sock.driver_outgoing)
|
||||
|
||||
def test_eofRemote(self):
|
||||
self.sock.dataReceived(
|
||||
struct.pack("!BBH", 4, 1, 34)
|
||||
+ socket.inet_aton("1.2.3.4")
|
||||
+ b"fooBAR"
|
||||
+ b"\0"
|
||||
)
|
||||
self.sock.transport.clear()
|
||||
|
||||
# pass some data through
|
||||
self.sock.dataReceived(b"hello, world")
|
||||
self.assertEqual(self.sock.driver_outgoing.transport.value(), b"hello, world")
|
||||
|
||||
# now close it from the server side
|
||||
self.sock.driver_outgoing.transport.loseConnection()
|
||||
self.sock.driver_outgoing.connectionLost("fake reason")
|
||||
|
||||
def test_eofLocal(self):
|
||||
self.sock.dataReceived(
|
||||
struct.pack("!BBH", 4, 1, 34)
|
||||
+ socket.inet_aton("1.2.3.4")
|
||||
+ b"fooBAR"
|
||||
+ b"\0"
|
||||
)
|
||||
self.sock.transport.clear()
|
||||
|
||||
# pass some data through
|
||||
self.sock.dataReceived(b"hello, world")
|
||||
self.assertEqual(self.sock.driver_outgoing.transport.value(), b"hello, world")
|
||||
|
||||
# now close it from the client side
|
||||
self.sock.connectionLost("fake reason")
|
||||
|
||||
|
||||
class BindTests(unittest.TestCase):
|
||||
"""
|
||||
Tests for SOCKS and SOCKSv4a bind requests using the L{SOCKSv4} protocol.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
self.sock = SOCKSv4Driver()
|
||||
self.sock.transport = StringTCPTransport()
|
||||
self.sock.connectionMade()
|
||||
self.sock.reactor = FakeResolverReactor({b"localhost": "127.0.0.1"})
|
||||
|
||||
## def tearDown(self):
|
||||
## # TODO ensure the listen port is closed
|
||||
## listen = self.sock.driver_listen
|
||||
## if listen is not None:
|
||||
## self.assert_(incoming.transport.stringTCPTransport_closing,
|
||||
## "Incoming SOCKS connections need to be closed.")
|
||||
|
||||
def test_simple(self):
|
||||
self.sock.dataReceived(
|
||||
struct.pack("!BBH", 4, 2, 34)
|
||||
+ socket.inet_aton("1.2.3.4")
|
||||
+ b"fooBAR"
|
||||
+ b"\0"
|
||||
)
|
||||
sent = self.sock.transport.value()
|
||||
self.sock.transport.clear()
|
||||
self.assertEqual(
|
||||
sent, struct.pack("!BBH", 0, 90, 1234) + socket.inet_aton("6.7.8.9")
|
||||
)
|
||||
self.assertFalse(self.sock.transport.stringTCPTransport_closing)
|
||||
self.assertIsNotNone(self.sock.driver_listen)
|
||||
|
||||
# connect
|
||||
incoming = self.sock.driver_listen.buildProtocol(("1.2.3.4", 5345))
|
||||
self.assertIsNotNone(incoming)
|
||||
incoming.transport = StringTCPTransport()
|
||||
incoming.connectionMade()
|
||||
|
||||
# now we should have the second reply packet
|
||||
sent = self.sock.transport.value()
|
||||
self.sock.transport.clear()
|
||||
self.assertEqual(
|
||||
sent, struct.pack("!BBH", 0, 90, 0) + socket.inet_aton("0.0.0.0")
|
||||
)
|
||||
self.assertFalse(self.sock.transport.stringTCPTransport_closing)
|
||||
|
||||
# pass some data through
|
||||
self.sock.dataReceived(b"hello, world")
|
||||
self.assertEqual(incoming.transport.value(), b"hello, world")
|
||||
|
||||
# the other way around
|
||||
incoming.dataReceived(b"hi there")
|
||||
self.assertEqual(self.sock.transport.value(), b"hi there")
|
||||
|
||||
self.sock.connectionLost("fake reason")
|
||||
|
||||
def test_socks4a(self):
|
||||
"""
|
||||
If the destination IP address has zeros for the first three octets and
|
||||
non-zero for the fourth octet, the client is attempting a v4a
|
||||
connection. A hostname is specified after the user ID string and the
|
||||
server connects to the address that hostname resolves to.
|
||||
|
||||
@see: U{http://en.wikipedia.org/wiki/SOCKS#SOCKS_4a_protocol}
|
||||
"""
|
||||
# send the domain name "localhost" to be resolved
|
||||
clientRequest = (
|
||||
struct.pack("!BBH", 4, 2, 34)
|
||||
+ socket.inet_aton("0.0.0.1")
|
||||
+ b"fooBAZ\0"
|
||||
+ b"localhost\0"
|
||||
)
|
||||
|
||||
# Deliver the bytes one by one to exercise the protocol's buffering
|
||||
# logic. FakeResolverReactor's resolve method is invoked to "resolve"
|
||||
# the hostname.
|
||||
for byte in iterbytes(clientRequest):
|
||||
self.sock.dataReceived(byte)
|
||||
|
||||
sent = self.sock.transport.value()
|
||||
self.sock.transport.clear()
|
||||
|
||||
# Verify that the server responded with the address which will be
|
||||
# connected to.
|
||||
self.assertEqual(
|
||||
sent, struct.pack("!BBH", 0, 90, 1234) + socket.inet_aton("6.7.8.9")
|
||||
)
|
||||
self.assertFalse(self.sock.transport.stringTCPTransport_closing)
|
||||
self.assertIsNotNone(self.sock.driver_listen)
|
||||
|
||||
# connect
|
||||
incoming = self.sock.driver_listen.buildProtocol(("127.0.0.1", 5345))
|
||||
self.assertIsNotNone(incoming)
|
||||
incoming.transport = StringTCPTransport()
|
||||
incoming.connectionMade()
|
||||
|
||||
# now we should have the second reply packet
|
||||
sent = self.sock.transport.value()
|
||||
self.sock.transport.clear()
|
||||
self.assertEqual(
|
||||
sent, struct.pack("!BBH", 0, 90, 0) + socket.inet_aton("0.0.0.0")
|
||||
)
|
||||
self.assertIsNot(self.sock.transport.stringTCPTransport_closing, None)
|
||||
|
||||
# Deliver some data from the output connection and verify it is
|
||||
# passed along to the incoming side.
|
||||
self.sock.dataReceived(b"hi there")
|
||||
self.assertEqual(incoming.transport.value(), b"hi there")
|
||||
|
||||
# the other way around
|
||||
incoming.dataReceived(b"hi there")
|
||||
self.assertEqual(self.sock.transport.value(), b"hi there")
|
||||
|
||||
self.sock.connectionLost("fake reason")
|
||||
|
||||
def test_socks4aFailedResolution(self):
|
||||
"""
|
||||
Failed hostname resolution on a SOCKSv4a packet results in a 91 error
|
||||
response and the connection getting closed.
|
||||
"""
|
||||
# send the domain name "failinghost" to be resolved
|
||||
clientRequest = (
|
||||
struct.pack("!BBH", 4, 2, 34)
|
||||
+ socket.inet_aton("0.0.0.1")
|
||||
+ b"fooBAZ\0"
|
||||
+ b"failinghost\0"
|
||||
)
|
||||
|
||||
# Deliver the bytes one by one to exercise the protocol's buffering
|
||||
# logic. FakeResolverReactor's resolve method is invoked to "resolve"
|
||||
# the hostname.
|
||||
for byte in iterbytes(clientRequest):
|
||||
self.sock.dataReceived(byte)
|
||||
|
||||
# Verify that the server responds with a 91 error.
|
||||
sent = self.sock.transport.value()
|
||||
self.assertEqual(
|
||||
sent, struct.pack("!BBH", 0, 91, 0) + socket.inet_aton("0.0.0.0")
|
||||
)
|
||||
|
||||
# A failed resolution causes the transport to drop the connection.
|
||||
self.assertTrue(self.sock.transport.stringTCPTransport_closing)
|
||||
self.assertIsNone(self.sock.driver_outgoing)
|
||||
|
||||
def test_accessDenied(self):
|
||||
self.sock.authorize = lambda code, server, port, user: 0
|
||||
self.sock.dataReceived(
|
||||
struct.pack("!BBH", 4, 2, 4242)
|
||||
+ socket.inet_aton("10.2.3.4")
|
||||
+ b"fooBAR"
|
||||
+ b"\0"
|
||||
)
|
||||
self.assertEqual(
|
||||
self.sock.transport.value(),
|
||||
struct.pack("!BBH", 0, 91, 0) + socket.inet_aton("0.0.0.0"),
|
||||
)
|
||||
self.assertTrue(self.sock.transport.stringTCPTransport_closing)
|
||||
self.assertIsNone(self.sock.driver_listen)
|
||||
|
||||
def test_eofRemote(self):
|
||||
self.sock.dataReceived(
|
||||
struct.pack("!BBH", 4, 2, 34)
|
||||
+ socket.inet_aton("1.2.3.4")
|
||||
+ b"fooBAR"
|
||||
+ b"\0"
|
||||
)
|
||||
sent = self.sock.transport.value()
|
||||
self.sock.transport.clear()
|
||||
|
||||
# connect
|
||||
incoming = self.sock.driver_listen.buildProtocol(("1.2.3.4", 5345))
|
||||
self.assertIsNotNone(incoming)
|
||||
incoming.transport = StringTCPTransport()
|
||||
incoming.connectionMade()
|
||||
|
||||
# now we should have the second reply packet
|
||||
sent = self.sock.transport.value()
|
||||
self.sock.transport.clear()
|
||||
self.assertEqual(
|
||||
sent, struct.pack("!BBH", 0, 90, 0) + socket.inet_aton("0.0.0.0")
|
||||
)
|
||||
self.assertFalse(self.sock.transport.stringTCPTransport_closing)
|
||||
|
||||
# pass some data through
|
||||
self.sock.dataReceived(b"hello, world")
|
||||
self.assertEqual(incoming.transport.value(), b"hello, world")
|
||||
|
||||
# now close it from the server side
|
||||
incoming.transport.loseConnection()
|
||||
incoming.connectionLost("fake reason")
|
||||
|
||||
def test_eofLocal(self):
|
||||
self.sock.dataReceived(
|
||||
struct.pack("!BBH", 4, 2, 34)
|
||||
+ socket.inet_aton("1.2.3.4")
|
||||
+ b"fooBAR"
|
||||
+ b"\0"
|
||||
)
|
||||
sent = self.sock.transport.value()
|
||||
self.sock.transport.clear()
|
||||
|
||||
# connect
|
||||
incoming = self.sock.driver_listen.buildProtocol(("1.2.3.4", 5345))
|
||||
self.assertIsNotNone(incoming)
|
||||
incoming.transport = StringTCPTransport()
|
||||
incoming.connectionMade()
|
||||
|
||||
# now we should have the second reply packet
|
||||
sent = self.sock.transport.value()
|
||||
self.sock.transport.clear()
|
||||
self.assertEqual(
|
||||
sent, struct.pack("!BBH", 0, 90, 0) + socket.inet_aton("0.0.0.0")
|
||||
)
|
||||
self.assertFalse(self.sock.transport.stringTCPTransport_closing)
|
||||
|
||||
# pass some data through
|
||||
self.sock.dataReceived(b"hello, world")
|
||||
self.assertEqual(incoming.transport.value(), b"hello, world")
|
||||
|
||||
# now close it from the client side
|
||||
self.sock.connectionLost("fake reason")
|
||||
|
||||
def test_badSource(self):
|
||||
self.sock.dataReceived(
|
||||
struct.pack("!BBH", 4, 2, 34)
|
||||
+ socket.inet_aton("1.2.3.4")
|
||||
+ b"fooBAR"
|
||||
+ b"\0"
|
||||
)
|
||||
sent = self.sock.transport.value()
|
||||
self.sock.transport.clear()
|
||||
|
||||
# connect from WRONG address
|
||||
incoming = self.sock.driver_listen.buildProtocol(("1.6.6.6", 666))
|
||||
self.assertIsNone(incoming)
|
||||
|
||||
# Now we should have the second reply packet and it should
|
||||
# be a failure. The connection should be closing.
|
||||
sent = self.sock.transport.value()
|
||||
self.sock.transport.clear()
|
||||
self.assertEqual(
|
||||
sent, struct.pack("!BBH", 0, 91, 0) + socket.inet_aton("0.0.0.0")
|
||||
)
|
||||
self.assertTrue(self.sock.transport.stringTCPTransport_closing)
|
||||
713
.venv/lib/python3.12/site-packages/twisted/test/test_ssl.py
Normal file
713
.venv/lib/python3.12/site-packages/twisted/test/test_ssl.py
Normal file
@@ -0,0 +1,713 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for twisted SSL support.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import hamcrest
|
||||
|
||||
from twisted.internet import defer, interfaces, protocol, reactor
|
||||
from twisted.internet.error import ConnectionDone
|
||||
from twisted.internet.testing import waitUntilAllDisconnected
|
||||
from twisted.protocols import basic
|
||||
from twisted.python.filepath import FilePath
|
||||
from twisted.python.runtime import platform
|
||||
from twisted.test.test_tcp import ProperlyCloseFilesMixin
|
||||
from twisted.trial.unittest import TestCase
|
||||
|
||||
try:
|
||||
from OpenSSL import SSL, crypto
|
||||
|
||||
from twisted.internet import ssl
|
||||
from twisted.test.ssl_helpers import ClientTLSContext, certPath
|
||||
except ImportError:
|
||||
|
||||
def _noSSL():
|
||||
# ugh, make pyflakes happy.
|
||||
global SSL
|
||||
global ssl
|
||||
SSL = ssl = None
|
||||
|
||||
_noSSL()
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
|
||||
class UnintelligentProtocol(basic.LineReceiver):
|
||||
"""
|
||||
@ivar deferred: a deferred that will fire at connection lost.
|
||||
@type deferred: L{defer.Deferred}
|
||||
|
||||
@cvar pretext: text sent before TLS is set up.
|
||||
@type pretext: C{bytes}
|
||||
|
||||
@cvar posttext: text sent after TLS is set up.
|
||||
@type posttext: C{bytes}
|
||||
"""
|
||||
|
||||
pretext = [b"first line", b"last thing before tls starts", b"STARTTLS"]
|
||||
|
||||
posttext = [b"first thing after tls started", b"last thing ever"]
|
||||
|
||||
def __init__(self):
|
||||
self.deferred = defer.Deferred()
|
||||
|
||||
def connectionMade(self):
|
||||
for l in self.pretext:
|
||||
self.sendLine(l)
|
||||
|
||||
def lineReceived(self, line):
|
||||
if line == b"READY":
|
||||
self.transport.startTLS(ClientTLSContext(), self.factory.client)
|
||||
for l in self.posttext:
|
||||
self.sendLine(l)
|
||||
self.transport.loseConnection()
|
||||
|
||||
def connectionLost(self, reason):
|
||||
self.deferred.callback(None)
|
||||
|
||||
|
||||
class LineCollector(basic.LineReceiver):
|
||||
"""
|
||||
@ivar deferred: a deferred that will fire at connection lost.
|
||||
@type deferred: L{defer.Deferred}
|
||||
|
||||
@ivar doTLS: whether the protocol is initiate TLS or not.
|
||||
@type doTLS: C{bool}
|
||||
|
||||
@ivar fillBuffer: if set to True, it will send lots of data once
|
||||
C{STARTTLS} is received.
|
||||
@type fillBuffer: C{bool}
|
||||
"""
|
||||
|
||||
def __init__(self, doTLS, fillBuffer=False):
|
||||
self.doTLS = doTLS
|
||||
self.fillBuffer = fillBuffer
|
||||
self.deferred = defer.Deferred()
|
||||
|
||||
def connectionMade(self):
|
||||
self.factory.rawdata = b""
|
||||
self.factory.lines = []
|
||||
|
||||
def lineReceived(self, line):
|
||||
self.factory.lines.append(line)
|
||||
if line == b"STARTTLS":
|
||||
if self.fillBuffer:
|
||||
for x in range(500):
|
||||
self.sendLine(b"X" * 1000)
|
||||
self.sendLine(b"READY")
|
||||
if self.doTLS:
|
||||
ctx = ServerTLSContext(
|
||||
privateKeyFileName=certPath,
|
||||
certificateFileName=certPath,
|
||||
)
|
||||
self.transport.startTLS(ctx, self.factory.server)
|
||||
else:
|
||||
self.setRawMode()
|
||||
|
||||
def rawDataReceived(self, data):
|
||||
self.factory.rawdata += data
|
||||
self.transport.loseConnection()
|
||||
|
||||
def connectionLost(self, reason):
|
||||
self.deferred.callback(None)
|
||||
|
||||
|
||||
class SingleLineServerProtocol(protocol.Protocol):
|
||||
"""
|
||||
A protocol that sends a single line of data at C{connectionMade}.
|
||||
"""
|
||||
|
||||
def connectionMade(self):
|
||||
self.transport.write(b"+OK <some crap>\r\n")
|
||||
self.transport.getPeerCertificate()
|
||||
|
||||
|
||||
class RecordingClientProtocol(protocol.Protocol):
|
||||
"""
|
||||
@ivar deferred: a deferred that will fire with first received content.
|
||||
@type deferred: L{defer.Deferred}
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.deferred = defer.Deferred()
|
||||
|
||||
def connectionMade(self):
|
||||
self.transport.getPeerCertificate()
|
||||
|
||||
def dataReceived(self, data):
|
||||
self.deferred.callback(data)
|
||||
|
||||
|
||||
@implementer(interfaces.IHandshakeListener)
|
||||
class ImmediatelyDisconnectingProtocol(protocol.Protocol):
|
||||
"""
|
||||
A protocol that disconnect immediately on connection. It fires the
|
||||
C{connectionDisconnected} deferred of its factory on connetion lost.
|
||||
"""
|
||||
|
||||
def handshakeCompleted(self):
|
||||
self.transport.loseConnection()
|
||||
|
||||
def connectionLost(self, reason):
|
||||
self.factory.connectionDisconnected.callback(None)
|
||||
|
||||
|
||||
def generateCertificateObjects(organization, organizationalUnit):
|
||||
"""
|
||||
Create a certificate for given C{organization} and C{organizationalUnit}.
|
||||
|
||||
@return: a tuple of (key, request, certificate) objects.
|
||||
"""
|
||||
pkey = crypto.PKey()
|
||||
pkey.generate_key(crypto.TYPE_RSA, 2048)
|
||||
req = crypto.X509Req()
|
||||
subject = req.get_subject()
|
||||
subject.O = organization
|
||||
subject.OU = organizationalUnit
|
||||
req.set_pubkey(pkey)
|
||||
req.sign(pkey, "md5")
|
||||
|
||||
# Here comes the actual certificate
|
||||
cert = crypto.X509()
|
||||
cert.set_serial_number(1)
|
||||
cert.gmtime_adj_notBefore(0)
|
||||
cert.gmtime_adj_notAfter(60) # Testing certificates need not be long lived
|
||||
cert.set_issuer(req.get_subject())
|
||||
cert.set_subject(req.get_subject())
|
||||
cert.set_pubkey(req.get_pubkey())
|
||||
cert.sign(pkey, "md5")
|
||||
|
||||
return pkey, req, cert
|
||||
|
||||
|
||||
def generateCertificateFiles(basename, organization, organizationalUnit):
|
||||
"""
|
||||
Create certificate files key, req and cert prefixed by C{basename} for
|
||||
given C{organization} and C{organizationalUnit}.
|
||||
"""
|
||||
pkey, req, cert = generateCertificateObjects(organization, organizationalUnit)
|
||||
|
||||
for ext, obj, dumpFunc in [
|
||||
("key", pkey, crypto.dump_privatekey),
|
||||
("req", req, crypto.dump_certificate_request),
|
||||
("cert", cert, crypto.dump_certificate),
|
||||
]:
|
||||
fName = os.extsep.join((basename, ext)).encode("utf-8")
|
||||
FilePath(fName).setContent(dumpFunc(crypto.FILETYPE_PEM, obj))
|
||||
|
||||
|
||||
class ContextGeneratingMixin:
|
||||
"""
|
||||
Offer methods to create L{ssl.DefaultOpenSSLContextFactory} for both client
|
||||
and server.
|
||||
|
||||
@ivar clientBase: prefix of client certificate files.
|
||||
@type clientBase: C{str}
|
||||
|
||||
@ivar serverBase: prefix of server certificate files.
|
||||
@type serverBase: C{str}
|
||||
|
||||
@ivar clientCtxFactory: a generated context factory to be used in
|
||||
L{IReactorSSL.connectSSL}.
|
||||
@type clientCtxFactory: L{ssl.DefaultOpenSSLContextFactory}
|
||||
|
||||
@ivar serverCtxFactory: a generated context factory to be used in
|
||||
L{IReactorSSL.listenSSL}.
|
||||
@type serverCtxFactory: L{ssl.DefaultOpenSSLContextFactory}
|
||||
"""
|
||||
|
||||
def makeContextFactory(self, org, orgUnit, *args, **kwArgs):
|
||||
base = self.mktemp()
|
||||
generateCertificateFiles(base, org, orgUnit)
|
||||
serverCtxFactory = ssl.DefaultOpenSSLContextFactory(
|
||||
os.extsep.join((base, "key")),
|
||||
os.extsep.join((base, "cert")),
|
||||
*args,
|
||||
**kwArgs,
|
||||
)
|
||||
|
||||
return base, serverCtxFactory
|
||||
|
||||
def setupServerAndClient(self, clientArgs, clientKwArgs, serverArgs, serverKwArgs):
|
||||
self.clientBase, self.clientCtxFactory = self.makeContextFactory(
|
||||
*clientArgs, **clientKwArgs
|
||||
)
|
||||
self.serverBase, self.serverCtxFactory = self.makeContextFactory(
|
||||
*serverArgs, **serverKwArgs
|
||||
)
|
||||
|
||||
|
||||
if SSL is not None:
|
||||
|
||||
class ServerTLSContext(ssl.DefaultOpenSSLContextFactory):
|
||||
"""
|
||||
A context factory with a default method set to
|
||||
L{OpenSSL.SSL.SSLv23_METHOD}.
|
||||
"""
|
||||
|
||||
isClient = False
|
||||
|
||||
def __init__(self, *args, **kw):
|
||||
kw["sslmethod"] = SSL.SSLv23_METHOD
|
||||
ssl.DefaultOpenSSLContextFactory.__init__(self, *args, **kw)
|
||||
|
||||
|
||||
class StolenTCPTests(ProperlyCloseFilesMixin, TestCase):
|
||||
"""
|
||||
For SSL transports, test many of the same things which are tested for
|
||||
TCP transports.
|
||||
"""
|
||||
|
||||
if interfaces.IReactorSSL(reactor, None) is None:
|
||||
skip = "Reactor does not support SSL, cannot run SSL tests"
|
||||
|
||||
def createServer(self, address, portNumber, factory):
|
||||
"""
|
||||
Create an SSL server with a certificate using L{IReactorSSL.listenSSL}.
|
||||
"""
|
||||
cert = ssl.PrivateCertificate.loadPEM(FilePath(certPath).getContent())
|
||||
contextFactory = cert.options()
|
||||
return reactor.listenSSL(portNumber, factory, contextFactory, interface=address)
|
||||
|
||||
def connectClient(self, address, portNumber, clientCreator):
|
||||
"""
|
||||
Create an SSL client using L{IReactorSSL.connectSSL}.
|
||||
"""
|
||||
contextFactory = ssl.CertificateOptions()
|
||||
return clientCreator.connectSSL(address, portNumber, contextFactory)
|
||||
|
||||
def getHandleExceptionType(self):
|
||||
"""
|
||||
Return L{OpenSSL.SSL.Error} as the expected error type which will be
|
||||
raised by a write to the L{OpenSSL.SSL.Connection} object after it has
|
||||
been closed.
|
||||
"""
|
||||
return SSL.Error
|
||||
|
||||
def getHandleErrorCodeMatcher(self):
|
||||
"""
|
||||
Return a L{hamcrest.core.matcher.Matcher} for the argument
|
||||
L{OpenSSL.SSL.Error} will be constructed with for this case.
|
||||
This is basically just a random OpenSSL implementation detail.
|
||||
It would be better if this test worked in a way which did not
|
||||
require this.
|
||||
"""
|
||||
# We expect an error about how we tried to write to a shutdown
|
||||
# connection. This is terribly implementation-specific.
|
||||
return hamcrest.contains(
|
||||
hamcrest.contains(
|
||||
hamcrest.equal_to("SSL routines"),
|
||||
hamcrest.any_of(
|
||||
hamcrest.equal_to("SSL_write"),
|
||||
hamcrest.equal_to("ssl_write_internal"),
|
||||
hamcrest.equal_to(""),
|
||||
),
|
||||
hamcrest.equal_to("protocol is shutdown"),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class TLSTests(TestCase):
|
||||
"""
|
||||
Tests for startTLS support.
|
||||
|
||||
@ivar fillBuffer: forwarded to L{LineCollector.fillBuffer}
|
||||
@type fillBuffer: C{bool}
|
||||
"""
|
||||
|
||||
if interfaces.IReactorSSL(reactor, None) is None:
|
||||
skip = "Reactor does not support SSL, cannot run SSL tests"
|
||||
|
||||
fillBuffer = False
|
||||
|
||||
clientProto = None
|
||||
serverProto = None
|
||||
|
||||
def tearDown(self):
|
||||
if self.clientProto.transport is not None:
|
||||
self.clientProto.transport.loseConnection()
|
||||
if self.serverProto.transport is not None:
|
||||
self.serverProto.transport.loseConnection()
|
||||
|
||||
def _runTest(self, clientProto, serverProto, clientIsServer=False):
|
||||
"""
|
||||
Helper method to run TLS tests.
|
||||
|
||||
@param clientProto: protocol instance attached to the client
|
||||
connection.
|
||||
@param serverProto: protocol instance attached to the server
|
||||
connection.
|
||||
@param clientIsServer: flag indicated if client should initiate
|
||||
startTLS instead of server.
|
||||
|
||||
@return: a L{defer.Deferred} that will fire when both connections are
|
||||
lost.
|
||||
"""
|
||||
self.clientProto = clientProto
|
||||
cf = self.clientFactory = protocol.ClientFactory()
|
||||
cf.protocol = lambda: clientProto
|
||||
if clientIsServer:
|
||||
cf.server = False
|
||||
else:
|
||||
cf.client = True
|
||||
|
||||
self.serverProto = serverProto
|
||||
sf = self.serverFactory = protocol.ServerFactory()
|
||||
sf.protocol = lambda: serverProto
|
||||
if clientIsServer:
|
||||
sf.client = False
|
||||
else:
|
||||
sf.server = True
|
||||
|
||||
port = reactor.listenTCP(0, sf, interface="127.0.0.1")
|
||||
self.addCleanup(port.stopListening)
|
||||
|
||||
reactor.connectTCP("127.0.0.1", port.getHost().port, cf)
|
||||
|
||||
return defer.gatherResults([clientProto.deferred, serverProto.deferred])
|
||||
|
||||
def test_TLS(self):
|
||||
"""
|
||||
Test for server and client startTLS: client should received data both
|
||||
before and after the startTLS.
|
||||
"""
|
||||
|
||||
def check(ignore):
|
||||
self.assertEqual(
|
||||
self.serverFactory.lines,
|
||||
UnintelligentProtocol.pretext + UnintelligentProtocol.posttext,
|
||||
)
|
||||
|
||||
d = self._runTest(UnintelligentProtocol(), LineCollector(True, self.fillBuffer))
|
||||
return d.addCallback(check)
|
||||
|
||||
def test_unTLS(self):
|
||||
"""
|
||||
Test for server startTLS not followed by a startTLS in client: the data
|
||||
received after server startTLS should be received as raw.
|
||||
"""
|
||||
|
||||
def check(ignored):
|
||||
self.assertEqual(self.serverFactory.lines, UnintelligentProtocol.pretext)
|
||||
self.assertTrue(self.serverFactory.rawdata, "No encrypted bytes received")
|
||||
|
||||
d = self._runTest(
|
||||
UnintelligentProtocol(), LineCollector(False, self.fillBuffer)
|
||||
)
|
||||
return d.addCallback(check)
|
||||
|
||||
def test_backwardsTLS(self):
|
||||
"""
|
||||
Test startTLS first initiated by client.
|
||||
"""
|
||||
|
||||
def check(ignored):
|
||||
self.assertEqual(
|
||||
self.clientFactory.lines,
|
||||
UnintelligentProtocol.pretext + UnintelligentProtocol.posttext,
|
||||
)
|
||||
|
||||
d = self._runTest(
|
||||
LineCollector(True, self.fillBuffer), UnintelligentProtocol(), True
|
||||
)
|
||||
return d.addCallback(check)
|
||||
|
||||
|
||||
class SpammyTLSTests(TLSTests):
|
||||
"""
|
||||
Test TLS features with bytes sitting in the out buffer.
|
||||
"""
|
||||
|
||||
if interfaces.IReactorSSL(reactor, None) is None:
|
||||
skip = "Reactor does not support SSL, cannot run SSL tests"
|
||||
|
||||
fillBuffer = True
|
||||
|
||||
|
||||
class BufferingTests(TestCase):
|
||||
if interfaces.IReactorSSL(reactor, None) is None:
|
||||
skip = "Reactor does not support SSL, cannot run SSL tests"
|
||||
|
||||
serverProto = None
|
||||
clientProto = None
|
||||
|
||||
def tearDown(self):
|
||||
if self.serverProto.transport is not None:
|
||||
self.serverProto.transport.loseConnection()
|
||||
if self.clientProto.transport is not None:
|
||||
self.clientProto.transport.loseConnection()
|
||||
|
||||
return waitUntilAllDisconnected(reactor, [self.serverProto, self.clientProto])
|
||||
|
||||
def test_openSSLBuffering(self):
|
||||
serverProto = self.serverProto = SingleLineServerProtocol()
|
||||
clientProto = self.clientProto = RecordingClientProtocol()
|
||||
|
||||
server = protocol.ServerFactory()
|
||||
client = self.client = protocol.ClientFactory()
|
||||
|
||||
server.protocol = lambda: serverProto
|
||||
client.protocol = lambda: clientProto
|
||||
|
||||
sCTX = ssl.DefaultOpenSSLContextFactory(certPath, certPath)
|
||||
cCTX = ssl.ClientContextFactory()
|
||||
|
||||
port = reactor.listenSSL(0, server, sCTX, interface="127.0.0.1")
|
||||
self.addCleanup(port.stopListening)
|
||||
|
||||
clientConnector = reactor.connectSSL(
|
||||
"127.0.0.1", port.getHost().port, client, cCTX
|
||||
)
|
||||
self.addCleanup(clientConnector.disconnect)
|
||||
|
||||
return clientProto.deferred.addCallback(
|
||||
self.assertEqual, b"+OK <some crap>\r\n"
|
||||
)
|
||||
|
||||
|
||||
class ConnectionLostTests(TestCase, ContextGeneratingMixin):
|
||||
"""
|
||||
SSL connection closing tests.
|
||||
"""
|
||||
|
||||
if interfaces.IReactorSSL(reactor, None) is None:
|
||||
skip = "Reactor does not support SSL, cannot run SSL tests"
|
||||
|
||||
def testImmediateDisconnect(self):
|
||||
org = "twisted.test.test_ssl"
|
||||
self.setupServerAndClient(
|
||||
(org, org + ", client"), {}, (org, org + ", server"), {}
|
||||
)
|
||||
|
||||
# Set up a server, connect to it with a client, which should work since our verifiers
|
||||
# allow anything, then disconnect.
|
||||
serverProtocolFactory = protocol.ServerFactory()
|
||||
serverProtocolFactory.protocol = protocol.Protocol
|
||||
self.serverPort = serverPort = reactor.listenSSL(
|
||||
0, serverProtocolFactory, self.serverCtxFactory
|
||||
)
|
||||
|
||||
clientProtocolFactory = protocol.ClientFactory()
|
||||
clientProtocolFactory.protocol = ImmediatelyDisconnectingProtocol
|
||||
clientProtocolFactory.connectionDisconnected = defer.Deferred()
|
||||
reactor.connectSSL(
|
||||
"127.0.0.1",
|
||||
serverPort.getHost().port,
|
||||
clientProtocolFactory,
|
||||
self.clientCtxFactory,
|
||||
)
|
||||
|
||||
return clientProtocolFactory.connectionDisconnected.addCallback(
|
||||
lambda ignoredResult: self.serverPort.stopListening()
|
||||
)
|
||||
|
||||
def test_bothSidesLoseConnection(self):
|
||||
"""
|
||||
Both sides of SSL connection close connection; the connections should
|
||||
close cleanly, and only after the underlying TCP connection has
|
||||
disconnected.
|
||||
"""
|
||||
|
||||
@implementer(interfaces.IHandshakeListener)
|
||||
class CloseAfterHandshake(protocol.Protocol):
|
||||
gotData = False
|
||||
|
||||
def __init__(self):
|
||||
self.done = defer.Deferred()
|
||||
|
||||
def handshakeCompleted(self):
|
||||
self.transport.loseConnection()
|
||||
|
||||
def connectionLost(self, reason):
|
||||
self.done.errback(reason)
|
||||
del self.done
|
||||
|
||||
org = "twisted.test.test_ssl"
|
||||
self.setupServerAndClient(
|
||||
(org, org + ", client"), {}, (org, org + ", server"), {}
|
||||
)
|
||||
|
||||
serverProtocol = CloseAfterHandshake()
|
||||
serverProtocolFactory = protocol.ServerFactory()
|
||||
serverProtocolFactory.protocol = lambda: serverProtocol
|
||||
serverPort = reactor.listenSSL(0, serverProtocolFactory, self.serverCtxFactory)
|
||||
self.addCleanup(serverPort.stopListening)
|
||||
|
||||
clientProtocol = CloseAfterHandshake()
|
||||
clientProtocolFactory = protocol.ClientFactory()
|
||||
clientProtocolFactory.protocol = lambda: clientProtocol
|
||||
reactor.connectSSL(
|
||||
"127.0.0.1",
|
||||
serverPort.getHost().port,
|
||||
clientProtocolFactory,
|
||||
self.clientCtxFactory,
|
||||
)
|
||||
|
||||
def checkResult(failure):
|
||||
failure.trap(ConnectionDone)
|
||||
|
||||
return defer.gatherResults(
|
||||
[
|
||||
clientProtocol.done.addErrback(checkResult),
|
||||
serverProtocol.done.addErrback(checkResult),
|
||||
]
|
||||
)
|
||||
|
||||
def testFailedVerify(self):
|
||||
org = "twisted.test.test_ssl"
|
||||
self.setupServerAndClient(
|
||||
(org, org + ", client"), {}, (org, org + ", server"), {}
|
||||
)
|
||||
|
||||
def verify(*a):
|
||||
return False
|
||||
|
||||
self.clientCtxFactory.getContext().set_verify(SSL.VERIFY_PEER, verify)
|
||||
|
||||
serverConnLost = defer.Deferred()
|
||||
serverProtocol = protocol.Protocol()
|
||||
serverProtocol.connectionLost = serverConnLost.callback
|
||||
serverProtocolFactory = protocol.ServerFactory()
|
||||
serverProtocolFactory.protocol = lambda: serverProtocol
|
||||
self.serverPort = serverPort = reactor.listenSSL(
|
||||
0, serverProtocolFactory, self.serverCtxFactory
|
||||
)
|
||||
|
||||
clientConnLost = defer.Deferred()
|
||||
clientProtocol = protocol.Protocol()
|
||||
clientProtocol.connectionLost = clientConnLost.callback
|
||||
clientProtocolFactory = protocol.ClientFactory()
|
||||
clientProtocolFactory.protocol = lambda: clientProtocol
|
||||
reactor.connectSSL(
|
||||
"127.0.0.1",
|
||||
serverPort.getHost().port,
|
||||
clientProtocolFactory,
|
||||
self.clientCtxFactory,
|
||||
)
|
||||
|
||||
dl = defer.DeferredList([serverConnLost, clientConnLost], consumeErrors=True)
|
||||
return dl.addCallback(self._cbLostConns)
|
||||
|
||||
def _cbLostConns(self, results):
|
||||
(sSuccess, sResult), (cSuccess, cResult) = results
|
||||
|
||||
self.assertFalse(sSuccess)
|
||||
self.assertFalse(cSuccess)
|
||||
|
||||
acceptableErrors = [SSL.Error]
|
||||
|
||||
# Rather than getting a verification failure on Windows, we are getting
|
||||
# a connection failure. Without something like sslverify proxying
|
||||
# in-between we can't fix up the platform's errors, so let's just
|
||||
# specifically say it is only OK in this one case to keep the tests
|
||||
# passing. Normally we'd like to be as strict as possible here, so
|
||||
# we're not going to allow this to report errors incorrectly on any
|
||||
# other platforms.
|
||||
|
||||
if platform.isWindows():
|
||||
from twisted.internet.error import ConnectionLost
|
||||
|
||||
acceptableErrors.append(ConnectionLost)
|
||||
|
||||
sResult.trap(*acceptableErrors)
|
||||
cResult.trap(*acceptableErrors)
|
||||
|
||||
return self.serverPort.stopListening()
|
||||
|
||||
|
||||
class FakeContext:
|
||||
"""
|
||||
L{OpenSSL.SSL.Context} double which can more easily be inspected.
|
||||
"""
|
||||
|
||||
def __init__(self, method):
|
||||
self._method = method
|
||||
self._options = 0
|
||||
|
||||
def set_options(self, options):
|
||||
self._options |= options
|
||||
|
||||
def use_certificate_file(self, fileName):
|
||||
pass
|
||||
|
||||
def use_privatekey_file(self, fileName):
|
||||
pass
|
||||
|
||||
|
||||
class DefaultOpenSSLContextFactoryTests(TestCase):
|
||||
"""
|
||||
Tests for L{ssl.DefaultOpenSSLContextFactory}.
|
||||
"""
|
||||
|
||||
if interfaces.IReactorSSL(reactor, None) is None:
|
||||
skip = "Reactor does not support SSL, cannot run SSL tests"
|
||||
|
||||
def setUp(self):
|
||||
# pyOpenSSL Context objects aren't introspectable enough. Pass in
|
||||
# an alternate context factory so we can inspect what is done to it.
|
||||
self.contextFactory = ssl.DefaultOpenSSLContextFactory(
|
||||
certPath, certPath, _contextFactory=FakeContext
|
||||
)
|
||||
self.context = self.contextFactory.getContext()
|
||||
|
||||
def test_method(self):
|
||||
"""
|
||||
L{ssl.DefaultOpenSSLContextFactory.getContext} returns an SSL context
|
||||
which can use SSLv3 or TLSv1 but not SSLv2.
|
||||
"""
|
||||
# TLS_METHOD allows for negotiating multiple versions of TLS
|
||||
self.assertEqual(self.context._method, SSL.TLS_METHOD)
|
||||
|
||||
# OP_NO_SSLv2 disables SSLv2 support
|
||||
self.assertEqual(self.context._options & SSL.OP_NO_SSLv2, SSL.OP_NO_SSLv2)
|
||||
|
||||
# Make sure TLSv1.2 isn't disabled though.
|
||||
self.assertFalse(self.context._options & SSL.OP_NO_TLSv1_2)
|
||||
|
||||
def test_missingCertificateFile(self):
|
||||
"""
|
||||
Instantiating L{ssl.DefaultOpenSSLContextFactory} with a certificate
|
||||
filename which does not identify an existing file results in the
|
||||
initializer raising L{OpenSSL.SSL.Error}.
|
||||
"""
|
||||
self.assertRaises(
|
||||
SSL.Error, ssl.DefaultOpenSSLContextFactory, certPath, self.mktemp()
|
||||
)
|
||||
|
||||
def test_missingPrivateKeyFile(self):
|
||||
"""
|
||||
Instantiating L{ssl.DefaultOpenSSLContextFactory} with a private key
|
||||
filename which does not identify an existing file results in the
|
||||
initializer raising L{OpenSSL.SSL.Error}.
|
||||
"""
|
||||
self.assertRaises(
|
||||
SSL.Error, ssl.DefaultOpenSSLContextFactory, self.mktemp(), certPath
|
||||
)
|
||||
|
||||
|
||||
class ClientContextFactoryTests(TestCase):
|
||||
"""
|
||||
Tests for L{ssl.ClientContextFactory}.
|
||||
"""
|
||||
|
||||
if interfaces.IReactorSSL(reactor, None) is None:
|
||||
skip = "Reactor does not support SSL, cannot run SSL tests"
|
||||
|
||||
def setUp(self):
|
||||
self.contextFactory = ssl.ClientContextFactory()
|
||||
self.contextFactory._contextFactory = FakeContext
|
||||
self.context = self.contextFactory.getContext()
|
||||
|
||||
def test_method(self):
|
||||
"""
|
||||
L{ssl.ClientContextFactory.getContext} returns a context which can use
|
||||
TLSv1.2 or 1.3 but nothing earlier.
|
||||
"""
|
||||
self.assertEqual(self.context._method, SSL.TLS_METHOD)
|
||||
self.assertEqual(self.context._options & SSL.OP_NO_SSLv2, SSL.OP_NO_SSLv2)
|
||||
self.assertTrue(self.context._options & SSL.OP_NO_SSLv3)
|
||||
self.assertTrue(self.context._options & SSL.OP_NO_TLSv1)
|
||||
3430
.venv/lib/python3.12/site-packages/twisted/test/test_sslverify.py
Normal file
3430
.venv/lib/python3.12/site-packages/twisted/test/test_sslverify.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,81 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
|
||||
"""
|
||||
Test cases for twisted.protocols.stateful
|
||||
"""
|
||||
|
||||
from struct import calcsize, pack, unpack
|
||||
|
||||
from twisted.protocols.stateful import StatefulProtocol
|
||||
from twisted.protocols.test import test_basic
|
||||
from twisted.trial.unittest import TestCase
|
||||
|
||||
|
||||
class MyInt32StringReceiver(StatefulProtocol):
|
||||
"""
|
||||
A stateful Int32StringReceiver.
|
||||
"""
|
||||
|
||||
MAX_LENGTH = 99999
|
||||
structFormat = "!I"
|
||||
prefixLength = calcsize(structFormat)
|
||||
|
||||
def getInitialState(self):
|
||||
return self._getHeader, 4
|
||||
|
||||
def lengthLimitExceeded(self, length):
|
||||
self.transport.loseConnection()
|
||||
|
||||
def _getHeader(self, msg):
|
||||
(length,) = unpack("!i", msg)
|
||||
if length > self.MAX_LENGTH:
|
||||
self.lengthLimitExceeded(length)
|
||||
return
|
||||
return self._getString, length
|
||||
|
||||
def _getString(self, msg):
|
||||
self.stringReceived(msg)
|
||||
return self._getHeader, 4
|
||||
|
||||
def stringReceived(self, msg):
|
||||
"""
|
||||
Override this.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def sendString(self, data):
|
||||
"""
|
||||
Send an int32-prefixed string to the other end of the connection.
|
||||
"""
|
||||
self.transport.write(pack(self.structFormat, len(data)) + data)
|
||||
|
||||
|
||||
class TestInt32(MyInt32StringReceiver):
|
||||
def connectionMade(self):
|
||||
self.received = []
|
||||
|
||||
def stringReceived(self, s):
|
||||
self.received.append(s)
|
||||
|
||||
MAX_LENGTH = 50
|
||||
closed = 0
|
||||
|
||||
def connectionLost(self, reason):
|
||||
self.closed = 1
|
||||
|
||||
|
||||
class Int32Tests(TestCase, test_basic.IntNTestCaseMixin):
|
||||
protocol = TestInt32
|
||||
strings = [b"a", b"b" * 16]
|
||||
illegalStrings = [b"\x10\x00\x00\x00aaaaaa"]
|
||||
partialStrings = [b"\x00\x00\x00", b"hello there", b""]
|
||||
|
||||
def test_bigReceive(self):
|
||||
r = self.getProtocol()
|
||||
big = b""
|
||||
for s in self.strings * 4:
|
||||
big += pack("!i", len(s)) + s
|
||||
r.dataReceived(big)
|
||||
self.assertEqual(r.received, self.strings * 4)
|
||||
406
.venv/lib/python3.12/site-packages/twisted/test/test_stdio.py
Normal file
406
.venv/lib/python3.12/site-packages/twisted/test/test_stdio.py
Normal file
@@ -0,0 +1,406 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.internet.stdio}.
|
||||
|
||||
@var properEnv: A copy of L{os.environ} which has L{bytes} keys/values on POSIX
|
||||
platforms and native L{str} keys/values on Windows.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, Callable
|
||||
from unittest import skipIf
|
||||
|
||||
from twisted.internet import defer, error, protocol, reactor, stdio
|
||||
from twisted.internet.interfaces import IProcessTransport, IReactorProcess
|
||||
from twisted.internet.protocol import ProcessProtocol
|
||||
from twisted.python import filepath, log
|
||||
from twisted.python.failure import Failure
|
||||
from twisted.python.reflect import requireModule
|
||||
from twisted.python.runtime import platform
|
||||
from twisted.test.test_tcp import ConnectionLostNotifyingProtocol
|
||||
from twisted.trial.unittest import SkipTest, TestCase
|
||||
|
||||
# A short string which is intended to appear here and nowhere else,
|
||||
# particularly not in any random garbage output CPython unavoidable
|
||||
# generates (such as in warning text and so forth). This is searched
|
||||
# for in the output from stdio_test_lastwrite and if it is found at
|
||||
# the end, the functionality works.
|
||||
UNIQUE_LAST_WRITE_STRING = b"xyz123abc Twisted is great!"
|
||||
|
||||
properEnv = dict(os.environ)
|
||||
properEnv["PYTHONPATH"] = os.pathsep.join(sys.path)
|
||||
|
||||
|
||||
class StandardIOTestProcessProtocol(protocol.ProcessProtocol):
|
||||
"""
|
||||
Test helper for collecting output from a child process and notifying
|
||||
something when it exits.
|
||||
|
||||
@ivar onConnection: A L{defer.Deferred} which will be called back with
|
||||
L{None} when the connection to the child process is established.
|
||||
|
||||
@ivar onCompletion: A L{defer.Deferred} which will be errbacked with the
|
||||
failure associated with the child process exiting when it exits.
|
||||
|
||||
@ivar onDataReceived: A L{defer.Deferred} which will be called back with
|
||||
this instance whenever C{childDataReceived} is called, or L{None} to
|
||||
suppress these callbacks.
|
||||
|
||||
@ivar data: A C{dict} mapping file descriptors to strings containing all
|
||||
bytes received from the child process on each file descriptor.
|
||||
"""
|
||||
|
||||
onDataReceived: defer.Deferred[None] | None = None
|
||||
transport: IProcessTransport
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.onConnection: defer.Deferred[None] = defer.Deferred()
|
||||
self.onCompletion: defer.Deferred[None] = defer.Deferred()
|
||||
self.data: dict[str, bytes] = {}
|
||||
|
||||
def connectionMade(self):
|
||||
self.onConnection.callback(None)
|
||||
|
||||
def childDataReceived(self, name, bytes):
|
||||
"""
|
||||
Record all bytes received from the child process in the C{data}
|
||||
dictionary. Fire C{onDataReceived} if it is not L{None}.
|
||||
"""
|
||||
self.data[name] = self.data.get(name, b"") + bytes
|
||||
if self.onDataReceived is not None:
|
||||
d, self.onDataReceived = self.onDataReceived, None
|
||||
d.callback(self)
|
||||
|
||||
def processEnded(self, reason):
|
||||
self.onCompletion.callback(reason)
|
||||
|
||||
|
||||
class StandardInputOutputTests(TestCase):
|
||||
if platform.isWindows() and requireModule("win32process") is None:
|
||||
skip = (
|
||||
"On windows, spawnProcess is not available in the "
|
||||
"absence of win32process."
|
||||
)
|
||||
|
||||
def _spawnProcess(
|
||||
self, proto: ProcessProtocol, sibling: str | bytes, *args: str, **kw: Any
|
||||
) -> IProcessTransport:
|
||||
"""
|
||||
Launch a child Python process and communicate with it using the given
|
||||
ProcessProtocol.
|
||||
|
||||
@param proto: A L{ProcessProtocol} instance which will be connected to
|
||||
the child process.
|
||||
|
||||
@param sibling: The basename of a file containing the Python program to
|
||||
run in the child process.
|
||||
|
||||
@param *args: strings which will be passed to the child process on the
|
||||
command line as C{argv[2:]}.
|
||||
|
||||
@param **kw: additional arguments to pass to L{reactor.spawnProcess}.
|
||||
|
||||
@return: The L{IProcessTransport} provider for the spawned process.
|
||||
"""
|
||||
if isinstance(sibling, bytes):
|
||||
sibling = sibling.decode()
|
||||
procargs = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"twisted.test." + sibling,
|
||||
reactor.__class__.__module__,
|
||||
] + list(args)
|
||||
return IReactorProcess(reactor).spawnProcess(
|
||||
proto, sys.executable, procargs, env=properEnv, **kw
|
||||
)
|
||||
|
||||
def _requireFailure(
|
||||
self, d: defer.Deferred[None], callback: Callable[[Failure], object]
|
||||
) -> defer.Deferred[None]:
|
||||
def cb(result):
|
||||
self.fail(f"Process terminated with non-Failure: {result!r}")
|
||||
|
||||
def eb(err):
|
||||
return callback(err)
|
||||
|
||||
return d.addCallbacks(cb, eb)
|
||||
|
||||
def test_loseConnection(self):
|
||||
"""
|
||||
Verify that a protocol connected to L{StandardIO} can disconnect
|
||||
itself using C{transport.loseConnection}.
|
||||
"""
|
||||
errorLogFile = self.mktemp()
|
||||
log.msg("Child process logging to " + errorLogFile)
|
||||
p = StandardIOTestProcessProtocol()
|
||||
d = p.onCompletion
|
||||
self._spawnProcess(p, b"stdio_test_loseconn", errorLogFile)
|
||||
|
||||
def processEnded(reason):
|
||||
# Copy the child's log to ours so it's more visible.
|
||||
with open(errorLogFile) as f:
|
||||
for line in f:
|
||||
log.msg("Child logged: " + line.rstrip())
|
||||
|
||||
self.failIfIn(1, p.data)
|
||||
reason.trap(error.ProcessDone)
|
||||
|
||||
return self._requireFailure(d, processEnded)
|
||||
|
||||
def exampleOutputsAndZeroExitCode(
|
||||
self, example: str, out: bool = False
|
||||
) -> defer.Deferred[None]:
|
||||
errorLogFile = self.mktemp()
|
||||
p = StandardIOTestProcessProtocol()
|
||||
p.onDataReceived = defer.Deferred()
|
||||
|
||||
def cbBytes(ignored: None) -> defer.Deferred[None]:
|
||||
d = p.onCompletion
|
||||
if out:
|
||||
p.transport.closeStdout()
|
||||
else:
|
||||
p.transport.closeStdin()
|
||||
return d
|
||||
|
||||
p.onDataReceived.addCallback(cbBytes)
|
||||
|
||||
def processEnded(reason):
|
||||
reason.trap(error.ProcessDone)
|
||||
|
||||
d = self._requireFailure(p.onDataReceived, processEnded)
|
||||
|
||||
self._spawnProcess(p, example, errorLogFile)
|
||||
return d
|
||||
|
||||
def test_readConnectionLost(self) -> defer.Deferred[None]:
|
||||
"""
|
||||
When stdin is closed and the protocol connected to it implements
|
||||
L{IHalfCloseableProtocol}, the protocol's C{readConnectionLost} method
|
||||
is called.
|
||||
"""
|
||||
return self.exampleOutputsAndZeroExitCode("stdio_test_halfclose")
|
||||
|
||||
def test_buggyReadConnectionLost(self) -> defer.Deferred[None]:
|
||||
"""
|
||||
When stdin is closed and the protocol connnected to it implements
|
||||
L{IHalfCloseableProtocol} but its C{readConnectionLost} method raises
|
||||
an exception its regular C{connectionLost} method will be called.
|
||||
"""
|
||||
return self.exampleOutputsAndZeroExitCode("stdio_test_halfclose_buggy")
|
||||
|
||||
def test_buggyWriteConnectionLost(self) -> defer.Deferred[None]:
|
||||
"""
|
||||
When stdin is closed and the protocol connnected to it implements
|
||||
L{IHalfCloseableProtocol} but its C{readConnectionLost} method raises
|
||||
an exception its regular C{connectionLost} method will be called.
|
||||
"""
|
||||
return self.exampleOutputsAndZeroExitCode(
|
||||
"stdio_test_halfclose_buggy_write", out=True
|
||||
)
|
||||
|
||||
def test_lastWriteReceived(self):
|
||||
"""
|
||||
Verify that a write made directly to stdout using L{os.write}
|
||||
after StandardIO has finished is reliably received by the
|
||||
process reading that stdout.
|
||||
"""
|
||||
p = StandardIOTestProcessProtocol()
|
||||
|
||||
# Note: the macOS bug which prompted the addition of this test
|
||||
# is an apparent race condition involving non-blocking PTYs.
|
||||
# Delaying the parent process significantly increases the
|
||||
# likelihood of the race going the wrong way. If you need to
|
||||
# fiddle with this code at all, uncommenting the next line
|
||||
# will likely make your life much easier. It is commented out
|
||||
# because it makes the test quite slow.
|
||||
|
||||
# p.onConnection.addCallback(lambda ign: __import__('time').sleep(5))
|
||||
|
||||
try:
|
||||
self._spawnProcess(
|
||||
p, b"stdio_test_lastwrite", UNIQUE_LAST_WRITE_STRING, usePTY=True
|
||||
)
|
||||
except ValueError as e:
|
||||
# Some platforms don't work with usePTY=True
|
||||
raise SkipTest(str(e))
|
||||
|
||||
def processEnded(reason):
|
||||
"""
|
||||
Asserts that the parent received the bytes written by the child
|
||||
immediately after the child starts.
|
||||
"""
|
||||
self.assertTrue(
|
||||
p.data[1].endswith(UNIQUE_LAST_WRITE_STRING),
|
||||
f"Received {p.data!r} from child, did not find expected bytes.",
|
||||
)
|
||||
reason.trap(error.ProcessDone)
|
||||
|
||||
return self._requireFailure(p.onCompletion, processEnded)
|
||||
|
||||
def test_hostAndPeer(self):
|
||||
"""
|
||||
Verify that the transport of a protocol connected to L{StandardIO}
|
||||
has C{getHost} and C{getPeer} methods.
|
||||
"""
|
||||
p = StandardIOTestProcessProtocol()
|
||||
d = p.onCompletion
|
||||
self._spawnProcess(p, b"stdio_test_hostpeer")
|
||||
|
||||
def processEnded(reason):
|
||||
host, peer = p.data[1].splitlines()
|
||||
self.assertTrue(host)
|
||||
self.assertTrue(peer)
|
||||
reason.trap(error.ProcessDone)
|
||||
|
||||
return self._requireFailure(d, processEnded)
|
||||
|
||||
def test_write(self):
|
||||
"""
|
||||
Verify that the C{write} method of the transport of a protocol
|
||||
connected to L{StandardIO} sends bytes to standard out.
|
||||
"""
|
||||
p = StandardIOTestProcessProtocol()
|
||||
d = p.onCompletion
|
||||
|
||||
self._spawnProcess(p, b"stdio_test_write")
|
||||
|
||||
def processEnded(reason):
|
||||
self.assertEqual(p.data[1], b"ok!")
|
||||
reason.trap(error.ProcessDone)
|
||||
|
||||
return self._requireFailure(d, processEnded)
|
||||
|
||||
def test_writeSequence(self):
|
||||
"""
|
||||
Verify that the C{writeSequence} method of the transport of a
|
||||
protocol connected to L{StandardIO} sends bytes to standard out.
|
||||
"""
|
||||
p = StandardIOTestProcessProtocol()
|
||||
d = p.onCompletion
|
||||
|
||||
self._spawnProcess(p, b"stdio_test_writeseq")
|
||||
|
||||
def processEnded(reason):
|
||||
self.assertEqual(p.data[1], b"ok!")
|
||||
reason.trap(error.ProcessDone)
|
||||
|
||||
return self._requireFailure(d, processEnded)
|
||||
|
||||
def _junkPath(self):
|
||||
junkPath = self.mktemp()
|
||||
with open(junkPath, "wb") as junkFile:
|
||||
for i in range(1024):
|
||||
junkFile.write(b"%d\n" % (i,))
|
||||
return junkPath
|
||||
|
||||
def test_producer(self):
|
||||
"""
|
||||
Verify that the transport of a protocol connected to L{StandardIO}
|
||||
is a working L{IProducer} provider.
|
||||
"""
|
||||
p = StandardIOTestProcessProtocol()
|
||||
d = p.onCompletion
|
||||
|
||||
written = []
|
||||
toWrite = list(range(100))
|
||||
|
||||
def connectionMade(ign):
|
||||
if toWrite:
|
||||
written.append(b"%d\n" % (toWrite.pop(),))
|
||||
proc.write(written[-1])
|
||||
reactor.callLater(0.01, connectionMade, None)
|
||||
|
||||
proc = self._spawnProcess(p, b"stdio_test_producer")
|
||||
|
||||
p.onConnection.addCallback(connectionMade)
|
||||
|
||||
def processEnded(reason):
|
||||
self.assertEqual(p.data[1], b"".join(written))
|
||||
self.assertFalse(
|
||||
toWrite, "Connection lost with %d writes left to go." % (len(toWrite),)
|
||||
)
|
||||
reason.trap(error.ProcessDone)
|
||||
|
||||
return self._requireFailure(d, processEnded)
|
||||
|
||||
def test_consumer(self):
|
||||
"""
|
||||
Verify that the transport of a protocol connected to L{StandardIO}
|
||||
is a working L{IConsumer} provider.
|
||||
"""
|
||||
p = StandardIOTestProcessProtocol()
|
||||
d = p.onCompletion
|
||||
|
||||
junkPath = self._junkPath()
|
||||
|
||||
self._spawnProcess(p, b"stdio_test_consumer", junkPath)
|
||||
|
||||
def processEnded(reason):
|
||||
with open(junkPath, "rb") as f:
|
||||
self.assertEqual(p.data[1], f.read())
|
||||
reason.trap(error.ProcessDone)
|
||||
|
||||
return self._requireFailure(d, processEnded)
|
||||
|
||||
@skipIf(
|
||||
platform.isWindows(),
|
||||
"StandardIO does not accept stdout as an argument to Windows. "
|
||||
"Testing redirection to a file is therefore harder.",
|
||||
)
|
||||
def test_normalFileStandardOut(self):
|
||||
"""
|
||||
If L{StandardIO} is created with a file descriptor which refers to a
|
||||
normal file (ie, a file from the filesystem), L{StandardIO.write}
|
||||
writes bytes to that file. In particular, it does not immediately
|
||||
consider the file closed or call its protocol's C{connectionLost}
|
||||
method.
|
||||
"""
|
||||
onConnLost = defer.Deferred()
|
||||
proto = ConnectionLostNotifyingProtocol(onConnLost)
|
||||
path = filepath.FilePath(self.mktemp())
|
||||
self.normal = normal = path.open("wb")
|
||||
self.addCleanup(normal.close)
|
||||
|
||||
kwargs = dict(stdout=normal.fileno())
|
||||
if not platform.isWindows():
|
||||
# Make a fake stdin so that StandardIO doesn't mess with the *real*
|
||||
# stdin.
|
||||
r, w = os.pipe()
|
||||
self.addCleanup(os.close, r)
|
||||
self.addCleanup(os.close, w)
|
||||
kwargs["stdin"] = r
|
||||
connection = stdio.StandardIO(proto, **kwargs)
|
||||
|
||||
# The reactor needs to spin a bit before it might have incorrectly
|
||||
# decided stdout is closed. Use this counter to keep track of how
|
||||
# much we've let it spin. If it closes before we expected, this
|
||||
# counter will have a value that's too small and we'll know.
|
||||
howMany = 5
|
||||
count = itertools.count()
|
||||
|
||||
def spin():
|
||||
for value in count:
|
||||
if value == howMany:
|
||||
connection.loseConnection()
|
||||
return
|
||||
connection.write(b"%d" % (value,))
|
||||
break
|
||||
reactor.callLater(0, spin)
|
||||
|
||||
reactor.callLater(0, spin)
|
||||
|
||||
# Once the connection is lost, make sure the counter is at the
|
||||
# appropriate value.
|
||||
def cbLost(reason):
|
||||
self.assertEqual(next(count), howMany + 1)
|
||||
self.assertEqual(
|
||||
path.getContent(), b"".join(b"%d" % (i,) for i in range(howMany))
|
||||
)
|
||||
|
||||
onConnLost.addCallback(cbLost)
|
||||
return onConnLost
|
||||
157
.venv/lib/python3.12/site-packages/twisted/test/test_strerror.py
Normal file
157
.venv/lib/python3.12/site-packages/twisted/test/test_strerror.py
Normal file
@@ -0,0 +1,157 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Test strerror
|
||||
"""
|
||||
|
||||
import os
|
||||
import socket
|
||||
from unittest import skipIf
|
||||
|
||||
from twisted.internet.tcp import ECONNABORTED
|
||||
from twisted.python.runtime import platform
|
||||
from twisted.python.win32 import _ErrorFormatter, formatError
|
||||
from twisted.trial.unittest import TestCase
|
||||
|
||||
|
||||
class _MyWindowsException(OSError):
|
||||
"""
|
||||
An exception type like L{ctypes.WinError}, but available on all platforms.
|
||||
"""
|
||||
|
||||
|
||||
class ErrorFormatingTests(TestCase):
|
||||
"""
|
||||
Tests for C{_ErrorFormatter.formatError}.
|
||||
"""
|
||||
|
||||
probeErrorCode = ECONNABORTED
|
||||
probeMessage = "correct message value"
|
||||
|
||||
def test_strerrorFormatting(self):
|
||||
"""
|
||||
L{_ErrorFormatter.formatError} should use L{os.strerror} to format
|
||||
error messages if it is constructed without any better mechanism.
|
||||
"""
|
||||
formatter = _ErrorFormatter(None, None, None)
|
||||
message = formatter.formatError(self.probeErrorCode)
|
||||
self.assertEqual(message, os.strerror(self.probeErrorCode))
|
||||
|
||||
def test_emptyErrorTab(self):
|
||||
"""
|
||||
L{_ErrorFormatter.formatError} should use L{os.strerror} to format
|
||||
error messages if it is constructed with only an error tab which does
|
||||
not contain the error code it is called with.
|
||||
"""
|
||||
error = 1
|
||||
# Sanity check
|
||||
self.assertNotEqual(self.probeErrorCode, error)
|
||||
formatter = _ErrorFormatter(None, None, {error: "wrong message"})
|
||||
message = formatter.formatError(self.probeErrorCode)
|
||||
self.assertEqual(message, os.strerror(self.probeErrorCode))
|
||||
|
||||
def test_errorTab(self):
|
||||
"""
|
||||
L{_ErrorFormatter.formatError} should use C{errorTab} if it is supplied
|
||||
and contains the requested error code.
|
||||
"""
|
||||
formatter = _ErrorFormatter(
|
||||
None, None, {self.probeErrorCode: self.probeMessage}
|
||||
)
|
||||
message = formatter.formatError(self.probeErrorCode)
|
||||
self.assertEqual(message, self.probeMessage)
|
||||
|
||||
def test_formatMessage(self):
|
||||
"""
|
||||
L{_ErrorFormatter.formatError} should return the return value of
|
||||
C{formatMessage} if it is supplied.
|
||||
"""
|
||||
formatCalls = []
|
||||
|
||||
def formatMessage(errorCode):
|
||||
formatCalls.append(errorCode)
|
||||
return self.probeMessage
|
||||
|
||||
formatter = _ErrorFormatter(
|
||||
None, formatMessage, {self.probeErrorCode: "wrong message"}
|
||||
)
|
||||
message = formatter.formatError(self.probeErrorCode)
|
||||
self.assertEqual(message, self.probeMessage)
|
||||
self.assertEqual(formatCalls, [self.probeErrorCode])
|
||||
|
||||
def test_winError(self):
|
||||
"""
|
||||
L{_ErrorFormatter.formatError} should return the message argument from
|
||||
the exception L{winError} returns, if L{winError} is supplied.
|
||||
"""
|
||||
winCalls = []
|
||||
|
||||
def winError(errorCode):
|
||||
winCalls.append(errorCode)
|
||||
return _MyWindowsException(errorCode, self.probeMessage)
|
||||
|
||||
formatter = _ErrorFormatter(
|
||||
winError,
|
||||
lambda error: "formatMessage: wrong message",
|
||||
{self.probeErrorCode: "errorTab: wrong message"},
|
||||
)
|
||||
message = formatter.formatError(self.probeErrorCode)
|
||||
self.assertEqual(message, self.probeMessage)
|
||||
|
||||
@skipIf(platform.getType() != "win32", "Test will run only on Windows.")
|
||||
def test_fromEnvironment(self):
|
||||
"""
|
||||
L{_ErrorFormatter.fromEnvironment} should create an L{_ErrorFormatter}
|
||||
instance with attributes populated from available modules.
|
||||
"""
|
||||
formatter = _ErrorFormatter.fromEnvironment()
|
||||
|
||||
if formatter.winError is not None:
|
||||
from ctypes import WinError
|
||||
|
||||
self.assertEqual(
|
||||
formatter.formatError(self.probeErrorCode),
|
||||
WinError(self.probeErrorCode).strerror,
|
||||
)
|
||||
formatter.winError = None
|
||||
|
||||
if formatter.formatMessage is not None:
|
||||
from win32api import FormatMessage
|
||||
|
||||
self.assertEqual(
|
||||
formatter.formatError(self.probeErrorCode),
|
||||
FormatMessage(self.probeErrorCode),
|
||||
)
|
||||
formatter.formatMessage = None
|
||||
|
||||
if formatter.errorTab is not None:
|
||||
from socket import errorTab
|
||||
|
||||
self.assertEqual(
|
||||
formatter.formatError(self.probeErrorCode),
|
||||
errorTab[self.probeErrorCode],
|
||||
)
|
||||
|
||||
@skipIf(platform.getType() != "win32", "Test will run only on Windows.")
|
||||
def test_correctLookups(self):
|
||||
"""
|
||||
Given a known-good errno, make sure that formatMessage gives results
|
||||
matching either C{socket.errorTab}, C{ctypes.WinError}, or
|
||||
C{win32api.FormatMessage}.
|
||||
"""
|
||||
acceptable = [socket.errorTab[ECONNABORTED]]
|
||||
try:
|
||||
from ctypes import WinError
|
||||
|
||||
acceptable.append(WinError(ECONNABORTED).strerror)
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
from win32api import FormatMessage
|
||||
|
||||
acceptable.append(FormatMessage(ECONNABORTED))
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
self.assertIn(formatError(ECONNABORTED), acceptable)
|
||||
@@ -0,0 +1,49 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.application.strports}.
|
||||
"""
|
||||
|
||||
|
||||
from twisted.application import internet, strports
|
||||
from twisted.internet.endpoints import TCP4ServerEndpoint
|
||||
from twisted.internet.protocol import Factory
|
||||
from twisted.trial.unittest import TestCase
|
||||
|
||||
|
||||
class ServiceTests(TestCase):
|
||||
"""
|
||||
Tests for L{strports.service}.
|
||||
"""
|
||||
|
||||
def test_service(self):
|
||||
"""
|
||||
L{strports.service} returns a L{StreamServerEndpointService}
|
||||
constructed with an endpoint produced from
|
||||
L{endpoint.serverFromString}, using the same syntax.
|
||||
"""
|
||||
reactor = object() # the cake is a lie
|
||||
aFactory = Factory()
|
||||
aGoodPort = 1337
|
||||
svc = strports.service("tcp:" + str(aGoodPort), aFactory, reactor=reactor)
|
||||
self.assertIsInstance(svc, internet.StreamServerEndpointService)
|
||||
|
||||
# See twisted.application.test.test_internet.EndpointServiceTests.
|
||||
# test_synchronousRaiseRaisesSynchronously
|
||||
self.assertTrue(svc._raiseSynchronously)
|
||||
self.assertIsInstance(svc.endpoint, TCP4ServerEndpoint)
|
||||
# Maybe we should implement equality for endpoints.
|
||||
self.assertEqual(svc.endpoint._port, aGoodPort)
|
||||
self.assertIs(svc.factory, aFactory)
|
||||
self.assertIs(svc.endpoint._reactor, reactor)
|
||||
|
||||
def test_serviceDefaultReactor(self):
|
||||
"""
|
||||
L{strports.service} will use the default reactor when none is provided
|
||||
as an argument.
|
||||
"""
|
||||
from twisted.internet import reactor as globalReactor
|
||||
|
||||
aService = strports.service("tcp:80", None)
|
||||
self.assertIs(aService.endpoint._reactor, globalReactor)
|
||||
1467
.venv/lib/python3.12/site-packages/twisted/test/test_task.py
Normal file
1467
.venv/lib/python3.12/site-packages/twisted/test/test_task.py
Normal file
File diff suppressed because it is too large
Load Diff
1870
.venv/lib/python3.12/site-packages/twisted/test/test_tcp.py
Normal file
1870
.venv/lib/python3.12/site-packages/twisted/test/test_tcp.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,380 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Whitebox tests for TCP APIs.
|
||||
"""
|
||||
|
||||
|
||||
import errno
|
||||
import os
|
||||
import socket
|
||||
|
||||
try:
|
||||
import resource
|
||||
except ImportError:
|
||||
resource = None # type: ignore[assignment]
|
||||
|
||||
from unittest import skipIf
|
||||
|
||||
from twisted.internet import interfaces, reactor
|
||||
from twisted.internet.defer import gatherResults, maybeDeferred
|
||||
from twisted.internet.protocol import Protocol, ServerFactory
|
||||
from twisted.internet.tcp import (
|
||||
_ACCEPT_ERRORS,
|
||||
EAGAIN,
|
||||
ECONNABORTED,
|
||||
EINPROGRESS,
|
||||
EMFILE,
|
||||
ENFILE,
|
||||
ENOBUFS,
|
||||
ENOMEM,
|
||||
EPERM,
|
||||
EWOULDBLOCK,
|
||||
Port,
|
||||
)
|
||||
from twisted.python import log
|
||||
from twisted.python.runtime import platform
|
||||
from twisted.trial.unittest import TestCase
|
||||
|
||||
|
||||
@skipIf(
|
||||
not interfaces.IReactorFDSet.providedBy(reactor),
|
||||
"This test only applies to reactors that implement IReactorFDset",
|
||||
)
|
||||
class PlatformAssumptionsTests(TestCase):
|
||||
"""
|
||||
Test assumptions about platform behaviors.
|
||||
"""
|
||||
|
||||
socketLimit = 8192
|
||||
|
||||
def setUp(self):
|
||||
self.openSockets = []
|
||||
if resource is not None:
|
||||
# On some buggy platforms we might leak FDs, and the test will
|
||||
# fail creating the initial two sockets we *do* want to
|
||||
# succeed. So, we make the soft limit the current number of fds
|
||||
# plus two more (for the two sockets we want to succeed). If we've
|
||||
# leaked too many fds for that to work, there's nothing we can
|
||||
# do.
|
||||
from twisted.internet.process import _listOpenFDs
|
||||
|
||||
newLimit = len(_listOpenFDs()) + 2
|
||||
self.originalFileLimit = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||
resource.setrlimit(
|
||||
resource.RLIMIT_NOFILE, (newLimit, self.originalFileLimit[1])
|
||||
)
|
||||
self.socketLimit = newLimit + 100
|
||||
|
||||
def tearDown(self):
|
||||
while self.openSockets:
|
||||
self.openSockets.pop().close()
|
||||
if resource is not None:
|
||||
# `macOS` implicitly lowers the hard limit in the setrlimit call
|
||||
# above. Retrieve the new hard limit to pass in to this
|
||||
# setrlimit call, so that it doesn't give us a permission denied
|
||||
# error.
|
||||
currentHardLimit = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
|
||||
newSoftLimit = min(self.originalFileLimit[0], currentHardLimit)
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (newSoftLimit, currentHardLimit))
|
||||
|
||||
def socket(self):
|
||||
"""
|
||||
Create and return a new socket object, also tracking it so it can be
|
||||
closed in the test tear down.
|
||||
"""
|
||||
s = socket.socket()
|
||||
self.openSockets.append(s)
|
||||
return s
|
||||
|
||||
@skipIf(
|
||||
platform.getType() == "win32",
|
||||
"Windows requires an unacceptably large amount of resources to "
|
||||
"provoke this behavior in the naive manner.",
|
||||
)
|
||||
def test_acceptOutOfFiles(self):
|
||||
"""
|
||||
Test that the platform accept(2) call fails with either L{EMFILE} or
|
||||
L{ENOBUFS} when there are too many file descriptors open.
|
||||
"""
|
||||
# Make a server to which to connect
|
||||
port = self.socket()
|
||||
port.bind(("127.0.0.1", 0))
|
||||
serverPortNumber = port.getsockname()[1]
|
||||
port.listen(5)
|
||||
|
||||
# Make a client to use to connect to the server
|
||||
client = self.socket()
|
||||
client.setblocking(False)
|
||||
|
||||
# Use up all the rest of the file descriptors.
|
||||
for i in range(self.socketLimit):
|
||||
try:
|
||||
self.socket()
|
||||
except OSError as e:
|
||||
if e.args[0] in (EMFILE, ENOBUFS):
|
||||
# The desired state has been achieved.
|
||||
break
|
||||
else:
|
||||
# Some unexpected error occurred.
|
||||
raise
|
||||
else:
|
||||
self.fail("Could provoke neither EMFILE nor ENOBUFS from platform.")
|
||||
|
||||
# Non-blocking connect is supposed to fail, but this is not true
|
||||
# everywhere (e.g. freeBSD)
|
||||
self.assertIn(
|
||||
client.connect_ex(("127.0.0.1", serverPortNumber)), (0, EINPROGRESS)
|
||||
)
|
||||
|
||||
# Make sure that the accept call fails in the way we expect.
|
||||
exc = self.assertRaises(socket.error, port.accept)
|
||||
self.assertIn(exc.args[0], (EMFILE, ENOBUFS))
|
||||
|
||||
|
||||
@skipIf(
|
||||
not interfaces.IReactorFDSet.providedBy(reactor),
|
||||
"This test only applies to reactors that implement IReactorFDset",
|
||||
)
|
||||
class SelectReactorTests(TestCase):
|
||||
"""
|
||||
Tests for select-specific failure conditions.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
self.ports = []
|
||||
self.messages = []
|
||||
log.addObserver(self.messages.append)
|
||||
|
||||
def tearDown(self):
|
||||
log.removeObserver(self.messages.append)
|
||||
return gatherResults([maybeDeferred(p.stopListening) for p in self.ports])
|
||||
|
||||
def port(self, portNumber, factory, interface):
|
||||
"""
|
||||
Create, start, and return a new L{Port}, also tracking it so it can
|
||||
be stopped in the test tear down.
|
||||
"""
|
||||
p = Port(portNumber, factory, interface=interface)
|
||||
p.startListening()
|
||||
self.ports.append(p)
|
||||
return p
|
||||
|
||||
def _acceptFailureTest(self, socketErrorNumber):
|
||||
"""
|
||||
Test behavior in the face of an exception from C{accept(2)}.
|
||||
|
||||
On any exception which indicates the platform is unable or unwilling
|
||||
to allocate further resources to us, the existing port should remain
|
||||
listening, a message should be logged, and the exception should not
|
||||
propagate outward from doRead.
|
||||
|
||||
@param socketErrorNumber: The errno to simulate from accept.
|
||||
"""
|
||||
|
||||
class FakeSocket:
|
||||
"""
|
||||
Pretend to be a socket in an overloaded system.
|
||||
"""
|
||||
|
||||
def accept(self):
|
||||
raise OSError(socketErrorNumber, os.strerror(socketErrorNumber))
|
||||
|
||||
factory = ServerFactory()
|
||||
port = self.port(0, factory, interface="127.0.0.1")
|
||||
self.patch(port, "socket", FakeSocket())
|
||||
|
||||
port.doRead()
|
||||
|
||||
expectedFormat = "Could not accept new connection ({acceptError})"
|
||||
expectedErrorCode = errno.errorcode[socketErrorNumber]
|
||||
matchingMessages = [
|
||||
(
|
||||
msg.get("log_format") == expectedFormat
|
||||
and msg.get("acceptError") == expectedErrorCode
|
||||
)
|
||||
for msg in self.messages
|
||||
]
|
||||
self.assertGreater(
|
||||
len(matchingMessages),
|
||||
0,
|
||||
"Log event for failed accept not found in " "%r" % (self.messages,),
|
||||
)
|
||||
|
||||
def test_tooManyFilesFromAccept(self):
|
||||
"""
|
||||
C{accept(2)} can fail with C{EMFILE} when there are too many open file
|
||||
descriptors in the process. Test that this doesn't negatively impact
|
||||
any other existing connections.
|
||||
|
||||
C{EMFILE} mainly occurs on Linux when the open file rlimit is
|
||||
encountered.
|
||||
"""
|
||||
return self._acceptFailureTest(EMFILE)
|
||||
|
||||
def test_noBufferSpaceFromAccept(self):
|
||||
"""
|
||||
Similar to L{test_tooManyFilesFromAccept}, but test the case where
|
||||
C{accept(2)} fails with C{ENOBUFS}.
|
||||
|
||||
This mainly occurs on Windows and FreeBSD, but may be possible on
|
||||
Linux and other platforms as well.
|
||||
"""
|
||||
return self._acceptFailureTest(ENOBUFS)
|
||||
|
||||
def test_connectionAbortedFromAccept(self):
|
||||
"""
|
||||
Similar to L{test_tooManyFilesFromAccept}, but test the case where
|
||||
C{accept(2)} fails with C{ECONNABORTED}.
|
||||
|
||||
It is not clear whether this is actually possible for TCP
|
||||
connections on modern versions of Linux.
|
||||
"""
|
||||
return self._acceptFailureTest(ECONNABORTED)
|
||||
|
||||
@skipIf(platform.getType() == "win32", "Windows accept(2) cannot generate ENFILE")
|
||||
def test_noFilesFromAccept(self):
|
||||
"""
|
||||
Similar to L{test_tooManyFilesFromAccept}, but test the case where
|
||||
C{accept(2)} fails with C{ENFILE}.
|
||||
|
||||
This can occur on Linux when the system has exhausted (!) its supply
|
||||
of inodes.
|
||||
"""
|
||||
return self._acceptFailureTest(ENFILE)
|
||||
|
||||
@skipIf(platform.getType() == "win32", "Windows accept(2) cannot generate ENOMEM")
|
||||
def test_noMemoryFromAccept(self):
|
||||
"""
|
||||
Similar to L{test_tooManyFilesFromAccept}, but test the case where
|
||||
C{accept(2)} fails with C{ENOMEM}.
|
||||
|
||||
On Linux at least, this can sensibly occur, even in a Python program
|
||||
(which eats memory like no ones business), when memory has become
|
||||
fragmented or low memory has been filled (d_alloc calls
|
||||
kmem_cache_alloc calls kmalloc - kmalloc only allocates out of low
|
||||
memory).
|
||||
"""
|
||||
return self._acceptFailureTest(ENOMEM)
|
||||
|
||||
@skipIf(
|
||||
os.environ.get("INFRASTRUCTURE") == "AZUREPIPELINES",
|
||||
"Hangs on Azure Pipelines due to firewall",
|
||||
)
|
||||
def test_acceptScaling(self):
|
||||
"""
|
||||
L{tcp.Port.doRead} increases the number of consecutive
|
||||
C{accept} calls it performs if all of the previous C{accept}
|
||||
calls succeed; otherwise, it reduces the number to the amount
|
||||
of successful calls.
|
||||
"""
|
||||
factory = ServerFactory()
|
||||
factory.protocol = Protocol
|
||||
port = self.port(0, factory, interface="127.0.0.1")
|
||||
self.addCleanup(port.stopListening)
|
||||
|
||||
clients = []
|
||||
|
||||
def closeAll():
|
||||
for client in clients:
|
||||
client.close()
|
||||
|
||||
self.addCleanup(closeAll)
|
||||
|
||||
def connect():
|
||||
client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
client.connect(("127.0.0.1", port.getHost().port))
|
||||
return client
|
||||
|
||||
clients.append(connect())
|
||||
port.numberAccepts = 1
|
||||
port.doRead()
|
||||
self.assertGreater(port.numberAccepts, 1)
|
||||
|
||||
clients.append(connect())
|
||||
port.doRead()
|
||||
# There was only one outstanding client connection, so only
|
||||
# one accept(2) was possible.
|
||||
self.assertEqual(port.numberAccepts, 1)
|
||||
|
||||
port.doRead()
|
||||
# There were no outstanding client connections, so only one
|
||||
# accept should be tried next.
|
||||
self.assertEqual(port.numberAccepts, 1)
|
||||
|
||||
@skipIf(platform.getType() == "win32", "Windows accept(2) cannot generate EPERM")
|
||||
def test_permissionFailure(self):
|
||||
"""
|
||||
C{accept(2)} returning C{EPERM} is treated as a transient
|
||||
failure and the call retried no more than the maximum number
|
||||
of consecutive C{accept(2)} calls.
|
||||
"""
|
||||
maximumNumberOfAccepts = 123
|
||||
acceptCalls = [0]
|
||||
|
||||
class FakeSocketWithAcceptLimit:
|
||||
"""
|
||||
Pretend to be a socket in an overloaded system whose
|
||||
C{accept} method can only be called
|
||||
C{maximumNumberOfAccepts} times.
|
||||
"""
|
||||
|
||||
def accept(oself):
|
||||
acceptCalls[0] += 1
|
||||
if acceptCalls[0] > maximumNumberOfAccepts:
|
||||
self.fail("Maximum number of accept calls exceeded.")
|
||||
raise OSError(EPERM, os.strerror(EPERM))
|
||||
|
||||
# Verify that FakeSocketWithAcceptLimit.accept() fails the
|
||||
# test if the number of accept calls exceeds the maximum.
|
||||
for _ in range(maximumNumberOfAccepts):
|
||||
self.assertRaises(socket.error, FakeSocketWithAcceptLimit().accept)
|
||||
|
||||
self.assertRaises(self.failureException, FakeSocketWithAcceptLimit().accept)
|
||||
|
||||
acceptCalls = [0]
|
||||
|
||||
factory = ServerFactory()
|
||||
port = self.port(0, factory, interface="127.0.0.1")
|
||||
port.numberAccepts = 123
|
||||
self.patch(port, "socket", FakeSocketWithAcceptLimit())
|
||||
|
||||
# This should not loop infinitely.
|
||||
port.doRead()
|
||||
|
||||
# This is scaled down to 1 because no accept(2)s returned
|
||||
# successfully.
|
||||
self.assertEquals(port.numberAccepts, 1)
|
||||
|
||||
def test_unknownSocketErrorRaise(self):
|
||||
"""
|
||||
A C{socket.error} raised by C{accept(2)} whose C{errno} is
|
||||
unknown to the recovery logic is logged.
|
||||
"""
|
||||
knownErrors = list(_ACCEPT_ERRORS)
|
||||
knownErrors.extend([EAGAIN, EPERM, EWOULDBLOCK])
|
||||
# Windows has object()s stubs for some errnos.
|
||||
unknownAcceptError = (
|
||||
max(error for error in knownErrors if isinstance(error, int)) + 1
|
||||
)
|
||||
|
||||
class FakeSocketWithUnknownAcceptError:
|
||||
"""
|
||||
Pretend to be a socket in an overloaded system whose
|
||||
C{accept} method can only be called
|
||||
C{maximumNumberOfAccepts} times.
|
||||
"""
|
||||
|
||||
def accept(oself):
|
||||
raise OSError(unknownAcceptError, "unknown socket error message")
|
||||
|
||||
factory = ServerFactory()
|
||||
port = self.port(0, factory, interface="127.0.0.1")
|
||||
self.patch(port, "socket", FakeSocketWithUnknownAcceptError())
|
||||
|
||||
port.doRead()
|
||||
|
||||
failures = self.flushLoggedErrors(socket.error)
|
||||
self.assertEqual(1, len(failures))
|
||||
self.assertEqual(failures[0].value.args[0], unknownAcceptError)
|
||||
235
.venv/lib/python3.12/site-packages/twisted/test/test_text.py
Normal file
235
.venv/lib/python3.12/site-packages/twisted/test/test_text.py
Normal file
@@ -0,0 +1,235 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.python.text}.
|
||||
"""
|
||||
|
||||
from io import StringIO
|
||||
|
||||
from twisted.python import text
|
||||
from twisted.trial import unittest
|
||||
|
||||
sampleText = """Every attempt to employ mathematical methods in the study of chemical
|
||||
questions must be considered profoundly irrational and contrary to the
|
||||
spirit of chemistry ... If mathematical analysis should ever hold a
|
||||
prominent place in chemistry - an aberration which is happily almost
|
||||
impossible - it would occasion a rapid and widespread degeneration of that
|
||||
science.
|
||||
|
||||
-- Auguste Comte, Philosophie Positive, Paris, 1838
|
||||
"""
|
||||
|
||||
|
||||
class WrapTests(unittest.TestCase):
|
||||
"""
|
||||
Tests for L{text.greedyWrap}.
|
||||
"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.lineWidth = 72
|
||||
self.sampleSplitText = sampleText.split()
|
||||
self.output = text.wordWrap(sampleText, self.lineWidth)
|
||||
|
||||
def test_wordCount(self) -> None:
|
||||
"""
|
||||
Compare the number of words.
|
||||
"""
|
||||
words = []
|
||||
for line in self.output:
|
||||
words.extend(line.split())
|
||||
wordCount = len(words)
|
||||
sampleTextWordCount = len(self.sampleSplitText)
|
||||
|
||||
self.assertEqual(wordCount, sampleTextWordCount)
|
||||
|
||||
def test_wordMatch(self) -> None:
|
||||
"""
|
||||
Compare the lists of words.
|
||||
"""
|
||||
words = []
|
||||
for line in self.output:
|
||||
words.extend(line.split())
|
||||
|
||||
# Using assertEqual here prints out some
|
||||
# rather too long lists.
|
||||
self.assertTrue(self.sampleSplitText == words)
|
||||
|
||||
def test_lineLength(self) -> None:
|
||||
"""
|
||||
Check the length of the lines.
|
||||
"""
|
||||
failures = []
|
||||
for line in self.output:
|
||||
if not len(line) <= self.lineWidth:
|
||||
failures.append(len(line))
|
||||
|
||||
if failures:
|
||||
self.fail(
|
||||
"%d of %d lines were too long.\n"
|
||||
"%d < %s" % (len(failures), len(self.output), self.lineWidth, failures)
|
||||
)
|
||||
|
||||
def test_doubleNewline(self) -> None:
|
||||
"""
|
||||
Allow paragraphs delimited by two \ns.
|
||||
"""
|
||||
sampleText = "et\n\nphone\nhome."
|
||||
result = text.wordWrap(sampleText, self.lineWidth)
|
||||
self.assertEqual(result, ["et", "", "phone home.", ""])
|
||||
|
||||
|
||||
class LineTests(unittest.TestCase):
|
||||
"""
|
||||
Tests for L{isMultiline} and L{endsInNewline}.
|
||||
"""
|
||||
|
||||
def test_isMultiline(self) -> None:
|
||||
"""
|
||||
L{text.isMultiline} returns C{True} if the string has a newline in it.
|
||||
"""
|
||||
s = 'This code\n "breaks."'
|
||||
m = text.isMultiline(s)
|
||||
self.assertTrue(m)
|
||||
|
||||
s = 'This code does not "break."'
|
||||
m = text.isMultiline(s)
|
||||
self.assertFalse(m)
|
||||
|
||||
def test_endsInNewline(self) -> None:
|
||||
"""
|
||||
L{text.endsInNewline} returns C{True} if the string ends in a newline.
|
||||
"""
|
||||
s = "newline\n"
|
||||
m = text.endsInNewline(s)
|
||||
self.assertTrue(m)
|
||||
|
||||
s = "oldline"
|
||||
m = text.endsInNewline(s)
|
||||
self.assertFalse(m)
|
||||
|
||||
|
||||
class StringyStringTests(unittest.TestCase):
|
||||
"""
|
||||
Tests for L{text.stringyString}.
|
||||
"""
|
||||
|
||||
def test_tuple(self) -> None:
|
||||
"""
|
||||
Tuple elements are displayed on separate lines.
|
||||
"""
|
||||
s = ("a", "b")
|
||||
m = text.stringyString(s)
|
||||
self.assertEqual(m, "(a,\n b,)\n")
|
||||
|
||||
def test_dict(self) -> None:
|
||||
"""
|
||||
Dicts elements are displayed using C{str()}.
|
||||
"""
|
||||
s = {"a": 0}
|
||||
m = text.stringyString(s)
|
||||
self.assertEqual(m, "{a: 0}")
|
||||
|
||||
def test_list(self) -> None:
|
||||
"""
|
||||
List elements are displayed on separate lines using C{str()}.
|
||||
"""
|
||||
s = ["a", "b"]
|
||||
m = text.stringyString(s)
|
||||
self.assertEqual(m, "[a,\n b,]\n")
|
||||
|
||||
|
||||
class SplitTests(unittest.TestCase):
|
||||
"""
|
||||
Tests for L{text.splitQuoted}.
|
||||
"""
|
||||
|
||||
def test_oneWord(self) -> None:
|
||||
"""
|
||||
Splitting strings with one-word phrases.
|
||||
"""
|
||||
s = 'This code "works."'
|
||||
r = text.splitQuoted(s)
|
||||
self.assertEqual(["This", "code", "works."], r)
|
||||
|
||||
def test_multiWord(self) -> None:
|
||||
s = 'The "hairy monkey" likes pie.'
|
||||
r = text.splitQuoted(s)
|
||||
self.assertEqual(["The", "hairy monkey", "likes", "pie."], r)
|
||||
|
||||
# Some of the many tests that would fail:
|
||||
|
||||
# def test_preserveWhitespace(self):
|
||||
# phrase = '"MANY SPACES"'
|
||||
# s = 'With %s between.' % (phrase,)
|
||||
# r = text.splitQuoted(s)
|
||||
# self.assertEqual(['With', phrase, 'between.'], r)
|
||||
|
||||
# def test_escapedSpace(self):
|
||||
# s = r"One\ Phrase"
|
||||
# r = text.splitQuoted(s)
|
||||
# self.assertEqual(["One Phrase"], r)
|
||||
|
||||
|
||||
class StrFileTests(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.io = StringIO("this is a test string")
|
||||
|
||||
def tearDown(self) -> None:
|
||||
pass
|
||||
|
||||
def test_1_f(self) -> None:
|
||||
self.assertFalse(text.strFile("x", self.io))
|
||||
|
||||
def test_1_1(self) -> None:
|
||||
self.assertTrue(text.strFile("t", self.io))
|
||||
|
||||
def test_1_2(self) -> None:
|
||||
self.assertTrue(text.strFile("h", self.io))
|
||||
|
||||
def test_1_3(self) -> None:
|
||||
self.assertTrue(text.strFile("i", self.io))
|
||||
|
||||
def test_1_4(self) -> None:
|
||||
self.assertTrue(text.strFile("s", self.io))
|
||||
|
||||
def test_1_5(self) -> None:
|
||||
self.assertTrue(text.strFile("n", self.io))
|
||||
|
||||
def test_1_6(self) -> None:
|
||||
self.assertTrue(text.strFile("g", self.io))
|
||||
|
||||
def test_3_1(self) -> None:
|
||||
self.assertTrue(text.strFile("thi", self.io))
|
||||
|
||||
def test_3_2(self) -> None:
|
||||
self.assertTrue(text.strFile("his", self.io))
|
||||
|
||||
def test_3_3(self) -> None:
|
||||
self.assertTrue(text.strFile("is ", self.io))
|
||||
|
||||
def test_3_4(self) -> None:
|
||||
self.assertTrue(text.strFile("ing", self.io))
|
||||
|
||||
def test_3_f(self) -> None:
|
||||
self.assertFalse(text.strFile("bla", self.io))
|
||||
|
||||
def test_large_1(self) -> None:
|
||||
self.assertTrue(text.strFile("this is a test", self.io))
|
||||
|
||||
def test_large_2(self) -> None:
|
||||
self.assertTrue(text.strFile("is a test string", self.io))
|
||||
|
||||
def test_large_f(self) -> None:
|
||||
self.assertFalse(text.strFile("ds jhfsa k fdas", self.io))
|
||||
|
||||
def test_overlarge_f(self) -> None:
|
||||
self.assertFalse(
|
||||
text.strFile("djhsakj dhsa fkhsa s,mdbnfsauiw bndasdf hreew", self.io)
|
||||
)
|
||||
|
||||
def test_self(self) -> None:
|
||||
self.assertTrue(text.strFile("this is a test string", self.io))
|
||||
|
||||
def test_insensitive(self) -> None:
|
||||
self.assertTrue(text.strFile("ThIs is A test STRING", self.io, False))
|
||||
@@ -0,0 +1,119 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.python.threadable}.
|
||||
"""
|
||||
|
||||
|
||||
import pickle
|
||||
import sys
|
||||
from unittest import skipIf
|
||||
|
||||
try:
|
||||
import threading
|
||||
except ImportError:
|
||||
threadingSkip = True
|
||||
else:
|
||||
threadingSkip = False
|
||||
|
||||
from twisted.python import threadable
|
||||
from twisted.trial.unittest import FailTest, SynchronousTestCase
|
||||
|
||||
|
||||
class TestObject:
|
||||
synchronized = ["aMethod"]
|
||||
|
||||
x = -1
|
||||
y = 1
|
||||
|
||||
def aMethod(self):
|
||||
for i in range(10):
|
||||
self.x, self.y = self.y, self.x
|
||||
self.z = self.x + self.y
|
||||
assert self.z == 0, "z == %d, not 0 as expected" % (self.z,)
|
||||
|
||||
|
||||
threadable.synchronize(TestObject)
|
||||
|
||||
|
||||
class SynchronizationTests(SynchronousTestCase):
|
||||
def setUp(self):
|
||||
"""
|
||||
Reduce the CPython check interval so that thread switches happen much
|
||||
more often, hopefully exercising more possible race conditions. Also,
|
||||
delay actual test startup until the reactor has been started.
|
||||
"""
|
||||
self.addCleanup(sys.setswitchinterval, sys.getswitchinterval())
|
||||
sys.setswitchinterval(0.0000001)
|
||||
|
||||
def test_synchronizedName(self):
|
||||
"""
|
||||
The name of a synchronized method is inaffected by the synchronization
|
||||
decorator.
|
||||
"""
|
||||
self.assertEqual("aMethod", TestObject.aMethod.__name__)
|
||||
|
||||
@skipIf(threadingSkip, "Platform does not support threads")
|
||||
def test_isInIOThread(self):
|
||||
"""
|
||||
L{threadable.isInIOThread} returns C{True} if and only if it is called
|
||||
in the same thread as L{threadable.registerAsIOThread}.
|
||||
"""
|
||||
threadable.registerAsIOThread()
|
||||
foreignResult = []
|
||||
t = threading.Thread(
|
||||
target=lambda: foreignResult.append(threadable.isInIOThread())
|
||||
)
|
||||
t.start()
|
||||
t.join()
|
||||
self.assertFalse(foreignResult[0], "Non-IO thread reported as IO thread")
|
||||
self.assertTrue(
|
||||
threadable.isInIOThread(), "IO thread reported as not IO thread"
|
||||
)
|
||||
|
||||
@skipIf(threadingSkip, "Platform does not support threads")
|
||||
def testThreadedSynchronization(self):
|
||||
o = TestObject()
|
||||
|
||||
errors = []
|
||||
|
||||
def callMethodLots():
|
||||
try:
|
||||
for i in range(1000):
|
||||
o.aMethod()
|
||||
except AssertionError as e:
|
||||
errors.append(str(e))
|
||||
|
||||
threads = []
|
||||
for x in range(5):
|
||||
t = threading.Thread(target=callMethodLots)
|
||||
threads.append(t)
|
||||
t.start()
|
||||
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
if errors:
|
||||
raise FailTest(errors)
|
||||
|
||||
def testUnthreadedSynchronization(self):
|
||||
o = TestObject()
|
||||
for i in range(1000):
|
||||
o.aMethod()
|
||||
|
||||
|
||||
class SerializationTests(SynchronousTestCase):
|
||||
@skipIf(threadingSkip, "Platform does not support threads")
|
||||
def testPickling(self):
|
||||
lock = threadable.XLock()
|
||||
lockType = type(lock)
|
||||
lockPickle = pickle.dumps(lock)
|
||||
newLock = pickle.loads(lockPickle)
|
||||
self.assertIsInstance(newLock, lockType)
|
||||
|
||||
def testUnpickling(self):
|
||||
lockPickle = b"ctwisted.python.threadable\nunpickle_lock\np0\n(tp1\nRp2\n."
|
||||
lock = pickle.loads(lockPickle)
|
||||
newPickle = pickle.dumps(lock, 2)
|
||||
pickle.loads(newPickle)
|
||||
@@ -0,0 +1,710 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.python.threadpool}
|
||||
"""
|
||||
|
||||
|
||||
import gc
|
||||
import pickle
|
||||
import threading
|
||||
import time
|
||||
import weakref
|
||||
|
||||
from twisted._threads import Team, createMemoryWorker
|
||||
from twisted.python import context, failure, threadable, threadpool
|
||||
from twisted.trial import unittest
|
||||
|
||||
|
||||
class Synchronization:
|
||||
failures = 0
|
||||
|
||||
def __init__(self, N, waiting):
|
||||
self.N = N
|
||||
self.waiting = waiting
|
||||
self.lock = threading.Lock()
|
||||
self.runs = []
|
||||
|
||||
def run(self):
|
||||
# This is the testy part: this is supposed to be invoked
|
||||
# serially from multiple threads. If that is actually the
|
||||
# case, we will never fail to acquire this lock. If it is
|
||||
# *not* the case, we might get here while someone else is
|
||||
# holding the lock.
|
||||
if self.lock.acquire(False):
|
||||
if not len(self.runs) % 5:
|
||||
# Constant selected based on empirical data to maximize the
|
||||
# chance of a quick failure if this code is broken.
|
||||
time.sleep(0.0002)
|
||||
self.lock.release()
|
||||
else:
|
||||
self.failures += 1
|
||||
|
||||
# This is just the only way I can think of to wake up the test
|
||||
# method. It doesn't actually have anything to do with the
|
||||
# test.
|
||||
self.lock.acquire()
|
||||
self.runs.append(None)
|
||||
if len(self.runs) == self.N:
|
||||
self.waiting.release()
|
||||
self.lock.release()
|
||||
|
||||
synchronized = ["run"]
|
||||
|
||||
|
||||
threadable.synchronize(Synchronization)
|
||||
|
||||
|
||||
class ThreadPoolTests(unittest.SynchronousTestCase):
|
||||
"""
|
||||
Test threadpools.
|
||||
"""
|
||||
|
||||
def getTimeout(self):
|
||||
"""
|
||||
Return number of seconds to wait before giving up.
|
||||
"""
|
||||
return 5 # Really should be order of magnitude less
|
||||
|
||||
def _waitForLock(self, lock):
|
||||
items = range(1000000)
|
||||
for i in items:
|
||||
if lock.acquire(False):
|
||||
break
|
||||
time.sleep(1e-5)
|
||||
else:
|
||||
self.fail("A long time passed without succeeding")
|
||||
|
||||
def test_attributes(self):
|
||||
"""
|
||||
L{ThreadPool.min} and L{ThreadPool.max} are set to the values passed to
|
||||
L{ThreadPool.__init__}.
|
||||
"""
|
||||
pool = threadpool.ThreadPool(12, 22)
|
||||
self.assertEqual(pool.min, 12)
|
||||
self.assertEqual(pool.max, 22)
|
||||
|
||||
def test_start(self):
|
||||
"""
|
||||
L{ThreadPool.start} creates the minimum number of threads specified.
|
||||
"""
|
||||
pool = threadpool.ThreadPool(0, 5)
|
||||
pool.start()
|
||||
self.addCleanup(pool.stop)
|
||||
self.assertEqual(len(pool.threads), 0)
|
||||
|
||||
pool = threadpool.ThreadPool(3, 10)
|
||||
self.assertEqual(len(pool.threads), 0)
|
||||
pool.start()
|
||||
self.addCleanup(pool.stop)
|
||||
self.assertEqual(len(pool.threads), 3)
|
||||
|
||||
def test_adjustingWhenPoolStopped(self):
|
||||
"""
|
||||
L{ThreadPool.adjustPoolsize} only modifies the pool size and does not
|
||||
start new workers while the pool is not running.
|
||||
"""
|
||||
pool = threadpool.ThreadPool(0, 5)
|
||||
pool.start()
|
||||
pool.stop()
|
||||
pool.adjustPoolsize(2)
|
||||
self.assertEqual(len(pool.threads), 0)
|
||||
|
||||
def test_threadCreationArguments(self):
|
||||
"""
|
||||
Test that creating threads in the threadpool with application-level
|
||||
objects as arguments doesn't results in those objects never being
|
||||
freed, with the thread maintaining a reference to them as long as it
|
||||
exists.
|
||||
"""
|
||||
tp = threadpool.ThreadPool(0, 1)
|
||||
tp.start()
|
||||
self.addCleanup(tp.stop)
|
||||
|
||||
# Sanity check - no threads should have been started yet.
|
||||
self.assertEqual(tp.threads, [])
|
||||
|
||||
# Here's our function
|
||||
def worker(arg):
|
||||
pass
|
||||
|
||||
# weakref needs an object subclass
|
||||
class Dumb:
|
||||
pass
|
||||
|
||||
# And here's the unique object
|
||||
unique = Dumb()
|
||||
|
||||
workerRef = weakref.ref(worker)
|
||||
uniqueRef = weakref.ref(unique)
|
||||
|
||||
# Put some work in
|
||||
tp.callInThread(worker, unique)
|
||||
|
||||
# Add an event to wait completion
|
||||
event = threading.Event()
|
||||
tp.callInThread(event.set)
|
||||
event.wait(self.getTimeout())
|
||||
|
||||
del worker
|
||||
del unique
|
||||
gc.collect()
|
||||
self.assertIsNone(uniqueRef())
|
||||
self.assertIsNone(workerRef())
|
||||
|
||||
def test_threadCreationArgumentsCallInThreadWithCallback(self):
|
||||
"""
|
||||
As C{test_threadCreationArguments} above, but for
|
||||
callInThreadWithCallback.
|
||||
"""
|
||||
|
||||
tp = threadpool.ThreadPool(0, 1)
|
||||
tp.start()
|
||||
self.addCleanup(tp.stop)
|
||||
|
||||
# Sanity check - no threads should have been started yet.
|
||||
self.assertEqual(tp.threads, [])
|
||||
|
||||
# this holds references obtained in onResult
|
||||
refdict = {} # name -> ref value
|
||||
|
||||
onResultWait = threading.Event()
|
||||
onResultDone = threading.Event()
|
||||
|
||||
resultRef = []
|
||||
|
||||
# result callback
|
||||
def onResult(success, result):
|
||||
# Spin the GC, which should now delete worker and unique if it's
|
||||
# not held on to by callInThreadWithCallback after it is complete
|
||||
gc.collect()
|
||||
onResultWait.wait(self.getTimeout())
|
||||
refdict["workerRef"] = workerRef()
|
||||
refdict["uniqueRef"] = uniqueRef()
|
||||
onResultDone.set()
|
||||
resultRef.append(weakref.ref(result))
|
||||
|
||||
# Here's our function
|
||||
def worker(arg, test):
|
||||
return Dumb()
|
||||
|
||||
# weakref needs an object subclass
|
||||
class Dumb:
|
||||
pass
|
||||
|
||||
# And here's the unique object
|
||||
unique = Dumb()
|
||||
|
||||
onResultRef = weakref.ref(onResult)
|
||||
workerRef = weakref.ref(worker)
|
||||
uniqueRef = weakref.ref(unique)
|
||||
|
||||
# Put some work in
|
||||
tp.callInThreadWithCallback(onResult, worker, unique, test=unique)
|
||||
|
||||
del worker
|
||||
del unique
|
||||
|
||||
# let onResult collect the refs
|
||||
onResultWait.set()
|
||||
# wait for onResult
|
||||
onResultDone.wait(self.getTimeout())
|
||||
gc.collect()
|
||||
|
||||
self.assertIsNone(uniqueRef())
|
||||
self.assertIsNone(workerRef())
|
||||
|
||||
# XXX There's a race right here - has onResult in the worker thread
|
||||
# returned and the locals in _worker holding it and the result been
|
||||
# deleted yet?
|
||||
|
||||
del onResult
|
||||
gc.collect()
|
||||
self.assertIsNone(onResultRef())
|
||||
self.assertIsNone(resultRef[0]())
|
||||
|
||||
# The callback shouldn't have been able to resolve the references.
|
||||
self.assertEqual(list(refdict.values()), [None, None])
|
||||
|
||||
def test_persistence(self):
|
||||
"""
|
||||
Threadpools can be pickled and unpickled, which should preserve the
|
||||
number of threads and other parameters.
|
||||
"""
|
||||
pool = threadpool.ThreadPool(7, 20)
|
||||
|
||||
self.assertEqual(pool.min, 7)
|
||||
self.assertEqual(pool.max, 20)
|
||||
|
||||
# check that unpickled threadpool has same number of threads
|
||||
copy = pickle.loads(pickle.dumps(pool))
|
||||
|
||||
self.assertEqual(copy.min, 7)
|
||||
self.assertEqual(copy.max, 20)
|
||||
|
||||
def _threadpoolTest(self, method):
|
||||
"""
|
||||
Test synchronization of calls made with C{method}, which should be
|
||||
one of the mechanisms of the threadpool to execute work in threads.
|
||||
"""
|
||||
# This is a schizophrenic test: it seems to be trying to test
|
||||
# both the callInThread()/dispatch() behavior of the ThreadPool as well
|
||||
# as the serialization behavior of threadable.synchronize(). It
|
||||
# would probably make more sense as two much simpler tests.
|
||||
N = 10
|
||||
|
||||
tp = threadpool.ThreadPool()
|
||||
tp.start()
|
||||
self.addCleanup(tp.stop)
|
||||
|
||||
waiting = threading.Lock()
|
||||
waiting.acquire()
|
||||
actor = Synchronization(N, waiting)
|
||||
|
||||
for i in range(N):
|
||||
method(tp, actor)
|
||||
|
||||
self._waitForLock(waiting)
|
||||
|
||||
self.assertFalse(actor.failures, f"run() re-entered {actor.failures} times")
|
||||
|
||||
def test_callInThread(self):
|
||||
"""
|
||||
Call C{_threadpoolTest} with C{callInThread}.
|
||||
"""
|
||||
return self._threadpoolTest(lambda tp, actor: tp.callInThread(actor.run))
|
||||
|
||||
def test_callInThreadException(self):
|
||||
"""
|
||||
L{ThreadPool.callInThread} logs exceptions raised by the callable it
|
||||
is passed.
|
||||
"""
|
||||
|
||||
class NewError(Exception):
|
||||
pass
|
||||
|
||||
def raiseError():
|
||||
raise NewError()
|
||||
|
||||
tp = threadpool.ThreadPool(0, 1)
|
||||
tp.callInThread(raiseError)
|
||||
tp.start()
|
||||
tp.stop()
|
||||
|
||||
errors = self.flushLoggedErrors(NewError)
|
||||
self.assertEqual(len(errors), 1)
|
||||
|
||||
def test_callInThreadWithCallback(self):
|
||||
"""
|
||||
L{ThreadPool.callInThreadWithCallback} calls C{onResult} with a
|
||||
two-tuple of C{(True, result)} where C{result} is the value returned
|
||||
by the callable supplied.
|
||||
"""
|
||||
waiter = threading.Lock()
|
||||
waiter.acquire()
|
||||
|
||||
results = []
|
||||
|
||||
def onResult(success, result):
|
||||
waiter.release()
|
||||
results.append(success)
|
||||
results.append(result)
|
||||
|
||||
tp = threadpool.ThreadPool(0, 1)
|
||||
tp.callInThreadWithCallback(onResult, lambda: "test")
|
||||
tp.start()
|
||||
|
||||
try:
|
||||
self._waitForLock(waiter)
|
||||
finally:
|
||||
tp.stop()
|
||||
|
||||
self.assertTrue(results[0])
|
||||
self.assertEqual(results[1], "test")
|
||||
|
||||
def test_callInThreadWithCallbackExceptionInCallback(self):
|
||||
"""
|
||||
L{ThreadPool.callInThreadWithCallback} calls C{onResult} with a
|
||||
two-tuple of C{(False, failure)} where C{failure} represents the
|
||||
exception raised by the callable supplied.
|
||||
"""
|
||||
|
||||
class NewError(Exception):
|
||||
pass
|
||||
|
||||
def raiseError():
|
||||
raise NewError()
|
||||
|
||||
waiter = threading.Lock()
|
||||
waiter.acquire()
|
||||
|
||||
results = []
|
||||
|
||||
def onResult(success, result):
|
||||
waiter.release()
|
||||
results.append(success)
|
||||
results.append(result)
|
||||
|
||||
tp = threadpool.ThreadPool(0, 1)
|
||||
tp.callInThreadWithCallback(onResult, raiseError)
|
||||
tp.start()
|
||||
|
||||
try:
|
||||
self._waitForLock(waiter)
|
||||
finally:
|
||||
tp.stop()
|
||||
|
||||
self.assertFalse(results[0])
|
||||
self.assertIsInstance(results[1], failure.Failure)
|
||||
self.assertTrue(issubclass(results[1].type, NewError))
|
||||
|
||||
def test_callInThreadWithCallbackExceptionInOnResult(self):
|
||||
"""
|
||||
L{ThreadPool.callInThreadWithCallback} logs the exception raised by
|
||||
C{onResult}.
|
||||
"""
|
||||
|
||||
class NewError(Exception):
|
||||
pass
|
||||
|
||||
waiter = threading.Lock()
|
||||
waiter.acquire()
|
||||
|
||||
results = []
|
||||
|
||||
def onResult(success, result):
|
||||
results.append(success)
|
||||
results.append(result)
|
||||
raise NewError()
|
||||
|
||||
tp = threadpool.ThreadPool(0, 1)
|
||||
tp.callInThreadWithCallback(onResult, lambda: None)
|
||||
tp.callInThread(waiter.release)
|
||||
tp.start()
|
||||
|
||||
try:
|
||||
self._waitForLock(waiter)
|
||||
finally:
|
||||
tp.stop()
|
||||
|
||||
errors = self.flushLoggedErrors(NewError)
|
||||
self.assertEqual(len(errors), 1)
|
||||
|
||||
self.assertTrue(results[0])
|
||||
self.assertIsNone(results[1])
|
||||
|
||||
def test_callbackThread(self):
|
||||
"""
|
||||
L{ThreadPool.callInThreadWithCallback} calls the function it is
|
||||
given and the C{onResult} callback in the same thread.
|
||||
"""
|
||||
threadIds = []
|
||||
|
||||
event = threading.Event()
|
||||
|
||||
def onResult(success, result):
|
||||
threadIds.append(threading.current_thread().ident)
|
||||
event.set()
|
||||
|
||||
def func():
|
||||
threadIds.append(threading.current_thread().ident)
|
||||
|
||||
tp = threadpool.ThreadPool(0, 1)
|
||||
tp.callInThreadWithCallback(onResult, func)
|
||||
tp.start()
|
||||
self.addCleanup(tp.stop)
|
||||
|
||||
event.wait(self.getTimeout())
|
||||
self.assertEqual(len(threadIds), 2)
|
||||
self.assertEqual(threadIds[0], threadIds[1])
|
||||
|
||||
def test_callbackContext(self):
|
||||
"""
|
||||
The context L{ThreadPool.callInThreadWithCallback} is invoked in is
|
||||
shared by the context the callable and C{onResult} callback are
|
||||
invoked in.
|
||||
"""
|
||||
myctx = context.theContextTracker.currentContext().contexts[-1]
|
||||
myctx["testing"] = "this must be present"
|
||||
|
||||
contexts = []
|
||||
|
||||
event = threading.Event()
|
||||
|
||||
def onResult(success, result):
|
||||
ctx = context.theContextTracker.currentContext().contexts[-1]
|
||||
contexts.append(ctx)
|
||||
event.set()
|
||||
|
||||
def func():
|
||||
ctx = context.theContextTracker.currentContext().contexts[-1]
|
||||
contexts.append(ctx)
|
||||
|
||||
tp = threadpool.ThreadPool(0, 1)
|
||||
tp.callInThreadWithCallback(onResult, func)
|
||||
tp.start()
|
||||
self.addCleanup(tp.stop)
|
||||
|
||||
event.wait(self.getTimeout())
|
||||
|
||||
self.assertEqual(len(contexts), 2)
|
||||
self.assertEqual(myctx, contexts[0])
|
||||
self.assertEqual(myctx, contexts[1])
|
||||
|
||||
def test_existingWork(self):
|
||||
"""
|
||||
Work added to the threadpool before its start should be executed once
|
||||
the threadpool is started: this is ensured by trying to release a lock
|
||||
previously acquired.
|
||||
"""
|
||||
waiter = threading.Lock()
|
||||
waiter.acquire()
|
||||
|
||||
tp = threadpool.ThreadPool(0, 1)
|
||||
tp.callInThread(waiter.release) # Before start()
|
||||
tp.start()
|
||||
|
||||
try:
|
||||
self._waitForLock(waiter)
|
||||
finally:
|
||||
tp.stop()
|
||||
|
||||
def test_workerStateTransition(self):
|
||||
"""
|
||||
As the worker receives and completes work, it transitions between
|
||||
the working and waiting states.
|
||||
"""
|
||||
pool = threadpool.ThreadPool(0, 1)
|
||||
pool.start()
|
||||
self.addCleanup(pool.stop)
|
||||
|
||||
# Sanity check
|
||||
self.assertEqual(pool.workers, 0)
|
||||
self.assertEqual(len(pool.waiters), 0)
|
||||
self.assertEqual(len(pool.working), 0)
|
||||
|
||||
# Fire up a worker and give it some 'work'
|
||||
threadWorking = threading.Event()
|
||||
threadFinish = threading.Event()
|
||||
|
||||
def _thread():
|
||||
threadWorking.set()
|
||||
threadFinish.wait(10)
|
||||
|
||||
pool.callInThread(_thread)
|
||||
threadWorking.wait(10)
|
||||
self.assertEqual(pool.workers, 1)
|
||||
self.assertEqual(len(pool.waiters), 0)
|
||||
self.assertEqual(len(pool.working), 1)
|
||||
|
||||
# Finish work, and spin until state changes
|
||||
threadFinish.set()
|
||||
while not len(pool.waiters):
|
||||
time.sleep(0.0005)
|
||||
|
||||
# Make sure state changed correctly
|
||||
self.assertEqual(len(pool.waiters), 1)
|
||||
self.assertEqual(len(pool.working), 0)
|
||||
|
||||
def test_q(self) -> None:
|
||||
"""
|
||||
There is a property '_queue' for legacy purposes
|
||||
"""
|
||||
pool = threadpool.ThreadPool(0, 1)
|
||||
self.assertEqual(pool._queue.qsize(), 0)
|
||||
|
||||
|
||||
class RaceConditionTests(unittest.SynchronousTestCase):
|
||||
def setUp(self):
|
||||
self.threadpool = threadpool.ThreadPool(0, 10)
|
||||
self.event = threading.Event()
|
||||
self.threadpool.start()
|
||||
|
||||
def done():
|
||||
self.threadpool.stop()
|
||||
del self.threadpool
|
||||
|
||||
self.addCleanup(done)
|
||||
|
||||
def getTimeout(self):
|
||||
"""
|
||||
A reasonable number of seconds to time out.
|
||||
"""
|
||||
return 5
|
||||
|
||||
def test_synchronization(self):
|
||||
"""
|
||||
If multiple threads are waiting on an event (via blocking on something
|
||||
in a callable passed to L{threadpool.ThreadPool.callInThread}), and
|
||||
there is spare capacity in the threadpool, sending another callable
|
||||
which will cause those to un-block to
|
||||
L{threadpool.ThreadPool.callInThread} will reliably run that callable
|
||||
and un-block the blocked threads promptly.
|
||||
|
||||
@note: This is not really a unit test, it is a stress-test. You may
|
||||
need to run it with C{trial -u} to fail reliably if there is a
|
||||
problem. It is very hard to regression-test for this particular
|
||||
bug - one where the thread pool may consider itself as having
|
||||
"enough capacity" when it really needs to spin up a new thread if
|
||||
it possibly can - in a deterministic way, since the bug can only be
|
||||
provoked by subtle race conditions.
|
||||
"""
|
||||
timeout = self.getTimeout()
|
||||
self.threadpool.callInThread(self.event.set)
|
||||
self.event.wait(timeout)
|
||||
self.event.clear()
|
||||
for i in range(3):
|
||||
self.threadpool.callInThread(self.event.wait)
|
||||
self.threadpool.callInThread(self.event.set)
|
||||
self.event.wait(timeout)
|
||||
if not self.event.isSet():
|
||||
self.event.set()
|
||||
self.fail("'set' did not run in thread; timed out waiting on 'wait'.")
|
||||
|
||||
|
||||
class MemoryPool(threadpool.ThreadPool):
|
||||
"""
|
||||
A deterministic threadpool that uses in-memory data structures to queue
|
||||
work rather than threads to execute work.
|
||||
"""
|
||||
|
||||
def __init__(self, coordinator, failTest, newWorker, *args, **kwargs):
|
||||
"""
|
||||
Initialize this L{MemoryPool} with a test case.
|
||||
|
||||
@param coordinator: a worker used to coordinate work in the L{Team}
|
||||
underlying this threadpool.
|
||||
@type coordinator: L{twisted._threads.IExclusiveWorker}
|
||||
|
||||
@param failTest: A 1-argument callable taking an exception and raising
|
||||
a test-failure exception.
|
||||
@type failTest: 1-argument callable taking (L{Failure}) and raising
|
||||
L{unittest.FailTest}.
|
||||
|
||||
@param newWorker: a 0-argument callable that produces a new
|
||||
L{twisted._threads.IWorker} provider on each invocation.
|
||||
@type newWorker: 0-argument callable returning
|
||||
L{twisted._threads.IWorker}.
|
||||
"""
|
||||
self._coordinator = coordinator
|
||||
self._failTest = failTest
|
||||
self._newWorker = newWorker
|
||||
threadpool.ThreadPool.__init__(self, *args, **kwargs)
|
||||
|
||||
def _pool(self, currentLimit, threadFactory):
|
||||
"""
|
||||
Override testing hook to create a deterministic threadpool.
|
||||
|
||||
@param currentLimit: A 1-argument callable which returns the current
|
||||
threadpool size limit.
|
||||
|
||||
@param threadFactory: ignored in this invocation; a 0-argument callable
|
||||
that would produce a thread.
|
||||
|
||||
@return: a L{Team} backed by the coordinator and worker passed to
|
||||
L{MemoryPool.__init__}.
|
||||
"""
|
||||
|
||||
def respectLimit():
|
||||
# The expression in this method copied and pasted from
|
||||
# twisted.threads._pool, which is unfortunately bound up
|
||||
# with lots of actual-threading stuff.
|
||||
stats = team.statistics()
|
||||
if (stats.busyWorkerCount + stats.idleWorkerCount) >= currentLimit():
|
||||
return None
|
||||
return self._newWorker()
|
||||
|
||||
team = Team(
|
||||
coordinator=self._coordinator,
|
||||
createWorker=respectLimit,
|
||||
logException=self._failTest,
|
||||
)
|
||||
return team
|
||||
|
||||
|
||||
class PoolHelper:
|
||||
"""
|
||||
A L{PoolHelper} constructs a L{threadpool.ThreadPool} that doesn't actually
|
||||
use threads, by using the internal interfaces in L{twisted._threads}.
|
||||
|
||||
@ivar performCoordination: a 0-argument callable that will perform one unit
|
||||
of "coordination" - work involved in delegating work to other threads -
|
||||
and return L{True} if it did any work, L{False} otherwise.
|
||||
|
||||
@ivar workers: the workers which represent the threads within the pool -
|
||||
the workers other than the coordinator.
|
||||
@type workers: L{list} of 2-tuple of (L{IWorker}, C{workPerformer}) where
|
||||
C{workPerformer} is a 0-argument callable like C{performCoordination}.
|
||||
|
||||
@ivar threadpool: a modified L{threadpool.ThreadPool} to test.
|
||||
@type threadpool: L{MemoryPool}
|
||||
"""
|
||||
|
||||
def __init__(self, testCase, *args, **kwargs):
|
||||
"""
|
||||
Create a L{PoolHelper}.
|
||||
|
||||
@param testCase: a test case attached to this helper.
|
||||
|
||||
@type args: The arguments passed to a L{threadpool.ThreadPool}.
|
||||
|
||||
@type kwargs: The arguments passed to a L{threadpool.ThreadPool}
|
||||
"""
|
||||
coordinator, self.performCoordination = createMemoryWorker()
|
||||
self.workers = []
|
||||
|
||||
def newWorker():
|
||||
self.workers.append(createMemoryWorker())
|
||||
return self.workers[-1][0]
|
||||
|
||||
self.threadpool = MemoryPool(
|
||||
coordinator, testCase.fail, newWorker, *args, **kwargs
|
||||
)
|
||||
|
||||
def performAllCoordination(self):
|
||||
"""
|
||||
Perform all currently scheduled "coordination", which is the work
|
||||
involved in delegating work to other threads.
|
||||
"""
|
||||
while self.performCoordination():
|
||||
pass
|
||||
|
||||
|
||||
class MemoryBackedTests(unittest.SynchronousTestCase):
|
||||
"""
|
||||
Tests using L{PoolHelper} to deterministically test properties of the
|
||||
threadpool implementation.
|
||||
"""
|
||||
|
||||
def test_workBeforeStarting(self):
|
||||
"""
|
||||
If a threadpool is told to do work before starting, then upon starting
|
||||
up, it will start enough workers to handle all of the enqueued work
|
||||
that it's been given.
|
||||
"""
|
||||
helper = PoolHelper(self, 0, 10)
|
||||
n = 5
|
||||
for x in range(n):
|
||||
helper.threadpool.callInThread(lambda: None)
|
||||
helper.performAllCoordination()
|
||||
self.assertEqual(helper.workers, [])
|
||||
helper.threadpool.start()
|
||||
helper.performAllCoordination()
|
||||
self.assertEqual(len(helper.workers), n)
|
||||
|
||||
def test_tooMuchWorkBeforeStarting(self):
|
||||
"""
|
||||
If the amount of work before starting exceeds the maximum number of
|
||||
threads allowed to the threadpool, only the maximum count will be
|
||||
started.
|
||||
"""
|
||||
helper = PoolHelper(self, 0, 10)
|
||||
n = 50
|
||||
for x in range(n):
|
||||
helper.threadpool.callInThread(lambda: None)
|
||||
helper.performAllCoordination()
|
||||
self.assertEqual(helper.workers, [])
|
||||
helper.threadpool.start()
|
||||
helper.performAllCoordination()
|
||||
self.assertEqual(len(helper.workers), helper.threadpool.max)
|
||||
431
.venv/lib/python3.12/site-packages/twisted/test/test_threads.py
Normal file
431
.venv/lib/python3.12/site-packages/twisted/test/test_threads.py
Normal file
@@ -0,0 +1,431 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
|
||||
"""
|
||||
Test methods in twisted.internet.threads and reactor thread APIs.
|
||||
"""
|
||||
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from unittest import skipIf
|
||||
|
||||
from twisted.internet import defer, error, interfaces, protocol, reactor, threads
|
||||
from twisted.python import failure, log, threadable, threadpool
|
||||
from twisted.trial.unittest import TestCase
|
||||
|
||||
try:
|
||||
import threading
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
@skipIf(
|
||||
not interfaces.IReactorThreads(reactor, None),
|
||||
"No thread support, nothing to test here.",
|
||||
)
|
||||
class ReactorThreadsTests(TestCase):
|
||||
"""
|
||||
Tests for the reactor threading API.
|
||||
"""
|
||||
|
||||
def test_suggestThreadPoolSize(self):
|
||||
"""
|
||||
Try to change maximum number of threads.
|
||||
"""
|
||||
reactor.suggestThreadPoolSize(34)
|
||||
self.assertEqual(reactor.threadpool.max, 34)
|
||||
reactor.suggestThreadPoolSize(4)
|
||||
self.assertEqual(reactor.threadpool.max, 4)
|
||||
|
||||
def _waitForThread(self):
|
||||
"""
|
||||
The reactor's threadpool is only available when the reactor is running,
|
||||
so to have a sane behavior during the tests we make a dummy
|
||||
L{threads.deferToThread} call.
|
||||
"""
|
||||
return threads.deferToThread(time.sleep, 0)
|
||||
|
||||
def test_callInThread(self):
|
||||
"""
|
||||
Test callInThread functionality: set a C{threading.Event}, and check
|
||||
that it's not in the main thread.
|
||||
"""
|
||||
|
||||
def cb(ign):
|
||||
waiter = threading.Event()
|
||||
result = []
|
||||
|
||||
def threadedFunc():
|
||||
result.append(threadable.isInIOThread())
|
||||
waiter.set()
|
||||
|
||||
reactor.callInThread(threadedFunc)
|
||||
waiter.wait(120)
|
||||
if not waiter.isSet():
|
||||
self.fail("Timed out waiting for event.")
|
||||
else:
|
||||
self.assertEqual(result, [False])
|
||||
|
||||
return self._waitForThread().addCallback(cb)
|
||||
|
||||
def test_callFromThread(self):
|
||||
"""
|
||||
Test callFromThread functionality: from the main thread, and from
|
||||
another thread.
|
||||
"""
|
||||
|
||||
def cb(ign):
|
||||
firedByReactorThread = defer.Deferred()
|
||||
firedByOtherThread = defer.Deferred()
|
||||
|
||||
def threadedFunc():
|
||||
reactor.callFromThread(firedByOtherThread.callback, None)
|
||||
|
||||
reactor.callInThread(threadedFunc)
|
||||
reactor.callFromThread(firedByReactorThread.callback, None)
|
||||
|
||||
return defer.DeferredList(
|
||||
[firedByReactorThread, firedByOtherThread], fireOnOneErrback=True
|
||||
)
|
||||
|
||||
return self._waitForThread().addCallback(cb)
|
||||
|
||||
def test_wakerOverflow(self):
|
||||
"""
|
||||
Try to make an overflow on the reactor waker using callFromThread.
|
||||
"""
|
||||
|
||||
def cb(ign):
|
||||
self.failure = None
|
||||
waiter = threading.Event()
|
||||
|
||||
def threadedFunction():
|
||||
# Hopefully a hundred thousand queued calls is enough to
|
||||
# trigger the error condition
|
||||
for i in range(100000):
|
||||
try:
|
||||
reactor.callFromThread(lambda: None)
|
||||
except BaseException:
|
||||
self.failure = failure.Failure()
|
||||
break
|
||||
waiter.set()
|
||||
|
||||
reactor.callInThread(threadedFunction)
|
||||
waiter.wait(120)
|
||||
if not waiter.isSet():
|
||||
self.fail("Timed out waiting for event")
|
||||
if self.failure is not None:
|
||||
return defer.fail(self.failure)
|
||||
|
||||
return self._waitForThread().addCallback(cb)
|
||||
|
||||
def _testBlockingCallFromThread(self, reactorFunc):
|
||||
"""
|
||||
Utility method to test L{threads.blockingCallFromThread}.
|
||||
"""
|
||||
waiter = threading.Event()
|
||||
results = []
|
||||
errors = []
|
||||
|
||||
def cb1(ign):
|
||||
def threadedFunc():
|
||||
try:
|
||||
r = threads.blockingCallFromThread(reactor, reactorFunc)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
else:
|
||||
results.append(r)
|
||||
waiter.set()
|
||||
|
||||
reactor.callInThread(threadedFunc)
|
||||
return threads.deferToThread(waiter.wait, self.getTimeout())
|
||||
|
||||
def cb2(ign):
|
||||
if not waiter.isSet():
|
||||
self.fail("Timed out waiting for event")
|
||||
return results, errors
|
||||
|
||||
return self._waitForThread().addCallback(cb1).addBoth(cb2)
|
||||
|
||||
def test_blockingCallFromThread(self):
|
||||
"""
|
||||
Test blockingCallFromThread facility: create a thread, call a function
|
||||
in the reactor using L{threads.blockingCallFromThread}, and verify the
|
||||
result returned.
|
||||
"""
|
||||
|
||||
def reactorFunc():
|
||||
return defer.succeed("foo")
|
||||
|
||||
def cb(res):
|
||||
self.assertEqual(res[0][0], "foo")
|
||||
|
||||
return self._testBlockingCallFromThread(reactorFunc).addCallback(cb)
|
||||
|
||||
def test_asyncBlockingCallFromThread(self):
|
||||
"""
|
||||
Test blockingCallFromThread as above, but be sure the resulting
|
||||
Deferred is not already fired.
|
||||
"""
|
||||
|
||||
def reactorFunc():
|
||||
d = defer.Deferred()
|
||||
reactor.callLater(0.1, d.callback, "egg")
|
||||
return d
|
||||
|
||||
def cb(res):
|
||||
self.assertEqual(res[0][0], "egg")
|
||||
|
||||
return self._testBlockingCallFromThread(reactorFunc).addCallback(cb)
|
||||
|
||||
def test_errorBlockingCallFromThread(self):
|
||||
"""
|
||||
Test error report for blockingCallFromThread.
|
||||
"""
|
||||
|
||||
def reactorFunc():
|
||||
return defer.fail(RuntimeError("bar"))
|
||||
|
||||
def cb(res):
|
||||
self.assertIsInstance(res[1][0], RuntimeError)
|
||||
self.assertEqual(res[1][0].args[0], "bar")
|
||||
|
||||
return self._testBlockingCallFromThread(reactorFunc).addCallback(cb)
|
||||
|
||||
def test_asyncErrorBlockingCallFromThread(self):
|
||||
"""
|
||||
Test error report for blockingCallFromThread as above, but be sure the
|
||||
resulting Deferred is not already fired.
|
||||
"""
|
||||
|
||||
def reactorFunc():
|
||||
d = defer.Deferred()
|
||||
reactor.callLater(0.1, d.errback, RuntimeError("spam"))
|
||||
return d
|
||||
|
||||
def cb(res):
|
||||
self.assertIsInstance(res[1][0], RuntimeError)
|
||||
self.assertEqual(res[1][0].args[0], "spam")
|
||||
|
||||
return self._testBlockingCallFromThread(reactorFunc).addCallback(cb)
|
||||
|
||||
|
||||
class Counter:
|
||||
index = 0
|
||||
problem = 0
|
||||
|
||||
def add(self):
|
||||
"""A non thread-safe method."""
|
||||
next = self.index + 1
|
||||
# another thread could jump in here and increment self.index on us
|
||||
if next != self.index + 1:
|
||||
self.problem = 1
|
||||
raise ValueError
|
||||
# or here, same issue but we wouldn't catch it. We'd overwrite
|
||||
# their results, and the index will have lost a count. If
|
||||
# several threads get in here, we will actually make the count
|
||||
# go backwards when we overwrite it.
|
||||
self.index = next
|
||||
|
||||
|
||||
@skipIf(
|
||||
not interfaces.IReactorThreads(reactor, None),
|
||||
"No thread support, nothing to test here.",
|
||||
)
|
||||
class DeferredResultTests(TestCase):
|
||||
"""
|
||||
Test twisted.internet.threads.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
reactor.suggestThreadPoolSize(8)
|
||||
|
||||
def tearDown(self):
|
||||
reactor.suggestThreadPoolSize(0)
|
||||
|
||||
def test_callMultiple(self):
|
||||
"""
|
||||
L{threads.callMultipleInThread} calls multiple functions in a thread.
|
||||
"""
|
||||
L = []
|
||||
N = 10
|
||||
d = defer.Deferred()
|
||||
|
||||
def finished():
|
||||
self.assertEqual(L, list(range(N)))
|
||||
d.callback(None)
|
||||
|
||||
threads.callMultipleInThread(
|
||||
[(L.append, (i,), {}) for i in range(N)]
|
||||
+ [(reactor.callFromThread, (finished,), {})]
|
||||
)
|
||||
return d
|
||||
|
||||
def test_deferredResult(self):
|
||||
"""
|
||||
L{threads.deferToThread} executes the function passed, and correctly
|
||||
handles the positional and keyword arguments given.
|
||||
"""
|
||||
d = threads.deferToThread(lambda x, y=5: x + y, 3, y=4)
|
||||
d.addCallback(self.assertEqual, 7)
|
||||
return d
|
||||
|
||||
def test_deferredFailure(self):
|
||||
"""
|
||||
Check that L{threads.deferToThread} return a failure object
|
||||
with an appropriate exception instance when the called
|
||||
function raises an exception.
|
||||
"""
|
||||
|
||||
class NewError(Exception):
|
||||
pass
|
||||
|
||||
def raiseError():
|
||||
raise NewError()
|
||||
|
||||
d = threads.deferToThread(raiseError)
|
||||
return self.assertFailure(d, NewError)
|
||||
|
||||
def test_deferredFailureAfterSuccess(self):
|
||||
"""
|
||||
Check that a successful L{threads.deferToThread} followed by a one
|
||||
that raises an exception correctly result as a failure.
|
||||
"""
|
||||
# set up a condition that causes cReactor to hang. These conditions
|
||||
# can also be set by other tests when the full test suite is run in
|
||||
# alphabetical order (test_flow.FlowTest.testThreaded followed by
|
||||
# test_internet.ReactorCoreTestCase.testStop, to be precise). By
|
||||
# setting them up explicitly here, we can reproduce the hang in a
|
||||
# single precise test case instead of depending upon side effects of
|
||||
# other tests.
|
||||
#
|
||||
# alas, this test appears to flunk the default reactor too
|
||||
|
||||
d = threads.deferToThread(lambda: None)
|
||||
d.addCallback(lambda ign: threads.deferToThread(lambda: 1 // 0))
|
||||
return self.assertFailure(d, ZeroDivisionError)
|
||||
|
||||
|
||||
class DeferToThreadPoolTests(TestCase):
|
||||
"""
|
||||
Test L{twisted.internet.threads.deferToThreadPool}.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
self.tp = threadpool.ThreadPool(0, 8)
|
||||
self.tp.start()
|
||||
|
||||
def tearDown(self):
|
||||
self.tp.stop()
|
||||
|
||||
def test_deferredResult(self):
|
||||
"""
|
||||
L{threads.deferToThreadPool} executes the function passed, and
|
||||
correctly handles the positional and keyword arguments given.
|
||||
"""
|
||||
d = threads.deferToThreadPool(reactor, self.tp, lambda x, y=5: x + y, 3, y=4)
|
||||
d.addCallback(self.assertEqual, 7)
|
||||
return d
|
||||
|
||||
def test_deferredFailure(self):
|
||||
"""
|
||||
Check that L{threads.deferToThreadPool} return a failure object with an
|
||||
appropriate exception instance when the called function raises an
|
||||
exception.
|
||||
"""
|
||||
|
||||
class NewError(Exception):
|
||||
pass
|
||||
|
||||
def raiseError():
|
||||
raise NewError()
|
||||
|
||||
d = threads.deferToThreadPool(reactor, self.tp, raiseError)
|
||||
return self.assertFailure(d, NewError)
|
||||
|
||||
|
||||
_callBeforeStartupProgram = """
|
||||
import time
|
||||
import %(reactor)s
|
||||
%(reactor)s.install()
|
||||
|
||||
from twisted.internet import reactor
|
||||
|
||||
def threadedCall():
|
||||
print('threaded call')
|
||||
|
||||
reactor.callInThread(threadedCall)
|
||||
|
||||
# Spin very briefly to try to give the thread a chance to run, if it
|
||||
# is going to. Is there a better way to achieve this behavior?
|
||||
for i in range(100):
|
||||
time.sleep(0.0)
|
||||
"""
|
||||
|
||||
|
||||
class ThreadStartupProcessProtocol(protocol.ProcessProtocol):
|
||||
def __init__(self, finished):
|
||||
self.finished = finished
|
||||
self.out = []
|
||||
self.err = []
|
||||
|
||||
def outReceived(self, out):
|
||||
self.out.append(out)
|
||||
|
||||
def errReceived(self, err):
|
||||
self.err.append(err)
|
||||
|
||||
def processEnded(self, reason):
|
||||
self.finished.callback((self.out, self.err, reason))
|
||||
|
||||
|
||||
@skipIf(
|
||||
not interfaces.IReactorThreads(reactor, None),
|
||||
"No thread support, nothing to test here.",
|
||||
)
|
||||
@skipIf(
|
||||
not interfaces.IReactorProcess(reactor, None),
|
||||
"No process support, cannot run subprocess thread tests.",
|
||||
)
|
||||
class StartupBehaviorTests(TestCase):
|
||||
"""
|
||||
Test cases for the behavior of the reactor threadpool near startup
|
||||
boundary conditions.
|
||||
|
||||
In particular, this asserts that no threaded calls are attempted
|
||||
until the reactor starts up, that calls attempted before it starts
|
||||
are in fact executed once it has started, and that in both cases,
|
||||
the reactor properly cleans itself up (which is tested for
|
||||
somewhat implicitly, by requiring a child process be able to exit,
|
||||
something it cannot do unless the threadpool has been properly
|
||||
torn down).
|
||||
"""
|
||||
|
||||
def testCallBeforeStartupUnexecuted(self):
|
||||
progname = self.mktemp()
|
||||
with open(progname, "w") as progfile:
|
||||
progfile.write(_callBeforeStartupProgram % {"reactor": reactor.__module__})
|
||||
|
||||
def programFinished(result):
|
||||
(out, err, reason) = result
|
||||
if reason.check(error.ProcessTerminated):
|
||||
self.fail(f"Process did not exit cleanly (out: {out} err: {err})")
|
||||
|
||||
if err:
|
||||
log.msg(f"Unexpected output on standard error: {err}")
|
||||
self.assertFalse(out, f"Expected no output, instead received:\n{out}")
|
||||
|
||||
def programTimeout(err):
|
||||
err.trap(error.TimeoutError)
|
||||
proto.signalProcess("KILL")
|
||||
return err
|
||||
|
||||
env = os.environ.copy()
|
||||
env["PYTHONPATH"] = os.pathsep.join(sys.path)
|
||||
d = defer.Deferred().addCallbacks(programFinished, programTimeout)
|
||||
proto = ThreadStartupProcessProtocol(d)
|
||||
reactor.spawnProcess(proto, sys.executable, ("python", progname), env)
|
||||
return d
|
||||
@@ -0,0 +1,56 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
|
||||
from io import BytesIO
|
||||
|
||||
from twisted.internet import abstract, defer, protocol
|
||||
from twisted.protocols import basic, loopback
|
||||
from twisted.trial import unittest
|
||||
|
||||
|
||||
class BufferingServer(protocol.Protocol):
|
||||
buffer = b""
|
||||
|
||||
def dataReceived(self, data: bytes) -> None:
|
||||
self.buffer += data
|
||||
|
||||
|
||||
class FileSendingClient(protocol.Protocol):
|
||||
def __init__(self, f: BytesIO) -> None:
|
||||
self.f = f
|
||||
|
||||
def connectionMade(self) -> None:
|
||||
assert self.transport is not None
|
||||
s = basic.FileSender()
|
||||
d = s.beginFileTransfer(self.f, self.transport, lambda x: x)
|
||||
d.addCallback(lambda r: self.transport.loseConnection())
|
||||
|
||||
|
||||
class FileSenderTests(unittest.TestCase):
|
||||
def testSendingFile(self) -> defer.Deferred[None]:
|
||||
testStr = b"xyz" * 100 + b"abc" * 100 + b"123" * 100
|
||||
s = BufferingServer()
|
||||
c = FileSendingClient(BytesIO(testStr))
|
||||
|
||||
d: defer.Deferred[None] = loopback.loopbackTCP(s, c)
|
||||
|
||||
def callback(x: object) -> None:
|
||||
self.assertEqual(s.buffer, testStr)
|
||||
|
||||
return d.addCallback(callback)
|
||||
|
||||
def testSendingEmptyFile(self) -> None:
|
||||
fileSender = basic.FileSender()
|
||||
consumer = abstract.FileDescriptor()
|
||||
consumer.connected = 1
|
||||
emptyFile = BytesIO(b"")
|
||||
|
||||
d = fileSender.beginFileTransfer(emptyFile, consumer, lambda x: x)
|
||||
|
||||
# The producer will be immediately exhausted, and so immediately
|
||||
# unregistered
|
||||
self.assertIsNone(consumer.producer)
|
||||
|
||||
# Which means the Deferred from FileSender should have been called
|
||||
self.assertTrue(d.called, "producer unregistered with deferred being called")
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user