okay fine

This commit is contained in:
pacnpal
2024-11-03 17:47:26 +00:00
parent 387c4740e7
commit 27f3326e22
10020 changed files with 1935769 additions and 2364 deletions

View File

@@ -0,0 +1,85 @@
###############################################################################
#
# The MIT License (MIT)
#
# Copyright (c) typedef int GmbH
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
###############################################################################
import platform
import twisted
import autobahn
# Twisted specific utilities (these should really be in Twisted, but
# they aren't, and we use these in example code, so it must be part of
# the public API)
from autobahn.twisted.util import sleep
from autobahn.twisted.choosereactor import install_reactor
# WebSocket protocol support
from autobahn.twisted.websocket import \
WebSocketServerProtocol, \
WebSocketClientProtocol, \
WebSocketServerFactory, \
WebSocketClientFactory
# support for running Twisted stream protocols over WebSocket
from autobahn.twisted.websocket import WrappingWebSocketServerFactory, \
WrappingWebSocketClientFactory
# Twisted Web support - FIXME: these imports trigger import of Twisted reactor!
# from autobahn.twisted.resource import WebSocketResource, WSGIRootResource
# WAMP support
from autobahn.twisted.wamp import ApplicationSession
__all__ = (
# this should really be in Twisted
'sleep',
'install_reactor',
# WebSocket
'WebSocketServerProtocol',
'WebSocketClientProtocol',
'WebSocketServerFactory',
'WebSocketClientFactory',
# wrapping stream protocols in WebSocket
'WrappingWebSocketServerFactory',
'WrappingWebSocketClientFactory',
# Twisted Web - FIXME: see comment for import above
# 'WebSocketResource',
# this should really be in Twisted - FIXME: see comment for import above
# 'WSGIRootResource',
# WAMP support
'ApplicationSession',
)
__ident__ = 'Autobahn/{}-Twisted/{}-{}/{}'.format(autobahn.__version__, twisted.__version__, platform.python_implementation(), platform.python_version())
"""
AutobahnPython library implementation (eg. "Autobahn/0.13.0-Twisted/15.5.0-CPython/3.5.1")
"""

View File

@@ -0,0 +1,226 @@
########################################
#
# The MIT License (MIT)
#
# Copyright (c) typedef int GmbH
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software'), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS', WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
########################################
import sys
import traceback
import txaio
txaio.use_twisted()
from twisted.python import reflect
from twisted.internet.error import ReactorAlreadyInstalledError
__all__ = (
'install_optimal_reactor',
'install_reactor',
'current_reactor_klass'
)
def current_reactor_klass():
"""
Return class name of currently installed Twisted reactor or None.
"""
if 'twisted.internet.reactor' in sys.modules:
current_reactor = reflect.qual(sys.modules['twisted.internet.reactor'].__class__).split('.')[-1]
else:
current_reactor = None
return current_reactor
def install_optimal_reactor(require_optimal_reactor=True):
"""
Try to install the optimal Twisted reactor for this platform:
- Linux: epoll
- BSD/OSX: kqueue
- Windows: iocp
- Other: select
Notes:
- This function exists, because the reactor types selected based on platform
in `twisted.internet.default` are different from here.
- The imports are inlined, because the Twisted code base is notorious for
importing the reactor as a side-effect of merely importing. Hence we postpone
all importing.
See: http://twistedmatrix.com/documents/current/core/howto/choosing-reactor.html#reactor-functionality
:param require_optimal_reactor: If ``True`` and the desired reactor could not be installed,
raise ``ReactorAlreadyInstalledError``, else fallback to another reactor.
:type require_optimal_reactor: bool
:returns: The Twisted reactor in place (`twisted.internet.reactor`).
"""
log = txaio.make_logger()
# determine currently installed reactor, if any
#
current_reactor = current_reactor_klass()
# depending on platform, install optimal reactor
#
if 'bsd' in sys.platform or sys.platform.startswith('darwin'):
# *BSD and MacOSX
#
if current_reactor != 'KQueueReactor':
if current_reactor is None:
try:
from twisted.internet import kqreactor
kqreactor.install()
except:
log.warn('Running on *BSD or MacOSX, but cannot install kqueue Twisted reactor: {tb}', tb=traceback.format_exc())
else:
log.debug('Running on *BSD or MacOSX and optimal reactor (kqueue) was installed.')
else:
log.warn('Running on *BSD or MacOSX, but cannot install kqueue Twisted reactor, because another reactor ({klass}) is already installed.', klass=current_reactor)
if require_optimal_reactor:
raise ReactorAlreadyInstalledError()
else:
log.debug('Running on *BSD or MacOSX and optimal reactor (kqueue) already installed.')
elif sys.platform in ['win32']:
# Windows
#
if current_reactor != 'IOCPReactor':
if current_reactor is None:
try:
from twisted.internet.iocpreactor import reactor as iocpreactor
iocpreactor.install()
except:
log.warn('Running on Windows, but cannot install IOCP Twisted reactor: {tb}', tb=traceback.format_exc())
else:
log.debug('Running on Windows and optimal reactor (ICOP) was installed.')
else:
log.warn('Running on Windows, but cannot install IOCP Twisted reactor, because another reactor ({klass}) is already installed.', klass=current_reactor)
if require_optimal_reactor:
raise ReactorAlreadyInstalledError()
else:
log.debug('Running on Windows and optimal reactor (ICOP) already installed.')
elif sys.platform.startswith('linux'):
# Linux
#
if current_reactor != 'EPollReactor':
if current_reactor is None:
try:
from twisted.internet import epollreactor
epollreactor.install()
except:
log.warn('Running on Linux, but cannot install Epoll Twisted reactor: {tb}', tb=traceback.format_exc())
else:
log.debug('Running on Linux and optimal reactor (epoll) was installed.')
else:
log.warn('Running on Linux, but cannot install Epoll Twisted reactor, because another reactor ({klass}) is already installed.', klass=current_reactor)
if require_optimal_reactor:
raise ReactorAlreadyInstalledError()
else:
log.debug('Running on Linux and optimal reactor (epoll) already installed.')
else:
# Other platform
#
if current_reactor != 'SelectReactor':
if current_reactor is None:
try:
from twisted.internet import selectreactor
selectreactor.install()
# from twisted.internet import default as defaultreactor
# defaultreactor.install()
except:
log.warn('Running on "{platform}", but cannot install Select Twisted reactor: {tb}', tb=traceback.format_exc(), platform=sys.platform)
else:
log.debug('Running on "{platform}" and optimal reactor (Select) was installed.', platform=sys.platform)
else:
log.warn('Running on "{platform}", but cannot install Select Twisted reactor, because another reactor ({klass}) is already installed.', klass=current_reactor, platform=sys.platform)
if require_optimal_reactor:
raise ReactorAlreadyInstalledError()
else:
log.debug('Running on "{platform}" and optimal reactor (Select) already installed.', platform=sys.platform)
from twisted.internet import reactor
txaio.config.loop = reactor
return reactor
def install_reactor(explicit_reactor=None, verbose=False, log=None, require_optimal_reactor=True):
"""
Install Twisted reactor.
:param explicit_reactor: If provided, install this reactor. Else, install
the optimal reactor.
:type explicit_reactor: obj
:param verbose: If ``True``, log (at level "info") the reactor that is
in place afterwards.
:type verbose: bool
:param log: Explicit logging to this txaio logger object.
:type log: obj
:param require_optimal_reactor: If ``True`` and the desired reactor could not be installed,
raise ``ReactorAlreadyInstalledError``, else fallback to another reactor.
:type require_optimal_reactor: bool
:returns: The Twisted reactor in place (`twisted.internet.reactor`).
"""
if not log:
log = txaio.make_logger()
if explicit_reactor:
# install explicitly given reactor
#
from twisted.application.reactors import installReactor
if verbose:
log.info('Trying to install explicitly specified Twisted reactor "{reactor}" ..', reactor=explicit_reactor)
try:
installReactor(explicit_reactor)
except:
log.failure('Could not install Twisted reactor {reactor}\n{log_failure.value}',
reactor=explicit_reactor)
sys.exit(1)
else:
# automatically choose optimal reactor
#
if verbose:
log.info('Automatically choosing optimal Twisted reactor ..')
install_optimal_reactor(require_optimal_reactor)
# now the reactor is installed, import it
from twisted.internet import reactor
txaio.config.loop = reactor
if verbose:
from twisted.python.reflect import qual
log.info('Running on Twisted reactor {reactor}', reactor=qual(reactor.__class__))
return reactor

View File

@@ -0,0 +1,380 @@
###############################################################################
#
# The MIT License (MIT)
#
# Copyright (c) typedef int GmbH
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
###############################################################################
from functools import wraps
from typing import List
from twisted.internet.interfaces import IStreamClientEndpoint
from twisted.internet.endpoints import UNIXClientEndpoint
from twisted.internet.endpoints import TCP4ClientEndpoint
from twisted.python.failure import Failure
from twisted.internet.error import ReactorNotRunning
try:
_TLS = True
from twisted.internet.endpoints import SSL4ClientEndpoint
from twisted.internet.ssl import optionsForClientTLS, CertificateOptions, Certificate
from twisted.internet.interfaces import IOpenSSLClientConnectionCreator
from OpenSSL import SSL
except ImportError:
_TLS = False
# there's no optionsForClientTLS in older Twisteds or we might be
# missing OpenSSL entirely.
import txaio
from autobahn.twisted.websocket import WampWebSocketClientFactory
from autobahn.twisted.rawsocket import WampRawSocketClientFactory
from autobahn.wamp import component
from autobahn.twisted.wamp import Session
from autobahn.wamp.serializer import create_transport_serializers, create_transport_serializer
__all__ = ('Component', 'run')
def _unique_list(seq):
"""
Return a list with unique elements from sequence, preserving order.
"""
seen = set()
return [x for x in seq if x not in seen and not seen.add(x)]
def _camel_case_from_snake_case(snake):
parts = snake.split('_')
return parts[0] + ''.join(s.capitalize() for s in parts[1:])
def _create_transport_factory(reactor, transport, session_factory):
"""
Create a WAMP-over-XXX transport factory.
"""
if transport.type == 'websocket':
serializers = create_transport_serializers(transport)
factory = WampWebSocketClientFactory(
session_factory,
url=transport.url,
serializers=serializers,
proxy=transport.proxy, # either None or a dict with host, port
)
elif transport.type == 'rawsocket':
serializer = create_transport_serializer(transport.serializers[0])
factory = WampRawSocketClientFactory(session_factory, serializer=serializer)
else:
assert(False), 'should not arrive here'
# set the options one at a time so we can give user better feedback
for k, v in transport.options.items():
try:
factory.setProtocolOptions(**{k: v})
except (TypeError, KeyError):
# this allows us to document options as snake_case
# until everything internally is upgraded from
# camelCase
try:
factory.setProtocolOptions(
**{_camel_case_from_snake_case(k): v}
)
except (TypeError, KeyError):
raise ValueError(
"Unknown {} transport option: {}={}".format(transport.type, k, v)
)
return factory
def _create_transport_endpoint(reactor, endpoint_config):
"""
Create a Twisted client endpoint for a WAMP-over-XXX transport.
"""
if IStreamClientEndpoint.providedBy(endpoint_config):
endpoint = IStreamClientEndpoint(endpoint_config)
else:
# create a connecting TCP socket
if endpoint_config['type'] == 'tcp':
version = endpoint_config.get('version', 4)
if version not in [4, 6]:
raise ValueError('invalid IP version {} in client endpoint configuration'.format(version))
host = endpoint_config['host']
if type(host) != str:
raise ValueError('invalid type {} for host in client endpoint configuration'.format(type(host)))
port = endpoint_config['port']
if type(port) != int:
raise ValueError('invalid type {} for port in client endpoint configuration'.format(type(port)))
timeout = endpoint_config.get('timeout', 10) # in seconds
if type(timeout) != int:
raise ValueError('invalid type {} for timeout in client endpoint configuration'.format(type(timeout)))
tls = endpoint_config.get('tls', None)
# create a TLS enabled connecting TCP socket
if tls:
if not _TLS:
raise RuntimeError('TLS configured in transport, but TLS support is not installed (eg OpenSSL?)')
# FIXME: create TLS context from configuration
if IOpenSSLClientConnectionCreator.providedBy(tls):
# eg created from twisted.internet.ssl.optionsForClientTLS()
context = IOpenSSLClientConnectionCreator(tls)
elif isinstance(tls, dict):
for k in tls.keys():
if k not in ["hostname", "trust_root"]:
raise ValueError("Invalid key '{}' in 'tls' config".format(k))
hostname = tls.get('hostname', host)
if type(hostname) != str:
raise ValueError('invalid type {} for hostname in TLS client endpoint configuration'.format(hostname))
trust_root = None
cert_fname = tls.get("trust_root", None)
if cert_fname is not None:
trust_root = Certificate.loadPEM(open(cert_fname, 'r').read())
context = optionsForClientTLS(hostname, trustRoot=trust_root)
elif isinstance(tls, CertificateOptions):
context = tls
elif tls is True:
context = optionsForClientTLS(host)
else:
raise RuntimeError('unknown type {} for "tls" configuration in transport'.format(type(tls)))
if version == 4:
endpoint = SSL4ClientEndpoint(reactor, host, port, context, timeout=timeout)
elif version == 6:
# there is no SSL6ClientEndpoint!
raise RuntimeError('TLS on IPv6 not implemented')
else:
assert(False), 'should not arrive here'
# create a non-TLS connecting TCP socket
else:
if host.endswith(".onion"):
# hmm, can't log here?
# self.log.info("{host} appears to be a Tor endpoint", host=host)
try:
import txtorcon
endpoint = txtorcon.TorClientEndpoint(host, port)
except ImportError:
raise RuntimeError(
"{} appears to be a Tor Onion service, but txtorcon is not installed".format(
host,
)
)
elif version == 4:
endpoint = TCP4ClientEndpoint(reactor, host, port, timeout=timeout)
elif version == 6:
try:
from twisted.internet.endpoints import TCP6ClientEndpoint
except ImportError:
raise RuntimeError('IPv6 is not supported (please upgrade Twisted)')
endpoint = TCP6ClientEndpoint(reactor, host, port, timeout=timeout)
else:
assert(False), 'should not arrive here'
# create a connecting Unix domain socket
elif endpoint_config['type'] == 'unix':
path = endpoint_config['path']
timeout = int(endpoint_config.get('timeout', 10)) # in seconds
endpoint = UNIXClientEndpoint(reactor, path, timeout=timeout)
else:
assert(False), 'should not arrive here'
return endpoint
class Component(component.Component):
"""
A component establishes a transport and attached a session
to a realm using the transport for communication.
The transports a component tries to use can be configured,
as well as the auto-reconnect strategy.
"""
log = txaio.make_logger()
session_factory = Session
"""
The factory of the session we will instantiate.
"""
def _is_ssl_error(self, e):
"""
Internal helper.
This is so we can just return False if we didn't import any
TLS/SSL libraries. Otherwise, returns True if this is an
OpenSSL.SSL.Error
"""
if _TLS:
return isinstance(e, SSL.Error)
return False
def _check_native_endpoint(self, endpoint):
if IStreamClientEndpoint.providedBy(endpoint):
pass
elif isinstance(endpoint, dict):
if 'tls' in endpoint:
tls = endpoint['tls']
if isinstance(tls, (dict, bool)):
pass
elif IOpenSSLClientConnectionCreator.providedBy(tls):
pass
elif isinstance(tls, CertificateOptions):
pass
else:
raise ValueError(
"'tls' configuration must be a dict, CertificateOptions or"
" IOpenSSLClientConnectionCreator provider"
)
else:
raise ValueError(
"'endpoint' configuration must be a dict or IStreamClientEndpoint"
" provider"
)
def _connect_transport(self, reactor, transport, session_factory, done):
"""
Create and connect a WAMP-over-XXX transport.
:param done: is a Deferred/Future from the parent which we
should signal upon error if it is not done yet (XXX maybe an
"on_error" callable instead?)
"""
transport_factory = _create_transport_factory(reactor, transport, session_factory)
if transport.proxy:
transport_endpoint = _create_transport_endpoint(
reactor,
{
"type": "tcp",
"host": transport.proxy["host"],
"port": transport.proxy["port"],
}
)
else:
transport_endpoint = _create_transport_endpoint(reactor, transport.endpoint)
d = transport_endpoint.connect(transport_factory)
def on_connect_success(proto):
# if e.g. an SSL handshake fails, we will have
# successfully connected (i.e. get here) but need to
# 'listen' for the "connectionLost" from the underlying
# protocol in case of handshake failure .. so we wrap
# it. Also, we don't increment transport.success_count
# here on purpose (because we might not succeed).
orig = proto.connectionLost
@wraps(orig)
def lost(fail):
rtn = orig(fail)
if not txaio.is_called(done):
txaio.reject(done, fail)
return rtn
proto.connectionLost = lost
def on_connect_failure(err):
transport.connect_failures += 1
# failed to establish a connection in the first place
txaio.reject(done, err)
txaio.add_callbacks(d, on_connect_success, None)
txaio.add_callbacks(d, None, on_connect_failure)
return d
def start(self, reactor=None):
"""
This starts the Component, which means it will start connecting
(and re-connecting) to its configured transports. A Component
runs until it is "done", which means one of:
- There was a "main" function defined, and it completed successfully;
- Something called ``.leave()`` on our session, and we left successfully;
- ``.stop()`` was called, and completed successfully;
- none of our transports were able to connect successfully (failure);
:returns: a Deferred that fires (with ``None``) when we are
"done" or with a Failure if something went wrong.
"""
if reactor is None:
self.log.warn("Using default reactor")
from twisted.internet import reactor
return self._start(loop=reactor)
def run(components: List[Component], log_level: str = 'info', stop_at_close: bool = True):
"""
High-level API to run a series of components.
This will only return once all the components have stopped
(including, possibly, after all re-connections have failed if you
have re-connections enabled). Under the hood, this calls
:meth:`twisted.internet.reactor.run` -- if you wish to manage the
reactor loop yourself, use the
:meth:`autobahn.twisted.component.Component.start` method to start
each component yourself.
:param components: the Component(s) you wish to run
:param log_level: a valid log-level (or None to avoid calling start_logging)
:param stop_at_close: Flag to control whether to stop the reactor when done.
"""
# only for Twisted > 12
# ...so this isn't in all Twisted versions we test against -- need
# to do "something else" if we can't import .. :/ (or drop some
# support)
from twisted.internet.task import react
# actually, should we even let people "not start" the logging? I'm
# not sure that's wise... (double-check: if they already called
# txaio.start_logging() what happens if we call it again?)
if log_level is not None:
txaio.start_logging(level=log_level)
log = txaio.make_logger()
if stop_at_close:
def done_callback(reactor, arg):
if isinstance(arg, Failure):
log.error('Something went wrong: {log_failure}', failure=arg)
try:
log.warn('Stopping reactor ..')
reactor.stop()
except ReactorNotRunning:
pass
else:
done_callback = None
react(component._run, (components, done_callback))

View File

@@ -0,0 +1,152 @@
###############################################################################
#
# The MIT License (MIT)
#
# Copyright (c) typedef int GmbH
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
###############################################################################
from autobahn.wamp.cryptosign import HAS_CRYPTOSIGN, CryptosignKey
from twisted.internet.defer import inlineCallbacks, returnValue
__all__ = [
'HAS_CRYPTOSIGN_SSHAGENT'
]
if HAS_CRYPTOSIGN:
try:
# WAMP-cryptosign support for SSH agent is currently
# only available on Twisted (on Python 2)
from twisted.internet.protocol import Factory
from twisted.internet.endpoints import UNIXClientEndpoint
from twisted.conch.ssh.agent import SSHAgentClient
except ImportError:
# twisted.conch is not yet fully ported to Python 3
HAS_CRYPTOSIGN_SSHAGENT = False
else:
HAS_CRYPTOSIGN_SSHAGENT = True
__all__.append('SSHAgentCryptosignKey')
if HAS_CRYPTOSIGN_SSHAGENT:
import os
from nacl import signing
from autobahn.wamp.cryptosign import _read_ssh_ed25519_pubkey, _unpack, _pack
class SSHAgentCryptosignKey(CryptosignKey):
"""
A WAMP-cryptosign signing key that is a proxy to a private Ed25510 key
actually held in SSH agent.
An instance of this class must be create via the class method new().
The instance only holds the public key part, whereas the private key
counterpart is held in SSH agent.
"""
def __init__(self, key, comment=None, reactor=None):
CryptosignKey.__init__(self, key, comment)
if not reactor:
from twisted.internet import reactor
self._reactor = reactor
@classmethod
def new(cls, pubkey=None, reactor=None):
"""
Create a proxy for a key held in SSH agent.
:param pubkey: A string with a public Ed25519 key in SSH format.
:type pubkey: unicode
"""
if not HAS_CRYPTOSIGN_SSHAGENT:
raise Exception("SSH agent integration is not supported on this platform")
pubkey, _ = _read_ssh_ed25519_pubkey(pubkey)
if not reactor:
from twisted.internet import reactor
if "SSH_AUTH_SOCK" not in os.environ:
raise Exception("no ssh-agent is running!")
factory = Factory()
factory.noisy = False
factory.protocol = SSHAgentClient
endpoint = UNIXClientEndpoint(reactor, os.environ["SSH_AUTH_SOCK"])
d = endpoint.connect(factory)
@inlineCallbacks
def on_connect(agent):
keys = yield agent.requestIdentities()
# if the key is found in ssh-agent, the raw public key (32 bytes), and the
# key comment as returned from ssh-agent
key_data = None
key_comment = None
for blob, comment in keys:
raw = _unpack(blob)
algo = raw[0].decode('utf8')
if algo == 'ssh-ed25519':
algo, _pubkey = raw
if _pubkey == pubkey:
key_data = _pubkey
key_comment = comment.decode('utf8')
break
agent.transport.loseConnection()
if key_data:
key = signing.VerifyKey(key_data)
returnValue(cls(key, key_comment, reactor))
else:
raise Exception("Ed25519 key not held in ssh-agent")
return d.addCallback(on_connect)
def sign(self, challenge):
if "SSH_AUTH_SOCK" not in os.environ:
raise Exception("no ssh-agent is running!")
factory = Factory()
factory.noisy = False
factory.protocol = SSHAgentClient
endpoint = UNIXClientEndpoint(self._reactor, os.environ["SSH_AUTH_SOCK"])
d = endpoint.connect(factory)
@inlineCallbacks
def on_connect(agent):
# we are now connected to the locally running ssh-agent
# that agent might be the openssh-agent, or eg on Ubuntu 14.04 by
# default the gnome-keyring / ssh-askpass-gnome application
blob = _pack(['ssh-ed25519'.encode(), self.public_key(binary=True)])
# now ask the agent
signature_blob = yield agent.signData(blob, challenge)
algo, signature = _unpack(signature_blob)
agent.transport.loseConnection()
returnValue(signature)
return d.addCallback(on_connect)

View File

@@ -0,0 +1,128 @@
###############################################################################
#
# The MIT License (MIT)
#
# Copyright (c) typedef int GmbH
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
###############################################################################
import txaio
txaio.use_twisted()
from twisted.python import usage
from twisted.internet.defer import inlineCallbacks
from twisted.internet.protocol import Factory, Protocol
from twisted.internet.endpoints import clientFromString, serverFromString
from twisted.application import service
class DestEndpointForwardingProtocol(Protocol):
log = txaio.make_logger()
def connectionMade(self):
self.log.debug("DestEndpointForwardingProtocol.connectionMade")
pass
def dataReceived(self, data):
self.log.debug(
"DestEndpointForwardingProtocol.dataReceived: {data}",
data=data,
)
if self.factory._sourceProtocol:
self.factory._sourceProtocol.transport.write(data)
def connectionLost(self, reason):
self.log.debug("DestEndpointForwardingProtocol.connectionLost")
if self.factory._sourceProtocol:
self.factory._sourceProtocol.transport.loseConnection()
class DestEndpointForwardingFactory(Factory):
def __init__(self, sourceProtocol):
self._sourceProtocol = sourceProtocol
self._proto = None
def buildProtocol(self, addr):
self._proto = DestEndpointForwardingProtocol()
self._proto.factory = self
return self._proto
class EndpointForwardingProtocol(Protocol):
log = txaio.make_logger()
@inlineCallbacks
def connectionMade(self):
self.log.debug("EndpointForwardingProtocol.connectionMade")
self._destFactory = DestEndpointForwardingFactory(self)
self._destEndpoint = clientFromString(self.factory.service._reactor,
self.factory.service._destEndpointDescriptor)
self._destEndpointPort = yield self._destEndpoint.connect(self._destFactory)
def dataReceived(self, data):
self.log.debug(
"EndpointForwardingProtocol.dataReceived: {data}",
data=data,
)
if self._destFactory._proto:
self._destFactory._proto.transport.write(data)
def connectionLost(self, reason):
self.log.debug("EndpointForwardingProtocol.connectionLost")
if self._destFactory._proto:
self._destFactory._proto.transport.loseConnection()
class EndpointForwardingService(service.Service):
def __init__(self, endpointDescriptor, destEndpointDescriptor, reactor=None):
if reactor is None:
from twisted.internet import reactor
self._reactor = reactor
self._endpointDescriptor = endpointDescriptor
self._destEndpointDescriptor = destEndpointDescriptor
@inlineCallbacks
def startService(self):
factory = Factory.forProtocol(EndpointForwardingProtocol)
factory.service = self
self._endpoint = serverFromString(self._reactor, self._endpointDescriptor)
self._endpointPort = yield self._endpoint.listen(factory)
def stopService(self):
return self._endpointPort.stopListening()
class Options(usage.Options):
synopsis = "[options]"
longdesc = 'Endpoint Forwarder.'
optParameters = [
["endpoint", "e", None, "Source endpoint."],
["dest_endpoint", "d", None, "Destination endpoint."]
]
def makeService(config):
service = EndpointForwardingService(config['endpoint'], config['dest_endpoint'])
return service

View File

@@ -0,0 +1,604 @@
###############################################################################
#
# The MIT License (MIT)
#
# Copyright (c) typedef int GmbH
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
###############################################################################
import copy
import math
from typing import Optional
import txaio
from twisted.internet.protocol import Factory
from twisted.protocols.basic import Int32StringReceiver
from twisted.internet.error import ConnectionDone
from twisted.internet.defer import CancelledError
from autobahn.util import public, _LazyHexFormatter
from autobahn.twisted.util import create_transport_details, transport_channel_id
from autobahn.wamp.types import TransportDetails
from autobahn.wamp.exception import ProtocolError, SerializationError, TransportLost, InvalidUriError
from autobahn.exception import PayloadExceededError
__all__ = (
'WampRawSocketServerProtocol',
'WampRawSocketClientProtocol',
'WampRawSocketServerFactory',
'WampRawSocketClientFactory'
)
class WampRawSocketProtocol(Int32StringReceiver):
"""
Base class for Twisted-based WAMP-over-RawSocket protocols.
"""
log = txaio.make_logger()
peer: Optional[str] = None
is_server: Optional[bool] = None
def __init__(self):
# set the RawSocket maximum message size by default
self._max_message_size = 2**24
self._transport_details = None
@property
def transport_details(self) -> Optional[TransportDetails]:
"""
Implements :func:`autobahn.wamp.interfaces.ITransport.transport_details`
"""
return self._transport_details
def lengthLimitExceeded(self, length):
# override hook in Int32StringReceiver base class that is fired when a message is (to be) received
# that is larger than what we agreed to handle (by negotiation in the RawSocket opening handshake)
emsg = 'RawSocket connection: length of received message exceeded (message was {} bytes, but current maximum is {} bytes)'.format(length, self.MAX_LENGTH)
raise PayloadExceededError(emsg)
def connectionMade(self):
# Twisted networking framework entry point, called by Twisted
# when the connection is established (either a client or a server)
# determine preliminary transport details (what is know at this point)
self._transport_details = create_transport_details(self.transport, self.is_server)
self._transport_details.channel_framing = TransportDetails.CHANNEL_FRAMING_RAWSOCKET
# backward compatibility
self.peer = self._transport_details.peer
# a Future/Deferred that fires when we hit STATE_CLOSED
self.is_closed = txaio.create_future()
# this will hold an ApplicationSession object
# once the RawSocket opening handshake has been
# completed
#
self._session = None
# Will hold the negotiated serializer once the opening handshake is complete
#
self._serializer = None
# Will be set to True once the opening handshake is complete
#
self._handshake_complete = False
# Buffer for opening handshake received bytes.
#
self._handshake_bytes = b''
# Peer requested to _receive_ this maximum length of serialized messages - hence we must not send larger msgs!
#
self._max_len_send = None
def _on_handshake_complete(self):
# RawSocket connection established. Now let the user WAMP session factory
# create a new WAMP session and fire off session open callback.
try:
if self._transport_details.is_secure:
# now that the TLS opening handshake is complete, the actual TLS channel ID
# will be available. make sure to set it!
channel_id = {
'tls-unique': transport_channel_id(self.transport, self._transport_details.is_server, 'tls-unique'),
}
self._transport_details.channel_id = channel_id
self._session = self.factory._factory()
self.log.debug('{klass}._on_handshake_complete(): calling {method}', session=self._session,
klass=self.__class__.__name__, method=self._session.onOpen)
res = self._session.onOpen(self)
except Exception as e:
# Exceptions raised in onOpen are fatal ..
self.log.warn("{klass}._on_handshake_complete(): ApplicationSession constructor / onOpen raised ({err})",
klass=self.__class__.__name__, err=e)
self.abort()
else:
self.log.debug('{klass}._on_handshake_complete(): {session} started (res={res}).', klass=self.__class__.__name__,
session=self._session, res=res)
def connectionLost(self, reason):
self.log.debug('{klass}.connectionLost(reason="{reason}"', klass=self.__class__.__name__, reason=reason)
txaio.resolve(self.is_closed, self)
try:
wasClean = isinstance(reason.value, ConnectionDone)
if self._session:
self._session.onClose(wasClean)
except Exception as e:
# silently ignore exceptions raised here ..
self.log.warn('{klass}.connectionLost(): ApplicationSession.onClose raised "{err}"',
klass=self.__class__.__name__, err=e)
self._session = None
def stringReceived(self, payload):
self.log.trace('{klass}.stringReceived(): RX {octets} octets',
klass=self.__class__.__name__, octets=_LazyHexFormatter(payload))
try:
for msg in self._serializer.unserialize(payload):
self.log.trace("{klass}.stringReceived: RX WAMP message: {msg}",
klass=self.__class__.__name__, msg=msg)
self._session.onMessage(msg)
except CancelledError as e:
self.log.debug("{klass}.stringReceived: WAMP CancelledError - connection will continue!\n{err}",
klass=self.__class__.__name__,
err=e)
except InvalidUriError as e:
self.log.warn("{klass}.stringReceived: WAMP InvalidUriError - aborting connection!\n{err}",
klass=self.__class__.__name__,
err=e)
self.abort()
except ProtocolError as e:
self.log.warn("{klass}.stringReceived: WAMP ProtocolError - aborting connection!\n{err}",
klass=self.__class__.__name__,
err=e)
self.abort()
except PayloadExceededError as e:
self.log.warn("{klass}.stringReceived: WAMP PayloadExceededError - aborting connection!\n{err}",
klass=self.__class__.__name__,
err=e)
self.abort()
except SerializationError as e:
self.log.warn("{klass}.stringReceived: WAMP SerializationError - aborting connection!\n{err}",
klass=self.__class__.__name__,
err=e)
self.abort()
except Exception as e:
self.log.failure()
self.log.warn("{klass}.stringReceived: WAMP Exception - aborting connection!\n{err}",
klass=self.__class__.__name__,
err=e)
self.abort()
def send(self, msg):
"""
Implements :func:`autobahn.wamp.interfaces.ITransport.send`
"""
if self.isOpen():
self.log.trace('{klass}.send() (serializer={serializer}): TX WAMP message: "{msg}"',
klass=self.__class__.__name__, msg=msg, serializer=self._serializer)
try:
payload, _ = self._serializer.serialize(msg)
except SerializationError as e:
# all exceptions raised from above should be serialization errors ..
raise SerializationError("WampRawSocketProtocol: unable to serialize WAMP application payload ({0})".format(e))
else:
payload_len = len(payload)
if 0 < self._max_len_send < payload_len:
emsg = 'tried to send RawSocket message with size {} exceeding payload limit of {} octets'.format(
payload_len, self._max_len_send)
self.log.warn(emsg)
raise PayloadExceededError(emsg)
else:
self.sendString(payload)
self.log.trace('{klass}.send(): TX {octets} octets',
klass=self.__class__.__name__, octets=_LazyHexFormatter(payload))
else:
raise TransportLost()
def isOpen(self):
"""
Implements :func:`autobahn.wamp.interfaces.ITransport.isOpen`
"""
return self._session is not None
def close(self):
"""
Implements :func:`autobahn.wamp.interfaces.ITransport.close`
"""
if self.isOpen():
self.transport.loseConnection()
else:
raise TransportLost()
def abort(self):
"""
Implements :func:`autobahn.wamp.interfaces.ITransport.abort`
"""
if self.isOpen():
if hasattr(self.transport, 'abortConnection'):
# ProcessProtocol lacks abortConnection()
self.transport.abortConnection()
else:
self.transport.loseConnection()
else:
raise TransportLost()
@public
class WampRawSocketServerProtocol(WampRawSocketProtocol):
"""
Twisted-based WAMP-over-RawSocket server protocol.
Implements:
* :class:`autobahn.wamp.interfaces.ITransport`
"""
def dataReceived(self, data):
if self._handshake_complete:
WampRawSocketProtocol.dataReceived(self, data)
else:
remaining = 4 - len(self._handshake_bytes)
self._handshake_bytes += data[:remaining]
if len(self._handshake_bytes) == 4:
self.log.debug(
"WampRawSocketServerProtocol: opening handshake received - 0x{octets}",
octets=_LazyHexFormatter(self._handshake_bytes),
)
# first octet must be magic octet 0x7f
#
_magic = ord(self._handshake_bytes[0:1])
if _magic != 127:
self.log.warn(
"WampRawSocketServerProtocol: invalid magic byte (octet 1) in"
" opening handshake: was {magic}, but expected 127",
magic=_magic,
)
self.abort()
else:
self.log.debug('WampRawSocketServerProtocol: correct magic byte received')
# peer requests us to send messages of maximum length 2**max_len_exp
#
self._max_len_send = 2 ** (9 + (ord(self._handshake_bytes[1:2]) >> 4))
self.log.debug(
"WampRawSocketServerProtocol: client requests us to send out most {max_bytes} bytes per message",
max_bytes=self._max_len_send,
)
# client wants to speak this serialization format
#
ser_id = ord(self._handshake_bytes[1:2]) & 0x0F
if ser_id in self.factory._serializers:
self._serializer = copy.copy(self.factory._serializers[ser_id])
self.log.debug(
"WampRawSocketServerProtocol: client wants to use serializer '{serializer}'",
serializer=ser_id,
)
else:
self.log.warn(
"WampRawSocketServerProtocol: opening handshake - no suitable serializer found (client requested {serializer}, and we have {serializers}",
serializer=ser_id,
serializers=self.factory._serializers.keys(),
)
self.abort()
# we request the client to send message of maximum length 2**reply_max_len_exp
#
reply_max_len_exp = int(math.ceil(math.log(self._max_message_size, 2)))
# this is an instance attribute on the Twisted base class for maximum size
# of _received_ messages
self.MAX_LENGTH = 2**reply_max_len_exp
# send out handshake reply
#
reply_octet2 = bytes(bytearray([
((reply_max_len_exp - 9) << 4) | self._serializer.RAWSOCKET_SERIALIZER_ID]))
self.transport.write(b'\x7F') # magic byte
self.transport.write(reply_octet2) # max length / serializer
self.transport.write(b'\x00\x00') # reserved octets
self._handshake_complete = True
self._on_handshake_complete()
self.log.debug(
"WampRawSocketServerProtocol: opening handshake completed: {serializer}",
serializer=self._serializer,
)
# consume any remaining data received already ..
#
data = data[remaining:]
if data:
self.dataReceived(data)
@public
class WampRawSocketClientProtocol(WampRawSocketProtocol):
"""
Twisted-based WAMP-over-RawSocket client protocol.
Implements:
* :class:`autobahn.wamp.interfaces.ITransport`
"""
def connectionMade(self):
WampRawSocketProtocol.connectionMade(self)
self._serializer = copy.copy(self.factory._serializer)
# we request the peer to send messages of maximum length 2**reply_max_len_exp
request_max_len_exp = int(math.ceil(math.log(self._max_message_size, 2)))
# this is an instance attribute on the Twisted base class for maximum size
# of _received_ messages
self.MAX_LENGTH = 2**request_max_len_exp
# send out handshake request
#
request_octet2 = bytes(bytearray([
((request_max_len_exp - 9) << 4) | self._serializer.RAWSOCKET_SERIALIZER_ID]))
self.transport.write(b'\x7F') # magic byte
self.transport.write(request_octet2) # max length / serializer
self.transport.write(b'\x00\x00') # reserved octets
def dataReceived(self, data):
if self._handshake_complete:
WampRawSocketProtocol.dataReceived(self, data)
else:
remaining = 4 - len(self._handshake_bytes)
self._handshake_bytes += data[:remaining]
if len(self._handshake_bytes) == 4:
self.log.debug(
"WampRawSocketClientProtocol: opening handshake received - {handshake}",
handshake=_LazyHexFormatter(self._handshake_bytes),
)
if ord(self._handshake_bytes[0:1]) != 0x7f:
self.log.debug(
"WampRawSocketClientProtocol: invalid magic byte (octet 1) in opening handshake: was 0x{magic}, but expected 0x7f",
magic=_LazyHexFormatter(self._handshake_bytes[0]),
)
self.abort()
# peer requests us to _send_ messages of maximum length 2**max_len_exp
#
self._max_len_send = 2 ** (9 + (ord(self._handshake_bytes[1:2]) >> 4))
self.log.debug(
"WampRawSocketClientProtocol: server requests us to send out most {max} bytes per message",
max=self._max_len_send,
)
# client wants to speak this serialization format
#
ser_id = ord(self._handshake_bytes[1:2]) & 0x0F
if ser_id != self._serializer.RAWSOCKET_SERIALIZER_ID:
self.log.error(
"WampRawSocketClientProtocol: opening handshake - no suitable serializer found (server replied {serializer}, and we requested {serializers})",
serializer=ser_id,
serializers=self._serializer.RAWSOCKET_SERIALIZER_ID,
)
self.abort()
self._handshake_complete = True
self._on_handshake_complete()
self.log.debug(
"WampRawSocketClientProtocol: opening handshake completed (using serializer {serializer})",
serializer=self._serializer,
)
# consume any remaining data received already ..
#
data = data[remaining:]
if data:
self.dataReceived(data)
class WampRawSocketFactory(Factory):
"""
Base class for Twisted-based WAMP-over-RawSocket factories.
"""
log = txaio.make_logger()
def __init__(self, factory):
"""
:param factory: A callable that produces instances that implement
:class:`autobahn.wamp.interfaces.ITransportHandler`
:type factory: callable
"""
if callable(factory):
self._factory = factory
else:
self._factory = lambda: factory
# RawSocket max payload size is 16M (https://wamp-proto.org/_static/gen/wamp_latest_ietf.html#handshake)
self._max_message_size = 2**24
def resetProtocolOptions(self):
self._max_message_size = 2**24
def setProtocolOptions(self, maxMessagePayloadSize=None):
self.log.debug('{klass}.setProtocolOptions(maxMessagePayloadSize={maxMessagePayloadSize})',
klass=self.__class__.__name__, maxMessagePayloadSize=maxMessagePayloadSize)
assert maxMessagePayloadSize is None or (type(maxMessagePayloadSize) == int and maxMessagePayloadSize >= 512 and maxMessagePayloadSize <= 2**24)
if maxMessagePayloadSize is not None and maxMessagePayloadSize != self._max_message_size:
self._max_message_size = maxMessagePayloadSize
def buildProtocol(self, addr):
self.log.debug('{klass}.buildProtocol(addr={addr})', klass=self.__class__.__name__, addr=addr)
p = self.protocol()
p.factory = self
p.MAX_LENGTH = self._max_message_size
p._max_message_size = self._max_message_size
self.log.debug('{klass}.buildProtocol() -> proto={proto}, max_message_size={max_message_size}, MAX_LENGTH={MAX_LENGTH}',
klass=self.__class__.__name__, proto=p, max_message_size=p._max_message_size, MAX_LENGTH=p.MAX_LENGTH)
return p
@public
class WampRawSocketServerFactory(WampRawSocketFactory):
"""
Twisted-based WAMP-over-RawSocket server protocol factory.
"""
protocol = WampRawSocketServerProtocol
def __init__(self, factory, serializers=None):
"""
:param factory: A callable that produces instances that implement
:class:`autobahn.wamp.interfaces.ITransportHandler`
:type factory: callable
:param serializers: A list of WAMP serializers to use (or ``None``
for all available serializers).
:type serializers: list of objects implementing
:class:`autobahn.wamp.interfaces.ISerializer`
"""
WampRawSocketFactory.__init__(self, factory)
if serializers is None:
serializers = []
# try CBOR WAMP serializer
try:
from autobahn.wamp.serializer import CBORSerializer
serializers.append(CBORSerializer(batched=True))
serializers.append(CBORSerializer())
except ImportError:
pass
# try MsgPack WAMP serializer
try:
from autobahn.wamp.serializer import MsgPackSerializer
serializers.append(MsgPackSerializer(batched=True))
serializers.append(MsgPackSerializer())
except ImportError:
pass
# try UBJSON WAMP serializer
try:
from autobahn.wamp.serializer import UBJSONSerializer
serializers.append(UBJSONSerializer(batched=True))
serializers.append(UBJSONSerializer())
except ImportError:
pass
# try JSON WAMP serializer
try:
from autobahn.wamp.serializer import JsonSerializer
serializers.append(JsonSerializer(batched=True))
serializers.append(JsonSerializer())
except ImportError:
pass
if not serializers:
raise Exception("could not import any WAMP serializers")
self._serializers = {}
for ser in serializers:
self._serializers[ser.RAWSOCKET_SERIALIZER_ID] = ser
@public
class WampRawSocketClientFactory(WampRawSocketFactory):
"""
Twisted-based WAMP-over-RawSocket client protocol factory.
"""
protocol = WampRawSocketClientProtocol
def __init__(self, factory, serializer=None):
"""
:param factory: A callable that produces instances that implement
:class:`autobahn.wamp.interfaces.ITransportHandler`
:type factory: callable
:param serializer: The WAMP serializer to use (or ``None`` for
"best" serializer, chosen as the first serializer available from
this list: CBOR, MessagePack, UBJSON, JSON).
:type serializer: object implementing :class:`autobahn.wamp.interfaces.ISerializer`
"""
WampRawSocketFactory.__init__(self, factory)
# Reduce the factory logs noise
self.noisy = False
if serializer is None:
# try CBOR WAMP serializer
try:
from autobahn.wamp.serializer import CBORSerializer
serializer = CBORSerializer()
except ImportError:
pass
if serializer is None:
# try MsgPack WAMP serializer
try:
from autobahn.wamp.serializer import MsgPackSerializer
serializer = MsgPackSerializer()
except ImportError:
pass
if serializer is None:
# try UBJSON WAMP serializer
try:
from autobahn.wamp.serializer import UBJSONSerializer
serializer = UBJSONSerializer()
except ImportError:
pass
if serializer is None:
# try JSON WAMP serializer
try:
from autobahn.wamp.serializer import JsonSerializer
serializer = JsonSerializer()
except ImportError:
pass
if serializer is None:
raise Exception("could not import any WAMP serializer")
self._serializer = serializer

View File

@@ -0,0 +1,182 @@
###############################################################################
#
# The MIT License (MIT)
#
# Copyright (c) typedef int GmbH
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
###############################################################################
from zope.interface import implementer
from twisted.protocols.policies import ProtocolWrapper
try:
# starting from Twisted 22.10.0 we have `notFound`
from twisted.web.pages import notFound
except ImportError:
try:
# In Twisted < 22.10.0 && > 12.2 this was called `NoResource`
from twisted.web.resource import NoResource as notFound
except ImportError:
# And in Twisted < 12.2 this was in a different place
from twisted.web.error import NoResource as notFound
from twisted.web.resource import IResource, Resource
# The following triggers an import of reactor at module level!
#
from twisted.web.server import NOT_DONE_YET
__all__ = (
'WebSocketResource',
'WSGIRootResource',
)
class WSGIRootResource(Resource):
"""
Root resource when you want a WSGI resource be the default serving
resource for a Twisted Web site, but have subpaths served by
different resources.
This is a hack needed since
`twisted.web.wsgi.WSGIResource <http://twistedmatrix.com/documents/current/api/twisted.web.wsgi.WSGIResource.html>`_.
does not provide a ``putChild()`` method.
.. seealso::
* `Autobahn Twisted Web WSGI example <https://github.com/crossbario/autobahn-python/tree/master/examples/twisted/websocket/echo_wsgi>`_
* `Original hack <http://blog.vrplumber.com/index.php?/archives/2426-Making-your-Twisted-resources-a-url-sub-tree-of-your-WSGI-resource....html>`_
"""
def __init__(self, wsgiResource, children):
"""
:param wsgiResource: The WSGI to serve as root resource.
:type wsgiResource: Instance of `twisted.web.wsgi.WSGIResource <http://twistedmatrix.com/documents/current/api/twisted.web.wsgi.WSGIResource.html>`_.
:param children: A dictionary with string keys constituting URL subpaths, and Twisted Web resources as values.
:type children: dict
"""
Resource.__init__(self)
self._wsgiResource = wsgiResource
self.children = children
def getChild(self, path, request):
request.prepath.pop()
request.postpath.insert(0, path)
return self._wsgiResource
@implementer(IResource)
class WebSocketResource(object):
"""
A Twisted Web resource for WebSocket.
"""
isLeaf = True
def __init__(self, factory):
"""
:param factory: An instance of :class:`autobahn.twisted.websocket.WebSocketServerFactory`.
:type factory: obj
"""
self._factory = factory
# noinspection PyUnusedLocal
def getChildWithDefault(self, name, request):
"""
This resource cannot have children, hence this will always fail.
"""
return notFound(message="No such child resource.")
def putChild(self, path, child):
"""
This resource cannot have children, hence this is always ignored.
"""
def render(self, request):
"""
Render the resource. This will takeover the transport underlying
the request, create a :class:`autobahn.twisted.websocket.WebSocketServerProtocol`
and let that do any subsequent communication.
"""
# for reasons unknown, the transport is already None when the
# request is over HTTP2. request.channel.getPeer() is valid at
# this point however
if request.channel.transport is None:
# render an "error, yo're doing HTTPS over WSS" webpage
from autobahn.websocket import protocol
request.setResponseCode(426, b"Upgrade required")
# RFC says MUST set upgrade along with 426 code:
# https://tools.ietf.org/html/rfc7231#section-6.5.15
request.setHeader(b"Upgrade", b"WebSocket")
html = protocol._SERVER_STATUS_TEMPLATE % ("", protocol.__version__)
return html.encode('utf8')
# Create Autobahn WebSocket protocol.
#
protocol = self._factory.buildProtocol(request.transport.getPeer())
if not protocol:
# If protocol creation fails, we signal "internal server error"
request.setResponseCode(500)
return b""
# Take over the transport from Twisted Web
#
transport, request.channel.transport = request.channel.transport, None
# Connect the transport to our protocol. Once #3204 is fixed, there
# may be a cleaner way of doing this.
# http://twistedmatrix.com/trac/ticket/3204
#
if isinstance(transport, ProtocolWrapper):
# i.e. TLS is a wrapping protocol
transport.wrappedProtocol = protocol
elif isinstance(transport.protocol, ProtocolWrapper):
# this happens in new-TLS
transport.protocol.wrappedProtocol = protocol
else:
transport.protocol = protocol
protocol.makeConnection(transport)
# On Twisted 16+, the transport is paused whilst the existing
# request is served; there won't be any requests after us so
# we can just resume this ourselves.
# 17.1 version
if hasattr(transport, "_networkProducer"):
transport._networkProducer.resumeProducing()
# 16.x version
elif hasattr(transport, "resumeProducing"):
transport.resumeProducing()
# We recreate the request and forward the raw data. This is somewhat
# silly (since Twisted Web already did the HTTP request parsing
# which we will do a 2nd time), but it's totally non-invasive to our
# code. Maybe improve this.
#
data = request.method + b' ' + request.uri + b' HTTP/1.1\x0d\x0a'
for h in request.requestHeaders.getAllRawHeaders():
data += h[0] + b': ' + b",".join(h[1]) + b'\x0d\x0a'
data += b"\x0d\x0a"
data += request.content.read()
protocol.dataReceived(data)
return NOT_DONE_YET

View File

@@ -0,0 +1,25 @@
###############################################################################
#
# The MIT License (MIT)
#
# Copyright (c) typedef int GmbH
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
###############################################################################

View File

@@ -0,0 +1,124 @@
###############################################################################
#
# The MIT License (MIT)
#
# Copyright (c) typedef int GmbH
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
###############################################################################
# t.i.reactor doesn't exist until we've imported it once, but we
# need it to exist so we can @patch it out in the tests ...
from twisted.internet import reactor # noqa
from unittest.mock import patch, Mock
from twisted.internet.defer import inlineCallbacks, succeed
from twisted.trial import unittest
from autobahn.twisted.wamp import ApplicationRunner
def raise_error(*args, **kw):
raise RuntimeError("we always fail")
class TestApplicationRunner(unittest.TestCase):
@patch('twisted.internet.reactor')
def test_runner_default(self, fakereactor):
fakereactor.connectTCP = Mock(side_effect=raise_error)
runner = ApplicationRunner('ws://fake:1234/ws', 'dummy realm')
# we should get "our" RuntimeError when we call run
self.assertRaises(RuntimeError, runner.run, raise_error)
# both reactor.run and reactor.stop should have been called
self.assertEqual(fakereactor.run.call_count, 1)
self.assertEqual(fakereactor.stop.call_count, 1)
@patch('twisted.internet.reactor')
@inlineCallbacks
def test_runner_no_run(self, fakereactor):
fakereactor.connectTCP = Mock(side_effect=raise_error)
runner = ApplicationRunner('ws://fake:1234/ws', 'dummy realm')
try:
yield runner.run(raise_error, start_reactor=False)
self.fail() # should have raise an exception, via Deferred
except RuntimeError as e:
# make sure it's "our" exception
self.assertEqual(e.args[0], "we always fail")
# neither reactor.run() NOR reactor.stop() should have been called
# (just connectTCP() will have been called)
self.assertEqual(fakereactor.run.call_count, 0)
self.assertEqual(fakereactor.stop.call_count, 0)
@patch('twisted.internet.reactor')
def test_runner_no_run_happypath(self, fakereactor):
proto = Mock()
fakereactor.connectTCP = Mock(return_value=succeed(proto))
runner = ApplicationRunner('ws://fake:1234/ws', 'dummy realm')
d = runner.run(Mock(), start_reactor=False)
# shouldn't have actually connected to anything
# successfully, and the run() call shouldn't have inserted
# any of its own call/errbacks. (except the cleanup handler)
self.assertFalse(d.called)
self.assertEqual(1, len(d.callbacks))
# neither reactor.run() NOR reactor.stop() should have been called
# (just connectTCP() will have been called)
self.assertEqual(fakereactor.run.call_count, 0)
self.assertEqual(fakereactor.stop.call_count, 0)
@patch('twisted.internet.reactor')
def test_runner_bad_proxy(self, fakereactor):
proxy = 'myproxy'
self.assertRaises(
AssertionError,
ApplicationRunner,
'ws://fake:1234/ws', 'dummy realm',
proxy=proxy
)
@patch('twisted.internet.reactor')
def test_runner_proxy(self, fakereactor):
proto = Mock()
fakereactor.connectTCP = Mock(return_value=succeed(proto))
proxy = {'host': 'myproxy', 'port': 3128}
runner = ApplicationRunner('ws://fake:1234/ws', 'dummy realm', proxy=proxy)
d = runner.run(Mock(), start_reactor=False)
# shouldn't have actually connected to anything
# successfully, and the run() call shouldn't have inserted
# any of its own call/errbacks. (except the cleanup handler)
self.assertFalse(d.called)
self.assertEqual(1, len(d.callbacks))
# neither reactor.run() NOR reactor.stop() should have been called
# (just connectTCP() will have been called)
self.assertEqual(fakereactor.run.call_count, 0)
self.assertEqual(fakereactor.stop.call_count, 0)

View File

@@ -0,0 +1,139 @@
###############################################################################
#
# The MIT License (MIT)
#
# Copyright (c) typedef int GmbH
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
###############################################################################
import sys
from unittest.mock import Mock
import twisted.internet
from twisted.trial import unittest
from autobahn.twisted import choosereactor
class ChooseReactorTests(unittest.TestCase):
def patch_reactor(self, name, new_reactor):
"""
Patch ``name`` so that Twisted will grab a fake reactor instead of
a real one.
"""
if hasattr(twisted.internet, name):
self.patch(twisted.internet, name, new_reactor)
else:
def _cleanup():
delattr(twisted.internet, name)
setattr(twisted.internet, name, new_reactor)
def patch_modules(self):
"""
Patch ``sys.modules`` so that Twisted believes there is no
installed reactor.
"""
old_modules = dict(sys.modules)
new_modules = dict(sys.modules)
del new_modules["twisted.internet.reactor"]
def _cleanup():
sys.modules = old_modules
self.addCleanup(_cleanup)
sys.modules = new_modules
def test_unknown(self):
"""
``install_optimal_reactor`` will use the default reactor if it is
unable to detect the platform it is running on.
"""
reactor_mock = Mock()
self.patch_reactor("selectreactor", reactor_mock)
self.patch(sys, "platform", "unknown")
# Emulate that a reactor has not been installed
self.patch_modules()
choosereactor.install_optimal_reactor()
reactor_mock.install.assert_called_once_with()
def test_mac(self):
"""
``install_optimal_reactor`` will install KQueueReactor on
Darwin (OS X).
"""
reactor_mock = Mock()
self.patch_reactor("kqreactor", reactor_mock)
self.patch(sys, "platform", "darwin")
# Emulate that a reactor has not been installed
self.patch_modules()
choosereactor.install_optimal_reactor()
reactor_mock.install.assert_called_once_with()
def test_win(self):
"""
``install_optimal_reactor`` will install IOCPReactor on Windows.
"""
if sys.platform != 'win32':
raise unittest.SkipTest('unit test requires Windows')
reactor_mock = Mock()
self.patch_reactor("iocpreactor", reactor_mock)
self.patch(sys, "platform", "win32")
# Emulate that a reactor has not been installed
self.patch_modules()
choosereactor.install_optimal_reactor()
reactor_mock.install.assert_called_once_with()
def test_bsd(self):
"""
``install_optimal_reactor`` will install KQueueReactor on BSD.
"""
reactor_mock = Mock()
self.patch_reactor("kqreactor", reactor_mock)
self.patch(sys, "platform", "freebsd11")
# Emulate that a reactor has not been installed
self.patch_modules()
choosereactor.install_optimal_reactor()
reactor_mock.install.assert_called_once_with()
def test_linux(self):
"""
``install_optimal_reactor`` will install EPollReactor on Linux.
"""
reactor_mock = Mock()
self.patch_reactor("epollreactor", reactor_mock)
self.patch(sys, "platform", "linux")
# Emulate that a reactor has not been installed
self.patch_modules()
choosereactor.install_optimal_reactor()
reactor_mock.install.assert_called_once_with()

View File

@@ -0,0 +1,438 @@
###############################################################################
#
# The MIT License (MIT)
#
# Copyright (c) typedef int GmbH
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
###############################################################################
import os
from unittest.mock import Mock, patch
if os.environ.get('USE_TWISTED', False):
from autobahn.twisted.component import Component
from zope.interface import directlyProvides
from autobahn.wamp.message import Welcome, Goodbye, Hello, Abort
from autobahn.wamp.serializer import JsonSerializer
from autobahn.testutil import FakeTransport
from twisted.internet.interfaces import IStreamClientEndpoint
from twisted.internet.defer import inlineCallbacks, succeed, Deferred
from twisted.internet.task import Clock
from twisted.trial import unittest
from txaio.testutil import replace_loop
class ConnectionTests(unittest.TestCase):
def setUp(self):
pass
@patch('txaio.sleep', return_value=succeed(None))
@inlineCallbacks
def test_successful_connect(self, fake_sleep):
endpoint = Mock()
joins = []
def joined(session, details):
joins.append((session, details))
return session.leave()
directlyProvides(endpoint, IStreamClientEndpoint)
component = Component(
transports={
"type": "websocket",
"url": "ws://127.0.0.1/ws",
"endpoint": endpoint,
}
)
component.on('join', joined)
def connect(factory, **kw):
proto = factory.buildProtocol('ws://localhost/')
transport = FakeTransport()
proto.makeConnection(transport)
from autobahn.websocket.protocol import WebSocketProtocol
from base64 import b64encode
from hashlib import sha1
key = proto.websocket_key + WebSocketProtocol._WS_MAGIC
proto.data = (
b"HTTP/1.1 101 Switching Protocols\x0d\x0a"
b"Upgrade: websocket\x0d\x0a"
b"Connection: upgrade\x0d\x0a"
b"Sec-Websocket-Protocol: wamp.2.json\x0d\x0a"
b"Sec-Websocket-Accept: " + b64encode(sha1(key).digest()) + b"\x0d\x0a\x0d\x0a"
)
proto.processHandshake()
from autobahn.wamp import role
features = role.RoleBrokerFeatures(
publisher_identification=True,
pattern_based_subscription=True,
session_meta_api=True,
subscription_meta_api=True,
subscriber_blackwhite_listing=True,
publisher_exclusion=True,
subscription_revocation=True,
payload_transparency=True,
payload_encryption_cryptobox=True,
)
msg = Welcome(123456, dict(broker=features), realm='realm')
serializer = JsonSerializer()
data, is_binary = serializer.serialize(msg)
proto.onMessage(data, is_binary)
msg = Goodbye()
proto.onMessage(*serializer.serialize(msg))
proto.onClose(True, 100, "some old reason")
return succeed(proto)
endpoint.connect = connect
# XXX it would actually be nicer if we *could* support
# passing a reactor in here, but the _batched_timer =
# make_batched_timer() stuff (slash txaio in general)
# makes this "hard".
reactor = Clock()
with replace_loop(reactor):
yield component.start(reactor=reactor)
self.assertTrue(len(joins), 1)
# make sure we fire all our time-outs
reactor.advance(3600)
@patch('txaio.sleep', return_value=succeed(None))
def test_successful_proxy_connect(self, fake_sleep):
endpoint = Mock()
directlyProvides(endpoint, IStreamClientEndpoint)
component = Component(
transports={
"type": "websocket",
"url": "ws://127.0.0.1/ws",
"endpoint": endpoint,
"proxy": {
"host": "10.0.0.0",
"port": 65000,
},
"max_retries": 0,
},
is_fatal=lambda _: True,
)
@component.on_join
def joined(session, details):
return session.leave()
def connect(factory, **kw):
return succeed(Mock())
endpoint.connect = connect
# XXX it would actually be nicer if we *could* support
# passing a reactor in here, but the _batched_timer =
# make_batched_timer() stuff (slash txaio in general)
# makes this "hard".
reactor = Clock()
got_proxy_connect = Deferred()
def _tcp(host, port, factory, **kw):
self.assertEqual("10.0.0.0", host)
self.assertEqual(port, 65000)
got_proxy_connect.callback(None)
return endpoint.connect(factory._wrappedFactory)
reactor.connectTCP = _tcp
with replace_loop(reactor):
d = component.start(reactor=reactor)
def done(x):
if not got_proxy_connect.called:
got_proxy_connect.callback(x)
# make sure we fire all our time-outs
d.addCallbacks(done, done)
reactor.advance(3600)
return got_proxy_connect
@patch('txaio.sleep', return_value=succeed(None))
@inlineCallbacks
def test_cancel(self, fake_sleep):
"""
if we start a component but call .stop before it connects, ever,
it should still exit properly
"""
endpoint = Mock()
directlyProvides(endpoint, IStreamClientEndpoint)
component = Component(
transports={
"type": "websocket",
"url": "ws://127.0.0.1/ws",
"endpoint": endpoint,
}
)
def connect(factory, **kw):
return Deferred()
endpoint.connect = connect
# XXX it would actually be nicer if we *could* support
# passing a reactor in here, but the _batched_timer =
# make_batched_timer() stuff (slash txaio in general)
# makes this "hard".
reactor = Clock()
with replace_loop(reactor):
d = component.start(reactor=reactor)
component.stop()
yield d
@inlineCallbacks
def test_cancel_while_waiting(self):
"""
if we start a component but call .stop before it connects, ever,
it should still exit properly -- even if we're 'between'
connection attempts
"""
endpoint = Mock()
directlyProvides(endpoint, IStreamClientEndpoint)
component = Component(
transports={
"type": "websocket",
"url": "ws://127.0.0.1/ws",
"endpoint": endpoint,
"max_retries": 0,
"max_retry_delay": 5,
"initial_retry_delay": 5,
},
)
# XXX it would actually be nicer if we *could* support
# passing a reactor in here, but the _batched_timer =
# make_batched_timer() stuff (slash txaio in general)
# makes this "hard".
reactor = Clock()
with replace_loop(reactor):
def connect(factory, **kw):
d = Deferred()
reactor.callLater(10, d.errback(RuntimeError("no connect for yo")))
return d
endpoint.connect = connect
d0 = component.start(reactor=reactor)
assert component._delay_f is not None
assert not component._done_f.called
d1 = component.stop()
assert component._done_f is None
assert d0.called
yield d1
yield d0
@patch('txaio.sleep', return_value=succeed(None))
@inlineCallbacks
def test_connect_no_auth_method(self, fake_sleep):
endpoint = Mock()
directlyProvides(endpoint, IStreamClientEndpoint)
component = Component(
transports={
"type": "websocket",
"url": "ws://127.0.0.1/ws",
"endpoint": endpoint,
},
is_fatal=lambda e: True,
)
def connect(factory, **kw):
proto = factory.buildProtocol('boom')
proto.makeConnection(Mock())
from autobahn.websocket.protocol import WebSocketProtocol
from base64 import b64encode
from hashlib import sha1
key = proto.websocket_key + WebSocketProtocol._WS_MAGIC
proto.data = (
b"HTTP/1.1 101 Switching Protocols\x0d\x0a"
b"Upgrade: websocket\x0d\x0a"
b"Connection: upgrade\x0d\x0a"
b"Sec-Websocket-Protocol: wamp.2.json\x0d\x0a"
b"Sec-Websocket-Accept: " + b64encode(sha1(key).digest()) + b"\x0d\x0a\x0d\x0a"
)
proto.processHandshake()
from autobahn.wamp import role
subrole = role.RoleSubscriberFeatures()
msg = Hello("realm", roles=dict(subscriber=subrole), authmethods=["anonymous"])
serializer = JsonSerializer()
data, is_binary = serializer.serialize(msg)
proto.onMessage(data, is_binary)
msg = Abort(reason="wamp.error.no_auth_method")
proto.onMessage(*serializer.serialize(msg))
proto.onClose(False, 100, "wamp.error.no_auth_method")
return succeed(proto)
endpoint.connect = connect
# XXX it would actually be nicer if we *could* support
# passing a reactor in here, but the _batched_timer =
# make_batched_timer() stuff (slash txaio in general)
# makes this "hard".
reactor = Clock()
with replace_loop(reactor):
with self.assertRaises(RuntimeError) as ctx:
d = component.start(reactor=reactor)
# make sure we fire all our time-outs
reactor.advance(3600)
yield d
self.assertIn(
"Exhausted all transport",
str(ctx.exception)
)
class InvalidTransportConfigs(unittest.TestCase):
def test_invalid_key(self):
with self.assertRaises(ValueError) as ctx:
Component(
transports=dict(
foo='bar', # totally invalid key
),
)
self.assertIn("'foo' is not", str(ctx.exception))
def test_invalid_key_transport_list(self):
with self.assertRaises(ValueError) as ctx:
Component(
transports=[
dict(type='websocket', url='ws://127.0.0.1/ws'),
dict(foo='bar'), # totally invalid key
]
)
self.assertIn("'foo' is not a valid configuration item", str(ctx.exception))
def test_invalid_serializer_key(self):
with self.assertRaises(ValueError) as ctx:
Component(
transports=[
{
"url": "ws://127.0.0.1/ws",
"serializer": ["quux"],
}
]
)
self.assertIn("only for rawsocket", str(ctx.exception))
def test_invalid_serializer(self):
with self.assertRaises(ValueError) as ctx:
Component(
transports=[
{
"url": "ws://127.0.0.1/ws",
"serializers": ["quux"],
}
]
)
self.assertIn("Invalid serializer", str(ctx.exception))
def test_invalid_serializer_type_0(self):
with self.assertRaises(ValueError) as ctx:
Component(
transports=[
{
"url": "ws://127.0.0.1/ws",
"serializers": [1, 2],
}
]
)
self.assertIn("must be a list", str(ctx.exception))
def test_invalid_serializer_type_1(self):
with self.assertRaises(ValueError) as ctx:
Component(
transports=[
{
"url": "ws://127.0.0.1/ws",
"serializers": 1,
}
]
)
self.assertIn("must be a list", str(ctx.exception))
def test_invalid_type_key(self):
with self.assertRaises(ValueError) as ctx:
Component(
transports=[
{
"type": "bad",
}
]
)
self.assertIn("Invalid transport type", str(ctx.exception))
def test_invalid_type(self):
with self.assertRaises(ValueError) as ctx:
Component(
transports=[
"foo"
]
)
self.assertIn("invalid WebSocket URL", str(ctx.exception))
def test_no_url(self):
with self.assertRaises(ValueError) as ctx:
Component(
transports=[
{
"type": "websocket",
}
]
)
self.assertIn("Transport requires 'url'", str(ctx.exception))
def test_endpoint_bogus_object(self):
with self.assertRaises(ValueError) as ctx:
Component(
main=lambda r, s: None,
transports=[
{
"type": "websocket",
"url": "ws://example.com/ws",
"endpoint": ("not", "a", "dict"),
}
]
)
self.assertIn("'endpoint' configuration must be", str(ctx.exception))
def test_endpoint_valid(self):
Component(
main=lambda r, s: None,
transports=[
{
"type": "websocket",
"url": "ws://example.com/ws",
"endpoint": {
"type": "tcp",
"host": "1.2.3.4",
"port": "4321",
}
}
]
)

View File

@@ -0,0 +1,49 @@
###############################################################################
#
# The MIT License (MIT)
#
# Copyright (c) typedef int GmbH
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
###############################################################################
from twisted.trial.unittest import TestCase
class PluginTests(TestCase):
if True:
skip = "Plugins don't work under Python3 yet"
def test_import(self):
from twisted.plugins import autobahn_endpoints
self.assertTrue(hasattr(autobahn_endpoints, 'AutobahnClientParser'))
def test_parse_client_basic(self):
from twisted.plugins import autobahn_endpoints
self.assertTrue(hasattr(autobahn_endpoints, 'AutobahnClientParser'))
from twisted.internet.endpoints import clientFromString, quoteStringArgument
from twisted.internet import reactor
ep_string = "autobahn:{0}:url={1}".format(
quoteStringArgument('tcp:localhost:9000'),
quoteStringArgument('ws://localhost:9000'),
)
# we're just testing that this doesn't fail entirely
clientFromString(reactor, ep_string)

View File

@@ -0,0 +1,447 @@
###############################################################################
#
# The MIT License (MIT)
#
# Copyright (c) typedef int GmbH
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
###############################################################################
from unittest.mock import Mock
import txaio
txaio.use_twisted()
from autobahn.util import wildcards2patterns
from autobahn.twisted.websocket import WebSocketServerFactory
from autobahn.twisted.websocket import WebSocketServerProtocol
from autobahn.twisted.websocket import WebSocketClientProtocol
from autobahn.wamp.types import TransportDetails
from autobahn.websocket.types import ConnectingRequest
from twisted.python.failure import Failure
from twisted.internet.error import ConnectionDone, ConnectionAborted, \
ConnectionLost
from twisted.trial import unittest
try:
from twisted.internet.testing import StringTransport
except ImportError:
from twisted.test.proto_helpers import StringTransport
from autobahn.testutil import FakeTransport
class ExceptionHandlingTests(unittest.TestCase):
"""
Tests that we format various exception variations properly during
connectionLost
"""
def setUp(self):
self.factory = WebSocketServerFactory()
self.proto = WebSocketServerProtocol()
self.proto.factory = self.factory
self.proto.log = Mock()
def tearDown(self):
for call in [
self.proto.autoPingPendingCall,
self.proto.autoPingTimeoutCall,
self.proto.openHandshakeTimeoutCall,
self.proto.closeHandshakeTimeoutCall,
]:
if call is not None:
call.cancel()
def test_connection_done(self):
# pretend we connected
self.proto._connectionMade()
self.proto.connectionLost(Failure(ConnectionDone()))
messages = ' '.join([str(x[1]) for x in self.proto.log.mock_calls])
self.assertTrue('closed cleanly' in messages)
def test_connection_aborted(self):
# pretend we connected
self.proto._connectionMade()
self.proto.connectionLost(Failure(ConnectionAborted()))
messages = ' '.join([str(x[1]) for x in self.proto.log.mock_calls])
self.assertTrue(' aborted ' in messages)
def test_connection_lost(self):
# pretend we connected
self.proto._connectionMade()
self.proto.connectionLost(Failure(ConnectionLost()))
messages = ' '.join([str(x[1]) for x in self.proto.log.mock_calls])
self.assertTrue(' was lost ' in messages)
def test_connection_lost_arg(self):
# pretend we connected
self.proto._connectionMade()
self.proto.connectionLost(Failure(ConnectionLost("greetings")))
messages = ' '.join([str(x[1]) + str(x[2]) for x in self.proto.log.mock_calls])
self.assertTrue(' was lost ' in messages)
self.assertTrue('greetings' in messages)
class Hixie76RejectionTests(unittest.TestCase):
"""
Hixie-76 should not be accepted by an Autobahn server.
"""
def test_handshake_fails(self):
"""
A handshake from a client only supporting Hixie-76 will fail.
"""
t = FakeTransport()
f = WebSocketServerFactory()
p = WebSocketServerProtocol()
p.factory = f
p.transport = t
# from http://tools.ietf.org/html/draft-hixie-thewebsocketprotocol-76
http_request = b"GET /demo HTTP/1.1\r\nHost: example.com\r\nConnection: Upgrade\r\nSec-WebSocket-Key2: 12998 5 Y3 1 .P00\r\nSec-WebSocket-Protocol: sample\r\nUpgrade: WebSocket\r\nSec-WebSocket-Key1: 4 @1 46546xW%0l 1 5\r\nOrigin: http://example.com\r\n\r\n^n:ds[4U"
p.openHandshakeTimeout = 0
p._connectionMade()
p.data = http_request
p.processHandshake()
self.assertIn(b"HTTP/1.1 400", t._written)
self.assertIn(b"Hixie76 protocol not supported", t._written)
class WebSocketOriginMatching(unittest.TestCase):
"""
Test that we match Origin: headers properly, when asked to
"""
def setUp(self):
self.factory = WebSocketServerFactory()
self.factory.setProtocolOptions(
allowedOrigins=['127.0.0.1:*', '*.example.com:*']
)
self.proto = WebSocketServerProtocol()
self.proto.transport = StringTransport()
self.proto.factory = self.factory
self.proto.failHandshake = Mock()
self.proto._connectionMade()
def tearDown(self):
for call in [
self.proto.autoPingPendingCall,
self.proto.autoPingTimeoutCall,
self.proto.openHandshakeTimeoutCall,
self.proto.closeHandshakeTimeoutCall,
]:
if call is not None:
call.cancel()
def test_match_full_origin(self):
self.proto.data = b"\r\n".join([
b'GET /ws HTTP/1.1',
b'Host: www.example.com',
b'Sec-WebSocket-Version: 13',
b'Origin: http://www.example.com.malicious.com',
b'Sec-WebSocket-Extensions: permessage-deflate',
b'Sec-WebSocket-Key: tXAxWFUqnhi86Ajj7dRY5g==',
b'Connection: keep-alive, Upgrade',
b'Upgrade: websocket',
b'\r\n', # last string doesn't get a \r\n from join()
])
self.proto.consumeData()
self.assertTrue(self.proto.failHandshake.called, "Handshake should have failed")
arg = self.proto.failHandshake.mock_calls[0][1][0]
self.assertTrue('not allowed' in arg)
def test_match_wrong_scheme_origin(self):
# some monkey-business since we already did this in setUp, but
# we want a different set of matching origins
self.factory.setProtocolOptions(
allowedOrigins=['http://*.example.com:*']
)
self.proto.allowedOriginsPatterns = self.factory.allowedOriginsPatterns
self.proto.allowedOrigins = self.factory.allowedOrigins
# the actual test
self.factory.isSecure = False
self.proto.data = b"\r\n".join([
b'GET /ws HTTP/1.1',
b'Host: www.example.com',
b'Sec-WebSocket-Version: 13',
b'Origin: https://www.example.com',
b'Sec-WebSocket-Extensions: permessage-deflate',
b'Sec-WebSocket-Key: tXAxWFUqnhi86Ajj7dRY5g==',
b'Connection: keep-alive, Upgrade',
b'Upgrade: websocket',
b'\r\n', # last string doesn't get a \r\n from join()
])
self.proto.consumeData()
self.assertTrue(self.proto.failHandshake.called, "Handshake should have failed")
arg = self.proto.failHandshake.mock_calls[0][1][0]
self.assertTrue('not allowed' in arg)
def test_match_origin_secure_scheme(self):
self.factory.isSecure = True
self.factory.port = 443
self.proto.data = b"\r\n".join([
b'GET /ws HTTP/1.1',
b'Host: www.example.com',
b'Sec-WebSocket-Version: 13',
b'Origin: https://www.example.com',
b'Sec-WebSocket-Extensions: permessage-deflate',
b'Sec-WebSocket-Key: tXAxWFUqnhi86Ajj7dRY5g==',
b'Connection: keep-alive, Upgrade',
b'Upgrade: websocket',
b'\r\n', # last string doesn't get a \r\n from join()
])
self.proto.consumeData()
self.assertFalse(self.proto.failHandshake.called, "Handshake should have succeeded")
def test_match_origin_documentation_example(self):
"""
Test the examples from the docs
"""
self.factory.setProtocolOptions(
allowedOrigins=['*://*.example.com:*']
)
self.factory.isSecure = True
self.factory.port = 443
self.proto.data = b"\r\n".join([
b'GET /ws HTTP/1.1',
b'Host: www.example.com',
b'Sec-WebSocket-Version: 13',
b'Origin: http://www.example.com',
b'Sec-WebSocket-Extensions: permessage-deflate',
b'Sec-WebSocket-Key: tXAxWFUqnhi86Ajj7dRY5g==',
b'Connection: keep-alive, Upgrade',
b'Upgrade: websocket',
b'\r\n', # last string doesn't get a \r\n from join()
])
self.proto.consumeData()
self.assertFalse(self.proto.failHandshake.called, "Handshake should have succeeded")
def test_match_origin_examples(self):
"""
All the example origins from RFC6454 (3.2.1)
"""
# we're just testing the low-level function here...
from autobahn.websocket.protocol import _is_same_origin, _url_to_origin
policy = wildcards2patterns(['*example.com:*'])
# should parametrize test ...
for url in ['http://example.com/', 'http://example.com:80/',
'http://example.com/path/file',
'http://example.com/;semi=true',
# 'http://example.com./',
'//example.com/',
'http://@example.com']:
self.assertTrue(_is_same_origin(_url_to_origin(url), 'http', 80, policy), url)
def test_match_origin_counter_examples(self):
"""
All the example 'not-same' origins from RFC6454 (3.2.1)
"""
# we're just testing the low-level function here...
from autobahn.websocket.protocol import _is_same_origin, _url_to_origin
policy = wildcards2patterns(['example.com'])
for url in ['http://ietf.org/', 'http://example.org/',
'https://example.com/', 'http://example.com:8080/',
'http://www.example.com/']:
self.assertFalse(_is_same_origin(_url_to_origin(url), 'http', 80, policy))
def test_match_origin_edge(self):
# we're just testing the low-level function here...
from autobahn.websocket.protocol import _is_same_origin, _url_to_origin
policy = wildcards2patterns(['http://*example.com:80'])
self.assertTrue(
_is_same_origin(_url_to_origin('http://example.com:80'), 'http', 80, policy)
)
self.assertFalse(
_is_same_origin(_url_to_origin('http://example.com:81'), 'http', 81, policy)
)
self.assertFalse(
_is_same_origin(_url_to_origin('https://example.com:80'), 'http', 80, policy)
)
def test_origin_from_url(self):
from autobahn.websocket.protocol import _url_to_origin
# basic function
self.assertEqual(
_url_to_origin('http://example.com'),
('http', 'example.com', 80)
)
# should lower-case scheme
self.assertEqual(
_url_to_origin('hTTp://example.com'),
('http', 'example.com', 80)
)
def test_origin_file(self):
from autobahn.websocket.protocol import _url_to_origin
self.assertEqual('null', _url_to_origin('file:///etc/passwd'))
def test_origin_null(self):
from autobahn.websocket.protocol import _is_same_origin, _url_to_origin
self.assertEqual('null', _url_to_origin('null'))
self.assertFalse(
_is_same_origin(_url_to_origin('null'), 'http', 80, [])
)
self.assertFalse(
_is_same_origin(_url_to_origin('null'), 'https', 80, [])
)
self.assertFalse(
_is_same_origin(_url_to_origin('null'), '', 80, [])
)
self.assertFalse(
_is_same_origin(_url_to_origin('null'), None, 80, [])
)
class WebSocketXForwardedFor(unittest.TestCase):
"""
Test that (only) a trusted X-Forwarded-For can replace the peer address.
"""
def setUp(self):
self.factory = WebSocketServerFactory()
self.factory.setProtocolOptions(
trustXForwardedFor=2
)
self.proto = WebSocketServerProtocol()
self.proto.transport = StringTransport()
self.proto.factory = self.factory
self.proto.failHandshake = Mock()
self.proto._connectionMade()
def tearDown(self):
for call in [
self.proto.autoPingPendingCall,
self.proto.autoPingTimeoutCall,
self.proto.openHandshakeTimeoutCall,
self.proto.closeHandshakeTimeoutCall,
]:
if call is not None:
call.cancel()
def test_trusted_addresses(self):
self.proto.data = b"\r\n".join([
b'GET /ws HTTP/1.1',
b'Host: www.example.com',
b'Origin: http://www.example.com',
b'Sec-WebSocket-Version: 13',
b'Sec-WebSocket-Extensions: permessage-deflate',
b'Sec-WebSocket-Key: tXAxWFUqnhi86Ajj7dRY5g==',
b'Connection: keep-alive, Upgrade',
b'Upgrade: websocket',
b'X-Forwarded-For: 1.2.3.4, 2.3.4.5, 111.222.33.44',
b'\r\n', # last string doesn't get a \r\n from join()
])
self.proto.consumeData()
self.assertEquals(
self.proto.peer, "2.3.4.5",
"The second address in X-Forwarded-For should have been picked as the peer address")
class OnConnectingTests(unittest.TestCase):
"""
Tests related to onConnecting callback
These tests are testing generic behavior, but are somewhat tied to
'a framework' so we're just testing using Twisted-specifics here.
"""
def test_on_connecting_client_fails(self):
MAGIC_STR = 'bad stuff'
class TestProto(WebSocketClientProtocol):
state = None
wasClean = True
log = Mock()
def onConnecting(self, transport_details):
raise RuntimeError(MAGIC_STR)
proto = TestProto()
proto.transport = FakeTransport()
d = proto.startHandshake()
self.successResultOf(d) # error is ignored
# ... but error should be logged
self.assertTrue(len(proto.log.mock_calls) > 0)
magic_found = False
for i in range(len(proto.log.mock_calls)):
if MAGIC_STR in str(proto.log.mock_calls[i]):
magic_found = True
self.assertTrue(magic_found, 'MAGIC_STR not found when expected')
def test_on_connecting_client_success(self):
class TestProto(WebSocketClientProtocol):
state = None
wasClean = True
perMessageCompressionOffers = []
version = 18
openHandshakeTimeout = 5
log = Mock()
def onConnecting(self, transport_details):
return ConnectingRequest(
host="example.com",
port=443,
resource="/ws",
)
proto = TestProto()
proto.transport = FakeTransport()
proto.factory = Mock()
proto._connectionMade()
d = proto.startHandshake()
req = self.successResultOf(d)
self.assertEqual("example.com", req.host)
self.assertEqual(443, req.port)
self.assertEqual("/ws", req.resource)
def test_str_transport(self):
details = TransportDetails(
channel_type=TransportDetails.CHANNEL_TYPE_FUNCTION,
peer="example.com",
is_secure=False,
channel_id={},
)
# we can str() this and it doesn't fail
str(details)
def test_str_connecting(self):
req = ConnectingRequest(host="example.com", port="1234", resource="/ws")
# we can str() this and it doesn't fail
str(req)

View File

@@ -0,0 +1,70 @@
###############################################################################
#
# The MIT License (MIT)
#
# Copyright (c) typedef int GmbH
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
###############################################################################
import unittest
from unittest.mock import Mock
from autobahn.twisted.rawsocket import (WampRawSocketServerFactory,
WampRawSocketServerProtocol,
WampRawSocketClientFactory,
WampRawSocketClientProtocol)
from autobahn.testutil import FakeTransport
class RawSocketHandshakeTests(unittest.TestCase):
def test_handshake_succeeds(self):
"""
A client can connect to a server.
"""
session_mock = Mock()
t = FakeTransport()
f = WampRawSocketClientFactory(lambda: session_mock)
p = WampRawSocketClientProtocol()
p.transport = t
p.factory = f
server_session_mock = Mock()
st = FakeTransport()
sf = WampRawSocketServerFactory(lambda: server_session_mock)
sp = WampRawSocketServerProtocol()
sp.transport = st
sp.factory = sf
sp.connectionMade()
p.connectionMade()
# Send the server the client handshake
sp.dataReceived(t._written[0:1])
sp.dataReceived(t._written[1:4])
# Send the client the server handshake
p.dataReceived(st._written)
# The handshake succeeds, a session on each end is created
# onOpen is called on the session
session_mock.onOpen.assert_called_once_with(p)
server_session_mock.onOpen.assert_called_once_with(sp)

View File

@@ -0,0 +1,81 @@
from twisted.trial import unittest
try:
from autobahn.twisted.testing import create_memory_agent, MemoryReactorClockResolver, create_pumper
HAVE_TESTING = True
except ImportError:
HAVE_TESTING = False
from twisted.internet.defer import inlineCallbacks
from autobahn.twisted.websocket import WebSocketServerProtocol
class TestAgent(unittest.TestCase):
skip = not HAVE_TESTING
def setUp(self):
self.pumper = create_pumper()
self.reactor = MemoryReactorClockResolver()
return self.pumper.start()
def tearDown(self):
return self.pumper.stop()
@inlineCallbacks
def test_echo_server(self):
class EchoServer(WebSocketServerProtocol):
def onMessage(self, msg, is_binary):
self.sendMessage(msg)
agent = create_memory_agent(self.reactor, self.pumper, EchoServer)
proto = yield agent.open("ws://localhost:1234/ws", dict())
messages = []
def got(msg, is_binary):
messages.append(msg)
proto.on("message", got)
proto.sendMessage(b"hello")
if True:
# clean close
proto.sendClose()
else:
# unclean close
proto.transport.loseConnection()
yield proto.is_closed
self.assertEqual([b"hello"], messages)
# FIXME:
# /twisted/util.py", line 162, in transport_channel_id channel_id_type, type(transport)))
# builtins.RuntimeError: cannot determine TLS channel ID of type "tls-unique" when TLS is not
# available on this transport <class 'twisted.test.iosim.FakeTransport'>
# @inlineCallbacks
# def test_secure_echo_server(self):
# class EchoServer(WebSocketServerProtocol):
# def onMessage(self, msg, is_binary):
# self.sendMessage(msg)
# agent = create_memory_agent(self.reactor, self.pumper, EchoServer)
# proto = yield agent.open("wss://localhost:1234/ws", dict())
# messages = []
# def got(msg, is_binary):
# messages.append(msg)
# proto.on("message", got)
# proto.sendMessage(b"hello")
# if True:
# # clean close
# proto.sendClose()
# else:
# # unclean close
# proto.transport.loseConnection()
# yield proto.is_closed
# self.assertEqual([b"hello"], messages)

View File

@@ -0,0 +1,90 @@
###############################################################################
#
# The MIT License (MIT)
#
# Copyright (c) typedef int GmbH
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
###############################################################################
import unittest
from unittest.mock import patch
from zope.interface import implementer
from twisted.internet.interfaces import IReactorTime
@implementer(IReactorTime)
class FakeReactor(object):
"""
This just fakes out enough reactor methods so .run() can work.
"""
stop_called = False
def __init__(self, to_raise):
self.stop_called = False
self.to_raise = to_raise
self.delayed = []
def run(self, *args, **kw):
raise self.to_raise
def stop(self):
self.stop_called = True
def callLater(self, delay, func, *args, **kwargs):
self.delayed.append((delay, func, args, kwargs))
def connectTCP(self, *args, **kw):
raise RuntimeError("ConnectTCP shouldn't get called")
class TestWampTwistedRunner(unittest.TestCase):
# XXX should figure out *why* but the test_protocol timeout
# tests fail if we *don't* patch out this txaio stuff. So,
# presumably it's messing up some global state that both tests
# implicitly depend on ...
@patch('txaio.use_twisted')
@patch('txaio.start_logging')
@patch('txaio.config')
def test_connect_error(self, *args):
"""
Ensure the runner doesn't swallow errors and that it exits the
reactor properly if there is one.
"""
try:
from autobahn.twisted.wamp import ApplicationRunner
from twisted.internet.error import ConnectionRefusedError
# the 'reactor' member doesn't exist until we import it
from twisted.internet import reactor # noqa: F401
except ImportError:
raise unittest.SkipTest('No twisted')
runner = ApplicationRunner('ws://localhost:1', 'realm')
exception = ConnectionRefusedError("It's a trap!")
with patch('twisted.internet.reactor', FakeReactor(exception)) as mockreactor:
self.assertRaises(
ConnectionRefusedError,
# pass a no-op session-creation method
runner.run, lambda _: None, start_reactor=True
)
self.assertTrue(mockreactor.stop_called)

View File

@@ -0,0 +1,293 @@
###############################################################################
#
# The MIT License (MIT)
#
# Copyright (c) typedef int GmbH
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
###############################################################################
# IHostnameResolver et al. were added in Twisted 17.1.0 .. before
# that, it was IResolverSimple only.
try:
from twisted.internet.interfaces import IHostnameResolver
except ImportError:
raise ImportError(
"Twisted 17.1.0 or later required for autobahn.twisted.testing"
)
from twisted.internet.defer import Deferred
from twisted.internet.address import IPv4Address
from twisted.internet._resolver import HostResolution # "internal" class, but it's simple
from twisted.internet.interfaces import ISSLTransport, IReactorPluggableNameResolver
try:
from twisted.internet.testing import MemoryReactorClock
except ImportError:
from twisted.test.proto_helpers import MemoryReactorClock
from twisted.test import iosim
from zope.interface import directlyProvides, implementer
from autobahn.websocket.interfaces import IWebSocketClientAgent
from autobahn.twisted.websocket import _TwistedWebSocketClientAgent
from autobahn.twisted.websocket import WebSocketServerProtocol
from autobahn.twisted.websocket import WebSocketServerFactory
__all__ = (
'create_pumper',
'create_memory_agent',
'MemoryReactorClockResolver',
)
@implementer(IHostnameResolver)
class _StaticTestResolver(object):
def resolveHostName(self, receiver, hostName, portNumber=0):
"""
Implement IHostnameResolver which always returns 127.0.0.1:31337
"""
resolution = HostResolution(hostName)
receiver.resolutionBegan(resolution)
receiver.addressResolved(
IPv4Address('TCP', '127.0.0.1', 31337 if portNumber == 0 else portNumber)
)
receiver.resolutionComplete()
@implementer(IReactorPluggableNameResolver)
class _TestNameResolver(object):
"""
A test version of IReactorPluggableNameResolver
"""
_resolver = None
@property
def nameResolver(self):
if self._resolver is None:
self._resolver = _StaticTestResolver()
return self._resolver
def installNameResolver(self, resolver):
old = self._resolver
self._resolver = resolver
return old
class MemoryReactorClockResolver(MemoryReactorClock, _TestNameResolver):
"""
Combine MemoryReactor, Clock and an IReactorPluggableNameResolver
together.
"""
pass
class _TwistedWebMemoryAgent(IWebSocketClientAgent):
"""
A testing agent which will hook up an instance of
`server_protocol` for every client that is created via the `open`
API call.
:param reactor: the reactor to use for tests (usually an instance
of MemoryReactorClockResolver)
:param pumper: an implementation IPumper (e.g. as returned by
`create_pumper`)
:param server_protocol: the server-side WebSocket protocol class
to instantiate (e.g. a subclass of `WebSocketServerProtocol`
"""
def __init__(self, reactor, pumper, server_protocol):
self._reactor = reactor
self._server_protocol = server_protocol
self._pumper = pumper
# our "real" underlying agent under test
self._agent = _TwistedWebSocketClientAgent(self._reactor)
self._pumps = set()
self._servers = dict() # client -> server
def open(self, transport_config, options, protocol_class=None):
"""
Implement IWebSocketClientAgent with in-memory transports.
:param transport_config: a string starting with 'wss://' or
'ws://'
:param options: a dict containing options
:param protocol_class: the client protocol class to
instantiate (or `None` for defaults, which is to use
`WebSocketClientProtocol`)
"""
is_secure = transport_config.startswith("wss://")
# call our "real" agent
real_client_protocol = self._agent.open(
transport_config, options,
protocol_class=protocol_class,
)
if is_secure:
host, port, factory, context_factory, timeout, bindAddress = self._reactor.sslClients[-1]
else:
host, port, factory, timeout, bindAddress = self._reactor.tcpClients[-1]
server_address = IPv4Address('TCP', '127.0.0.1', port)
client_address = IPv4Address('TCP', '127.0.0.1', 31337)
server_protocol = self._server_protocol()
# the protocol could already have a factory
if getattr(server_protocol, "factory", None) is None:
server_protocol.factory = WebSocketServerFactory()
server_transport = iosim.FakeTransport(
server_protocol, isServer=True,
hostAddress=server_address, peerAddress=client_address)
clientProtocol = factory.buildProtocol(None)
client_transport = iosim.FakeTransport(
clientProtocol, isServer=False,
hostAddress=client_address, peerAddress=server_address)
if is_secure:
directlyProvides(server_transport, ISSLTransport)
directlyProvides(client_transport, ISSLTransport)
pump = iosim.connect(
server_protocol, server_transport, clientProtocol, client_transport)
self._pumper.add(pump)
def add_mapping(proto):
self._servers[proto] = server_protocol
return proto
real_client_protocol.addCallback(add_mapping)
return real_client_protocol
class _Kalamazoo(object):
"""
Feeling whimsical about class names, see https://en.wikipedia.org/wiki/Handcar
This is 'an IOPump pumper', an object which causes a series of
IOPumps it is monitoring to do their I/O operations
periodically. This needs the 'real' reactor which trial drives,
because reasons:
- so @inlineCallbacks / async-def functions work
(if I could explain exactly why here, I would)
- we need to 'break the loop' of synchronous calls somewhere and
polluting the tests themselves with that is bad
- get rid of e.g. .flush() calls in tests themselves (thus
'teaching' the tests about details of I/O scheduling that they
shouldn't know).
"""
def __init__(self):
self._pumps = set()
self._pumping = False
self._waiting_for_stop = []
from twisted.internet import reactor as global_reactor
self._global_reactor = global_reactor
def add(self, p):
"""
Add a new IOPump. It will be removed when both its client and
server are disconnected.
"""
self._pumps.add(p)
def start(self):
"""
Begin triggering I/O in all IOPump instances we have. We will keep
periodically 'pumping' our IOPumps until `.stop()` is
called. Call from `setUp()` for example.
"""
if self._pumping:
return
self._pumping = True
self._global_reactor.callLater(0, self._pump_once)
def stop(self):
"""
:returns: a Deferred that fires when we have stopped pump()-ing
Call from `tearDown()`, for example.
"""
if self._pumping or len(self._waiting_for_stop):
d = Deferred()
self._waiting_for_stop.append(d)
self._pumping = False
return d
d = Deferred()
d.callback(None)
return d
def _pump_once(self):
"""
flush all data from all our IOPump instances and schedule another
iteration on the global reactor
"""
if self._pumping:
self._flush()
self._global_reactor.callLater(0.1, self._pump_once)
else:
for d in self._waiting_for_stop:
d.callback(None)
self._waiting_for_stop = []
def _flush(self):
"""
Flush all data between pending client/server pairs.
"""
old_pumps = self._pumps
new_pumps = self._pumps = set()
for p in old_pumps:
p.flush()
if p.clientIO.disconnected and p.serverIO.disconnected:
continue
new_pumps.add(p)
def create_pumper():
"""
return a new instance implementing IPumper
"""
return _Kalamazoo()
def create_memory_agent(reactor, pumper, server_protocol):
"""
return a new instance implementing `IWebSocketClientAgent`.
connection attempts will be satisfied by traversing the Upgrade
request path starting at `resource` to find a `WebSocketResource`
and then exchange data between client and server using purely
in-memory buffers.
"""
# Note, we currently don't actually do any "resource traversing"
# and basically accept any path at all to our websocket resource
if server_protocol is None:
server_protocol = WebSocketServerProtocol
return _TwistedWebMemoryAgent(reactor, pumper, server_protocol)

View File

@@ -0,0 +1,304 @@
###############################################################################
#
# The MIT License (MIT)
#
# Copyright (c) typedef int GmbH
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
###############################################################################
import os
import hashlib
import threading
from typing import Optional, Union, Dict, Any
from twisted.internet.defer import Deferred
from twisted.internet.address import IPv4Address, UNIXAddress
from twisted.internet.interfaces import ITransport, IProcessTransport
from autobahn.wamp.types import TransportDetails
try:
from twisted.internet.stdio import PipeAddress
except ImportError:
# stdio.PipeAddress is only avail on Twisted 13.0+
PipeAddress = type(None)
try:
from twisted.internet.address import IPv6Address
_HAS_IPV6 = True
except ImportError:
_HAS_IPV6 = False
IPv6Address = type(None)
try:
from twisted.internet.interfaces import ISSLTransport
from twisted.protocols.tls import TLSMemoryBIOProtocol
from OpenSSL.SSL import Connection
_HAS_TLS = True
except ImportError:
_HAS_TLS = False
__all = (
'sleep',
'peer2str',
'transport_channel_id',
'extract_peer_certificate',
'create_transport_details',
)
def sleep(delay, reactor=None):
"""
Inline sleep for use in co-routines (Twisted ``inlineCallback`` decorated functions).
.. seealso::
* `twisted.internet.defer.inlineCallbacks <http://twistedmatrix.com/documents/current/api/twisted.internet.defer.html#inlineCallbacks>`__
* `twisted.internet.interfaces.IReactorTime <http://twistedmatrix.com/documents/current/api/twisted.internet.interfaces.IReactorTime.html>`__
:param delay: Time to sleep in seconds.
:type delay: float
:param reactor: The Twisted reactor to use.
:type reactor: None or provider of ``IReactorTime``.
"""
if not reactor:
from twisted.internet import reactor
d = Deferred()
reactor.callLater(delay, d.callback, None)
return d
def peer2str(transport: Union[ITransport, IProcessTransport]) -> str:
"""
Return a *peer descriptor* given a Twisted transport, for example:
* ``tcp4:127.0.0.1:52914``: a TCPv4 socket
* ``unix:/tmp/server.sock``: a Unix domain socket
* ``process:142092``: a Pipe originating from a spawning (parent) process
* ``pipe``: a Pipe terminating in a spawned (child) process
:returns: Returns a string representation of the peer of the Twisted transport.
"""
# IMPORTANT: we need to _first_ test for IProcessTransport
if IProcessTransport.providedBy(transport):
# note the PID of the forked process in the peer descriptor
res = "process:{}".format(transport.pid)
elif ITransport.providedBy(transport):
addr: Union[IPv4Address, IPv6Address, UNIXAddress, PipeAddress] = transport.getPeer()
if isinstance(addr, IPv4Address):
res = "tcp4:{0}:{1}".format(addr.host, addr.port)
elif _HAS_IPV6 and isinstance(addr, IPv6Address):
res = "tcp6:{0}:{1}".format(addr.host, addr.port)
elif isinstance(addr, UNIXAddress):
if addr.name:
res = "unix:{0}".format(addr.name)
else:
res = "unix"
elif isinstance(addr, PipeAddress):
# sadly, we don't have a way to get at the PID of the other side of the pipe
# res = "pipe"
res = "process:{0}".format(os.getppid())
else:
# gracefully fallback if we can't map the peer's address
res = "unknown"
else:
# gracefully fallback if we can't map the peer's transport
res = "unknown"
return res
if not _HAS_TLS:
def transport_channel_id(transport: object, is_server: bool, channel_id_type: Optional[str] = None) -> Optional[bytes]:
if channel_id_type is None:
return b'\x00' * 32
else:
raise RuntimeError('cannot determine TLS channel ID of type "{}" when TLS is not available on this system'.format(channel_id_type))
else:
def transport_channel_id(transport: object, is_server: bool, channel_id_type: Optional[str] = None) -> Optional[bytes]:
"""
Return TLS channel ID of WAMP transport of the given TLS channel ID type.
Application-layer user authentication protocols are vulnerable to generic credential forwarding attacks,
where an authentication credential sent by a client C to a server M may then be used by M to impersonate C at
another server S.
To prevent such credential forwarding attacks, modern authentication protocols rely on channel bindings.
For example, WAMP-cryptosign can use the tls-unique channel identifier provided by the TLS layer to strongly
bind authentication credentials to the underlying channel, so that a credential received on one TLS channel
cannot be forwarded on another.
:param transport: The Twisted TLS transport to extract the TLS channel ID from. If the transport isn't
TLS based, and non-empty ``channel_id_type`` is requested, ``None`` will be returned. If the transport
is indeed TLS based, an empty ``channel_id_type`` of ``None`` is requested, 32 NUL bytes will be returned.
:param is_server: Flag indicating that the transport is a server transport.
:param channel_id_type: TLS channel ID type, if set currently only ``"tls-unique"`` is supported.
:returns: The TLS channel ID (32 bytes).
"""
if channel_id_type is None:
return b'\x00' * 32
if channel_id_type not in ['tls-unique']:
raise RuntimeError('invalid TLS channel ID type "{}" requested'.format(channel_id_type))
if not isinstance(transport, TLSMemoryBIOProtocol):
raise RuntimeError(
'cannot determine TLS channel ID of type "{}" when TLS is not available on this transport {}'.format(
channel_id_type, type(transport)))
# get access to the OpenSSL connection underlying the Twisted protocol
# https://twistedmatrix.com/documents/current/api/twisted.protocols.tls.TLSMemoryBIOProtocol.html#getHandle
connection: Connection = transport.getHandle()
assert connection and isinstance(connection, Connection)
# Obtain latest TLS Finished message that we expected from peer, or None if handshake is not completed.
# http://www.pyopenssl.org/en/stable/api/ssl.html#OpenSSL.SSL.Connection.get_peer_finished
is_not_resumed = True
if channel_id_type == 'tls-unique':
# see also: https://bugs.python.org/file22646/tls_channel_binding.patch
if is_server != is_not_resumed:
# for routers (=servers) XOR new sessions, the channel ID is based on the TLS Finished message we
# expected to receive from the client: contents of the message or None if the TLS handshake has
# not yet completed.
tls_finished_msg = connection.get_peer_finished()
else:
# for clients XOR resumed sessions, the channel ID is based on the TLS Finished message we sent
# to the router (=server): contents of the message or None if the TLS handshake has not yet completed.
tls_finished_msg = connection.get_finished()
if tls_finished_msg is None:
# this can occur when:
# 1. we made a successful connection (in a TCP sense) but something failed with
# the TLS handshake (e.g. invalid certificate)
# 2. the TLS handshake has not yet completed
return b'\x00' * 32
else:
m = hashlib.sha256()
m.update(tls_finished_msg)
return m.digest()
else:
raise NotImplementedError('should not arrive here (unhandled channel_id_type "{}")'.format(channel_id_type))
if not _HAS_TLS:
def extract_peer_certificate(transport: object) -> Optional[Dict[str, Any]]:
"""
Dummy when no TLS is available.
:param transport: Ignored.
:return: Always return ``None``.
"""
return None
else:
def extract_peer_certificate(transport: TLSMemoryBIOProtocol) -> Optional[Dict[str, Any]]:
"""
Extract TLS x509 client certificate information from a Twisted stream transport, and
return a dict with x509 TLS client certificate information (if the client provided a
TLS client certificate).
:param transport: The secure transport from which to extract the peer certificate (if present).
:returns: If the peer provided a certificate, the parsed certificate information set.
"""
# check if the Twisted transport is a TLSMemoryBIOProtocol
if not (ISSLTransport.providedBy(transport) and hasattr(transport, 'getPeerCertificate')):
return None
cert = transport.getPeerCertificate()
if cert:
# extract x509 name components from an OpenSSL X509Name object
def maybe_bytes(_value):
if isinstance(_value, bytes):
return _value.decode('utf8')
else:
return _value
result = {
'md5': '{}'.format(maybe_bytes(cert.digest('md5'))).upper(),
'sha1': '{}'.format(maybe_bytes(cert.digest('sha1'))).upper(),
'sha256': '{}'.format(maybe_bytes(cert.digest('sha256'))).upper(),
'expired': bool(cert.has_expired()),
'hash': maybe_bytes(cert.subject_name_hash()),
'serial': int(cert.get_serial_number()),
'signature_algorithm': maybe_bytes(cert.get_signature_algorithm()),
'version': int(cert.get_version()),
'not_before': maybe_bytes(cert.get_notBefore()),
'not_after': maybe_bytes(cert.get_notAfter()),
'extensions': []
}
for i in range(cert.get_extension_count()):
ext = cert.get_extension(i)
ext_info = {
'name': '{}'.format(maybe_bytes(ext.get_short_name())),
'value': '{}'.format(maybe_bytes(ext)),
'critical': ext.get_critical() != 0
}
result['extensions'].append(ext_info)
for entity, name in [('subject', cert.get_subject()), ('issuer', cert.get_issuer())]:
result[entity] = {}
for key, value in name.get_components():
key = maybe_bytes(key)
value = maybe_bytes(value)
result[entity]['{}'.format(key).lower()] = '{}'.format(value)
return result
def create_transport_details(transport: Union[ITransport, IProcessTransport], is_server: bool) -> TransportDetails:
"""
Create transport details from Twisted transport.
:param transport: The Twisted transport to extract information from.
:param is_server: Flag indicating whether this transport side is a "server" (as in TCP server).
:return: Transport details object filled with information from the Twisted transport.
"""
peer = peer2str(transport)
own_pid = os.getpid()
if hasattr(threading, 'get_native_id'):
# New in Python 3.8
# https://docs.python.org/3/library/threading.html?highlight=get_native_id#threading.get_native_id
own_tid = threading.get_native_id()
else:
own_tid = threading.get_ident()
own_fd = -1
if _HAS_TLS and ISSLTransport.providedBy(transport):
channel_id = {
# this will only be filled when the TLS opening handshake is complete (!)
'tls-unique': transport_channel_id(transport, is_server, 'tls-unique'),
}
channel_type = TransportDetails.CHANNEL_TYPE_TLS
peer_cert = extract_peer_certificate(transport)
is_secure = True
else:
channel_id = {}
channel_type = TransportDetails.CHANNEL_TYPE_TCP
peer_cert = None
is_secure = False
# FIXME: really set a default (websocket)?
channel_framing = TransportDetails.CHANNEL_FRAMING_WEBSOCKET
td = TransportDetails(channel_type=channel_type, channel_framing=channel_framing, peer=peer,
is_server=is_server, own_pid=own_pid, own_tid=own_tid, own_fd=own_fd,
is_secure=is_secure, channel_id=channel_id, peer_cert=peer_cert)
return td

View File

@@ -0,0 +1,902 @@
###############################################################################
#
# The MIT License (MIT)
#
# Copyright (c) typedef int GmbH
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
###############################################################################
import inspect
import binascii
import random
from typing import Optional, Dict, Any, List, Union
import txaio
from autobahn.websocket.protocol import WebSocketProtocol
txaio.use_twisted() # noqa
from twisted.internet.defer import inlineCallbacks, succeed, Deferred
from twisted.application import service
from twisted.internet.interfaces import IReactorCore, IStreamClientEndpoint
try:
from twisted.internet.ssl import CertificateOptions
except ImportError:
# PyOpenSSL / TLS not available
CertificateOptions = Any
from autobahn.util import public
from autobahn.websocket.util import parse_url as parse_ws_url
from autobahn.rawsocket.util import parse_url as parse_rs_url
from autobahn.twisted.websocket import WampWebSocketClientFactory
from autobahn.twisted.rawsocket import WampRawSocketClientFactory
from autobahn.websocket.compress import PerMessageDeflateOffer, \
PerMessageDeflateResponse, PerMessageDeflateResponseAccept
from autobahn.wamp import protocol, auth
from autobahn.wamp.interfaces import ITransportHandler, ISession, IAuthenticator, ISerializer
from autobahn.wamp.types import ComponentConfig
__all__ = [
'ApplicationSession',
'ApplicationSessionFactory',
'ApplicationRunner',
'Application',
'Service',
# new API
'Session',
# 'run', # should probably move this method to here? instead of component
]
@public
class ApplicationSession(protocol.ApplicationSession):
"""
WAMP application session for Twisted-based applications.
Implements:
* :class:`autobahn.wamp.interfaces.ITransportHandler`
* :class:`autobahn.wamp.interfaces.ISession`
"""
log = txaio.make_logger()
ITransportHandler.register(ApplicationSession)
# ISession.register collides with the abc.ABCMeta.register method
ISession.abc_register(ApplicationSession)
class ApplicationSessionFactory(protocol.ApplicationSessionFactory):
"""
WAMP application session factory for Twisted-based applications.
"""
session: ApplicationSession = ApplicationSession
"""
The application session class this application session factory will use. Defaults to :class:`autobahn.twisted.wamp.ApplicationSession`.
"""
log = txaio.make_logger()
@public
class ApplicationRunner(object):
"""
This class is a convenience tool mainly for development and quick hosting
of WAMP application components.
It can host a WAMP application component in a WAMP-over-WebSocket client
connecting to a WAMP router.
"""
log = txaio.make_logger()
def __init__(self,
url: str,
realm: Optional[str] = None,
extra: Optional[Dict[str, Any]] = None,
serializers: Optional[List[ISerializer]] = None,
ssl: Optional[CertificateOptions] = None,
proxy: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, Any]] = None,
websocket_options: Optional[Dict[str, Any]] = None,
max_retries: Optional[int] = None,
initial_retry_delay: Optional[float] = None,
max_retry_delay: Optional[float] = None,
retry_delay_growth: Optional[float] = None,
retry_delay_jitter: Optional[float] = None):
"""
:param url: The WebSocket URL of the WAMP router to connect to (e.g. `ws://example.com:8080/mypath`)
:param realm: The WAMP realm to join the application session to.
:param extra: Optional extra configuration to forward to the application component.
:param serializers: A list of WAMP serializers to use (or None for default serializers).
Serializers must implement :class:`autobahn.wamp.interfaces.ISerializer`.
:type serializers: list
:param ssl: (Optional). If specified this should be an
instance suitable to pass as ``sslContextFactory`` to
:class:`twisted.internet.endpoints.SSL4ClientEndpoint`` such
as :class:`twisted.internet.ssl.CertificateOptions`. Leaving
it as ``None`` will use the result of calling Twisted
:meth:`twisted.internet.ssl.platformTrust` which tries to use
your distribution's CA certificates.
:param proxy: Explicit proxy server to use; a dict with ``host`` and ``port`` keys.
:param headers: Additional headers to send (only applies to WAMP-over-WebSocket).
:param websocket_options: Specific WebSocket options to set (only applies to WAMP-over-WebSocket).
If not provided, conservative and practical default are chosen.
:param max_retries: Maximum number of reconnection attempts. Unlimited if set to -1.
:param initial_retry_delay: Initial delay for reconnection attempt in seconds (Default: 1.0s).
:param max_retry_delay: Maximum delay for reconnection attempts in seconds (Default: 60s).
:param retry_delay_growth: The growth factor applied to the retry delay between reconnection
attempts (Default 1.5).
:param retry_delay_jitter: A 0-argument callable that introduces noise into the
delay (Default ``random.random``).
"""
# IMPORTANT: keep this, as it is tested in
# autobahn.twisted.test.test_tx_application_runner.TestApplicationRunner.test_runner_bad_proxy
assert (proxy is None or type(proxy) == dict)
self.url = url
self.realm = realm
self.extra = extra or dict()
self.serializers = serializers
self.ssl = ssl
self.proxy = proxy
self.headers = headers
self.websocket_options = websocket_options
self.max_retries = max_retries
self.initial_retry_delay = initial_retry_delay
self.max_retry_delay = max_retry_delay
self.retry_delay_growth = retry_delay_growth
self.retry_delay_jitter = retry_delay_jitter
# this if for auto-reconnection when Twisted ClientService is avail
self._client_service = None
# total number of successful connections
self._connect_successes = 0
@public
def stop(self):
"""
Stop reconnecting, if auto-reconnecting was enabled.
"""
self.log.debug('{klass}.stop()', klass=self.__class__.__name__)
if self._client_service:
return self._client_service.stopService()
else:
return succeed(None)
@public
def run(self, make, start_reactor: bool = True, auto_reconnect: bool = False,
log_level: str = 'info', endpoint: Optional[IStreamClientEndpoint] = None,
reactor: Optional[IReactorCore] = None) -> Union[type(None), Deferred]:
"""
Run the application component.
:param make: A factory that produces instances of :class:`autobahn.twisted.wamp.ApplicationSession`
when called with an instance of :class:`autobahn.wamp.types.ComponentConfig`.
:param start_reactor: When ``True`` (the default) this method starts
the Twisted reactor and doesn't return until the reactor
stops. If there are any problems starting the reactor or
connect()-ing, we stop the reactor and raise the exception
back to the caller.
:param auto_reconnect:
:param log_level:
:param endpoint:
:param reactor:
:return: None is returned, unless you specify
``start_reactor=False`` in which case the Deferred that
connect() returns is returned; this will callback() with
an IProtocol instance, which will actually be an instance
of :class:`WampWebSocketClientProtocol`
"""
self.log.debug('{klass}.run()', klass=self.__class__.__name__)
if start_reactor:
# only select framework, set loop and start logging when we are asked
# start the reactor - otherwise we are running in a program that likely
# already tool care of all this.
from twisted.internet import reactor
txaio.use_twisted()
txaio.config.loop = reactor
txaio.start_logging(level=log_level)
if callable(make):
# factory for use ApplicationSession
def create():
cfg = ComponentConfig(self.realm, self.extra, runner=self)
try:
session = make(cfg)
except Exception:
self.log.failure('ApplicationSession could not be instantiated: {log_failure.value}')
if start_reactor and reactor.running:
reactor.stop()
raise
else:
return session
else:
create = make
if self.url.startswith('rs'):
# try to parse RawSocket URL
isSecure, host, port = parse_rs_url(self.url)
# use the first configured serializer if any (which means, auto-choose "best")
serializer = self.serializers[0] if self.serializers else None
# create a WAMP-over-RawSocket transport client factory
transport_factory = WampRawSocketClientFactory(create, serializer=serializer)
else:
# try to parse WebSocket URL
isSecure, host, port, resource, path, params = parse_ws_url(self.url)
# create a WAMP-over-WebSocket transport client factory
transport_factory = WampWebSocketClientFactory(create, url=self.url, serializers=self.serializers, proxy=self.proxy, headers=self.headers)
# client WebSocket settings - similar to:
# - http://crossbar.io/docs/WebSocket-Compression/#production-settings
# - http://crossbar.io/docs/WebSocket-Options/#production-settings
# The permessage-deflate extensions offered to the server
offers = [PerMessageDeflateOffer()]
# Function to accept permessage-deflate responses from the server
def accept(response):
if isinstance(response, PerMessageDeflateResponse):
return PerMessageDeflateResponseAccept(response)
# default WebSocket options for all client connections
protocol_options = {
'version': WebSocketProtocol.DEFAULT_SPEC_VERSION,
'utf8validateIncoming': True,
'acceptMaskedServerFrames': False,
'maskClientFrames': True,
'applyMask': True,
'maxFramePayloadSize': 1048576,
'maxMessagePayloadSize': 1048576,
'autoFragmentSize': 65536,
'failByDrop': True,
'echoCloseCodeReason': False,
'serverConnectionDropTimeout': 1.,
'openHandshakeTimeout': 2.5,
'closeHandshakeTimeout': 1.,
'tcpNoDelay': True,
'perMessageCompressionOffers': offers,
'perMessageCompressionAccept': accept,
'autoPingInterval': 10.,
'autoPingTimeout': 5.,
'autoPingSize': 12,
# see: https://github.com/crossbario/autobahn-python/issues/1327 and
# _cancelAutoPingTimeoutCall
'autoPingRestartOnAnyTraffic': True,
}
# let user override above default options
if self.websocket_options:
protocol_options.update(self.websocket_options)
# set websocket protocol options on Autobahn/Twisted protocol factory, from where it will
# be applied for every Autobahn/Twisted protocol instance from the factory
transport_factory.setProtocolOptions(**protocol_options)
# supress pointless log noise
transport_factory.noisy = False
if endpoint:
client = endpoint
else:
# if user passed ssl= but isn't using isSecure, we'll never
# use the ssl argument which makes no sense.
context_factory = None
if self.ssl is not None:
if not isSecure:
raise RuntimeError(
'ssl= argument value passed to %s conflicts with the "ws:" '
'prefix of the url argument. Did you mean to use "wss:"?' %
self.__class__.__name__)
context_factory = self.ssl
elif isSecure:
from twisted.internet.ssl import optionsForClientTLS
context_factory = optionsForClientTLS(host)
from twisted.internet import reactor
if self.proxy is not None:
from twisted.internet.endpoints import TCP4ClientEndpoint
client = TCP4ClientEndpoint(reactor, self.proxy['host'], self.proxy['port'])
transport_factory.contextFactory = context_factory
elif isSecure:
from twisted.internet.endpoints import SSL4ClientEndpoint
assert context_factory is not None
client = SSL4ClientEndpoint(reactor, host, port, context_factory)
else:
from twisted.internet.endpoints import TCP4ClientEndpoint
client = TCP4ClientEndpoint(reactor, host, port)
# as the reactor shuts down, we wish to wait until we've sent
# out our "Goodbye" message; leave() returns a Deferred that
# fires when the transport gets to STATE_CLOSED
def cleanup(proto):
if hasattr(proto, '_session') and proto._session is not None:
if proto._session.is_attached():
return proto._session.leave()
elif proto._session.is_connected():
return proto._session.disconnect()
# when our proto was created and connected, make sure it's cleaned
# up properly later on when the reactor shuts down for whatever reason
def init_proto(proto):
self._connect_successes += 1
reactor.addSystemEventTrigger('before', 'shutdown', cleanup, proto)
return proto
use_service = False
if auto_reconnect:
try:
# since Twisted 16.1.0
from twisted.application.internet import ClientService
from twisted.application.internet import backoffPolicy
use_service = True
except ImportError:
use_service = False
if use_service:
# this code path is automatically reconnecting ..
self.log.debug('using t.a.i.ClientService')
if (self.max_retries is not None or self.initial_retry_delay is not None or self.max_retry_delay is not None or self.retry_delay_growth is not None or self.retry_delay_jitter is not None):
if self.max_retry_delay > 0:
kwargs = {}
def _jitter():
j = 1 if self.retry_delay_jitter is None else self.retry_delay_jitter
return random.random() * j
for key, val in [('initialDelay', self.initial_retry_delay),
('maxDelay', self.max_retry_delay),
('factor', self.retry_delay_growth),
('jitter', _jitter)]:
if val is not None:
kwargs[key] = val
# retry policy that will only try to reconnect if we connected
# successfully at least once before (so it fails on host unreachable etc ..)
def retry(failed_attempts):
if self._connect_successes > 0 and (self.max_retries == -1 or failed_attempts < self.max_retries):
return backoffPolicy(**kwargs)(failed_attempts)
else:
print('hit stop')
self.stop()
return 100000000000000
else:
# immediately reconnect (zero delay)
def retry(_):
return 0
else:
retry = backoffPolicy()
# https://twistedmatrix.com/documents/current/api/twisted.application.internet.ClientService.html
self._client_service = ClientService(client, transport_factory, retryPolicy=retry)
self._client_service.startService()
d = self._client_service.whenConnected()
else:
# this code path is only connecting once!
self.log.debug('using t.i.e.connect()')
d = client.connect(transport_factory)
# if we connect successfully, the arg is a WampWebSocketClientProtocol
d.addCallback(init_proto)
# if the user didn't ask us to start the reactor, then they
# get to deal with any connect errors themselves.
if start_reactor:
# if an error happens in the connect(), we save the underlying
# exception so that after the event-loop exits we can re-raise
# it to the caller.
class ErrorCollector(object):
exception = None
def __call__(self, failure):
self.exception = failure.value
reactor.stop()
connect_error = ErrorCollector()
d.addErrback(connect_error)
# now enter the Twisted reactor loop
reactor.run()
# if the ApplicationSession sets an "error" key on the self.config.extra dictionary, which
# has been set to the self.extra dictionary, extract the Exception from that and re-raise
# it as the very last one (see below) exciting back to the caller of self.run()
app_error = self.extra.get('error', None)
# if we exited due to a connection error, raise that to the caller
if connect_error.exception:
raise connect_error.exception
elif app_error:
raise app_error
else:
# let the caller handle any errors
return d
class _ApplicationSession(ApplicationSession):
"""
WAMP application session class used internally with :class:`autobahn.twisted.app.Application`.
"""
def __init__(self, config, app):
"""
:param config: The component configuration.
:type config: Instance of :class:`autobahn.wamp.types.ComponentConfig`
:param app: The application this session is for.
:type app: Instance of :class:`autobahn.twisted.wamp.Application`.
"""
# noinspection PyArgumentList
ApplicationSession.__init__(self, config)
self.app = app
@inlineCallbacks
def onConnect(self):
"""
Implements :meth:`autobahn.wamp.interfaces.ISession.onConnect`
"""
yield self.app._fire_signal('onconnect')
self.join(self.config.realm)
@inlineCallbacks
def onJoin(self, details):
"""
Implements :meth:`autobahn.wamp.interfaces.ISession.onJoin`
"""
for uri, proc in self.app._procs:
yield self.register(proc, uri)
for uri, handler in self.app._handlers:
yield self.subscribe(handler, uri)
yield self.app._fire_signal('onjoined')
@inlineCallbacks
def onLeave(self, details):
"""
Implements :meth:`autobahn.wamp.interfaces.ISession.onLeave`
"""
yield self.app._fire_signal('onleave')
self.disconnect()
@inlineCallbacks
def onDisconnect(self):
"""
Implements :meth:`autobahn.wamp.interfaces.ISession.onDisconnect`
"""
yield self.app._fire_signal('ondisconnect')
class Application(object):
"""
A WAMP application. The application object provides a simple way of
creating, debugging and running WAMP application components.
"""
log = txaio.make_logger()
def __init__(self, prefix=None):
"""
:param prefix: The application URI prefix to use for procedures and topics,
e.g. ``"com.example.myapp"``.
:type prefix: unicode
"""
self._prefix = prefix
# procedures to be registered once the app session has joined the router/realm
self._procs = []
# event handler to be subscribed once the app session has joined the router/realm
self._handlers = []
# app lifecycle signal handlers
self._signals = {}
# once an app session is connected, this will be here
self.session = None
def __call__(self, config):
"""
Factory creating a WAMP application session for the application.
:param config: Component configuration.
:type config: Instance of :class:`autobahn.wamp.types.ComponentConfig`
:returns: obj -- An object that derives of
:class:`autobahn.twisted.wamp.ApplicationSession`
"""
assert(self.session is None)
self.session = _ApplicationSession(config, self)
return self.session
def run(self, url="ws://localhost:8080/ws", realm="realm1", start_reactor=True):
"""
Run the application.
:param url: The URL of the WAMP router to connect to.
:type url: unicode
:param realm: The realm on the WAMP router to join.
:type realm: unicode
"""
runner = ApplicationRunner(url, realm)
return runner.run(self.__call__, start_reactor)
def register(self, uri=None):
"""
Decorator exposing a function as a remote callable procedure.
The first argument of the decorator should be the URI of the procedure
to register under.
:Example:
.. code-block:: python
@app.register('com.myapp.add2')
def add2(a, b):
return a + b
Above function can then be called remotely over WAMP using the URI `com.myapp.add2`
the function was registered under.
If no URI is given, the URI is constructed from the application URI prefix
and the Python function name.
:Example:
.. code-block:: python
app = Application('com.myapp')
# implicit URI will be 'com.myapp.add2'
@app.register()
def add2(a, b):
return a + b
If the function `yields` (is a co-routine), the `@inlineCallbacks` decorator
will be applied automatically to it. In that case, if you wish to return something,
you should use `returnValue`:
:Example:
.. code-block:: python
from twisted.internet.defer import returnValue
@app.register('com.myapp.add2')
def add2(a, b):
res = yield stuff(a, b)
returnValue(res)
:param uri: The URI of the procedure to register under.
:type uri: unicode
"""
def decorator(func):
if uri:
_uri = uri
else:
assert(self._prefix is not None)
_uri = "{0}.{1}".format(self._prefix, func.__name__)
if inspect.isgeneratorfunction(func):
func = inlineCallbacks(func)
self._procs.append((_uri, func))
return func
return decorator
def subscribe(self, uri=None):
"""
Decorator attaching a function as an event handler.
The first argument of the decorator should be the URI of the topic
to subscribe to. If no URI is given, the URI is constructed from
the application URI prefix and the Python function name.
If the function yield, it will be assumed that it's an asynchronous
process and inlineCallbacks will be applied to it.
:Example:
.. code-block:: python
@app.subscribe('com.myapp.topic1')
def onevent1(x, y):
print("got event on topic1", x, y)
:param uri: The URI of the topic to subscribe to.
:type uri: unicode
"""
def decorator(func):
if uri:
_uri = uri
else:
assert(self._prefix is not None)
_uri = "{0}.{1}".format(self._prefix, func.__name__)
if inspect.isgeneratorfunction(func):
func = inlineCallbacks(func)
self._handlers.append((_uri, func))
return func
return decorator
def signal(self, name):
"""
Decorator attaching a function as handler for application signals.
Signals are local events triggered internally and exposed to the
developer to be able to react to the application lifecycle.
If the function yield, it will be assumed that it's an asynchronous
coroutine and inlineCallbacks will be applied to it.
Current signals :
- `onjoined`: Triggered after the application session has joined the
realm on the router and registered/subscribed all procedures
and event handlers that were setup via decorators.
- `onleave`: Triggered when the application session leaves the realm.
.. code-block:: python
@app.signal('onjoined')
def _():
# do after the app has join a realm
:param name: The name of the signal to watch.
:type name: unicode
"""
def decorator(func):
if inspect.isgeneratorfunction(func):
func = inlineCallbacks(func)
self._signals.setdefault(name, []).append(func)
return func
return decorator
@inlineCallbacks
def _fire_signal(self, name, *args, **kwargs):
"""
Utility method to call all signal handlers for a given signal.
:param name: The signal name.
:type name: str
"""
for handler in self._signals.get(name, []):
try:
# FIXME: what if the signal handler is not a coroutine?
# Why run signal handlers synchronously?
yield handler(*args, **kwargs)
except Exception as e:
# FIXME
self.log.info("Warning: exception in signal handler swallowed: {err}", err=e)
class Service(service.MultiService):
"""
A WAMP application as a twisted service.
The application object provides a simple way of creating, debugging and running WAMP application
components inside a traditional twisted application
This manages application lifecycle of the wamp connection using startService and stopService
Using services also allows to create integration tests that properly terminates their connections
It can host a WAMP application component in a WAMP-over-WebSocket client
connecting to a WAMP router.
"""
factory = WampWebSocketClientFactory
def __init__(self, url, realm, make, extra=None, context_factory=None):
"""
:param url: The WebSocket URL of the WAMP router to connect to (e.g. `ws://somehost.com:8090/somepath`)
:type url: unicode
:param realm: The WAMP realm to join the application session to.
:type realm: unicode
:param make: A factory that produces instances of :class:`autobahn.asyncio.wamp.ApplicationSession`
when called with an instance of :class:`autobahn.wamp.types.ComponentConfig`.
:type make: callable
:param extra: Optional extra configuration to forward to the application component.
:type extra: dict
:param context_factory: optional, only for secure connections. Passed as contextFactory to
the ``listenSSL()`` call; see https://twistedmatrix.com/documents/current/api/twisted.internet.interfaces.IReactorSSL.connectSSL.html
:type context_factory: twisted.internet.ssl.ClientContextFactory or None
You can replace the attribute factory in order to change connectionLost or connectionFailed behaviour.
The factory attribute must return a WampWebSocketClientFactory object
"""
self.url = url
self.realm = realm
self.extra = extra or dict()
self.make = make
self.context_factory = context_factory
service.MultiService.__init__(self)
self.setupService()
def setupService(self):
"""
Setup the application component.
"""
is_secure, host, port, resource, path, params = parse_ws_url(self.url)
# factory for use ApplicationSession
def create():
cfg = ComponentConfig(self.realm, self.extra)
session = self.make(cfg)
return session
# create a WAMP-over-WebSocket transport client factory
transport_factory = self.factory(create, url=self.url)
# setup the client from a Twisted endpoint
if is_secure:
from twisted.application.internet import SSLClient
ctx = self.context_factory
if ctx is None:
from twisted.internet.ssl import optionsForClientTLS
ctx = optionsForClientTLS(host)
client = SSLClient(host, port, transport_factory, contextFactory=ctx)
else:
if self.context_factory is not None:
raise Exception("context_factory specified on non-secure URI")
from twisted.application.internet import TCPClient
client = TCPClient(host, port, transport_factory)
client.setServiceParent(self)
# new API
class Session(protocol._SessionShim):
# XXX these methods are redundant, but put here for possibly
# better clarity; maybe a bad idea.
def on_welcome(self, welcome_msg):
pass
def on_join(self, details):
pass
def on_leave(self, details):
self.disconnect()
def on_connect(self):
self.join(self.config.realm)
def on_disconnect(self):
pass
# experimental authentication API
class AuthCryptoSign(object):
def __init__(self, **kw):
# should put in checkconfig or similar
for key in kw.keys():
if key not in ['authextra', 'authid', 'authrole', 'privkey']:
raise ValueError(
"Unexpected key '{}' for {}".format(key, self.__class__.__name__)
)
for key in ['privkey']:
if key not in kw:
raise ValueError(
"Must provide '{}' for cryptosign".format(key)
)
for key in kw.get('authextra', dict()):
if key not in ['pubkey', 'channel_binding', 'trustroot', 'challenge']:
raise ValueError(
"Unexpected key '{}' in 'authextra'".format(key)
)
from autobahn.wamp.cryptosign import CryptosignKey
self._privkey = CryptosignKey.from_bytes(
binascii.a2b_hex(kw['privkey'])
)
if 'pubkey' in kw.get('authextra', dict()):
pubkey = kw['authextra']['pubkey']
if pubkey != self._privkey.public_key():
raise ValueError(
"Public key doesn't correspond to private key"
)
else:
kw['authextra'] = kw.get('authextra', dict())
kw['authextra']['pubkey'] = self._privkey.public_key()
self._args = kw
def on_challenge(self, session, challenge):
# sign the challenge with our private key.
channel_id_type = self._args['authextra'].get('channel_binding', None)
channel_id = self.transport.transport_details.channel_id.get(channel_id_type, None)
signed_challenge = self._privkey.sign_challenge(challenge, channel_id=channel_id,
channel_id_type=channel_id_type)
return signed_challenge
IAuthenticator.register(AuthCryptoSign)
class AuthWampCra(object):
def __init__(self, **kw):
# should put in checkconfig or similar
for key in kw.keys():
if key not in ['authextra', 'authid', 'authrole', 'secret']:
raise ValueError(
"Unexpected key '{}' for {}".format(key, self.__class__.__name__)
)
for key in ['secret', 'authid']:
if key not in kw:
raise ValueError(
"Must provide '{}' for wampcra".format(key)
)
self._args = kw
self._secret = kw.pop('secret')
if not isinstance(self._secret, str):
self._secret = self._secret.decode('utf8')
def on_challenge(self, session, challenge):
key = self._secret.encode('utf8')
if 'salt' in challenge.extra:
key = auth.derive_key(
key,
challenge.extra['salt'],
challenge.extra['iterations'],
challenge.extra['keylen']
)
signature = auth.compute_wcs(
key,
challenge.extra['challenge'].encode('utf8')
)
return signature.decode('ascii')
IAuthenticator.register(AuthWampCra)

View File

@@ -0,0 +1,907 @@
###############################################################################
#
# The MIT License (MIT)
#
# Copyright (c) typedef int GmbH
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
###############################################################################
from base64 import b64encode, b64decode
from typing import Optional
from zope.interface import implementer
import txaio
txaio.use_twisted()
import twisted.internet.protocol
from twisted.internet import endpoints
from twisted.internet.interfaces import ITransport
from twisted.internet.error import ConnectionDone, ConnectionAborted, \
ConnectionLost
from twisted.internet.defer import Deferred
from twisted.python.failure import Failure
from twisted.internet.protocol import connectionDone
from autobahn.util import public, hltype, hlval
from autobahn.util import _is_tls_error, _maybe_tls_reason
from autobahn.wamp import websocket
from autobahn.wamp.types import TransportDetails
from autobahn.websocket.types import ConnectionRequest, ConnectionResponse, ConnectionDeny
from autobahn.websocket import protocol
from autobahn.websocket.interfaces import IWebSocketClientAgent
from autobahn.twisted.util import create_transport_details, transport_channel_id
from autobahn.websocket.compress import PerMessageDeflateOffer, \
PerMessageDeflateOfferAccept, \
PerMessageDeflateResponse, \
PerMessageDeflateResponseAccept
__all__ = (
'create_client_agent',
'WebSocketAdapterProtocol',
'WebSocketServerProtocol',
'WebSocketClientProtocol',
'WebSocketAdapterFactory',
'WebSocketServerFactory',
'WebSocketClientFactory',
'WrappingWebSocketAdapter',
'WrappingWebSocketServerProtocol',
'WrappingWebSocketClientProtocol',
'WrappingWebSocketServerFactory',
'WrappingWebSocketClientFactory',
'listenWS',
'connectWS',
'WampWebSocketServerProtocol',
'WampWebSocketServerFactory',
'WampWebSocketClientProtocol',
'WampWebSocketClientFactory',
)
def create_client_agent(reactor):
"""
:returns: an instance implementing IWebSocketClientAgent
"""
return _TwistedWebSocketClientAgent(reactor)
def check_transport_config(transport_config):
"""
raises a ValueError if `transport_config` is invalid
"""
# XXX move me to "autobahn.websocket.util"
if not isinstance(transport_config, str):
raise ValueError(
"'transport_config' must be a string, found {}".format(type(transport_config))
)
# XXX also accept everything Crossbar has in client transport configs? e.g like:
# { "type": "websocket", "endpoint": {"type": "tcp", "host": "example.com", ...}}
# XXX what about TLS options? (the above point would address that too)
if not transport_config.startswith("ws://") and \
not transport_config.startswith("wss://"):
raise ValueError(
"'transport_config' must start with 'ws://' or 'wss://'"
)
return None
def check_client_options(options):
"""
raises a ValueError if `options` is invalid
"""
# XXX move me to "autobahn.websocket.util"
if not isinstance(options, dict):
raise ValueError(
"'options' must be a dict"
)
# anything that WebSocketClientFactory accepts (at least)
valid_keys = [
"origin",
"protocols",
"useragent",
"headers",
"proxy",
]
for actual_k in options.keys():
if actual_k not in valid_keys:
raise ValueError(
"'options' may not contain '{}'".format(actual_k)
)
def _endpoint_from_config(reactor, factory, transport_config, options):
# XXX might want some Crossbar code here? e.g. if we allow
# "transport_config" to be a dict etc.
# ... passing in the Factory is weird, but that's what parses all
# the options and the URL currently
if factory.isSecure:
# create default client SSL context factory when none given
from twisted.internet import ssl
context_factory = ssl.optionsForClientTLS(factory.host)
if factory.proxy is not None:
factory.contextFactory = context_factory
endpoint = endpoints.HostnameEndpoint(
reactor,
factory.proxy['host'],
factory.proxy['port'],
# timeout, option?
)
else:
if factory.isSecure:
from twisted.internet import ssl
endpoint = endpoints.SSL4ClientEndpoint(
reactor,
factory.host,
factory.port,
context_factory,
# timeout, option?
)
else:
endpoint = endpoints.HostnameEndpoint( # XXX right? not TCP4ClientEndpoint
reactor,
factory.host,
factory.port,
# timeout, option?
# attemptDelay, option?
)
return endpoint
class _TwistedWebSocketClientAgent(IWebSocketClientAgent):
"""
This agent creates connections using Twisted
"""
def __init__(self, reactor):
self._reactor = reactor
def open(self, transport_config, options, protocol_class=None):
"""
Open a new connection.
:param dict transport_config: valid transport configuration
:param dict options: additional options for the factory
:param protocol_class: a callable that returns an instance of
the protocol (WebSocketClientProtocol if the default None
is passed in)
:returns: a Deferred that fires with an instance of
`protocol_class` (or WebSocketClientProtocol by default)
that has successfully shaken hands (completed the
handshake).
"""
check_transport_config(transport_config)
check_client_options(options)
factory = WebSocketClientFactory(
url=transport_config,
reactor=self._reactor,
**options
)
factory.protocol = WebSocketClientProtocol if protocol_class is None else protocol_class
# XXX might want "contextFactory" for TLS ...? (or e.g. CA etc options?)
endpoint = _endpoint_from_config(self._reactor, factory, transport_config, options)
rtn_d = Deferred()
proto_d = endpoint.connect(factory)
def failed(f):
rtn_d.errback(f)
def got_proto(proto):
def handshake_completed(arg):
rtn_d.callback(proto)
return arg
proto.is_open.addCallbacks(handshake_completed, failed)
return proto
proto_d.addCallbacks(got_proto, failed)
return rtn_d
class WebSocketAdapterProtocol(twisted.internet.protocol.Protocol):
"""
Adapter class for Twisted WebSocket client and server protocols.
Called from Twisted:
* :meth:`autobahn.twisted.websocket.WebSocketAdapterProtocol.connectionMade`
* :meth:`autobahn.twisted.websocket.WebSocketAdapterProtocol.connectionLost`
* :meth:`autobahn.twisted.websocket.WebSocketAdapterProtocol.dataReceived`
Called from Network-independent Code (WebSocket implementation):
* :meth:`autobahn.twisted.websocket.WebSocketAdapterProtocol._onOpen`
* :meth:`autobahn.twisted.websocket.WebSocketAdapterProtocol._onMessageBegin`
* :meth:`autobahn.twisted.websocket.WebSocketAdapterProtocol._onMessageFrameData`
* :meth:`autobahn.twisted.websocket.WebSocketAdapterProtocol._onMessageFrameEnd`
* :meth:`autobahn.twisted.websocket.WebSocketAdapterProtocol._onMessageEnd`
* :meth:`autobahn.twisted.websocket.WebSocketAdapterProtocol._onMessage`
* :meth:`autobahn.twisted.websocket.WebSocketAdapterProtocol._onPing`
* :meth:`autobahn.twisted.websocket.WebSocketAdapterProtocol._onPong`
* :meth:`autobahn.twisted.websocket.WebSocketAdapterProtocol._onClose`
FIXME:
* :meth:`autobahn.twisted.websocket.WebSocketAdapterProtocol._closeConnection`
* :meth:`autobahn.twisted.websocket.WebSocketAdapterProtocol._create_transport_details`
* :meth:`autobahn.twisted.websocket.WebSocketAdapterProtocol.registerProducer`
* :meth:`autobahn.twisted.websocket.WebSocketAdapterProtocol.unregisterProducer`
"""
log = txaio.make_logger()
peer: Optional[str] = None
is_server: Optional[bool] = None
def connectionMade(self):
# Twisted networking framework entry point, called by Twisted
# when the connection is established (either a client or a server)
# determine preliminary transport details (what is know at this point)
self._transport_details = create_transport_details(self.transport, self.is_server)
self._transport_details.channel_framing = TransportDetails.CHANNEL_FRAMING_WEBSOCKET
# backward compatibility
self.peer = self._transport_details.peer
# try to set "Nagle" option for TCP sockets
try:
self.transport.setTcpNoDelay(self.tcpNoDelay)
except: # don't touch this! does not work: AttributeError, OSError
# eg Unix Domain sockets throw Errno 22 on this
pass
# ok, now forward to the networking framework independent code for websocket
self._connectionMade()
# ok, done!
self.log.debug('{func} connection established for peer="{peer}"',
func=hltype(self.connectionMade),
peer=hlval(self.peer))
def connectionLost(self, reason: Failure = connectionDone):
# Twisted networking framework entry point, called by Twisted
# when the connection is lost (either a client or a server)
was_clean = False
if isinstance(reason.value, ConnectionDone):
self.log.debug("Connection to/from {peer} was closed cleanly",
peer=self.peer)
was_clean = True
elif _is_tls_error(reason.value):
self.log.error(_maybe_tls_reason(reason.value))
elif isinstance(reason.value, ConnectionAborted):
self.log.debug("Connection to/from {peer} was aborted locally",
peer=self.peer)
elif isinstance(reason.value, ConnectionLost):
message = str(reason.value)
if hasattr(reason.value, 'message'):
message = reason.value.message
self.log.debug(
"Connection to/from {peer} was lost in a non-clean fashion: {message}",
peer=self.peer,
message=message,
)
# at least: FileDescriptorOverrun, ConnectionFdescWentAway - but maybe others as well?
else:
self.log.debug("Connection to/from {peer} lost ({error_type}): {error})",
peer=self.peer, error_type=type(reason.value), error=reason.value)
# ok, now forward to the networking framework independent code for websocket
self._connectionLost(reason)
# ok, done!
if was_clean:
self.log.debug('{func} connection lost for peer="{peer}", closed cleanly',
func=hltype(self.connectionLost),
peer=hlval(self.peer))
else:
self.log.debug('{func} connection lost for peer="{peer}", closed with error {reason}',
func=hltype(self.connectionLost),
peer=hlval(self.peer),
reason=reason)
def dataReceived(self, data: bytes):
self.log.debug('{func} received {data_len} bytes for peer="{peer}"',
func=hltype(self.dataReceived),
peer=hlval(self.peer),
data_len=hlval(len(data)))
# bytes received from Twisted, forward to the networking framework independent code for websocket
self._dataReceived(data)
def _closeConnection(self, abort=False):
if abort and hasattr(self.transport, 'abortConnection'):
self.transport.abortConnection()
else:
# e.g. ProcessProtocol lacks abortConnection()
self.transport.loseConnection()
def _onOpen(self):
if self._transport_details.is_secure:
# now that the TLS opening handshake is complete, the actual TLS channel ID
# will be available. make sure to set it!
channel_id = {
'tls-unique': transport_channel_id(self.transport, self._transport_details.is_server, 'tls-unique'),
}
self._transport_details.channel_id = channel_id
self.onOpen()
def _onMessageBegin(self, isBinary):
self.onMessageBegin(isBinary)
def _onMessageFrameBegin(self, length):
self.onMessageFrameBegin(length)
def _onMessageFrameData(self, payload):
self.onMessageFrameData(payload)
def _onMessageFrameEnd(self):
self.onMessageFrameEnd()
def _onMessageFrame(self, payload):
self.onMessageFrame(payload)
def _onMessageEnd(self):
self.onMessageEnd()
def _onMessage(self, payload, isBinary):
self.onMessage(payload, isBinary)
def _onPing(self, payload):
self.onPing(payload)
def _onPong(self, payload):
self.onPong(payload)
def _onClose(self, wasClean, code, reason):
self.onClose(wasClean, code, reason)
def registerProducer(self, producer, streaming):
"""
Register a Twisted producer with this protocol.
:param producer: A Twisted push or pull producer.
:type producer: object
:param streaming: Producer type.
:type streaming: bool
"""
self.transport.registerProducer(producer, streaming)
def unregisterProducer(self):
"""
Unregister Twisted producer with this protocol.
"""
self.transport.unregisterProducer()
@public
class WebSocketServerProtocol(WebSocketAdapterProtocol, protocol.WebSocketServerProtocol):
"""
Base class for Twisted-based WebSocket server protocols.
Implements :class:`autobahn.websocket.interfaces.IWebSocketChannel`.
"""
log = txaio.make_logger()
is_server = True
# def onConnect(self, request: ConnectionRequest) -> Union[Optional[str], Tuple[Optional[str], Dict[str, str]]]:
# pass
@public
class WebSocketClientProtocol(WebSocketAdapterProtocol, protocol.WebSocketClientProtocol):
"""
Base class for Twisted-based WebSocket client protocols.
Implements :class:`autobahn.websocket.interfaces.IWebSocketChannel`.
"""
log = txaio.make_logger()
is_server = False
def _onConnect(self, response: ConnectionResponse):
self.log.debug('{meth}(response={response})', meth=hltype(self._onConnect), response=response)
return self.onConnect(response)
def startTLS(self):
self.log.debug("Starting TLS upgrade")
self.transport.startTLS(self.factory.contextFactory)
class WebSocketAdapterFactory(object):
"""
Adapter class for Twisted-based WebSocket client and server factories.
"""
@public
class WebSocketServerFactory(WebSocketAdapterFactory, protocol.WebSocketServerFactory, twisted.internet.protocol.ServerFactory):
"""
Base class for Twisted-based WebSocket server factories.
Implements :class:`autobahn.websocket.interfaces.IWebSocketServerChannelFactory`
"""
log = txaio.make_logger()
def __init__(self, *args, **kwargs):
"""
.. note::
In addition to all arguments to the constructor of
:meth:`autobahn.websocket.interfaces.IWebSocketServerChannelFactory`,
you can supply a ``reactor`` keyword argument to specify the
Twisted reactor to be used.
"""
# lazy import to avoid reactor install upon module import
reactor = kwargs.pop('reactor', None)
if reactor is None:
from twisted.internet import reactor
self.reactor = reactor
protocol.WebSocketServerFactory.__init__(self, *args, **kwargs)
@public
class WebSocketClientFactory(WebSocketAdapterFactory, protocol.WebSocketClientFactory, twisted.internet.protocol.ClientFactory):
"""
Base class for Twisted-based WebSocket client factories.
Implements :class:`autobahn.websocket.interfaces.IWebSocketClientChannelFactory`
"""
log = txaio.make_logger()
def __init__(self, *args, **kwargs):
"""
.. note::
In addition to all arguments to the constructor of
:func:`autobahn.websocket.interfaces.IWebSocketClientChannelFactory`,
you can supply a ``reactor`` keyword argument to specify the
Twisted reactor to be used.
"""
# lazy import to avoid reactor install upon module import
reactor = kwargs.pop('reactor', None)
if reactor is None:
from twisted.internet import reactor
self.reactor = reactor
protocol.WebSocketClientFactory.__init__(self, *args, **kwargs)
# we must up-call *before* we set up the contextFactory
# because we need self.host etc to be set properly.
if self.isSecure and self.proxy is not None:
# if we have a proxy, then our factory will be used to
# create the connection after CONNECT and if it's doing
# TLS it needs a contextFactory
from twisted.internet import ssl
self.contextFactory = ssl.optionsForClientTLS(self.host)
# NOTE: there's thus no way to send in our own
# context-factory, nor any TLS options.
# Possibly we should allow 'proxy' to contain an actual
# IStreamClientEndpoint instance instead of configuration for
# how to make one
@implementer(ITransport)
class WrappingWebSocketAdapter(object):
"""
An adapter for stream-based transport over WebSocket.
This follows `websockify <https://github.com/kanaka/websockify>`_
and should be compatible with that.
It uses WebSocket subprotocol negotiation and supports the
following WebSocket subprotocols:
- ``binary`` (or a compatible subprotocol)
- ``base64``
Octets are either transmitted as the payload of WebSocket binary
messages when using the ``binary`` subprotocol (or an alternative
binary compatible subprotocol), or encoded with Base64 and then
transmitted as the payload of WebSocket text messages when using
the ``base64`` subprotocol.
"""
def onConnect(self, requestOrResponse):
# Negotiate either the 'binary' or the 'base64' WebSocket subprotocol
if isinstance(requestOrResponse, ConnectionRequest):
request = requestOrResponse
for p in request.protocols:
if p in self.factory._subprotocols:
self._binaryMode = (p != 'base64')
return p
raise ConnectionDeny(ConnectionDeny.NOT_ACCEPTABLE, 'this server only speaks {0} WebSocket subprotocols'.format(self.factory._subprotocols))
elif isinstance(requestOrResponse, ConnectionResponse):
response = requestOrResponse
if response.protocol not in self.factory._subprotocols:
self._fail_connection(protocol.WebSocketProtocol.CLOSE_STATUS_CODE_PROTOCOL_ERROR, 'this client only speaks {0} WebSocket subprotocols'.format(self.factory._subprotocols))
self._binaryMode = (response.protocol != 'base64')
else:
# should not arrive here
raise Exception("logic error")
def onOpen(self):
self._proto.connectionMade()
def onMessage(self, payload, isBinary):
if isBinary != self._binaryMode:
self._fail_connection(protocol.WebSocketProtocol.CLOSE_STATUS_CODE_UNSUPPORTED_DATA, 'message payload type does not match the negotiated subprotocol')
else:
if not isBinary:
try:
payload = b64decode(payload)
except Exception as e:
self._fail_connection(protocol.WebSocketProtocol.CLOSE_STATUS_CODE_INVALID_PAYLOAD, 'message payload base64 decoding error: {0}'.format(e))
self._proto.dataReceived(payload)
# noinspection PyUnusedLocal
def onClose(self, wasClean, code, reason):
self._proto.connectionLost(None)
def write(self, data):
# part of ITransport
assert(type(data) == bytes)
if self._binaryMode:
self.sendMessage(data, isBinary=True)
else:
data = b64encode(data)
self.sendMessage(data, isBinary=False)
def writeSequence(self, data):
# part of ITransport
for d in data:
self.write(d)
def loseConnection(self):
# part of ITransport
self.sendClose()
def getPeer(self):
# part of ITransport
return self.transport.getPeer()
def getHost(self):
# part of ITransport
return self.transport.getHost()
class WrappingWebSocketServerProtocol(WrappingWebSocketAdapter, WebSocketServerProtocol):
"""
Server protocol for stream-based transport over WebSocket.
"""
class WrappingWebSocketClientProtocol(WrappingWebSocketAdapter, WebSocketClientProtocol):
"""
Client protocol for stream-based transport over WebSocket.
"""
class WrappingWebSocketServerFactory(WebSocketServerFactory):
"""
Wrapping server factory for stream-based transport over WebSocket.
"""
def __init__(self,
factory,
url,
reactor=None,
enableCompression=True,
autoFragmentSize=0,
subprotocol=None):
"""
:param factory: Stream-based factory to be wrapped.
:type factory: A subclass of ``twisted.internet.protocol.Factory``
:param url: WebSocket URL of the server this server factory will work for.
:type url: unicode
"""
self._factory = factory
self._subprotocols = ['binary', 'base64']
if subprotocol:
self._subprotocols.append(subprotocol)
WebSocketServerFactory.__init__(self,
url=url,
reactor=reactor,
protocols=self._subprotocols)
# automatically fragment outgoing traffic into WebSocket frames
# of this size
self.setProtocolOptions(autoFragmentSize=autoFragmentSize)
# play nice and perform WS closing handshake
self.setProtocolOptions(failByDrop=False)
if enableCompression:
# Enable WebSocket extension "permessage-deflate".
# Function to accept offers from the client ..
def accept(offers):
for offer in offers:
if isinstance(offer, PerMessageDeflateOffer):
return PerMessageDeflateOfferAccept(offer)
self.setProtocolOptions(perMessageCompressionAccept=accept)
def buildProtocol(self, addr):
proto = WrappingWebSocketServerProtocol()
proto.factory = self
proto._proto = self._factory.buildProtocol(addr)
proto._proto.transport = proto
return proto
def startFactory(self):
self._factory.startFactory()
WebSocketServerFactory.startFactory(self)
def stopFactory(self):
self._factory.stopFactory()
WebSocketServerFactory.stopFactory(self)
class WrappingWebSocketClientFactory(WebSocketClientFactory):
"""
Wrapping client factory for stream-based transport over WebSocket.
"""
def __init__(self,
factory,
url,
reactor=None,
enableCompression=True,
autoFragmentSize=0,
subprotocol=None):
"""
:param factory: Stream-based factory to be wrapped.
:type factory: A subclass of ``twisted.internet.protocol.Factory``
:param url: WebSocket URL of the server this client factory will connect to.
:type url: unicode
"""
self._factory = factory
self._subprotocols = ['binary', 'base64']
if subprotocol:
self._subprotocols.append(subprotocol)
WebSocketClientFactory.__init__(self,
url=url,
reactor=reactor,
protocols=self._subprotocols)
# automatically fragment outgoing traffic into WebSocket frames
# of this size
self.setProtocolOptions(autoFragmentSize=autoFragmentSize)
# play nice and perform WS closing handshake
self.setProtocolOptions(failByDrop=False)
if enableCompression:
# Enable WebSocket extension "permessage-deflate".
# The extensions offered to the server ..
offers = [PerMessageDeflateOffer()]
self.setProtocolOptions(perMessageCompressionOffers=offers)
# Function to accept responses from the server ..
def accept(response):
if isinstance(response, PerMessageDeflateResponse):
return PerMessageDeflateResponseAccept(response)
self.setProtocolOptions(perMessageCompressionAccept=accept)
def buildProtocol(self, addr):
proto = WrappingWebSocketClientProtocol()
proto.factory = self
proto._proto = self._factory.buildProtocol(addr)
proto._proto.transport = proto
return proto
@public
def connectWS(factory, contextFactory=None, timeout=30, bindAddress=None):
"""
Establish WebSocket connection to a server. The connection parameters like target
host, port, resource and others are provided via the factory.
:param factory: The WebSocket protocol factory to be used for creating client protocol instances.
:type factory: An :class:`autobahn.websocket.WebSocketClientFactory` instance.
:param contextFactory: SSL context factory, required for secure WebSocket connections ("wss").
:type contextFactory: A `twisted.internet.ssl.ClientContextFactory <http://twistedmatrix.com/documents/current/api/twisted.internet.ssl.ClientContextFactory.html>`_ instance.
:param timeout: Number of seconds to wait before assuming the connection has failed.
:type timeout: int
:param bindAddress: A (host, port) tuple of local address to bind to, or None.
:type bindAddress: tuple
:returns: The connector.
:rtype: An object which implements `twisted.interface.IConnector <http://twistedmatrix.com/documents/current/api/twisted.internet.interfaces.IConnector.html>`_.
"""
# lazy import to avoid reactor install upon module import
if hasattr(factory, 'reactor'):
reactor = factory.reactor
else:
from twisted.internet import reactor
if factory.isSecure:
if contextFactory is None:
# create default client SSL context factory when none given
from twisted.internet import ssl
contextFactory = ssl.ClientContextFactory()
if factory.proxy is not None:
factory.contextFactory = contextFactory
conn = reactor.connectTCP(factory.proxy['host'], factory.proxy['port'], factory, timeout, bindAddress)
else:
if factory.isSecure:
conn = reactor.connectSSL(factory.host, factory.port, factory, contextFactory, timeout, bindAddress)
else:
conn = reactor.connectTCP(factory.host, factory.port, factory, timeout, bindAddress)
return conn
@public
def listenWS(factory, contextFactory=None, backlog=50, interface=''):
"""
Listen for incoming WebSocket connections from clients. The connection parameters like
listening port and others are provided via the factory.
:param factory: The WebSocket protocol factory to be used for creating server protocol instances.
:type factory: An :class:`autobahn.websocket.WebSocketServerFactory` instance.
:param contextFactory: SSL context factory, required for secure WebSocket connections ("wss").
:type contextFactory: A twisted.internet.ssl.ContextFactory.
:param backlog: Size of the listen queue.
:type backlog: int
:param interface: The interface (derived from hostname given) to bind to, defaults to '' (all).
:type interface: str
:returns: The listening port.
:rtype: An object that implements `twisted.interface.IListeningPort <http://twistedmatrix.com/documents/current/api/twisted.internet.interfaces.IListeningPort.html>`_.
"""
# lazy import to avoid reactor install upon module import
if hasattr(factory, 'reactor'):
reactor = factory.reactor
else:
from twisted.internet import reactor
if factory.isSecure:
if contextFactory is None:
raise Exception("Secure WebSocket listen requested, but no SSL context factory given")
listener = reactor.listenSSL(factory.port, factory, contextFactory, backlog, interface)
else:
listener = reactor.listenTCP(factory.port, factory, backlog, interface)
return listener
@public
class WampWebSocketServerProtocol(websocket.WampWebSocketServerProtocol, WebSocketServerProtocol):
"""
Twisted-based WAMP-over-WebSocket server protocol.
Implements:
* :class:`autobahn.wamp.interfaces.ITransport`
"""
@public
class WampWebSocketServerFactory(websocket.WampWebSocketServerFactory, WebSocketServerFactory):
"""
Twisted-based WAMP-over-WebSocket server protocol factory.
"""
protocol = WampWebSocketServerProtocol
def __init__(self, factory, *args, **kwargs):
"""
:param factory: A callable that produces instances that implement
:class:`autobahn.wamp.interfaces.ITransportHandler`
:type factory: callable
:param serializers: A list of WAMP serializers to use (or ``None``
for all available serializers).
:type serializers: list of objects implementing
:class:`autobahn.wamp.interfaces.ISerializer`
"""
serializers = kwargs.pop('serializers', None)
websocket.WampWebSocketServerFactory.__init__(self, factory, serializers)
kwargs['protocols'] = self._protocols
# noinspection PyCallByClass
WebSocketServerFactory.__init__(self, *args, **kwargs)
@public
class WampWebSocketClientProtocol(websocket.WampWebSocketClientProtocol, WebSocketClientProtocol):
"""
Twisted-based WAMP-over-WebSocket client protocol.
Implements:
* :class:`autobahn.wamp.interfaces.ITransport`
"""
@public
class WampWebSocketClientFactory(websocket.WampWebSocketClientFactory, WebSocketClientFactory):
"""
Twisted-based WAMP-over-WebSocket client protocol factory.
"""
protocol = WampWebSocketClientProtocol
def __init__(self, factory, *args, **kwargs):
"""
:param factory: A callable that produces instances that implement
:class:`autobahn.wamp.interfaces.ITransportHandler`
:type factory: callable
:param serializer: The WAMP serializer to use (or ``None`` for
"best" serializer, chosen as the first serializer available from
this list: CBOR, MessagePack, UBJSON, JSON).
:type serializer: object implementing :class:`autobahn.wamp.interfaces.ISerializer`
"""
serializers = kwargs.pop('serializers', None)
websocket.WampWebSocketClientFactory.__init__(self, factory, serializers)
kwargs['protocols'] = self._protocols
WebSocketClientFactory.__init__(self, *args, **kwargs)
# Reduce the factory logs noise
self.noisy = False

View File

@@ -0,0 +1,107 @@
###############################################################################
#
# The MIT License (MIT)
#
# Copyright (c) typedef int GmbH
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
###############################################################################
import sys
try:
from autobahn import xbr # noqa
HAS_XBR = True
except ImportError as e:
sys.stderr.write('WARNING: could not import autobahn.xbr - {}\n'.format(e))
HAS_XBR = False
if HAS_XBR:
import txaio
txaio.use_twisted()
from twisted.internet.threads import deferToThread
from twisted.internet.task import LoopingCall
from twisted.internet.defer import ensureDeferred
import uuid
from autobahn.util import hl
from autobahn.xbr._interfaces import IProvider, ISeller, IConsumer, IBuyer, IDelegate
from autobahn.xbr import _seller, _buyer, _blockchain
class SimpleBlockchain(_blockchain.SimpleBlockchain):
log = txaio.make_logger()
backgroundCaller = deferToThread
class KeySeries(_seller.KeySeries):
log = txaio.make_logger()
def __init__(self, api_id, price, interval=None, count=None, on_rotate=None):
super().__init__(api_id, price, interval, count, on_rotate)
self.running = False
self._run_loop = None
self._started = None
async def start(self):
"""
Start offering and selling data encryption keys in the background.
"""
assert self._run_loop is None
self.log.info('Starting key rotation every {interval} seconds for api_id="{api_id}" ..',
interval=hl(self._interval), api_id=hl(uuid.UUID(bytes=self._api_id)))
self.running = True
self._run_loop = LoopingCall(lambda: ensureDeferred(self._rotate()))
self._started = self._run_loop.start(self._interval)
return self._started
def stop(self):
"""
Stop offering/selling data encryption keys.
"""
if not self._run_loop:
raise RuntimeError('cannot stop {} - not currently running'.format(self.__class__.__name__))
self._run_loop.stop()
self._run_loop = None
return self._started
class SimpleSeller(_seller.SimpleSeller):
"""
Simple XBR seller component. This component can be used by a XBR seller delegate to
handle the automated selling of data encryption keys to the XBR market maker.
"""
log = txaio.make_logger()
KeySeries = KeySeries
class SimpleBuyer(_buyer.SimpleBuyer):
log = txaio.make_logger()
ISeller.register(SimpleSeller)
IProvider.register(SimpleSeller)
IDelegate.register(SimpleSeller)
IBuyer.register(SimpleBuyer)
IConsumer.register(SimpleBuyer)
IDelegate.register(SimpleBuyer)