okay fine

This commit is contained in:
pacnpal
2024-11-03 17:47:26 +00:00
parent 01c6004a79
commit 27eb239e97
10020 changed files with 1935769 additions and 2364 deletions

View File

@@ -0,0 +1,19 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Twisted's unit tests.
"""
from twisted.python.deprecate import deprecatedModuleAttribute
from twisted.python.versions import Version
from twisted.test import proto_helpers
for obj in proto_helpers.__all__:
deprecatedModuleAttribute(
Version("Twisted", 19, 7, 0),
f"Please use twisted.internet.testing.{obj} instead.",
"twisted.test.proto_helpers",
obj,
)

View File

@@ -0,0 +1,25 @@
-----BEGIN CERTIFICATE-----
[AWS-SECRET-REMOVED]A9HEyFEwDQYJKoZIhvcNAQEL
[AWS-SECRET-REMOVED]eTELMAkGA1UEBhMCVFIxDzAN
[AWS-SECRET-REMOVED]a8OnxLExHDAaBgNVBAoME1R3
[AWS-SECRET-REMOVED]dG9tYXRlZCBUZXN0aW5nIEF1
[AWS-SECRET-REMOVED]dHlAdHdpc3RlZG1hdHJpeC5j
[AWS-SECRET-REMOVED]MzQwMjhaMIG9MRgwFgYDVQQD
[AWS-SECRET-REMOVED]MQ8wDQYDVQQIDAbDh29ydW0x
[AWS-SECRET-REMOVED]DBNUd2lzdGVkIE1hdHJpeCBM
[AWS-SECRET-REMOVED]ZyBBdXRob3JpdHkxKTAnBgkq
[AWS-SECRET-REMOVED]aXguY29tMIIBIjANBgkqhkiG
[AWS-SECRET-REMOVED]7qXms9PZWHskXZGXLPiYVmiY
[AWS-SECRET-REMOVED]Ch4liyxdWkBLw9maxMoE+r6d
[AWS-SECRET-REMOVED]D4GvTby6xpoR09AqrfjuEIYR
[AWS-SECRET-REMOVED]55t7UW6Ebj2X2WTO6Zh7gJ1d
[AWS-SECRET-REMOVED]agE3evUv/BECJLONNYLaFjLt
[AWS-SECRET-REMOVED]lHiQnXdyrwIDAQABoxgwFjAU
[AWS-SECRET-REMOVED]AQELBQADggEBAEHAErq/Fs8h
[AWS-SECRET-REMOVED]a/u+ajoxrZaOheg8E2MYVwQi
[AWS-SECRET-REMOVED]iQIzXON2RvgJpwFfkLNtq0t9
[AWS-SECRET-REMOVED]IUcO0tU8O4kWrLIFPpJbcHQq
[AWS-SECRET-REMOVED]93OZJgwE2x3iUd3k8HbwxfoY
[AWS-SECRET-REMOVED]zYlxFBoDyalR7NJjJGdTwNFt
3CPGCQ28cDk=
-----END CERTIFICATE-----

View File

@@ -0,0 +1,40 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
from __future__ import annotations
from typing import Literal
from zope.interface import Interface, implementer
from twisted.python import components
def foo() -> Literal[2]:
return 2
class X:
def __init__(self, x: str) -> None:
self.x = x
def do(self) -> None:
# print 'X',self.x,'doing!'
pass
class XComponent(components.Componentized):
pass
class IX(Interface):
pass
@implementer(IX)
class XA(components.Adapter):
def method(self) -> None:
# Kick start :(
pass
components.registerAdapter(XA, X, IX)

View File

@@ -0,0 +1,577 @@
# -*- test-case-name: twisted.test.test_amp,twisted.test.test_iosim -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Utilities and helpers for simulating a network
"""
import itertools
try:
from OpenSSL.SSL import Error as NativeOpenSSLError
except ImportError:
pass
from zope.interface import directlyProvides, implementer
from twisted.internet import error, interfaces
from twisted.internet.endpoints import TCP4ClientEndpoint, TCP4ServerEndpoint
from twisted.internet.error import ConnectionRefusedError
from twisted.internet.protocol import Factory, Protocol
from twisted.internet.testing import MemoryReactorClock
from twisted.python.failure import Failure
class TLSNegotiation:
def __init__(self, obj, connectState):
self.obj = obj
self.connectState = connectState
self.sent = False
self.readyToSend = connectState
def __repr__(self) -> str:
return f"TLSNegotiation({self.obj!r})"
def pretendToVerify(self, other, tpt):
# Set the transport problems list here? disconnections?
# hmmmmm... need some negative path tests.
if not self.obj.iosimVerify(other.obj):
tpt.disconnectReason = NativeOpenSSLError()
tpt.loseConnection()
@implementer(interfaces.IAddress)
class FakeAddress:
"""
The default address type for the host and peer of L{FakeTransport}
connections.
"""
@implementer(interfaces.ITransport, interfaces.ITLSTransport)
class FakeTransport:
"""
A wrapper around a file-like object to make it behave as a Transport.
This doesn't actually stream the file to the attached protocol,
and is thus useful mainly as a utility for debugging protocols.
"""
_nextserial = staticmethod(lambda counter=itertools.count(): int(next(counter)))
closed = 0
disconnecting = 0
disconnected = 0
disconnectReason = error.ConnectionDone("Connection done")
producer = None
streamingProducer = 0
tls = None
def __init__(self, protocol, isServer, hostAddress=None, peerAddress=None):
"""
@param protocol: This transport will deliver bytes to this protocol.
@type protocol: L{IProtocol} provider
@param isServer: C{True} if this is the accepting side of the
connection, C{False} if it is the connecting side.
@type isServer: L{bool}
@param hostAddress: The value to return from C{getHost}. L{None}
results in a new L{FakeAddress} being created to use as the value.
@type hostAddress: L{IAddress} provider or L{None}
@param peerAddress: The value to return from C{getPeer}. L{None}
results in a new L{FakeAddress} being created to use as the value.
@type peerAddress: L{IAddress} provider or L{None}
"""
self.protocol = protocol
self.isServer = isServer
self.stream = []
self.serial = self._nextserial()
if hostAddress is None:
hostAddress = FakeAddress()
self.hostAddress = hostAddress
if peerAddress is None:
peerAddress = FakeAddress()
self.peerAddress = peerAddress
def __repr__(self) -> str:
return "FakeTransport<{},{},{}>".format(
self.isServer and "S" or "C",
self.serial,
self.protocol.__class__.__name__,
)
def write(self, data):
# If transport is closed, we should accept writes but drop the data.
if self.disconnecting:
return
if self.tls is not None:
self.tlsbuf.append(data)
else:
self.stream.append(data)
def _checkProducer(self):
# Cheating; this is called at "idle" times to allow producers to be
# found and dealt with
if self.producer and not self.streamingProducer:
self.producer.resumeProducing()
def registerProducer(self, producer, streaming):
"""
From abstract.FileDescriptor
"""
self.producer = producer
self.streamingProducer = streaming
if not streaming:
producer.resumeProducing()
def unregisterProducer(self):
self.producer = None
def stopConsuming(self):
self.unregisterProducer()
self.loseConnection()
def writeSequence(self, iovec):
self.write(b"".join(iovec))
def loseConnection(self):
self.disconnecting = True
def abortConnection(self):
"""
For the time being, this is the same as loseConnection; no buffered
data will be lost.
"""
self.disconnecting = True
def reportDisconnect(self):
if self.tls is not None:
# We were in the middle of negotiating! Must have been a TLS
# problem.
err = NativeOpenSSLError()
else:
err = self.disconnectReason
self.protocol.connectionLost(Failure(err))
def logPrefix(self):
"""
Identify this transport/event source to the logging system.
"""
return "iosim"
def getPeer(self):
return self.peerAddress
def getHost(self):
return self.hostAddress
def resumeProducing(self):
# Never sends data anyways
pass
def pauseProducing(self):
# Never sends data anyways
pass
def stopProducing(self):
self.loseConnection()
def startTLS(self, contextFactory, beNormal=True):
# Nothing's using this feature yet, but startTLS has an undocumented
# second argument which defaults to true; if set to False, servers will
# behave like clients and clients will behave like servers.
connectState = self.isServer ^ beNormal
self.tls = TLSNegotiation(contextFactory, connectState)
self.tlsbuf = []
def getOutBuffer(self):
"""
Get the pending writes from this transport, clearing them from the
pending buffer.
@return: the bytes written with C{transport.write}
@rtype: L{bytes}
"""
S = self.stream
if S:
self.stream = []
return b"".join(S)
elif self.tls is not None:
if self.tls.readyToSend:
# Only _send_ the TLS negotiation "packet" if I'm ready to.
self.tls.sent = True
return self.tls
else:
return None
else:
return None
def bufferReceived(self, buf):
if isinstance(buf, TLSNegotiation):
assert self.tls is not None # By the time you're receiving a
# negotiation, you have to have called
# startTLS already.
if self.tls.sent:
self.tls.pretendToVerify(buf, self)
self.tls = None # We're done with the handshake if we've gotten
# this far... although maybe it failed...?
# TLS started! Unbuffer...
b, self.tlsbuf = self.tlsbuf, None
self.writeSequence(b)
directlyProvides(self, interfaces.ISSLTransport)
else:
# We haven't sent our own TLS negotiation: time to do that!
self.tls.readyToSend = True
else:
self.protocol.dataReceived(buf)
def getTcpKeepAlive(self):
# ITCPTransport.getTcpKeepAlive
pass
def getTcpNoDelay(self):
# ITCPTransport.getTcpNoDelay
pass
def loseWriteConnection(self):
# ITCPTransport.loseWriteConnection
pass
def setTcpKeepAlive(self, enabled):
# ITCPTransport.setTcpKeepAlive
pass
def setTcpNoDelay(self, enabled):
# ITCPTransport.setTcpNoDelay
pass
def makeFakeClient(clientProtocol):
"""
Create and return a new in-memory transport hooked up to the given protocol.
@param clientProtocol: The client protocol to use.
@type clientProtocol: L{IProtocol} provider
@return: The transport.
@rtype: L{FakeTransport}
"""
return FakeTransport(clientProtocol, isServer=False)
def makeFakeServer(serverProtocol):
"""
Create and return a new in-memory transport hooked up to the given protocol.
@param serverProtocol: The server protocol to use.
@type serverProtocol: L{IProtocol} provider
@return: The transport.
@rtype: L{FakeTransport}
"""
return FakeTransport(serverProtocol, isServer=True)
class IOPump:
"""
Utility to pump data between clients and servers for protocol testing.
Perhaps this is a utility worthy of being in protocol.py?
"""
def __init__(self, client, server, clientIO, serverIO, debug, clock=None):
self.client = client
self.server = server
self.clientIO = clientIO
self.serverIO = serverIO
self.debug = debug
if clock is None:
clock = MemoryReactorClock()
self.clock = clock
def flush(self, debug=False, advanceClock=True):
"""
Pump until there is no more input or output.
Returns whether any data was moved.
"""
result = False
for _ in range(1000):
if self.pump(debug, advanceClock):
result = True
else:
break
else:
assert 0, "Too long"
return result
def pump(self, debug=False, advanceClock=True):
"""
Move data back and forth, while also triggering any currently pending
scheduled calls (i.e. C{callLater(0, f)}).
Returns whether any data was moved.
"""
if advanceClock:
self.clock.advance(0)
if self.debug or debug:
print("-- GLUG --")
sData = self.serverIO.getOutBuffer()
cData = self.clientIO.getOutBuffer()
self.clientIO._checkProducer()
self.serverIO._checkProducer()
if self.debug or debug:
print(".")
# XXX slightly buggy in the face of incremental output
if cData:
print("C: " + repr(cData))
if sData:
print("S: " + repr(sData))
if cData:
self.serverIO.bufferReceived(cData)
if sData:
self.clientIO.bufferReceived(sData)
if cData or sData:
return True
if self.serverIO.disconnecting and not self.serverIO.disconnected:
if self.debug or debug:
print("* C")
self.serverIO.disconnected = True
self.clientIO.disconnecting = True
self.clientIO.reportDisconnect()
return True
if self.clientIO.disconnecting and not self.clientIO.disconnected:
if self.debug or debug:
print("* S")
self.clientIO.disconnected = True
self.serverIO.disconnecting = True
self.serverIO.reportDisconnect()
return True
return False
def connect(
serverProtocol,
serverTransport,
clientProtocol,
clientTransport,
debug=False,
greet=True,
clock=None,
):
"""
Create a new L{IOPump} connecting two protocols.
@param serverProtocol: The protocol to use on the accepting side of the
connection.
@type serverProtocol: L{IProtocol} provider
@param serverTransport: The transport to associate with C{serverProtocol}.
@type serverTransport: L{FakeTransport}
@param clientProtocol: The protocol to use on the initiating side of the
connection.
@type clientProtocol: L{IProtocol} provider
@param clientTransport: The transport to associate with C{clientProtocol}.
@type clientTransport: L{FakeTransport}
@param debug: A flag indicating whether to log information about what the
L{IOPump} is doing.
@type debug: L{bool}
@param greet: Should the L{IOPump} be L{flushed <IOPump.flush>} once before
returning to put the protocols into their post-handshake or
post-server-greeting state?
@type greet: L{bool}
@param clock: An optional L{Clock}. Pumping the resulting L{IOPump} will
also increase clock time by a small increment.
@return: An L{IOPump} which connects C{serverProtocol} and
C{clientProtocol} and delivers bytes between them when it is pumped.
@rtype: L{IOPump}
"""
serverProtocol.makeConnection(serverTransport)
clientProtocol.makeConnection(clientTransport)
pump = IOPump(
clientProtocol,
serverProtocol,
clientTransport,
serverTransport,
debug,
clock=clock,
)
if greet:
# Kick off server greeting, etc
pump.flush()
return pump
def connectedServerAndClient(
ServerClass,
ClientClass,
clientTransportFactory=makeFakeClient,
serverTransportFactory=makeFakeServer,
debug=False,
greet=True,
clock=None,
):
"""
Connect a given server and client class to each other.
@param ServerClass: a callable that produces the server-side protocol.
@type ServerClass: 0-argument callable returning L{IProtocol} provider.
@param ClientClass: like C{ServerClass} but for the other side of the
connection.
@type ClientClass: 0-argument callable returning L{IProtocol} provider.
@param clientTransportFactory: a callable that produces the transport which
will be attached to the protocol returned from C{ClientClass}.
@type clientTransportFactory: callable taking (L{IProtocol}) and returning
L{FakeTransport}
@param serverTransportFactory: a callable that produces the transport which
will be attached to the protocol returned from C{ServerClass}.
@type serverTransportFactory: callable taking (L{IProtocol}) and returning
L{FakeTransport}
@param debug: Should this dump an escaped version of all traffic on this
connection to stdout for inspection?
@type debug: L{bool}
@param greet: Should the L{IOPump} be L{flushed <IOPump.flush>} once before
returning to put the protocols into their post-handshake or
post-server-greeting state?
@type greet: L{bool}
@param clock: An optional L{Clock}. Pumping the resulting L{IOPump} will
also increase clock time by a small increment.
@return: the client protocol, the server protocol, and an L{IOPump} which,
when its C{pump} and C{flush} methods are called, will move data
between the created client and server protocol instances.
@rtype: 3-L{tuple} of L{IProtocol}, L{IProtocol}, L{IOPump}
"""
c = ClientClass()
s = ServerClass()
cio = clientTransportFactory(c)
sio = serverTransportFactory(s)
return c, s, connect(s, sio, c, cio, debug, greet, clock=clock)
def _factoriesShouldConnect(clientInfo, serverInfo):
"""
Should the client and server described by the arguments be connected to
each other, i.e. do their port numbers match?
@param clientInfo: the args for connectTCP
@type clientInfo: L{tuple}
@param serverInfo: the args for listenTCP
@type serverInfo: L{tuple}
@return: If they do match, return factories for the client and server that
should connect; otherwise return L{None}, indicating they shouldn't be
connected.
@rtype: L{None} or 2-L{tuple} of (L{ClientFactory},
L{IProtocolFactory})
"""
(
clientHost,
clientPort,
clientFactory,
clientTimeout,
clientBindAddress,
) = clientInfo
(serverPort, serverFactory, serverBacklog, serverInterface) = serverInfo
if serverPort == clientPort:
return clientFactory, serverFactory
else:
return None
class ConnectionCompleter:
"""
A L{ConnectionCompleter} can cause synthetic TCP connections established by
L{MemoryReactor.connectTCP} and L{MemoryReactor.listenTCP} to succeed or
fail.
"""
def __init__(self, memoryReactor):
"""
Create a L{ConnectionCompleter} from a L{MemoryReactor}.
@param memoryReactor: The reactor to attach to.
@type memoryReactor: L{MemoryReactor}
"""
self._reactor = memoryReactor
def succeedOnce(self, debug=False):
"""
Complete a single TCP connection established on this
L{ConnectionCompleter}'s L{MemoryReactor}.
@param debug: A flag; whether to dump output from the established
connection to stdout.
@type debug: L{bool}
@return: a pump for the connection, or L{None} if no connection could
be established.
@rtype: L{IOPump} or L{None}
"""
memoryReactor = self._reactor
for clientIdx, clientInfo in enumerate(memoryReactor.tcpClients):
for serverInfo in memoryReactor.tcpServers:
factories = _factoriesShouldConnect(clientInfo, serverInfo)
if factories:
memoryReactor.tcpClients.remove(clientInfo)
memoryReactor.connectors.pop(clientIdx)
clientFactory, serverFactory = factories
clientProtocol = clientFactory.buildProtocol(None)
serverProtocol = serverFactory.buildProtocol(None)
serverTransport = makeFakeServer(serverProtocol)
clientTransport = makeFakeClient(clientProtocol)
return connect(
serverProtocol,
serverTransport,
clientProtocol,
clientTransport,
debug,
)
def failOnce(self, reason=Failure(ConnectionRefusedError())):
"""
Fail a single TCP connection established on this
L{ConnectionCompleter}'s L{MemoryReactor}.
@param reason: the reason to provide that the connection failed.
@type reason: L{Failure}
"""
self._reactor.tcpClients.pop(0)[2].clientConnectionFailed(
self._reactor.connectors.pop(0), reason
)
def connectableEndpoint(debug=False):
"""
Create an endpoint that can be fired on demand.
@param debug: A flag; whether to dump output from the established
connection to stdout.
@type debug: L{bool}
@return: A client endpoint, and an object that will cause one of the
L{Deferred}s returned by that client endpoint.
@rtype: 2-L{tuple} of (L{IStreamClientEndpoint}, L{ConnectionCompleter})
"""
reactor = MemoryReactorClock()
clientEndpoint = TCP4ClientEndpoint(reactor, "0.0.0.0", 4321)
serverEndpoint = TCP4ServerEndpoint(reactor, 4321)
serverEndpoint.listen(Factory.forProtocol(Protocol))
return clientEndpoint, ConnectionCompleter(reactor)

View File

@@ -0,0 +1,27 @@
[RSA-PRIVATE-KEY-REMOVED]
[AWS-SECRET-REMOVED]XZGXLPiYVmiYjsVeJAOtHAYq
[AWS-SECRET-REMOVED]w9maxMoE+r6dW1zZ8Tllunbd
[AWS-SECRET-REMOVED]09AqrfjuEIYR8V/y+8dG3mR5
[AWS-SECRET-REMOVED]2WTO6Zh7gJ1dyHPMVkUHJF9J
[AWS-SECRET-REMOVED]JLONNYLaFjLtWnsCEJDV9owC
[AWS-SECRET-REMOVED]AQABAoH/Ib7aSjKDHXTaFV58
[AWS-SECRET-REMOVED]Rn5wphAt/mlXbx7IW0X1cali
[AWS-SECRET-REMOVED]J9qiUUGDyCnGKWbofN9VpCYg
[AWS-SECRET-REMOVED]cCLVrhVrHzw1HFTIlA51LjfI
[AWS-SECRET-REMOVED]rf7URSudS+Us6vr6gDVpKAky
[AWS-SECRET-REMOVED]sK/ELQfhPoyHyRvL1woUIO5C
[AWS-SECRET-REMOVED]3N0jWm8R8ENOnuIjhCl5aKsB
[AWS-SECRET-REMOVED]/k8zoqgcj9CmmDofBka4XxZb
[AWS-SECRET-REMOVED]UME2t2EaryUzAoGBANxpb4Jz
[AWS-SECRET-REMOVED]01KTaKyYpq+9q9VxXhWxYsh3
[AWS-SECRET-REMOVED]QxDEfKeFedskxORs+FIUzqBb
[AWS-SECRET-REMOVED]WNFhLoGPYgsTcnrk0N1QLmnZ
[AWS-SECRET-REMOVED]fbAbaU3gDy0K24z+YeNbWCjI
[AWS-SECRET-REMOVED]kmcCuEyBqlSinLslRd/997Bx
[AWS-SECRET-REMOVED]5k94Qjn4wBf7WwmgfDm6HHbs
[AWS-SECRET-REMOVED]Eiwcq4r2aBSNsI305Z5sUWtn
[AWS-SECRET-REMOVED]Yg/jP5EyqSiXtUZfSodL7yeH
[AWS-SECRET-REMOVED]iZ8P86FnWBf1iDeuywEZJqvG
[AWS-SECRET-REMOVED]vrjk0ONnXX7VXNgJ3/e7aJTx
[AWS-SECRET-REMOVED]jVuh1lyd4C4=
-----END RSA PRIVATE KEY-----

View File

@@ -0,0 +1,52 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
This is a mock win32process module.
The purpose of this module is mock process creation for the PID test.
CreateProcess(...) will spawn a process, and always return a PID of 42.
"""
import win32process
GetExitCodeProcess = win32process.GetExitCodeProcess
STARTUPINFO = win32process.STARTUPINFO
STARTF_USESTDHANDLES = win32process.STARTF_USESTDHANDLES
def CreateProcess(
appName,
cmdline,
procSecurity,
threadSecurity,
inheritHandles,
newEnvironment,
env,
workingDir,
startupInfo,
):
"""
This function mocks the generated pid aspect of the win32.CreateProcess
function.
- the true win32process.CreateProcess is called
- return values are harvested in a tuple.
- all return values from createProcess are passed back to the calling
function except for the pid, the returned pid is hardcoded to 42
"""
hProcess, hThread, dwPid, dwTid = win32process.CreateProcess(
appName,
cmdline,
procSecurity,
threadSecurity,
inheritHandles,
newEnvironment,
env,
workingDir,
startupInfo,
)
dwPid = 42
return (hProcess, hThread, dwPid, dwTid)

View File

@@ -0,0 +1,13 @@
class A:
def a(self) -> str:
return "a"
class B(A):
def b(self) -> str:
return "b"
class Inherit(A):
def a(self) -> str:
return "c"

View File

@@ -0,0 +1,13 @@
class A:
def a(self) -> str:
return "b"
class B(A):
def b(self) -> str:
return "c"
class Inherit(A):
def a(self) -> str:
return "d"

View File

@@ -0,0 +1,47 @@
# Copyright (c) 2005 Divmod, Inc.
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
# Don't change the docstring, it's part of the tests
"""
I'm a test drop-in. The plugin system's unit tests use me. No one
else should.
"""
from zope.interface import provider
from twisted.plugin import IPlugin
from twisted.test.test_plugin import ITestPlugin, ITestPlugin2
@provider(ITestPlugin, IPlugin)
class TestPlugin:
"""
A plugin used solely for testing purposes.
"""
@staticmethod
def test1() -> None:
pass
@provider(ITestPlugin2, IPlugin)
class AnotherTestPlugin:
"""
Another plugin used solely for testing purposes.
"""
@staticmethod
def test() -> None:
pass
@provider(ITestPlugin2, IPlugin)
class ThirdTestPlugin:
"""
Another plugin used solely for testing purposes.
"""
@staticmethod
def test() -> None:
pass

View File

@@ -0,0 +1,19 @@
# Copyright (c) 2005 Divmod, Inc.
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Test plugin used in L{twisted.test.test_plugin}.
"""
from zope.interface import provider
from twisted.plugin import IPlugin
from twisted.test.test_plugin import ITestPlugin
@provider(ITestPlugin, IPlugin)
class FourthTestPlugin:
@staticmethod
def test1() -> None:
pass

View File

@@ -0,0 +1,30 @@
# Copyright (c) 2005 Divmod, Inc.
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Test plugin used in L{twisted.test.test_plugin}.
"""
from zope.interface import provider
from twisted.plugin import IPlugin
from twisted.test.test_plugin import ITestPlugin
@provider(ITestPlugin, IPlugin)
class FourthTestPlugin:
@staticmethod
def test1() -> None:
pass
@provider(ITestPlugin, IPlugin)
class FifthTestPlugin:
"""
More documentation: I hate you.
"""
@staticmethod
def test1() -> None:
pass

View File

@@ -0,0 +1,9 @@
"""
Write to stdout the command line args it received, one per line.
"""
import sys
for x in sys.argv[1:]:
print(x)

View File

@@ -0,0 +1,11 @@
"""Write back all data it receives."""
import sys
data = sys.stdin.read(1)
while data:
sys.stdout.write(data)
sys.stdout.flush()
data = sys.stdin.read(1)
sys.stderr.write("byebye")
sys.stderr.flush()

View File

@@ -0,0 +1,50 @@
"""Write to a handful of file descriptors, to test the childFDs= argument of
reactor.spawnProcess()
"""
import os
import sys
if __name__ == "__main__":
debug = 0
if debug:
stderr = os.fdopen(2, "w")
if debug:
print("this is stderr", file=stderr)
abcd = os.read(0, 4)
if debug:
print("read(0):", abcd, file=stderr)
if abcd != b"abcd":
sys.exit(1)
if debug:
print("os.write(1, righto)", file=stderr)
os.write(1, b"righto")
efgh = os.read(3, 4)
if debug:
print("read(3):", file=stderr)
if efgh != b"efgh":
sys.exit(2)
if debug:
print("os.close(4)", file=stderr)
os.close(4)
eof = os.read(5, 4)
if debug:
print("read(5):", eof, file=stderr)
if eof != b"":
sys.exit(3)
if debug:
print("os.write(1, closed)", file=stderr)
os.write(1, b"closed")
if debug:
print("sys.exit(0)", file=stderr)
sys.exit(0)

View File

@@ -0,0 +1,13 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Used by L{twisted.test.test_process}.
"""
from sys import argv, stdout
if __name__ == "__main__":
stdout.write(chr(0).join(argv))
stdout.flush()

View File

@@ -0,0 +1,13 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Used by L{twisted.test.test_process}.
"""
from os import environ
from sys import stdout
items = environ.items()
stdout.write(chr(0).join([k + chr(0) + v for k, v in items]))
stdout.flush()

View File

@@ -0,0 +1,17 @@
"""Write to a file descriptor and then close it, waiting a few seconds before
quitting. This serves to make sure SIGCHLD is actually being noticed.
"""
import os
import sys
import time
print("here is some text")
time.sleep(1)
print("goodbye")
os.close(1)
os.close(2)
time.sleep(2)
sys.exit(0)

View File

@@ -0,0 +1,10 @@
"""Script used by test_process.TestTwoProcesses"""
# run until stdin is closed, then quit
import sys
while 1:
d = sys.stdin.read()
if len(d) == 0:
sys.exit(0)

View File

@@ -0,0 +1,9 @@
import signal
import sys
signal.signal(signal.SIGINT, signal.SIG_DFL)
if getattr(signal, "SIGHUP", None) is not None:
signal.signal(signal.SIGHUP, signal.SIG_DFL)
print("ok, signal us")
sys.stdin.read()
sys.exit(1)

View File

@@ -0,0 +1,34 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Script used by twisted.test.test_process on win32.
"""
import msvcrt
import os
import sys
msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY) # type:ignore[attr-defined]
msvcrt.setmode(sys.stderr.fileno(), os.O_BINARY) # type:ignore[attr-defined]
# We want to write bytes directly to the output, not text, because otherwise
# newlines get mangled. Get the underlying buffer.
stdout = sys.stdout.buffer
stderr = sys.stderr.buffer
stdin = sys.stdin.buffer
stdout.write(b"out\n")
stdout.flush()
stderr.write(b"err\n")
stderr.flush()
data = stdin.read()
stdout.write(data)
stdout.write(b"\nout\n")
stderr.write(b"err\n")
sys.stdout.flush()
sys.stderr.flush()

View File

@@ -0,0 +1,44 @@
"""Test program for processes."""
import os
import sys
test_file_match = "process_test.log.*"
test_file = "process_test.log.%d" % os.getpid()
def main() -> None:
f = open(test_file, "wb")
stdin = sys.stdin.buffer
stderr = sys.stderr.buffer
stdout = sys.stdout.buffer
# stage 1
b = stdin.read(4)
f.write(b"one: " + b + b"\n")
# stage 2
stdout.write(b)
stdout.flush()
os.close(sys.stdout.fileno())
# and a one, and a two, and a...
b = stdin.read(4)
f.write(b"two: " + b + b"\n")
# stage 3
stderr.write(b)
stderr.flush()
os.close(stderr.fileno())
# stage 4
b = stdin.read(4)
f.write(b"three: " + b + b"\n")
# exit with status code 23
sys.exit(23)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,5 @@
"""Test to make sure we can open /dev/tty"""
with open("/dev/tty", "rb+", buffering=0) as f:
a = f.readline()
f.write(a)

View File

@@ -0,0 +1,49 @@
"""A process that reads from stdin and out using Twisted."""
# Twisted Preamble
# This makes sure that users don't have to set up their environment
# specially in order to run these programs from bin/.
import os
import sys
pos = os.path.abspath(sys.argv[0]).find(os.sep + "Twisted")
if pos != -1:
sys.path.insert(0, os.path.abspath(sys.argv[0])[: pos + 8])
sys.path.insert(0, os.curdir)
# end of preamble
from zope.interface import implementer
from twisted.internet import interfaces
from twisted.python import log
log.startLogging(sys.stderr)
from twisted.internet import protocol, reactor, stdio
@implementer(interfaces.IHalfCloseableProtocol)
class Echo(protocol.Protocol):
def connectionMade(self):
print("connection made")
def dataReceived(self, data):
self.transport.write(data)
def readConnectionLost(self):
print("readConnectionLost")
self.transport.loseConnection()
def writeConnectionLost(self):
print("writeConnectionLost")
def connectionLost(self, reason):
print("connectionLost", reason)
reactor.stop()
stdio.StandardIO(Echo())
reactor.run() # type: ignore[attr-defined]

View File

@@ -0,0 +1,43 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Assorted functionality which is commonly useful when writing unit tests.
This module has been deprecated, please use twisted.internet.testing
instead.
"""
from twisted.internet import testing
__all__ = [
"AccumulatingProtocol",
"LineSendingProtocol",
"FakeDatagramTransport",
"StringTransport",
"StringTransportWithDisconnection",
"StringIOWithoutClosing",
"_FakeConnector",
"_FakePort",
"MemoryReactor",
"MemoryReactorClock",
"RaisingMemoryReactor",
"NonStreamingProducer",
"waitUntilAllDisconnected",
"EventLoggingObserver",
]
AccumulatingProtocol = testing.AccumulatingProtocol
LineSendingProtocol = testing.LineSendingProtocol
FakeDatagramTransport = testing.FakeDatagramTransport
StringTransport = testing.StringTransport
StringTransportWithDisconnection = testing.StringTransportWithDisconnection
StringIOWithoutClosing = testing.StringIOWithoutClosing
_FakeConnector = testing._FakeConnector
_FakePort = testing._FakePort
MemoryReactor = testing.MemoryReactor
MemoryReactorClock = testing.MemoryReactorClock
RaisingMemoryReactor = testing.RaisingMemoryReactor
NonStreamingProducer = testing.NonStreamingProducer
waitUntilAllDisconnected = testing.waitUntilAllDisconnected
EventLoggingObserver = testing.EventLoggingObserver

View File

@@ -0,0 +1,3 @@
# Helper for a test_reflect test
__import__("idonotexist")

View File

@@ -0,0 +1,3 @@
# Helper for a test_reflect test
raise ValueError("Stuff is broken and things")

View File

@@ -0,0 +1,3 @@
# Helper module for a test_reflect test
1 // 0

View File

@@ -0,0 +1,123 @@
# coding: utf-8
from datetime import datetime, timedelta
from inspect import getsource
from cryptography.hazmat.primitives.asymmetric.rsa import generate_private_key
from cryptography.hazmat.primitives.hashes import SHA256
from cryptography.hazmat.primitives.serialization import (
Encoding,
NoEncryption,
PrivateFormat,
)
from cryptography.x509 import (
CertificateBuilder,
Name,
NameAttribute,
NameOID,
SubjectAlternativeName,
DNSName,
random_serial_number,
)
pk = generate_private_key(key_size=2048, public_exponent=65537)
me = Name(
[
NameAttribute(NameOID.COMMON_NAME, "A Host, Locally"),
NameAttribute(NameOID.COUNTRY_NAME, "TR"),
NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "Çorum"),
NameAttribute(NameOID.LOCALITY_NAME, "Başmakçı"),
NameAttribute(NameOID.ORGANIZATION_NAME, "Twisted Matrix Labs"),
NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, "Automated Testing Authority"),
NameAttribute(NameOID.EMAIL_ADDRESS, "security@twistedmatrix.com"),
]
)
certificate_bytes = (
CertificateBuilder()
.serial_number(random_serial_number())
.not_valid_before(datetime.now())
.not_valid_after(datetime.now() + timedelta(seconds=60 * 60 * 24 * 365 * 100))
.subject_name(me)
.add_extension(SubjectAlternativeName([DNSName("localhost")]), critical=False)
.issuer_name(me)
.public_key(pk.public_key())
.sign(pk, algorithm=SHA256())
).public_bytes(Encoding.PEM)
privkey_bytes = pk.private_bytes(
Encoding.PEM, PrivateFormat.TraditionalOpenSSL, NoEncryption()
)
import __main__
source = getsource(__main__)
source = source.split("\n" + "-" * 5)[0].rsplit("\n", 1)[0]
with open("server.pem", "w") as fObj:
fObj.write(source)
fObj.write("\n")
fObj.write('"""\n')
fObj.write(privkey_bytes.decode("ascii"))
fObj.write(certificate_bytes.decode("ascii"))
fObj.write('"""\n')
with open(b"key.pem.no_trailing_newline", "w") as fObj:
fObj.write(privkey_bytes.decode("ascii").rstrip("\n"))
with open(b"cert.pem.no_trailing_newline", "w") as fObj:
fObj.write(certificate_bytes.decode("ascii").rstrip("\n"))
"""
[RSA-PRIVATE-KEY-REMOVED]
[AWS-SECRET-REMOVED]XZGXLPiYVmiYjsVeJAOtHAYq
[AWS-SECRET-REMOVED]w9maxMoE+r6dW1zZ8Tllunbd
[AWS-SECRET-REMOVED]09AqrfjuEIYR8V/y+8dG3mR5
[AWS-SECRET-REMOVED]2WTO6Zh7gJ1dyHPMVkUHJF9J
[AWS-SECRET-REMOVED]JLONNYLaFjLtWnsCEJDV9owC
[AWS-SECRET-REMOVED]AQABAoH/Ib7aSjKDHXTaFV58
[AWS-SECRET-REMOVED]Rn5wphAt/mlXbx7IW0X1cali
[AWS-SECRET-REMOVED]J9qiUUGDyCnGKWbofN9VpCYg
[AWS-SECRET-REMOVED]cCLVrhVrHzw1HFTIlA51LjfI
[AWS-SECRET-REMOVED]rf7URSudS+Us6vr6gDVpKAky
[AWS-SECRET-REMOVED]sK/ELQfhPoyHyRvL1woUIO5C
[AWS-SECRET-REMOVED]3N0jWm8R8ENOnuIjhCl5aKsB
[AWS-SECRET-REMOVED]/k8zoqgcj9CmmDofBka4XxZb
[AWS-SECRET-REMOVED]UME2t2EaryUzAoGBANxpb4Jz
[AWS-SECRET-REMOVED]01KTaKyYpq+9q9VxXhWxYsh3
[AWS-SECRET-REMOVED]QxDEfKeFedskxORs+FIUzqBb
[AWS-SECRET-REMOVED]WNFhLoGPYgsTcnrk0N1QLmnZ
[AWS-SECRET-REMOVED]fbAbaU3gDy0K24z+YeNbWCjI
[AWS-SECRET-REMOVED]kmcCuEyBqlSinLslRd/997Bx
[AWS-SECRET-REMOVED]5k94Qjn4wBf7WwmgfDm6HHbs
[AWS-SECRET-REMOVED]Eiwcq4r2aBSNsI305Z5sUWtn
[AWS-SECRET-REMOVED]Yg/jP5EyqSiXtUZfSodL7yeH
[AWS-SECRET-REMOVED]iZ8P86FnWBf1iDeuywEZJqvG
[AWS-SECRET-REMOVED]vrjk0ONnXX7VXNgJ3/e7aJTx
[AWS-SECRET-REMOVED]jVuh1lyd4C4=
-----END RSA PRIVATE KEY-----
-----BEGIN CERTIFICATE-----
[AWS-SECRET-REMOVED]A9HEyFEwDQYJKoZIhvcNAQEL
[AWS-SECRET-REMOVED]eTELMAkGA1UEBhMCVFIxDzAN
[AWS-SECRET-REMOVED]a8OnxLExHDAaBgNVBAoME1R3
[AWS-SECRET-REMOVED]dG9tYXRlZCBUZXN0aW5nIEF1
[AWS-SECRET-REMOVED]dHlAdHdpc3RlZG1hdHJpeC5j
[AWS-SECRET-REMOVED]MzQwMjhaMIG9MRgwFgYDVQQD
[AWS-SECRET-REMOVED]MQ8wDQYDVQQIDAbDh29ydW0x
[AWS-SECRET-REMOVED]DBNUd2lzdGVkIE1hdHJpeCBM
[AWS-SECRET-REMOVED]ZyBBdXRob3JpdHkxKTAnBgkq
[AWS-SECRET-REMOVED]aXguY29tMIIBIjANBgkqhkiG
[AWS-SECRET-REMOVED]7qXms9PZWHskXZGXLPiYVmiY
[AWS-SECRET-REMOVED]Ch4liyxdWkBLw9maxMoE+r6d
[AWS-SECRET-REMOVED]D4GvTby6xpoR09AqrfjuEIYR
[AWS-SECRET-REMOVED]55t7UW6Ebj2X2WTO6Zh7gJ1d
[AWS-SECRET-REMOVED]agE3evUv/BECJLONNYLaFjLt
[AWS-SECRET-REMOVED]lHiQnXdyrwIDAQABoxgwFjAU
[AWS-SECRET-REMOVED]AQELBQADggEBAEHAErq/Fs8h
[AWS-SECRET-REMOVED]a/u+ajoxrZaOheg8E2MYVwQi
[AWS-SECRET-REMOVED]iQIzXON2RvgJpwFfkLNtq0t9
[AWS-SECRET-REMOVED]IUcO0tU8O4kWrLIFPpJbcHQq
[AWS-SECRET-REMOVED]93OZJgwE2x3iUd3k8HbwxfoY
[AWS-SECRET-REMOVED]zYlxFBoDyalR7NJjJGdTwNFt
3CPGCQ28cDk=
-----END CERTIFICATE-----
"""

View File

@@ -0,0 +1,66 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Helper classes for twisted.test.test_ssl.
They are in a separate module so they will not prevent test_ssl importing if
pyOpenSSL is unavailable.
"""
from __future__ import annotations
from OpenSSL import SSL
from twisted.internet import ssl
from twisted.python.compat import nativeString
from twisted.python.filepath import FilePath
certPath = nativeString(FilePath(__file__.encode("utf-8")).sibling(b"server.pem").path)
class ClientTLSContext(ssl.ClientContextFactory):
"""
SSL Context Factory for client-side connections.
"""
isClient = 1
def getContext(self) -> SSL.Context:
"""
Return an L{SSL.Context} to be use for client-side connections.
Will not return a cached context.
This is done to improve the test coverage as most implementation
are caching the context.
"""
return SSL.Context(SSL.SSLv23_METHOD)
class ServerTLSContext:
"""
SSL Context Factory for server-side connections.
"""
isClient = 0
def __init__(
self, filename: str | bytes = certPath, method: int | None = None
) -> None:
self.filename = filename
if method is None:
method = SSL.SSLv23_METHOD
self._method = method
def getContext(self) -> SSL.Context:
"""
Return an L{SSL.Context} to be use for server-side connections.
Will not return a cached context.
This is done to improve the test coverage as most implementation
are caching the context.
"""
ctx = SSL.Context(self._method)
ctx.use_certificate_file(self.filename)
ctx.use_privatekey_file(self.filename)
return ctx

View File

@@ -0,0 +1,44 @@
# -*- test-case-name: twisted.test.test_stdio.StandardInputOutputTests.test_consumer -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Main program for the child process run by
L{twisted.test.test_stdio.StandardInputOutputTests.test_consumer} to test
that process transports implement IConsumer properly.
"""
import sys
from twisted.internet import protocol, stdio
from twisted.protocols import basic
from twisted.python import log, reflect
def failed(err):
log.startLogging(sys.stderr)
log.err(err)
class ConsumerChild(protocol.Protocol):
def __init__(self, junkPath):
self.junkPath = junkPath
def connectionMade(self):
d = basic.FileSender().beginFileTransfer(
open(self.junkPath, "rb"), self.transport
)
d.addErrback(failed)
d.addCallback(lambda ign: self.transport.loseConnection())
def connectionLost(self, reason):
reactor.stop()
if __name__ == "__main__":
reflect.namedAny(sys.argv[1]).install()
from twisted.internet import reactor
stdio.StandardIO(ConsumerChild(sys.argv[2]))
reactor.run() # type: ignore[attr-defined]

View File

@@ -0,0 +1,68 @@
# -*- test-case-name: twisted.test.test_stdio.StandardInputOutputTests.test_readConnectionLost -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Main program for the child process run by
L{twisted.test.test_stdio.StandardInputOutputTests.test_readConnectionLost} to
test that IHalfCloseableProtocol.readConnectionLost works for stdio transports.
"""
import sys
from zope.interface import implementer
from twisted.internet import protocol, stdio
from twisted.internet.interfaces import IHalfCloseableProtocol
from twisted.python import log, reflect
@implementer(IHalfCloseableProtocol)
class HalfCloseProtocol(protocol.Protocol):
"""
A protocol to hook up to stdio and observe its transport being
half-closed. If all goes as expected, C{exitCode} will be set to C{0};
otherwise it will be set to C{1} to indicate failure.
"""
exitCode = None
def connectionMade(self):
"""
Signal the parent process that we're ready.
"""
self.transport.write(b"x")
def readConnectionLost(self):
"""
This is the desired event. Once it has happened, stop the reactor so
the process will exit.
"""
self.exitCode = 0
reactor.stop()
def connectionLost(self, reason):
"""
This may only be invoked after C{readConnectionLost}. If it happens
otherwise, mark it as an error and shut down.
"""
if self.exitCode is None:
self.exitCode = 1
log.err(reason, "Unexpected call to connectionLost")
reactor.stop()
def writeConnectionLost(self):
# IHalfCloseableProtocol.writeConnectionLost
pass
if __name__ == "__main__":
reflect.namedAny(sys.argv[1]).install()
log.startLogging(open(sys.argv[2], "wb"))
from twisted.internet import reactor
halfCloseProtocol = HalfCloseProtocol()
stdio.StandardIO(halfCloseProtocol)
reactor.run() # type: ignore[attr-defined]
sys.exit(halfCloseProtocol.exitCode)

View File

@@ -0,0 +1,67 @@
# -*- test-case-name: twisted.test.test_stdio.StandardInputOutputTests.test_buggyReadConnectionLost -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Main program for the child process run by
L{twisted.test.test_stdio.StandardInputOutputTests.test_readConnectionLost} to
test that IHalfCloseableProtocol.readConnectionLost works for stdio transports.
"""
import sys
from zope.interface import implementer
from twisted.internet import protocol, stdio
from twisted.internet.interfaces import IHalfCloseableProtocol, IReactorCore, ITransport
from twisted.python import log, reflect
@implementer(IHalfCloseableProtocol)
class HalfCloseProtocol(protocol.Protocol):
"""
A protocol to hook up to stdio and observe its transport being
half-closed. If all goes as expected, C{exitCode} will be set to C{0};
otherwise it will be set to C{1} to indicate failure.
"""
exitCode = None
transport: ITransport
def connectionMade(self) -> None:
"""
Signal the parent process that we're ready.
"""
self.transport.write(b"x")
def readConnectionLost(self) -> None:
"""
This is the desired event. Once it has happened, stop the reactor so
the process will exit.
"""
raise ValueError("something went wrong")
def connectionLost(self, reason: object = None) -> None:
"""
This may only be invoked after C{readConnectionLost}. If it happens
otherwise, mark it as an error and shut down.
"""
self.exitCode = 0
reactor.stop()
def writeConnectionLost(self) -> None:
# IHalfCloseableProtocol.writeConnectionLost
pass
if __name__ == "__main__": # pragma: no branch
reflect.namedAny(sys.argv[1]).install()
log.startLogging(open(sys.argv[2], "w"))
reactor: IReactorCore
from twisted.internet import reactor # type:ignore[assignment]
halfCloseProtocol = HalfCloseProtocol()
stdio.StandardIO(halfCloseProtocol)
reactor.run()
sys.exit(halfCloseProtocol.exitCode)

View File

@@ -0,0 +1,69 @@
# -*- test-case-name: twisted.test.test_stdio.StandardInputOutputTests.test_buggyWriteConnectionLost -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Main program for the child process run by
L{twisted.test.test_stdio.StandardInputOutputTests.test_buggyWriteConnectionLost}
to test that IHalfCloseableProtocol.writeConnectionLost works for stdio
transports.
"""
import sys
from zope.interface import implementer
from twisted.internet import protocol, stdio
from twisted.internet.interfaces import IHalfCloseableProtocol, IReactorCore, ITransport
from twisted.python import log, reflect
@implementer(IHalfCloseableProtocol)
class HalfCloseProtocol(protocol.Protocol):
"""
A protocol to hook up to stdio and observe its transport being
half-closed. If all goes as expected, C{exitCode} will be set to C{0};
otherwise it will be set to C{1} to indicate failure.
"""
exitCode = 9
wasWriteConnectionLost = False
transport: ITransport
def connectionMade(self) -> None:
"""
Signal the parent process that we're ready.
"""
self.transport.write(b"x")
def readConnectionLost(self) -> None:
"""
This is the desired event. Once it has happened, stop the reactor so
the process will exit.
"""
def connectionLost(self, reason: object = None) -> None:
"""
This may only be invoked after C{readConnectionLost}. If it happens
otherwise, mark it as an error and shut down.
"""
if self.wasWriteConnectionLost: # pragma: no branch
self.exitCode = 0
reactor.stop()
def writeConnectionLost(self) -> None:
self.wasWriteConnectionLost = True
raise ValueError("something went wrong")
if __name__ == "__main__": # pragma: no branch
reflect.namedAny(sys.argv[1]).install()
log.startLogging(open(sys.argv[2], "w"))
reactor: IReactorCore
from twisted.internet import reactor # type:ignore[assignment]
halfCloseProtocol = HalfCloseProtocol()
stdio.StandardIO(halfCloseProtocol)
reactor.run()
sys.exit(halfCloseProtocol.exitCode)

View File

@@ -0,0 +1,39 @@
# -*- test-case-name: twisted.test.test_stdio.StandardInputOutputTests.test_hostAndPeer -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Main program for the child process run by
L{twisted.test.test_stdio.StandardInputOutputTests.test_hostAndPeer} to test
that ITransport.getHost() and ITransport.getPeer() work for process transports.
"""
import sys
from twisted.internet import protocol, stdio
from twisted.python import reflect
class HostPeerChild(protocol.Protocol):
def connectionMade(self):
self.transport.write(
b"\n".join(
[
str(self.transport.getHost()).encode("ascii"),
str(self.transport.getPeer()).encode("ascii"),
]
)
)
self.transport.loseConnection()
def connectionLost(self, reason):
reactor.stop()
if __name__ == "__main__":
reflect.namedAny(sys.argv[1]).install()
from twisted.internet import reactor
stdio.StandardIO(HostPeerChild())
reactor.run() # type: ignore[attr-defined]

View File

@@ -0,0 +1,43 @@
# -*- test-case-name: twisted.test.test_stdio.StandardInputOutputTests.test_lastWriteReceived -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Main program for the child process run by
L{twisted.test.test_stdio.StandardInputOutputTests.test_lastWriteReceived}
to test that L{os.write} can be reliably used after
L{twisted.internet.stdio.StandardIO} has finished.
"""
import sys
from twisted.internet.protocol import Protocol
from twisted.internet.stdio import StandardIO
from twisted.python.reflect import namedAny
class LastWriteChild(Protocol):
def __init__(self, reactor, magicString):
self.reactor = reactor
self.magicString = magicString
def connectionMade(self):
self.transport.write(self.magicString)
self.transport.loseConnection()
def connectionLost(self, reason):
self.reactor.stop()
def main(reactor, magicString):
p = LastWriteChild(reactor, magicString.encode("ascii"))
StandardIO(p)
reactor.run()
if __name__ == "__main__":
namedAny(sys.argv[1]).install()
from twisted.internet import reactor
main(reactor, sys.argv[2])

View File

@@ -0,0 +1,50 @@
# -*- test-case-name: twisted.test.test_stdio.StandardInputOutputTests.test_loseConnection -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Main program for the child process run by
L{twisted.test.test_stdio.StandardInputOutputTests.test_loseConnection} to
test that ITransport.loseConnection() works for process transports.
"""
import sys
from twisted.internet import protocol, stdio
from twisted.internet.error import ConnectionDone
from twisted.python import log, reflect
class LoseConnChild(protocol.Protocol):
exitCode = 0
def connectionMade(self):
self.transport.loseConnection()
def connectionLost(self, reason):
"""
Check that C{reason} is a L{Failure} wrapping a L{ConnectionDone}
instance and stop the reactor. If C{reason} is wrong for some reason,
log something about that in C{self.errorLogFile} and make sure the
process exits with a non-zero status.
"""
try:
try:
reason.trap(ConnectionDone)
except BaseException:
log.err(None, "Problem with reason passed to connectionLost")
self.exitCode = 1
finally:
reactor.stop()
if __name__ == "__main__":
reflect.namedAny(sys.argv[1]).install()
log.startLogging(open(sys.argv[2], "wb"))
from twisted.internet import reactor
protocolLoseConnChild = LoseConnChild()
stdio.StandardIO(protocolLoseConnChild)
reactor.run() # type: ignore[attr-defined]
sys.exit(protocolLoseConnChild.exitCode)

View File

@@ -0,0 +1,54 @@
# -*- test-case-name: twisted.test.test_stdio.StandardInputOutputTests.test_producer -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Main program for the child process run by
L{twisted.test.test_stdio.StandardInputOutputTests.test_producer} to test
that process transports implement IProducer properly.
"""
import sys
from twisted.internet import protocol, stdio
from twisted.python import log, reflect
class ProducerChild(protocol.Protocol):
_paused = False
buf = b""
def connectionLost(self, reason):
log.msg("*****OVER*****")
reactor.callLater(1, reactor.stop)
def dataReceived(self, data):
self.buf += data
if self._paused:
log.startLogging(sys.stderr)
log.msg("dataReceived while transport paused!")
self.transport.loseConnection()
else:
self.transport.write(data)
if self.buf.endswith(b"\n0\n"):
self.transport.loseConnection()
else:
self.pause()
def pause(self):
self._paused = True
self.transport.pauseProducing()
reactor.callLater(0.01, self.unpause)
def unpause(self):
self._paused = False
self.transport.resumeProducing()
if __name__ == "__main__":
reflect.namedAny(sys.argv[1]).install()
from twisted.internet import reactor
stdio.StandardIO(ProducerChild())
reactor.run() # type: ignore[attr-defined]

View File

@@ -0,0 +1,34 @@
# -*- test-case-name: twisted.test.test_stdio.StandardInputOutputTests.test_write -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Main program for the child process run by
L{twisted.test.test_stdio.StandardInputOutputTests.test_write} to test that
ITransport.write() works for process transports.
"""
import sys
from twisted.internet import protocol, stdio
from twisted.python import reflect
class WriteChild(protocol.Protocol):
def connectionMade(self):
self.transport.write(b"o")
self.transport.write(b"k")
self.transport.write(b"!")
self.transport.loseConnection()
def connectionLost(self, reason):
reactor.stop()
if __name__ == "__main__":
reflect.namedAny(sys.argv[1]).install()
from twisted.internet import reactor
stdio.StandardIO(WriteChild())
reactor.run() # type: ignore[attr-defined]

View File

@@ -0,0 +1,32 @@
# -*- test-case-name: twisted.test.test_stdio.StandardInputOutputTests.test_writeSequence -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Main program for the child process run by
L{twisted.test.test_stdio.StandardInputOutputTests.test_writeSequence} to test
that ITransport.writeSequence() works for process transports.
"""
import sys
from twisted.internet import protocol, stdio
from twisted.python import reflect
class WriteSequenceChild(protocol.Protocol):
def connectionMade(self):
self.transport.writeSequence([b"o", b"k", b"!"])
self.transport.loseConnection()
def connectionLost(self, reason):
reactor.stop()
if __name__ == "__main__":
reflect.namedAny(sys.argv[1]).install()
from twisted.internet import reactor
stdio.StandardIO(WriteSequenceChild())
reactor.run() # type: ignore[attr-defined]

View File

@@ -0,0 +1,106 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for generic file descriptor based reactor support code.
"""
from socket import AF_IPX
from twisted.internet.abstract import isIPAddress
from twisted.trial.unittest import TestCase
class AddressTests(TestCase):
"""
Tests for address-related functionality.
"""
def test_decimalDotted(self) -> None:
"""
L{isIPAddress} should return C{True} for any decimal dotted
representation of an IPv4 address.
"""
self.assertTrue(isIPAddress("0.1.2.3"))
self.assertTrue(isIPAddress("252.253.254.255"))
def test_shortDecimalDotted(self) -> None:
"""
L{isIPAddress} should return C{False} for a dotted decimal
representation with fewer or more than four octets.
"""
self.assertFalse(isIPAddress("0"))
self.assertFalse(isIPAddress("0.1"))
self.assertFalse(isIPAddress("0.1.2"))
self.assertFalse(isIPAddress("0.1.2.3.4"))
def test_invalidLetters(self) -> None:
"""
L{isIPAddress} should return C{False} for any non-decimal dotted
representation including letters.
"""
self.assertFalse(isIPAddress("a.2.3.4"))
self.assertFalse(isIPAddress("1.b.3.4"))
def test_invalidPunctuation(self) -> None:
"""
L{isIPAddress} should return C{False} for a string containing
strange punctuation.
"""
self.assertFalse(isIPAddress(","))
self.assertFalse(isIPAddress("1,2"))
self.assertFalse(isIPAddress("1,2,3"))
self.assertFalse(isIPAddress("1.,.3,4"))
def test_emptyString(self) -> None:
"""
L{isIPAddress} should return C{False} for the empty string.
"""
self.assertFalse(isIPAddress(""))
def test_invalidNegative(self) -> None:
"""
L{isIPAddress} should return C{False} for negative decimal values.
"""
self.assertFalse(isIPAddress("-1"))
self.assertFalse(isIPAddress("1.-2"))
self.assertFalse(isIPAddress("1.2.-3"))
self.assertFalse(isIPAddress("1.2.-3.4"))
def test_invalidPositive(self) -> None:
"""
L{isIPAddress} should return C{False} for a string containing
positive decimal values greater than 255.
"""
self.assertFalse(isIPAddress("256.0.0.0"))
self.assertFalse(isIPAddress("0.256.0.0"))
self.assertFalse(isIPAddress("0.0.256.0"))
self.assertFalse(isIPAddress("0.0.0.256"))
self.assertFalse(isIPAddress("256.256.256.256"))
def test_unicodeAndBytes(self) -> None:
"""
L{isIPAddress} evaluates ASCII-encoded bytes as well as text.
"""
# we test passing bytes but don't support bytes in the type annotation
self.assertFalse(isIPAddress(b"256.0.0.0")) # type: ignore[arg-type]
self.assertFalse(isIPAddress("256.0.0.0"))
self.assertTrue(isIPAddress(b"252.253.254.255")) # type: ignore[arg-type]
self.assertTrue(isIPAddress("252.253.254.255"))
def test_nonIPAddressFamily(self) -> None:
"""
All address families other than C{AF_INET} and C{AF_INET6} result in a
L{ValueError} being raised.
"""
self.assertRaises(ValueError, isIPAddress, b"anything", AF_IPX)
def test_nonASCII(self) -> None:
"""
All IP addresses must be encodable as ASCII; non-ASCII should result in
a L{False} result.
"""
# we test passing bytes but don't support bytes in the type annotation
self.assertFalse(isIPAddress(b"\xff.notascii")) # type: ignore[arg-type]
self.assertFalse(isIPAddress("\u4321.notascii"))

View File

@@ -0,0 +1,869 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for twisted.enterprise.adbapi.
"""
import os
import stat
from typing import Dict, Optional
from twisted.enterprise.adbapi import (
Connection,
ConnectionLost,
ConnectionPool,
Transaction,
)
from twisted.internet import defer, interfaces, reactor
from twisted.python.failure import Failure
from twisted.python.reflect import requireModule
from twisted.trial import unittest
simple_table_schema = """
CREATE TABLE simple (
x integer
)
"""
class ADBAPITestBase:
"""
Test the asynchronous DB-API code.
"""
openfun_called: Dict[object, bool] = {}
if interfaces.IReactorThreads(reactor, None) is None:
skip = "ADB-API requires threads, no way to test without them"
def extraSetUp(self):
"""
Set up the database and create a connection pool pointing at it.
"""
self.startDB()
self.dbpool = self.makePool(cp_openfun=self.openfun)
self.dbpool.start()
def tearDown(self):
d = self.dbpool.runOperation("DROP TABLE simple")
d.addCallback(lambda res: self.dbpool.close())
d.addCallback(lambda res: self.stopDB())
return d
def openfun(self, conn):
self.openfun_called[conn] = True
def checkOpenfunCalled(self, conn=None):
if not conn:
self.assertTrue(self.openfun_called)
else:
self.assertIn(conn, self.openfun_called)
def test_pool(self):
d = self.dbpool.runOperation(simple_table_schema)
if self.test_failures:
d.addCallback(self._testPool_1_1)
d.addCallback(self._testPool_1_2)
d.addCallback(self._testPool_1_3)
d.addCallback(self._testPool_1_4)
d.addCallback(lambda res: self.flushLoggedErrors())
d.addCallback(self._testPool_2)
d.addCallback(self._testPool_3)
d.addCallback(self._testPool_4)
d.addCallback(self._testPool_5)
d.addCallback(self._testPool_6)
d.addCallback(self._testPool_7)
d.addCallback(self._testPool_8)
d.addCallback(self._testPool_9)
return d
def _testPool_1_1(self, res):
d = defer.maybeDeferred(self.dbpool.runQuery, "select * from NOTABLE")
d.addCallbacks(lambda res: self.fail("no exception"), lambda f: None)
return d
def _testPool_1_2(self, res):
d = defer.maybeDeferred(self.dbpool.runOperation, "deletexxx from NOTABLE")
d.addCallbacks(lambda res: self.fail("no exception"), lambda f: None)
return d
def _testPool_1_3(self, res):
d = defer.maybeDeferred(self.dbpool.runInteraction, self.bad_interaction)
d.addCallbacks(lambda res: self.fail("no exception"), lambda f: None)
return d
def _testPool_1_4(self, res):
d = defer.maybeDeferred(self.dbpool.runWithConnection, self.bad_withConnection)
d.addCallbacks(lambda res: self.fail("no exception"), lambda f: None)
return d
def _testPool_2(self, res):
# verify simple table is empty
sql = "select count(1) from simple"
d = self.dbpool.runQuery(sql)
def _check(row):
self.assertTrue(int(row[0][0]) == 0, "Interaction not rolled back")
self.checkOpenfunCalled()
d.addCallback(_check)
return d
def _testPool_3(self, res):
sql = "select count(1) from simple"
inserts = []
# add some rows to simple table (runOperation)
for i in range(self.num_iterations):
sql = "insert into simple(x) values(%d)" % i
inserts.append(self.dbpool.runOperation(sql))
d = defer.gatherResults(inserts)
def _select(res):
# make sure they were added (runQuery)
sql = "select x from simple order by x"
d = self.dbpool.runQuery(sql)
return d
d.addCallback(_select)
def _check(rows):
self.assertTrue(len(rows) == self.num_iterations, "Wrong number of rows")
for i in range(self.num_iterations):
self.assertTrue(len(rows[i]) == 1, "Wrong size row")
self.assertTrue(rows[i][0] == i, "Values not returned.")
d.addCallback(_check)
return d
def _testPool_4(self, res):
# runInteraction
d = self.dbpool.runInteraction(self.interaction)
d.addCallback(lambda res: self.assertEqual(res, "done"))
return d
def _testPool_5(self, res):
# withConnection
d = self.dbpool.runWithConnection(self.withConnection)
d.addCallback(lambda res: self.assertEqual(res, "done"))
return d
def _testPool_6(self, res):
# Test a withConnection cannot be closed
d = self.dbpool.runWithConnection(self.close_withConnection)
return d
def _testPool_7(self, res):
# give the pool a workout
ds = []
for i in range(self.num_iterations):
sql = "select x from simple where x = %d" % i
ds.append(self.dbpool.runQuery(sql))
dlist = defer.DeferredList(ds, fireOnOneErrback=True)
def _check(result):
for i in range(self.num_iterations):
self.assertTrue(result[i][1][0][0] == i, "Value not returned")
dlist.addCallback(_check)
return dlist
def _testPool_8(self, res):
# now delete everything
ds = []
for i in range(self.num_iterations):
sql = "delete from simple where x = %d" % i
ds.append(self.dbpool.runOperation(sql))
dlist = defer.DeferredList(ds, fireOnOneErrback=True)
return dlist
def _testPool_9(self, res):
# verify simple table is empty
sql = "select count(1) from simple"
d = self.dbpool.runQuery(sql)
def _check(row):
self.assertTrue(
int(row[0][0]) == 0, "Didn't successfully delete table contents"
)
self.checkConnect()
d.addCallback(_check)
return d
def checkConnect(self):
"""Check the connect/disconnect synchronous calls."""
conn = self.dbpool.connect()
self.checkOpenfunCalled(conn)
curs = conn.cursor()
curs.execute("insert into simple(x) values(1)")
curs.execute("select x from simple")
res = curs.fetchall()
self.assertEqual(len(res), 1)
self.assertEqual(len(res[0]), 1)
self.assertEqual(res[0][0], 1)
curs.execute("delete from simple")
curs.execute("select x from simple")
self.assertEqual(len(curs.fetchall()), 0)
curs.close()
self.dbpool.disconnect(conn)
def interaction(self, transaction):
transaction.execute("select x from simple order by x")
for i in range(self.num_iterations):
row = transaction.fetchone()
self.assertTrue(len(row) == 1, "Wrong size row")
self.assertTrue(row[0] == i, "Value not returned.")
self.assertIsNone(transaction.fetchone(), "Too many rows")
return "done"
def bad_interaction(self, transaction):
if self.can_rollback:
transaction.execute("insert into simple(x) values(0)")
transaction.execute("select * from NOTABLE")
def withConnection(self, conn):
curs = conn.cursor()
try:
curs.execute("select x from simple order by x")
for i in range(self.num_iterations):
row = curs.fetchone()
self.assertTrue(len(row) == 1, "Wrong size row")
self.assertTrue(row[0] == i, "Value not returned.")
finally:
curs.close()
return "done"
def close_withConnection(self, conn):
conn.close()
def bad_withConnection(self, conn):
curs = conn.cursor()
try:
curs.execute("select * from NOTABLE")
finally:
curs.close()
class ReconnectTestBase:
"""
Test the asynchronous DB-API code with reconnect.
"""
if interfaces.IReactorThreads(reactor, None) is None:
skip = "ADB-API requires threads, no way to test without them"
def extraSetUp(self):
"""
Skip the test if C{good_sql} is unavailable. Otherwise, set up the
database, create a connection pool pointed at it, and set up a simple
schema in it.
"""
if self.good_sql is None:
raise unittest.SkipTest("no good sql for reconnect test")
self.startDB()
self.dbpool = self.makePool(
cp_max=1, cp_reconnect=True, cp_good_sql=self.good_sql
)
self.dbpool.start()
return self.dbpool.runOperation(simple_table_schema)
def tearDown(self):
d = self.dbpool.runOperation("DROP TABLE simple")
d.addCallback(lambda res: self.dbpool.close())
d.addCallback(lambda res: self.stopDB())
return d
def test_pool(self):
d = defer.succeed(None)
d.addCallback(self._testPool_1)
d.addCallback(self._testPool_2)
if not self.early_reconnect:
d.addCallback(self._testPool_3)
d.addCallback(self._testPool_4)
d.addCallback(self._testPool_5)
return d
def _testPool_1(self, res):
sql = "select count(1) from simple"
d = self.dbpool.runQuery(sql)
def _check(row):
self.assertTrue(int(row[0][0]) == 0, "Table not empty")
d.addCallback(_check)
return d
def _testPool_2(self, res):
# reach in and close the connection manually
list(self.dbpool.connections.values())[0].close()
def _testPool_3(self, res):
sql = "select count(1) from simple"
d = defer.maybeDeferred(self.dbpool.runQuery, sql)
d.addCallbacks(lambda res: self.fail("no exception"), lambda f: None)
return d
def _testPool_4(self, res):
sql = "select count(1) from simple"
d = self.dbpool.runQuery(sql)
def _check(row):
self.assertTrue(int(row[0][0]) == 0, "Table not empty")
d.addCallback(_check)
return d
def _testPool_5(self, res):
self.flushLoggedErrors()
sql = "select * from NOTABLE" # bad sql
d = defer.maybeDeferred(self.dbpool.runQuery, sql)
d.addCallbacks(
lambda res: self.fail("no exception"),
lambda f: self.assertFalse(f.check(ConnectionLost)),
)
return d
class DBTestConnector:
"""
A class which knows how to test for the presence of
and establish a connection to a relational database.
To enable test cases which use a central, system database,
you must create a database named DB_NAME with a user DB_USER
and password DB_PASS with full access rights to database DB_NAME.
"""
# used for creating new test cases
TEST_PREFIX: Optional[str] = None
DB_NAME = "twisted_test"
DB_USER = "twisted_test"
DB_PASS = "twisted_test"
DB_DIR = None # directory for database storage
nulls_ok = True # nulls supported
trailing_spaces_ok = True # trailing spaces in strings preserved
can_rollback = True # rollback supported
test_failures = True # test bad sql?
escape_slashes = True # escape \ in sql?
good_sql: Optional[str] = ConnectionPool.good_sql
early_reconnect = True # cursor() will fail on closed connection
can_clear = True # can try to clear out tables when starting
# number of iterations for test loop (lower this for slow db's)
num_iterations = 50
def setUp(self):
self.DB_DIR = self.mktemp()
os.mkdir(self.DB_DIR)
if not self.can_connect():
raise unittest.SkipTest("%s: Cannot access db" % self.TEST_PREFIX)
return self.extraSetUp()
def can_connect(self):
"""Return true if this database is present on the system
and can be used in a test."""
raise NotImplementedError()
def startDB(self):
"""Take any steps needed to bring database up."""
pass
def stopDB(self):
"""Bring database down, if needed."""
pass
def makePool(self, **newkw):
"""Create a connection pool with additional keyword arguments."""
args, kw = self.getPoolArgs()
kw = kw.copy()
kw.update(newkw)
return ConnectionPool(*args, **kw)
def getPoolArgs(self):
"""Return a tuple (args, kw) of list and keyword arguments
that need to be passed to ConnectionPool to create a connection
to this database."""
raise NotImplementedError()
class SQLite3Connector(DBTestConnector):
"""
Connector that uses the stdlib SQLite3 database support.
"""
TEST_PREFIX = "SQLite3"
escape_slashes = False
num_iterations = 1 # slow
def can_connect(self):
if requireModule("sqlite3") is None:
return False
else:
return True
def startDB(self):
self.database = os.path.join(self.DB_DIR, self.DB_NAME)
if os.path.exists(self.database):
os.unlink(self.database)
def getPoolArgs(self):
args = ("sqlite3",)
kw = {"database": self.database, "cp_max": 1, "check_same_thread": False}
return args, kw
class PySQLite2Connector(DBTestConnector):
"""
Connector that uses pysqlite's SQLite database support.
"""
TEST_PREFIX = "pysqlite2"
escape_slashes = False
num_iterations = 1 # slow
def can_connect(self):
if requireModule("pysqlite2.dbapi2") is None:
return False
else:
return True
def startDB(self):
self.database = os.path.join(self.DB_DIR, self.DB_NAME)
if os.path.exists(self.database):
os.unlink(self.database)
def getPoolArgs(self):
args = ("pysqlite2.dbapi2",)
kw = {"database": self.database, "cp_max": 1, "check_same_thread": False}
return args, kw
class PyPgSQLConnector(DBTestConnector):
TEST_PREFIX = "PyPgSQL"
def can_connect(self):
try:
from pyPgSQL import PgSQL
except BaseException:
return False
try:
conn = PgSQL.connect(
database=self.DB_NAME, user=self.DB_USER, password=self.DB_PASS
)
conn.close()
return True
except BaseException:
return False
def getPoolArgs(self):
args = ("pyPgSQL.PgSQL",)
kw = {
"database": self.DB_NAME,
"user": self.DB_USER,
"password": self.DB_PASS,
"cp_min": 0,
}
return args, kw
class PsycopgConnector(DBTestConnector):
TEST_PREFIX = "Psycopg"
def can_connect(self):
try:
import psycopg
except BaseException:
return False
try:
conn = psycopg.connect(
database=self.DB_NAME, user=self.DB_USER, password=self.DB_PASS
)
conn.close()
return True
except BaseException:
return False
def getPoolArgs(self):
args = ("psycopg",)
kw = {
"database": self.DB_NAME,
"user": self.DB_USER,
"password": self.DB_PASS,
"cp_min": 0,
}
return args, kw
class MySQLConnector(DBTestConnector):
TEST_PREFIX = "MySQL"
trailing_spaces_ok = False
can_rollback = False
early_reconnect = False
def can_connect(self):
try:
import MySQLdb
except BaseException:
return False
try:
conn = MySQLdb.connect(
db=self.DB_NAME, user=self.DB_USER, passwd=self.DB_PASS
)
conn.close()
return True
except BaseException:
return False
def getPoolArgs(self):
args = ("MySQLdb",)
kw = {"db": self.DB_NAME, "user": self.DB_USER, "passwd": self.DB_PASS}
return args, kw
class FirebirdConnector(DBTestConnector):
TEST_PREFIX = "Firebird"
test_failures = False # failure testing causes problems
escape_slashes = False
good_sql = None # firebird doesn't handle failed sql well
can_clear = False # firebird is not so good
num_iterations = 5 # slow
def can_connect(self):
if requireModule("kinterbasdb") is None:
return False
try:
self.startDB()
self.stopDB()
return True
except BaseException:
return False
def startDB(self):
import kinterbasdb
self.DB_NAME = os.path.join(self.DB_DIR, DBTestConnector.DB_NAME)
os.chmod(self.DB_DIR, stat.S_IRWXU + stat.S_IRWXG + stat.S_IRWXO)
sql = 'create database "%s" user "%s" password "%s"'
sql %= (self.DB_NAME, self.DB_USER, self.DB_PASS)
conn = kinterbasdb.create_database(sql)
conn.close()
def getPoolArgs(self):
args = ("kinterbasdb",)
kw = {
"database": self.DB_NAME,
"host": "127.0.0.1",
"user": self.DB_USER,
"password": self.DB_PASS,
}
return args, kw
def stopDB(self):
import kinterbasdb
conn = kinterbasdb.connect(
database=self.DB_NAME,
host="127.0.0.1",
user=self.DB_USER,
password=self.DB_PASS,
)
conn.drop_database()
def makeSQLTests(base, suffix, globals):
"""
Make a test case for every db connector which can connect.
@param base: Base class for test case. Additional base classes
will be a DBConnector subclass and unittest.TestCase
@param suffix: A suffix used to create test case names. Prefixes
are defined in the DBConnector subclasses.
"""
connectors = [
PySQLite2Connector,
SQLite3Connector,
PyPgSQLConnector,
PsycopgConnector,
MySQLConnector,
FirebirdConnector,
]
tests = {}
for connclass in connectors:
name = connclass.TEST_PREFIX + suffix
class testcase(connclass, base, unittest.TestCase):
__module__ = connclass.__module__
testcase.__name__ = name
if hasattr(connclass, "__qualname__"):
testcase.__qualname__ = ".".join(
connclass.__qualname__.split()[0:-1] + [name]
)
tests[name] = testcase
globals.update(tests)
# PySQLite2Connector SQLite3ADBAPITests PyPgSQLADBAPITests
# PsycopgADBAPITests MySQLADBAPITests FirebirdADBAPITests
makeSQLTests(ADBAPITestBase, "ADBAPITests", globals())
# PySQLite2Connector SQLite3ReconnectTests PyPgSQLReconnectTests
# PsycopgReconnectTests MySQLReconnectTests FirebirdReconnectTests
makeSQLTests(ReconnectTestBase, "ReconnectTests", globals())
class FakePool:
"""
A fake L{ConnectionPool} for tests.
@ivar connectionFactory: factory for making connections returned by the
C{connect} method.
@type connectionFactory: any callable
"""
reconnect = True
noisy = True
def __init__(self, connectionFactory):
self.connectionFactory = connectionFactory
def connect(self):
"""
Return an instance of C{self.connectionFactory}.
"""
return self.connectionFactory()
def disconnect(self, connection):
"""
Do nothing.
"""
class ConnectionTests(unittest.TestCase):
"""
Tests for the L{Connection} class.
"""
def test_rollbackErrorLogged(self):
"""
If an error happens during rollback, L{ConnectionLost} is raised but
the original error is logged.
"""
class ConnectionRollbackRaise:
def rollback(self):
raise RuntimeError("problem!")
pool = FakePool(ConnectionRollbackRaise)
connection = Connection(pool)
self.assertRaises(ConnectionLost, connection.rollback)
errors = self.flushLoggedErrors(RuntimeError)
self.assertEqual(len(errors), 1)
self.assertEqual(errors[0].value.args[0], "problem!")
class TransactionTests(unittest.TestCase):
"""
Tests for the L{Transaction} class.
"""
def test_reopenLogErrorIfReconnect(self):
"""
If the cursor creation raises an error in L{Transaction.reopen}, it
reconnects but log the error occurred.
"""
class ConnectionCursorRaise:
count = 0
def reconnect(self):
pass
def cursor(self):
if self.count == 0:
self.count += 1
raise RuntimeError("problem!")
pool = FakePool(None)
transaction = Transaction(pool, ConnectionCursorRaise())
transaction.reopen()
errors = self.flushLoggedErrors(RuntimeError)
self.assertEqual(len(errors), 1)
self.assertEqual(errors[0].value.args[0], "problem!")
class NonThreadPool:
def callInThreadWithCallback(self, onResult, f, *a, **kw):
success = True
try:
result = f(*a, **kw)
except Exception:
success = False
result = Failure()
onResult(success, result)
class DummyConnectionPool(ConnectionPool):
"""
A testable L{ConnectionPool};
"""
threadpool = NonThreadPool()
def __init__(self):
"""
Don't forward init call.
"""
self._reactor = reactor
class EventReactor:
"""
Partial L{IReactorCore} implementation with simple event-related
methods.
@ivar _running: A C{bool} indicating whether the reactor is pretending
to have been started already or not.
@ivar triggers: A C{list} of pending system event triggers.
"""
def __init__(self, running):
self._running = running
self.triggers = []
def callWhenRunning(self, function):
if self._running:
function()
else:
return self.addSystemEventTrigger("after", "startup", function)
def addSystemEventTrigger(self, phase, event, trigger):
handle = (phase, event, trigger)
self.triggers.append(handle)
return handle
def removeSystemEventTrigger(self, handle):
self.triggers.remove(handle)
class ConnectionPoolTests(unittest.TestCase):
"""
Unit tests for L{ConnectionPool}.
"""
def test_runWithConnectionRaiseOriginalError(self):
"""
If rollback fails, L{ConnectionPool.runWithConnection} raises the
original exception and log the error of the rollback.
"""
class ConnectionRollbackRaise:
def __init__(self, pool):
pass
def rollback(self):
raise RuntimeError("problem!")
def raisingFunction(connection):
raise ValueError("foo")
pool = DummyConnectionPool()
pool.connectionFactory = ConnectionRollbackRaise
d = pool.runWithConnection(raisingFunction)
d = self.assertFailure(d, ValueError)
def cbFailed(ignored):
errors = self.flushLoggedErrors(RuntimeError)
self.assertEqual(len(errors), 1)
self.assertEqual(errors[0].value.args[0], "problem!")
d.addCallback(cbFailed)
return d
def test_closeLogError(self):
"""
L{ConnectionPool._close} logs exceptions.
"""
class ConnectionCloseRaise:
def close(self):
raise RuntimeError("problem!")
pool = DummyConnectionPool()
pool._close(ConnectionCloseRaise())
errors = self.flushLoggedErrors(RuntimeError)
self.assertEqual(len(errors), 1)
self.assertEqual(errors[0].value.args[0], "problem!")
def test_runWithInteractionRaiseOriginalError(self):
"""
If rollback fails, L{ConnectionPool.runInteraction} raises the
original exception and log the error of the rollback.
"""
class ConnectionRollbackRaise:
def __init__(self, pool):
pass
def rollback(self):
raise RuntimeError("problem!")
class DummyTransaction:
def __init__(self, pool, connection):
pass
def raisingFunction(transaction):
raise ValueError("foo")
pool = DummyConnectionPool()
pool.connectionFactory = ConnectionRollbackRaise
pool.transactionFactory = DummyTransaction
d = pool.runInteraction(raisingFunction)
d = self.assertFailure(d, ValueError)
def cbFailed(ignored):
errors = self.flushLoggedErrors(RuntimeError)
self.assertEqual(len(errors), 1)
self.assertEqual(errors[0].value.args[0], "problem!")
d.addCallback(cbFailed)
return d
def test_unstartedClose(self):
"""
If L{ConnectionPool.close} is called without L{ConnectionPool.start}
having been called, the pool's startup event is cancelled.
"""
reactor = EventReactor(False)
pool = ConnectionPool("twisted.test.test_adbapi", cp_reactor=reactor)
# There should be a startup trigger waiting.
self.assertEqual(reactor.triggers, [("after", "startup", pool._start)])
pool.close()
# But not anymore.
self.assertFalse(reactor.triggers)
def test_startedClose(self):
"""
If L{ConnectionPool.close} is called after it has been started, but
not by its shutdown trigger, the shutdown trigger is cancelled.
"""
reactor = EventReactor(True)
pool = ConnectionPool("twisted.test.test_adbapi", cp_reactor=reactor)
# There should be a shutdown trigger waiting.
self.assertEqual(reactor.triggers, [("during", "shutdown", pool.finalClose)])
pool.close()
# But not anymore.
self.assertFalse(reactor.triggers)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,548 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.python.compat}.
"""
import codecs
import io
import sys
import traceback
from unittest import skipIf
from twisted.python.compat import (
_PYPY,
bytesEnviron,
cmp,
comparable,
execfile,
intToBytes,
ioType,
iterbytes,
lazyByteSlice,
nativeString,
networkString,
reraise,
)
from twisted.python.filepath import FilePath
from twisted.python.runtime import platform
from twisted.trial.unittest import SynchronousTestCase, TestCase
class IOTypeTests(SynchronousTestCase):
"""
Test cases for determining a file-like object's type.
"""
def test_3StringIO(self):
"""
An L{io.StringIO} accepts and returns text.
"""
self.assertEqual(ioType(io.StringIO()), str)
def test_3BytesIO(self):
"""
An L{io.BytesIO} accepts and returns bytes.
"""
self.assertEqual(ioType(io.BytesIO()), bytes)
def test_3openTextMode(self):
"""
A file opened via 'io.open' in text mode accepts and returns text.
"""
with open(self.mktemp(), "w") as f:
self.assertEqual(ioType(f), str)
def test_3openBinaryMode(self):
"""
A file opened via 'io.open' in binary mode accepts and returns bytes.
"""
with open(self.mktemp(), "wb") as f:
self.assertEqual(ioType(f), bytes)
def test_codecsOpenBytes(self):
"""
The L{codecs} module, oddly, returns a file-like object which returns
bytes when not passed an 'encoding' argument.
"""
with codecs.open(self.mktemp(), "wb") as f:
self.assertEqual(ioType(f), bytes)
def test_codecsOpenText(self):
"""
When passed an encoding, however, the L{codecs} module returns unicode.
"""
with codecs.open(self.mktemp(), "wb", encoding="utf-8") as f:
self.assertEqual(ioType(f), str)
def test_defaultToText(self):
"""
When passed an object about which no sensible decision can be made, err
on the side of unicode.
"""
self.assertEqual(ioType(object()), str)
class CompatTests(SynchronousTestCase):
"""
Various utility functions in C{twisted.python.compat} provide same
functionality as modern Python variants.
"""
def test_set(self):
"""
L{set} should behave like the expected set interface.
"""
a = set()
a.add("b")
a.add("c")
a.add("a")
b = list(a)
b.sort()
self.assertEqual(b, ["a", "b", "c"])
a.remove("b")
b = list(a)
b.sort()
self.assertEqual(b, ["a", "c"])
a.discard("d")
b = {"r", "s"}
d = a.union(b)
b = list(d)
b.sort()
self.assertEqual(b, ["a", "c", "r", "s"])
def test_frozenset(self):
"""
L{frozenset} should behave like the expected frozenset interface.
"""
a = frozenset(["a", "b"])
self.assertRaises(AttributeError, getattr, a, "add")
self.assertEqual(sorted(a), ["a", "b"])
b = frozenset(["r", "s"])
d = a.union(b)
b = list(d)
b.sort()
self.assertEqual(b, ["a", "b", "r", "s"])
class ExecfileCompatTests(SynchronousTestCase):
"""
Tests for the Python 3-friendly L{execfile} implementation.
"""
def writeScript(self, content):
"""
Write L{content} to a new temporary file, returning the L{FilePath}
for the new file.
"""
path = self.mktemp()
with open(path, "wb") as f:
f.write(content.encode("ascii"))
return FilePath(path.encode("utf-8"))
def test_execfileGlobals(self):
"""
L{execfile} executes the specified file in the given global namespace.
"""
script = self.writeScript("foo += 1\n")
globalNamespace = {"foo": 1}
execfile(script.path, globalNamespace)
self.assertEqual(2, globalNamespace["foo"])
def test_execfileGlobalsAndLocals(self):
"""
L{execfile} executes the specified file in the given global and local
namespaces.
"""
script = self.writeScript("foo += 1\n")
globalNamespace = {"foo": 10}
localNamespace = {"foo": 20}
execfile(script.path, globalNamespace, localNamespace)
self.assertEqual(10, globalNamespace["foo"])
self.assertEqual(21, localNamespace["foo"])
def test_execfileUniversalNewlines(self):
"""
L{execfile} reads in the specified file using universal newlines so
that scripts written on one platform will work on another.
"""
for lineEnding in "\n", "\r", "\r\n":
script = self.writeScript("foo = 'okay'" + lineEnding)
globalNamespace = {"foo": None}
execfile(script.path, globalNamespace)
self.assertEqual("okay", globalNamespace["foo"])
class PYPYTest(SynchronousTestCase):
"""
Identification of PyPy.
"""
def test_PYPY(self):
"""
On PyPy, L{_PYPY} is True.
"""
if "PyPy" in sys.version:
self.assertTrue(_PYPY)
else:
self.assertFalse(_PYPY)
@comparable
class Comparable:
"""
Objects that can be compared to each other, but not others.
"""
def __init__(self, value):
self.value = value
def __cmp__(self, other):
if not isinstance(other, Comparable):
return NotImplemented
return cmp(self.value, other.value)
class ComparableTests(SynchronousTestCase):
"""
L{comparable} decorated classes emulate Python 2's C{__cmp__} semantics.
"""
def test_equality(self):
"""
Instances of a class that is decorated by C{comparable} support
equality comparisons.
"""
# Make explicitly sure we're using ==:
self.assertTrue(Comparable(1) == Comparable(1))
self.assertFalse(Comparable(2) == Comparable(1))
def test_nonEquality(self):
"""
Instances of a class that is decorated by C{comparable} support
inequality comparisons.
"""
# Make explicitly sure we're using !=:
self.assertFalse(Comparable(1) != Comparable(1))
self.assertTrue(Comparable(2) != Comparable(1))
def test_greaterThan(self):
"""
Instances of a class that is decorated by C{comparable} support
greater-than comparisons.
"""
self.assertTrue(Comparable(2) > Comparable(1))
self.assertFalse(Comparable(0) > Comparable(3))
def test_greaterThanOrEqual(self):
"""
Instances of a class that is decorated by C{comparable} support
greater-than-or-equal comparisons.
"""
self.assertTrue(Comparable(1) >= Comparable(1))
self.assertTrue(Comparable(2) >= Comparable(1))
self.assertFalse(Comparable(0) >= Comparable(3))
def test_lessThan(self):
"""
Instances of a class that is decorated by C{comparable} support
less-than comparisons.
"""
self.assertTrue(Comparable(0) < Comparable(3))
self.assertFalse(Comparable(2) < Comparable(0))
def test_lessThanOrEqual(self):
"""
Instances of a class that is decorated by C{comparable} support
less-than-or-equal comparisons.
"""
self.assertTrue(Comparable(3) <= Comparable(3))
self.assertTrue(Comparable(0) <= Comparable(3))
self.assertFalse(Comparable(2) <= Comparable(0))
class Python3ComparableTests(SynchronousTestCase):
"""
Python 3-specific functionality of C{comparable}.
"""
def test_notImplementedEquals(self):
"""
Instances of a class that is decorated by C{comparable} support
returning C{NotImplemented} from C{__eq__} if it is returned by the
underlying C{__cmp__} call.
"""
self.assertEqual(Comparable(1).__eq__(object()), NotImplemented)
def test_notImplementedNotEquals(self):
"""
Instances of a class that is decorated by C{comparable} support
returning C{NotImplemented} from C{__ne__} if it is returned by the
underlying C{__cmp__} call.
"""
self.assertEqual(Comparable(1).__ne__(object()), NotImplemented)
def test_notImplementedGreaterThan(self):
"""
Instances of a class that is decorated by C{comparable} support
returning C{NotImplemented} from C{__gt__} if it is returned by the
underlying C{__cmp__} call.
"""
self.assertEqual(Comparable(1).__gt__(object()), NotImplemented)
def test_notImplementedLessThan(self):
"""
Instances of a class that is decorated by C{comparable} support
returning C{NotImplemented} from C{__lt__} if it is returned by the
underlying C{__cmp__} call.
"""
self.assertEqual(Comparable(1).__lt__(object()), NotImplemented)
def test_notImplementedGreaterThanEquals(self):
"""
Instances of a class that is decorated by C{comparable} support
returning C{NotImplemented} from C{__ge__} if it is returned by the
underlying C{__cmp__} call.
"""
self.assertEqual(Comparable(1).__ge__(object()), NotImplemented)
def test_notImplementedLessThanEquals(self):
"""
Instances of a class that is decorated by C{comparable} support
returning C{NotImplemented} from C{__le__} if it is returned by the
underlying C{__cmp__} call.
"""
self.assertEqual(Comparable(1).__le__(object()), NotImplemented)
class CmpTests(SynchronousTestCase):
"""
L{cmp} should behave like the built-in Python 2 C{cmp}.
"""
def test_equals(self):
"""
L{cmp} returns 0 for equal objects.
"""
self.assertEqual(cmp("a", "a"), 0)
self.assertEqual(cmp(1, 1), 0)
self.assertEqual(cmp([1], [1]), 0)
def test_greaterThan(self):
"""
L{cmp} returns 1 if its first argument is bigger than its second.
"""
self.assertEqual(cmp(4, 0), 1)
self.assertEqual(cmp(b"z", b"a"), 1)
def test_lessThan(self):
"""
L{cmp} returns -1 if its first argument is smaller than its second.
"""
self.assertEqual(cmp(0.1, 2.3), -1)
self.assertEqual(cmp(b"a", b"d"), -1)
class StringTests(SynchronousTestCase):
"""
Compatibility functions and types for strings.
"""
def assertNativeString(self, original, expected):
"""
Raise an exception indicating a failed test if the output of
C{nativeString(original)} is unequal to the expected string, or is not
a native string.
"""
self.assertEqual(nativeString(original), expected)
self.assertIsInstance(nativeString(original), str)
def test_nonASCIIBytesToString(self):
"""
C{nativeString} raises a C{UnicodeError} if input bytes are not ASCII
decodable.
"""
self.assertRaises(UnicodeError, nativeString, b"\xFF")
def test_nonASCIIUnicodeToString(self):
"""
C{nativeString} raises a C{UnicodeError} if input Unicode is not ASCII
encodable.
"""
self.assertRaises(UnicodeError, nativeString, "\u1234")
def test_bytesToString(self):
"""
C{nativeString} converts bytes to the native string format, assuming
an ASCII encoding if applicable.
"""
self.assertNativeString(b"hello", "hello")
def test_unicodeToString(self):
"""
C{nativeString} converts unicode to the native string format, assuming
an ASCII encoding if applicable.
"""
self.assertNativeString("Good day", "Good day")
def test_stringToString(self):
"""
C{nativeString} leaves native strings as native strings.
"""
self.assertNativeString("Hello!", "Hello!")
def test_unexpectedType(self):
"""
C{nativeString} raises a C{TypeError} if given an object that is not a
string of some sort.
"""
self.assertRaises(TypeError, nativeString, 1)
class NetworkStringTests(SynchronousTestCase):
"""
Tests for L{networkString}.
"""
def test_str(self):
"""
L{networkString} returns a C{unicode} object passed to it encoded into
a C{bytes} instance.
"""
self.assertEqual(b"foo", networkString("foo"))
def test_unicodeOutOfRange(self):
"""
L{networkString} raises L{UnicodeError} if passed a C{unicode} instance
containing characters not encodable in ASCII.
"""
self.assertRaises(UnicodeError, networkString, "\N{SNOWMAN}")
def test_nonString(self):
"""
L{networkString} raises L{TypeError} if passed a non-string object or
the wrong type of string object.
"""
self.assertRaises(TypeError, networkString, object())
self.assertRaises(TypeError, networkString, b"bytes")
class ReraiseTests(SynchronousTestCase):
"""
L{reraise} re-raises exceptions on both Python 2 and Python 3.
"""
def test_reraiseWithNone(self):
"""
Calling L{reraise} with an exception instance and a traceback of
L{None} re-raises it with a new traceback.
"""
try:
1 / 0
except BaseException:
typ, value, tb = sys.exc_info()
try:
reraise(value, None)
except BaseException:
typ2, value2, tb2 = sys.exc_info()
self.assertEqual(typ2, ZeroDivisionError)
self.assertIs(value, value2)
self.assertNotEqual(
traceback.format_tb(tb)[-1], traceback.format_tb(tb2)[-1]
)
else:
self.fail("The exception was not raised.")
def test_reraiseWithTraceback(self):
"""
Calling L{reraise} with an exception instance and a traceback
re-raises the exception with the given traceback.
"""
try:
1 / 0
except BaseException:
typ, value, tb = sys.exc_info()
try:
reraise(value, tb)
except BaseException:
typ2, value2, tb2 = sys.exc_info()
self.assertEqual(typ2, ZeroDivisionError)
self.assertIs(value, value2)
self.assertEqual(traceback.format_tb(tb)[-1], traceback.format_tb(tb2)[-1])
else:
self.fail("The exception was not raised.")
class Python3BytesTests(SynchronousTestCase):
"""
Tests for L{iterbytes}, L{intToBytes}, L{lazyByteSlice}.
"""
def test_iteration(self):
"""
When L{iterbytes} is called with a bytestring, the returned object
can be iterated over, resulting in the individual bytes of the
bytestring.
"""
input = b"abcd"
result = list(iterbytes(input))
self.assertEqual(result, [b"a", b"b", b"c", b"d"])
def test_intToBytes(self):
"""
When L{intToBytes} is called with an integer, the result is an
ASCII-encoded string representation of the number.
"""
self.assertEqual(intToBytes(213), b"213")
def test_lazyByteSliceNoOffset(self):
"""
L{lazyByteSlice} called with some bytes returns a semantically equal
version of these bytes.
"""
data = b"123XYZ"
self.assertEqual(bytes(lazyByteSlice(data)), data)
def test_lazyByteSliceOffset(self):
"""
L{lazyByteSlice} called with some bytes and an offset returns a
semantically equal version of these bytes starting at the given offset.
"""
data = b"123XYZ"
self.assertEqual(bytes(lazyByteSlice(data, 2)), data[2:])
def test_lazyByteSliceOffsetAndLength(self):
"""
L{lazyByteSlice} called with some bytes, an offset and a length returns
a semantically equal version of these bytes starting at the given
offset, up to the given length.
"""
data = b"123XYZ"
self.assertEqual(bytes(lazyByteSlice(data, 2, 3)), data[2:5])
class BytesEnvironTests(TestCase):
"""
Tests for L{BytesEnviron}.
"""
@skipIf(platform.isWindows(), "Environment vars are always str on Windows.")
def test_alwaysBytes(self):
"""
The output of L{BytesEnviron} should always be a L{dict} with L{bytes}
values and L{bytes} keys.
"""
result = bytesEnviron()
types = set()
for key, val in result.items():
types.add(type(key))
types.add(type(val))
self.assertEqual(list(types), [bytes])

View File

@@ -0,0 +1,48 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.python.context}.
"""
from twisted.python import context
from twisted.trial.unittest import SynchronousTestCase
class ContextTests(SynchronousTestCase):
"""
Tests for the module-scope APIs for L{twisted.python.context}.
"""
def test_notPresentIfNotSet(self):
"""
Arbitrary keys which have not been set in the context have an associated
value of L{None}.
"""
self.assertIsNone(context.get("x"))
def test_setByCall(self):
"""
Values may be associated with keys by passing them in a dictionary as
the first argument to L{twisted.python.context.call}.
"""
self.assertEqual(context.call({"x": "y"}, context.get, "x"), "y")
def test_unsetAfterCall(self):
"""
After a L{twisted.python.context.call} completes, keys specified in the
call are no longer associated with the values from that call.
"""
context.call({"x": "y"}, lambda: None)
self.assertIsNone(context.get("x"))
def test_setDefault(self):
"""
A default value may be set for a key in the context using
L{twisted.python.context.setDefault}.
"""
key = object()
self.addCleanup(context.defaultContextDict.pop, key, None)
context.setDefault(key, "y")
self.assertEqual("y", context.get(key))

View File

@@ -0,0 +1,690 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
This module contains tests for L{twisted.internet.task.Cooperator} and
related functionality.
"""
from twisted.internet import defer, reactor, task
from twisted.trial import unittest
class FakeDelayedCall:
"""
Fake delayed call which lets us simulate the scheduler.
"""
def __init__(self, func):
"""
A function to run, later.
"""
self.func = func
self.cancelled = False
def cancel(self):
"""
Don't run my function later.
"""
self.cancelled = True
class FakeScheduler:
"""
A fake scheduler for testing against.
"""
def __init__(self):
"""
Create a fake scheduler with a list of work to do.
"""
self.work = []
def __call__(self, thunk):
"""
Schedule a unit of work to be done later.
"""
unit = FakeDelayedCall(thunk)
self.work.append(unit)
return unit
def pump(self):
"""
Do all of the work that is currently available to be done.
"""
work, self.work = self.work, []
for unit in work:
if not unit.cancelled:
unit.func()
class CooperatorTests(unittest.TestCase):
RESULT = "done"
def ebIter(self, err):
err.trap(task.SchedulerStopped)
return self.RESULT
def cbIter(self, ign):
self.fail()
def testStoppedRejectsNewTasks(self):
"""
Test that Cooperators refuse new tasks when they have been stopped.
"""
def testwith(stuff):
c = task.Cooperator()
c.stop()
d = c.coiterate(iter(()), stuff)
d.addCallback(self.cbIter)
d.addErrback(self.ebIter)
return d.addCallback(lambda result: self.assertEqual(result, self.RESULT))
return testwith(None).addCallback(lambda ign: testwith(defer.Deferred()))
def testStopRunning(self):
"""
Test that a running iterator will not run to completion when the
cooperator is stopped.
"""
c = task.Cooperator()
def myiter():
yield from range(3)
myiter.value = -1
d = c.coiterate(myiter())
d.addCallback(self.cbIter)
d.addErrback(self.ebIter)
c.stop()
def doasserts(result):
self.assertEqual(result, self.RESULT)
self.assertEqual(myiter.value, -1)
d.addCallback(doasserts)
return d
def testStopOutstanding(self):
"""
An iterator run with L{Cooperator.coiterate} paused on a L{Deferred}
yielded by that iterator will fire its own L{Deferred} (the one
returned by C{coiterate}) when L{Cooperator.stop} is called.
"""
testControlD = defer.Deferred()
outstandingD = defer.Deferred()
def myiter():
reactor.callLater(0, testControlD.callback, None)
yield outstandingD
self.fail()
c = task.Cooperator()
d = c.coiterate(myiter())
def stopAndGo(ign):
c.stop()
outstandingD.callback("arglebargle")
testControlD.addCallback(stopAndGo)
d.addCallback(self.cbIter)
d.addErrback(self.ebIter)
return d.addCallback(lambda result: self.assertEqual(result, self.RESULT))
def testUnexpectedError(self):
c = task.Cooperator()
def myiter():
if False:
yield None
else:
raise RuntimeError()
d = c.coiterate(myiter())
return self.assertFailure(d, RuntimeError)
def testUnexpectedErrorActuallyLater(self):
def myiter():
D = defer.Deferred()
reactor.callLater(0, D.errback, RuntimeError())
yield D
c = task.Cooperator()
d = c.coiterate(myiter())
return self.assertFailure(d, RuntimeError)
def testUnexpectedErrorNotActuallyLater(self):
def myiter():
yield defer.fail(RuntimeError())
c = task.Cooperator()
d = c.coiterate(myiter())
return self.assertFailure(d, RuntimeError)
def testCooperation(self):
L = []
def myiter(things):
for th in things:
L.append(th)
yield None
groupsOfThings = ["abc", (1, 2, 3), "def", (4, 5, 6)]
c = task.Cooperator()
tasks = []
for stuff in groupsOfThings:
tasks.append(c.coiterate(myiter(stuff)))
return defer.DeferredList(tasks).addCallback(
lambda ign: self.assertEqual(tuple(L), sum(zip(*groupsOfThings), ()))
)
def testResourceExhaustion(self):
output = []
def myiter():
for i in range(100):
output.append(i)
if i == 9:
_TPF.stopped = True
yield i
class _TPF:
stopped = False
def __call__(self):
return self.stopped
c = task.Cooperator(terminationPredicateFactory=_TPF)
c.coiterate(myiter()).addErrback(self.ebIter)
c._delayedCall.cancel()
# testing a private method because only the test case will ever care
# about this, so we have to carefully clean up after ourselves.
c._tick()
c.stop()
self.assertTrue(_TPF.stopped)
self.assertEqual(output, list(range(10)))
def testCallbackReCoiterate(self):
"""
If a callback to a deferred returned by coiterate calls coiterate on
the same Cooperator, we should make sure to only do the minimal amount
of scheduling work. (This test was added to demonstrate a specific bug
that was found while writing the scheduler.)
"""
calls = []
class FakeCall:
def __init__(self, func):
self.func = func
def __repr__(self) -> str:
return f"<FakeCall {self.func!r}>"
def sched(f):
self.assertFalse(calls, repr(calls))
calls.append(FakeCall(f))
return calls[-1]
c = task.Cooperator(
scheduler=sched, terminationPredicateFactory=lambda: lambda: True
)
d = c.coiterate(iter(()))
done = []
def anotherTask(ign):
c.coiterate(iter(())).addBoth(done.append)
d.addCallback(anotherTask)
work = 0
while not done:
work += 1
while calls:
calls.pop(0).func()
work += 1
if work > 50:
self.fail("Cooperator took too long")
def test_removingLastTaskStopsScheduledCall(self):
"""
If the last task in a Cooperator is removed, the scheduled call for
the next tick is cancelled, since it is no longer necessary.
This behavior is useful for tests that want to assert they have left
no reactor state behind when they're done.
"""
calls = [None]
def sched(f):
calls[0] = FakeDelayedCall(f)
return calls[0]
coop = task.Cooperator(scheduler=sched)
# Add two task; this should schedule the tick:
task1 = coop.cooperate(iter([1, 2]))
task2 = coop.cooperate(iter([1, 2]))
self.assertEqual(calls[0].func, coop._tick)
# Remove first task; scheduled call should still be going:
task1.stop()
self.assertFalse(calls[0].cancelled)
self.assertEqual(coop._delayedCall, calls[0])
# Remove second task; scheduled call should be cancelled:
task2.stop()
self.assertTrue(calls[0].cancelled)
self.assertIsNone(coop._delayedCall)
# Add another task; scheduled call will be recreated:
coop.cooperate(iter([1, 2]))
self.assertFalse(calls[0].cancelled)
self.assertEqual(coop._delayedCall, calls[0])
def test_runningWhenStarted(self):
"""
L{Cooperator.running} reports C{True} if the L{Cooperator}
was started on creation.
"""
c = task.Cooperator()
self.assertTrue(c.running)
def test_runningWhenNotStarted(self):
"""
L{Cooperator.running} reports C{False} if the L{Cooperator}
has not been started.
"""
c = task.Cooperator(started=False)
self.assertFalse(c.running)
def test_runningWhenRunning(self):
"""
L{Cooperator.running} reports C{True} when the L{Cooperator}
is running.
"""
c = task.Cooperator(started=False)
c.start()
self.addCleanup(c.stop)
self.assertTrue(c.running)
def test_runningWhenStopped(self):
"""
L{Cooperator.running} reports C{False} after the L{Cooperator}
has been stopped.
"""
c = task.Cooperator(started=False)
c.start()
c.stop()
self.assertFalse(c.running)
class UnhandledException(Exception):
"""
An exception that should go unhandled.
"""
class AliasTests(unittest.TestCase):
"""
Integration test to verify that the global singleton aliases do what
they're supposed to.
"""
def test_cooperate(self):
"""
L{twisted.internet.task.cooperate} ought to run the generator that it is
"""
d = defer.Deferred()
def doit():
yield 1
yield 2
yield 3
d.callback("yay")
it = doit()
theTask = task.cooperate(it)
self.assertIn(theTask, task._theCooperator._tasks)
return d
class RunStateTests(unittest.TestCase):
"""
Tests to verify the behavior of L{CooperativeTask.pause},
L{CooperativeTask.resume}, L{CooperativeTask.stop}, exhausting the
underlying iterator, and their interactions with each other.
"""
def setUp(self):
"""
Create a cooperator with a fake scheduler and a termination predicate
that ensures only one unit of work will take place per tick.
"""
self._doDeferNext = False
self._doStopNext = False
self._doDieNext = False
self.work = []
self.scheduler = FakeScheduler()
self.cooperator = task.Cooperator(
scheduler=self.scheduler,
# Always stop after one iteration of work (return a function which
# returns a function which always returns True)
terminationPredicateFactory=lambda: lambda: True,
)
self.task = self.cooperator.cooperate(self.worker())
self.cooperator.start()
def worker(self):
"""
This is a sample generator which yields Deferreds when we are testing
deferral and an ascending integer count otherwise.
"""
i = 0
while True:
i += 1
if self._doDeferNext:
self._doDeferNext = False
d = defer.Deferred()
self.work.append(d)
yield d
elif self._doStopNext:
return
elif self._doDieNext:
raise UnhandledException()
else:
self.work.append(i)
yield i
def tearDown(self):
"""
Drop references to interesting parts of the fixture to allow Deferred
errors to be noticed when things start failing.
"""
del self.task
del self.scheduler
def deferNext(self):
"""
Defer the next result from my worker iterator.
"""
self._doDeferNext = True
def stopNext(self):
"""
Make the next result from my worker iterator be completion (raising
StopIteration).
"""
self._doStopNext = True
def dieNext(self):
"""
Make the next result from my worker iterator be raising an
L{UnhandledException}.
"""
def ignoreUnhandled(failure):
failure.trap(UnhandledException)
return None
self._doDieNext = True
def test_pauseResume(self):
"""
Cooperators should stop running their tasks when they're paused, and
start again when they're resumed.
"""
# first, sanity check
self.scheduler.pump()
self.assertEqual(self.work, [1])
self.scheduler.pump()
self.assertEqual(self.work, [1, 2])
# OK, now for real
self.task.pause()
self.scheduler.pump()
self.assertEqual(self.work, [1, 2])
self.task.resume()
# Resuming itself shoult not do any work
self.assertEqual(self.work, [1, 2])
self.scheduler.pump()
# But when the scheduler rolls around again...
self.assertEqual(self.work, [1, 2, 3])
def test_resumeNotPaused(self):
"""
L{CooperativeTask.resume} should raise a L{TaskNotPaused} exception if
it was not paused; e.g. if L{CooperativeTask.pause} was not invoked
more times than L{CooperativeTask.resume} on that object.
"""
self.assertRaises(task.NotPaused, self.task.resume)
self.task.pause()
self.task.resume()
self.assertRaises(task.NotPaused, self.task.resume)
def test_pauseTwice(self):
"""
Pauses on tasks should behave like a stack. If a task is paused twice,
it needs to be resumed twice.
"""
# pause once
self.task.pause()
self.scheduler.pump()
self.assertEqual(self.work, [])
# pause twice
self.task.pause()
self.scheduler.pump()
self.assertEqual(self.work, [])
# resume once (it shouldn't)
self.task.resume()
self.scheduler.pump()
self.assertEqual(self.work, [])
# resume twice (now it should go)
self.task.resume()
self.scheduler.pump()
self.assertEqual(self.work, [1])
def test_pauseWhileDeferred(self):
"""
C{pause()}ing a task while it is waiting on an outstanding
L{defer.Deferred} should put the task into a state where the
outstanding L{defer.Deferred} must be called back I{and} the task is
C{resume}d before it will continue processing.
"""
self.deferNext()
self.scheduler.pump()
self.assertEqual(len(self.work), 1)
self.assertIsInstance(self.work[0], defer.Deferred)
self.scheduler.pump()
self.assertEqual(len(self.work), 1)
self.task.pause()
self.scheduler.pump()
self.assertEqual(len(self.work), 1)
self.task.resume()
self.scheduler.pump()
self.assertEqual(len(self.work), 1)
self.work[0].callback("STUFF!")
self.scheduler.pump()
self.assertEqual(len(self.work), 2)
self.assertEqual(self.work[1], 2)
def test_whenDone(self):
"""
L{CooperativeTask.whenDone} returns a Deferred which fires when the
Cooperator's iterator is exhausted. It returns a new Deferred each
time it is called; callbacks added to other invocations will not modify
the value that subsequent invocations will fire with.
"""
deferred1 = self.task.whenDone()
deferred2 = self.task.whenDone()
results1 = []
results2 = []
final1 = []
final2 = []
def callbackOne(result):
results1.append(result)
return 1
def callbackTwo(result):
results2.append(result)
return 2
deferred1.addCallback(callbackOne)
deferred2.addCallback(callbackTwo)
deferred1.addCallback(final1.append)
deferred2.addCallback(final2.append)
# exhaust the task iterator
# callbacks fire
self.stopNext()
self.scheduler.pump()
self.assertEqual(len(results1), 1)
self.assertEqual(len(results2), 1)
self.assertIs(results1[0], self.task._iterator)
self.assertIs(results2[0], self.task._iterator)
self.assertEqual(final1, [1])
self.assertEqual(final2, [2])
def test_whenDoneError(self):
"""
L{CooperativeTask.whenDone} returns a L{defer.Deferred} that will fail
when the iterable's C{next} method raises an exception, with that
exception.
"""
deferred1 = self.task.whenDone()
results = []
deferred1.addErrback(results.append)
self.dieNext()
self.scheduler.pump()
self.assertEqual(len(results), 1)
self.assertEqual(results[0].check(UnhandledException), UnhandledException)
def test_whenDoneStop(self):
"""
L{CooperativeTask.whenDone} returns a L{defer.Deferred} that fails with
L{TaskStopped} when the C{stop} method is called on that
L{CooperativeTask}.
"""
deferred1 = self.task.whenDone()
errors = []
deferred1.addErrback(errors.append)
self.task.stop()
self.assertEqual(len(errors), 1)
self.assertEqual(errors[0].check(task.TaskStopped), task.TaskStopped)
def test_whenDoneAlreadyDone(self):
"""
L{CooperativeTask.whenDone} will return a L{defer.Deferred} that will
succeed immediately if its iterator has already completed.
"""
self.stopNext()
self.scheduler.pump()
results = []
self.task.whenDone().addCallback(results.append)
self.assertEqual(results, [self.task._iterator])
def test_stopStops(self):
"""
C{stop()}ping a task should cause it to be removed from the run just as
C{pause()}ing, with the distinction that C{resume()} will raise a
L{TaskStopped} exception.
"""
self.task.stop()
self.scheduler.pump()
self.assertEqual(len(self.work), 0)
self.assertRaises(task.TaskStopped, self.task.stop)
self.assertRaises(task.TaskStopped, self.task.pause)
# Sanity check - it's still not scheduled, is it?
self.scheduler.pump()
self.assertEqual(self.work, [])
def test_pauseStopResume(self):
"""
C{resume()}ing a paused, stopped task should be a no-op; it should not
raise an exception, because it's paused, but neither should it actually
do more work from the task.
"""
self.task.pause()
self.task.stop()
self.task.resume()
self.scheduler.pump()
self.assertEqual(self.work, [])
def test_stopDeferred(self):
"""
As a corrolary of the interaction of C{pause()} and C{unpause()},
C{stop()}ping a task which is waiting on a L{Deferred} should cause the
task to gracefully shut down, meaning that it should not be unpaused
when the deferred fires.
"""
self.deferNext()
self.scheduler.pump()
d = self.work.pop()
self.assertEqual(self.task._pauseCount, 1)
results = []
d.addBoth(results.append)
self.scheduler.pump()
self.task.stop()
self.scheduler.pump()
d.callback(7)
self.scheduler.pump()
# Let's make sure that Deferred doesn't come out fried with an
# unhandled error that will be logged. The value is None, rather than
# our test value, 7, because this Deferred is returned to and consumed
# by the cooperator code. Its callback therefore has no contract.
self.assertEqual(results, [None])
# But more importantly, no further work should have happened.
self.assertEqual(self.work, [])
def test_stopExhausted(self):
"""
C{stop()}ping a L{CooperativeTask} whose iterator has been exhausted
should raise L{TaskDone}.
"""
self.stopNext()
self.scheduler.pump()
self.assertRaises(task.TaskDone, self.task.stop)
def test_stopErrored(self):
"""
C{stop()}ping a L{CooperativeTask} whose iterator has encountered an
error should raise L{TaskFailed}.
"""
self.dieNext()
self.scheduler.pump()
self.assertRaises(task.TaskFailed, self.task.stop)
def test_stopCooperatorReentrancy(self):
"""
If a callback of a L{Deferred} from L{CooperativeTask.whenDone} calls
C{Cooperator.stop} on its L{CooperativeTask._cooperator}, the
L{Cooperator} will stop, but the L{CooperativeTask} whose callback is
calling C{stop} should already be considered 'stopped' by the time the
callback is running, and therefore removed from the
L{CoooperativeTask}.
"""
callbackPhases = []
def stopit(result):
callbackPhases.append(result)
self.cooperator.stop()
# "done" here is a sanity check to make sure that we get all the
# way through the callback; i.e. stop() shouldn't be raising an
# exception due to the stopped-ness of our main task.
callbackPhases.append("done")
self.task.whenDone().addCallback(stopit)
self.stopNext()
self.scheduler.pump()
self.assertEqual(callbackPhases, [self.task._iterator, "done"])

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,235 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.internet.defer.deferredGenerator} and related APIs.
"""
from twisted.internet import defer, reactor
from twisted.internet.defer import Deferred, deferredGenerator, waitForDeferred
from twisted.python.util import runWithWarningsSuppressed
from twisted.trial import unittest
from twisted.trial.util import suppress as SUPPRESS
def getThing():
d = Deferred()
reactor.callLater(0, d.callback, "hi")
return d
def getOwie():
d = Deferred()
def CRAP():
d.errback(ZeroDivisionError("OMG"))
reactor.callLater(0, CRAP)
return d
class TerminalException(Exception):
pass
def deprecatedDeferredGenerator(f):
"""
Calls L{deferredGenerator} while suppressing the deprecation warning.
@param f: Function to call
@return: Return value of function.
"""
return runWithWarningsSuppressed(
[
SUPPRESS(
message="twisted.internet.defer.deferredGenerator was " "deprecated"
)
],
deferredGenerator,
f,
)
class DeferredGeneratorTests(unittest.TestCase):
def testBasics(self):
"""
Test that a normal deferredGenerator works. Tests yielding a
deferred which callbacks, as well as a deferred errbacks. Also
ensures returning a final value works.
"""
@deprecatedDeferredGenerator
def _genBasics():
x = waitForDeferred(getThing())
yield x
x = x.getResult()
self.assertEqual(x, "hi")
ow = waitForDeferred(getOwie())
yield ow
try:
ow.getResult()
except ZeroDivisionError as e:
self.assertEqual(str(e), "OMG")
yield "WOOSH"
return
return _genBasics().addCallback(self.assertEqual, "WOOSH")
def testProducesException(self):
"""
Ensure that a generator that produces an exception signals
a Failure condition on result deferred by converting the exception to
a L{Failure}.
"""
@deprecatedDeferredGenerator
def _genProduceException():
yield waitForDeferred(getThing())
1 // 0
return self.assertFailure(_genProduceException(), ZeroDivisionError)
def testNothing(self):
"""Test that a generator which never yields results in None."""
@deprecatedDeferredGenerator
def _genNothing():
if False:
yield 1 # pragma: no cover
return _genNothing().addCallback(self.assertEqual, None)
def testHandledTerminalFailure(self):
"""
Create a Deferred Generator which yields a Deferred which fails and
handles the exception which results. Assert that the Deferred
Generator does not errback its Deferred.
"""
@deprecatedDeferredGenerator
def _genHandledTerminalFailure():
x = waitForDeferred(
defer.fail(TerminalException("Handled Terminal Failure"))
)
yield x
try:
x.getResult()
except TerminalException:
pass
return _genHandledTerminalFailure().addCallback(self.assertEqual, None)
def testHandledTerminalAsyncFailure(self):
"""
Just like testHandledTerminalFailure, only with a Deferred which fires
asynchronously with an error.
"""
@deprecatedDeferredGenerator
def _genHandledTerminalAsyncFailure(d):
x = waitForDeferred(d)
yield x
try:
x.getResult()
except TerminalException:
pass
d = defer.Deferred()
deferredGeneratorResultDeferred = _genHandledTerminalAsyncFailure(d)
d.errback(TerminalException("Handled Terminal Failure"))
return deferredGeneratorResultDeferred.addCallback(self.assertEqual, None)
def testStackUsage(self):
"""
Make sure we don't blow the stack when yielding immediately
available deferreds.
"""
@deprecatedDeferredGenerator
def _genStackUsage():
for x in range(5000):
# Test with yielding a deferred
x = waitForDeferred(defer.succeed(1))
yield x
x = x.getResult()
yield 0
return _genStackUsage().addCallback(self.assertEqual, 0)
def testStackUsage2(self):
"""
Make sure we don't blow the stack when yielding immediately
available values.
"""
@deprecatedDeferredGenerator
def _genStackUsage2():
for x in range(5000):
# Test with yielding a random value
yield 1
yield 0
return _genStackUsage2().addCallback(self.assertEqual, 0)
def testDeferredYielding(self):
"""
Ensure that yielding a Deferred directly is trapped as an
error.
"""
# See the comment _deferGenerator about d.callback(Deferred).
def _genDeferred():
yield getThing()
_genDeferred = deprecatedDeferredGenerator(_genDeferred)
return self.assertFailure(_genDeferred(), TypeError)
suppress = [
SUPPRESS(message="twisted.internet.defer.waitForDeferred was " "deprecated")
]
class DeprecateDeferredGeneratorTests(unittest.SynchronousTestCase):
"""
Tests that L{DeferredGeneratorTests} and L{waitForDeferred} are
deprecated.
"""
def test_deferredGeneratorDeprecated(self):
"""
L{deferredGenerator} is deprecated.
"""
@deferredGenerator
def decoratedFunction():
yield None
warnings = self.flushWarnings([self.test_deferredGeneratorDeprecated])
self.assertEqual(len(warnings), 1)
self.assertEqual(warnings[0]["category"], DeprecationWarning)
self.assertEqual(
warnings[0]["message"],
"twisted.internet.defer.deferredGenerator was deprecated in "
"Twisted 15.0.0; please use "
"twisted.internet.defer.inlineCallbacks instead",
)
def test_waitForDeferredDeprecated(self):
"""
L{waitForDeferred} is deprecated.
"""
d = Deferred()
waitForDeferred(d)
warnings = self.flushWarnings([self.test_waitForDeferredDeprecated])
self.assertEqual(len(warnings), 1)
self.assertEqual(warnings[0]["category"], DeprecationWarning)
self.assertEqual(
warnings[0]["message"],
"twisted.internet.defer.waitForDeferred was deprecated in "
"Twisted 15.0.0; please use "
"twisted.internet.defer.inlineCallbacks instead",
)

View File

@@ -0,0 +1,218 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Test cases for dirdbm module.
"""
import shutil
from base64 import b64decode
from twisted.persisted import dirdbm
from twisted.python import rebuild
from twisted.python.filepath import FilePath
from twisted.trial import unittest
class DirDbmTests(unittest.TestCase):
def setUp(self) -> None:
self.path = FilePath(self.mktemp())
self.dbm = dirdbm.open(self.path.path)
self.items: tuple[tuple[bytes, bytes | int | float | tuple[None, int]], ...] = (
(b"abc", b"foo"),
(b"/lalal", b"\000\001"),
(b"\000\012", b"baz"),
)
def test_all(self) -> None:
k = b64decode("//==")
self.dbm[k] = b"a"
self.dbm[k] = b"a"
self.assertEqual(self.dbm[k], b"a")
def test_rebuildInteraction(self) -> None:
s = dirdbm.Shelf("dirdbm.rebuild.test")
s[b"key"] = b"value"
rebuild.rebuild(dirdbm)
def test_dbm(self) -> None:
d = self.dbm
# Insert keys
keys = []
values = set()
for k, v in self.items:
d[k] = v
keys.append(k)
values.add(v)
keys.sort()
# Check they exist
for k, v in self.items:
self.assertIn(k, d)
self.assertEqual(d[k], v)
# Check non existent key
try:
d[b"XXX"]
except KeyError:
pass
else:
assert 0, "didn't raise KeyError on non-existent key"
# Check keys(), values() and items()
dbkeys = d.keys()
dbvalues = set(d.values())
dbitems = set(d.items())
dbkeys.sort()
items = set(self.items)
self.assertEqual(
keys,
dbkeys,
f".keys() output didn't match: {repr(keys)} != {repr(dbkeys)}",
)
self.assertEqual(
values,
dbvalues,
".values() output didn't match: {} != {}".format(
repr(values), repr(dbvalues)
),
)
self.assertEqual(
items,
dbitems,
f"items() didn't match: {repr(items)} != {repr(dbitems)}",
)
copyPath = self.mktemp()
d2 = d.copyTo(copyPath)
copykeys = d.keys()
copyvalues = set(d.values())
copyitems = set(d.items())
copykeys.sort()
self.assertEqual(
dbkeys,
copykeys,
".copyTo().keys() didn't match: {} != {}".format(
repr(dbkeys), repr(copykeys)
),
)
self.assertEqual(
dbvalues,
copyvalues,
".copyTo().values() didn't match: %s != %s"
% (repr(dbvalues), repr(copyvalues)),
)
self.assertEqual(
dbitems,
copyitems,
".copyTo().items() didn't match: %s != %s"
% (repr(dbkeys), repr(copyitems)),
)
d2.clear()
self.assertTrue(
len(d2.keys()) == len(d2.values()) == len(d2.items()) == len(d2) == 0,
".clear() failed",
)
self.assertNotEqual(len(d), len(d2))
shutil.rmtree(copyPath)
# Delete items
for k, v in self.items:
del d[k]
self.assertNotIn(
k, d, "key is still in database, even though we deleted it"
)
self.assertEqual(len(d.keys()), 0, "database has keys")
self.assertEqual(len(d.values()), 0, "database has values")
self.assertEqual(len(d.items()), 0, "database has items")
self.assertEqual(len(d), 0, "database has items")
def test_modificationTime(self) -> None:
import time
# The mtime value for files comes from a different place than the
# gettimeofday() system call. On linux, gettimeofday() can be
# slightly ahead (due to clock drift which gettimeofday() takes into
# account but which open()/write()/close() do not), and if we are
# close to the edge of the next second, time.time() can give a value
# which is larger than the mtime which results from a subsequent
# write(). I consider this a kernel bug, but it is beyond the scope
# of this test. Thus we keep the range of acceptability to 3 seconds time.
# -warner
self.dbm[b"k"] = b"v"
self.assertTrue(abs(time.time() - self.dbm.getModificationTime(b"k")) <= 3)
self.assertRaises(KeyError, self.dbm.getModificationTime, b"nokey")
def test_recovery(self) -> None:
"""
DirDBM: test recovery from directory after a faked crash
"""
k = self.dbm._encode(b"key1")
with self.path.child(k + b".rpl").open(mode="w") as f:
f.write(b"value")
k2 = self.dbm._encode(b"key2")
with self.path.child(k2).open(mode="w") as f:
f.write(b"correct")
with self.path.child(k2 + b".rpl").open(mode="w") as f:
f.write(b"wrong")
with self.path.child("aa.new").open(mode="w") as f:
f.write(b"deleted")
dbm = dirdbm.DirDBM(self.path.path)
self.assertEqual(dbm[b"key1"], b"value")
self.assertEqual(dbm[b"key2"], b"correct")
self.assertFalse(self.path.globChildren("*.new"))
self.assertFalse(self.path.globChildren("*.rpl"))
def test_nonStringKeys(self) -> None:
"""
L{dirdbm.DirDBM} operations only support string keys: other types
should raise a L{TypeError}.
"""
self.assertRaises(TypeError, self.dbm.__setitem__, 2, "3")
try:
self.assertRaises(TypeError, self.dbm.__setitem__, "2", 3)
except unittest.FailTest:
# dirdbm.Shelf.__setitem__ supports non-string values
self.assertIsInstance(self.dbm, dirdbm.Shelf)
self.assertRaises(TypeError, self.dbm.__getitem__, 2)
self.assertRaises(TypeError, self.dbm.__delitem__, 2)
self.assertRaises(TypeError, self.dbm.has_key, 2)
self.assertRaises(TypeError, self.dbm.__contains__, 2)
self.assertRaises(TypeError, self.dbm.getModificationTime, 2)
def test_failSet(self) -> None:
"""
Failure path when setting an item.
"""
def _writeFail(path: FilePath[str], data: bytes) -> None:
path.setContent(data)
raise OSError("fail to write")
self.dbm[b"failkey"] = b"test"
self.patch(self.dbm, "_writeFile", _writeFail)
self.assertRaises(IOError, self.dbm.__setitem__, b"failkey", b"test2")
class ShelfTests(DirDbmTests):
def setUp(self) -> None:
self.path = FilePath(self.mktemp())
self.dbm = dirdbm.Shelf(self.path.path)
self.items = (
(b"abc", b"foo"),
(b"/lalal", b"\000\001"),
(b"\000\012", b"baz"),
(b"int", 12),
(b"float", 12.0),
(b"tuple", (None, 12)),
)
testCases = [DirDbmTests, ShelfTests]

View File

@@ -0,0 +1,290 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
from __future__ import annotations
import errno
import socket
import sys
from typing import Sequence
from twisted.internet import error
from twisted.trial import unittest
class StringificationTests(unittest.SynchronousTestCase):
"""Test that the exceptions have useful stringifications."""
listOfTests: list[
tuple[
str,
type[Exception],
Sequence[str | int | Exception | None],
dict[str, str | int],
]
] = [
# (output, exception[, args[, kwargs]]),
("An error occurred binding to an interface.", error.BindError, [], {}),
(
"An error occurred binding to an interface: foo.",
error.BindError,
["foo"],
{},
),
(
"An error occurred binding to an interface: foo bar.",
error.BindError,
["foo", "bar"],
{},
),
(
"Couldn't listen on eth0:4242: Foo.",
error.CannotListenError,
("eth0", 4242, socket.error("Foo")),
{},
),
("Message is too long to send.", error.MessageLengthError, [], {}),
(
"Message is too long to send: foo bar.",
error.MessageLengthError,
["foo", "bar"],
{},
),
("DNS lookup failed.", error.DNSLookupError, [], {}),
("DNS lookup failed: foo bar.", error.DNSLookupError, ["foo", "bar"], {}),
("An error occurred while connecting.", error.ConnectError, [], {}),
(
"An error occurred while connecting: someOsError.",
error.ConnectError,
["someOsError"],
{},
),
(
"An error occurred while connecting: foo.",
error.ConnectError,
[],
{"string": "foo"},
),
(
"An error occurred while connecting: someOsError: foo.",
error.ConnectError,
["someOsError", "foo"],
{},
),
("Couldn't bind.", error.ConnectBindError, [], {}),
("Couldn't bind: someOsError.", error.ConnectBindError, ["someOsError"], {}),
(
"Couldn't bind: someOsError: foo.",
error.ConnectBindError,
["someOsError", "foo"],
{},
),
("Hostname couldn't be looked up.", error.UnknownHostError, [], {}),
("No route to host.", error.NoRouteError, [], {}),
("Connection was refused by other side.", error.ConnectionRefusedError, [], {}),
("TCP connection timed out.", error.TCPTimedOutError, [], {}),
("File used for UNIX socket is no good.", error.BadFileError, [], {}),
(
"Service name given as port is unknown.",
error.ServiceNameUnknownError,
[],
{},
),
("User aborted connection.", error.UserError, [], {}),
("User timeout caused connection failure.", error.TimeoutError, [], {}),
("An SSL error occurred.", error.SSLError, [], {}),
(
"Connection to the other side was lost in a non-clean fashion.",
error.ConnectionLost,
[],
{},
),
(
"Connection to the other side was lost in a non-clean fashion: foo bar.",
error.ConnectionLost,
["foo", "bar"],
{},
),
("Connection was closed cleanly.", error.ConnectionDone, [], {}),
(
"Connection was closed cleanly: foo bar.",
error.ConnectionDone,
["foo", "bar"],
{},
),
(
"Uh.", # TODO nice docstring, you've got there.
error.ConnectionFdescWentAway,
[],
{},
),
("Tried to cancel an already-called event.", error.AlreadyCalled, [], {}),
(
"Tried to cancel an already-called event: foo bar.",
error.AlreadyCalled,
["foo", "bar"],
{},
),
("Tried to cancel an already-cancelled event.", error.AlreadyCancelled, [], {}),
(
"Tried to cancel an already-cancelled event: x 2.",
error.AlreadyCancelled,
["x", "2"],
{},
),
(
"A process has ended without apparent errors: process finished with exit code 0.",
error.ProcessDone,
[None],
{},
),
(
"A process has ended with a probable error condition: process ended.",
error.ProcessTerminated,
[],
{},
),
(
"A process has ended with a probable error condition: process ended with exit code 42.",
error.ProcessTerminated,
[],
{"exitCode": 42},
),
(
"A process has ended with a probable error condition: process ended by signal SIGBUS.",
error.ProcessTerminated,
[],
{"signal": "SIGBUS"},
),
(
"The Connector was not connecting when it was asked to stop connecting.",
error.NotConnectingError,
[],
{},
),
(
"The Connector was not connecting when it was asked to stop connecting: x 13.",
error.NotConnectingError,
["x", "13"],
{},
),
(
"The Port was not listening when it was asked to stop listening.",
error.NotListeningError,
[],
{},
),
(
"The Port was not listening when it was asked to stop listening: a 12.",
error.NotListeningError,
["a", "12"],
{},
),
]
def testThemAll(self) -> None:
for entry in self.listOfTests:
output = entry[0]
exception = entry[1]
args = entry[2]
kwargs = entry[3]
self.assertEqual(str(exception(*args, **kwargs)), output)
def test_connectingCancelledError(self) -> None:
"""
L{error.ConnectingCancelledError} has an C{address} attribute.
"""
address = object()
e = error.ConnectingCancelledError(address)
self.assertIs(e.address, address)
class SubclassingTests(unittest.SynchronousTestCase):
"""
Some exceptions are subclasses of other exceptions.
"""
def test_[AWS-SECRET-REMOVED](self) -> None:
"""
L{error.ConnectionClosed} is a superclass of L{error.ConnectionLost}.
"""
self.assertTrue(issubclass(error.ConnectionLost, error.ConnectionClosed))
def test_[AWS-SECRET-REMOVED](self) -> None:
"""
L{error.ConnectionClosed} is a superclass of L{error.ConnectionDone}.
"""
self.assertTrue(issubclass(error.ConnectionDone, error.ConnectionClosed))
def test_invalidAddressErrorSubclassOfValueError(self) -> None:
"""
L{ValueError} is a superclass of L{error.InvalidAddressError}.
"""
self.assertTrue(issubclass(error.InvalidAddressError, ValueError))
class GetConnectErrorTests(unittest.SynchronousTestCase):
"""
Given an exception instance thrown by C{socket.connect},
L{error.getConnectError} returns the appropriate high-level Twisted
exception instance.
"""
def assertErrnoException(
self, errno: int, expectedClass: type[error.ConnectError]
) -> None:
"""
When called with a tuple with the given errno,
L{error.getConnectError} returns an exception which is an instance of
the expected class.
"""
e = (errno, "lalala")
result = error.getConnectError(e)
self.assertCorrectException(errno, "lalala", result, expectedClass)
def assertCorrectException(
self,
errno: int | None,
message: object,
result: error.ConnectError,
expectedClass: type[error.ConnectError],
) -> None:
"""
The given result of L{error.getConnectError} has the given attributes
(C{osError} and C{args}), and is an instance of the given class.
"""
# Want exact class match, not inherited classes, so no isinstance():
self.assertEqual(result.__class__, expectedClass)
self.assertEqual(result.osError, errno)
self.assertEqual(result.args, (message,))
def test_errno(self) -> None:
"""
L{error.getConnectError} converts based on errno for C{socket.error}.
"""
self.assertErrnoException(errno.ENETUNREACH, error.NoRouteError)
self.assertErrnoException(errno.ECONNREFUSED, error.ConnectionRefusedError)
self.assertErrnoException(errno.ETIMEDOUT, error.TCPTimedOutError)
if sys.platform == "win32":
self.assertErrnoException(
errno.WSAECONNREFUSED, error.ConnectionRefusedError
)
self.assertErrnoException(errno.WSAENETUNREACH, error.NoRouteError)
def test_gaierror(self) -> None:
"""
L{error.getConnectError} converts to a L{error.UnknownHostError} given
a C{socket.gaierror} instance.
"""
result = error.getConnectError(socket.gaierror(12, "hello"))
self.assertCorrectException(12, "hello", result, error.UnknownHostError)
def test_nonTuple(self) -> None:
"""
L{error.getConnectError} converts to a L{error.ConnectError} given
an argument that cannot be unpacked.
"""
e = Exception()
result = error.getConnectError(e)
self.assertCorrectException(None, e, result, error.ConnectError)

View File

@@ -0,0 +1,139 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Test code for basic Factory classes.
"""
import pickle
from twisted.internet.protocol import Protocol, ReconnectingClientFactory
from twisted.internet.task import Clock
from twisted.trial.unittest import TestCase
class FakeConnector:
"""
A fake connector class, to be used to mock connections failed or lost.
"""
def stopConnecting(self):
pass
def connect(self):
pass
class ReconnectingFactoryTests(TestCase):
"""
Tests for L{ReconnectingClientFactory}.
"""
def test_stopTryingWhenConnected(self):
"""
If a L{ReconnectingClientFactory} has C{stopTrying} called while it is
connected, it does not subsequently attempt to reconnect if the
connection is later lost.
"""
class NoConnectConnector:
def stopConnecting(self):
raise RuntimeError("Shouldn't be called, we're connected.")
def connect(self):
raise RuntimeError("Shouldn't be reconnecting.")
c = ReconnectingClientFactory()
c.protocol = Protocol
# Let's pretend we've connected:
c.buildProtocol(None)
# Now we stop trying, then disconnect:
c.stopTrying()
c.clientConnectionLost(NoConnectConnector(), None)
self.assertFalse(c.continueTrying)
def test_stopTryingDoesNotReconnect(self):
"""
Calling stopTrying on a L{ReconnectingClientFactory} doesn't attempt a
retry on any active connector.
"""
class FactoryAwareFakeConnector(FakeConnector):
attemptedRetry = False
def stopConnecting(self):
"""
Behave as though an ongoing connection attempt has now
failed, and notify the factory of this.
"""
f.clientConnectionFailed(self, None)
def connect(self):
"""
Record an attempt to reconnect, since this is what we
are trying to avoid.
"""
self.attemptedRetry = True
f = ReconnectingClientFactory()
f.clock = Clock()
# simulate an active connection - stopConnecting on this connector should
# be triggered when we call stopTrying
f.connector = FactoryAwareFakeConnector()
f.stopTrying()
# make sure we never attempted to retry
self.assertFalse(f.connector.attemptedRetry)
self.assertFalse(f.clock.getDelayedCalls())
def test_serializeUnused(self):
"""
A L{ReconnectingClientFactory} which hasn't been used for anything
can be pickled and unpickled and end up with the same state.
"""
original = ReconnectingClientFactory()
reconstituted = pickle.loads(pickle.dumps(original))
self.assertEqual(original.__dict__, reconstituted.__dict__)
def test_serializeWithClock(self):
"""
The clock attribute of L{ReconnectingClientFactory} is not serialized,
and the restored value sets it to the default value, the reactor.
"""
clock = Clock()
original = ReconnectingClientFactory()
original.clock = clock
reconstituted = pickle.loads(pickle.dumps(original))
self.assertIsNone(reconstituted.clock)
def test_deserializationResetsParameters(self):
"""
A L{ReconnectingClientFactory} which is unpickled does not have an
L{IConnector} and has its reconnecting timing parameters reset to their
initial values.
"""
factory = ReconnectingClientFactory()
factory.clientConnectionFailed(FakeConnector(), None)
self.addCleanup(factory.stopTrying)
serialized = pickle.dumps(factory)
unserialized = pickle.loads(serialized)
self.assertIsNone(unserialized.connector)
self.assertIsNone(unserialized._callID)
self.assertEqual(unserialized.retries, 0)
self.assertEqual(unserialized.delay, factory.initialDelay)
self.assertTrue(unserialized.continueTrying)
def test_parametrizedClock(self):
"""
The clock used by L{ReconnectingClientFactory} can be parametrized, so
that one can cleanly test reconnections.
"""
clock = Clock()
factory = ReconnectingClientFactory()
factory.clock = clock
factory.clientConnectionLost(FakeConnector(), None)
self.assertEqual(len(clock.calls), 1)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,258 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.internet.fdesc}.
"""
import errno
import os
import sys
try:
import fcntl
except ImportError:
skip = "not supported on this platform"
else:
from twisted.internet import fdesc
from twisted.python.util import untilConcludes
from twisted.trial import unittest
class NonBlockingTests(unittest.SynchronousTestCase):
"""
Tests for L{fdesc.setNonBlocking} and L{fdesc.setBlocking}.
"""
def test_setNonBlocking(self):
"""
L{fdesc.setNonBlocking} sets a file description to non-blocking.
"""
r, w = os.pipe()
self.addCleanup(os.close, r)
self.addCleanup(os.close, w)
self.assertFalse(fcntl.fcntl(r, fcntl.F_GETFL) & os.O_NONBLOCK)
fdesc.setNonBlocking(r)
self.assertTrue(fcntl.fcntl(r, fcntl.F_GETFL) & os.O_NONBLOCK)
def test_setBlocking(self):
"""
L{fdesc.setBlocking} sets a file description to blocking.
"""
r, w = os.pipe()
self.addCleanup(os.close, r)
self.addCleanup(os.close, w)
fdesc.setNonBlocking(r)
fdesc.setBlocking(r)
self.assertFalse(fcntl.fcntl(r, fcntl.F_GETFL) & os.O_NONBLOCK)
class ReadWriteTests(unittest.SynchronousTestCase):
"""
Tests for L{fdesc.readFromFD}, L{fdesc.writeToFD}.
"""
def setUp(self):
"""
Create a non-blocking pipe that can be used in tests.
"""
self.r, self.w = os.pipe()
fdesc.setNonBlocking(self.r)
fdesc.setNonBlocking(self.w)
def tearDown(self):
"""
Close pipes.
"""
try:
os.close(self.w)
except OSError:
pass
try:
os.close(self.r)
except OSError:
pass
def write(self, d):
"""
Write data to the pipe.
"""
return fdesc.writeToFD(self.w, d)
def read(self):
"""
Read data from the pipe.
"""
l = []
res = fdesc.readFromFD(self.r, l.append)
if res is None:
if l:
return l[0]
else:
return b""
else:
return res
def test_writeAndRead(self):
"""
Test that the number of bytes L{fdesc.writeToFD} reports as written
with its return value are seen by L{fdesc.readFromFD}.
"""
n = self.write(b"hello")
self.assertTrue(n > 0)
s = self.read()
self.assertEqual(len(s), n)
self.assertEqual(b"hello"[:n], s)
def test_writeAndReadLarge(self):
"""
Similar to L{test_writeAndRead}, but use a much larger string to verify
the behavior for that case.
"""
orig = b"0123456879" * 10000
written = self.write(orig)
self.assertTrue(written > 0)
result = []
resultlength = 0
i = 0
while resultlength < written or i < 50:
result.append(self.read())
resultlength += len(result[-1])
# Increment a counter to be sure we'll exit at some point
i += 1
result = b"".join(result)
self.assertEqual(len(result), written)
self.assertEqual(orig[:written], result)
def test_readFromEmpty(self):
"""
Verify that reading from a file descriptor with no data does not raise
an exception and does not result in the callback function being called.
"""
l = []
result = fdesc.readFromFD(self.r, l.append)
self.assertEqual(l, [])
self.assertIsNone(result)
def test_readFromCleanClose(self):
"""
Test that using L{fdesc.readFromFD} on a cleanly closed file descriptor
returns a connection done indicator.
"""
os.close(self.w)
self.assertEqual(self.read(), fdesc.CONNECTION_DONE)
def test_writeToClosed(self):
"""
Verify that writing with L{fdesc.writeToFD} when the read end is closed
results in a connection lost indicator.
"""
os.close(self.r)
self.assertEqual(self.write(b"s"), fdesc.CONNECTION_LOST)
def test_readFromInvalid(self):
"""
Verify that reading with L{fdesc.readFromFD} when the read end is
closed results in a connection lost indicator.
"""
os.close(self.r)
self.assertEqual(self.read(), fdesc.CONNECTION_LOST)
def test_writeToInvalid(self):
"""
Verify that writing with L{fdesc.writeToFD} when the write end is
closed results in a connection lost indicator.
"""
os.close(self.w)
self.assertEqual(self.write(b"s"), fdesc.CONNECTION_LOST)
def test_writeErrors(self):
"""
Test error path for L{fdesc.writeTod}.
"""
oldOsWrite = os.write
def eagainWrite(fd, data):
err = OSError()
err.errno = errno.EAGAIN
raise err
os.write = eagainWrite
try:
self.assertEqual(self.write(b"s"), 0)
finally:
os.write = oldOsWrite
def eintrWrite(fd, data):
err = OSError()
err.errno = errno.EINTR
raise err
os.write = eintrWrite
try:
self.assertEqual(self.write(b"s"), 0)
finally:
os.write = oldOsWrite
class CloseOnExecTests(unittest.SynchronousTestCase):
"""
Tests for L{fdesc._setCloseOnExec} and L{fdesc._unsetCloseOnExec}.
"""
program = """
import os, errno
try:
os.write(%d, b'lul')
except OSError as e:
if e.errno == errno.EBADF:
os._exit(0)
os._exit(5)
except BaseException:
os._exit(10)
else:
os._exit(20)
"""
def _execWithFileDescriptor(self, fObj):
pid = os.fork()
if pid == 0:
try:
os.execv(
sys.executable,
[sys.executable, "-c", self.program % (fObj.fileno(),)],
)
except BaseException:
import traceback
traceback.print_exc()
os._exit(30)
else:
# On Linux wait(2) doesn't seem ever able to fail with EINTR but
# POSIX seems to allow it and on macOS it happens quite a lot.
return untilConcludes(os.waitpid, pid, 0)[1]
def test_setCloseOnExec(self):
"""
A file descriptor passed to L{fdesc._setCloseOnExec} is not inherited
by a new process image created with one of the exec family of
functions.
"""
with open(self.mktemp(), "wb") as fObj:
fdesc._setCloseOnExec(fObj.fileno())
status = self._execWithFileDescriptor(fObj)
self.assertTrue(os.WIFEXITED(status))
self.assertEqual(os.WEXITSTATUS(status), 0)
def test_unsetCloseOnExec(self):
"""
A file descriptor passed to L{fdesc._unsetCloseOnExec} is inherited by
a new process image created with one of the exec family of functions.
"""
with open(self.mktemp(), "wb") as fObj:
fdesc._setCloseOnExec(fObj.fileno())
fdesc._unsetCloseOnExec(fObj.fileno())
status = self._execWithFileDescriptor(fObj)
self.assertTrue(os.WIFEXITED(status))
self.assertEqual(os.WEXITSTATUS(status), 20)

View File

@@ -0,0 +1,56 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.protocols.finger}.
"""
from twisted.internet.testing import StringTransport
from twisted.protocols import finger
from twisted.trial import unittest
class FingerTests(unittest.TestCase):
"""
Tests for L{finger.Finger}.
"""
def setUp(self) -> None:
"""
Create and connect a L{finger.Finger} instance.
"""
self.transport = StringTransport()
self.protocol = finger.Finger()
self.protocol.makeConnection(self.transport)
def test_simple(self) -> None:
"""
When L{finger.Finger} receives a CR LF terminated line, it responds
with the default user status message - that no such user exists.
"""
self.protocol.dataReceived(b"moshez\r\n")
self.assertEqual(self.transport.value(), b"Login: moshez\nNo such user\n")
def test_simpleW(self) -> None:
"""
The behavior for a query which begins with C{"/w"} is the same as the
behavior for one which does not. The user is reported as not existing.
"""
self.protocol.dataReceived(b"/w moshez\r\n")
self.assertEqual(self.transport.value(), b"Login: moshez\nNo such user\n")
def test_forwarding(self) -> None:
"""
When L{finger.Finger} receives a request for a remote user, it responds
with a message rejecting the request.
"""
self.protocol.dataReceived(b"moshez@example.com\r\n")
self.assertEqual(self.transport.value(), b"Finger forwarding service denied\n")
def test_list(self) -> None:
"""
When L{finger.Finger} receives a blank line, it responds with a message
rejecting the request for all online users.
"""
self.protocol.dataReceived(b"\r\n")
self.assertEqual(self.transport.value(), b"Finger online list denied\n")

View File

@@ -0,0 +1,139 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
from __future__ import annotations
"""
Test cases for formmethod module.
"""
from typing import Callable, Iterable
from typing_extensions import Concatenate, ParamSpec
from twisted.python import formmethod
from twisted.trial import unittest
_P = ParamSpec("_P")
class ArgumentTests(unittest.TestCase):
def argTest(
self,
argKlass: Callable[Concatenate[str, _P], formmethod.Argument],
testPairs: Iterable[tuple[object, object]],
badValues: Iterable[object],
*args: _P.args,
**kwargs: _P.kwargs,
) -> None:
arg = argKlass("name", *args, **kwargs)
for val, result in testPairs:
self.assertEqual(arg.coerce(val), result)
for val in badValues:
self.assertRaises(formmethod.InputError, arg.coerce, val)
def test_argument(self) -> None:
"""
Test that corce correctly raises NotImplementedError.
"""
arg = formmethod.Argument("name")
self.assertRaises(NotImplementedError, arg.coerce, "")
def testString(self) -> None:
self.argTest(formmethod.String, [("a", "a"), (1, "1"), ("", "")], ())
self.argTest(
formmethod.String, [("ab", "ab"), ("abc", "abc")], ("2", ""), min=2
)
self.argTest(
formmethod.String, [("ab", "ab"), ("a", "a")], ("223213", "345x"), max=3
)
self.argTest(
formmethod.String,
[("ab", "ab"), ("add", "add")],
("223213", "x"),
min=2,
max=3,
)
def testInt(self) -> None:
self.argTest(
formmethod.Integer, [("3", 3), ("-2", -2), ("", None)], ("q", "2.3")
)
self.argTest(
formmethod.Integer, [("3", 3), ("-2", -2)], ("q", "2.3", ""), allowNone=0
)
def testFloat(self) -> None:
self.argTest(
formmethod.Float, [("3", 3.0), ("-2.3", -2.3), ("", None)], ("q", "2.3z")
)
self.argTest(
formmethod.Float,
[("3", 3.0), ("-2.3", -2.3)],
("q", "2.3z", ""),
allowNone=0,
)
def testChoice(self) -> None:
choices = [("a", "apple", "an apple"), ("b", "banana", "ook")]
self.argTest(
formmethod.Choice,
[("a", "apple"), ("b", "banana")],
("c", 1),
choices=choices,
)
def testFlags(self) -> None:
flags = [("a", "apple", "an apple"), ("b", "banana", "ook")]
self.argTest(
formmethod.Flags,
[(["a"], ["apple"]), (["b", "a"], ["banana", "apple"])],
(["a", "c"], ["fdfs"]),
flags=flags,
)
def testBoolean(self) -> None:
tests = [("yes", 1), ("", 0), ("False", 0), ("no", 0)]
self.argTest(formmethod.Boolean, tests, ())
def test_file(self) -> None:
"""
Test the correctness of the coerce function.
"""
arg = formmethod.File("name", allowNone=0)
self.assertEqual(arg.coerce("something"), "something")
self.assertRaises(formmethod.InputError, arg.coerce, None)
arg2 = formmethod.File("name")
self.assertIsNone(arg2.coerce(None))
def testDate(self) -> None:
goodTests = {
("2002", "12", "21"): (2002, 12, 21),
("1996", "2", "29"): (1996, 2, 29),
("", "", ""): None,
}.items()
badTests = [
("2002", "2", "29"),
("xx", "2", "3"),
("2002", "13", "1"),
("1999", "12", "32"),
("2002", "1"),
("2002", "2", "3", "4"),
]
self.argTest(formmethod.Date, goodTests, badTests)
def testRangedInteger(self) -> None:
goodTests = {"0": 0, "12": 12, "3": 3}.items()
badTests = ["-1", "x", "13", "-2000", "3.4"]
self.argTest(formmethod.IntegerRange, goodTests, badTests, 0, 12)
def testVerifiedPassword(self) -> None:
goodTests = {("foo", "foo"): "foo", ("ab", "ab"): "ab"}.items()
badTests = [
("ab", "a"),
("12345", "12345"),
("", ""),
("a", "a"),
("a",),
("a", "a", "a"),
]
self.argTest(formmethod.VerifiedPassword, goodTests, badTests, min=2, max=4)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,76 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.tap.ftp}.
"""
from twisted.cred import credentials, error
from twisted.python import versions
from twisted.python.filepath import FilePath
from twisted.tap.ftp import Options
from twisted.trial.unittest import TestCase
class FTPOptionsTests(TestCase):
"""
Tests for the command line option parser used for C{twistd ftp}.
"""
usernamePassword = (b"iamuser", b"thisispassword")
def setUp(self) -> None:
"""
Create a file with two users.
"""
self.filename = self.mktemp()
f = FilePath(self.filename)
f.setContent(b":".join(self.usernamePassword))
self.options = Options()
def test_passwordfileDeprecation(self) -> None:
"""
The C{--password-file} option will emit a warning stating that
said option is deprecated.
"""
self.callDeprecated(
versions.Version("Twisted", 11, 1, 0),
self.options.opt_password_file,
self.filename,
)
def test_authAdded(self) -> None:
"""
The C{--auth} command-line option will add a checker to the list of
checkers
"""
numCheckers = len(self.options["credCheckers"])
self.options.parseOptions(["--auth", "file:" + self.filename])
self.assertEqual(len(self.options["credCheckers"]), numCheckers + 1)
def test_authFailure(self):
"""
The checker created by the C{--auth} command-line option returns a
L{Deferred} that fails with L{UnauthorizedLogin} when
presented with credentials that are unknown to that checker.
"""
self.options.parseOptions(["--auth", "file:" + self.filename])
checker = self.options["credCheckers"][-1]
invalid = credentials.UsernamePassword(self.usernamePassword[0], "fake")
return checker.requestAvatarId(invalid).addCallbacks(
lambda ignore: self.fail("Wrong password should raise error"),
lambda err: err.trap(error.UnauthorizedLogin),
)
def test_authSuccess(self):
"""
The checker created by the C{--auth} command-line option returns a
L{Deferred} that returns the avatar id when presented with credentials
that are known to that checker.
"""
self.options.parseOptions(["--auth", "file:" + self.filename])
checker = self.options["credCheckers"][-1]
correct = credentials.UsernamePassword(*self.usernamePassword)
return checker.requestAvatarId(correct).addCallback(
lambda username: self.assertEqual(username, correct.username)
)

View File

@@ -0,0 +1,118 @@
# -*- Python -*-
__version__ = "$Revision: 1.3 $"[11:-2]
from twisted.protocols import htb
from twisted.trial import unittest
from .test_pcp import DummyConsumer
class DummyClock:
time = 0
def set(self, when: int) -> None:
self.time = when
def __call__(self) -> int:
return self.time
class SomeBucket(htb.Bucket):
maxburst = 100
rate = 2
class TestBucketBase(unittest.TestCase):
def setUp(self) -> None:
self._realTimeFunc = htb.time
self.clock = DummyClock()
htb.time = self.clock
def tearDown(self) -> None:
htb.time = self._realTimeFunc
class BucketTests(TestBucketBase):
def testBucketSize(self) -> None:
"""
Testing the size of the bucket.
"""
b = SomeBucket()
fit = b.add(1000)
self.assertEqual(100, fit)
def testBucketDrain(self) -> None:
"""
Testing the bucket's drain rate.
"""
b = SomeBucket()
fit = b.add(1000)
self.clock.set(10)
fit = b.add(1000)
self.assertEqual(20, fit)
def test_bucketEmpty(self) -> None:
"""
L{htb.Bucket.drip} returns C{True} if the bucket is empty after that drip.
"""
b = SomeBucket()
b.add(20)
self.clock.set(9)
empty = b.drip()
self.assertFalse(empty)
self.clock.set(10)
empty = b.drip()
self.assertTrue(empty)
class BucketNestingTests(TestBucketBase):
def setUp(self) -> None:
TestBucketBase.setUp(self)
self.parent = SomeBucket()
self.child1 = SomeBucket(self.parent)
self.child2 = SomeBucket(self.parent)
def testBucketParentSize(self) -> None:
# Use up most of the parent bucket.
self.child1.add(90)
fit = self.child2.add(90)
self.assertEqual(10, fit)
def testBucketParentRate(self) -> None:
# Make the parent bucket drain slower.
self.parent.rate = 1
# Fill both child1 and parent.
self.child1.add(100)
self.clock.set(10)
fit = self.child1.add(100)
# How much room was there? The child bucket would have had 20,
# but the parent bucket only ten (so no, it wouldn't make too much
# sense to have a child bucket draining faster than its parent in a real
# application.)
self.assertEqual(10, fit)
# TODO: Test the Transport stuff?
class ConsumerShaperTests(TestBucketBase):
def setUp(self) -> None:
TestBucketBase.setUp(self)
self.underlying = DummyConsumer()
self.bucket = SomeBucket()
self.shaped = htb.ShapedConsumer(self.underlying, self.bucket)
def testRate(self) -> None:
# Start off with a full bucket, so the burst-size doesn't factor in
# to the calculations.
delta_t = 10
self.bucket.add(100)
self.shaped.write("x" * 100)
self.clock.set(delta_t)
self.shaped.resumeProducing()
self.assertEqual(len(self.underlying.getvalue()), delta_t * self.bucket.rate)
def testBucketRefs(self) -> None:
self.assertEqual(self.bucket._refcount, 1)
self.shaped.stopProducing()
self.assertEqual(self.bucket._refcount, 0)

View File

@@ -0,0 +1,211 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Test cases for twisted.protocols.ident module.
"""
import builtins
import struct
from io import StringIO
from twisted.internet import defer, error
from twisted.internet.testing import StringTransport
from twisted.protocols import ident
from twisted.python import failure
from twisted.trial import unittest
class ClassParserTests(unittest.TestCase):
"""
Test parsing of ident responses.
"""
def setUp(self):
"""
Create an ident client used in tests.
"""
self.client = ident.IdentClient()
def test_indentError(self):
"""
'UNKNOWN-ERROR' error should map to the L{ident.IdentError} exception.
"""
d = defer.Deferred()
self.client.queries.append((d, 123, 456))
self.client.lineReceived("123, 456 : ERROR : UNKNOWN-ERROR")
return self.assertFailure(d, ident.IdentError)
def test_noUSerError(self):
"""
'NO-USER' error should map to the L{ident.NoUser} exception.
"""
d = defer.Deferred()
self.client.queries.append((d, 234, 456))
self.client.lineReceived("234, 456 : ERROR : NO-USER")
return self.assertFailure(d, ident.NoUser)
def test_invalidPortError(self):
"""
'INVALID-PORT' error should map to the L{ident.InvalidPort} exception.
"""
d = defer.Deferred()
self.client.queries.append((d, 345, 567))
self.client.lineReceived("345, 567 : ERROR : INVALID-PORT")
return self.assertFailure(d, ident.InvalidPort)
def test_hiddenUserError(self):
"""
'HIDDEN-USER' error should map to the L{ident.HiddenUser} exception.
"""
d = defer.Deferred()
self.client.queries.append((d, 567, 789))
self.client.lineReceived("567, 789 : ERROR : HIDDEN-USER")
return self.assertFailure(d, ident.HiddenUser)
def test_lostConnection(self):
"""
A pending query which failed because of a ConnectionLost should
receive an L{ident.IdentError}.
"""
d = defer.Deferred()
self.client.queries.append((d, 765, 432))
self.client.connectionLost(failure.Failure(error.ConnectionLost()))
return self.assertFailure(d, ident.IdentError)
class TestIdentServer(ident.IdentServer):
def lookup(self, serverAddress, clientAddress):
return self.resultValue
class TestErrorIdentServer(ident.IdentServer):
def lookup(self, serverAddress, clientAddress):
raise self.exceptionType()
class NewException(RuntimeError):
pass
class ServerParserTests(unittest.TestCase):
def testErrors(self):
p = TestErrorIdentServer()
p.makeConnection(StringTransport())
L = []
p.sendLine = L.append
p.exceptionType = ident.IdentError
p.lineReceived("123, 345")
self.assertEqual(L[0], "123, 345 : ERROR : UNKNOWN-ERROR")
p.exceptionType = ident.NoUser
p.lineReceived("432, 210")
self.assertEqual(L[1], "432, 210 : ERROR : NO-USER")
p.exceptionType = ident.InvalidPort
p.lineReceived("987, 654")
self.assertEqual(L[2], "987, 654 : ERROR : INVALID-PORT")
p.exceptionType = ident.HiddenUser
p.lineReceived("756, 827")
self.assertEqual(L[3], "756, 827 : ERROR : HIDDEN-USER")
p.exceptionType = NewException
p.lineReceived("987, 789")
self.assertEqual(L[4], "987, 789 : ERROR : UNKNOWN-ERROR")
errs = self.flushLoggedErrors(NewException)
self.assertEqual(len(errs), 1)
for port in -1, 0, 65536, 65537:
del L[:]
p.lineReceived("%d, 5" % (port,))
p.lineReceived("5, %d" % (port,))
self.assertEqual(
L,
[
"%d, 5 : ERROR : INVALID-PORT" % (port,),
"5, %d : ERROR : INVALID-PORT" % (port,),
],
)
def testSuccess(self):
p = TestIdentServer()
p.makeConnection(StringTransport())
L = []
p.sendLine = L.append
p.resultValue = ("SYS", "USER")
p.lineReceived("123, 456")
self.assertEqual(L[0], "123, 456 : USERID : SYS : USER")
if struct.pack("=L", 1)[0:1] == b"\x01":
_addr1 = "0100007F"
_addr2 = "04030201"
else:
_addr1 = "7F000001"
_addr2 = "01020304"
class ProcMixinTests(unittest.TestCase):
line = (
"4: %s:0019 %s:02FA 0A 00000000:00000000 "
"00:00000000 00000000 0 0 10927 1 f72a5b80 "
"3000 0 0 2 -1"
) % (_addr1, _addr2)
sampleFile = (
" sl local_address rem_address st tx_queue rx_queue tr "
"tm->when retrnsmt uid timeout inode\n " + line
)
def testDottedQuadFromHexString(self):
p = ident.ProcServerMixin()
self.assertEqual(p.dottedQuadFromHexString(_addr1), "127.0.0.1")
def testUnpackAddress(self):
p = ident.ProcServerMixin()
self.assertEqual(p.unpackAddress(_addr1 + ":0277"), ("127.0.0.1", 631))
def testLineParser(self):
p = ident.ProcServerMixin()
self.assertEqual(
p.parseLine(self.line), (("127.0.0.1", 25), ("1.2.3.4", 762), 0)
)
def testExistingAddress(self):
username = []
p = ident.ProcServerMixin()
p.entries = lambda: iter([self.line])
p.getUsername = lambda uid: (username.append(uid), "root")[1]
self.assertEqual(
p.lookup(("127.0.0.1", 25), ("1.2.3.4", 762)), (p.SYSTEM_NAME, "root")
)
self.assertEqual(username, [0])
def testNonExistingAddress(self):
p = ident.ProcServerMixin()
p.entries = lambda: iter([self.line])
self.assertRaises(ident.NoUser, p.lookup, ("127.0.0.1", 26), ("1.2.3.4", 762))
self.assertRaises(ident.NoUser, p.lookup, ("127.0.0.1", 25), ("1.2.3.5", 762))
self.assertRaises(ident.NoUser, p.lookup, ("127.0.0.1", 25), ("1.2.3.4", 763))
def testLookupProcNetTcp(self):
"""
L{ident.ProcServerMixin.lookup} uses the Linux TCP process table.
"""
open_calls = []
def mocked_open(*args, **kwargs):
"""
Mock for the open call to prevent actually opening /proc/net/tcp.
"""
open_calls.append((args, kwargs))
return StringIO(self.sampleFile)
self.patch(builtins, "open", mocked_open)
p = ident.ProcServerMixin()
self.assertRaises(ident.NoUser, p.lookup, ("127.0.0.1", 26), ("1.2.3.4", 762))
self.assertEqual([(("/proc/net/tcp",), {})], open_calls)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,314 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.test.iosim}.
"""
from __future__ import annotations
from typing import Literal
from zope.interface import implementer
from twisted.internet.interfaces import IPushProducer
from twisted.internet.protocol import Protocol
from twisted.internet.task import Clock
from twisted.test.iosim import FakeTransport, connect, connectedServerAndClient
from twisted.trial.unittest import TestCase
class FakeTransportTests(TestCase):
"""
Tests for L{FakeTransport}.
"""
def test_connectionSerial(self) -> None:
"""
Each L{FakeTransport} receives a serial number that uniquely identifies
it.
"""
a = FakeTransport(object(), True)
b = FakeTransport(object(), False)
self.assertIsInstance(a.serial, int)
self.assertIsInstance(b.serial, int)
self.assertNotEqual(a.serial, b.serial)
def test_writeSequence(self) -> None:
"""
L{FakeTransport.writeSequence} will write a sequence of L{bytes} to the
transport.
"""
a = FakeTransport(object(), False)
a.write(b"a")
a.writeSequence([b"b", b"c", b"d"])
self.assertEqual(b"".join(a.stream), b"abcd")
def test_writeAfterClose(self) -> None:
"""
L{FakeTransport.write} will accept writes after transport was closed,
but the data will be silently discarded.
"""
a = FakeTransport(object(), False)
a.write(b"before")
a.loseConnection()
a.write(b"after")
self.assertEqual(b"".join(a.stream), b"before")
@implementer(IPushProducer)
class StrictPushProducer:
"""
An L{IPushProducer} implementation which produces nothing but enforces
preconditions on its state transition methods.
"""
_state = "running"
def stopProducing(self) -> None:
if self._state == "stopped":
raise ValueError("Cannot stop already-stopped IPushProducer")
self._state = "stopped"
def pauseProducing(self) -> None:
if self._state != "running":
raise ValueError(f"Cannot pause {self._state} IPushProducer")
self._state = "paused"
def resumeProducing(self) -> None:
if self._state != "paused":
raise ValueError(f"Cannot resume {self._state} IPushProducer")
self._state = "running"
class StrictPushProducerTests(TestCase):
"""
Tests for L{StrictPushProducer}.
"""
def _initial(self) -> StrictPushProducer:
"""
@return: A new L{StrictPushProducer} which has not been through any state
changes.
"""
return StrictPushProducer()
def _stopped(self) -> StrictPushProducer:
"""
@return: A new, stopped L{StrictPushProducer}.
"""
producer = StrictPushProducer()
producer.stopProducing()
return producer
def _paused(self) -> StrictPushProducer:
"""
@return: A new, paused L{StrictPushProducer}.
"""
producer = StrictPushProducer()
producer.pauseProducing()
return producer
def _resumed(self) -> StrictPushProducer:
"""
@return: A new L{StrictPushProducer} which has been paused and resumed.
"""
producer = StrictPushProducer()
producer.pauseProducing()
producer.resumeProducing()
return producer
def assertStopped(self, producer: StrictPushProducer) -> None:
"""
Assert that the given producer is in the stopped state.
@param producer: The producer to verify.
@type producer: L{StrictPushProducer}
"""
self.assertEqual(producer._state, "stopped")
def assertPaused(self, producer: StrictPushProducer) -> None:
"""
Assert that the given producer is in the paused state.
@param producer: The producer to verify.
@type producer: L{StrictPushProducer}
"""
self.assertEqual(producer._state, "paused")
def assertRunning(self, producer: StrictPushProducer) -> None:
"""
Assert that the given producer is in the running state.
@param producer: The producer to verify.
@type producer: L{StrictPushProducer}
"""
self.assertEqual(producer._state, "running")
def test_stopThenStop(self) -> None:
"""
L{StrictPushProducer.stopProducing} raises L{ValueError} if called when
the producer is stopped.
"""
self.assertRaises(ValueError, self._stopped().stopProducing)
def test_stopThenPause(self) -> None:
"""
L{StrictPushProducer.pauseProducing} raises L{ValueError} if called when
the producer is stopped.
"""
self.assertRaises(ValueError, self._stopped().pauseProducing)
def test_stopThenResume(self) -> None:
"""
L{StrictPushProducer.resumeProducing} raises L{ValueError} if called when
the producer is stopped.
"""
self.assertRaises(ValueError, self._stopped().resumeProducing)
def test_pauseThenStop(self) -> None:
"""
L{StrictPushProducer} is stopped if C{stopProducing} is called on a paused
producer.
"""
producer = self._paused()
producer.stopProducing()
self.assertStopped(producer)
def test_pauseThenPause(self) -> None:
"""
L{StrictPushProducer.pauseProducing} raises L{ValueError} if called on a
paused producer.
"""
producer = self._paused()
self.assertRaises(ValueError, producer.pauseProducing)
def test_pauseThenResume(self) -> None:
"""
L{StrictPushProducer} is resumed if C{resumeProducing} is called on a
paused producer.
"""
producer = self._paused()
producer.resumeProducing()
self.assertRunning(producer)
def test_resumeThenStop(self) -> None:
"""
L{StrictPushProducer} is stopped if C{stopProducing} is called on a
resumed producer.
"""
producer = self._resumed()
producer.stopProducing()
self.assertStopped(producer)
def test_resumeThenPause(self) -> None:
"""
L{StrictPushProducer} is paused if C{pauseProducing} is called on a
resumed producer.
"""
producer = self._resumed()
producer.pauseProducing()
self.assertPaused(producer)
def test_resumeThenResume(self) -> None:
"""
L{StrictPushProducer.resumeProducing} raises L{ValueError} if called on a
resumed producer.
"""
producer = self._resumed()
self.assertRaises(ValueError, producer.resumeProducing)
def test_stop(self) -> None:
"""
L{StrictPushProducer} is stopped if C{stopProducing} is called in the
initial state.
"""
producer = self._initial()
producer.stopProducing()
self.assertStopped(producer)
def test_pause(self) -> None:
"""
L{StrictPushProducer} is paused if C{pauseProducing} is called in the
initial state.
"""
producer = self._initial()
producer.pauseProducing()
self.assertPaused(producer)
def test_resume(self) -> None:
"""
L{StrictPushProducer} raises L{ValueError} if C{resumeProducing} is called
in the initial state.
"""
producer = self._initial()
self.assertRaises(ValueError, producer.resumeProducing)
class IOPumpTests(TestCase):
"""
Tests for L{IOPump}.
"""
def _testStreamingProducer(self, mode: Literal["server", "client"]) -> None:
"""
Connect a couple protocol/transport pairs to an L{IOPump} and then pump
it. Verify that a streaming producer registered with one of the
transports does not receive invalid L{IPushProducer} method calls and
ends in the right state.
@param mode: C{u"server"} to test a producer registered with the
server transport. C{u"client"} to test a producer registered with
the client transport.
"""
serverProto = Protocol()
serverTransport = FakeTransport(serverProto, isServer=True)
clientProto = Protocol()
clientTransport = FakeTransport(clientProto, isServer=False)
pump = connect(
serverProto,
serverTransport,
clientProto,
clientTransport,
greet=False,
)
producer = StrictPushProducer()
victim = {
"server": serverTransport,
"client": clientTransport,
}[mode]
victim.registerProducer(producer, streaming=True)
pump.pump()
self.assertEqual("running", producer._state)
def test_serverStreamingProducer(self) -> None:
"""
L{IOPump.pump} does not call C{resumeProducing} on a L{IPushProducer}
(stream producer) registered with the server transport.
"""
self._testStreamingProducer(mode="server")
def test_clientStreamingProducer(self) -> None:
"""
L{IOPump.pump} does not call C{resumeProducing} on a L{IPushProducer}
(stream producer) registered with the client transport.
"""
self._testStreamingProducer(mode="client")
def test_timeAdvances(self) -> None:
"""
L{IOPump.pump} advances time in the given L{Clock}.
"""
time_passed = []
clock = Clock()
_, _, pump = connectedServerAndClient(Protocol, Protocol, clock=clock)
clock.callLater(0, lambda: time_passed.append(True))
self.assertFalse(time_passed)
pump.pump()
self.assertTrue(time_passed)

View File

@@ -0,0 +1,383 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Test running processes with the APIs in L{twisted.internet.utils}.
"""
import os
import signal
import stat
import sys
import warnings
from unittest import skipIf
from twisted.internet import error, interfaces, reactor, utils
from twisted.internet.defer import Deferred
from twisted.python.runtime import platform
from twisted.python.test.test_util import SuppressedWarningsTests
from twisted.trial.unittest import SynchronousTestCase, TestCase
class ProcessUtilsTests(TestCase):
"""
Test running a process using L{getProcessOutput}, L{getProcessValue}, and
L{getProcessOutputAndValue}.
"""
if interfaces.IReactorProcess(reactor, None) is None:
skip = "reactor doesn't implement IReactorProcess"
output = None
value = None
exe = sys.executable
def makeSourceFile(self, sourceLines):
"""
Write the given list of lines to a text file and return the absolute
path to it.
"""
script = self.mktemp()
with open(script, "wt") as scriptFile:
scriptFile.write(os.linesep.join(sourceLines) + os.linesep)
return os.path.abspath(script)
def test_output(self):
"""
L{getProcessOutput} returns a L{Deferred} which fires with the complete
output of the process it runs after that process exits.
"""
scriptFile = self.makeSourceFile(
[
"import sys",
"for s in b'hello world\\n':",
" s = bytes([s])",
" sys.stdout.buffer.write(s)",
" sys.stdout.flush()",
]
)
d = utils.getProcessOutput(self.exe, ["-u", scriptFile])
return d.addCallback(self.assertEqual, b"hello world\n")
def test_outputWithErrorIgnored(self):
"""
The L{Deferred} returned by L{getProcessOutput} is fired with an
L{IOError} L{Failure} if the child process writes to stderr.
"""
# make sure stderr raises an error normally
scriptFile = self.makeSourceFile(
["import sys", 'sys.stderr.write("hello world\\n")']
)
d = utils.getProcessOutput(self.exe, ["-u", scriptFile])
d = self.assertFailure(d, IOError)
def cbFailed(err):
return self.assertFailure(err.processEnded, error.ProcessDone)
d.addCallback(cbFailed)
return d
def test_outputWithErrorCollected(self):
"""
If a C{True} value is supplied for the C{errortoo} parameter to
L{getProcessOutput}, the returned L{Deferred} fires with the child's
stderr output as well as its stdout output.
"""
scriptFile = self.makeSourceFile(
[
"import sys",
# Write the same value to both because ordering isn't guaranteed so
# this simplifies the test.
'sys.stdout.write("foo")',
"sys.stdout.flush()",
'sys.stderr.write("foo")',
"sys.stderr.flush()",
]
)
d = utils.getProcessOutput(self.exe, ["-u", scriptFile], errortoo=True)
return d.addCallback(self.assertEqual, b"foofoo")
def test_value(self):
"""
The L{Deferred} returned by L{getProcessValue} is fired with the exit
status of the child process.
"""
scriptFile = self.makeSourceFile(["raise SystemExit(1)"])
d = utils.getProcessValue(self.exe, ["-u", scriptFile])
return d.addCallback(self.assertEqual, 1)
def test_outputAndValue(self):
"""
The L{Deferred} returned by L{getProcessOutputAndValue} fires with a
three-tuple, the elements of which give the data written to the child's
stdout, the data written to the child's stderr, and the exit status of
the child.
"""
scriptFile = self.makeSourceFile(
[
"import sys",
"sys.stdout.buffer.write(b'hello world!\\n')",
"sys.stderr.buffer.write(b'goodbye world!\\n')",
"sys.exit(1)",
]
)
def gotOutputAndValue(out_err_code):
out, err, code = out_err_code
self.assertEqual(out, b"hello world!\n")
self.assertEqual(err, b"goodbye world!\n")
self.assertEqual(code, 1)
d = utils.getProcessOutputAndValue(self.exe, ["-u", scriptFile])
return d.addCallback(gotOutputAndValue)
@skipIf(platform.isWindows(), "Windows doesn't have real signals.")
def test_outputSignal(self):
"""
If the child process exits because of a signal, the L{Deferred}
returned by L{getProcessOutputAndValue} fires a L{Failure} of a tuple
containing the child's stdout, stderr, and the signal which caused
it to exit.
"""
# Use SIGKILL here because it's guaranteed to be delivered. Using
# SIGHUP might not work in, e.g., a buildbot slave run under the
# 'nohup' command.
scriptFile = self.makeSourceFile(
[
"import sys, os, signal",
"sys.stdout.write('stdout bytes\\n')",
"sys.stderr.write('stderr bytes\\n')",
"sys.stdout.flush()",
"sys.stderr.flush()",
"os.kill(os.getpid(), signal.SIGKILL)",
]
)
def gotOutputAndValue(out_err_sig):
out, err, sig = out_err_sig
self.assertEqual(out, b"stdout bytes\n")
self.assertEqual(err, b"stderr bytes\n")
self.assertEqual(sig, signal.SIGKILL)
d = utils.getProcessOutputAndValue(self.exe, ["-u", scriptFile])
d = self.assertFailure(d, tuple)
return d.addCallback(gotOutputAndValue)
def _pathTest(self, utilFunc, check):
dir = os.path.abspath(self.mktemp())
os.makedirs(dir)
scriptFile = self.makeSourceFile(
["import os, sys", "sys.stdout.write(os.getcwd())"]
)
d = utilFunc(self.exe, ["-u", scriptFile], path=dir)
d.addCallback(check, dir.encode(sys.getfilesystemencoding()))
return d
def test_getProcessOutputPath(self):
"""
L{getProcessOutput} runs the given command with the working directory
given by the C{path} parameter.
"""
return self._pathTest(utils.getProcessOutput, self.assertEqual)
def test_getProcessValuePath(self):
"""
L{getProcessValue} runs the given command with the working directory
given by the C{path} parameter.
"""
def check(result, ignored):
self.assertEqual(result, 0)
return self._pathTest(utils.getProcessValue, check)
def test_getProcessOutputAndValuePath(self):
"""
L{getProcessOutputAndValue} runs the given command with the working
directory given by the C{path} parameter.
"""
def check(out_err_status, dir):
out, err, status = out_err_status
self.assertEqual(out, dir)
self.assertEqual(status, 0)
return self._pathTest(utils.getProcessOutputAndValue, check)
def _defaultPathTest(self, utilFunc, check):
# Make another directory to mess around with.
dir = os.path.abspath(self.mktemp())
os.makedirs(dir)
scriptFile = self.makeSourceFile(
["import os, sys", "cdir = os.getcwd()", "sys.stdout.write(cdir)"]
)
# Switch to it, but make sure we switch back
self.addCleanup(os.chdir, os.getcwd())
os.chdir(dir)
# Remember its default permissions.
originalMode = stat.S_IMODE(os.stat(".").st_mode)
# On macOS Catalina (and maybe elsewhere), os.getcwd() sometimes fails
# with EACCES if u+rx is missing from the working directory, so don't
# reduce it further than this.
os.chmod(dir, stat.S_IXUSR | stat.S_IRUSR)
# Restore the permissions to their original state later (probably
# adding at least u+w), because otherwise it might be hard to delete
# the trial temporary directory.
self.addCleanup(os.chmod, dir, originalMode)
d = utilFunc(self.exe, ["-u", scriptFile])
d.addCallback(check, dir.encode(sys.getfilesystemencoding()))
return d
def test_getProcessOutputDefaultPath(self):
"""
If no value is supplied for the C{path} parameter, L{getProcessOutput}
runs the given command in the same working directory as the parent
process and succeeds even if the current working directory is not
accessible.
"""
return self._defaultPathTest(utils.getProcessOutput, self.assertEqual)
def test_getProcessValueDefaultPath(self):
"""
If no value is supplied for the C{path} parameter, L{getProcessValue}
runs the given command in the same working directory as the parent
process and succeeds even if the current working directory is not
accessible.
"""
def check(result, ignored):
self.assertEqual(result, 0)
return self._defaultPathTest(utils.getProcessValue, check)
def test_getProcessOutputAndValueDefaultPath(self):
"""
If no value is supplied for the C{path} parameter,
L{getProcessOutputAndValue} runs the given command in the same working
directory as the parent process and succeeds even if the current
working directory is not accessible.
"""
def check(out_err_status, dir):
out, err, status = out_err_status
self.assertEqual(out, dir)
self.assertEqual(status, 0)
return self._defaultPathTest(utils.getProcessOutputAndValue, check)
def test_get_processOutputAndValueStdin(self):
"""
Standard input can be made available to the child process by passing
bytes for the `stdinBytes` parameter.
"""
scriptFile = self.makeSourceFile(
[
"import sys",
"sys.stdout.write(sys.stdin.read())",
]
)
stdinBytes = b"These are the bytes to see."
d = utils.getProcessOutputAndValue(
self.exe,
["-u", scriptFile],
stdinBytes=stdinBytes,
)
def gotOutputAndValue(out_err_code):
out, err, code = out_err_code
# Avoid making an exact equality comparison in case there is extra
# random output on stdout (warnings, stray print statements,
# logging, who knows).
self.assertIn(stdinBytes, out)
self.assertEqual(0, code)
d.addCallback(gotOutputAndValue)
return d
class SuppressWarningsTests(SynchronousTestCase):
"""
Tests for L{utils.suppressWarnings}.
"""
def test_suppressWarnings(self):
"""
L{utils.suppressWarnings} decorates a function so that the given
warnings are suppressed.
"""
result = []
def showwarning(self, *a, **kw):
result.append((a, kw))
self.patch(warnings, "showwarning", showwarning)
def f(msg):
warnings.warn(msg)
g = utils.suppressWarnings(f, (("ignore",), dict(message="This is message")))
# Start off with a sanity check - calling the original function
# should emit the warning.
f("Sanity check message")
self.assertEqual(len(result), 1)
# Now that that's out of the way, call the wrapped function, and
# make sure no new warnings show up.
g("This is message")
self.assertEqual(len(result), 1)
# Finally, emit another warning which should not be ignored, and
# make sure it is not.
g("Unignored message")
self.assertEqual(len(result), 2)
class DeferredSuppressedWarningsTests(SuppressedWarningsTests):
"""
Tests for L{utils.runWithWarningsSuppressed}, the version that supports
Deferreds.
"""
# Override the non-Deferred-supporting function from the base class with
# the function we are testing in this class:
runWithWarningsSuppressed = staticmethod(utils.runWithWarningsSuppressed)
def test_deferredCallback(self):
"""
If the function called by L{utils.runWithWarningsSuppressed} returns a
C{Deferred}, the warning filters aren't removed until the Deferred
fires.
"""
filters = [(("ignore", ".*foo.*"), {}), (("ignore", ".*bar.*"), {})]
result = Deferred()
self.runWithWarningsSuppressed(filters, lambda: result)
warnings.warn("ignore foo")
result.callback(3)
warnings.warn("ignore foo 2")
self.assertEqual(["ignore foo 2"], [w["message"] for w in self.flushWarnings()])
def test_deferredErrback(self):
"""
If the function called by L{utils.runWithWarningsSuppressed} returns a
C{Deferred}, the warning filters aren't removed until the Deferred
fires with an errback.
"""
filters = [(("ignore", ".*foo.*"), {}), (("ignore", ".*bar.*"), {})]
result = Deferred()
d = self.runWithWarningsSuppressed(filters, lambda: result)
warnings.warn("ignore foo")
result.errback(ZeroDivisionError())
d.addErrback(lambda f: f.trap(ZeroDivisionError))
warnings.warn("ignore foo 2")
self.assertEqual(["ignore foo 2"], [w["message"] for w in self.flushWarnings()])

View File

@@ -0,0 +1,457 @@
# Copyright (c) 2005 Divmod, Inc.
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.python.lockfile}.
"""
from __future__ import annotations
import errno
import os
from unittest import skipIf, skipUnless
from typing_extensions import NoReturn
from twisted.python import lockfile
from twisted.python.reflect import requireModule
from twisted.python.runtime import platform
from twisted.trial.unittest import TestCase
skipKill = False
skipKillReason = ""
if platform.isWindows():
if (
requireModule("win32api.OpenProcess") is None
and requireModule("pywintypes") is None
):
skipKill = True
skipKillReason = (
"On windows, lockfile.kill is not implemented "
"in the absence of win32api and/or pywintypes."
)
class UtilTests(TestCase):
"""
Tests for the helper functions used to implement L{FilesystemLock}.
"""
def test_symlinkEEXIST(self) -> None:
"""
L{lockfile.symlink} raises L{OSError} with C{errno} set to L{EEXIST}
when an attempt is made to create a symlink which already exists.
"""
name = self.mktemp()
lockfile.symlink("foo", name)
exc = self.assertRaises(OSError, lockfile.symlink, "foo", name)
self.assertEqual(exc.errno, errno.EEXIST)
@skipUnless(
platform.isWindows(),
"special rename EIO handling only necessary and correct on " "Windows.",
)
def test_symlinkEIOWindows(self) -> None:
"""
L{lockfile.symlink} raises L{OSError} with C{errno} set to L{EIO} when
the underlying L{rename} call fails with L{EIO}.
Renaming a file on Windows may fail if the target of the rename is in
the process of being deleted (directory deletion appears not to be
atomic).
"""
name = self.mktemp()
def fakeRename(src: str, dst: str) -> NoReturn:
raise OSError(errno.EIO, None)
self.patch(lockfile, "rename", fakeRename)
exc = self.assertRaises(IOError, lockfile.symlink, name, "foo")
self.assertEqual(exc.errno, errno.EIO)
def test_readlinkENOENT(self) -> None:
"""
L{lockfile.readlink} raises L{OSError} with C{errno} set to L{ENOENT}
when an attempt is made to read a symlink which does not exist.
"""
name = self.mktemp()
exc = self.assertRaises(OSError, lockfile.readlink, name)
self.assertEqual(exc.errno, errno.ENOENT)
@skipUnless(
platform.isWindows(),
"special readlink EACCES handling only necessary and " "correct on Windows.",
)
def test_readlinkEACCESWindows(self) -> None:
"""
L{lockfile.readlink} raises L{OSError} with C{errno} set to L{EACCES}
on Windows when the underlying file open attempt fails with C{EACCES}.
Opening a file on Windows may fail if the path is inside a directory
which is in the process of being deleted (directory deletion appears
not to be atomic).
"""
name = self.mktemp()
def fakeOpen(path: str, mode: str) -> NoReturn:
raise OSError(errno.EACCES, None)
self.patch(lockfile, "_open", fakeOpen)
exc = self.assertRaises(IOError, lockfile.readlink, name)
self.assertEqual(exc.errno, errno.EACCES)
@skipIf(skipKill, skipKillReason)
def test_kill(self) -> None:
"""
L{lockfile.kill} returns without error if passed the PID of a
process which exists and signal C{0}.
"""
lockfile.kill(os.getpid(), 0)
@skipIf(skipKill, skipKillReason)
def test_killESRCH(self) -> None:
"""
L{lockfile.kill} raises L{OSError} with errno of L{ESRCH} if
passed a PID which does not correspond to any process.
"""
# Hopefully there is no process with PID 2 ** 31 - 1
exc = self.assertRaises(OSError, lockfile.kill, 2**31 - 1, 0)
self.assertEqual(exc.errno, errno.ESRCH)
def test_noKillCall(self) -> None:
"""
Verify that when L{lockfile.kill} does end up as None (e.g. on Windows
without pywin32), it doesn't end up being called and raising a
L{TypeError}.
"""
self.patch(lockfile, "kill", None)
fl = lockfile.FilesystemLock(self.mktemp())
fl.lock()
self.assertFalse(fl.lock())
class LockingTests(TestCase):
def _symlinkErrorTest(self, errno: int) -> None:
def fakeSymlink(source: str, dest: str) -> NoReturn:
raise OSError(errno, None)
self.patch(lockfile, "symlink", fakeSymlink)
lockf = self.mktemp()
lock = lockfile.FilesystemLock(lockf)
exc = self.assertRaises(OSError, lock.lock)
self.assertEqual(exc.errno, errno)
def test_symlinkError(self) -> None:
"""
An exception raised by C{symlink} other than C{EEXIST} is passed up to
the caller of L{FilesystemLock.lock}.
"""
self._symlinkErrorTest(errno.ENOSYS)
@skipIf(
platform.isWindows(),
"POSIX-specific error propagation not expected on Windows.",
)
def test_symlinkErrorPOSIX(self) -> None:
"""
An L{OSError} raised by C{symlink} on a POSIX platform with an errno of
C{EACCES} or C{EIO} is passed to the caller of L{FilesystemLock.lock}.
On POSIX, unlike on Windows, these are unexpected errors which cannot
be handled by L{FilesystemLock}.
"""
self._symlinkErrorTest(errno.EACCES)
self._symlinkErrorTest(errno.EIO)
def test_cleanlyAcquire(self) -> None:
"""
If the lock has never been held, it can be acquired and the C{clean}
and C{locked} attributes are set to C{True}.
"""
lockf = self.mktemp()
lock = lockfile.FilesystemLock(lockf)
self.assertTrue(lock.lock())
self.assertTrue(lock.clean)
self.assertTrue(lock.locked)
def test_cleanlyRelease(self) -> None:
"""
If a lock is released cleanly, it can be re-acquired and the C{clean}
and C{locked} attributes are set to C{True}.
"""
lockf = self.mktemp()
lock = lockfile.FilesystemLock(lockf)
self.assertTrue(lock.lock())
lock.unlock()
self.assertFalse(lock.locked)
lock = lockfile.FilesystemLock(lockf)
self.assertTrue(lock.lock())
self.assertTrue(lock.clean)
self.assertTrue(lock.locked)
def test_cannotLockLocked(self) -> None:
"""
If a lock is currently locked, it cannot be locked again.
"""
lockf = self.mktemp()
firstLock = lockfile.FilesystemLock(lockf)
self.assertTrue(firstLock.lock())
secondLock = lockfile.FilesystemLock(lockf)
self.assertFalse(secondLock.lock())
self.assertFalse(secondLock.locked)
def test_uncleanlyAcquire(self) -> None:
"""
If a lock was held by a process which no longer exists, it can be
acquired, the C{clean} attribute is set to C{False}, and the
C{locked} attribute is set to C{True}.
"""
owner = 12345
def fakeKill(pid: int, signal: int) -> None:
if signal != 0:
raise OSError(errno.EPERM, None)
if pid == owner:
raise OSError(errno.ESRCH, None)
lockf = self.mktemp()
self.patch(lockfile, "kill", fakeKill)
lockfile.symlink(str(owner), lockf)
lock = lockfile.FilesystemLock(lockf)
self.assertTrue(lock.lock())
self.assertFalse(lock.clean)
self.assertTrue(lock.locked)
self.assertEqual(lockfile.readlink(lockf), str(os.getpid()))
def test_lockReleasedBeforeCheck(self) -> None:
"""
If the lock is initially held but then released before it can be
examined to determine if the process which held it still exists, it is
acquired and the C{clean} and C{locked} attributes are set to C{True}.
"""
def fakeReadlink(name: str) -> str:
# Pretend to be another process releasing the lock.
lockfile.rmlink(lockf)
# Fall back to the real implementation of readlink.
readlinkPatch.restore()
return lockfile.readlink(name)
readlinkPatch = self.patch(lockfile, "readlink", fakeReadlink)
def fakeKill(pid: int, signal: int) -> None:
if signal != 0:
raise OSError(errno.EPERM, None)
if pid == 43125:
raise OSError(errno.ESRCH, None)
self.patch(lockfile, "kill", fakeKill)
lockf = self.mktemp()
lock = lockfile.FilesystemLock(lockf)
lockfile.symlink(str(43125), lockf)
self.assertTrue(lock.lock())
self.assertTrue(lock.clean)
self.assertTrue(lock.locked)
@skipUnless(
platform.isWindows(),
"special rename EIO handling only necessary and correct on " "Windows.",
)
def test_lockReleasedDuringAcquireSymlink(self) -> None:
"""
If the lock is released while an attempt is made to acquire
it, the lock attempt fails and C{FilesystemLock.lock} returns
C{False}. This can happen on Windows when L{lockfile.symlink}
fails with L{IOError} of C{EIO} because another process is in
the middle of a call to L{os.rmdir} (implemented in terms of
RemoveDirectory) which is not atomic.
"""
def fakeSymlink(src: str, dst: str) -> NoReturn:
# While another process id doing os.rmdir which the Windows
# implementation of rmlink does, a rename call will fail with EIO.
raise OSError(errno.EIO, None)
self.patch(lockfile, "symlink", fakeSymlink)
lockf = self.mktemp()
lock = lockfile.FilesystemLock(lockf)
self.assertFalse(lock.lock())
self.assertFalse(lock.locked)
@skipUnless(
platform.isWindows(),
"special readlink EACCES handling only necessary and " "correct on Windows.",
)
def test_lockReleasedDuringAcquireReadlink(self) -> None:
"""
If the lock is initially held but is released while an attempt
is made to acquire it, the lock attempt fails and
L{FilesystemLock.lock} returns C{False}.
"""
def fakeReadlink(name: str) -> NoReturn:
# While another process is doing os.rmdir which the
# Windows implementation of rmlink does, a readlink call
# will fail with EACCES.
raise OSError(errno.EACCES, None)
self.patch(lockfile, "readlink", fakeReadlink)
lockf = self.mktemp()
lock = lockfile.FilesystemLock(lockf)
lockfile.symlink(str(43125), lockf)
self.assertFalse(lock.lock())
self.assertFalse(lock.locked)
def _readlinkErrorTest(
self, exceptionType: type[OSError] | type[IOError], errno: int
) -> None:
def fakeReadlink(name: str) -> NoReturn:
raise exceptionType(errno, None)
self.patch(lockfile, "readlink", fakeReadlink)
lockf = self.mktemp()
# Make it appear locked so it has to use readlink
lockfile.symlink(str(43125), lockf)
lock = lockfile.FilesystemLock(lockf)
exc = self.assertRaises(exceptionType, lock.lock)
self.assertEqual(exc.errno, errno)
self.assertFalse(lock.locked)
def test_readlinkError(self) -> None:
"""
An exception raised by C{readlink} other than C{ENOENT} is passed up to
the caller of L{FilesystemLock.lock}.
"""
self._readlinkErrorTest(OSError, errno.ENOSYS)
self._readlinkErrorTest(IOError, errno.ENOSYS)
@skipIf(
platform.isWindows(),
"POSIX-specific error propagation not expected on Windows.",
)
def test_readlinkErrorPOSIX(self) -> None:
"""
Any L{IOError} raised by C{readlink} on a POSIX platform passed to the
caller of L{FilesystemLock.lock}.
On POSIX, unlike on Windows, these are unexpected errors which cannot
be handled by L{FilesystemLock}.
"""
self._readlinkErrorTest(IOError, errno.ENOSYS)
self._readlinkErrorTest(IOError, errno.EACCES)
def test_lockCleanedUpConcurrently(self) -> None:
"""
If a second process cleans up the lock after a first one checks the
lock and finds that no process is holding it, the first process does
not fail when it tries to clean up the lock.
"""
def fakeRmlink(name: str) -> None:
rmlinkPatch.restore()
# Pretend to be another process cleaning up the lock.
lockfile.rmlink(lockf)
# Fall back to the real implementation of rmlink.
return lockfile.rmlink(name)
rmlinkPatch = self.patch(lockfile, "rmlink", fakeRmlink)
def fakeKill(pid: int, signal: int) -> None:
if signal != 0:
raise OSError(errno.EPERM, None)
if pid == 43125:
raise OSError(errno.ESRCH, None)
self.patch(lockfile, "kill", fakeKill)
lockf = self.mktemp()
lock = lockfile.FilesystemLock(lockf)
lockfile.symlink(str(43125), lockf)
self.assertTrue(lock.lock())
self.assertTrue(lock.clean)
self.assertTrue(lock.locked)
def test_rmlinkError(self) -> None:
"""
An exception raised by L{rmlink} other than C{ENOENT} is passed up
to the caller of L{FilesystemLock.lock}.
"""
def fakeRmlink(name: str) -> NoReturn:
raise OSError(errno.ENOSYS, None)
self.patch(lockfile, "rmlink", fakeRmlink)
def fakeKill(pid: int, signal: int) -> None:
if signal != 0:
raise OSError(errno.EPERM, None)
if pid == 43125:
raise OSError(errno.ESRCH, None)
self.patch(lockfile, "kill", fakeKill)
lockf = self.mktemp()
# Make it appear locked so it has to use readlink
lockfile.symlink(str(43125), lockf)
lock = lockfile.FilesystemLock(lockf)
exc = self.assertRaises(OSError, lock.lock)
self.assertEqual(exc.errno, errno.ENOSYS)
self.assertFalse(lock.locked)
def test_killError(self) -> None:
"""
If L{kill} raises an exception other than L{OSError} with errno set to
C{ESRCH}, the exception is passed up to the caller of
L{FilesystemLock.lock}.
"""
def fakeKill(pid: int, signal: int) -> NoReturn:
raise OSError(errno.EPERM, None)
self.patch(lockfile, "kill", fakeKill)
lockf = self.mktemp()
# Make it appear locked so it has to use readlink
lockfile.symlink(str(43125), lockf)
lock = lockfile.FilesystemLock(lockf)
exc = self.assertRaises(OSError, lock.lock)
self.assertEqual(exc.errno, errno.EPERM)
self.assertFalse(lock.locked)
def test_unlockOther(self) -> None:
"""
L{FilesystemLock.unlock} raises L{ValueError} if called for a lock
which is held by a different process.
"""
lockf = self.mktemp()
lockfile.symlink(str(os.getpid() + 1), lockf)
lock = lockfile.FilesystemLock(lockf)
self.assertRaises(ValueError, lock.unlock)
def test_isLocked(self) -> None:
"""
L{isLocked} returns C{True} if the named lock is currently locked,
C{False} otherwise.
"""
lockf = self.mktemp()
self.assertFalse(lockfile.isLocked(lockf))
lock = lockfile.FilesystemLock(lockf)
self.assertTrue(lock.lock())
self.assertTrue(lockfile.isLocked(lockf))
lock.unlock()
self.assertFalse(lockfile.isLocked(lockf))

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,534 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
from __future__ import annotations
import contextlib
import errno
import os
import stat
import time
from unittest import skipIf
from twisted.python import logfile, runtime
from twisted.trial.unittest import TestCase
class LogFileTests(TestCase):
"""
Test the rotating log file.
"""
def setUp(self) -> None:
self.dir = self.mktemp()
os.makedirs(self.dir)
self.name = "test.log"
self.path = os.path.join(self.dir, self.name)
def tearDown(self) -> None:
"""
Restore back write rights on created paths: if tests modified the
rights, that will allow the paths to be removed easily afterwards.
"""
os.chmod(self.dir, 0o777)
if os.path.exists(self.path):
os.chmod(self.path, 0o777)
def test_abstractShouldRotate(self) -> None:
"""
L{BaseLogFile.shouldRotate} is abstract and must be implemented by
subclass.
"""
log = logfile.BaseLogFile(self.name, self.dir)
self.addCleanup(log.close)
self.assertRaises(NotImplementedError, log.shouldRotate)
def test_writing(self) -> None:
"""
Log files can be written to, flushed and closed. Closing a log file
also flushes it.
"""
with contextlib.closing(logfile.LogFile(self.name, self.dir)) as log:
log.write("123")
log.write("456")
log.flush()
log.write("7890")
with open(self.path) as f:
self.assertEqual(f.read(), "1234567890")
def test_rotation(self) -> None:
"""
Rotating log files autorotate after a period of time, and can also be
manually rotated.
"""
# this logfile should rotate every 10 bytes
with contextlib.closing(
logfile.LogFile(self.name, self.dir, rotateLength=10)
) as log:
# test automatic rotation
log.write("123")
log.write("4567890")
log.write("1" * 11)
self.assertTrue(os.path.exists(f"{self.path}.1"))
self.assertFalse(os.path.exists(f"{self.path}.2"))
log.write("")
self.assertTrue(os.path.exists(f"{self.path}.1"))
self.assertTrue(os.path.exists(f"{self.path}.2"))
self.assertFalse(os.path.exists(f"{self.path}.3"))
log.write("3")
self.assertFalse(os.path.exists(f"{self.path}.3"))
# test manual rotation
log.rotate()
self.assertTrue(os.path.exists(f"{self.path}.3"))
self.assertFalse(os.path.exists(f"{self.path}.4"))
self.assertEqual(log.listLogs(), [1, 2, 3])
def test_append(self) -> None:
"""
Log files can be written to, closed. Their size is the number of
bytes written to them. Everything that was written to them can
be read, even if the writing happened on separate occasions,
and even if the log file was closed in between.
"""
with contextlib.closing(logfile.LogFile(self.name, self.dir)) as log:
log.write("0123456789")
log = logfile.LogFile(self.name, self.dir)
self.addCleanup(log.close)
self.assertEqual(log.size, 10)
self.assertEqual(log._file.tell(), log.size)
log.write("abc")
log.write(b"def\xff")
expectResult = b"0123456789abcdef\xff"
self.assertEqual(log.size, len(expectResult))
self.assertEqual(log._file.tell(), log.size)
f = log._file
f.seek(0, 0)
self.assertEqual(f.read(), expectResult)
def test_logReader(self) -> None:
"""
Various tests for log readers.
First of all, log readers can get logs by number and read what
was written to those log files. Getting nonexistent log files
raises C{ValueError}. Using anything other than an integer
index raises C{TypeError}. As logs get older, their log
numbers increase.
"""
log = logfile.LogFile(self.name, self.dir)
self.addCleanup(log.close)
log.write("abc\n")
log.write("def\n")
log.rotate()
log.write("ghi\n")
log.flush()
# check reading logs
self.assertEqual(log.listLogs(), [1])
with contextlib.closing(log.getCurrentLog()) as reader:
reader._file.seek(0)
self.assertEqual(reader.readLines(), ["ghi\n"])
self.assertEqual(reader.readLines(), [])
with contextlib.closing(log.getLog(1)) as reader:
self.assertEqual(reader.readLines(), ["abc\n", "def\n"])
self.assertEqual(reader.readLines(), [])
# check getting illegal log readers
self.assertRaises(ValueError, log.getLog, 2)
self.assertRaises(TypeError, log.getLog, "1")
# check that log numbers are higher for older logs
log.rotate()
self.assertEqual(log.listLogs(), [1, 2])
with contextlib.closing(log.getLog(1)) as reader:
reader._file.seek(0)
self.assertEqual(reader.readLines(), ["ghi\n"])
self.assertEqual(reader.readLines(), [])
with contextlib.closing(log.getLog(2)) as reader:
self.assertEqual(reader.readLines(), ["abc\n", "def\n"])
self.assertEqual(reader.readLines(), [])
def test_LogReaderReadsZeroLine(self) -> None:
"""
L{LogReader.readLines} supports reading no line.
"""
# We don't need any content, just a file path that can be opened.
with open(self.path, "w"):
pass
reader = logfile.LogReader(self.path)
self.addCleanup(reader.close)
self.assertEqual([], reader.readLines(0))
def test_modePreservation(self) -> None:
"""
Check rotated files have same permissions as original.
"""
open(self.path, "w").close()
os.chmod(self.path, 0o707)
mode = os.stat(self.path)[stat.ST_MODE]
log = logfile.LogFile(self.name, self.dir)
self.addCleanup(log.close)
log.write("abc")
log.rotate()
self.assertEqual(mode, os.stat(self.path)[stat.ST_MODE])
def test_noPermission(self) -> None:
"""
Check it keeps working when permission on dir changes.
"""
log = logfile.LogFile(self.name, self.dir)
self.addCleanup(log.close)
log.write("abc")
# change permissions so rotation would fail
os.chmod(self.dir, 0o555)
# if this succeeds, chmod doesn't restrict us, so we can't
# do the test
try:
f = open(os.path.join(self.dir, "xxx"), "w")
except OSError:
pass
else:
f.close()
return
log.rotate() # this should not fail
log.write("def")
log.flush()
f = log._file
self.assertEqual(f.tell(), 6)
f.seek(0, 0)
self.assertEqual(f.read(), b"abcdef")
def test_maxNumberOfLog(self) -> None:
"""
Test it respect the limit on the number of files when maxRotatedFiles
is not None.
"""
log = logfile.LogFile(self.name, self.dir, rotateLength=10, maxRotatedFiles=3)
self.addCleanup(log.close)
log.write("1" * 11)
log.write("2" * 11)
self.assertTrue(os.path.exists(f"{self.path}.1"))
log.write("3" * 11)
self.assertTrue(os.path.exists(f"{self.path}.2"))
log.write("4" * 11)
self.assertTrue(os.path.exists(f"{self.path}.3"))
with open(f"{self.path}.3") as fp:
self.assertEqual(fp.read(), "1" * 11)
log.write("5" * 11)
with open(f"{self.path}.3") as fp:
self.assertEqual(fp.read(), "2" * 11)
self.assertFalse(os.path.exists(f"{self.path}.4"))
def test_fromFullPath(self) -> None:
"""
Test the fromFullPath method.
"""
log1 = logfile.LogFile(self.name, self.dir, 10, defaultMode=0o777)
self.addCleanup(log1.close)
log2 = logfile.LogFile.fromFullPath(self.path, 10, defaultMode=0o777)
self.addCleanup(log2.close)
self.assertEqual(log1.name, log2.name)
self.assertEqual(os.path.abspath(log1.path), log2.path)
self.assertEqual(log1.rotateLength, log2.rotateLength)
self.assertEqual(log1.defaultMode, log2.defaultMode)
def test_defaultPermissions(self) -> None:
"""
Test the default permission of the log file: if the file exist, it
should keep the permission.
"""
with open(self.path, "wb"):
os.chmod(self.path, 0o707)
currentMode = stat.S_IMODE(os.stat(self.path)[stat.ST_MODE])
log1 = logfile.LogFile(self.name, self.dir)
self.assertEqual(stat.S_IMODE(os.stat(self.path)[stat.ST_MODE]), currentMode)
self.addCleanup(log1.close)
def test_specifiedPermissions(self) -> None:
"""
Test specifying the permissions used on the log file.
"""
log1 = logfile.LogFile(self.name, self.dir, defaultMode=0o066)
self.addCleanup(log1.close)
mode = stat.S_IMODE(os.stat(self.path)[stat.ST_MODE])
if runtime.platform.isWindows():
# The only thing we can get here is global read-only
self.assertEqual(mode, 0o444)
else:
self.assertEqual(mode, 0o066)
@skipIf(runtime.platform.isWindows(), "Can't test reopen on Windows")
def test_reopen(self) -> None:
"""
L{logfile.LogFile.reopen} allows to rename the currently used file and
make L{logfile.LogFile} create a new file.
"""
with contextlib.closing(logfile.LogFile(self.name, self.dir)) as log1:
log1.write("hello1")
savePath = os.path.join(self.dir, "save.log")
os.rename(self.path, savePath)
log1.reopen()
log1.write("hello2")
with open(self.path) as f:
self.assertEqual(f.read(), "hello2")
with open(savePath) as f:
self.assertEqual(f.read(), "hello1")
def test_nonExistentDir(self) -> None:
"""
Specifying an invalid directory to L{LogFile} raises C{IOError}.
"""
e = self.assertRaises(
IOError, logfile.LogFile, self.name, "this_dir_does_not_exist"
)
self.assertEqual(e.errno, errno.ENOENT)
def test_cantChangeFileMode(self) -> None:
"""
Opening a L{LogFile} which can be read and write but whose mode can't
be changed doesn't trigger an error.
"""
if runtime.platform.isWindows():
name, directory = "NUL", ""
expectedPath = "NUL"
else:
name, directory = "null", "/dev"
expectedPath = "/dev/null"
log = logfile.LogFile(name, directory, defaultMode=0o555)
self.addCleanup(log.close)
self.assertEqual(log.path, expectedPath)
self.assertEqual(log.defaultMode, 0o555)
def test_listLogsWithBadlyNamedFiles(self) -> None:
"""
L{LogFile.listLogs} doesn't choke if it encounters a file with an
unexpected name.
"""
log = logfile.LogFile(self.name, self.dir)
self.addCleanup(log.close)
with open(f"{log.path}.1", "w") as fp:
fp.write("123")
with open(f"{log.path}.bad-file", "w") as fp:
fp.write("123")
self.assertEqual([1], log.listLogs())
def test_listLogsIgnoresZeroSuffixedFiles(self) -> None:
"""
L{LogFile.listLogs} ignores log files which rotated suffix is 0.
"""
log = logfile.LogFile(self.name, self.dir)
self.addCleanup(log.close)
for i in range(0, 3):
with open(f"{log.path}.{i}", "w") as fp:
fp.write("123")
self.assertEqual([1, 2], log.listLogs())
class RiggedDailyLogFile(logfile.DailyLogFile):
_clock = 0.0
def _openFile(self) -> None:
logfile.DailyLogFile._openFile(self)
# rig the date to match _clock, not mtime
self.lastDate = self.toDate()
def toDate(self, *args: float) -> tuple[int, int, int]:
if args:
return time.gmtime(*args)[:3]
return time.gmtime(self._clock)[:3]
class DailyLogFileTests(TestCase):
"""
Test rotating log file.
"""
def setUp(self) -> None:
self.dir = self.mktemp()
os.makedirs(self.dir)
self.name = "testdaily.log"
self.path = os.path.join(self.dir, self.name)
def test_writing(self) -> None:
"""
A daily log file can be written to like an ordinary log file.
"""
with contextlib.closing(RiggedDailyLogFile(self.name, self.dir)) as log:
log.write("123")
log.write("456")
log.flush()
log.write("7890")
with open(self.path) as f:
self.assertEqual(f.read(), "1234567890")
def test_rotation(self) -> None:
"""
Daily log files rotate daily.
"""
log = RiggedDailyLogFile(self.name, self.dir)
self.addCleanup(log.close)
days = [(self.path + "." + log.suffix(day * 86400)) for day in range(3)]
# test automatic rotation
log._clock = 0.0 # 1970/01/01 00:00.00
log.write("123")
log._clock = 43200 # 1970/01/01 12:00.00
log.write("4567890")
log._clock = 86400 # 1970/01/02 00:00.00
log.write("1" * 11)
self.assertTrue(os.path.exists(days[0]))
self.assertFalse(os.path.exists(days[1]))
log._clock = 172800 # 1970/01/03 00:00.00
log.write("")
self.assertTrue(os.path.exists(days[0]))
self.assertTrue(os.path.exists(days[1]))
self.assertFalse(os.path.exists(days[2]))
log._clock = 259199 # 1970/01/03 23:59.59
log.write("3")
self.assertFalse(os.path.exists(days[2]))
def test_getLog(self) -> None:
"""
Test retrieving log files with L{DailyLogFile.getLog}.
"""
data = ["1\n", "2\n", "3\n"]
log = RiggedDailyLogFile(self.name, self.dir)
self.addCleanup(log.close)
for d in data:
log.write(d)
log.flush()
# This returns the current log file.
r = log.getLog(0.0)
self.addCleanup(r.close)
self.assertEqual(data, r.readLines())
# We can't get this log, it doesn't exist yet.
self.assertRaises(ValueError, log.getLog, 86400)
log._clock = 86401 # New day
r.close()
log.rotate()
r = log.getLog(0) # We get the previous log
self.addCleanup(r.close)
self.assertEqual(data, r.readLines())
def test_rotateAlreadyExists(self) -> None:
"""
L{DailyLogFile.rotate} doesn't do anything if they new log file already
exists on the disk.
"""
log = RiggedDailyLogFile(self.name, self.dir)
self.addCleanup(log.close)
# Build a new file with the same name as the file which would be created
# if the log file is to be rotated.
newFilePath = f"{log.path}.{log.suffix(log.lastDate)}"
with open(newFilePath, "w") as fp:
fp.write("123")
previousFile = log._file
log.rotate()
self.assertEqual(previousFile, log._file)
@skipIf(
runtime.platform.isWindows(),
"Making read-only directories on Windows is too complex for this "
"test to reasonably do.",
)
def test_rotatePermissionDirectoryNotOk(self) -> None:
"""
L{DailyLogFile.rotate} doesn't do anything if the directory containing
the log files can't be written to.
"""
log = logfile.DailyLogFile(self.name, self.dir)
self.addCleanup(log.close)
os.chmod(log.directory, 0o444)
# Restore permissions so tests can be cleaned up.
self.addCleanup(os.chmod, log.directory, 0o755)
previousFile = log._file
log.rotate()
self.assertEqual(previousFile, log._file)
def test_rotatePermissionFileNotOk(self) -> None:
"""
L{DailyLogFile.rotate} doesn't do anything if the log file can't be
written to.
"""
log = logfile.DailyLogFile(self.name, self.dir)
self.addCleanup(log.close)
os.chmod(log.path, 0o444)
previousFile = log._file
log.rotate()
self.assertEqual(previousFile, log._file)
def test_toDate(self) -> None:
"""
Test that L{DailyLogFile.toDate} converts its timestamp argument to a
time tuple (year, month, day).
"""
log = logfile.DailyLogFile(self.name, self.dir)
self.addCleanup(log.close)
timestamp = time.mktime((2000, 1, 1, 0, 0, 0, 0, 0, 0))
self.assertEqual((2000, 1, 1), log.toDate(timestamp))
def test_toDateDefaultToday(self) -> None:
"""
Test that L{DailyLogFile.toDate} returns today's date by default.
By mocking L{time.localtime}, we ensure that L{DailyLogFile.toDate}
returns the first 3 values of L{time.localtime} which is the current
date.
Note that we don't compare the *real* result of L{DailyLogFile.toDate}
to the *real* current date, as there's a slight possibility that the
date changes between the 2 function calls.
"""
def mock_localtime(*args: object) -> list[int]:
self.assertEqual((), args)
return list(range(0, 9))
log = logfile.DailyLogFile(self.name, self.dir)
self.addCleanup(log.close)
self.patch(time, "localtime", mock_localtime)
logDate = log.toDate()
self.assertEqual([0, 1, 2], logDate)
def test_toDateUsesArgumentsToMakeADate(self) -> None:
"""
Test that L{DailyLogFile.toDate} uses its arguments to create a new
date.
"""
log = logfile.DailyLogFile(self.name, self.dir)
self.addCleanup(log.close)
date = (2014, 10, 22)
seconds = time.mktime(date + (0,) * 6)
logDate = log.toDate(seconds)
self.assertEqual(date, logDate)

View File

@@ -0,0 +1,464 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Test case for L{twisted.protocols.loopback}.
"""
from zope.interface import implementer
from twisted.internet import defer, interfaces, reactor
from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IAddress, IPullProducer, IPushProducer
from twisted.internet.protocol import Protocol
from twisted.protocols import basic, loopback
from twisted.trial import unittest
class SimpleProtocol(basic.LineReceiver):
def __init__(self):
self.conn = defer.Deferred()
self.lines = []
self.connLost = []
def connectionMade(self):
self.conn.callback(None)
def lineReceived(self, line):
self.lines.append(line)
def connectionLost(self, reason):
self.connLost.append(reason)
class DoomProtocol(SimpleProtocol):
i = 0
def lineReceived(self, line):
self.i += 1
if self.i < 4:
# by this point we should have connection closed,
# but just in case we didn't we won't ever send 'Hello 4'
self.sendLine(b"Hello %d" % (self.i,))
SimpleProtocol.lineReceived(self, line)
if self.lines[-1] == b"Hello 3":
self.transport.loseConnection()
class LoopbackTestCaseMixin:
def testRegularFunction(self):
s = SimpleProtocol()
c = SimpleProtocol()
def sendALine(result):
s.sendLine(b"THIS IS LINE ONE!")
s.transport.loseConnection()
s.conn.addCallback(sendALine)
def check(ignored):
self.assertEqual(c.lines, [b"THIS IS LINE ONE!"])
self.assertEqual(len(s.connLost), 1)
self.assertEqual(len(c.connLost), 1)
d = defer.maybeDeferred(self.loopbackFunc, s, c)
d.addCallback(check)
return d
def testSneakyHiddenDoom(self):
s = DoomProtocol()
c = DoomProtocol()
def sendALine(result):
s.sendLine(b"DOOM LINE")
s.conn.addCallback(sendALine)
def check(ignored):
self.assertEqual(s.lines, [b"Hello 1", b"Hello 2", b"Hello 3"])
self.assertEqual(
c.lines, [b"DOOM LINE", b"Hello 1", b"Hello 2", b"Hello 3"]
)
self.assertEqual(len(s.connLost), 1)
self.assertEqual(len(c.connLost), 1)
d = defer.maybeDeferred(self.loopbackFunc, s, c)
d.addCallback(check)
return d
class LoopbackAsyncTests(LoopbackTestCaseMixin, unittest.TestCase):
loopbackFunc = staticmethod(loopback.loopbackAsync)
def test_makeConnection(self):
"""
Test that the client and server protocol both have makeConnection
invoked on them by loopbackAsync.
"""
class TestProtocol(Protocol):
transport = None
def makeConnection(self, transport):
self.transport = transport
server = TestProtocol()
client = TestProtocol()
loopback.loopbackAsync(server, client)
self.assertIsNotNone(client.transport)
self.assertIsNotNone(server.transport)
def _hostpeertest(self, get, testServer):
"""
Test one of the permutations of client/server host/peer.
"""
class TestProtocol(Protocol):
def makeConnection(self, transport):
Protocol.makeConnection(self, transport)
self.onConnection.callback(transport)
if testServer:
server = TestProtocol()
d = server.onConnection = Deferred()
client = Protocol()
else:
server = Protocol()
client = TestProtocol()
d = client.onConnection = Deferred()
loopback.loopbackAsync(server, client)
def connected(transport):
host = getattr(transport, get)()
self.assertTrue(IAddress.providedBy(host))
return d.addCallback(connected)
def test_serverHost(self):
"""
Test that the server gets a transport with a properly functioning
implementation of L{ITransport.getHost}.
"""
return self._hostpeertest("getHost", True)
def test_serverPeer(self):
"""
Like C{test_serverHost} but for L{ITransport.getPeer}
"""
return self._hostpeertest("getPeer", True)
def test_clientHost(self, get="getHost"):
"""
Test that the client gets a transport with a properly functioning
implementation of L{ITransport.getHost}.
"""
return self._hostpeertest("getHost", False)
def test_clientPeer(self):
"""
Like C{test_clientHost} but for L{ITransport.getPeer}.
"""
return self._hostpeertest("getPeer", False)
def _greetingtest(self, write, testServer):
"""
Test one of the permutations of write/writeSequence client/server.
@param write: The name of the method to test, C{"write"} or
C{"writeSequence"}.
"""
class GreeteeProtocol(Protocol):
bytes = b""
def dataReceived(self, bytes):
self.bytes += bytes
if self.bytes == b"bytes":
self.received.callback(None)
class GreeterProtocol(Protocol):
def connectionMade(self):
if write == "write":
self.transport.write(b"bytes")
else:
self.transport.writeSequence([b"byt", b"es"])
if testServer:
server = GreeterProtocol()
client = GreeteeProtocol()
d = client.received = Deferred()
else:
server = GreeteeProtocol()
d = server.received = Deferred()
client = GreeterProtocol()
loopback.loopbackAsync(server, client)
return d
def test_clientGreeting(self):
"""
Test that on a connection where the client speaks first, the server
receives the bytes sent by the client.
"""
return self._greetingtest("write", False)
def test_clientGreetingSequence(self):
"""
Like C{test_clientGreeting}, but use C{writeSequence} instead of
C{write} to issue the greeting.
"""
return self._greetingtest("writeSequence", False)
def test_serverGreeting(self, write="write"):
"""
Test that on a connection where the server speaks first, the client
receives the bytes sent by the server.
"""
return self._greetingtest("write", True)
def test_serverGreetingSequence(self):
"""
Like C{test_serverGreeting}, but use C{writeSequence} instead of
C{write} to issue the greeting.
"""
return self._greetingtest("writeSequence", True)
def _producertest(self, producerClass):
toProduce = [b"%d" % (i,) for i in range(0, 10)]
class ProducingProtocol(Protocol):
def connectionMade(self):
self.producer = producerClass(list(toProduce))
self.producer.start(self.transport)
class ReceivingProtocol(Protocol):
bytes = b""
def dataReceived(self, data):
self.bytes += data
if self.bytes == b"".join(toProduce):
self.received.callback((client, server))
server = ProducingProtocol()
client = ReceivingProtocol()
client.received = Deferred()
loopback.loopbackAsync(server, client)
return client.received
def test_pushProducer(self):
"""
Test a push producer registered against a loopback transport.
"""
@implementer(IPushProducer)
class PushProducer:
resumed = False
def __init__(self, toProduce):
self.toProduce = toProduce
def resumeProducing(self):
self.resumed = True
def start(self, consumer):
self.consumer = consumer
consumer.registerProducer(self, True)
self._produceAndSchedule()
def _produceAndSchedule(self):
if self.toProduce:
self.consumer.write(self.toProduce.pop(0))
reactor.callLater(0, self._produceAndSchedule)
else:
self.consumer.unregisterProducer()
d = self._producertest(PushProducer)
def finished(results):
(client, server) = results
self.assertFalse(
server.producer.resumed,
"Streaming producer should not have been resumed.",
)
d.addCallback(finished)
return d
def test_pullProducer(self):
"""
Test a pull producer registered against a loopback transport.
"""
@implementer(IPullProducer)
class PullProducer:
def __init__(self, toProduce):
self.toProduce = toProduce
def start(self, consumer):
self.consumer = consumer
self.consumer.registerProducer(self, False)
def resumeProducing(self):
self.consumer.write(self.toProduce.pop(0))
if not self.toProduce:
self.consumer.unregisterProducer()
return self._producertest(PullProducer)
def test_writeNotReentrant(self):
"""
L{loopback.loopbackAsync} does not call a protocol's C{dataReceived}
method while that protocol's transport's C{write} method is higher up
on the stack.
"""
class Server(Protocol):
def dataReceived(self, bytes):
self.transport.write(b"bytes")
class Client(Protocol):
ready = False
def connectionMade(self):
reactor.callLater(0, self.go)
def go(self):
self.transport.write(b"foo")
self.ready = True
def dataReceived(self, bytes):
self.wasReady = self.ready
self.transport.loseConnection()
server = Server()
client = Client()
d = loopback.loopbackAsync(client, server)
def cbFinished(ignored):
self.assertTrue(client.wasReady)
d.addCallback(cbFinished)
return d
def test_pumpPolicy(self):
"""
The callable passed as the value for the C{pumpPolicy} parameter to
L{loopbackAsync} is called with a L{_LoopbackQueue} of pending bytes
and a protocol to which they should be delivered.
"""
pumpCalls = []
def dummyPolicy(queue, target):
bytes = []
while queue:
bytes.append(queue.get())
pumpCalls.append((target, bytes))
client = Protocol()
server = Protocol()
finished = loopback.loopbackAsync(server, client, dummyPolicy)
self.assertEqual(pumpCalls, [])
client.transport.write(b"foo")
client.transport.write(b"bar")
server.transport.write(b"baz")
server.transport.write(b"quux")
server.transport.loseConnection()
def cbComplete(ignored):
self.assertEqual(
pumpCalls,
# The order here is somewhat arbitrary. The implementation
# happens to always deliver data to the client first.
[(client, [b"baz", b"quux", None]), (server, [b"foo", b"bar"])],
)
finished.addCallback(cbComplete)
return finished
def test_identityPumpPolicy(self):
"""
L{identityPumpPolicy} is a pump policy which calls the target's
C{dataReceived} method one for each string in the queue passed to it.
"""
bytes = []
client = Protocol()
client.dataReceived = bytes.append
queue = loopback._LoopbackQueue()
queue.put(b"foo")
queue.put(b"bar")
queue.put(None)
loopback.identityPumpPolicy(queue, client)
self.assertEqual(bytes, [b"foo", b"bar"])
def test_collapsingPumpPolicy(self):
"""
L{collapsingPumpPolicy} is a pump policy which calls the target's
C{dataReceived} only once with all of the strings in the queue passed
to it joined together.
"""
bytes = []
client = Protocol()
client.dataReceived = bytes.append
queue = loopback._LoopbackQueue()
queue.put(b"foo")
queue.put(b"bar")
queue.put(None)
loopback.collapsingPumpPolicy(queue, client)
self.assertEqual(bytes, [b"foobar"])
class LoopbackTCPTests(LoopbackTestCaseMixin, unittest.TestCase):
loopbackFunc = staticmethod(loopback.loopbackTCP)
class LoopbackUNIXTests(LoopbackTestCaseMixin, unittest.TestCase):
loopbackFunc = staticmethod(loopback.loopbackUNIX)
if interfaces.IReactorUNIX(reactor, None) is None:
skip = "Current reactor does not support UNIX sockets"
class LoopbackRelayTest(unittest.TestCase):
"""
Test for L{twisted.protocols.loopback.LoopbackRelay}
"""
class Receiver(Protocol):
"""
Simple Receiver class used for testing LoopbackRelay
"""
data = b""
def dataReceived(self, data):
"Accumulate received data for verification"
self.data += data
def test_write(self):
"Test to verify that the write function works as expected"
receiver = self.Receiver()
relay = loopback.LoopbackRelay(receiver)
relay.write(b"abc")
relay.write(b"def")
self.assertEqual(receiver.data, b"")
relay.clearBuffer()
self.assertEqual(receiver.data, b"abcdef")
def test_writeSequence(self):
"Test to verify that the writeSequence function works as expected"
receiver = self.Receiver()
relay = loopback.LoopbackRelay(receiver)
relay.writeSequence([b"The ", b"quick ", b"brown ", b"fox "])
relay.writeSequence([b"jumps ", b"over ", b"the lazy dog"])
self.assertEqual(receiver.data, b"")
relay.clearBuffer()
self.assertEqual(receiver.data, b"The quick brown fox jumps over the lazy dog")

View File

@@ -0,0 +1,74 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Test that twisted scripts can be invoked as modules.
"""
import sys
from io import StringIO
from twisted.internet import defer, reactor
from twisted.test.test_process import Accumulator
from twisted.trial.unittest import TestCase
class MainTests(TestCase):
"""Test that twisted scripts can be invoked as modules."""
def test_twisted(self):
"""Invoking python -m twisted should execute twist."""
cmd = sys.executable
p = Accumulator()
d = p.endedDeferred = defer.Deferred()
reactor.spawnProcess(p, cmd, [cmd, "-m", "twisted", "--help"], env=None)
p.transport.closeStdin()
# Fix up our sys args to match the command we issued
from twisted import __main__
self.patch(sys, "argv", [__main__.__file__, "--help"])
def processEnded(ign):
f = p.outF
output = f.getvalue()
self.assertTrue(
b"-m twisted [options] plugin [plugin_options]" in output, output
)
return d.addCallback(processEnded)
def test_trial(self):
"""Invoking python -m twisted.trial should execute trial."""
cmd = sys.executable
p = Accumulator()
d = p.endedDeferred = defer.Deferred()
reactor.spawnProcess(p, cmd, [cmd, "-m", "twisted.trial", "--help"], env=None)
p.transport.closeStdin()
# Fix up our sys args to match the command we issued
from twisted.trial import __main__
self.patch(sys, "argv", [__main__.__file__, "--help"])
def processEnded(ign):
f = p.outF
output = f.getvalue()
self.assertTrue(b"-j, --jobs= " in output, output)
return d.addCallback(processEnded)
def test_twisted_import(self):
"""Importing twisted.__main__ does not execute twist."""
output = StringIO()
monkey = self.patch(sys, "stdout", output)
import twisted.__main__
self.assertTrue(twisted.__main__) # Appease pyflakes
monkey.restore()
self.assertEqual(output.getvalue(), "")

View File

@@ -0,0 +1,714 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Test the memcache client protocol.
"""
from twisted.internet.defer import Deferred, DeferredList, TimeoutError, gatherResults
from twisted.internet.error import ConnectionDone
from twisted.internet.task import Clock
from twisted.internet.testing import StringTransportWithDisconnection
from twisted.protocols.memcache import (
ClientError,
MemCacheProtocol,
NoSuchCommand,
ServerError,
)
from twisted.trial.unittest import TestCase
class CommandMixin:
"""
Setup and tests for basic invocation of L{MemCacheProtocol} commands.
"""
def _test(self, d, send, recv, result):
"""
Helper test method to test the resulting C{Deferred} of a
L{MemCacheProtocol} command.
"""
raise NotImplementedError()
def test_get(self):
"""
L{MemCacheProtocol.get} returns a L{Deferred} which is called back with
the value and the flag associated with the given key if the server
returns a successful result.
"""
return self._test(
self.proto.get(b"foo"),
b"get foo\r\n",
b"VALUE foo 0 3\r\nbar\r\nEND\r\n",
(0, b"bar"),
)
def test_emptyGet(self):
"""
Test getting a non-available key: it succeeds but return L{None} as
value and C{0} as flag.
"""
return self._test(self.proto.get(b"foo"), b"get foo\r\n", b"END\r\n", (0, None))
def test_getMultiple(self):
"""
L{MemCacheProtocol.getMultiple} returns a L{Deferred} which is called
back with a dictionary of flag, value for each given key.
"""
return self._test(
self.proto.getMultiple([b"foo", b"cow"]),
b"get foo cow\r\n",
b"VALUE foo 0 3\r\nbar\r\nVALUE cow 0 7\r\nchicken\r\nEND\r\n",
{b"cow": (0, b"chicken"), b"foo": (0, b"bar")},
)
def test_getMultipleWithEmpty(self):
"""
When L{MemCacheProtocol.getMultiple} is called with non-available keys,
the corresponding tuples are (0, None).
"""
return self._test(
self.proto.getMultiple([b"foo", b"cow"]),
b"get foo cow\r\n",
b"VALUE cow 1 3\r\nbar\r\nEND\r\n",
{b"cow": (1, b"bar"), b"foo": (0, None)},
)
def test_set(self):
"""
L{MemCacheProtocol.set} returns a L{Deferred} which is called back with
C{True} when the operation succeeds.
"""
return self._test(
self.proto.set(b"foo", b"bar"),
b"set foo 0 0 3\r\nbar\r\n",
b"STORED\r\n",
True,
)
def test_add(self):
"""
L{MemCacheProtocol.add} returns a L{Deferred} which is called back with
C{True} when the operation succeeds.
"""
return self._test(
self.proto.add(b"foo", b"bar"),
b"add foo 0 0 3\r\nbar\r\n",
b"STORED\r\n",
True,
)
def test_replace(self):
"""
L{MemCacheProtocol.replace} returns a L{Deferred} which is called back
with C{True} when the operation succeeds.
"""
return self._test(
self.proto.replace(b"foo", b"bar"),
b"replace foo 0 0 3\r\nbar\r\n",
b"STORED\r\n",
True,
)
def test_errorAdd(self):
"""
Test an erroneous add: if a L{MemCacheProtocol.add} is called but the
key already exists on the server, it returns a B{NOT STORED} answer,
which calls back the resulting L{Deferred} with C{False}.
"""
return self._test(
self.proto.add(b"foo", b"bar"),
b"add foo 0 0 3\r\nbar\r\n",
b"NOT STORED\r\n",
False,
)
def test_errorReplace(self):
"""
Test an erroneous replace: if a L{MemCacheProtocol.replace} is called
but the key doesn't exist on the server, it returns a B{NOT STORED}
answer, which calls back the resulting L{Deferred} with C{False}.
"""
return self._test(
self.proto.replace(b"foo", b"bar"),
b"replace foo 0 0 3\r\nbar\r\n",
b"NOT STORED\r\n",
False,
)
def test_delete(self):
"""
L{MemCacheProtocol.delete} returns a L{Deferred} which is called back
with C{True} when the server notifies a success.
"""
return self._test(
self.proto.delete(b"bar"), b"delete bar\r\n", b"DELETED\r\n", True
)
def test_errorDelete(self):
"""
Test an error during a delete: if key doesn't exist on the server, it
returns a B{NOT FOUND} answer which calls back the resulting
L{Deferred} with C{False}.
"""
return self._test(
self.proto.delete(b"bar"), b"delete bar\r\n", b"NOT FOUND\r\n", False
)
def test_increment(self):
"""
Test incrementing a variable: L{MemCacheProtocol.increment} returns a
L{Deferred} which is called back with the incremented value of the
given key.
"""
return self._test(self.proto.increment(b"foo"), b"incr foo 1\r\n", b"4\r\n", 4)
def test_decrement(self):
"""
Test decrementing a variable: L{MemCacheProtocol.decrement} returns a
L{Deferred} which is called back with the decremented value of the
given key.
"""
return self._test(self.proto.decrement(b"foo"), b"decr foo 1\r\n", b"5\r\n", 5)
def test_incrementVal(self):
"""
L{MemCacheProtocol.increment} takes an optional argument C{value} which
replaces the default value of 1 when specified.
"""
return self._test(
self.proto.increment(b"foo", 8), b"incr foo 8\r\n", b"4\r\n", 4
)
def test_decrementVal(self):
"""
L{MemCacheProtocol.decrement} takes an optional argument C{value} which
replaces the default value of 1 when specified.
"""
return self._test(
self.proto.decrement(b"foo", 3), b"decr foo 3\r\n", b"5\r\n", 5
)
def test_stats(self):
"""
Test retrieving server statistics via the L{MemCacheProtocol.stats}
command: it parses the data sent by the server and calls back the
resulting L{Deferred} with a dictionary of the received statistics.
"""
return self._test(
self.proto.stats(),
b"stats\r\n",
b"STAT foo bar\r\nSTAT egg spam\r\nEND\r\n",
{b"foo": b"bar", b"egg": b"spam"},
)
def test_statsWithArgument(self):
"""
L{MemCacheProtocol.stats} takes an optional C{bytes} argument which,
if specified, is sent along with the I{STAT} command. The I{STAT}
responses from the server are parsed as key/value pairs and returned
as a C{dict} (as in the case where the argument is not specified).
"""
return self._test(
self.proto.stats(b"blah"),
b"stats blah\r\n",
b"STAT foo bar\r\nSTAT egg spam\r\nEND\r\n",
{b"foo": b"bar", b"egg": b"spam"},
)
def test_version(self):
"""
Test version retrieval via the L{MemCacheProtocol.version} command: it
returns a L{Deferred} which is called back with the version sent by the
server.
"""
return self._test(
self.proto.version(), b"version\r\n", b"VERSION 1.1\r\n", b"1.1"
)
def test_flushAll(self):
"""
L{MemCacheProtocol.flushAll} returns a L{Deferred} which is called back
with C{True} if the server acknowledges success.
"""
return self._test(self.proto.flushAll(), b"flush_all\r\n", b"OK\r\n", True)
class MemCacheTests(CommandMixin, TestCase):
"""
Test client protocol class L{MemCacheProtocol}.
"""
def setUp(self):
"""
Create a memcache client, connect it to a string protocol, and make it
use a deterministic clock.
"""
self.proto = MemCacheProtocol()
self.clock = Clock()
self.proto.callLater = self.clock.callLater
self.transport = StringTransportWithDisconnection()
self.transport.protocol = self.proto
self.proto.makeConnection(self.transport)
def _test(self, d, send, recv, result):
"""
Implementation of C{_test} which checks that the command sends C{send}
data, and that upon reception of C{recv} the result is C{result}.
@param d: the resulting deferred from the memcache command.
@type d: C{Deferred}
@param send: the expected data to be sent.
@type send: C{bytes}
@param recv: the data to simulate as reception.
@type recv: C{bytes}
@param result: the expected result.
@type result: C{any}
"""
def cb(res):
self.assertEqual(res, result)
self.assertEqual(self.transport.value(), send)
d.addCallback(cb)
self.proto.dataReceived(recv)
return d
def test_invalidGetResponse(self):
"""
If the value returned doesn't match the expected key of the current
C{get} command, an error is raised in L{MemCacheProtocol.dataReceived}.
"""
self.proto.get(b"foo")
self.assertRaises(
RuntimeError,
self.proto.dataReceived,
b"VALUE bar 0 7\r\nspamegg\r\nEND\r\n",
)
def test_invalidMultipleGetResponse(self):
"""
If the value returned doesn't match one the expected keys of the
current multiple C{get} command, an error is raised error in
L{MemCacheProtocol.dataReceived}.
"""
self.proto.getMultiple([b"foo", b"bar"])
self.assertRaises(
RuntimeError,
self.proto.dataReceived,
b"VALUE egg 0 7\r\nspamegg\r\nEND\r\n",
)
def test_invalidEndResponse(self):
"""
If an END is received in response to an operation that isn't C{get},
C{gets}, or C{stats}, an error is raised in
L{MemCacheProtocol.dataReceived}.
"""
self.proto.set(b"key", b"value")
self.assertRaises(RuntimeError, self.proto.dataReceived, b"END\r\n")
def test_timeOut(self):
"""
Test the timeout on outgoing requests: when timeout is detected, all
current commands fail with a L{TimeoutError}, and the connection is
closed.
"""
d1 = self.proto.get(b"foo")
d2 = self.proto.get(b"bar")
d3 = Deferred()
self.proto.connectionLost = d3.callback
self.clock.advance(self.proto.persistentTimeOut)
self.assertFailure(d1, TimeoutError)
self.assertFailure(d2, TimeoutError)
def checkMessage(error):
self.assertEqual(str(error), "Connection timeout")
d1.addCallback(checkMessage)
self.assertFailure(d3, ConnectionDone)
return gatherResults([d1, d2, d3])
def test_timeoutRemoved(self):
"""
When a request gets a response, no pending timeout call remains around.
"""
d = self.proto.get(b"foo")
self.clock.advance(self.proto.persistentTimeOut - 1)
self.proto.dataReceived(b"VALUE foo 0 3\r\nbar\r\nEND\r\n")
def check(result):
self.assertEqual(result, (0, b"bar"))
self.assertEqual(len(self.clock.calls), 0)
d.addCallback(check)
return d
def test_timeOutRaw(self):
"""
Test the timeout when raw mode was started: the timeout is not reset
until all the data has been received, so we can have a L{TimeoutError}
when waiting for raw data.
"""
d1 = self.proto.get(b"foo")
d2 = Deferred()
self.proto.connectionLost = d2.callback
self.proto.dataReceived(b"VALUE foo 0 10\r\n12345")
self.clock.advance(self.proto.persistentTimeOut)
self.assertFailure(d1, TimeoutError)
self.assertFailure(d2, ConnectionDone)
return gatherResults([d1, d2])
def test_timeOutStat(self):
"""
Test the timeout when stat command has started: the timeout is not
reset until the final B{END} is received.
"""
d1 = self.proto.stats()
d2 = Deferred()
self.proto.connectionLost = d2.callback
self.proto.dataReceived(b"STAT foo bar\r\n")
self.clock.advance(self.proto.persistentTimeOut)
self.assertFailure(d1, TimeoutError)
self.assertFailure(d2, ConnectionDone)
return gatherResults([d1, d2])
def test_timeoutPipelining(self):
"""
When two requests are sent, a timeout call remains around for the
second request, and its timeout time is correct.
"""
d1 = self.proto.get(b"foo")
d2 = self.proto.get(b"bar")
d3 = Deferred()
self.proto.connectionLost = d3.callback
self.clock.advance(self.proto.persistentTimeOut - 1)
self.proto.dataReceived(b"VALUE foo 0 3\r\nbar\r\nEND\r\n")
def check(result):
self.assertEqual(result, (0, b"bar"))
self.assertEqual(len(self.clock.calls), 1)
for i in range(self.proto.persistentTimeOut):
self.clock.advance(1)
return self.assertFailure(d2, TimeoutError).addCallback(checkTime)
def checkTime(ignored):
# Check that the timeout happened C{self.proto.persistentTimeOut}
# after the last response
self.assertEqual(self.clock.seconds(), 2 * self.proto.persistentTimeOut - 1)
d1.addCallback(check)
self.assertFailure(d3, ConnectionDone)
return d1
def test_timeoutNotReset(self):
"""
Check that timeout is not resetted for every command, but keep the
timeout from the first command without response.
"""
d1 = self.proto.get(b"foo")
d3 = Deferred()
self.proto.connectionLost = d3.callback
self.clock.advance(self.proto.persistentTimeOut - 1)
d2 = self.proto.get(b"bar")
self.clock.advance(1)
self.assertFailure(d1, TimeoutError)
self.assertFailure(d2, TimeoutError)
self.assertFailure(d3, ConnectionDone)
return gatherResults([d1, d2, d3])
def test_timeoutCleanDeferreds(self):
"""
C{timeoutConnection} cleans the list of commands that it fires with
C{TimeoutError}: C{connectionLost} doesn't try to fire them again, but
sets the disconnected state so that future commands fail with a
C{RuntimeError}.
"""
d1 = self.proto.get(b"foo")
self.clock.advance(self.proto.persistentTimeOut)
self.assertFailure(d1, TimeoutError)
d2 = self.proto.get(b"bar")
self.assertFailure(d2, RuntimeError)
return gatherResults([d1, d2])
def test_connectionLost(self):
"""
When disconnection occurs while commands are still outstanding, the
commands fail.
"""
d1 = self.proto.get(b"foo")
d2 = self.proto.get(b"bar")
self.transport.loseConnection()
done = DeferredList([d1, d2], consumeErrors=True)
def checkFailures(results):
for success, result in results:
self.assertFalse(success)
result.trap(ConnectionDone)
return done.addCallback(checkFailures)
def test_tooLongKey(self):
"""
An error is raised when trying to use a too long key: the called
command returns a L{Deferred} which fails with a L{ClientError}.
"""
d1 = self.assertFailure(self.proto.set(b"a" * 500, b"bar"), ClientError)
d2 = self.assertFailure(self.proto.increment(b"a" * 500), ClientError)
d3 = self.assertFailure(self.proto.get(b"a" * 500), ClientError)
d4 = self.assertFailure(self.proto.append(b"a" * 500, b"bar"), ClientError)
d5 = self.assertFailure(self.proto.prepend(b"a" * 500, b"bar"), ClientError)
d6 = self.assertFailure(
self.proto.getMultiple([b"foo", b"a" * 500]), ClientError
)
return gatherResults([d1, d2, d3, d4, d5, d6])
def test_invalidCommand(self):
"""
When an unknown command is sent directly (not through public API), the
server answers with an B{ERROR} token, and the command fails with
L{NoSuchCommand}.
"""
d = self.proto._set(b"egg", b"foo", b"bar", 0, 0, b"")
self.assertEqual(self.transport.value(), b"egg foo 0 0 3\r\nbar\r\n")
self.assertFailure(d, NoSuchCommand)
self.proto.dataReceived(b"ERROR\r\n")
return d
def test_clientError(self):
"""
Test the L{ClientError} error: when the server sends a B{CLIENT_ERROR}
token, the originating command fails with L{ClientError}, and the error
contains the text sent by the server.
"""
a = b"eggspamm"
d = self.proto.set(b"foo", a)
self.assertEqual(self.transport.value(), b"set foo 0 0 8\r\neggspamm\r\n")
self.assertFailure(d, ClientError)
def check(err):
self.assertEqual(str(err), repr(b"We don't like egg and spam"))
d.addCallback(check)
self.proto.dataReceived(b"CLIENT_ERROR We don't like egg and spam\r\n")
return d
def test_serverError(self):
"""
Test the L{ServerError} error: when the server sends a B{SERVER_ERROR}
token, the originating command fails with L{ServerError}, and the error
contains the text sent by the server.
"""
a = b"eggspamm"
d = self.proto.set(b"foo", a)
self.assertEqual(self.transport.value(), b"set foo 0 0 8\r\neggspamm\r\n")
self.assertFailure(d, ServerError)
def check(err):
self.assertEqual(str(err), repr(b"zomg"))
d.addCallback(check)
self.proto.dataReceived(b"SERVER_ERROR zomg\r\n")
return d
def test_unicodeKey(self):
"""
Using a non-string key as argument to commands raises an error.
"""
d1 = self.assertFailure(self.proto.set("foo", b"bar"), ClientError)
d2 = self.assertFailure(self.proto.increment("egg"), ClientError)
d3 = self.assertFailure(self.proto.get(1), ClientError)
d4 = self.assertFailure(self.proto.delete("bar"), ClientError)
d5 = self.assertFailure(self.proto.append("foo", b"bar"), ClientError)
d6 = self.assertFailure(self.proto.prepend("foo", b"bar"), ClientError)
d7 = self.assertFailure(self.proto.getMultiple([b"egg", 1]), ClientError)
return gatherResults([d1, d2, d3, d4, d5, d6, d7])
def test_unicodeValue(self):
"""
Using a non-string value raises an error.
"""
return self.assertFailure(self.proto.set(b"foo", "bar"), ClientError)
def test_pipelining(self):
"""
Multiple requests can be sent subsequently to the server, and the
protocol orders the responses correctly and dispatch to the
corresponding client command.
"""
d1 = self.proto.get(b"foo")
d1.addCallback(self.assertEqual, (0, b"bar"))
d2 = self.proto.set(b"bar", b"spamspamspam")
d2.addCallback(self.assertEqual, True)
d3 = self.proto.get(b"egg")
d3.addCallback(self.assertEqual, (0, b"spam"))
self.assertEqual(
self.transport.value(),
b"get foo\r\nset bar 0 0 12\r\nspamspamspam\r\nget egg\r\n",
)
self.proto.dataReceived(
b"VALUE foo 0 3\r\nbar\r\nEND\r\n"
b"STORED\r\n"
b"VALUE egg 0 4\r\nspam\r\nEND\r\n"
)
return gatherResults([d1, d2, d3])
def test_getInChunks(self):
"""
If the value retrieved by a C{get} arrive in chunks, the protocol
is able to reconstruct it and to produce the good value.
"""
d = self.proto.get(b"foo")
d.addCallback(self.assertEqual, (0, b"0123456789"))
self.assertEqual(self.transport.value(), b"get foo\r\n")
self.proto.dataReceived(b"VALUE foo 0 10\r\n0123456")
self.proto.dataReceived(b"789")
self.proto.dataReceived(b"\r\nEND")
self.proto.dataReceived(b"\r\n")
return d
def test_append(self):
"""
L{MemCacheProtocol.append} behaves like a L{MemCacheProtocol.set}
method: it returns a L{Deferred} which is called back with C{True} when
the operation succeeds.
"""
return self._test(
self.proto.append(b"foo", b"bar"),
b"append foo 0 0 3\r\nbar\r\n",
b"STORED\r\n",
True,
)
def test_prepend(self):
"""
L{MemCacheProtocol.prepend} behaves like a L{MemCacheProtocol.set}
method: it returns a L{Deferred} which is called back with C{True} when
the operation succeeds.
"""
return self._test(
self.proto.prepend(b"foo", b"bar"),
b"prepend foo 0 0 3\r\nbar\r\n",
b"STORED\r\n",
True,
)
def test_gets(self):
"""
L{MemCacheProtocol.get} handles an additional cas result when
C{withIdentifier} is C{True} and forward it in the resulting
L{Deferred}.
"""
return self._test(
self.proto.get(b"foo", True),
b"gets foo\r\n",
b"VALUE foo 0 3 1234\r\nbar\r\nEND\r\n",
(0, b"1234", b"bar"),
)
def test_emptyGets(self):
"""
Test getting a non-available key with gets: it succeeds but return
L{None} as value, C{0} as flag and an empty cas value.
"""
return self._test(
self.proto.get(b"foo", True), b"gets foo\r\n", b"END\r\n", (0, b"", None)
)
def test_getsMultiple(self):
"""
L{MemCacheProtocol.getMultiple} handles an additional cas field in the
returned tuples if C{withIdentifier} is C{True}.
"""
return self._test(
self.proto.getMultiple([b"foo", b"bar"], True),
b"gets foo bar\r\n",
b"VALUE foo 0 3 1234\r\negg\r\n" b"VALUE bar 0 4 2345\r\nspam\r\nEND\r\n",
{b"bar": (0, b"2345", b"spam"), b"foo": (0, b"1234", b"egg")},
)
def test_getsMultipleIterableKeys(self):
"""
L{MemCacheProtocol.getMultiple} accepts any iterable of keys.
"""
return self._test(
self.proto.getMultiple(iter([b"foo", b"bar"]), True),
b"gets foo bar\r\n",
b"VALUE foo 0 3 1234\r\negg\r\n" b"VALUE bar 0 4 2345\r\nspam\r\nEND\r\n",
{b"bar": (0, b"2345", b"spam"), b"foo": (0, b"1234", b"egg")},
)
def test_getsMultipleWithEmpty(self):
"""
When getting a non-available key with L{MemCacheProtocol.getMultiple}
when C{withIdentifier} is C{True}, the other keys are retrieved
correctly, and the non-available key gets a tuple of C{0} as flag,
L{None} as value, and an empty cas value.
"""
return self._test(
self.proto.getMultiple([b"foo", b"bar"], True),
b"gets foo bar\r\n",
b"VALUE foo 0 3 1234\r\negg\r\nEND\r\n",
{b"bar": (0, b"", None), b"foo": (0, b"1234", b"egg")},
)
def test_checkAndSet(self):
"""
L{MemCacheProtocol.checkAndSet} passes an additional cas identifier
that the server handles to check if the data has to be updated.
"""
return self._test(
self.proto.checkAndSet(b"foo", b"bar", cas=b"1234"),
b"cas foo 0 0 3 1234\r\nbar\r\n",
b"STORED\r\n",
True,
)
def test_casUnknowKey(self):
"""
When L{MemCacheProtocol.checkAndSet} response is C{EXISTS}, the
resulting L{Deferred} fires with C{False}.
"""
return self._test(
self.proto.checkAndSet(b"foo", b"bar", cas=b"1234"),
b"cas foo 0 0 3 1234\r\nbar\r\n",
b"EXISTS\r\n",
False,
)
class CommandFailureTests(CommandMixin, TestCase):
"""
Tests for correct failure of commands on a disconnected
L{MemCacheProtocol}.
"""
def setUp(self):
"""
Create a disconnected memcache client, using a deterministic clock.
"""
self.proto = MemCacheProtocol()
self.clock = Clock()
self.proto.callLater = self.clock.callLater
self.transport = StringTransportWithDisconnection()
self.transport.protocol = self.proto
self.proto.makeConnection(self.transport)
self.transport.loseConnection()
def _test(self, d, send, recv, result):
"""
Implementation of C{_test} which checks that the command fails with
C{RuntimeError} because the transport is disconnected. All the
parameters except C{d} are ignored.
"""
return self.assertFailure(d, RuntimeError)

View File

@@ -0,0 +1,503 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for twisted.python.modules, abstract access to imported or importable
objects.
"""
from __future__ import annotations
import compileall
import itertools
import sys
import zipfile
from importlib.abc import PathEntryFinder
from types import ModuleType
from typing import Any, Generator
from typing_extensions import Protocol
import twisted
from twisted.python import modules
from twisted.python.compat import networkString
from twisted.python.filepath import FilePath
from twisted.python.reflect import namedAny
from twisted.python.test.modules_helpers import TwistedModulesMixin
from twisted.python.test.test_zippath import zipit
from twisted.trial.unittest import TestCase
class _SupportsWalkModules(Protocol):
def walkModules(
self, importPackages: bool
) -> Generator[modules.PythonModule, None, None]:
...
class TwistedModulesTestCase(TwistedModulesMixin, TestCase):
"""
Base class for L{modules} test cases.
"""
def findByIteration(
self,
modname: str,
where: _SupportsWalkModules = modules,
importPackages: bool = False,
) -> modules.PythonModule:
"""
You don't ever actually want to do this, so it's not in the public
API, but sometimes we want to compare the result of an iterative call
with a lookup call and make sure they're the same for test purposes.
"""
for modinfo in where.walkModules(importPackages=importPackages):
if modinfo.name == modname:
return modinfo
self.fail(f"Unable to find module {modname!r} through iteration.")
class BasicTests(TwistedModulesTestCase):
def test_namespacedPackages(self) -> None:
"""
Duplicate packages are not yielded when iterating over namespace
packages.
"""
# Force pkgutil to be loaded already, since the probe package being
# created depends on it, and the replaceSysPath call below will make
# pretty much everything unimportable.
__import__("pkgutil")
namespaceBoilerplate = (
b"import pkgutil; " b"__path__ = pkgutil.extend_path(__path__, __name__)"
)
# Create two temporary directories with packages:
#
# entry:
# test_package/
# __init__.py
# nested_package/
# __init__.py
# module.py
#
# anotherEntry:
# test_package/
# __init__.py
# nested_package/
# __init__.py
# module2.py
#
# test_package and test_package.nested_package are namespace packages,
# and when both of these are in sys.path, test_package.nested_package
# should become a virtual package containing both "module" and
# "module2"
entry = self.pathEntryWithOnePackage()
testPackagePath = entry.child("test_package")
testPackagePath.child("__init__.py").setContent(namespaceBoilerplate)
nestedEntry = testPackagePath.child("nested_package")
nestedEntry.makedirs()
nestedEntry.child("__init__.py").setContent(namespaceBoilerplate)
nestedEntry.child("module.py").setContent(b"")
anotherEntry = self.pathEntryWithOnePackage()
anotherPackagePath = anotherEntry.child("test_package")
anotherPackagePath.child("__init__.py").setContent(namespaceBoilerplate)
anotherNestedEntry = anotherPackagePath.child("nested_package")
anotherNestedEntry.makedirs()
anotherNestedEntry.child("__init__.py").setContent(namespaceBoilerplate)
anotherNestedEntry.child("module2.py").setContent(b"")
self.replaceSysPath([entry.path, anotherEntry.path])
module = modules.getModule("test_package")
# We have to use importPackages=True in order to resolve the namespace
# packages, so we remove the imported packages from sys.modules after
# walking
try:
walkedNames = [mod.name for mod in module.walkModules(importPackages=True)]
finally:
for module in list(sys.modules.keys()):
if module.startswith("test_package"):
del sys.modules[module]
expected = [
"test_package",
"test_package.nested_package",
"test_package.nested_package.module",
"test_package.nested_package.module2",
]
self.assertEqual(walkedNames, expected)
def test_unimportablePackageGetItem(self) -> None:
"""
If a package has been explicitly forbidden from importing by setting a
L{None} key in sys.modules under its name,
L{modules.PythonPath.__getitem__} should still be able to retrieve an
unloaded L{modules.PythonModule} for that package.
"""
shouldNotLoad: list[str] = []
path = modules.PythonPath(
sysPath=[self.pathEntryWithOnePackage().path],
moduleLoader=shouldNotLoad.append,
importerCache={},
sysPathHooks={},
moduleDict={"test_package": None},
)
self.assertEqual(shouldNotLoad, [])
self.assertFalse(path["test_package"].isLoaded())
def test_unimportablePackageWalkModules(self) -> None:
"""
If a package has been explicitly forbidden from importing by setting a
L{None} key in sys.modules under its name, L{modules.walkModules} should
still be able to retrieve an unloaded L{modules.PythonModule} for that
package.
"""
existentPath = self.pathEntryWithOnePackage()
self.replaceSysPath([existentPath.path])
self.replaceSysModules({"test_package": None}) # type: ignore[dict-item]
walked = list(modules.walkModules())
self.assertEqual([m.name for m in walked], ["test_package"])
self.assertFalse(walked[0].isLoaded())
def test_nonexistentPaths(self) -> None:
"""
Verify that L{modules.walkModules} ignores entries in sys.path which
do not exist in the filesystem.
"""
existentPath = self.pathEntryWithOnePackage()
nonexistentPath = FilePath(self.mktemp())
self.assertFalse(nonexistentPath.exists())
self.replaceSysPath([existentPath.path])
expected = [modules.getModule("test_package")]
beforeModules = list(modules.walkModules())
sys.path.append(nonexistentPath.path)
afterModules = list(modules.walkModules())
self.assertEqual(beforeModules, expected)
self.assertEqual(afterModules, expected)
def test_nonDirectoryPaths(self) -> None:
"""
Verify that L{modules.walkModules} ignores entries in sys.path which
refer to regular files in the filesystem.
"""
existentPath = self.pathEntryWithOnePackage()
nonDirectoryPath = FilePath(self.mktemp())
self.assertFalse(nonDirectoryPath.exists())
nonDirectoryPath.setContent(b"zip file or whatever\n")
self.replaceSysPath([existentPath.path])
beforeModules = list(modules.walkModules())
sys.path.append(nonDirectoryPath.path)
afterModules = list(modules.walkModules())
self.assertEqual(beforeModules, afterModules)
def test_twistedShowsUp(self) -> None:
"""
Scrounge around in the top-level module namespace and make sure that
Twisted shows up, and that the module thusly obtained is the same as
the module that we find when we look for it explicitly by name.
"""
self.assertEqual(modules.getModule("twisted"), self.findByIteration("twisted"))
def test_dottedNames(self) -> None:
"""
Verify that the walkModules APIs will give us back subpackages, not just
subpackages.
"""
self.assertEqual(
modules.getModule("twisted.python"),
self.findByIteration("twisted.python", where=modules.getModule("twisted")),
)
def test_onlyTopModules(self) -> None:
"""
Verify that the iterModules API will only return top-level modules and
packages, not submodules or subpackages.
"""
for module in modules.iterModules():
self.assertFalse(
"." in module.name,
"no nested modules should be returned from iterModules: %r"
% (module.filePath),
)
def test_loadPackagesAndModules(self) -> None:
"""
Verify that we can locate and load packages, modules, submodules, and
subpackages.
"""
for n in ["os", "twisted", "twisted.python", "twisted.python.reflect"]:
m = namedAny(n)
self.failUnlessIdentical(modules.getModule(n).load(), m)
self.failUnlessIdentical(self.findByIteration(n).load(), m)
def test_pathEntriesOnPath(self) -> None:
"""
Verify that path entries discovered via module loading are, in fact, on
sys.path somewhere.
"""
for n in ["os", "twisted", "twisted.python", "twisted.python.reflect"]:
self.failUnlessIn(modules.getModule(n).pathEntry.filePath.path, sys.path)
def test_alwaysPreferPy(self) -> None:
"""
Verify that .py files will always be preferred to .pyc files, regardless of
directory listing order.
"""
mypath = FilePath(self.mktemp())
mypath.createDirectory()
pp = modules.PythonPath(sysPath=[mypath.path])
originalSmartPath = pp._smartPath
def _evilSmartPath(pathName: str) -> Any:
o = originalSmartPath(pathName)
originalChildren = o.children
def evilChildren() -> Any:
# normally this order is random; let's make sure it always
# comes up .pyc-first.
x = list(originalChildren())
x.sort()
x.reverse()
return x
o.children = evilChildren
return o
mypath.child("abcd.py").setContent(b"\n")
compileall.compile_dir(mypath.path, quiet=True)
# sanity check
self.assertEqual(len(list(mypath.children())), 2)
pp._smartPath = _evilSmartPath # type: ignore[method-assign]
self.assertEqual(pp["abcd"].filePath, mypath.child("abcd.py"))
def test_packageMissingPath(self) -> None:
"""
A package can delete its __path__ for some reasons,
C{modules.PythonPath} should be able to deal with it.
"""
mypath = FilePath(self.mktemp())
mypath.createDirectory()
pp = modules.PythonPath(sysPath=[mypath.path])
subpath = mypath.child("abcd")
subpath.createDirectory()
subpath.child("__init__.py").setContent(b"del __path__\n")
sys.path.append(mypath.path)
__import__("abcd")
try:
l = list(pp.walkModules())
self.assertEqual(len(l), 1)
self.assertEqual(l[0].name, "abcd")
finally:
del sys.modules["abcd"]
sys.path.remove(mypath.path)
class PathModificationTests(TwistedModulesTestCase):
"""
These tests share setup/cleanup behavior of creating a dummy package and
stuffing some code in it.
"""
_serialnum = itertools.count() # used to generate serial numbers for
# package names.
def setUp(self) -> None:
self.pathExtensionName = self.mktemp()
self.pathExtension = FilePath(self.pathExtensionName)
self.pathExtension.createDirectory()
self.packageName = "pyspacetests%d" % (next(self._serialnum),)
self.packagePath = self.pathExtension.child(self.packageName)
self.packagePath.createDirectory()
self.packagePath.child("__init__.py").setContent(b"")
self.packagePath.child("a.py").setContent(b"")
self.packagePath.child("b.py").setContent(b"")
self.packagePath.child("c__init__.py").setContent(b"")
self.pathSetUp = False
def _setupSysPath(self) -> None:
assert not self.pathSetUp
self.pathSetUp = True
sys.path.append(self.pathExtensionName)
def _underUnderPathTest(self, doImport: bool = True) -> None:
moddir2 = self.mktemp()
fpmd = FilePath(moddir2)
fpmd.createDirectory()
fpmd.child("foozle.py").setContent(b"x = 123\n")
self.packagePath.child("__init__.py").setContent(
networkString(f"__path__.append({repr(moddir2)})\n")
)
# Cut here
self._setupSysPath()
modinfo = modules.getModule(self.packageName)
self.assertEqual(
self.findByIteration(
self.packageName + ".foozle", modinfo, importPackages=doImport
),
modinfo["foozle"],
)
self.assertEqual(modinfo["foozle"].load().x, 123)
def test_underUnderPathAlreadyImported(self) -> None:
"""
Verify that iterModules will honor the __path__ of already-loaded packages.
"""
self._underUnderPathTest()
def _listModules(self) -> None:
pkginfo = modules.getModule(self.packageName)
nfni = [modinfo.name.split(".")[-1] for modinfo in pkginfo.iterModules()]
nfni.sort()
self.assertEqual(nfni, ["a", "b", "c__init__"])
def test_listingModules(self) -> None:
"""
Make sure the module list comes back as we expect from iterModules on a
package, whether zipped or not.
"""
self._setupSysPath()
self._listModules()
def test_listingModulesAlreadyImported(self) -> None:
"""
Make sure the module list comes back as we expect from iterModules on a
package, whether zipped or not, even if the package has already been
imported.
"""
self._setupSysPath()
namedAny(self.packageName)
self._listModules()
def tearDown(self) -> None:
# Intentionally using 'assert' here, this is not a test assertion, this
# is just an "oh fuck what is going ON" assertion. -glyph
if self.pathSetUp:
HORK = "path cleanup failed: don't be surprised if other tests break"
assert sys.path.pop() is self.pathExtensionName, HORK + ", 1"
assert self.pathExtensionName not in sys.path, HORK + ", 2"
class RebindingTests(PathModificationTests):
"""
These tests verify that the default path interrogation API works properly
even when sys.path has been rebound to a different object.
"""
def _setupSysPath(self) -> None:
assert not self.pathSetUp
self.pathSetUp = True
self.savedSysPath = sys.path
sys.path = sys.path[:]
sys.path.append(self.pathExtensionName)
def tearDown(self) -> None:
"""
Clean up sys.path by re-binding our original object.
"""
if self.pathSetUp:
sys.path = self.savedSysPath
class ZipPathModificationTests(PathModificationTests):
def _setupSysPath(self) -> None:
assert not self.pathSetUp
zipit(self.pathExtensionName, self.pathExtensionName + ".zip")
self.pathExtensionName += ".zip"
assert zipfile.is_zipfile(self.pathExtensionName)
PathModificationTests._setupSysPath(self)
class PythonPathTests(TestCase):
"""
Tests for the class which provides the implementation for all of the
public API of L{twisted.python.modules}, L{PythonPath}.
"""
def test_unhandledImporter(self) -> None:
"""
Make sure that the behavior when encountering an unknown importer
type is not catastrophic failure.
"""
class SecretImporter:
pass
def hook(name: object) -> SecretImporter:
return SecretImporter()
syspath = ["example/path"]
sysmodules: dict[str, ModuleType] = {}
syshooks = [hook]
syscache: dict[str, PathEntryFinder | None] = {}
def sysloader(name: object) -> None:
return None
space = modules.PythonPath(syspath, sysmodules, syshooks, syscache, sysloader)
entries = list(space.iterEntries())
self.assertEqual(len(entries), 1)
self.assertRaises(KeyError, lambda: entries[0]["module"])
def test_inconsistentImporterCache(self) -> None:
"""
If the path a module loaded with L{PythonPath.__getitem__} is not
present in the path importer cache, a warning is emitted, but the
L{PythonModule} is returned as usual.
"""
space = modules.PythonPath([], sys.modules, [], {})
thisModule = space[__name__]
warnings = self.flushWarnings([self.test_inconsistentImporterCache])
self.assertEqual(warnings[0]["category"], UserWarning)
self.assertEqual(
warnings[0]["message"],
FilePath(twisted.__file__).parent().dirname()
+ " (for module "
+ __name__
+ ") not in path importer cache "
"(PEP 302 violation - check your local configuration).",
)
self.assertEqual(len(warnings), 1)
self.assertEqual(thisModule.name, __name__)
def test_containsModule(self) -> None:
"""
L{PythonPath} implements the C{in} operator so that when it is the
right-hand argument and the name of a module which exists on that
L{PythonPath} is the left-hand argument, the result is C{True}.
"""
thePath = modules.PythonPath()
self.assertIn("os", thePath)
def test_doesntContainModule(self) -> None:
"""
L{PythonPath} implements the C{in} operator so that when it is the
right-hand argument and the name of a module which does not exist on
that L{PythonPath} is the left-hand argument, the result is C{False}.
"""
thePath = modules.PythonPath()
self.assertNotIn("bogusModule", thePath)
__all__ = [
"BasicTests",
"PathModificationTests",
"RebindingTests",
"ZipPathModificationTests",
"PythonPathTests",
]

View File

@@ -0,0 +1,175 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.python.monkey}.
"""
from __future__ import annotations
from typing_extensions import NoReturn
from twisted.python.monkey import MonkeyPatcher
from twisted.trial import unittest
class TestObj:
def __init__(self) -> None:
self.foo = "foo value"
self.bar = "bar value"
self.baz = "baz value"
class MonkeyPatcherTests(unittest.SynchronousTestCase):
"""
Tests for L{MonkeyPatcher} monkey-patching class.
"""
def setUp(self) -> None:
self.testObject = TestObj()
self.originalObject = TestObj()
self.monkeyPatcher = MonkeyPatcher()
def test_empty(self) -> None:
"""
A monkey patcher without patches shouldn't change a thing.
"""
self.monkeyPatcher.patch()
# We can't assert that all state is unchanged, but at least we can
# check our test object.
self.assertEqual(self.originalObject.foo, self.testObject.foo)
self.assertEqual(self.originalObject.bar, self.testObject.bar)
self.assertEqual(self.originalObject.baz, self.testObject.baz)
def test_constructWithPatches(self) -> None:
"""
Constructing a L{MonkeyPatcher} with patches should add all of the
given patches to the patch list.
"""
patcher = MonkeyPatcher(
(self.testObject, "foo", "haha"), (self.testObject, "bar", "hehe")
)
patcher.patch()
self.assertEqual("haha", self.testObject.foo)
self.assertEqual("hehe", self.testObject.bar)
self.assertEqual(self.originalObject.baz, self.testObject.baz)
def test_patchExisting(self) -> None:
"""
Patching an attribute that exists sets it to the value defined in the
patch.
"""
self.monkeyPatcher.addPatch(self.testObject, "foo", "haha")
self.monkeyPatcher.patch()
self.assertEqual(self.testObject.foo, "haha")
def test_patchNonExisting(self) -> None:
"""
Patching a non-existing attribute fails with an C{AttributeError}.
"""
self.monkeyPatcher.addPatch(self.testObject, "nowhere", "blow up please")
self.assertRaises(AttributeError, self.monkeyPatcher.patch)
def test_patchAlreadyPatched(self) -> None:
"""
Adding a patch for an object and attribute that already have a patch
overrides the existing patch.
"""
self.monkeyPatcher.addPatch(self.testObject, "foo", "blah")
self.monkeyPatcher.addPatch(self.testObject, "foo", "BLAH")
self.monkeyPatcher.patch()
self.assertEqual(self.testObject.foo, "BLAH")
self.monkeyPatcher.restore()
self.assertEqual(self.testObject.foo, self.originalObject.foo)
def test_restoreTwiceIsANoOp(self) -> None:
"""
Restoring an already-restored monkey patch is a no-op.
"""
self.monkeyPatcher.addPatch(self.testObject, "foo", "blah")
self.monkeyPatcher.patch()
self.monkeyPatcher.restore()
self.assertEqual(self.testObject.foo, self.originalObject.foo)
self.monkeyPatcher.restore()
self.assertEqual(self.testObject.foo, self.originalObject.foo)
def test_runWithPatchesDecoration(self) -> None:
"""
runWithPatches should run the given callable, passing in all arguments
and keyword arguments, and return the return value of the callable.
"""
log: list[tuple[int, int, int | None]] = []
def f(a: int, b: int, c: int | None = None) -> str:
log.append((a, b, c))
return "foo"
result = self.monkeyPatcher.runWithPatches(f, 1, 2, c=10)
self.assertEqual("foo", result)
self.assertEqual([(1, 2, 10)], log)
def test_repeatedRunWithPatches(self) -> None:
"""
We should be able to call the same function with runWithPatches more
than once. All patches should apply for each call.
"""
def f() -> tuple[str, str, str]:
return (self.testObject.foo, self.testObject.bar, self.testObject.baz)
self.monkeyPatcher.addPatch(self.testObject, "foo", "haha")
result = self.monkeyPatcher.runWithPatches(f)
self.assertEqual(
("haha", self.originalObject.bar, self.originalObject.baz), result
)
result = self.monkeyPatcher.runWithPatches(f)
self.assertEqual(
("haha", self.originalObject.bar, self.originalObject.baz), result
)
def test_runWithPatchesRestores(self) -> None:
"""
C{runWithPatches} should restore the original values after the function
has executed.
"""
self.monkeyPatcher.addPatch(self.testObject, "foo", "haha")
self.assertEqual(self.originalObject.foo, self.testObject.foo)
self.monkeyPatcher.runWithPatches(lambda: None)
self.assertEqual(self.originalObject.foo, self.testObject.foo)
def test_runWithPatchesRestoresOnException(self) -> None:
"""
Test runWithPatches restores the original values even when the function
raises an exception.
"""
def _() -> NoReturn:
self.assertEqual(self.testObject.foo, "haha")
self.assertEqual(self.testObject.bar, "blahblah")
raise RuntimeError("Something went wrong!")
self.monkeyPatcher.addPatch(self.testObject, "foo", "haha")
self.monkeyPatcher.addPatch(self.testObject, "bar", "blahblah")
self.assertRaises(RuntimeError, self.monkeyPatcher.runWithPatches, _)
self.assertEqual(self.testObject.foo, self.originalObject.foo)
self.assertEqual(self.testObject.bar, self.originalObject.bar)
def test_contextManager(self) -> None:
"""
L{MonkeyPatcher} is a context manager that applies its patches on
entry and restore original values on exit.
"""
self.monkeyPatcher.addPatch(self.testObject, "foo", "patched value")
with self.monkeyPatcher:
self.assertEqual(self.testObject.foo, "patched value")
self.assertEqual(self.testObject.foo, self.originalObject.foo)
def test_contextManagerPropagatesExceptions(self) -> None:
"""
Exceptions propagate through the L{MonkeyPatcher} context-manager
exit method.
"""
with self.assertRaises(RuntimeError):
with self.monkeyPatcher:
raise RuntimeError("something")

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,382 @@
# -*- Python -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
__version__ = "$Revision: 1.5 $"[11:-2]
from twisted.protocols import pcp
from twisted.trial import unittest
# Goal:
# Take a Protocol instance. Own all outgoing data - anything that
# would go to p.transport.write. Own all incoming data - anything
# that comes to p.dataReceived.
# I need:
# Something with the AbstractFileDescriptor interface.
# That is:
# - acts as a Transport
# - has a method write()
# - which buffers
# - acts as a Consumer
# - has a registerProducer, unRegisterProducer
# - tells the Producer to back off (pauseProducing) when its buffer is full.
# - tells the Producer to resumeProducing when its buffer is not so full.
# - acts as a Producer
# - calls registerProducer
# - calls write() on consumers
# - honors requests to pause/resume producing
# - honors stopProducing, and passes it along to upstream Producers
class DummyTransport:
"""A dumb transport to wrap around."""
def __init__(self):
self._writes = []
def write(self, data):
self._writes.append(data)
def getvalue(self):
return "".join(self._writes)
class DummyProducer:
resumed = False
stopped = False
paused = False
def __init__(self, consumer):
self.consumer = consumer
def resumeProducing(self):
self.resumed = True
self.paused = False
def pauseProducing(self):
self.paused = True
def stopProducing(self):
self.stopped = True
class DummyConsumer(DummyTransport):
producer = None
finished = False
unregistered = True
def registerProducer(self, producer, streaming):
self.producer = (producer, streaming)
def unregisterProducer(self):
self.unregistered = True
def finish(self):
self.finished = True
class TransportInterfaceTests(unittest.TestCase):
proxyClass = pcp.BasicProducerConsumerProxy
def setUp(self):
self.underlying = DummyConsumer()
self.transport = self.proxyClass(self.underlying)
def testWrite(self):
self.transport.write("some bytes")
class ConsumerInterfaceTest:
"""Test ProducerConsumerProxy as a Consumer.
Normally we have ProducingServer -> ConsumingTransport.
If I am to go between (Server -> Shaper -> Transport), I have to
play the role of Consumer convincingly for the ProducingServer.
"""
def setUp(self):
self.underlying = DummyConsumer()
self.consumer = self.proxyClass(self.underlying)
self.producer = DummyProducer(self.consumer)
def testRegisterPush(self):
self.consumer.registerProducer(self.producer, True)
## Consumer should NOT have called PushProducer.resumeProducing
self.assertFalse(self.producer.resumed)
## I'm I'm just a proxy, should I only do resumeProducing when
## I get poked myself?
# def testRegisterPull(self):
# self.consumer.registerProducer(self.producer, False)
# ## Consumer SHOULD have called PushProducer.resumeProducing
# self.assertTrue(self.producer.resumed)
def testUnregister(self):
self.consumer.registerProducer(self.producer, False)
self.consumer.unregisterProducer()
# Now when the consumer would ordinarily want more data, it
# shouldn't ask producer for it.
# The most succinct way to trigger "want more data" is to proxy for
# a PullProducer and have someone ask me for data.
self.producer.resumed = False
self.consumer.resumeProducing()
self.assertFalse(self.producer.resumed)
def testFinish(self):
self.consumer.registerProducer(self.producer, False)
self.consumer.finish()
# I guess finish should behave like unregister?
self.producer.resumed = False
self.consumer.resumeProducing()
self.assertFalse(self.producer.resumed)
class ProducerInterfaceTest:
"""Test ProducerConsumerProxy as a Producer.
Normally we have ProducingServer -> ConsumingTransport.
If I am to go between (Server -> Shaper -> Transport), I have to
play the role of Producer convincingly for the ConsumingTransport.
"""
def setUp(self):
self.consumer = DummyConsumer()
self.producer = self.proxyClass(self.consumer)
def testRegistersProducer(self):
self.assertEqual(self.consumer.producer[0], self.producer)
def testPause(self):
self.producer.pauseProducing()
self.producer.write("yakkity yak")
self.assertFalse(
self.consumer.getvalue(), "Paused producer should not have sent data."
)
def testResume(self):
self.producer.pauseProducing()
self.producer.resumeProducing()
self.producer.write("yakkity yak")
self.assertEqual(self.consumer.getvalue(), "yakkity yak")
def testResumeNoEmptyWrite(self):
self.producer.pauseProducing()
self.producer.resumeProducing()
self.assertEqual(
len(self.consumer._writes), 0, "Resume triggered an empty write."
)
def testResumeBuffer(self):
self.producer.pauseProducing()
self.producer.write("buffer this")
self.producer.resumeProducing()
self.assertEqual(self.consumer.getvalue(), "buffer this")
def testStop(self):
self.producer.stopProducing()
self.producer.write("yakkity yak")
self.assertFalse(
self.consumer.getvalue(), "Stopped producer should not have sent data."
)
class PCP_ConsumerInterfaceTests(ConsumerInterfaceTest, unittest.TestCase):
proxyClass = pcp.BasicProducerConsumerProxy
class PCPII_ConsumerInterfaceTests(ConsumerInterfaceTest, unittest.TestCase):
proxyClass = pcp.ProducerConsumerProxy
class PCP_ProducerInterfaceTests(ProducerInterfaceTest, unittest.TestCase):
proxyClass = pcp.BasicProducerConsumerProxy
class PCPII_ProducerInterfaceTests(ProducerInterfaceTest, unittest.TestCase):
proxyClass = pcp.ProducerConsumerProxy
class ProducerProxyTests(unittest.TestCase):
"""Producer methods on me should be relayed to the Producer I proxy."""
proxyClass = pcp.BasicProducerConsumerProxy
def setUp(self):
self.proxy = self.proxyClass(None)
self.parentProducer = DummyProducer(self.proxy)
self.proxy.registerProducer(self.parentProducer, True)
def testStop(self):
self.proxy.stopProducing()
self.assertTrue(self.parentProducer.stopped)
class ConsumerProxyTests(unittest.TestCase):
"""Consumer methods on me should be relayed to the Consumer I proxy."""
proxyClass = pcp.BasicProducerConsumerProxy
def setUp(self):
self.underlying = DummyConsumer()
self.consumer = self.proxyClass(self.underlying)
def testWrite(self):
# NOTE: This test only valid for streaming (Push) systems.
self.consumer.write("some bytes")
self.assertEqual(self.underlying.getvalue(), "some bytes")
def testFinish(self):
self.consumer.finish()
self.assertTrue(self.underlying.finished)
def testUnregister(self):
self.consumer.unregisterProducer()
self.assertTrue(self.underlying.unregistered)
class PullProducerTest:
def setUp(self):
self.underlying = DummyConsumer()
self.proxy = self.proxyClass(self.underlying)
self.parentProducer = DummyProducer(self.proxy)
self.proxy.registerProducer(self.parentProducer, True)
def testHoldWrites(self):
self.proxy.write("hello")
# Consumer should get no data before it says resumeProducing.
self.assertFalse(
self.underlying.getvalue(), "Pulling Consumer got data before it pulled."
)
def testPull(self):
self.proxy.write("hello")
self.proxy.resumeProducing()
self.assertEqual(self.underlying.getvalue(), "hello")
def testMergeWrites(self):
self.proxy.write("hello ")
self.proxy.write("sunshine")
self.proxy.resumeProducing()
nwrites = len(self.underlying._writes)
self.assertEqual(
nwrites, 1, "Pull resulted in %d writes instead " "of 1." % (nwrites,)
)
self.assertEqual(self.underlying.getvalue(), "hello sunshine")
def testLateWrite(self):
# consumer sends its initial pull before we have data
self.proxy.resumeProducing()
self.proxy.write("data")
# This data should answer that pull request.
self.assertEqual(self.underlying.getvalue(), "data")
class PCP_PullProducerTests(PullProducerTest, unittest.TestCase):
class proxyClass(pcp.BasicProducerConsumerProxy):
iAmStreaming = False
class PCPII_PullProducerTests(PullProducerTest, unittest.TestCase):
class proxyClass(pcp.ProducerConsumerProxy):
iAmStreaming = False
# Buffering!
class BufferedConsumerTests(unittest.TestCase):
"""As a consumer, ask the producer to pause after too much data."""
proxyClass = pcp.ProducerConsumerProxy
def setUp(self):
self.underlying = DummyConsumer()
self.proxy = self.proxyClass(self.underlying)
self.proxy.bufferSize = 100
self.parentProducer = DummyProducer(self.proxy)
self.proxy.registerProducer(self.parentProducer, True)
def testRegisterPull(self):
self.proxy.registerProducer(self.parentProducer, False)
## Consumer SHOULD have called PushProducer.resumeProducing
self.assertTrue(self.parentProducer.resumed)
def testPauseIntercept(self):
self.proxy.pauseProducing()
self.assertFalse(self.parentProducer.paused)
def testResumeIntercept(self):
self.proxy.pauseProducing()
self.proxy.resumeProducing()
# With a streaming producer, just because the proxy was resumed is
# not necessarily a reason to resume the parent producer. The state
# of the buffer should decide that.
self.assertFalse(self.parentProducer.resumed)
def testTriggerPause(self):
"""Make sure I say \"when.\" """
# Pause the proxy so data sent to it builds up in its buffer.
self.proxy.pauseProducing()
self.assertFalse(self.parentProducer.paused, "don't pause yet")
self.proxy.write("x" * 51)
self.assertFalse(self.parentProducer.paused, "don't pause yet")
self.proxy.write("x" * 51)
self.assertTrue(self.parentProducer.paused)
def testTriggerResume(self):
"""Make sure I resumeProducing when my buffer empties."""
self.proxy.pauseProducing()
self.proxy.write("x" * 102)
self.assertTrue(self.parentProducer.paused, "should be paused")
self.proxy.resumeProducing()
# Resuming should have emptied my buffer, so I should tell my
# parent to resume too.
self.assertFalse(self.parentProducer.paused, "Producer should have resumed.")
self.assertFalse(self.proxy.producerPaused)
class BufferedPullTests(unittest.TestCase):
class proxyClass(pcp.ProducerConsumerProxy):
iAmStreaming = False
def _writeSomeData(self, data):
pcp.ProducerConsumerProxy._writeSomeData(self, data[:100])
return min(len(data), 100)
def setUp(self):
self.underlying = DummyConsumer()
self.proxy = self.proxyClass(self.underlying)
self.proxy.bufferSize = 100
self.parentProducer = DummyProducer(self.proxy)
self.proxy.registerProducer(self.parentProducer, False)
def testResumePull(self):
# If proxy has no data to send on resumeProducing, it had better pull
# some from its PullProducer.
self.parentProducer.resumed = False
self.proxy.resumeProducing()
self.assertTrue(self.parentProducer.resumed)
def testLateWriteBuffering(self):
# consumer sends its initial pull before we have data
self.proxy.resumeProducing()
self.proxy.write("datum" * 21)
# This data should answer that pull request.
self.assertEqual(self.underlying.getvalue(), "datum" * 20)
# but there should be some left over
self.assertEqual(self.proxy._buffer, ["datum"])
# TODO:
# test that web request finishing bug (when we weren't proxying
# unregisterProducer but were proxying finish, web file transfers
# would hang on the last block.)
# test what happens if writeSomeBytes decided to write zero bytes.

View File

@@ -0,0 +1,536 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
from __future__ import annotations
# System Imports
import copyreg
import io
import pickle
import sys
import textwrap
from typing import Any, Callable, List, Tuple
from typing_extensions import NoReturn
# Twisted Imports
from twisted.persisted import aot, crefutil, styles
from twisted.trial import unittest
from twisted.trial.unittest import TestCase
class VersionTests(TestCase):
def test_nullVersionUpgrade(self) -> None:
global NullVersioned
class NullVersioned:
def __init__(self) -> None:
self.ok = 0
pkcl = pickle.dumps(NullVersioned())
class NullVersioned(styles.Versioned): # type: ignore[no-redef]
persistenceVersion = 1
def upgradeToVersion1(self) -> None:
self.ok = 1
mnv = pickle.loads(pkcl)
styles.doUpgrade()
assert mnv.ok, "initial upgrade not run!"
def test_versionUpgrade(self) -> None:
global MyVersioned
class MyVersioned(styles.Versioned):
persistenceVersion = 2
# persistenceForgets should be a tuple
persistenceForgets = ["garbagedata"] # type: ignore[assignment]
v3 = 0
v4 = 0
def __init__(self) -> None:
self.somedata = "xxx"
self.garbagedata = lambda q: "cant persist"
def upgradeToVersion3(self) -> None:
self.v3 += 1
def upgradeToVersion4(self) -> None:
self.v4 += 1
mv = MyVersioned()
assert not (mv.v3 or mv.v4), "hasn't been upgraded yet"
pickl = pickle.dumps(mv)
MyVersioned.persistenceVersion = 4
obj = pickle.loads(pickl)
styles.doUpgrade()
assert obj.v3, "didn't do version 3 upgrade"
assert obj.v4, "didn't do version 4 upgrade"
pickl = pickle.dumps(obj)
obj = pickle.loads(pickl)
styles.doUpgrade()
assert obj.v3 == 1, "upgraded unnecessarily"
assert obj.v4 == 1, "upgraded unnecessarily"
def test_nonIdentityHash(self) -> None:
global ClassWithCustomHash
class ClassWithCustomHash(styles.Versioned):
def __init__(self, unique: str, hash: int) -> None:
self.unique = unique
self.hash = hash
def __hash__(self) -> int:
return self.hash
v1 = ClassWithCustomHash("v1", 0)
v2 = ClassWithCustomHash("v2", 0)
pkl = pickle.dumps((v1, v2))
del v1, v2
ClassWithCustomHash.persistenceVersion = 1
ClassWithCustomHash.upgradeToVersion1 = lambda self: setattr( # type: ignore[attr-defined]
self, "upgraded", True
)
v1, v2 = pickle.loads(pkl)
styles.doUpgrade()
self.assertEqual(v1.unique, "v1")
self.assertEqual(v2.unique, "v2")
self.assertTrue(v1.upgraded) # type: ignore[attr-defined]
self.assertTrue(v2.upgraded) # type: ignore[attr-defined]
def test_[AWS-SECRET-REMOVED]de(self) -> None:
global ToyClassA, ToyClassB
class ToyClassA(styles.Versioned):
pass
class ToyClassB(styles.Versioned):
pass
x = ToyClassA()
y = ToyClassB()
pklA, pklB = pickle.dumps(x), pickle.dumps(y)
del x, y
ToyClassA.persistenceVersion = 1
def upgradeToVersion1(self: Any) -> None:
self.y = pickle.loads(pklB)
styles.doUpgrade()
ToyClassA.upgradeToVersion1 = upgradeToVersion1 # type: ignore[attr-defined]
ToyClassB.persistenceVersion = 1
def setUpgraded(self: object) -> None:
setattr(self, "upgraded", True)
ToyClassB.upgradeToVersion1 = setUpgraded # type: ignore[attr-defined]
x = pickle.loads(pklA)
styles.doUpgrade()
self.assertTrue(x.y.upgraded) # type: ignore[attr-defined]
class VersionedSubClass(styles.Versioned):
pass
class SecondVersionedSubClass(styles.Versioned):
pass
class VersionedSubSubClass(VersionedSubClass):
pass
class VersionedDiamondSubClass(VersionedSubSubClass, SecondVersionedSubClass):
pass
class AybabtuTests(TestCase):
"""
L{styles._aybabtu} gets all of classes in the inheritance hierarchy of its
argument that are strictly between L{Versioned} and the class itself.
"""
def test_aybabtuStrictEmpty(self) -> None:
"""
L{styles._aybabtu} of L{Versioned} itself is an empty list.
"""
self.assertEqual(styles._aybabtu(styles.Versioned), [])
def test_aybabtuStrictSubclass(self) -> None:
"""
There are no classes I{between} L{VersionedSubClass} and L{Versioned},
so L{styles._aybabtu} returns an empty list.
"""
self.assertEqual(styles._aybabtu(VersionedSubClass), [])
def test_aybabtuSubsubclass(self) -> None:
"""
With a sub-sub-class of L{Versioned}, L{styles._aybabtu} returns a list
containing the intervening subclass.
"""
self.assertEqual(styles._aybabtu(VersionedSubSubClass), [VersionedSubClass])
def test_aybabtuStrict(self) -> None:
"""
For a diamond-shaped inheritance graph, L{styles._aybabtu} returns a
list containing I{both} intermediate subclasses.
"""
self.assertEqual(
styles._aybabtu(VersionedDiamondSubClass),
[VersionedSubSubClass, VersionedSubClass, SecondVersionedSubClass],
)
class MyEphemeral(styles.Ephemeral):
def __init__(self, x: int) -> None:
self.x = x
class EphemeralTests(TestCase):
def test_ephemeral(self) -> None:
o = MyEphemeral(3)
self.assertEqual(o.__class__, MyEphemeral)
self.assertEqual(o.x, 3)
pickl = pickle.dumps(o)
o = pickle.loads(pickl)
self.assertEqual(o.__class__, styles.Ephemeral)
self.assertFalse(hasattr(o, "x"))
class Pickleable:
def __init__(self, x: int) -> None:
self.x = x
def getX(self) -> int:
return self.x
class NotPickleable:
"""
A class that is not pickleable.
"""
def __reduce__(self) -> NoReturn:
"""
Raise an exception instead of pickling.
"""
raise TypeError("Not serializable.")
class CopyRegistered:
"""
A class that is pickleable only because it is registered with the
C{copyreg} module.
"""
def __init__(self) -> None:
"""
Ensure that this object is normally not pickleable.
"""
self.notPickleable = NotPickleable()
class CopyRegisteredLoaded:
"""
L{CopyRegistered} after unserialization.
"""
def reduceCopyRegistered(cr: object) -> tuple[type[CopyRegisteredLoaded], tuple[()]]:
"""
Externally implement C{__reduce__} for L{CopyRegistered}.
@param cr: The L{CopyRegistered} instance.
@return: a 2-tuple of callable and argument list, in this case
L{CopyRegisteredLoaded} and no arguments.
"""
return CopyRegisteredLoaded, ()
copyreg.pickle(CopyRegistered, reduceCopyRegistered) # type: ignore[arg-type]
class A:
"""
dummy class
"""
bmethod: Callable[[], None]
def amethod(self) -> None:
pass
class B:
"""
dummy class
"""
a: A
def bmethod(self) -> None:
pass
def funktion() -> None:
pass
class PicklingTests(TestCase):
"""Test pickling of extra object types."""
def test_module(self) -> None:
pickl = pickle.dumps(styles)
o = pickle.loads(pickl)
self.assertEqual(o, styles)
def test_instanceMethod(self) -> None:
obj = Pickleable(4)
pickl = pickle.dumps(obj.getX)
o = pickle.loads(pickl)
self.assertEqual(o(), 4)
self.assertEqual(type(o), type(obj.getX))
class StringIOTransitionTests(TestCase):
"""
When pickling a cStringIO in Python 2, it should unpickle as a BytesIO or a
StringIO in Python 3, depending on the type of its contents.
"""
def test_unpickleBytesIO(self) -> None:
"""
A cStringIO pickled with bytes in it will yield an L{io.BytesIO} on
python 3.
"""
pickledStringIWithText = (
b"ctwisted.persisted.styles\nunpickleStringI\np0\n"
b"(S'test'\np1\nI0\ntp2\nRp3\n."
)
loaded = pickle.loads(pickledStringIWithText)
self.assertIsInstance(loaded, io.StringIO)
self.assertEqual(loaded.getvalue(), "test")
class EvilSourceror:
a: EvilSourceror
b: EvilSourceror
c: object
def __init__(self, x: object) -> None:
self.a = self
self.a.b = self
self.a.b.c = x
class NonDictState:
state: str
def __getstate__(self) -> str:
return self.state
def __setstate__(self, state: str) -> None:
self.state = state
_CircularTupleType = List[Tuple["_CircularTupleType", int]]
class AOTTests(TestCase):
def test_simpleTypes(self) -> None:
obj = (
1,
2.0,
3j,
True,
slice(1, 2, 3),
"hello",
"world",
sys.maxsize + 1,
None,
Ellipsis,
)
rtObj = aot.unjellyFromSource(aot.jellyToSource(obj))
self.assertEqual(obj, rtObj)
def test_methodSelfIdentity(self) -> None:
a = A()
b = B()
a.bmethod = b.bmethod
b.a = a
im_ = aot.unjellyFromSource(aot.jellyToSource(b)).a.bmethod
self.assertEqual(aot._selfOfMethod(im_).__class__, aot._classOfMethod(im_))
def test_methodNotSelfIdentity(self) -> None:
"""
If a class change after an instance has been created,
L{aot.unjellyFromSource} shoud raise a C{TypeError} when trying to
unjelly the instance.
"""
a = A()
b = B()
a.bmethod = b.bmethod
b.a = a
savedbmethod = B.bmethod
del B.bmethod
try:
self.assertRaises(TypeError, aot.unjellyFromSource, aot.jellyToSource(b))
finally:
B.bmethod = savedbmethod # type: ignore[method-assign]
def test_unsupportedType(self) -> None:
"""
L{aot.jellyToSource} should raise a C{TypeError} when trying to jelly
an unknown type without a C{__dict__} property or C{__getstate__}
method.
"""
class UnknownType:
@property
def __dict__(self) -> NoReturn: # type: ignore[override]
raise AttributeError()
@property
def __getstate__(self) -> NoReturn:
raise AttributeError()
self.assertRaises(TypeError, aot.jellyToSource, UnknownType())
def test_basicIdentity(self) -> None:
# Anyone wanting to make this datastructure more complex, and thus this
# test more comprehensive, is welcome to do so.
aj = aot.AOTJellier().jellyToAO
d = {"hello": "world", "method": aj}
l = [
1,
2,
3,
'he\tllo\n\n"x world!',
"goodbye \n\t\u1010 world!",
1,
1.0,
100**100,
unittest,
aot.AOTJellier,
d,
funktion,
]
t = tuple(l)
l.append(l)
l.append(t)
l.append(t)
uj = aot.unjellyFromSource(aot.jellyToSource([l, l]))
assert uj[0] is uj[1]
assert uj[1][0:5] == l[0:5]
def test_nonDictState(self) -> None:
a = NonDictState()
a.state = "meringue!"
assert aot.unjellyFromSource(aot.jellyToSource(a)).state == a.state
def test_copyReg(self) -> None:
"""
L{aot.jellyToSource} and L{aot.unjellyFromSource} honor functions
registered in the pickle copy registry.
"""
uj = aot.unjellyFromSource(aot.jellyToSource(CopyRegistered()))
self.assertIsInstance(uj, CopyRegisteredLoaded)
def test_funkyReferences(self) -> None:
o = EvilSourceror(EvilSourceror([]))
j1 = aot.jellyToAOT(o)
oj = aot.unjellyFromAOT(j1)
assert oj.a is oj
assert oj.a.b is oj.b
assert oj.c is not oj.c.c
def test_circularTuple(self) -> None:
"""
L{aot.jellyToAOT} can persist circular references through tuples.
"""
l: _CircularTupleType = []
t = (l, 4321)
l.append(t)
j1 = aot.jellyToAOT(l)
oj = aot.unjellyFromAOT(j1)
self.assertIsInstance(oj[0], tuple)
self.assertIs(oj[0][0], oj)
self.assertEqual(oj[0][1], 4321)
def testIndentify(self) -> None:
"""
The generated serialization is indented.
"""
self.assertEqual(
aot.jellyToSource({"hello": {"world": []}}),
textwrap.dedent(
"""\
app={
'hello':{
'world':[],
},
}""",
),
)
class CrefUtilTests(TestCase):
"""
Tests for L{crefutil}.
"""
def test_dictUnknownKey(self) -> None:
"""
L{crefutil._DictKeyAndValue} only support keys C{0} and C{1}.
"""
d = crefutil._DictKeyAndValue({})
self.assertRaises(RuntimeError, d.__setitem__, 2, 3)
def test_deferSetMultipleTimes(self) -> None:
"""
L{crefutil._Defer} can be assigned a key only one time.
"""
d = crefutil._Defer()
d[0] = 1
self.assertRaises(RuntimeError, d.__setitem__, 0, 1)
def test_containerWhereAllElementsAreKnown(self) -> None:
"""
A L{crefutil._Container} where all of its elements are known at
construction time is nonsensical and will result in errors in any call
to addDependant.
"""
container = crefutil._Container([1, 2, 3], list)
self.assertRaises(AssertionError, container.addDependant, {}, "ignore-me")
def test_[AWS-SECRET-REMOVED]s(self) -> None:
"""
If a dictionary key contains a circular reference (which is probably a
bad practice anyway) it will be resolved by a
L{crefutil._DictKeyAndValue}, not by placing a L{crefutil.NotKnown}
into a dictionary key.
"""
self.assertRaises(
AssertionError, dict().__setitem__, crefutil.NotKnown(), "value"
)
def test_dontCallInstanceMethodsThatArentReady(self) -> None:
"""
L{crefutil._InstanceMethod} raises L{AssertionError} to indicate it
should not be called. This should not be possible with any of its API
clients, but is provided for helping to debug.
"""
self.assertRaises(
AssertionError,
crefutil._InstanceMethod("no_name", crefutil.NotKnown(), type),
)
testCases = [VersionTests, EphemeralTests, PicklingTests]

View File

@@ -0,0 +1,744 @@
# Copyright (c) 2005 Divmod, Inc.
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for Twisted plugin system.
"""
from __future__ import annotations
import compileall
import errno
import functools
import os
import sys
import time
from importlib import invalidate_caches as invalidateImportCaches
from types import ModuleType
from typing import Callable, TypedDict, TypeVar
from zope.interface import Interface
from twisted import plugin
from twisted.python.filepath import FilePath
from twisted.python.log import EventDict, addObserver, removeObserver, textFromEventDict
from twisted.trial import unittest
_T = TypeVar("_T")
class ITestPlugin(Interface):
"""
A plugin for use by the plugin system's unit tests.
Do not use this.
"""
class ITestPlugin2(Interface):
"""
See L{ITestPlugin}.
"""
class PluginTests(unittest.TestCase):
"""
Tests which verify the behavior of the current, active Twisted plugins
directory.
"""
def setUp(self) -> None:
"""
Save C{sys.path} and C{sys.modules}, and create a package for tests.
"""
self.originalPath = sys.path[:]
self.savedModules = sys.modules.copy()
self.root = FilePath(self.mktemp())
self.root.createDirectory()
self.package = self.root.child("mypackage")
self.package.createDirectory()
self.package.child("__init__.py").setContent(b"")
FilePath(__file__).sibling("plugin_basic.py").copyTo(
self.package.child("testplugin.py")
)
self.originalPlugin = "testplugin"
sys.path.insert(0, self.root.path)
import mypackage # type: ignore[import-not-found]
self.module = mypackage
def tearDown(self) -> None:
"""
Restore C{sys.path} and C{sys.modules} to their original values.
"""
sys.path[:] = self.originalPath
sys.modules.clear()
sys.modules.update(self.savedModules)
def _unimportPythonModule(
self, module: ModuleType, deleteSource: bool = False
) -> None:
assert module.__file__ is not None
modulePath = module.__name__.split(".")
packageName = ".".join(modulePath[:-1])
moduleName = modulePath[-1]
delattr(sys.modules[packageName], moduleName)
del sys.modules[module.__name__]
for ext in ["c", "o"] + (deleteSource and [""] or []):
try:
os.remove(module.__file__ + ext)
except FileNotFoundError:
pass
def _clearCache(self) -> None:
"""
Remove the plugins B{droping.cache} file.
"""
self.package.child("dropin.cache").remove()
def _withCacheness(
meth: Callable[[PluginTests], object]
) -> Callable[[PluginTests], None]:
"""
This is a paranoid test wrapper, that calls C{meth} 2 times, clear the
cache, and calls it 2 other times. It's supposed to ensure that the
plugin system behaves correctly no matter what the state of the cache
is.
"""
@functools.wraps(meth)
def wrapped(self: PluginTests) -> None:
meth(self)
meth(self)
self._clearCache()
meth(self)
meth(self)
return wrapped
@_withCacheness
def test_cache(self) -> None:
"""
Check that the cache returned by L{plugin.getCache} hold the plugin
B{testplugin}, and that this plugin has the properties we expect:
provide L{TestPlugin}, has the good name and description, and can be
loaded successfully.
"""
cache = plugin.getCache(self.module)
dropin = cache[self.originalPlugin]
self.assertEqual(dropin.moduleName, f"mypackage.{self.originalPlugin}")
self.assertIn("I'm a test drop-in.", dropin.description)
# Note, not the preferred way to get a plugin by its interface.
p1 = [p for p in dropin.plugins if ITestPlugin in p.provided][0]
self.assertIs(p1.dropin, dropin)
self.assertEqual(p1.name, "TestPlugin")
# Check the content of the description comes from the plugin module
# docstring
self.assertEqual(
p1.description.strip(), "A plugin used solely for testing purposes."
)
self.assertEqual(p1.provided, [ITestPlugin, plugin.IPlugin])
realPlugin = p1.load()
# The plugin should match the class present in sys.modules
self.assertIs(
realPlugin,
sys.modules[f"mypackage.{self.originalPlugin}"].TestPlugin,
)
# And it should also match if we import it classicly
import mypackage.testplugin as tp # type: ignore[import-not-found]
self.assertIs(realPlugin, tp.TestPlugin)
def test_cacheRepr(self) -> None:
"""
L{CachedPlugin} has a helpful C{repr} which contains relevant
information about it.
"""
cachedDropin = plugin.getCache(self.module)[self.originalPlugin]
cachedPlugin = list(p for p in cachedDropin.plugins if p.name == "TestPlugin")[
0
]
self.assertEqual(
repr(cachedPlugin),
"<CachedPlugin 'TestPlugin'/'mypackage.testplugin' "
"(provides 'ITestPlugin, IPlugin')>",
)
@_withCacheness
def test_plugins(self) -> None:
"""
L{plugin.getPlugins} should return the list of plugins matching the
specified interface (here, L{ITestPlugin2}), and these plugins
should be instances of classes with a C{test} method, to be sure
L{plugin.getPlugins} load classes correctly.
"""
plugins = list(plugin.getPlugins(ITestPlugin2, self.module))
self.assertEqual(len(plugins), 2)
names = ["AnotherTestPlugin", "ThirdTestPlugin"]
for p in plugins:
names.remove(p.__name__) # type: ignore[attr-defined]
p.test() # type: ignore[attr-defined]
@_withCacheness
def test_detectNewFiles(self) -> None:
"""
Check that L{plugin.getPlugins} is able to detect plugins added at
runtime.
"""
FilePath(__file__).sibling("plugin_extra1.py").copyTo(
self.package.child("pluginextra.py")
)
try:
# Check that the current situation is clean
self.failIfIn("mypackage.pluginextra", sys.modules)
self.assertFalse(
hasattr(sys.modules["mypackage"], "pluginextra"),
"mypackage still has pluginextra module",
)
plgs = list(plugin.getPlugins(ITestPlugin, self.module))
# We should find 2 plugins: the one in testplugin, and the one in
# pluginextra
self.assertEqual(len(plgs), 2)
names = ["TestPlugin", "FourthTestPlugin"]
for p in plgs:
names.remove(p.__name__) # type: ignore[attr-defined]
p.test1() # type: ignore[attr-defined]
finally:
self._unimportPythonModule(sys.modules["mypackage.pluginextra"], True)
@_withCacheness
def test_detectFilesChanged(self) -> None:
"""
Check that if the content of a plugin change, L{plugin.getPlugins} is
able to detect the new plugins added.
"""
FilePath(__file__).sibling("plugin_extra1.py").copyTo(
self.package.child("pluginextra.py")
)
try:
plgs = list(plugin.getPlugins(ITestPlugin, self.module))
# Sanity check
self.assertEqual(len(plgs), 2)
FilePath(__file__).sibling("plugin_extra2.py").copyTo(
self.package.child("pluginextra.py")
)
# Fake out Python.
self._unimportPythonModule(sys.modules["mypackage.pluginextra"])
# Make sure additions are noticed
plgs = list(plugin.getPlugins(ITestPlugin, self.module))
self.assertEqual(len(plgs), 3)
names = ["TestPlugin", "FourthTestPlugin", "FifthTestPlugin"]
for p in plgs:
names.remove(p.__name__) # type: ignore[attr-defined]
p.test1() # type: ignore[attr-defined]
finally:
self._unimportPythonModule(sys.modules["mypackage.pluginextra"], True)
@_withCacheness
def test_detectFilesRemoved(self) -> None:
"""
Check that when a dropin file is removed, L{plugin.getPlugins} doesn't
return it anymore.
"""
FilePath(__file__).sibling("plugin_extra1.py").copyTo(
self.package.child("pluginextra.py")
)
try:
# Generate a cache with pluginextra in it.
list(plugin.getPlugins(ITestPlugin, self.module))
finally:
self._unimportPythonModule(sys.modules["mypackage.pluginextra"], True)
plgs = list(plugin.getPlugins(ITestPlugin, self.module))
self.assertEqual(1, len(plgs))
@_withCacheness
def test_nonexistentPathEntry(self) -> None:
"""
Test that getCache skips over any entries in a plugin package's
C{__path__} which do not exist.
"""
path = self.mktemp()
self.assertFalse(os.path.exists(path))
# Add the test directory to the plugins path
self.module.__path__.append(path)
try:
plgs = list(plugin.getPlugins(ITestPlugin, self.module))
self.assertEqual(len(plgs), 1)
finally:
self.module.__path__.remove(path)
@_withCacheness
def test_nonDirectoryChildEntry(self) -> None:
"""
Test that getCache skips over any entries in a plugin package's
C{__path__} which refer to children of paths which are not directories.
"""
path = FilePath(self.mktemp())
self.assertFalse(path.exists())
path.touch()
child = path.child("test_package").path
self.module.__path__.append(child)
try:
plgs = list(plugin.getPlugins(ITestPlugin, self.module))
self.assertEqual(len(plgs), 1)
finally:
self.module.__path__.remove(child)
def test_deployedMode(self) -> None:
"""
The C{dropin.cache} file may not be writable: the cache should still be
attainable, but an error should be logged to show that the cache
couldn't be updated.
"""
# Generate the cache
plugin.getCache(self.module)
cachepath = self.package.child("dropin.cache")
# Add a new plugin
FilePath(__file__).sibling("plugin_extra1.py").copyTo(
self.package.child("pluginextra.py")
)
invalidateImportCaches()
os.chmod(self.package.path, 0o500)
# Change the right of dropin.cache too for windows
os.chmod(cachepath.path, 0o400)
self.addCleanup(os.chmod, self.package.path, 0o700)
self.addCleanup(os.chmod, cachepath.path, 0o700)
# Start observing log events to see the warning
events: list[EventDict] = []
addObserver(events.append)
self.addCleanup(removeObserver, events.append)
cache = plugin.getCache(self.module)
# The new plugin should be reported
self.assertIn("pluginextra", cache)
self.assertIn(self.originalPlugin, cache)
# Make sure something was logged about the cache.
expected = "Unable to write to plugin cache %s: error number %d" % (
cachepath.path,
errno.EPERM,
)
for event in events:
maybeText = textFromEventDict(event)
assert maybeText is not None
if expected in maybeText: # pragma: no branch
break
else: # pragma: no cover
self.fail(
"Did not observe unwriteable cache warning in log "
"events: %r" % (events,)
)
# This is something like the Twisted plugins file.
pluginInitFile = b"""
from twisted.plugin import pluginPackagePaths
__path__.extend(pluginPackagePaths(__name__))
__all__ = []
"""
def pluginFileContents(name: str) -> bytes:
return (
(
"from zope.interface import provider\n"
"from twisted.plugin import IPlugin\n"
"from twisted.test.test_plugin import ITestPlugin\n"
"\n"
"@provider(IPlugin, ITestPlugin)\n"
"class {}:\n"
" pass\n"
)
.format(name)
.encode("ascii")
)
def _createPluginDummy(
entrypath: FilePath[str], pluginContent: bytes, real: bool, pluginModule: str
) -> FilePath[str]:
"""
Create a plugindummy package.
"""
entrypath.createDirectory()
pkg = entrypath.child("plugindummy")
pkg.createDirectory()
if real:
pkg.child("__init__.py").setContent(b"")
plugs = pkg.child("plugins")
plugs.createDirectory()
if real:
plugs.child("__init__.py").setContent(pluginInitFile)
plugs.child(pluginModule + ".py").setContent(pluginContent)
return plugs
class _HasBoolLegacyKey(TypedDict):
legacy: bool
class DeveloperSetupTests(unittest.TestCase):
"""
These tests verify things about the plugin system without actually
interacting with the deployed 'twisted.plugins' package, instead creating a
temporary package.
"""
def setUp(self) -> None:
"""
Create a complex environment with multiple entries on sys.path, akin to
a developer's environment who has a development (trunk) checkout of
Twisted, a system installed version of Twisted (for their operating
system's tools) and a project which provides Twisted plugins.
"""
self.savedPath = sys.path[:]
self.savedModules = sys.modules.copy()
self.fakeRoot = FilePath(self.mktemp())
self.fakeRoot.createDirectory()
self.systemPath = self.fakeRoot.child("system_path")
self.devPath = self.fakeRoot.child("development_path")
self.appPath = self.fakeRoot.child("application_path")
self.systemPackage = _createPluginDummy(
self.systemPath, pluginFileContents("system"), True, "plugindummy_builtin"
)
self.devPackage = _createPluginDummy(
self.devPath, pluginFileContents("dev"), True, "plugindummy_builtin"
)
self.appPackage = _createPluginDummy(
self.appPath, pluginFileContents("app"), False, "plugindummy_app"
)
# Now we're going to do the system installation.
sys.path.extend([x.path for x in [self.systemPath, self.appPath]])
# Run all the way through the plugins list to cause the
# L{plugin.getPlugins} generator to write cache files for the system
# installation.
self.getAllPlugins()
self.sysplug = self.systemPath.child("plugindummy").child("plugins")
self.syscache = self.sysplug.child("dropin.cache")
# Make sure there's a nice big difference in modification times so that
# we won't re-build the system cache.
now = time.time()
os.utime(self.sysplug.child("plugindummy_builtin.py").path, (now - 5000,) * 2)
os.utime(self.syscache.path, (now - 2000,) * 2)
# For extra realism, let's make sure that the system path is no longer
# writable.
self.lockSystem()
self.resetEnvironment()
def lockSystem(self) -> None:
"""
Lock the system directories, as if they were unwritable by this user.
"""
os.chmod(self.sysplug.path, 0o555)
os.chmod(self.syscache.path, 0o555)
def unlockSystem(self) -> None:
"""
Unlock the system directories, as if they were writable by this user.
"""
os.chmod(self.sysplug.path, 0o777)
os.chmod(self.syscache.path, 0o777)
def getAllPlugins(self) -> list[str]:
"""
Get all the plugins loadable from our dummy package, and return their
short names.
"""
# Import the module we just added to our path. (Local scope because
# this package doesn't exist outside of this test.)
import plugindummy.plugins # type: ignore[import-not-found]
x = list(plugin.getPlugins(ITestPlugin, plugindummy.plugins))
return [plug.__name__ for plug in x] # type: ignore[attr-defined]
def resetEnvironment(self) -> None:
"""
Change the environment to what it should be just as the test is
starting.
"""
self.unsetEnvironment()
sys.path.extend([x.path for x in [self.devPath, self.systemPath, self.appPath]])
def unsetEnvironment(self) -> None:
"""
Change the Python environment back to what it was before the test was
started.
"""
invalidateImportCaches()
sys.modules.clear()
sys.modules.update(self.savedModules)
sys.path[:] = self.savedPath
def tearDown(self) -> None:
"""
Reset the Python environment to what it was before this test ran, and
restore permissions on files which were marked read-only so that the
directory may be cleanly cleaned up.
"""
self.unsetEnvironment()
# Normally we wouldn't "clean up" the filesystem like this (leaving
# things for post-test inspection), but if we left the permissions the
# way they were, we'd be leaving files around that the buildbots
# couldn't delete, and that would be bad.
self.unlockSystem()
def test_developmentPluginAvailability(self) -> None:
"""
Plugins added in the development path should be loadable, even when
the (now non-importable) system path contains its own idea of the
list of plugins for a package. Inversely, plugins added in the
system path should not be available.
"""
# Run 3 times: uncached, cached, and then cached again to make sure we
# didn't overwrite / corrupt the cache on the cached try.
for x in range(3):
names = self.getAllPlugins()
names.sort()
self.assertEqual(names, ["app", "dev"])
def test_freshPyReplacesStalePyc(self) -> None:
"""
Verify that if a stale .pyc file on the PYTHONPATH is replaced by a
fresh .py file, the plugins in the new .py are picked up rather than
the stale .pyc, even if the .pyc is still around.
"""
mypath = self.appPackage.child("stale.py")
mypath.setContent(pluginFileContents("one"))
# Make it super stale
x = time.time() - 1000
os.utime(mypath.path, (x, x))
pyc = mypath.sibling("stale.pyc")
# compile it
# On python 3, don't use the __pycache__ directory; the intention
# of scanning for .pyc files is for configurations where you want
# to intentionally include them, which means we _don't_ scan for
# them inside cache directories.
extra = _HasBoolLegacyKey(legacy=True)
compileall.compile_dir(self.appPackage.path, quiet=1, **extra)
os.utime(pyc.path, (x, x))
# Eliminate the other option.
mypath.remove()
# Make sure it's the .pyc path getting cached.
self.resetEnvironment()
# Sanity check.
self.assertIn("one", self.getAllPlugins())
self.failIfIn("two", self.getAllPlugins())
self.resetEnvironment()
mypath.setContent(pluginFileContents("two"))
self.failIfIn("one", self.getAllPlugins())
self.assertIn("two", self.getAllPlugins())
def test_newPluginsOnReadOnlyPath(self) -> None:
"""
Verify that a failure to write the dropin.cache file on a read-only
path will not affect the list of plugins returned.
Note: this test should pass on both Linux and Windows, but may not
provide useful coverage on Windows due to the different meaning of
"read-only directory".
"""
self.unlockSystem()
self.sysplug.child("newstuff.py").setContent(pluginFileContents("one"))
self.lockSystem()
# Take the developer path out, so that the system plugins are actually
# examined.
sys.path.remove(self.devPath.path)
# Start observing log events to see the warning
events: list[EventDict] = []
addObserver(events.append)
self.addCleanup(removeObserver, events.append)
self.assertIn("one", self.getAllPlugins())
# Make sure something was logged about the cache.
expected = "Unable to write to plugin cache %s: error number %d" % (
self.syscache.path,
errno.EPERM,
)
for event in events:
maybeText = textFromEventDict(event)
assert maybeText is not None
if expected in maybeText: # pragma: no branch
break
else: # pragma: no cover
self.fail(
"Did not observe unwriteable cache warning in log "
"events: %r" % (events,)
)
class AdjacentPackageTests(unittest.TestCase):
"""
Tests for the behavior of the plugin system when there are multiple
installed copies of the package containing the plugins being loaded.
"""
def setUp(self) -> None:
"""
Save the elements of C{sys.path} and the items of C{sys.modules}.
"""
self.originalPath = sys.path[:]
self.savedModules = sys.modules.copy()
def tearDown(self) -> None:
"""
Restore C{sys.path} and C{sys.modules} to their original values.
"""
sys.path[:] = self.originalPath
sys.modules.clear()
sys.modules.update(self.savedModules)
def createDummyPackage(
self, root: FilePath[str], name: str, pluginName: str
) -> FilePath[str]:
"""
Create a directory containing a Python package named I{dummy} with a
I{plugins} subpackage.
@type root: L{FilePath}
@param root: The directory in which to create the hierarchy.
@type name: C{str}
@param name: The name of the directory to create which will contain
the package.
@type pluginName: C{str}
@param pluginName: The name of a module to create in the
I{dummy.plugins} package.
@rtype: L{FilePath}
@return: The directory which was created to contain the I{dummy}
package.
"""
directory = root.child(name)
package = directory.child("dummy")
package.makedirs()
package.child("__init__.py").setContent(b"")
plugins = package.child("plugins")
plugins.makedirs()
plugins.child("__init__.py").setContent(pluginInitFile)
pluginModule = plugins.child(pluginName + ".py")
pluginModule.setContent(pluginFileContents(name))
return directory
def test_[AWS-SECRET-REMOVED]d(self) -> None:
"""
Only plugins from the first package in sys.path should be returned by
getPlugins in the case where there are two Python packages by the same
name installed, each with a plugin module by a single name.
"""
root = FilePath(self.mktemp())
root.makedirs()
firstDirectory = self.createDummyPackage(root, "first", "someplugin")
secondDirectory = self.createDummyPackage(root, "second", "someplugin")
sys.path.append(firstDirectory.path)
sys.path.append(secondDirectory.path)
import dummy.plugins # type: ignore[import-not-found]
plugins = list(plugin.getPlugins(ITestPlugin, dummy.plugins))
self.assertEqual(["first"], [p.__name__ for p in plugins]) # type: ignore[attr-defined]
def test_[AWS-SECRET-REMOVED]scured(self) -> None:
"""
Plugins from the first package in sys.path should be returned by
getPlugins in the case where there are two Python packages by the same
name installed, each with a plugin module by a different name.
"""
root = FilePath(self.mktemp())
root.makedirs()
firstDirectory = self.createDummyPackage(root, "first", "thisplugin")
secondDirectory = self.createDummyPackage(root, "second", "thatplugin")
sys.path.append(firstDirectory.path)
sys.path.append(secondDirectory.path)
import dummy.plugins
plugins = list(plugin.getPlugins(ITestPlugin, dummy.plugins))
self.assertEqual(["first"], [p.__name__ for p in plugins]) # type: ignore[attr-defined]
class PackagePathTests(unittest.TestCase):
"""
Tests for L{plugin.pluginPackagePaths} which constructs search paths for
plugin packages.
"""
def setUp(self) -> None:
"""
Save the elements of C{sys.path}.
"""
self.originalPath = sys.path[:]
def tearDown(self) -> None:
"""
Restore C{sys.path} to its original value.
"""
sys.path[:] = self.originalPath
def test_pluginDirectories(self) -> None:
"""
L{plugin.pluginPackagePaths} should return a list containing each
directory in C{sys.path} with a suffix based on the supplied package
name.
"""
foo = FilePath("foo")
bar = FilePath("bar")
sys.path = [foo.path, bar.path]
self.assertEqual(
plugin.pluginPackagePaths("dummy.plugins"),
[
foo.child("dummy").child("plugins").path,
bar.child("dummy").child("plugins").path,
],
)
def test_pluginPackagesExcluded(self) -> None:
"""
L{plugin.pluginPackagePaths} should exclude directories which are
Python packages. The only allowed plugin package (the only one
associated with a I{dummy} package which Python will allow to be
imported) will already be known to the caller of
L{plugin.pluginPackagePaths} and will most commonly already be in
the C{__path__} they are about to mutate.
"""
root = FilePath(self.mktemp())
foo = root.child("foo").child("dummy").child("plugins")
foo.makedirs()
foo.child("__init__.py").setContent(b"")
sys.path = [root.child("foo").path, root.child("bar").path]
self.assertEqual(
plugin.pluginPackagePaths("dummy.plugins"),
[root.child("bar").child("dummy").child("plugins").path],
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,140 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Test cases for twisted.protocols.postfix module.
"""
from typing import Dict, List, Tuple
from twisted.internet.testing import StringTransport
from twisted.protocols import postfix
from twisted.trial import unittest
class PostfixTCPMapQuoteTests(unittest.TestCase):
data = [
# (raw, quoted, [aliasQuotedForms]),
(b"foo", b"foo"),
(b"foo bar", b"foo%20bar"),
(b"foo\tbar", b"foo%09bar"),
(b"foo\nbar", b"foo%0Abar", b"foo%0abar"),
(
b"foo\r\nbar",
b"foo%0D%0Abar",
b"foo%0D%0abar",
b"foo%0d%0Abar",
b"foo%0d%0abar",
),
(b"foo ", b"foo%20"),
(b" foo", b"%20foo"),
]
def testData(self):
for entry in self.data:
raw = entry[0]
quoted = entry[1:]
self.assertEqual(postfix.quote(raw), quoted[0])
for q in quoted:
self.assertEqual(postfix.unquote(q), raw)
class PostfixTCPMapServerTestCase:
data: Dict[bytes, bytes] = {
# 'key': 'value',
}
chat: List[Tuple[bytes, bytes]] = [
# (input, expected_output),
]
def test_chat(self):
"""
Test that I{get} and I{put} commands are responded to correctly by
L{postfix.PostfixTCPMapServer} when its factory is an instance of
L{postifx.PostfixTCPMapDictServerFactory}.
"""
factory = postfix.PostfixTCPMapDictServerFactory(self.data)
transport = StringTransport()
protocol = postfix.PostfixTCPMapServer()
protocol.service = factory
protocol.factory = factory
protocol.makeConnection(transport)
for input, expected_output in self.chat:
protocol.lineReceived(input)
self.assertEqual(
transport.value(),
expected_output,
"For %r, expected %r but got %r"
% (input, expected_output, transport.value()),
)
transport.clear()
protocol.setTimeout(None)
def test_deferredChat(self):
"""
Test that I{get} and I{put} commands are responded to correctly by
L{postfix.PostfixTCPMapServer} when its factory is an instance of
L{postifx.PostfixTCPMapDeferringDictServerFactory}.
"""
factory = postfix.PostfixTCPMapDeferringDictServerFactory(self.data)
transport = StringTransport()
protocol = postfix.PostfixTCPMapServer()
protocol.service = factory
protocol.factory = factory
protocol.makeConnection(transport)
for input, expected_output in self.chat:
protocol.lineReceived(input)
self.assertEqual(
transport.value(),
expected_output,
"For {!r}, expected {!r} but got {!r}".format(
input, expected_output, transport.value()
),
)
transport.clear()
protocol.setTimeout(None)
def test_getException(self):
"""
If the factory throws an exception,
error code 400 must be returned.
"""
class ErrorFactory:
"""
Factory that raises an error on key lookup.
"""
def get(self, key):
raise Exception("This is a test error")
server = postfix.PostfixTCPMapServer()
server.factory = ErrorFactory()
server.transport = StringTransport()
server.lineReceived(b"get example")
self.assertEqual(server.transport.value(), b"400 This is a test error\n")
class ValidTests(PostfixTCPMapServerTestCase, unittest.TestCase):
data = {
b"foo": b"ThisIs Foo",
b"bar": b" bar really is found\r\n",
}
chat = [
(b"get", b"400 Command 'get' takes 1 parameters.\n"),
(b"get foo bar", b"500 \n"),
(b"put", b"400 Command 'put' takes 2 parameters.\n"),
(b"put foo", b"400 Command 'put' takes 2 parameters.\n"),
(b"put foo bar baz", b"500 put is not implemented yet.\n"),
(b"put foo bar", b"500 put is not implemented yet.\n"),
(b"get foo", b"200 ThisIs%20Foo\n"),
(b"get bar", b"200 %20bar%20really%20is%20found%0D%0A\n"),
(b"get baz", b"500 \n"),
(b"foo", b"400 unknown command\n"),
]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,227 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Test cases for twisted.protocols package.
"""
from twisted.internet import address, defer, protocol, reactor
from twisted.protocols import portforward, wire
from twisted.python.compat import iterbytes
from twisted.test import proto_helpers
from twisted.trial import unittest
class WireTests(unittest.TestCase):
"""
Test wire protocols.
"""
def test_echo(self):
"""
Test wire.Echo protocol: send some data and check it send it back.
"""
t = proto_helpers.StringTransport()
a = wire.Echo()
a.makeConnection(t)
a.dataReceived(b"hello")
a.dataReceived(b"world")
a.dataReceived(b"how")
a.dataReceived(b"are")
a.dataReceived(b"you")
self.assertEqual(t.value(), b"helloworldhowareyou")
def test_who(self):
"""
Test wire.Who protocol.
"""
t = proto_helpers.StringTransport()
a = wire.Who()
a.makeConnection(t)
self.assertEqual(t.value(), b"root\r\n")
def test_QOTD(self):
"""
Test wire.QOTD protocol.
"""
t = proto_helpers.StringTransport()
a = wire.QOTD()
a.makeConnection(t)
self.assertEqual(t.value(), b"An apple a day keeps the doctor away.\r\n")
def test_discard(self):
"""
Test wire.Discard protocol.
"""
t = proto_helpers.StringTransport()
a = wire.Discard()
a.makeConnection(t)
a.dataReceived(b"hello")
a.dataReceived(b"world")
a.dataReceived(b"how")
a.dataReceived(b"are")
a.dataReceived(b"you")
self.assertEqual(t.value(), b"")
class TestableProxyClientFactory(portforward.ProxyClientFactory):
"""
Test proxy client factory that keeps the last created protocol instance.
@ivar protoInstance: the last instance of the protocol.
@type protoInstance: L{portforward.ProxyClient}
"""
def buildProtocol(self, addr):
"""
Create the protocol instance and keeps track of it.
"""
proto = portforward.ProxyClientFactory.buildProtocol(self, addr)
self.protoInstance = proto
return proto
class TestableProxyFactory(portforward.ProxyFactory):
"""
Test proxy factory that keeps the last created protocol instance.
@ivar protoInstance: the last instance of the protocol.
@type protoInstance: L{portforward.ProxyServer}
@ivar clientFactoryInstance: client factory used by C{protoInstance} to
create forward connections.
@type clientFactoryInstance: L{TestableProxyClientFactory}
"""
def buildProtocol(self, addr):
"""
Create the protocol instance, keeps track of it, and makes it use
C{clientFactoryInstance} as client factory.
"""
proto = portforward.ProxyFactory.buildProtocol(self, addr)
self.clientFactoryInstance = TestableProxyClientFactory()
# Force the use of this specific instance
proto.clientProtocolFactory = lambda: self.clientFactoryInstance
self.protoInstance = proto
return proto
class PortforwardingTests(unittest.TestCase):
"""
Test port forwarding.
"""
def setUp(self):
self.serverProtocol = wire.Echo()
self.clientProtocol = protocol.Protocol()
self.openPorts = []
def tearDown(self):
try:
self.proxyServerFactory.protoInstance.transport.loseConnection()
except AttributeError:
pass
try:
pi = self.proxyServerFactory.clientFactoryInstance.protoInstance
pi.transport.loseConnection()
except AttributeError:
pass
try:
self.clientProtocol.transport.loseConnection()
except AttributeError:
pass
try:
self.serverProtocol.transport.loseConnection()
except AttributeError:
pass
return defer.gatherResults(
[defer.maybeDeferred(p.stopListening) for p in self.openPorts]
)
def test_portforward(self):
"""
Test port forwarding through Echo protocol.
"""
realServerFactory = protocol.ServerFactory()
realServerFactory.protocol = lambda: self.serverProtocol
realServerPort = reactor.listenTCP(0, realServerFactory, interface="127.0.0.1")
self.openPorts.append(realServerPort)
self.proxyServerFactory = TestableProxyFactory(
"127.0.0.1", realServerPort.getHost().port
)
proxyServerPort = reactor.listenTCP(
0, self.proxyServerFactory, interface="127.0.0.1"
)
self.openPorts.append(proxyServerPort)
nBytes = 1000
received = []
d = defer.Deferred()
def testDataReceived(data):
received.extend(iterbytes(data))
if len(received) >= nBytes:
self.assertEqual(b"".join(received), b"x" * nBytes)
d.callback(None)
self.clientProtocol.dataReceived = testDataReceived
def testConnectionMade():
self.clientProtocol.transport.write(b"x" * nBytes)
self.clientProtocol.connectionMade = testConnectionMade
clientFactory = protocol.ClientFactory()
clientFactory.protocol = lambda: self.clientProtocol
reactor.connectTCP("127.0.0.1", proxyServerPort.getHost().port, clientFactory)
return d
def test_registerProducers(self):
"""
The proxy client registers itself as a producer of the proxy server and
vice versa.
"""
# create a ProxyServer instance
addr = address.IPv4Address("TCP", "127.0.0.1", 0)
server = portforward.ProxyFactory("127.0.0.1", 0).buildProtocol(addr)
# set the reactor for this test
reactor = proto_helpers.MemoryReactor()
server.reactor = reactor
# make the connection
serverTransport = proto_helpers.StringTransport()
server.makeConnection(serverTransport)
# check that the ProxyClientFactory is connecting to the backend
self.assertEqual(len(reactor.tcpClients), 1)
# get the factory instance and check it's the one we expect
host, port, clientFactory, timeout, _ = reactor.tcpClients[0]
self.assertIsInstance(clientFactory, portforward.ProxyClientFactory)
# Connect it
client = clientFactory.buildProtocol(addr)
clientTransport = proto_helpers.StringTransport()
client.makeConnection(clientTransport)
# check that the producers are registered
self.assertIs(clientTransport.producer, serverTransport)
self.assertIs(serverTransport.producer, clientTransport)
# check the streaming attribute in both transports
self.assertTrue(clientTransport.streaming)
self.assertTrue(serverTransport.streaming)
class StringTransportTests(unittest.TestCase):
"""
Test L{proto_helpers.StringTransport} helper behaviour.
"""
def test_noUnicode(self):
"""
Test that L{proto_helpers.StringTransport} doesn't accept unicode data.
"""
s = proto_helpers.StringTransport()
self.assertRaises(TypeError, s.write, "foo")

View File

@@ -0,0 +1,126 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Test cases for L{twisted.python.randbytes}.
"""
from __future__ import annotations
from typing import Callable
from typing_extensions import NoReturn, Protocol
from twisted.python import randbytes
from twisted.trial import unittest
class _SupportsAssertions(Protocol):
def assertEqual(self, a: object, b: object) -> object:
...
def assertNotEqual(self, a: object, b: object) -> object:
...
class SecureRandomTestCaseBase:
"""
Base class for secureRandom test cases.
"""
def _check(self: _SupportsAssertions, source: Callable[[int], bytes]) -> None:
"""
The given random bytes source should return the number of bytes
requested each time it is called and should probably not return the
same bytes on two consecutive calls (although this is a perfectly
legitimate occurrence and rejecting it may generate a spurious failure
-- maybe we'll get lucky and the heat death with come first).
"""
for nbytes in range(17, 25):
s = source(nbytes)
self.assertEqual(len(s), nbytes)
s2 = source(nbytes)
self.assertEqual(len(s2), nbytes)
# This is crude but hey
self.assertNotEqual(s2, s)
class SecureRandomTests(SecureRandomTestCaseBase, unittest.TestCase):
"""
Test secureRandom under normal conditions.
"""
def test_normal(self) -> None:
"""
L{randbytes.secureRandom} should return a string of the requested
length and make some effort to make its result otherwise unpredictable.
"""
self._check(randbytes.secureRandom)
class ConditionalSecureRandomTests(
SecureRandomTestCaseBase, unittest.SynchronousTestCase
):
"""
Test random sources one by one, then remove it to.
"""
def setUp(self) -> None:
"""
Create a L{randbytes.RandomFactory} to use in the tests.
"""
self.factory = randbytes.RandomFactory()
def errorFactory(self, nbytes: object) -> NoReturn:
"""
A factory raising an error when a source is not available.
"""
raise randbytes.SourceNotAvailable()
def test_osUrandom(self) -> None:
"""
L{RandomFactory._osUrandom} should work as a random source whenever
L{os.urandom} is available.
"""
self._check(self.factory._osUrandom)
def test_withoutAnything(self) -> None:
"""
Remove all secure sources and assert it raises a failure. Then try the
fallback parameter.
"""
self.factory._osUrandom = self.errorFactory # type: ignore[method-assign]
self.assertRaises(
randbytes.SecureRandomNotAvailable, self.factory.secureRandom, 18
)
def wrapper() -> bytes:
return self.factory.secureRandom(18, fallback=True)
s = self.assertWarns(
RuntimeWarning,
"urandom unavailable - "
"proceeding with non-cryptographically secure random source",
__file__,
wrapper,
)
self.assertEqual(len(s), 18)
class RandomBaseTests(SecureRandomTestCaseBase, unittest.SynchronousTestCase):
"""
'Normal' random test cases.
"""
def test_normal(self) -> None:
"""
Test basic case.
"""
self._check(randbytes.insecureRandom)
def test_withoutGetrandbits(self) -> None:
"""
Test C{insecureRandom} without C{random.getrandbits}.
"""
factory = randbytes.RandomFactory()
factory.getrandbits = None
self._check(factory.insecureRandom)

View File

@@ -0,0 +1,266 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
from __future__ import annotations
import os
import sys
import types
from typing_extensions import NoReturn
from twisted.python import rebuild
from twisted.trial.unittest import TestCase
from . import crash_test_dummy
f = crash_test_dummy.foo
class Foo:
pass
class Bar(Foo):
pass
class Baz:
pass
class Buz(Bar, Baz):
pass
class HashRaisesRuntimeError:
"""
Things that don't hash (raise an Exception) should be ignored by the
rebuilder.
@ivar hashCalled: C{bool} set to True when __hash__ is called.
"""
def __init__(self) -> None:
self.hashCalled = False
def __hash__(self) -> NoReturn:
self.hashCalled = True
raise RuntimeError("not a TypeError!")
# Set in test_hashException
unhashableObject = None
class RebuildTests(TestCase):
"""
Simple testcase for rebuilding, to at least exercise the code.
"""
def setUp(self) -> None:
self.libPath = self.mktemp()
os.mkdir(self.libPath)
self.fakelibPath = os.path.join(self.libPath, "twisted_rebuild_fakelib")
os.mkdir(self.fakelibPath)
open(os.path.join(self.fakelibPath, "__init__.py"), "w").close()
sys.path.insert(0, self.libPath)
def tearDown(self) -> None:
sys.path.remove(self.libPath)
def test_FileRebuild(self) -> None:
import shutil
import time
from twisted.python.util import sibpath
shutil.copyfile(
sibpath(__file__, "myrebuilder1.py"),
os.path.join(self.fakelibPath, "myrebuilder.py"),
)
from twisted_rebuild_fakelib import ( # type: ignore[import-not-found]
myrebuilder,
)
a = myrebuilder.A()
b = myrebuilder.B()
i = myrebuilder.Inherit()
self.assertEqual(a.a(), "a")
# Necessary because the file has not "changed" if a second has not gone
# by in unix. This sucks, but it's not often that you'll be doing more
# than one reload per second.
time.sleep(1.1)
shutil.copyfile(
sibpath(__file__, "myrebuilder2.py"),
os.path.join(self.fakelibPath, "myrebuilder.py"),
)
rebuild.rebuild(myrebuilder)
b2 = myrebuilder.B()
self.assertEqual(b2.b(), "c")
self.assertEqual(b.b(), "c")
self.assertEqual(i.a(), "d")
self.assertEqual(a.a(), "b")
def test_Rebuild(self) -> None:
"""
Rebuilding an unchanged module.
"""
# This test would actually pass if rebuild was a no-op, but it
# ensures rebuild doesn't break stuff while being a less
# complex test than testFileRebuild.
x = crash_test_dummy.X("a")
rebuild.rebuild(crash_test_dummy, doLog=False)
# Instance rebuilding is triggered by attribute access.
x.do()
self.assertEqual(x.__class__, crash_test_dummy.X)
self.assertEqual(f, crash_test_dummy.foo)
def test_ComponentInteraction(self) -> None:
x = crash_test_dummy.XComponent()
x.setAdapter(crash_test_dummy.IX, crash_test_dummy.XA)
x.getComponent(crash_test_dummy.IX)
rebuild.rebuild(crash_test_dummy, 0)
newComponent = x.getComponent(crash_test_dummy.IX)
newComponent.method()
self.assertEqual(newComponent.__class__, crash_test_dummy.XA)
# Test that a duplicate registerAdapter is not allowed
from twisted.python import components
self.assertRaises(
ValueError,
components.registerAdapter,
crash_test_dummy.XA,
crash_test_dummy.X,
crash_test_dummy.IX,
)
def test_UpdateInstance(self) -> None:
global Foo, Buz
b = Buz()
class Foo:
def foo(self) -> None:
"""
Dummy method
"""
class Buz(Bar, Baz):
x = 10
rebuild.updateInstance(b)
assert hasattr(b, "foo"), "Missing method on rebuilt instance"
assert hasattr(b, "x"), "Missing class attribute on rebuilt instance"
def test_BananaInteraction(self) -> None:
from twisted.python import rebuild
from twisted.spread import banana
rebuild.latestClass(banana.Banana)
def test_hashException(self) -> None:
"""
Rebuilding something that has a __hash__ that raises a non-TypeError
shouldn't cause rebuild to die.
"""
global unhashableObject
unhashableObject = HashRaisesRuntimeError()
def _cleanup() -> None:
global unhashableObject
unhashableObject = None
self.addCleanup(_cleanup)
rebuild.rebuild(rebuild)
self.assertTrue(unhashableObject.hashCalled)
def test_Sensitive(self) -> None:
"""
L{twisted.python.rebuild.Sensitive}
"""
from twisted.python import rebuild
from twisted.python.rebuild import Sensitive
class TestSensitive(Sensitive):
def test_method(self) -> None:
"""
Dummy method
"""
testSensitive = TestSensitive()
testSensitive.rebuildUpToDate()
self.assertFalse(testSensitive.needRebuildUpdate())
# Test rebuilding a builtin class
newException = rebuild.latestClass(Exception)
self.assertEqual(repr(Exception), repr(newException))
self.assertEqual(newException, testSensitive.latestVersionOf(newException))
# Test types.MethodType on method in class
self.assertEqual(
TestSensitive.test_method,
testSensitive.latestVersionOf(TestSensitive.test_method),
)
# Test types.MethodType on method in instance of class
self.assertEqual(
testSensitive.test_method,
testSensitive.latestVersionOf(testSensitive.test_method),
)
# Test a class
self.assertEqual(TestSensitive, testSensitive.latestVersionOf(TestSensitive))
def myFunction() -> None:
"""
Dummy method
"""
# Test types.FunctionType
self.assertEqual(myFunction, testSensitive.latestVersionOf(myFunction))
class NewStyleTests(TestCase):
"""
Tests for rebuilding new-style classes of various sorts.
"""
def setUp(self) -> None:
self.m = types.ModuleType("whipping")
sys.modules["whipping"] = self.m
def tearDown(self) -> None:
del sys.modules["whipping"]
del self.m
def test_slots(self) -> None:
"""
Try to rebuild a new style class with slots defined.
"""
classDefinition = "class SlottedClass:\n" " __slots__ = ['a']\n"
exec(classDefinition, self.m.__dict__)
inst = self.m.SlottedClass()
inst.a = 7
exec(classDefinition, self.m.__dict__)
rebuild.updateInstance(inst)
self.assertEqual(inst.a, 7)
self.assertIs(type(inst), self.m.SlottedClass)
def test_typeSubclass(self) -> None:
"""
Try to rebuild a base type subclass.
"""
classDefinition = "class ListSubclass(list):\n" " pass\n"
exec(classDefinition, self.m.__dict__)
inst = self.m.ListSubclass()
inst.append(2)
exec(classDefinition, self.m.__dict__)
rebuild.updateInstance(inst)
self.assertEqual(inst[0], 2)
self.assertIs(type(inst), self.m.ListSubclass)

View File

@@ -0,0 +1,825 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Test cases for the L{twisted.python.reflect} module.
"""
import os
import weakref
from collections import deque
from twisted.python import reflect
from twisted.python.reflect import (
accumulateMethods,
addMethodNamesToDict,
fullyQualifiedName,
prefixedMethodNames,
prefixedMethods,
)
from twisted.trial.unittest import SynchronousTestCase as TestCase
class Base:
"""
A no-op class which can be used to verify the behavior of
method-discovering APIs.
"""
def method(self):
"""
A no-op method which can be discovered.
"""
class Sub(Base):
"""
A subclass of a class with a method which can be discovered.
"""
class Separate:
"""
A no-op class with methods with differing prefixes.
"""
def good_method(self):
"""
A no-op method which a matching prefix to be discovered.
"""
def bad_method(self):
"""
A no-op method with a mismatched prefix to not be discovered.
"""
class AccumulateMethodsTests(TestCase):
"""
Tests for L{accumulateMethods} which finds methods on a class hierarchy and
adds them to a dictionary.
"""
def test_ownClass(self):
"""
If x is and instance of Base and Base defines a method named method,
L{accumulateMethods} adds an item to the given dictionary with
C{"method"} as the key and a bound method object for Base.method value.
"""
x = Base()
output = {}
accumulateMethods(x, output)
self.assertEqual({"method": x.method}, output)
def test_baseClass(self):
"""
If x is an instance of Sub and Sub is a subclass of Base and Base
defines a method named method, L{accumulateMethods} adds an item to the
given dictionary with C{"method"} as the key and a bound method object
for Base.method as the value.
"""
x = Sub()
output = {}
accumulateMethods(x, output)
self.assertEqual({"method": x.method}, output)
def test_prefix(self):
"""
If a prefix is given, L{accumulateMethods} limits its results to
methods beginning with that prefix. Keys in the resulting dictionary
also have the prefix removed from them.
"""
x = Separate()
output = {}
accumulateMethods(x, output, "good_")
self.assertEqual({"method": x.good_method}, output)
class PrefixedMethodsTests(TestCase):
"""
Tests for L{prefixedMethods} which finds methods on a class hierarchy and
adds them to a dictionary.
"""
def test_onlyObject(self):
"""
L{prefixedMethods} returns a list of the methods discovered on an
object.
"""
x = Base()
output = prefixedMethods(x)
self.assertEqual([x.method], output)
def test_prefix(self):
"""
If a prefix is given, L{prefixedMethods} returns only methods named
with that prefix.
"""
x = Separate()
output = prefixedMethods(x, "good_")
self.assertEqual([x.good_method], output)
class PrefixedMethodNamesTests(TestCase):
"""
Tests for L{prefixedMethodNames}.
"""
def test_method(self):
"""
L{prefixedMethodNames} returns a list including methods with the given
prefix defined on the class passed to it.
"""
self.assertEqual(["method"], prefixedMethodNames(Separate, "good_"))
def test_inheritedMethod(self):
"""
L{prefixedMethodNames} returns a list included methods with the given
prefix defined on base classes of the class passed to it.
"""
class Child(Separate):
pass
self.assertEqual(["method"], prefixedMethodNames(Child, "good_"))
class AddMethodNamesToDictTests(TestCase):
"""
Tests for L{addMethodNamesToDict}.
"""
def test_baseClass(self):
"""
If C{baseClass} is passed to L{addMethodNamesToDict}, only methods which
are a subclass of C{baseClass} are added to the result dictionary.
"""
class Alternate:
pass
class Child(Separate, Alternate):
def good_alternate(self):
pass
result = {}
addMethodNamesToDict(Child, result, "good_", Alternate)
self.assertEqual({"alternate": 1}, result)
class Summer:
"""
A class we look up as part of the LookupsTests.
"""
def reallySet(self):
"""
Do something.
"""
class LookupsTests(TestCase):
"""
Tests for L{namedClass}, L{namedModule}, and L{namedAny}.
"""
def test_namedClassLookup(self):
"""
L{namedClass} should return the class object for the name it is passed.
"""
self.assertIs(reflect.namedClass("twisted.test.test_reflect.Summer"), Summer)
def test_namedModuleLookup(self):
"""
L{namedModule} should return the module object for the name it is
passed.
"""
from twisted.python import monkey
self.assertIs(reflect.namedModule("twisted.python.monkey"), monkey)
def test_namedAnyPackageLookup(self):
"""
L{namedAny} should return the package object for the name it is passed.
"""
import twisted.python
self.assertIs(reflect.namedAny("twisted.python"), twisted.python)
def test_namedAnyModuleLookup(self):
"""
L{namedAny} should return the module object for the name it is passed.
"""
from twisted.python import monkey
self.assertIs(reflect.namedAny("twisted.python.monkey"), monkey)
def test_namedAnyClassLookup(self):
"""
L{namedAny} should return the class object for the name it is passed.
"""
self.assertIs(reflect.namedAny("twisted.test.test_reflect.Summer"), Summer)
def test_namedAnyAttributeLookup(self):
"""
L{namedAny} should return the object an attribute of a non-module,
non-package object is bound to for the name it is passed.
"""
# Note - not assertIs because unbound method lookup creates a new
# object every time. This is a foolishness of Python's object
# implementation, not a bug in Twisted.
self.assertEqual(
reflect.namedAny("twisted.test.test_reflect.Summer.reallySet"),
Summer.reallySet,
)
def test_namedAnySecondAttributeLookup(self):
"""
L{namedAny} should return the object an attribute of an object which
itself was an attribute of a non-module, non-package object is bound to
for the name it is passed.
"""
self.assertIs(
reflect.namedAny("twisted.test.test_reflect." "Summer.reallySet.__doc__"),
Summer.reallySet.__doc__,
)
def test_importExceptions(self):
"""
Exceptions raised by modules which L{namedAny} causes to be imported
should pass through L{namedAny} to the caller.
"""
self.assertRaises(
ZeroDivisionError, reflect.namedAny, "twisted.test.reflect_helper_ZDE"
)
# Make sure that there is post-failed-import cleanup
self.assertRaises(
ZeroDivisionError, reflect.namedAny, "twisted.test.reflect_helper_ZDE"
)
self.assertRaises(
ValueError, reflect.namedAny, "twisted.test.reflect_helper_VE"
)
# Modules which themselves raise ImportError when imported should
# result in an ImportError
self.assertRaises(
ImportError, reflect.namedAny, "twisted.test.reflect_helper_IE"
)
def test_attributeExceptions(self):
"""
If segments on the end of a fully-qualified Python name represents
attributes which aren't actually present on the object represented by
the earlier segments, L{namedAny} should raise an L{AttributeError}.
"""
self.assertRaises(
AttributeError, reflect.namedAny, "twisted.nosuchmoduleintheworld"
)
# ImportError behaves somewhat differently between "import
# extant.nonextant" and "import extant.nonextant.nonextant", so test
# the latter as well.
self.assertRaises(
AttributeError, reflect.namedAny, "twisted.nosuch.modulein.theworld"
)
self.assertRaises(
AttributeError,
reflect.namedAny,
"twisted.test.test_reflect.Summer.nosuchattribute",
)
def test_invalidNames(self):
"""
Passing a name which isn't a fully-qualified Python name to L{namedAny}
should result in one of the following exceptions:
- L{InvalidName}: the name is not a dot-separated list of Python
objects
- L{ObjectNotFound}: the object doesn't exist
- L{ModuleNotFound}: the object doesn't exist and there is only one
component in the name
"""
err = self.assertRaises(
reflect.ModuleNotFound, reflect.namedAny, "nosuchmoduleintheworld"
)
self.assertEqual(str(err), "No module named 'nosuchmoduleintheworld'")
# This is a dot-separated list, but it isn't valid!
err = self.assertRaises(
reflect.ObjectNotFound, reflect.namedAny, "@#$@(#.!@(#!@#"
)
self.assertEqual(str(err), "'@#$@(#.!@(#!@#' does not name an object")
err = self.assertRaises(
reflect.ObjectNotFound, reflect.namedAny, "tcelfer.nohtyp.detsiwt"
)
self.assertEqual(str(err), "'tcelfer.nohtyp.detsiwt' does not name an object")
err = self.assertRaises(reflect.InvalidName, reflect.namedAny, "")
self.assertEqual(str(err), "Empty module name")
for invalidName in [".twisted", "twisted.", "twisted..python"]:
err = self.assertRaises(reflect.InvalidName, reflect.namedAny, invalidName)
self.assertEqual(
str(err),
"name must be a string giving a '.'-separated list of Python "
"identifiers, not %r" % (invalidName,),
)
def test_requireModuleImportError(self):
"""
When module import fails with ImportError it returns the specified
default value.
"""
for name in ["nosuchmtopodule", "no.such.module"]:
default = object()
result = reflect.requireModule(name, default=default)
self.assertIs(result, default)
def test_requireModuleDefaultNone(self):
"""
When module import fails it returns L{None} by default.
"""
result = reflect.requireModule("no.such.module")
self.assertIsNone(result)
def test_requireModuleRequestedImport(self):
"""
When module import succeed it returns the module and not the default
value.
"""
from twisted.python import monkey
default = object()
self.assertIs(
reflect.requireModule("twisted.python.monkey", default=default),
monkey,
)
class Breakable:
breakRepr = False
breakStr = False
def __str__(self) -> str:
if self.breakStr:
raise RuntimeError("str!")
else:
return "<Breakable>"
def __repr__(self) -> str:
if self.breakRepr:
raise RuntimeError("repr!")
else:
return "Breakable()"
class BrokenType(Breakable, type):
breakName = False
@property
def __name__(self):
if self.breakName:
raise RuntimeError("no name")
return "BrokenType"
BTBase = BrokenType("BTBase", (Breakable,), {"breakRepr": True, "breakStr": True})
class NoClassAttr(Breakable):
__class__ = property(lambda x: x.not_class) # type: ignore[assignment]
class SafeReprTests(TestCase):
"""
Tests for L{reflect.safe_repr} function.
"""
def test_workingRepr(self):
"""
L{reflect.safe_repr} produces the same output as C{repr} on a working
object.
"""
xs = ([1, 2, 3], b"a")
self.assertEqual(list(map(reflect.safe_repr, xs)), list(map(repr, xs)))
def test_brokenRepr(self):
"""
L{reflect.safe_repr} returns a string with class name, address, and
traceback when the repr call failed.
"""
b = Breakable()
b.breakRepr = True
bRepr = reflect.safe_repr(b)
self.assertIn("Breakable instance at 0x", bRepr)
# Check that the file is in the repr, but without the extension as it
# can be .py/.pyc
self.assertIn(os.path.splitext(__file__)[0], bRepr)
self.assertIn("RuntimeError: repr!", bRepr)
def test_brokenStr(self):
"""
L{reflect.safe_repr} isn't affected by a broken C{__str__} method.
"""
b = Breakable()
b.breakStr = True
self.assertEqual(reflect.safe_repr(b), repr(b))
def test_brokenClassRepr(self):
class X(BTBase):
breakRepr = True
reflect.safe_repr(X)
reflect.safe_repr(X())
def test_brokenReprIncludesID(self):
"""
C{id} is used to print the ID of the object in case of an error.
L{safe_repr} includes a traceback after a newline, so we only check
against the first line of the repr.
"""
class X(BTBase):
breakRepr = True
xRepr = reflect.safe_repr(X)
xReprExpected = f"<BrokenType instance at 0x{id(X):x} with repr error:"
self.assertEqual(xReprExpected, xRepr.split("\n")[0])
def test_brokenClassStr(self):
class X(BTBase):
breakStr = True
reflect.safe_repr(X)
reflect.safe_repr(X())
def test_brokenClassAttribute(self):
"""
If an object raises an exception when accessing its C{__class__}
attribute, L{reflect.safe_repr} uses C{type} to retrieve the class
object.
"""
b = NoClassAttr()
b.breakRepr = True
bRepr = reflect.safe_repr(b)
self.assertIn("NoClassAttr instance at 0x", bRepr)
self.assertIn(os.path.splitext(__file__)[0], bRepr)
self.assertIn("RuntimeError: repr!", bRepr)
def test_brokenClassNameAttribute(self):
"""
If a class raises an exception when accessing its C{__name__} attribute
B{and} when calling its C{__str__} implementation, L{reflect.safe_repr}
returns 'BROKEN CLASS' instead of the class name.
"""
class X(BTBase):
breakName = True
xRepr = reflect.safe_repr(X())
self.assertIn("<BROKEN CLASS AT 0x", xRepr)
self.assertIn(os.path.splitext(__file__)[0], xRepr)
self.assertIn("RuntimeError: repr!", xRepr)
class SafeStrTests(TestCase):
"""
Tests for L{reflect.safe_str} function.
"""
def test_workingStr(self):
x = [1, 2, 3]
self.assertEqual(reflect.safe_str(x), str(x))
def test_brokenStr(self):
b = Breakable()
b.breakStr = True
reflect.safe_str(b)
def test_workingAscii(self):
"""
L{safe_str} for C{str} with ascii-only data should return the
value unchanged.
"""
x = "a"
self.assertEqual(reflect.safe_str(x), "a")
def test_workingUtf8_3(self):
"""
L{safe_str} for C{bytes} with utf-8 encoded data should return
the value decoded into C{str}.
"""
x = b"t\xc3\xbcst"
self.assertEqual(reflect.safe_str(x), x.decode("utf-8"))
def test_brokenUtf8(self):
"""
Use str() for non-utf8 bytes: "b'non-utf8'"
"""
x = b"\xff"
xStr = reflect.safe_str(x)
self.assertEqual(xStr, str(x))
def test_brokenRepr(self):
b = Breakable()
b.breakRepr = True
reflect.safe_str(b)
def test_brokenClassStr(self):
class X(BTBase):
breakStr = True
reflect.safe_str(X)
reflect.safe_str(X())
def test_brokenClassRepr(self):
class X(BTBase):
breakRepr = True
reflect.safe_str(X)
reflect.safe_str(X())
def test_brokenClassAttribute(self):
"""
If an object raises an exception when accessing its C{__class__}
attribute, L{reflect.safe_str} uses C{type} to retrieve the class
object.
"""
b = NoClassAttr()
b.breakStr = True
bStr = reflect.safe_str(b)
self.assertIn("NoClassAttr instance at 0x", bStr)
self.assertIn(os.path.splitext(__file__)[0], bStr)
self.assertIn("RuntimeError: str!", bStr)
def test_brokenClassNameAttribute(self):
"""
If a class raises an exception when accessing its C{__name__} attribute
B{and} when calling its C{__str__} implementation, L{reflect.safe_str}
returns 'BROKEN CLASS' instead of the class name.
"""
class X(BTBase):
breakName = True
xStr = reflect.safe_str(X())
self.assertIn("<BROKEN CLASS AT 0x", xStr)
self.assertIn(os.path.splitext(__file__)[0], xStr)
self.assertIn("RuntimeError: str!", xStr)
class FilenameToModuleTests(TestCase):
"""
Test L{filenameToModuleName} detection.
"""
def setUp(self):
self.path = os.path.join(self.mktemp(), "fakepackage", "test")
os.makedirs(self.path)
with open(os.path.join(self.path, "__init__.py"), "w") as f:
f.write("")
with open(os.path.join(os.path.dirname(self.path), "__init__.py"), "w") as f:
f.write("")
def test_directory(self):
"""
L{filenameToModuleName} returns the correct module (a package) given a
directory.
"""
module = reflect.filenameToModuleName(self.path)
self.assertEqual(module, "fakepackage.test")
module = reflect.filenameToModuleName(self.path + os.path.sep)
self.assertEqual(module, "fakepackage.test")
def test_file(self):
"""
L{filenameToModuleName} returns the correct module given the path to
its file.
"""
module = reflect.filenameToModuleName(
os.path.join(self.path, "test_reflect.py")
)
self.assertEqual(module, "fakepackage.test.test_reflect")
def test_bytes(self):
"""
L{filenameToModuleName} returns the correct module given a C{bytes}
path to its file.
"""
module = reflect.filenameToModuleName(
os.path.join(self.path.encode("utf-8"), b"test_reflect.py")
)
# Module names are always native string:
self.assertEqual(module, "fakepackage.test.test_reflect")
class FullyQualifiedNameTests(TestCase):
"""
Test for L{fullyQualifiedName}.
"""
def _checkFullyQualifiedName(self, obj, expected):
"""
Helper to check that fully qualified name of C{obj} results to
C{expected}.
"""
self.assertEqual(fullyQualifiedName(obj), expected)
def test_package(self):
"""
L{fullyQualifiedName} returns the full name of a package and a
subpackage.
"""
import twisted
self._checkFullyQualifiedName(twisted, "twisted")
import twisted.python
self._checkFullyQualifiedName(twisted.python, "twisted.python")
def test_module(self):
"""
L{fullyQualifiedName} returns the name of a module inside a package.
"""
import twisted.python.compat
self._checkFullyQualifiedName(twisted.python.compat, "twisted.python.compat")
def test_class(self):
"""
L{fullyQualifiedName} returns the name of a class and its module.
"""
self._checkFullyQualifiedName(
FullyQualifiedNameTests, f"{__name__}.FullyQualifiedNameTests"
)
def test_function(self):
"""
L{fullyQualifiedName} returns the name of a function inside its module.
"""
self._checkFullyQualifiedName(
fullyQualifiedName, "twisted.python.reflect.fullyQualifiedName"
)
def test_boundMethod(self):
"""
L{fullyQualifiedName} returns the name of a bound method inside its
class and its module.
"""
self._checkFullyQualifiedName(
self.test_boundMethod,
f"{__name__}.{self.__class__.__name__}.test_boundMethod",
)
def test_unboundMethod(self):
"""
L{fullyQualifiedName} returns the name of an unbound method inside its
class and its module.
"""
self._checkFullyQualifiedName(
self.__class__.test_unboundMethod,
f"{__name__}.{self.__class__.__name__}.test_unboundMethod",
)
class ObjectGrepTests(TestCase):
def test_dictionary(self):
"""
Test references search through a dictionary, as a key or as a value.
"""
o = object()
d1 = {None: o}
d2 = {o: None}
self.assertIn("[None]", reflect.objgrep(d1, o, reflect.isSame))
self.assertIn("{None}", reflect.objgrep(d2, o, reflect.isSame))
def test_list(self):
"""
Test references search through a list.
"""
o = object()
L = [None, o]
self.assertIn("[1]", reflect.objgrep(L, o, reflect.isSame))
def test_tuple(self):
"""
Test references search through a tuple.
"""
o = object()
T = (o, None)
self.assertIn("[0]", reflect.objgrep(T, o, reflect.isSame))
def test_instance(self):
"""
Test references search through an object attribute.
"""
class Dummy:
pass
o = object()
d = Dummy()
d.o = o
self.assertIn(".o", reflect.objgrep(d, o, reflect.isSame))
def test_weakref(self):
"""
Test references search through a weakref object.
"""
class Dummy:
pass
o = Dummy()
w1 = weakref.ref(o)
self.assertIn("()", reflect.objgrep(w1, o, reflect.isSame))
def test_boundMethod(self):
"""
Test references search through method special attributes.
"""
class Dummy:
def dummy(self):
pass
o = Dummy()
m = o.dummy
self.assertIn(".__self__", reflect.objgrep(m, m.__self__, reflect.isSame))
self.assertIn(
".__self__.__class__",
reflect.objgrep(m, m.__self__.__class__, reflect.isSame),
)
self.assertIn(".__func__", reflect.objgrep(m, m.__func__, reflect.isSame))
def test_everything(self):
"""
Test references search using complex set of objects.
"""
class Dummy:
def method(self):
pass
o = Dummy()
D1 = {(): "baz", None: "Quux", o: "Foosh"}
L = [None, (), D1, 3]
T = (L, {}, Dummy())
D2 = {0: "foo", 1: "bar", 2: T}
i = Dummy()
i.attr = D2
m = i.method
w = weakref.ref(m)
self.assertIn(
"().__self__.attr[2][0][2]{'Foosh'}", reflect.objgrep(w, o, reflect.isSame)
)
def test_depthLimit(self):
"""
Test the depth of references search.
"""
a = []
b = [a]
c = [a, b]
d = [a, c]
self.assertEqual(["[0]"], reflect.objgrep(d, a, reflect.isSame, maxDepth=1))
self.assertEqual(
["[0]", "[1][0]"], reflect.objgrep(d, a, reflect.isSame, maxDepth=2)
)
self.assertEqual(
["[0]", "[1][0]", "[1][1][0]"],
reflect.objgrep(d, a, reflect.isSame, maxDepth=3),
)
def test_deque(self):
"""
Test references search through a deque object.
"""
o = object()
D = deque()
D.append(None)
D.append(o)
self.assertIn("[1]", reflect.objgrep(D, o, reflect.isSame))
class GetClassTests(TestCase):
def test_new(self):
class NewClass:
pass
new = NewClass()
self.assertEqual(reflect.getClass(NewClass).__name__, "type")
self.assertEqual(reflect.getClass(new).__name__, "NewClass")

View File

@@ -0,0 +1,58 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
from twisted.python import roots
from twisted.trial import unittest
class RootsTests(unittest.TestCase):
def testExceptions(self) -> None:
request = roots.Request()
try:
request.write(b"blah")
except NotImplementedError:
pass
else:
self.fail()
try:
request.finish()
except NotImplementedError:
pass
else:
self.fail()
def testCollection(self) -> None:
collection = roots.Collection()
collection.putEntity("x", "test")
self.assertEqual(collection.getStaticEntity("x"), "test")
collection.delEntity("x")
self.assertEqual(collection.getStaticEntity("x"), None)
try:
collection.storeEntity("x", None)
except NotImplementedError:
pass
else:
self.fail()
try:
collection.removeEntity("x", None)
except NotImplementedError:
pass
else:
self.fail()
def testConstrained(self) -> None:
class const(roots.Constrained):
def nameConstraint(self, name: str) -> bool:
return name == "x"
c = const()
self.assertIsNone(c.putEntity("x", "test"))
self.assertRaises(roots.ConstraintViolation, c.putEntity, "y", "test")
def testHomogenous(self) -> None:
h = roots.Homogenous()
h.entityType = int
h.putEntity("a", 1)
self.assertEqual(h.getStaticEntity("a"), 1)
self.assertRaises(roots.ConstraintViolation, h.putEntity, "x", "y")

View File

@@ -0,0 +1,64 @@
"""
Test win32 shortcut script
"""
import os.path
import sys
import tempfile
from twisted.trial import unittest
skipReason = None
try:
from win32com.shell import shell
from twisted.python import shortcut
except ImportError:
skipReason = "Only runs on Windows with win32com"
if sys.version_info[0:2] >= (3, 7):
skipReason = "Broken on Python 3.7+."
class ShortcutTests(unittest.TestCase):
skip = skipReason
def test_create(self) -> None:
"""
Create a simple shortcut.
"""
testFilename = __file__
baseFileName = os.path.basename(testFilename)
s1 = shortcut.Shortcut(testFilename)
tempname = self.mktemp() + ".lnk"
s1.save(tempname)
self.assertTrue(os.path.exists(tempname))
sc = shortcut.open(tempname)
scPath = sc.GetPath(shell.SLGP_RAWPATH)[0]
self.assertEqual(scPath[-len(baseFileName) :].lower(), baseFileName.lower())
def test_createPythonShortcut(self) -> None:
"""
Create a shortcut to the Python executable,
and set some values.
"""
testFilename = sys.executable
baseFileName = os.path.basename(testFilename)
tempDir = tempfile.gettempdir()
s1 = shortcut.Shortcut(
path=testFilename,
arguments="-V",
description="The Python executable",
workingdir=tempDir,
iconpath=tempDir,
iconidx=1,
)
tempname = self.mktemp() + ".lnk"
s1.save(tempname)
self.assertTrue(os.path.exists(tempname))
sc = shortcut.open(tempname)
scPath = sc.GetPath(shell.SLGP_RAWPATH)[0]
self.assertEqual(scPath[-len(baseFileName) :].lower(), baseFileName.lower())
self.assertEqual(sc.GetDescription(), "The Python executable")
self.assertEqual(sc.GetWorkingDirectory(), tempDir)
self.assertEqual(sc.GetIconLocation(), (tempDir, 1))

View File

@@ -0,0 +1,780 @@
# -*- test-case-name: twisted.test.test_sip -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Session Initialization Protocol tests.
"""
from twisted.cred import checkers, portal
from twisted.internet import defer, reactor
from twisted.protocols import sip
from twisted.trial import unittest
try:
from twisted.internet.asyncioreactor import AsyncioSelectorReactor
except BaseException:
AsyncioSelectorReactor = None # type: ignore[assignment,misc]
from zope.interface import implementer
# request, prefixed by random CRLFs
request1 = (
"\n\r\n\n\r"
+ """\
INVITE sip:foo SIP/2.0
From: mo
To: joe
Content-Length: 4
abcd""".replace(
"\n", "\r\n"
)
)
# request, no content-length
request2 = """INVITE sip:foo SIP/2.0
From: mo
To: joe
1234""".replace(
"\n", "\r\n"
)
# request, with garbage after
request3 = """INVITE sip:foo SIP/2.0
From: mo
To: joe
Content-Length: 4
1234
lalalal""".replace(
"\n", "\r\n"
)
# three requests
request4 = """INVITE sip:foo SIP/2.0
From: mo
To: joe
Content-Length: 0
INVITE sip:loop SIP/2.0
From: foo
To: bar
Content-Length: 4
abcdINVITE sip:loop SIP/2.0
From: foo
To: bar
Content-Length: 4
1234""".replace(
"\n", "\r\n"
)
# response, no content
response1 = """SIP/2.0 200 OK
From: foo
To:bar
Content-Length: 0
""".replace(
"\n", "\r\n"
)
# short header version
request_short = """\
INVITE sip:foo SIP/2.0
f: mo
t: joe
l: 4
abcd""".replace(
"\n", "\r\n"
)
request_natted = """\
INVITE sip:foo SIP/2.0
Via: SIP/2.0/UDP 10.0.0.1:5060;rport
""".replace(
"\n", "\r\n"
)
# multiline headers (example from RFC 3621).
response_multiline = """\
SIP/2.0 200 OK
Via: SIP/2.0/UDP server10.biloxi.com
;branch=z9hG4bKnashds8;received=192.0.2.3
Via: SIP/2.0/UDP bigbox3.site3.atlanta.com
;branch=z9hG4bK77ef4c2312983.1;received=192.0.2.2
Via: SIP/2.0/UDP pc33.atlanta.com
;branch=z9hG4bK776asdhds ;received=192.0.2.1
To: Bob <sip:bob@biloxi.com>;tag=a6c85cf
From: Alice <sip:alice@atlanta.com>;tag=1928301774
Call-ID: a84b4c76e66710@pc33.atlanta.com
CSeq: 314159 INVITE
Contact: <sip:bob@192.0.2.4>
Content-Type: application/sdp
Content-Length: 0
\n""".replace(
"\n", "\r\n"
)
class TestRealm:
def requestAvatar(self, avatarId, mind, *interfaces):
return sip.IContact, None, lambda: None
class MessageParsingTests(unittest.TestCase):
def setUp(self):
self.l = []
self.parser = sip.MessagesParser(self.l.append)
def feedMessage(self, message):
self.parser.dataReceived(message)
self.parser.dataDone()
def validateMessage(self, m, method, uri, headers, body):
"""
Validate Requests.
"""
self.assertEqual(m.method, method)
self.assertEqual(m.uri.toString(), uri)
self.assertEqual(m.headers, headers)
self.assertEqual(m.body, body)
self.assertEqual(m.finished, 1)
def testSimple(self):
l = self.l
self.feedMessage(request1)
self.assertEqual(len(l), 1)
self.validateMessage(
l[0],
"INVITE",
"sip:foo",
{"from": ["mo"], "to": ["joe"], "content-length": ["4"]},
"abcd",
)
def testTwoMessages(self):
l = self.l
self.feedMessage(request1)
self.feedMessage(request2)
self.assertEqual(len(l), 2)
self.validateMessage(
l[0],
"INVITE",
"sip:foo",
{"from": ["mo"], "to": ["joe"], "content-length": ["4"]},
"abcd",
)
self.validateMessage(
l[1], "INVITE", "sip:foo", {"from": ["mo"], "to": ["joe"]}, "1234"
)
def testGarbage(self):
l = self.l
self.feedMessage(request3)
self.assertEqual(len(l), 1)
self.validateMessage(
l[0],
"INVITE",
"sip:foo",
{"from": ["mo"], "to": ["joe"], "content-length": ["4"]},
"1234",
)
def testThreeInOne(self):
l = self.l
self.feedMessage(request4)
self.assertEqual(len(l), 3)
self.validateMessage(
l[0],
"INVITE",
"sip:foo",
{"from": ["mo"], "to": ["joe"], "content-length": ["0"]},
"",
)
self.validateMessage(
l[1],
"INVITE",
"sip:loop",
{"from": ["foo"], "to": ["bar"], "content-length": ["4"]},
"abcd",
)
self.validateMessage(
l[2],
"INVITE",
"sip:loop",
{"from": ["foo"], "to": ["bar"], "content-length": ["4"]},
"1234",
)
def testShort(self):
l = self.l
self.feedMessage(request_short)
self.assertEqual(len(l), 1)
self.validateMessage(
l[0],
"INVITE",
"sip:foo",
{"from": ["mo"], "to": ["joe"], "content-length": ["4"]},
"abcd",
)
def testSimpleResponse(self):
l = self.l
self.feedMessage(response1)
self.assertEqual(len(l), 1)
m = l[0]
self.assertEqual(m.code, 200)
self.assertEqual(m.phrase, "OK")
self.assertEqual(
m.headers, {"from": ["foo"], "to": ["bar"], "content-length": ["0"]}
)
self.assertEqual(m.body, "")
self.assertEqual(m.finished, 1)
def test_multiLine(self):
"""
A header may be split across multiple lines. Subsequent lines begin
with C{" "} or C{"\\t"}.
"""
l = self.l
self.feedMessage(response_multiline)
self.assertEqual(len(l), 1)
m = l[0]
self.assertEqual(
m.headers["via"][0],
"SIP/2.0/UDP server10.biloxi.com;"
"branch=z9hG4bKnashds8;received=192.0.2.3",
)
self.assertEqual(
m.headers["via"][1],
"SIP/2.0/UDP bigbox3.site3.atlanta.com;"
"branch=z9hG4bK77ef4c2312983.1;received=192.0.2.2",
)
self.assertEqual(
m.headers["via"][2],
"SIP/2.0/UDP pc33.atlanta.com;"
"branch=z9hG4bK776asdhds ;received=192.0.2.1",
)
class MessageParsingFeedDataCharByCharTests(MessageParsingTests):
"""
Same as base class, but feed data char by char.
"""
def feedMessage(self, message):
for c in message:
self.parser.dataReceived(c)
self.parser.dataDone()
class MakeMessageTests(unittest.TestCase):
def testRequest(self):
r = sip.Request("INVITE", "sip:foo")
r.addHeader("foo", "bar")
self.assertEqual(r.toString(), "INVITE sip:foo SIP/2.0\r\nFoo: bar\r\n\r\n")
def testResponse(self):
r = sip.Response(200, "OK")
r.addHeader("foo", "bar")
r.addHeader("Content-Length", "4")
r.bodyDataReceived("1234")
self.assertEqual(
r.toString(), "SIP/2.0 200 OK\r\nFoo: bar\r\nContent-Length: 4\r\n\r\n1234"
)
def testStatusCode(self):
r = sip.Response(200)
self.assertEqual(r.toString(), "SIP/2.0 200 OK\r\n\r\n")
class ViaTests(unittest.TestCase):
def checkRoundtrip(self, v):
s = v.toString()
self.assertEqual(s, sip.parseViaHeader(s).toString())
def testExtraWhitespace(self):
v1 = sip.parseViaHeader("SIP/2.0/UDP 192.168.1.1:5060")
v2 = sip.parseViaHeader("SIP/2.0/UDP 192.168.1.1:5060")
self.assertEqual(v1.transport, v2.transport)
self.assertEqual(v1.host, v2.host)
self.assertEqual(v1.port, v2.port)
def test_complex(self):
"""
Test parsing a Via header with one of everything.
"""
s = (
"SIP/2.0/UDP first.example.com:4000;ttl=16;maddr=224.2.0.1"
" ;branch=a7c6a8dlze (Example)"
)
v = sip.parseViaHeader(s)
self.assertEqual(v.transport, "UDP")
self.assertEqual(v.host, "first.example.com")
self.assertEqual(v.port, 4000)
self.assertIsNone(v.rport)
self.assertIsNone(v.rportValue)
self.assertFalse(v.rportRequested)
self.assertEqual(v.ttl, 16)
self.assertEqual(v.maddr, "224.2.0.1")
self.assertEqual(v.branch, "a7c6a8dlze")
self.assertEqual(v.hidden, 0)
self.assertEqual(
v.toString(),
"SIP/2.0/UDP first.example.com:4000"
";ttl=16;branch=a7c6a8dlze;maddr=224.2.0.1",
)
self.checkRoundtrip(v)
def test_simple(self):
"""
Test parsing a simple Via header.
"""
s = "SIP/2.0/UDP example.com;hidden"
v = sip.parseViaHeader(s)
self.assertEqual(v.transport, "UDP")
self.assertEqual(v.host, "example.com")
self.assertEqual(v.port, 5060)
self.assertIsNone(v.rport)
self.assertIsNone(v.rportValue)
self.assertFalse(v.rportRequested)
self.assertIsNone(v.ttl)
self.assertIsNone(v.maddr)
self.assertIsNone(v.branch)
self.assertTrue(v.hidden)
self.assertEqual(v.toString(), "SIP/2.0/UDP example.com:5060;hidden")
self.checkRoundtrip(v)
def testSimpler(self):
v = sip.Via("example.com")
self.checkRoundtrip(v)
def test_deprecatedRPort(self):
"""
Setting rport to True is deprecated, but still produces a Via header
with the expected properties.
"""
v = sip.Via("foo.bar", rport=True)
warnings = self.flushWarnings(offendingFunctions=[self.test_deprecatedRPort])
self.assertEqual(len(warnings), 1)
self.assertEqual(
warnings[0]["message"], "rport=True is deprecated since Twisted 9.0."
)
self.assertEqual(warnings[0]["category"], DeprecationWarning)
self.assertEqual(v.toString(), "SIP/2.0/UDP foo.bar:5060;rport")
self.assertTrue(v.rport)
self.assertTrue(v.rportRequested)
self.assertIsNone(v.rportValue)
def test_rport(self):
"""
An rport setting of None should insert the parameter with no value.
"""
v = sip.Via("foo.bar", rport=None)
self.assertEqual(v.toString(), "SIP/2.0/UDP foo.bar:5060;rport")
self.assertTrue(v.rportRequested)
self.assertIsNone(v.rportValue)
def test_rportValue(self):
"""
An rport numeric setting should insert the parameter with the number
value given.
"""
v = sip.Via("foo.bar", rport=1)
self.assertEqual(v.toString(), "SIP/2.0/UDP foo.bar:5060;rport=1")
self.assertFalse(v.rportRequested)
self.assertEqual(v.rportValue, 1)
self.assertEqual(v.rport, 1)
def testNAT(self):
s = "SIP/2.0/UDP 10.0.0.1:5060;received=22.13.1.5;rport=12345"
v = sip.parseViaHeader(s)
self.assertEqual(v.transport, "UDP")
self.assertEqual(v.host, "10.0.0.1")
self.assertEqual(v.port, 5060)
self.assertEqual(v.received, "22.13.1.5")
self.assertEqual(v.rport, 12345)
self.assertNotEqual(v.toString().find("rport=12345"), -1)
def test_unknownParams(self):
"""
Parsing and serializing Via headers with unknown parameters should work.
"""
s = "SIP/2.0/UDP example.com:5060;branch=a12345b;bogus;pie=delicious"
v = sip.parseViaHeader(s)
self.assertEqual(v.toString(), s)
class URLTests(unittest.TestCase):
def testRoundtrip(self):
for url in [
"sip:j.doe@big.com",
"sip:j.doe:secret@big.com;transport=tcp",
"sip:j.doe@big.com?subject=project",
"sip:example.com",
]:
self.assertEqual(sip.parseURL(url).toString(), url)
def testComplex(self):
s = (
"sip:user:pass@hosta:123;transport=udp;user=phone;method=foo;"
"ttl=12;maddr=1.2.3.4;blah;goo=bar?a=b&c=d"
)
url = sip.parseURL(s)
for k, v in [
("username", "user"),
("password", "pass"),
("host", "hosta"),
("port", 123),
("transport", "udp"),
("usertype", "phone"),
("method", "foo"),
("ttl", 12),
("maddr", "1.2.3.4"),
("other", ["blah", "goo=bar"]),
("headers", {"a": "b", "c": "d"}),
]:
self.assertEqual(getattr(url, k), v)
class ParseTests(unittest.TestCase):
def testParseAddress(self):
for address, name, urls, params in [
(
'"A. G. Bell" <sip:foo@example.com>',
"A. G. Bell",
"sip:foo@example.com",
{},
),
("Anon <sip:foo@example.com>", "Anon", "sip:foo@example.com", {}),
("sip:foo@example.com", "", "sip:foo@example.com", {}),
("<sip:foo@example.com>", "", "sip:foo@example.com", {}),
(
"foo <sip:foo@example.com>;tag=bar;foo=baz",
"foo",
"sip:foo@example.com",
{"tag": "bar", "foo": "baz"},
),
]:
gname, gurl, gparams = sip.parseAddress(address)
self.assertEqual(name, gname)
self.assertEqual(gurl.toString(), urls)
self.assertEqual(gparams, params)
@implementer(sip.ILocator)
class DummyLocator:
def getAddress(self, logicalURL):
return defer.succeed(sip.URL("server.com", port=5060))
@implementer(sip.ILocator)
class FailingLocator:
def getAddress(self, logicalURL):
return defer.fail(LookupError())
class ProxyTests(unittest.TestCase):
def setUp(self):
self.proxy = sip.Proxy("127.0.0.1")
self.proxy.locator = DummyLocator()
self.sent = []
self.proxy.sendMessage = lambda dest, msg: self.sent.append((dest, msg))
def testRequestForward(self):
r = sip.Request("INVITE", "sip:foo")
r.addHeader("via", sip.Via("1.2.3.4").toString())
r.addHeader("via", sip.Via("1.2.3.5").toString())
r.addHeader("foo", "bar")
r.addHeader("to", "<sip:joe@server.com>")
r.addHeader("contact", "<sip:joe@1.2.3.5>")
self.proxy.datagramReceived(r.toString(), ("1.2.3.4", 5060))
self.assertEqual(len(self.sent), 1)
dest, m = self.sent[0]
self.assertEqual(dest.port, 5060)
self.assertEqual(dest.host, "server.com")
self.assertEqual(m.uri.toString(), "sip:foo")
self.assertEqual(m.method, "INVITE")
self.assertEqual(
m.headers["via"],
[
"SIP/2.0/UDP 127.0.0.1:5060",
"SIP/2.0/UDP 1.2.3.4:5060",
"SIP/2.0/UDP 1.2.3.5:5060",
],
)
def testReceivedRequestForward(self):
r = sip.Request("INVITE", "sip:foo")
r.addHeader("via", sip.Via("1.2.3.4").toString())
r.addHeader("foo", "bar")
r.addHeader("to", "<sip:joe@server.com>")
r.addHeader("contact", "<sip:joe@1.2.3.4>")
self.proxy.datagramReceived(r.toString(), ("1.1.1.1", 5060))
dest, m = self.sent[0]
self.assertEqual(
m.headers["via"],
["SIP/2.0/UDP 127.0.0.1:5060", "SIP/2.0/UDP 1.2.3.4:5060;received=1.1.1.1"],
)
def testResponseWrongVia(self):
# first via must match proxy's address
r = sip.Response(200)
r.addHeader("via", sip.Via("foo.com").toString())
self.proxy.datagramReceived(r.toString(), ("1.1.1.1", 5060))
self.assertEqual(len(self.sent), 0)
def testResponseForward(self):
r = sip.Response(200)
r.addHeader("via", sip.Via("127.0.0.1").toString())
r.addHeader("via", sip.Via("client.com", port=1234).toString())
self.proxy.datagramReceived(r.toString(), ("1.1.1.1", 5060))
self.assertEqual(len(self.sent), 1)
dest, m = self.sent[0]
self.assertEqual((dest.host, dest.port), ("client.com", 1234))
self.assertEqual(m.code, 200)
self.assertEqual(m.headers["via"], ["SIP/2.0/UDP client.com:1234"])
def testReceivedResponseForward(self):
r = sip.Response(200)
r.addHeader("via", sip.Via("127.0.0.1").toString())
r.addHeader("via", sip.Via("10.0.0.1", received="client.com").toString())
self.proxy.datagramReceived(r.toString(), ("1.1.1.1", 5060))
self.assertEqual(len(self.sent), 1)
dest, m = self.sent[0]
self.assertEqual((dest.host, dest.port), ("client.com", 5060))
def testResponseToUs(self):
r = sip.Response(200)
r.addHeader("via", sip.Via("127.0.0.1").toString())
l = []
self.proxy.gotResponse = lambda *a: l.append(a)
self.proxy.datagramReceived(r.toString(), ("1.1.1.1", 5060))
self.assertEqual(len(l), 1)
m, addr = l[0]
self.assertEqual(len(m.headers.get("via", [])), 0)
self.assertEqual(m.code, 200)
def testLoop(self):
r = sip.Request("INVITE", "sip:foo")
r.addHeader("via", sip.Via("1.2.3.4").toString())
r.addHeader("via", sip.Via("127.0.0.1").toString())
self.proxy.datagramReceived(r.toString(), ("client.com", 5060))
self.assertEqual(self.sent, [])
def testCantForwardRequest(self):
r = sip.Request("INVITE", "sip:foo")
r.addHeader("via", sip.Via("1.2.3.4").toString())
r.addHeader("to", "<sip:joe@server.com>")
self.proxy.locator = FailingLocator()
self.proxy.datagramReceived(r.toString(), ("1.2.3.4", 5060))
self.assertEqual(len(self.sent), 1)
dest, m = self.sent[0]
self.assertEqual((dest.host, dest.port), ("1.2.3.4", 5060))
self.assertEqual(m.code, 404)
self.assertEqual(m.headers["via"], ["SIP/2.0/UDP 1.2.3.4:5060"])
class RegistrationTests(unittest.TestCase):
def setUp(self):
self.proxy = sip.RegisterProxy(host="127.0.0.1")
self.registry = sip.InMemoryRegistry("bell.example.com")
self.proxy.registry = self.proxy.locator = self.registry
self.sent = []
self.proxy.sendMessage = lambda dest, msg: self.sent.append((dest, msg))
def tearDown(self):
for d, uri in self.registry.users.values():
d.cancel()
del self.proxy
def register(self):
r = sip.Request("REGISTER", "sip:bell.example.com")
r.addHeader("to", "sip:joe@bell.example.com")
r.addHeader("contact", "sip:joe@client.com:1234")
r.addHeader("via", sip.Via("client.com").toString())
self.proxy.datagramReceived(r.toString(), ("client.com", 5060))
def unregister(self):
r = sip.Request("REGISTER", "sip:bell.example.com")
r.addHeader("to", "sip:joe@bell.example.com")
r.addHeader("contact", "*")
r.addHeader("via", sip.Via("client.com").toString())
r.addHeader("expires", "0")
self.proxy.datagramReceived(r.toString(), ("client.com", 5060))
def testRegister(self):
self.register()
dest, m = self.sent[0]
self.assertEqual((dest.host, dest.port), ("client.com", 5060))
self.assertEqual(m.code, 200)
self.assertEqual(m.headers["via"], ["SIP/2.0/UDP client.com:5060"])
self.assertEqual(m.headers["to"], ["sip:joe@bell.example.com"])
self.assertEqual(m.headers["contact"], ["sip:joe@client.com:5060"])
#
# XX: See http://tm.tl/8886
#
if type(reactor) != AsyncioSelectorReactor:
self.assertTrue(int(m.headers["expires"][0]) in (3600, 3601, 3599, 3598))
self.assertEqual(len(self.registry.users), 1)
dc, uri = self.registry.users["joe"]
self.assertEqual(uri.toString(), "sip:joe@client.com:5060")
d = self.proxy.locator.getAddress(
sip.URL(username="joe", host="bell.example.com")
)
d.addCallback(lambda desturl: (desturl.host, desturl.port))
d.addCallback(self.assertEqual, ("client.com", 5060))
return d
def testUnregister(self):
self.register()
self.unregister()
dest, m = self.sent[1]
self.assertEqual((dest.host, dest.port), ("client.com", 5060))
self.assertEqual(m.code, 200)
self.assertEqual(m.headers["via"], ["SIP/2.0/UDP client.com:5060"])
self.assertEqual(m.headers["to"], ["sip:joe@bell.example.com"])
self.assertEqual(m.headers["contact"], ["sip:joe@client.com:5060"])
self.assertEqual(m.headers["expires"], ["0"])
self.assertEqual(self.registry.users, {})
def addPortal(self):
r = TestRealm()
p = portal.Portal(r)
c = checkers.InMemoryUsernamePasswordDatabaseDontUse()
c.addUser("userXname@127.0.0.1", "passXword")
p.registerChecker(c)
self.proxy.portal = p
def testFailedAuthentication(self):
self.addPortal()
self.register()
self.assertEqual(len(self.registry.users), 0)
self.assertEqual(len(self.sent), 1)
dest, m = self.sent[0]
self.assertEqual(m.code, 401)
def testWrongDomainRegister(self):
r = sip.Request("REGISTER", "sip:wrong.com")
r.addHeader("to", "sip:joe@bell.example.com")
r.addHeader("contact", "sip:joe@client.com:1234")
r.addHeader("via", sip.Via("client.com").toString())
self.proxy.datagramReceived(r.toString(), ("client.com", 5060))
self.assertEqual(len(self.sent), 0)
def testWrongToDomainRegister(self):
r = sip.Request("REGISTER", "sip:bell.example.com")
r.addHeader("to", "sip:joe@foo.com")
r.addHeader("contact", "sip:joe@client.com:1234")
r.addHeader("via", sip.Via("client.com").toString())
self.proxy.datagramReceived(r.toString(), ("client.com", 5060))
self.assertEqual(len(self.sent), 0)
def testWrongDomainLookup(self):
self.register()
url = sip.URL(username="joe", host="foo.com")
d = self.proxy.locator.getAddress(url)
self.assertFailure(d, LookupError)
return d
def testNoContactLookup(self):
self.register()
url = sip.URL(username="jane", host="bell.example.com")
d = self.proxy.locator.getAddress(url)
self.assertFailure(d, LookupError)
return d
class Client(sip.Base):
def __init__(self):
sip.Base.__init__(self)
self.received = []
self.deferred = defer.Deferred()
def handle_response(self, response, addr):
self.received.append(response)
self.deferred.callback(self.received)
class LiveTests(unittest.TestCase):
def setUp(self):
self.proxy = sip.RegisterProxy(host="127.0.0.1")
self.registry = sip.InMemoryRegistry("bell.example.com")
self.proxy.registry = self.proxy.locator = self.registry
self.serverPort = reactor.listenUDP(0, self.proxy, interface="127.0.0.1")
self.client = Client()
self.clientPort = reactor.listenUDP(0, self.client, interface="127.0.0.1")
self.serverAddress = (
self.serverPort.getHost().host,
self.serverPort.getHost().port,
)
def tearDown(self):
for d, uri in self.registry.users.values():
d.cancel()
d1 = defer.maybeDeferred(self.clientPort.stopListening)
d2 = defer.maybeDeferred(self.serverPort.stopListening)
return defer.gatherResults([d1, d2])
def testRegister(self):
p = self.clientPort.getHost().port
r = sip.Request("REGISTER", "sip:bell.example.com")
r.addHeader("to", "sip:joe@bell.example.com")
r.addHeader("contact", "sip:joe@127.0.0.1:%d" % p)
r.addHeader("via", sip.Via("127.0.0.1", port=p).toString())
self.client.sendMessage(
sip.URL(host="127.0.0.1", port=self.serverAddress[1]), r
)
d = self.client.deferred
def check(received):
self.assertEqual(len(received), 1)
r = received[0]
self.assertEqual(r.code, 200)
d.addCallback(check)
return d
def test_amoralRPort(self):
"""
rport is allowed without a value, apparently because server
implementors might be too stupid to check the received port
against 5060 and see if they're equal, and because client
implementors might be too stupid to bind to port 5060, or set a
value on the rport parameter they send if they bind to another
port.
"""
p = self.clientPort.getHost().port
r = sip.Request("REGISTER", "sip:bell.example.com")
r.addHeader("to", "sip:joe@bell.example.com")
r.addHeader("contact", "sip:joe@127.0.0.1:%d" % p)
r.addHeader("via", sip.Via("127.0.0.1", port=p, rport=True).toString())
warnings = self.flushWarnings(offendingFunctions=[self.test_amoralRPort])
self.assertEqual(len(warnings), 1)
self.assertEqual(
warnings[0]["message"], "rport=True is deprecated since Twisted 9.0."
)
self.assertEqual(warnings[0]["category"], DeprecationWarning)
self.client.sendMessage(
sip.URL(host="127.0.0.1", port=self.serverAddress[1]), r
)
d = self.client.deferred
def check(received):
self.assertEqual(len(received), 1)
r = received[0]
self.assertEqual(r.code, 200)
d.addCallback(check)
return d

View File

@@ -0,0 +1,176 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
import os
import sys
from textwrap import dedent
from twisted.persisted import sob
from twisted.persisted.styles import Ephemeral
from twisted.python import components
from twisted.trial import unittest
class Dummy(components.Componentized):
pass
objects = [
1,
"hello",
(1, "hello"),
[1, "hello"],
{1: "hello"},
]
class FakeModule:
pass
class PersistTests(unittest.TestCase):
def testStyles(self):
for o in objects:
p = sob.Persistent(o, "")
for style in "source pickle".split():
p.setStyle(style)
p.save(filename="persisttest." + style)
o1 = sob.load("persisttest." + style, style)
self.assertEqual(o, o1)
def testStylesBeingSet(self):
o = Dummy()
o.foo = 5
o.setComponent(sob.IPersistable, sob.Persistent(o, "lala"))
for style in "source pickle".split():
sob.IPersistable(o).setStyle(style)
sob.IPersistable(o).save(filename="lala." + style)
o1 = sob.load("lala." + style, style)
self.assertEqual(o.foo, o1.foo)
self.assertEqual(sob.IPersistable(o1).style, style)
def testPassphraseError(self):
"""
Calling save() with a passphrase is an error.
"""
p = sob.Persistant(None, "object")
self.assertRaises(TypeError, p.save, "filename.pickle", passphrase="abc")
def testNames(self):
o = [1, 2, 3]
p = sob.Persistent(o, "object")
for style in "source pickle".split():
p.setStyle(style)
p.save()
o1 = sob.load("object.ta" + style[0], style)
self.assertEqual(o, o1)
for tag in "lala lolo".split():
p.save(tag)
o1 = sob.load("object-" + tag + ".ta" + style[0], style)
self.assertEqual(o, o1)
def testPython(self):
with open("persisttest.python", "w") as f:
f.write("foo=[1,2,3] ")
o = sob.loadValueFromFile("persisttest.python", "foo")
self.assertEqual(o, [1, 2, 3])
def testTypeGuesser(self):
self.assertRaises(KeyError, sob.guessType, "file.blah")
self.assertEqual("python", sob.guessType("file.py"))
self.assertEqual("python", sob.guessType("file.tac"))
self.assertEqual("python", sob.guessType("file.etac"))
self.assertEqual("pickle", sob.guessType("file.tap"))
self.assertEqual("pickle", sob.guessType("file.etap"))
self.assertEqual("source", sob.guessType("file.tas"))
self.assertEqual("source", sob.guessType("file.etas"))
def testEverythingEphemeralGetattr(self):
"""
L{_EverythingEphermal.__getattr__} will proxy the __main__ module as an
L{Ephemeral} object, and during load will be transparent, but after
load will return L{Ephemeral} objects from any accessed attributes.
"""
self.fakeMain.testMainModGetattr = 1
dirname = self.mktemp()
os.mkdir(dirname)
filename = os.path.join(dirname, "persisttest.ee_getattr")
global mainWhileLoading
mainWhileLoading = None
with open(filename, "w") as f:
f.write(
dedent(
"""
app = []
import __main__
app.append(__main__.testMainModGetattr == 1)
try:
__main__.somethingElse
except AttributeError:
app.append(True)
else:
app.append(False)
from twisted.test import test_sob
test_sob.mainWhileLoading = __main__
"""
)
)
loaded = sob.load(filename, "source")
self.assertIsInstance(loaded, list)
self.assertTrue(loaded[0], "Expected attribute not set.")
self.assertTrue(loaded[1], "Unexpected attribute set.")
self.assertIsInstance(mainWhileLoading, Ephemeral)
self.assertIsInstance(mainWhileLoading.somethingElse, Ephemeral)
del mainWhileLoading
def testEverythingEphemeralSetattr(self):
"""
Verify that _EverythingEphemeral.__setattr__ won't affect __main__.
"""
self.fakeMain.testMainModSetattr = 1
dirname = self.mktemp()
os.mkdir(dirname)
filename = os.path.join(dirname, "persisttest.ee_setattr")
with open(filename, "w") as f:
f.write("import __main__\n")
f.write("__main__.testMainModSetattr = 2\n")
f.write("app = None\n")
sob.load(filename, "source")
self.assertEqual(self.fakeMain.testMainModSetattr, 1)
def testEverythingEphemeralException(self):
"""
Test that an exception during load() won't cause _EE to mask __main__
"""
dirname = self.mktemp()
os.mkdir(dirname)
filename = os.path.join(dirname, "persisttest.ee_exception")
with open(filename, "w") as f:
f.write("raise ValueError\n")
self.assertRaises(ValueError, sob.load, filename, "source")
self.assertEqual(type(sys.modules["__main__"]), FakeModule)
def setUp(self):
"""
Replace the __main__ module with a fake one, so that it can be mutated
in tests
"""
self.realMain = sys.modules["__main__"]
self.fakeMain = sys.modules["__main__"] = FakeModule()
def tearDown(self):
"""
Restore __main__ to its original value
"""
sys.modules["__main__"] = self.realMain

View File

@@ -0,0 +1,498 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.protocol.socks}, an implementation of the SOCKSv4 and
SOCKSv4a protocols.
"""
import socket
import struct
from twisted.internet import address, defer
from twisted.internet.error import DNSLookupError
from twisted.protocols import socks
from twisted.python.compat import iterbytes
from twisted.test import proto_helpers
from twisted.trial import unittest
class StringTCPTransport(proto_helpers.StringTransport):
stringTCPTransport_closing = False
peer = None
def getPeer(self):
return self.peer
def getHost(self):
return address.IPv4Address("TCP", "2.3.4.5", 42)
def loseConnection(self):
self.stringTCPTransport_closing = True
class FakeResolverReactor:
"""
Bare-bones reactor with deterministic behavior for the resolve method.
"""
def __init__(self, names):
"""
@type names: L{dict} containing L{str} keys and L{str} values.
@param names: A hostname to IP address mapping. The IP addresses are
stringified dotted quads.
"""
self.names = names
def resolve(self, hostname):
"""
Resolve a hostname by looking it up in the C{names} dictionary.
"""
try:
return defer.succeed(self.names[hostname])
except KeyError:
return defer.fail(
DNSLookupError(
"FakeResolverReactor couldn't find " + hostname.decode("utf-8")
)
)
class SOCKSv4Driver(socks.SOCKSv4):
# last SOCKSv4Outgoing instantiated
driver_outgoing = None
# last SOCKSv4IncomingFactory instantiated
driver_listen = None
def connectClass(self, host, port, klass, *args):
# fake it
proto = klass(*args)
proto.transport = StringTCPTransport()
proto.transport.peer = address.IPv4Address("TCP", host, port)
proto.connectionMade()
self.driver_outgoing = proto
return defer.succeed(proto)
def listenClass(self, port, klass, *args):
# fake it
factory = klass(*args)
self.driver_listen = factory
if port == 0:
port = 1234
return defer.succeed(("6.7.8.9", port))
class ConnectTests(unittest.TestCase):
"""
Tests for SOCKS and SOCKSv4a connect requests using the L{SOCKSv4} protocol.
"""
def setUp(self):
self.sock = SOCKSv4Driver()
self.sock.transport = StringTCPTransport()
self.sock.connectionMade()
self.sock.reactor = FakeResolverReactor({b"localhost": "127.0.0.1"})
def tearDown(self):
outgoing = self.sock.driver_outgoing
if outgoing is not None:
self.assertTrue(
outgoing.transport.stringTCPTransport_closing,
"Outgoing SOCKS connections need to be closed.",
)
def test_simple(self):
self.sock.dataReceived(
struct.pack("!BBH", 4, 1, 34)
+ socket.inet_aton("1.2.3.4")
+ b"fooBAR"
+ b"\0"
)
sent = self.sock.transport.value()
self.sock.transport.clear()
self.assertEqual(
sent, struct.pack("!BBH", 0, 90, 34) + socket.inet_aton("1.2.3.4")
)
self.assertFalse(self.sock.transport.stringTCPTransport_closing)
self.assertIsNotNone(self.sock.driver_outgoing)
# pass some data through
self.sock.dataReceived(b"hello, world")
self.assertEqual(self.sock.driver_outgoing.transport.value(), b"hello, world")
# the other way around
self.sock.driver_outgoing.dataReceived(b"hi there")
self.assertEqual(self.sock.transport.value(), b"hi there")
self.sock.connectionLost("fake reason")
def test_socks4aSuccessfulResolution(self):
"""
If the destination IP address has zeros for the first three octets and
non-zero for the fourth octet, the client is attempting a v4a
connection. A hostname is specified after the user ID string and the
server connects to the address that hostname resolves to.
@see: U{http://en.wikipedia.org/wiki/SOCKS#SOCKS_4a_protocol}
"""
# send the domain name "localhost" to be resolved
clientRequest = (
struct.pack("!BBH", 4, 1, 34)
+ socket.inet_aton("0.0.0.1")
+ b"fooBAZ\0"
+ b"localhost\0"
)
# Deliver the bytes one by one to exercise the protocol's buffering
# logic. FakeResolverReactor's resolve method is invoked to "resolve"
# the hostname.
for byte in iterbytes(clientRequest):
self.sock.dataReceived(byte)
sent = self.sock.transport.value()
self.sock.transport.clear()
# Verify that the server responded with the address which will be
# connected to.
self.assertEqual(
sent, struct.pack("!BBH", 0, 90, 34) + socket.inet_aton("127.0.0.1")
)
self.assertFalse(self.sock.transport.stringTCPTransport_closing)
self.assertIsNotNone(self.sock.driver_outgoing)
# Pass some data through and verify it is forwarded to the outgoing
# connection.
self.sock.dataReceived(b"hello, world")
self.assertEqual(self.sock.driver_outgoing.transport.value(), b"hello, world")
# Deliver some data from the output connection and verify it is
# passed along to the incoming side.
self.sock.driver_outgoing.dataReceived(b"hi there")
self.assertEqual(self.sock.transport.value(), b"hi there")
self.sock.connectionLost("fake reason")
def test_socks4aFailedResolution(self):
"""
Failed hostname resolution on a SOCKSv4a packet results in a 91 error
response and the connection getting closed.
"""
# send the domain name "failinghost" to be resolved
clientRequest = (
struct.pack("!BBH", 4, 1, 34)
+ socket.inet_aton("0.0.0.1")
+ b"fooBAZ\0"
+ b"failinghost\0"
)
# Deliver the bytes one by one to exercise the protocol's buffering
# logic. FakeResolverReactor's resolve method is invoked to "resolve"
# the hostname.
for byte in iterbytes(clientRequest):
self.sock.dataReceived(byte)
# Verify that the server responds with a 91 error.
sent = self.sock.transport.value()
self.assertEqual(
sent, struct.pack("!BBH", 0, 91, 0) + socket.inet_aton("0.0.0.0")
)
# A failed resolution causes the transport to drop the connection.
self.assertTrue(self.sock.transport.stringTCPTransport_closing)
self.assertIsNone(self.sock.driver_outgoing)
def test_accessDenied(self):
self.sock.authorize = lambda code, server, port, user: 0
self.sock.dataReceived(
struct.pack("!BBH", 4, 1, 4242)
+ socket.inet_aton("10.2.3.4")
+ b"fooBAR"
+ b"\0"
)
self.assertEqual(
self.sock.transport.value(),
struct.pack("!BBH", 0, 91, 0) + socket.inet_aton("0.0.0.0"),
)
self.assertTrue(self.sock.transport.stringTCPTransport_closing)
self.assertIsNone(self.sock.driver_outgoing)
def test_eofRemote(self):
self.sock.dataReceived(
struct.pack("!BBH", 4, 1, 34)
+ socket.inet_aton("1.2.3.4")
+ b"fooBAR"
+ b"\0"
)
self.sock.transport.clear()
# pass some data through
self.sock.dataReceived(b"hello, world")
self.assertEqual(self.sock.driver_outgoing.transport.value(), b"hello, world")
# now close it from the server side
self.sock.driver_outgoing.transport.loseConnection()
self.sock.driver_outgoing.connectionLost("fake reason")
def test_eofLocal(self):
self.sock.dataReceived(
struct.pack("!BBH", 4, 1, 34)
+ socket.inet_aton("1.2.3.4")
+ b"fooBAR"
+ b"\0"
)
self.sock.transport.clear()
# pass some data through
self.sock.dataReceived(b"hello, world")
self.assertEqual(self.sock.driver_outgoing.transport.value(), b"hello, world")
# now close it from the client side
self.sock.connectionLost("fake reason")
class BindTests(unittest.TestCase):
"""
Tests for SOCKS and SOCKSv4a bind requests using the L{SOCKSv4} protocol.
"""
def setUp(self):
self.sock = SOCKSv4Driver()
self.sock.transport = StringTCPTransport()
self.sock.connectionMade()
self.sock.reactor = FakeResolverReactor({b"localhost": "127.0.0.1"})
## def tearDown(self):
## # TODO ensure the listen port is closed
## listen = self.sock.driver_listen
## if listen is not None:
## self.assert_(incoming.transport.stringTCPTransport_closing,
## "Incoming SOCKS connections need to be closed.")
def test_simple(self):
self.sock.dataReceived(
struct.pack("!BBH", 4, 2, 34)
+ socket.inet_aton("1.2.3.4")
+ b"fooBAR"
+ b"\0"
)
sent = self.sock.transport.value()
self.sock.transport.clear()
self.assertEqual(
sent, struct.pack("!BBH", 0, 90, 1234) + socket.inet_aton("6.7.8.9")
)
self.assertFalse(self.sock.transport.stringTCPTransport_closing)
self.assertIsNotNone(self.sock.driver_listen)
# connect
incoming = self.sock.driver_listen.buildProtocol(("1.2.3.4", 5345))
self.assertIsNotNone(incoming)
incoming.transport = StringTCPTransport()
incoming.connectionMade()
# now we should have the second reply packet
sent = self.sock.transport.value()
self.sock.transport.clear()
self.assertEqual(
sent, struct.pack("!BBH", 0, 90, 0) + socket.inet_aton("0.0.0.0")
)
self.assertFalse(self.sock.transport.stringTCPTransport_closing)
# pass some data through
self.sock.dataReceived(b"hello, world")
self.assertEqual(incoming.transport.value(), b"hello, world")
# the other way around
incoming.dataReceived(b"hi there")
self.assertEqual(self.sock.transport.value(), b"hi there")
self.sock.connectionLost("fake reason")
def test_socks4a(self):
"""
If the destination IP address has zeros for the first three octets and
non-zero for the fourth octet, the client is attempting a v4a
connection. A hostname is specified after the user ID string and the
server connects to the address that hostname resolves to.
@see: U{http://en.wikipedia.org/wiki/SOCKS#SOCKS_4a_protocol}
"""
# send the domain name "localhost" to be resolved
clientRequest = (
struct.pack("!BBH", 4, 2, 34)
+ socket.inet_aton("0.0.0.1")
+ b"fooBAZ\0"
+ b"localhost\0"
)
# Deliver the bytes one by one to exercise the protocol's buffering
# logic. FakeResolverReactor's resolve method is invoked to "resolve"
# the hostname.
for byte in iterbytes(clientRequest):
self.sock.dataReceived(byte)
sent = self.sock.transport.value()
self.sock.transport.clear()
# Verify that the server responded with the address which will be
# connected to.
self.assertEqual(
sent, struct.pack("!BBH", 0, 90, 1234) + socket.inet_aton("6.7.8.9")
)
self.assertFalse(self.sock.transport.stringTCPTransport_closing)
self.assertIsNotNone(self.sock.driver_listen)
# connect
incoming = self.sock.driver_listen.buildProtocol(("127.0.0.1", 5345))
self.assertIsNotNone(incoming)
incoming.transport = StringTCPTransport()
incoming.connectionMade()
# now we should have the second reply packet
sent = self.sock.transport.value()
self.sock.transport.clear()
self.assertEqual(
sent, struct.pack("!BBH", 0, 90, 0) + socket.inet_aton("0.0.0.0")
)
self.assertIsNot(self.sock.transport.stringTCPTransport_closing, None)
# Deliver some data from the output connection and verify it is
# passed along to the incoming side.
self.sock.dataReceived(b"hi there")
self.assertEqual(incoming.transport.value(), b"hi there")
# the other way around
incoming.dataReceived(b"hi there")
self.assertEqual(self.sock.transport.value(), b"hi there")
self.sock.connectionLost("fake reason")
def test_socks4aFailedResolution(self):
"""
Failed hostname resolution on a SOCKSv4a packet results in a 91 error
response and the connection getting closed.
"""
# send the domain name "failinghost" to be resolved
clientRequest = (
struct.pack("!BBH", 4, 2, 34)
+ socket.inet_aton("0.0.0.1")
+ b"fooBAZ\0"
+ b"failinghost\0"
)
# Deliver the bytes one by one to exercise the protocol's buffering
# logic. FakeResolverReactor's resolve method is invoked to "resolve"
# the hostname.
for byte in iterbytes(clientRequest):
self.sock.dataReceived(byte)
# Verify that the server responds with a 91 error.
sent = self.sock.transport.value()
self.assertEqual(
sent, struct.pack("!BBH", 0, 91, 0) + socket.inet_aton("0.0.0.0")
)
# A failed resolution causes the transport to drop the connection.
self.assertTrue(self.sock.transport.stringTCPTransport_closing)
self.assertIsNone(self.sock.driver_outgoing)
def test_accessDenied(self):
self.sock.authorize = lambda code, server, port, user: 0
self.sock.dataReceived(
struct.pack("!BBH", 4, 2, 4242)
+ socket.inet_aton("10.2.3.4")
+ b"fooBAR"
+ b"\0"
)
self.assertEqual(
self.sock.transport.value(),
struct.pack("!BBH", 0, 91, 0) + socket.inet_aton("0.0.0.0"),
)
self.assertTrue(self.sock.transport.stringTCPTransport_closing)
self.assertIsNone(self.sock.driver_listen)
def test_eofRemote(self):
self.sock.dataReceived(
struct.pack("!BBH", 4, 2, 34)
+ socket.inet_aton("1.2.3.4")
+ b"fooBAR"
+ b"\0"
)
sent = self.sock.transport.value()
self.sock.transport.clear()
# connect
incoming = self.sock.driver_listen.buildProtocol(("1.2.3.4", 5345))
self.assertIsNotNone(incoming)
incoming.transport = StringTCPTransport()
incoming.connectionMade()
# now we should have the second reply packet
sent = self.sock.transport.value()
self.sock.transport.clear()
self.assertEqual(
sent, struct.pack("!BBH", 0, 90, 0) + socket.inet_aton("0.0.0.0")
)
self.assertFalse(self.sock.transport.stringTCPTransport_closing)
# pass some data through
self.sock.dataReceived(b"hello, world")
self.assertEqual(incoming.transport.value(), b"hello, world")
# now close it from the server side
incoming.transport.loseConnection()
incoming.connectionLost("fake reason")
def test_eofLocal(self):
self.sock.dataReceived(
struct.pack("!BBH", 4, 2, 34)
+ socket.inet_aton("1.2.3.4")
+ b"fooBAR"
+ b"\0"
)
sent = self.sock.transport.value()
self.sock.transport.clear()
# connect
incoming = self.sock.driver_listen.buildProtocol(("1.2.3.4", 5345))
self.assertIsNotNone(incoming)
incoming.transport = StringTCPTransport()
incoming.connectionMade()
# now we should have the second reply packet
sent = self.sock.transport.value()
self.sock.transport.clear()
self.assertEqual(
sent, struct.pack("!BBH", 0, 90, 0) + socket.inet_aton("0.0.0.0")
)
self.assertFalse(self.sock.transport.stringTCPTransport_closing)
# pass some data through
self.sock.dataReceived(b"hello, world")
self.assertEqual(incoming.transport.value(), b"hello, world")
# now close it from the client side
self.sock.connectionLost("fake reason")
def test_badSource(self):
self.sock.dataReceived(
struct.pack("!BBH", 4, 2, 34)
+ socket.inet_aton("1.2.3.4")
+ b"fooBAR"
+ b"\0"
)
sent = self.sock.transport.value()
self.sock.transport.clear()
# connect from WRONG address
incoming = self.sock.driver_listen.buildProtocol(("1.6.6.6", 666))
self.assertIsNone(incoming)
# Now we should have the second reply packet and it should
# be a failure. The connection should be closing.
sent = self.sock.transport.value()
self.sock.transport.clear()
self.assertEqual(
sent, struct.pack("!BBH", 0, 91, 0) + socket.inet_aton("0.0.0.0")
)
self.assertTrue(self.sock.transport.stringTCPTransport_closing)

View File

@@ -0,0 +1,713 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for twisted SSL support.
"""
import os
import hamcrest
from twisted.internet import defer, interfaces, protocol, reactor
from twisted.internet.error import ConnectionDone
from twisted.internet.testing import waitUntilAllDisconnected
from twisted.protocols import basic
from twisted.python.filepath import FilePath
from twisted.python.runtime import platform
from twisted.test.test_tcp import ProperlyCloseFilesMixin
from twisted.trial.unittest import TestCase
try:
from OpenSSL import SSL, crypto
from twisted.internet import ssl
from twisted.test.ssl_helpers import ClientTLSContext, certPath
except ImportError:
def _noSSL():
# ugh, make pyflakes happy.
global SSL
global ssl
SSL = ssl = None
_noSSL()
from zope.interface import implementer
class UnintelligentProtocol(basic.LineReceiver):
"""
@ivar deferred: a deferred that will fire at connection lost.
@type deferred: L{defer.Deferred}
@cvar pretext: text sent before TLS is set up.
@type pretext: C{bytes}
@cvar posttext: text sent after TLS is set up.
@type posttext: C{bytes}
"""
pretext = [b"first line", b"last thing before tls starts", b"STARTTLS"]
posttext = [b"first thing after tls started", b"last thing ever"]
def __init__(self):
self.deferred = defer.Deferred()
def connectionMade(self):
for l in self.pretext:
self.sendLine(l)
def lineReceived(self, line):
if line == b"READY":
self.transport.startTLS(ClientTLSContext(), self.factory.client)
for l in self.posttext:
self.sendLine(l)
self.transport.loseConnection()
def connectionLost(self, reason):
self.deferred.callback(None)
class LineCollector(basic.LineReceiver):
"""
@ivar deferred: a deferred that will fire at connection lost.
@type deferred: L{defer.Deferred}
@ivar doTLS: whether the protocol is initiate TLS or not.
@type doTLS: C{bool}
@ivar fillBuffer: if set to True, it will send lots of data once
C{STARTTLS} is received.
@type fillBuffer: C{bool}
"""
def __init__(self, doTLS, fillBuffer=False):
self.doTLS = doTLS
self.fillBuffer = fillBuffer
self.deferred = defer.Deferred()
def connectionMade(self):
self.factory.rawdata = b""
self.factory.lines = []
def lineReceived(self, line):
self.factory.lines.append(line)
if line == b"STARTTLS":
if self.fillBuffer:
for x in range(500):
self.sendLine(b"X" * 1000)
self.sendLine(b"READY")
if self.doTLS:
ctx = ServerTLSContext(
privateKeyFileName=certPath,
certificateFileName=certPath,
)
self.transport.startTLS(ctx, self.factory.server)
else:
self.setRawMode()
def rawDataReceived(self, data):
self.factory.rawdata += data
self.transport.loseConnection()
def connectionLost(self, reason):
self.deferred.callback(None)
class SingleLineServerProtocol(protocol.Protocol):
"""
A protocol that sends a single line of data at C{connectionMade}.
"""
def connectionMade(self):
self.transport.write(b"+OK <some crap>\r\n")
self.transport.getPeerCertificate()
class RecordingClientProtocol(protocol.Protocol):
"""
@ivar deferred: a deferred that will fire with first received content.
@type deferred: L{defer.Deferred}
"""
def __init__(self):
self.deferred = defer.Deferred()
def connectionMade(self):
self.transport.getPeerCertificate()
def dataReceived(self, data):
self.deferred.callback(data)
@implementer(interfaces.IHandshakeListener)
class ImmediatelyDisconnectingProtocol(protocol.Protocol):
"""
A protocol that disconnect immediately on connection. It fires the
C{connectionDisconnected} deferred of its factory on connetion lost.
"""
def handshakeCompleted(self):
self.transport.loseConnection()
def connectionLost(self, reason):
self.factory.connectionDisconnected.callback(None)
def generateCertificateObjects(organization, organizationalUnit):
"""
Create a certificate for given C{organization} and C{organizationalUnit}.
@return: a tuple of (key, request, certificate) objects.
"""
pkey = crypto.PKey()
pkey.generate_key(crypto.TYPE_RSA, 2048)
req = crypto.X509Req()
subject = req.get_subject()
subject.O = organization
subject.OU = organizationalUnit
req.set_pubkey(pkey)
req.sign(pkey, "md5")
# Here comes the actual certificate
cert = crypto.X509()
cert.set_serial_number(1)
cert.gmtime_adj_notBefore(0)
cert.gmtime_adj_notAfter(60) # Testing certificates need not be long lived
cert.set_issuer(req.get_subject())
cert.set_subject(req.get_subject())
cert.set_pubkey(req.get_pubkey())
cert.sign(pkey, "md5")
return pkey, req, cert
def generateCertificateFiles(basename, organization, organizationalUnit):
"""
Create certificate files key, req and cert prefixed by C{basename} for
given C{organization} and C{organizationalUnit}.
"""
pkey, req, cert = generateCertificateObjects(organization, organizationalUnit)
for ext, obj, dumpFunc in [
("key", pkey, crypto.dump_privatekey),
("req", req, crypto.dump_certificate_request),
("cert", cert, crypto.dump_certificate),
]:
fName = os.extsep.join((basename, ext)).encode("utf-8")
FilePath(fName).setContent(dumpFunc(crypto.FILETYPE_PEM, obj))
class ContextGeneratingMixin:
"""
Offer methods to create L{ssl.DefaultOpenSSLContextFactory} for both client
and server.
@ivar clientBase: prefix of client certificate files.
@type clientBase: C{str}
@ivar serverBase: prefix of server certificate files.
@type serverBase: C{str}
@ivar clientCtxFactory: a generated context factory to be used in
L{IReactorSSL.connectSSL}.
@type clientCtxFactory: L{ssl.DefaultOpenSSLContextFactory}
@ivar serverCtxFactory: a generated context factory to be used in
L{IReactorSSL.listenSSL}.
@type serverCtxFactory: L{ssl.DefaultOpenSSLContextFactory}
"""
def makeContextFactory(self, org, orgUnit, *args, **kwArgs):
base = self.mktemp()
generateCertificateFiles(base, org, orgUnit)
serverCtxFactory = ssl.DefaultOpenSSLContextFactory(
os.extsep.join((base, "key")),
os.extsep.join((base, "cert")),
*args,
**kwArgs,
)
return base, serverCtxFactory
def setupServerAndClient(self, clientArgs, clientKwArgs, serverArgs, serverKwArgs):
self.clientBase, self.clientCtxFactory = self.makeContextFactory(
*clientArgs, **clientKwArgs
)
self.serverBase, self.serverCtxFactory = self.makeContextFactory(
*serverArgs, **serverKwArgs
)
if SSL is not None:
class ServerTLSContext(ssl.DefaultOpenSSLContextFactory):
"""
A context factory with a default method set to
L{OpenSSL.SSL.SSLv23_METHOD}.
"""
isClient = False
def __init__(self, *args, **kw):
kw["sslmethod"] = SSL.SSLv23_METHOD
ssl.DefaultOpenSSLContextFactory.__init__(self, *args, **kw)
class StolenTCPTests(ProperlyCloseFilesMixin, TestCase):
"""
For SSL transports, test many of the same things which are tested for
TCP transports.
"""
if interfaces.IReactorSSL(reactor, None) is None:
skip = "Reactor does not support SSL, cannot run SSL tests"
def createServer(self, address, portNumber, factory):
"""
Create an SSL server with a certificate using L{IReactorSSL.listenSSL}.
"""
cert = ssl.PrivateCertificate.loadPEM(FilePath(certPath).getContent())
contextFactory = cert.options()
return reactor.listenSSL(portNumber, factory, contextFactory, interface=address)
def connectClient(self, address, portNumber, clientCreator):
"""
Create an SSL client using L{IReactorSSL.connectSSL}.
"""
contextFactory = ssl.CertificateOptions()
return clientCreator.connectSSL(address, portNumber, contextFactory)
def getHandleExceptionType(self):
"""
Return L{OpenSSL.SSL.Error} as the expected error type which will be
raised by a write to the L{OpenSSL.SSL.Connection} object after it has
been closed.
"""
return SSL.Error
def getHandleErrorCodeMatcher(self):
"""
Return a L{hamcrest.core.matcher.Matcher} for the argument
L{OpenSSL.SSL.Error} will be constructed with for this case.
This is basically just a random OpenSSL implementation detail.
It would be better if this test worked in a way which did not
require this.
"""
# We expect an error about how we tried to write to a shutdown
# connection. This is terribly implementation-specific.
return hamcrest.contains(
hamcrest.contains(
hamcrest.equal_to("SSL routines"),
hamcrest.any_of(
hamcrest.equal_to("SSL_write"),
hamcrest.equal_to("ssl_write_internal"),
hamcrest.equal_to(""),
),
hamcrest.equal_to("protocol is shutdown"),
),
)
class TLSTests(TestCase):
"""
Tests for startTLS support.
@ivar fillBuffer: forwarded to L{LineCollector.fillBuffer}
@type fillBuffer: C{bool}
"""
if interfaces.IReactorSSL(reactor, None) is None:
skip = "Reactor does not support SSL, cannot run SSL tests"
fillBuffer = False
clientProto = None
serverProto = None
def tearDown(self):
if self.clientProto.transport is not None:
self.clientProto.transport.loseConnection()
if self.serverProto.transport is not None:
self.serverProto.transport.loseConnection()
def _runTest(self, clientProto, serverProto, clientIsServer=False):
"""
Helper method to run TLS tests.
@param clientProto: protocol instance attached to the client
connection.
@param serverProto: protocol instance attached to the server
connection.
@param clientIsServer: flag indicated if client should initiate
startTLS instead of server.
@return: a L{defer.Deferred} that will fire when both connections are
lost.
"""
self.clientProto = clientProto
cf = self.clientFactory = protocol.ClientFactory()
cf.protocol = lambda: clientProto
if clientIsServer:
cf.server = False
else:
cf.client = True
self.serverProto = serverProto
sf = self.serverFactory = protocol.ServerFactory()
sf.protocol = lambda: serverProto
if clientIsServer:
sf.client = False
else:
sf.server = True
port = reactor.listenTCP(0, sf, interface="127.0.0.1")
self.addCleanup(port.stopListening)
reactor.connectTCP("127.0.0.1", port.getHost().port, cf)
return defer.gatherResults([clientProto.deferred, serverProto.deferred])
def test_TLS(self):
"""
Test for server and client startTLS: client should received data both
before and after the startTLS.
"""
def check(ignore):
self.assertEqual(
self.serverFactory.lines,
UnintelligentProtocol.pretext + UnintelligentProtocol.posttext,
)
d = self._runTest(UnintelligentProtocol(), LineCollector(True, self.fillBuffer))
return d.addCallback(check)
def test_unTLS(self):
"""
Test for server startTLS not followed by a startTLS in client: the data
received after server startTLS should be received as raw.
"""
def check(ignored):
self.assertEqual(self.serverFactory.lines, UnintelligentProtocol.pretext)
self.assertTrue(self.serverFactory.rawdata, "No encrypted bytes received")
d = self._runTest(
UnintelligentProtocol(), LineCollector(False, self.fillBuffer)
)
return d.addCallback(check)
def test_backwardsTLS(self):
"""
Test startTLS first initiated by client.
"""
def check(ignored):
self.assertEqual(
self.clientFactory.lines,
UnintelligentProtocol.pretext + UnintelligentProtocol.posttext,
)
d = self._runTest(
LineCollector(True, self.fillBuffer), UnintelligentProtocol(), True
)
return d.addCallback(check)
class SpammyTLSTests(TLSTests):
"""
Test TLS features with bytes sitting in the out buffer.
"""
if interfaces.IReactorSSL(reactor, None) is None:
skip = "Reactor does not support SSL, cannot run SSL tests"
fillBuffer = True
class BufferingTests(TestCase):
if interfaces.IReactorSSL(reactor, None) is None:
skip = "Reactor does not support SSL, cannot run SSL tests"
serverProto = None
clientProto = None
def tearDown(self):
if self.serverProto.transport is not None:
self.serverProto.transport.loseConnection()
if self.clientProto.transport is not None:
self.clientProto.transport.loseConnection()
return waitUntilAllDisconnected(reactor, [self.serverProto, self.clientProto])
def test_openSSLBuffering(self):
serverProto = self.serverProto = SingleLineServerProtocol()
clientProto = self.clientProto = RecordingClientProtocol()
server = protocol.ServerFactory()
client = self.client = protocol.ClientFactory()
server.protocol = lambda: serverProto
client.protocol = lambda: clientProto
sCTX = ssl.DefaultOpenSSLContextFactory(certPath, certPath)
cCTX = ssl.ClientContextFactory()
port = reactor.listenSSL(0, server, sCTX, interface="127.0.0.1")
self.addCleanup(port.stopListening)
clientConnector = reactor.connectSSL(
"127.0.0.1", port.getHost().port, client, cCTX
)
self.addCleanup(clientConnector.disconnect)
return clientProto.deferred.addCallback(
self.assertEqual, b"+OK <some crap>\r\n"
)
class ConnectionLostTests(TestCase, ContextGeneratingMixin):
"""
SSL connection closing tests.
"""
if interfaces.IReactorSSL(reactor, None) is None:
skip = "Reactor does not support SSL, cannot run SSL tests"
def testImmediateDisconnect(self):
org = "twisted.test.test_ssl"
self.setupServerAndClient(
(org, org + ", client"), {}, (org, org + ", server"), {}
)
# Set up a server, connect to it with a client, which should work since our verifiers
# allow anything, then disconnect.
serverProtocolFactory = protocol.ServerFactory()
serverProtocolFactory.protocol = protocol.Protocol
self.serverPort = serverPort = reactor.listenSSL(
0, serverProtocolFactory, self.serverCtxFactory
)
clientProtocolFactory = protocol.ClientFactory()
clientProtocolFactory.protocol = ImmediatelyDisconnectingProtocol
clientProtocolFactory.connectionDisconnected = defer.Deferred()
reactor.connectSSL(
"127.0.0.1",
serverPort.getHost().port,
clientProtocolFactory,
self.clientCtxFactory,
)
return clientProtocolFactory.connectionDisconnected.addCallback(
lambda ignoredResult: self.serverPort.stopListening()
)
def test_bothSidesLoseConnection(self):
"""
Both sides of SSL connection close connection; the connections should
close cleanly, and only after the underlying TCP connection has
disconnected.
"""
@implementer(interfaces.IHandshakeListener)
class CloseAfterHandshake(protocol.Protocol):
gotData = False
def __init__(self):
self.done = defer.Deferred()
def handshakeCompleted(self):
self.transport.loseConnection()
def connectionLost(self, reason):
self.done.errback(reason)
del self.done
org = "twisted.test.test_ssl"
self.setupServerAndClient(
(org, org + ", client"), {}, (org, org + ", server"), {}
)
serverProtocol = CloseAfterHandshake()
serverProtocolFactory = protocol.ServerFactory()
serverProtocolFactory.protocol = lambda: serverProtocol
serverPort = reactor.listenSSL(0, serverProtocolFactory, self.serverCtxFactory)
self.addCleanup(serverPort.stopListening)
clientProtocol = CloseAfterHandshake()
clientProtocolFactory = protocol.ClientFactory()
clientProtocolFactory.protocol = lambda: clientProtocol
reactor.connectSSL(
"127.0.0.1",
serverPort.getHost().port,
clientProtocolFactory,
self.clientCtxFactory,
)
def checkResult(failure):
failure.trap(ConnectionDone)
return defer.gatherResults(
[
clientProtocol.done.addErrback(checkResult),
serverProtocol.done.addErrback(checkResult),
]
)
def testFailedVerify(self):
org = "twisted.test.test_ssl"
self.setupServerAndClient(
(org, org + ", client"), {}, (org, org + ", server"), {}
)
def verify(*a):
return False
self.clientCtxFactory.getContext().set_verify(SSL.VERIFY_PEER, verify)
serverConnLost = defer.Deferred()
serverProtocol = protocol.Protocol()
serverProtocol.connectionLost = serverConnLost.callback
serverProtocolFactory = protocol.ServerFactory()
serverProtocolFactory.protocol = lambda: serverProtocol
self.serverPort = serverPort = reactor.listenSSL(
0, serverProtocolFactory, self.serverCtxFactory
)
clientConnLost = defer.Deferred()
clientProtocol = protocol.Protocol()
clientProtocol.connectionLost = clientConnLost.callback
clientProtocolFactory = protocol.ClientFactory()
clientProtocolFactory.protocol = lambda: clientProtocol
reactor.connectSSL(
"127.0.0.1",
serverPort.getHost().port,
clientProtocolFactory,
self.clientCtxFactory,
)
dl = defer.DeferredList([serverConnLost, clientConnLost], consumeErrors=True)
return dl.addCallback(self._cbLostConns)
def _cbLostConns(self, results):
(sSuccess, sResult), (cSuccess, cResult) = results
self.assertFalse(sSuccess)
self.assertFalse(cSuccess)
acceptableErrors = [SSL.Error]
# Rather than getting a verification failure on Windows, we are getting
# a connection failure. Without something like sslverify proxying
# in-between we can't fix up the platform's errors, so let's just
# specifically say it is only OK in this one case to keep the tests
# passing. Normally we'd like to be as strict as possible here, so
# we're not going to allow this to report errors incorrectly on any
# other platforms.
if platform.isWindows():
from twisted.internet.error import ConnectionLost
acceptableErrors.append(ConnectionLost)
sResult.trap(*acceptableErrors)
cResult.trap(*acceptableErrors)
return self.serverPort.stopListening()
class FakeContext:
"""
L{OpenSSL.SSL.Context} double which can more easily be inspected.
"""
def __init__(self, method):
self._method = method
self._options = 0
def set_options(self, options):
self._options |= options
def use_certificate_file(self, fileName):
pass
def use_privatekey_file(self, fileName):
pass
class DefaultOpenSSLContextFactoryTests(TestCase):
"""
Tests for L{ssl.DefaultOpenSSLContextFactory}.
"""
if interfaces.IReactorSSL(reactor, None) is None:
skip = "Reactor does not support SSL, cannot run SSL tests"
def setUp(self):
# pyOpenSSL Context objects aren't introspectable enough. Pass in
# an alternate context factory so we can inspect what is done to it.
self.contextFactory = ssl.DefaultOpenSSLContextFactory(
certPath, certPath, _contextFactory=FakeContext
)
self.context = self.contextFactory.getContext()
def test_method(self):
"""
L{ssl.DefaultOpenSSLContextFactory.getContext} returns an SSL context
which can use SSLv3 or TLSv1 but not SSLv2.
"""
# TLS_METHOD allows for negotiating multiple versions of TLS
self.assertEqual(self.context._method, SSL.TLS_METHOD)
# OP_NO_SSLv2 disables SSLv2 support
self.assertEqual(self.context._options & SSL.OP_NO_SSLv2, SSL.OP_NO_SSLv2)
# Make sure TLSv1.2 isn't disabled though.
self.assertFalse(self.context._options & SSL.OP_NO_TLSv1_2)
def test_missingCertificateFile(self):
"""
Instantiating L{ssl.DefaultOpenSSLContextFactory} with a certificate
filename which does not identify an existing file results in the
initializer raising L{OpenSSL.SSL.Error}.
"""
self.assertRaises(
SSL.Error, ssl.DefaultOpenSSLContextFactory, certPath, self.mktemp()
)
def test_missingPrivateKeyFile(self):
"""
Instantiating L{ssl.DefaultOpenSSLContextFactory} with a private key
filename which does not identify an existing file results in the
initializer raising L{OpenSSL.SSL.Error}.
"""
self.assertRaises(
SSL.Error, ssl.DefaultOpenSSLContextFactory, self.mktemp(), certPath
)
class ClientContextFactoryTests(TestCase):
"""
Tests for L{ssl.ClientContextFactory}.
"""
if interfaces.IReactorSSL(reactor, None) is None:
skip = "Reactor does not support SSL, cannot run SSL tests"
def setUp(self):
self.contextFactory = ssl.ClientContextFactory()
self.contextFactory._contextFactory = FakeContext
self.context = self.contextFactory.getContext()
def test_method(self):
"""
L{ssl.ClientContextFactory.getContext} returns a context which can use
TLSv1.2 or 1.3 but nothing earlier.
"""
self.assertEqual(self.context._method, SSL.TLS_METHOD)
self.assertEqual(self.context._options & SSL.OP_NO_SSLv2, SSL.OP_NO_SSLv2)
self.assertTrue(self.context._options & SSL.OP_NO_SSLv3)
self.assertTrue(self.context._options & SSL.OP_NO_TLSv1)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,81 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Test cases for twisted.protocols.stateful
"""
from struct import calcsize, pack, unpack
from twisted.protocols.stateful import StatefulProtocol
from twisted.protocols.test import test_basic
from twisted.trial.unittest import TestCase
class MyInt32StringReceiver(StatefulProtocol):
"""
A stateful Int32StringReceiver.
"""
MAX_LENGTH = 99999
structFormat = "!I"
prefixLength = calcsize(structFormat)
def getInitialState(self):
return self._getHeader, 4
def lengthLimitExceeded(self, length):
self.transport.loseConnection()
def _getHeader(self, msg):
(length,) = unpack("!i", msg)
if length > self.MAX_LENGTH:
self.lengthLimitExceeded(length)
return
return self._getString, length
def _getString(self, msg):
self.stringReceived(msg)
return self._getHeader, 4
def stringReceived(self, msg):
"""
Override this.
"""
raise NotImplementedError
def sendString(self, data):
"""
Send an int32-prefixed string to the other end of the connection.
"""
self.transport.write(pack(self.structFormat, len(data)) + data)
class TestInt32(MyInt32StringReceiver):
def connectionMade(self):
self.received = []
def stringReceived(self, s):
self.received.append(s)
MAX_LENGTH = 50
closed = 0
def connectionLost(self, reason):
self.closed = 1
class Int32Tests(TestCase, test_basic.IntNTestCaseMixin):
protocol = TestInt32
strings = [b"a", b"b" * 16]
illegalStrings = [b"\x10\x00\x00\x00aaaaaa"]
partialStrings = [b"\x00\x00\x00", b"hello there", b""]
def test_bigReceive(self):
r = self.getProtocol()
big = b""
for s in self.strings * 4:
big += pack("!i", len(s)) + s
r.dataReceived(big)
self.assertEqual(r.received, self.strings * 4)

View File

@@ -0,0 +1,406 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.internet.stdio}.
@var properEnv: A copy of L{os***REMOVED***iron} which has L{bytes} keys/values on POSIX
platforms and native L{str} keys/values on Windows.
"""
from __future__ import annotations
import itertools
import os
import sys
from typing import Any, Callable
from unittest import skipIf
from twisted.internet import defer, error, protocol, reactor, stdio
from twisted.internet.interfaces import IProcessTransport, IReactorProcess
from twisted.internet.protocol import ProcessProtocol
from twisted.python import filepath, log
from twisted.python.failure import Failure
from twisted.python.reflect import requireModule
from twisted.python.runtime import platform
from twisted.test.test_tcp import ConnectionLostNotifyingProtocol
from twisted.trial.unittest import SkipTest, TestCase
# A short string which is intended to appear here and nowhere else,
# particularly not in any random garbage output CPython unavoidable
# generates (such as in warning text and so forth). This is searched
# for in the output from stdio_test_lastwrite and if it is found at
# the end, the functionality works.
UNIQUE_LAST_WRITE_STRING = b"xyz123abc Twisted is great!"
properEnv = dict(os***REMOVED***iron)
properEnv["PYTHONPATH"] = os.pathsep.join(sys.path)
class StandardIOTestProcessProtocol(protocol.ProcessProtocol):
"""
Test helper for collecting output from a child process and notifying
something when it exits.
@ivar onConnection: A L{defer.Deferred} which will be called back with
L{None} when the connection to the child process is established.
@ivar onCompletion: A L{defer.Deferred} which will be errbacked with the
failure associated with the child process exiting when it exits.
@ivar onDataReceived: A L{defer.Deferred} which will be called back with
this instance whenever C{childDataReceived} is called, or L{None} to
suppress these callbacks.
@ivar data: A C{dict} mapping file descriptors to strings containing all
bytes received from the child process on each file descriptor.
"""
onDataReceived: defer.Deferred[None] | None = None
transport: IProcessTransport
def __init__(self) -> None:
self.onConnection: defer.Deferred[None] = defer.Deferred()
self.onCompletion: defer.Deferred[None] = defer.Deferred()
self.data: dict[str, bytes] = {}
def connectionMade(self):
self.onConnection.callback(None)
def childDataReceived(self, name, bytes):
"""
Record all bytes received from the child process in the C{data}
dictionary. Fire C{onDataReceived} if it is not L{None}.
"""
self.data[name] = self.data.get(name, b"") + bytes
if self.onDataReceived is not None:
d, self.onDataReceived = self.onDataReceived, None
d.callback(self)
def processEnded(self, reason):
self.onCompletion.callback(reason)
class StandardInputOutputTests(TestCase):
if platform.isWindows() and requireModule("win32process") is None:
skip = (
"On windows, spawnProcess is not available in the "
"absence of win32process."
)
def _spawnProcess(
self, proto: ProcessProtocol, sibling: str | bytes, *args: str, **kw: Any
) -> IProcessTransport:
"""
Launch a child Python process and communicate with it using the given
ProcessProtocol.
@param proto: A L{ProcessProtocol} instance which will be connected to
the child process.
@param sibling: The basename of a file containing the Python program to
run in the child process.
@param *args: strings which will be passed to the child process on the
command line as C{argv[2:]}.
@param **kw: additional arguments to pass to L{reactor.spawnProcess}.
@return: The L{IProcessTransport} provider for the spawned process.
"""
if isinstance(sibling, bytes):
sibling = sibling.decode()
procargs = [
sys.executable,
"-m",
"twisted.test." + sibling,
reactor.__class__.__module__,
] + list(args)
return IReactorProcess(reactor).spawnProcess(
proto, sys.executable, procargs, env=properEnv, **kw
)
def _requireFailure(
self, d: defer.Deferred[None], callback: Callable[[Failure], object]
) -> defer.Deferred[None]:
def cb(result):
self.fail(f"Process terminated with non-Failure: {result!r}")
def eb(err):
return callback(err)
return d.addCallbacks(cb, eb)
def test_loseConnection(self):
"""
Verify that a protocol connected to L{StandardIO} can disconnect
itself using C{transport.loseConnection}.
"""
errorLogFile = self.mktemp()
log.msg("Child process logging to " + errorLogFile)
p = StandardIOTestProcessProtocol()
d = p.onCompletion
self._spawnProcess(p, b"stdio_test_loseconn", errorLogFile)
def processEnded(reason):
# Copy the child's log to ours so it's more visible.
with open(errorLogFile) as f:
for line in f:
log.msg("Child logged: " + line.rstrip())
self.failIfIn(1, p.data)
reason.trap(error.ProcessDone)
return self._requireFailure(d, processEnded)
def exampleOutputsAndZeroExitCode(
self, example: str, out: bool = False
) -> defer.Deferred[None]:
errorLogFile = self.mktemp()
p = StandardIOTestProcessProtocol()
p.onDataReceived = defer.Deferred()
def cbBytes(ignored: None) -> defer.Deferred[None]:
d = p.onCompletion
if out:
p.transport.closeStdout()
else:
p.transport.closeStdin()
return d
p.onDataReceived.addCallback(cbBytes)
def processEnded(reason):
reason.trap(error.ProcessDone)
d = self._requireFailure(p.onDataReceived, processEnded)
self._spawnProcess(p, example, errorLogFile)
return d
def test_readConnectionLost(self) -> defer.Deferred[None]:
"""
When stdin is closed and the protocol connected to it implements
L{IHalfCloseableProtocol}, the protocol's C{readConnectionLost} method
is called.
"""
return self.exampleOutputsAndZeroExitCode("stdio_test_halfclose")
def test_buggyReadConnectionLost(self) -> defer.Deferred[None]:
"""
When stdin is closed and the protocol connnected to it implements
L{IHalfCloseableProtocol} but its C{readConnectionLost} method raises
an exception its regular C{connectionLost} method will be called.
"""
return self.exampleOutputsAndZeroExitCode("stdio_test_halfclose_buggy")
def test_buggyWriteConnectionLost(self) -> defer.Deferred[None]:
"""
When stdin is closed and the protocol connnected to it implements
L{IHalfCloseableProtocol} but its C{readConnectionLost} method raises
an exception its regular C{connectionLost} method will be called.
"""
return self.exampleOutputsAndZeroExitCode(
"stdio_test_halfclose_buggy_write", out=True
)
def test_lastWriteReceived(self):
"""
Verify that a write made directly to stdout using L{os.write}
after StandardIO has finished is reliably received by the
process reading that stdout.
"""
p = StandardIOTestProcessProtocol()
# Note: the macOS bug which prompted the addition of this test
# is an apparent race condition involving non-blocking PTYs.
# Delaying the parent process significantly increases the
# likelihood of the race going the wrong way. If you need to
# fiddle with this code at all, uncommenting the next line
# will likely make your life much easier. It is commented out
# because it makes the test quite slow.
# p.onConnection.addCallback(lambda ign: __import__('time').sleep(5))
try:
self._spawnProcess(
p, b"stdio_test_lastwrite", UNIQUE_LAST_WRITE_STRING, usePTY=True
)
except ValueError as e:
# Some platforms don't work with usePTY=True
raise SkipTest(str(e))
def processEnded(reason):
"""
Asserts that the parent received the bytes written by the child
immediately after the child starts.
"""
self.assertTrue(
p.data[1].endswith(UNIQUE_LAST_WRITE_STRING),
f"Received {p.data!r} from child, did not find expected bytes.",
)
reason.trap(error.ProcessDone)
return self._requireFailure(p.onCompletion, processEnded)
def test_hostAndPeer(self):
"""
Verify that the transport of a protocol connected to L{StandardIO}
has C{getHost} and C{getPeer} methods.
"""
p = StandardIOTestProcessProtocol()
d = p.onCompletion
self._spawnProcess(p, b"stdio_test_hostpeer")
def processEnded(reason):
host, peer = p.data[1].splitlines()
self.assertTrue(host)
self.assertTrue(peer)
reason.trap(error.ProcessDone)
return self._requireFailure(d, processEnded)
def test_write(self):
"""
Verify that the C{write} method of the transport of a protocol
connected to L{StandardIO} sends bytes to standard out.
"""
p = StandardIOTestProcessProtocol()
d = p.onCompletion
self._spawnProcess(p, b"stdio_test_write")
def processEnded(reason):
self.assertEqual(p.data[1], b"ok!")
reason.trap(error.ProcessDone)
return self._requireFailure(d, processEnded)
def test_writeSequence(self):
"""
Verify that the C{writeSequence} method of the transport of a
protocol connected to L{StandardIO} sends bytes to standard out.
"""
p = StandardIOTestProcessProtocol()
d = p.onCompletion
self._spawnProcess(p, b"stdio_test_writeseq")
def processEnded(reason):
self.assertEqual(p.data[1], b"ok!")
reason.trap(error.ProcessDone)
return self._requireFailure(d, processEnded)
def _junkPath(self):
junkPath = self.mktemp()
with open(junkPath, "wb") as junkFile:
for i in range(1024):
junkFile.write(b"%d\n" % (i,))
return junkPath
def test_producer(self):
"""
Verify that the transport of a protocol connected to L{StandardIO}
is a working L{IProducer} provider.
"""
p = StandardIOTestProcessProtocol()
d = p.onCompletion
written = []
toWrite = list(range(100))
def connectionMade(ign):
if toWrite:
written.append(b"%d\n" % (toWrite.pop(),))
proc.write(written[-1])
reactor.callLater(0.01, connectionMade, None)
proc = self._spawnProcess(p, b"stdio_test_producer")
p.onConnection.addCallback(connectionMade)
def processEnded(reason):
self.assertEqual(p.data[1], b"".join(written))
self.assertFalse(
toWrite, "Connection lost with %d writes left to go." % (len(toWrite),)
)
reason.trap(error.ProcessDone)
return self._requireFailure(d, processEnded)
def test_consumer(self):
"""
Verify that the transport of a protocol connected to L{StandardIO}
is a working L{IConsumer} provider.
"""
p = StandardIOTestProcessProtocol()
d = p.onCompletion
junkPath = self._junkPath()
self._spawnProcess(p, b"stdio_test_consumer", junkPath)
def processEnded(reason):
with open(junkPath, "rb") as f:
self.assertEqual(p.data[1], f.read())
reason.trap(error.ProcessDone)
return self._requireFailure(d, processEnded)
@skipIf(
platform.isWindows(),
"StandardIO does not accept stdout as an argument to Windows. "
"Testing redirection to a file is therefore harder.",
)
def test_normalFileStandardOut(self):
"""
If L{StandardIO} is created with a file descriptor which refers to a
normal file (ie, a file from the filesystem), L{StandardIO.write}
writes bytes to that file. In particular, it does not immediately
consider the file closed or call its protocol's C{connectionLost}
method.
"""
onConnLost = defer.Deferred()
proto = ConnectionLostNotifyingProtocol(onConnLost)
path = filepath.FilePath(self.mktemp())
self.normal = normal = path.open("wb")
self.addCleanup(normal.close)
kwargs = dict(stdout=normal.fileno())
if not platform.isWindows():
# Make a fake stdin so that StandardIO doesn't mess with the *real*
# stdin.
r, w = os.pipe()
self.addCleanup(os.close, r)
self.addCleanup(os.close, w)
kwargs["stdin"] = r
connection = stdio.StandardIO(proto, **kwargs)
# The reactor needs to spin a bit before it might have incorrectly
# decided stdout is closed. Use this counter to keep track of how
# much we've let it spin. If it closes before we expected, this
# counter will have a value that's too small and we'll know.
howMany = 5
count = itertools.count()
def spin():
for value in count:
if value == howMany:
connection.loseConnection()
return
connection.write(b"%d" % (value,))
break
reactor.callLater(0, spin)
reactor.callLater(0, spin)
# Once the connection is lost, make sure the counter is at the
# appropriate value.
def cbLost(reason):
self.assertEqual(next(count), howMany + 1)
self.assertEqual(
path.getContent(), b"".join(b"%d" % (i,) for i in range(howMany))
)
onConnLost.addCallback(cbLost)
return onConnLost

View File

@@ -0,0 +1,157 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Test strerror
"""
import os
import socket
from unittest import skipIf
from twisted.internet.tcp import ECONNABORTED
from twisted.python.runtime import platform
from twisted.python.win32 import _ErrorFormatter, formatError
from twisted.trial.unittest import TestCase
class _MyWindowsException(OSError):
"""
An exception type like L{ctypes.WinError}, but available on all platforms.
"""
class ErrorFormatingTests(TestCase):
"""
Tests for C{_ErrorFormatter.formatError}.
"""
probeErrorCode = ECONNABORTED
probeMessage = "correct message value"
def test_strerrorFormatting(self):
"""
L{_ErrorFormatter.formatError} should use L{os.strerror} to format
error messages if it is constructed without any better mechanism.
"""
formatter = _ErrorFormatter(None, None, None)
message = formatter.formatError(self.probeErrorCode)
self.assertEqual(message, os.strerror(self.probeErrorCode))
def test_emptyErrorTab(self):
"""
L{_ErrorFormatter.formatError} should use L{os.strerror} to format
error messages if it is constructed with only an error tab which does
not contain the error code it is called with.
"""
error = 1
# Sanity check
self.assertNotEqual(self.probeErrorCode, error)
formatter = _ErrorFormatter(None, None, {error: "wrong message"})
message = formatter.formatError(self.probeErrorCode)
self.assertEqual(message, os.strerror(self.probeErrorCode))
def test_errorTab(self):
"""
L{_ErrorFormatter.formatError} should use C{errorTab} if it is supplied
and contains the requested error code.
"""
formatter = _ErrorFormatter(
None, None, {self.probeErrorCode: self.probeMessage}
)
message = formatter.formatError(self.probeErrorCode)
self.assertEqual(message, self.probeMessage)
def test_formatMessage(self):
"""
L{_ErrorFormatter.formatError} should return the return value of
C{formatMessage} if it is supplied.
"""
formatCalls = []
def formatMessage(errorCode):
formatCalls.append(errorCode)
return self.probeMessage
formatter = _ErrorFormatter(
None, formatMessage, {self.probeErrorCode: "wrong message"}
)
message = formatter.formatError(self.probeErrorCode)
self.assertEqual(message, self.probeMessage)
self.assertEqual(formatCalls, [self.probeErrorCode])
def test_winError(self):
"""
L{_ErrorFormatter.formatError} should return the message argument from
the exception L{winError} returns, if L{winError} is supplied.
"""
winCalls = []
def winError(errorCode):
winCalls.append(errorCode)
return _MyWindowsException(errorCode, self.probeMessage)
formatter = _ErrorFormatter(
winError,
lambda error: "formatMessage: wrong message",
{self.probeErrorCode: "errorTab: wrong message"},
)
message = formatter.formatError(self.probeErrorCode)
self.assertEqual(message, self.probeMessage)
@skipIf(platform.getType() != "win32", "Test will run only on Windows.")
def test_fromEnvironment(self):
"""
L{_ErrorFormatter.fromEnvironment} should create an L{_ErrorFormatter}
instance with attributes populated from available modules.
"""
formatter = _ErrorFormatter.fromEnvironment()
if formatter.winError is not None:
from ctypes import WinError
self.assertEqual(
formatter.formatError(self.probeErrorCode),
WinError(self.probeErrorCode).strerror,
)
formatter.winError = None
if formatter.formatMessage is not None:
from win32api import FormatMessage
self.assertEqual(
formatter.formatError(self.probeErrorCode),
FormatMessage(self.probeErrorCode),
)
formatter.formatMessage = None
if formatter.errorTab is not None:
from socket import errorTab
self.assertEqual(
formatter.formatError(self.probeErrorCode),
errorTab[self.probeErrorCode],
)
@skipIf(platform.getType() != "win32", "Test will run only on Windows.")
def test_correctLookups(self):
"""
Given a known-good errno, make sure that formatMessage gives results
matching either C{socket.errorTab}, C{ctypes.WinError}, or
C{win32api.FormatMessage}.
"""
acceptable = [socket.errorTab[ECONNABORTED]]
try:
from ctypes import WinError
acceptable.append(WinError(ECONNABORTED).strerror)
except ImportError:
pass
try:
from win32api import FormatMessage
acceptable.append(FormatMessage(ECONNABORTED))
except ImportError:
pass
self.assertIn(formatError(ECONNABORTED), acceptable)

View File

@@ -0,0 +1,49 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.application.strports}.
"""
from twisted.application import internet, strports
from twisted.internet.endpoints import TCP4ServerEndpoint
from twisted.internet.protocol import Factory
from twisted.trial.unittest import TestCase
class ServiceTests(TestCase):
"""
Tests for L{strports.service}.
"""
def test_service(self):
"""
L{strports.service} returns a L{StreamServerEndpointService}
constructed with an endpoint produced from
L{endpoint.serverFromString}, using the same syntax.
"""
reactor = object() # the cake is a lie
aFactory = Factory()
aGoodPort = 1337
svc = strports.service("tcp:" + str(aGoodPort), aFactory, reactor=reactor)
self.assertIsInstance(svc, internet.StreamServerEndpointService)
# See twisted.application.test.test_internet.EndpointServiceTests.
# test_synchronousRaiseRaisesSynchronously
self.assertTrue(svc._raiseSynchronously)
self.assertIsInstance(svc.endpoint, TCP4ServerEndpoint)
# Maybe we should implement equality for endpoints.
self.assertEqual(svc.endpoint._port, aGoodPort)
self.assertIs(svc.factory, aFactory)
self.assertIs(svc.endpoint._reactor, reactor)
def test_serviceDefaultReactor(self):
"""
L{strports.service} will use the default reactor when none is provided
as an argument.
"""
from twisted.internet import reactor as globalReactor
aService = strports.service("tcp:80", None)
self.assertIs(aService.endpoint._reactor, globalReactor)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,380 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Whitebox tests for TCP APIs.
"""
import errno
import os
import socket
try:
import resource
except ImportError:
resource = None # type: ignore[assignment]
from unittest import skipIf
from twisted.internet import interfaces, reactor
from twisted.internet.defer import gatherResults, maybeDeferred
from twisted.internet.protocol import Protocol, ServerFactory
from twisted.internet.tcp import (
_ACCEPT_ERRORS,
EAGAIN,
ECONNABORTED,
EINPROGRESS,
EMFILE,
ENFILE,
ENOBUFS,
ENOMEM,
EPERM,
EWOULDBLOCK,
Port,
)
from twisted.python import log
from twisted.python.runtime import platform
from twisted.trial.unittest import TestCase
@skipIf(
not interfaces.IReactorFDSet.providedBy(reactor),
"This test only applies to reactors that implement IReactorFDset",
)
class PlatformAssumptionsTests(TestCase):
"""
Test assumptions about platform behaviors.
"""
socketLimit = 8192
def setUp(self):
self.openSockets = []
if resource is not None:
# On some buggy platforms we might leak FDs, and the test will
# fail creating the initial two sockets we *do* want to
# succeed. So, we make the soft limit the current number of fds
# plus two more (for the two sockets we want to succeed). If we've
# leaked too many fds for that to work, there's nothing we can
# do.
from twisted.internet.process import _listOpenFDs
newLimit = len(_listOpenFDs()) + 2
self.originalFileLimit = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(
resource.RLIMIT_NOFILE, (newLimit, self.originalFileLimit[1])
)
self.socketLimit = newLimit + 100
def tearDown(self):
while self.openSockets:
self.openSockets.pop().close()
if resource is not None:
# `macOS` implicitly lowers the hard limit in the setrlimit call
# above. Retrieve the new hard limit to pass in to this
# setrlimit call, so that it doesn't give us a permission denied
# error.
currentHardLimit = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
newSoftLimit = min(self.originalFileLimit[0], currentHardLimit)
resource.setrlimit(resource.RLIMIT_NOFILE, (newSoftLimit, currentHardLimit))
def socket(self):
"""
Create and return a new socket object, also tracking it so it can be
closed in the test tear down.
"""
s = socket.socket()
self.openSockets.append(s)
return s
@skipIf(
platform.getType() == "win32",
"Windows requires an unacceptably large amount of resources to "
"provoke this behavior in the naive manner.",
)
def test_acceptOutOfFiles(self):
"""
Test that the platform accept(2) call fails with either L{EMFILE} or
L{ENOBUFS} when there are too many file descriptors open.
"""
# Make a server to which to connect
port = self.socket()
port.bind(("127.0.0.1", 0))
serverPortNumber = port.getsockname()[1]
port.listen(5)
# Make a client to use to connect to the server
client = self.socket()
client.setblocking(False)
# Use up all the rest of the file descriptors.
for i in range(self.socketLimit):
try:
self.socket()
except OSError as e:
if e.args[0] in (EMFILE, ENOBUFS):
# The desired state has been achieved.
break
else:
# Some unexpected error occurred.
raise
else:
self.fail("Could provoke neither EMFILE nor ENOBUFS from platform.")
# Non-blocking connect is supposed to fail, but this is not true
# everywhere (e.g. freeBSD)
self.assertIn(
client.connect_ex(("127.0.0.1", serverPortNumber)), (0, EINPROGRESS)
)
# Make sure that the accept call fails in the way we expect.
exc = self.assertRaises(socket.error, port.accept)
self.assertIn(exc.args[0], (EMFILE, ENOBUFS))
@skipIf(
not interfaces.IReactorFDSet.providedBy(reactor),
"This test only applies to reactors that implement IReactorFDset",
)
class SelectReactorTests(TestCase):
"""
Tests for select-specific failure conditions.
"""
def setUp(self):
self.ports = []
self.messages = []
log.addObserver(self.messages.append)
def tearDown(self):
log.removeObserver(self.messages.append)
return gatherResults([maybeDeferred(p.stopListening) for p in self.ports])
def port(self, portNumber, factory, interface):
"""
Create, start, and return a new L{Port}, also tracking it so it can
be stopped in the test tear down.
"""
p = Port(portNumber, factory, interface=interface)
p.startListening()
self.ports.append(p)
return p
def _acceptFailureTest(self, socketErrorNumber):
"""
Test behavior in the face of an exception from C{accept(2)}.
On any exception which indicates the platform is unable or unwilling
to allocate further resources to us, the existing port should remain
listening, a message should be logged, and the exception should not
propagate outward from doRead.
@param socketErrorNumber: The errno to simulate from accept.
"""
class FakeSocket:
"""
Pretend to be a socket in an overloaded system.
"""
def accept(self):
raise OSError(socketErrorNumber, os.strerror(socketErrorNumber))
factory = ServerFactory()
port = self.port(0, factory, interface="127.0.0.1")
self.patch(port, "socket", FakeSocket())
port.doRead()
expectedFormat = "Could not accept new connection ({acceptError})"
expectedErrorCode = errno.errorcode[socketErrorNumber]
matchingMessages = [
(
msg.get("log_format") == expectedFormat
and msg.get("acceptError") == expectedErrorCode
)
for msg in self.messages
]
self.assertGreater(
len(matchingMessages),
0,
"Log event for failed accept not found in " "%r" % (self.messages,),
)
def test_tooManyFilesFromAccept(self):
"""
C{accept(2)} can fail with C{EMFILE} when there are too many open file
descriptors in the process. Test that this doesn't negatively impact
any other existing connections.
C{EMFILE} mainly occurs on Linux when the open file rlimit is
encountered.
"""
return self._acceptFailureTest(EMFILE)
def test_noBufferSpaceFromAccept(self):
"""
Similar to L{test_tooManyFilesFromAccept}, but test the case where
C{accept(2)} fails with C{ENOBUFS}.
This mainly occurs on Windows and FreeBSD, but may be possible on
Linux and other platforms as well.
"""
return self._acceptFailureTest(ENOBUFS)
def test_connectionAbortedFromAccept(self):
"""
Similar to L{test_tooManyFilesFromAccept}, but test the case where
C{accept(2)} fails with C{ECONNABORTED}.
It is not clear whether this is actually possible for TCP
connections on modern versions of Linux.
"""
return self._acceptFailureTest(ECONNABORTED)
@skipIf(platform.getType() == "win32", "Windows accept(2) cannot generate ENFILE")
def test_noFilesFromAccept(self):
"""
Similar to L{test_tooManyFilesFromAccept}, but test the case where
C{accept(2)} fails with C{ENFILE}.
This can occur on Linux when the system has exhausted (!) its supply
of inodes.
"""
return self._acceptFailureTest(ENFILE)
@skipIf(platform.getType() == "win32", "Windows accept(2) cannot generate ENOMEM")
def test_noMemoryFromAccept(self):
"""
Similar to L{test_tooManyFilesFromAccept}, but test the case where
C{accept(2)} fails with C{ENOMEM}.
On Linux at least, this can sensibly occur, even in a Python program
(which eats memory like no ones business), when memory has become
fragmented or low memory has been filled (d_alloc calls
kmem_cache_alloc calls kmalloc - kmalloc only allocates out of low
memory).
"""
return self._acceptFailureTest(ENOMEM)
@skipIf(
os***REMOVED***iron.get("INFRASTRUCTURE") == "AZUREPIPELINES",
"Hangs on Azure Pipelines due to firewall",
)
def test_acceptScaling(self):
"""
L{tcp.Port.doRead} increases the number of consecutive
C{accept} calls it performs if all of the previous C{accept}
calls succeed; otherwise, it reduces the number to the amount
of successful calls.
"""
factory = ServerFactory()
factory.protocol = Protocol
port = self.port(0, factory, interface="127.0.0.1")
self.addCleanup(port.stopListening)
clients = []
def closeAll():
for client in clients:
client.close()
self.addCleanup(closeAll)
def connect():
client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
client.connect(("127.0.0.1", port.getHost().port))
return client
clients.append(connect())
port.numberAccepts = 1
port.doRead()
self.assertGreater(port.numberAccepts, 1)
clients.append(connect())
port.doRead()
# There was only one outstanding client connection, so only
# one accept(2) was possible.
self.assertEqual(port.numberAccepts, 1)
port.doRead()
# There were no outstanding client connections, so only one
# accept should be tried next.
self.assertEqual(port.numberAccepts, 1)
@skipIf(platform.getType() == "win32", "Windows accept(2) cannot generate EPERM")
def test_permissionFailure(self):
"""
C{accept(2)} returning C{EPERM} is treated as a transient
failure and the call retried no more than the maximum number
of consecutive C{accept(2)} calls.
"""
maximumNumberOfAccepts = 123
acceptCalls = [0]
class FakeSocketWithAcceptLimit:
"""
Pretend to be a socket in an overloaded system whose
C{accept} method can only be called
C{maximumNumberOfAccepts} times.
"""
def accept(oself):
acceptCalls[0] += 1
if acceptCalls[0] > maximumNumberOfAccepts:
self.fail("Maximum number of accept calls exceeded.")
raise OSError(EPERM, os.strerror(EPERM))
# Verify that FakeSocketWithAcceptLimit.accept() fails the
# test if the number of accept calls exceeds the maximum.
for _ in range(maximumNumberOfAccepts):
self.assertRaises(socket.error, FakeSocketWithAcceptLimit().accept)
self.assertRaises(self.failureException, FakeSocketWithAcceptLimit().accept)
acceptCalls = [0]
factory = ServerFactory()
port = self.port(0, factory, interface="127.0.0.1")
port.numberAccepts = 123
self.patch(port, "socket", FakeSocketWithAcceptLimit())
# This should not loop infinitely.
port.doRead()
# This is scaled down to 1 because no accept(2)s returned
# successfully.
self.assertEquals(port.numberAccepts, 1)
def test_unknownSocketErrorRaise(self):
"""
A C{socket.error} raised by C{accept(2)} whose C{errno} is
unknown to the recovery logic is logged.
"""
knownErrors = list(_ACCEPT_ERRORS)
knownErrors.extend([EAGAIN, EPERM, EWOULDBLOCK])
# Windows has object()s stubs for some errnos.
unknownAcceptError = (
max(error for error in knownErrors if isinstance(error, int)) + 1
)
class FakeSocketWithUnknownAcceptError:
"""
Pretend to be a socket in an overloaded system whose
C{accept} method can only be called
C{maximumNumberOfAccepts} times.
"""
def accept(oself):
raise OSError(unknownAcceptError, "unknown socket error message")
factory = ServerFactory()
port = self.port(0, factory, interface="127.0.0.1")
self.patch(port, "socket", FakeSocketWithUnknownAcceptError())
port.doRead()
failures = self.flushLoggedErrors(socket.error)
self.assertEqual(1, len(failures))
self.assertEqual(failures[0].value.args[0], unknownAcceptError)

View File

@@ -0,0 +1,235 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.python.text}.
"""
from io import StringIO
from twisted.python import text
from twisted.trial import unittest
sampleText = """Every attempt to employ mathematical methods in the study of chemical
questions must be considered profoundly irrational and contrary to the
spirit of chemistry ... If mathematical analysis should ever hold a
prominent place in chemistry - an aberration which is happily almost
impossible - it would occasion a rapid and widespread degeneration of that
science.
-- Auguste Comte, Philosophie Positive, Paris, 1838
"""
class WrapTests(unittest.TestCase):
"""
Tests for L{text.greedyWrap}.
"""
def setUp(self) -> None:
self.lineWidth = 72
self.sampleSplitText = sampleText.split()
self.output = text.wordWrap(sampleText, self.lineWidth)
def test_wordCount(self) -> None:
"""
Compare the number of words.
"""
words = []
for line in self.output:
words.extend(line.split())
wordCount = len(words)
sampleTextWordCount = len(self.sampleSplitText)
self.assertEqual(wordCount, sampleTextWordCount)
def test_wordMatch(self) -> None:
"""
Compare the lists of words.
"""
words = []
for line in self.output:
words.extend(line.split())
# Using assertEqual here prints out some
# rather too long lists.
self.assertTrue(self.sampleSplitText == words)
def test_lineLength(self) -> None:
"""
Check the length of the lines.
"""
failures = []
for line in self.output:
if not len(line) <= self.lineWidth:
failures.append(len(line))
if failures:
self.fail(
"%d of %d lines were too long.\n"
"%d < %s" % (len(failures), len(self.output), self.lineWidth, failures)
)
def test_doubleNewline(self) -> None:
"""
Allow paragraphs delimited by two \ns.
"""
sampleText = "et\n\nphone\nhome."
result = text.wordWrap(sampleText, self.lineWidth)
self.assertEqual(result, ["et", "", "phone home.", ""])
class LineTests(unittest.TestCase):
"""
Tests for L{isMultiline} and L{endsInNewline}.
"""
def test_isMultiline(self) -> None:
"""
L{text.isMultiline} returns C{True} if the string has a newline in it.
"""
s = 'This code\n "breaks."'
m = text.isMultiline(s)
self.assertTrue(m)
s = 'This code does not "break."'
m = text.isMultiline(s)
self.assertFalse(m)
def test_endsInNewline(self) -> None:
"""
L{text.endsInNewline} returns C{True} if the string ends in a newline.
"""
s = "newline\n"
m = text.endsInNewline(s)
self.assertTrue(m)
s = "oldline"
m = text.endsInNewline(s)
self.assertFalse(m)
class StringyStringTests(unittest.TestCase):
"""
Tests for L{text.stringyString}.
"""
def test_tuple(self) -> None:
"""
Tuple elements are displayed on separate lines.
"""
s = ("a", "b")
m = text.stringyString(s)
self.assertEqual(m, "(a,\n b,)\n")
def test_dict(self) -> None:
"""
Dicts elements are displayed using C{str()}.
"""
s = {"a": 0}
m = text.stringyString(s)
self.assertEqual(m, "{a: 0}")
def test_list(self) -> None:
"""
List elements are displayed on separate lines using C{str()}.
"""
s = ["a", "b"]
m = text.stringyString(s)
self.assertEqual(m, "[a,\n b,]\n")
class SplitTests(unittest.TestCase):
"""
Tests for L{text.splitQuoted}.
"""
def test_oneWord(self) -> None:
"""
Splitting strings with one-word phrases.
"""
s = 'This code "works."'
r = text.splitQuoted(s)
self.assertEqual(["This", "code", "works."], r)
def test_multiWord(self) -> None:
s = 'The "hairy monkey" likes pie.'
r = text.splitQuoted(s)
self.assertEqual(["The", "hairy monkey", "likes", "pie."], r)
# Some of the many tests that would fail:
# def test_preserveWhitespace(self):
# phrase = '"MANY SPACES"'
# s = 'With %s between.' % (phrase,)
# r = text.splitQuoted(s)
# self.assertEqual(['With', phrase, 'between.'], r)
# def test_escapedSpace(self):
# s = r"One\ Phrase"
# r = text.splitQuoted(s)
# self.assertEqual(["One Phrase"], r)
class StrFileTests(unittest.TestCase):
def setUp(self) -> None:
self.io = StringIO("this is a test string")
def tearDown(self) -> None:
pass
def test_1_f(self) -> None:
self.assertFalse(text.strFile("x", self.io))
def test_1_1(self) -> None:
self.assertTrue(text.strFile("t", self.io))
def test_1_2(self) -> None:
self.assertTrue(text.strFile("h", self.io))
def test_1_3(self) -> None:
self.assertTrue(text.strFile("i", self.io))
def test_1_4(self) -> None:
self.assertTrue(text.strFile("s", self.io))
def test_1_5(self) -> None:
self.assertTrue(text.strFile("n", self.io))
def test_1_6(self) -> None:
self.assertTrue(text.strFile("g", self.io))
def test_3_1(self) -> None:
self.assertTrue(text.strFile("thi", self.io))
def test_3_2(self) -> None:
self.assertTrue(text.strFile("his", self.io))
def test_3_3(self) -> None:
self.assertTrue(text.strFile("is ", self.io))
def test_3_4(self) -> None:
self.assertTrue(text.strFile("ing", self.io))
def test_3_f(self) -> None:
self.assertFalse(text.strFile("bla", self.io))
def test_large_1(self) -> None:
self.assertTrue(text.strFile("this is a test", self.io))
def test_large_2(self) -> None:
self.assertTrue(text.strFile("is a test string", self.io))
def test_large_f(self) -> None:
self.assertFalse(text.strFile("ds jhfsa k fdas", self.io))
def test_overlarge_f(self) -> None:
self.assertFalse(
text.strFile("djhsakj dhsa fkhsa s,mdbnfsauiw bndasdf hreew", self.io)
)
def test_self(self) -> None:
self.assertTrue(text.strFile("this is a test string", self.io))
def test_insensitive(self) -> None:
self.assertTrue(text.strFile("ThIs is A test STRING", self.io, False))

View File

@@ -0,0 +1,119 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.python.threadable}.
"""
import pickle
import sys
from unittest import skipIf
try:
import threading
except ImportError:
threadingSkip = True
else:
threadingSkip = False
from twisted.python import threadable
from twisted.trial.unittest import FailTest, SynchronousTestCase
class TestObject:
synchronized = ["aMethod"]
x = -1
y = 1
def aMethod(self):
for i in range(10):
self.x, self.y = self.y, self.x
self.z = self.x + self.y
assert self.z == 0, "z == %d, not 0 as expected" % (self.z,)
threadable.synchronize(TestObject)
class SynchronizationTests(SynchronousTestCase):
def setUp(self):
"""
Reduce the CPython check interval so that thread switches happen much
more often, hopefully exercising more possible race conditions. Also,
delay actual test startup until the reactor has been started.
"""
self.addCleanup(sys.setswitchinterval, sys.getswitchinterval())
sys.setswitchinterval(0.0000001)
def test_synchronizedName(self):
"""
The name of a synchronized method is inaffected by the synchronization
decorator.
"""
self.assertEqual("aMethod", TestObject.aMethod.__name__)
@skipIf(threadingSkip, "Platform does not support threads")
def test_isInIOThread(self):
"""
L{threadable.isInIOThread} returns C{True} if and only if it is called
in the same thread as L{threadable.registerAsIOThread}.
"""
threadable.registerAsIOThread()
foreignResult = []
t = threading.Thread(
target=lambda: foreignResult.append(threadable.isInIOThread())
)
t.start()
t.join()
self.assertFalse(foreignResult[0], "Non-IO thread reported as IO thread")
self.assertTrue(
threadable.isInIOThread(), "IO thread reported as not IO thread"
)
@skipIf(threadingSkip, "Platform does not support threads")
def testThreadedSynchronization(self):
o = TestObject()
errors = []
def callMethodLots():
try:
for i in range(1000):
o.aMethod()
except AssertionError as e:
errors.append(str(e))
threads = []
for x in range(5):
t = threading.Thread(target=callMethodLots)
threads.append(t)
t.start()
for t in threads:
t.join()
if errors:
raise FailTest(errors)
def testUnthreadedSynchronization(self):
o = TestObject()
for i in range(1000):
o.aMethod()
class SerializationTests(SynchronousTestCase):
@skipIf(threadingSkip, "Platform does not support threads")
def testPickling(self):
lock = threadable.XLock()
lockType = type(lock)
lockPickle = pickle.dumps(lock)
newLock = pickle.loads(lockPickle)
self.assertIsInstance(newLock, lockType)
def testUnpickling(self):
lockPickle = b"ctwisted.python.threadable\nunpickle_lock\np0\n(tp1\nRp2\n."
lock = pickle.loads(lockPickle)
newPickle = pickle.dumps(lock, 2)
pickle.loads(newPickle)

View File

@@ -0,0 +1,710 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.python.threadpool}
"""
import gc
import pickle
import threading
import time
import weakref
from twisted._threads import Team, createMemoryWorker
from twisted.python import context, failure, threadable, threadpool
from twisted.trial import unittest
class Synchronization:
failures = 0
def __init__(self, N, waiting):
self.N = N
self.waiting = waiting
self.lock = threading.Lock()
self.runs = []
def run(self):
# This is the testy part: this is supposed to be invoked
# serially from multiple threads. If that is actually the
# case, we will never fail to acquire this lock. If it is
# *not* the case, we might get here while someone else is
# holding the lock.
if self.lock.acquire(False):
if not len(self.runs) % 5:
# Constant selected based on empirical data to maximize the
# chance of a quick failure if this code is broken.
time.sleep(0.0002)
self.lock.release()
else:
self.failures += 1
# This is just the only way I can think of to wake up the test
# method. It doesn't actually have anything to do with the
# test.
self.lock.acquire()
self.runs.append(None)
if len(self.runs) == self.N:
self.waiting.release()
self.lock.release()
synchronized = ["run"]
threadable.synchronize(Synchronization)
class ThreadPoolTests(unittest.SynchronousTestCase):
"""
Test threadpools.
"""
def getTimeout(self):
"""
Return number of seconds to wait before giving up.
"""
return 5 # Really should be order of magnitude less
def _waitForLock(self, lock):
items = range(1000000)
for i in items:
if lock.acquire(False):
break
time.sleep(1e-5)
else:
self.fail("A long time passed without succeeding")
def test_attributes(self):
"""
L{ThreadPool.min} and L{ThreadPool.max} are set to the values passed to
L{ThreadPool.__init__}.
"""
pool = threadpool.ThreadPool(12, 22)
self.assertEqual(pool.min, 12)
self.assertEqual(pool.max, 22)
def test_start(self):
"""
L{ThreadPool.start} creates the minimum number of threads specified.
"""
pool = threadpool.ThreadPool(0, 5)
pool.start()
self.addCleanup(pool.stop)
self.assertEqual(len(pool.threads), 0)
pool = threadpool.ThreadPool(3, 10)
self.assertEqual(len(pool.threads), 0)
pool.start()
self.addCleanup(pool.stop)
self.assertEqual(len(pool.threads), 3)
def test_adjustingWhenPoolStopped(self):
"""
L{ThreadPool.adjustPoolsize} only modifies the pool size and does not
start new workers while the pool is not running.
"""
pool = threadpool.ThreadPool(0, 5)
pool.start()
pool.stop()
pool.adjustPoolsize(2)
self.assertEqual(len(pool.threads), 0)
def test_threadCreationArguments(self):
"""
Test that creating threads in the threadpool with application-level
objects as arguments doesn't results in those objects never being
freed, with the thread maintaining a reference to them as long as it
exists.
"""
tp = threadpool.ThreadPool(0, 1)
tp.start()
self.addCleanup(tp.stop)
# Sanity check - no threads should have been started yet.
self.assertEqual(tp.threads, [])
# Here's our function
def worker(arg):
pass
# weakref needs an object subclass
class Dumb:
pass
# And here's the unique object
unique = Dumb()
workerRef = weakref.ref(worker)
uniqueRef = weakref.ref(unique)
# Put some work in
tp.callInThread(worker, unique)
# Add an event to wait completion
event = threading.Event()
tp.callInThread(event.set)
event.wait(self.getTimeout())
del worker
del unique
gc.collect()
self.assertIsNone(uniqueRef())
self.assertIsNone(workerRef())
def test_[AWS-SECRET-REMOVED]allback(self):
"""
As C{test_threadCreationArguments} above, but for
callInThreadWithCallback.
"""
tp = threadpool.ThreadPool(0, 1)
tp.start()
self.addCleanup(tp.stop)
# Sanity check - no threads should have been started yet.
self.assertEqual(tp.threads, [])
# this holds references obtained in onResult
refdict = {} # name -> ref value
onResultWait = threading.Event()
onResultDone = threading.Event()
resultRef = []
# result callback
def onResult(success, result):
# Spin the GC, which should now delete worker and unique if it's
# not held on to by callInThreadWithCallback after it is complete
gc.collect()
onResultWait.wait(self.getTimeout())
refdict["workerRef"] = workerRef()
refdict["uniqueRef"] = uniqueRef()
onResultDone.set()
resultRef.append(weakref.ref(result))
# Here's our function
def worker(arg, test):
return Dumb()
# weakref needs an object subclass
class Dumb:
pass
# And here's the unique object
unique = Dumb()
onResultRef = weakref.ref(onResult)
workerRef = weakref.ref(worker)
uniqueRef = weakref.ref(unique)
# Put some work in
tp.callInThreadWithCallback(onResult, worker, unique, test=unique)
del worker
del unique
# let onResult collect the refs
onResultWait.set()
# wait for onResult
onResultDone.wait(self.getTimeout())
gc.collect()
self.assertIsNone(uniqueRef())
self.assertIsNone(workerRef())
# XXX There's a race right here - has onResult in the worker thread
# returned and the locals in _worker holding it and the result been
# deleted yet?
del onResult
gc.collect()
self.assertIsNone(onResultRef())
self.assertIsNone(resultRef[0]())
# The callback shouldn't have been able to resolve the references.
self.assertEqual(list(refdict.values()), [None, None])
def test_persistence(self):
"""
Threadpools can be pickled and unpickled, which should preserve the
number of threads and other parameters.
"""
pool = threadpool.ThreadPool(7, 20)
self.assertEqual(pool.min, 7)
self.assertEqual(pool.max, 20)
# check that unpickled threadpool has same number of threads
copy = pickle.loads(pickle.dumps(pool))
self.assertEqual(copy.min, 7)
self.assertEqual(copy.max, 20)
def _threadpoolTest(self, method):
"""
Test synchronization of calls made with C{method}, which should be
one of the mechanisms of the threadpool to execute work in threads.
"""
# This is a schizophrenic test: it seems to be trying to test
# both the callInThread()/dispatch() behavior of the ThreadPool as well
# as the serialization behavior of threadable.synchronize(). It
# would probably make more sense as two much simpler tests.
N = 10
tp = threadpool.ThreadPool()
tp.start()
self.addCleanup(tp.stop)
waiting = threading.Lock()
waiting.acquire()
actor = Synchronization(N, waiting)
for i in range(N):
method(tp, actor)
self._waitForLock(waiting)
self.assertFalse(actor.failures, f"run() re-entered {actor.failures} times")
def test_callInThread(self):
"""
Call C{_threadpoolTest} with C{callInThread}.
"""
return self._threadpoolTest(lambda tp, actor: tp.callInThread(actor.run))
def test_callInThreadException(self):
"""
L{ThreadPool.callInThread} logs exceptions raised by the callable it
is passed.
"""
class NewError(Exception):
pass
def raiseError():
raise NewError()
tp = threadpool.ThreadPool(0, 1)
tp.callInThread(raiseError)
tp.start()
tp.stop()
errors = self.flushLoggedErrors(NewError)
self.assertEqual(len(errors), 1)
def test_callInThreadWithCallback(self):
"""
L{ThreadPool.callInThreadWithCallback} calls C{onResult} with a
two-tuple of C{(True, result)} where C{result} is the value returned
by the callable supplied.
"""
waiter = threading.Lock()
waiter.acquire()
results = []
def onResult(success, result):
waiter.release()
results.append(success)
results.append(result)
tp = threadpool.ThreadPool(0, 1)
tp.callInThreadWithCallback(onResult, lambda: "test")
tp.start()
try:
self._waitForLock(waiter)
finally:
tp.stop()
self.assertTrue(results[0])
self.assertEqual(results[1], "test")
def test_[AWS-SECRET-REMOVED]ack(self):
"""
L{ThreadPool.callInThreadWithCallback} calls C{onResult} with a
two-tuple of C{(False, failure)} where C{failure} represents the
exception raised by the callable supplied.
"""
class NewError(Exception):
pass
def raiseError():
raise NewError()
waiter = threading.Lock()
waiter.acquire()
results = []
def onResult(success, result):
waiter.release()
results.append(success)
results.append(result)
tp = threadpool.ThreadPool(0, 1)
tp.callInThreadWithCallback(onResult, raiseError)
tp.start()
try:
self._waitForLock(waiter)
finally:
tp.stop()
self.assertFalse(results[0])
self.assertIsInstance(results[1], failure.Failure)
self.assertTrue(issubclass(results[1].type, NewError))
def test_[AWS-SECRET-REMOVED]ult(self):
"""
L{ThreadPool.callInThreadWithCallback} logs the exception raised by
C{onResult}.
"""
class NewError(Exception):
pass
waiter = threading.Lock()
waiter.acquire()
results = []
def onResult(success, result):
results.append(success)
results.append(result)
raise NewError()
tp = threadpool.ThreadPool(0, 1)
tp.callInThreadWithCallback(onResult, lambda: None)
tp.callInThread(waiter.release)
tp.start()
try:
self._waitForLock(waiter)
finally:
tp.stop()
errors = self.flushLoggedErrors(NewError)
self.assertEqual(len(errors), 1)
self.assertTrue(results[0])
self.assertIsNone(results[1])
def test_callbackThread(self):
"""
L{ThreadPool.callInThreadWithCallback} calls the function it is
given and the C{onResult} callback in the same thread.
"""
threadIds = []
event = threading.Event()
def onResult(success, result):
threadIds.append(threading.current_thread().ident)
event.set()
def func():
threadIds.append(threading.current_thread().ident)
tp = threadpool.ThreadPool(0, 1)
tp.callInThreadWithCallback(onResult, func)
tp.start()
self.addCleanup(tp.stop)
event.wait(self.getTimeout())
self.assertEqual(len(threadIds), 2)
self.assertEqual(threadIds[0], threadIds[1])
def test_callbackContext(self):
"""
The context L{ThreadPool.callInThreadWithCallback} is invoked in is
shared by the context the callable and C{onResult} callback are
invoked in.
"""
myctx = context.theContextTracker.currentContext().contexts[-1]
myctx["testing"] = "this must be present"
contexts = []
event = threading.Event()
def onResult(success, result):
ctx = context.theContextTracker.currentContext().contexts[-1]
contexts.append(ctx)
event.set()
def func():
ctx = context.theContextTracker.currentContext().contexts[-1]
contexts.append(ctx)
tp = threadpool.ThreadPool(0, 1)
tp.callInThreadWithCallback(onResult, func)
tp.start()
self.addCleanup(tp.stop)
event.wait(self.getTimeout())
self.assertEqual(len(contexts), 2)
self.assertEqual(myctx, contexts[0])
self.assertEqual(myctx, contexts[1])
def test_existingWork(self):
"""
Work added to the threadpool before its start should be executed once
the threadpool is started: this is ensured by trying to release a lock
previously acquired.
"""
waiter = threading.Lock()
waiter.acquire()
tp = threadpool.ThreadPool(0, 1)
tp.callInThread(waiter.release) # Before start()
tp.start()
try:
self._waitForLock(waiter)
finally:
tp.stop()
def test_workerStateTransition(self):
"""
As the worker receives and completes work, it transitions between
the working and waiting states.
"""
pool = threadpool.ThreadPool(0, 1)
pool.start()
self.addCleanup(pool.stop)
# Sanity check
self.assertEqual(pool.workers, 0)
self.assertEqual(len(pool.waiters), 0)
self.assertEqual(len(pool.working), 0)
# Fire up a worker and give it some 'work'
threadWorking = threading.Event()
threadFinish = threading.Event()
def _thread():
threadWorking.set()
threadFinish.wait(10)
pool.callInThread(_thread)
threadWorking.wait(10)
self.assertEqual(pool.workers, 1)
self.assertEqual(len(pool.waiters), 0)
self.assertEqual(len(pool.working), 1)
# Finish work, and spin until state changes
threadFinish.set()
while not len(pool.waiters):
time.sleep(0.0005)
# Make sure state changed correctly
self.assertEqual(len(pool.waiters), 1)
self.assertEqual(len(pool.working), 0)
def test_q(self) -> None:
"""
There is a property '_queue' for legacy purposes
"""
pool = threadpool.ThreadPool(0, 1)
self.assertEqual(pool._queue.qsize(), 0)
class RaceConditionTests(unittest.SynchronousTestCase):
def setUp(self):
self.threadpool = threadpool.ThreadPool(0, 10)
self.event = threading.Event()
self.threadpool.start()
def done():
self.threadpool.stop()
del self.threadpool
self.addCleanup(done)
def getTimeout(self):
"""
A reasonable number of seconds to time out.
"""
return 5
def test_synchronization(self):
"""
If multiple threads are waiting on an event (via blocking on something
in a callable passed to L{threadpool.ThreadPool.callInThread}), and
there is spare capacity in the threadpool, sending another callable
which will cause those to un-block to
L{threadpool.ThreadPool.callInThread} will reliably run that callable
and un-block the blocked threads promptly.
@note: This is not really a unit test, it is a stress-test. You may
need to run it with C{trial -u} to fail reliably if there is a
problem. It is very hard to regression-test for this particular
bug - one where the thread pool may consider itself as having
"enough capacity" when it really needs to spin up a new thread if
it possibly can - in a deterministic way, since the bug can only be
provoked by subtle race conditions.
"""
timeout = self.getTimeout()
self.threadpool.callInThread(self.event.set)
self.event.wait(timeout)
self.event.clear()
for i in range(3):
self.threadpool.callInThread(self.event.wait)
self.threadpool.callInThread(self.event.set)
self.event.wait(timeout)
if not self.event.isSet():
self.event.set()
self.fail("'set' did not run in thread; timed out waiting on 'wait'.")
class MemoryPool(threadpool.ThreadPool):
"""
A deterministic threadpool that uses in-memory data structures to queue
work rather than threads to execute work.
"""
def __init__(self, coordinator, failTest, newWorker, *args, **kwargs):
"""
Initialize this L{MemoryPool} with a test case.
@param coordinator: a worker used to coordinate work in the L{Team}
underlying this threadpool.
@type coordinator: L{twisted._threads.IExclusiveWorker}
@param failTest: A 1-argument callable taking an exception and raising
a test-failure exception.
@type failTest: 1-argument callable taking (L{Failure}) and raising
L{unittest.FailTest}.
@param newWorker: a 0-argument callable that produces a new
L{twisted._threads.IWorker} provider on each invocation.
@type newWorker: 0-argument callable returning
L{twisted._threads.IWorker}.
"""
self._coordinator = coordinator
self._failTest = failTest
self._newWorker = newWorker
threadpool.ThreadPool.__init__(self, *args, **kwargs)
def _pool(self, currentLimit, threadFactory):
"""
Override testing hook to create a deterministic threadpool.
@param currentLimit: A 1-argument callable which returns the current
threadpool size limit.
@param threadFactory: ignored in this invocation; a 0-argument callable
that would produce a thread.
@return: a L{Team} backed by the coordinator and worker passed to
L{MemoryPool.__init__}.
"""
def respectLimit():
# The expression in this method copied and pasted from
# twisted.threads._pool, which is unfortunately bound up
# with lots of actual-threading stuff.
stats = team.statistics()
if (stats.busyWorkerCount + stats.idleWorkerCount) >= currentLimit():
return None
return self._newWorker()
team = Team(
coordinator=self._coordinator,
createWorker=respectLimit,
logException=self._failTest,
)
return team
class PoolHelper:
"""
A L{PoolHelper} constructs a L{threadpool.ThreadPool} that doesn't actually
use threads, by using the internal interfaces in L{twisted._threads}.
@ivar performCoordination: a 0-argument callable that will perform one unit
of "coordination" - work involved in delegating work to other threads -
and return L{True} if it did any work, L{False} otherwise.
@ivar workers: the workers which represent the threads within the pool -
the workers other than the coordinator.
@type workers: L{list} of 2-tuple of (L{IWorker}, C{workPerformer}) where
C{workPerformer} is a 0-argument callable like C{performCoordination}.
@ivar threadpool: a modified L{threadpool.ThreadPool} to test.
@type threadpool: L{MemoryPool}
"""
def __init__(self, testCase, *args, **kwargs):
"""
Create a L{PoolHelper}.
@param testCase: a test case attached to this helper.
@type args: The arguments passed to a L{threadpool.ThreadPool}.
@type kwargs: The arguments passed to a L{threadpool.ThreadPool}
"""
coordinator, self.performCoordination = createMemoryWorker()
self.workers = []
def newWorker():
self.workers.append(createMemoryWorker())
return self.workers[-1][0]
self.threadpool = MemoryPool(
coordinator, testCase.fail, newWorker, *args, **kwargs
)
def performAllCoordination(self):
"""
Perform all currently scheduled "coordination", which is the work
involved in delegating work to other threads.
"""
while self.performCoordination():
pass
class MemoryBackedTests(unittest.SynchronousTestCase):
"""
Tests using L{PoolHelper} to deterministically test properties of the
threadpool implementation.
"""
def test_workBeforeStarting(self):
"""
If a threadpool is told to do work before starting, then upon starting
up, it will start enough workers to handle all of the enqueued work
that it's been given.
"""
helper = PoolHelper(self, 0, 10)
n = 5
for x in range(n):
helper.threadpool.callInThread(lambda: None)
helper.performAllCoordination()
self.assertEqual(helper.workers, [])
helper.threadpool.start()
helper.performAllCoordination()
self.assertEqual(len(helper.workers), n)
def test_tooMuchWorkBeforeStarting(self):
"""
If the amount of work before starting exceeds the maximum number of
threads allowed to the threadpool, only the maximum count will be
started.
"""
helper = PoolHelper(self, 0, 10)
n = 50
for x in range(n):
helper.threadpool.callInThread(lambda: None)
helper.performAllCoordination()
self.assertEqual(helper.workers, [])
helper.threadpool.start()
helper.performAllCoordination()
self.assertEqual(len(helper.workers), helper.threadpool.max)

View File

@@ -0,0 +1,431 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Test methods in twisted.internet.threads and reactor thread APIs.
"""
import os
import sys
import time
from unittest import skipIf
from twisted.internet import defer, error, interfaces, protocol, reactor, threads
from twisted.python import failure, log, threadable, threadpool
from twisted.trial.unittest import TestCase
try:
import threading
except ImportError:
pass
@skipIf(
not interfaces.IReactorThreads(reactor, None),
"No thread support, nothing to test here.",
)
class ReactorThreadsTests(TestCase):
"""
Tests for the reactor threading API.
"""
def test_suggestThreadPoolSize(self):
"""
Try to change maximum number of threads.
"""
reactor.suggestThreadPoolSize(34)
self.assertEqual(reactor.threadpool.max, 34)
reactor.suggestThreadPoolSize(4)
self.assertEqual(reactor.threadpool.max, 4)
def _waitForThread(self):
"""
The reactor's threadpool is only available when the reactor is running,
so to have a sane behavior during the tests we make a dummy
L{threads.deferToThread} call.
"""
return threads.deferToThread(time.sleep, 0)
def test_callInThread(self):
"""
Test callInThread functionality: set a C{threading.Event}, and check
that it's not in the main thread.
"""
def cb(ign):
waiter = threading.Event()
result = []
def threadedFunc():
result.append(threadable.isInIOThread())
waiter.set()
reactor.callInThread(threadedFunc)
waiter.wait(120)
if not waiter.isSet():
self.fail("Timed out waiting for event.")
else:
self.assertEqual(result, [False])
return self._waitForThread().addCallback(cb)
def test_callFromThread(self):
"""
Test callFromThread functionality: from the main thread, and from
another thread.
"""
def cb(ign):
firedByReactorThread = defer.Deferred()
firedByOtherThread = defer.Deferred()
def threadedFunc():
reactor.callFromThread(firedByOtherThread.callback, None)
reactor.callInThread(threadedFunc)
reactor.callFromThread(firedByReactorThread.callback, None)
return defer.DeferredList(
[firedByReactorThread, firedByOtherThread], fireOnOneErrback=True
)
return self._waitForThread().addCallback(cb)
def test_wakerOverflow(self):
"""
Try to make an overflow on the reactor waker using callFromThread.
"""
def cb(ign):
self.failure = None
waiter = threading.Event()
def threadedFunction():
# Hopefully a hundred thousand queued calls is enough to
# trigger the error condition
for i in range(100000):
try:
reactor.callFromThread(lambda: None)
except BaseException:
self.failure = failure.Failure()
break
waiter.set()
reactor.callInThread(threadedFunction)
waiter.wait(120)
if not waiter.isSet():
self.fail("Timed out waiting for event")
if self.failure is not None:
return defer.fail(self.failure)
return self._waitForThread().addCallback(cb)
def _testBlockingCallFromThread(self, reactorFunc):
"""
Utility method to test L{threads.blockingCallFromThread}.
"""
waiter = threading.Event()
results = []
errors = []
def cb1(ign):
def threadedFunc():
try:
r = threads.blockingCallFromThread(reactor, reactorFunc)
except Exception as e:
errors.append(e)
else:
results.append(r)
waiter.set()
reactor.callInThread(threadedFunc)
return threads.deferToThread(waiter.wait, self.getTimeout())
def cb2(ign):
if not waiter.isSet():
self.fail("Timed out waiting for event")
return results, errors
return self._waitForThread().addCallback(cb1).addBoth(cb2)
def test_blockingCallFromThread(self):
"""
Test blockingCallFromThread facility: create a thread, call a function
in the reactor using L{threads.blockingCallFromThread}, and verify the
result returned.
"""
def reactorFunc():
return defer.succeed("foo")
def cb(res):
self.assertEqual(res[0][0], "foo")
return self._testBlockingCallFromThread(reactorFunc).addCallback(cb)
def test_asyncBlockingCallFromThread(self):
"""
Test blockingCallFromThread as above, but be sure the resulting
Deferred is not already fired.
"""
def reactorFunc():
d = defer.Deferred()
reactor.callLater(0.1, d.callback, "egg")
return d
def cb(res):
self.assertEqual(res[0][0], "egg")
return self._testBlockingCallFromThread(reactorFunc).addCallback(cb)
def test_errorBlockingCallFromThread(self):
"""
Test error report for blockingCallFromThread.
"""
def reactorFunc():
return defer.fail(RuntimeError("bar"))
def cb(res):
self.assertIsInstance(res[1][0], RuntimeError)
self.assertEqual(res[1][0].args[0], "bar")
return self._testBlockingCallFromThread(reactorFunc).addCallback(cb)
def test_asyncErrorBlockingCallFromThread(self):
"""
Test error report for blockingCallFromThread as above, but be sure the
resulting Deferred is not already fired.
"""
def reactorFunc():
d = defer.Deferred()
reactor.callLater(0.1, d.errback, RuntimeError("spam"))
return d
def cb(res):
self.assertIsInstance(res[1][0], RuntimeError)
self.assertEqual(res[1][0].args[0], "spam")
return self._testBlockingCallFromThread(reactorFunc).addCallback(cb)
class Counter:
index = 0
problem = 0
def add(self):
"""A non thread-safe method."""
next = self.index + 1
# another thread could jump in here and increment self.index on us
if next != self.index + 1:
self.problem = 1
raise ValueError
# or here, same issue but we wouldn't catch it. We'd overwrite
# their results, and the index will have lost a count. If
# several threads get in here, we will actually make the count
# go backwards when we overwrite it.
self.index = next
@skipIf(
not interfaces.IReactorThreads(reactor, None),
"No thread support, nothing to test here.",
)
class DeferredResultTests(TestCase):
"""
Test twisted.internet.threads.
"""
def setUp(self):
reactor.suggestThreadPoolSize(8)
def tearDown(self):
reactor.suggestThreadPoolSize(0)
def test_callMultiple(self):
"""
L{threads.callMultipleInThread} calls multiple functions in a thread.
"""
L = []
N = 10
d = defer.Deferred()
def finished():
self.assertEqual(L, list(range(N)))
d.callback(None)
threads.callMultipleInThread(
[(L.append, (i,), {}) for i in range(N)]
+ [(reactor.callFromThread, (finished,), {})]
)
return d
def test_deferredResult(self):
"""
L{threads.deferToThread} executes the function passed, and correctly
handles the positional and keyword arguments given.
"""
d = threads.deferToThread(lambda x, y=5: x + y, 3, y=4)
d.addCallback(self.assertEqual, 7)
return d
def test_deferredFailure(self):
"""
Check that L{threads.deferToThread} return a failure object
with an appropriate exception instance when the called
function raises an exception.
"""
class NewError(Exception):
pass
def raiseError():
raise NewError()
d = threads.deferToThread(raiseError)
return self.assertFailure(d, NewError)
def test_deferredFailureAfterSuccess(self):
"""
Check that a successful L{threads.deferToThread} followed by a one
that raises an exception correctly result as a failure.
"""
# set up a condition that causes cReactor to hang. These conditions
# can also be set by other tests when the full test suite is run in
# alphabetical order (test_flow.FlowTest.testThreaded followed by
# test_internet.ReactorCoreTestCase.testStop, to be precise). By
# setting them up explicitly here, we can reproduce the hang in a
# single precise test case instead of depending upon side effects of
# other tests.
#
# alas, this test appears to flunk the default reactor too
d = threads.deferToThread(lambda: None)
d.addCallback(lambda ign: threads.deferToThread(lambda: 1 // 0))
return self.assertFailure(d, ZeroDivisionError)
class DeferToThreadPoolTests(TestCase):
"""
Test L{twisted.internet.threads.deferToThreadPool}.
"""
def setUp(self):
self.tp = threadpool.ThreadPool(0, 8)
self.tp.start()
def tearDown(self):
self.tp.stop()
def test_deferredResult(self):
"""
L{threads.deferToThreadPool} executes the function passed, and
correctly handles the positional and keyword arguments given.
"""
d = threads.deferToThreadPool(reactor, self.tp, lambda x, y=5: x + y, 3, y=4)
d.addCallback(self.assertEqual, 7)
return d
def test_deferredFailure(self):
"""
Check that L{threads.deferToThreadPool} return a failure object with an
appropriate exception instance when the called function raises an
exception.
"""
class NewError(Exception):
pass
def raiseError():
raise NewError()
d = threads.deferToThreadPool(reactor, self.tp, raiseError)
return self.assertFailure(d, NewError)
_callBeforeStartupProgram = """
import time
import %(reactor)s
%(reactor)s.install()
from twisted.internet import reactor
def threadedCall():
print('threaded call')
reactor.callInThread(threadedCall)
# Spin very briefly to try to give the thread a chance to run, if it
# is going to. Is there a better way to achieve this behavior?
for i in range(100):
time.sleep(0.0)
"""
class ThreadStartupProcessProtocol(protocol.ProcessProtocol):
def __init__(self, finished):
self.finished = finished
self.out = []
self.err = []
def outReceived(self, out):
self.out.append(out)
def errReceived(self, err):
self.err.append(err)
def processEnded(self, reason):
self.finished.callback((self.out, self.err, reason))
@skipIf(
not interfaces.IReactorThreads(reactor, None),
"No thread support, nothing to test here.",
)
@skipIf(
not interfaces.IReactorProcess(reactor, None),
"No process support, cannot run subprocess thread tests.",
)
class StartupBehaviorTests(TestCase):
"""
Test cases for the behavior of the reactor threadpool near startup
boundary conditions.
In particular, this asserts that no threaded calls are attempted
until the reactor starts up, that calls attempted before it starts
are in fact executed once it has started, and that in both cases,
the reactor properly cleans itself up (which is tested for
somewhat implicitly, by requiring a child process be able to exit,
something it cannot do unless the threadpool has been properly
torn down).
"""
def testCallBeforeStartupUnexecuted(self):
progname = self.mktemp()
with open(progname, "w") as progfile:
progfile.write(_callBeforeStartupProgram % {"reactor": reactor.__module__})
def programFinished(result):
(out, err, reason) = result
if reason.check(error.ProcessTerminated):
self.fail(f"Process did not exit cleanly (out: {out} err: {err})")
if err:
log.msg(f"Unexpected output on standard error: {err}")
self.assertFalse(out, f"Expected no output, instead received:\n{out}")
def programTimeout(err):
err.trap(error.TimeoutError)
proto.signalProcess("KILL")
return err
env = os***REMOVED***iron.copy()
env["PYTHONPATH"] = os.pathsep.join(sys.path)
d = defer.Deferred().addCallbacks(programFinished, programTimeout)
proto = ThreadStartupProcessProtocol(d)
reactor.spawnProcess(proto, sys.executable, ("python", progname), env)
return d

View File

@@ -0,0 +1,56 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
from io import BytesIO
from twisted.internet import abstract, defer, protocol
from twisted.protocols import basic, loopback
from twisted.trial import unittest
class BufferingServer(protocol.Protocol):
buffer = b""
def dataReceived(self, data: bytes) -> None:
self.buffer += data
class FileSendingClient(protocol.Protocol):
def __init__(self, f: BytesIO) -> None:
self.f = f
def connectionMade(self) -> None:
assert self.transport is not None
s = basic.FileSender()
d = s.beginFileTransfer(self.f, self.transport, lambda x: x)
d.addCallback(lambda r: self.transport.loseConnection())
class FileSenderTests(unittest.TestCase):
def testSendingFile(self) -> defer.Deferred[None]:
testStr = b"xyz" * 100 + b"abc" * 100 + b"123" * 100
s = BufferingServer()
c = FileSendingClient(BytesIO(testStr))
d: defer.Deferred[None] = loopback.loopbackTCP(s, c)
def callback(x: object) -> None:
self.assertEqual(s.buffer, testStr)
return d.addCallback(callback)
def testSendingEmptyFile(self) -> None:
fileSender = basic.FileSender()
consumer = abstract.FileDescriptor()
consumer.connected = 1
emptyFile = BytesIO(b"")
d = fileSender.beginFileTransfer(emptyFile, consumer, lambda x: x)
# The producer will be immediately exhausted, and so immediately
# unregistered
self.assertIsNone(consumer.producer)
# Which means the Deferred from FileSender should have been called
self.assertTrue(d.called, "producer unregistered with deferred being called")

Some files were not shown because too many files have changed in this diff Show More