okay fine

This commit is contained in:
pacnpal
2024-11-03 17:47:26 +00:00
parent 387c4740e7
commit 27f3326e22
10020 changed files with 1935769 additions and 2364 deletions

View File

@@ -0,0 +1,53 @@
###############################################################################
#
# The MIT License (MIT)
#
# Copyright (c) typedef int GmbH
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
###############################################################################
import platform
import autobahn
# WebSocket protocol support
from autobahn.asyncio.websocket import \
WebSocketServerProtocol, \
WebSocketClientProtocol, \
WebSocketServerFactory, \
WebSocketClientFactory
# WAMP support
from autobahn.asyncio.wamp import ApplicationSession
__all__ = (
'WebSocketServerProtocol',
'WebSocketClientProtocol',
'WebSocketServerFactory',
'WebSocketClientFactory',
'ApplicationSession',
)
__ident__ = 'Autobahn/{}-asyncio-{}/{}'.format(autobahn.__version__, platform.python_implementation(), platform.python_version())
"""
AutobahnPython library implementation (eg. "Autobahn/0.13.0-asyncio-CPython/3.5.1")
"""

View File

@@ -0,0 +1,417 @@
###############################################################################
#
# The MIT License (MIT)
#
# Copyright (c) typedef int GmbH
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
###############################################################################
import asyncio
import ssl
import signal
from functools import wraps
import txaio
from autobahn.asyncio.websocket import WampWebSocketClientFactory
from autobahn.asyncio.rawsocket import WampRawSocketClientFactory
from autobahn.wamp import component
from autobahn.wamp.exception import TransportLost
from autobahn.asyncio.wamp import Session
from autobahn.wamp.serializer import create_transport_serializers, create_transport_serializer
__all__ = ('Component', 'run')
def _unique_list(seq):
"""
Return a list with unique elements from sequence, preserving order.
"""
seen = set()
return [x for x in seq if x not in seen and not seen.add(x)]
def _camel_case_from_snake_case(snake):
parts = snake.split('_')
return parts[0] + ''.join(s.capitalize() for s in parts[1:])
def _create_transport_factory(loop, transport, session_factory):
"""
Create a WAMP-over-XXX transport factory.
"""
if transport.type == 'websocket':
serializers = create_transport_serializers(transport)
factory = WampWebSocketClientFactory(
session_factory,
url=transport.url,
serializers=serializers,
proxy=transport.proxy, # either None or a dict with host, port
)
elif transport.type == 'rawsocket':
serializer = create_transport_serializer(transport.serializers[0])
factory = WampRawSocketClientFactory(session_factory, serializer=serializer)
else:
assert(False), 'should not arrive here'
# set the options one at a time so we can give user better feedback
for k, v in transport.options.items():
try:
factory.setProtocolOptions(**{k: v})
except (TypeError, KeyError):
# this allows us to document options as snake_case
# until everything internally is upgraded from
# camelCase
try:
factory.setProtocolOptions(
**{_camel_case_from_snake_case(k): v}
)
except (TypeError, KeyError):
raise ValueError(
"Unknown {} transport option: {}={}".format(transport.type, k, v)
)
return factory
class Component(component.Component):
"""
A component establishes a transport and attached a session
to a realm using the transport for communication.
The transports a component tries to use can be configured,
as well as the auto-reconnect strategy.
"""
log = txaio.make_logger()
session_factory = Session
"""
The factory of the session we will instantiate.
"""
def _is_ssl_error(self, e):
"""
Internal helper.
"""
return isinstance(e, ssl.SSLError)
def _check_native_endpoint(self, endpoint):
if isinstance(endpoint, dict):
if 'tls' in endpoint:
tls = endpoint['tls']
if isinstance(tls, (dict, bool)):
pass
elif isinstance(tls, ssl.SSLContext):
pass
else:
raise ValueError(
"'tls' configuration must be a dict, bool or "
"SSLContext instance"
)
else:
raise ValueError(
"'endpoint' configuration must be a dict or IStreamClientEndpoint"
" provider"
)
# async function
def _connect_transport(self, loop, transport, session_factory, done):
"""
Create and connect a WAMP-over-XXX transport.
"""
factory = _create_transport_factory(loop, transport, session_factory)
# XXX the rest of this should probably be factored into its
# own method (or three!)...
if transport.proxy:
timeout = transport.endpoint.get('timeout', 10) # in seconds
if type(timeout) != int:
raise ValueError('invalid type {} for timeout in client endpoint configuration'.format(type(timeout)))
# do we support HTTPS proxies?
f = loop.create_connection(
protocol_factory=factory,
host=transport.proxy['host'],
port=transport.proxy['port'],
)
time_f = asyncio.ensure_future(asyncio.wait_for(f, timeout=timeout))
return self._wrap_connection_future(transport, done, time_f)
elif transport.endpoint['type'] == 'tcp':
version = transport.endpoint.get('version', 4)
if version not in [4, 6]:
raise ValueError('invalid IP version {} in client endpoint configuration'.format(version))
host = transport.endpoint['host']
if type(host) != str:
raise ValueError('invalid type {} for host in client endpoint configuration'.format(type(host)))
port = transport.endpoint['port']
if type(port) != int:
raise ValueError('invalid type {} for port in client endpoint configuration'.format(type(port)))
timeout = transport.endpoint.get('timeout', 10) # in seconds
if type(timeout) != int:
raise ValueError('invalid type {} for timeout in client endpoint configuration'.format(type(timeout)))
tls = transport.endpoint.get('tls', None)
tls_hostname = None
# create a TLS enabled connecting TCP socket
if tls:
if isinstance(tls, dict):
for k in tls.keys():
if k not in ["hostname", "trust_root"]:
raise ValueError("Invalid key '{}' in 'tls' config".format(k))
hostname = tls.get('hostname', host)
if type(hostname) != str:
raise ValueError('invalid type {} for hostname in TLS client endpoint configuration'.format(hostname))
cert_fname = tls.get('trust_root', None)
tls_hostname = hostname
tls = True
if cert_fname is not None:
tls = ssl.create_default_context(
purpose=ssl.Purpose.SERVER_AUTH,
cafile=cert_fname,
)
elif isinstance(tls, ssl.SSLContext):
# tls=<an SSLContext> is valid
tls_hostname = host
elif tls in [False, True]:
if tls:
tls_hostname = host
else:
raise RuntimeError('unknown type {} for "tls" configuration in transport'.format(type(tls)))
f = loop.create_connection(
protocol_factory=factory,
host=host,
port=port,
ssl=tls,
server_hostname=tls_hostname,
)
time_f = asyncio.ensure_future(asyncio.wait_for(f, timeout=timeout))
return self._wrap_connection_future(transport, done, time_f)
elif transport.endpoint['type'] == 'unix':
path = transport.endpoint['path']
timeout = int(transport.endpoint.get('timeout', 10)) # in seconds
f = loop.create_unix_connection(
protocol_factory=factory,
path=path,
)
time_f = asyncio.ensure_future(asyncio.wait_for(f, timeout=timeout))
return self._wrap_connection_future(transport, done, time_f)
else:
assert(False), 'should not arrive here'
def _wrap_connection_future(self, transport, done, conn_f):
def on_connect_success(result):
# async connect call returns a 2-tuple
transport, proto = result
# in the case where we .abort() the transport / connection
# during setup, we still get on_connect_success but our
# transport is already closed (this will happen if
# e.g. there's an "open handshake timeout") -- I don't
# know if there's a "better" way to detect this? #python
# doesn't know of one, anyway
if transport.is_closing():
if not txaio.is_called(done):
reason = getattr(proto, "_onclose_reason", "Connection already closed")
txaio.reject(done, TransportLost(reason))
return
# if e.g. an SSL handshake fails, we will have
# successfully connected (i.e. get here) but need to
# 'listen' for the "connection_lost" from the underlying
# protocol in case of handshake failure .. so we wrap
# it. Also, we don't increment transport.success_count
# here on purpose (because we might not succeed).
# XXX double-check that asyncio behavior on TLS handshake
# failures is in fact as described above
orig = proto.connection_lost
@wraps(orig)
def lost(fail):
rtn = orig(fail)
if not txaio.is_called(done):
# asyncio will call connection_lost(None) in case of
# a transport failure, in which case we create an
# appropriate exception
if fail is None:
fail = TransportLost("failed to complete connection")
txaio.reject(done, fail)
return rtn
proto.connection_lost = lost
def on_connect_failure(err):
transport.connect_failures += 1
# failed to establish a connection in the first place
txaio.reject(done, err)
txaio.add_callbacks(conn_f, on_connect_success, None)
# the errback is added as a second step so it gets called if
# there as an error in on_connect_success itself.
txaio.add_callbacks(conn_f, None, on_connect_failure)
return conn_f
# async function
def start(self, loop=None):
"""
This starts the Component, which means it will start connecting
(and re-connecting) to its configured transports. A Component
runs until it is "done", which means one of:
- There was a "main" function defined, and it completed successfully;
- Something called ``.leave()`` on our session, and we left successfully;
- ``.stop()`` was called, and completed successfully;
- none of our transports were able to connect successfully (failure);
:returns: a Future which will resolve (to ``None``) when we are
"done" or with an error if something went wrong.
"""
if loop is None:
self.log.warn("Using default loop")
loop = asyncio.get_event_loop()
return self._start(loop=loop)
def run(components, start_loop=True, log_level='info'):
"""
High-level API to run a series of components.
This will only return once all the components have stopped
(including, possibly, after all re-connections have failed if you
have re-connections enabled). Under the hood, this calls
XXX fixme for asyncio
-- if you wish to manage the loop yourself, use the
:meth:`autobahn.asyncio.component.Component.start` method to start
each component yourself.
:param components: the Component(s) you wish to run
:type components: instance or list of :class:`autobahn.asyncio.component.Component`
:param start_loop: When ``True`` (the default) this method
start a new asyncio loop.
:type start_loop: bool
:param log_level: a valid log-level (or None to avoid calling start_logging)
:type log_level: string
"""
# actually, should we even let people "not start" the logging? I'm
# not sure that's wise... (double-check: if they already called
# txaio.start_logging() what happens if we call it again?)
if log_level is not None:
txaio.start_logging(level=log_level)
loop = asyncio.get_event_loop()
if loop.is_closed():
asyncio.set_event_loop(asyncio.new_event_loop())
loop = asyncio.get_event_loop()
txaio.config.loop = loop
log = txaio.make_logger()
# see https://github.com/python/asyncio/issues/341 asyncio has
# "odd" handling of KeyboardInterrupt when using Tasks (as
# run_until_complete does). Another option is to just resture
# default SIGINT handling, which is to exit:
# import signal
# signal.signal(signal.SIGINT, signal.SIG_DFL)
async def nicely_exit(signal):
log.info("Shutting down due to {signal}", signal=signal)
try:
tasks = asyncio.Task.all_tasks()
except AttributeError:
# this changed with python >= 3.7
tasks = asyncio.all_tasks()
for task in tasks:
# Do not cancel the current task.
try:
current_task = asyncio.Task.current_task()
except AttributeError:
current_task = asyncio.current_task()
if task is not current_task:
task.cancel()
def cancel_all_callback(fut):
try:
fut.result()
except asyncio.CancelledError:
log.debug("All task cancelled")
except Exception as e:
log.error("Error while shutting down: {exception}", exception=e)
finally:
loop.stop()
fut = asyncio.gather(*tasks)
fut.add_done_callback(cancel_all_callback)
try:
loop.add_signal_handler(signal.SIGINT, lambda: asyncio.ensure_future(nicely_exit("SIGINT")))
loop.add_signal_handler(signal.SIGTERM, lambda: asyncio.ensure_future(nicely_exit("SIGTERM")))
except NotImplementedError:
# signals are not available on Windows
pass
def done_callback(loop, arg):
loop.stop()
# returns a future; could run_until_complete() but see below
component._run(loop, components, done_callback)
if start_loop:
try:
loop.run_forever()
# this is probably more-correct, but then you always get
# "Event loop stopped before Future completed":
# loop.run_until_complete(f)
except asyncio.CancelledError:
pass
# finally:
# signal.signal(signal.SIGINT, signal.SIG_DFL)
# signal.signal(signal.SIGTERM, signal.SIG_DFL)
# Close the event loop at the end, otherwise an exception is
# thrown. https://bugs.python.org/issue23548
loop.close()

View File

@@ -0,0 +1,517 @@
###############################################################################
#
# The MIT License (MIT)
#
# Copyright (c) typedef int GmbH
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
###############################################################################
import asyncio
import struct
import math
import copy
from typing import Optional
import txaio
from autobahn.util import public, _LazyHexFormatter, hltype
from autobahn.wamp.exception import ProtocolError, SerializationError, TransportLost
from autobahn.wamp.types import TransportDetails
from autobahn.asyncio.util import get_serializers, create_transport_details, transport_channel_id
__all__ = (
'WampRawSocketServerProtocol',
'WampRawSocketClientProtocol',
'WampRawSocketServerFactory',
'WampRawSocketClientFactory'
)
FRAME_TYPE_DATA = 0
FRAME_TYPE_PING = 1
FRAME_TYPE_PONG = 2
MAGIC_BYTE = 0x7F
class PrefixProtocol(asyncio.Protocol):
prefix_format = '!L'
prefix_length = struct.calcsize(prefix_format)
max_length = 16 * 1024 * 1024
max_length_send = max_length
log = txaio.make_logger() # @UndefinedVariable
peer: Optional[str] = None
is_server: Optional[bool] = None
@property
def transport_details(self) -> Optional[TransportDetails]:
"""
Implements :func:`autobahn.wamp.interfaces.ITransport.transport_details`
"""
return self._transport_details
def connection_made(self, transport):
# asyncio networking framework entry point, called by asyncio
# when the connection is established (either a client or a server)
self.log.debug('RawSocker Asyncio: Connection made with peer {peer}', peer=self.peer)
self.transport = transport
# determine preliminary transport details (what is known at this point)
self._transport_details = create_transport_details(self.transport, self.is_server)
self._transport_details.channel_framing = TransportDetails.CHANNEL_FRAMING_RAWSOCKET
# backward compatibility
self.peer = self._transport_details.peer
self._buffer = b''
self._header = None
self._wait_closed = txaio.create_future()
@property
def is_closed(self):
if hasattr(self, '_wait_closed'):
return self._wait_closed
else:
f = txaio.create_future()
f.set_result(True)
return f
def connection_lost(self, exc):
self.log.debug('RawSocker Asyncio: Connection lost')
self.transport = None
self._wait_closed.set_result(True)
self._on_connection_lost(exc)
def _on_connection_lost(self, exc):
pass
def protocol_error(self, msg):
self.log.error(msg)
self.transport.close()
def sendString(self, data):
l = len(data)
if l > self.max_length_send:
raise ValueError('Data too big')
header = struct.pack(self.prefix_format, len(data))
self.transport.write(header)
self.transport.write(data)
def ping(self, data):
raise NotImplementedError()
def pong(self, data):
raise NotImplementedError()
def data_received(self, data):
self._buffer += data
pos = 0
remaining = len(self._buffer)
while remaining >= self.prefix_length:
# do not recalculate header if available from previous call
if self._header:
frame_type, frame_length = self._header
else:
header = self._buffer[pos:pos + self.prefix_length]
frame_type = ord(header[0:1]) & 0b00000111
if frame_type > FRAME_TYPE_PONG:
self.protocol_error('Invalid frame type')
return
frame_length = struct.unpack(self.prefix_format, b'\0' + header[1:])[0]
if frame_length > self.max_length:
self.protocol_error('Frame too big')
return
if remaining - self.prefix_length >= frame_length:
self._header = None
pos += self.prefix_length
remaining -= self.prefix_length
data = self._buffer[pos:pos + frame_length]
pos += frame_length
remaining -= frame_length
if frame_type == FRAME_TYPE_DATA:
self.stringReceived(data)
elif frame_type == FRAME_TYPE_PING:
self.ping(data)
elif frame_type == FRAME_TYPE_PONG:
self.pong(data)
else:
# save heaader
self._header = frame_type, frame_length
break
self._buffer = self._buffer[pos:]
def stringReceived(self, data):
raise NotImplementedError()
class RawSocketProtocol(PrefixProtocol):
def __init__(self):
max_size = None
if max_size:
exp = int(math.ceil(math.log(max_size, 2))) - 9
if exp > 15:
raise ValueError('Maximum length is 16M')
self.max_length = 2**(exp + 9)
self._length_exp = exp
else:
self._length_exp = 15
self.max_length = 2**24
def connection_made(self, transport):
PrefixProtocol.connection_made(self, transport)
self._handshake_done = False
def _on_handshake_complete(self):
raise NotImplementedError()
def parse_handshake(self):
buf = bytearray(self._buffer[:4])
if buf[0] != MAGIC_BYTE:
raise HandshakeError('Invalid magic byte in handshake')
ser = buf[1] & 0x0F
lexp = buf[1] >> 4
self.max_length_send = 2**(lexp + 9)
if buf[2] != 0 or buf[3] != 0:
raise HandshakeError('Reserved bytes must be zero')
return ser, lexp
def process_handshake(self):
raise NotImplementedError()
def data_received(self, data):
self.log.debug('RawSocker Asyncio: data received {data}', data=_LazyHexFormatter(data))
if self._handshake_done:
return PrefixProtocol.data_received(self, data)
else:
self._buffer += data
if len(self._buffer) >= 4:
try:
self.process_handshake()
except HandshakeError as e:
self.protocol_error('Handshake error : {err}'.format(err=e))
return
self._handshake_done = True
self._on_handshake_complete()
data = self._buffer[4:]
self._buffer = b''
if data:
PrefixProtocol.data_received(self, data)
ERR_SERIALIZER_UNSUPPORTED = 1
ERRMAP = {
0: "illegal (must not be used)",
1: "serializer unsupported",
2: "maximum message length unacceptable",
3: "use of reserved bits (unsupported feature)",
4: "maximum connection count reached"
}
class HandshakeError(Exception):
def __init__(self, msg, code=0):
Exception.__init__(self, msg if not code else msg + ' : %s' % ERRMAP.get(code))
class RawSocketClientProtocol(RawSocketProtocol):
is_server = False
def check_serializer(self, ser_id):
return True
def process_handshake(self):
ser_id, err = self.parse_handshake()
if ser_id == 0:
raise HandshakeError('Server returned handshake error', err)
if self.serializer_id != ser_id:
raise HandshakeError('Server returned different serializer {0} then requested {1}'
.format(ser_id, self.serializer_id))
@property
def serializer_id(self):
raise NotImplementedError()
def connection_made(self, transport):
RawSocketProtocol.connection_made(self, transport)
# start handshake
hs = bytes(bytearray([MAGIC_BYTE,
self._length_exp << 4 | self.serializer_id,
0, 0]))
transport.write(hs)
self.log.debug('RawSocket Asyncio: Client handshake sent')
class RawSocketServerProtocol(RawSocketProtocol):
is_server = True
def supports_serializer(self, ser_id):
raise NotImplementedError()
def process_handshake(self):
def send_response(lexp, ser_id):
b2 = lexp << 4 | (ser_id & 0x0f)
self.transport.write(bytes(bytearray([MAGIC_BYTE, b2, 0, 0])))
ser_id, _lexp = self.parse_handshake()
if not self.supports_serializer(ser_id):
send_response(ERR_SERIALIZER_UNSUPPORTED, 0)
raise HandshakeError('Serializer unsupported : {ser_id}'.format(ser_id=ser_id))
send_response(self._length_exp, ser_id)
# this is transport independent part of WAMP protocol
class WampRawSocketMixinGeneral(object):
def _on_handshake_complete(self):
self.log.debug("WampRawSocketProtocol: Handshake complete")
# RawSocket connection established. Now let the user WAMP session factory
# create a new WAMP session and fire off session open callback.
try:
if self._transport_details.is_secure:
# now that the TLS opening handshake is complete, the actual TLS channel ID
# will be available. make sure to set it!
channel_id = {
'tls-unique': transport_channel_id(self.transport, self._transport_details.is_server, 'tls-unique'),
}
self._transport_details.channel_id = channel_id
self._session = self.factory._factory()
self._session.onOpen(self)
except Exception as e:
# Exceptions raised in onOpen are fatal ..
self.log.warn("WampRawSocketProtocol: ApplicationSession constructor / onOpen raised ({err})", err=e)
self.abort()
else:
self.log.info("ApplicationSession started.")
def stringReceived(self, payload):
self.log.debug("WampRawSocketProtocol: RX octets: {octets}", octets=_LazyHexFormatter(payload))
try:
for msg in self._serializer.unserialize(payload):
self.log.debug("WampRawSocketProtocol: RX WAMP message: {msg}", msg=msg)
self._session.onMessage(msg)
except ProtocolError as e:
self.log.warn("WampRawSocketProtocol: WAMP Protocol Error ({err}) - aborting connection", err=e)
self.abort()
except Exception as e:
self.log.warn("WampRawSocketProtocol: WAMP Internal Error ({err}) - aborting connection", err=e)
self.abort()
def send(self, msg):
"""
Implements :func:`autobahn.wamp.interfaces.ITransport.send`
"""
if self.isOpen():
self.log.debug('{func}: TX WAMP message: {msg}', func=hltype(self.send), msg=msg)
try:
payload, _ = self._serializer.serialize(msg)
except Exception as e:
# all exceptions raised from above should be serialization errors ..
raise SerializationError("WampRawSocketProtocol: unable to serialize WAMP application payload ({0})"
.format(e))
else:
self.sendString(payload)
self.log.debug("WampRawSocketProtocol: TX octets: {octets}", octets=_LazyHexFormatter(payload))
else:
raise TransportLost()
def isOpen(self):
"""
Implements :func:`autobahn.wamp.interfaces.ITransport.isOpen`
"""
return hasattr(self, '_session') and self._session is not None
# this is asyncio dependent part of WAMP protocol
class WampRawSocketMixinAsyncio(object):
"""
Base class for asyncio-based WAMP-over-RawSocket protocols.
"""
def _on_connection_lost(self, exc):
try:
wasClean = exc is None
self._session.onClose(wasClean)
except Exception as e:
# silently ignore exceptions raised here ..
self.log.warn("WampRawSocketProtocol: ApplicationSession.onClose raised ({err})", err=e)
self._session = None
def close(self):
"""
Implements :func:`autobahn.wamp.interfaces.ITransport.close`
"""
if self.isOpen():
self.transport.close()
else:
raise TransportLost()
def abort(self):
"""
Implements :func:`autobahn.wamp.interfaces.ITransport.abort`
"""
if self.isOpen():
if hasattr(self.transport, 'abort'):
# ProcessProtocol lacks abortConnection()
self.transport.abort()
else:
self.transport.close()
else:
raise TransportLost()
@public
class WampRawSocketServerProtocol(WampRawSocketMixinGeneral, WampRawSocketMixinAsyncio, RawSocketServerProtocol):
"""
asyncio-based WAMP-over-RawSocket server protocol.
Implements:
* :class:`autobahn.wamp.interfaces.ITransport`
"""
def supports_serializer(self, ser_id):
if ser_id in self.factory._serializers:
self._serializer = copy.copy(self.factory._serializers[ser_id])
self.log.debug(
"WampRawSocketProtocol: client wants to use serializer '{serializer}'",
serializer=ser_id,
)
return True
else:
self.log.debug(
"WampRawSocketProtocol: opening handshake - no suitable serializer found (client requested {serializer}, and we have {serializers}",
serializer=ser_id,
serializers=self.factory._serializers.keys(),
)
self.abort()
return False
@public
class WampRawSocketClientProtocol(WampRawSocketMixinGeneral, WampRawSocketMixinAsyncio, RawSocketClientProtocol):
"""
asyncio-based WAMP-over-RawSocket client protocol.
Implements:
* :class:`autobahn.wamp.interfaces.ITransport`
"""
@property
def serializer_id(self):
if not hasattr(self, '_serializer'):
self._serializer = copy.copy(self.factory._serializer)
return self._serializer.RAWSOCKET_SERIALIZER_ID
class WampRawSocketFactory(object):
"""
Adapter class for asyncio-based WebSocket client and server factories.def dataReceived(self, data):
"""
log = txaio.make_logger()
@public
def __call__(self):
proto = self.protocol()
proto.factory = self
return proto
@public
class WampRawSocketServerFactory(WampRawSocketFactory):
"""
asyncio-based WAMP-over-RawSocket server protocol factory.
"""
protocol = WampRawSocketServerProtocol
def __init__(self, factory, serializers=None):
"""
:param factory: A callable that produces instances that implement
:class:`autobahn.wamp.interfaces.ITransportHandler`
:type factory: callable
:param serializers: A list of WAMP serializers to use (or ``None``
for all available serializers).
:type serializers: list of objects implementing
:class:`autobahn.wamp.interfaces.ISerializer`
"""
if callable(factory):
self._factory = factory
else:
self._factory = lambda: factory
# when no serializers were requested specifically, then support
# all that are available
if serializers is None:
serializers = [serializer() for serializer in get_serializers()]
if not serializers:
raise Exception("could not import any WAMP serializers")
self._serializers = {ser.RAWSOCKET_SERIALIZER_ID: ser for ser in serializers}
@public
class WampRawSocketClientFactory(WampRawSocketFactory):
"""
asyncio-based WAMP-over-RawSocket client factory.
"""
protocol = WampRawSocketClientProtocol
def __init__(self, factory, serializer=None):
"""
:param factory: A callable that produces instances that implement
:class:`autobahn.wamp.interfaces.ITransportHandler`
:type factory: callable
:param serializer: The WAMP serializer to use (or ``None`` for
"best" serializer, chosen as the first serializer available from
this list: CBOR, MessagePack, UBJSON, JSON).
:type serializer: object implementing :class:`autobahn.wamp.interfaces.ISerializer`
"""
if callable(factory):
self._factory = factory
else:
self._factory = lambda: factory
# when no serializer was requested specifically, use the first
# one available
if serializer is None:
serializers = get_serializers()
if serializers:
serializer = serializers[0]()
if serializer is None:
raise Exception("could not import any WAMP serializer")
self._serializer = serializer

View File

@@ -0,0 +1,26 @@
**DO NOT ADD a __init__.py file in this directory**
"Why not?" you ask; read on!
1. If we're running asyncio tests, we can't ever call txaio.use_twisted()
2. If we're running twisted tests, we can't ever call txaio.use_asycnio()...
3. ...and these are decided/called at import time
4. so: we can't *import* any of the autobahn.asyncio.* modules if we're
running twisted tests (or vice versa)
5. ...but test-runners (py.test and trial) import things automagically
(to "discover" tests)
6. We use py.test to run asyncio tests; see "setup.cfg" where we tell
it "norecursedirs = autobahn/twisted/*" so it doesn't ipmort twisted
stuff (and hence call txaio.use_twisted())
7. We use trial to run twisted tests; the lack of __init__ in here
stops it from trying to import this (and hence the parent
package). (The only files matching test_*.py are in this
directory.)
*Therefore*, we don't put a __init__ file in this directory.

View File

@@ -0,0 +1,235 @@
import pytest
import os
from unittest.mock import Mock, call
from autobahn.asyncio.rawsocket import PrefixProtocol, RawSocketClientProtocol, RawSocketServerProtocol, \
WampRawSocketClientFactory, WampRawSocketServerFactory
from autobahn.asyncio.util import get_serializers
from autobahn.wamp import message
from autobahn.wamp.types import TransportDetails
@pytest.mark.skipif(not os.environ.get('USE_ASYNCIO', False), reason='test runs on asyncio only')
def test_sers(event_loop):
serializers = get_serializers()
assert len(serializers) > 0
m = serializers[0]().serialize(message.Abort('close'))
assert m
@pytest.mark.skipif(not os.environ.get('USE_ASYNCIO', False), reason='test runs on asyncio only')
def test_prefix(event_loop):
p = PrefixProtocol()
transport = Mock()
receiver = Mock()
p.stringReceived = receiver
p.connection_made(transport)
small_msg = b'\x00\x00\x00\x04abcd'
p.data_received(small_msg)
receiver.assert_called_once_with(b'abcd')
assert len(p._buffer) == 0
p.sendString(b'abcd')
# print(transport.write.call_args_list)
transport.write.assert_has_calls([call(b'\x00\x00\x00\x04'), call(b'abcd')])
transport.reset_mock()
receiver.reset_mock()
big_msg = b'\x00\x00\x00\x0C' + b'0123456789AB'
p.data_received(big_msg[0:2])
assert not receiver.called
p.data_received(big_msg[2:6])
assert not receiver.called
p.data_received(big_msg[6:11])
assert not receiver.called
p.data_received(big_msg[11:16])
receiver.assert_called_once_with(b'0123456789AB')
transport.reset_mock()
receiver.reset_mock()
two_messages = b'\x00\x00\x00\x04' + b'abcd' + b'\x00\x00\x00\x05' + b'12345' + b'\x00'
p.data_received(two_messages)
receiver.assert_has_calls([call(b'abcd'), call(b'12345')])
assert p._buffer == b'\x00'
@pytest.mark.skipif(not os.environ.get('USE_ASYNCIO', False), reason='test runs on asyncio only')
def test_is_closed(event_loop):
class CP(RawSocketClientProtocol):
@property
def serializer_id(self):
return 1
client = CP()
on_hs = Mock()
transport = Mock()
receiver = Mock()
client.stringReceived = receiver
client._on_handshake_complete = on_hs
assert client.is_closed.done()
client.connection_made(transport)
assert not client.is_closed.done()
client.connection_lost(None)
assert client.is_closed.done()
@pytest.mark.skipif(not os.environ.get('USE_ASYNCIO', False), reason='test runs on asyncio only')
def test_raw_socket_server1(event_loop):
server = RawSocketServerProtocol()
ser = Mock(return_value=True)
on_hs = Mock()
transport = Mock()
receiver = Mock()
server.supports_serializer = ser
server.stringReceived = receiver
server._on_handshake_complete = on_hs
server.stringReceived = receiver
server.connection_made(transport)
hs = b'\x7F\xF1\x00\x00' + b'\x00\x00\x00\x04abcd'
server.data_received(hs)
ser.assert_called_once_with(1)
on_hs.assert_called_once_with()
assert transport.write.called
transport.write.assert_called_once_with(b'\x7F\xF1\x00\x00')
assert not transport.close.called
receiver.assert_called_once_with(b'abcd')
@pytest.mark.skipif(not os.environ.get('USE_ASYNCIO', False), reason='test runs on asyncio only')
def test_raw_socket_server_errors(event_loop):
server = RawSocketServerProtocol()
ser = Mock(return_value=True)
on_hs = Mock()
transport = Mock()
receiver = Mock()
server.supports_serializer = ser
server.stringReceived = receiver
server._on_handshake_complete = on_hs
server.stringReceived = receiver
server.connection_made(transport)
server.data_received(b'abcdef')
transport.close.assert_called_once_with()
server = RawSocketServerProtocol()
ser = Mock(return_value=False)
on_hs = Mock()
transport = Mock(spec_set=('close', 'write', 'get_extra_info'))
receiver = Mock()
server.supports_serializer = ser
server.stringReceived = receiver
server._on_handshake_complete = on_hs
server.stringReceived = receiver
server.connection_made(transport)
server.data_received(b'\x7F\xF1\x00\x00')
transport.close.assert_called_once_with()
transport.write.assert_called_once_with(b'\x7F\x10\x00\x00')
@pytest.mark.skipif(not os.environ.get('USE_ASYNCIO', False), reason='test runs on asyncio only')
def test_raw_socket_client1(event_loop):
class CP(RawSocketClientProtocol):
@property
def serializer_id(self):
return 1
client = CP()
on_hs = Mock()
transport = Mock()
receiver = Mock()
client.stringReceived = receiver
client._on_handshake_complete = on_hs
client.connection_made(transport)
client.data_received(b'\x7F\xF1\x00\x00' + b'\x00\x00\x00\x04abcd')
on_hs.assert_called_once_with()
assert transport.write.called
transport.write.called_one_with(b'\x7F\xF1\x00\x00')
assert not transport.close.called
receiver.assert_called_once_with(b'abcd')
@pytest.mark.skipif(not os.environ.get('USE_ASYNCIO', False), reason='test runs on asyncio only')
def test_raw_socket_client_error(event_loop):
class CP(RawSocketClientProtocol):
@property
def serializer_id(self):
return 1
client = CP()
on_hs = Mock()
transport = Mock(spec_set=('close', 'write', 'get_extra_info'))
receiver = Mock()
client.stringReceived = receiver
client._on_handshake_complete = on_hs
client.connection_made(transport)
client.data_received(b'\x7F\xF1\x00\x01')
transport.close.assert_called_once_with()
@pytest.mark.skipif(not os.environ.get('USE_ASYNCIO', False), reason='test runs on asyncio only')
def test_wamp_server(event_loop):
transport = Mock(spec_set=('abort', 'close', 'write', 'get_extra_info'))
transport.write = Mock(side_effect=lambda m: messages.append(m))
server = Mock(spec=['onOpen', 'onMessage'])
def fact_server():
return server
messages = []
proto = WampRawSocketServerFactory(fact_server)()
proto.connection_made(transport)
assert proto.transport_details.is_server is True
assert proto.transport_details.channel_framing == TransportDetails.CHANNEL_FRAMING_RAWSOCKET
assert proto.factory._serializers
s = proto.factory._serializers[1].RAWSOCKET_SERIALIZER_ID
proto.data_received(bytes(bytearray([0x7F, 0xF0 | s, 0, 0])))
assert proto._serializer
server.onOpen.assert_called_once_with(proto)
proto.send(message.Abort('close'))
for d in messages[1:]:
proto.data_received(d)
assert server.onMessage.called
assert isinstance(server.onMessage.call_args[0][0], message.Abort)
@pytest.mark.skipif(not os.environ.get('USE_ASYNCIO', False), reason='test runs on asyncio only')
def test_wamp_client(event_loop):
transport = Mock(spec_set=('abort', 'close', 'write', 'get_extra_info'))
transport.write = Mock(side_effect=lambda m: messages.append(m))
client = Mock(spec=['onOpen', 'onMessage'])
def fact_client():
return client
messages = []
proto = WampRawSocketClientFactory(fact_client)()
proto.connection_made(transport)
assert proto.transport_details.is_server is False
assert proto.transport_details.channel_framing == TransportDetails.CHANNEL_FRAMING_RAWSOCKET
assert proto._serializer
s = proto._serializer.RAWSOCKET_SERIALIZER_ID
proto.data_received(bytes(bytearray([0x7F, 0xF0 | s, 0, 0])))
client.onOpen.assert_called_once_with(proto)
proto.send(message.Abort('close'))
for d in messages[1:]:
proto.data_received(d)
assert client.onMessage.called
assert isinstance(client.onMessage.call_args[0][0], message.Abort)

View File

@@ -0,0 +1,71 @@
import os
import asyncio
import pytest
import txaio
# because py.test tries to collect it as a test-case
from unittest.mock import Mock
from autobahn.asyncio.websocket import WebSocketServerFactory
async def echo_async(what, when):
await asyncio.sleep(when)
return what
@pytest.mark.skipif(not os.environ.get('USE_ASYNCIO', False), reason='test runs on asyncio only')
@pytest.mark.asyncio
async def test_echo_async():
assert 'Hello!' == await echo_async('Hello!', 0)
# @pytest.mark.asyncio(forbid_global_loop=True)
@pytest.mark.skipif(not os.environ.get('USE_ASYNCIO', False), reason='test runs on asyncio only')
def test_websocket_custom_loop(event_loop):
factory = WebSocketServerFactory(loop=event_loop)
server = factory()
transport = Mock()
server.connection_made(transport)
@pytest.mark.skipif(not os.environ.get('USE_ASYNCIO', False), reason='test runs on asyncio only')
@pytest.mark.asyncio
async def test_async_on_connect_server(event_loop):
num = 42
done = txaio.create_future()
values = []
async def foo(x):
await asyncio.sleep(1)
return x * x
async def on_connect(req):
v = await foo(num)
values.append(v)
txaio.resolve(done, req)
factory = WebSocketServerFactory()
server = factory()
server.onConnect = on_connect
transport = Mock()
server.connection_made(transport)
server.data = b'\r\n'.join([
b'GET /ws HTTP/1.1',
b'Host: www.example.com',
b'Sec-WebSocket-Version: 13',
b'Origin: http://www.example.com.malicious.com',
b'Sec-WebSocket-Extensions: permessage-deflate',
b'Sec-WebSocket-Key: tXAxWFUqnhi86Ajj7dRY5g==',
b'Connection: keep-alive, Upgrade',
b'Upgrade: websocket',
b'\r\n', # last string doesn't get a \r\n from join()
])
server.processHandshake()
await done
assert len(values) == 1
assert values[0] == num * num

View File

@@ -0,0 +1,134 @@
###############################################################################
#
# The MIT License (MIT)
#
# Copyright (c) typedef int GmbH
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
###############################################################################
import unittest
from txaio.testutil import replace_loop
import asyncio
from unittest.mock import patch, Mock
from autobahn.asyncio.wamp import ApplicationRunner
class TestApplicationRunner(unittest.TestCase):
"""
Test the autobahn.asyncio.wamp.ApplicationRunner class.
"""
def _assertRaisesRegex(self, exception, error, *args, **kw):
try:
self.assertRaisesRegex
except AttributeError:
f = self.assertRaisesRegexp
else:
f = self.assertRaisesRegex
f(exception, error, *args, **kw)
def test_explicit_SSLContext(self):
"""
Ensure that loop.create_connection is called with the exact SSL
context object that is passed (as ssl) to the __init__ method of
ApplicationRunner.
"""
with replace_loop(Mock()) as loop:
with patch.object(asyncio, 'get_event_loop', return_value=loop):
loop.run_until_complete = Mock(return_value=(Mock(), Mock()))
ssl = {}
runner = ApplicationRunner('ws://127.0.0.1:8080/ws', 'realm',
ssl=ssl)
runner.run('_unused_')
self.assertIs(ssl, loop.create_connection.call_args[1]['ssl'])
def test_omitted_SSLContext_insecure(self):
"""
Ensure that loop.create_connection is called with ssl=False
if no ssl argument is passed to the __init__ method of
ApplicationRunner and the websocket URL starts with "ws:".
"""
with replace_loop(Mock()) as loop:
with patch.object(asyncio, 'get_event_loop', return_value=loop):
loop.run_until_complete = Mock(return_value=(Mock(), Mock()))
runner = ApplicationRunner('ws://127.0.0.1:8080/ws', 'realm')
runner.run('_unused_')
self.assertIs(False, loop.create_connection.call_args[1]['ssl'])
def test_omitted_SSLContext_secure(self):
"""
Ensure that loop.create_connection is called with ssl=True
if no ssl argument is passed to the __init__ method of
ApplicationRunner and the websocket URL starts with "wss:".
"""
with replace_loop(Mock()) as loop:
with patch.object(asyncio, 'get_event_loop', return_value=loop):
loop.run_until_complete = Mock(return_value=(Mock(), Mock()))
runner = ApplicationRunner('wss://127.0.0.1:8080/wss', 'realm')
runner.run(self.fail)
self.assertIs(True, loop.create_connection.call_args[1]['ssl'])
def test_conflict_SSL_True_with_ws_url(self):
"""
ApplicationRunner must raise an exception if given an ssl value of True
but only a "ws:" URL.
"""
with replace_loop(Mock()) as loop:
loop.run_until_complete = Mock(return_value=(Mock(), Mock()))
runner = ApplicationRunner('ws://127.0.0.1:8080/wss', 'realm',
ssl=True)
error = (r'^ssl argument value passed to ApplicationRunner '
r'conflicts with the "ws:" prefix of the url '
r'argument\. Did you mean to use "wss:"\?$')
self._assertRaisesRegex(Exception, error, runner.run, '_unused_')
def test_conflict_SSLContext_with_ws_url(self):
"""
ApplicationRunner must raise an exception if given an ssl value that is
an instance of SSLContext, but only a "ws:" URL.
"""
import ssl
try:
# Try to create an SSLContext, to be as rigorous as we can be
# by avoiding making assumptions about the ApplicationRunner
# implementation. If we happen to be on a Python that has no
# SSLContext, we pass ssl=True, which will simply cause this
# test to degenerate to the behavior of
# test_conflict_SSL_True_with_ws_url (above). In fact, at the
# moment (2015-05-10), none of this matters because the
# ApplicationRunner implementation does not check to require
# that its ssl argument is either a bool or an SSLContext. But
# that may change, so we should be careful.
ssl.create_default_context
except AttributeError:
context = True
else:
context = ssl.create_default_context()
with replace_loop(Mock()) as loop:
loop.run_until_complete = Mock(return_value=(Mock(), Mock()))
runner = ApplicationRunner('ws://127.0.0.1:8080/wss', 'realm',
ssl=context)
error = (r'^ssl argument value passed to ApplicationRunner '
r'conflicts with the "ws:" prefix of the url '
r'argument\. Did you mean to use "wss:"\?$')
self._assertRaisesRegex(Exception, error, runner.run, '_unused_')

View File

@@ -0,0 +1,147 @@
###############################################################################
#
# The MIT License (MIT)
#
# Copyright (c) typedef int GmbH
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
###############################################################################
import hashlib
from subprocess import Popen
from typing import Optional
import asyncio
from asyncio import sleep # noqa
from autobahn.wamp.types import TransportDetails
__all = (
'sleep',
'peer2str',
'transport_channel_id',
'create_transport_details',
)
def transport_channel_id(transport, is_server: bool, channel_id_type: Optional[str] = None) -> bytes:
"""
Application-layer user authentication protocols are vulnerable to generic
credential forwarding attacks, where an authentication credential sent by
a client C to a server M may then be used by M to impersonate C at another
server S. To prevent such credential forwarding attacks, modern authentication
protocols rely on channel bindings. For example, WAMP-cryptosign can use
the tls-unique channel identifier provided by the TLS layer to strongly bind
authentication credentials to the underlying channel, so that a credential
received on one TLS channel cannot be forwarded on another.
:param transport: The asyncio TLS transport to extract the TLS channel ID from.
:param is_server: Flag indicating the transport is for a server.
:param channel_id_type: TLS channel ID type, currently only "tls-unique" is supported.
:returns: The TLS channel id (32 bytes).
"""
if channel_id_type is None:
return b'\x00' * 32
# ssl.CHANNEL_BINDING_TYPES
if channel_id_type not in ['tls-unique']:
raise Exception("invalid channel ID type {}".format(channel_id_type))
ssl_obj = transport.get_extra_info('ssl_object')
if ssl_obj is None:
raise Exception("TLS transport channel_id for tls-unique requested, but ssl_obj not found on transport")
if not hasattr(ssl_obj, 'get_channel_binding'):
raise Exception("TLS transport channel_id for tls-unique requested, but get_channel_binding not found on ssl_obj")
# https://python.readthedocs.io/en/latest/library/ssl.html#ssl.SSLSocket.get_channel_binding
# https://tools.ietf.org/html/rfc5929.html
tls_finished_msg: bytes = ssl_obj.get_channel_binding(cb_type='tls-unique')
if type(tls_finished_msg) != bytes:
return b'\x00' * 32
else:
m = hashlib.sha256()
m.update(tls_finished_msg)
channel_id = m.digest()
return channel_id
def peer2str(transport: asyncio.transports.BaseTransport) -> str:
# https://docs.python.org/3.9/library/asyncio-protocol.html?highlight=get_extra_info#asyncio.BaseTransport.get_extra_info
# https://docs.python.org/3.9/library/socket.html#socket.socket.getpeername
try:
peer = transport.get_extra_info('peername')
if isinstance(peer, tuple):
ip_ver = 4 if len(peer) == 2 else 6
return "tcp{2}:{0}:{1}".format(peer[0], peer[1], ip_ver)
elif isinstance(peer, str):
return "unix:{0}".format(peer)
else:
return "?:{0}".format(peer)
except:
pass
try:
proc: Popen = transport.get_extra_info('subprocess')
# return 'process:{}'.format(transport.pid)
return 'process:{}'.format(proc.pid)
except:
pass
try:
pipe = transport.get_extra_info('pipe')
return 'pipe:{}'.format(pipe)
except:
pass
# gracefully fallback if we can't map the peer's transport
return 'unknown'
def get_serializers():
from autobahn.wamp import serializer
serializers = ['CBORSerializer', 'MsgPackSerializer', 'UBJSONSerializer', 'JsonSerializer']
serializers = list(filter(lambda x: x, map(lambda s: getattr(serializer, s) if hasattr(serializer, s)
else None, serializers)))
return serializers
def create_transport_details(transport, is_server: bool) -> TransportDetails:
# Internal helper. Base class calls this to create a TransportDetails
peer = peer2str(transport)
# https://docs.python.org/3.9/library/asyncio-protocol.html?highlight=get_extra_info#asyncio.BaseTransport.get_extra_info
is_secure = transport.get_extra_info('peercert', None) is not None
if is_secure:
channel_id = {
'tls-unique': transport_channel_id(transport, is_server, 'tls-unique'),
}
channel_type = TransportDetails.CHANNEL_TYPE_TLS
peer_cert = None
else:
channel_id = {}
channel_type = TransportDetails.CHANNEL_TYPE_TCP
peer_cert = None
channel_framing = TransportDetails.CHANNEL_FRAMING_WEBSOCKET
return TransportDetails(channel_type=channel_type, channel_framing=channel_framing,
peer=peer, is_server=is_server, is_secure=is_secure,
channel_id=channel_id, peer_cert=peer_cert)

View File

@@ -0,0 +1,309 @@
###############################################################################
#
# The MIT License (MIT)
#
# Copyright (c) typedef int GmbH
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
###############################################################################
import asyncio
import signal
import txaio
txaio.use_asyncio() # noqa
from autobahn.util import public
from autobahn.wamp import protocol
from autobahn.wamp.types import ComponentConfig
from autobahn.websocket.util import parse_url as parse_ws_url
from autobahn.rawsocket.util import parse_url as parse_rs_url
from autobahn.asyncio.websocket import WampWebSocketClientFactory
from autobahn.asyncio.rawsocket import WampRawSocketClientFactory
from autobahn.websocket.compress import PerMessageDeflateOffer, \
PerMessageDeflateResponse, PerMessageDeflateResponseAccept
from autobahn.wamp.interfaces import ITransportHandler, ISession
__all__ = (
'ApplicationSession',
'ApplicationSessionFactory',
'ApplicationRunner'
)
@public
class ApplicationSession(protocol.ApplicationSession):
"""
WAMP application session for asyncio-based applications.
Implements:
* ``autobahn.wamp.interfaces.ITransportHandler``
* ``autobahn.wamp.interfaces.ISession``
"""
log = txaio.make_logger()
ITransportHandler.register(ApplicationSession)
# ISession.register collides with the abc.ABCMeta.register method
ISession.abc_register(ApplicationSession)
class ApplicationSessionFactory(protocol.ApplicationSessionFactory):
"""
WAMP application session factory for asyncio-based applications.
"""
session: ApplicationSession = ApplicationSession
"""
The application session class this application session factory will use.
Defaults to :class:`autobahn.asyncio.wamp.ApplicationSession`.
"""
log = txaio.make_logger()
@public
class ApplicationRunner(object):
"""
This class is a convenience tool mainly for development and quick hosting
of WAMP application components.
It can host a WAMP application component in a WAMP-over-WebSocket client
connecting to a WAMP router.
"""
log = txaio.make_logger()
def __init__(self,
url,
realm=None,
extra=None,
serializers=None,
ssl=None,
proxy=None,
headers=None):
"""
:param url: The WebSocket URL of the WAMP router to connect to (e.g. `ws://somehost.com:8090/somepath`)
:type url: str
:param realm: The WAMP realm to join the application session to.
:type realm: str
:param extra: Optional extra configuration to forward to the application component.
:type extra: dict
:param serializers: A list of WAMP serializers to use (or None for default serializers).
Serializers must implement :class:`autobahn.wamp.interfaces.ISerializer`.
:type serializers: list
:param ssl: An (optional) SSL context instance or a bool. See
the documentation for the `loop.create_connection` asyncio
method, to which this value is passed as the ``ssl``
keyword parameter.
:type ssl: :class:`ssl.SSLContext` or bool
:param proxy: Explicit proxy server to use; a dict with ``host`` and ``port`` keys
:type proxy: dict or None
:param headers: Additional headers to send (only applies to WAMP-over-WebSocket).
:type headers: dict
"""
assert(type(url) == str)
assert(realm is None or type(realm) == str)
assert(extra is None or type(extra) == dict)
assert(headers is None or type(headers) == dict)
assert(proxy is None or type(proxy) == dict)
self.url = url
self.realm = realm
self.extra = extra or dict()
self.serializers = serializers
self.ssl = ssl
self.proxy = proxy
self.headers = headers
@public
def stop(self):
"""
Stop reconnecting, if auto-reconnecting was enabled.
"""
raise NotImplementedError()
@public
def run(self, make, start_loop=True, log_level='info'):
"""
Run the application component. Under the hood, this runs the event
loop (unless `start_loop=False` is passed) so won't return
until the program is done.
:param make: A factory that produces instances of :class:`autobahn.asyncio.wamp.ApplicationSession`
when called with an instance of :class:`autobahn.wamp.types.ComponentConfig`.
:type make: callable
:param start_loop: When ``True`` (the default) this method
start a new asyncio loop.
:type start_loop: bool
:returns: None is returned, unless you specify
`start_loop=False` in which case the coroutine from calling
`loop.create_connection()` is returned. This will yield the
(transport, protocol) pair.
"""
if callable(make):
def create():
cfg = ComponentConfig(self.realm, self.extra)
try:
session = make(cfg)
except Exception as e:
self.log.error('ApplicationSession could not be instantiated: {}'.format(e))
loop = asyncio.get_event_loop()
if loop.is_running():
loop.stop()
raise
else:
return session
else:
create = make
if self.url.startswith('rs'):
# try to parse RawSocket URL ..
isSecure, host, port = parse_rs_url(self.url)
# use the first configured serializer if any (which means, auto-choose "best")
serializer = self.serializers[0] if self.serializers else None
# create a WAMP-over-RawSocket transport client factory
transport_factory = WampRawSocketClientFactory(create, serializer=serializer)
else:
# try to parse WebSocket URL ..
isSecure, host, port, resource, path, params = parse_ws_url(self.url)
# create a WAMP-over-WebSocket transport client factory
transport_factory = WampWebSocketClientFactory(create, url=self.url, serializers=self.serializers, proxy=self.proxy, headers=self.headers)
# client WebSocket settings - similar to:
# - http://crossbar.io/docs/WebSocket-Compression/#production-settings
# - http://crossbar.io/docs/WebSocket-Options/#production-settings
# The permessage-deflate extensions offered to the server ..
offers = [PerMessageDeflateOffer()]
# Function to accept permessage_delate responses from the server ..
def accept(response):
if isinstance(response, PerMessageDeflateResponse):
return PerMessageDeflateResponseAccept(response)
# set WebSocket options for all client connections
transport_factory.setProtocolOptions(maxFramePayloadSize=1048576,
maxMessagePayloadSize=1048576,
autoFragmentSize=65536,
failByDrop=False,
openHandshakeTimeout=2.5,
closeHandshakeTimeout=1.,
tcpNoDelay=True,
autoPingInterval=10.,
autoPingTimeout=5.,
autoPingSize=12,
perMessageCompressionOffers=offers,
perMessageCompressionAccept=accept)
# SSL context for client connection
if self.ssl is None:
ssl = isSecure
else:
if self.ssl and not isSecure:
raise RuntimeError(
'ssl argument value passed to %s conflicts with the "ws:" '
'prefix of the url argument. Did you mean to use "wss:"?' %
self.__class__.__name__)
ssl = self.ssl
# start the client connection
loop = asyncio.get_event_loop()
if loop.is_closed() and start_loop:
asyncio.set_event_loop(asyncio.new_event_loop())
loop = asyncio.get_event_loop()
if hasattr(transport_factory, 'loop'):
transport_factory.loop = loop
# assure we are using asyncio
# txaio.use_asyncio()
assert txaio._explicit_framework == 'asyncio'
txaio.config.loop = loop
coro = loop.create_connection(transport_factory, host, port, ssl=ssl)
# start a asyncio loop
if not start_loop:
return coro
else:
(transport, protocol) = loop.run_until_complete(coro)
# start logging
txaio.start_logging(level=log_level)
try:
loop.add_signal_handler(signal.SIGTERM, loop.stop)
except NotImplementedError:
# signals are not available on Windows
pass
# 4) now enter the asyncio event loop
try:
loop.run_forever()
except KeyboardInterrupt:
# wait until we send Goodbye if user hit ctrl-c
# (done outside this except so SIGTERM gets the same handling)
pass
# give Goodbye message a chance to go through, if we still
# have an active session
if protocol._session:
loop.run_until_complete(protocol._session.leave())
loop.close()
# new API
class Session(protocol._SessionShim):
# XXX these methods are redundant, but put here for possibly
# better clarity; maybe a bad idea.
def on_welcome(self, welcome_msg):
pass
def on_join(self, details):
pass
def on_leave(self, details):
self.disconnect()
def on_connect(self):
self.join(self.config.realm)
def on_disconnect(self):
pass

View File

@@ -0,0 +1,386 @@
###############################################################################
#
# The MIT License (MIT)
#
# Copyright (c) typedef int GmbH
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
###############################################################################
import asyncio
from asyncio import iscoroutine
from asyncio import Future
from collections import deque
from typing import Optional
import txaio
txaio.use_asyncio() # noqa
from autobahn.util import public, hltype
from autobahn.asyncio.util import create_transport_details, transport_channel_id
from autobahn.wamp import websocket
from autobahn.websocket import protocol
__all__ = (
'WebSocketServerProtocol',
'WebSocketClientProtocol',
'WebSocketServerFactory',
'WebSocketClientFactory',
'WampWebSocketServerProtocol',
'WampWebSocketClientProtocol',
'WampWebSocketServerFactory',
'WampWebSocketClientFactory',
)
def yields(value):
"""
Returns ``True`` iff the value yields.
.. seealso:: http://stackoverflow.com/questions/20730248/maybedeferred-analog-with-asyncio
"""
return isinstance(value, Future) or iscoroutine(value)
class WebSocketAdapterProtocol(asyncio.Protocol):
"""
Adapter class for asyncio-based WebSocket client and server protocols.
"""
log = txaio.make_logger()
peer: Optional[str] = None
is_server: Optional[bool] = None
def connection_made(self, transport):
# asyncio networking framework entry point, called by asyncio
# when the connection is established (either a client or a server)
self.log.debug('{func}(transport={transport})', func=hltype(self.connection_made),
transport=transport)
self.transport = transport
# determine preliminary transport details (what is know at this point)
self._transport_details = create_transport_details(self.transport, self.is_server)
# backward compatibility
self.peer = self._transport_details.peer
self.receive_queue = deque()
self._consume()
self._connectionMade()
def connection_lost(self, exc):
self._connectionLost(exc)
# according to asyncio docs, connection_lost(None) is called
# if something else called transport.close()
if exc is not None:
self.transport.close()
self.transport = None
def _consume(self):
self.waiter = Future(loop=self.factory.loop or txaio.config.loop)
def process(_):
while self.receive_queue:
data = self.receive_queue.popleft()
if self.transport:
self._dataReceived(data)
self._consume()
self.waiter.add_done_callback(process)
def data_received(self, data):
self.receive_queue.append(data)
if not self.waiter.done():
self.waiter.set_result(None)
def _closeConnection(self, abort=False):
if abort and hasattr(self.transport, 'abort'):
self.transport.abort()
else:
self.transport.close()
def _onOpen(self):
if self._transport_details.is_secure:
# now that the TLS opening handshake is complete, the actual TLS channel ID
# will be available. make sure to set it!
channel_id = {
'tls-unique': transport_channel_id(self.transport, self._transport_details.is_server, 'tls-unique'),
}
self._transport_details.channel_id = channel_id
res = self.onOpen()
if yields(res):
asyncio.ensure_future(res)
def _onMessageBegin(self, isBinary):
res = self.onMessageBegin(isBinary)
if yields(res):
asyncio.ensure_future(res)
def _onMessageFrameBegin(self, length):
res = self.onMessageFrameBegin(length)
if yields(res):
asyncio.ensure_future(res)
def _onMessageFrameData(self, payload):
res = self.onMessageFrameData(payload)
if yields(res):
asyncio.ensure_future(res)
def _onMessageFrameEnd(self):
res = self.onMessageFrameEnd()
if yields(res):
asyncio.ensure_future(res)
def _onMessageFrame(self, payload):
res = self.onMessageFrame(payload)
if yields(res):
asyncio.ensure_future(res)
def _onMessageEnd(self):
res = self.onMessageEnd()
if yields(res):
asyncio.ensure_future(res)
def _onMessage(self, payload, isBinary):
res = self.onMessage(payload, isBinary)
if yields(res):
asyncio.ensure_future(res)
def _onPing(self, payload):
res = self.onPing(payload)
if yields(res):
asyncio.ensure_future(res)
def _onPong(self, payload):
res = self.onPong(payload)
if yields(res):
asyncio.ensure_future(res)
def _onClose(self, wasClean, code, reason):
res = self.onClose(wasClean, code, reason)
if yields(res):
asyncio.ensure_future(res)
def registerProducer(self, producer, streaming):
raise Exception("not implemented")
def unregisterProducer(self):
# note that generic websocket/protocol.py code calls
# .unregisterProducer whenever we dropConnection -- that's
# correct behavior on Twisted so either we'd have to
# try/except there, or special-case Twisted, ..or just make
# this "not an error"
pass
@public
class WebSocketServerProtocol(WebSocketAdapterProtocol, protocol.WebSocketServerProtocol):
"""
Base class for asyncio-based WebSocket server protocols.
Implements:
* :class:`autobahn.websocket.interfaces.IWebSocketChannel`
"""
log = txaio.make_logger()
@public
class WebSocketClientProtocol(WebSocketAdapterProtocol, protocol.WebSocketClientProtocol):
"""
Base class for asyncio-based WebSocket client protocols.
Implements:
* :class:`autobahn.websocket.interfaces.IWebSocketChannel`
"""
log = txaio.make_logger()
def _onConnect(self, response):
res = self.onConnect(response)
self.log.debug('{func}: {res}', func=hltype(self._onConnect), res=res)
if yields(res):
asyncio.ensure_future(res)
def startTLS(self):
raise Exception("WSS over explicit proxies not implemented")
class WebSocketAdapterFactory(object):
"""
Adapter class for asyncio-based WebSocket client and server factories.
"""
log = txaio.make_logger()
def __call__(self):
proto = self.protocol()
proto.factory = self
return proto
@public
class WebSocketServerFactory(WebSocketAdapterFactory, protocol.WebSocketServerFactory):
"""
Base class for asyncio-based WebSocket server factories.
Implements:
* :class:`autobahn.websocket.interfaces.IWebSocketServerChannelFactory`
"""
log = txaio.make_logger()
protocol = WebSocketServerProtocol
def __init__(self, *args, **kwargs):
"""
.. note::
In addition to all arguments to the constructor of
:meth:`autobahn.websocket.interfaces.IWebSocketServerChannelFactory`,
you can supply a ``loop`` keyword argument to specify the
asyncio event loop to be used.
"""
loop = kwargs.pop('loop', None)
self.loop = loop or asyncio.get_event_loop()
protocol.WebSocketServerFactory.__init__(self, *args, **kwargs)
@public
class WebSocketClientFactory(WebSocketAdapterFactory, protocol.WebSocketClientFactory):
"""
Base class for asyncio-based WebSocket client factories.
Implements:
* :class:`autobahn.websocket.interfaces.IWebSocketClientChannelFactory`
"""
log = txaio.make_logger()
def __init__(self, *args, **kwargs):
"""
.. note::
In addition to all arguments to the constructor of
:meth:`autobahn.websocket.interfaces.IWebSocketClientChannelFactory`,
you can supply a ``loop`` keyword argument to specify the
asyncio event loop to be used.
"""
loop = kwargs.pop('loop', None)
self.loop = loop or asyncio.get_event_loop()
protocol.WebSocketClientFactory.__init__(self, *args, **kwargs)
@public
class WampWebSocketServerProtocol(websocket.WampWebSocketServerProtocol, WebSocketServerProtocol):
"""
asyncio-based WAMP-over-WebSocket server protocol.
Implements:
* :class:`autobahn.wamp.interfaces.ITransport`
"""
log = txaio.make_logger()
@public
class WampWebSocketServerFactory(websocket.WampWebSocketServerFactory, WebSocketServerFactory):
"""
asyncio-based WAMP-over-WebSocket server factory.
"""
log = txaio.make_logger()
protocol = WampWebSocketServerProtocol
def __init__(self, factory, *args, **kwargs):
"""
:param factory: A callable that produces instances that implement
:class:`autobahn.wamp.interfaces.ITransportHandler`
:type factory: callable
:param serializers: A list of WAMP serializers to use (or ``None``
for all available serializers).
:type serializers: list of objects implementing
:class:`autobahn.wamp.interfaces.ISerializer`
"""
serializers = kwargs.pop('serializers', None)
websocket.WampWebSocketServerFactory.__init__(self, factory, serializers)
kwargs['protocols'] = self._protocols
# noinspection PyCallByClass
WebSocketServerFactory.__init__(self, *args, **kwargs)
@public
class WampWebSocketClientProtocol(websocket.WampWebSocketClientProtocol, WebSocketClientProtocol):
"""
asyncio-based WAMP-over-WebSocket client protocols.
Implements:
* :class:`autobahn.wamp.interfaces.ITransport`
"""
log = txaio.make_logger()
@public
class WampWebSocketClientFactory(websocket.WampWebSocketClientFactory, WebSocketClientFactory):
"""
asyncio-based WAMP-over-WebSocket client factory.
"""
log = txaio.make_logger()
protocol = WampWebSocketClientProtocol
def __init__(self, factory, *args, **kwargs):
"""
:param factory: A callable that produces instances that implement
:class:`autobahn.wamp.interfaces.ITransportHandler`
:type factory: callable
:param serializer: The WAMP serializer to use (or ``None`` for
"best" serializer, chosen as the first serializer available from
this list: CBOR, MessagePack, UBJSON, JSON).
:type serializer: object implementing :class:`autobahn.wamp.interfaces.ISerializer`
"""
serializers = kwargs.pop('serializers', None)
websocket.WampWebSocketClientFactory.__init__(self, factory, serializers)
kwargs['protocols'] = self._protocols
WebSocketClientFactory.__init__(self, *args, **kwargs)

View File

@@ -0,0 +1,93 @@
###############################################################################
#
# The MIT License (MIT)
#
# Copyright (c) typedef int GmbH
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
###############################################################################
try:
from autobahn import xbr
HAS_XBR = True
except:
HAS_XBR = False
if HAS_XBR:
import uuid
import asyncio
import txaio
from autobahn.util import hl
from autobahn.xbr._interfaces import IProvider, ISeller, IConsumer, IBuyer
def run_in_executor(*args, **kwargs):
return asyncio.get_running_loop().run_in_executor(None, *args, **kwargs)
class SimpleBlockchain(xbr.SimpleBlockchain):
backgroundCaller = run_in_executor
class KeySeries(xbr.KeySeries):
log = txaio.make_logger()
def __init__(self, api_id, price, interval, on_rotate=None):
super().__init__(api_id, price, interval, on_rotate)
self.running = False
async def start(self):
"""
Start offering and selling data encryption keys in the background.
"""
assert not self.running
self.log.info('Starting key rotation every {interval} seconds for api_id="{api_id}" ..',
interval=hl(self._interval), api_id=hl(uuid.UUID(bytes=self._api_id)))
self.running = True
async def rotate_with_interval():
while self.running:
await self._rotate()
await asyncio.sleep(self._interval)
asyncio.create_task(rotate_with_interval())
def stop(self):
"""
Stop offering/selling data encryption keys.
"""
if not self.running:
raise RuntimeError('cannot stop {} - not currently running'.format(self.__class__.__name__))
self.running = False
class SimpleSeller(xbr.SimpleSeller):
"""
Simple XBR seller component. This component can be used by a XBR seller delegate to
handle the automated selling of data encryption keys to the XBR market maker.
"""
xbr.SimpleSeller.KeySeries = KeySeries
class SimpleBuyer(xbr.SimpleBuyer):
pass
ISeller.register(SimpleSeller)
IProvider.register(SimpleSeller)
IBuyer.register(SimpleBuyer)
IConsumer.register(SimpleBuyer)