okay fine

This commit is contained in:
pacnpal
2024-11-03 17:47:26 +00:00
parent 387c4740e7
commit 27f3326e22
10020 changed files with 1935769 additions and 2364 deletions

View File

@@ -0,0 +1,7 @@
# -*- test-case-name: twisted.conch.test -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Twisted Conch: The Twisted Shell. Terminal emulation, SSHv2 and telnet.
"""

View File

@@ -0,0 +1,56 @@
# -*- test-case-name: twisted.conch.test.test_conch -*-
from zope.interface import implementer
from twisted.conch.error import ConchError
from twisted.conch.interfaces import IConchUser
from twisted.conch.ssh.connection import OPEN_UNKNOWN_CHANNEL_TYPE
from twisted.logger import Logger
from twisted.python.compat import nativeString
@implementer(IConchUser)
class ConchUser:
_log = Logger()
def __init__(self):
self.channelLookup = {}
self.subsystemLookup = {}
@property
def conn(self):
return self._conn
@conn.setter
def conn(self, value):
self._conn = value
def lookupChannel(self, channelType, windowSize, maxPacket, data):
klass = self.channelLookup.get(channelType, None)
if not klass:
raise ConchError(OPEN_UNKNOWN_CHANNEL_TYPE, "unknown channel")
else:
return klass(
remoteWindow=windowSize,
remoteMaxPacket=maxPacket,
data=data,
avatar=self,
)
def lookupSubsystem(self, subsystem, data):
self._log.debug(
"Subsystem lookup: {subsystem!r}", subsystem=self.subsystemLookup
)
klass = self.subsystemLookup.get(subsystem, None)
if not klass:
return False
return klass(data, avatar=self)
def gotGlobalRequest(self, requestType, data):
# XXX should this use method dispatch?
requestType = nativeString(requestType.replace(b"-", b"_"))
f = getattr(self, "global_%s" % requestType, None)
if not f:
return 0
return f(data)

View File

@@ -0,0 +1,640 @@
# -*- test-case-name: twisted.conch.test.test_checkers -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Provide L{ICredentialsChecker} implementations to be used in Conch protocols.
"""
import binascii
import errno
import sys
from base64 import decodebytes
from typing import IO, Any, Callable, Iterable, Iterator, Mapping, Optional, Tuple, cast
from zope.interface import Interface, implementer, providedBy
from incremental import Version
from typing_extensions import Literal, Protocol
from twisted.conch import error
from twisted.conch.ssh import keys
from twisted.cred.checkers import ICredentialsChecker
from twisted.cred.credentials import ISSHPrivateKey, IUsernamePassword
from twisted.cred.error import UnauthorizedLogin, UnhandledCredentials
from twisted.internet import defer
from twisted.logger import Logger
from twisted.plugins.cred_unix import verifyCryptedPassword
from twisted.python import failure, reflect
from twisted.python.deprecate import deprecatedModuleAttribute
from twisted.python.filepath import FilePath
from twisted.python.util import runAsEffectiveUser
_log = Logger()
class UserRecord(Tuple[str, str, int, int, str, str, str]):
"""
A record in a UNIX-style password database. See L{pwd} for field details.
This corresponds to the undocumented type L{pwd.struct_passwd}, but lacks named
field accessors.
"""
@property
def pw_dir(self) -> str: # type: ignore[empty-body]
...
class UserDB(Protocol):
"""
A database of users by name, like the stdlib L{pwd} module.
See L{twisted.python.fakepwd} for an in-memory implementation.
"""
def getpwnam(self, username: str) -> UserRecord:
"""
Lookup a user record by name.
@raises KeyError: when no such user exists
"""
pwd: Optional[UserDB]
try:
import pwd as _pwd
except ImportError:
pwd = None
else:
pwd = cast(UserDB, _pwd)
try:
import spwd as _spwd
except ImportError:
spwd = None
else:
spwd = _spwd
class CryptedPasswordRecord(Protocol):
"""
A sequence where the item at index 1 may be a crypted password.
Both L{pwd.struct_passwd} and L{spwd.struct_spwd} conform to this protocol.
"""
def __getitem__(self, index: Literal[1]) -> str:
"""
Get the crypted password.
"""
def _lookupUser(userdb: UserDB, username: bytes) -> UserRecord:
"""
Lookup a user by name in a L{pwd}-style database.
@param userdb: The user database.
@param username: Identifying name in bytes. This will be decoded according
to the filesystem encoding, as the L{pwd} module does internally.
@raises KeyError: when the user doesn't exist
"""
return userdb.getpwnam(username.decode(sys.getfilesystemencoding()))
def _pwdGetByName(username: str) -> Optional[CryptedPasswordRecord]:
"""
Look up a user in the /etc/passwd database using the pwd module. If the
pwd module is not available, return None.
@param username: the username of the user to return the passwd database
information for.
@returns: A L{pwd.struct_passwd}, where field 1 may contain a crypted
password, or L{None} when the L{pwd} database is unavailable.
@raises KeyError: when no such user exists
"""
if pwd is None:
return None
return cast(CryptedPasswordRecord, pwd.getpwnam(username))
def _shadowGetByName(username: str) -> Optional[CryptedPasswordRecord]:
"""
Look up a user in the /etc/shadow database using the spwd module. If it is
not available, return L{None}.
@param username: the username of the user to return the shadow database
information for.
@type username: L{str}
@returns: A L{spwd.struct_spwd}, where field 1 may contain a crypted
password, or L{None} when the L{spwd} database is unavailable.
@raises KeyError: when no such user exists
"""
if spwd is not None:
f = spwd.getspnam
else:
return None
return cast(CryptedPasswordRecord, runAsEffectiveUser(0, 0, f, username))
@implementer(ICredentialsChecker)
class UNIXPasswordDatabase:
"""
A checker which validates users out of the UNIX password databases, or
databases of a compatible format.
@ivar _getByNameFunctions: a C{list} of functions which are called in order
to validate a user. The default value is such that the C{/etc/passwd}
database will be tried first, followed by the C{/etc/shadow} database.
"""
credentialInterfaces = (IUsernamePassword,)
def __init__(self, getByNameFunctions=None):
if getByNameFunctions is None:
getByNameFunctions = [_pwdGetByName, _shadowGetByName]
self._getByNameFunctions = getByNameFunctions
def requestAvatarId(self, credentials):
# We get bytes, but the Py3 pwd module uses str. So attempt to decode
# it using the same method that CPython does for the file on disk.
username = credentials.username.decode(sys.getfilesystemencoding())
password = credentials.password.decode(sys.getfilesystemencoding())
for func in self._getByNameFunctions:
try:
pwnam = func(username)
except KeyError:
return defer.fail(UnauthorizedLogin("invalid username"))
else:
if pwnam is not None:
crypted = pwnam[1]
if crypted == "":
continue
if verifyCryptedPassword(crypted, password):
return defer.succeed(credentials.username)
# fallback
return defer.fail(UnauthorizedLogin("unable to verify password"))
@implementer(ICredentialsChecker)
class SSHPublicKeyDatabase:
"""
Checker that authenticates SSH public keys, based on public keys listed in
authorized_keys and authorized_keys2 files in user .ssh/ directories.
"""
credentialInterfaces = (ISSHPrivateKey,)
_userdb: UserDB = cast(UserDB, pwd)
def requestAvatarId(self, credentials):
d = defer.maybeDeferred(self.checkKey, credentials)
d.addCallback(self._cbRequestAvatarId, credentials)
d.addErrback(self._ebRequestAvatarId)
return d
def _cbRequestAvatarId(self, validKey, credentials):
"""
Check whether the credentials themselves are valid, now that we know
if the key matches the user.
@param validKey: A boolean indicating whether or not the public key
matches a key in the user's authorized_keys file.
@param credentials: The credentials offered by the user.
@type credentials: L{ISSHPrivateKey} provider
@raise UnauthorizedLogin: (as a failure) if the key does not match the
user in C{credentials}. Also raised if the user provides an invalid
signature.
@raise ValidPublicKey: (as a failure) if the key matches the user but
the credentials do not include a signature. See
L{error.ValidPublicKey} for more information.
@return: The user's username, if authentication was successful.
"""
if not validKey:
return failure.Failure(UnauthorizedLogin("invalid key"))
if not credentials.signature:
return failure.Failure(error.ValidPublicKey())
else:
try:
pubKey = keys.Key.fromString(credentials.blob)
if pubKey.verify(credentials.signature, credentials.sigData):
return credentials.username
except Exception: # any error should be treated as a failed login
_log.failure("Error while verifying key")
return failure.Failure(UnauthorizedLogin("error while verifying key"))
return failure.Failure(UnauthorizedLogin("unable to verify key"))
def getAuthorizedKeysFiles(self, credentials):
"""
Return a list of L{FilePath} instances for I{authorized_keys} files
which might contain information about authorized keys for the given
credentials.
On OpenSSH servers, the default location of the file containing the
list of authorized public keys is
U{$HOME/.ssh/authorized_keys<http://www.openbsd.org/cgi-bin/man.cgi?query=sshd_config>}.
I{$HOME/.ssh/authorized_keys2} is also returned, though it has been
U{deprecated by OpenSSH since
2001<http://marc.info/?m=100508718416162>}.
@return: A list of L{FilePath} instances to files with the authorized keys.
"""
pwent = _lookupUser(self._userdb, credentials.username)
root = FilePath(pwent.pw_dir).child(".ssh")
files = ["authorized_keys", "authorized_keys2"]
return [root.child(f) for f in files]
def checkKey(self, credentials):
"""
Retrieve files containing authorized keys and check against user
credentials.
"""
ouid, ogid = _lookupUser(self._userdb, credentials.username)[2:4]
for filepath in self.getAuthorizedKeysFiles(credentials):
if not filepath.exists():
continue
try:
lines = filepath.open()
except OSError as e:
if e.errno == errno.EACCES:
lines = runAsEffectiveUser(ouid, ogid, filepath.open)
else:
raise
with lines:
for l in lines:
l2 = l.split()
if len(l2) < 2:
continue
try:
if decodebytes(l2[1]) == credentials.blob:
return True
except binascii.Error:
continue
return False
def _ebRequestAvatarId(self, f):
if not f.check(UnauthorizedLogin):
_log.error(
"Unauthorized login due to internal error: {error}", error=f.value
)
return failure.Failure(UnauthorizedLogin("unable to get avatar id"))
return f
@implementer(ICredentialsChecker)
class SSHProtocolChecker:
"""
SSHProtocolChecker is a checker that requires multiple authentications
to succeed. To add a checker, call my registerChecker method with
the checker and the interface.
After each successful authenticate, I call my areDone method with the
avatar id. To get a list of the successful credentials for an avatar id,
use C{SSHProcotolChecker.successfulCredentials[avatarId]}. If L{areDone}
returns True, the authentication has succeeded.
"""
def __init__(self):
self.checkers = {}
self.successfulCredentials = {}
@property
def credentialInterfaces(self):
return list(self.checkers.keys())
def registerChecker(self, checker, *credentialInterfaces):
if not credentialInterfaces:
credentialInterfaces = checker.credentialInterfaces
for credentialInterface in credentialInterfaces:
self.checkers[credentialInterface] = checker
def requestAvatarId(self, credentials):
"""
Part of the L{ICredentialsChecker} interface. Called by a portal with
some credentials to check if they'll authenticate a user. We check the
interfaces that the credentials provide against our list of acceptable
checkers. If one of them matches, we ask that checker to verify the
credentials. If they're valid, we call our L{_cbGoodAuthentication}
method to continue.
@param credentials: the credentials the L{Portal} wants us to verify
"""
ifac = providedBy(credentials)
for i in ifac:
c = self.checkers.get(i)
if c is not None:
d = defer.maybeDeferred(c.requestAvatarId, credentials)
return d.addCallback(self._cbGoodAuthentication, credentials)
return defer.fail(
UnhandledCredentials(
"No checker for %s" % ", ".join(map(reflect.qual, ifac))
)
)
def _cbGoodAuthentication(self, avatarId, credentials):
"""
Called if a checker has verified the credentials. We call our
L{areDone} method to see if the whole of the successful authentications
are enough. If they are, we return the avatar ID returned by the first
checker.
"""
if avatarId not in self.successfulCredentials:
self.successfulCredentials[avatarId] = []
self.successfulCredentials[avatarId].append(credentials)
if self.areDone(avatarId):
del self.successfulCredentials[avatarId]
return avatarId
else:
raise error.NotEnoughAuthentication()
def areDone(self, avatarId):
"""
Override to determine if the authentication is finished for a given
avatarId.
@param avatarId: the avatar returned by the first checker. For
this checker to function correctly, all the checkers must
return the same avatar ID.
"""
return True
deprecatedModuleAttribute(
Version("Twisted", 15, 0, 0),
(
"Please use twisted.conch.checkers.SSHPublicKeyChecker, "
"initialized with an instance of "
"twisted.conch.checkers.UNIXAuthorizedKeysFiles instead."
),
__name__,
"SSHPublicKeyDatabase",
)
class IAuthorizedKeysDB(Interface):
"""
An object that provides valid authorized ssh keys mapped to usernames.
@since: 15.0
"""
def getAuthorizedKeys(avatarId):
"""
Gets an iterable of authorized keys that are valid for the given
C{avatarId}.
@param avatarId: the ID of the avatar
@type avatarId: valid return value of
L{twisted.cred.checkers.ICredentialsChecker.requestAvatarId}
@return: an iterable of L{twisted.conch.ssh.keys.Key}
"""
def readAuthorizedKeyFile(
fileobj: IO[bytes], parseKey: Callable[[bytes], keys.Key] = keys.Key.fromString
) -> Iterator[keys.Key]:
"""
Reads keys from an authorized keys file. Any non-comment line that cannot
be parsed as a key will be ignored, although that particular line will
be logged.
@param fileobj: something from which to read lines which can be parsed
as keys
@param parseKey: a callable that takes bytes and returns a
L{twisted.conch.ssh.keys.Key}, mainly to be used for testing. The
default is L{twisted.conch.ssh.keys.Key.fromString}.
@return: an iterable of L{twisted.conch.ssh.keys.Key}
@since: 15.0
"""
for line in fileobj:
line = line.strip()
if line and not line.startswith(b"#"): # for comments
try:
yield parseKey(line)
except keys.BadKeyError as e:
_log.error(
"Unable to parse line {line!r} as a key: {error!s}",
line=line,
error=e,
)
def _keysFromFilepaths(
filepaths: Iterable[FilePath[Any]], parseKey: Callable[[bytes], keys.Key]
) -> Iterable[keys.Key]:
"""
Helper function that turns an iterable of filepaths into a generator of
keys. If any file cannot be read, a message is logged but it is
otherwise ignored.
@param filepaths: iterable of L{twisted.python.filepath.FilePath}.
@type filepaths: iterable
@param parseKey: a callable that takes a string and returns a
L{twisted.conch.ssh.keys.Key}
@type parseKey: L{callable}
@return: generator of L{twisted.conch.ssh.keys.Key}
@since: 15.0
"""
for fp in filepaths:
if fp.exists():
try:
with fp.open() as f:
yield from readAuthorizedKeyFile(f, parseKey)
except OSError as e:
_log.error("Unable to read {path!r}: {error!s}", path=fp.path, error=e)
@implementer(IAuthorizedKeysDB)
class InMemorySSHKeyDB:
"""
Object that provides SSH public keys based on a dictionary of usernames
mapped to L{twisted.conch.ssh.keys.Key}s.
@since: 15.0
"""
def __init__(self, mapping: Mapping[bytes, Iterable[keys.Key]]) -> None:
"""
Initializes a new L{InMemorySSHKeyDB}.
@param mapping: mapping of usernames to iterables of
L{twisted.conch.ssh.keys.Key}s
"""
self._mapping = mapping
def getAuthorizedKeys(self, username: bytes) -> Iterable[keys.Key]:
"""
Look up the authorized keys for a user.
@param username: Name of the user
"""
return self._mapping.get(username, [])
@implementer(IAuthorizedKeysDB)
class UNIXAuthorizedKeysFiles:
"""
Object that provides SSH public keys based on public keys listed in
authorized_keys and authorized_keys2 files in UNIX user .ssh/ directories.
If any of the files cannot be read, a message is logged but that file is
otherwise ignored.
@since: 15.0
"""
_userdb: UserDB
def __init__(
self,
userdb: Optional[UserDB] = None,
parseKey: Callable[[bytes], keys.Key] = keys.Key.fromString,
):
"""
Initializes a new L{UNIXAuthorizedKeysFiles}.
@param userdb: access to the Unix user account and password database
(default is the Python module L{pwd}, if available)
@param parseKey: a callable that takes a string and returns a
L{twisted.conch.ssh.keys.Key}, mainly to be used for testing. The
default is L{twisted.conch.ssh.keys.Key.fromString}.
"""
if userdb is not None:
self._userdb = userdb
elif pwd is not None:
self._userdb = pwd
else:
raise ValueError("No pwd module found, and no userdb argument passed.")
self._parseKey = parseKey
def getAuthorizedKeys(self, username: bytes) -> Iterable[keys.Key]:
try:
passwd = _lookupUser(self._userdb, username)
except KeyError:
return ()
root = FilePath(passwd.pw_dir).child(".ssh")
files = ["authorized_keys", "authorized_keys2"]
return _keysFromFilepaths((root.child(f) for f in files), self._parseKey)
@implementer(ICredentialsChecker)
class SSHPublicKeyChecker:
"""
Checker that authenticates SSH public keys, based on public keys listed in
authorized_keys and authorized_keys2 files in user .ssh/ directories.
Initializing this checker with a L{UNIXAuthorizedKeysFiles} should be
used instead of L{twisted.conch.checkers.SSHPublicKeyDatabase}.
@since: 15.0
"""
credentialInterfaces = (ISSHPrivateKey,)
def __init__(self, keydb: IAuthorizedKeysDB) -> None:
"""
Initializes a L{SSHPublicKeyChecker}.
@param keydb: a provider of L{IAuthorizedKeysDB}
"""
self._keydb = keydb
def requestAvatarId(self, credentials):
d = defer.execute(self._sanityCheckKey, credentials)
d.addCallback(self._checkKey, credentials)
d.addCallback(self._verifyKey, credentials)
return d
def _sanityCheckKey(self, credentials):
"""
Checks whether the provided credentials are a valid SSH key with a
signature (does not actually verify the signature).
@param credentials: the credentials offered by the user
@type credentials: L{ISSHPrivateKey} provider
@raise ValidPublicKey: the credentials do not include a signature. See
L{error.ValidPublicKey} for more information.
@raise BadKeyError: The key included with the credentials is not
recognized as a key.
@return: the key in the credentials
@rtype: L{twisted.conch.ssh.keys.Key}
"""
if not credentials.signature:
raise error.ValidPublicKey()
return keys.Key.fromString(credentials.blob)
def _checkKey(self, pubKey, credentials):
"""
Checks the public key against all authorized keys (if any) for the
user.
@param pubKey: the key in the credentials (just to prevent it from
having to be calculated again)
@type pubKey:
@param credentials: the credentials offered by the user
@type credentials: L{ISSHPrivateKey} provider
@raise UnauthorizedLogin: If the key is not authorized, or if there
was any error obtaining a list of authorized keys for the user.
@return: C{pubKey} if the key is authorized
@rtype: L{twisted.conch.ssh.keys.Key}
"""
if any(
key == pubKey for key in self._keydb.getAuthorizedKeys(credentials.username)
):
return pubKey
raise UnauthorizedLogin("Key not authorized")
def _verifyKey(self, pubKey, credentials):
"""
Checks whether the credentials themselves are valid, now that we know
if the key matches the user.
@param pubKey: the key in the credentials (just to prevent it from
having to be calculated again)
@type pubKey: L{twisted.conch.ssh.keys.Key}
@param credentials: the credentials offered by the user
@type credentials: L{ISSHPrivateKey} provider
@raise UnauthorizedLogin: If the key signature is invalid or there
was any error verifying the signature.
@return: The user's username, if authentication was successful
@rtype: L{bytes}
"""
try:
if pubKey.verify(credentials.signature, credentials.sigData):
return credentials.username
except Exception as e: # Any error should be treated as a failed login
raise UnauthorizedLogin("Error while verifying key") from e
raise UnauthorizedLogin("Key signature invalid.")

View File

@@ -0,0 +1,9 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
#
"""
Client support code for Conch.
Maintainer: Paul Swartz
"""

View File

@@ -0,0 +1,65 @@
# -*- test-case-name: twisted.conch.test.test_default -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Accesses the key agent for user authentication.
Maintainer: Paul Swartz
"""
import os
from twisted.conch.ssh import agent, channel, keys
from twisted.internet import protocol, reactor
from twisted.logger import Logger
class SSHAgentClient(agent.SSHAgentClient):
_log = Logger()
def __init__(self):
agent.SSHAgentClient.__init__(self)
self.blobs = []
def getPublicKeys(self):
return self.requestIdentities().addCallback(self._cbPublicKeys)
def _cbPublicKeys(self, blobcomm):
self._log.debug("got {num_keys} public keys", num_keys=len(blobcomm))
self.blobs = [x[0] for x in blobcomm]
def getPublicKey(self):
"""
Return a L{Key} from the first blob in C{self.blobs}, if any, or
return L{None}.
"""
if self.blobs:
return keys.Key.fromString(self.blobs.pop(0))
return None
class SSHAgentForwardingChannel(channel.SSHChannel):
def channelOpen(self, specificData):
cc = protocol.ClientCreator(reactor, SSHAgentForwardingLocal)
d = cc.connectUNIX(os.environ["SSH_AUTH_SOCK"])
d.addCallback(self._cbGotLocal)
d.addErrback(lambda x: self.loseConnection())
self.buf = ""
def _cbGotLocal(self, local):
self.local = local
self.dataReceived = self.local.transport.write
self.local.dataReceived = self.write
def dataReceived(self, data):
self.buf += data
def closed(self):
if self.local:
self.local.loseConnection()
self.local = None
class SSHAgentForwardingLocal(protocol.Protocol):
pass

View File

@@ -0,0 +1,24 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
#
from twisted.conch.client import direct
connectTypes = {"direct": direct.connect}
def connect(host, port, options, verifyHostKey, userAuthObject):
useConnects = ["direct"]
return _ebConnect(
None, useConnects, host, port, options, verifyHostKey, userAuthObject
)
def _ebConnect(f, useConnects, host, port, options, vhk, uao):
if not useConnects:
return f
connectType = useConnects.pop(0)
f = connectTypes[connectType]
d = f(host, port, options, vhk, uao)
d.addErrback(_ebConnect, useConnects, host, port, options, vhk, uao)
return d

View File

@@ -0,0 +1,331 @@
# -*- test-case-name: twisted.conch.test.test_knownhosts,twisted.conch.test.test_default -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Various classes and functions for implementing user-interaction in the
command-line conch client.
You probably shouldn't use anything in this module directly, since it assumes
you are sitting at an interactive terminal. For example, to programmatically
interact with a known_hosts database, use L{twisted.conch.client.knownhosts}.
"""
import contextlib
import getpass
import io
import os
import sys
from base64 import decodebytes
from twisted.conch.client import agent
from twisted.conch.client.knownhosts import ConsoleUI, KnownHostsFile
from twisted.conch.error import ConchError
from twisted.conch.ssh import common, keys, userauth
from twisted.internet import defer, protocol, reactor
from twisted.python.compat import nativeString
from twisted.python.filepath import FilePath
# The default location of the known hosts file (probably should be parsed out
# of an ssh config file someday).
_KNOWN_HOSTS = "~/.ssh/known_hosts"
# This name is bound so that the unit tests can use 'patch' to override it.
_open = open
_input = input
def verifyHostKey(transport, host, pubKey, fingerprint):
"""
Verify a host's key.
This function is a gross vestige of some bad factoring in the client
internals. The actual implementation, and a better signature of this logic
is in L{KnownHostsFile.verifyHostKey}. This function is not deprecated yet
because the callers have not yet been rehabilitated, but they should
eventually be changed to call that method instead.
However, this function does perform two functions not implemented by
L{KnownHostsFile.verifyHostKey}. It determines the path to the user's
known_hosts file based on the options (which should really be the options
object's job), and it provides an opener to L{ConsoleUI} which opens
'/dev/tty' so that the user will be prompted on the tty of the process even
if the input and output of the process has been redirected. This latter
part is, somewhat obviously, not portable, but I don't know of a portable
equivalent that could be used.
@param host: Due to a bug in L{SSHClientTransport.verifyHostKey}, this is
always the dotted-quad IP address of the host being connected to.
@type host: L{str}
@param transport: the client transport which is attempting to connect to
the given host.
@type transport: L{SSHClientTransport}
@param fingerprint: the fingerprint of the given public key, in
xx:xx:xx:... format. This is ignored in favor of getting the fingerprint
from the key itself.
@type fingerprint: L{str}
@param pubKey: The public key of the server being connected to.
@type pubKey: L{str}
@return: a L{Deferred} which fires with C{1} if the key was successfully
verified, or fails if the key could not be successfully verified. Failure
types may include L{HostKeyChanged}, L{UserRejectedKey}, L{IOError} or
L{KeyboardInterrupt}.
"""
actualHost = transport.factory.options["host"]
actualKey = keys.Key.fromString(pubKey)
kh = KnownHostsFile.fromPath(
FilePath(
transport.factory.options["known-hosts"] or os.path.expanduser(_KNOWN_HOSTS)
)
)
ui = ConsoleUI(lambda: _open("/dev/tty", "r+b", buffering=0))
return kh.verifyHostKey(ui, actualHost, host, actualKey)
def isInKnownHosts(host, pubKey, options):
"""
Checks to see if host is in the known_hosts file for the user.
@return: 0 if it isn't, 1 if it is and is the same, 2 if it's changed.
@rtype: L{int}
"""
keyType = common.getNS(pubKey)[0]
retVal = 0
if not options["known-hosts"] and not os.path.exists(os.path.expanduser("~/.ssh/")):
print("Creating ~/.ssh directory...")
os.mkdir(os.path.expanduser("~/.ssh"))
kh_file = options["known-hosts"] or _KNOWN_HOSTS
try:
known_hosts = open(os.path.expanduser(kh_file), "rb")
except OSError:
return 0
with known_hosts:
for line in known_hosts.readlines():
split = line.split()
if len(split) < 3:
continue
hosts, hostKeyType, encodedKey = split[:3]
if host not in hosts.split(b","): # incorrect host
continue
if hostKeyType != keyType: # incorrect type of key
continue
try:
decodedKey = decodebytes(encodedKey)
except BaseException:
continue
if decodedKey == pubKey:
return 1
else:
retVal = 2
return retVal
def getHostKeyAlgorithms(host, options):
"""
Look in known_hosts for a key corresponding to C{host}.
This can be used to change the order of supported key types
in the KEXINIT packet.
@type host: L{str}
@param host: the host to check in known_hosts
@type options: L{twisted.conch.client.options.ConchOptions}
@param options: options passed to client
@return: L{list} of L{str} representing key types or L{None}.
"""
knownHosts = KnownHostsFile.fromPath(
FilePath(options["known-hosts"] or os.path.expanduser(_KNOWN_HOSTS))
)
keyTypes = []
for entry in knownHosts.iterentries():
if entry.matchesHost(host):
if entry.keyType not in keyTypes:
keyTypes.append(entry.keyType)
return keyTypes or None
class SSHUserAuthClient(userauth.SSHUserAuthClient):
def __init__(self, user, options, *args):
userauth.SSHUserAuthClient.__init__(self, user, *args)
self.keyAgent = None
self.options = options
self.usedFiles = []
if not options.identitys:
options.identitys = ["~/.ssh/id_rsa", "~/.ssh/id_dsa"]
def serviceStarted(self):
if "SSH_AUTH_SOCK" in os.environ and not self.options["noagent"]:
self._log.debug(
"using SSH agent {authSock!r}", authSock=os.environ["SSH_AUTH_SOCK"]
)
cc = protocol.ClientCreator(reactor, agent.SSHAgentClient)
d = cc.connectUNIX(os.environ["SSH_AUTH_SOCK"])
d.addCallback(self._setAgent)
d.addErrback(self._ebSetAgent)
else:
userauth.SSHUserAuthClient.serviceStarted(self)
def serviceStopped(self):
if self.keyAgent:
self.keyAgent.transport.loseConnection()
self.keyAgent = None
def _setAgent(self, a):
self.keyAgent = a
d = self.keyAgent.getPublicKeys()
d.addBoth(self._ebSetAgent)
return d
def _ebSetAgent(self, f):
userauth.SSHUserAuthClient.serviceStarted(self)
def _getPassword(self, prompt):
"""
Prompt for a password using L{getpass.getpass}.
@param prompt: Written on tty to ask for the input.
@type prompt: L{str}
@return: The input.
@rtype: L{str}
"""
with self._replaceStdoutStdin():
try:
p = getpass.getpass(prompt)
return p
except (KeyboardInterrupt, OSError):
print()
raise ConchError("PEBKAC")
def getPassword(self, prompt=None):
if prompt:
prompt = nativeString(prompt)
else:
prompt = "{}@{}'s password: ".format(
nativeString(self.user),
self.transport.transport.getPeer().host,
)
try:
# We don't know the encoding the other side is using,
# signaling that is not part of the SSH protocol. But
# using our defaultencoding is better than just going for
# ASCII.
p = self._getPassword(prompt).encode(sys.getdefaultencoding())
return defer.succeed(p)
except ConchError:
return defer.fail()
def getPublicKey(self):
"""
Get a public key from the key agent if possible, otherwise look in
the next configured identity file for one.
"""
if self.keyAgent:
key = self.keyAgent.getPublicKey()
if key is not None:
return key
files = [x for x in self.options.identitys if x not in self.usedFiles]
self._log.debug(
"public key identities: {identities}\n{files}",
identities=self.options.identitys,
files=files,
)
if not files:
return None
file = files[0]
self.usedFiles.append(file)
file = os.path.expanduser(file)
file += ".pub"
if not os.path.exists(file):
return self.getPublicKey() # try again
try:
return keys.Key.fromFile(file)
except keys.BadKeyError:
return self.getPublicKey() # try again
def signData(self, publicKey, signData):
"""
Extend the base signing behavior by using an SSH agent to sign the
data, if one is available.
@type publicKey: L{Key}
@type signData: L{bytes}
"""
if not self.usedFiles: # agent key
return self.keyAgent.signData(publicKey.blob(), signData)
else:
return userauth.SSHUserAuthClient.signData(self, publicKey, signData)
def getPrivateKey(self):
"""
Try to load the private key from the last used file identified by
C{getPublicKey}, potentially asking for the passphrase if the key is
encrypted.
"""
file = os.path.expanduser(self.usedFiles[-1])
if not os.path.exists(file):
return None
try:
return defer.succeed(keys.Key.fromFile(file))
except keys.EncryptedKeyError:
for i in range(3):
prompt = "Enter passphrase for key '%s': " % self.usedFiles[-1]
try:
p = self._getPassword(prompt).encode(sys.getfilesystemencoding())
return defer.succeed(keys.Key.fromFile(file, passphrase=p))
except (keys.BadKeyError, ConchError):
pass
return defer.fail(ConchError("bad password"))
raise
except KeyboardInterrupt:
print()
reactor.stop()
def getGenericAnswers(self, name, instruction, prompts):
responses = []
with self._replaceStdoutStdin():
if name:
print(name.decode("utf-8"))
if instruction:
print(instruction.decode("utf-8"))
for prompt, echo in prompts:
prompt = prompt.decode("utf-8")
if echo:
responses.append(_input(prompt))
else:
responses.append(getpass.getpass(prompt))
return defer.succeed(responses)
@classmethod
def _openTty(cls):
"""
Open /dev/tty as two streams one in read, one in write mode,
and return them.
@return: File objects for reading and writing to /dev/tty,
corresponding to standard input and standard output.
@rtype: A L{tuple} of L{io.TextIOWrapper} on Python 3.
"""
stdin = io.TextIOWrapper(open("/dev/tty", "rb"))
stdout = io.TextIOWrapper(open("/dev/tty", "wb"))
return stdin, stdout
@classmethod
@contextlib.contextmanager
def _replaceStdoutStdin(cls):
"""
Contextmanager that replaces stdout and stdin with /dev/tty
and resets them when it is done.
"""
oldout, oldin = sys.stdout, sys.stdin
sys.stdin, sys.stdout = cls._openTty()
try:
yield
finally:
sys.stdout.close()
sys.stdin.close()
sys.stdout, sys.stdin = oldout, oldin

View File

@@ -0,0 +1,98 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
from twisted.conch import error
from twisted.conch.ssh import transport
from twisted.internet import defer, protocol, reactor
class SSHClientFactory(protocol.ClientFactory):
def __init__(self, d, options, verifyHostKey, userAuthObject):
self.d = d
self.options = options
self.verifyHostKey = verifyHostKey
self.userAuthObject = userAuthObject
def clientConnectionLost(self, connector, reason):
if self.options["reconnect"]:
connector.connect()
def clientConnectionFailed(self, connector, reason):
if self.d is None:
return
d, self.d = self.d, None
d.errback(reason)
def buildProtocol(self, addr):
trans = SSHClientTransport(self)
if self.options["ciphers"]:
trans.supportedCiphers = self.options["ciphers"]
if self.options["macs"]:
trans.supportedMACs = self.options["macs"]
if self.options["compress"]:
trans.supportedCompressions[0:1] = ["zlib"]
if self.options["host-key-algorithms"]:
trans.supportedPublicKeys = self.options["host-key-algorithms"]
return trans
class SSHClientTransport(transport.SSHClientTransport):
def __init__(self, factory):
self.factory = factory
self.unixServer = None
def connectionLost(self, reason):
if self.unixServer:
d = self.unixServer.stopListening()
self.unixServer = None
else:
d = defer.succeed(None)
d.addCallback(
lambda x: transport.SSHClientTransport.connectionLost(self, reason)
)
def receiveError(self, code, desc):
if self.factory.d is None:
return
d, self.factory.d = self.factory.d, None
d.errback(error.ConchError(desc, code))
def sendDisconnect(self, code, reason):
if self.factory.d is None:
return
d, self.factory.d = self.factory.d, None
transport.SSHClientTransport.sendDisconnect(self, code, reason)
d.errback(error.ConchError(reason, code))
def receiveDebug(self, alwaysDisplay, message, lang):
self._log.debug(
"Received Debug Message: {message}",
message=message,
alwaysDisplay=alwaysDisplay,
lang=lang,
)
if alwaysDisplay: # XXX what should happen here?
print(message)
def verifyHostKey(self, pubKey, fingerprint):
return self.factory.verifyHostKey(
self, self.transport.getPeer().host, pubKey, fingerprint
)
def setService(self, service):
self._log.info("setting client server to {service}", service=service)
transport.SSHClientTransport.setService(self, service)
if service.name != "ssh-userauth" and self.factory.d is not None:
d, self.factory.d = self.factory.d, None
d.callback(None)
def connectionSecure(self):
self.requestService(self.factory.userAuthObject)
def connect(host, port, options, verifyHostKey, userAuthObject):
d = defer.Deferred()
factory = SSHClientFactory(d, options, verifyHostKey, userAuthObject)
reactor.connectTCP(host, port, factory)
return d

View File

@@ -0,0 +1,622 @@
# -*- test-case-name: twisted.conch.test.test_knownhosts -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
An implementation of the OpenSSH known_hosts database.
@since: 8.2
"""
from __future__ import annotations
import hmac
import sys
from binascii import Error as DecodeError, a2b_base64, b2a_base64
from contextlib import closing
from hashlib import sha1
from typing import IO, Callable, Literal
from zope.interface import implementer
from twisted.conch.error import HostKeyChanged, InvalidEntry, UserRejectedKey
from twisted.conch.interfaces import IKnownHostEntry
from twisted.conch.ssh.keys import BadKeyError, FingerprintFormats, Key
from twisted.internet import defer
from twisted.internet.defer import Deferred
from twisted.logger import Logger
from twisted.python.compat import nativeString
from twisted.python.filepath import FilePath
from twisted.python.randbytes import secureRandom
from twisted.python.util import FancyEqMixin
log = Logger()
def _b64encode(s):
"""
Encode a binary string as base64 with no trailing newline.
@param s: The string to encode.
@type s: L{bytes}
@return: The base64-encoded string.
@rtype: L{bytes}
"""
return b2a_base64(s).strip()
def _extractCommon(string):
"""
Extract common elements of base64 keys from an entry in a hosts file.
@param string: A known hosts file entry (a single line).
@type string: L{bytes}
@return: a 4-tuple of hostname data (L{bytes}), ssh key type (L{bytes}), key
(L{Key}), and comment (L{bytes} or L{None}). The hostname data is
simply the beginning of the line up to the first occurrence of
whitespace.
@rtype: L{tuple}
"""
elements = string.split(None, 2)
if len(elements) != 3:
raise InvalidEntry()
hostnames, keyType, keyAndComment = elements
splitkey = keyAndComment.split(None, 1)
if len(splitkey) == 2:
keyString, comment = splitkey
comment = comment.rstrip(b"\n")
else:
keyString = splitkey[0]
comment = None
key = Key.fromString(a2b_base64(keyString))
return hostnames, keyType, key, comment
class _BaseEntry:
"""
Abstract base of both hashed and non-hashed entry objects, since they
represent keys and key types the same way.
@ivar keyType: The type of the key; either ssh-dss or ssh-rsa.
@type keyType: L{bytes}
@ivar publicKey: The server public key indicated by this line.
@type publicKey: L{twisted.conch.ssh.keys.Key}
@ivar comment: Trailing garbage after the key line.
@type comment: L{bytes}
"""
def __init__(self, keyType, publicKey, comment):
self.keyType = keyType
self.publicKey = publicKey
self.comment = comment
def matchesKey(self, keyObject):
"""
Check to see if this entry matches a given key object.
@param keyObject: A public key object to check.
@type keyObject: L{Key}
@return: C{True} if this entry's key matches C{keyObject}, C{False}
otherwise.
@rtype: L{bool}
"""
return self.publicKey == keyObject
@implementer(IKnownHostEntry)
class PlainEntry(_BaseEntry):
"""
A L{PlainEntry} is a representation of a plain-text entry in a known_hosts
file.
@ivar _hostnames: the list of all host-names associated with this entry.
"""
def __init__(
self, hostnames: list[bytes], keyType: bytes, publicKey: Key, comment: bytes
):
self._hostnames: list[bytes] = hostnames
super().__init__(keyType, publicKey, comment)
@classmethod
def fromString(cls, string: bytes) -> PlainEntry:
"""
Parse a plain-text entry in a known_hosts file, and return a
corresponding L{PlainEntry}.
@param string: a space-separated string formatted like "hostname
key-type base64-key-data comment".
@raise DecodeError: if the key is not valid encoded as valid base64.
@raise InvalidEntry: if the entry does not have the right number of
elements and is therefore invalid.
@raise BadKeyError: if the key, once decoded from base64, is not
actually an SSH key.
@return: an IKnownHostEntry representing the hostname and key in the
input line.
@rtype: L{PlainEntry}
"""
hostnames, keyType, key, comment = _extractCommon(string)
self = cls(hostnames.split(b","), keyType, key, comment)
return self
def matchesHost(self, hostname: bytes | str) -> bool:
"""
Check to see if this entry matches a given hostname.
@param hostname: A hostname or IP address literal to check against this
entry.
@return: C{True} if this entry is for the given hostname or IP address,
C{False} otherwise.
"""
if isinstance(hostname, str):
hostname = hostname.encode("utf-8")
return hostname in self._hostnames
def toString(self) -> bytes:
"""
Implement L{IKnownHostEntry.toString} by recording the comma-separated
hostnames, key type, and base-64 encoded key.
@return: The string representation of this entry, with unhashed hostname
information.
"""
fields = [
b",".join(self._hostnames),
self.keyType,
_b64encode(self.publicKey.blob()),
]
if self.comment is not None:
fields.append(self.comment)
return b" ".join(fields)
@implementer(IKnownHostEntry)
class UnparsedEntry:
"""
L{UnparsedEntry} is an entry in a L{KnownHostsFile} which can't actually be
parsed; therefore it matches no keys and no hosts.
"""
def __init__(self, string):
"""
Create an unparsed entry from a line in a known_hosts file which cannot
otherwise be parsed.
"""
self._string = string
def matchesHost(self, hostname):
"""
Always returns False.
"""
return False
def matchesKey(self, key):
"""
Always returns False.
"""
return False
def toString(self):
"""
Returns the input line, without its newline if one was given.
@return: The string representation of this entry, almost exactly as was
used to initialize this entry but without a trailing newline.
@rtype: L{bytes}
"""
return self._string.rstrip(b"\n")
def _hmacedString(key, string):
"""
Return the SHA-1 HMAC hash of the given key and string.
@param key: The HMAC key.
@type key: L{bytes}
@param string: The string to be hashed.
@type string: L{bytes}
@return: The keyed hash value.
@rtype: L{bytes}
"""
hash = hmac.HMAC(key, digestmod=sha1)
if isinstance(string, str):
string = string.encode("utf-8")
hash.update(string)
return hash.digest()
@implementer(IKnownHostEntry)
class HashedEntry(_BaseEntry, FancyEqMixin):
"""
A L{HashedEntry} is a representation of an entry in a known_hosts file
where the hostname has been hashed and salted.
@ivar _hostSalt: the salt to combine with a hostname for hashing.
@ivar _hostHash: the hashed representation of the hostname.
@cvar MAGIC: the 'hash magic' string used to identify a hashed line in a
known_hosts file as opposed to a plaintext one.
"""
MAGIC = b"|1|"
compareAttributes = ("_hostSalt", "_hostHash", "keyType", "publicKey", "comment")
def __init__(
self,
hostSalt: bytes,
hostHash: bytes,
keyType: bytes,
publicKey: Key,
comment: bytes | None,
) -> None:
self._hostSalt = hostSalt
self._hostHash = hostHash
super().__init__(keyType, publicKey, comment)
@classmethod
def fromString(cls, string: bytes) -> HashedEntry:
"""
Load a hashed entry from a string representing a line in a known_hosts
file.
@param string: A complete single line from a I{known_hosts} file,
formatted as defined by OpenSSH.
@raise DecodeError: if the key, the hostname, or the is not valid
encoded as valid base64
@raise InvalidEntry: if the entry does not have the right number of
elements and is therefore invalid, or the host/hash portion
contains more items than just the host and hash.
@raise BadKeyError: if the key, once decoded from base64, is not
actually an SSH key.
@return: The newly created L{HashedEntry} instance, initialized with
the information from C{string}.
"""
stuff, keyType, key, comment = _extractCommon(string)
saltAndHash = stuff[len(cls.MAGIC) :].split(b"|")
if len(saltAndHash) != 2:
raise InvalidEntry()
hostSalt, hostHash = saltAndHash
self = cls(a2b_base64(hostSalt), a2b_base64(hostHash), keyType, key, comment)
return self
def matchesHost(self, hostname):
"""
Implement L{IKnownHostEntry.matchesHost} to compare the hash of the
input to the stored hash.
@param hostname: A hostname or IP address literal to check against this
entry.
@type hostname: L{bytes}
@return: C{True} if this entry is for the given hostname or IP address,
C{False} otherwise.
@rtype: L{bool}
"""
return hmac.compare_digest(
_hmacedString(self._hostSalt, hostname), self._hostHash
)
def toString(self):
"""
Implement L{IKnownHostEntry.toString} by base64-encoding the salt, host
hash, and key.
@return: The string representation of this entry, with the hostname part
hashed.
@rtype: L{bytes}
"""
fields = [
self.MAGIC
+ b"|".join([_b64encode(self._hostSalt), _b64encode(self._hostHash)]),
self.keyType,
_b64encode(self.publicKey.blob()),
]
if self.comment is not None:
fields.append(self.comment)
return b" ".join(fields)
class KnownHostsFile:
"""
A structured representation of an OpenSSH-format ~/.ssh/known_hosts file.
@ivar _added: A list of L{IKnownHostEntry} providers which have been added
to this instance in memory but not yet saved.
@ivar _clobber: A flag indicating whether the current contents of the save
path will be disregarded and potentially overwritten or not. If
C{True}, this will be done. If C{False}, entries in the save path will
be read and new entries will be saved by appending rather than
overwriting.
@type _clobber: L{bool}
@ivar _savePath: See C{savePath} parameter of L{__init__}.
"""
def __init__(self, savePath: FilePath[str]) -> None:
"""
Create a new, empty KnownHostsFile.
Unless you want to erase the current contents of C{savePath}, you want
to use L{KnownHostsFile.fromPath} instead.
@param savePath: The L{FilePath} to which to save new entries.
@type savePath: L{FilePath}
"""
self._added: list[IKnownHostEntry] = []
self._savePath = savePath
self._clobber = True
@property
def savePath(self) -> FilePath[str]:
"""
@see: C{savePath} parameter of L{__init__}
"""
return self._savePath
def iterentries(self):
"""
Iterate over the host entries in this file.
@return: An iterable the elements of which provide L{IKnownHostEntry}.
There is an element for each entry in the file as well as an element
for each added but not yet saved entry.
@rtype: iterable of L{IKnownHostEntry} providers
"""
for entry in self._added:
yield entry
if self._clobber:
return
try:
fp = self._savePath.open()
except OSError:
return
with fp:
for line in fp:
try:
if line.startswith(HashedEntry.MAGIC):
entry = HashedEntry.fromString(line)
else:
entry = PlainEntry.fromString(line)
except (DecodeError, InvalidEntry, BadKeyError):
entry = UnparsedEntry(line)
yield entry
def hasHostKey(self, hostname, key):
"""
Check for an entry with matching hostname and key.
@param hostname: A hostname or IP address literal to check for.
@type hostname: L{bytes}
@param key: The public key to check for.
@type key: L{Key}
@return: C{True} if the given hostname and key are present in this file,
C{False} if they are not.
@rtype: L{bool}
@raise HostKeyChanged: if the host key found for the given hostname
does not match the given key.
"""
for lineidx, entry in enumerate(self.iterentries(), -len(self._added)):
if entry.matchesHost(hostname) and entry.keyType == key.sshType():
if entry.matchesKey(key):
return True
else:
# Notice that lineidx is 0-based but HostKeyChanged.lineno
# is 1-based.
if lineidx < 0:
line = None
path = None
else:
line = lineidx + 1
path = self._savePath
raise HostKeyChanged(entry, path, line)
return False
def verifyHostKey(
self, ui: ConsoleUI, hostname: bytes, ip: bytes, key: Key
) -> Deferred[bool]:
"""
Verify the given host key for the given IP and host, asking for
confirmation from, and notifying, the given UI about changes to this
file.
@param ui: The user interface to request an IP address from.
@param hostname: The hostname that the user requested to connect to.
@param ip: The string representation of the IP address that is actually
being connected to.
@param key: The public key of the server.
@return: a L{Deferred} that fires with True when the key has been
verified, or fires with an errback when the key either cannot be
verified or has changed.
@rtype: L{Deferred}
"""
hhk = defer.execute(self.hasHostKey, hostname, key)
def gotHasKey(result: bool) -> bool | Deferred[bool]:
if result:
if not self.hasHostKey(ip, key):
addMessage = (
f"Warning: Permanently added the {key.type()} host key"
f" for IP address '{ip.decode()}' to the list of known"
" hosts.\n"
)
ui.warn(addMessage.encode("utf-8"))
self.addHostKey(ip, key)
self.save()
return result
else:
def promptResponse(response: bool) -> bool:
if response:
self.addHostKey(hostname, key)
self.addHostKey(ip, key)
self.save()
return response
else:
raise UserRejectedKey()
keytype: str = key.type()
if keytype == "EC":
keytype = "ECDSA"
prompt = (
"The authenticity of host '%s (%s)' "
"can't be established.\n"
"%s key fingerprint is SHA256:%s.\n"
"Are you sure you want to continue connecting (yes/no)? "
% (
nativeString(hostname),
nativeString(ip),
keytype,
key.fingerprint(format=FingerprintFormats.SHA256_BASE64),
)
)
proceed = ui.prompt(prompt.encode(sys.getdefaultencoding()))
return proceed.addCallback(promptResponse)
return hhk.addCallback(gotHasKey)
def addHostKey(self, hostname: bytes, key: Key) -> HashedEntry:
"""
Add a new L{HashedEntry} to the key database.
Note that you still need to call L{KnownHostsFile.save} if you wish
these changes to be persisted.
@param hostname: A hostname or IP address literal to associate with the
new entry.
@type hostname: L{bytes}
@param key: The public key to associate with the new entry.
@type key: L{Key}
@return: The L{HashedEntry} that was added.
@rtype: L{HashedEntry}
"""
salt = secureRandom(20)
keyType = key.sshType()
entry = HashedEntry(salt, _hmacedString(salt, hostname), keyType, key, None)
self._added.append(entry)
return entry
def save(self) -> None:
"""
Save this L{KnownHostsFile} to the path it was loaded from.
"""
p = self._savePath.parent()
if not p.isdir():
p.makedirs()
mode: Literal["a", "w"] = "w" if self._clobber else "a"
with self._savePath.open(mode) as hostsFileObj:
if self._added:
hostsFileObj.write(
b"\n".join([entry.toString() for entry in self._added]) + b"\n"
)
self._added = []
self._clobber = False
@classmethod
def fromPath(cls, path: FilePath[str]) -> KnownHostsFile:
"""
Create a new L{KnownHostsFile}, potentially reading existing known
hosts information from the given file.
@param path: A path object to use for both reading contents from and
later saving to. If no file exists at this path, it is not an
error; a L{KnownHostsFile} with no entries is returned.
@return: A L{KnownHostsFile} initialized with entries from C{path}.
"""
knownHosts = cls(path)
knownHosts._clobber = False
return knownHosts
class ConsoleUI:
"""
A UI object that can ask true/false questions and post notifications on the
console, to be used during key verification.
"""
def __init__(self, opener: Callable[[], IO[bytes]]):
"""
@param opener: A no-argument callable which should open a console
binary-mode file-like object to be used for reading and writing.
This initializes the C{opener} attribute.
@type opener: callable taking no arguments and returning a read/write
file-like object
"""
self.opener = opener
def prompt(self, text: bytes) -> Deferred[bool]:
"""
Write the given text as a prompt to the console output, then read a
result from the console input.
@param text: Something to present to a user to solicit a yes or no
response.
@type text: L{bytes}
@return: a L{Deferred} which fires with L{True} when the user answers
'yes' and L{False} when the user answers 'no'. It may errback if
there were any I/O errors.
"""
d = defer.succeed(None)
def body(ignored):
with closing(self.opener()) as f:
f.write(text)
while True:
answer = f.readline().strip().lower()
if answer == b"yes":
return True
elif answer in {b"no", b""}:
return False
else:
f.write(b"Please type 'yes' or 'no': ")
return d.addCallback(body)
def warn(self, text: bytes) -> None:
"""
Notify the user (non-interactively) of the provided text, by writing it
to the console.
@param text: Some information the user is to be made aware of.
"""
try:
with closing(self.opener()) as f:
f.write(text)
except Exception:
log.failure("Failed to write to console")

View File

@@ -0,0 +1,109 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
import sys
from typing import List, Optional, Union
#
from twisted.conch.ssh.transport import SSHCiphers, SSHClientTransport
from twisted.python import usage
class ConchOptions(usage.Options):
optParameters: List[List[Optional[Union[str, int]]]] = [
["user", "l", None, "Log in using this user name."],
["identity", "i", None],
["ciphers", "c", None],
["macs", "m", None],
["port", "p", None, "Connect to this port. Server must be on the same port."],
["option", "o", None, "Ignored OpenSSH options"],
["host-key-algorithms", "", None],
["known-hosts", "", None, "File to check for host keys"],
["user-authentications", "", None, "Types of user authentications to use."],
["logfile", "", None, "File to log to, or - for stdout"],
]
optFlags = [
["version", "V", "Display version number only."],
["compress", "C", "Enable compression."],
["log", "v", "Enable logging (defaults to stderr)"],
["nox11", "x", "Disable X11 connection forwarding (default)"],
["agent", "A", "Enable authentication agent forwarding"],
["noagent", "a", "Disable authentication agent forwarding (default)"],
["reconnect", "r", "Reconnect to the server if the connection is lost."],
]
compData = usage.Completions(
mutuallyExclusive=[("agent", "noagent")],
optActions={
"user": usage.CompleteUsernames(),
"ciphers": usage.CompleteMultiList(
[v.decode() for v in SSHCiphers.cipherMap.keys()],
descr="ciphers to choose from",
),
"macs": usage.CompleteMultiList(
[v.decode() for v in SSHCiphers.macMap.keys()],
descr="macs to choose from",
),
"host-key-algorithms": usage.CompleteMultiList(
[v.decode() for v in SSHClientTransport.supportedPublicKeys],
descr="host key algorithms to choose from",
),
# "user-authentications": usage.CompleteMultiList(?
# descr='user authentication types' ),
},
extraActions=[
usage.CompleteUserAtHost(),
usage.Completer(descr="command"),
usage.Completer(descr="argument", repeat=True),
],
)
def __init__(self, *args, **kw):
usage.Options.__init__(self, *args, **kw)
self.identitys = []
self.conns = None
def opt_identity(self, i):
"""Identity for public-key authentication"""
self.identitys.append(i)
def opt_ciphers(self, ciphers):
"Select encryption algorithms"
ciphers = ciphers.split(",")
for cipher in ciphers:
if cipher not in SSHCiphers.cipherMap:
sys.exit("Unknown cipher type '%s'" % cipher)
self["ciphers"] = ciphers
def opt_macs(self, macs):
"Specify MAC algorithms"
if isinstance(macs, str):
macs = macs.encode("utf-8")
macs = macs.split(b",")
for mac in macs:
if mac not in SSHCiphers.macMap:
sys.exit("Unknown mac type '%r'" % mac)
self["macs"] = macs
def opt_host_key_algorithms(self, hkas):
"Select host key algorithms"
if isinstance(hkas, str):
hkas = hkas.encode("utf-8")
hkas = hkas.split(b",")
for hka in hkas:
if hka not in SSHClientTransport.supportedPublicKeys:
sys.exit("Unknown host key type '%r'" % hka)
self["host-key-algorithms"] = hkas
def opt_user_authentications(self, uas):
"Choose how to authenticate to the remote server"
if isinstance(uas, str):
uas = uas.encode("utf-8")
self["user-authentications"] = uas.split(b",")
# def opt_compress(self):
# "Enable compression"
# self.enableCompression = 1
# SSHClientTransport.supportedCompressions[0:1] = ['zlib']

View File

@@ -0,0 +1,845 @@
# -*- test-case-name: twisted.conch.test.test_endpoints -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Endpoint implementations of various SSH interactions.
"""
from __future__ import annotations
__all__ = [
"AuthenticationFailed",
"SSHCommandAddress",
"SSHCommandClientEndpoint",
]
import signal
from io import BytesIO
from os.path import expanduser
from struct import unpack
from typing import IO, Any
from zope.interface import Interface, implementer
from twisted.conch.client.agent import SSHAgentClient
from twisted.conch.client.default import _KNOWN_HOSTS
from twisted.conch.client.knownhosts import ConsoleUI, KnownHostsFile
from twisted.conch.ssh.channel import SSHChannel
from twisted.conch.ssh.common import NS, getNS
from twisted.conch.ssh.connection import SSHConnection
from twisted.conch.ssh.keys import Key
from twisted.conch.ssh.transport import SSHClientTransport
from twisted.conch.ssh.userauth import SSHUserAuthClient
from twisted.internet.defer import CancelledError, Deferred, succeed
from twisted.internet.endpoints import TCP4ClientEndpoint, connectProtocol
from twisted.internet.error import ConnectionDone, ProcessTerminated
from twisted.internet.interfaces import IStreamClientEndpoint
from twisted.internet.protocol import Factory
from twisted.logger import Logger
from twisted.python.compat import nativeString, networkString
from twisted.python.failure import Failure
from twisted.python.filepath import FilePath
class AuthenticationFailed(Exception):
"""
An SSH session could not be established because authentication was not
successful.
"""
# This should be public. See #6541.
class _ISSHConnectionCreator(Interface):
"""
An L{_ISSHConnectionCreator} knows how to create SSH connections somehow.
"""
def secureConnection():
"""
Return a new, connected, secured, but not yet authenticated instance of
L{twisted.conch.ssh.transport.SSHServerTransport} or
L{twisted.conch.ssh.transport.SSHClientTransport}.
"""
def cleanupConnection(connection, immediate):
"""
Perform cleanup necessary for a connection object previously returned
from this creator's C{secureConnection} method.
@param connection: An L{twisted.conch.ssh.transport.SSHServerTransport}
or L{twisted.conch.ssh.transport.SSHClientTransport} returned by a
previous call to C{secureConnection}. It is no longer needed by
the caller of that method and may be closed or otherwise cleaned up
as necessary.
@param immediate: If C{True} don't wait for any network communication,
just close the connection immediately and as aggressively as
necessary.
"""
class SSHCommandAddress:
"""
An L{SSHCommandAddress} instance represents the address of an SSH server, a
username which was used to authenticate with that server, and a command
which was run there.
@ivar server: See L{__init__}
@ivar username: See L{__init__}
@ivar command: See L{__init__}
"""
def __init__(self, server, username, command):
"""
@param server: The address of the SSH server on which the command is
running.
@type server: L{IAddress} provider
@param username: An authentication username which was used to
authenticate against the server at the given address.
@type username: L{bytes}
@param command: A command which was run in a session channel on the
server at the given address.
@type command: L{bytes}
"""
self.server = server
self.username = username
self.command = command
class _CommandChannel(SSHChannel):
"""
A L{_CommandChannel} executes a command in a session channel and connects
its input and output to an L{IProtocol} provider.
@ivar _creator: See L{__init__}
@ivar _command: See L{__init__}
@ivar _protocolFactory: See L{__init__}
@ivar _commandConnected: See L{__init__}
@ivar _protocol: An L{IProtocol} provider created using C{_protocolFactory}
which is hooked up to the running command's input and output streams.
"""
name = b"session"
_log = Logger()
def __init__(self, creator, command, protocolFactory, commandConnected):
"""
@param creator: The L{_ISSHConnectionCreator} provider which was used
to get the connection which this channel exists on.
@type creator: L{_ISSHConnectionCreator} provider
@param command: The command to be executed.
@type command: L{bytes}
@param protocolFactory: A client factory to use to build a L{IProtocol}
provider to use to associate with the running command.
@param commandConnected: A L{Deferred} to use to signal that execution
of the command has failed or that it has succeeded and the command
is now running.
@type commandConnected: L{Deferred}
"""
SSHChannel.__init__(self)
self._creator = creator
self._command = command
self._protocolFactory = protocolFactory
self._commandConnected = commandConnected
self._reason = None
def openFailed(self, reason):
"""
When the request to open a new channel to run this command in fails,
fire the C{commandConnected} deferred with a failure indicating that.
"""
self._commandConnected.errback(reason)
def channelOpen(self, ignored):
"""
When the request to open a new channel to run this command in succeeds,
issue an C{"exec"} request to run the command.
"""
command = self.conn.sendRequest(
self, b"exec", NS(self._command), wantReply=True
)
command.addCallbacks(self._execSuccess, self._execFailure)
def _execFailure(self, reason):
"""
When the request to execute the command in this channel fails, fire the
C{commandConnected} deferred with a failure indicating this.
@param reason: The cause of the command execution failure.
@type reason: L{Failure}
"""
self._commandConnected.errback(reason)
def _execSuccess(self, ignored):
"""
When the request to execute the command in this channel succeeds, use
C{protocolFactory} to build a protocol to handle the command's input
and output and connect the protocol to a transport representing those
streams.
Also fire C{commandConnected} with the created protocol after it is
connected to its transport.
@param ignored: The (ignored) result of the execute request
"""
self._protocol = self._protocolFactory.buildProtocol(
SSHCommandAddress(
self.conn.transport.transport.getPeer(),
self.conn.transport.creator.username,
self.conn.transport.creator.command,
)
)
self._protocol.makeConnection(self)
self._commandConnected.callback(self._protocol)
def dataReceived(self, data):
"""
When the command's stdout data arrives over the channel, deliver it to
the protocol instance.
@param data: The bytes from the command's stdout.
@type data: L{bytes}
"""
self._protocol.dataReceived(data)
def request_exit_status(self, data):
"""
When the server sends the command's exit status, record it for later
delivery to the protocol.
@param data: The network-order four byte representation of the exit
status of the command.
@type data: L{bytes}
"""
(status,) = unpack(">L", data)
if status != 0:
self._reason = ProcessTerminated(status, None, None)
def request_exit_signal(self, data):
"""
When the server sends the command's exit status, record it for later
delivery to the protocol.
@param data: The network-order four byte representation of the exit
signal of the command.
@type data: L{bytes}
"""
shortSignalName, data = getNS(data)
coreDumped, data = bool(ord(data[0:1])), data[1:]
errorMessage, data = getNS(data)
languageTag, data = getNS(data)
signalName = f"SIG{nativeString(shortSignalName)}"
signalID = getattr(signal, signalName, -1)
self._log.info(
"Process exited with signal {shortSignalName!r};"
" core dumped: {coreDumped};"
" error message: {errorMessage};"
" language: {languageTag!r}",
shortSignalName=shortSignalName,
coreDumped=coreDumped,
errorMessage=errorMessage.decode("utf-8"),
languageTag=languageTag,
)
self._reason = ProcessTerminated(None, signalID, None)
def closed(self):
"""
When the channel closes, deliver disconnection notification to the
protocol.
"""
self._creator.cleanupConnection(self.conn, False)
if self._reason is None:
reason = ConnectionDone("ssh channel closed")
else:
reason = self._reason
self._protocol.connectionLost(Failure(reason))
class _ConnectionReady(SSHConnection):
"""
L{_ConnectionReady} is an L{SSHConnection} (an SSH service) which only
propagates the I{serviceStarted} event to a L{Deferred} to be handled
elsewhere.
"""
def __init__(self, ready):
"""
@param ready: A L{Deferred} which should be fired when
I{serviceStarted} happens.
"""
SSHConnection.__init__(self)
self._ready = ready
def serviceStarted(self):
"""
When the SSH I{connection} I{service} this object represents is ready
to be used, fire the C{connectionReady} L{Deferred} to publish that
event to some other interested party.
"""
self._ready.callback(self)
del self._ready
class _UserAuth(SSHUserAuthClient):
"""
L{_UserAuth} implements the client part of SSH user authentication in the
convenient way a user might expect if they are familiar with the
interactive I{ssh} command line client.
L{_UserAuth} supports key-based authentication, password-based
authentication, and delegating authentication to an agent.
"""
password = None
keys = None
agent = None
def getPublicKey(self):
"""
Retrieve the next public key object to offer to the server, possibly
delegating to an authentication agent if there is one.
@return: The public part of a key pair that could be used to
authenticate with the server, or L{None} if there are no more
public keys to try.
@rtype: L{twisted.conch.ssh.keys.Key} or L{None}
"""
if self.agent is not None:
return self.agent.getPublicKey()
if self.keys:
self.key = self.keys.pop(0)
else:
self.key = None
return self.key.public()
def signData(self, publicKey, signData):
"""
Extend the base signing behavior by using an SSH agent to sign the
data, if one is available.
@type publicKey: L{Key}
@type signData: L{str}
"""
if self.agent is not None:
return self.agent.signData(publicKey.blob(), signData)
else:
return SSHUserAuthClient.signData(self, publicKey, signData)
def getPrivateKey(self):
"""
Get the private part of a key pair to use for authentication. The key
corresponds to the public part most recently returned from
C{getPublicKey}.
@return: A L{Deferred} which fires with the private key.
@rtype: L{Deferred}
"""
return succeed(self.key)
def getPassword(self):
"""
Get the password to use for authentication.
@return: A L{Deferred} which fires with the password, or L{None} if the
password was not specified.
"""
if self.password is None:
return
return succeed(self.password)
def ssh_USERAUTH_SUCCESS(self, packet):
"""
Handle user authentication success in the normal way, but also make a
note of the state change on the L{_CommandTransport}.
"""
self.transport._state = b"CHANNELLING"
return SSHUserAuthClient.ssh_USERAUTH_SUCCESS(self, packet)
def connectToAgent(self, endpoint):
"""
Set up a connection to the authentication agent and trigger its
initialization.
@param endpoint: An endpoint which can be used to connect to the
authentication agent.
@type endpoint: L{IStreamClientEndpoint} provider
@return: A L{Deferred} which fires when the agent connection is ready
for use.
"""
factory = Factory()
factory.protocol = SSHAgentClient
d = endpoint.connect(factory)
def connected(agent):
self.agent = agent
return agent.getPublicKeys()
d.addCallback(connected)
return d
def loseAgentConnection(self):
"""
Disconnect the agent.
"""
if self.agent is None:
return
self.agent.transport.loseConnection()
class _CommandTransport(SSHClientTransport):
"""
L{_CommandTransport} is an SSH client I{transport} which includes a host
key verification step before it will proceed to secure the connection.
L{_CommandTransport} also knows how to set up a connection to an
authentication agent if it is told where it can connect to one.
@ivar _userauth: The L{_UserAuth} instance which is in charge of the
overall authentication process or L{None} if the SSH connection has not
reach yet the C{user-auth} service.
@type _userauth: L{_UserAuth}
"""
# STARTING -> SECURING -> AUTHENTICATING -> CHANNELLING -> RUNNING
_state = b"STARTING"
_hostKeyFailure = None
_userauth = None
def __init__(self, creator):
"""
@param creator: The L{_NewConnectionHelper} that created this
connection.
@type creator: L{_NewConnectionHelper}.
"""
self.connectionReady = Deferred(lambda d: self.transport.abortConnection())
# Clear the reference to that deferred to help the garbage collector
# and to signal to other parts of this implementation (in particular
# connectionLost) that it has already been fired and does not need to
# be fired again.
def readyFired(result):
self.connectionReady = None
return result
self.connectionReady.addBoth(readyFired)
self.creator = creator
def verifyHostKey(self, hostKey, fingerprint):
"""
Ask the L{KnownHostsFile} provider available on the factory which
created this protocol this protocol to verify the given host key.
@return: A L{Deferred} which fires with the result of
L{KnownHostsFile.verifyHostKey}.
"""
hostname = self.creator.hostname
ip = networkString(self.transport.getPeer().host)
self._state = b"SECURING"
d = self.creator.knownHosts.verifyHostKey(
self.creator.ui, hostname, ip, Key.fromString(hostKey)
)
d.addErrback(self._saveHostKeyFailure)
return d
def _saveHostKeyFailure(self, reason):
"""
When host key verification fails, record the reason for the failure in
order to fire a L{Deferred} with it later.
@param reason: The cause of the host key verification failure.
@type reason: L{Failure}
@return: C{reason}
@rtype: L{Failure}
"""
self._hostKeyFailure = reason
return reason
def connectionSecure(self):
"""
When the connection is secure, start the authentication process.
"""
self._state = b"AUTHENTICATING"
command = _ConnectionReady(self.connectionReady)
self._userauth = _UserAuth(self.creator.username, command)
self._userauth.password = self.creator.password
if self.creator.keys:
self._userauth.keys = list(self.creator.keys)
if self.creator.agentEndpoint is not None:
d = self._userauth.connectToAgent(self.creator.agentEndpoint)
else:
d = succeed(None)
def maybeGotAgent(ignored):
self.requestService(self._userauth)
d.addBoth(maybeGotAgent)
def connectionLost(self, reason):
"""
When the underlying connection to the SSH server is lost, if there were
any connection setup errors, propagate them. Also, clean up the
connection to the ssh agent if one was created.
"""
if self._userauth:
self._userauth.loseAgentConnection()
if self._state == b"RUNNING" or self.connectionReady is None:
return
if self._state == b"SECURING" and self._hostKeyFailure is not None:
reason = self._hostKeyFailure
elif self._state == b"AUTHENTICATING":
reason = Failure(
AuthenticationFailed("Connection lost while authenticating")
)
self.connectionReady.errback(reason)
@implementer(IStreamClientEndpoint)
class SSHCommandClientEndpoint:
"""
L{SSHCommandClientEndpoint} exposes the command-executing functionality of
SSH servers.
L{SSHCommandClientEndpoint} can set up a new SSH connection, authenticate
it in any one of a number of different ways (keys, passwords, agents),
launch a command over that connection and then associate its input and
output with a protocol.
It can also re-use an existing, already-authenticated SSH connection
(perhaps one which already has some SSH channels being used for other
purposes). In this case it creates a new SSH channel to use to execute the
command. Notably this means it supports multiplexing several different
command invocations over a single SSH connection.
"""
def __init__(self, creator, command):
"""
@param creator: An L{_ISSHConnectionCreator} provider which will be
used to set up the SSH connection which will be used to run a
command.
@type creator: L{_ISSHConnectionCreator} provider
@param command: The command line to execute on the SSH server. This
byte string is interpreted by a shell on the SSH server, so it may
have a value like C{"ls /"}. Take care when trying to run a
command like C{"/Volumes/My Stuff/a-program"} - spaces (and other
special bytes) may require escaping.
@type command: L{bytes}
"""
self._creator = creator
self._command = command
@classmethod
def newConnection(
cls,
reactor,
command,
username,
hostname,
port=None,
keys=None,
password=None,
agentEndpoint=None,
knownHosts=None,
ui=None,
):
"""
Create and return a new endpoint which will try to create a new
connection to an SSH server and run a command over it. It will also
close the connection if there are problems leading up to the command
being executed, after the command finishes, or if the connection
L{Deferred} is cancelled.
@param reactor: The reactor to use to establish the connection.
@type reactor: L{IReactorTCP} provider
@param command: See L{__init__}'s C{command} argument.
@param username: The username with which to authenticate to the SSH
server.
@type username: L{bytes}
@param hostname: The hostname of the SSH server.
@type hostname: L{bytes}
@param port: The port number of the SSH server. By default, the
standard SSH port number is used.
@type port: L{int}
@param keys: Private keys with which to authenticate to the SSH server,
if key authentication is to be attempted (otherwise L{None}).
@type keys: L{list} of L{Key}
@param password: The password with which to authenticate to the SSH
server, if password authentication is to be attempted (otherwise
L{None}).
@type password: L{bytes} or L{None}
@param agentEndpoint: An L{IStreamClientEndpoint} provider which may be
used to connect to an SSH agent, if one is to be used to help with
authentication.
@type agentEndpoint: L{IStreamClientEndpoint} provider
@param knownHosts: The currently known host keys, used to check the
host key presented by the server we actually connect to.
@type knownHosts: L{KnownHostsFile}
@param ui: An object for interacting with users to make decisions about
whether to accept the server host keys. If L{None}, a L{ConsoleUI}
connected to /dev/tty will be used; if /dev/tty is unavailable, an
object which answers C{b"no"} to all prompts will be used.
@type ui: L{None} or L{ConsoleUI}
@return: A new instance of C{cls} (probably
L{SSHCommandClientEndpoint}).
"""
helper = _NewConnectionHelper(
reactor,
hostname,
port,
command,
username,
keys,
password,
agentEndpoint,
knownHosts,
ui,
)
return cls(helper, command)
@classmethod
def existingConnection(cls, connection, command):
"""
Create and return a new endpoint which will try to open a new channel
on an existing SSH connection and run a command over it. It will
B{not} close the connection if there is a problem executing the command
or after the command finishes.
@param connection: An existing connection to an SSH server.
@type connection: L{SSHConnection}
@param command: See L{SSHCommandClientEndpoint.newConnection}'s
C{command} parameter.
@type command: L{bytes}
@return: A new instance of C{cls} (probably
L{SSHCommandClientEndpoint}).
"""
helper = _ExistingConnectionHelper(connection)
return cls(helper, command)
def connect(self, protocolFactory):
"""
Set up an SSH connection, use a channel from that connection to launch
a command, and hook the stdin and stdout of that command up as a
transport for a protocol created by the given factory.
@param protocolFactory: A L{Factory} to use to create the protocol
which will be connected to the stdin and stdout of the command on
the SSH server.
@return: A L{Deferred} which will fire with an error if the connection
cannot be set up for any reason or with the protocol instance
created by C{protocolFactory} once it has been connected to the
command.
"""
d = self._creator.secureConnection()
d.addCallback(self._executeCommand, protocolFactory)
return d
def _executeCommand(self, connection, protocolFactory):
"""
Given a secured SSH connection, try to execute a command in a new
channel created on it and associate the result with a protocol from the
given factory.
@param connection: See L{SSHCommandClientEndpoint.existingConnection}'s
C{connection} parameter.
@param protocolFactory: See L{SSHCommandClientEndpoint.connect}'s
C{protocolFactory} parameter.
@return: See L{SSHCommandClientEndpoint.connect}'s return value.
"""
commandConnected = Deferred()
def disconnectOnFailure(passthrough):
# Close the connection immediately in case of cancellation, since
# that implies user wants it gone immediately (e.g. a timeout):
immediate = passthrough.check(CancelledError)
self._creator.cleanupConnection(connection, immediate)
return passthrough
commandConnected.addErrback(disconnectOnFailure)
channel = _CommandChannel(
self._creator, self._command, protocolFactory, commandConnected
)
connection.openChannel(channel)
return commandConnected
@implementer(_ISSHConnectionCreator)
class _NewConnectionHelper:
"""
L{_NewConnectionHelper} implements L{_ISSHConnectionCreator} by
establishing a brand new SSH connection, securing it, and authenticating.
"""
_KNOWN_HOSTS = _KNOWN_HOSTS
port = 22
def __init__(
self,
reactor: Any,
hostname: str,
port: int,
command: str,
username: str,
keys: str,
password: str,
agentEndpoint: str,
knownHosts: str | None,
ui: ConsoleUI | None,
tty: FilePath[bytes] | FilePath[str] = FilePath(b"/dev/tty"),
):
"""
@param tty: The path of the tty device to use in case C{ui} is L{None}.
@type tty: L{FilePath}
@see: L{SSHCommandClientEndpoint.newConnection}
"""
self.reactor = reactor
self.hostname = hostname
if port is not None:
self.port = port
self.command = command
self.username = username
self.keys = keys
self.password = password
self.agentEndpoint = agentEndpoint
if knownHosts is None:
knownHosts = self._knownHosts()
self.knownHosts = knownHosts
if ui is None:
ui = ConsoleUI(self._opener)
self.ui = ui
self.tty: FilePath[bytes] | FilePath[str] = tty
def _opener(self) -> IO[bytes]:
"""
Open the tty if possible, otherwise give back a file-like object from
which C{b"no"} can be read.
For use as the opener argument to L{ConsoleUI}.
"""
try:
return self.tty.open("r+")
except BaseException:
# Give back a file-like object from which can be read a byte string
# that KnownHostsFile recognizes as rejecting some option (b"no").
return BytesIO(b"no")
@classmethod
def _knownHosts(cls):
"""
@return: A L{KnownHostsFile} instance pointed at the user's personal
I{known hosts} file.
@rtype: L{KnownHostsFile}
"""
return KnownHostsFile.fromPath(FilePath(expanduser(cls._KNOWN_HOSTS)))
def secureConnection(self):
"""
Create and return a new SSH connection which has been secured and on
which authentication has already happened.
@return: A L{Deferred} which fires with the ready-to-use connection or
with a failure if something prevents the connection from being
setup, secured, or authenticated.
"""
protocol = _CommandTransport(self)
ready = protocol.connectionReady
sshClient = TCP4ClientEndpoint(
self.reactor, nativeString(self.hostname), self.port
)
d = connectProtocol(sshClient, protocol)
d.addCallback(lambda ignored: ready)
return d
def cleanupConnection(self, connection, immediate):
"""
Clean up the connection by closing it. The command running on the
endpoint has ended so the connection is no longer needed.
@param connection: The L{SSHConnection} to close.
@type connection: L{SSHConnection}
@param immediate: Whether to close connection immediately.
@type immediate: L{bool}.
"""
if immediate:
# We're assuming the underlying connection is an ITCPTransport,
# which is what the current implementation is restricted to:
connection.transport.transport.abortConnection()
else:
connection.transport.loseConnection()
@implementer(_ISSHConnectionCreator)
class _ExistingConnectionHelper:
"""
L{_ExistingConnectionHelper} implements L{_ISSHConnectionCreator} by
handing out an existing SSH connection which is supplied to its
initializer.
"""
def __init__(self, connection):
"""
@param connection: See L{SSHCommandClientEndpoint.existingConnection}'s
C{connection} parameter.
"""
self.connection = connection
def secureConnection(self):
"""
@return: A L{Deferred} that fires synchronously with the
already-established connection object.
"""
return succeed(self.connection)
def cleanupConnection(self, connection, immediate):
"""
Do not do any cleanup on the connection. Leave that responsibility to
whatever code created it in the first place.
@param connection: The L{SSHConnection} which will not be modified in
any way.
@type connection: L{SSHConnection}
@param immediate: An argument which will be ignored.
@type immediate: L{bool}.
"""

View File

@@ -0,0 +1,96 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
An error to represent bad things happening in Conch.
Maintainer: Paul Swartz
"""
from twisted.cred.error import UnauthorizedLogin
class ConchError(Exception):
def __init__(self, value, data=None):
Exception.__init__(self, value, data)
self.value = value
self.data = data
class NotEnoughAuthentication(Exception):
"""
This is thrown if the authentication is valid, but is not enough to
successfully verify the user. i.e. don't retry this type of
authentication, try another one.
"""
class ValidPublicKey(UnauthorizedLogin):
"""
Raised by public key checkers when they receive public key credentials
that don't contain a signature at all, but are valid in every other way.
(e.g. the public key matches one in the user's authorized_keys file).
Protocol code (eg
L{SSHUserAuthServer<twisted.conch.ssh.userauth.SSHUserAuthServer>}) which
attempts to log in using
L{ISSHPrivateKey<twisted.cred.credentials.ISSHPrivateKey>} credentials
should be prepared to handle a failure of this type by telling the user to
re-authenticate using the same key and to include a signature with the new
attempt.
See U{http://www.ietf.org/rfc/rfc4252.txt} section 7 for more details.
"""
class IgnoreAuthentication(Exception):
"""
This is thrown to let the UserAuthServer know it doesn't need to handle the
authentication anymore.
"""
class MissingKeyStoreError(Exception):
"""
Raised if an SSHAgentServer starts receiving data without its factory
providing a keys dict on which to read/write key data.
"""
class UserRejectedKey(Exception):
"""
The user interactively rejected a key.
"""
class InvalidEntry(Exception):
"""
An entry in a known_hosts file could not be interpreted as a valid entry.
"""
class HostKeyChanged(Exception):
"""
The host key of a remote host has changed.
@ivar offendingEntry: The entry which contains the persistent host key that
disagrees with the given host key.
@type offendingEntry: L{twisted.conch.interfaces.IKnownHostEntry}
@ivar path: a reference to the known_hosts file that the offending entry
was loaded from
@type path: L{twisted.python.filepath.FilePath}
@ivar lineno: The line number of the offending entry in the given path.
@type lineno: L{int}
"""
def __init__(self, offendingEntry, path, lineno):
Exception.__init__(self)
self.offendingEntry = offendingEntry
self.path = path
self.lineno = lineno

View File

@@ -0,0 +1,4 @@
"""
Insults: a replacement for Curses/S-Lang.
Very basic at the moment."""

View File

@@ -0,0 +1,556 @@
# -*- test-case-name: twisted.conch.test.test_helper -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Partial in-memory terminal emulator
@author: Jp Calderone
"""
import re
import string
from zope.interface import implementer
from incremental import Version
from twisted.conch.insults import insults
from twisted.internet import defer, protocol, reactor
from twisted.logger import Logger
from twisted.python import _textattributes
from twisted.python.compat import iterbytes
from twisted.python.deprecate import deprecated, deprecatedModuleAttribute
FOREGROUND = 30
BACKGROUND = 40
BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE, N_COLORS = range(9)
class _FormattingState(_textattributes._FormattingStateMixin):
"""
Represents the formatting state/attributes of a single character.
Character set, intensity, underlinedness, blinkitude, video
reversal, as well as foreground and background colors made up a
character's attributes.
"""
compareAttributes = (
"charset",
"bold",
"underline",
"blink",
"reverseVideo",
"foreground",
"background",
"_subtracting",
)
def __init__(
self,
charset=insults.G0,
bold=False,
underline=False,
blink=False,
reverseVideo=False,
foreground=WHITE,
background=BLACK,
_subtracting=False,
):
self.charset = charset
self.bold = bold
self.underline = underline
self.blink = blink
self.reverseVideo = reverseVideo
self.foreground = foreground
self.background = background
self._subtracting = _subtracting
@deprecated(Version("Twisted", 13, 1, 0))
def wantOne(self, **kw):
"""
Add a character attribute to a copy of this formatting state.
@param kw: An optional attribute name and value can be provided with
a keyword argument.
@return: A formatting state instance with the new attribute.
@see: L{DefaultFormattingState._withAttribute}.
"""
k, v = kw.popitem()
return self._withAttribute(k, v)
def toVT102(self):
# Spit out a vt102 control sequence that will set up
# all the attributes set here. Except charset.
attrs = []
if self._subtracting:
attrs.append(0)
if self.bold:
attrs.append(insults.BOLD)
if self.underline:
attrs.append(insults.UNDERLINE)
if self.blink:
attrs.append(insults.BLINK)
if self.reverseVideo:
attrs.append(insults.REVERSE_VIDEO)
if self.foreground != WHITE:
attrs.append(FOREGROUND + self.foreground)
if self.background != BLACK:
attrs.append(BACKGROUND + self.background)
if attrs:
return "\x1b[" + ";".join(map(str, attrs)) + "m"
return ""
CharacterAttribute = _FormattingState
deprecatedModuleAttribute(
Version("Twisted", 13, 1, 0),
"Use twisted.conch.insults.text.assembleFormattedText instead.",
"twisted.conch.insults.helper",
"CharacterAttribute",
)
# XXX - need to support scroll regions and scroll history
@implementer(insults.ITerminalTransport)
class TerminalBuffer(protocol.Protocol):
"""
An in-memory terminal emulator.
"""
for keyID in (
b"UP_ARROW",
b"DOWN_ARROW",
b"RIGHT_ARROW",
b"LEFT_ARROW",
b"HOME",
b"INSERT",
b"DELETE",
b"END",
b"PGUP",
b"PGDN",
b"F1",
b"F2",
b"F3",
b"F4",
b"F5",
b"F6",
b"F7",
b"F8",
b"F9",
b"F10",
b"F11",
b"F12",
):
execBytes = keyID + b" = object()"
execStr = execBytes.decode("ascii")
exec(execStr)
TAB = b"\t"
BACKSPACE = b"\x7f"
width = 80
height = 24
fill = b" "
void = object()
_log = Logger()
def getCharacter(self, x, y):
return self.lines[y][x]
def connectionMade(self):
self.reset()
def write(self, data):
"""
Add the given printable bytes to the terminal.
Line feeds in L{bytes} will be replaced with carriage return / line
feed pairs.
"""
for b in iterbytes(data.replace(b"\n", b"\r\n")):
self.insertAtCursor(b)
def _currentFormattingState(self):
return _FormattingState(self.activeCharset, **self.graphicRendition)
def insertAtCursor(self, b):
"""
Add one byte to the terminal at the cursor and make consequent state
updates.
If b is a carriage return, move the cursor to the beginning of the
current row.
If b is a line feed, move the cursor to the next row or scroll down if
the cursor is already in the last row.
Otherwise, if b is printable, put it at the cursor position (inserting
or overwriting as dictated by the current mode) and move the cursor.
"""
if b == b"\r":
self.x = 0
elif b == b"\n":
self._scrollDown()
elif b in string.printable.encode("ascii"):
if self.x >= self.width:
self.nextLine()
ch = (b, self._currentFormattingState())
if self.modes.get(insults.modes.IRM):
self.lines[self.y][self.x : self.x] = [ch]
self.lines[self.y].pop()
else:
self.lines[self.y][self.x] = ch
self.x += 1
def _emptyLine(self, width):
return [(self.void, self._currentFormattingState()) for i in range(width)]
def _scrollDown(self):
self.y += 1
if self.y >= self.height:
self.y -= 1
del self.lines[0]
self.lines.append(self._emptyLine(self.width))
def _scrollUp(self):
self.y -= 1
if self.y < 0:
self.y = 0
del self.lines[-1]
self.lines.insert(0, self._emptyLine(self.width))
def cursorUp(self, n=1):
self.y = max(0, self.y - n)
def cursorDown(self, n=1):
self.y = min(self.height - 1, self.y + n)
def cursorBackward(self, n=1):
self.x = max(0, self.x - n)
def cursorForward(self, n=1):
self.x = min(self.width, self.x + n)
def cursorPosition(self, column, line):
self.x = column
self.y = line
def cursorHome(self):
self.x = self.home.x
self.y = self.home.y
def index(self):
self._scrollDown()
def reverseIndex(self):
self._scrollUp()
def nextLine(self):
"""
Update the cursor position attributes and scroll down if appropriate.
"""
self.x = 0
self._scrollDown()
def saveCursor(self):
self._savedCursor = (self.x, self.y)
def restoreCursor(self):
self.x, self.y = self._savedCursor
del self._savedCursor
def setModes(self, modes):
for m in modes:
self.modes[m] = True
def resetModes(self, modes):
for m in modes:
try:
del self.modes[m]
except KeyError:
pass
def setPrivateModes(self, modes):
"""
Enable the given modes.
Track which modes have been enabled so that the implementations of
other L{insults.ITerminalTransport} methods can be properly implemented
to respect these settings.
@see: L{resetPrivateModes}
@see: L{insults.ITerminalTransport.setPrivateModes}
"""
for m in modes:
self.privateModes[m] = True
def resetPrivateModes(self, modes):
"""
Disable the given modes.
@see: L{setPrivateModes}
@see: L{insults.ITerminalTransport.resetPrivateModes}
"""
for m in modes:
try:
del self.privateModes[m]
except KeyError:
pass
def applicationKeypadMode(self):
self.keypadMode = "app"
def numericKeypadMode(self):
self.keypadMode = "num"
def selectCharacterSet(self, charSet, which):
self.charsets[which] = charSet
def shiftIn(self):
self.activeCharset = insults.G0
def shiftOut(self):
self.activeCharset = insults.G1
def singleShift2(self):
oldActiveCharset = self.activeCharset
self.activeCharset = insults.G2
f = self.insertAtCursor
def insertAtCursor(b):
f(b)
del self.insertAtCursor
self.activeCharset = oldActiveCharset
self.insertAtCursor = insertAtCursor
def singleShift3(self):
oldActiveCharset = self.activeCharset
self.activeCharset = insults.G3
f = self.insertAtCursor
def insertAtCursor(b):
f(b)
del self.insertAtCursor
self.activeCharset = oldActiveCharset
self.insertAtCursor = insertAtCursor
def selectGraphicRendition(self, *attributes):
for a in attributes:
if a == insults.NORMAL:
self.graphicRendition = {
"bold": False,
"underline": False,
"blink": False,
"reverseVideo": False,
"foreground": WHITE,
"background": BLACK,
}
elif a == insults.BOLD:
self.graphicRendition["bold"] = True
elif a == insults.UNDERLINE:
self.graphicRendition["underline"] = True
elif a == insults.BLINK:
self.graphicRendition["blink"] = True
elif a == insults.REVERSE_VIDEO:
self.graphicRendition["reverseVideo"] = True
else:
try:
v = int(a)
except ValueError:
self._log.error(
"Unknown graphic rendition attribute: {attr!r}", attr=a
)
else:
if FOREGROUND <= v <= FOREGROUND + N_COLORS:
self.graphicRendition["foreground"] = v - FOREGROUND
elif BACKGROUND <= v <= BACKGROUND + N_COLORS:
self.graphicRendition["background"] = v - BACKGROUND
else:
self._log.error(
"Unknown graphic rendition attribute: {attr!r}", attr=a
)
def eraseLine(self):
self.lines[self.y] = self._emptyLine(self.width)
def eraseToLineEnd(self):
width = self.width - self.x
self.lines[self.y][self.x :] = self._emptyLine(width)
def eraseToLineBeginning(self):
self.lines[self.y][: self.x + 1] = self._emptyLine(self.x + 1)
def eraseDisplay(self):
self.lines = [self._emptyLine(self.width) for i in range(self.height)]
def eraseToDisplayEnd(self):
self.eraseToLineEnd()
height = self.height - self.y - 1
self.lines[self.y + 1 :] = [self._emptyLine(self.width) for i in range(height)]
def eraseToDisplayBeginning(self):
self.eraseToLineBeginning()
self.lines[: self.y] = [self._emptyLine(self.width) for i in range(self.y)]
def deleteCharacter(self, n=1):
del self.lines[self.y][self.x : self.x + n]
self.lines[self.y].extend(self._emptyLine(min(self.width - self.x, n)))
def insertLine(self, n=1):
self.lines[self.y : self.y] = [self._emptyLine(self.width) for i in range(n)]
del self.lines[self.height :]
def deleteLine(self, n=1):
del self.lines[self.y : self.y + n]
self.lines.extend([self._emptyLine(self.width) for i in range(n)])
def reportCursorPosition(self):
return (self.x, self.y)
def reset(self):
self.home = insults.Vector(0, 0)
self.x = self.y = 0
self.modes = {}
self.privateModes = {}
self.setPrivateModes(
[insults.privateModes.AUTO_WRAP, insults.privateModes.CURSOR_MODE]
)
self.numericKeypad = "app"
self.activeCharset = insults.G0
self.graphicRendition = {
"bold": False,
"underline": False,
"blink": False,
"reverseVideo": False,
"foreground": WHITE,
"background": BLACK,
}
self.charsets = {
insults.G0: insults.CS_US,
insults.G1: insults.CS_US,
insults.G2: insults.CS_ALTERNATE,
insults.G3: insults.CS_ALTERNATE_SPECIAL,
}
self.eraseDisplay()
def unhandledControlSequence(self, buf):
print("Could not handle", repr(buf))
def __bytes__(self):
lines = []
for L in self.lines:
buf = []
length = 0
for ch, attr in L:
if ch is not self.void:
buf.append(ch)
length = len(buf)
else:
buf.append(self.fill)
lines.append(b"".join(buf[:length]))
return b"\n".join(lines)
def getHost(self):
# ITransport.getHost
raise NotImplementedError("Unimplemented: TerminalBuffer.getHost")
def getPeer(self):
# ITransport.getPeer
raise NotImplementedError("Unimplemented: TerminalBuffer.getPeer")
def loseConnection(self):
# ITransport.loseConnection
raise NotImplementedError("Unimplemented: TerminalBuffer.loseConnection")
def writeSequence(self, data):
# ITransport.writeSequence
raise NotImplementedError("Unimplemented: TerminalBuffer.writeSequence")
def horizontalTabulationSet(self):
# ITerminalTransport.horizontalTabulationSet
raise NotImplementedError(
"Unimplemented: TerminalBuffer.horizontalTabulationSet"
)
def tabulationClear(self):
# TerminalTransport.tabulationClear
raise NotImplementedError("Unimplemented: TerminalBuffer.tabulationClear")
def tabulationClearAll(self):
# TerminalTransport.tabulationClearAll
raise NotImplementedError("Unimplemented: TerminalBuffer.tabulationClearAll")
def doubleHeightLine(self, top=True):
# ITerminalTransport.doubleHeightLine
raise NotImplementedError("Unimplemented: TerminalBuffer.doubleHeightLine")
def singleWidthLine(self):
# ITerminalTransport.singleWidthLine
raise NotImplementedError("Unimplemented: TerminalBuffer.singleWidthLine")
def doubleWidthLine(self):
# ITerminalTransport.doubleWidthLine
raise NotImplementedError("Unimplemented: TerminalBuffer.doubleWidthLine")
class ExpectationTimeout(Exception):
pass
class ExpectableBuffer(TerminalBuffer):
_mark = 0
def connectionMade(self):
TerminalBuffer.connectionMade(self)
self._expecting = []
def write(self, data):
TerminalBuffer.write(self, data)
self._checkExpected()
def cursorHome(self):
TerminalBuffer.cursorHome(self)
self._mark = 0
def _timeoutExpected(self, d):
d.errback(ExpectationTimeout())
self._checkExpected()
def _checkExpected(self):
s = self.__bytes__()[self._mark :]
while self._expecting:
expr, timer, deferred = self._expecting[0]
if timer and not timer.active():
del self._expecting[0]
continue
for match in expr.finditer(s):
if timer:
timer.cancel()
del self._expecting[0]
self._mark += match.end()
s = s[match.end() :]
deferred.callback(match)
break
else:
return
def expect(self, expression, timeout=None, scheduler=reactor):
d = defer.Deferred()
timer = None
if timeout:
timer = scheduler.callLater(timeout, self._timeoutExpected, d)
self._expecting.append((re.compile(expression), timer, d))
self._checkExpected()
return d
__all__ = ["CharacterAttribute", "TerminalBuffer", "ExpectableBuffer"]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,176 @@
# -*- test-case-name: twisted.conch.test.test_text -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Character attribute manipulation API.
This module provides a domain-specific language (using Python syntax)
for the creation of text with additional display attributes associated
with it. It is intended as an alternative to manually building up
strings containing ECMA 48 character attribute control codes. It
currently supports foreground and background colors (black, red,
green, yellow, blue, magenta, cyan, and white), intensity selection,
underlining, blinking and reverse video. Character set selection
support is planned.
Character attributes are specified by using two Python operations:
attribute lookup and indexing. For example, the string \"Hello
world\" with red foreground and all other attributes set to their
defaults, assuming the name twisted.conch.insults.text.attributes has
been imported and bound to the name \"A\" (with the statement C{from
twisted.conch.insults.text import attributes as A}, for example) one
uses this expression::
A.fg.red[\"Hello world\"]
Other foreground colors are set by substituting their name for
\"red\". To set both a foreground and a background color, this
expression is used::
A.fg.red[A.bg.green[\"Hello world\"]]
Note that either A.bg.green can be nested within A.fg.red or vice
versa. Also note that multiple items can be nested within a single
index operation by separating them with commas::
A.bg.green[A.fg.red[\"Hello\"], " ", A.fg.blue[\"world\"]]
Other character attributes are set in a similar fashion. To specify a
blinking version of the previous expression::
A.blink[A.bg.green[A.fg.red[\"Hello\"], " ", A.fg.blue[\"world\"]]]
C{A.reverseVideo}, C{A.underline}, and C{A.bold} are also valid.
A third operation is actually supported: unary negation. This turns
off an attribute when an enclosing expression would otherwise have
caused it to be on. For example::
A.underline[A.fg.red[\"Hello\", -A.underline[\" world\"]]]
A formatting structure can then be serialized into a string containing the
necessary VT102 control codes with L{assembleFormattedText}.
@see: L{twisted.conch.insults.text._CharacterAttributes}
@author: Jp Calderone
"""
from incremental import Version
from twisted.conch.insults import helper, insults
from twisted.python import _textattributes
from twisted.python.deprecate import deprecatedModuleAttribute
flatten = _textattributes.flatten
deprecatedModuleAttribute(
Version("Twisted", 13, 1, 0),
"Use twisted.conch.insults.text.assembleFormattedText instead.",
"twisted.conch.insults.text",
"flatten",
)
_TEXT_COLORS = {
"black": helper.BLACK,
"red": helper.RED,
"green": helper.GREEN,
"yellow": helper.YELLOW,
"blue": helper.BLUE,
"magenta": helper.MAGENTA,
"cyan": helper.CYAN,
"white": helper.WHITE,
}
class _CharacterAttributes(_textattributes.CharacterAttributesMixin):
"""
Factory for character attributes, including foreground and background color
and non-color attributes such as bold, reverse video and underline.
Character attributes are applied to actual text by using object
indexing-syntax (C{obj['abc']}) after accessing a factory attribute, for
example::
attributes.bold['Some text']
These can be nested to mix attributes::
attributes.bold[attributes.underline['Some text']]
And multiple values can be passed::
attributes.normal[attributes.bold['Some'], ' text']
Non-color attributes can be accessed by attribute name, available
attributes are:
- bold
- blink
- reverseVideo
- underline
Available colors are:
0. black
1. red
2. green
3. yellow
4. blue
5. magenta
6. cyan
7. white
@ivar fg: Foreground colors accessed by attribute name, see above
for possible names.
@ivar bg: Background colors accessed by attribute name, see above
for possible names.
"""
fg = _textattributes._ColorAttribute(
_textattributes._ForegroundColorAttr, _TEXT_COLORS
)
bg = _textattributes._ColorAttribute(
_textattributes._BackgroundColorAttr, _TEXT_COLORS
)
attrs = {
"bold": insults.BOLD,
"blink": insults.BLINK,
"underline": insults.UNDERLINE,
"reverseVideo": insults.REVERSE_VIDEO,
}
def assembleFormattedText(formatted):
"""
Assemble formatted text from structured information.
Currently handled formatting includes: bold, blink, reverse, underline and
color codes.
For example::
from twisted.conch.insults.text import attributes as A
assembleFormattedText(
A.normal[A.bold['Time: '], A.fg.lightRed['Now!']])
Would produce "Time: " in bold formatting, followed by "Now!" with a
foreground color of light red and without any additional formatting.
@param formatted: Structured text and attributes.
@rtype: L{str}
@return: String containing VT102 control sequences that mimic those
specified by C{formatted}.
@see: L{twisted.conch.insults.text._CharacterAttributes}
@since: 13.1
"""
return _textattributes.flatten(formatted, helper._FormattingState(), "toVT102")
attributes = _CharacterAttributes()
__all__ = ["attributes", "flatten"]

View File

@@ -0,0 +1,936 @@
# -*- test-case-name: twisted.conch.test.test_window -*-
"""
Simple insults-based widget library
@author: Jp Calderone
"""
from __future__ import annotations
import array
from twisted.conch.insults import helper, insults
from twisted.python import text as tptext
class YieldFocus(Exception):
"""
Input focus manipulation exception
"""
class BoundedTerminalWrapper:
def __init__(self, terminal, width, height, xoff, yoff):
self.width = width
self.height = height
self.xoff = xoff
self.yoff = yoff
self.terminal = terminal
self.cursorForward = terminal.cursorForward
self.selectCharacterSet = terminal.selectCharacterSet
self.selectGraphicRendition = terminal.selectGraphicRendition
self.saveCursor = terminal.saveCursor
self.restoreCursor = terminal.restoreCursor
def cursorPosition(self, x, y):
return self.terminal.cursorPosition(
self.xoff + min(self.width, x), self.yoff + min(self.height, y)
)
def cursorHome(self):
return self.terminal.cursorPosition(self.xoff, self.yoff)
def write(self, data):
return self.terminal.write(data)
class Widget:
focused = False
parent = None
dirty = False
width: int | None = None
height: int | None = None
def repaint(self):
if not self.dirty:
self.dirty = True
if self.parent is not None and not self.parent.dirty:
self.parent.repaint()
def filthy(self):
self.dirty = True
def redraw(self, width, height, terminal):
self.filthy()
self.draw(width, height, terminal)
def draw(self, width, height, terminal):
if width != self.width or height != self.height or self.dirty:
self.width = width
self.height = height
self.dirty = False
self.render(width, height, terminal)
def render(self, width, height, terminal):
pass
def sizeHint(self):
return None
def keystrokeReceived(self, keyID, modifier):
if keyID == b"\t":
self.tabReceived(modifier)
elif keyID == b"\x7f":
self.backspaceReceived()
elif keyID in insults.FUNCTION_KEYS:
self.functionKeyReceived(keyID, modifier)
else:
self.characterReceived(keyID, modifier)
def tabReceived(self, modifier):
# XXX TODO - Handle shift+tab
raise YieldFocus()
def focusReceived(self):
"""
Called when focus is being given to this widget.
May raise YieldFocus is this widget does not want focus.
"""
self.focused = True
self.repaint()
def focusLost(self):
self.focused = False
self.repaint()
def backspaceReceived(self):
pass
def functionKeyReceived(self, keyID, modifier):
name = keyID
if not isinstance(keyID, str):
name = name.decode("utf-8")
# Peel off the square brackets added by the computed definition of
# twisted.conch.insults.insults.FUNCTION_KEYS.
methodName = "func_" + name[1:-1]
func = getattr(self, methodName, None)
if func is not None:
func(modifier)
def characterReceived(self, keyID, modifier):
pass
class ContainerWidget(Widget):
"""
@ivar focusedChild: The contained widget which currently has
focus, or None.
"""
focusedChild = None
focused = False
def __init__(self):
Widget.__init__(self)
self.children = []
def addChild(self, child):
assert child.parent is None
child.parent = self
self.children.append(child)
if self.focusedChild is None and self.focused:
try:
child.focusReceived()
except YieldFocus:
pass
else:
self.focusedChild = child
self.repaint()
def remChild(self, child):
assert child.parent is self
child.parent = None
self.children.remove(child)
self.repaint()
def filthy(self):
for ch in self.children:
ch.filthy()
Widget.filthy(self)
def render(self, width, height, terminal):
for ch in self.children:
ch.draw(width, height, terminal)
def changeFocus(self):
self.repaint()
if self.focusedChild is not None:
self.focusedChild.focusLost()
focusedChild = self.focusedChild
self.focusedChild = None
try:
curFocus = self.children.index(focusedChild) + 1
except ValueError:
raise YieldFocus()
else:
curFocus = 0
while curFocus < len(self.children):
try:
self.children[curFocus].focusReceived()
except YieldFocus:
curFocus += 1
else:
self.focusedChild = self.children[curFocus]
return
# None of our children wanted focus
raise YieldFocus()
def focusReceived(self):
self.changeFocus()
self.focused = True
def keystrokeReceived(self, keyID, modifier):
if self.focusedChild is not None:
try:
self.focusedChild.keystrokeReceived(keyID, modifier)
except YieldFocus:
self.changeFocus()
self.repaint()
else:
Widget.keystrokeReceived(self, keyID, modifier)
class TopWindow(ContainerWidget):
"""
A top-level container object which provides focus wrap-around and paint
scheduling.
@ivar painter: A no-argument callable which will be invoked when this
widget needs to be redrawn.
@ivar scheduler: A one-argument callable which will be invoked with a
no-argument callable and should arrange for it to invoked at some point in
the near future. The no-argument callable will cause this widget and all
its children to be redrawn. It is typically beneficial for the no-argument
callable to be invoked at the end of handling for whatever event is
currently active; for example, it might make sense to call it at the end of
L{twisted.conch.insults.insults.ITerminalProtocol.keystrokeReceived}.
Note, however, that since calls to this may also be made in response to no
apparent event, arrangements should be made for the function to be called
even if an event handler such as C{keystrokeReceived} is not on the call
stack (eg, using
L{reactor.callLater<twisted.internet.interfaces.IReactorTime.callLater>}
with a short timeout).
"""
focused = True
def __init__(self, painter, scheduler):
ContainerWidget.__init__(self)
self.painter = painter
self.scheduler = scheduler
_paintCall = None
def repaint(self):
if self._paintCall is None:
self._paintCall = object()
self.scheduler(self._paint)
ContainerWidget.repaint(self)
def _paint(self):
self._paintCall = None
self.painter()
def changeFocus(self):
try:
ContainerWidget.changeFocus(self)
except YieldFocus:
try:
ContainerWidget.changeFocus(self)
except YieldFocus:
pass
def keystrokeReceived(self, keyID, modifier):
try:
ContainerWidget.keystrokeReceived(self, keyID, modifier)
except YieldFocus:
self.changeFocus()
class AbsoluteBox(ContainerWidget):
def moveChild(self, child, x, y):
for n in range(len(self.children)):
if self.children[n][0] is child:
self.children[n] = (child, x, y)
break
else:
raise ValueError("No such child", child)
def render(self, width, height, terminal):
for ch, x, y in self.children:
wrap = BoundedTerminalWrapper(terminal, width - x, height - y, x, y)
ch.draw(width, height, wrap)
class _Box(ContainerWidget):
TOP, CENTER, BOTTOM = range(3)
def __init__(self, gravity=CENTER):
ContainerWidget.__init__(self)
self.gravity = gravity
def sizeHint(self):
height = 0
width = 0
for ch in self.children:
hint = ch.sizeHint()
if hint is None:
hint = (None, None)
if self.variableDimension == 0:
if hint[0] is None:
width = None
elif width is not None:
width += hint[0]
if hint[1] is None:
height = None
elif height is not None:
height = max(height, hint[1])
else:
if hint[0] is None:
width = None
elif width is not None:
width = max(width, hint[0])
if hint[1] is None:
height = None
elif height is not None:
height += hint[1]
return width, height
def render(self, width, height, terminal):
if not self.children:
return
greedy = 0
wants = []
for ch in self.children:
hint = ch.sizeHint()
if hint is None:
hint = (None, None)
if hint[self.variableDimension] is None:
greedy += 1
wants.append(hint[self.variableDimension])
length = (width, height)[self.variableDimension]
totalWant = sum(w for w in wants if w is not None)
if greedy:
leftForGreedy = int((length - totalWant) / greedy)
widthOffset = heightOffset = 0
for want, ch in zip(wants, self.children):
if want is None:
want = leftForGreedy
subWidth, subHeight = width, height
if self.variableDimension == 0:
subWidth = want
else:
subHeight = want
wrap = BoundedTerminalWrapper(
terminal,
subWidth,
subHeight,
widthOffset,
heightOffset,
)
ch.draw(subWidth, subHeight, wrap)
if self.variableDimension == 0:
widthOffset += want
else:
heightOffset += want
class HBox(_Box):
variableDimension = 0
class VBox(_Box):
variableDimension = 1
class Packer(ContainerWidget):
def render(self, width, height, terminal):
if not self.children:
return
root = int(len(self.children) ** 0.5 + 0.5)
boxes = [VBox() for n in range(root)]
for n, ch in enumerate(self.children):
boxes[n % len(boxes)].addChild(ch)
h = HBox()
map(h.addChild, boxes)
h.render(width, height, terminal)
class Canvas(Widget):
focused = False
contents = None
def __init__(self):
Widget.__init__(self)
self.resize(1, 1)
def resize(self, width, height):
contents = array.array("B", b" " * width * height)
if self.contents is not None:
for x in range(min(width, self._width)):
for y in range(min(height, self._height)):
contents[width * y + x] = self[x, y]
self.contents = contents
self._width = width
self._height = height
if self.x >= width:
self.x = width - 1
if self.y >= height:
self.y = height - 1
def __getitem__(self, index):
(x, y) = index
return self.contents[(self._width * y) + x]
def __setitem__(self, index, value):
(x, y) = index
self.contents[(self._width * y) + x] = value
def clear(self):
self.contents = array.array("B", b" " * len(self.contents))
def render(self, width, height, terminal):
if not width or not height:
return
if width != self._width or height != self._height:
self.resize(width, height)
for i in range(height):
terminal.cursorPosition(0, i)
text = self.contents[
self._width * i : self._width * i + self._width
].tobytes()
text = text[:width]
terminal.write(text)
def horizontalLine(terminal, y, left, right):
terminal.selectCharacterSet(insults.CS_DRAWING, insults.G0)
terminal.cursorPosition(left, y)
terminal.write(b"\161" * (right - left))
terminal.selectCharacterSet(insults.CS_US, insults.G0)
def verticalLine(terminal, x, top, bottom):
terminal.selectCharacterSet(insults.CS_DRAWING, insults.G0)
for n in range(top, bottom):
terminal.cursorPosition(x, n)
terminal.write(b"\170")
terminal.selectCharacterSet(insults.CS_US, insults.G0)
def rectangle(terminal, position, dimension):
"""
Draw a rectangle
@type position: L{tuple}
@param position: A tuple of the (top, left) coordinates of the rectangle.
@type dimension: L{tuple}
@param dimension: A tuple of the (width, height) size of the rectangle.
"""
(top, left) = position
(width, height) = dimension
terminal.selectCharacterSet(insults.CS_DRAWING, insults.G0)
terminal.cursorPosition(top, left)
terminal.write(b"\154")
terminal.write(b"\161" * (width - 2))
terminal.write(b"\153")
for n in range(height - 2):
terminal.cursorPosition(left, top + n + 1)
terminal.write(b"\170")
terminal.cursorForward(width - 2)
terminal.write(b"\170")
terminal.cursorPosition(0, top + height - 1)
terminal.write(b"\155")
terminal.write(b"\161" * (width - 2))
terminal.write(b"\152")
terminal.selectCharacterSet(insults.CS_US, insults.G0)
class Border(Widget):
def __init__(self, containee):
Widget.__init__(self)
self.containee = containee
self.containee.parent = self
def focusReceived(self):
return self.containee.focusReceived()
def focusLost(self):
return self.containee.focusLost()
def keystrokeReceived(self, keyID, modifier):
return self.containee.keystrokeReceived(keyID, modifier)
def sizeHint(self):
hint = self.containee.sizeHint()
if hint is None:
hint = (None, None)
if hint[0] is None:
x = None
else:
x = hint[0] + 2
if hint[1] is None:
y = None
else:
y = hint[1] + 2
return x, y
def filthy(self):
self.containee.filthy()
Widget.filthy(self)
def render(self, width, height, terminal):
if self.containee.focused:
terminal.write(b"\x1b[31m")
rectangle(terminal, (0, 0), (width, height))
terminal.write(b"\x1b[0m")
wrap = BoundedTerminalWrapper(terminal, width - 2, height - 2, 1, 1)
self.containee.draw(width - 2, height - 2, wrap)
class Button(Widget):
def __init__(self, label, onPress):
Widget.__init__(self)
self.label = label
self.onPress = onPress
def sizeHint(self):
return len(self.label), 1
def characterReceived(self, keyID, modifier):
if keyID == b"\r":
self.onPress()
def render(self, width, height, terminal):
terminal.cursorPosition(0, 0)
if self.focused:
terminal.write(b"\x1b[1m" + self.label + b"\x1b[0m")
else:
terminal.write(self.label)
class TextInput(Widget):
def __init__(self, maxwidth, onSubmit):
Widget.__init__(self)
self.onSubmit = onSubmit
self.maxwidth = maxwidth
self.buffer = b""
self.cursor = 0
def setText(self, text):
self.buffer = text[: self.maxwidth]
self.cursor = len(self.buffer)
self.repaint()
def func_LEFT_ARROW(self, modifier):
if self.cursor > 0:
self.cursor -= 1
self.repaint()
def func_RIGHT_ARROW(self, modifier):
if self.cursor < len(self.buffer):
self.cursor += 1
self.repaint()
def backspaceReceived(self):
if self.cursor > 0:
self.buffer = self.buffer[: self.cursor - 1] + self.buffer[self.cursor :]
self.cursor -= 1
self.repaint()
def characterReceived(self, keyID, modifier):
if keyID == b"\r":
self.onSubmit(self.buffer)
else:
if len(self.buffer) < self.maxwidth:
self.buffer = (
self.buffer[: self.cursor] + keyID + self.buffer[self.cursor :]
)
self.cursor += 1
self.repaint()
def sizeHint(self):
return self.maxwidth + 1, 1
def render(self, width, height, terminal):
currentText = self._renderText()
terminal.cursorPosition(0, 0)
if self.focused:
terminal.write(currentText[: self.cursor])
cursor(terminal, currentText[self.cursor : self.cursor + 1] or b" ")
terminal.write(currentText[self.cursor + 1 :])
terminal.write(b" " * (self.maxwidth - len(currentText) + 1))
else:
more = self.maxwidth - len(currentText)
terminal.write(currentText + b"_" * more)
def _renderText(self):
return self.buffer
class PasswordInput(TextInput):
def _renderText(self):
return "*" * len(self.buffer)
class TextOutput(Widget):
text = b""
def __init__(self, size=None):
Widget.__init__(self)
self.size = size
def sizeHint(self):
return self.size
def render(self, width, height, terminal):
terminal.cursorPosition(0, 0)
text = self.text[:width]
terminal.write(text + b" " * (width - len(text)))
def setText(self, text):
self.text = text
self.repaint()
def focusReceived(self):
raise YieldFocus()
class TextOutputArea(TextOutput):
WRAP, TRUNCATE = range(2)
def __init__(self, size=None, longLines=WRAP):
TextOutput.__init__(self, size)
self.longLines = longLines
def render(self, width, height, terminal):
n = 0
inputLines = self.text.splitlines()
outputLines = []
while inputLines:
if self.longLines == self.WRAP:
line = inputLines.pop(0)
if not isinstance(line, str):
line = line.decode("utf-8")
wrappedLines = []
for wrappedLine in tptext.greedyWrap(line, width):
if not isinstance(wrappedLine, bytes):
wrappedLine = wrappedLine.encode("utf-8")
wrappedLines.append(wrappedLine)
outputLines.extend(wrappedLines or [b""])
else:
outputLines.append(inputLines.pop(0)[:width])
if len(outputLines) >= height:
break
for n, L in enumerate(outputLines[:height]):
terminal.cursorPosition(0, n)
terminal.write(L)
class Viewport(Widget):
_xOffset = 0
_yOffset = 0
@property
def xOffset(self):
return self._xOffset
@xOffset.setter
def xOffset(self, value):
if self._xOffset != value:
self._xOffset = value
self.repaint()
@property
def yOffset(self):
return self._yOffset
@yOffset.setter
def yOffset(self, value):
if self._yOffset != value:
self._yOffset = value
self.repaint()
_width = 160
_height = 24
def __init__(self, containee):
Widget.__init__(self)
self.containee = containee
self.containee.parent = self
self._buf = helper.TerminalBuffer()
self._buf.width = self._width
self._buf.height = self._height
self._buf.connectionMade()
def filthy(self):
self.containee.filthy()
Widget.filthy(self)
def render(self, width, height, terminal):
self.containee.draw(self._width, self._height, self._buf)
# XXX /Lame/
for y, line in enumerate(
self._buf.lines[self._yOffset : self._yOffset + height]
):
terminal.cursorPosition(0, y)
n = 0
for n, (ch, attr) in enumerate(line[self._xOffset : self._xOffset + width]):
if ch is self._buf.void:
ch = b" "
terminal.write(ch)
if n < width:
terminal.write(b" " * (width - n - 1))
class _Scrollbar(Widget):
def __init__(self, onScroll):
Widget.__init__(self)
self.onScroll = onScroll
self.percent = 0.0
def smaller(self):
self.percent = min(1.0, max(0.0, self.onScroll(-1)))
self.repaint()
def bigger(self):
self.percent = min(1.0, max(0.0, self.onScroll(+1)))
self.repaint()
class HorizontalScrollbar(_Scrollbar):
def sizeHint(self):
return (None, 1)
def func_LEFT_ARROW(self, modifier):
self.smaller()
def func_RIGHT_ARROW(self, modifier):
self.bigger()
_left = "\N{BLACK LEFT-POINTING TRIANGLE}"
_right = "\N{BLACK RIGHT-POINTING TRIANGLE}"
_bar = "\N{LIGHT SHADE}"
_slider = "\N{DARK SHADE}"
def render(self, width, height, terminal):
terminal.cursorPosition(0, 0)
n = width - 3
before = int(n * self.percent)
after = n - before
me = (
self._left
+ (self._bar * before)
+ self._slider
+ (self._bar * after)
+ self._right
)
terminal.write(me.encode("utf-8"))
class VerticalScrollbar(_Scrollbar):
def sizeHint(self):
return (1, None)
def func_UP_ARROW(self, modifier):
self.smaller()
def func_DOWN_ARROW(self, modifier):
self.bigger()
_up = "\N{BLACK UP-POINTING TRIANGLE}"
_down = "\N{BLACK DOWN-POINTING TRIANGLE}"
_bar = "\N{LIGHT SHADE}"
_slider = "\N{DARK SHADE}"
def render(self, width, height, terminal):
terminal.cursorPosition(0, 0)
knob = int(self.percent * (height - 2))
terminal.write(self._up.encode("utf-8"))
for i in range(1, height - 1):
terminal.cursorPosition(0, i)
if i != (knob + 1):
terminal.write(self._bar.encode("utf-8"))
else:
terminal.write(self._slider.encode("utf-8"))
terminal.cursorPosition(0, height - 1)
terminal.write(self._down.encode("utf-8"))
class ScrolledArea(Widget):
"""
A L{ScrolledArea} contains another widget wrapped in a viewport and
vertical and horizontal scrollbars for moving the viewport around.
"""
def __init__(self, containee):
Widget.__init__(self)
self._viewport = Viewport(containee)
self._horiz = HorizontalScrollbar(self._horizScroll)
self._vert = VerticalScrollbar(self._vertScroll)
for w in self._viewport, self._horiz, self._vert:
w.parent = self
def _horizScroll(self, n):
self._viewport.xOffset += n
self._viewport.xOffset = max(0, self._viewport.xOffset)
return self._viewport.xOffset / 25.0
def _vertScroll(self, n):
self._viewport.yOffset += n
self._viewport.yOffset = max(0, self._viewport.yOffset)
return self._viewport.yOffset / 25.0
def func_UP_ARROW(self, modifier):
self._vert.smaller()
def func_DOWN_ARROW(self, modifier):
self._vert.bigger()
def func_LEFT_ARROW(self, modifier):
self._horiz.smaller()
def func_RIGHT_ARROW(self, modifier):
self._horiz.bigger()
def filthy(self):
self._viewport.filthy()
self._horiz.filthy()
self._vert.filthy()
Widget.filthy(self)
def render(self, width, height, terminal):
wrapper = BoundedTerminalWrapper(terminal, width - 2, height - 2, 1, 1)
self._viewport.draw(width - 2, height - 2, wrapper)
if self.focused:
terminal.write(b"\x1b[31m")
horizontalLine(terminal, 0, 1, width - 1)
verticalLine(terminal, 0, 1, height - 1)
self._vert.draw(
1, height - 1, BoundedTerminalWrapper(terminal, 1, height - 1, width - 1, 0)
)
self._horiz.draw(
width, 1, BoundedTerminalWrapper(terminal, width, 1, 0, height - 1)
)
terminal.write(b"\x1b[0m")
def cursor(terminal, ch):
terminal.saveCursor()
terminal.selectGraphicRendition(str(insults.REVERSE_VIDEO))
terminal.write(ch)
terminal.restoreCursor()
terminal.cursorForward()
class Selection(Widget):
# Index into the sequence
focusedIndex = 0
# Offset into the displayed subset of the sequence
renderOffset = 0
def __init__(self, sequence, onSelect, minVisible=None):
Widget.__init__(self)
self.sequence = sequence
self.onSelect = onSelect
self.minVisible = minVisible
if minVisible is not None:
self._width = max(map(len, self.sequence))
def sizeHint(self):
if self.minVisible is not None:
return self._width, self.minVisible
def func_UP_ARROW(self, modifier):
if self.focusedIndex > 0:
self.focusedIndex -= 1
if self.renderOffset > 0:
self.renderOffset -= 1
self.repaint()
def func_PGUP(self, modifier):
if self.renderOffset != 0:
self.focusedIndex -= self.renderOffset
self.renderOffset = 0
else:
self.focusedIndex = max(0, self.focusedIndex - self.height)
self.repaint()
def func_DOWN_ARROW(self, modifier):
if self.focusedIndex < len(self.sequence) - 1:
self.focusedIndex += 1
if self.renderOffset < self.height - 1:
self.renderOffset += 1
self.repaint()
def func_PGDN(self, modifier):
if self.renderOffset != self.height - 1:
change = self.height - self.renderOffset - 1
if change + self.focusedIndex >= len(self.sequence):
change = len(self.sequence) - self.focusedIndex - 1
self.focusedIndex += change
self.renderOffset = self.height - 1
else:
self.focusedIndex = min(
len(self.sequence) - 1, self.focusedIndex + self.height
)
self.repaint()
def characterReceived(self, keyID, modifier):
if keyID == b"\r":
self.onSelect(self.sequence[self.focusedIndex])
def render(self, width, height, terminal):
self.height = height
start = self.focusedIndex - self.renderOffset
if start > len(self.sequence) - height:
start = max(0, len(self.sequence) - height)
elements = self.sequence[start : start + height]
for n, ele in enumerate(elements):
terminal.cursorPosition(0, n)
if n == self.renderOffset:
terminal.saveCursor()
if self.focused:
modes = str(insults.REVERSE_VIDEO), str(insults.BOLD)
else:
modes = (str(insults.REVERSE_VIDEO),)
terminal.selectGraphicRendition(*modes)
text = ele[:width]
terminal.write(text + (b" " * (width - len(text))))
if n == self.renderOffset:
terminal.restoreCursor()

View File

@@ -0,0 +1,456 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
This module contains interfaces defined for the L{twisted.conch} package.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
from zope.interface import Attribute, Interface
if TYPE_CHECKING:
from twisted.conch.ssh.keys import Key
class IConchUser(Interface):
"""
A user who has been authenticated to Cred through Conch. This is
the interface between the SSH connection and the user.
"""
conn = Attribute("The SSHConnection object for this user.")
def lookupChannel(channelType, windowSize, maxPacket, data):
"""
The other side requested a channel of some sort.
C{channelType} is the type of channel being requested,
as an ssh connection protocol channel type.
C{data} is any other packet data (often nothing).
We return a subclass of L{SSHChannel<ssh.channel.SSHChannel>}. If
the channel type is unknown, we return C{None}.
For other failures, we raise an exception. If a
L{ConchError<error.ConchError>} is raised, the C{.value} will
be the message, and the C{.data} will be the error code.
@param channelType: The requested channel type
@type channelType: L{bytes}
@param windowSize: The initial size of the remote window
@type windowSize: L{int}
@param maxPacket: The largest packet we should send
@type maxPacket: L{int}
@param data: Additional request data
@type data: L{bytes}
@rtype: a subclass of L{SSHChannel} or L{None}
"""
def lookupSubsystem(subsystem, data):
"""
The other side requested a subsystem.
We return a L{Protocol} implementing the requested subsystem.
If the subsystem is not available, we return C{None}.
@param subsystem: The name of the subsystem being requested
@type subsystem: L{bytes}
@param data: Additional request data (often nothing)
@type data: L{bytes}
@rtype: L{Protocol} or L{None}
"""
def gotGlobalRequest(requestType, data):
"""
A global request was sent from the other side.
We return a true value on success or a false value on failure.
If we indicate success by returning a tuple, its second item
will be sent to the other side as additional response data.
@param requestType: The type of the request
@type requestType: L{bytes}
@param data: Additional request data
@type data: L{bytes}
@rtype: boolean or L{tuple}
"""
class ISession(Interface):
def getPty(term, windowSize, modes):
"""
Get a pseudo-terminal for use by a shell or command.
If a pseudo-terminal is not available, or the request otherwise
fails, raise an exception.
"""
def openShell(proto):
"""
Open a shell and connect it to proto.
@param proto: a L{ProcessProtocol} instance.
"""
def execCommand(proto, command):
"""
Execute a command.
@param proto: a L{ProcessProtocol} instance.
"""
def windowChanged(newWindowSize):
"""
Called when the size of the remote screen has changed.
"""
def eofReceived():
"""
Called when the other side has indicated no more data will be sent.
"""
def closed():
"""
Called when the session is closed.
"""
class EnvironmentVariableNotPermitted(ValueError):
"""Setting this environment variable in this session is not permitted."""
class ISessionSetEnv(Interface):
"""A session that can set environment variables."""
def setEnv(name, value):
"""
Set an environment variable for the shell or command to be started.
From U{RFC 4254, section 6.4
<https://tools.ietf.org/html/rfc4254#section-6.4>}: "Uncontrolled
setting of environment variables in a privileged process can be a
security hazard. It is recommended that implementations either
maintain a list of allowable variable names or only set environment
variables after the server process has dropped sufficient
privileges."
(OpenSSH refuses all environment variables by default, but has an
C{AcceptEnv} configuration option to select specific variables to
accept.)
@param name: The name of the environment variable to set.
@type name: L{bytes}
@param value: The value of the environment variable to set.
@type value: L{bytes}
@raise EnvironmentVariableNotPermitted: if setting this environment
variable is not permitted.
"""
class ISFTPServer(Interface):
"""
SFTP subsystem for server-side communication.
Each method should check to verify that the user has permission for
their actions.
"""
avatar = Attribute(
"""
The avatar returned by the Realm that we are authenticated with,
and represents the logged-in user.
"""
)
def gotVersion(otherVersion, extData):
"""
Called when the client sends their version info.
otherVersion is an integer representing the version of the SFTP
protocol they are claiming.
extData is a dictionary of extended_name : extended_data items.
These items are sent by the client to indicate additional features.
This method should return a dictionary of extended_name : extended_data
items. These items are the additional features (if any) supported
by the server.
"""
return {}
def openFile(filename, flags, attrs):
"""
Called when the clients asks to open a file.
@param filename: a string representing the file to open.
@param flags: an integer of the flags to open the file with, ORed
together. The flags and their values are listed at the bottom of
L{twisted.conch.ssh.filetransfer} as FXF_*.
@param attrs: a list of attributes to open the file with. It is a
dictionary, consisting of 0 or more keys. The possible keys are::
size: the size of the file in bytes
uid: the user ID of the file as an integer
gid: the group ID of the file as an integer
permissions: the permissions of the file with as an integer.
the bit representation of this field is defined by POSIX.
atime: the access time of the file as seconds since the epoch.
mtime: the modification time of the file as seconds since the epoch.
ext_*: extended attributes. The server is not required to
understand this, but it may.
NOTE: there is no way to indicate text or binary files. it is up
to the SFTP client to deal with this.
This method returns an object that meets the ISFTPFile interface.
Alternatively, it can return a L{Deferred} that will be called back
with the object.
"""
def removeFile(filename):
"""
Remove the given file.
This method returns when the remove succeeds, or a Deferred that is
called back when it succeeds.
@param filename: the name of the file as a string.
"""
def renameFile(oldpath, newpath):
"""
Rename the given file.
This method returns when the rename succeeds, or a L{Deferred} that is
called back when it succeeds. If the rename fails, C{renameFile} will
raise an implementation-dependent exception.
@param oldpath: the current location of the file.
@param newpath: the new file name.
"""
def makeDirectory(path, attrs):
"""
Make a directory.
This method returns when the directory is created, or a Deferred that
is called back when it is created.
@param path: the name of the directory to create as a string.
@param attrs: a dictionary of attributes to create the directory with.
Its meaning is the same as the attrs in the L{openFile} method.
"""
def removeDirectory(path):
"""
Remove a directory (non-recursively)
It is an error to remove a directory that has files or directories in
it.
This method returns when the directory is removed, or a Deferred that
is called back when it is removed.
@param path: the directory to remove.
"""
def openDirectory(path):
"""
Open a directory for scanning.
This method returns an iterable object that has a close() method,
or a Deferred that is called back with same.
The close() method is called when the client is finished reading
from the directory. At this point, the iterable will no longer
be used.
The iterable should return triples of the form (filename,
longname, attrs) or Deferreds that return the same. The
sequence must support __getitem__, but otherwise may be any
'sequence-like' object.
filename is the name of the file relative to the directory.
logname is an expanded format of the filename. The recommended format
is:
-rwxr-xr-x 1 mjos staff 348911 Mar 25 14:29 t-filexfer
1234567890 123 12345678 12345678 12345678 123456789012
The first line is sample output, the second is the length of the field.
The fields are: permissions, link count, user owner, group owner,
size in bytes, modification time.
attrs is a dictionary in the format of the attrs argument to openFile.
@param path: the directory to open.
"""
def getAttrs(path, followLinks):
"""
Return the attributes for the given path.
This method returns a dictionary in the same format as the attrs
argument to openFile or a Deferred that is called back with same.
@param path: the path to return attributes for as a string.
@param followLinks: a boolean. If it is True, follow symbolic links
and return attributes for the real path at the base. If it is False,
return attributes for the specified path.
"""
def setAttrs(path, attrs):
"""
Set the attributes for the path.
This method returns when the attributes are set or a Deferred that is
called back when they are.
@param path: the path to set attributes for as a string.
@param attrs: a dictionary in the same format as the attrs argument to
L{openFile}.
"""
def readLink(path):
"""
Find the root of a set of symbolic links.
This method returns the target of the link, or a Deferred that
returns the same.
@param path: the path of the symlink to read.
"""
def makeLink(linkPath, targetPath):
"""
Create a symbolic link.
This method returns when the link is made, or a Deferred that
returns the same.
@param linkPath: the pathname of the symlink as a string.
@param targetPath: the path of the target of the link as a string.
"""
def realPath(path):
"""
Convert any path to an absolute path.
This method returns the absolute path as a string, or a Deferred
that returns the same.
@param path: the path to convert as a string.
"""
def extendedRequest(extendedName, extendedData):
"""
This is the extension mechanism for SFTP. The other side can send us
arbitrary requests.
If we don't implement the request given by extendedName, raise
NotImplementedError.
The return value is a string, or a Deferred that will be called
back with a string.
@param extendedName: the name of the request as a string.
@param extendedData: the data the other side sent with the request,
as a string.
"""
class IKnownHostEntry(Interface):
"""
A L{IKnownHostEntry} is an entry in an OpenSSH-formatted C{known_hosts}
file.
@since: 8.2
"""
def matchesKey(key: Key) -> bool:
"""
Return True if this entry matches the given Key object, False
otherwise.
@param key: The key object to match against.
"""
def matchesHost(hostname: bytes) -> bool:
"""
Return True if this entry matches the given hostname, False otherwise.
Note that this does no name resolution; if you want to match an IP
address, you have to resolve it yourself, and pass it in as a dotted
quad string.
@param hostname: The hostname to match against.
"""
def toString() -> bytes:
"""
@return: a serialized string representation of this entry, suitable for
inclusion in a known_hosts file. (Newline not included.)
"""
class ISFTPFile(Interface):
"""
This represents an open file on the server. An object adhering to this
interface should be returned from L{openFile}().
"""
def close():
"""
Close the file.
This method returns nothing if the close succeeds immediately, or a
Deferred that is called back when the close succeeds.
"""
def readChunk(offset, length):
"""
Read from the file.
If EOF is reached before any data is read, raise EOFError.
This method returns the data as a string, or a Deferred that is
called back with same.
@param offset: an integer that is the index to start from in the file.
@param length: the maximum length of data to return. The actual amount
returned may less than this. For normal disk files, however,
this should read the requested number (up to the end of the file).
"""
def writeChunk(offset, data):
"""
Write to the file.
This method returns when the write completes, or a Deferred that is
called when it completes.
@param offset: an integer that is the index to start from in the file.
@param data: a string that is the data to write.
"""
def getAttrs():
"""
Return the attributes for the file.
This method returns a dictionary in the same format as the attrs
argument to L{openFile} or a L{Deferred} that is called back with same.
"""
def setAttrs(attrs):
"""
Set the attributes for the file.
This method returns when the attributes are set or a Deferred that is
called back when they are.
@param attrs: a dictionary in the same format as the attrs argument to
L{openFile}.
"""

View File

@@ -0,0 +1,104 @@
# -*- test-case-name: twisted.conch.test.test_cftp -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
import array
import stat
from time import localtime, strftime, time
# Locale-independent month names to use instead of strftime's
_MONTH_NAMES = dict(
list(
zip(
list(range(1, 13)),
"Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec".split(),
)
)
)
def lsLine(name, s):
"""
Build an 'ls' line for a file ('file' in its generic sense, it
can be of any type).
"""
mode = s.st_mode
perms = array.array("B", b"-" * 10)
ft = stat.S_IFMT(mode)
if stat.S_ISDIR(ft):
perms[0] = ord("d")
elif stat.S_ISCHR(ft):
perms[0] = ord("c")
elif stat.S_ISBLK(ft):
perms[0] = ord("b")
elif stat.S_ISREG(ft):
perms[0] = ord("-")
elif stat.S_ISFIFO(ft):
perms[0] = ord("f")
elif stat.S_ISLNK(ft):
perms[0] = ord("l")
elif stat.S_ISSOCK(ft):
perms[0] = ord("s")
else:
perms[0] = ord("!")
# User
if mode & stat.S_IRUSR:
perms[1] = ord("r")
if mode & stat.S_IWUSR:
perms[2] = ord("w")
if mode & stat.S_IXUSR:
perms[3] = ord("x")
# Group
if mode & stat.S_IRGRP:
perms[4] = ord("r")
if mode & stat.S_IWGRP:
perms[5] = ord("w")
if mode & stat.S_IXGRP:
perms[6] = ord("x")
# Other
if mode & stat.S_IROTH:
perms[7] = ord("r")
if mode & stat.S_IWOTH:
perms[8] = ord("w")
if mode & stat.S_IXOTH:
perms[9] = ord("x")
# Suid/sgid
if mode & stat.S_ISUID:
if perms[3] == ord("x"):
perms[3] = ord("s")
else:
perms[3] = ord("S")
if mode & stat.S_ISGID:
if perms[6] == ord("x"):
perms[6] = ord("s")
else:
perms[6] = ord("S")
if isinstance(name, bytes):
name = name.decode("utf-8")
lsPerms = perms.tobytes()
lsPerms = lsPerms.decode("utf-8")
lsresult = [
lsPerms,
str(s.st_nlink).rjust(5),
" ",
str(s.st_uid).ljust(9),
str(s.st_gid).ljust(9),
str(s.st_size).rjust(8),
" ",
]
# Need to specify the month manually, as strftime depends on locale
ttup = localtime(s.st_mtime)
sixmonths = 60 * 60 * 24 * 7 * 26
if s.st_mtime + sixmonths < time(): # Last edited more than 6mo ago
strtime = strftime("%%s %d %Y ", ttup)
else:
strtime = strftime("%%s %d %H:%M ", ttup)
lsresult.append(strtime % (_MONTH_NAMES[ttup[1]],))
lsresult.append(name)
return "".join(lsresult)
__all__ = ["lsLine"]

View File

@@ -0,0 +1,392 @@
# -*- test-case-name: twisted.conch.test.test_manhole -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Line-input oriented interactive interpreter loop.
Provides classes for handling Python source input and arbitrary output
interactively from a Twisted application. Also included is syntax coloring
code with support for VT102 terminals, control code handling (^C, ^D, ^Q),
and reasonable handling of Deferreds.
@author: Jp Calderone
"""
import code
import sys
import tokenize
from io import BytesIO
from traceback import format_exception
from types import TracebackType
from typing import Type
from twisted.conch import recvline
from twisted.internet import defer
from twisted.python.htmlizer import TokenPrinter
from twisted.python.monkey import MonkeyPatcher
class FileWrapper:
"""
Minimal write-file-like object.
Writes are translated into addOutput calls on an object passed to
__init__. Newlines are also converted from network to local style.
"""
softspace = 0
state = "normal"
def __init__(self, o):
self.o = o
def flush(self):
pass
def write(self, data):
self.o.addOutput(data.replace("\r\n", "\n"))
def writelines(self, lines):
self.write("".join(lines))
class ManholeInterpreter(code.InteractiveInterpreter):
"""
Interactive Interpreter with special output and Deferred support.
Aside from the features provided by L{code.InteractiveInterpreter}, this
class captures sys.stdout output and redirects it to the appropriate
location (the Manhole protocol instance). It also treats Deferreds
which reach the top-level specially: each is formatted to the user with
a unique identifier and a new callback and errback added to it, each of
which will format the unique identifier and the result with which the
Deferred fires and then pass it on to the next participant in the
callback chain.
"""
numDeferreds = 0
def __init__(self, handler, locals=None, filename="<console>"):
code.InteractiveInterpreter.__init__(self, locals)
self._pendingDeferreds = {}
self.handler = handler
self.filename = filename
self.resetBuffer()
self.monkeyPatcher = MonkeyPatcher()
self.monkeyPatcher.addPatch(sys, "displayhook", self.displayhook)
self.monkeyPatcher.addPatch(sys, "excepthook", self.excepthook)
self.monkeyPatcher.addPatch(sys, "stdout", FileWrapper(self.handler))
def resetBuffer(self):
"""
Reset the input buffer.
"""
self.buffer = []
def push(self, line):
"""
Push a line to the interpreter.
The line should not have a trailing newline; it may have
internal newlines. The line is appended to a buffer and the
interpreter's runsource() method is called with the
concatenated contents of the buffer as source. If this
indicates that the command was executed or invalid, the buffer
is reset; otherwise, the command is incomplete, and the buffer
is left as it was after the line was appended. The return
value is 1 if more input is required, 0 if the line was dealt
with in some way (this is the same as runsource()).
@param line: line of text
@type line: L{bytes}
@return: L{bool} from L{code.InteractiveInterpreter.runsource}
"""
self.buffer.append(line)
source = b"\n".join(self.buffer)
source = source.decode("utf-8")
more = self.runsource(source, self.filename)
if not more:
self.resetBuffer()
return more
def runcode(self, *a, **kw):
with self.monkeyPatcher:
code.InteractiveInterpreter.runcode(self, *a, **kw)
def excepthook(
self,
excType: Type[BaseException],
excValue: BaseException,
excTraceback: TracebackType,
) -> None:
"""
Format exception tracebacks and write them to the output handler.
"""
code_obj = excTraceback.tb_frame.f_code
if code_obj.co_filename == code.__file__ and code_obj.co_name == "runcode":
traceback = excTraceback.tb_next
else:
# Workaround for https://github.com/python/cpython/issues/122478,
# present e.g. in Python 3.12.6:
traceback = excTraceback
lines = format_exception(excType, excValue, traceback)
self.write("".join(lines))
def displayhook(self, obj):
self.locals["_"] = obj
if isinstance(obj, defer.Deferred):
# XXX Ick, where is my "hasFired()" interface?
if hasattr(obj, "result"):
self.write(repr(obj))
elif id(obj) in self._pendingDeferreds:
self.write("<Deferred #%d>" % (self._pendingDeferreds[id(obj)][0],))
else:
d = self._pendingDeferreds
k = self.numDeferreds
d[id(obj)] = (k, obj)
self.numDeferreds += 1
obj.addCallbacks(
self._cbDisplayDeferred,
self._ebDisplayDeferred,
callbackArgs=(k, obj),
errbackArgs=(k, obj),
)
self.write("<Deferred #%d>" % (k,))
elif obj is not None:
self.write(repr(obj))
def _cbDisplayDeferred(self, result, k, obj):
self.write("Deferred #%d called back: %r" % (k, result), True)
del self._pendingDeferreds[id(obj)]
return result
def _ebDisplayDeferred(self, failure, k, obj):
self.write("Deferred #%d failed: %r" % (k, failure.getErrorMessage()), True)
del self._pendingDeferreds[id(obj)]
return failure
def write(self, data, isAsync=None):
self.handler.addOutput(data, isAsync)
CTRL_C = b"\x03"
CTRL_D = b"\x04"
CTRL_BACKSLASH = b"\x1c"
CTRL_L = b"\x0c"
CTRL_A = b"\x01"
CTRL_E = b"\x05"
class Manhole(recvline.HistoricRecvLine):
r"""
Mediator between a fancy line source and an interactive interpreter.
This accepts lines from its transport and passes them on to a
L{ManholeInterpreter}. Control commands (^C, ^D, ^\) are also handled
with something approximating their normal terminal-mode behavior. It
can optionally be constructed with a dict which will be used as the
local namespace for any code executed.
"""
namespace = None
def __init__(self, namespace=None):
recvline.HistoricRecvLine.__init__(self)
if namespace is not None:
self.namespace = namespace.copy()
def connectionMade(self):
recvline.HistoricRecvLine.connectionMade(self)
self.interpreter = ManholeInterpreter(self, self.namespace)
self.keyHandlers[CTRL_C] = self.handle_INT
self.keyHandlers[CTRL_D] = self.handle_EOF
self.keyHandlers[CTRL_L] = self.handle_FF
self.keyHandlers[CTRL_A] = self.handle_HOME
self.keyHandlers[CTRL_E] = self.handle_END
self.keyHandlers[CTRL_BACKSLASH] = self.handle_QUIT
def handle_INT(self):
"""
Handle ^C as an interrupt keystroke by resetting the current input
variables to their initial state.
"""
self.pn = 0
self.lineBuffer = []
self.lineBufferIndex = 0
self.interpreter.resetBuffer()
self.terminal.nextLine()
self.terminal.write(b"KeyboardInterrupt")
self.terminal.nextLine()
self.terminal.write(self.ps[self.pn])
def handle_EOF(self):
if self.lineBuffer:
self.terminal.write(b"\a")
else:
self.handle_QUIT()
def handle_FF(self):
"""
Handle a 'form feed' byte - generally used to request a screen
refresh/redraw.
"""
self.terminal.eraseDisplay()
self.terminal.cursorHome()
self.drawInputLine()
def handle_QUIT(self):
self.terminal.loseConnection()
def _needsNewline(self):
w = self.terminal.lastWrite
return not w.endswith(b"\n") and not w.endswith(b"\x1bE")
def addOutput(self, data, isAsync=None):
if isAsync:
self.terminal.eraseLine()
self.terminal.cursorBackward(len(self.lineBuffer) + len(self.ps[self.pn]))
self.terminal.write(data)
if isAsync:
if self._needsNewline():
self.terminal.nextLine()
self.terminal.write(self.ps[self.pn])
if self.lineBuffer:
oldBuffer = self.lineBuffer
self.lineBuffer = []
self.lineBufferIndex = 0
self._deliverBuffer(oldBuffer)
def lineReceived(self, line):
more = self.interpreter.push(line)
self.pn = bool(more)
if self._needsNewline():
self.terminal.nextLine()
self.terminal.write(self.ps[self.pn])
class VT102Writer:
"""
Colorizer for Python tokens.
A series of tokens are written to instances of this object. Each is
colored in a particular way. The final line of the result of this is
generally added to the output.
"""
typeToColor = {
"identifier": b"\x1b[31m",
"keyword": b"\x1b[32m",
"parameter": b"\x1b[33m",
"variable": b"\x1b[1;33m",
"string": b"\x1b[35m",
"number": b"\x1b[36m",
"op": b"\x1b[37m",
}
normalColor = b"\x1b[0m"
def __init__(self):
self.written = []
def color(self, type):
r = self.typeToColor.get(type, b"")
return r
def write(self, token, type=None):
if token and token != b"\r":
c = self.color(type)
if c:
self.written.append(c)
self.written.append(token)
if c:
self.written.append(self.normalColor)
def __bytes__(self):
s = b"".join(self.written)
return s.strip(b"\n").splitlines()[-1]
def lastColorizedLine(source):
"""
Tokenize and colorize the given Python source.
Returns a VT102-format colorized version of the last line of C{source}.
@param source: Python source code
@type source: L{str} or L{bytes}
@return: L{bytes} of colorized source
"""
if not isinstance(source, bytes):
source = source.encode("utf-8")
w = VT102Writer()
p = TokenPrinter(w.write).printtoken
s = BytesIO(source)
for token in tokenize.tokenize(s.readline):
(tokenType, string, start, end, line) = token
p(tokenType, string, start, end, line)
return bytes(w)
class ColoredManhole(Manhole):
"""
A REPL which syntax colors input as users type it.
"""
def getSource(self):
"""
Return a string containing the currently entered source.
This is only the code which will be considered for execution
next.
"""
return b"\n".join(self.interpreter.buffer) + b"\n" + b"".join(self.lineBuffer)
def characterReceived(self, ch, moreCharactersComing):
if self.mode == "insert":
self.lineBuffer.insert(self.lineBufferIndex, ch)
else:
self.lineBuffer[self.lineBufferIndex : self.lineBufferIndex + 1] = [ch]
self.lineBufferIndex += 1
if moreCharactersComing:
# Skip it all, we'll get called with another character in
# like 2 femtoseconds.
return
if ch == b" ":
# Don't bother to try to color whitespace
self.terminal.write(ch)
return
source = self.getSource()
# Try to write some junk
try:
coloredLine = lastColorizedLine(source)
except tokenize.TokenError:
# We couldn't do it. Strange. Oh well, just add the character.
self.terminal.write(ch)
else:
# Success! Clear the source on this line.
self.terminal.eraseLine()
self.terminal.cursorBackward(
len(self.lineBuffer) + len(self.ps[self.pn]) - 1
)
# And write a new, colorized one.
self.terminal.write(self.ps[self.pn] + coloredLine)
# And move the cursor to where it belongs
n = len(self.lineBuffer) - self.lineBufferIndex
if n:
self.terminal.cursorBackward(n)

View File

@@ -0,0 +1,148 @@
# -*- test-case-name: twisted.conch.test.test_manhole -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
insults/SSH integration support.
@author: Jp Calderone
"""
from typing import Dict
from zope.interface import implementer
from twisted.conch import avatar, error as econch, interfaces as iconch
from twisted.conch.insults import insults
from twisted.conch.ssh import factory, session
from twisted.python import components
class _Glue:
"""
A feeble class for making one attribute look like another.
This should be replaced with a real class at some point, probably.
Try not to write new code that uses it.
"""
def __init__(self, **kw):
self.__dict__.update(kw)
def __getattr__(self, name):
raise AttributeError(self.name, "has no attribute", name)
class TerminalSessionTransport:
def __init__(self, proto, chainedProtocol, avatar, width, height):
self.proto = proto
self.avatar = avatar
self.chainedProtocol = chainedProtocol
protoSession = self.proto.session
self.proto.makeConnection(
_Glue(
write=self.chainedProtocol.dataReceived,
loseConnection=lambda: avatar.conn.sendClose(protoSession),
name="SSH Proto Transport",
)
)
def loseConnection():
self.proto.loseConnection()
self.chainedProtocol.makeConnection(
_Glue(
write=self.proto.write,
loseConnection=loseConnection,
name="Chained Proto Transport",
)
)
# XXX TODO
# chainedProtocol is supposed to be an ITerminalTransport,
# maybe. That means perhaps its terminalProtocol attribute is
# an ITerminalProtocol, it could be. So calling terminalSize
# on that should do the right thing But it'd be nice to clean
# this bit up.
self.chainedProtocol.terminalProtocol.terminalSize(width, height)
@implementer(iconch.ISession)
class TerminalSession(components.Adapter):
transportFactory = TerminalSessionTransport
chainedProtocolFactory = insults.ServerProtocol
def getPty(self, term, windowSize, attrs):
self.height, self.width = windowSize[:2]
def openShell(self, proto):
self.transportFactory(
proto,
self.chainedProtocolFactory(),
iconch.IConchUser(self.original),
self.width,
self.height,
)
def execCommand(self, proto, cmd):
raise econch.ConchError("Cannot execute commands")
def windowChanged(self, newWindowSize):
# ISession.windowChanged
raise NotImplementedError("Unimplemented: TerminalSession.windowChanged")
def eofReceived(self):
# ISession.eofReceived
raise NotImplementedError("Unimplemented: TerminalSession.eofReceived")
def closed(self):
# ISession.closed
pass
class TerminalUser(avatar.ConchUser, components.Adapter):
def __init__(self, original, avatarId):
components.Adapter.__init__(self, original)
avatar.ConchUser.__init__(self)
self.channelLookup[b"session"] = session.SSHSession
class TerminalRealm:
userFactory = TerminalUser
sessionFactory = TerminalSession
transportFactory = TerminalSessionTransport
chainedProtocolFactory = insults.ServerProtocol
def _getAvatar(self, avatarId):
comp = components.Componentized()
user = self.userFactory(comp, avatarId)
sess = self.sessionFactory(comp)
sess.transportFactory = self.transportFactory
sess.chainedProtocolFactory = self.chainedProtocolFactory
comp.setComponent(iconch.IConchUser, user)
comp.setComponent(iconch.ISession, sess)
return user
def __init__(self, transportFactory=None):
if transportFactory is not None:
self.transportFactory = transportFactory
def requestAvatar(self, avatarId, mind, *interfaces):
for i in interfaces:
if i is iconch.IConchUser:
return (iconch.IConchUser, self._getAvatar(avatarId), lambda: None)
raise NotImplementedError()
class ConchFactory(factory.SSHFactory):
publicKeys: Dict[bytes, bytes] = {}
privateKeys: Dict[bytes, bytes] = {}
def __init__(self, portal):
self.portal = portal

View File

@@ -0,0 +1,180 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
TAP plugin for creating telnet- and ssh-accessible manhole servers.
@author: Jp Calderone
"""
from zope.interface import implementer
from twisted.application import service, strports
from twisted.conch import manhole, manhole_ssh, telnet
from twisted.conch.insults import insults
from twisted.conch.ssh import keys
from twisted.cred import checkers, portal
from twisted.internet import protocol
from twisted.python import filepath, usage
class makeTelnetProtocol:
def __init__(self, portal):
self.portal = portal
def __call__(self):
auth = telnet.AuthenticatingTelnetProtocol
args = (self.portal,)
return telnet.TelnetTransport(auth, *args)
class chainedProtocolFactory:
def __init__(self, namespace):
self.namespace = namespace
def __call__(self):
return insults.ServerProtocol(manhole.ColoredManhole, self.namespace)
@implementer(portal.IRealm)
class _StupidRealm:
def __init__(self, proto, *a, **kw):
self.protocolFactory = proto
self.protocolArgs = a
self.protocolKwArgs = kw
def requestAvatar(self, avatarId, *interfaces):
if telnet.ITelnetProtocol in interfaces:
return (
telnet.ITelnetProtocol,
self.protocolFactory(*self.protocolArgs, **self.protocolKwArgs),
lambda: None,
)
raise NotImplementedError()
class Options(usage.Options):
optParameters = [
[
"telnetPort",
"t",
None,
(
"strports description of the address on which to listen for telnet "
"connections"
),
],
[
"sshPort",
"s",
None,
(
"strports description of the address on which to listen for ssh "
"connections"
),
],
[
"passwd",
"p",
"/etc/passwd",
"name of a passwd(5)-format username/password file",
],
[
"sshKeyDir",
None,
"<USER DATA DIR>",
"Directory where the autogenerated SSH key is kept.",
],
["sshKeyName", None, "server.key", "Filename of the autogenerated SSH key."],
["sshKeySize", None, 4096, "Size of the automatically generated SSH key."],
]
def __init__(self):
usage.Options.__init__(self)
self["namespace"] = None
def postOptions(self):
if self["telnetPort"] is None and self["sshPort"] is None:
raise usage.UsageError(
"At least one of --telnetPort and --sshPort must be specified"
)
def makeService(options):
"""
Create a manhole server service.
@type options: L{dict}
@param options: A mapping describing the configuration of
the desired service. Recognized key/value pairs are::
"telnetPort": strports description of the address on which
to listen for telnet connections. If None,
no telnet service will be started.
"sshPort": strports description of the address on which to
listen for ssh connections. If None, no ssh
service will be started.
"namespace": dictionary containing desired initial locals
for manhole connections. If None, an empty
dictionary will be used.
"passwd": Name of a passwd(5)-format username/password file.
"sshKeyDir": The folder that the SSH server key will be kept in.
"sshKeyName": The filename of the key.
"sshKeySize": The size of the key, in bits. Default is 4096.
@rtype: L{twisted.application.service.IService}
@return: A manhole service.
"""
svc = service.MultiService()
namespace = options["namespace"]
if namespace is None:
namespace = {}
checker = checkers.FilePasswordDB(options["passwd"])
if options["telnetPort"]:
telnetRealm = _StupidRealm(
telnet.TelnetBootstrapProtocol,
insults.ServerProtocol,
manhole.ColoredManhole,
namespace,
)
telnetPortal = portal.Portal(telnetRealm, [checker])
telnetFactory = protocol.ServerFactory()
telnetFactory.protocol = makeTelnetProtocol(telnetPortal)
telnetService = strports.service(options["telnetPort"], telnetFactory)
telnetService.setServiceParent(svc)
if options["sshPort"]:
sshRealm = manhole_ssh.TerminalRealm()
sshRealm.chainedProtocolFactory = chainedProtocolFactory(namespace)
sshPortal = portal.Portal(sshRealm, [checker])
sshFactory = manhole_ssh.ConchFactory(sshPortal)
if options["sshKeyDir"] != "<USER DATA DIR>":
keyDir = options["sshKeyDir"]
else:
from twisted.python._appdirs import getDataDirectory
keyDir = getDataDirectory()
keyLocation = filepath.FilePath(keyDir).child(options["sshKeyName"])
sshKey = keys._getPersistentRSAKey(keyLocation, int(options["sshKeySize"]))
sshFactory.publicKeys[b"ssh-rsa"] = sshKey
sshFactory.privateKeys[b"ssh-rsa"] = sshKey
sshService = strports.service(options["sshPort"], sshFactory)
sshService.setServiceParent(svc)
return svc

View File

@@ -0,0 +1,54 @@
# -*- test-case-name: twisted.conch.test.test_mixin -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Experimental optimization
This module provides a single mixin class which allows protocols to
collapse numerous small writes into a single larger one.
@author: Jp Calderone
"""
from twisted.internet import reactor
class BufferingMixin:
"""
Mixin which adds write buffering.
"""
_delayedWriteCall = None
data = None
DELAY = 0.0
def schedule(self):
return reactor.callLater(self.DELAY, self.flush)
def reschedule(self, token):
token.reset(self.DELAY)
def write(self, data):
"""
Buffer some bytes to be written soon.
Every call to this function delays the real write by C{self.DELAY}
seconds. When the delay expires, all collected bytes are written
to the underlying transport using L{ITransport.writeSequence}.
"""
if self._delayedWriteCall is None:
self.data = []
self._delayedWriteCall = self.schedule()
else:
self.reschedule(self._delayedWriteCall)
self.data.append(data)
def flush(self):
"""
Flush the buffer immediately.
"""
self._delayedWriteCall = None
self.transport.writeSequence(self.data)
self.data = None

View File

@@ -0,0 +1 @@
!.gitignore

View File

@@ -0,0 +1,10 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
#
"""
Support for OpenSSH configuration files.
Maintainer: Paul Swartz
"""

View File

@@ -0,0 +1,74 @@
# -*- test-case-name: twisted.conch.test.test_openssh_compat -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Factory for reading openssh configuration files: public keys, private keys, and
moduli file.
"""
import errno
import os
from typing import Dict, List, Optional, Tuple
from twisted.conch.openssh_compat import primes
from twisted.conch.ssh import common, factory, keys
from twisted.python.util import runAsEffectiveUser
class OpenSSHFactory(factory.SSHFactory):
dataRoot = "/usr/local/etc"
# For openbsd which puts moduli in a different directory from keys.
moduliRoot = "/usr/local/etc"
def getPublicKeys(self):
"""
Return the server public keys.
"""
ks = {}
for filename in os.listdir(self.dataRoot):
if filename[:9] == "ssh_host_" and filename[-8:] == "_key.pub":
try:
k = keys.Key.fromFile(os.path.join(self.dataRoot, filename))
t = common.getNS(k.blob())[0]
ks[t] = k
except Exception as e:
self._log.error(
"bad public key file {filename}: {error}",
filename=filename,
error=e,
)
return ks
def getPrivateKeys(self):
"""
Return the server private keys.
"""
privateKeys = {}
for filename in os.listdir(self.dataRoot):
if filename[:9] == "ssh_host_" and filename[-4:] == "_key":
fullPath = os.path.join(self.dataRoot, filename)
try:
key = keys.Key.fromFile(fullPath)
except OSError as e:
if e.errno == errno.EACCES:
# Not allowed, let's switch to root
key = runAsEffectiveUser(0, 0, keys.Key.fromFile, fullPath)
privateKeys[key.sshType()] = key
else:
raise
except Exception as e:
self._log.error(
"bad public key file {filename}: {error}",
filename=filename,
error=e,
)
else:
privateKeys[key.sshType()] = key
return privateKeys
def getPrimes(self) -> Optional[Dict[int, List[Tuple[int, int]]]]:
try:
return primes.parseModuliFile(self.moduliRoot + "/moduli")
except OSError:
return None

View File

@@ -0,0 +1,31 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
#
"""
Parsing for the moduli file, which contains Diffie-Hellman prime groups.
Maintainer: Paul Swartz
"""
from typing import Dict, List, Tuple
def parseModuliFile(filename: str) -> Dict[int, List[Tuple[int, int]]]:
with open(filename) as f:
lines = f.readlines()
primes: Dict[int, List[Tuple[int, int]]] = {}
for l in lines:
l = l.strip()
if not l or l[0] == "#":
continue
tim, typ, tst, tri, sizestr, genstr, modstr = l.split()
size = int(sizestr) + 1
gen = int(genstr)
mod = int(modstr, 16)
if size not in primes:
primes[size] = []
primes[size].append((gen, mod))
return primes

View File

@@ -0,0 +1,569 @@
# -*- test-case-name: twisted.conch.test.test_recvline -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Basic line editing support.
@author: Jp Calderone
"""
import string
from typing import Dict
from zope.interface import implementer
from twisted.conch.insults import helper, insults
from twisted.logger import Logger
from twisted.python import reflect
from twisted.python.compat import iterbytes
_counters: Dict[str, int] = {}
class Logging:
"""
Wrapper which logs attribute lookups.
This was useful in debugging something, I guess. I forget what.
It can probably be deleted or moved somewhere more appropriate.
Nothing special going on here, really.
"""
def __init__(self, original):
self.original = original
key = reflect.qual(original.__class__)
count = _counters.get(key, 0)
_counters[key] = count + 1
self._logFile = open(key + "-" + str(count), "w")
def __str__(self) -> str:
return str(super().__getattribute__("original"))
def __repr__(self) -> str:
return repr(super().__getattribute__("original"))
def __getattribute__(self, name):
original = super().__getattribute__("original")
logFile = super().__getattribute__("_logFile")
logFile.write(name + "\n")
return getattr(original, name)
@implementer(insults.ITerminalTransport)
class TransportSequence:
"""
An L{ITerminalTransport} implementation which forwards calls to
one or more other L{ITerminalTransport}s.
This is a cheap way for servers to keep track of the state they
expect the client to see, since all terminal manipulations can be
send to the real client and to a terminal emulator that lives in
the server process.
"""
for keyID in (
b"UP_ARROW",
b"DOWN_ARROW",
b"RIGHT_ARROW",
b"LEFT_ARROW",
b"HOME",
b"INSERT",
b"DELETE",
b"END",
b"PGUP",
b"PGDN",
b"F1",
b"F2",
b"F3",
b"F4",
b"F5",
b"F6",
b"F7",
b"F8",
b"F9",
b"F10",
b"F11",
b"F12",
):
execBytes = keyID + b" = object()"
execStr = execBytes.decode("ascii")
exec(execStr)
TAB = b"\t"
BACKSPACE = b"\x7f"
def __init__(self, *transports):
assert transports, "Cannot construct a TransportSequence with no transports"
self.transports = transports
for method in insults.ITerminalTransport:
exec(
"""\
def %s(self, *a, **kw):
for tpt in self.transports:
result = tpt.%s(*a, **kw)
return result
"""
% (method, method)
)
def getHost(self):
# ITransport.getHost
raise NotImplementedError("Unimplemented: TransportSequence.getHost")
def getPeer(self):
# ITransport.getPeer
raise NotImplementedError("Unimplemented: TransportSequence.getPeer")
def loseConnection(self):
# ITransport.loseConnection
raise NotImplementedError("Unimplemented: TransportSequence.loseConnection")
def write(self, data):
# ITransport.write
raise NotImplementedError("Unimplemented: TransportSequence.write")
def writeSequence(self, data):
# ITransport.writeSequence
raise NotImplementedError("Unimplemented: TransportSequence.writeSequence")
def cursorUp(self, n=1):
# ITerminalTransport.cursorUp
raise NotImplementedError("Unimplemented: TransportSequence.cursorUp")
def cursorDown(self, n=1):
# ITerminalTransport.cursorDown
raise NotImplementedError("Unimplemented: TransportSequence.cursorDown")
def cursorForward(self, n=1):
# ITerminalTransport.cursorForward
raise NotImplementedError("Unimplemented: TransportSequence.cursorForward")
def cursorBackward(self, n=1):
# ITerminalTransport.cursorBackward
raise NotImplementedError("Unimplemented: TransportSequence.cursorBackward")
def cursorPosition(self, column, line):
# ITerminalTransport.cursorPosition
raise NotImplementedError("Unimplemented: TransportSequence.cursorPosition")
def cursorHome(self):
# ITerminalTransport.cursorHome
raise NotImplementedError("Unimplemented: TransportSequence.cursorHome")
def index(self):
# ITerminalTransport.index
raise NotImplementedError("Unimplemented: TransportSequence.index")
def reverseIndex(self):
# ITerminalTransport.reverseIndex
raise NotImplementedError("Unimplemented: TransportSequence.reverseIndex")
def nextLine(self):
# ITerminalTransport.nextLine
raise NotImplementedError("Unimplemented: TransportSequence.nextLine")
def saveCursor(self):
# ITerminalTransport.saveCursor
raise NotImplementedError("Unimplemented: TransportSequence.saveCursor")
def restoreCursor(self):
# ITerminalTransport.restoreCursor
raise NotImplementedError("Unimplemented: TransportSequence.restoreCursor")
def setModes(self, modes):
# ITerminalTransport.setModes
raise NotImplementedError("Unimplemented: TransportSequence.setModes")
def resetModes(self, mode):
# ITerminalTransport.resetModes
raise NotImplementedError("Unimplemented: TransportSequence.resetModes")
def setPrivateModes(self, modes):
# ITerminalTransport.setPrivateModes
raise NotImplementedError("Unimplemented: TransportSequence.setPrivateModes")
def resetPrivateModes(self, modes):
# ITerminalTransport.resetPrivateModes
raise NotImplementedError("Unimplemented: TransportSequence.resetPrivateModes")
def applicationKeypadMode(self):
# ITerminalTransport.applicationKeypadMode
raise NotImplementedError(
"Unimplemented: TransportSequence.applicationKeypadMode"
)
def numericKeypadMode(self):
# ITerminalTransport.numericKeypadMode
raise NotImplementedError("Unimplemented: TransportSequence.numericKeypadMode")
def selectCharacterSet(self, charSet, which):
# ITerminalTransport.selectCharacterSet
raise NotImplementedError("Unimplemented: TransportSequence.selectCharacterSet")
def shiftIn(self):
# ITerminalTransport.shiftIn
raise NotImplementedError("Unimplemented: TransportSequence.shiftIn")
def shiftOut(self):
# ITerminalTransport.shiftOut
raise NotImplementedError("Unimplemented: TransportSequence.shiftOut")
def singleShift2(self):
# ITerminalTransport.singleShift2
raise NotImplementedError("Unimplemented: TransportSequence.singleShift2")
def singleShift3(self):
# ITerminalTransport.singleShift3
raise NotImplementedError("Unimplemented: TransportSequence.singleShift3")
def selectGraphicRendition(self, *attributes):
# ITerminalTransport.selectGraphicRendition
raise NotImplementedError(
"Unimplemented: TransportSequence.selectGraphicRendition"
)
def horizontalTabulationSet(self):
# ITerminalTransport.horizontalTabulationSet
raise NotImplementedError(
"Unimplemented: TransportSequence.horizontalTabulationSet"
)
def tabulationClear(self):
# ITerminalTransport.tabulationClear
raise NotImplementedError("Unimplemented: TransportSequence.tabulationClear")
def tabulationClearAll(self):
# ITerminalTransport.tabulationClearAll
raise NotImplementedError("Unimplemented: TransportSequence.tabulationClearAll")
def doubleHeightLine(self, top=True):
# ITerminalTransport.doubleHeightLine
raise NotImplementedError("Unimplemented: TransportSequence.doubleHeightLine")
def singleWidthLine(self):
# ITerminalTransport.singleWidthLine
raise NotImplementedError("Unimplemented: TransportSequence.singleWidthLine")
def doubleWidthLine(self):
# ITerminalTransport.doubleWidthLine
raise NotImplementedError("Unimplemented: TransportSequence.doubleWidthLine")
def eraseToLineEnd(self):
# ITerminalTransport.eraseToLineEnd
raise NotImplementedError("Unimplemented: TransportSequence.eraseToLineEnd")
def eraseToLineBeginning(self):
# ITerminalTransport.eraseToLineBeginning
raise NotImplementedError(
"Unimplemented: TransportSequence.eraseToLineBeginning"
)
def eraseLine(self):
# ITerminalTransport.eraseLine
raise NotImplementedError("Unimplemented: TransportSequence.eraseLine")
def eraseToDisplayEnd(self):
# ITerminalTransport.eraseToDisplayEnd
raise NotImplementedError("Unimplemented: TransportSequence.eraseToDisplayEnd")
def eraseToDisplayBeginning(self):
# ITerminalTransport.eraseToDisplayBeginning
raise NotImplementedError(
"Unimplemented: TransportSequence.eraseToDisplayBeginning"
)
def eraseDisplay(self):
# ITerminalTransport.eraseDisplay
raise NotImplementedError("Unimplemented: TransportSequence.eraseDisplay")
def deleteCharacter(self, n=1):
# ITerminalTransport.deleteCharacter
raise NotImplementedError("Unimplemented: TransportSequence.deleteCharacter")
def insertLine(self, n=1):
# ITerminalTransport.insertLine
raise NotImplementedError("Unimplemented: TransportSequence.insertLine")
def deleteLine(self, n=1):
# ITerminalTransport.deleteLine
raise NotImplementedError("Unimplemented: TransportSequence.deleteLine")
def reportCursorPosition(self):
# ITerminalTransport.reportCursorPosition
raise NotImplementedError(
"Unimplemented: TransportSequence.reportCursorPosition"
)
def reset(self):
# ITerminalTransport.reset
raise NotImplementedError("Unimplemented: TransportSequence.reset")
def unhandledControlSequence(self, seq):
# ITerminalTransport.unhandledControlSequence
raise NotImplementedError(
"Unimplemented: TransportSequence.unhandledControlSequence"
)
class LocalTerminalBufferMixin:
"""
A mixin for RecvLine subclasses which records the state of the terminal.
This is accomplished by performing all L{ITerminalTransport} operations on both
the transport passed to makeConnection and an instance of helper.TerminalBuffer.
@ivar terminalCopy: A L{helper.TerminalBuffer} instance which efforts
will be made to keep up to date with the actual terminal
associated with this protocol instance.
"""
def makeConnection(self, transport):
self.terminalCopy = helper.TerminalBuffer()
self.terminalCopy.connectionMade()
return super().makeConnection(TransportSequence(transport, self.terminalCopy))
def __str__(self) -> str:
return str(self.terminalCopy)
class RecvLine(insults.TerminalProtocol):
"""
L{TerminalProtocol} which adds line editing features.
Clients will be prompted for lines of input with all the usual
features: character echoing, left and right arrow support for
moving the cursor to different areas of the line buffer, backspace
and delete for removing characters, and insert for toggling
between typeover and insert mode. Tabs will be expanded to enough
spaces to move the cursor to the next tabstop (every four
characters by default). Enter causes the line buffer to be
cleared and the line to be passed to the lineReceived() method
which, by default, does nothing. Subclasses are responsible for
redrawing the input prompt (this will probably change).
"""
width = 80
height = 24
TABSTOP = 4
ps = (b">>> ", b"... ")
pn = 0
_printableChars = string.printable.encode("ascii")
_log = Logger()
def connectionMade(self):
# A list containing the characters making up the current line
self.lineBuffer = []
# A zero-based (wtf else?) index into self.lineBuffer.
# Indicates the current cursor position.
self.lineBufferIndex = 0
t = self.terminal
# A map of keyIDs to bound instance methods.
self.keyHandlers = {
t.LEFT_ARROW: self.handle_LEFT,
t.RIGHT_ARROW: self.handle_RIGHT,
t.TAB: self.handle_TAB,
# Both of these should not be necessary, but figuring out
# which is necessary is a huge hassle.
b"\r": self.handle_RETURN,
b"\n": self.handle_RETURN,
t.BACKSPACE: self.handle_BACKSPACE,
t.DELETE: self.handle_DELETE,
t.INSERT: self.handle_INSERT,
t.HOME: self.handle_HOME,
t.END: self.handle_END,
}
self.initializeScreen()
def initializeScreen(self):
# Hmm, state sucks. Oh well.
# For now we will just take over the whole terminal.
self.terminal.reset()
self.terminal.write(self.ps[self.pn])
# XXX Note: I would prefer to default to starting in insert
# mode, however this does not seem to actually work! I do not
# know why. This is probably of interest to implementors
# subclassing RecvLine.
# XXX XXX Note: But the unit tests all expect the initial mode
# to be insert right now. Fuck, there needs to be a way to
# query the current mode or something.
# self.setTypeoverMode()
self.setInsertMode()
def currentLineBuffer(self):
s = b"".join(self.lineBuffer)
return s[: self.lineBufferIndex], s[self.lineBufferIndex :]
def setInsertMode(self):
self.mode = "insert"
self.terminal.setModes([insults.modes.IRM])
def setTypeoverMode(self):
self.mode = "typeover"
self.terminal.resetModes([insults.modes.IRM])
def drawInputLine(self):
"""
Write a line containing the current input prompt and the current line
buffer at the current cursor position.
"""
self.terminal.write(self.ps[self.pn] + b"".join(self.lineBuffer))
def terminalSize(self, width, height):
# XXX - Clear the previous input line, redraw it at the new
# cursor position
self.terminal.eraseDisplay()
self.terminal.cursorHome()
self.width = width
self.height = height
self.drawInputLine()
def unhandledControlSequence(self, seq):
pass
def keystrokeReceived(self, keyID, modifier):
m = self.keyHandlers.get(keyID)
if m is not None:
m()
elif keyID in self._printableChars:
self.characterReceived(keyID, False)
else:
self._log.warn("Received unhandled keyID: {keyID!r}", keyID=keyID)
def characterReceived(self, ch, moreCharactersComing):
if self.mode == "insert":
self.lineBuffer.insert(self.lineBufferIndex, ch)
else:
self.lineBuffer[self.lineBufferIndex : self.lineBufferIndex + 1] = [ch]
self.lineBufferIndex += 1
self.terminal.write(ch)
def handle_TAB(self):
n = self.TABSTOP - (len(self.lineBuffer) % self.TABSTOP)
self.terminal.cursorForward(n)
self.lineBufferIndex += n
self.lineBuffer.extend(iterbytes(b" " * n))
def handle_LEFT(self):
if self.lineBufferIndex > 0:
self.lineBufferIndex -= 1
self.terminal.cursorBackward()
def handle_RIGHT(self):
if self.lineBufferIndex < len(self.lineBuffer):
self.lineBufferIndex += 1
self.terminal.cursorForward()
def handle_HOME(self):
if self.lineBufferIndex:
self.terminal.cursorBackward(self.lineBufferIndex)
self.lineBufferIndex = 0
def handle_END(self):
offset = len(self.lineBuffer) - self.lineBufferIndex
if offset:
self.terminal.cursorForward(offset)
self.lineBufferIndex = len(self.lineBuffer)
def handle_BACKSPACE(self):
if self.lineBufferIndex > 0:
self.lineBufferIndex -= 1
del self.lineBuffer[self.lineBufferIndex]
self.terminal.cursorBackward()
self.terminal.deleteCharacter()
def handle_DELETE(self):
if self.lineBufferIndex < len(self.lineBuffer):
del self.lineBuffer[self.lineBufferIndex]
self.terminal.deleteCharacter()
def handle_RETURN(self):
line = b"".join(self.lineBuffer)
self.lineBuffer = []
self.lineBufferIndex = 0
self.terminal.nextLine()
self.lineReceived(line)
def handle_INSERT(self):
assert self.mode in ("typeover", "insert")
if self.mode == "typeover":
self.setInsertMode()
else:
self.setTypeoverMode()
def lineReceived(self, line):
pass
class HistoricRecvLine(RecvLine):
"""
L{TerminalProtocol} which adds both basic line-editing features and input history.
Everything supported by L{RecvLine} is also supported by this class. In addition, the
up and down arrows traverse the input history. Each received line is automatically
added to the end of the input history.
"""
def connectionMade(self):
RecvLine.connectionMade(self)
self.historyLines = []
self.historyPosition = 0
t = self.terminal
self.keyHandlers.update(
{t.UP_ARROW: self.handle_UP, t.DOWN_ARROW: self.handle_DOWN}
)
def currentHistoryBuffer(self):
b = tuple(self.historyLines)
return b[: self.historyPosition], b[self.historyPosition :]
def _deliverBuffer(self, buf):
if buf:
for ch in iterbytes(buf[:-1]):
self.characterReceived(ch, True)
self.characterReceived(buf[-1:], False)
def handle_UP(self):
if self.lineBuffer and self.historyPosition == len(self.historyLines):
self.historyLines.append(b"".join(self.lineBuffer))
if self.historyPosition > 0:
self.handle_HOME()
self.terminal.eraseToLineEnd()
self.historyPosition -= 1
self.lineBuffer = []
self._deliverBuffer(self.historyLines[self.historyPosition])
def handle_DOWN(self):
if self.historyPosition < len(self.historyLines) - 1:
self.handle_HOME()
self.terminal.eraseToLineEnd()
self.historyPosition += 1
self.lineBuffer = []
self._deliverBuffer(self.historyLines[self.historyPosition])
else:
self.handle_HOME()
self.terminal.eraseToLineEnd()
self.historyPosition = len(self.historyLines)
self.lineBuffer = []
self.lineBufferIndex = 0
def handle_RETURN(self):
if self.lineBuffer:
self.historyLines.append(b"".join(self.lineBuffer))
self.historyPosition = len(self.historyLines)
return RecvLine.handle_RETURN(self)

View File

@@ -0,0 +1 @@
"conch scripts"

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,400 @@
# -*- test-case-name: twisted.conch.test.test_ckeygen -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Implementation module for the `ckeygen` command.
"""
from __future__ import annotations
import getpass
import os
import platform
import socket
import sys
from collections.abc import Callable
from functools import wraps
from importlib import reload
from typing import Any, Dict, Optional
from twisted.conch.ssh import keys
from twisted.python import failure, filepath, log, usage
if getpass.getpass == getpass.unix_getpass: # type: ignore[attr-defined]
try:
import termios # hack around broken termios
termios.tcgetattr, termios.tcsetattr
except (ImportError, AttributeError):
sys.modules["termios"] = None # type: ignore[assignment]
reload(getpass)
supportedKeyTypes = dict()
def _keyGenerator(keyType):
def assignkeygenerator(keygenerator):
@wraps(keygenerator)
def wrapper(*args, **kwargs):
return keygenerator(*args, **kwargs)
supportedKeyTypes[keyType] = wrapper
return wrapper
return assignkeygenerator
class GeneralOptions(usage.Options):
synopsis = """Usage: ckeygen [options]
"""
longdesc = "ckeygen manipulates public/private keys in various ways."
optParameters = [
["bits", "b", None, "Number of bits in the key to create."],
["filename", "f", None, "Filename of the key file."],
["type", "t", None, "Specify type of key to create."],
["comment", "C", None, "Provide new comment."],
["newpass", "N", None, "Provide new passphrase."],
["pass", "P", None, "Provide old passphrase."],
["format", "o", "sha256-base64", "Fingerprint format of key file."],
[
"private-key-subtype",
None,
None,
'OpenSSH private key subtype to write ("PEM" or "v1").',
],
]
optFlags = [
["fingerprint", "l", "Show fingerprint of key file."],
["changepass", "p", "Change passphrase of private key file."],
["quiet", "q", "Quiet."],
["no-passphrase", None, "Create the key with no passphrase."],
["showpub", "y", "Read private key file and print public key."],
]
compData = usage.Completions(
optActions={
"type": usage.CompleteList(list(supportedKeyTypes.keys())),
"private-key-subtype": usage.CompleteList(["PEM", "v1"]),
}
)
def run():
options = GeneralOptions()
try:
options.parseOptions(sys.argv[1:])
except usage.UsageError as u:
print("ERROR: %s" % u)
options.opt_help()
sys.exit(1)
log.discardLogs()
log.deferr = handleError # HACK
if options["type"]:
if options["type"].lower() in supportedKeyTypes:
print("Generating public/private %s key pair." % (options["type"]))
supportedKeyTypes[options["type"].lower()](options)
else:
sys.exit(
"Key type was %s, must be one of %s"
% (options["type"], ", ".join(supportedKeyTypes.keys()))
)
elif options["fingerprint"]:
printFingerprint(options)
elif options["changepass"]:
changePassPhrase(options)
elif options["showpub"]:
displayPublicKey(options)
else:
options.opt_help()
sys.exit(1)
def enumrepresentation(options):
if options["format"] == "md5-hex":
options["format"] = keys.FingerprintFormats.MD5_HEX
return options
elif options["format"] == "sha256-base64":
options["format"] = keys.FingerprintFormats.SHA256_BASE64
return options
else:
raise keys.BadFingerPrintFormat(
f"Unsupported fingerprint format: {options['format']}"
)
def handleError():
global exitStatus
exitStatus = 2
log.err(failure.Failure())
raise
@_keyGenerator("rsa")
def generateRSAkey(options):
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import rsa
if not options["bits"]:
options["bits"] = 2048
keyPrimitive = rsa.generate_private_key(
key_size=int(options["bits"]),
public_exponent=65537,
backend=default_backend(),
)
key = keys.Key(keyPrimitive)
_saveKey(key, options)
@_keyGenerator("dsa")
def generateDSAkey(options):
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import dsa
if not options["bits"]:
options["bits"] = 1024
keyPrimitive = dsa.generate_private_key(
key_size=int(options["bits"]),
backend=default_backend(),
)
key = keys.Key(keyPrimitive)
_saveKey(key, options)
@_keyGenerator("ecdsa")
def generateECDSAkey(options):
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import ec
if not options["bits"]:
options["bits"] = 256
# OpenSSH supports only mandatory sections of RFC5656.
# See https://www.openssh.com/txt/release-5.7
curve = b"ecdsa-sha2-nistp" + str(options["bits"]).encode("ascii")
keyPrimitive = ec.generate_private_key(
curve=keys._curveTable[curve], backend=default_backend()
)
key = keys.Key(keyPrimitive)
_saveKey(key, options)
@_keyGenerator("ed25519")
def generateEd25519key(options):
keyPrimitive = keys.Ed25519PrivateKey.generate()
key = keys.Key(keyPrimitive)
_saveKey(key, options)
def _defaultPrivateKeySubtype(keyType):
"""
Return a reasonable default private key subtype for a given key type.
@type keyType: L{str}
@param keyType: A key type, as returned by
L{twisted.conch.ssh.keys.Key.type}.
@rtype: L{str}
@return: A private OpenSSH key subtype (C{'PEM'} or C{'v1'}).
"""
if keyType == "Ed25519":
# No PEM format is defined for Ed25519 keys.
return "v1"
else:
return "PEM"
def _getKeyOrDefault(
options: Dict[Any, Any],
inputCollector: Optional[Callable[[str], str]] = None,
keyTypeName: str = "rsa",
) -> str:
"""
If C{options["filename"]} is None, prompt the user to enter a path
or attempt to set it to .ssh/id_rsa
@param options: command line options
@param inputCollector: dependency injection for testing
@param keyTypeName: key type or "rsa"
"""
if inputCollector is None:
inputCollector = input
filename = options["filename"]
if not filename:
filename = os.path.expanduser(f"~/.ssh/id_{keyTypeName}")
if platform.system() == "Windows":
filename = os.path.expanduser(Rf"%HOMEPATH %\.ssh\id_{keyTypeName}")
filename = (
inputCollector("Enter file in which the key is (%s): " % filename)
or filename
)
return str(filename)
def printFingerprint(options: Dict[Any, Any]) -> None:
filename = _getKeyOrDefault(options)
if os.path.exists(filename + ".pub"):
filename += ".pub"
options = enumrepresentation(options)
try:
key = keys.Key.fromFile(filename)
print(
"%s %s %s"
% (
key.size(),
key.fingerprint(options["format"]),
os.path.basename(filename),
)
)
except keys.BadKeyError:
sys.exit("bad key")
except FileNotFoundError:
sys.exit(f"{filename} could not be opened, please specify a file.")
def changePassPhrase(options):
filename = _getKeyOrDefault(options)
try:
key = keys.Key.fromFile(filename)
except keys.EncryptedKeyError:
# Raised if password not supplied for an encrypted key
if not options.get("pass"):
options["pass"] = getpass.getpass("Enter old passphrase: ")
try:
key = keys.Key.fromFile(filename, passphrase=options["pass"])
except keys.BadKeyError:
sys.exit("Could not change passphrase: old passphrase error")
except keys.EncryptedKeyError as e:
sys.exit(f"Could not change passphrase: {e}")
except keys.BadKeyError as e:
sys.exit(f"Could not change passphrase: {e}")
except FileNotFoundError:
sys.exit(f"{filename} could not be opened, please specify a file.")
if not options.get("newpass"):
while 1:
p1 = getpass.getpass("Enter new passphrase (empty for no passphrase): ")
p2 = getpass.getpass("Enter same passphrase again: ")
if p1 == p2:
break
print("Passphrases do not match. Try again.")
options["newpass"] = p1
if options.get("private-key-subtype") is None:
options["private-key-subtype"] = _defaultPrivateKeySubtype(key.type())
try:
newkeydata = key.toString(
"openssh",
subtype=options["private-key-subtype"],
passphrase=options["newpass"],
)
except Exception as e:
sys.exit(f"Could not change passphrase: {e}")
try:
keys.Key.fromString(newkeydata, passphrase=options["newpass"])
except (keys.EncryptedKeyError, keys.BadKeyError) as e:
sys.exit(f"Could not change passphrase: {e}")
with open(filename, "wb") as fd:
fd.write(newkeydata)
print("Your identification has been saved with the new passphrase.")
def displayPublicKey(options):
filename = _getKeyOrDefault(options)
try:
key = keys.Key.fromFile(filename)
except FileNotFoundError:
sys.exit(f"{filename} could not be opened, please specify a file.")
except keys.EncryptedKeyError:
if not options.get("pass"):
options["pass"] = getpass.getpass("Enter passphrase: ")
key = keys.Key.fromFile(filename, passphrase=options["pass"])
displayKey = key.public().toString("openssh").decode("ascii")
print(displayKey)
def _inputSaveFile(prompt: str) -> str:
"""
Ask the user where to save the key.
This needs to be a separate function so the unit test can patch it.
"""
return input(prompt)
def _saveKey(
key: keys.Key,
options: Dict[Any, Any],
inputCollector: Optional[Callable[[str], str]] = None,
) -> None:
"""
Persist a SSH key on local filesystem.
@param key: Key which is persisted on local filesystem.
@param options:
@param inputCollector: Dependency injection for testing.
"""
if inputCollector is None:
inputCollector = input
KeyTypeMapping = {"EC": "ecdsa", "Ed25519": "ed25519", "RSA": "rsa", "DSA": "dsa"}
keyTypeName = KeyTypeMapping[key.type()]
filename = options["filename"]
if not filename:
defaultPath = _getKeyOrDefault(options, inputCollector, keyTypeName)
newPath = _inputSaveFile(
f"Enter file in which to save the key ({defaultPath}): "
)
filename = newPath.strip() or defaultPath
if os.path.exists(filename):
print(f"{filename} already exists.")
yn = inputCollector("Overwrite (y/n)? ")
if yn[0].lower() != "y":
sys.exit()
if options.get("no-passphrase"):
options["pass"] = b""
elif not options["pass"]:
while 1:
p1 = getpass.getpass("Enter passphrase (empty for no passphrase): ")
p2 = getpass.getpass("Enter same passphrase again: ")
if p1 == p2:
break
print("Passphrases do not match. Try again.")
options["pass"] = p1
if options.get("private-key-subtype") is None:
options["private-key-subtype"] = _defaultPrivateKeySubtype(key.type())
comment = f"{getpass.getuser()}@{socket.gethostname()}"
fp = filepath.FilePath(filename)
fp.setContent(
key.toString(
"openssh",
subtype=options["private-key-subtype"],
passphrase=options["pass"],
)
)
fp.chmod(0o100600)
filepath.FilePath(filename + ".pub").setContent(
key.public().toString("openssh", comment=comment)
)
options = enumrepresentation(options)
print(f"Your identification has been saved in {filename}")
print(f"Your public key has been saved in {filename}.pub")
print(f"The key fingerprint in {options['format']} is:")
print(key.fingerprint(options["format"]))
if __name__ == "__main__":
run()

View File

@@ -0,0 +1,578 @@
# -*- test-case-name: twisted.conch.test.test_conch -*-
#
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
#
# $Id: conch.py,v 1.65 2004/03/11 00:29:14 z3p Exp $
# Implementation module for the `conch` command.
#
import fcntl
import getpass
import os
import signal
import struct
import sys
import tty
from typing import List, Tuple
from twisted.conch.client import connect, default
from twisted.conch.client.options import ConchOptions
from twisted.conch.error import ConchError
from twisted.conch.ssh import channel, common, connection, forwarding, session
from twisted.internet import reactor, stdio, task
from twisted.python import log, usage
from twisted.python.compat import ioType, networkString
class ClientOptions(ConchOptions):
synopsis = """Usage: conch [options] host [command]
"""
longdesc = (
"conch is a SSHv2 client that allows logging into a remote "
"machine and executing commands."
)
optParameters = [
["escape", "e", "~"],
[
"localforward",
"L",
None,
"listen-port:host:port Forward local port to remote address",
],
[
"remoteforward",
"R",
None,
"listen-port:host:port Forward remote port to local address",
],
]
optFlags = [
["null", "n", "Redirect input from /dev/null."],
["fork", "f", "Fork to background after authentication."],
["tty", "t", "Tty; allocate a tty even if command is given."],
["notty", "T", "Do not allocate a tty."],
["noshell", "N", "Do not execute a shell or command."],
["subsystem", "s", "Invoke command (mandatory) as SSH2 subsystem."],
]
compData = usage.Completions(
mutuallyExclusive=[("tty", "notty")],
optActions={
"localforward": usage.Completer(descr="listen-port:host:port"),
"remoteforward": usage.Completer(descr="listen-port:host:port"),
},
extraActions=[
usage.CompleteUserAtHost(),
usage.Completer(descr="command"),
usage.Completer(descr="argument", repeat=True),
],
)
localForwards: List[Tuple[int, Tuple[int, int]]] = []
remoteForwards: List[Tuple[int, Tuple[int, int]]] = []
def opt_escape(self, esc):
"""
Set escape character; ``none'' = disable
"""
if esc == "none":
self["escape"] = None
elif esc[0] == "^" and len(esc) == 2:
self["escape"] = chr(ord(esc[1]) - 64)
elif len(esc) == 1:
self["escape"] = esc
else:
sys.exit(f"Bad escape character '{esc}'.")
def opt_localforward(self, f):
"""
Forward local port to remote address (lport:host:port)
"""
localPort, remoteHost, remotePort = f.split(":") # Doesn't do v6 yet
localPort = int(localPort)
remotePort = int(remotePort)
self.localForwards.append((localPort, (remoteHost, remotePort)))
def opt_remoteforward(self, f):
"""
Forward remote port to local address (rport:host:port)
"""
remotePort, connHost, connPort = f.split(":") # Doesn't do v6 yet
remotePort = int(remotePort)
connPort = int(connPort)
self.remoteForwards.append((remotePort, (connHost, connPort)))
def parseArgs(self, host, *command):
self["host"] = host
self["command"] = " ".join(command)
# Rest of code in "run"
options = None
conn = None
exitStatus = 0
old = None
_inRawMode = 0
_savedRawMode = None
def run():
global options, old
args = sys.argv[1:]
if "-l" in args: # CVS is an idiot
i = args.index("-l")
args = args[i : i + 2] + args
del args[i + 2 : i + 4]
for arg in args[:]:
try:
i = args.index(arg)
if arg[:2] == "-o" and args[i + 1][0] != "-":
args[i : i + 2] = [] # Suck on it scp
except ValueError:
pass
options = ClientOptions()
try:
options.parseOptions(args)
except usage.UsageError as u:
print(f"ERROR: {u}")
options.opt_help()
sys.exit(1)
if options["log"]:
if options["logfile"]:
if options["logfile"] == "-":
f = sys.stdout
else:
f = open(options["logfile"], "a+")
else:
f = sys.stderr
realout = sys.stdout
log.startLogging(f)
sys.stdout = realout
else:
log.discardLogs()
doConnect()
fd = sys.stdin.fileno()
try:
old = tty.tcgetattr(fd)
except BaseException:
old = None
try:
oldUSR1 = signal.signal(
signal.SIGUSR1, lambda *a: reactor.callLater(0, reConnect)
)
except BaseException:
oldUSR1 = None
try:
reactor.run()
finally:
if old:
tty.tcsetattr(fd, tty.TCSANOW, old)
if oldUSR1:
signal.signal(signal.SIGUSR1, oldUSR1)
if (options["command"] and options["tty"]) or not options["notty"]:
signal.signal(signal.SIGWINCH, signal.SIG_DFL)
if sys.stdout.isatty() and not options["command"]:
print("Connection to {} closed.".format(options["host"]))
sys.exit(exitStatus)
def handleError():
from twisted.python import failure
global exitStatus
exitStatus = 2
reactor.callLater(0.01, _stopReactor)
log.err(failure.Failure())
raise
def _stopReactor():
try:
reactor.stop()
except BaseException:
pass
def doConnect():
if "@" in options["host"]:
options["user"], options["host"] = options["host"].split("@", 1)
if not options.identitys:
options.identitys = ["~/.ssh/id_rsa", "~/.ssh/id_dsa"]
host = options["host"]
if not options["user"]:
options["user"] = getpass.getuser()
if not options["port"]:
options["port"] = 22
else:
options["port"] = int(options["port"])
host = options["host"]
port = options["port"]
vhk = default.verifyHostKey
if not options["host-key-algorithms"]:
options["host-key-algorithms"] = default.getHostKeyAlgorithms(host, options)
uao = default.SSHUserAuthClient(options["user"], options, SSHConnection())
connect.connect(host, port, options, vhk, uao).addErrback(_ebExit)
def _ebExit(f):
global exitStatus
exitStatus = f"conch: exiting with error {f}"
reactor.callLater(0.1, _stopReactor)
def onConnect():
# if keyAgent and options['agent']:
# cc = protocol.ClientCreator(reactor, SSHAgentForwardingLocal, conn)
# cc.connectUNIX(os.environ['SSH_AUTH_SOCK'])
if hasattr(conn.transport, "sendIgnore"):
_KeepAlive(conn)
if options.localForwards:
for localPort, hostport in options.localForwards:
s = reactor.listenTCP(
localPort,
forwarding.SSHListenForwardingFactory(
conn, hostport, SSHListenClientForwardingChannel
),
)
conn.localForwards.append(s)
if options.remoteForwards:
for remotePort, hostport in options.remoteForwards:
log.msg(f"asking for remote forwarding for {remotePort}:{hostport}")
conn.requestRemoteForwarding(remotePort, hostport)
reactor.addSystemEventTrigger("before", "shutdown", beforeShutdown)
if not options["noshell"] or options["agent"]:
conn.openChannel(SSHSession())
if options["fork"]:
if os.fork():
os._exit(0)
os.setsid()
for i in range(3):
try:
os.close(i)
except OSError as e:
import errno
if e.errno != errno.EBADF:
raise
def reConnect():
beforeShutdown()
conn.transport.transport.loseConnection()
def beforeShutdown():
remoteForwards = options.remoteForwards
for remotePort, hostport in remoteForwards:
log.msg(f"cancelling {remotePort}:{hostport}")
conn.cancelRemoteForwarding(remotePort)
def stopConnection():
if not options["reconnect"]:
reactor.callLater(0.1, _stopReactor)
class _KeepAlive:
def __init__(self, conn):
self.conn = conn
self.globalTimeout = None
self.lc = task.LoopingCall(self.sendGlobal)
self.lc.start(300)
def sendGlobal(self):
d = self.conn.sendGlobalRequest(
b"conch-keep-alive@twistedmatrix.com", b"", wantReply=1
)
d.addBoth(self._cbGlobal)
self.globalTimeout = reactor.callLater(30, self._ebGlobal)
def _cbGlobal(self, res):
if self.globalTimeout:
self.globalTimeout.cancel()
self.globalTimeout = None
def _ebGlobal(self):
if self.globalTimeout:
self.globalTimeout = None
self.conn.transport.loseConnection()
class SSHConnection(connection.SSHConnection):
def serviceStarted(self):
global conn
conn = self
self.localForwards = []
self.remoteForwards = {}
onConnect()
def serviceStopped(self):
lf = self.localForwards
self.localForwards = []
for s in lf:
s.loseConnection()
stopConnection()
def requestRemoteForwarding(self, remotePort, hostport):
data = forwarding.packGlobal_tcpip_forward(("0.0.0.0", remotePort))
d = self.sendGlobalRequest(b"tcpip-forward", data, wantReply=1)
log.msg(f"requesting remote forwarding {remotePort}:{hostport}")
d.addCallback(self._cbRemoteForwarding, remotePort, hostport)
d.addErrback(self._ebRemoteForwarding, remotePort, hostport)
def _cbRemoteForwarding(self, result, remotePort, hostport):
log.msg(f"accepted remote forwarding {remotePort}:{hostport}")
self.remoteForwards[remotePort] = hostport
log.msg(repr(self.remoteForwards))
def _ebRemoteForwarding(self, f, remotePort, hostport):
log.msg(f"remote forwarding {remotePort}:{hostport} failed")
log.msg(f)
def cancelRemoteForwarding(self, remotePort):
data = forwarding.packGlobal_tcpip_forward(("0.0.0.0", remotePort))
self.sendGlobalRequest(b"cancel-tcpip-forward", data)
log.msg(f"cancelling remote forwarding {remotePort}")
try:
del self.remoteForwards[remotePort]
except Exception:
pass
log.msg(repr(self.remoteForwards))
def channel_forwarded_tcpip(self, windowSize, maxPacket, data):
log.msg(f"FTCP {data!r}")
remoteHP, origHP = forwarding.unpackOpen_forwarded_tcpip(data)
log.msg(self.remoteForwards)
log.msg(remoteHP)
if remoteHP[1] in self.remoteForwards:
connectHP = self.remoteForwards[remoteHP[1]]
log.msg(f"connect forwarding {connectHP}")
return SSHConnectForwardingChannel(
connectHP, remoteWindow=windowSize, remoteMaxPacket=maxPacket, conn=self
)
else:
raise ConchError(
connection.OPEN_CONNECT_FAILED, "don't know about that port"
)
def channelClosed(self, channel):
log.msg(f"connection closing {channel}")
log.msg(self.channels)
if len(self.channels) == 1: # Just us left
log.msg("stopping connection")
stopConnection()
else:
# Because of the unix thing
self.__class__.__bases__[0].channelClosed(self, channel)
class SSHSession(channel.SSHChannel):
name = b"session"
def channelOpen(self, foo):
log.msg(f"session {self.id} open")
if options["agent"]:
d = self.conn.sendRequest(
self, b"auth-agent-req@openssh.com", b"", wantReply=1
)
d.addBoth(lambda x: log.msg(x))
if options["noshell"]:
return
if (options["command"] and options["tty"]) or not options["notty"]:
_enterRawMode()
c = session.SSHSessionClient()
if options["escape"] and not options["notty"]:
self.escapeMode = 1
c.dataReceived = self.handleInput
else:
c.dataReceived = self.write
c.connectionLost = lambda x: self.sendEOF()
self.stdio = stdio.StandardIO(c)
fd = 0
if options["subsystem"]:
self.conn.sendRequest(self, b"subsystem", common.NS(options["command"]))
elif options["command"]:
if options["tty"]:
term = os.environ["TERM"]
winsz = fcntl.ioctl(fd, tty.TIOCGWINSZ, "12345678")
winSize = struct.unpack("4H", winsz)
ptyReqData = session.packRequest_pty_req(term, winSize, "")
self.conn.sendRequest(self, b"pty-req", ptyReqData)
signal.signal(signal.SIGWINCH, self._windowResized)
self.conn.sendRequest(self, b"exec", common.NS(options["command"]))
else:
if not options["notty"]:
term = os.environ["TERM"]
winsz = fcntl.ioctl(fd, tty.TIOCGWINSZ, "12345678")
winSize = struct.unpack("4H", winsz)
ptyReqData = session.packRequest_pty_req(term, winSize, "")
self.conn.sendRequest(self, b"pty-req", ptyReqData)
signal.signal(signal.SIGWINCH, self._windowResized)
self.conn.sendRequest(self, b"shell", b"")
# if hasattr(conn.transport, 'transport'):
# conn.transport.transport.setTcpNoDelay(1)
def handleInput(self, char):
if char in (b"\n", b"\r"):
self.escapeMode = 1
self.write(char)
elif self.escapeMode == 1 and char == options["escape"]:
self.escapeMode = 2
elif self.escapeMode == 2:
self.escapeMode = 1 # So we can chain escapes together
if char == b".": # Disconnect
log.msg("disconnecting from escape")
stopConnection()
return
elif char == b"\x1a": # ^Z, suspend
def _():
_leaveRawMode()
sys.stdout.flush()
sys.stdin.flush()
os.kill(os.getpid(), signal.SIGTSTP)
_enterRawMode()
reactor.callLater(0, _)
return
elif char == b"R": # Rekey connection
log.msg("rekeying connection")
self.conn.transport.sendKexInit()
return
elif char == b"#": # Display connections
self.stdio.write(b"\r\nThe following connections are open:\r\n")
channels = self.conn.channels.keys()
channels.sort()
for channelId in channels:
self.stdio.write(
networkString(
" #{} {}\r\n".format(
channelId, self.conn.channels[channelId]
)
)
)
return
self.write(b"~" + char)
else:
self.escapeMode = 0
self.write(char)
def dataReceived(self, data):
self.stdio.write(data)
def extReceived(self, t, data):
if t == connection.EXTENDED_DATA_STDERR:
log.msg(f"got {len(data)} stderr data")
if ioType(sys.stderr) == str:
sys.stderr.buffer.write(data)
else:
sys.stderr.write(data)
def eofReceived(self):
log.msg("got eof")
self.stdio.loseWriteConnection()
def closeReceived(self):
log.msg(f"remote side closed {self}")
self.conn.sendClose(self)
def closed(self):
global old
log.msg(f"closed {self}")
log.msg(repr(self.conn.channels))
def request_exit_status(self, data):
global exitStatus
exitStatus = int(struct.unpack(">L", data)[0])
log.msg(f"exit status: {exitStatus}")
def sendEOF(self):
self.conn.sendEOF(self)
def stopWriting(self):
self.stdio.pauseProducing()
def startWriting(self):
self.stdio.resumeProducing()
def _windowResized(self, *args):
winsz = fcntl.ioctl(0, tty.TIOCGWINSZ, "12345678")
winSize = struct.unpack("4H", winsz)
newSize = winSize[1], winSize[0], winSize[2], winSize[3]
self.conn.sendRequest(self, b"window-change", struct.pack("!4L", *newSize))
class SSHListenClientForwardingChannel(forwarding.SSHListenClientForwardingChannel):
pass
class SSHConnectForwardingChannel(forwarding.SSHConnectForwardingChannel):
pass
def _leaveRawMode():
global _inRawMode
if not _inRawMode:
return
fd = sys.stdin.fileno()
tty.tcsetattr(fd, tty.TCSANOW, _savedRawMode)
_inRawMode = 0
def _enterRawMode():
global _inRawMode, _savedRawMode
if _inRawMode:
return
fd = sys.stdin.fileno()
try:
old = tty.tcgetattr(fd)
new = old[:]
except BaseException:
log.msg("not a typewriter!")
else:
# iflage
new[0] = new[0] | tty.IGNPAR
new[0] = new[0] & ~(
tty.ISTRIP
| tty.INLCR
| tty.IGNCR
| tty.ICRNL
| tty.IXON
| tty.IXANY
| tty.IXOFF
)
if hasattr(tty, "IUCLC"):
new[0] = new[0] & ~tty.IUCLC
# lflag
new[3] = new[3] & ~(
tty.ISIG
| tty.ICANON
| tty.ECHO
| tty.ECHO
| tty.ECHOE
| tty.ECHOK
| tty.ECHONL
)
if hasattr(tty, "IEXTEN"):
new[3] = new[3] & ~tty.IEXTEN
# oflag
new[1] = new[1] & ~tty.OPOST
new[6][tty.VMIN] = 1
new[6][tty.VTIME] = 0
_savedRawMode = old
tty.tcsetattr(fd, tty.TCSANOW, new)
# tty.setraw(fd)
_inRawMode = 1
if __name__ == "__main__":
run()

View File

@@ -0,0 +1,673 @@
# -*- test-case-name: twisted.conch.test.test_scripts -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Implementation module for the `tkconch` command.
"""
import base64
import getpass
import os
import signal
import struct
import sys
import tkinter as Tkinter
import tkinter.filedialog as tkFileDialog
import tkinter.messagebox as tkMessageBox
from typing import List, Tuple
from twisted.conch import error
from twisted.conch.client.default import isInKnownHosts
from twisted.conch.ssh import (
channel,
common,
connection,
forwarding,
keys,
session,
transport,
userauth,
)
from twisted.conch.ui import tkvt100
from twisted.internet import defer, protocol, reactor, tksupport
from twisted.python import log, usage
class TkConchMenu(Tkinter.Frame):
def __init__(self, *args, **params):
## Standard heading: initialization
Tkinter.Frame.__init__(self, *args, **params)
self.master.title("TkConch")
self.localRemoteVar = Tkinter.StringVar()
self.localRemoteVar.set("local")
Tkinter.Label(self, anchor="w", justify="left", text="Hostname").grid(
column=1, row=1, sticky="w"
)
self.host = Tkinter.Entry(self)
self.host.grid(column=2, columnspan=2, row=1, sticky="nesw")
Tkinter.Label(self, anchor="w", justify="left", text="Port").grid(
column=1, row=2, sticky="w"
)
self.port = Tkinter.Entry(self)
self.port.grid(column=2, columnspan=2, row=2, sticky="nesw")
Tkinter.Label(self, anchor="w", justify="left", text="Username").grid(
column=1, row=3, sticky="w"
)
self.user = Tkinter.Entry(self)
self.user.grid(column=2, columnspan=2, row=3, sticky="nesw")
Tkinter.Label(self, anchor="w", justify="left", text="Command").grid(
column=1, row=4, sticky="w"
)
self.command = Tkinter.Entry(self)
self.command.grid(column=2, columnspan=2, row=4, sticky="nesw")
Tkinter.Label(self, anchor="w", justify="left", text="Identity").grid(
column=1, row=5, sticky="w"
)
self.identity = Tkinter.Entry(self)
self.identity.grid(column=2, row=5, sticky="nesw")
Tkinter.Button(self, command=self.getIdentityFile, text="Browse").grid(
column=3, row=5, sticky="nesw"
)
Tkinter.Label(self, text="Port Forwarding").grid(column=1, row=6, sticky="w")
self.forwards = Tkinter.Listbox(self, height=0, width=0)
self.forwards.grid(column=2, columnspan=2, row=6, sticky="nesw")
Tkinter.Button(self, text="Add", command=self.addForward).grid(column=1, row=7)
Tkinter.Button(self, text="Remove", command=self.removeForward).grid(
column=1, row=8
)
self.forwardPort = Tkinter.Entry(self)
self.forwardPort.grid(column=2, row=7, sticky="nesw")
Tkinter.Label(self, text="Port").grid(column=3, row=7, sticky="nesw")
self.forwardHost = Tkinter.Entry(self)
self.forwardHost.grid(column=2, row=8, sticky="nesw")
Tkinter.Label(self, text="Host").grid(column=3, row=8, sticky="nesw")
self.localForward = Tkinter.Radiobutton(
self, text="Local", variable=self.localRemoteVar, value="local"
)
self.localForward.grid(column=2, row=9)
self.remoteForward = Tkinter.Radiobutton(
self, text="Remote", variable=self.localRemoteVar, value="remote"
)
self.remoteForward.grid(column=3, row=9)
Tkinter.Label(self, text="Advanced Options").grid(
column=1, columnspan=3, row=10, sticky="nesw"
)
Tkinter.Label(self, anchor="w", justify="left", text="Cipher").grid(
column=1, row=11, sticky="w"
)
self.cipher = Tkinter.Entry(self, name="cipher")
self.cipher.grid(column=2, columnspan=2, row=11, sticky="nesw")
Tkinter.Label(self, anchor="w", justify="left", text="MAC").grid(
column=1, row=12, sticky="w"
)
self.mac = Tkinter.Entry(self, name="mac")
self.mac.grid(column=2, columnspan=2, row=12, sticky="nesw")
Tkinter.Label(self, anchor="w", justify="left", text="Escape Char").grid(
column=1, row=13, sticky="w"
)
self.escape = Tkinter.Entry(self, name="escape")
self.escape.grid(column=2, columnspan=2, row=13, sticky="nesw")
Tkinter.Button(self, text="Connect!", command=self.doConnect).grid(
column=1, columnspan=3, row=14, sticky="nesw"
)
# Resize behavior(s)
self.grid_rowconfigure(6, weight=1, minsize=64)
self.grid_columnconfigure(2, weight=1, minsize=2)
self.master.protocol("WM_DELETE_WINDOW", sys.exit)
def getIdentityFile(self):
r = tkFileDialog.askopenfilename()
if r:
self.identity.delete(0, Tkinter.END)
self.identity.insert(Tkinter.END, r)
def addForward(self):
port = self.forwardPort.get()
self.forwardPort.delete(0, Tkinter.END)
host = self.forwardHost.get()
self.forwardHost.delete(0, Tkinter.END)
if self.localRemoteVar.get() == "local":
self.forwards.insert(Tkinter.END, f"L:{port}:{host}")
else:
self.forwards.insert(Tkinter.END, f"R:{port}:{host}")
def removeForward(self):
cur = self.forwards.curselection()
if cur:
self.forwards.remove(cur[0])
def doConnect(self):
finished = 1
options["host"] = self.host.get()
options["port"] = self.port.get()
options["user"] = self.user.get()
options["command"] = self.command.get()
cipher = self.cipher.get()
mac = self.mac.get()
escape = self.escape.get()
if cipher:
if cipher in SSHClientTransport.supportedCiphers:
SSHClientTransport.supportedCiphers = [cipher]
else:
tkMessageBox.showerror("TkConch", "Bad cipher.")
finished = 0
if mac:
if mac in SSHClientTransport.supportedMACs:
SSHClientTransport.supportedMACs = [mac]
elif finished:
tkMessageBox.showerror("TkConch", "Bad MAC.")
finished = 0
if escape:
if escape == "none":
options["escape"] = None
elif escape[0] == "^" and len(escape) == 2:
options["escape"] = chr(ord(escape[1]) - 64)
elif len(escape) == 1:
options["escape"] = escape
elif finished:
tkMessageBox.showerror("TkConch", "Bad escape character '%s'." % escape)
finished = 0
if self.identity.get():
options.identitys.append(self.identity.get())
for line in self.forwards.get(0, Tkinter.END):
if line[0] == "L":
options.opt_localforward(line[2:])
else:
options.opt_remoteforward(line[2:])
if "@" in options["host"]:
options["user"], options["host"] = options["host"].split("@", 1)
if (not options["host"] or not options["user"]) and finished:
tkMessageBox.showerror("TkConch", "Missing host or username.")
finished = 0
if finished:
self.master.quit()
self.master.destroy()
if options["log"]:
realout = sys.stdout
log.startLogging(sys.stderr)
sys.stdout = realout
else:
log.discardLogs()
log.deferr = handleError # HACK
if not options.identitys:
options.identitys = ["~/.ssh/id_rsa", "~/.ssh/id_dsa"]
host = options["host"]
port = int(options["port"] or 22)
log.msg((host, port))
reactor.connectTCP(host, port, SSHClientFactory())
frame.master.deiconify()
frame.master.title(
"{}@{} - TkConch".format(options["user"], options["host"])
)
else:
self.focus()
class GeneralOptions(usage.Options):
synopsis = """Usage: tkconch [options] host [command]
"""
optParameters = [
["user", "l", None, "Log in using this user name."],
["identity", "i", "~/.ssh/identity", "Identity for public key authentication"],
["escape", "e", "~", "Set escape character; ``none'' = disable"],
["cipher", "c", None, "Select encryption algorithm."],
["macs", "m", None, "Specify MAC algorithms for protocol version 2."],
["port", "p", None, "Connect to this port. Server must be on the same port."],
[
"localforward",
"L",
None,
"listen-port:host:port Forward local port to remote address",
],
[
"remoteforward",
"R",
None,
"listen-port:host:port Forward remote port to local address",
],
]
optFlags = [
["tty", "t", "Tty; allocate a tty even if command is given."],
["notty", "T", "Do not allocate a tty."],
["version", "V", "Display version number only."],
["compress", "C", "Enable compression."],
["noshell", "N", "Do not execute a shell or command."],
["subsystem", "s", "Invoke command (mandatory) as SSH2 subsystem."],
["log", "v", "Log to stderr"],
["ansilog", "a", "Print the received data to stdout"],
]
_ciphers = transport.SSHClientTransport.supportedCiphers
_macs = transport.SSHClientTransport.supportedMACs
compData = usage.Completions(
mutuallyExclusive=[("tty", "notty")],
optActions={
"cipher": usage.CompleteList([v.decode() for v in _ciphers]),
"macs": usage.CompleteList([v.decode() for v in _macs]),
"localforward": usage.Completer(descr="listen-port:host:port"),
"remoteforward": usage.Completer(descr="listen-port:host:port"),
},
extraActions=[
usage.CompleteUserAtHost(),
usage.Completer(descr="command"),
usage.Completer(descr="argument", repeat=True),
],
)
identitys: List[str] = []
localForwards: List[Tuple[int, Tuple[int, int]]] = []
remoteForwards: List[Tuple[int, Tuple[int, int]]] = []
def opt_identity(self, i):
self.identitys.append(i)
def opt_localforward(self, f):
localPort, remoteHost, remotePort = f.split(":") # doesn't do v6 yet
localPort = int(localPort)
remotePort = int(remotePort)
self.localForwards.append((localPort, (remoteHost, remotePort)))
def opt_remoteforward(self, f):
remotePort, connHost, connPort = f.split(":") # doesn't do v6 yet
remotePort = int(remotePort)
connPort = int(connPort)
self.remoteForwards.append((remotePort, (connHost, connPort)))
def opt_compress(self):
SSHClientTransport.supportedCompressions[0:1] = ["zlib"]
def parseArgs(self, *args):
if args:
self["host"] = args[0]
self["command"] = " ".join(args[1:])
else:
self["host"] = ""
self["command"] = ""
# Rest of code in "run"
options = None
menu = None
exitStatus = 0
frame = None
def deferredAskFrame(question, echo):
if frame.callback:
raise ValueError("can't ask 2 questions at once!")
d = defer.Deferred()
resp = []
def gotChar(ch, resp=resp):
if not ch:
return
if ch == "\x03": # C-c
reactor.stop()
if ch == "\r":
frame.write("\r\n")
stresp = "".join(resp)
del resp
frame.callback = None
d.callback(stresp)
return
elif 32 <= ord(ch) < 127:
resp.append(ch)
if echo:
frame.write(ch)
elif ord(ch) == 8 and resp: # BS
if echo:
frame.write("\x08 \x08")
resp.pop()
frame.callback = gotChar
frame.write(question)
frame.canvas.focus_force()
return d
def run():
global menu, options, frame
args = sys.argv[1:]
if "-l" in args: # cvs is an idiot
i = args.index("-l")
args = args[i : i + 2] + args
del args[i + 2 : i + 4]
for arg in args[:]:
try:
i = args.index(arg)
if arg[:2] == "-o" and args[i + 1][0] != "-":
args[i : i + 2] = [] # suck on it scp
except ValueError:
pass
root = Tkinter.Tk()
root.withdraw()
top = Tkinter.Toplevel()
menu = TkConchMenu(top)
menu.pack(side=Tkinter.TOP, fill=Tkinter.BOTH, expand=1)
options = GeneralOptions()
try:
options.parseOptions(args)
except usage.UsageError as u:
print("ERROR: %s" % u)
options.opt_help()
sys.exit(1)
for k, v in options.items():
if v and hasattr(menu, k):
getattr(menu, k).insert(Tkinter.END, v)
for p, (rh, rp) in options.localForwards:
menu.forwards.insert(Tkinter.END, f"L:{p}:{rh}:{rp}")
options.localForwards = []
for p, (rh, rp) in options.remoteForwards:
menu.forwards.insert(Tkinter.END, f"R:{p}:{rh}:{rp}")
options.remoteForwards = []
frame = tkvt100.VT100Frame(root, callback=None)
root.geometry(
"%dx%d"
% (tkvt100.fontWidth * frame.width + 3, tkvt100.fontHeight * frame.height + 3)
)
frame.pack(side=Tkinter.TOP)
tksupport.install(root)
root.withdraw()
if (options["host"] and options["user"]) or "@" in options["host"]:
menu.doConnect()
else:
top.mainloop()
reactor.run()
sys.exit(exitStatus)
def handleError():
from twisted.python import failure
global exitStatus
exitStatus = 2
log.err(failure.Failure())
reactor.stop()
raise
class SSHClientFactory(protocol.ClientFactory):
noisy = True
def stopFactory(self):
reactor.stop()
def buildProtocol(self, addr):
return SSHClientTransport()
def clientConnectionFailed(self, connector, reason):
tkMessageBox.showwarning(
"TkConch",
f"Connection Failed, Reason:\n {reason.type}: {reason.value}",
)
class SSHClientTransport(transport.SSHClientTransport):
def receiveError(self, code, desc):
global exitStatus
exitStatus = (
"conch:\tRemote side disconnected with error code %i\nconch:\treason: %s"
% (code, desc)
)
def sendDisconnect(self, code, reason):
global exitStatus
exitStatus = (
"conch:\tSending disconnect with error code %i\nconch:\treason: %s"
% (code, reason)
)
transport.SSHClientTransport.sendDisconnect(self, code, reason)
def receiveDebug(self, alwaysDisplay, message, lang):
global options
if alwaysDisplay or options["log"]:
log.msg("Received Debug Message: %s" % message)
def verifyHostKey(self, pubKey, fingerprint):
# d = defer.Deferred()
# d.addCallback(lambda x:defer.succeed(1))
# d.callback(2)
# return d
goodKey = isInKnownHosts(options["host"], pubKey, {"known-hosts": None})
if goodKey == 1: # good key
return defer.succeed(1)
elif goodKey == 2: # AAHHHHH changed
return defer.fail(error.ConchError("bad host key"))
else:
if options["host"] == self.transport.getPeer().host:
host = options["host"]
khHost = options["host"]
else:
host = "{} ({})".format(options["host"], self.transport.getPeer().host)
khHost = "{},{}".format(options["host"], self.transport.getPeer().host)
keyType = common.getNS(pubKey)[0]
ques = """The authenticity of host '{}' can't be established.\r
{} key fingerprint is {}.""".format(
host,
{b"ssh-dss": "DSA", b"ssh-rsa": "RSA"}[keyType],
fingerprint,
)
ques += "\r\nAre you sure you want to continue connecting (yes/no)? "
return deferredAskFrame(ques, 1).addCallback(
self._cbVerifyHostKey, pubKey, khHost, keyType
)
def _cbVerifyHostKey(self, ans, pubKey, khHost, keyType):
if ans.lower() not in ("yes", "no"):
return deferredAskFrame("Please type 'yes' or 'no': ", 1).addCallback(
self._cbVerifyHostKey, pubKey, khHost, keyType
)
if ans.lower() == "no":
frame.write("Host key verification failed.\r\n")
raise error.ConchError("bad host key")
try:
frame.write(
"Warning: Permanently added '%s' (%s) to the list of "
"known hosts.\r\n"
% (khHost, {b"ssh-dss": "DSA", b"ssh-rsa": "RSA"}[keyType])
)
with open(os.path.expanduser("~/.ssh/known_hosts"), "a") as known_hosts:
encodedKey = base64.b64encode(pubKey)
known_hosts.write(f"\n{khHost} {keyType} {encodedKey}")
except BaseException:
log.deferr()
raise error.ConchError
def connectionSecure(self):
if options["user"]:
user = options["user"]
else:
user = getpass.getuser()
self.requestService(SSHUserAuthClient(user, SSHConnection()))
class SSHUserAuthClient(userauth.SSHUserAuthClient):
usedFiles: List[str] = []
def getPassword(self, prompt=None):
if not prompt:
prompt = "{}@{}'s password: ".format(self.user, options["host"])
return deferredAskFrame(prompt, 0)
def getPublicKey(self):
files = [x for x in options.identitys if x not in self.usedFiles]
if not files:
return None
file = files[0]
log.msg(file)
self.usedFiles.append(file)
file = os.path.expanduser(file)
file += ".pub"
if not os.path.exists(file):
return
try:
return keys.Key.fromFile(file).blob()
except BaseException:
return self.getPublicKey() # try again
def getPrivateKey(self):
file = os.path.expanduser(self.usedFiles[-1])
if not os.path.exists(file):
return None
try:
return defer.succeed(keys.Key.fromFile(file).keyObject)
except keys.BadKeyError as e:
if e.args[0] == "encrypted key with no password":
prompt = "Enter passphrase for key '%s': " % self.usedFiles[-1]
return deferredAskFrame(prompt, 0).addCallback(self._cbGetPrivateKey, 0)
def _cbGetPrivateKey(self, ans, count):
file = os.path.expanduser(self.usedFiles[-1])
try:
return keys.Key.fromFile(file, password=ans).keyObject
except keys.BadKeyError:
if count == 2:
raise
prompt = "Enter passphrase for key '%s': " % self.usedFiles[-1]
return deferredAskFrame(prompt, 0).addCallback(
self._cbGetPrivateKey, count + 1
)
class SSHConnection(connection.SSHConnection):
def serviceStarted(self):
if not options["noshell"]:
self.openChannel(SSHSession())
if options.localForwards:
for localPort, hostport in options.localForwards:
reactor.listenTCP(
localPort,
forwarding.SSHListenForwardingFactory(
self, hostport, forwarding.SSHListenClientForwardingChannel
),
)
if options.remoteForwards:
for remotePort, hostport in options.remoteForwards:
log.msg(
"asking for remote forwarding for {}:{}".format(
remotePort, hostport
)
)
data = forwarding.packGlobal_tcpip_forward(("0.0.0.0", remotePort))
self.sendGlobalRequest("tcpip-forward", data)
self.remoteForwards[remotePort] = hostport
class SSHSession(channel.SSHChannel):
name = b"session"
def channelOpen(self, foo):
# global globalSession
# globalSession = self
# turn off local echo
self.escapeMode = 1
c = session.SSHSessionClient()
if options["escape"]:
c.dataReceived = self.handleInput
else:
c.dataReceived = self.write
c.connectionLost = self.sendEOF
frame.callback = c.dataReceived
frame.canvas.focus_force()
if options["subsystem"]:
self.conn.sendRequest(self, b"subsystem", common.NS(options["command"]))
elif options["command"]:
if options["tty"]:
term = os.environ.get("TERM", "xterm")
# winsz = fcntl.ioctl(fd, tty.TIOCGWINSZ, '12345678')
winSize = (25, 80, 0, 0) # struct.unpack('4H', winsz)
ptyReqData = session.packRequest_pty_req(term, winSize, "")
self.conn.sendRequest(self, b"pty-req", ptyReqData)
self.conn.sendRequest(self, "exec", common.NS(options["command"]))
else:
if not options["notty"]:
term = os.environ.get("TERM", "xterm")
# winsz = fcntl.ioctl(fd, tty.TIOCGWINSZ, '12345678')
winSize = (25, 80, 0, 0) # struct.unpack('4H', winsz)
ptyReqData = session.packRequest_pty_req(term, winSize, "")
self.conn.sendRequest(self, b"pty-req", ptyReqData)
self.conn.sendRequest(self, b"shell", b"")
self.conn.transport.transport.setTcpNoDelay(1)
def handleInput(self, char):
# log.msg('handling %s' % repr(char))
if char in ("\n", "\r"):
self.escapeMode = 1
self.write(char)
elif self.escapeMode == 1 and char == options["escape"]:
self.escapeMode = 2
elif self.escapeMode == 2:
self.escapeMode = 1 # so we can chain escapes together
if char == ".": # disconnect
log.msg("disconnecting from escape")
reactor.stop()
return
elif char == "\x1a": # ^Z, suspend
# following line courtesy of Erwin@freenode
os.kill(os.getpid(), signal.SIGSTOP)
return
elif char == "R": # rekey connection
log.msg("rekeying connection")
self.conn.transport.sendKexInit()
return
self.write("~" + char)
else:
self.escapeMode = 0
self.write(char)
def dataReceived(self, data):
data = data.decode("utf-8")
if options["ansilog"]:
print(repr(data))
frame.write(data)
def extReceived(self, t, data):
if t == connection.EXTENDED_DATA_STDERR:
log.msg("got %s stderr data" % len(data))
sys.stderr.write(data)
sys.stderr.flush()
def eofReceived(self):
log.msg("got eof")
sys.stdin.close()
def closed(self):
log.msg("closed %s" % self)
if len(self.conn.channels) == 1: # just us left
reactor.stop()
def request_exit_status(self, data):
global exitStatus
exitStatus = int(struct.unpack(">L", data)[0])
log.msg("exit status: %s" % exitStatus)
def sendEOF(self):
self.conn.sendEOF(self)
if __name__ == "__main__":
run()

View File

@@ -0,0 +1,10 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
#
"""
An SSHv2 implementation for Twisted. Part of the Twisted.Conch package.
Maintainer: Paul Swartz
"""

View File

@@ -0,0 +1,293 @@
# -*- test-case-name: twisted.conch.test.test_transport -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
SSH key exchange handling.
"""
from hashlib import sha1, sha256, sha384, sha512
from zope.interface import Attribute, Interface, implementer
from twisted.conch import error
class _IKexAlgorithm(Interface):
"""
An L{_IKexAlgorithm} describes a key exchange algorithm.
"""
preference = Attribute(
"An L{int} giving the preference of the algorithm when negotiating "
"key exchange. Algorithms with lower precedence values are more "
"preferred."
)
hashProcessor = Attribute(
"A callable hash algorithm constructor (e.g. C{hashlib.sha256}) "
"suitable for use with this key exchange algorithm."
)
class _IFixedGroupKexAlgorithm(_IKexAlgorithm):
"""
An L{_IFixedGroupKexAlgorithm} describes a key exchange algorithm with a
fixed prime / generator group.
"""
prime = Attribute(
"An L{int} giving the prime number used in Diffie-Hellman key "
"exchange, or L{None} if not applicable."
)
generator = Attribute(
"An L{int} giving the generator number used in Diffie-Hellman key "
"exchange, or L{None} if not applicable. (This is not related to "
"Python generator functions.)"
)
class _IEllipticCurveExchangeKexAlgorithm(_IKexAlgorithm):
"""
An L{_IEllipticCurveExchangeKexAlgorithm} describes a key exchange algorithm
that uses an elliptic curve exchange between the client and server.
"""
class _IGroupExchangeKexAlgorithm(_IKexAlgorithm):
"""
An L{_IGroupExchangeKexAlgorithm} describes a key exchange algorithm
that uses group exchange between the client and server.
A prime / generator group should be chosen at run time based on the
requested size. See RFC 4419.
"""
@implementer(_IEllipticCurveExchangeKexAlgorithm)
class _Curve25519SHA256:
"""
Elliptic Curve Key Exchange using Curve25519 and SHA256. Defined in
U{https://datatracker.ietf.org/doc/draft-ietf-curdle-ssh-curves/}.
"""
preference = 1
hashProcessor = sha256
@implementer(_IEllipticCurveExchangeKexAlgorithm)
class _Curve25519SHA256LibSSH:
"""
As L{_Curve25519SHA256}, but with a pre-standardized algorithm name.
"""
preference = 2
hashProcessor = sha256
@implementer(_IEllipticCurveExchangeKexAlgorithm)
class _ECDH256:
"""
Elliptic Curve Key Exchange with SHA-256 as HASH. Defined in
RFC 5656.
Note that C{ecdh-sha2-nistp256} takes priority over nistp384 or nistp512.
This is the same priority from OpenSSH.
C{ecdh-sha2-nistp256} is considered preety good cryptography.
If you need something better consider using C{curve25519-sha256}.
"""
preference = 3
hashProcessor = sha256
@implementer(_IEllipticCurveExchangeKexAlgorithm)
class _ECDH384:
"""
Elliptic Curve Key Exchange with SHA-384 as HASH. Defined in
RFC 5656.
"""
preference = 4
hashProcessor = sha384
@implementer(_IEllipticCurveExchangeKexAlgorithm)
class _ECDH512:
"""
Elliptic Curve Key Exchange with SHA-512 as HASH. Defined in
RFC 5656.
"""
preference = 5
hashProcessor = sha512
@implementer(_IGroupExchangeKexAlgorithm)
class _DHGroupExchangeSHA256:
"""
Diffie-Hellman Group and Key Exchange with SHA-256 as HASH. Defined in
RFC 4419, 4.2.
"""
preference = 6
hashProcessor = sha256
@implementer(_IGroupExchangeKexAlgorithm)
class _DHGroupExchangeSHA1:
"""
Diffie-Hellman Group and Key Exchange with SHA-1 as HASH. Defined in
RFC 4419, 4.1.
"""
preference = 7
hashProcessor = sha1
@implementer(_IFixedGroupKexAlgorithm)
class _DHGroup14SHA1:
"""
Diffie-Hellman key exchange with SHA-1 as HASH and Oakley Group 14
(2048-bit MODP Group). Defined in RFC 4253, 8.2.
"""
preference = 8
hashProcessor = sha1
# Diffie-Hellman primes from Oakley Group 14 (RFC 3526, 3).
prime = int(
"323170060713110073003389139264238282488179412411402391128420"
"097514007417066343542226196894173635693471179017379097041917"
"546058732091950288537589861856221532121754125149017745202702"
"357960782362488842461894775876411059286460994117232454266225"
"221932305409190376805242355191256797158701170010580558776510"
"388618472802579760549035697325615261670813393617995413364765"
"591603683178967290731783845896806396719009772021941686472258"
"710314113364293195361934716365332097170774482279885885653692"
"086452966360772502689555059283627511211740969729980684105543"
"595848665832916421362182310789909994486524682624169720359118"
"52507045361090559"
)
generator = 2
# Which ECDH hash function to use is dependent on the size.
_kexAlgorithms = {
b"curve25519-sha256": _Curve25519SHA256(),
b"curve25519-sha256@libssh.org": _Curve25519SHA256LibSSH(),
b"diffie-hellman-group-exchange-sha256": _DHGroupExchangeSHA256(),
b"diffie-hellman-group-exchange-sha1": _DHGroupExchangeSHA1(),
b"diffie-hellman-group14-sha1": _DHGroup14SHA1(),
b"ecdh-sha2-nistp256": _ECDH256(),
b"ecdh-sha2-nistp384": _ECDH384(),
b"ecdh-sha2-nistp521": _ECDH512(),
}
def getKex(kexAlgorithm):
"""
Get a description of a named key exchange algorithm.
@param kexAlgorithm: The key exchange algorithm name.
@type kexAlgorithm: L{bytes}
@return: A description of the key exchange algorithm named by
C{kexAlgorithm}.
@rtype: L{_IKexAlgorithm}
@raises ConchError: if the key exchange algorithm is not found.
"""
if kexAlgorithm not in _kexAlgorithms:
raise error.ConchError(f"Unsupported key exchange algorithm: {kexAlgorithm}")
return _kexAlgorithms[kexAlgorithm]
def isEllipticCurve(kexAlgorithm):
"""
Returns C{True} if C{kexAlgorithm} is an elliptic curve.
@param kexAlgorithm: The key exchange algorithm name.
@type kexAlgorithm: C{str}
@return: C{True} if C{kexAlgorithm} is an elliptic curve,
otherwise C{False}.
@rtype: C{bool}
"""
return _IEllipticCurveExchangeKexAlgorithm.providedBy(getKex(kexAlgorithm))
def isFixedGroup(kexAlgorithm):
"""
Returns C{True} if C{kexAlgorithm} has a fixed prime / generator group.
@param kexAlgorithm: The key exchange algorithm name.
@type kexAlgorithm: L{bytes}
@return: C{True} if C{kexAlgorithm} has a fixed prime / generator group,
otherwise C{False}.
@rtype: L{bool}
"""
return _IFixedGroupKexAlgorithm.providedBy(getKex(kexAlgorithm))
def getHashProcessor(kexAlgorithm):
"""
Get the hash algorithm callable to use in key exchange.
@param kexAlgorithm: The key exchange algorithm name.
@type kexAlgorithm: L{bytes}
@return: A callable hash algorithm constructor (e.g. C{hashlib.sha256}).
@rtype: C{callable}
"""
kex = getKex(kexAlgorithm)
return kex.hashProcessor
def getDHGeneratorAndPrime(kexAlgorithm):
"""
Get the generator and the prime to use in key exchange.
@param kexAlgorithm: The key exchange algorithm name.
@type kexAlgorithm: L{bytes}
@return: A L{tuple} containing L{int} generator and L{int} prime.
@rtype: L{tuple}
"""
kex = getKex(kexAlgorithm)
return kex.generator, kex.prime
def getSupportedKeyExchanges():
"""
Get a list of supported key exchange algorithm names in order of
preference.
@return: A C{list} of supported key exchange algorithm names.
@rtype: C{list} of L{bytes}
"""
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import ec
from twisted.conch.ssh.keys import _curveTable
backend = default_backend()
kexAlgorithms = _kexAlgorithms.copy()
for keyAlgorithm in list(kexAlgorithms):
if keyAlgorithm.startswith(b"ecdh"):
keyAlgorithmDsa = keyAlgorithm.replace(b"ecdh", b"ecdsa")
supported = backend.elliptic_curve_exchange_algorithm_supported(
ec.ECDH(), _curveTable[keyAlgorithmDsa]
)
elif keyAlgorithm.startswith(b"curve25519-sha256"):
supported = backend.x25519_supported()
else:
supported = True
if not supported:
kexAlgorithms.pop(keyAlgorithm)
return sorted(
kexAlgorithms, key=lambda kexAlgorithm: kexAlgorithms[kexAlgorithm].preference
)

View File

@@ -0,0 +1,43 @@
# -*- test-case-name: twisted.conch.test.test_address -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Address object for SSH network connections.
Maintainer: Paul Swartz
@since: 12.1
"""
from zope.interface import implementer
from twisted.internet.interfaces import IAddress
from twisted.python import util
@implementer(IAddress)
class SSHTransportAddress(util.FancyEqMixin):
"""
Object representing an SSH Transport endpoint.
This is used to ensure that any code inspecting this address and
attempting to construct a similar connection based upon it is not
mislead into creating a transport which is not similar to the one it is
indicating.
@ivar address: An instance of an object which implements I{IAddress} to
which this transport address is connected.
"""
compareAttributes = ("address",)
def __init__(self, address):
self.address = address
def __repr__(self) -> str:
return f"SSHTransportAddress({self.address!r})"
def __hash__(self):
return hash(("SSH", self.address))

View File

@@ -0,0 +1,278 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Implements the SSH v2 key agent protocol. This protocol is documented in the
SSH source code, in the file
U{PROTOCOL.agent<http://www.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.agent>}.
Maintainer: Paul Swartz
"""
import struct
from twisted.conch.error import ConchError, MissingKeyStoreError
from twisted.conch.ssh import keys
from twisted.conch.ssh.common import NS, getMP, getNS
from twisted.internet import defer, protocol
class SSHAgentClient(protocol.Protocol):
"""
The client side of the SSH agent protocol. This is equivalent to
ssh-add(1) and can be used with either ssh-agent(1) or the SSHAgentServer
protocol, also in this package.
"""
def __init__(self):
self.buf = b""
self.deferreds = []
def dataReceived(self, data):
self.buf += data
while 1:
if len(self.buf) <= 4:
return
packLen = struct.unpack("!L", self.buf[:4])[0]
if len(self.buf) < 4 + packLen:
return
packet, self.buf = self.buf[4 : 4 + packLen], self.buf[4 + packLen :]
reqType = ord(packet[0:1])
d = self.deferreds.pop(0)
if reqType == AGENT_FAILURE:
d.errback(ConchError("agent failure"))
elif reqType == AGENT_SUCCESS:
d.callback(b"")
else:
d.callback(packet)
def sendRequest(self, reqType, data):
pack = struct.pack("!LB", len(data) + 1, reqType) + data
self.transport.write(pack)
d = defer.Deferred()
self.deferreds.append(d)
return d
def requestIdentities(self):
"""
@return: A L{Deferred} which will fire with a list of all keys found in
the SSH agent. The list of keys is comprised of (public key blob,
comment) tuples.
"""
d = self.sendRequest(AGENTC_REQUEST_IDENTITIES, b"")
d.addCallback(self._cbRequestIdentities)
return d
def _cbRequestIdentities(self, data):
"""
Unpack a collection of identities into a list of tuples comprised of
public key blobs and comments.
"""
if ord(data[0:1]) != AGENT_IDENTITIES_ANSWER:
raise ConchError("unexpected response: %i" % ord(data[0:1]))
numKeys = struct.unpack("!L", data[1:5])[0]
result = []
data = data[5:]
for i in range(numKeys):
blob, data = getNS(data)
comment, data = getNS(data)
result.append((blob, comment))
return result
def addIdentity(self, blob, comment=b""):
"""
Add a private key blob to the agent's collection of keys.
"""
req = blob
req += NS(comment)
return self.sendRequest(AGENTC_ADD_IDENTITY, req)
def signData(self, blob, data):
"""
Request that the agent sign the given C{data} with the private key
which corresponds to the public key given by C{blob}. The private
key should have been added to the agent already.
@type blob: L{bytes}
@type data: L{bytes}
@return: A L{Deferred} which fires with a signature for given data
created with the given key.
"""
req = NS(blob)
req += NS(data)
req += b"\000\000\000\000" # flags
return self.sendRequest(AGENTC_SIGN_REQUEST, req).addCallback(self._cbSignData)
def _cbSignData(self, data):
if ord(data[0:1]) != AGENT_SIGN_RESPONSE:
raise ConchError("unexpected data: %i" % ord(data[0:1]))
signature = getNS(data[1:])[0]
return signature
def removeIdentity(self, blob):
"""
Remove the private key corresponding to the public key in blob from the
running agent.
"""
req = NS(blob)
return self.sendRequest(AGENTC_REMOVE_IDENTITY, req)
def removeAllIdentities(self):
"""
Remove all keys from the running agent.
"""
return self.sendRequest(AGENTC_REMOVE_ALL_IDENTITIES, b"")
class SSHAgentServer(protocol.Protocol):
"""
The server side of the SSH agent protocol. This is equivalent to
ssh-agent(1) and can be used with either ssh-add(1) or the SSHAgentClient
protocol, also in this package.
"""
def __init__(self):
self.buf = b""
def dataReceived(self, data):
self.buf += data
while 1:
if len(self.buf) <= 4:
return
packLen = struct.unpack("!L", self.buf[:4])[0]
if len(self.buf) < 4 + packLen:
return
packet, self.buf = self.buf[4 : 4 + packLen], self.buf[4 + packLen :]
reqType = ord(packet[0:1])
reqName = messages.get(reqType, None)
if not reqName:
self.sendResponse(AGENT_FAILURE, b"")
else:
f = getattr(self, "agentc_%s" % reqName)
if getattr(self.factory, "keys", None) is None:
self.sendResponse(AGENT_FAILURE, b"")
raise MissingKeyStoreError()
f(packet[1:])
def sendResponse(self, reqType, data):
pack = struct.pack("!LB", len(data) + 1, reqType) + data
self.transport.write(pack)
def agentc_REQUEST_IDENTITIES(self, data):
"""
Return all of the identities that have been added to the server
"""
assert data == b""
numKeys = len(self.factory.keys)
resp = []
resp.append(struct.pack("!L", numKeys))
for key, comment in self.factory.keys.values():
resp.append(NS(key.blob())) # yes, wrapped in an NS
resp.append(NS(comment))
self.sendResponse(AGENT_IDENTITIES_ANSWER, b"".join(resp))
def agentc_SIGN_REQUEST(self, data):
"""
Data is a structure with a reference to an already added key object and
some data that the clients wants signed with that key. If the key
object wasn't loaded, return AGENT_FAILURE, else return the signature.
"""
blob, data = getNS(data)
if blob not in self.factory.keys:
return self.sendResponse(AGENT_FAILURE, b"")
signData, data = getNS(data)
assert data == b"\000\000\000\000"
self.sendResponse(
AGENT_SIGN_RESPONSE, NS(self.factory.keys[blob][0].sign(signData))
)
def agentc_ADD_IDENTITY(self, data):
"""
Adds a private key to the agent's collection of identities. On
subsequent interactions, the private key can be accessed using only the
corresponding public key.
"""
# need to pre-read the key data so we can get past it to the comment string
keyType, rest = getNS(data)
if keyType == b"ssh-rsa":
nmp = 6
elif keyType == b"ssh-dss":
nmp = 5
else:
raise keys.BadKeyError("unknown blob type: %s" % keyType)
rest = getMP(rest, nmp)[
-1
] # ignore the key data for now, we just want the comment
comment, rest = getNS(rest) # the comment, tacked onto the end of the key blob
k = keys.Key.fromString(data, type="private_blob") # not wrapped in NS here
self.factory.keys[k.blob()] = (k, comment)
self.sendResponse(AGENT_SUCCESS, b"")
def agentc_REMOVE_IDENTITY(self, data):
"""
Remove a specific key from the agent's collection of identities.
"""
blob, _ = getNS(data)
k = keys.Key.fromString(blob, type="blob")
del self.factory.keys[k.blob()]
self.sendResponse(AGENT_SUCCESS, b"")
def agentc_REMOVE_ALL_IDENTITIES(self, data):
"""
Remove all keys from the agent's collection of identities.
"""
assert data == b""
self.factory.keys = {}
self.sendResponse(AGENT_SUCCESS, b"")
# v1 messages that we ignore because we don't keep v1 keys
# open-ssh sends both v1 and v2 commands, so we have to
# do no-ops for v1 commands or we'll get "bad request" errors
def agentc_REQUEST_RSA_IDENTITIES(self, data):
"""
v1 message for listing RSA1 keys; superseded by
agentc_REQUEST_IDENTITIES, which handles different key types.
"""
self.sendResponse(AGENT_RSA_IDENTITIES_ANSWER, struct.pack("!L", 0))
def agentc_REMOVE_RSA_IDENTITY(self, data):
"""
v1 message for removing RSA1 keys; superseded by
agentc_REMOVE_IDENTITY, which handles different key types.
"""
self.sendResponse(AGENT_SUCCESS, b"")
def agentc_REMOVE_ALL_RSA_IDENTITIES(self, data):
"""
v1 message for removing all RSA1 keys; superseded by
agentc_REMOVE_ALL_IDENTITIES, which handles different key types.
"""
self.sendResponse(AGENT_SUCCESS, b"")
AGENTC_REQUEST_RSA_IDENTITIES = 1
AGENT_RSA_IDENTITIES_ANSWER = 2
AGENT_FAILURE = 5
AGENT_SUCCESS = 6
AGENTC_REMOVE_RSA_IDENTITY = 8
AGENTC_REMOVE_ALL_RSA_IDENTITIES = 9
AGENTC_REQUEST_IDENTITIES = 11
AGENT_IDENTITIES_ANSWER = 12
AGENTC_SIGN_REQUEST = 13
AGENT_SIGN_RESPONSE = 14
AGENTC_ADD_IDENTITY = 17
AGENTC_REMOVE_IDENTITY = 18
AGENTC_REMOVE_ALL_IDENTITIES = 19
messages = {}
for name, value in locals().copy().items():
if name[:7] == "AGENTC_":
messages[value] = name[7:] # doesn't handle doubles

View File

@@ -0,0 +1,312 @@
# -*- test-case-name: twisted.conch.test.test_channel -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
The parent class for all the SSH Channels. Currently implemented channels
are session, direct-tcp, and forwarded-tcp.
Maintainer: Paul Swartz
"""
from zope.interface import implementer
from twisted.internet import interfaces
from twisted.logger import Logger
from twisted.python import log
@implementer(interfaces.ITransport)
class SSHChannel(log.Logger):
"""
A class that represents a multiplexed channel over an SSH connection.
The channel has a local window which is the maximum amount of data it will
receive, and a remote which is the maximum amount of data the remote side
will accept. There is also a maximum packet size for any individual data
packet going each way.
@ivar name: the name of the channel.
@type name: L{bytes}
@ivar localWindowSize: the maximum size of the local window in bytes.
@type localWindowSize: L{int}
@ivar localWindowLeft: how many bytes are left in the local window.
@type localWindowLeft: L{int}
@ivar localMaxPacket: the maximum size of packet we will accept in bytes.
@type localMaxPacket: L{int}
@ivar remoteWindowLeft: how many bytes are left in the remote window.
@type remoteWindowLeft: L{int}
@ivar remoteMaxPacket: the maximum size of a packet the remote side will
accept in bytes.
@type remoteMaxPacket: L{int}
@ivar conn: the connection this channel is multiplexed through.
@type conn: L{SSHConnection}
@ivar data: any data to send to the other side when the channel is
requested.
@type data: L{bytes}
@ivar avatar: an avatar for the logged-in user (if a server channel)
@ivar localClosed: True if we aren't accepting more data.
@type localClosed: L{bool}
@ivar remoteClosed: True if the other side isn't accepting more data.
@type remoteClosed: L{bool}
"""
_log = Logger()
name: bytes = None # type: ignore[assignment] # only needed for client channels
def __init__(
self,
localWindow=0,
localMaxPacket=0,
remoteWindow=0,
remoteMaxPacket=0,
conn=None,
data=None,
avatar=None,
):
self.localWindowSize = localWindow or 131072
self.localWindowLeft = self.localWindowSize
self.localMaxPacket = localMaxPacket or 32768
self.remoteWindowLeft = remoteWindow
self.remoteMaxPacket = remoteMaxPacket
self.areWriting = 1
self.conn = conn
self.data = data
self.avatar = avatar
self.specificData = b""
self.buf = b""
self.extBuf = []
self.closing = 0
self.localClosed = 0
self.remoteClosed = 0
self.id = None # gets set later by SSHConnection
def __str__(self) -> str:
return self.__bytes__().decode("ascii")
def __bytes__(self) -> bytes:
"""
Return a byte string representation of the channel
"""
name = self.name
if not name:
name = b"None"
return b"<SSHChannel %b (lw %d rw %d)>" % (
name,
self.localWindowLeft,
self.remoteWindowLeft,
)
def logPrefix(self):
id = (self.id is not None and str(self.id)) or "unknown"
if self.name:
name = self.name.decode("ascii")
else:
name = "None"
return f"SSHChannel {name} ({id}) on {self.conn.logPrefix()}"
def channelOpen(self, specificData):
"""
Called when the channel is opened. specificData is any data that the
other side sent us when opening the channel.
@type specificData: L{bytes}
"""
self._log.info("channel open")
def openFailed(self, reason):
"""
Called when the open failed for some reason.
reason.desc is a string descrption, reason.code the SSH error code.
@type reason: L{error.ConchError}
"""
self._log.error("other side refused open\nreason: {reason}", reason=reason)
def addWindowBytes(self, data):
"""
Called when bytes are added to the remote window. By default it clears
the data buffers.
@type data: L{bytes}
"""
self.remoteWindowLeft = self.remoteWindowLeft + data
if not self.areWriting and not self.closing:
self.areWriting = True
self.startWriting()
if self.buf:
b = self.buf
self.buf = b""
self.write(b)
if self.extBuf:
b = self.extBuf
self.extBuf = []
for type, data in b:
self.writeExtended(type, data)
def requestReceived(self, requestType, data):
"""
Called when a request is sent to this channel. By default it delegates
to self.request_<requestType>.
If this function returns true, the request succeeded, otherwise it
failed.
@type requestType: L{bytes}
@type data: L{bytes}
@rtype: L{bool}
"""
foo = requestType.replace(b"-", b"_").decode("ascii")
f = getattr(self, "request_" + foo, None)
if f:
return f(data)
self._log.info("unhandled request for {requestType}", requestType=requestType)
return 0
def dataReceived(self, data):
"""
Called when we receive data.
@type data: L{bytes}
"""
self._log.debug("got data {data}", data=data)
def extReceived(self, dataType, data):
"""
Called when we receive extended data (usually standard error).
@type dataType: L{int}
@type data: L{str}
"""
self._log.debug(
"got extended data {dataType} {data!r}", dataType=dataType, data=data
)
def eofReceived(self):
"""
Called when the other side will send no more data.
"""
self._log.info("remote eof")
def closeReceived(self):
"""
Called when the other side has closed the channel.
"""
self._log.info("remote close")
self.loseConnection()
def closed(self):
"""
Called when the channel is closed. This means that both our side and
the remote side have closed the channel.
"""
self._log.info("closed")
def write(self, data):
"""
Write some data to the channel. If there is not enough remote window
available, buffer until it is. Otherwise, split the data into
packets of length remoteMaxPacket and send them.
@type data: L{bytes}
"""
if self.buf:
self.buf += data
return
top = len(data)
if top > self.remoteWindowLeft:
data, self.buf = (
data[: self.remoteWindowLeft],
data[self.remoteWindowLeft :],
)
self.areWriting = 0
self.stopWriting()
top = self.remoteWindowLeft
rmp = self.remoteMaxPacket
write = self.conn.sendData
r = range(0, top, rmp)
for offset in r:
write(self, data[offset : offset + rmp])
self.remoteWindowLeft -= top
if self.closing and not self.buf:
self.loseConnection() # try again
def writeExtended(self, dataType, data):
"""
Send extended data to this channel. If there is not enough remote
window available, buffer until there is. Otherwise, split the data
into packets of length remoteMaxPacket and send them.
@type dataType: L{int}
@type data: L{bytes}
"""
if self.extBuf:
if self.extBuf[-1][0] == dataType:
self.extBuf[-1][1] += data
else:
self.extBuf.append([dataType, data])
return
if len(data) > self.remoteWindowLeft:
data, self.extBuf = (
data[: self.remoteWindowLeft],
[[dataType, data[self.remoteWindowLeft :]]],
)
self.areWriting = 0
self.stopWriting()
while len(data) > self.remoteMaxPacket:
self.conn.sendExtendedData(self, dataType, data[: self.remoteMaxPacket])
data = data[self.remoteMaxPacket :]
self.remoteWindowLeft -= self.remoteMaxPacket
if data:
self.conn.sendExtendedData(self, dataType, data)
self.remoteWindowLeft -= len(data)
if self.closing:
self.loseConnection() # try again
def writeSequence(self, data):
"""
Part of the Transport interface. Write a list of strings to the
channel.
@type data: C{list} of L{str}
"""
self.write(b"".join(data))
def loseConnection(self):
"""
Close the channel if there is no buferred data. Otherwise, note the
request and return.
"""
self.closing = 1
if not self.buf and not self.extBuf:
self.conn.sendClose(self)
def getPeer(self):
"""
See: L{ITransport.getPeer}
@return: The remote address of this connection.
@rtype: L{SSHTransportAddress}.
"""
return self.conn.transport.getPeer()
def getHost(self):
"""
See: L{ITransport.getHost}
@return: An address describing this side of the connection.
@rtype: L{SSHTransportAddress}.
"""
return self.conn.transport.getHost()
def stopWriting(self):
"""
Called when the remote buffer is full, as a hint to stop writing.
This can be ignored, but it can be helpful.
"""
def startWriting(self):
"""
Called when the remote buffer has more room, as a hint to continue
writing.
"""

View File

@@ -0,0 +1,85 @@
# -*- test-case-name: twisted.conch.test.test_ssh -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Common functions for the SSH classes.
Maintainer: Paul Swartz
"""
import struct
from cryptography.utils import int_to_bytes
from twisted.python.deprecate import deprecated
from twisted.python.versions import Version
__all__ = ["NS", "getNS", "MP", "getMP", "ffs"]
def NS(t):
"""
net string
"""
if isinstance(t, str):
t = t.encode("utf-8")
return struct.pack("!L", len(t)) + t
def getNS(s, count=1):
"""
get net string
"""
ns = []
c = 0
for i in range(count):
(l,) = struct.unpack("!L", s[c : c + 4])
ns.append(s[c + 4 : 4 + l + c])
c += 4 + l
return tuple(ns) + (s[c:],)
def MP(number):
if number == 0:
return b"\000" * 4
assert number > 0
bn = int_to_bytes(number)
if ord(bn[0:1]) & 128:
bn = b"\000" + bn
return struct.pack(">L", len(bn)) + bn
def getMP(data, count=1):
"""
Get multiple precision integer out of the string. A multiple precision
integer is stored as a 4-byte length followed by length bytes of the
integer. If count is specified, get count integers out of the string.
The return value is a tuple of count integers followed by the rest of
the data.
"""
mp = []
c = 0
for i in range(count):
(length,) = struct.unpack(">L", data[c : c + 4])
mp.append(int.from_bytes(data[c + 4 : c + 4 + length], "big"))
c += 4 + length
return tuple(mp) + (data[c:],)
def ffs(c, s):
"""
first from second
goes through the first list, looking for items in the second, returns the first one
"""
for i in c:
if i in s:
return i
@deprecated(Version("Twisted", 16, 5, 0))
def install():
# This used to install gmpy, but is technically public API, so just do
# nothing.
pass

View File

@@ -0,0 +1,679 @@
# -*- test-case-name: twisted.conch.test.test_connection -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
This module contains the implementation of the ssh-connection service, which
allows access to the shell and port-forwarding.
Maintainer: Paul Swartz
"""
import string
import struct
import twisted.internet.error
from twisted.conch import error
from twisted.conch.ssh import common, service
from twisted.internet import defer
from twisted.logger import Logger
from twisted.python.compat import nativeString, networkString
class SSHConnection(service.SSHService):
"""
An implementation of the 'ssh-connection' service. It is used to
multiplex multiple channels over the single SSH connection.
@ivar localChannelID: the next number to use as a local channel ID.
@type localChannelID: L{int}
@ivar channels: a L{dict} mapping a local channel ID to C{SSHChannel}
subclasses.
@type channels: L{dict}
@ivar localToRemoteChannel: a L{dict} mapping a local channel ID to a
remote channel ID.
@type localToRemoteChannel: L{dict}
@ivar channelsToRemoteChannel: a L{dict} mapping a C{SSHChannel} subclass
to remote channel ID.
@type channelsToRemoteChannel: L{dict}
@ivar deferreds: a L{dict} mapping a local channel ID to a C{list} of
C{Deferreds} for outstanding channel requests. Also, the 'global'
key stores the C{list} of pending global request C{Deferred}s.
"""
name = b"ssh-connection"
_log = Logger()
def __init__(self):
self.localChannelID = 0 # this is the current # to use for channel ID
# local channel ID -> remote channel ID
self.localToRemoteChannel = {}
# local channel ID -> subclass of SSHChannel
self.channels = {}
# subclass of SSHChannel -> remote channel ID
self.channelsToRemoteChannel = {}
# local channel -> list of deferreds for pending requests
# or 'global' -> list of deferreds for global requests
self.deferreds = {"global": []}
self.transport = None # gets set later
def serviceStarted(self):
if hasattr(self.transport, "avatar"):
self.transport.avatar.conn = self
def serviceStopped(self):
"""
Called when the connection is stopped.
"""
# Close any fully open channels
for channel in list(self.channelsToRemoteChannel.keys()):
self.channelClosed(channel)
# Indicate failure to any channels that were in the process of
# opening but not yet open.
while self.channels:
(_, channel) = self.channels.popitem()
channel.openFailed(twisted.internet.error.ConnectionLost())
# Errback any unfinished global requests.
self._cleanupGlobalDeferreds()
def _cleanupGlobalDeferreds(self):
"""
All pending requests that have returned a deferred must be errbacked
when this service is stopped, otherwise they might be left uncalled and
uncallable.
"""
for d in self.deferreds["global"]:
d.errback(error.ConchError("Connection stopped."))
del self.deferreds["global"][:]
# packet methods
def ssh_GLOBAL_REQUEST(self, packet):
"""
The other side has made a global request. Payload::
string request type
bool want reply
<request specific data>
This dispatches to self.gotGlobalRequest.
"""
requestType, rest = common.getNS(packet)
wantReply, rest = ord(rest[0:1]), rest[1:]
ret = self.gotGlobalRequest(requestType, rest)
if wantReply:
reply = MSG_REQUEST_FAILURE
data = b""
if ret:
reply = MSG_REQUEST_SUCCESS
if isinstance(ret, (tuple, list)):
data = ret[1]
self.transport.sendPacket(reply, data)
def ssh_REQUEST_SUCCESS(self, packet):
"""
Our global request succeeded. Get the appropriate Deferred and call
it back with the packet we received.
"""
self._log.debug("global request success")
self.deferreds["global"].pop(0).callback(packet)
def ssh_REQUEST_FAILURE(self, packet):
"""
Our global request failed. Get the appropriate Deferred and errback
it with the packet we received.
"""
self._log.debug("global request failure")
self.deferreds["global"].pop(0).errback(
error.ConchError("global request failed", packet)
)
def ssh_CHANNEL_OPEN(self, packet):
"""
The other side wants to get a channel. Payload::
string channel name
uint32 remote channel number
uint32 remote window size
uint32 remote maximum packet size
<channel specific data>
We get a channel from self.getChannel(), give it a local channel number
and notify the other side. Then notify the channel by calling its
channelOpen method.
"""
channelType, rest = common.getNS(packet)
senderChannel, windowSize, maxPacket = struct.unpack(">3L", rest[:12])
packet = rest[12:]
try:
channel = self.getChannel(channelType, windowSize, maxPacket, packet)
localChannel = self.localChannelID
self.localChannelID += 1
channel.id = localChannel
self.channels[localChannel] = channel
self.channelsToRemoteChannel[channel] = senderChannel
self.localToRemoteChannel[localChannel] = senderChannel
openConfirmPacket = (
struct.pack(
">4L",
senderChannel,
localChannel,
channel.localWindowSize,
channel.localMaxPacket,
)
+ channel.specificData
)
self.transport.sendPacket(MSG_CHANNEL_OPEN_CONFIRMATION, openConfirmPacket)
channel.channelOpen(packet)
except Exception as e:
self._log.failure("channel open failed")
if isinstance(e, error.ConchError):
textualInfo, reason = e.args
if isinstance(textualInfo, int):
# See #3657 and #3071
textualInfo, reason = reason, textualInfo
else:
reason = OPEN_CONNECT_FAILED
textualInfo = "unknown failure"
self.transport.sendPacket(
MSG_CHANNEL_OPEN_FAILURE,
struct.pack(">2L", senderChannel, reason)
+ common.NS(networkString(textualInfo))
+ common.NS(b""),
)
def ssh_CHANNEL_OPEN_CONFIRMATION(self, packet):
"""
The other side accepted our MSG_CHANNEL_OPEN request. Payload::
uint32 local channel number
uint32 remote channel number
uint32 remote window size
uint32 remote maximum packet size
<channel specific data>
Find the channel using the local channel number and notify its
channelOpen method.
"""
(localChannel, remoteChannel, windowSize, maxPacket) = struct.unpack(
">4L", packet[:16]
)
specificData = packet[16:]
channel = self.channels[localChannel]
channel.conn = self
self.localToRemoteChannel[localChannel] = remoteChannel
self.channelsToRemoteChannel[channel] = remoteChannel
channel.remoteWindowLeft = windowSize
channel.remoteMaxPacket = maxPacket
channel.channelOpen(specificData)
def ssh_CHANNEL_OPEN_FAILURE(self, packet):
"""
The other side did not accept our MSG_CHANNEL_OPEN request. Payload::
uint32 local channel number
uint32 reason code
string reason description
Find the channel using the local channel number and notify it by
calling its openFailed() method.
"""
localChannel, reasonCode = struct.unpack(">2L", packet[:8])
reasonDesc = common.getNS(packet[8:])[0]
channel = self.channels[localChannel]
del self.channels[localChannel]
channel.conn = self
reason = error.ConchError(reasonDesc, reasonCode)
channel.openFailed(reason)
def ssh_CHANNEL_WINDOW_ADJUST(self, packet):
"""
The other side is adding bytes to its window. Payload::
uint32 local channel number
uint32 bytes to add
Call the channel's addWindowBytes() method to add new bytes to the
remote window.
"""
localChannel, bytesToAdd = struct.unpack(">2L", packet[:8])
channel = self.channels[localChannel]
channel.addWindowBytes(bytesToAdd)
def ssh_CHANNEL_DATA(self, packet):
"""
The other side is sending us data. Payload::
uint32 local channel number
string data
Check to make sure the other side hasn't sent too much data (more
than what's in the window, or more than the maximum packet size). If
they have, close the channel. Otherwise, decrease the available
window and pass the data to the channel's dataReceived().
"""
localChannel, dataLength = struct.unpack(">2L", packet[:8])
channel = self.channels[localChannel]
# XXX should this move to dataReceived to put client in charge?
if (
dataLength > channel.localWindowLeft or dataLength > channel.localMaxPacket
): # more data than we want
self._log.error("too much data")
self.sendClose(channel)
return
# packet = packet[:channel.localWindowLeft+4]
data = common.getNS(packet[4:])[0]
channel.localWindowLeft -= dataLength
if channel.localWindowLeft < channel.localWindowSize // 2:
self.adjustWindow(
channel, channel.localWindowSize - channel.localWindowLeft
)
channel.dataReceived(data)
def ssh_CHANNEL_EXTENDED_DATA(self, packet):
"""
The other side is sending us exteneded data. Payload::
uint32 local channel number
uint32 type code
string data
Check to make sure the other side hasn't sent too much data (more
than what's in the window, or than the maximum packet size). If
they have, close the channel. Otherwise, decrease the available
window and pass the data and type code to the channel's
extReceived().
"""
localChannel, typeCode, dataLength = struct.unpack(">3L", packet[:12])
channel = self.channels[localChannel]
if dataLength > channel.localWindowLeft or dataLength > channel.localMaxPacket:
self._log.error("too much extdata")
self.sendClose(channel)
return
data = common.getNS(packet[8:])[0]
channel.localWindowLeft -= dataLength
if channel.localWindowLeft < channel.localWindowSize // 2:
self.adjustWindow(
channel, channel.localWindowSize - channel.localWindowLeft
)
channel.extReceived(typeCode, data)
def ssh_CHANNEL_EOF(self, packet):
"""
The other side is not sending any more data. Payload::
uint32 local channel number
Notify the channel by calling its eofReceived() method.
"""
localChannel = struct.unpack(">L", packet[:4])[0]
channel = self.channels[localChannel]
channel.eofReceived()
def ssh_CHANNEL_CLOSE(self, packet):
"""
The other side is closing its end; it does not want to receive any
more data. Payload::
uint32 local channel number
Notify the channnel by calling its closeReceived() method. If
the channel has also sent a close message, call self.channelClosed().
"""
localChannel = struct.unpack(">L", packet[:4])[0]
channel = self.channels[localChannel]
channel.closeReceived()
channel.remoteClosed = True
if channel.localClosed and channel.remoteClosed:
self.channelClosed(channel)
def ssh_CHANNEL_REQUEST(self, packet):
"""
The other side is sending a request to a channel. Payload::
uint32 local channel number
string request name
bool want reply
<request specific data>
Pass the message to the channel's requestReceived method. If the
other side wants a reply, add callbacks which will send the
reply.
"""
localChannel = struct.unpack(">L", packet[:4])[0]
requestType, rest = common.getNS(packet[4:])
wantReply = ord(rest[0:1])
channel = self.channels[localChannel]
d = defer.maybeDeferred(channel.requestReceived, requestType, rest[1:])
if wantReply:
d.addCallback(self._cbChannelRequest, localChannel)
d.addErrback(self._ebChannelRequest, localChannel)
return d
def _cbChannelRequest(self, result, localChannel):
"""
Called back if the other side wanted a reply to a channel request. If
the result is true, send a MSG_CHANNEL_SUCCESS. Otherwise, raise
a C{error.ConchError}
@param result: the value returned from the channel's requestReceived()
method. If it's False, the request failed.
@type result: L{bool}
@param localChannel: the local channel ID of the channel to which the
request was made.
@type localChannel: L{int}
@raises ConchError: if the result is False.
"""
if not result:
raise error.ConchError("failed request")
self.transport.sendPacket(
MSG_CHANNEL_SUCCESS,
struct.pack(">L", self.localToRemoteChannel[localChannel]),
)
def _ebChannelRequest(self, result, localChannel):
"""
Called if the other wisde wanted a reply to the channel requeset and
the channel request failed.
@param result: a Failure, but it's not used.
@param localChannel: the local channel ID of the channel to which the
request was made.
@type localChannel: L{int}
"""
self.transport.sendPacket(
MSG_CHANNEL_FAILURE,
struct.pack(">L", self.localToRemoteChannel[localChannel]),
)
def ssh_CHANNEL_SUCCESS(self, packet):
"""
Our channel request to the other side succeeded. Payload::
uint32 local channel number
Get the C{Deferred} out of self.deferreds and call it back.
"""
localChannel = struct.unpack(">L", packet[:4])[0]
if self.deferreds.get(localChannel):
d = self.deferreds[localChannel].pop(0)
d.callback("")
def ssh_CHANNEL_FAILURE(self, packet):
"""
Our channel request to the other side failed. Payload::
uint32 local channel number
Get the C{Deferred} out of self.deferreds and errback it with a
C{error.ConchError}.
"""
localChannel = struct.unpack(">L", packet[:4])[0]
if self.deferreds.get(localChannel):
d = self.deferreds[localChannel].pop(0)
d.errback(error.ConchError("channel request failed"))
# methods for users of the connection to call
def sendGlobalRequest(self, request, data, wantReply=0):
"""
Send a global request for this connection. Current this is only used
for remote->local TCP forwarding.
@type request: L{bytes}
@type data: L{bytes}
@type wantReply: L{bool}
@rtype: C{Deferred}/L{None}
"""
self.transport.sendPacket(
MSG_GLOBAL_REQUEST,
common.NS(request) + (wantReply and b"\xff" or b"\x00") + data,
)
if wantReply:
d = defer.Deferred()
self.deferreds["global"].append(d)
return d
def openChannel(self, channel, extra=b""):
"""
Open a new channel on this connection.
@type channel: subclass of C{SSHChannel}
@type extra: L{bytes}
"""
self._log.info(
"opening channel {id} with {localWindowSize} {localMaxPacket}",
id=self.localChannelID,
localWindowSize=channel.localWindowSize,
localMaxPacket=channel.localMaxPacket,
)
self.transport.sendPacket(
MSG_CHANNEL_OPEN,
common.NS(channel.name)
+ struct.pack(
">3L",
self.localChannelID,
channel.localWindowSize,
channel.localMaxPacket,
)
+ extra,
)
channel.id = self.localChannelID
self.channels[self.localChannelID] = channel
self.localChannelID += 1
def sendRequest(self, channel, requestType, data, wantReply=0):
"""
Send a request to a channel.
@type channel: subclass of C{SSHChannel}
@type requestType: L{bytes}
@type data: L{bytes}
@type wantReply: L{bool}
@rtype: C{Deferred}/L{None}
"""
if channel.localClosed:
return
self._log.debug("sending request {requestType}", requestType=requestType)
self.transport.sendPacket(
MSG_CHANNEL_REQUEST,
struct.pack(">L", self.channelsToRemoteChannel[channel])
+ common.NS(requestType)
+ (b"\1" if wantReply else b"\0")
+ data,
)
if wantReply:
d = defer.Deferred()
self.deferreds.setdefault(channel.id, []).append(d)
return d
def adjustWindow(self, channel, bytesToAdd):
"""
Tell the other side that we will receive more data. This should not
normally need to be called as it is managed automatically.
@type channel: subclass of L{SSHChannel}
@type bytesToAdd: L{int}
"""
if channel.localClosed:
return # we're already closed
packet = struct.pack(">2L", self.channelsToRemoteChannel[channel], bytesToAdd)
self.transport.sendPacket(MSG_CHANNEL_WINDOW_ADJUST, packet)
self._log.debug(
"adding {bytesToAdd} to {localWindowLeft} in channel {id}",
bytesToAdd=bytesToAdd,
localWindowLeft=channel.localWindowLeft,
id=channel.id,
)
channel.localWindowLeft += bytesToAdd
def sendData(self, channel, data):
"""
Send data to a channel. This should not normally be used: instead use
channel.write(data) as it manages the window automatically.
@type channel: subclass of L{SSHChannel}
@type data: L{bytes}
"""
if channel.localClosed:
return # we're already closed
self.transport.sendPacket(
MSG_CHANNEL_DATA,
struct.pack(">L", self.channelsToRemoteChannel[channel]) + common.NS(data),
)
def sendExtendedData(self, channel, dataType, data):
"""
Send extended data to a channel. This should not normally be used:
instead use channel.writeExtendedData(data, dataType) as it manages
the window automatically.
@type channel: subclass of L{SSHChannel}
@type dataType: L{int}
@type data: L{bytes}
"""
if channel.localClosed:
return # we're already closed
self.transport.sendPacket(
MSG_CHANNEL_EXTENDED_DATA,
struct.pack(">2L", self.channelsToRemoteChannel[channel], dataType)
+ common.NS(data),
)
def sendEOF(self, channel):
"""
Send an EOF (End of File) for a channel.
@type channel: subclass of L{SSHChannel}
"""
if channel.localClosed:
return # we're already closed
self._log.debug("sending eof")
self.transport.sendPacket(
MSG_CHANNEL_EOF, struct.pack(">L", self.channelsToRemoteChannel[channel])
)
def sendClose(self, channel):
"""
Close a channel.
@type channel: subclass of L{SSHChannel}
"""
if channel.localClosed:
return # we're already closed
self._log.info("sending close {id}", id=channel.id)
self.transport.sendPacket(
MSG_CHANNEL_CLOSE, struct.pack(">L", self.channelsToRemoteChannel[channel])
)
channel.localClosed = True
if channel.localClosed and channel.remoteClosed:
self.channelClosed(channel)
# methods to override
def getChannel(self, channelType, windowSize, maxPacket, data):
"""
The other side requested a channel of some sort.
channelType is the type of channel being requested,
windowSize is the initial size of the remote window,
maxPacket is the largest packet we should send,
data is any other packet data (often nothing).
We return a subclass of L{SSHChannel}.
By default, this dispatches to a method 'channel_channelType' with any
non-alphanumerics in the channelType replace with _'s. If it cannot
find a suitable method, it returns an OPEN_UNKNOWN_CHANNEL_TYPE error.
The method is called with arguments of windowSize, maxPacket, data.
@type channelType: L{bytes}
@type windowSize: L{int}
@type maxPacket: L{int}
@type data: L{bytes}
@rtype: subclass of L{SSHChannel}/L{tuple}
"""
self._log.debug("got channel {channelType!r} request", channelType=channelType)
if hasattr(self.transport, "avatar"): # this is a server!
chan = self.transport.avatar.lookupChannel(
channelType, windowSize, maxPacket, data
)
else:
channelType = channelType.translate(TRANSLATE_TABLE)
attr = "channel_%s" % nativeString(channelType)
f = getattr(self, attr, None)
if f is not None:
chan = f(windowSize, maxPacket, data)
else:
chan = None
if chan is None:
raise error.ConchError("unknown channel", OPEN_UNKNOWN_CHANNEL_TYPE)
else:
chan.conn = self
return chan
def gotGlobalRequest(self, requestType, data):
"""
We got a global request. pretty much, this is just used by the client
to request that we forward a port from the server to the client.
Returns either:
- 1: request accepted
- 1, <data>: request accepted with request specific data
- 0: request denied
By default, this dispatches to a method 'global_requestType' with
-'s in requestType replaced with _'s. The found method is passed data.
If this method cannot be found, this method returns 0. Otherwise, it
returns the return value of that method.
@type requestType: L{bytes}
@type data: L{bytes}
@rtype: L{int}/L{tuple}
"""
self._log.debug("got global {requestType} request", requestType=requestType)
if hasattr(self.transport, "avatar"): # this is a server!
return self.transport.avatar.gotGlobalRequest(requestType, data)
requestType = nativeString(requestType.replace(b"-", b"_"))
f = getattr(self, "global_%s" % requestType, None)
if not f:
return 0
return f(data)
def channelClosed(self, channel):
"""
Called when a channel is closed.
It clears the local state related to the channel, and calls
channel.closed().
MAKE SURE YOU CALL THIS METHOD, even if you subclass L{SSHConnection}.
If you don't, things will break mysteriously.
@type channel: L{SSHChannel}
"""
if channel in self.channelsToRemoteChannel: # actually open
channel.localClosed = channel.remoteClosed = True
del self.localToRemoteChannel[channel.id]
del self.channels[channel.id]
del self.channelsToRemoteChannel[channel]
for d in self.deferreds.pop(channel.id, []):
d.errback(error.ConchError("Channel closed."))
channel.closed()
MSG_GLOBAL_REQUEST = 80
MSG_REQUEST_SUCCESS = 81
MSG_REQUEST_FAILURE = 82
MSG_CHANNEL_OPEN = 90
MSG_CHANNEL_OPEN_CONFIRMATION = 91
MSG_CHANNEL_OPEN_FAILURE = 92
MSG_CHANNEL_WINDOW_ADJUST = 93
MSG_CHANNEL_DATA = 94
MSG_CHANNEL_EXTENDED_DATA = 95
MSG_CHANNEL_EOF = 96
MSG_CHANNEL_CLOSE = 97
MSG_CHANNEL_REQUEST = 98
MSG_CHANNEL_SUCCESS = 99
MSG_CHANNEL_FAILURE = 100
OPEN_ADMINISTRATIVELY_PROHIBITED = 1
OPEN_CONNECT_FAILED = 2
OPEN_UNKNOWN_CHANNEL_TYPE = 3
OPEN_RESOURCE_SHORTAGE = 4
# From RFC 4254
EXTENDED_DATA_STDERR = 1
messages = {}
for name, value in locals().copy().items():
if name[:4] == "MSG_":
messages[value] = name # Doesn't handle doubles
alphanums = networkString(string.ascii_letters + string.digits)
TRANSLATE_TABLE = bytes(i if i in alphanums else ord("_") for i in range(256))
SSHConnection.protocolMessages = messages

View File

@@ -0,0 +1,129 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
A Factory for SSH servers.
See also L{twisted.conch.openssh_compat.factory} for OpenSSH compatibility.
Maintainer: Paul Swartz
"""
import random
from itertools import chain
from typing import Dict, List, Optional, Tuple
from twisted.conch import error
from twisted.conch.ssh import _kex, connection, transport, userauth
from twisted.internet import protocol
from twisted.logger import Logger
class SSHFactory(protocol.Factory):
"""
A Factory for SSH servers.
"""
primes: Optional[Dict[int, List[Tuple[int, int]]]]
_log = Logger()
protocol = transport.SSHServerTransport
services = {
b"ssh-userauth": userauth.SSHUserAuthServer,
b"ssh-connection": connection.SSHConnection,
}
def startFactory(self) -> None:
"""
Check for public and private keys.
"""
if not hasattr(self, "publicKeys"):
self.publicKeys = self.getPublicKeys()
if not hasattr(self, "privateKeys"):
self.privateKeys = self.getPrivateKeys()
if not self.publicKeys or not self.privateKeys:
raise error.ConchError("no host keys, failing")
if not hasattr(self, "primes"):
self.primes = self.getPrimes()
def buildProtocol(self, addr):
"""
Create an instance of the server side of the SSH protocol.
@type addr: L{twisted.internet.interfaces.IAddress} provider
@param addr: The address at which the server will listen.
@rtype: L{twisted.conch.ssh.transport.SSHServerTransport}
@return: The built transport.
"""
t = protocol.Factory.buildProtocol(self, addr)
t.supportedPublicKeys = list(
chain.from_iterable(
key.supportedSignatureAlgorithms() for key in self.privateKeys.values()
)
)
if not self.primes:
self._log.info(
"disabling non-fixed-group key exchange algorithms "
"because we cannot find moduli file"
)
t.supportedKeyExchanges = [
kexAlgorithm
for kexAlgorithm in t.supportedKeyExchanges
if _kex.isFixedGroup(kexAlgorithm) or _kex.isEllipticCurve(kexAlgorithm)
]
return t
def getPublicKeys(self):
"""
Called when the factory is started to get the public portions of the
servers host keys. Returns a dictionary mapping SSH key types to
public key strings.
@rtype: L{dict}
"""
raise NotImplementedError("getPublicKeys unimplemented")
def getPrivateKeys(self):
"""
Called when the factory is started to get the private portions of the
servers host keys. Returns a dictionary mapping SSH key types to
L{twisted.conch.ssh.keys.Key} objects.
@rtype: L{dict}
"""
raise NotImplementedError("getPrivateKeys unimplemented")
def getPrimes(self) -> Optional[Dict[int, List[Tuple[int, int]]]]:
"""
Called when the factory is started to get Diffie-Hellman generators and
primes to use. Returns a dictionary mapping number of bits to lists of
tuple of (generator, prime).
"""
def getDHPrime(self, bits: int) -> Tuple[int, int]:
"""
Return a tuple of (g, p) for a Diffe-Hellman process, with p being as
close to C{bits} bits as possible.
"""
def keyfunc(i: int) -> int:
return abs(i - bits)
assert self.primes is not None, "Factory should have been started by now."
primesKeys = sorted(self.primes.keys(), key=keyfunc)
realBits = primesKeys[0]
return random.choice(self.primes[realBits])
def getService(self, transport, service):
"""
Return a class to use as a service for the given transport.
@type transport: L{transport.SSHServerTransport}
@type service: L{bytes}
@rtype: subclass of L{service.SSHService}
"""
if service == b"ssh-userauth" or hasattr(transport, "avatar"):
return self.services[service]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,272 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
This module contains the implementation of the TCP forwarding, which allows
clients and servers to forward arbitrary TCP data across the connection.
Maintainer: Paul Swartz
"""
import struct
from twisted.conch.ssh import channel, common
from twisted.internet import protocol, reactor
from twisted.internet.endpoints import HostnameEndpoint, connectProtocol
class SSHListenForwardingFactory(protocol.Factory):
def __init__(self, connection, hostport, klass):
self.conn = connection
self.hostport = hostport # tuple
self.klass = klass
def buildProtocol(self, addr):
channel = self.klass(conn=self.conn)
client = SSHForwardingClient(channel)
channel.client = client
addrTuple = (addr.host, addr.port)
channelOpenData = packOpen_direct_tcpip(self.hostport, addrTuple)
self.conn.openChannel(channel, channelOpenData)
return client
class SSHListenForwardingChannel(channel.SSHChannel):
def channelOpen(self, specificData):
self._log.info("opened forwarding channel {id}", id=self.id)
if len(self.client.buf) > 1:
b = self.client.buf[1:]
self.write(b)
self.client.buf = b""
def openFailed(self, reason):
self.closed()
def dataReceived(self, data):
self.client.transport.write(data)
def eofReceived(self):
self.client.transport.loseConnection()
def closed(self):
if hasattr(self, "client"):
self._log.info("closing local forwarding channel {id}", id=self.id)
self.client.transport.loseConnection()
del self.client
class SSHListenClientForwardingChannel(SSHListenForwardingChannel):
name = b"direct-tcpip"
class SSHListenServerForwardingChannel(SSHListenForwardingChannel):
name = b"forwarded-tcpip"
class SSHConnectForwardingChannel(channel.SSHChannel):
"""
Channel used for handling server side forwarding request.
It acts as a client for the remote forwarding destination.
@ivar hostport: C{(host, port)} requested by client as forwarding
destination.
@type hostport: L{tuple} or a C{sequence}
@ivar client: Protocol connected to the forwarding destination.
@type client: L{protocol.Protocol}
@ivar clientBuf: Data received while forwarding channel is not yet
connected.
@type clientBuf: L{bytes}
@var _reactor: Reactor used for TCP connections.
@type _reactor: A reactor.
@ivar _channelOpenDeferred: Deferred used in testing to check the
result of C{channelOpen}.
@type _channelOpenDeferred: L{twisted.internet.defer.Deferred}
"""
_reactor = reactor
def __init__(self, hostport, *args, **kw):
channel.SSHChannel.__init__(self, *args, **kw)
self.hostport = hostport
self.client = None
self.clientBuf = b""
def channelOpen(self, specificData):
"""
See: L{channel.SSHChannel}
"""
self._log.info(
"connecting to {host}:{port}", host=self.hostport[0], port=self.hostport[1]
)
ep = HostnameEndpoint(self._reactor, self.hostport[0], self.hostport[1])
d = connectProtocol(ep, SSHForwardingClient(self))
d.addCallbacks(self._setClient, self._close)
self._channelOpenDeferred = d
def _setClient(self, client):
"""
Called when the connection was established to the forwarding
destination.
@param client: Client protocol connected to the forwarding destination.
@type client: L{protocol.Protocol}
"""
self.client = client
self._log.info(
"connected to {host}:{port}", host=self.hostport[0], port=self.hostport[1]
)
if self.clientBuf:
self.client.transport.write(self.clientBuf)
self.clientBuf = None
if self.client.buf[1:]:
self.write(self.client.buf[1:])
self.client.buf = b""
def _close(self, reason):
"""
Called when failed to connect to the forwarding destination.
@param reason: Reason why connection failed.
@type reason: L{twisted.python.failure.Failure}
"""
self._log.error(
"failed to connect to {host}:{port}: {reason}",
host=self.hostport[0],
port=self.hostport[1],
reason=reason,
)
self.loseConnection()
def dataReceived(self, data):
"""
See: L{channel.SSHChannel}
"""
if self.client:
self.client.transport.write(data)
else:
self.clientBuf += data
def closed(self):
"""
See: L{channel.SSHChannel}
"""
if self.client:
self._log.info("closed remote forwarding channel {id}", id=self.id)
if self.client.channel:
self.loseConnection()
self.client.transport.loseConnection()
del self.client
def openConnectForwardingClient(remoteWindow, remoteMaxPacket, data, avatar):
remoteHP, origHP = unpackOpen_direct_tcpip(data)
return SSHConnectForwardingChannel(
remoteHP,
remoteWindow=remoteWindow,
remoteMaxPacket=remoteMaxPacket,
avatar=avatar,
)
class SSHForwardingClient(protocol.Protocol):
def __init__(self, channel):
self.channel = channel
self.buf = b"\000"
def dataReceived(self, data):
if self.buf:
self.buf += data
else:
self.channel.write(data)
def connectionLost(self, reason):
if self.channel:
self.channel.loseConnection()
self.channel = None
def packOpen_direct_tcpip(destination, source):
"""
Pack the data suitable for sending in a CHANNEL_OPEN packet.
@type destination: L{tuple}
@param destination: A tuple of the (host, port) of the destination host.
@type source: L{tuple}
@param source: A tuple of the (host, port) of the source host.
"""
(connHost, connPort) = destination
(origHost, origPort) = source
if isinstance(connHost, str):
connHost = connHost.encode("utf-8")
if isinstance(origHost, str):
origHost = origHost.encode("utf-8")
conn = common.NS(connHost) + struct.pack(">L", connPort)
orig = common.NS(origHost) + struct.pack(">L", origPort)
return conn + orig
packOpen_forwarded_tcpip = packOpen_direct_tcpip
def unpackOpen_direct_tcpip(data):
"""Unpack the data to a usable format."""
connHost, rest = common.getNS(data)
if isinstance(connHost, bytes):
connHost = connHost.decode("utf-8")
connPort = int(struct.unpack(">L", rest[:4])[0])
origHost, rest = common.getNS(rest[4:])
if isinstance(origHost, bytes):
origHost = origHost.decode("utf-8")
origPort = int(struct.unpack(">L", rest[:4])[0])
return (connHost, connPort), (origHost, origPort)
unpackOpen_forwarded_tcpip = unpackOpen_direct_tcpip
def packGlobal_tcpip_forward(peer):
"""
Pack the data for tcpip forwarding.
@param peer: A tuple of the (host, port) .
@type peer: L{tuple}
"""
(host, port) = peer
return common.NS(host) + struct.pack(">L", port)
def unpackGlobal_tcpip_forward(data):
host, rest = common.getNS(data)
if isinstance(host, bytes):
host = host.decode("utf-8")
port = int(struct.unpack(">L", rest[:4])[0])
return host, port
"""This is how the data -> eof -> close stuff /should/ work.
debug3: channel 1: waiting for connection
debug1: channel 1: connected
debug1: channel 1: read<=0 rfd 7 len 0
debug1: channel 1: read failed
debug1: channel 1: close_read
debug1: channel 1: input open -> drain
debug1: channel 1: ibuf empty
debug1: channel 1: send eof
debug1: channel 1: input drain -> closed
debug1: channel 1: rcvd eof
debug1: channel 1: output open -> drain
debug1: channel 1: obuf empty
debug1: channel 1: close_write
debug1: channel 1: output drain -> closed
debug1: channel 1: rcvd close
debug3: channel 1: will not send data after close
debug1: channel 1: send close
debug1: channel 1: is dead
"""

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,56 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
The parent class for all the SSH services. Currently implemented services
are ssh-userauth and ssh-connection.
Maintainer: Paul Swartz
"""
from typing import Dict
from twisted.logger import Logger
class SSHService:
# this is the ssh name for the service:
name: bytes = None # type:ignore[assignment]
protocolMessages: Dict[int, str] = {} # map #'s -> protocol names
transport = None # gets set later
_log = Logger()
def serviceStarted(self):
"""
called when the service is active on the transport.
"""
def serviceStopped(self):
"""
called when the service is stopped, either by the connection ending
or by another service being started
"""
def logPrefix(self):
return "SSHService {!r} on {}".format(
self.name, self.transport.transport.logPrefix()
)
def packetReceived(self, messageNum, packet):
"""
called when we receive a packet on the transport
"""
# print self.protocolMessages
if messageNum in self.protocolMessages:
messageType = self.protocolMessages[messageNum]
f = getattr(self, "ssh_%s" % messageType[4:], None)
if f is not None:
return f(packet)
self._log.info(
"couldn't handle {messageNum} {packet!r}",
messageNum=messageNum,
packet=packet,
)
self.transport.sendUnimplemented()

View File

@@ -0,0 +1,440 @@
# -*- test-case-name: twisted.conch.test.test_session -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
This module contains the implementation of SSHSession, which (by default)
allows access to a shell and a python interpreter over SSH.
Maintainer: Paul Swartz
"""
import os
import signal
import struct
import sys
from zope.interface import implementer
from twisted.conch.interfaces import (
EnvironmentVariableNotPermitted,
ISession,
ISessionSetEnv,
)
from twisted.conch.ssh import channel, common, connection
from twisted.internet import interfaces, protocol
from twisted.logger import Logger
from twisted.python.compat import networkString
log = Logger()
class SSHSession(channel.SSHChannel):
"""
A generalized implementation of an SSH session.
See RFC 4254, section 6.
The precise implementation of the various operations that the remote end
can send is left up to the avatar, usually via an adapter to an
interface such as L{ISession}.
@ivar buf: a buffer for data received before making a connection to a
client.
@type buf: L{bytes}
@ivar client: a protocol for communication with a shell, an application
program, or a subsystem (see RFC 4254, section 6.5).
@type client: L{SSHSessionProcessProtocol}
@ivar session: an object providing concrete implementations of session
operations.
@type session: L{ISession}
"""
name = b"session"
def __init__(self, *args, **kw):
channel.SSHChannel.__init__(self, *args, **kw)
self.buf = b""
self.client = None
self.session = None
def request_subsystem(self, data):
subsystem, ignored = common.getNS(data)
log.info('Asking for subsystem "{subsystem}"', subsystem=subsystem)
client = self.avatar.lookupSubsystem(subsystem, data)
if client:
pp = SSHSessionProcessProtocol(self)
proto = wrapProcessProtocol(pp)
client.makeConnection(proto)
pp.makeConnection(wrapProtocol(client))
self.client = pp
return 1
else:
log.error("Failed to get subsystem")
return 0
def request_shell(self, data):
log.info("Getting shell")
if not self.session:
self.session = ISession(self.avatar)
try:
pp = SSHSessionProcessProtocol(self)
self.session.openShell(pp)
except Exception:
log.failure("Error getting shell")
return 0
else:
self.client = pp
return 1
def request_exec(self, data):
if not self.session:
self.session = ISession(self.avatar)
f, data = common.getNS(data)
log.info('Executing command "{f}"', f=f)
try:
pp = SSHSessionProcessProtocol(self)
self.session.execCommand(pp, f)
except Exception:
log.failure('Error executing command "{f}"', f=f)
return 0
else:
self.client = pp
return 1
def request_pty_req(self, data):
if not self.session:
self.session = ISession(self.avatar)
term, windowSize, modes = parseRequest_pty_req(data)
log.info(
"Handling pty request: {term!r} {windowSize!r}",
term=term,
windowSize=windowSize,
)
try:
self.session.getPty(term, windowSize, modes)
except Exception:
log.failure("Error handling pty request")
return 0
else:
return 1
def request_env(self, data):
"""
Process a request to pass an environment variable.
@param data: The environment variable name and value, each encoded
as an SSH protocol string and concatenated.
@type data: L{bytes}
@return: A true value if the request to pass this environment
variable was accepted, otherwise a false value.
"""
if not self.session:
self.session = ISession(self.avatar)
if not ISessionSetEnv.providedBy(self.session):
return 0
name, value, data = common.getNS(data, 2)
try:
self.session.setEnv(name, value)
except EnvironmentVariableNotPermitted:
return 0
except Exception:
log.failure("Error setting environment variable {name}", name=name)
return 0
else:
return 1
def request_window_change(self, data):
if not self.session:
self.session = ISession(self.avatar)
winSize = parseRequest_window_change(data)
try:
self.session.windowChanged(winSize)
except Exception:
log.failure("Error changing window size")
return 0
else:
return 1
def dataReceived(self, data):
if not self.client:
# self.conn.sendClose(self)
self.buf += data
return
self.client.transport.write(data)
def extReceived(self, dataType, data):
if dataType == connection.EXTENDED_DATA_STDERR:
if self.client and hasattr(self.client.transport, "writeErr"):
self.client.transport.writeErr(data)
else:
log.warn("Weird extended data: {dataType}", dataType=dataType)
def eofReceived(self):
# If we have a session, tell it that EOF has been received and
# expect it to send a close message (it may need to send other
# messages such as exit-status or exit-signal first). If we don't
# have a session, then just send a close message directly.
if self.session:
self.session.eofReceived()
elif self.client:
self.conn.sendClose(self)
def closed(self):
if self.client and self.client.transport:
self.client.transport.loseConnection()
if self.session:
self.session.closed()
# def closeReceived(self):
# self.loseConnection() # don't know what to do with this
def loseConnection(self):
if self.client:
self.client.transport.loseConnection()
channel.SSHChannel.loseConnection(self)
class _ProtocolWrapper(protocol.ProcessProtocol):
"""
This class wraps a L{Protocol} instance in a L{ProcessProtocol} instance.
"""
def __init__(self, proto):
self.proto = proto
def connectionMade(self):
self.proto.connectionMade()
def outReceived(self, data):
self.proto.dataReceived(data)
def processEnded(self, reason):
self.proto.connectionLost(reason)
class _DummyTransport:
def __init__(self, proto):
self.proto = proto
def dataReceived(self, data):
self.proto.transport.write(data)
def write(self, data):
self.proto.dataReceived(data)
def writeSequence(self, seq):
self.write(b"".join(seq))
def loseConnection(self):
self.proto.connectionLost(protocol.connectionDone)
def wrapProcessProtocol(inst):
if isinstance(inst, protocol.Protocol):
return _ProtocolWrapper(inst)
else:
return inst
def wrapProtocol(proto):
return _DummyTransport(proto)
# SUPPORTED_SIGNALS is a list of signals that every session channel is supposed
# to accept. See RFC 4254
SUPPORTED_SIGNALS = [
"ABRT",
"ALRM",
"FPE",
"HUP",
"ILL",
"INT",
"KILL",
"PIPE",
"QUIT",
"SEGV",
"TERM",
"USR1",
"USR2",
]
@implementer(interfaces.ITransport)
class SSHSessionProcessProtocol(protocol.ProcessProtocol):
"""I am both an L{IProcessProtocol} and an L{ITransport}.
I am a transport to the remote endpoint and a process protocol to the
local subsystem.
"""
# once initialized, a dictionary mapping signal values to strings
# that follow RFC 4254.
_signalValuesToNames = None
def __init__(self, session):
self.session = session
self.lostOutOrErrFlag = False
def connectionMade(self):
if self.session.buf:
self.transport.write(self.session.buf)
self.session.buf = None
def outReceived(self, data):
self.session.write(data)
def errReceived(self, err):
self.session.writeExtended(connection.EXTENDED_DATA_STDERR, err)
def outConnectionLost(self):
"""
EOF should only be sent when both STDOUT and STDERR have been closed.
"""
if self.lostOutOrErrFlag:
self.session.conn.sendEOF(self.session)
else:
self.lostOutOrErrFlag = True
def errConnectionLost(self):
"""
See outConnectionLost().
"""
self.outConnectionLost()
def connectionLost(self, reason=None):
self.session.loseConnection()
def _getSignalName(self, signum):
"""
Get a signal name given a signal number.
"""
if self._signalValuesToNames is None:
self._signalValuesToNames = {}
# make sure that the POSIX ones are the defaults
for signame in SUPPORTED_SIGNALS:
signame = "SIG" + signame
sigvalue = getattr(signal, signame, None)
if sigvalue is not None:
self._signalValuesToNames[sigvalue] = signame
for k, v in signal.__dict__.items():
# Check for platform specific signals, ignoring Python specific
# SIG_DFL and SIG_IGN
if k.startswith("SIG") and not k.startswith("SIG_"):
if v not in self._signalValuesToNames:
self._signalValuesToNames[v] = k + "@" + sys.platform
return self._signalValuesToNames[signum]
def processEnded(self, reason=None):
"""
When we are told the process ended, try to notify the other side about
how the process ended using the exit-signal or exit-status requests.
Also, close the channel.
"""
if reason is not None:
err = reason.value
if err.signal is not None:
signame = self._getSignalName(err.signal)
if getattr(os, "WCOREDUMP", None) is not None and os.WCOREDUMP(
err.status
):
log.info("exitSignal: {signame} (core dumped)", signame=signame)
coreDumped = True
else:
log.info("exitSignal: {}", signame=signame)
coreDumped = False
self.session.conn.sendRequest(
self.session,
b"exit-signal",
common.NS(networkString(signame[3:]))
+ (b"\1" if coreDumped else b"\0")
+ common.NS(b"")
+ common.NS(b""),
)
elif err.exitCode is not None:
log.info("exitCode: {exitCode!r}", exitCode=err.exitCode)
self.session.conn.sendRequest(
self.session, b"exit-status", struct.pack(">L", err.exitCode)
)
self.session.loseConnection()
def getHost(self):
"""
Return the host from my session's transport.
"""
return self.session.conn.transport.getHost()
def getPeer(self):
"""
Return the peer from my session's transport.
"""
return self.session.conn.transport.getPeer()
def write(self, data):
self.session.write(data)
def writeSequence(self, seq):
self.session.write(b"".join(seq))
def loseConnection(self):
self.session.loseConnection()
class SSHSessionClient(protocol.Protocol):
def dataReceived(self, data):
if self.transport:
self.transport.write(data)
# methods factored out to make live easier on server writers
def parseRequest_pty_req(data):
"""Parse the data from a pty-req request into usable data.
@returns: a tuple of (terminal type, (rows, cols, xpixel, ypixel), modes)
"""
term, rest = common.getNS(data)
cols, rows, xpixel, ypixel = struct.unpack(">4L", rest[:16])
modes, ignored = common.getNS(rest[16:])
winSize = (rows, cols, xpixel, ypixel)
modes = [
(ord(modes[i : i + 1]), struct.unpack(">L", modes[i + 1 : i + 5])[0])
for i in range(0, len(modes) - 1, 5)
]
return term, winSize, modes
def packRequest_pty_req(term, geometry, modes):
"""
Pack a pty-req request so that it is suitable for sending.
NOTE: modes must be packed before being sent here.
@type geometry: L{tuple}
@param geometry: A tuple of (rows, columns, xpixel, ypixel)
"""
(rows, cols, xpixel, ypixel) = geometry
termPacked = common.NS(term)
winSizePacked = struct.pack(">4L", cols, rows, xpixel, ypixel)
modesPacked = common.NS(modes) # depend on the client packing modes
return termPacked + winSizePacked + modesPacked
def parseRequest_window_change(data):
"""Parse the data from a window-change request into usuable data.
@returns: a tuple of (rows, cols, xpixel, ypixel)
"""
cols, rows, xpixel, ypixel = struct.unpack(">4L", data)
return rows, cols, xpixel, ypixel
def packRequest_window_change(geometry):
"""
Pack a window-change request so that it is suitable for sending.
@type geometry: L{tuple}
@param geometry: A tuple of (rows, columns, xpixel, ypixel)
"""
(rows, cols, xpixel, ypixel) = geometry
return struct.pack(">4L", cols, rows, xpixel, ypixel)

View File

@@ -0,0 +1,40 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
def parse(s):
s = s.strip()
expr = []
while s:
if s[0:1] == b"(":
newSexp = []
if expr:
expr[-1].append(newSexp)
expr.append(newSexp)
s = s[1:]
continue
if s[0:1] == b")":
aList = expr.pop()
s = s[1:]
if not expr:
assert not s
return aList
continue
i = 0
while s[i : i + 1].isdigit():
i += 1
assert i
length = int(s[:i])
data = s[i + 1 : i + 1 + length]
expr[-1].append(data)
s = s[i + 1 + length :]
assert False, "this should not happen"
def pack(sexp):
return b"".join(
b"(%b)" % (pack(o),)
if type(o) in (type(()), type([]))
else b"%d:%b" % (len(o), o)
for o in sexp
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,764 @@
# -*- test-case-name: twisted.conch.test.test_userauth -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Implementation of the ssh-userauth service.
Currently implemented authentication types are public-key and password.
Maintainer: Paul Swartz
"""
import struct
from twisted.conch import error, interfaces
from twisted.conch.ssh import keys, service, transport
from twisted.conch.ssh.common import NS, getNS
from twisted.cred import credentials
from twisted.cred.error import UnauthorizedLogin
from twisted.internet import defer, reactor
from twisted.logger import Logger
from twisted.python import failure
from twisted.python.compat import nativeString
class SSHUserAuthServer(service.SSHService):
"""
A service implementing the server side of the 'ssh-userauth' service. It
is used to authenticate the user on the other side as being able to access
this server.
@ivar name: the name of this service: 'ssh-userauth'
@type name: L{bytes}
@ivar authenticatedWith: a list of authentication methods that have
already been used.
@type authenticatedWith: L{list}
@ivar loginTimeout: the number of seconds we wait before disconnecting
the user for taking too long to authenticate
@type loginTimeout: L{int}
@ivar attemptsBeforeDisconnect: the number of failed login attempts we
allow before disconnecting.
@type attemptsBeforeDisconnect: L{int}
@ivar loginAttempts: the number of login attempts that have been made
@type loginAttempts: L{int}
@ivar passwordDelay: the number of seconds to delay when the user gives
an incorrect password
@type passwordDelay: L{int}
@ivar interfaceToMethod: a L{dict} mapping credential interfaces to
authentication methods. The server checks to see which of the
cred interfaces have checkers and tells the client that those methods
are valid for authentication.
@type interfaceToMethod: L{dict}
@ivar supportedAuthentications: A list of the supported authentication
methods.
@type supportedAuthentications: L{list} of L{bytes}
@ivar user: the last username the client tried to authenticate with
@type user: L{bytes}
@ivar method: the current authentication method
@type method: L{bytes}
@ivar nextService: the service the user wants started after authentication
has been completed.
@type nextService: L{bytes}
@ivar portal: the L{twisted.cred.portal.Portal} we are using for
authentication
@type portal: L{twisted.cred.portal.Portal}
@ivar clock: an object with a callLater method. Stubbed out for testing.
"""
name = b"ssh-userauth"
loginTimeout = 10 * 60 * 60
# 10 minutes before we disconnect them
attemptsBeforeDisconnect = 20
# 20 login attempts before a disconnect
passwordDelay = 1 # number of seconds to delay on a failed password
clock = reactor
interfaceToMethod = {
credentials.ISSHPrivateKey: b"publickey",
credentials.IUsernamePassword: b"password",
}
_log = Logger()
def serviceStarted(self):
"""
Called when the userauth service is started. Set up instance
variables, check if we should allow password authentication (only
allow if the outgoing connection is encrypted) and set up a login
timeout.
"""
self.authenticatedWith = []
self.loginAttempts = 0
self.user = None
self.nextService = None
self.portal = self.transport.factory.portal
self.supportedAuthentications = []
for i in self.portal.listCredentialsInterfaces():
if i in self.interfaceToMethod:
self.supportedAuthentications.append(self.interfaceToMethod[i])
if not self.transport.isEncrypted("in"):
# don't let us transport password in plaintext
if b"password" in self.supportedAuthentications:
self.supportedAuthentications.remove(b"password")
self._cancelLoginTimeout = self.clock.callLater(
self.loginTimeout, self.timeoutAuthentication
)
def serviceStopped(self):
"""
Called when the userauth service is stopped. Cancel the login timeout
if it's still going.
"""
if self._cancelLoginTimeout:
self._cancelLoginTimeout.cancel()
self._cancelLoginTimeout = None
def timeoutAuthentication(self):
"""
Called when the user has timed out on authentication. Disconnect
with a DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE message.
"""
self._cancelLoginTimeout = None
self.transport.sendDisconnect(
transport.DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE, b"you took too long"
)
def tryAuth(self, kind, user, data):
"""
Try to authenticate the user with the given method. Dispatches to a
auth_* method.
@param kind: the authentication method to try.
@type kind: L{bytes}
@param user: the username the client is authenticating with.
@type user: L{bytes}
@param data: authentication specific data sent by the client.
@type data: L{bytes}
@return: A Deferred called back if the method succeeded, or erred back
if it failed.
@rtype: C{defer.Deferred}
"""
self._log.debug("{user!r} trying auth {kind!r}", user=user, kind=kind)
if kind not in self.supportedAuthentications:
return defer.fail(error.ConchError("unsupported authentication, failing"))
kind = nativeString(kind.replace(b"-", b"_"))
f = getattr(self, f"auth_{kind}", None)
if f:
ret = f(data)
if not ret:
return defer.fail(
error.ConchError(f"{kind} return None instead of a Deferred")
)
else:
return ret
return defer.fail(error.ConchError(f"bad auth type: {kind}"))
def ssh_USERAUTH_REQUEST(self, packet):
"""
The client has requested authentication. Payload::
string user
string next service
string method
<authentication specific data>
@type packet: L{bytes}
"""
user, nextService, method, rest = getNS(packet, 3)
if user != self.user or nextService != self.nextService:
self.authenticatedWith = [] # clear auth state
self.user = user
self.nextService = nextService
self.method = method
d = self.tryAuth(method, user, rest)
if not d:
self._ebBadAuth(failure.Failure(error.ConchError("auth returned none")))
return
d.addCallback(self._cbFinishedAuth)
d.addErrback(self._ebMaybeBadAuth)
d.addErrback(self._ebBadAuth)
return d
def _cbFinishedAuth(self, result):
"""
The callback when user has successfully been authenticated. For a
description of the arguments, see L{twisted.cred.portal.Portal.login}.
We start the service requested by the user.
"""
(interface, avatar, logout) = result
self.transport.avatar = avatar
self.transport.logoutFunction = logout
service = self.transport.factory.getService(self.transport, self.nextService)
if not service:
raise error.ConchError(f"could not get next service: {self.nextService}")
self._log.debug(
"{user!r} authenticated with {method!r}", user=self.user, method=self.method
)
self.transport.sendPacket(MSG_USERAUTH_SUCCESS, b"")
self.transport.setService(service())
def _ebMaybeBadAuth(self, reason):
"""
An intermediate errback. If the reason is
error.NotEnoughAuthentication, we send a MSG_USERAUTH_FAILURE, but
with the partial success indicator set.
@type reason: L{twisted.python.failure.Failure}
"""
reason.trap(error.NotEnoughAuthentication)
self.transport.sendPacket(
MSG_USERAUTH_FAILURE, NS(b",".join(self.supportedAuthentications)) + b"\xff"
)
def _ebBadAuth(self, reason):
"""
The final errback in the authentication chain. If the reason is
error.IgnoreAuthentication, we simply return; the authentication
method has sent its own response. Otherwise, send a failure message
and (if the method is not 'none') increment the number of login
attempts.
@type reason: L{twisted.python.failure.Failure}
"""
if reason.check(error.IgnoreAuthentication):
return
if self.method != b"none":
self._log.debug(
"{user!r} failed auth {method!r}", user=self.user, method=self.method
)
if reason.check(UnauthorizedLogin):
self._log.debug(
"unauthorized login: {message}", message=reason.getErrorMessage()
)
elif reason.check(error.ConchError):
self._log.debug("reason: {reason}", reason=reason.getErrorMessage())
else:
self._log.failure(
"Error checking auth for user {user}",
failure=reason,
user=self.user,
)
self.loginAttempts += 1
if self.loginAttempts > self.attemptsBeforeDisconnect:
self.transport.sendDisconnect(
transport.DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE,
b"too many bad auths",
)
return
self.transport.sendPacket(
MSG_USERAUTH_FAILURE, NS(b",".join(self.supportedAuthentications)) + b"\x00"
)
def auth_publickey(self, packet):
"""
Public key authentication. Payload::
byte has signature
string algorithm name
string key blob
[string signature] (if has signature is True)
Create a SSHPublicKey credential and verify it using our portal.
"""
hasSig = ord(packet[0:1])
algName, blob, rest = getNS(packet[1:], 2)
try:
keys.Key.fromString(blob)
except keys.BadKeyError:
error = "Unsupported key type {} or bad key".format(algName.decode("ascii"))
self._log.error(error)
return defer.fail(UnauthorizedLogin(error))
signature = hasSig and getNS(rest)[0] or None
if hasSig:
b = (
NS(self.transport.sessionID)
+ bytes((MSG_USERAUTH_REQUEST,))
+ NS(self.user)
+ NS(self.nextService)
+ NS(b"publickey")
+ bytes((hasSig,))
+ NS(algName)
+ NS(blob)
)
c = credentials.SSHPrivateKey(self.user, algName, blob, b, signature)
return self.portal.login(c, None, interfaces.IConchUser)
else:
c = credentials.SSHPrivateKey(self.user, algName, blob, None, None)
return self.portal.login(c, None, interfaces.IConchUser).addErrback(
self._ebCheckKey, packet[1:]
)
def _ebCheckKey(self, reason, packet):
"""
Called back if the user did not sent a signature. If reason is
error.ValidPublicKey then this key is valid for the user to
authenticate with. Send MSG_USERAUTH_PK_OK.
"""
reason.trap(error.ValidPublicKey)
# if we make it here, it means that the publickey is valid
self.transport.sendPacket(MSG_USERAUTH_PK_OK, packet)
return failure.Failure(error.IgnoreAuthentication())
def auth_password(self, packet):
"""
Password authentication. Payload::
string password
Make a UsernamePassword credential and verify it with our portal.
"""
password = getNS(packet[1:])[0]
c = credentials.UsernamePassword(self.user, password)
return self.portal.login(c, None, interfaces.IConchUser).addErrback(
self._ebPassword
)
def _ebPassword(self, f):
"""
If the password is invalid, wait before sending the failure in order
to delay brute-force password guessing.
"""
d = defer.Deferred()
self.clock.callLater(self.passwordDelay, d.callback, f)
return d
class SSHUserAuthClient(service.SSHService):
"""
A service implementing the client side of 'ssh-userauth'.
This service will try all authentication methods provided by the server,
making callbacks for more information when necessary.
@ivar name: the name of this service: 'ssh-userauth'
@type name: L{str}
@ivar preferredOrder: a list of authentication methods that should be used
first, in order of preference, if supported by the server
@type preferredOrder: L{list}
@ivar user: the name of the user to authenticate as
@type user: L{bytes}
@ivar instance: the service to start after authentication has finished
@type instance: L{service.SSHService}
@ivar authenticatedWith: a list of strings of authentication methods we've tried
@type authenticatedWith: L{list} of L{bytes}
@ivar triedPublicKeys: a list of public key objects that we've tried to
authenticate with
@type triedPublicKeys: L{list} of L{Key}
@ivar lastPublicKey: the last public key object we've tried to authenticate
with
@type lastPublicKey: L{Key}
"""
name = b"ssh-userauth"
preferredOrder = [b"publickey", b"password", b"keyboard-interactive"]
def __init__(self, user, instance):
self.user = user
self.instance = instance
def serviceStarted(self):
self.authenticatedWith = []
self.triedPublicKeys = []
self.lastPublicKey = None
self.askForAuth(b"none", b"")
def askForAuth(self, kind, extraData):
"""
Send a MSG_USERAUTH_REQUEST.
@param kind: the authentication method to try.
@type kind: L{bytes}
@param extraData: method-specific data to go in the packet
@type extraData: L{bytes}
"""
self.lastAuth = kind
self.transport.sendPacket(
MSG_USERAUTH_REQUEST,
NS(self.user) + NS(self.instance.name) + NS(kind) + extraData,
)
def tryAuth(self, kind):
"""
Dispatch to an authentication method.
@param kind: the authentication method
@type kind: L{bytes}
"""
kind = nativeString(kind.replace(b"-", b"_"))
self._log.debug("trying to auth with {kind}", kind=kind)
f = getattr(self, "auth_" + kind, None)
if f:
return f()
def _ebAuth(self, ignored, *args):
"""
Generic callback for a failed authentication attempt. Respond by
asking for the list of accepted methods (the 'none' method)
"""
self.askForAuth(b"none", b"")
def ssh_USERAUTH_SUCCESS(self, packet):
"""
We received a MSG_USERAUTH_SUCCESS. The server has accepted our
authentication, so start the next service.
"""
self.transport.setService(self.instance)
def ssh_USERAUTH_FAILURE(self, packet):
"""
We received a MSG_USERAUTH_FAILURE. Payload::
string methods
byte partial success
If partial success is C{True}, then the previous method succeeded but is
not sufficient for authentication. C{methods} is a comma-separated list
of accepted authentication methods.
We sort the list of methods by their position in C{self.preferredOrder},
removing methods that have already succeeded. We then call
C{self.tryAuth} with the most preferred method.
@param packet: the C{MSG_USERAUTH_FAILURE} payload.
@type packet: L{bytes}
@return: a L{defer.Deferred} that will be callbacked with L{None} as
soon as all authentication methods have been tried, or L{None} if no
more authentication methods are available.
@rtype: C{defer.Deferred} or L{None}
"""
canContinue, partial = getNS(packet)
partial = ord(partial)
if partial:
self.authenticatedWith.append(self.lastAuth)
def orderByPreference(meth):
"""
Invoked once per authentication method in order to extract a
comparison key which is then used for sorting.
@param meth: the authentication method.
@type meth: L{bytes}
@return: the comparison key for C{meth}.
@rtype: L{int}
"""
if meth in self.preferredOrder:
return self.preferredOrder.index(meth)
else:
# put the element at the end of the list.
return len(self.preferredOrder)
canContinue = sorted(
(
meth
for meth in canContinue.split(b",")
if meth not in self.authenticatedWith
),
key=orderByPreference,
)
self._log.debug("can continue with: {methods}", methods=canContinue)
return self._cbUserauthFailure(None, iter(canContinue))
def _cbUserauthFailure(self, result, iterator):
if result:
return
try:
method = next(iterator)
except StopIteration:
self.transport.sendDisconnect(
transport.DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE,
b"no more authentication methods available",
)
else:
d = defer.maybeDeferred(self.tryAuth, method)
d.addCallback(self._cbUserauthFailure, iterator)
return d
def ssh_USERAUTH_PK_OK(self, packet):
"""
This message (number 60) can mean several different messages depending
on the current authentication type. We dispatch to individual methods
in order to handle this request.
"""
func = getattr(
self,
"ssh_USERAUTH_PK_OK_%s" % nativeString(self.lastAuth.replace(b"-", b"_")),
None,
)
if func is not None:
return func(packet)
else:
self.askForAuth(b"none", b"")
def ssh_USERAUTH_PK_OK_publickey(self, packet):
"""
This is MSG_USERAUTH_PK. Our public key is valid, so we create a
signature and try to authenticate with it.
"""
publicKey = self.lastPublicKey
b = (
NS(self.transport.sessionID)
+ bytes((MSG_USERAUTH_REQUEST,))
+ NS(self.user)
+ NS(self.instance.name)
+ NS(b"publickey")
+ b"\x01"
+ NS(publicKey.sshType())
+ NS(publicKey.blob())
)
d = self.signData(publicKey, b)
if not d:
self.askForAuth(b"none", b"")
# this will fail, we'll move on
return
d.addCallback(self._cbSignedData)
d.addErrback(self._ebAuth)
def ssh_USERAUTH_PK_OK_password(self, packet):
"""
This is MSG_USERAUTH_PASSWD_CHANGEREQ. The password given has expired.
We ask for an old password and a new password, then send both back to
the server.
"""
prompt, language, rest = getNS(packet, 2)
self._oldPass = self._newPass = None
d = self.getPassword(b"Old Password: ")
d = d.addCallbacks(self._setOldPass, self._ebAuth)
d.addCallback(lambda ignored: self.getPassword(prompt))
d.addCallbacks(self._setNewPass, self._ebAuth)
def ssh_USERAUTH_PK_OK_keyboard_interactive(self, packet):
"""
This is MSG_USERAUTH_INFO_RESPONSE. The server has sent us the
questions it wants us to answer, so we ask the user and sent the
responses.
"""
name, instruction, lang, data = getNS(packet, 3)
numPrompts = struct.unpack("!L", data[:4])[0]
data = data[4:]
prompts = []
for i in range(numPrompts):
prompt, data = getNS(data)
echo = bool(ord(data[0:1]))
data = data[1:]
prompts.append((prompt, echo))
d = self.getGenericAnswers(name, instruction, prompts)
d.addCallback(self._cbGenericAnswers)
d.addErrback(self._ebAuth)
def _cbSignedData(self, signedData):
"""
Called back out of self.signData with the signed data. Send the
authentication request with the signature.
@param signedData: the data signed by the user's private key.
@type signedData: L{bytes}
"""
publicKey = self.lastPublicKey
self.askForAuth(
b"publickey",
b"\x01" + NS(publicKey.sshType()) + NS(publicKey.blob()) + NS(signedData),
)
def _setOldPass(self, op):
"""
Called back when we are choosing a new password. Simply store the old
password for now.
@param op: the old password as entered by the user
@type op: L{bytes}
"""
self._oldPass = op
def _setNewPass(self, np):
"""
Called back when we are choosing a new password. Get the old password
and send the authentication message with both.
@param np: the new password as entered by the user
@type np: L{bytes}
"""
op = self._oldPass
self._oldPass = None
self.askForAuth(b"password", b"\xff" + NS(op) + NS(np))
def _cbGenericAnswers(self, responses):
"""
Called back when we are finished answering keyboard-interactive
questions. Send the info back to the server in a
MSG_USERAUTH_INFO_RESPONSE.
@param responses: a list of L{bytes} responses
@type responses: L{list}
"""
data = struct.pack("!L", len(responses))
for r in responses:
data += NS(r.encode("UTF8"))
self.transport.sendPacket(MSG_USERAUTH_INFO_RESPONSE, data)
def auth_publickey(self):
"""
Try to authenticate with a public key. Ask the user for a public key;
if the user has one, send the request to the server and return True.
Otherwise, return False.
@rtype: L{bool}
"""
d = defer.maybeDeferred(self.getPublicKey)
d.addBoth(self._cbGetPublicKey)
return d
def _cbGetPublicKey(self, publicKey):
if not isinstance(publicKey, keys.Key): # failure or None
publicKey = None
if publicKey is not None:
self.lastPublicKey = publicKey
self.triedPublicKeys.append(publicKey)
self._log.debug("using key of type {keyType}", keyType=publicKey.type())
self.askForAuth(
b"publickey", b"\x00" + NS(publicKey.sshType()) + NS(publicKey.blob())
)
return True
else:
return False
def auth_password(self):
"""
Try to authenticate with a password. Ask the user for a password.
If the user will return a password, return True. Otherwise, return
False.
@rtype: L{bool}
"""
d = self.getPassword()
if d:
d.addCallbacks(self._cbPassword, self._ebAuth)
return True
else: # returned None, don't do password auth
return False
def auth_keyboard_interactive(self):
"""
Try to authenticate with keyboard-interactive authentication. Send
the request to the server and return True.
@rtype: L{bool}
"""
self._log.debug("authing with keyboard-interactive")
self.askForAuth(b"keyboard-interactive", NS(b"") + NS(b""))
return True
def _cbPassword(self, password):
"""
Called back when the user gives a password. Send the request to the
server.
@param password: the password the user entered
@type password: L{bytes}
"""
self.askForAuth(b"password", b"\x00" + NS(password))
def signData(self, publicKey, signData):
"""
Sign the given data with the given public key.
By default, this will call getPrivateKey to get the private key,
then sign the data using Key.sign().
This method is factored out so that it can be overridden to use
alternate methods, such as a key agent.
@param publicKey: The public key object returned from L{getPublicKey}
@type publicKey: L{keys.Key}
@param signData: the data to be signed by the private key.
@type signData: L{bytes}
@return: a Deferred that's called back with the signature
@rtype: L{defer.Deferred}
"""
key = self.getPrivateKey()
if not key:
return
return key.addCallback(self._cbSignData, signData)
def _cbSignData(self, privateKey, signData):
"""
Called back when the private key is returned. Sign the data and
return the signature.
@param privateKey: the private key object
@type privateKey: L{keys.Key}
@param signData: the data to be signed by the private key.
@type signData: L{bytes}
@return: the signature
@rtype: L{bytes}
"""
return privateKey.sign(signData)
def getPublicKey(self):
"""
Return a public key for the user. If no more public keys are
available, return L{None}.
This implementation always returns L{None}. Override it in a
subclass to actually find and return a public key object.
@rtype: L{Key} or L{None}
"""
return None
def getPrivateKey(self):
"""
Return a L{Deferred} that will be called back with the private key
object corresponding to the last public key from getPublicKey().
If the private key is not available, errback on the Deferred.
@rtype: L{Deferred} called back with L{Key}
"""
return defer.fail(NotImplementedError())
def getPassword(self, prompt=None):
"""
Return a L{Deferred} that will be called back with a password.
prompt is a string to display for the password, or None for a generic
'user@hostname's password: '.
@type prompt: L{bytes}/L{None}
@rtype: L{defer.Deferred}
"""
return defer.fail(NotImplementedError())
def getGenericAnswers(self, name, instruction, prompts):
"""
Returns a L{Deferred} with the responses to the promopts.
@param name: The name of the authentication currently in progress.
@param instruction: Describes what the authentication wants.
@param prompts: A list of (prompt, echo) pairs, where prompt is a
string to display and echo is a boolean indicating whether the
user's response should be echoed as they type it.
"""
return defer.fail(NotImplementedError())
MSG_USERAUTH_REQUEST = 50
MSG_USERAUTH_FAILURE = 51
MSG_USERAUTH_SUCCESS = 52
MSG_USERAUTH_BANNER = 53
MSG_USERAUTH_INFO_RESPONSE = 61
MSG_USERAUTH_PK_OK = 60
messages = {}
for k, v in list(locals().items()):
if k[:4] == "MSG_":
messages[v] = k
SSHUserAuthServer.protocolMessages = messages
SSHUserAuthClient.protocolMessages = messages
del messages
del v
# Doubles, not included in the protocols' mappings
MSG_USERAUTH_PASSWD_CHANGEREQ = 60
MSG_USERAUTH_INFO_REQUEST = 60

View File

@@ -0,0 +1,114 @@
# -*- test-case-name: twisted.conch.test.test_manhole -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Asynchronous local terminal input handling
@author: Jp Calderone
"""
import os
import sys
import termios
import tty
from twisted.conch.insults.insults import ServerProtocol
from twisted.conch.manhole import ColoredManhole
from twisted.internet import defer, protocol, reactor, stdio
from twisted.python import failure, log, reflect
class UnexpectedOutputError(Exception):
pass
class TerminalProcessProtocol(protocol.ProcessProtocol):
def __init__(self, proto):
self.proto = proto
self.onConnection = defer.Deferred()
def connectionMade(self):
self.proto.makeConnection(self)
self.onConnection.callback(None)
self.onConnection = None
def write(self, data):
"""
Write to the terminal.
@param data: Data to write.
@type data: L{bytes}
"""
self.transport.write(data)
def outReceived(self, data):
"""
Receive data from the terminal.
@param data: Data received.
@type data: L{bytes}
"""
self.proto.dataReceived(data)
def errReceived(self, data):
"""
Report an error.
@param data: Data to include in L{Failure}.
@type data: L{bytes}
"""
self.transport.loseConnection()
if self.proto is not None:
self.proto.connectionLost(failure.Failure(UnexpectedOutputError(data)))
self.proto = None
def childConnectionLost(self, childFD):
if self.proto is not None:
self.proto.childConnectionLost(childFD)
def processEnded(self, reason):
if self.proto is not None:
self.proto.connectionLost(reason)
self.proto = None
class ConsoleManhole(ColoredManhole):
"""
A manhole protocol specifically for use with L{stdio.StandardIO}.
"""
def connectionLost(self, reason):
"""
When the connection is lost, there is nothing more to do. Stop the
reactor so that the process can exit.
"""
reactor.stop()
def runWithProtocol(klass):
fd = sys.__stdin__.fileno()
oldSettings = termios.tcgetattr(fd)
tty.setraw(fd)
try:
stdio.StandardIO(ServerProtocol(klass))
reactor.run()
finally:
termios.tcsetattr(fd, termios.TCSANOW, oldSettings)
os.write(fd, b"\r\x1bc\r")
def main(argv=None):
log.startLogging(open("child.log", "w"))
if argv is None:
argv = sys.argv[1:]
if argv:
klass = reflect.namedClass(argv[0])
else:
klass = ConsoleManhole
runWithProtocol(klass)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,91 @@
# -*- test-case-name: twisted.conch.test.test_tap -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Support module for making SSH servers with twistd.
"""
from twisted.application import strports
from twisted.conch import checkers as conch_checkers, unix
from twisted.conch.openssh_compat import factory
from twisted.cred import portal, strcred
from twisted.python import usage
class Options(usage.Options, strcred.AuthOptionMixin):
synopsis = "[-i <interface>] [-p <port>] [-d <dir>] "
longdesc = (
"Makes a Conch SSH server. If no authentication methods are "
"specified, the default authentication methods are UNIX passwords "
"and SSH public keys. If --auth options are "
"passed, only the measures specified will be used."
)
optParameters = [
["interface", "i", "", "local interface to which we listen"],
["port", "p", "tcp:22", "Port on which to listen"],
["data", "d", "/etc", "directory to look for host keys in"],
[
"moduli",
"",
None,
"directory to look for moduli in " "(if different from --data)",
],
]
compData = usage.Completions(
optActions={
"data": usage.CompleteDirs(descr="data directory"),
"moduli": usage.CompleteDirs(descr="moduli directory"),
"interface": usage.CompleteNetInterfaces(),
}
)
def __init__(self, *a, **kw):
usage.Options.__init__(self, *a, **kw)
# Call the default addCheckers (for backwards compatibility) that will
# be used if no --auth option is provided - note that conch's
# UNIXPasswordDatabase is used, instead of twisted.plugins.cred_unix's
# checker
super().addChecker(conch_checkers.UNIXPasswordDatabase())
super().addChecker(
conch_checkers.SSHPublicKeyChecker(conch_checkers.UNIXAuthorizedKeysFiles())
)
self._usingDefaultAuth = True
def addChecker(self, checker):
"""
Add the checker specified. If any checkers are added, the default
checkers are automatically cleared and the only checkers will be the
specified one(s).
"""
if self._usingDefaultAuth:
self["credCheckers"] = []
self["credInterfaces"] = {}
self._usingDefaultAuth = False
super().addChecker(checker)
def makeService(config):
"""
Construct a service for operating a SSH server.
@param config: An L{Options} instance specifying server options, including
where server keys are stored and what authentication methods to use.
@return: A L{twisted.application.service.IService} provider which contains
the requested SSH server.
"""
t = factory.OpenSSHFactory()
r = unix.UnixSSHRealm()
t.portal = portal.Portal(r, config.get("credCheckers", []))
t.dataRoot = config["data"]
t.moduliRoot = config["moduli"] or config["data"]
port = config["port"]
if config["interface"]:
# Add warning here
port += ":interface=" + config["interface"]
return strports.service(port, t)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1 @@
"conch tests"

View File

@@ -0,0 +1,671 @@
# -*- test-case-name: twisted.conch.test.test_keys -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
# pylint: disable=I0011,C0103,W9401,W9402
"""
Data used by test_keys as well as others.
"""
from base64 import decodebytes
RSAData = {
"n": int(
"269413617238113438198661010376758399219880277968382122687862697"
"296942471209955603071120391975773283844560230371884389952067978"
"789684135947515341209478065209455427327369102356204259106807047"
"964139525310539133073743116175821417513079706301100600025815509"
"786721808719302671068052414466483676821987505720384645561708425"
"794379383191274856941628512616355437197560712892001107828247792"
"561858327085521991407807015047750218508971611590850575870321007"
"991909043252470730134547038841839367764074379439843108550888709"
"430958143271417044750314742880542002948053835745429446485015316"
"60749404403945254975473896534482849256068133525751"
),
"e": 65537,
"d": int(
"420335724286999695680502438485489819800002417295071059780489811"
"840828351636754206234982682752076205397047218449504537476523960"
"987613148307573487322720481066677105211155388802079519869249746"
"774085882219244493290663802569201213676433159425782937159766786"
"329742053214957933941260042101377175565683849732354700525628975"
"239000548651346620826136200952740446562751690924335365940810658"
"931238410612521441739702170503547025018016868116037053013935451"
"477930426013703886193016416453215950072147440344656137718959053"
"897268663969428680144841987624962928576808352739627262941675617"
"7724661940425316604626522633351193810751757014073"
),
"p": int(
"152689878451107675391723141129365667732639179427453246378763774"
"448531436802867910180261906924087589684175595016060014593521649"
"964959248408388984465569934780790357826811592229318702991401054"
"226302790395714901636384511513449977061729214247279176398290513"
"085108930550446985490864812445551198848562639933888780317"
),
"q": int(
"176444974592327996338888725079951900172097062203378367409936859"
"072670162290963119826394224277287608693818012745872307600855894"
"647300295516866118620024751601329775653542084052616260193174546"
"400544176890518564317596334518015173606460860373958663673307503"
"231977779632583864454001476729233959405710696795574874403"
),
"u": int(
"936018002388095842969518498561007090965136403384715613439364803"
"229386793506402222847415019772053080458257034241832795210460612"
"924445085372678524176842007912276654532773301546269997020970818"
"155956828553418266110329867222673040098885651348225673298948529"
"93885224775891490070400861134282266967852120152546563278"
),
}
DSAData = {
"g": int(
"10253261326864117157640690761723586967382334319435778695"
"29171533815411392477819921538350732400350395446211982054"
"96512489289702949127531056893725702005035043292195216541"
"11525058911428414042792836395195432445511200566318251789"
"10575695836669396181746841141924498545494149998282951407"
"18645344764026044855941864175"
),
"p": int(
"10292031726231756443208850082191198787792966516790381991"
"77502076899763751166291092085666022362525614129374702633"
"26262930887668422949051881895212412718444016917144560705"
"45675251775747156453237145919794089496168502517202869160"
"78674893099371444940800865897607102159386345313384716752"
"18590012064772045092956919481"
),
"q": 1393384845225358996250882900535419012502712821577,
"x": 1220877188542930584999385210465204342686893855021,
"y": int(
"14604423062661947579790240720337570315008549983452208015"
"39426429789435409684914513123700756086453120500041882809"
"10283610277194188071619191739512379408443695946763554493"
"86398594314468629823767964702559709430618263927529765769"
"10270265745700231533660131769648708944711006508965764877"
"684264272082256183140297951"
),
}
ECDatanistp256 = {
"x": int(
"762825130203920963171185031449647317742997734817505505433829043"
"45687059013883"
),
"y": int(
"815431978646028526322656647694416475343443758943143196810611371"
"59310646683104"
),
"privateValue": int(
"3463874347721034170096400845565569825355565567882605"
"9678074967909361042656500"
),
"curve": b"ecdsa-sha2-nistp256",
}
SKECDatanistp256 = {
"x": int(
"239399367768747020111880335553299826848360860410053166887934464"
"83115637049597"
),
"y": int(
"114119006635761413192818806701564910719235784173643448381780025"
"223832906554748"
),
"curve": b"sk-ecdsa-sha2-nistp256@openssh.com",
}
ECDatanistp384 = {
"privateValue": int(
"280814107134858470598753916394807521398239633534281"
"633982576099083357871098966021020900021966162732114"
"95718603965098"
),
"x": int(
"10036914308591746758780165503819213553101287571902957054148542"
"504671046744460374996612408381962208627004841444205030"
),
"y": int(
"17337335659928075994560513699823544906448896792102247714689323"
"575406618073069185107088229463828921069465902299522926"
),
"curve": b"ecdsa-sha2-nistp384",
}
ECDatanistp521 = {
"x": int(
"12944742826257420846659527752683763193401384271391513286022917"
"29910013082920512632908350502247952686156279140016049549948975"
"670668730618745449113644014505462"
),
"y": int(
"10784108810271976186737587749436295782985563640368689081052886"
"16296815984553198866894145509329328086635278430266482551941240"
"591605833440825557820439734509311"
),
"privateValue": int(
"662751235215460886290293902658128847495347691199214"
"706697089140769672273950767961331442265530524063943"
"548846724348048614239791498442599782310681891569896"
"0565"
),
"curve": b"ecdsa-sha2-nistp521",
}
Ed25519Data = {
"a": (
b"\xf1\x16\xd1\x15J\x1e\x15\x0e\x19^\x19F\xb5\xf2D\r\xb2R\xa0\xae*k"
b"#\x13sE\xfd@\xd9W{\x8b"
),
"k": (
b"7/%\xda\x8d\xd4\xa8\x9ax|a\xf0\x98\x01\xc6\xf4^mg\x05i17Li\r\x05U"
b"\xbb\xc9DX"
),
}
SKEd25519Data = {
"a": (
b"\x08}'U\xd2i\x04\x11\xea\x01~+\x165iRM\xdd\xe6R\x7f\xd3\xaf\\\xa8p"
b"\xa0LL\xe5\x8a\xa0"
),
"k": (
b"7/%\xda\x8d\xd4\xa8\x9ax|a\xf0\x98\x01\xc6\xf4^mg\x05i17Li\r\x05U"
b"\xbb\xc9DX"
),
}
privateECDSA_openssh521 = b"""-----BEGIN EC PRIVATE KEY-----
MIHcAgEBBEIAjn0lSVF6QweS4bjOGP9RHwqxUiTastSE0MVuLtFvkxygZqQ712oZ
ewMvqKkxthMQgxzSpGtRBcmkL7RqZ94+18qgBwYFK4EEACOhgYkDgYYABAFpX/6B
mxxglwD+VpEvw0hcyxVzLxNnMGzxZGF7xmNj8nlF7M+TQctdlR2Xv/J+AgIeVGmB
j2p84bkV9jBzrUNJEACsJjttZw8NbUrhxjkLT/3rMNtuwjE4vLja0P7DMTE0EV8X
f09ETdku/z/1tOSSrSvRwmUcM9nQUJtHHAZlr5Q0fw==
-----END EC PRIVATE KEY-----"""
# New format introduced in OpenSSH 6.5
privateECDSA_openssh521_new = b"""-----BEGIN OPENSSH PRIVATE KEY-----
b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAArAAAABNlY2RzYS
1zaGEyLW5pc3RwNTIxAAAACG5pc3RwNTIxAAAAhQQBaV/+gZscYJcA/laRL8NIXMsVcy8T
ZzBs8WRhe8ZjY/J5RezPk0HLXZUdl7/yfgICHlRpgY9qfOG5FfYwc61DSRAArCY7bWcPDW
1K4cY5C0/96zDbbsIxOLy42tD+wzExNBFfF39PRE3ZLv8/9bTkkq0r0cJlHDPZ0FCbRxwG
Za+UNH8AAAEAeRISlnkSEpYAAAATZWNkc2Etc2hhMi1uaXN0cDUyMQAAAAhuaXN0cDUyMQ
AAAIUEAWlf/oGbHGCXAP5WkS/DSFzLFXMvE2cwbPFkYXvGY2PyeUXsz5NBy12VHZe/8n4C
Ah5UaYGPanzhuRX2MHOtQ0kQAKwmO21nDw1tSuHGOQtP/esw227CMTi8uNrQ/sMxMTQRXx
d/T0RN2S7/P/W05JKtK9HCZRwz2dBQm0ccBmWvlDR/AAAAQgCOfSVJUXpDB5LhuM4Y/1Ef
CrFSJNqy1ITQxW4u0W+THKBmpDvXahl7Ay+oqTG2ExCDHNKka1EFyaQvtGpn3j7XygAAAA
ABAg==
-----END OPENSSH PRIVATE KEY-----
"""
publicECDSA_openssh521 = (
b"ecdsa-sha2-nistp521 AAAAE2VjZHNhLXNoYTItbmlzdHA1MjEAAAAIbmlzdHA1MjEAAACF"
b"BAFpX/6BmxxglwD+VpEvw0hcyxVzLxNnMGzxZGF7xmNj8nlF7M+TQctdlR2Xv/J+AgIeVGmB"
b"j2p84bkV9jBzrUNJEACsJjttZw8NbUrhxjkLT/3rMNtuwjE4vLja0P7DMTE0EV8Xf09ETdku"
b"/z/1tOSSrSvRwmUcM9nQUJtHHAZlr5Q0fw== comment"
)
privateECDSA_openssh384 = b"""-----BEGIN EC PRIVATE KEY-----
MIGkAgEBBDAtAi7I8j73WCX20qUM5hhHwHuFzYWYYILs2Sh8UZ+awNkARZ/Fu2LU
LLl5RtOQpbWgBwYFK4EEACKhZANiAATU17sA9P5FRwSknKcFsjjsk0+E3CeXPYX0
Tk/M0HK3PpWQWgrO8JdRHP9eFE9O/23P8BumwFt7F/AvPlCzVd35VfraFT0o4cCW
G0RqpQ+np31aKmeJshkcYALEchnU+tQ=
-----END EC PRIVATE KEY-----"""
# New format introduced in OpenSSH 6.5
privateECDSA_openssh384_new = b"""-----BEGIN OPENSSH PRIVATE KEY-----
b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAiAAAABNlY2RzYS
1zaGEyLW5pc3RwMzg0AAAACG5pc3RwMzg0AAAAYQTU17sA9P5FRwSknKcFsjjsk0+E3CeX
PYX0Tk/M0HK3PpWQWgrO8JdRHP9eFE9O/23P8BumwFt7F/AvPlCzVd35VfraFT0o4cCWG0
RqpQ+np31aKmeJshkcYALEchnU+tQAAADIiktpWIpLaVgAAAATZWNkc2Etc2hhMi1uaXN0
cDM4NAAAAAhuaXN0cDM4NAAAAGEE1Ne7APT+RUcEpJynBbI47JNPhNwnlz2F9E5PzNBytz
6VkFoKzvCXURz/XhRPTv9tz/AbpsBbexfwLz5Qs1Xd+VX62hU9KOHAlhtEaqUPp6d9Wipn
ibIZHGACxHIZ1PrUAAAAMC0CLsjyPvdYJfbSpQzmGEfAe4XNhZhgguzZKHxRn5rA2QBFn8
W7YtQsuXlG05CltQAAAAA=
-----END OPENSSH PRIVATE KEY-----
"""
publicECDSA_openssh384 = (
b"ecdsa-sha2-nistp384 AAAAE2VjZHNhLXNoYTItbmlzdHAzODQAAAAIbmlzdHAzODQAAABh"
b"BNTXuwD0/kVHBKScpwWyOOyTT4TcJ5c9hfROT8zQcrc+lZBaCs7wl1Ec/14UT07/bc/wG6bA"
b"W3sX8C8+ULNV3flV+toVPSjhwJYbRGqlD6enfVoqZ4myGRxgAsRyGdT61A== comment"
)
publicECDSA_openssh = (
b"ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABB"
b"BKimX1DZ7+Qj0SpfePMbo1pb6yGkAb5l7duC1l855yD7tEfQfqk7bc7v46We1hLMyz6ObUBY"
b"gkN/34n42F4vpeA= comment"
)
publicSKECDSA_openssh = (
b"sk-ecdsa-sha2-nistp256@openssh.com AAAAInNrLWVjZHNhLXNoYTItbmlzdHAyNTZAb3"
b"BlbnNzaC5jb20AAAAIbmlzdHAyNTYAAABBBDTthidmBSzlQiO8aZPfLmUDOS2TSRevW8IrHPK"
b"IhYj9/E0RnTyvPIB1eWQx4rQl5iO1mihuBz+u4LkjwVEU3XwAAAAUc3NoOmVjZHNhLWZpZG8y"
b"LXRlc3Q= comment"
)
publicSKEd25519_openssh = (
b"sk-ssh-ed25519@openssh.com AAAAGnNrLXNzaC1lZDI1NTE5QG9wZW5zc2guY29tAAAAIA"
b"h9J1XSaQQR6gF+KxY1aVJN3eZSf9OvXKhwoExM5YqgAAAABHNzaDo= comment"
)
publicSKECDSA_cert_openssh = (
b"sk-ecdsa-sha2-nistp256-cert-v01@openssh.com AAAAInNrLWVjZHNhLXNoYTItbmlzdHAyNTZAb3"
b"BlbnNzaC5jb20AAAAIbmlzdHAyNTYAAABBBDTthidmBSzlQiO8aZPfLmUDOS2TSRevW8IrHPK"
b"IhYj9/E0RnTyvPIB1eWQx4rQl5iO1mihuBz+u4LkjwVEU3XwAAAAUc3NoOmVjZHNhLWZpZG8y"
b"LXRlc3Q= comment"
)
publicSKEd25519_cert_openssh = (
b"sk-ssh-ed25519-cert-v01@openssh.com AAAAGnNrLXNzaC1lZDI1NTE5QG9wZW5zc2guY29tAAAAIA"
b"h9J1XSaQQR6gF+KxY1aVJN3eZSf9OvXKhwoExM5YqgAAAABHNzaDo= comment"
)
privateECDSA_openssh = b"""-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIEyU1YOT2JxxofwbJXIjGftdNcJK55aQdNrhIt2xYQz0oAoGCCqGSM49
AwEHoUQDQgAEqKZfUNnv5CPRKl948xujWlvrIaQBvmXt24LWXznnIPu0R9B+qTtt
zu/jpZ7WEszLPo5tQFiCQ3/fifjYXi+l4A==
-----END EC PRIVATE KEY-----"""
# New format introduced in OpenSSH 6.5
privateECDSA_openssh_new = b"""-----BEGIN OPENSSH PRIVATE KEY-----
b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAaAAAABNlY2RzYS
1zaGEyLW5pc3RwMjU2AAAACG5pc3RwMjU2AAAAQQSopl9Q2e/kI9EqX3jzG6NaW+shpAG+
Ze3bgtZfOecg+7RH0H6pO23O7+OlntYSzMs+jm1AWIJDf9+J+NheL6XgAAAAmCKU4hcilO
IXAAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBKimX1DZ7+Qj0Spf
ePMbo1pb6yGkAb5l7duC1l855yD7tEfQfqk7bc7v46We1hLMyz6ObUBYgkN/34n42F4vpe
AAAAAgTJTVg5PYnHGh/BslciMZ+101wkrnlpB02uEi3bFhDPQAAAAA
-----END OPENSSH PRIVATE KEY-----
"""
publicEd25519_openssh = (
b"ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIPEW0RVKHhUOGV4ZRrXyRA2yUqCuKmsjE3NF"
b"/UDZV3uL comment"
)
# OpenSSH has only ever supported the "new" (v1) format for Ed25519.
privateEd25519_openssh_new = b"""-----BEGIN OPENSSH PRIVATE KEY-----
b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW
QyNTUxOQAAACDxFtEVSh4VDhleGUa18kQNslKgriprIxNzRf1A2Vd7iwAAAJA61eMLOtXj
CwAAAAtzc2gtZWQyNTUxOQAAACDxFtEVSh4VDhleGUa18kQNslKgriprIxNzRf1A2Vd7iw
AAAEA3LyXajdSomnh8YfCYAcb0Xm1nBWkxN0xpDQVVu8lEWPEW0RVKHhUOGV4ZRrXyRA2y
UqCuKmsjE3NF/UDZV3uLAAAAB2NvbW1lbnQBAgMEBQY=
-----END OPENSSH PRIVATE KEY-----"""
publicRSA_openssh = (
b"ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDVaqx4I9bWG+wloVDEd2NQhEUBVUIUKirg"
b"0GDu1OmjrUr6OQZehFV1XwA2v2+qKj+DJjfBaS5b/fDz0n3WmM06QHjVyqgYwBGTJAkMgUyP"
b"95ztExZqpATpSXfD5FVks3loniwI66zoBC0hdwWnju9TMA2l5bs9auIJNm/9NNN9b0b/h9qp"
b"KSeq/631heY+Grh6HUqx6sBa9zDfH8Kk5O8/kUmWQNUZdy03w17snaY6RKXCpCnd1bqcPUWz"
b"xiwYZNW6Pd+rf81CrKfxGAugWBViC6QqbkPD5ASfNaNHjkbtM6Vlvbw7KW4CC1ffdOgTtDc1"
b"foNfICZgptyti8ZseZj3 comment"
)
privateRSA_openssh = b"""-----BEGIN RSA PRIVATE KEY-----
MIIEogIBAAKCAQEA1WqseCPW1hvsJaFQxHdjUIRFAVVCFCoq4NBg7tTpo61K+jkG
XoRVdV8ANr9vqio/gyY3wWkuW/3w89J91pjNOkB41cqoGMARkyQJDIFMj/ec7RMW
aqQE6Ul3w+RVZLN5aJ4sCOus6AQtIXcFp47vUzANpeW7PWriCTZv/TTTfW9G/4fa
qSknqv+t9YXmPhq4eh1KserAWvcw3x/CpOTvP5FJlkDVGXctN8Ne7J2mOkSlwqQp
3dW6nD1Fs8YsGGTVuj3fq3/NQqyn8RgLoFgVYgukKm5Dw+QEnzWjR45G7TOlZb28
OyluAgtX33ToE7Q3NX6DXyAmYKbcrYvGbHmY9wIDAQABAoIBACFMCGaiKNW0+44P
chuFCQC58k438BxXS+NRf54jp+Q6mFUb6ot6mB682Lqx+YkSGGCs6MwLTglaQGq6
L5n4syRghLnOaZWa+eL8H1FNJxXbKyet77RprL59EOuGR3BztACHlRU7N/nnFOeA
u2geG+bdu3NjuWfmsid/z88wm8KY/dkYNi82LvE9gXqf4QMtR9s0UWI53U/prKiL
2dbzhMQXuXGdBghCeE27xSr0w1jNVSvtvjNfBOp75gQkY/It1z0bbNWcY0MvkoiN
Pm7aGDfYDyVniR25RjReyc7Ei+2SWjMHD9+GCPmS6dvrOAg2yc3NCgFIWzk+esrG
gKnc1DkCgYEA2XAG2OK81HiRUJTUwRuJOGxGZFpRoJoHPUiPA1HMaxKOfRqxZedx
dTngMgV1jRhMr5OxSbFmX3hietEMyuZNQ7Oc9Gt95gyY3M8hYo7VLhLeBK7XJG6D
MaIVokQ9IqliJiK5su1UCp0Ig6cHDf8ZGI7Yqx3aSJwxaBGhZm3j2B0CgYEA+0QX
i6Q2vh43Haf2YWwExKrdeD4HjB4zAq4DFIeDeuWefQhnqPKqvxJwz3Kpp8cLHYjV
IP2cY8pHMFVOi8TP9H8WpJISdKEJwsRunIwz76Xl9+ArrU9cEaoahDdb/Xrqw818
sMjkH1Rjtcev3/QJp/zHJfxc6ZHXksWYHlbTsSMCgYBRr+mSn5QLSoRlPpSzO5IQ
tXS4jMnvyQ4BMvovaBKhAyauz1FoFEwmmyikAjMIX+GncJgBNHleUo7Ezza8H0tV
rOvBU4TH4WGoStSi/0ANgB8SqVDAKhh1lAwGmxZQqEvsQc177/dLyXUCaMSYuIaI
GFpD5wIzlyJkk4MMRSp87QKBgGlmN8ZA3SHFBPOwuD5HlHx2/C3rPzk8lcNDAVHE
Qpfz6Bakxu7s1EkQUDgE7jvN19DMzDJpkAegG1qf/jHNHjp+cR4ZlBpOTwzfX1LV
0Rdu7NectlWd244hX7wkiLb8r6vw76QssNyfhrADEriL4t0PwO4jPUpQ/i+4KUZY
v7YnAoGAZhb5IDTQVCW8YTGsgvvvnDUefkpVAmiVDQqTvh6/4UD6kKdUcDHpePzg
Zrcid5rr3dXSMEbK4tdeQZvPtUg1Uaol3N7bNClIIdvWdPx+5S9T95wJcLnkoHam
rXp0IjScTxfLP+Cq5V6lJ94/pX8Ppoj1FdZfNxeS4NYFSRA7kvY=
-----END RSA PRIVATE KEY-----"""
# New format introduced in OpenSSH 6.5
privateRSA_openssh_new = b"""-----BEGIN OPENSSH PRIVATE KEY-----
b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAABFwAAAAdzc2gtcn
NhAAAAAwEAAQAAAQEA1WqseCPW1hvsJaFQxHdjUIRFAVVCFCoq4NBg7tTpo61K+jkGXoRV
dV8ANr9vqio/gyY3wWkuW/3w89J91pjNOkB41cqoGMARkyQJDIFMj/ec7RMWaqQE6Ul3w+
RVZLN5aJ4sCOus6AQtIXcFp47vUzANpeW7PWriCTZv/TTTfW9G/4faqSknqv+t9YXmPhq4
eh1KserAWvcw3x/CpOTvP5FJlkDVGXctN8Ne7J2mOkSlwqQp3dW6nD1Fs8YsGGTVuj3fq3
/NQqyn8RgLoFgVYgukKm5Dw+QEnzWjR45G7TOlZb28OyluAgtX33ToE7Q3NX6DXyAmYKbc
rYvGbHmY9wAAA7gXkBoMF5AaDAAAAAdzc2gtcnNhAAABAQDVaqx4I9bWG+wloVDEd2NQhE
UBVUIUKirg0GDu1OmjrUr6OQZehFV1XwA2v2+qKj+DJjfBaS5b/fDz0n3WmM06QHjVyqgY
wBGTJAkMgUyP95ztExZqpATpSXfD5FVks3loniwI66zoBC0hdwWnju9TMA2l5bs9auIJNm
/9NNN9b0b/h9qpKSeq/631heY+Grh6HUqx6sBa9zDfH8Kk5O8/kUmWQNUZdy03w17snaY6
RKXCpCnd1bqcPUWzxiwYZNW6Pd+rf81CrKfxGAugWBViC6QqbkPD5ASfNaNHjkbtM6Vlvb
w7KW4CC1ffdOgTtDc1foNfICZgptyti8ZseZj3AAAAAwEAAQAAAQAhTAhmoijVtPuOD3Ib
hQkAufJON/AcV0vjUX+eI6fkOphVG+qLepgevNi6sfmJEhhgrOjMC04JWkBqui+Z+LMkYI
S5zmmVmvni/B9RTScV2ysnre+0aay+fRDrhkdwc7QAh5UVOzf55xTngLtoHhvm3btzY7ln
5rInf8/PMJvCmP3ZGDYvNi7xPYF6n+EDLUfbNFFiOd1P6ayoi9nW84TEF7lxnQYIQnhNu8
Uq9MNYzVUr7b4zXwTqe+YEJGPyLdc9G2zVnGNDL5KIjT5u2hg32A8lZ4kduUY0XsnOxIvt
klozBw/fhgj5kunb6zgINsnNzQoBSFs5PnrKxoCp3NQ5AAAAgQCFSxt6mxIQN54frV7a/s
aW/t81a7k04haXkiYJvb1wIAOnNb0tG6DSB0cr1N6oqAcHG7gEIKcnQTxsOTnpQc7nFx3R
TFy8PdImJv5q1v1Icq5G+nvD0xlgRB2lE6eA9WMp1HpdBgcWXfaLPctkOuKEWk2MBi0tnR
zrg0x4PXlUzgAAAIEA2XAG2OK81HiRUJTUwRuJOGxGZFpRoJoHPUiPA1HMaxKOfRqxZedx
dTngMgV1jRhMr5OxSbFmX3hietEMyuZNQ7Oc9Gt95gyY3M8hYo7VLhLeBK7XJG6DMaIVok
Q9IqliJiK5su1UCp0Ig6cHDf8ZGI7Yqx3aSJwxaBGhZm3j2B0AAACBAPtEF4ukNr4eNx2n
9mFsBMSq3Xg+B4weMwKuAxSHg3rlnn0IZ6jyqr8ScM9yqafHCx2I1SD9nGPKRzBVTovEz/
R/FqSSEnShCcLEbpyMM++l5ffgK61PXBGqGoQ3W/166sPNfLDI5B9UY7XHr9/0Caf8xyX8
XOmR15LFmB5W07EjAAAAAAEC
-----END OPENSSH PRIVATE KEY-----
"""
# Encrypted with the passphrase 'encrypted'
privateRSA_openssh_encrypted = b"""-----BEGIN RSA PRIVATE KEY-----
Proc-Type: 4,ENCRYPTED
DEK-Info: DES-EDE3-CBC,FFFFFFFFFFFFFFFF
p2A1YsHLXkpMVcsEqhh/nCYb5AqL0uMzfEIqc8hpZ/Ub8PtLsypilMkqzYTnZIGS
ouyPjU/WgtR4VaDnutPWdgYaKdixSEmGhKghCtXFySZqCTJ4O8NCczsktYjUK3D4
Jtl90zL6O81WBY6xP76PBQo9lrI/heAetATeyqutc18bwQIGU+gKk32qvfo15DfS
VYiY0Ds4D7F7fd9pz+f5+UbFUCgU+tfDvBrqodYrUgmH7jKoW/CRDCHHyeEIZDbF
mcMwdcKOyw1sRLaPdihRSVx3kOMvIotHKVTkIDMp+0RTNeXzQnp5U2qzsxzTcG/M
UyJN38XXkuvq5VMj2zmmjHzx34w3NK3ZxpZcoaFUqUBlNp2C8hkCLrAa/DWobKqN
5xA1ElrQvli9XXkT/RIuy4Gc10bbGEoJjuxNRibtSxxWd5Bd1E40ocOd4l1ebI8+
w69XvMTnsmHvkBEADGF2zfRszKnMelg+W5NER1UDuNT03i+1cuhp+2AZg8z7niTO
M17XP3ScGVxrQAEYgtxPrPeIpFJvOx2j5Yt78U9Y2WlaAG6DrubbYv2RsMIibhOG
yk139vMdD8FwCey6yMkkhFAJwnBtC22MAWgjmC5c6AF3SRQSjjQXepPsJcLgpOjy
YwjhnL8w56x9kVDUNPw9A9Cqgxo2sty34ATnKrh4h59PsP83LOL6OC5WjbASgZRd
OIBD8RloQPISo+RUF7X0i4kdaHVNPlR0KyapR+3M5BwhQuvEO99IArDV2LNKGzfc
W4ssugm8iyAJlmwmb2yRXIDHXabInWY7XCdGk8J2qPFbDTvnPbiagJBimjVjgpWw
tV3sVlJYqmOqmCDP78J6he04l0vaHtiOWTDEmNCrK7oFMXIIp3XWjOZGPSOJFdPs
6Go3YB+EGWfOQxqkFM28gcqmYfVPF2sa1FbZLz0ffO11Ma/rliZxZu7WdrAXe/tc
BgIQ8etp2PwAK4jCwwVwjIO8FzqQGpS23Y9NY3rfi97ckgYXKESFtXPsMMA+drZd
ThbXvccfh4EPmaqQXKf4WghHiVJ+/yuY1kUIDEl/O0jRZWT7STgBim/Aha1m6qRs
zl1H7hkDbU4solb1GM5oPzbgGTzyBc+z0XxM9iFRM+fMzPB8+yYHTr4kPbVmKBjy
SCovjQQVsHE4YeUGTq6k/NF5cVIRKTW/RlHvzxsky1Zj31MC736jrxGw4KG7VSLZ
fP6F5jj+mXwS7m0v5to42JBZmRJdKUD88QaGE3ncyQ4yleW5bn9Lf9SuzQg1Dhao
3rSA1RuexsHlIAHvGxx/17X+pyygl8DJbt6TBfbLQk9wc707DJTfh5M/bnk9wwIX
l/Hsa1WtylAMW/2MzgiVy83MbYz4+Ss6GQ5W66okWji+NxrnrYEy6q+WgVQanp7X
D+D7oKykqE1Cdvvulvtfl5fh8wlAs8mrUnKPBBUru348u++2lfacLkxRXyT1ooqY
uSNE5nlwFt08N2Ou/bl7yq6QNRMYrRkn+UEfHWCNYDoGMHln2/i6Z1RapQzNarik
tJf7radBz5nBwBjP08YAEACNSQvpsUgdqiuYjLwX7efFXQva2RzqaQ==
-----END RSA PRIVATE KEY-----"""
# Encrypted with the passphrase 'encrypted', and using the new format
# introduced in OpenSSH 6.5
privateRSA_openssh_encrypted_new = b"""-----BEGIN OPENSSH PRIVATE KEY-----
b3BlbnNzaC1rZXktdjEAAAAACmFlczI1Ni1jdHIAAAAGYmNyeXB0AAAAGAAAABD0f9WAof
DTbmwztb8pdrSeAAAAEAAAAAEAAAEXAAAAB3NzaC1yc2EAAAADAQABAAABAQDVaqx4I9bW
G+wloVDEd2NQhEUBVUIUKirg0GDu1OmjrUr6OQZehFV1XwA2v2+qKj+DJjfBaS5b/fDz0n
3WmM06QHjVyqgYwBGTJAkMgUyP95ztExZqpATpSXfD5FVks3loniwI66zoBC0hdwWnju9T
MA2l5bs9auIJNm/9NNN9b0b/h9qpKSeq/631heY+Grh6HUqx6sBa9zDfH8Kk5O8/kUmWQN
UZdy03w17snaY6RKXCpCnd1bqcPUWzxiwYZNW6Pd+rf81CrKfxGAugWBViC6QqbkPD5ASf
NaNHjkbtM6Vlvbw7KW4CC1ffdOgTtDc1foNfICZgptyti8ZseZj3AAADwPQaac8s1xX3af
hQTQexj0vEAWDQsLYzDHN9G7W+UP5WHUu7igeu2GqAC/TOnjUXDP73I+EN3n7T3JFeDRfs
U1Z6Zqb0NKHSRVYwDIdIi8qVohFv85g6+xQ01OpaoOzz+vI34OUvCRHQGTgR6L9fQShZyC
McopYMYfbIse6KcqkfxX3KSdG1Pao6Njx/ShFRbgvmALpR/z0EaGCzHCDxpfUyAdnxm621
Jzaf+LverWdN7sfrfMptaS9//9iJb70sL67K+YIB64qhDnA/w9UOQfXGQFL+AEtdM0BPv8
thP1bs7T0yucBl+ZXdrDKVLZfaS3S/w85Jlgfu+a1DG73pOBOuag435iEJ9EnspjXiiydx
GrfSRk2C+/c4fBDZVGFscK5bfQuUUZyU1qOagekxX7WLHFKk9xajnud+nrAN070SeNwlX8
FZ2CI4KGlQfDvVUpKanYn8Kkj3fZ+YBGyx4M+19clF65FKSM0x1Rrh5tAmNT/SNDbSc28m
ASxrBhztzxUFTrIn3tp+uqkJniFLmFsUtiAUmj8fNyE9blykU7dqq+CqpLA872nQ9bOHHA
JsS1oBYmQ0n6AJz8WrYMdcepqWVld6Q8QSD1zdrY/sAWUovuBA1s4oIEXZhpXSS4ZJiMfh
PVktKBwj5bmoG/mmwYLbo0JHntK8N3TGTzTGLq5TpSBBdVvWSWo7tnfEkrFObmhi1uJSrQ
3zfPVP6BguboxBv+oxhaUBK8UOANe6ZwM4vfiu+QN+sZqWymHIfAktz7eWzwlToe4cKpdG
Uv+e3/7Lo2dyMl3nke5HsSUrlsMGPREuGkBih8+o85ii6D+cuCiVtus3f5c78Cir80zLIr
Z0wWvEAjciEvml00DWaA+JIaOrWwvXySaOzFGpCqC9SQjao379bvn9P3b7kVZsy6zBfHqm
bNEJUOuhBZaY8Okz36chh1xqh4sz7m3nsZ3GYGcvM+3mvRY72QnqsQEG0Sp1XYIn2bHa29
tqp7CG9X8J6dqMcPeoPRDWIX9gw7EPl/M0LP6xgewGJ9bgxwle6Mnr9kNITIswjAJqrLec
zx7dfixjAPc42ADqrw/tEdFQcSqxigcfJNKO1LbDBjh+Hk/cSBou2PoxbIcl0qfQfbGcqI
Dbpd695IEuiW9pYR22txNoIi+7cbMsuFHxQ/OqbrX/jCsprGNNJLAjgGsVEI1JnHWDH0db
3UbqbOHAeY3ufoYXNY1utVOIACpW3r9wBw3FjRi04d70VcKr16OXvOAHGN2G++Y+kMya84
Hl/Kt/gA==
-----END OPENSSH PRIVATE KEY-----
"""
# Encrypted with the passphrase 'testxp'. NB: this key was generated by
# OpenSSH, so it doesn't use the same key data as the other keys here.
privateRSA_openssh_encrypted_aes = b"""-----BEGIN RSA PRIVATE KEY-----
Proc-Type: 4,ENCRYPTED
DEK-Info: AES-128-CBC,0673309A6ACCAB4B77DEE1C1E536AC26
4Ed/a9OgJWHJsne7yOGWeWMzHYKsxuP9w1v0aYcp+puS75wvhHLiUnNwxz0KDi6n
T3YkKLBsoCWS68ApR2J9yeQ6R+EyS+UQDrO9nwqo3DB5BT3Ggt8S1wE7vjNLQD0H
g/SJnlqwsECNhh8aAx+Ag0m3ZKOZiRD5mCkcDQsZET7URSmFytDKOjhFn3u6ZFVB
sXrfpYc6TJtOQlHd/52JB6aAbjt6afSv955Z7enIi+5yEJ5y7oYQTaE5zrFMP7N5
9LbfJFlKXxEddy/DErRLxEjmC+t4svHesoJKc2jjjyNPiOoGGF3kJXea62vsjdNV
gMK5Eged3TBVIk2dv8rtJUvyFeCUtjQ1UJZIebScRR47KrbsIpCmU8I4/uHWm5hW
0mOwvdx1L/mqx/BHqVU9Dw2COhOdLbFxlFI92chkovkmNk4P48ziyVnpm7ME22sE
vfCMsyirdqB1mrL4CSM7FXONv+CgfBfeYVkYW8RfJac9U1L/O+JNn7yee414O/rS
hRYw4UdWnH6Gg6niklVKWNY0ZwUZC8zgm2iqy8YCYuneS37jC+OEKP+/s6HSKuqk
2bzcl3/TcZXNSM815hnFRpz0anuyAsvwPNRyvxG2/DacJHL1f6luV4B0o6W410yf
qXQx01DLo7nuyhJqoH3UGCyyXB+/QUs0mbG2PAEn3f5dVs31JMdbt+PrxURXXjKk
4cexpUcIpqqlfpIRe3RD0sDVbH4OXsGhi2kiTfPZu7mgyFxKopRbn1KwU1qKinfY
EU9O4PoTak/tPT+5jFNhaP+HrURoi/pU8EAUNSktl7xAkHYwkN/9Cm7DeBghgf3n
8+tyCGYDsB5utPD0/Xe9yx0Qhc/kMm4xIyQDyA937dk3mUvLC9vulnAP8I+Izim0
fZ182+D1bWwykoD0997mUHG/AUChWR01V1OLwRyPv2wUtiS8VNG76Y2aqKlgqP1P
V+IvIEqR4ERvSBVFzXNF8Y6j/sVxo8+aZw+d0L1Ns/R55deErGg3B8i/2EqGd3r+
0jps9BqFHHWW87n3VyEB3jWCMj8Vi2EJIfa/7pSaViFIQn8LiBLf+zxG5LTOToK5
xkN42fReDcqi3UNfKNGnv4dsplyTR2hyx65lsj4bRKDGLKOuB1y7iB0AGb0LtcAI
dcsVlcCeUquDXtqKvRnwfIMg+ZunyjqHBhj3qgRgbXbT6zjaSdNnih569aTg0Vup
VykzZ7+n/KVcGLmvX0NesdoI7TKbq4TnEIOynuG5Sf+2GpARO5bjcWKSZeN/Ybgk
gccf8Cqf6XWqiwlWd0B7BR3SymeHIaSymC45wmbgdstrbk7Ppa2Tp9AZku8M2Y7c
8mY9b+onK075/ypiwBm4L4GRNTFLnoNQJXx0OSl4FNRWsn6ztbD+jZhu8Seu10Jw
SEJVJ+gmTKdRLYORJKyqhDet6g7kAxs4EoJ25WsOnX5nNr00rit+NkMPA7xbJT+7
CfI51GQLw7pUPeO2WNt6yZO/YkzZrqvTj5FEwybkUyBv7L0gkqu9wjfDdUw0fVHE
xEm4DxjEoaIp8dW/JOzXQ2EF+WaSOgdYsw3Ac+rnnjnNptCdOEDGP6QBkt+oXj4P
-----END RSA PRIVATE KEY-----"""
publicRSA_lsh = (
b"{KDEwOnB1YmxpYy1rZXkoMTQ6cnNhLXBrY3MxLXNoYTEoMTpuMjU3OgDVaqx4I9bWG+wloVD"
b"Ed2NQhEUBVUIUKirg0GDu1OmjrUr6OQZehFV1XwA2v2+qKj+DJjfBaS5b/fDz0n3WmM06QHj"
b"VyqgYwBGTJAkMgUyP95ztExZqpATpSXfD5FVks3loniwI66zoBC0hdwWnju9TMA2l5bs9auI"
b"JNm/9NNN9b0b/h9qpKSeq/631heY+Grh6HUqx6sBa9zDfH8Kk5O8/kUmWQNUZdy03w17snaY"
b"6RKXCpCnd1bqcPUWzxiwYZNW6Pd+rf81CrKfxGAugWBViC6QqbkPD5ASfNaNHjkbtM6Vlvbw"
b"7KW4CC1ffdOgTtDc1foNfICZgptyti8ZseZj3KSgxOmUzOgEAASkpKQ==}"
)
privateRSA_lsh = (
b"(11:private-key(9:rsa-pkcs1(1:n257:\x00\xd5j\xacx#\xd6\xd6\x1b\xec%\xa1P"
b"\xc4wcP\x84E\x01UB\x14**\xe0\xd0`\xee\xd4\xe9\xa3\xadJ\xfa9\x06^\x84Uu_"
b"\x006\xbfo\xaa*?\x83&7\xc1i.[\xfd\xf0\xf3\xd2}\xd6\x98\xcd:@x\xd5\xca"
b"\xa8\x18\xc0\x11\x93$\t\x0c\x81L\x8f\xf7\x9c\xed\x13\x16j\xa4\x04\xe9Iw"
b"\xc3\xe4Ud\xb3yh\x9e,\x08\xeb\xac\xe8\x04-!w\x05\xa7\x8e\xefS0\r\xa5\xe5"
b"\xbb=j\xe2\t6o\xfd4\xd3}oF\xff\x87\xda\xa9)'\xaa\xff\xad\xf5\x85\xe6>"
b"\x1a\xb8z\x1dJ\xb1\xea\xc0Z\xf70\xdf\x1f\xc2\xa4\xe4\xef?\x91I\x96@\xd5"
b"\x19w-7\xc3^\xec\x9d\xa6:D\xa5\xc2\xa4)\xdd\xd5\xba\x9c=E\xb3\xc6,\x18d"
b"\xd5\xba=\xdf\xab\x7f\xcdB\xac\xa7\xf1\x18\x0b\xa0X\x15b\x0b\xa4*nC\xc3"
b"\xe4\x04\x9f5\xa3G\x8eF\xed3\xa5e\xbd\xbc;)n\x02\x0bW\xdft\xe8\x13\xb475"
b"~\x83_ &`\xa6\xdc\xad\x8b\xc6ly\x98\xf7)(1:e3:\x01\x00\x01)(1:d256:!L"
b"\x08f\xa2(\xd5\xb4\xfb\x8e\x0fr\x1b\x85\t\x00\xb9\xf2N7\xf0\x1cWK\xe3Q"
b"\x7f\x9e#\xa7\xe4:\x98U\x1b\xea\x8bz\x98\x1e\xbc\xd8\xba\xb1\xf9\x89\x12"
b"\x18`\xac\xe8\xcc\x0bN\tZ@j\xba/\x99\xf8\xb3$`\x84\xb9\xcei\x95\x9a\xf9"
b"\xe2\xfc\x1fQM'\x15\xdb+'\xad\xef\xb4i\xac\xbe}\x10\xeb\x86Gps\xb4\x00"
b"\x87\x95\x15;7\xf9\xe7\x14\xe7\x80\xbbh\x1e\x1b\xe6\xdd\xbbsc\xb9g\xe6"
b"\xb2'\x7f\xcf\xcf0\x9b\xc2\x98\xfd\xd9\x186/6.\xf1=\x81z\x9f\xe1\x03-G"
b"\xdb4Qb9\xddO\xe9\xac\xa8\x8b\xd9\xd6\xf3\x84\xc4\x17\xb9q\x9d\x06\x08Bx"
b"M\xbb\xc5*\xf4\xc3X\xcdU+\xed\xbe3_\x04\xea{\xe6\x04$c\xf2-\xd7=\x1bl"
b"\xd5\x9ccC/\x92\x88\x8d>n\xda\x187\xd8\x0f%g\x89\x1d\xb9F4^\xc9\xce\xc4"
b"\x8b\xed\x92Z3\x07\x0f\xdf\x86\x08\xf9\x92\xe9\xdb\xeb8\x086\xc9\xcd\xcd"
b"\n\x01H[9>z\xca\xc6\x80\xa9\xdc\xd49)(1:p129:\x00\xfbD\x17\x8b\xa46\xbe"
b"\x1e7\x1d\xa7\xf6al\x04\xc4\xaa\xddx>\x07\x8c\x1e3\x02\xae\x03\x14\x87"
b"\x83z\xe5\x9e}\x08g\xa8\xf2\xaa\xbf\x12p\xcfr\xa9\xa7\xc7\x0b\x1d\x88"
b"\xd5 \xfd\x9cc\xcaG0UN\x8b\xc4\xcf\xf4\x7f\x16\xa4\x92\x12t\xa1\t\xc2"
b"\xc4n\x9c\x8c3\xef\xa5\xe5\xf7\xe0+\xadO\\\x11\xaa\x1a\x847[\xfdz\xea"
b"\xc3\xcd|\xb0\xc8\xe4\x1fTc\xb5\xc7\xaf\xdf\xf4\t\xa7\xfc\xc7%\xfc\\\xe9"
b"\x91\xd7\x92\xc5\x98\x1eV\xd3\xb1#)(1:q129:\x00\xd9p\x06\xd8\xe2\xbc\xd4"
b"x\x91P\x94\xd4\xc1\x1b\x898lFdZQ\xa0\x9a\x07=H\x8f\x03Q\xcck\x12\x8e}"
b"\x1a\xb1e\xe7qu9\xe02\x05u\x8d\x18L\xaf\x93\xb1I\xb1f_xbz\xd1\x0c\xca"
b"\xe6MC\xb3\x9c\xf4k}\xe6\x0c\x98\xdc\xcf!b\x8e\xd5.\x12\xde\x04\xae\xd7$"
b'n\x831\xa2\x15\xa2D="\xa9b&"\xb9\xb2\xedT\n\x9d\x08\x83\xa7\x07\r\xff'
b"\x19\x18\x8e\xd8\xab\x1d\xdaH\x9c1h\x11\xa1fm\xe3\xd8\x1d)(1:a128:if7"
b"\xc6@\xdd!\xc5\x04\xf3\xb0\xb8>G\x94|v\xfc-\xeb?9<\x95\xc3C\x01Q\xc4B"
b"\x97\xf3\xe8\x16\xa4\xc6\xee\xec\xd4I\x10P8\x04\xee;\xcd\xd7\xd0\xcc\xcc"
b"2i\x90\x07\xa0\x1bZ\x9f\xfe1\xcd\x1e:~q\x1e\x19\x94\x1aNO\x0c\xdf_R\xd5"
b"\xd1\x17n\xec\xd7\x9c\xb6U\x9d\xdb\x8e!_\xbc$\x88\xb6\xfc\xaf\xab\xf0"
b"\xef\xa4,\xb0\xdc\x9f\x86\xb0\x03\x12\xb8\x8b\xe2\xdd\x0f\xc0\xee#=JP"
b"\xfe/\xb8)FX\xbf\xb6')(1:b128:Q\xaf\xe9\x92\x9f\x94\x0bJ\x84e>\x94\xb3;"
b"\x92\x10\xb5t\xb8\x8c\xc9\xef\xc9\x0e\x012\xfa/h\x12\xa1\x03&\xae\xcfQh"
b"\x14L&\x9b(\xa4\x023\x08_\xe1\xa7p\x98\x014y^R\x8e\xc4\xcf6\xbc\x1fKU"
b"\xac\xeb\xc1S\x84\xc7\xe1a\xa8J\xd4\xa2\xff@\r\x80\x1f\x12\xa9P\xc0*\x18"
b"u\x94\x0c\x06\x9b\x16P\xa8K\xecA\xcd{\xef\xf7K\xc9u\x02h\xc4\x98\xb8\x86"
b'\x88\x18ZC\xe7\x023\x97"d\x93\x83\x0cE*|\xed)(1:c128:f\x16\xf9 4\xd0T%'
b"\xbca1\xac\x82\xfb\xef\x9c5\x1e~JU\x02h\x95\r\n\x93\xbe\x1e\xbf\xe1@\xfa"
b'\x90\xa7Tp1\xe9x\xfc\xe0f\xb7"w\x9a\xeb\xdd\xd5\xd20F\xca\xe2\xd7^A\x9b'
b"\xcf\xb5H5Q\xaa%\xdc\xde\xdb4)H!\xdb\xd6t\xfc~\xe5/S\xf7\x9c\tp\xb9\xe4"
b"\xa0v\xa6\xadzt\"4\x9cO\x17\xcb?\xe0\xaa\xe5^\xa5'\xde?\xa5\x7f\x0f\xa6"
b"\x88\xf5\x15\xd6_7\x17\x92\xe0\xd6\x05I\x10;\x92\xf6)))"
)
privateRSA_agentv3 = (
b"\x00\x00\x00\x07ssh-rsa\x00\x00\x00\x03\x01\x00\x01\x00\x00\x01\x00!L"
b"\x08f\xa2(\xd5\xb4\xfb\x8e\x0fr\x1b\x85\t\x00\xb9\xf2N7\xf0\x1cWK\xe3Q"
b"\x7f\x9e#\xa7\xe4:\x98U\x1b\xea\x8bz\x98\x1e\xbc\xd8\xba\xb1\xf9\x89\x12"
b"\x18`\xac\xe8\xcc\x0bN\tZ@j\xba/\x99\xf8\xb3$`\x84\xb9\xcei\x95\x9a\xf9"
b"\xe2\xfc\x1fQM'\x15\xdb+'\xad\xef\xb4i\xac\xbe}\x10\xeb\x86Gps\xb4\x00"
b"\x87\x95\x15;7\xf9\xe7\x14\xe7\x80\xbbh\x1e\x1b\xe6\xdd\xbbsc\xb9g\xe6"
b"\xb2'\x7f\xcf\xcf0\x9b\xc2\x98\xfd\xd9\x186/6.\xf1=\x81z\x9f\xe1\x03-G"
b"\xdb4Qb9\xddO\xe9\xac\xa8\x8b\xd9\xd6\xf3\x84\xc4\x17\xb9q\x9d\x06\x08Bx"
b"M\xbb\xc5*\xf4\xc3X\xcdU+\xed\xbe3_\x04\xea{\xe6\x04$c\xf2-\xd7=\x1bl"
b"\xd5\x9ccC/\x92\x88\x8d>n\xda\x187\xd8\x0f%g\x89\x1d\xb9F4^\xc9\xce\xc4"
b"\x8b\xed\x92Z3\x07\x0f\xdf\x86\x08\xf9\x92\xe9\xdb\xeb8\x086\xc9\xcd\xcd"
b"\n\x01H[9>z\xca\xc6\x80\xa9\xdc\xd49\x00\x00\x01\x01\x00\xd5j\xacx#\xd6"
b"\xd6\x1b\xec%\xa1P\xc4wcP\x84E\x01UB\x14**\xe0\xd0`\xee\xd4\xe9\xa3\xadJ"
b"\xfa9\x06^\x84Uu_\x006\xbfo\xaa*?\x83&7\xc1i.[\xfd\xf0\xf3\xd2}\xd6\x98"
b"\xcd:@x\xd5\xca\xa8\x18\xc0\x11\x93$\t\x0c\x81L\x8f\xf7\x9c\xed\x13\x16j"
b"\xa4\x04\xe9Iw\xc3\xe4Ud\xb3yh\x9e,\x08\xeb\xac\xe8\x04-!w\x05\xa7\x8e"
b"\xefS0\r\xa5\xe5\xbb=j\xe2\t6o\xfd4\xd3}oF\xff\x87\xda\xa9)'\xaa\xff\xad"
b"\xf5\x85\xe6>\x1a\xb8z\x1dJ\xb1\xea\xc0Z\xf70\xdf\x1f\xc2\xa4\xe4\xef?"
b"\x91I\x96@\xd5\x19w-7\xc3^\xec\x9d\xa6:D\xa5\xc2\xa4)\xdd\xd5\xba\x9c=E"
b"\xb3\xc6,\x18d\xd5\xba=\xdf\xab\x7f\xcdB\xac\xa7\xf1\x18\x0b\xa0X\x15b"
b"\x0b\xa4*nC\xc3\xe4\x04\x9f5\xa3G\x8eF\xed3\xa5e\xbd\xbc;)n\x02\x0bW\xdf"
b"t\xe8\x13\xb475~\x83_ &`\xa6\xdc\xad\x8b\xc6ly\x98\xf7\x00\x00\x00\x81"
b"\x00\x85K\x1bz\x9b\x12\x107\x9e\x1f\xad^\xda\xfe\xc6\x96\xfe\xdf5k\xb94"
b"\xe2\x16\x97\x92&\t\xbd\xbdp \x03\xa75\xbd-\x1b\xa0\xd2\x07G+\xd4\xde"
b"\xa8\xa8\x07\x07\x1b\xb8\x04 \xa7'A<l99\xe9A\xce\xe7\x17\x1d\xd1L\\\xbc="
b"\xd2&&\xfej\xd6\xfdHr\xaeF\xfa{\xc3\xd3\x19`D\x1d\xa5\x13\xa7\x80\xf5c)"
b"\xd4z]\x06\x07\x16]\xf6\x8b=\xcbd:\xe2\x84ZM\x8c\x06--\x9d\x1c\xeb\x83Lx"
b"=yT\xce\x00\x00\x00\x81\x00\xd9p\x06\xd8\xe2\xbc\xd4x\x91P\x94\xd4\xc1"
b"\x1b\x898lFdZQ\xa0\x9a\x07=H\x8f\x03Q\xcck\x12\x8e}\x1a\xb1e\xe7qu9\xe02"
b"\x05u\x8d\x18L\xaf\x93\xb1I\xb1f_xbz\xd1\x0c\xca\xe6MC\xb3\x9c\xf4k}\xe6"
b'\x0c\x98\xdc\xcf!b\x8e\xd5.\x12\xde\x04\xae\xd7$n\x831\xa2\x15\xa2D="'
b'\xa9b&"\xb9\xb2\xedT\n\x9d\x08\x83\xa7\x07\r\xff\x19\x18\x8e\xd8\xab'
b"\x1d\xdaH\x9c1h\x11\xa1fm\xe3\xd8\x1d\x00\x00\x00\x81\x00\xfbD\x17\x8b"
b"\xa46\xbe\x1e7\x1d\xa7\xf6al\x04\xc4\xaa\xddx>\x07\x8c\x1e3\x02\xae\x03"
b"\x14\x87\x83z\xe5\x9e}\x08g\xa8\xf2\xaa\xbf\x12p\xcfr\xa9\xa7\xc7\x0b"
b"\x1d\x88\xd5 \xfd\x9cc\xcaG0UN\x8b\xc4\xcf\xf4\x7f\x16\xa4\x92\x12t\xa1"
b"\t\xc2\xc4n\x9c\x8c3\xef\xa5\xe5\xf7\xe0+\xadO\\\x11\xaa\x1a\x847[\xfdz"
b"\xea\xc3\xcd|\xb0\xc8\xe4\x1fTc\xb5\xc7\xaf\xdf\xf4\t\xa7\xfc\xc7%\xfc\\"
b"\xe9\x91\xd7\x92\xc5\x98\x1eV\xd3\xb1#"
)
publicDSA_openssh = b"""\
ssh-dss AAAAB3NzaC1kc3MAAACBAJKQOsVERVDQIpANHH+JAAylo9\
LvFYmFFVMIuHFGlZpIL7sh3IMkqy+cssINM/lnHD3fmsAyLlUXZtt6PD9LgZRazsPOgptuH+Gu48G\
+yFuE8l0fVVUivos/MmYVJ66qT99htcZKatrTWZnpVW7gFABoqw+he2LZ0gkeU0+Sx9a5AAAAFQD0\
EYmTNaFJ8CS0+vFSF4nYcyEnSQAAAIEAkgLjxHJAE7qFWdTqf7EZngu7jAGmdB9k3YzMHe1ldMxEB\
7zNw5aOnxjhoYLtiHeoEcOk2XOyvnE+VfhIWwWAdOiKRTEZlmizkvhGbq0DCe2EPMXirjqWACI5nD\
ioQX1oEMonR8N3AEO5v9SfBqS2Q9R6OBr6lf04RvwpHZ0UGu8AAACAAhRpxGMIWEyaEh8YnjiazQT\
NEpklRZqeBGo1gotJggNmVaIQNIClGlLyCi359efEUuQcZ9SXxM59P+hecc/GU/GHakW5YWE4dP2G\
gdgMQWC7S6WFIXePGGXqNQDdWxlX8umhenvQqa1PnKrFRhDrJw8Z7GjdHxflsxCEmXPoLN8= \
comment\
"""
privateDSA_openssh = b"""\
-----BEGIN DSA PRIVATE KEY-----
MIIBvAIBAAKBgQCSkDrFREVQ0CKQDRx/iQAMpaPS7xWJhRVTCLhxRpWaSC+7IdyD
JKsvnLLCDTP5Zxw935rAMi5VF2bbejw/S4GUWs7DzoKbbh/hruPBvshbhPJdH1VV
Ir6LPzJmFSeuqk/fYbXGSmra01mZ6VVu4BQAaKsPoXti2dIJHlNPksfWuQIVAPQR
iZM1oUnwJLT68VIXidhzISdJAoGBAJIC48RyQBO6hVnU6n+xGZ4Lu4wBpnQfZN2M
zB3tZXTMRAe8zcOWjp8Y4aGC7Yh3qBHDpNlzsr5xPlX4SFsFgHToikUxGZZos5L4
Rm6tAwnthDzF4q46lgAiOZw4qEF9aBDKJ0fDdwBDub/UnwaktkPUejga+pX9OEb8
KR2dFBrvAoGAAhRpxGMIWEyaEh8YnjiazQTNEpklRZqeBGo1gotJggNmVaIQNICl
GlLyCi359efEUuQcZ9SXxM59P+hecc/GU/GHakW5YWE4dP2GgdgMQWC7S6WFIXeP
GGXqNQDdWxlX8umhenvQqa1PnKrFRhDrJw8Z7GjdHxflsxCEmXPoLN8CFQDV2gbL
czUdxCus0pfEP1bddaXRLQ==
-----END DSA PRIVATE KEY-----\
"""
privateDSA_openssh_new = b"""\
-----BEGIN OPENSSH PRIVATE KEY-----
b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAABsgAAAAdzc2gtZH
NzAAAAgQCSkDrFREVQ0CKQDRx/iQAMpaPS7xWJhRVTCLhxRpWaSC+7IdyDJKsvnLLCDTP5
Zxw935rAMi5VF2bbejw/S4GUWs7DzoKbbh/hruPBvshbhPJdH1VVIr6LPzJmFSeuqk/fYb
XGSmra01mZ6VVu4BQAaKsPoXti2dIJHlNPksfWuQAAABUA9BGJkzWhSfAktPrxUheJ2HMh
J0kAAACBAJIC48RyQBO6hVnU6n+xGZ4Lu4wBpnQfZN2MzB3tZXTMRAe8zcOWjp8Y4aGC7Y
h3qBHDpNlzsr5xPlX4SFsFgHToikUxGZZos5L4Rm6tAwnthDzF4q46lgAiOZw4qEF9aBDK
J0fDdwBDub/UnwaktkPUejga+pX9OEb8KR2dFBrvAAAAgAIUacRjCFhMmhIfGJ44ms0EzR
KZJUWangRqNYKLSYIDZlWiEDSApRpS8got+fXnxFLkHGfUl8TOfT/oXnHPxlPxh2pFuWFh
OHT9hoHYDEFgu0ulhSF3jxhl6jUA3VsZV/LpoXp70KmtT5yqxUYQ6ycPGexo3R8X5bMQhJ
lz6CzfAAAB2MVcBjzFXAY8AAAAB3NzaC1kc3MAAACBAJKQOsVERVDQIpANHH+JAAylo9Lv
FYmFFVMIuHFGlZpIL7sh3IMkqy+cssINM/lnHD3fmsAyLlUXZtt6PD9LgZRazsPOgptuH+
Gu48G+yFuE8l0fVVUivos/MmYVJ66qT99htcZKatrTWZnpVW7gFABoqw+he2LZ0gkeU0+S
x9a5AAAAFQD0EYmTNaFJ8CS0+vFSF4nYcyEnSQAAAIEAkgLjxHJAE7qFWdTqf7EZngu7jA
GmdB9k3YzMHe1ldMxEB7zNw5aOnxjhoYLtiHeoEcOk2XOyvnE+VfhIWwWAdOiKRTEZlmiz
kvhGbq0DCe2EPMXirjqWACI5nDioQX1oEMonR8N3AEO5v9SfBqS2Q9R6OBr6lf04RvwpHZ
0UGu8AAACAAhRpxGMIWEyaEh8YnjiazQTNEpklRZqeBGo1gotJggNmVaIQNIClGlLyCi35
9efEUuQcZ9SXxM59P+hecc/GU/GHakW5YWE4dP2GgdgMQWC7S6WFIXePGGXqNQDdWxlX8u
mhenvQqa1PnKrFRhDrJw8Z7GjdHxflsxCEmXPoLN8AAAAVANXaBstzNR3EK6zSl8Q/Vt11
pdEtAAAAAAE=
-----END OPENSSH PRIVATE KEY-----
"""
publicDSA_lsh = decodebytes(
b"""\
e0tERXdPbkIxWW14cFl5MXJaWGtvTXpwa2MyRW9NVHB3TVRJNU9nQ1NrRHJGUkVWUTBDS1FEUngv
aVFBTXBhUFM3eFdKaFJWVENMaHhScFdhU0MrN0lkeURKS3N2bkxMQ0RUUDVaeHc5MzVyQU1pNVZG
MmJiZWp3L1M0R1VXczdEem9LYmJoL2hydVBCdnNoYmhQSmRIMVZWSXI2TFB6Sm1GU2V1cWsvZlli
WEdTbXJhMDFtWjZWVnU0QlFBYUtzUG9YdGkyZElKSGxOUGtzZld1U2tvTVRweE1qRTZBUFFSaVpN
MW9VbndKTFQ2OFZJWGlkaHpJU2RKS1NneE9tY3hNams2QUpJQzQ4UnlRQk82aFZuVTZuK3hHWjRM
dTR3QnBuUWZaTjJNekIzdFpYVE1SQWU4emNPV2pwOFk0YUdDN1loM3FCSERwTmx6c3I1eFBsWDRT
RnNGZ0hUb2lrVXhHWlpvczVMNFJtNnRBd250aER6RjRxNDZsZ0FpT1p3NHFFRjlhQkRLSjBmRGR3
QkR1Yi9Vbndha3RrUFVlamdhK3BYOU9FYjhLUjJkRkJydktTZ3hPbmt4TWpnNkFoUnB4R01JV0V5
YUVoOFluamlhelFUTkVwa2xSWnFlQkdvMWdvdEpnZ05tVmFJUU5JQ2xHbEx5Q2kzNTllZkVVdVFj
WjlTWHhNNTlQK2hlY2MvR1UvR0hha1c1WVdFNGRQMkdnZGdNUVdDN1M2V0ZJWGVQR0dYcU5RRGRX
eGxYOHVtaGVudlFxYTFQbktyRlJoRHJKdzhaN0dqZEh4ZmxzeENFbVhQb0xOOHBLU2s9fQ==
"""
)
privateDSA_lsh = decodebytes(
b"""\
KDExOnByaXZhdGUta2V5KDM6ZHNhKDE6cDEyOToAkpA6xURFUNAikA0cf4kADKWj0u8ViYUVUwi4
cUaVmkgvuyHcgySrL5yywg0z+WccPd+awDIuVRdm23o8P0uBlFrOw86Cm24f4a7jwb7IW4TyXR9V
VSK+iz8yZhUnrqpP32G1xkpq2tNZmelVbuAUAGirD6F7YtnSCR5TT5LH1rkpKDE6cTIxOgD0EYmT
NaFJ8CS0+vFSF4nYcyEnSSkoMTpnMTI5OgCSAuPEckATuoVZ1Op/sRmeC7uMAaZ0H2TdjMwd7WV0
zEQHvM3Dlo6fGOGhgu2Id6gRw6TZc7K+cT5V+EhbBYB06IpFMRmWaLOS+EZurQMJ7YQ8xeKuOpYA
IjmcOKhBfWgQyidHw3cAQ7m/1J8GpLZD1Ho4GvqV/ThG/CkdnRQa7ykoMTp5MTI4OgIUacRjCFhM
mhIfGJ44ms0EzRKZJUWangRqNYKLSYIDZlWiEDSApRpS8got+fXnxFLkHGfUl8TOfT/oXnHPxlPx
h2pFuWFhOHT9hoHYDEFgu0ulhSF3jxhl6jUA3VsZV/LpoXp70KmtT5yqxUYQ6ycPGexo3R8X5bMQ
hJlz6CzfKSgxOngyMToA1doGy3M1HcQrrNKXxD9W3XWl0S0pKSk=
"""
)
privateDSA_agentv3 = decodebytes(
b"""\
AAAAB3NzaC1kc3MAAACBAJKQOsVERVDQIpANHH+JAAylo9LvFYmFFVMIuHFGlZpIL7sh3IMkqy+c
ssINM/lnHD3fmsAyLlUXZtt6PD9LgZRazsPOgptuH+Gu48G+yFuE8l0fVVUivos/MmYVJ66qT99h
tcZKatrTWZnpVW7gFABoqw+he2LZ0gkeU0+Sx9a5AAAAFQD0EYmTNaFJ8CS0+vFSF4nYcyEnSQAA
AIEAkgLjxHJAE7qFWdTqf7EZngu7jAGmdB9k3YzMHe1ldMxEB7zNw5aOnxjhoYLtiHeoEcOk2XOy
vnE+VfhIWwWAdOiKRTEZlmizkvhGbq0DCe2EPMXirjqWACI5nDioQX1oEMonR8N3AEO5v9SfBqS2
Q9R6OBr6lf04RvwpHZ0UGu8AAACAAhRpxGMIWEyaEh8YnjiazQTNEpklRZqeBGo1gotJggNmVaIQ
NIClGlLyCi359efEUuQcZ9SXxM59P+hecc/GU/GHakW5YWE4dP2GgdgMQWC7S6WFIXePGGXqNQDd
WxlX8umhenvQqa1PnKrFRhDrJw8Z7GjdHxflsxCEmXPoLN8AAAAVANXaBstzNR3EK6zSl8Q/Vt11
pdEt
"""
)
__all__ = [
"DSAData",
"RSAData",
"privateDSA_agentv3",
"privateDSA_lsh",
"privateDSA_openssh",
"privateRSA_agentv3",
"privateRSA_lsh",
"privateRSA_openssh",
"publicDSA_lsh",
"publicDSA_openssh",
"publicRSA_lsh",
"publicRSA_openssh",
]

View File

@@ -0,0 +1,28 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Loopback helper used in test_ssh and test_recvline
"""
from twisted.protocols import loopback
class LoopbackRelay(loopback.LoopbackRelay):
clearCall = None
def logPrefix(self):
return f"LoopbackRelay({self.target.__class__.__name__!r})"
def write(self, data):
loopback.LoopbackRelay.write(self, data)
if self.clearCall is not None:
self.clearCall.cancel()
from twisted.internet import reactor
self.clearCall = reactor.callLater(0, self._clearBuffer)
def _clearBuffer(self):
self.clearCall = None
loopback.LoopbackRelay.clearBuffer(self)

View File

@@ -0,0 +1,45 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{SSHTransportAddrress} in ssh/address.py
"""
from __future__ import annotations
from typing import Callable
from twisted.conch.ssh.address import SSHTransportAddress
from twisted.internet.address import IPv4Address
from twisted.internet.test.test_address import AddressTestCaseMixin
from twisted.trial import unittest
class SSHTransportAddressTests(unittest.TestCase, AddressTestCaseMixin):
"""
L{twisted.conch.ssh.address.SSHTransportAddress} is what Conch transports
use to represent the other side of the SSH connection. This tests the
basic functionality of that class (string representation, comparison, &c).
"""
def _stringRepresentation(self, stringFunction: Callable[[object], str]) -> None:
"""
The string representation of C{SSHTransportAddress} should be
"SSHTransportAddress(<stringFunction on address>)".
"""
addr = self.buildAddress()
stringValue = stringFunction(addr)
addressValue = stringFunction(addr.address)
self.assertEqual(stringValue, "SSHTransportAddress(%s)" % addressValue)
def buildAddress(self) -> SSHTransportAddress:
"""
Create an arbitrary new C{SSHTransportAddress}. A new instance is
created for each call, but always for the same address.
"""
return SSHTransportAddress(IPv4Address("TCP", "127.0.0.1", 22))
def buildDifferentAddress(self) -> SSHTransportAddress:
"""
Like C{buildAddress}, but with a different fixed address.
"""
return SSHTransportAddress(IPv4Address("TCP", "127.0.0.2", 22))

View File

@@ -0,0 +1,398 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.conch.ssh.agent}.
"""
import struct
from twisted.test import iosim
from twisted.trial import unittest
try:
import cryptography as _cryptography
except ImportError:
cryptography = None
else:
cryptography = _cryptography
try:
from twisted.conch.ssh import agent as _agent, keys as _keys
except ImportError:
keys = agent = None
else:
keys, agent = _keys, _agent
from twisted.conch.error import ConchError, MissingKeyStoreError
from twisted.conch.test import keydata
class StubFactory:
"""
Mock factory that provides the keys attribute required by the
SSHAgentServerProtocol
"""
def __init__(self):
self.keys = {}
class AgentTestBase(unittest.TestCase):
"""
Tests for SSHAgentServer/Client.
"""
if agent is None or keys is None:
skip = "Cannot run without cryptography"
def setUp(self):
# wire up our client <-> server
self.client, self.server, self.pump = iosim.connectedServerAndClient(
agent.SSHAgentServer, agent.SSHAgentClient
)
# the server's end of the protocol is stateful and we store it on the
# factory, for which we only need a mock
self.server.factory = StubFactory()
# pub/priv keys of each kind
self.rsaPrivate = keys.Key.fromString(keydata.privateRSA_openssh)
self.dsaPrivate = keys.Key.fromString(keydata.privateDSA_openssh)
self.rsaPublic = keys.Key.fromString(keydata.publicRSA_openssh)
self.dsaPublic = keys.Key.fromString(keydata.publicDSA_openssh)
class ServerProtocolContractWithFactoryTests(AgentTestBase):
"""
The server protocol is stateful and so uses its factory to track state
across requests. This test asserts that the protocol raises if its factory
doesn't provide the necessary storage for that state.
"""
def test_factorySuppliesKeyStorageForServerProtocol(self):
# need a message to send into the server
msg = struct.pack("!LB", 1, agent.AGENTC_REQUEST_IDENTITIES)
del self.server.factory.__dict__["keys"]
self.assertRaises(MissingKeyStoreError, self.server.dataReceived, msg)
class UnimplementedVersionOneServerTests(AgentTestBase):
"""
Tests for methods with no-op implementations on the server. We need these
for clients, such as openssh, that try v1 methods before going to v2.
Because the client doesn't expose these operations with nice method names,
we invoke sendRequest directly with an op code.
"""
def test_agentc_REQUEST_RSA_IDENTITIES(self):
"""
assert that we get the correct op code for an RSA identities request
"""
d = self.client.sendRequest(agent.AGENTC_REQUEST_RSA_IDENTITIES, b"")
self.pump.flush()
def _cb(packet):
self.assertEqual(agent.AGENT_RSA_IDENTITIES_ANSWER, ord(packet[0:1]))
return d.addCallback(_cb)
def test_agentc_REMOVE_RSA_IDENTITY(self):
"""
assert that we get the correct op code for an RSA remove identity request
"""
d = self.client.sendRequest(agent.AGENTC_REMOVE_RSA_IDENTITY, b"")
self.pump.flush()
return d.addCallback(self.assertEqual, b"")
def test_agentc_REMOVE_ALL_RSA_IDENTITIES(self):
"""
assert that we get the correct op code for an RSA remove all identities
request.
"""
d = self.client.sendRequest(agent.AGENTC_REMOVE_ALL_RSA_IDENTITIES, b"")
self.pump.flush()
return d.addCallback(self.assertEqual, b"")
if agent is not None:
class CorruptServer(agent.SSHAgentServer): # type: ignore[name-defined]
"""
A misbehaving server that returns bogus response op codes so that we can
verify that our callbacks that deal with these op codes handle such
miscreants.
"""
def agentc_REQUEST_IDENTITIES(self, data):
self.sendResponse(254, b"")
def agentc_SIGN_REQUEST(self, data):
self.sendResponse(254, b"")
class ClientWithBrokenServerTests(AgentTestBase):
"""
verify error handling code in the client using a misbehaving server
"""
def setUp(self):
AgentTestBase.setUp(self)
self.client, self.server, self.pump = iosim.connectedServerAndClient(
CorruptServer, agent.SSHAgentClient
)
# the server's end of the protocol is stateful and we store it on the
# factory, for which we only need a mock
self.server.factory = StubFactory()
def test_signDataCallbackErrorHandling(self):
"""
Assert that L{SSHAgentClient.signData} raises a ConchError
if we get a response from the server whose opcode doesn't match
the protocol for data signing requests.
"""
d = self.client.signData(self.rsaPublic.blob(), b"John Hancock")
self.pump.flush()
return self.assertFailure(d, ConchError)
def test_requestIdentitiesCallbackErrorHandling(self):
"""
Assert that L{SSHAgentClient.requestIdentities} raises a ConchError
if we get a response from the server whose opcode doesn't match
the protocol for identity requests.
"""
d = self.client.requestIdentities()
self.pump.flush()
return self.assertFailure(d, ConchError)
class AgentKeyAdditionTests(AgentTestBase):
"""
Test adding different flavors of keys to an agent.
"""
def test_addRSAIdentityNoComment(self):
"""
L{SSHAgentClient.addIdentity} adds the private key it is called
with to the SSH agent server to which it is connected, associating
it with the comment it is called with.
This test asserts that omitting the comment produces an
empty string for the comment on the server.
"""
d = self.client.addIdentity(self.rsaPrivate.privateBlob())
self.pump.flush()
def _check(ignored):
serverKey = self.server.factory.keys[self.rsaPrivate.blob()]
self.assertEqual(self.rsaPrivate, serverKey[0])
self.assertEqual(b"", serverKey[1])
return d.addCallback(_check)
def test_addDSAIdentityNoComment(self):
"""
L{SSHAgentClient.addIdentity} adds the private key it is called
with to the SSH agent server to which it is connected, associating
it with the comment it is called with.
This test asserts that omitting the comment produces an
empty string for the comment on the server.
"""
d = self.client.addIdentity(self.dsaPrivate.privateBlob())
self.pump.flush()
def _check(ignored):
serverKey = self.server.factory.keys[self.dsaPrivate.blob()]
self.assertEqual(self.dsaPrivate, serverKey[0])
self.assertEqual(b"", serverKey[1])
return d.addCallback(_check)
def test_addRSAIdentityWithComment(self):
"""
L{SSHAgentClient.addIdentity} adds the private key it is called
with to the SSH agent server to which it is connected, associating
it with the comment it is called with.
This test asserts that the server receives/stores the comment
as sent by the client.
"""
d = self.client.addIdentity(
self.rsaPrivate.privateBlob(), comment=b"My special key"
)
self.pump.flush()
def _check(ignored):
serverKey = self.server.factory.keys[self.rsaPrivate.blob()]
self.assertEqual(self.rsaPrivate, serverKey[0])
self.assertEqual(b"My special key", serverKey[1])
return d.addCallback(_check)
def test_addDSAIdentityWithComment(self):
"""
L{SSHAgentClient.addIdentity} adds the private key it is called
with to the SSH agent server to which it is connected, associating
it with the comment it is called with.
This test asserts that the server receives/stores the comment
as sent by the client.
"""
d = self.client.addIdentity(
self.dsaPrivate.privateBlob(), comment=b"My special key"
)
self.pump.flush()
def _check(ignored):
serverKey = self.server.factory.keys[self.dsaPrivate.blob()]
self.assertEqual(self.dsaPrivate, serverKey[0])
self.assertEqual(b"My special key", serverKey[1])
return d.addCallback(_check)
class AgentClientFailureTests(AgentTestBase):
def test_agentFailure(self):
"""
verify that the client raises ConchError on AGENT_FAILURE
"""
d = self.client.sendRequest(254, b"")
self.pump.flush()
return self.assertFailure(d, ConchError)
class AgentIdentityRequestsTests(AgentTestBase):
"""
Test operations against a server with identities already loaded.
"""
def setUp(self):
AgentTestBase.setUp(self)
self.server.factory.keys[self.dsaPrivate.blob()] = (
self.dsaPrivate,
b"a comment",
)
self.server.factory.keys[self.rsaPrivate.blob()] = (
self.rsaPrivate,
b"another comment",
)
def test_signDataRSA(self):
"""
Sign data with an RSA private key and then verify it with the public
key.
"""
d = self.client.signData(self.rsaPublic.blob(), b"John Hancock")
self.pump.flush()
signature = self.successResultOf(d)
expected = self.rsaPrivate.sign(b"John Hancock")
self.assertEqual(expected, signature)
self.assertTrue(self.rsaPublic.verify(signature, b"John Hancock"))
def test_signDataDSA(self):
"""
Sign data with a DSA private key and then verify it with the public
key.
"""
d = self.client.signData(self.dsaPublic.blob(), b"John Hancock")
self.pump.flush()
def _check(sig):
# Cannot do this b/c DSA uses random numbers when signing
# expected = self.dsaPrivate.sign("John Hancock")
# self.assertEqual(expected, sig)
self.assertTrue(self.dsaPublic.verify(sig, b"John Hancock"))
return d.addCallback(_check)
def test_signDataRSAErrbackOnUnknownBlob(self):
"""
Assert that we get an errback if we try to sign data using a key that
wasn't added.
"""
del self.server.factory.keys[self.rsaPublic.blob()]
d = self.client.signData(self.rsaPublic.blob(), b"John Hancock")
self.pump.flush()
return self.assertFailure(d, ConchError)
def test_requestIdentities(self):
"""
Assert that we get all of the keys/comments that we add when we issue a
request for all identities.
"""
d = self.client.requestIdentities()
self.pump.flush()
def _check(keyt):
expected = {}
expected[self.dsaPublic.blob()] = b"a comment"
expected[self.rsaPublic.blob()] = b"another comment"
received = {}
for k in keyt:
received[keys.Key.fromString(k[0], type="blob").blob()] = k[1]
self.assertEqual(expected, received)
return d.addCallback(_check)
class AgentKeyRemovalTests(AgentTestBase):
"""
Test support for removing keys in a remote server.
"""
def setUp(self):
AgentTestBase.setUp(self)
self.server.factory.keys[self.dsaPrivate.blob()] = (
self.dsaPrivate,
b"a comment",
)
self.server.factory.keys[self.rsaPrivate.blob()] = (
self.rsaPrivate,
b"another comment",
)
def test_removeRSAIdentity(self):
"""
Assert that we can remove an RSA identity.
"""
# only need public key for this
d = self.client.removeIdentity(self.rsaPrivate.blob())
self.pump.flush()
def _check(ignored):
self.assertEqual(1, len(self.server.factory.keys))
self.assertIn(self.dsaPrivate.blob(), self.server.factory.keys)
self.assertNotIn(self.rsaPrivate.blob(), self.server.factory.keys)
return d.addCallback(_check)
def test_removeDSAIdentity(self):
"""
Assert that we can remove a DSA identity.
"""
# only need public key for this
d = self.client.removeIdentity(self.dsaPrivate.blob())
self.pump.flush()
def _check(ignored):
self.assertEqual(1, len(self.server.factory.keys))
self.assertIn(self.rsaPrivate.blob(), self.server.factory.keys)
return d.addCallback(_check)
def test_removeAllIdentities(self):
"""
Assert that we can remove all identities.
"""
d = self.client.removeAllIdentities()
self.pump.flush()
def _check(ignored):
self.assertEqual(0, len(self.server.factory.keys))
return d.addCallback(_check)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,358 @@
# Copyright Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Test ssh/channel.py.
"""
from __future__ import annotations
from unittest import skipIf
from zope.interface.verify import verifyObject
try:
from twisted.conch.ssh import channel
from twisted.conch.ssh.address import SSHTransportAddress
from twisted.conch.ssh.service import SSHService
from twisted.conch.ssh.transport import SSHServerTransport
from twisted.internet import interfaces
from twisted.internet.address import IPv4Address
from twisted.internet.testing import StringTransport
skipTest = ""
except ImportError:
skipTest = "Conch SSH not supported."
SSHService = object # type: ignore[assignment,misc]
from twisted.trial.unittest import TestCase
class MockConnection(SSHService):
"""
A mock for twisted.conch.ssh.connection.SSHConnection. Record the data
that channels send, and when they try to close the connection.
@ivar data: a L{dict} mapping channel id #s to lists of data sent by that
channel.
@ivar extData: a L{dict} mapping channel id #s to lists of 2-tuples
(extended data type, data) sent by that channel.
@ivar closes: a L{dict} mapping channel id #s to True if that channel sent
a close message.
"""
def __init__(self) -> None:
self.data: dict[channel.SSHChannel, list[bytes]] = {}
self.extData: dict[channel.SSHChannel, list[tuple[int, bytes]]] = {}
self.closes: dict[channel.SSHChannel, bool] = {}
def logPrefix(self) -> str:
"""
Return our logging prefix.
"""
return "MockConnection"
def sendData(self, channel: channel.SSHChannel, data: bytes) -> None:
"""
Record the sent data.
"""
self.data.setdefault(channel, []).append(data)
def sendExtendedData(
self, channel: channel.SSHChannel, type: int, data: bytes
) -> None:
"""
Record the sent extended data.
"""
self.extData.setdefault(channel, []).append((type, data))
def sendClose(self, channel: channel.SSHChannel) -> None:
"""
Record that the channel sent a close message.
"""
self.closes[channel] = True
def connectSSHTransport(
service: SSHService,
hostAddress: interfaces.IAddress | None = None,
peerAddress: interfaces.IAddress | None = None,
) -> None:
"""
Connect a SSHTransport which is already connected to a remote peer to
the channel under test.
@param service: Service used over the connected transport.
@type service: L{SSHService}
@param hostAddress: Local address of the connected transport.
@type hostAddress: L{interfaces.IAddress}
@param peerAddress: Remote address of the connected transport.
@type peerAddress: L{interfaces.IAddress}
"""
transport = SSHServerTransport()
transport.makeConnection(
StringTransport(hostAddress=hostAddress, peerAddress=peerAddress)
)
transport.setService(service)
@skipIf(skipTest, skipTest)
class ChannelTests(TestCase):
"""
Tests for L{SSHChannel}.
"""
def setUp(self) -> None:
"""
Initialize the channel. remoteMaxPacket is 10 so that data is able
to be sent (the default of 0 means no data is sent because no packets
are made).
"""
self.conn = MockConnection()
self.channel = channel.SSHChannel(conn=self.conn, remoteMaxPacket=10)
self.channel.name = b"channel"
def test_interface(self) -> None:
"""
L{SSHChannel} instances provide L{interfaces.ITransport}.
"""
self.assertTrue(verifyObject(interfaces.ITransport, self.channel))
def test_init(self) -> None:
"""
Test that SSHChannel initializes correctly. localWindowSize defaults
to 131072 (2**17) and localMaxPacket to 32768 (2**15) as reasonable
defaults (what OpenSSH uses for those variables).
The values in the second set of assertions are meaningless; they serve
only to verify that the instance variables are assigned in the correct
order.
"""
c = channel.SSHChannel(conn=self.conn)
self.assertEqual(c.localWindowSize, 131072)
self.assertEqual(c.localWindowLeft, 131072)
self.assertEqual(c.localMaxPacket, 32768)
self.assertEqual(c.remoteWindowLeft, 0)
self.assertEqual(c.remoteMaxPacket, 0)
self.assertEqual(c.conn, self.conn)
self.assertIsNone(c.data)
self.assertIsNone(c.avatar)
c2 = channel.SSHChannel(1, 2, 3, 4, 5, 6, 7)
self.assertEqual(c2.localWindowSize, 1)
self.assertEqual(c2.localWindowLeft, 1)
self.assertEqual(c2.localMaxPacket, 2)
self.assertEqual(c2.remoteWindowLeft, 3)
self.assertEqual(c2.remoteMaxPacket, 4)
self.assertEqual(c2.conn, 5)
self.assertEqual(c2.data, 6)
self.assertEqual(c2.avatar, 7)
def test_str(self) -> None:
"""
Test that str(SSHChannel) works gives the channel name and local and
remote windows at a glance..
"""
self.assertEqual(str(self.channel), "<SSHChannel channel (lw 131072 rw 0)>")
self.assertEqual(
str(channel.SSHChannel(localWindow=1)), "<SSHChannel None (lw 1 rw 0)>"
)
def test_bytes(self) -> None:
"""
Test that bytes(SSHChannel) works, gives the channel name and
local and remote windows at a glance..
"""
self.assertEqual(
self.channel.__bytes__(), b"<SSHChannel channel (lw 131072 rw 0)>"
)
self.assertEqual(
channel.SSHChannel(localWindow=1).__bytes__(),
b"<SSHChannel None (lw 1 rw 0)>",
)
def test_logPrefix(self) -> None:
"""
Test that SSHChannel.logPrefix gives the name of the channel, the
local channel ID and the underlying connection.
"""
self.assertEqual(
self.channel.logPrefix(), "SSHChannel channel (unknown) on MockConnection"
)
def test_addWindowBytes(self) -> None:
"""
Test that addWindowBytes adds bytes to the window and resumes writing
if it was paused.
"""
cb = [False]
def stubStartWriting() -> None:
cb[0] = True
self.channel.startWriting = stubStartWriting # type: ignore[method-assign]
self.channel.write(b"test")
self.channel.writeExtended(1, b"test")
self.channel.addWindowBytes(50)
self.assertEqual(self.channel.remoteWindowLeft, 50 - 4 - 4)
self.assertTrue(self.channel.areWriting)
self.assertTrue(cb[0])
self.assertEqual(self.channel.buf, b"")
self.assertEqual(self.conn.data[self.channel], [b"test"])
self.assertEqual(self.channel.extBuf, [])
self.assertEqual(self.conn.extData[self.channel], [(1, b"test")])
cb[0] = False
self.channel.addWindowBytes(20)
self.assertFalse(cb[0])
self.channel.write(b"a" * 80)
self.channel.loseConnection()
self.channel.addWindowBytes(20)
self.assertFalse(cb[0])
def test_requestReceived(self) -> None:
"""
Test that requestReceived handles requests by dispatching them to
request_* methods.
"""
self.channel.request_test_method = lambda data: data == b"" # type: ignore[attr-defined]
self.assertTrue(self.channel.requestReceived(b"test-method", b""))
self.assertFalse(self.channel.requestReceived(b"test-method", b"a"))
self.assertFalse(self.channel.requestReceived(b"bad-method", b""))
def test_closeReceieved(self) -> None:
"""
Test that the default closeReceieved closes the connection.
"""
self.assertFalse(self.channel.closing)
self.channel.closeReceived()
self.assertTrue(self.channel.closing)
def test_write(self) -> None:
"""
Test that write handles data correctly. Send data up to the size
of the remote window, splitting the data into packets of length
remoteMaxPacket.
"""
cb = [False]
def stubStopWriting() -> None:
cb[0] = True
# no window to start with
self.channel.stopWriting = stubStopWriting # type: ignore[method-assign]
self.channel.write(b"d")
self.channel.write(b"a")
self.assertFalse(self.channel.areWriting)
self.assertTrue(cb[0])
# regular write
self.channel.addWindowBytes(20)
self.channel.write(b"ta")
data = self.conn.data[self.channel]
self.assertEqual(data, [b"da", b"ta"])
self.assertEqual(self.channel.remoteWindowLeft, 16)
# larger than max packet
self.channel.write(b"12345678901")
self.assertEqual(data, [b"da", b"ta", b"1234567890", b"1"])
self.assertEqual(self.channel.remoteWindowLeft, 5)
# running out of window
cb[0] = False
self.channel.write(b"123456")
self.assertFalse(self.channel.areWriting)
self.assertTrue(cb[0])
self.assertEqual(data, [b"da", b"ta", b"1234567890", b"1", b"12345"])
self.assertEqual(self.channel.buf, b"6")
self.assertEqual(self.channel.remoteWindowLeft, 0)
def test_writeExtended(self) -> None:
"""
Test that writeExtended handles data correctly. Send extended data
up to the size of the window, splitting the extended data into packets
of length remoteMaxPacket.
"""
cb = [False]
def stubStopWriting() -> None:
cb[0] = True
# no window to start with
self.channel.stopWriting = stubStopWriting # type: ignore[method-assign]
self.channel.writeExtended(1, b"d")
self.channel.writeExtended(1, b"a")
self.channel.writeExtended(2, b"t")
self.assertFalse(self.channel.areWriting)
self.assertTrue(cb[0])
# regular write
self.channel.addWindowBytes(20)
self.channel.writeExtended(2, b"a")
data = self.conn.extData[self.channel]
self.assertEqual(data, [(1, b"da"), (2, b"t"), (2, b"a")])
self.assertEqual(self.channel.remoteWindowLeft, 16)
# larger than max packet
self.channel.writeExtended(3, b"12345678901")
self.assertEqual(
data, [(1, b"da"), (2, b"t"), (2, b"a"), (3, b"1234567890"), (3, b"1")]
)
self.assertEqual(self.channel.remoteWindowLeft, 5)
# running out of window
cb[0] = False
self.channel.writeExtended(4, b"123456")
self.assertFalse(self.channel.areWriting)
self.assertTrue(cb[0])
self.assertEqual(
data,
[
(1, b"da"),
(2, b"t"),
(2, b"a"),
(3, b"1234567890"),
(3, b"1"),
(4, b"12345"),
],
)
self.assertEqual(self.channel.extBuf, [[4, b"6"]])
self.assertEqual(self.channel.remoteWindowLeft, 0)
def test_writeSequence(self) -> None:
"""
Test that writeSequence is equivalent to write(''.join(sequece)).
"""
self.channel.addWindowBytes(20)
self.channel.writeSequence(b"%d" % (i,) for i in range(10))
self.assertEqual(self.conn.data[self.channel], [b"0123456789"])
def test_loseConnection(self) -> None:
"""
Tesyt that loseConnection() doesn't close the channel until all
the data is sent.
"""
self.channel.write(b"data")
self.channel.writeExtended(1, b"datadata")
self.channel.loseConnection()
self.assertIsNone(self.conn.closes.get(self.channel))
self.channel.addWindowBytes(4) # send regular data
self.assertIsNone(self.conn.closes.get(self.channel))
self.channel.addWindowBytes(8) # send extended data
self.assertTrue(self.conn.closes.get(self.channel))
def test_getPeer(self) -> None:
"""
L{SSHChannel.getPeer} returns the same object as the underlying
transport's C{getPeer} method returns.
"""
peer = IPv4Address("TCP", "192.168.0.1", 54321)
connectSSHTransport(service=self.channel.conn, peerAddress=peer)
self.assertEqual(SSHTransportAddress(peer), self.channel.getPeer())
def test_getHost(self) -> None:
"""
L{SSHChannel.getHost} returns the same object as the underlying
transport's C{getHost} method returns.
"""
host = IPv4Address("TCP", "127.0.0.1", 12345)
connectSSHTransport(service=self.channel.conn, hostAddress=host)
self.assertEqual(SSHTransportAddress(host), self.channel.getHost())

View File

@@ -0,0 +1,886 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.conch.checkers}.
"""
import os
from base64 import encodebytes
from collections import namedtuple
from io import BytesIO
from typing import Optional
cryptSkip: Optional[str]
try:
import crypt
except ImportError:
cryptSkip = "cannot run without crypt module"
else:
cryptSkip = None
from zope.interface.verify import verifyObject
from twisted.cred.checkers import InMemoryUsernamePasswordDatabaseDontUse
from twisted.cred.credentials import (
ISSHPrivateKey,
IUsernamePassword,
SSHPrivateKey,
UsernamePassword,
)
from twisted.cred.error import UnauthorizedLogin, UnhandledCredentials
from twisted.internet.defer import Deferred
from twisted.python import util
from twisted.python.fakepwd import ShadowDatabase, UserDatabase
from twisted.python.filepath import FilePath
from twisted.python.reflect import requireModule
from twisted.test.test_process import MockOS
from twisted.trial.unittest import TestCase
if requireModule("cryptography"):
dependencySkip = None
from twisted.conch import checkers
from twisted.conch.error import NotEnoughAuthentication, ValidPublicKey
from twisted.conch.ssh import keys
from twisted.conch.test import keydata
else:
dependencySkip = "can't run without cryptography"
if getattr(os, "geteuid", None) is not None:
euidSkip = None
else:
euidSkip = "Cannot run without effective UIDs (questionable)"
class HelperTests(TestCase):
"""
Tests for helper functions L{verifyCryptedPassword}, L{_pwdGetByName} and
L{_shadowGetByName}.
"""
skip = cryptSkip or dependencySkip
def setUp(self):
self.mockos = MockOS()
def test_verifyCryptedPassword(self):
"""
L{verifyCryptedPassword} returns C{True} if the plaintext password
passed to it matches the encrypted password passed to it.
"""
password = "secret string"
salt = "salty"
crypted = crypt.crypt(password, salt)
self.assertTrue(
checkers.verifyCryptedPassword(crypted, password),
"{!r} supposed to be valid encrypted password for {!r}".format(
crypted, password
),
)
def test_verifyCryptedPasswordMD5(self):
"""
L{verifyCryptedPassword} returns True if the provided cleartext password
matches the provided MD5 password hash.
"""
password = "password"
salt = "$1$salt"
crypted = crypt.crypt(password, salt)
self.assertTrue(
checkers.verifyCryptedPassword(crypted, password),
"{!r} supposed to be valid encrypted password for {}".format(
crypted, password
),
)
def test_refuteCryptedPassword(self):
"""
L{verifyCryptedPassword} returns C{False} if the plaintext password
passed to it does not match the encrypted password passed to it.
"""
password = "string secret"
wrong = "secret string"
crypted = crypt.crypt(password, password)
self.assertFalse(
checkers.verifyCryptedPassword(crypted, wrong),
"{!r} not supposed to be valid encrypted password for {}".format(
crypted, wrong
),
)
def test_pwdGetByName(self):
"""
L{_pwdGetByName} returns a tuple of items from the UNIX /etc/passwd
database if the L{pwd} module is present.
"""
userdb = UserDatabase()
userdb.addUser("alice", "secrit", 1, 2, "first last", "/foo", "/bin/sh")
self.patch(checkers, "pwd", userdb)
self.assertEqual(checkers._pwdGetByName("alice"), userdb.getpwnam("alice"))
def test_pwdGetByNameWithoutPwd(self):
"""
If the C{pwd} module isn't present, L{_pwdGetByName} returns L{None}.
"""
self.patch(checkers, "pwd", None)
self.assertIsNone(checkers._pwdGetByName("alice"))
def test_shadowGetByName(self):
"""
L{_shadowGetByName} returns a tuple of items from the UNIX /etc/shadow
database if the L{spwd} is present.
"""
userdb = ShadowDatabase()
userdb.addUser("bob", "passphrase", 1, 2, 3, 4, 5, 6, 7)
self.patch(checkers, "spwd", userdb)
self.mockos.euid = 2345
self.mockos.egid = 1234
self.patch(util, "os", self.mockos)
self.assertEqual(checkers._shadowGetByName("bob"), userdb.getspnam("bob"))
self.assertEqual(self.mockos.seteuidCalls, [0, 2345])
self.assertEqual(self.mockos.setegidCalls, [0, 1234])
def test_shadowGetByNameWithoutSpwd(self):
"""
L{_shadowGetByName} returns L{None} if C{spwd} is not present.
"""
self.patch(checkers, "spwd", None)
self.assertIsNone(checkers._shadowGetByName("bob"))
self.assertEqual(self.mockos.seteuidCalls, [])
self.assertEqual(self.mockos.setegidCalls, [])
class SSHPublicKeyDatabaseTests(TestCase):
"""
Tests for L{SSHPublicKeyDatabase}.
"""
skip = euidSkip or dependencySkip
def setUp(self) -> None:
self.checker = checkers.SSHPublicKeyDatabase()
self.key1 = encodebytes(b"foobar")
self.key2 = encodebytes(b"eggspam")
self.content = b"t1 " + self.key1 + b" foo\nt2 " + self.key2 + b" egg\n"
self.mockos = MockOS()
self.patch(util, "os", self.mockos)
self.path = FilePath(self.mktemp())
assert isinstance(self.path.path, str) # text mode
self.sshDir = self.path.child(".ssh")
self.sshDir.makedirs()
userdb = UserDatabase()
userdb.addUser(
"user",
"password",
1,
2,
"first last",
self.path.path,
"/bin/shell",
)
self.checker._userdb = userdb # type: ignore
def test_deprecated(self):
"""
L{SSHPublicKeyDatabase} is deprecated as of version 15.0
"""
warningsShown = self.flushWarnings(offendingFunctions=[self.setUp])
self.assertEqual(warningsShown[0]["category"], DeprecationWarning)
self.assertEqual(
warningsShown[0]["message"],
"twisted.conch.checkers.SSHPublicKeyDatabase "
"was deprecated in Twisted 15.0.0: Please use "
"twisted.conch.checkers.SSHPublicKeyChecker, "
"initialized with an instance of "
"twisted.conch.checkers.UNIXAuthorizedKeysFiles instead.",
)
self.assertEqual(len(warningsShown), 1)
def _testCheckKey(self, filename):
self.sshDir.child(filename).setContent(self.content)
user = UsernamePassword(b"user", b"password")
user.blob = b"foobar"
self.assertTrue(self.checker.checkKey(user))
user.blob = b"eggspam"
self.assertTrue(self.checker.checkKey(user))
user.blob = b"notallowed"
self.assertFalse(self.checker.checkKey(user))
def test_checkKey(self):
"""
L{SSHPublicKeyDatabase.checkKey} should retrieve the content of the
authorized_keys file and check the keys against that file.
"""
self._testCheckKey("authorized_keys")
self.assertEqual(self.mockos.seteuidCalls, [])
self.assertEqual(self.mockos.setegidCalls, [])
def test_checkKey2(self):
"""
L{SSHPublicKeyDatabase.checkKey} should retrieve the content of the
authorized_keys2 file and check the keys against that file.
"""
self._testCheckKey("authorized_keys2")
self.assertEqual(self.mockos.seteuidCalls, [])
self.assertEqual(self.mockos.setegidCalls, [])
def test_checkKeyAsRoot(self):
"""
If the key file is readable, L{SSHPublicKeyDatabase.checkKey} should
switch its uid/gid to the ones of the authenticated user.
"""
keyFile = self.sshDir.child("authorized_keys")
keyFile.setContent(self.content)
# Fake permission error by changing the mode
keyFile.chmod(0o000)
self.addCleanup(keyFile.chmod, 0o777)
# And restore the right mode when seteuid is called
savedSeteuid = self.mockos.seteuid
def seteuid(euid):
keyFile.chmod(0o777)
return savedSeteuid(euid)
self.mockos.euid = 2345
self.mockos.egid = 1234
self.patch(self.mockos, "seteuid", seteuid)
self.patch(util, "os", self.mockos)
user = UsernamePassword(b"user", b"password")
user.blob = b"foobar"
self.assertTrue(self.checker.checkKey(user))
self.assertEqual(self.mockos.seteuidCalls, [0, 1, 0, 2345])
self.assertEqual(self.mockos.setegidCalls, [2, 1234])
def test_requestAvatarId(self):
"""
L{SSHPublicKeyDatabase.requestAvatarId} should return the avatar id
passed in if its C{_checkKey} method returns True.
"""
def _checkKey(ignored):
return True
self.patch(self.checker, "checkKey", _checkKey)
credentials = SSHPrivateKey(
b"test",
b"ssh-rsa",
keydata.publicRSA_openssh,
b"foo",
keys.Key.fromString(keydata.privateRSA_openssh).sign(b"foo"),
)
d = self.checker.requestAvatarId(credentials)
def _verify(avatarId):
self.assertEqual(avatarId, b"test")
return d.addCallback(_verify)
def test_requestAvatarIdWithoutSignature(self):
"""
L{SSHPublicKeyDatabase.requestAvatarId} should raise L{ValidPublicKey}
if the credentials represent a valid key without a signature. This
tells the user that the key is valid for login, but does not actually
allow that user to do so without a signature.
"""
def _checkKey(ignored):
return True
self.patch(self.checker, "checkKey", _checkKey)
credentials = SSHPrivateKey(
b"test", b"ssh-rsa", keydata.publicRSA_openssh, None, None
)
d = self.checker.requestAvatarId(credentials)
return self.assertFailure(d, ValidPublicKey)
def test_requestAvatarIdInvalidKey(self):
"""
If L{SSHPublicKeyDatabase.checkKey} returns False,
C{_cbRequestAvatarId} should raise L{UnauthorizedLogin}.
"""
def _checkKey(ignored):
return False
self.patch(self.checker, "checkKey", _checkKey)
d = self.checker.requestAvatarId(None)
return self.assertFailure(d, UnauthorizedLogin)
def test_requestAvatarIdInvalidSignature(self):
"""
Valid keys with invalid signatures should cause
L{SSHPublicKeyDatabase.requestAvatarId} to return a {UnauthorizedLogin}
failure
"""
def _checkKey(ignored):
return True
self.patch(self.checker, "checkKey", _checkKey)
credentials = SSHPrivateKey(
b"test",
b"ssh-rsa",
keydata.publicRSA_openssh,
b"foo",
keys.Key.fromString(keydata.privateDSA_openssh).sign(b"foo"),
)
d = self.checker.requestAvatarId(credentials)
return self.assertFailure(d, UnauthorizedLogin)
def test_requestAvatarIdNormalizeException(self):
"""
Exceptions raised while verifying the key should be normalized into an
C{UnauthorizedLogin} failure.
"""
def _checkKey(ignored):
return True
self.patch(self.checker, "checkKey", _checkKey)
credentials = SSHPrivateKey(b"test", None, b"blob", b"sigData", b"sig")
d = self.checker.requestAvatarId(credentials)
def _verifyLoggedException(failure):
errors = self.flushLoggedErrors(keys.BadKeyError)
self.assertEqual(len(errors), 1)
return failure
d.addErrback(_verifyLoggedException)
return self.assertFailure(d, UnauthorizedLogin)
class SSHProtocolCheckerTests(TestCase):
"""
Tests for L{SSHProtocolChecker}.
"""
skip = dependencySkip
def test_registerChecker(self):
"""
L{SSHProcotolChecker.registerChecker} should add the given checker to
the list of registered checkers.
"""
checker = checkers.SSHProtocolChecker()
self.assertEqual(checker.credentialInterfaces, [])
checker.registerChecker(
checkers.SSHPublicKeyDatabase(),
)
self.assertEqual(checker.credentialInterfaces, [ISSHPrivateKey])
self.assertIsInstance(
checker.checkers[ISSHPrivateKey], checkers.SSHPublicKeyDatabase
)
def test_registerCheckerWithInterface(self):
"""
If a specific interface is passed into
L{SSHProtocolChecker.registerChecker}, that interface should be
registered instead of what the checker specifies in
credentialIntefaces.
"""
checker = checkers.SSHProtocolChecker()
self.assertEqual(checker.credentialInterfaces, [])
checker.registerChecker(checkers.SSHPublicKeyDatabase(), IUsernamePassword)
self.assertEqual(checker.credentialInterfaces, [IUsernamePassword])
self.assertIsInstance(
checker.checkers[IUsernamePassword], checkers.SSHPublicKeyDatabase
)
def test_requestAvatarId(self):
"""
L{SSHProtocolChecker.requestAvatarId} should defer to one if its
registered checkers to authenticate a user.
"""
checker = checkers.SSHProtocolChecker()
passwordDatabase = InMemoryUsernamePasswordDatabaseDontUse()
passwordDatabase.addUser(b"test", b"test")
checker.registerChecker(passwordDatabase)
d = checker.requestAvatarId(UsernamePassword(b"test", b"test"))
def _callback(avatarId):
self.assertEqual(avatarId, b"test")
return d.addCallback(_callback)
def test_requestAvatarIdWithNotEnoughAuthentication(self):
"""
If the client indicates that it is never satisfied, by always returning
False from _areDone, then L{SSHProtocolChecker} should raise
L{NotEnoughAuthentication}.
"""
checker = checkers.SSHProtocolChecker()
def _areDone(avatarId):
return False
self.patch(checker, "areDone", _areDone)
passwordDatabase = InMemoryUsernamePasswordDatabaseDontUse()
passwordDatabase.addUser(b"test", b"test")
checker.registerChecker(passwordDatabase)
d = checker.requestAvatarId(UsernamePassword(b"test", b"test"))
return self.assertFailure(d, NotEnoughAuthentication)
def test_requestAvatarIdInvalidCredential(self):
"""
If the passed credentials aren't handled by any registered checker,
L{SSHProtocolChecker} should raise L{UnhandledCredentials}.
"""
checker = checkers.SSHProtocolChecker()
d = checker.requestAvatarId(UsernamePassword(b"test", b"test"))
return self.assertFailure(d, UnhandledCredentials)
def test_areDone(self):
"""
The default L{SSHProcotolChecker.areDone} should simply return True.
"""
self.assertTrue(checkers.SSHProtocolChecker().areDone(None))
class UNIXPasswordDatabaseTests(TestCase):
"""
Tests for L{UNIXPasswordDatabase}.
"""
skip = cryptSkip or dependencySkip
def assertLoggedIn(self, d: Deferred[bytes], username: bytes) -> None:
"""
Assert that the L{Deferred} passed in is called back with the value
'username'. This represents a valid login for this TestCase.
@param d: a L{Deferred} from an L{IChecker.requestAvatarId} method.
"""
self.assertEqual(self.successResultOf(d), username)
def test_defaultCheckers(self):
"""
L{UNIXPasswordDatabase} with no arguments has checks the C{pwd} database
and then the C{spwd} database.
"""
checker = checkers.UNIXPasswordDatabase()
def crypted(username, password):
salt = crypt.crypt(password, username)
crypted = crypt.crypt(password, "$1$" + salt)
return crypted
pwd = UserDatabase()
pwd.addUser(
"alice", crypted("alice", "password"), 1, 2, "foo", "/foo", "/bin/sh"
)
# x and * are convention for "look elsewhere for the password"
pwd.addUser("bob", "x", 1, 2, "bar", "/bar", "/bin/sh")
spwd = ShadowDatabase()
spwd.addUser("alice", "wrong", 1, 2, 3, 4, 5, 6, 7)
spwd.addUser("bob", crypted("bob", "password"), 8, 9, 10, 11, 12, 13, 14)
self.patch(checkers, "pwd", pwd)
self.patch(checkers, "spwd", spwd)
mockos = MockOS()
self.patch(util, "os", mockos)
mockos.euid = 2345
mockos.egid = 1234
cred = UsernamePassword(b"alice", b"password")
self.assertLoggedIn(checker.requestAvatarId(cred), b"alice")
self.assertEqual(mockos.seteuidCalls, [])
self.assertEqual(mockos.setegidCalls, [])
cred.username = b"bob"
self.assertLoggedIn(checker.requestAvatarId(cred), b"bob")
self.assertEqual(mockos.seteuidCalls, [0, 2345])
self.assertEqual(mockos.setegidCalls, [0, 1234])
def assertUnauthorizedLogin(self, d):
"""
Asserts that the L{Deferred} passed in is erred back with an
L{UnauthorizedLogin} L{Failure}. This reprsents an invalid login for
this TestCase.
NOTE: To work, this method's return value must be returned from the
test method, or otherwise hooked up to the test machinery.
@param d: a L{Deferred} from an L{IChecker.requestAvatarId} method.
@type d: L{Deferred}
@rtype: L{None}
"""
self.failureResultOf(d, checkers.UnauthorizedLogin)
def test_passInCheckers(self):
"""
L{UNIXPasswordDatabase} takes a list of functions to check for UNIX
user information.
"""
password = crypt.crypt("secret", "secret")
userdb = UserDatabase()
userdb.addUser("anybody", password, 1, 2, "foo", "/bar", "/bin/sh")
checker = checkers.UNIXPasswordDatabase([userdb.getpwnam])
self.assertLoggedIn(
checker.requestAvatarId(UsernamePassword(b"anybody", b"secret")), b"anybody"
)
def test_verifyPassword(self):
"""
If the encrypted password provided by the getpwnam function is valid
(verified by the L{verifyCryptedPassword} function), we callback the
C{requestAvatarId} L{Deferred} with the username.
"""
def verifyCryptedPassword(crypted, pw):
return crypted == pw
def getpwnam(username):
return [username, username]
self.patch(checkers, "verifyCryptedPassword", verifyCryptedPassword)
checker = checkers.UNIXPasswordDatabase([getpwnam])
credential = UsernamePassword(b"username", b"username")
self.assertLoggedIn(checker.requestAvatarId(credential), b"username")
def test_failOnKeyError(self):
"""
If the getpwnam function raises a KeyError, the login fails with an
L{UnauthorizedLogin} exception.
"""
def getpwnam(username):
raise KeyError(username)
checker = checkers.UNIXPasswordDatabase([getpwnam])
credential = UsernamePassword(b"username", b"password")
self.assertUnauthorizedLogin(checker.requestAvatarId(credential))
def test_failOnBadPassword(self):
"""
If the verifyCryptedPassword function doesn't verify the password, the
login fails with an L{UnauthorizedLogin} exception.
"""
def verifyCryptedPassword(crypted, pw):
return False
def getpwnam(username):
return [username, b"password"]
self.patch(checkers, "verifyCryptedPassword", verifyCryptedPassword)
checker = checkers.UNIXPasswordDatabase([getpwnam])
credential = UsernamePassword(b"username", b"password")
self.assertUnauthorizedLogin(checker.requestAvatarId(credential))
def test_loopThroughFunctions(self):
"""
UNIXPasswordDatabase.requestAvatarId loops through each getpwnam
function associated with it and returns a L{Deferred} which fires with
the result of the first one which returns a value other than None.
ones do not verify the password.
"""
def verifyCryptedPassword(crypted, pw):
return crypted == pw
def getpwnam1(username):
return [username, "not the password"]
def getpwnam2(username):
return [username, "password"]
self.patch(checkers, "verifyCryptedPassword", verifyCryptedPassword)
checker = checkers.UNIXPasswordDatabase([getpwnam1, getpwnam2])
credential = UsernamePassword(b"username", b"password")
self.assertLoggedIn(checker.requestAvatarId(credential), b"username")
def test_failOnSpecial(self):
"""
If the password returned by any function is C{""}, C{"x"}, or C{"*"} it
is not compared against the supplied password. Instead it is skipped.
"""
pwd = UserDatabase()
pwd.addUser("alice", "", 1, 2, "", "foo", "bar")
pwd.addUser("bob", "x", 1, 2, "", "foo", "bar")
pwd.addUser("carol", "*", 1, 2, "", "foo", "bar")
self.patch(checkers, "pwd", pwd)
checker = checkers.UNIXPasswordDatabase([checkers._pwdGetByName])
cred = UsernamePassword(b"alice", b"")
self.assertUnauthorizedLogin(checker.requestAvatarId(cred))
cred = UsernamePassword(b"bob", b"x")
self.assertUnauthorizedLogin(checker.requestAvatarId(cred))
cred = UsernamePassword(b"carol", b"*")
self.assertUnauthorizedLogin(checker.requestAvatarId(cred))
class AuthorizedKeyFileReaderTests(TestCase):
"""
Tests for L{checkers.readAuthorizedKeyFile}
"""
skip = dependencySkip
def test_ignoresComments(self):
"""
L{checkers.readAuthorizedKeyFile} does not attempt to turn comments
into keys
"""
fileobj = BytesIO(
b"# this comment is ignored\n"
b"this is not\n"
b"# this is again\n"
b"and this is not"
)
result = checkers.readAuthorizedKeyFile(fileobj, lambda x: x)
self.assertEqual([b"this is not", b"and this is not"], list(result))
def test_ignoresLeadingWhitespaceAndEmptyLines(self):
"""
L{checkers.readAuthorizedKeyFile} ignores leading whitespace in
lines, as well as empty lines
"""
fileobj = BytesIO(
b"""
# ignore
not ignored
"""
)
result = checkers.readAuthorizedKeyFile(fileobj, parseKey=lambda x: x)
self.assertEqual([b"not ignored"], list(result))
def test_ignoresUnparsableKeys(self):
"""
L{checkers.readAuthorizedKeyFile} does not raise an exception
when a key fails to parse (raises a
L{twisted.conch.ssh.keys.BadKeyError}), but rather just keeps going
"""
def failOnSome(line):
if line.startswith(b"f"):
raise keys.BadKeyError("failed to parse")
return line
fileobj = BytesIO(b"failed key\ngood key")
result = checkers.readAuthorizedKeyFile(fileobj, parseKey=failOnSome)
self.assertEqual([b"good key"], list(result))
class InMemorySSHKeyDBTests(TestCase):
"""
Tests for L{checkers.InMemorySSHKeyDB}
"""
skip = dependencySkip
def test_implementsInterface(self):
"""
L{checkers.InMemorySSHKeyDB} implements
L{checkers.IAuthorizedKeysDB}
"""
keydb = checkers.InMemorySSHKeyDB({b"alice": [b"key"]})
verifyObject(checkers.IAuthorizedKeysDB, keydb)
def test_noKeysForUnauthorizedUser(self):
"""
If the user is not in the mapping provided to
L{checkers.InMemorySSHKeyDB}, an empty iterator is returned
by L{checkers.InMemorySSHKeyDB.getAuthorizedKeys}
"""
keydb = checkers.InMemorySSHKeyDB({b"alice": [b"keys"]})
self.assertEqual([], list(keydb.getAuthorizedKeys(b"bob")))
def test_allKeysForAuthorizedUser(self):
"""
If the user is in the mapping provided to
L{checkers.InMemorySSHKeyDB}, an iterator with all the keys
is returned by L{checkers.InMemorySSHKeyDB.getAuthorizedKeys}
"""
keydb = checkers.InMemorySSHKeyDB({b"alice": [b"a", b"b"]})
self.assertEqual([b"a", b"b"], list(keydb.getAuthorizedKeys(b"alice")))
class UNIXAuthorizedKeysFilesTests(TestCase):
"""
Tests for L{checkers.UNIXAuthorizedKeysFiles}.
"""
skip = dependencySkip
def setUp(self) -> None:
self.path = FilePath(self.mktemp())
assert isinstance(self.path.path, str)
self.path.makedirs()
self.userdb = UserDatabase()
self.userdb.addUser(
"alice",
"password",
1,
2,
"alice lastname",
self.path.path,
"/bin/shell",
)
self.sshDir = self.path.child(".ssh")
self.sshDir.makedirs()
authorizedKeys = self.sshDir.child("authorized_keys")
authorizedKeys.setContent(b"key 1\nkey 2")
self.expectedKeys = [b"key 1", b"key 2"]
def test_implementsInterface(self):
"""
L{checkers.UNIXAuthorizedKeysFiles} implements
L{checkers.IAuthorizedKeysDB}.
"""
keydb = checkers.UNIXAuthorizedKeysFiles(self.userdb)
verifyObject(checkers.IAuthorizedKeysDB, keydb)
def test_noKeysForUnauthorizedUser(self):
"""
If the user is not in the user database provided to
L{checkers.UNIXAuthorizedKeysFiles}, an empty iterator is returned
by L{checkers.UNIXAuthorizedKeysFiles.getAuthorizedKeys}.
"""
keydb = checkers.UNIXAuthorizedKeysFiles(self.userdb, parseKey=lambda x: x)
self.assertEqual([], list(keydb.getAuthorizedKeys(b"bob")))
def test_allKeysInAllAuthorizedFilesForAuthorizedUser(self):
"""
If the user is in the user database provided to
L{checkers.UNIXAuthorizedKeysFiles}, an iterator with all the keys in
C{~/.ssh/authorized_keys} and C{~/.ssh/authorized_keys2} is returned
by L{checkers.UNIXAuthorizedKeysFiles.getAuthorizedKeys}.
"""
self.sshDir.child("authorized_keys2").setContent(b"key 3")
keydb = checkers.UNIXAuthorizedKeysFiles(self.userdb, parseKey=lambda x: x)
self.assertEqual(
self.expectedKeys + [b"key 3"], list(keydb.getAuthorizedKeys(b"alice"))
)
def test_ignoresNonexistantFile(self):
"""
L{checkers.UNIXAuthorizedKeysFiles.getAuthorizedKeys} returns only
the keys in C{~/.ssh/authorized_keys} and C{~/.ssh/authorized_keys2}
if they exist.
"""
keydb = checkers.UNIXAuthorizedKeysFiles(self.userdb, parseKey=lambda x: x)
self.assertEqual(self.expectedKeys, list(keydb.getAuthorizedKeys(b"alice")))
def test_ignoresUnreadableFile(self):
"""
L{checkers.UNIXAuthorizedKeysFiles.getAuthorizedKeys} returns only
the keys in C{~/.ssh/authorized_keys} and C{~/.ssh/authorized_keys2}
if they are readable.
"""
self.sshDir.child("authorized_keys2").makedirs()
keydb = checkers.UNIXAuthorizedKeysFiles(self.userdb, parseKey=lambda x: x)
self.assertEqual(self.expectedKeys, list(keydb.getAuthorizedKeys(b"alice")))
_KeyDB = namedtuple("_KeyDB", ["getAuthorizedKeys"])
class _DummyException(Exception):
"""
Fake exception to be used for testing.
"""
pass
class SSHPublicKeyCheckerTests(TestCase):
"""
Tests for L{checkers.SSHPublicKeyChecker}.
"""
skip = dependencySkip
def setUp(self):
self.credentials = SSHPrivateKey(
b"alice",
b"ssh-rsa",
keydata.publicRSA_openssh,
b"foo",
keys.Key.fromString(keydata.privateRSA_openssh).sign(b"foo"),
)
self.keydb = _KeyDB(lambda _: [keys.Key.fromString(keydata.publicRSA_openssh)])
self.checker = checkers.SSHPublicKeyChecker(self.keydb)
def test_credentialsWithoutSignature(self):
"""
Calling L{checkers.SSHPublicKeyChecker.requestAvatarId} with
credentials that do not have a signature fails with L{ValidPublicKey}.
"""
self.credentials.signature = None
self.failureResultOf(
self.checker.requestAvatarId(self.credentials), ValidPublicKey
)
def test_credentialsWithBadKey(self):
"""
Calling L{checkers.SSHPublicKeyChecker.requestAvatarId} with
credentials that have a bad key fails with L{keys.BadKeyError}.
"""
self.credentials.blob = b""
self.failureResultOf(
self.checker.requestAvatarId(self.credentials), keys.BadKeyError
)
def test_credentialsNoMatchingKey(self):
"""
If L{checkers.IAuthorizedKeysDB.getAuthorizedKeys} returns no keys
that match the credentials,
L{checkers.SSHPublicKeyChecker.requestAvatarId} fails with
L{UnauthorizedLogin}.
"""
self.credentials.blob = keydata.publicDSA_openssh
self.failureResultOf(
self.checker.requestAvatarId(self.credentials), UnauthorizedLogin
)
def test_credentialsInvalidSignature(self):
"""
Calling L{checkers.SSHPublicKeyChecker.requestAvatarId} with
credentials that are incorrectly signed fails with
L{UnauthorizedLogin}.
"""
self.credentials.signature = keys.Key.fromString(
keydata.privateDSA_openssh
).sign(b"foo")
self.failureResultOf(
self.checker.requestAvatarId(self.credentials), UnauthorizedLogin
)
def test_failureVerifyingKey(self):
"""
If L{keys.Key.verify} raises an exception,
L{checkers.SSHPublicKeyChecker.requestAvatarId} fails with
L{UnauthorizedLogin}.
"""
def fail(*args, **kwargs):
raise _DummyException()
self.patch(keys.Key, "verify", fail)
self.failureResultOf(
self.checker.requestAvatarId(self.credentials), UnauthorizedLogin
)
self.flushLoggedErrors(_DummyException)
def test_usernameReturnedOnSuccess(self):
"""
L{checker.SSHPublicKeyChecker.requestAvatarId}, if successful,
callbacks with the username.
"""
d = self.checker.requestAvatarId(self.credentials)
self.assertEqual(b"alice", self.successResultOf(d))

View File

@@ -0,0 +1,725 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.conch.scripts.ckeygen}.
"""
from __future__ import annotations
import getpass
import os
import subprocess
import sys
from io import StringIO
from typing import Callable
from typing_extensions import NoReturn
from twisted.conch.test.keydata import (
privateECDSA_openssh,
privateEd25519_openssh_new,
privateRSA_openssh,
privateRSA_openssh_encrypted,
publicRSA_openssh,
)
from twisted.python.filepath import FilePath
from twisted.python.reflect import requireModule
from twisted.trial.unittest import TestCase
if requireModule("cryptography"):
from twisted.conch.scripts.ckeygen import (
_getKeyOrDefault,
_saveKey,
changePassPhrase,
displayPublicKey,
enumrepresentation,
printFingerprint,
)
from twisted.conch.ssh.keys import (
BadFingerPrintFormat,
BadKeyError,
FingerprintFormats,
Key,
)
else:
skip = "cryptography required for twisted.conch.scripts.ckeygen"
def makeGetpass(*passphrases: str) -> Callable[[object], str]:
"""
Return a callable to patch C{getpass.getpass}. Yields a passphrase each
time called. Use case is to provide an old, then new passphrase(s) as if
requested interactively.
@param passphrases: The list of passphrases returned, one per each call.
@return: A callable to patch C{getpass.getpass}.
"""
passphrasesIter = iter(passphrases)
def fakeGetpass(_: object) -> str:
return next(passphrasesIter)
return fakeGetpass
class KeyGenTests(TestCase):
"""
Tests for various functions used to implement the I{ckeygen} script.
"""
def setUp(self) -> None:
"""
Patch C{sys.stdout} so tests can make assertions about what's printed.
"""
self.stdout = StringIO()
self.patch(sys, "stdout", self.stdout)
def _testrun(
self,
keyType: str,
keySize: str | None = None,
privateKeySubtype: str | None = None,
) -> None:
filename = self.mktemp()
args = ["ckeygen", "-t", keyType, "-f", filename, "--no-passphrase"]
if keySize is not None:
args.extend(["-b", keySize])
if privateKeySubtype is not None:
args.extend(["--private-key-subtype", privateKeySubtype])
subprocess.call(args)
privKey = Key.fromFile(filename)
pubKey = Key.fromFile(filename + ".pub")
if keyType == "ecdsa":
self.assertEqual(privKey.type(), "EC")
elif keyType == "ed25519":
self.assertEqual(privKey.type(), "Ed25519")
else:
self.assertEqual(privKey.type(), keyType.upper())
self.assertTrue(pubKey.isPublic())
def test_keygeneration(self) -> None:
self._testrun("ecdsa", "384")
self._testrun("ecdsa", "384", privateKeySubtype="v1")
self._testrun("ecdsa")
self._testrun("ecdsa", privateKeySubtype="v1")
self._testrun("ed25519")
self._testrun("dsa", "2048")
self._testrun("dsa", "2048", privateKeySubtype="v1")
self._testrun("dsa")
self._testrun("dsa", privateKeySubtype="v1")
self._testrun("rsa", "2048")
self._testrun("rsa", "2048", privateKeySubtype="v1")
self._testrun("rsa")
self._testrun("rsa", privateKeySubtype="v1")
def test_runBadKeytype(self) -> None:
filename = self.mktemp()
with self.assertRaises(subprocess.CalledProcessError):
subprocess.check_call(["ckeygen", "-t", "foo", "-f", filename])
def test_enumrepresentation(self) -> None:
"""
L{enumrepresentation} takes a dictionary as input and returns a
dictionary with its attributes changed to enum representation.
"""
options = enumrepresentation({"format": "md5-hex"})
self.assertIs(options["format"], FingerprintFormats.MD5_HEX)
def test_enumrepresentationsha256(self) -> None:
"""
Test for format L{FingerprintFormats.SHA256-BASE64}.
"""
options = enumrepresentation({"format": "sha256-base64"})
self.assertIs(options["format"], FingerprintFormats.SHA256_BASE64)
def test_enumrepresentationBadFormat(self) -> None:
"""
Test for unsupported fingerprint format
"""
with self.assertRaises(BadFingerPrintFormat) as em:
enumrepresentation({"format": "sha-base64"})
self.assertEqual(
"Unsupported fingerprint format: sha-base64", em.exception.args[0]
)
def test_printFingerprint(self) -> None:
"""
L{printFingerprint} writes a line to standard out giving the number of
bits of the key, its fingerprint, and the basename of the file from it
was read.
"""
filename = self.mktemp()
FilePath(filename).setContent(publicRSA_openssh)
printFingerprint({"filename": filename, "format": "md5-hex"})
self.assertEqual(
self.stdout.getvalue(),
"2048 85:25:04:32:58:55:96:9f:57:ee:fb:a8:1a:ea:69:da temp\n",
)
def test_printFingerprintsha256(self) -> None:
"""
L{printFigerprint} will print key fingerprint in
L{FingerprintFormats.SHA256-BASE64} format if explicitly specified.
"""
filename = self.mktemp()
FilePath(filename).setContent(publicRSA_openssh)
printFingerprint({"filename": filename, "format": "sha256-base64"})
self.assertEqual(
self.stdout.getvalue(),
"2048 FBTCOoknq0mHy+kpfnY9tDdcAJuWtCpuQMaV3EsvbUI= temp\n",
)
def test_printFingerprintBadFingerPrintFormat(self) -> None:
"""
L{printFigerprint} raises C{keys.BadFingerprintFormat} when unsupported
formats are requested.
"""
filename = self.mktemp()
FilePath(filename).setContent(publicRSA_openssh)
with self.assertRaises(BadFingerPrintFormat) as em:
printFingerprint({"filename": filename, "format": "sha-base64"})
self.assertEqual(
"Unsupported fingerprint format: sha-base64", em.exception.args[0]
)
def test_printFingerprintSuffixAppended(self) -> None:
"""
L{printFingerprint} checks if the filename with the '.pub' suffix
exists in ~/.ssh.
"""
filename = self.mktemp()
FilePath(filename + ".pub").setContent(publicRSA_openssh)
printFingerprint({"filename": filename, "format": "md5-hex"})
self.assertEqual(
self.stdout.getvalue(),
"2048 85:25:04:32:58:55:96:9f:57:ee:fb:a8:1a:ea:69:da temp.pub\n",
)
def test_saveKey(self) -> None:
"""
L{_saveKey} writes the private and public parts of a key to two
different files and writes a report of this to standard out.
"""
base = FilePath(self.mktemp())
base.makedirs()
filename = base.child("id_rsa").path
key = Key.fromString(privateRSA_openssh)
_saveKey(key, {"filename": filename, "pass": "passphrase", "format": "md5-hex"})
self.assertEqual(
self.stdout.getvalue(),
"Your identification has been saved in %s\n"
"Your public key has been saved in %s.pub\n"
"The key fingerprint in <FingerprintFormats=MD5_HEX> is:\n"
"85:25:04:32:58:55:96:9f:57:ee:fb:a8:1a:ea:69:da\n" % (filename, filename),
)
self.assertEqual(
key.fromString(base.child("id_rsa").getContent(), None, "passphrase"), key
)
self.assertEqual(
Key.fromString(base.child("id_rsa.pub").getContent()), key.public()
)
def test_saveKeyECDSA(self) -> None:
"""
L{_saveKey} writes the private and public parts of a key to two
different files and writes a report of this to standard out.
Test with ECDSA key.
"""
base = FilePath(self.mktemp())
base.makedirs()
filename = base.child("id_ecdsa").path
key = Key.fromString(privateECDSA_openssh)
_saveKey(key, {"filename": filename, "pass": "passphrase", "format": "md5-hex"})
self.assertEqual(
self.stdout.getvalue(),
"Your identification has been saved in %s\n"
"Your public key has been saved in %s.pub\n"
"The key fingerprint in <FingerprintFormats=MD5_HEX> is:\n"
"1e:ab:83:a6:f2:04:22:99:7c:64:14:d2:ab:fa:f5:16\n" % (filename, filename),
)
self.assertEqual(
key.fromString(base.child("id_ecdsa").getContent(), None, "passphrase"), key
)
self.assertEqual(
Key.fromString(base.child("id_ecdsa.pub").getContent()), key.public()
)
def test_saveKeyEd25519(self) -> None:
"""
L{_saveKey} writes the private and public parts of a key to two
different files and writes a report of this to standard out.
Test with Ed25519 key.
"""
base = FilePath(self.mktemp())
base.makedirs()
filename = base.child("id_ed25519").path
key = Key.fromString(privateEd25519_openssh_new)
_saveKey(key, {"filename": filename, "pass": "passphrase", "format": "md5-hex"})
self.assertEqual(
self.stdout.getvalue(),
"Your identification has been saved in %s\n"
"Your public key has been saved in %s.pub\n"
"The key fingerprint in <FingerprintFormats=MD5_HEX> is:\n"
"ab:ee:c8:ed:e5:01:1b:45:b7:8d:b2:f0:8f:61:1c:14\n" % (filename, filename),
)
self.assertEqual(
key.fromString(base.child("id_ed25519").getContent(), None, "passphrase"),
key,
)
self.assertEqual(
Key.fromString(base.child("id_ed25519.pub").getContent()), key.public()
)
def test_saveKeysha256(self) -> None:
"""
L{_saveKey} will generate key fingerprint in
L{FingerprintFormats.SHA256-BASE64} format if explicitly specified.
"""
base = FilePath(self.mktemp())
base.makedirs()
filename = base.child("id_rsa").path
key = Key.fromString(privateRSA_openssh)
_saveKey(
key, {"filename": filename, "pass": "passphrase", "format": "sha256-base64"}
)
self.assertEqual(
self.stdout.getvalue(),
"Your identification has been saved in %s\n"
"Your public key has been saved in %s.pub\n"
"The key fingerprint in <FingerprintFormats=SHA256_BASE64> is:\n"
"FBTCOoknq0mHy+kpfnY9tDdcAJuWtCpuQMaV3EsvbUI=\n" % (filename, filename),
)
self.assertEqual(
key.fromString(base.child("id_rsa").getContent(), None, "passphrase"), key
)
self.assertEqual(
Key.fromString(base.child("id_rsa.pub").getContent()), key.public()
)
def test_saveKeyBadFingerPrintformat(self) -> None:
"""
L{_saveKey} raises C{keys.BadFingerprintFormat} when unsupported
formats are requested.
"""
base = FilePath(self.mktemp())
base.makedirs()
filename = base.child("id_rsa").path
key = Key.fromString(privateRSA_openssh)
with self.assertRaises(BadFingerPrintFormat) as em:
_saveKey(
key,
{"filename": filename, "pass": "passphrase", "format": "sha-base64"},
)
self.assertEqual(
"Unsupported fingerprint format: sha-base64", em.exception.args[0]
)
def test_saveKeyEmptyPassphrase(self) -> None:
"""
L{_saveKey} will choose an empty string for the passphrase if
no-passphrase is C{True}.
"""
base = FilePath(self.mktemp())
base.makedirs()
filename = base.child("id_rsa").path
key = Key.fromString(privateRSA_openssh)
_saveKey(
key, {"filename": filename, "no-passphrase": True, "format": "md5-hex"}
)
self.assertEqual(
key.fromString(base.child("id_rsa").getContent(), None, b""), key
)
def test_saveKeyECDSAEmptyPassphrase(self) -> None:
"""
L{_saveKey} will choose an empty string for the passphrase if
no-passphrase is C{True}.
"""
base = FilePath(self.mktemp())
base.makedirs()
filename = base.child("id_ecdsa").path
key = Key.fromString(privateECDSA_openssh)
_saveKey(
key, {"filename": filename, "no-passphrase": True, "format": "md5-hex"}
)
self.assertEqual(key.fromString(base.child("id_ecdsa").getContent(), None), key)
def test_saveKeyEd25519EmptyPassphrase(self) -> None:
"""
L{_saveKey} will choose an empty string for the passphrase if
no-passphrase is C{True}.
"""
base = FilePath(self.mktemp())
base.makedirs()
filename = base.child("id_ed25519").path
key = Key.fromString(privateEd25519_openssh_new)
_saveKey(
key, {"filename": filename, "no-passphrase": True, "format": "md5-hex"}
)
self.assertEqual(
key.fromString(base.child("id_ed25519").getContent(), None), key
)
def test_saveKeyNoFilename(self) -> None:
"""
When no path is specified, it will ask for the path used to store the
key.
"""
base = FilePath(self.mktemp())
base.makedirs()
keyPath = base.child("custom_key").path
input_prompts: list[str] = []
import twisted.conch.scripts.ckeygen
def mock_input(*args: object) -> str:
input_prompts.append("")
return ""
self.patch(twisted.conch.scripts.ckeygen, "_inputSaveFile", lambda _: keyPath)
key = Key.fromString(privateRSA_openssh)
_saveKey(
key,
{"filename": None, "no-passphrase": True, "format": "md5-hex"},
mock_input,
)
persistedKeyContent = base.child("custom_key").getContent()
persistedKey = key.fromString(persistedKeyContent, None, b"")
self.assertEqual(key, persistedKey)
def test_saveKeyFileExists(self) -> None:
"""
When the specified file exists, it will ask the user for confirmation
before overwriting.
"""
def mock_input(*args: object) -> list[str]:
return ["n"]
base = FilePath(self.mktemp())
base.makedirs()
keyPath = base.child("custom_key").path
self.patch(os.path, "exists", lambda _: True)
key = Key.fromString(privateRSA_openssh)
options = {"filename": keyPath, "no-passphrase": True, "format": "md5-hex"}
self.assertRaises(SystemExit, _saveKey, key, options, mock_input)
def test_saveKeySubtypeV1(self) -> None:
"""
L{_saveKey} can be told to write the new private key file in OpenSSH
v1 format.
"""
base = FilePath(self.mktemp())
base.makedirs()
filename = base.child("id_rsa").path
key = Key.fromString(privateRSA_openssh)
_saveKey(
key,
{
"filename": filename,
"pass": "passphrase",
"format": "md5-hex",
"private-key-subtype": "v1",
},
)
self.assertEqual(
self.stdout.getvalue(),
"Your identification has been saved in %s\n"
"Your public key has been saved in %s.pub\n"
"The key fingerprint in <FingerprintFormats=MD5_HEX> is:\n"
"85:25:04:32:58:55:96:9f:57:ee:fb:a8:1a:ea:69:da\n" % (filename, filename),
)
privateKeyContent = base.child("id_rsa").getContent()
self.assertEqual(key.fromString(privateKeyContent, None, "passphrase"), key)
self.assertTrue(
privateKeyContent.startswith(b"-----BEGIN OPENSSH PRIVATE KEY-----\n")
)
self.assertEqual(
Key.fromString(base.child("id_rsa.pub").getContent()), key.public()
)
def test_displayPublicKey(self) -> None:
"""
L{displayPublicKey} prints out the public key associated with a given
private key.
"""
filename = self.mktemp()
pubKey = Key.fromString(publicRSA_openssh)
FilePath(filename).setContent(privateRSA_openssh)
displayPublicKey({"filename": filename})
displayed = self.stdout.getvalue().strip("\n").encode("ascii")
self.assertEqual(displayed, pubKey.toString("openssh"))
def test_displayPublicKeyEncrypted(self) -> None:
"""
L{displayPublicKey} prints out the public key associated with a given
private key using the given passphrase when it's encrypted.
"""
filename = self.mktemp()
pubKey = Key.fromString(publicRSA_openssh)
FilePath(filename).setContent(privateRSA_openssh_encrypted)
displayPublicKey({"filename": filename, "pass": "encrypted"})
displayed = self.stdout.getvalue().strip("\n").encode("ascii")
self.assertEqual(displayed, pubKey.toString("openssh"))
def test_displayPublicKeyEncryptedPassphrasePrompt(self) -> None:
"""
L{displayPublicKey} prints out the public key associated with a given
private key, asking for the passphrase when it's encrypted.
"""
filename = self.mktemp()
pubKey = Key.fromString(publicRSA_openssh)
FilePath(filename).setContent(privateRSA_openssh_encrypted)
self.patch(getpass, "getpass", lambda x: "encrypted")
displayPublicKey({"filename": filename})
displayed = self.stdout.getvalue().strip("\n").encode("ascii")
self.assertEqual(displayed, pubKey.toString("openssh"))
def test_displayPublicKeyWrongPassphrase(self) -> None:
"""
L{displayPublicKey} fails with a L{BadKeyError} when trying to decrypt
an encrypted key with the wrong password.
"""
filename = self.mktemp()
FilePath(filename).setContent(privateRSA_openssh_encrypted)
self.assertRaises(
BadKeyError, displayPublicKey, {"filename": filename, "pass": "wrong"}
)
def test_changePassphrase(self) -> None:
"""
L{changePassPhrase} allows a user to change the passphrase of a
private key interactively.
"""
oldNewConfirm = makeGetpass("encrypted", "newpass", "newpass")
self.patch(getpass, "getpass", oldNewConfirm)
filename = self.mktemp()
FilePath(filename).setContent(privateRSA_openssh_encrypted)
changePassPhrase({"filename": filename})
self.assertEqual(
self.stdout.getvalue().strip("\n"),
"Your identification has been saved with the new passphrase.",
)
self.assertNotEqual(
privateRSA_openssh_encrypted, FilePath(filename).getContent()
)
def test_changePassphraseWithOld(self) -> None:
"""
L{changePassPhrase} allows a user to change the passphrase of a
private key, providing the old passphrase and prompting for new one.
"""
newConfirm = makeGetpass("newpass", "newpass")
self.patch(getpass, "getpass", newConfirm)
filename = self.mktemp()
FilePath(filename).setContent(privateRSA_openssh_encrypted)
changePassPhrase({"filename": filename, "pass": "encrypted"})
self.assertEqual(
self.stdout.getvalue().strip("\n"),
"Your identification has been saved with the new passphrase.",
)
self.assertNotEqual(
privateRSA_openssh_encrypted, FilePath(filename).getContent()
)
def test_changePassphraseWithBoth(self) -> None:
"""
L{changePassPhrase} allows a user to change the passphrase of a private
key by providing both old and new passphrases without prompting.
"""
filename = self.mktemp()
FilePath(filename).setContent(privateRSA_openssh_encrypted)
changePassPhrase(
{"filename": filename, "pass": "encrypted", "newpass": "newencrypt"}
)
self.assertEqual(
self.stdout.getvalue().strip("\n"),
"Your identification has been saved with the new passphrase.",
)
self.assertNotEqual(
privateRSA_openssh_encrypted, FilePath(filename).getContent()
)
def test_changePassphraseWrongPassphrase(self) -> None:
"""
L{changePassPhrase} exits if passed an invalid old passphrase when
trying to change the passphrase of a private key.
"""
filename = self.mktemp()
FilePath(filename).setContent(privateRSA_openssh_encrypted)
error = self.assertRaises(
SystemExit, changePassPhrase, {"filename": filename, "pass": "wrong"}
)
self.assertEqual(
"Could not change passphrase: old passphrase error", str(error)
)
self.assertEqual(privateRSA_openssh_encrypted, FilePath(filename).getContent())
def test_changePassphraseEmptyGetPass(self) -> None:
"""
L{changePassPhrase} exits if no passphrase is specified for the
C{getpass} call and the key is encrypted.
"""
self.patch(getpass, "getpass", makeGetpass(""))
filename = self.mktemp()
FilePath(filename).setContent(privateRSA_openssh_encrypted)
error = self.assertRaises(SystemExit, changePassPhrase, {"filename": filename})
self.assertEqual(
"Could not change passphrase: Passphrase must be provided "
"for an encrypted key",
str(error),
)
self.assertEqual(privateRSA_openssh_encrypted, FilePath(filename).getContent())
def test_changePassphraseBadKey(self) -> None:
"""
L{changePassPhrase} exits if the file specified points to an invalid
key.
"""
filename = self.mktemp()
FilePath(filename).setContent(b"foobar")
error = self.assertRaises(SystemExit, changePassPhrase, {"filename": filename})
expected = "Could not change passphrase: cannot " "guess the type of b'foobar'"
self.assertEqual(expected, str(error))
self.assertEqual(b"foobar", FilePath(filename).getContent())
def test_changePassphraseCreateError(self) -> None:
"""
L{changePassPhrase} doesn't modify the key file if an unexpected error
happens when trying to create the key with the new passphrase.
"""
filename = self.mktemp()
FilePath(filename).setContent(privateRSA_openssh)
def toString(*args: object, **kwargs: object) -> NoReturn:
raise RuntimeError("oops")
self.patch(Key, "toString", toString)
error = self.assertRaises(
SystemExit,
changePassPhrase,
{"filename": filename, "newpass": "newencrypt"},
)
self.assertEqual("Could not change passphrase: oops", str(error))
self.assertEqual(privateRSA_openssh, FilePath(filename).getContent())
def test_changePassphraseEmptyStringError(self) -> None:
"""
L{changePassPhrase} doesn't modify the key file if C{toString} returns
an empty string.
"""
filename = self.mktemp()
FilePath(filename).setContent(privateRSA_openssh)
def toString(*args: object, **kwargs: object) -> str:
return ""
self.patch(Key, "toString", toString)
error = self.assertRaises(
SystemExit,
changePassPhrase,
{"filename": filename, "newpass": "newencrypt"},
)
expected = "Could not change passphrase: cannot guess the type of b''"
self.assertEqual(expected, str(error))
self.assertEqual(privateRSA_openssh, FilePath(filename).getContent())
def test_changePassphrasePublicKey(self) -> None:
"""
L{changePassPhrase} exits when trying to change the passphrase on a
public key, and doesn't change the file.
"""
filename = self.mktemp()
FilePath(filename).setContent(publicRSA_openssh)
error = self.assertRaises(
SystemExit, changePassPhrase, {"filename": filename, "newpass": "pass"}
)
self.assertEqual("Could not change passphrase: key not encrypted", str(error))
self.assertEqual(publicRSA_openssh, FilePath(filename).getContent())
def test_changePassphraseSubtypeV1(self) -> None:
"""
L{changePassPhrase} can be told to write the new private key file in
OpenSSH v1 format.
"""
oldNewConfirm = makeGetpass("encrypted", "newpass", "newpass")
self.patch(getpass, "getpass", oldNewConfirm)
filename = self.mktemp()
FilePath(filename).setContent(privateRSA_openssh_encrypted)
changePassPhrase({"filename": filename, "private-key-subtype": "v1"})
self.assertEqual(
self.stdout.getvalue().strip("\n"),
"Your identification has been saved with the new passphrase.",
)
privateKeyContent = FilePath(filename).getContent()
self.assertNotEqual(privateRSA_openssh_encrypted, privateKeyContent)
self.assertTrue(
privateKeyContent.startswith(b"-----BEGIN OPENSSH PRIVATE KEY-----\n")
)
def test_useDefaultForKey(self) -> None:
"""
L{options} will default to "~/.ssh/id_rsa" if the user doesn't
specify a key.
"""
input_prompts: list[str] = []
def mock_input(*args: object) -> str:
input_prompts.append("")
return ""
options = {"filename": ""}
filename = _getKeyOrDefault(options, mock_input)
self.assertEqual(
options["filename"],
"",
)
# Resolved path is an RSA key inside .ssh dir.
self.assertTrue(filename.endswith(os.path.join(".ssh", "id_rsa")))
# The user is prompted once to enter the path, since no path was
# provided via CLI.
self.assertEqual(1, len(input_prompts))
self.assertEqual([""], input_prompts)
def test_displayPublicKeyHandleFileNotFound(self) -> None:
"""
Ensure FileNotFoundError is handled, whether the user has supplied
a bad path, or has no key at the default path.
"""
options = {"filename": "/foo/bar"}
exc = self.assertRaises(SystemExit, displayPublicKey, options)
self.assertIn("could not be opened, please specify a file.", exc.args[0])
def test_changePassPhraseHandleFileNotFound(self) -> None:
"""
Ensure FileNotFoundError is handled for an invalid filename.
"""
options = {"filename": "/foo/bar"}
exc = self.assertRaises(SystemExit, changePassPhrase, options)
self.assertIn("could not be opened, please specify a file.", exc.args[0])
def test_printFingerprintHandleFileNotFound(self) -> None:
"""
Ensure FileNotFoundError is handled for an invalid filename.
"""
options = {"filename": "/foo/bar", "format": "md5-hex"}
exc = self.assertRaises(SystemExit, printFingerprint, options)
self.assertIn("could not be opened, please specify a file.", exc.args[0])

View File

@@ -0,0 +1,766 @@
# -*- test-case-name: twisted.conch.test.test_conch -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
import os
import socket
import subprocess
import sys
from itertools import count
from zope.interface import implementer
from twisted.conch.error import ConchError
from twisted.conch.test.keydata import privateRSA_openssh, publicRSA_openssh
from twisted.conch.test.test_ssh import ConchTestRealm
from twisted.cred import portal
from twisted.internet import defer, protocol, reactor
from twisted.internet.error import ProcessExitedAlready
from twisted.internet.task import LoopingCall
from twisted.python import filepath, log, runtime
from twisted.python.filepath import FilePath
from twisted.python.procutils import which
from twisted.python.reflect import requireModule
from twisted.test.testutils import HAS_IPV6, skipWithoutIPv6
from twisted.trial.unittest import SkipTest, TestCase
try:
from twisted.conch.test.test_ssh import (
ConchTestServerFactory,
conchTestPublicKeyChecker,
)
except ImportError:
pass
cryptography = requireModule("cryptography")
if cryptography:
from twisted.conch.avatar import ConchUser
from twisted.conch.ssh.session import ISession, SSHSession, wrapProtocol
else:
from twisted.conch.interfaces import ISession
class ConchUser: # type: ignore[no-redef]
pass
try:
from twisted.conch.scripts.conch import SSHSession as _StdioInteractingSession
except ImportError as e:
StdioInteractingSession = None
_reason = str(e)
del e
else:
StdioInteractingSession = _StdioInteractingSession
class FakeStdio:
"""
A fake for testing L{twisted.conch.scripts.conch.SSHSession.eofReceived} and
L{twisted.conch.scripts.cftp.SSHSession.eofReceived}.
@ivar writeConnLost: A flag which records whether L{loserWriteConnection}
has been called.
"""
writeConnLost = False
def loseWriteConnection(self):
"""
Record the call to loseWriteConnection.
"""
self.writeConnLost = True
class StdioInteractingSessionTests(TestCase):
"""
Tests for L{twisted.conch.scripts.conch.SSHSession}.
"""
if StdioInteractingSession is None:
skip = _reason
def test_eofReceived(self):
"""
L{twisted.conch.scripts.conch.SSHSession.eofReceived} loses the
write half of its stdio connection.
"""
stdio = FakeStdio()
channel = StdioInteractingSession()
channel.stdio = stdio
channel.eofReceived()
self.assertTrue(stdio.writeConnLost)
class Echo(protocol.Protocol):
def connectionMade(self):
log.msg("ECHO CONNECTION MADE")
def connectionLost(self, reason):
log.msg("ECHO CONNECTION DONE")
def dataReceived(self, data):
self.transport.write(data)
if b"\n" in data:
self.transport.loseConnection()
class EchoFactory(protocol.Factory):
protocol = Echo
class ConchTestOpenSSHProcess(protocol.ProcessProtocol):
"""
Test protocol for launching an OpenSSH client process.
@ivar deferred: Set by whatever uses this object. Accessed using
L{_getDeferred}, which destroys the value so the Deferred is not
fired twice. Fires when the process is terminated.
"""
deferred = None
buf = b""
problems = b""
def _getDeferred(self):
d, self.deferred = self.deferred, None
return d
def outReceived(self, data):
self.buf += data
def errReceived(self, data):
self.problems += data
def processEnded(self, reason):
"""
Called when the process has ended.
@param reason: a Failure giving the reason for the process' end.
"""
if reason.value.exitCode != 0:
self._getDeferred().errback(
ConchError(
"exit code was not 0: {} ({})".format(
reason.value.exitCode,
self.problems.decode("charmap"),
)
)
)
else:
buf = self.buf.replace(b"\r\n", b"\n")
self._getDeferred().callback(buf)
class ConchTestForwardingProcess(protocol.ProcessProtocol):
"""
Manages a third-party process which launches a server.
Uses L{ConchTestForwardingPort} to connect to the third-party server.
Once L{ConchTestForwardingPort} has disconnected, kill the process and fire
a Deferred with the data received by the L{ConchTestForwardingPort}.
@ivar deferred: Set by whatever uses this object. Accessed using
L{_getDeferred}, which destroys the value so the Deferred is not
fired twice. Fires when the process is terminated.
"""
deferred = None
def __init__(self, port, data):
"""
@type port: L{int}
@param port: The port on which the third-party server is listening.
(it is assumed that the server is running on localhost).
@type data: L{str}
@param data: This is sent to the third-party server. Must end with '\n'
in order to trigger a disconnect.
"""
self.port = port
self.buffer = None
self.data = data
def _getDeferred(self):
d, self.deferred = self.deferred, None
return d
def connectionMade(self):
self._connect()
def _connect(self):
"""
Connect to the server, which is often a third-party process.
Tries to reconnect if it fails because we have no way of determining
exactly when the port becomes available for listening -- we can only
know when the process starts.
"""
cc = protocol.ClientCreator(reactor, ConchTestForwardingPort, self, self.data)
d = cc.connectTCP("127.0.0.1", self.port)
d.addErrback(self._ebConnect)
return d
def _ebConnect(self, f):
reactor.callLater(0.1, self._connect)
def forwardingPortDisconnected(self, buffer):
"""
The network connection has died; save the buffer of output
from the network and attempt to quit the process gracefully,
and then (after the reactor has spun) send it a KILL signal.
"""
self.buffer = buffer
self.transport.write(b"\x03")
self.transport.loseConnection()
reactor.callLater(0, self._reallyDie)
def _reallyDie(self):
try:
self.transport.signalProcess("KILL")
except ProcessExitedAlready:
pass
def processEnded(self, reason):
"""
Fire the Deferred at self.deferred with the data collected
from the L{ConchTestForwardingPort} connection, if any.
"""
self._getDeferred().callback(self.buffer)
class ConchTestForwardingPort(protocol.Protocol):
"""
Connects to server launched by a third-party process (managed by
L{ConchTestForwardingProcess}) sends data, then reports whatever it
received back to the L{ConchTestForwardingProcess} once the connection
is ended.
"""
def __init__(self, protocol, data):
"""
@type protocol: L{ConchTestForwardingProcess}
@param protocol: The L{ProcessProtocol} which made this connection.
@type data: str
@param data: The data to be sent to the third-party server.
"""
self.protocol = protocol
self.data = data
def connectionMade(self):
self.buffer = b""
self.transport.write(self.data)
def dataReceived(self, data):
self.buffer += data
def connectionLost(self, reason):
self.protocol.forwardingPortDisconnected(self.buffer)
def _makeArgs(args, mod="conch"):
start = [
sys.executable,
"-c"
"""
### Twisted Preamble
import sys, os
path = os.path.abspath(sys.argv[0])
while os.path.dirname(path) != path:
if os.path.basename(path).startswith('Twisted'):
sys.path.insert(0, path)
break
path = os.path.dirname(path)
from twisted.conch.scripts.%s import run
run()"""
% mod,
]
madeArgs = []
for arg in start + list(args):
if isinstance(arg, str):
arg = arg.encode("utf-8")
madeArgs.append(arg)
return madeArgs
class ConchServerSetupMixin:
if not cryptography:
skip = "can't run without cryptography"
@staticmethod
def realmFactory():
return ConchTestRealm(b"testuser")
def _createFiles(self):
for f in ["rsa_test", "rsa_test.pub", "kh_test"]:
if os.path.exists(f):
os.remove(f)
with open("rsa_test", "wb") as f:
f.write(privateRSA_openssh)
with open("rsa_test.pub", "wb") as f:
f.write(publicRSA_openssh)
os.chmod("rsa_test", 0o600)
permissions = FilePath("rsa_test").getPermissions()
if permissions.group.read or permissions.other.read:
raise SkipTest(
"private key readable by others despite chmod;"
" possible windows permission issue?"
" see https://tm.tl/9767"
)
with open("kh_test", "wb") as f:
f.write(b"127.0.0.1 " + publicRSA_openssh)
def _getFreePort(self):
s = socket.socket()
s.bind(("", 0))
port = s.getsockname()[1]
s.close()
return port
def _makeConchFactory(self):
"""
Make a L{ConchTestServerFactory}, which allows us to start a
L{ConchTestServer} -- i.e. an actually listening conch.
"""
realm = self.realmFactory()
p = portal.Portal(realm)
p.registerChecker(conchTestPublicKeyChecker())
factory = ConchTestServerFactory()
factory.portal = p
return factory
def setUp(self):
self._createFiles()
self.conchFactory = self._makeConchFactory()
self.conchFactory.expectedLoseConnection = 1
self.conchServer = reactor.listenTCP(
0, self.conchFactory, interface="127.0.0.1"
)
self.echoServer = reactor.listenTCP(0, EchoFactory())
self.echoPort = self.echoServer.getHost().port
if HAS_IPV6:
self.echoServerV6 = reactor.listenTCP(0, EchoFactory(), interface="::1")
self.echoPortV6 = self.echoServerV6.getHost().port
def tearDown(self):
try:
self.conchFactory.proto.done = 1
except AttributeError:
pass
else:
self.conchFactory.proto.transport.loseConnection()
deferreds = [
defer.maybeDeferred(self.conchServer.stopListening),
defer.maybeDeferred(self.echoServer.stopListening),
]
if HAS_IPV6:
deferreds.append(defer.maybeDeferred(self.echoServerV6.stopListening))
return defer.gatherResults(deferreds)
class ForwardingMixin(ConchServerSetupMixin):
"""
Template class for tests of the Conch server's ability to forward arbitrary
protocols over SSH.
These tests are integration tests, not unit tests. They launch a Conch
server, a custom TCP server (just an L{EchoProtocol}) and then call
L{execute}.
L{execute} is implemented by subclasses of L{ForwardingMixin}. It should
cause an SSH client to connect to the Conch server, asking it to forward
data to the custom TCP server.
"""
def test_exec(self):
"""
Test that we can use whatever client to send the command "echo goodbye"
to the Conch server. Make sure we receive "goodbye" back from the
server.
"""
d = self.execute("echo goodbye", ConchTestOpenSSHProcess())
return d.addCallback(self.assertEqual, b"goodbye\n")
def test_localToRemoteForwarding(self):
"""
Test that we can use whatever client to forward a local port to a
specified port on the server.
"""
localPort = self._getFreePort()
process = ConchTestForwardingProcess(localPort, b"test\n")
d = self.execute(
"", process, sshArgs="-N -L%i:127.0.0.1:%i" % (localPort, self.echoPort)
)
d.addCallback(self.assertEqual, b"test\n")
return d
def test_remoteToLocalForwarding(self):
"""
Test that we can use whatever client to forward a port from the server
to a port locally.
"""
localPort = self._getFreePort()
process = ConchTestForwardingProcess(localPort, b"test\n")
d = self.execute(
"", process, sshArgs="-N -R %i:127.0.0.1:%i" % (localPort, self.echoPort)
)
d.addCallback(self.assertEqual, b"test\n")
return d
# Conventionally there is a separate adapter object which provides ISession for
# the user, but making the user provide ISession directly works too. This isn't
# a full implementation of ISession though, just enough to make these tests
# pass.
@implementer(ISession)
class RekeyAvatar(ConchUser):
"""
This avatar implements a shell which sends 60 numbered lines to whatever
connects to it, then closes the session with a 0 exit status.
60 lines is selected as being enough to send more than 2kB of traffic, the
amount the client is configured to initiate a rekey after.
"""
def __init__(self):
ConchUser.__init__(self)
self.channelLookup[b"session"] = SSHSession
def openShell(self, transport):
"""
Write 60 lines of data to the transport, then exit.
"""
proto = protocol.Protocol()
proto.makeConnection(transport)
transport.makeConnection(wrapProtocol(proto))
# Send enough bytes to the connection so that a rekey is triggered in
# the client.
def write(counter):
i = next(counter)
if i == 60:
call.stop()
transport.session.conn.sendRequest(
transport.session, b"exit-status", b"\x00\x00\x00\x00"
)
transport.loseConnection()
else:
line = "line #%02d\n" % (i,)
line = line.encode("utf-8")
transport.write(line)
# The timing for this loop is an educated guess (and/or the result of
# experimentation) to exercise the case where a packet is generated
# mid-rekey. Since the other side of the connection is (so far) the
# OpenSSH command line client, there's no easy way to determine when the
# rekey has been initiated. If there were, then generating a packet
# immediately at that time would be a better way to test the
# functionality being tested here.
call = LoopingCall(write, count())
call.start(0.01)
def closed(self):
"""
Ignore the close of the session.
"""
def eofReceived(self):
# ISession.eofReceived
pass
def execCommand(self, proto, command):
# ISession.execCommand
pass
def getPty(self, term, windowSize, modes):
# ISession.getPty
pass
def windowChanged(self, newWindowSize):
# ISession.windowChanged
pass
class RekeyRealm:
"""
This realm gives out new L{RekeyAvatar} instances for any avatar request.
"""
def requestAvatar(self, avatarID, mind, *interfaces):
return interfaces[0], RekeyAvatar(), lambda: None
class RekeyTestsMixin(ConchServerSetupMixin):
"""
TestCase mixin which defines tests exercising L{SSHTransportBase}'s handling
of rekeying messages.
"""
realmFactory = RekeyRealm
def test_clientRekey(self):
"""
After a client-initiated rekey is completed, application data continues
to be passed over the SSH connection.
"""
process = ConchTestOpenSSHProcess()
d = self.execute("", process, "-o RekeyLimit=2K")
def finished(result):
expectedResult = "\n".join(["line #%02d" % (i,) for i in range(60)]) + "\n"
expectedResult = expectedResult.encode("utf-8")
self.assertEqual(result, expectedResult)
d.addCallback(finished)
return d
class OpenSSHClientMixin:
if not which("ssh"):
skip = "no ssh command-line client available"
def execute(self, remoteCommand, process, sshArgs=""):
"""
Connects to the SSH server started in L{ConchServerSetupMixin.setUp} by
running the 'ssh' command line tool.
@type remoteCommand: str
@param remoteCommand: The command (with arguments) to run on the
remote end.
@type process: L{ConchTestOpenSSHProcess}
@type sshArgs: str
@param sshArgs: Arguments to pass to the 'ssh' process.
@return: L{defer.Deferred}
"""
process.deferred = defer.Deferred()
# Pass -F /dev/null to avoid the user's configuration file from
# being loaded, as it may contain settings that cause our tests to
# fail or hang.
cmdline = (
(
"ssh -2 -l testuser -p %i "
"-F /dev/null "
"-oIdentitiesOnly=yes "
"-oUserKnownHostsFile=kh_test "
"-oPasswordAuthentication=no "
# Always use the RSA key, since that's the one in kh_test.
"-oHostKeyAlgorithms=ssh-rsa "
"-a "
"-i rsa_test "
)
+ sshArgs
+ " 127.0.0.1 "
+ remoteCommand
)
port = self.conchServer.getHost().port
cmds = (cmdline % port).split()
encodedCmds = []
for cmd in cmds:
encodedCmds.append(cmd.encode("utf-8"))
reactor.spawnProcess(process, which("ssh")[0], encodedCmds)
return process.deferred
class OpenSSHKeyExchangeTests(ConchServerSetupMixin, OpenSSHClientMixin, TestCase):
"""
Tests L{SSHTransportBase}'s key exchange algorithm compatibility with
OpenSSH.
"""
def assertExecuteWithKexAlgorithm(self, keyExchangeAlgo):
"""
Call execute() method of L{OpenSSHClientMixin} with an ssh option that
forces the exclusive use of the key exchange algorithm specified by
keyExchangeAlgo
@type keyExchangeAlgo: L{str}
@param keyExchangeAlgo: The key exchange algorithm to use
@return: L{defer.Deferred}
"""
kexAlgorithms = []
try:
output = subprocess.check_output(
[which("ssh")[0], "-Q", "kex"], stderr=subprocess.STDOUT
)
if not isinstance(output, str):
output = output.decode("utf-8")
kexAlgorithms = output.split()
except BaseException:
pass
if keyExchangeAlgo not in kexAlgorithms:
raise SkipTest(f"{keyExchangeAlgo} not supported by ssh client")
d = self.execute(
"echo hello",
ConchTestOpenSSHProcess(),
"-oKexAlgorithms=" + keyExchangeAlgo,
)
return d.addCallback(self.assertEqual, b"hello\n")
def test_ECDHSHA256(self):
"""
The ecdh-sha2-nistp256 key exchange algorithm is compatible with
OpenSSH
"""
return self.assertExecuteWithKexAlgorithm("ecdh-sha2-nistp256")
def test_ECDHSHA384(self):
"""
The ecdh-sha2-nistp384 key exchange algorithm is compatible with
OpenSSH
"""
return self.assertExecuteWithKexAlgorithm("ecdh-sha2-nistp384")
def test_ECDHSHA521(self):
"""
The ecdh-sha2-nistp521 key exchange algorithm is compatible with
OpenSSH
"""
return self.assertExecuteWithKexAlgorithm("ecdh-sha2-nistp521")
def test_DH_GROUP14(self):
"""
The diffie-hellman-group14-sha1 key exchange algorithm is compatible
with OpenSSH.
"""
return self.assertExecuteWithKexAlgorithm("diffie-hellman-group14-sha1")
def test_DH_GROUP_EXCHANGE_SHA1(self):
"""
The diffie-hellman-group-exchange-sha1 key exchange algorithm is
compatible with OpenSSH.
"""
return self.assertExecuteWithKexAlgorithm("diffie-hellman-group-exchange-sha1")
def test_DH_GROUP_EXCHANGE_SHA256(self):
"""
The diffie-hellman-group-exchange-sha256 key exchange algorithm is
compatible with OpenSSH.
"""
return self.assertExecuteWithKexAlgorithm(
"diffie-hellman-group-exchange-sha256"
)
def test_unsupported_algorithm(self):
"""
The list of key exchange algorithms supported
by OpenSSH client is obtained with C{ssh -Q kex}.
"""
self.assertRaises(
SkipTest, self.assertExecuteWithKexAlgorithm, "unsupported-algorithm"
)
class OpenSSHClientForwardingTests(ForwardingMixin, OpenSSHClientMixin, TestCase):
"""
Connection forwarding tests run against the OpenSSL command line client.
"""
@skipWithoutIPv6
def test_localToRemoteForwardingV6(self):
"""
Forwarding of arbitrary IPv6 TCP connections via SSH.
"""
localPort = self._getFreePort()
process = ConchTestForwardingProcess(localPort, b"test\n")
d = self.execute(
"", process, sshArgs="-N -L%i:[::1]:%i" % (localPort, self.echoPortV6)
)
d.addCallback(self.assertEqual, b"test\n")
return d
class OpenSSHClientRekeyTests(RekeyTestsMixin, OpenSSHClientMixin, TestCase):
"""
Rekeying tests run against the OpenSSL command line client.
"""
class CmdLineClientTests(ForwardingMixin, TestCase):
"""
Connection forwarding tests run against the Conch command line client.
"""
if runtime.platformType == "win32":
skip = "can't run cmdline client on win32"
def execute(self, remoteCommand, process, sshArgs="", conchArgs=None):
"""
As for L{OpenSSHClientTestCase.execute}, except it runs the 'conch'
command line tool, not 'ssh'.
"""
if conchArgs is None:
conchArgs = []
process.deferred = defer.Deferred()
port = self.conchServer.getHost().port
cmd = (
"-p {} -l testuser "
"--known-hosts kh_test "
"--user-authentications publickey "
"-a "
"-i rsa_test "
"-v ".format(port) + sshArgs + " 127.0.0.1 " + remoteCommand
)
cmds = _makeArgs(conchArgs + cmd.split())
env = os.environ.copy()
env["PYTHONPATH"] = os.pathsep.join(sys.path)
encodedCmds = []
encodedEnv = {}
for cmd in cmds:
if isinstance(cmd, str):
cmd = cmd.encode("utf-8")
encodedCmds.append(cmd)
for var in env:
val = env[var]
if isinstance(var, str):
var = var.encode("utf-8")
if isinstance(val, str):
val = val.encode("utf-8")
encodedEnv[var] = val
reactor.spawnProcess(process, sys.executable, encodedCmds, env=encodedEnv)
return process.deferred
def test_runWithLogFile(self):
"""
It can store logs to a local file.
"""
def cb_check_log(result):
logContent = logPath.getContent()
self.assertIn(b"Log opened.", logContent)
logPath = filepath.FilePath(self.mktemp())
d = self.execute(
remoteCommand="echo goodbye",
process=ConchTestOpenSSHProcess(),
conchArgs=[
"--log",
"--logfile",
logPath.path,
"--host-key-algorithms",
"ssh-rsa",
],
)
d.addCallback(self.assertEqual, b"goodbye\n")
d.addCallback(cb_check_log)
return d
def test_runWithNoHostAlgorithmsSpecified(self):
"""
Do not use --host-key-algorithms flag on command line.
"""
d = self.execute(
remoteCommand="echo goodbye", process=ConchTestOpenSSHProcess()
)
d.addCallback(self.assertEqual, b"goodbye\n")
return d

View File

@@ -0,0 +1,846 @@
# Copyright (c) 2007-2010 Twisted Matrix Laboratories.
# See LICENSE for details
"""
This module tests twisted.conch.ssh.connection.
"""
import struct
from twisted.conch.ssh import channel
from twisted.conch.test import test_userauth
from twisted.python.reflect import requireModule
from twisted.trial import unittest
cryptography = requireModule("cryptography")
from twisted.conch import error
if cryptography:
from twisted.conch.ssh import common, connection
else:
class connection: # type: ignore[no-redef]
class SSHConnection:
pass
class TestChannel(channel.SSHChannel):
"""
A mocked-up version of twisted.conch.ssh.channel.SSHChannel.
@ivar gotOpen: True if channelOpen has been called.
@type gotOpen: L{bool}
@ivar specificData: the specific channel open data passed to channelOpen.
@type specificData: L{bytes}
@ivar openFailureReason: the reason passed to openFailed.
@type openFailed: C{error.ConchError}
@ivar inBuffer: a C{list} of strings received by the channel.
@type inBuffer: C{list}
@ivar extBuffer: a C{list} of 2-tuples (type, extended data) of received by
the channel.
@type extBuffer: C{list}
@ivar numberRequests: the number of requests that have been made to this
channel.
@type numberRequests: L{int}
@ivar gotEOF: True if the other side sent EOF.
@type gotEOF: L{bool}
@ivar gotOneClose: True if the other side closed the connection.
@type gotOneClose: L{bool}
@ivar gotClosed: True if the channel is closed.
@type gotClosed: L{bool}
"""
name = b"TestChannel"
gotOpen = False
gotClosed = False
def logPrefix(self):
return "TestChannel %i" % self.id
def channelOpen(self, specificData):
"""
The channel is open. Set up the instance variables.
"""
self.gotOpen = True
self.specificData = specificData
self.inBuffer = []
self.extBuffer = []
self.numberRequests = 0
self.gotEOF = False
self.gotOneClose = False
self.gotClosed = False
def openFailed(self, reason):
"""
Opening the channel failed. Store the reason why.
"""
self.openFailureReason = reason
def request_test(self, data):
"""
A test request. Return True if data is 'data'.
@type data: L{bytes}
"""
self.numberRequests += 1
return data == b"data"
def dataReceived(self, data):
"""
Data was received. Store it in the buffer.
"""
self.inBuffer.append(data)
def extReceived(self, code, data):
"""
Extended data was received. Store it in the buffer.
"""
self.extBuffer.append((code, data))
def eofReceived(self):
"""
EOF was received. Remember it.
"""
self.gotEOF = True
def closeReceived(self):
"""
Close was received. Remember it.
"""
self.gotOneClose = True
def closed(self):
"""
The channel is closed. Rembember it.
"""
self.gotClosed = True
class TestAvatar:
"""
A mocked-up version of twisted.conch.avatar.ConchUser
"""
_ARGS_ERROR_CODE = 123
def lookupChannel(self, channelType, windowSize, maxPacket, data):
"""
The server wants us to return a channel. If the requested channel is
our TestChannel, return it, otherwise return None.
"""
if channelType == TestChannel.name:
return TestChannel(
remoteWindow=windowSize,
remoteMaxPacket=maxPacket,
data=data,
avatar=self,
)
elif channelType == b"conch-error-args":
# Raise a ConchError with backwards arguments to make sure the
# connection fixes it for us. This case should be deprecated and
# deleted eventually, but only after all of Conch gets the argument
# order right.
raise error.ConchError(self._ARGS_ERROR_CODE, "error args in wrong order")
def gotGlobalRequest(self, requestType, data):
"""
The client has made a global request. If the global request is
'TestGlobal', return True. If the global request is 'TestData',
return True and the request-specific data we received. Otherwise,
return False.
"""
if requestType == b"TestGlobal":
return True
elif requestType == b"TestData":
return True, data
else:
return False
class TestConnection(connection.SSHConnection):
"""
A subclass of SSHConnection for testing.
@ivar channel: the current channel.
@type channel. C{TestChannel}
"""
if not cryptography:
skip = "Cannot run without cryptography"
def logPrefix(self):
return "TestConnection"
def global_TestGlobal(self, data):
"""
The other side made the 'TestGlobal' global request. Return True.
"""
return True
def global_Test_Data(self, data):
"""
The other side made the 'Test-Data' global request. Return True and
the data we received.
"""
return True, data
def channel_TestChannel(self, windowSize, maxPacket, data):
"""
The other side is requesting the TestChannel. Create a C{TestChannel}
instance, store it, and return it.
"""
self.channel = TestChannel(
remoteWindow=windowSize, remoteMaxPacket=maxPacket, data=data
)
return self.channel
def channel_ErrorChannel(self, windowSize, maxPacket, data):
"""
The other side is requesting the ErrorChannel. Raise an exception.
"""
raise AssertionError("no such thing")
class ConnectionTests(unittest.TestCase):
if not cryptography:
skip = "Cannot run without cryptography"
def setUp(self):
self.transport = test_userauth.FakeTransport(None)
self.transport.avatar = TestAvatar()
self.conn = TestConnection()
self.conn.transport = self.transport
self.conn.serviceStarted()
def _openChannel(self, channel):
"""
Open the channel with the default connection.
"""
self.conn.openChannel(channel)
self.transport.packets = self.transport.packets[:-1]
self.conn.ssh_CHANNEL_OPEN_CONFIRMATION(
struct.pack(">2L", channel.id, 255) + b"\x00\x02\x00\x00\x00\x00\x80\x00"
)
def tearDown(self):
self.conn.serviceStopped()
def test_linkAvatar(self):
"""
Test that the connection links itself to the avatar in the
transport.
"""
self.assertIs(self.transport.avatar.conn, self.conn)
def test_serviceStopped(self):
"""
Test that serviceStopped() closes any open channels.
"""
channel1 = TestChannel()
channel2 = TestChannel()
self.conn.openChannel(channel1)
self.conn.openChannel(channel2)
self.conn.ssh_CHANNEL_OPEN_CONFIRMATION(b"\x00\x00\x00\x00" * 4)
self.assertTrue(channel1.gotOpen)
self.assertFalse(channel1.gotClosed)
self.assertFalse(channel2.gotOpen)
self.assertFalse(channel2.gotClosed)
self.conn.serviceStopped()
self.assertTrue(channel1.gotClosed)
self.assertFalse(channel2.gotOpen)
self.assertFalse(channel2.gotClosed)
from twisted.internet.error import ConnectionLost
self.assertIsInstance(channel2.openFailureReason, ConnectionLost)
def test_GLOBAL_REQUEST(self):
"""
Test that global request packets are dispatched to the global_*
methods and the return values are translated into success or failure
messages.
"""
self.conn.ssh_GLOBAL_REQUEST(common.NS(b"TestGlobal") + b"\xff")
self.assertEqual(
self.transport.packets, [(connection.MSG_REQUEST_SUCCESS, b"")]
)
self.transport.packets = []
self.conn.ssh_GLOBAL_REQUEST(common.NS(b"TestData") + b"\xff" + b"test data")
self.assertEqual(
self.transport.packets, [(connection.MSG_REQUEST_SUCCESS, b"test data")]
)
self.transport.packets = []
self.conn.ssh_GLOBAL_REQUEST(common.NS(b"TestBad") + b"\xff")
self.assertEqual(
self.transport.packets, [(connection.MSG_REQUEST_FAILURE, b"")]
)
self.transport.packets = []
self.conn.ssh_GLOBAL_REQUEST(common.NS(b"TestGlobal") + b"\x00")
self.assertEqual(self.transport.packets, [])
def test_REQUEST_SUCCESS(self):
"""
Test that global request success packets cause the Deferred to be
called back.
"""
d = self.conn.sendGlobalRequest(b"request", b"data", True)
self.conn.ssh_REQUEST_SUCCESS(b"data")
def check(data):
self.assertEqual(data, b"data")
d.addCallback(check)
d.addErrback(self.fail)
return d
def test_REQUEST_FAILURE(self):
"""
Test that global request failure packets cause the Deferred to be
erred back.
"""
d = self.conn.sendGlobalRequest(b"request", b"data", True)
self.conn.ssh_REQUEST_FAILURE(b"data")
def check(f):
self.assertEqual(f.value.data, b"data")
d.addCallback(self.fail)
d.addErrback(check)
return d
def test_CHANNEL_OPEN(self):
"""
Test that open channel packets cause a channel to be created and
opened or a failure message to be returned.
"""
del self.transport.avatar
self.conn.ssh_CHANNEL_OPEN(common.NS(b"TestChannel") + b"\x00\x00\x00\x01" * 4)
self.assertTrue(self.conn.channel.gotOpen)
self.assertEqual(self.conn.channel.conn, self.conn)
self.assertEqual(self.conn.channel.data, b"\x00\x00\x00\x01")
self.assertEqual(self.conn.channel.specificData, b"\x00\x00\x00\x01")
self.assertEqual(self.conn.channel.remoteWindowLeft, 1)
self.assertEqual(self.conn.channel.remoteMaxPacket, 1)
self.assertEqual(
self.transport.packets,
[
(
connection.MSG_CHANNEL_OPEN_CONFIRMATION,
b"\x00\x00\x00\x01\x00\x00\x00\x00\x00\x02\x00\x00"
b"\x00\x00\x80\x00",
)
],
)
self.transport.packets = []
self.conn.ssh_CHANNEL_OPEN(common.NS(b"BadChannel") + b"\x00\x00\x00\x02" * 4)
self.flushLoggedErrors()
self.assertEqual(
self.transport.packets,
[
(
connection.MSG_CHANNEL_OPEN_FAILURE,
b"\x00\x00\x00\x02\x00\x00\x00\x03"
+ common.NS(b"unknown channel")
+ common.NS(b""),
)
],
)
self.transport.packets = []
self.conn.ssh_CHANNEL_OPEN(common.NS(b"ErrorChannel") + b"\x00\x00\x00\x02" * 4)
self.flushLoggedErrors()
self.assertEqual(
self.transport.packets,
[
(
connection.MSG_CHANNEL_OPEN_FAILURE,
b"\x00\x00\x00\x02\x00\x00\x00\x02"
+ common.NS(b"unknown failure")
+ common.NS(b""),
)
],
)
def _lookupChannelErrorTest(self, code):
"""
Deliver a request for a channel open which will result in an exception
being raised during channel lookup. Assert that an error response is
delivered as a result.
"""
self.transport.avatar._ARGS_ERROR_CODE = code
self.conn.ssh_CHANNEL_OPEN(
common.NS(b"conch-error-args") + b"\x00\x00\x00\x01" * 4
)
errors = self.flushLoggedErrors(error.ConchError)
self.assertEqual(len(errors), 1, f"Expected one error, got: {errors!r}")
self.assertEqual(errors[0].value.args, (123, "error args in wrong order"))
self.assertEqual(
self.transport.packets,
[
(
connection.MSG_CHANNEL_OPEN_FAILURE,
# The response includes some bytes which identifying the
# associated request, as well as the error code (7b in hex) and
# the error message.
b"\x00\x00\x00\x01\x00\x00\x00\x7b"
+ common.NS(b"error args in wrong order")
+ common.NS(b""),
)
],
)
def test_lookupChannelError(self):
"""
If a C{lookupChannel} implementation raises L{error.ConchError} with the
arguments in the wrong order, a C{MSG_CHANNEL_OPEN} failure is still
sent in response to the message.
This is a temporary work-around until L{error.ConchError} is given
better attributes and all of the Conch code starts constructing
instances of it properly. Eventually this functionality should be
deprecated and then removed.
"""
self._lookupChannelErrorTest(123)
def test_CHANNEL_OPEN_CONFIRMATION(self):
"""
Test that channel open confirmation packets cause the channel to be
notified that it's open.
"""
channel = TestChannel()
self.conn.openChannel(channel)
self.conn.ssh_CHANNEL_OPEN_CONFIRMATION(b"\x00\x00\x00\x00" * 5)
self.assertEqual(channel.remoteWindowLeft, 0)
self.assertEqual(channel.remoteMaxPacket, 0)
self.assertEqual(channel.specificData, b"\x00\x00\x00\x00")
self.assertEqual(self.conn.channelsToRemoteChannel[channel], 0)
self.assertEqual(self.conn.localToRemoteChannel[0], 0)
def test_CHANNEL_OPEN_FAILURE(self):
"""
Test that channel open failure packets cause the channel to be
notified that its opening failed.
"""
channel = TestChannel()
self.conn.openChannel(channel)
self.conn.ssh_CHANNEL_OPEN_FAILURE(
b"\x00\x00\x00\x00\x00\x00\x00" b"\x01" + common.NS(b"failure!")
)
self.assertEqual(channel.openFailureReason.args, (b"failure!", 1))
self.assertIsNone(self.conn.channels.get(channel))
def test_CHANNEL_WINDOW_ADJUST(self):
"""
Test that channel window adjust messages add bytes to the channel
window.
"""
channel = TestChannel()
self._openChannel(channel)
oldWindowSize = channel.remoteWindowLeft
self.conn.ssh_CHANNEL_WINDOW_ADJUST(b"\x00\x00\x00\x00\x00\x00\x00" b"\x01")
self.assertEqual(channel.remoteWindowLeft, oldWindowSize + 1)
def test_CHANNEL_DATA(self):
"""
Test that channel data messages are passed up to the channel, or
cause the channel to be closed if the data is too large.
"""
channel = TestChannel(localWindow=6, localMaxPacket=5)
self._openChannel(channel)
self.conn.ssh_CHANNEL_DATA(b"\x00\x00\x00\x00" + common.NS(b"data"))
self.assertEqual(channel.inBuffer, [b"data"])
self.assertEqual(
self.transport.packets,
[
(
connection.MSG_CHANNEL_WINDOW_ADJUST,
b"\x00\x00\x00\xff" b"\x00\x00\x00\x04",
)
],
)
self.transport.packets = []
longData = b"a" * (channel.localWindowLeft + 1)
self.conn.ssh_CHANNEL_DATA(b"\x00\x00\x00\x00" + common.NS(longData))
self.assertEqual(channel.inBuffer, [b"data"])
self.assertEqual(
self.transport.packets,
[(connection.MSG_CHANNEL_CLOSE, b"\x00\x00\x00\xff")],
)
channel = TestChannel()
self._openChannel(channel)
bigData = b"a" * (channel.localMaxPacket + 1)
self.transport.packets = []
self.conn.ssh_CHANNEL_DATA(b"\x00\x00\x00\x01" + common.NS(bigData))
self.assertEqual(channel.inBuffer, [])
self.assertEqual(
self.transport.packets,
[(connection.MSG_CHANNEL_CLOSE, b"\x00\x00\x00\xff")],
)
def test_CHANNEL_EXTENDED_DATA(self):
"""
Test that channel extended data messages are passed up to the channel,
or cause the channel to be closed if they're too big.
"""
channel = TestChannel(localWindow=6, localMaxPacket=5)
self._openChannel(channel)
self.conn.ssh_CHANNEL_EXTENDED_DATA(
b"\x00\x00\x00\x00\x00\x00\x00" b"\x00" + common.NS(b"data")
)
self.assertEqual(channel.extBuffer, [(0, b"data")])
self.assertEqual(
self.transport.packets,
[
(
connection.MSG_CHANNEL_WINDOW_ADJUST,
b"\x00\x00\x00\xff" b"\x00\x00\x00\x04",
)
],
)
self.transport.packets = []
longData = b"a" * (channel.localWindowLeft + 1)
self.conn.ssh_CHANNEL_EXTENDED_DATA(
b"\x00\x00\x00\x00\x00\x00\x00" b"\x00" + common.NS(longData)
)
self.assertEqual(channel.extBuffer, [(0, b"data")])
self.assertEqual(
self.transport.packets,
[(connection.MSG_CHANNEL_CLOSE, b"\x00\x00\x00\xff")],
)
channel = TestChannel()
self._openChannel(channel)
bigData = b"a" * (channel.localMaxPacket + 1)
self.transport.packets = []
self.conn.ssh_CHANNEL_EXTENDED_DATA(
b"\x00\x00\x00\x01\x00\x00\x00" b"\x00" + common.NS(bigData)
)
self.assertEqual(channel.extBuffer, [])
self.assertEqual(
self.transport.packets,
[(connection.MSG_CHANNEL_CLOSE, b"\x00\x00\x00\xff")],
)
def test_CHANNEL_EOF(self):
"""
Test that channel eof messages are passed up to the channel.
"""
channel = TestChannel()
self._openChannel(channel)
self.conn.ssh_CHANNEL_EOF(b"\x00\x00\x00\x00")
self.assertTrue(channel.gotEOF)
def test_CHANNEL_CLOSE(self):
"""
Test that channel close messages are passed up to the channel. Also,
test that channel.close() is called if both sides are closed when this
message is received.
"""
channel = TestChannel()
self._openChannel(channel)
self.assertTrue(channel.gotOpen)
self.assertFalse(channel.gotOneClose)
self.assertFalse(channel.gotClosed)
self.conn.sendClose(channel)
self.conn.ssh_CHANNEL_CLOSE(b"\x00\x00\x00\x00")
self.assertTrue(channel.gotOneClose)
self.assertTrue(channel.gotClosed)
def test_CHANNEL_REQUEST_success(self):
"""
Test that channel requests that succeed send MSG_CHANNEL_SUCCESS.
"""
channel = TestChannel()
self._openChannel(channel)
self.conn.ssh_CHANNEL_REQUEST(
b"\x00\x00\x00\x00" + common.NS(b"test") + b"\x00"
)
self.assertEqual(channel.numberRequests, 1)
d = self.conn.ssh_CHANNEL_REQUEST(
b"\x00\x00\x00\x00" + common.NS(b"test") + b"\xff" + b"data"
)
def check(result):
self.assertEqual(
self.transport.packets,
[(connection.MSG_CHANNEL_SUCCESS, b"\x00\x00\x00\xff")],
)
d.addCallback(check)
return d
def test_CHANNEL_REQUEST_failure(self):
"""
Test that channel requests that fail send MSG_CHANNEL_FAILURE.
"""
channel = TestChannel()
self._openChannel(channel)
d = self.conn.ssh_CHANNEL_REQUEST(
b"\x00\x00\x00\x00" + common.NS(b"test") + b"\xff"
)
def check(result):
self.assertEqual(
self.transport.packets,
[(connection.MSG_CHANNEL_FAILURE, b"\x00\x00\x00\xff")],
)
d.addCallback(self.fail)
d.addErrback(check)
return d
def test_CHANNEL_REQUEST_SUCCESS(self):
"""
Test that channel request success messages cause the Deferred to be
called back.
"""
channel = TestChannel()
self._openChannel(channel)
d = self.conn.sendRequest(channel, b"test", b"data", True)
self.conn.ssh_CHANNEL_SUCCESS(b"\x00\x00\x00\x00")
def check(result):
self.assertTrue(result)
return d
def test_CHANNEL_REQUEST_FAILURE(self):
"""
Test that channel request failure messages cause the Deferred to be
erred back.
"""
channel = TestChannel()
self._openChannel(channel)
d = self.conn.sendRequest(channel, b"test", b"", True)
self.conn.ssh_CHANNEL_FAILURE(b"\x00\x00\x00\x00")
def check(result):
self.assertEqual(result.value.value, "channel request failed")
d.addCallback(self.fail)
d.addErrback(check)
return d
def test_sendGlobalRequest(self):
"""
Test that global request messages are sent in the right format.
"""
d = self.conn.sendGlobalRequest(b"wantReply", b"data", True)
# must be added to prevent errbacking during teardown
d.addErrback(lambda failure: None)
self.conn.sendGlobalRequest(b"noReply", b"", False)
self.assertEqual(
self.transport.packets,
[
(connection.MSG_GLOBAL_REQUEST, common.NS(b"wantReply") + b"\xffdata"),
(connection.MSG_GLOBAL_REQUEST, common.NS(b"noReply") + b"\x00"),
],
)
self.assertEqual(self.conn.deferreds, {"global": [d]})
def test_openChannel(self):
"""
Test that open channel messages are sent in the right format.
"""
channel = TestChannel()
self.conn.openChannel(channel, b"aaaa")
self.assertEqual(
self.transport.packets,
[
(
connection.MSG_CHANNEL_OPEN,
common.NS(b"TestChannel")
+ b"\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x80\x00aaaa",
)
],
)
self.assertEqual(channel.id, 0)
self.assertEqual(self.conn.localChannelID, 1)
def test_sendRequest(self):
"""
Test that channel request messages are sent in the right format.
"""
channel = TestChannel()
self._openChannel(channel)
d = self.conn.sendRequest(channel, b"test", b"test", True)
# needed to prevent errbacks during teardown.
d.addErrback(lambda failure: None)
self.conn.sendRequest(channel, b"test2", b"", False)
channel.localClosed = True # emulate sending a close message
self.conn.sendRequest(channel, b"test3", b"", True)
self.assertEqual(
self.transport.packets,
[
(
connection.MSG_CHANNEL_REQUEST,
b"\x00\x00\x00\xff" + common.NS(b"test") + b"\x01test",
),
(
connection.MSG_CHANNEL_REQUEST,
b"\x00\x00\x00\xff" + common.NS(b"test2") + b"\x00",
),
],
)
self.assertEqual(self.conn.deferreds[0], [d])
def test_adjustWindow(self):
"""
Test that channel window adjust messages cause bytes to be added
to the window.
"""
channel = TestChannel(localWindow=5)
self._openChannel(channel)
channel.localWindowLeft = 0
self.conn.adjustWindow(channel, 1)
self.assertEqual(channel.localWindowLeft, 1)
channel.localClosed = True
self.conn.adjustWindow(channel, 2)
self.assertEqual(channel.localWindowLeft, 1)
self.assertEqual(
self.transport.packets,
[
(
connection.MSG_CHANNEL_WINDOW_ADJUST,
b"\x00\x00\x00\xff" b"\x00\x00\x00\x01",
)
],
)
def test_sendData(self):
"""
Test that channel data messages are sent in the right format.
"""
channel = TestChannel()
self._openChannel(channel)
self.conn.sendData(channel, b"a")
channel.localClosed = True
self.conn.sendData(channel, b"b")
self.assertEqual(
self.transport.packets,
[(connection.MSG_CHANNEL_DATA, b"\x00\x00\x00\xff" + common.NS(b"a"))],
)
def test_sendExtendedData(self):
"""
Test that channel extended data messages are sent in the right format.
"""
channel = TestChannel()
self._openChannel(channel)
self.conn.sendExtendedData(channel, 1, b"test")
channel.localClosed = True
self.conn.sendExtendedData(channel, 2, b"test2")
self.assertEqual(
self.transport.packets,
[
(
connection.MSG_CHANNEL_EXTENDED_DATA,
b"\x00\x00\x00\xff" + b"\x00\x00\x00\x01" + common.NS(b"test"),
)
],
)
def test_sendEOF(self):
"""
Test that channel EOF messages are sent in the right format.
"""
channel = TestChannel()
self._openChannel(channel)
self.conn.sendEOF(channel)
self.assertEqual(
self.transport.packets, [(connection.MSG_CHANNEL_EOF, b"\x00\x00\x00\xff")]
)
channel.localClosed = True
self.conn.sendEOF(channel)
self.assertEqual(
self.transport.packets, [(connection.MSG_CHANNEL_EOF, b"\x00\x00\x00\xff")]
)
def test_sendClose(self):
"""
Test that channel close messages are sent in the right format.
"""
channel = TestChannel()
self._openChannel(channel)
self.conn.sendClose(channel)
self.assertTrue(channel.localClosed)
self.assertEqual(
self.transport.packets,
[(connection.MSG_CHANNEL_CLOSE, b"\x00\x00\x00\xff")],
)
self.conn.sendClose(channel)
self.assertEqual(
self.transport.packets,
[(connection.MSG_CHANNEL_CLOSE, b"\x00\x00\x00\xff")],
)
channel2 = TestChannel()
self._openChannel(channel2)
self.assertTrue(channel2.gotOpen)
self.assertFalse(channel2.gotClosed)
channel2.remoteClosed = True
self.conn.sendClose(channel2)
self.assertTrue(channel2.gotClosed)
def test_getChannelWithAvatar(self):
"""
Test that getChannel dispatches to the avatar when an avatar is
present. Correct functioning without the avatar is verified in
test_CHANNEL_OPEN.
"""
channel = self.conn.getChannel(b"TestChannel", 50, 30, b"data")
self.assertEqual(channel.data, b"data")
self.assertEqual(channel.remoteWindowLeft, 50)
self.assertEqual(channel.remoteMaxPacket, 30)
self.assertRaises(
error.ConchError, self.conn.getChannel, b"BadChannel", 50, 30, b"data"
)
def test_gotGlobalRequestWithoutAvatar(self):
"""
Test that gotGlobalRequests dispatches to global_* without an avatar.
"""
del self.transport.avatar
self.assertTrue(self.conn.gotGlobalRequest(b"TestGlobal", b"data"))
self.assertEqual(
self.conn.gotGlobalRequest(b"Test-Data", b"data"), (True, b"data")
)
self.assertFalse(self.conn.gotGlobalRequest(b"BadGlobal", b"data"))
def test_channelClosedCausesLeftoverChannelDeferredsToErrback(self):
"""
Whenever an SSH channel gets closed any Deferred that was returned by a
sendRequest() on its parent connection must be errbacked.
"""
channel = TestChannel()
self._openChannel(channel)
d = self.conn.sendRequest(channel, b"dummyrequest", b"dummydata", wantReply=1)
d = self.assertFailure(d, error.ConchError)
self.conn.channelClosed(channel)
return d
class CleanConnectionShutdownTests(unittest.TestCase):
"""
Check whether correct cleanup is performed on connection shutdown.
"""
if not cryptography:
skip = "Cannot run without cryptography"
def setUp(self):
self.transport = test_userauth.FakeTransport(None)
self.transport.avatar = TestAvatar()
self.conn = TestConnection()
self.conn.transport = self.transport
def test_serviceStoppedCausesLeftoverGlobalDeferredsToErrback(self):
"""
Once the service is stopped any leftover global deferred returned by
a sendGlobalRequest() call must be errbacked.
"""
self.conn.serviceStarted()
d = self.conn.sendGlobalRequest(b"dummyrequest", b"dummydata", wantReply=1)
d = self.assertFailure(d, error.ConchError)
self.conn.serviceStopped()
return d

View File

@@ -0,0 +1,326 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.conch.client.default}.
"""
import sys
from unittest import skipIf
from twisted.conch.error import ConchError
from twisted.conch.test import keydata
from twisted.internet.testing import StringTransport
from twisted.python.compat import nativeString
from twisted.python.filepath import FilePath
from twisted.python.reflect import requireModule
from twisted.python.runtime import platform
from twisted.trial.unittest import TestCase
doSkip = False
skipReason = ""
if requireModule("cryptography"):
from twisted.conch.client import default
from twisted.conch.client.agent import SSHAgentClient
from twisted.conch.client.default import SSHUserAuthClient
from twisted.conch.client.options import ConchOptions
from twisted.conch.ssh.keys import Key
else:
doSkip = True
skipReason = "cryptography required for twisted.conch.client.default."
skip = skipReason # no SSL available, skip the entire module
if platform.isWindows():
doSkip = True
skipReason = (
"genericAnswers and getPassword does not work on Windows."
" Should be fixed as part of fixing bug 6409 and 6410"
)
if not sys.stdin.isatty():
doSkip = True
skipReason = "sys.stdin is not an interactive tty"
if not sys.stdout.isatty():
doSkip = True
skipReason = "sys.stdout is not an interactive tty"
class SSHUserAuthClientTests(TestCase):
"""
Tests for L{SSHUserAuthClient}.
@type rsaPublic: L{Key}
@ivar rsaPublic: A public RSA key.
"""
def setUp(self):
self.rsaPublic = Key.fromString(keydata.publicRSA_openssh)
self.tmpdir = FilePath(self.mktemp())
self.tmpdir.makedirs()
self.rsaFile = self.tmpdir.child("id_rsa")
self.rsaFile.setContent(keydata.privateRSA_openssh)
self.tmpdir.child("id_rsa.pub").setContent(keydata.publicRSA_openssh)
def test_signDataWithAgent(self):
"""
When connected to an agent, L{SSHUserAuthClient} can use it to
request signatures of particular data with a particular L{Key}.
"""
client = SSHUserAuthClient(b"user", ConchOptions(), None)
agent = SSHAgentClient()
transport = StringTransport()
agent.makeConnection(transport)
client.keyAgent = agent
cleartext = b"Sign here"
client.signData(self.rsaPublic, cleartext)
self.assertEqual(
transport.value(),
b"\x00\x00\x01\x2d\r\x00\x00\x01\x17"
+ self.rsaPublic.blob()
+ b"\x00\x00\x00\t"
+ cleartext
+ b"\x00\x00\x00\x00",
)
def test_agentGetPublicKey(self):
"""
L{SSHUserAuthClient} looks up public keys from the agent using the
L{SSHAgentClient} class. That L{SSHAgentClient.getPublicKey} returns a
L{Key} object with one of the public keys in the agent. If no more
keys are present, it returns L{None}.
"""
agent = SSHAgentClient()
agent.blobs = [self.rsaPublic.blob()]
key = agent.getPublicKey()
self.assertTrue(key.isPublic())
self.assertEqual(key, self.rsaPublic)
self.assertIsNone(agent.getPublicKey())
def test_getPublicKeyFromFile(self):
"""
L{SSHUserAuthClient.getPublicKey()} is able to get a public key from
the first file described by its options' C{identitys} list, and return
the corresponding public L{Key} object.
"""
options = ConchOptions()
options.identitys = [self.rsaFile.path]
client = SSHUserAuthClient(b"user", options, None)
key = client.getPublicKey()
self.assertTrue(key.isPublic())
self.assertEqual(key, self.rsaPublic)
def test_getPublicKeyAgentFallback(self):
"""
If an agent is present, but doesn't return a key,
L{SSHUserAuthClient.getPublicKey} continue with the normal key lookup.
"""
options = ConchOptions()
options.identitys = [self.rsaFile.path]
agent = SSHAgentClient()
client = SSHUserAuthClient(b"user", options, None)
client.keyAgent = agent
key = client.getPublicKey()
self.assertTrue(key.isPublic())
self.assertEqual(key, self.rsaPublic)
def test_getPublicKeyBadKeyError(self):
"""
If L{keys.Key.fromFile} raises a L{keys.BadKeyError}, the
L{SSHUserAuthClient.getPublicKey} tries again to get a public key by
calling itself recursively.
"""
options = ConchOptions()
self.tmpdir.child("id_dsa.pub").setContent(keydata.publicDSA_openssh)
dsaFile = self.tmpdir.child("id_dsa")
dsaFile.setContent(keydata.privateDSA_openssh)
options.identitys = [self.rsaFile.path, dsaFile.path]
self.tmpdir.child("id_rsa.pub").setContent(b"not a key!")
client = SSHUserAuthClient(b"user", options, None)
key = client.getPublicKey()
self.assertTrue(key.isPublic())
self.assertEqual(key, Key.fromString(keydata.publicDSA_openssh))
self.assertEqual(client.usedFiles, [self.rsaFile.path, dsaFile.path])
def test_getPrivateKey(self):
"""
L{SSHUserAuthClient.getPrivateKey} will load a private key from the
last used file populated by L{SSHUserAuthClient.getPublicKey}, and
return a L{Deferred} which fires with the corresponding private L{Key}.
"""
rsaPrivate = Key.fromString(keydata.privateRSA_openssh)
options = ConchOptions()
options.identitys = [self.rsaFile.path]
client = SSHUserAuthClient(b"user", options, None)
# Populate the list of used files
client.getPublicKey()
def _cbGetPrivateKey(key):
self.assertFalse(key.isPublic())
self.assertEqual(key, rsaPrivate)
return client.getPrivateKey().addCallback(_cbGetPrivateKey)
def test_getPrivateKeyPassphrase(self):
"""
L{SSHUserAuthClient} can get a private key from a file, and return a
Deferred called back with a private L{Key} object, even if the key is
encrypted.
"""
rsaPrivate = Key.fromString(keydata.privateRSA_openssh)
passphrase = b"this is the passphrase"
self.rsaFile.setContent(rsaPrivate.toString("openssh", passphrase=passphrase))
options = ConchOptions()
options.identitys = [self.rsaFile.path]
client = SSHUserAuthClient(b"user", options, None)
# Populate the list of used files
client.getPublicKey()
def _getPassword(prompt):
self.assertEqual(
prompt, f"Enter passphrase for key '{self.rsaFile.path}': "
)
return nativeString(passphrase)
def _cbGetPrivateKey(key):
self.assertFalse(key.isPublic())
self.assertEqual(key, rsaPrivate)
self.patch(client, "_getPassword", _getPassword)
return client.getPrivateKey().addCallback(_cbGetPrivateKey)
@skipIf(doSkip, skipReason)
def test_getPassword(self):
"""
Get the password using
L{twisted.conch.client.default.SSHUserAuthClient.getPassword}
"""
class FakeTransport:
def __init__(self, host):
self.transport = self
self.host = host
def getPeer(self):
return self
options = ConchOptions()
client = SSHUserAuthClient(b"user", options, None)
client.transport = FakeTransport("127.0.0.1")
def getpass(prompt):
self.assertEqual(prompt, "user@127.0.0.1's password: ")
return "bad password"
self.patch(default.getpass, "getpass", getpass)
d = client.getPassword()
d.addCallback(self.assertEqual, b"bad password")
return d
@skipIf(doSkip, skipReason)
def test_getPasswordPrompt(self):
"""
Get the password using
L{twisted.conch.client.default.SSHUserAuthClient.getPassword}
using a different prompt.
"""
options = ConchOptions()
client = SSHUserAuthClient(b"user", options, None)
prompt = b"Give up your password"
def getpass(p):
self.assertEqual(p, nativeString(prompt))
return "bad password"
self.patch(default.getpass, "getpass", getpass)
d = client.getPassword(prompt)
d.addCallback(self.assertEqual, b"bad password")
return d
@skipIf(doSkip, skipReason)
def test_getPasswordConchError(self):
"""
Get the password using
L{twisted.conch.client.default.SSHUserAuthClient.getPassword}
and trigger a {twisted.conch.error import ConchError}.
"""
options = ConchOptions()
client = SSHUserAuthClient(b"user", options, None)
def getpass(prompt):
raise KeyboardInterrupt("User pressed CTRL-C")
self.patch(default.getpass, "getpass", getpass)
stdout, stdin = sys.stdout, sys.stdin
d = client.getPassword(b"?")
@d.addErrback
def check_sys(fail):
self.assertEqual([stdout, stdin], [sys.stdout, sys.stdin])
return fail
self.assertFailure(d, ConchError)
@skipIf(doSkip, skipReason)
def test_getGenericAnswers(self):
"""
L{twisted.conch.client.default.SSHUserAuthClient.getGenericAnswers}
"""
options = ConchOptions()
client = SSHUserAuthClient(b"user", options, None)
def getpass(prompt):
self.assertEqual(prompt, "pass prompt")
return "getpass"
self.patch(default.getpass, "getpass", getpass)
def raw_input(prompt):
self.assertEqual(prompt, "raw_input prompt")
return "raw_input"
self.patch(default, "_input", raw_input)
d = client.getGenericAnswers(
b"Name",
b"Instruction",
[(b"pass prompt", False), (b"raw_input prompt", True)],
)
d.addCallback(self.assertListEqual, ["getpass", "raw_input"])
return d
class ConchOptionsParsing(TestCase):
"""
Options parsing.
"""
def test_macs(self):
"""
Specify MAC algorithms.
"""
opts = ConchOptions()
e = self.assertRaises(SystemExit, opts.opt_macs, "invalid-mac")
self.assertIn("Unknown mac type", e.code)
opts = ConchOptions()
opts.opt_macs("hmac-sha2-512")
self.assertEqual(opts["macs"], [b"hmac-sha2-512"])
opts.opt_macs(b"hmac-sha2-512")
self.assertEqual(opts["macs"], [b"hmac-sha2-512"])
opts.opt_macs("hmac-sha2-256,hmac-sha1,hmac-md5")
self.assertEqual(opts["macs"], [b"hmac-sha2-256", b"hmac-sha1", b"hmac-md5"])
def test_host_key_algorithms(self):
"""
Specify host key algorithms.
"""
opts = ConchOptions()
e = self.assertRaises(SystemExit, opts.opt_host_key_algorithms, "invalid-key")
self.assertIn("Unknown host key type", e.code)
opts = ConchOptions()
opts.opt_host_key_algorithms("ssh-rsa")
self.assertEqual(opts["host-key-algorithms"], [b"ssh-rsa"])
opts.opt_host_key_algorithms(b"ssh-dss")
self.assertEqual(opts["host-key-algorithms"], [b"ssh-dss"])
opts.opt_host_key_algorithms("ssh-rsa,ssh-dss")
self.assertEqual(opts["host-key-algorithms"], [b"ssh-rsa", b"ssh-dss"])

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,964 @@
# -*- test-case-name: twisted.conch.test.test_filetransfer -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE file for details.
"""
Tests for L{twisted.conch.ssh.filetransfer}.
"""
import os
import re
import struct
from unittest import skipIf
from hamcrest import assert_that, equal_to
from twisted.internet import defer
from twisted.internet.error import ConnectionLost
from twisted.internet.testing import StringTransport
from twisted.protocols import loopback
from twisted.python import components
from twisted.python.filepath import FilePath
from twisted.trial.unittest import TestCase
try:
from twisted.conch import unix as _unix
except ImportError:
unix = None
else:
unix = _unix
try:
from twisted.conch.unix import (
SFTPServerForUnixConchUser as _SFTPServerForUnixConchUser,
)
except ImportError:
SFTPServerForUnixConchUser = None
else:
SFTPServerForUnixConchUser = _SFTPServerForUnixConchUser
try:
import cryptography as _cryptography
except ImportError:
cryptography = None
else:
cryptography = _cryptography
try:
from twisted.conch.avatar import ConchUser as _ConchUser
except ImportError:
ConchUser = object
else:
ConchUser = _ConchUser # type: ignore[misc]
try:
from twisted.conch.ssh import common, connection, filetransfer, session
except ImportError:
pass
class TestAvatar(ConchUser):
def __init__(self):
ConchUser.__init__(self)
self.channelLookup[b"session"] = session.SSHSession
self.subsystemLookup[b"sftp"] = filetransfer.FileTransferServer
def _runAsUser(self, f, *args, **kw):
try:
f = iter(f)
except TypeError:
f = [(f, args, kw)]
for i in f:
func = i[0]
args = len(i) > 1 and i[1] or ()
kw = len(i) > 2 and i[2] or {}
r = func(*args, **kw)
return r
class FileTransferTestAvatar(TestAvatar):
def __init__(self, homeDir):
TestAvatar.__init__(self)
self.homeDir = homeDir
def getHomeDir(self):
return FilePath(os.getcwd()).preauthChild(self.homeDir.path)
class ConchSessionForTestAvatar:
def __init__(self, avatar):
self.avatar = avatar
if SFTPServerForUnixConchUser is None:
# unix should either be a fully working module, or None. I'm not sure
# how this happens, but on win32 it does. Try to cope. --spiv.
import warnings
warnings.warn(
(
"twisted.conch.unix imported %r, "
"but doesn't define SFTPServerForUnixConchUser'"
)
% (unix,)
)
else:
class FileTransferForTestAvatar(SFTPServerForUnixConchUser): # type: ignore[misc,valid-type]
def gotVersion(self, version, otherExt):
return {b"conchTest": b"ext data"}
def extendedRequest(self, extName, extData):
if extName == b"testExtendedRequest":
return b"bar"
raise NotImplementedError
components.registerAdapter(
FileTransferForTestAvatar, TestAvatar, filetransfer.ISFTPServer
)
class SFTPTestBase(TestCase):
def setUp(self):
self.testDir = FilePath(self.mktemp())
# Give the testDir another level so we can safely "cd .." from it in
# tests.
self.testDir = self.testDir.child("extra")
self.testDir.child("testDirectory").makedirs(True)
with self.testDir.child("testfile1").open(mode="wb") as f:
f.write(b"a" * 10 + b"b" * 10)
with open("/dev/urandom", "rb") as f2:
f.write(f2.read(1024 * 64)) # random data
self.testDir.child("testfile1").chmod(0o644)
with self.testDir.child("testRemoveFile").open(mode="wb") as f:
f.write(b"a")
with self.testDir.child("testRenameFile").open(mode="wb") as f:
f.write(b"a")
with self.testDir.child(".testHiddenFile").open(mode="wb") as f:
f.write(b"a")
@skipIf(not unix, "can't run on non-posix computers")
class OurServerOurClientTests(SFTPTestBase):
def setUp(self):
SFTPTestBase.setUp(self)
self.avatar = FileTransferTestAvatar(self.testDir)
self.server = filetransfer.FileTransferServer(avatar=self.avatar)
clientTransport = loopback.LoopbackRelay(self.server)
self.client = filetransfer.FileTransferClient()
self._serverVersion = None
self._extData = None
def _(serverVersion, extData):
self._serverVersion = serverVersion
self._extData = extData
self.client.gotServerVersion = _
serverTransport = loopback.LoopbackRelay(self.client)
self.client.makeConnection(clientTransport)
self.server.makeConnection(serverTransport)
self.clientTransport = clientTransport
self.serverTransport = serverTransport
self._emptyBuffers()
def _emptyBuffers(self):
while self.serverTransport.buffer or self.clientTransport.buffer:
self.serverTransport.clearBuffer()
self.clientTransport.clearBuffer()
def tearDown(self):
self.serverTransport.loseConnection()
self.clientTransport.loseConnection()
self.serverTransport.clearBuffer()
self.clientTransport.clearBuffer()
def test_serverVersion(self):
self.assertEqual(self._serverVersion, 3)
self.assertEqual(self._extData, {b"conchTest": b"ext data"})
def test_interface_implementation(self):
"""
It implements the ISFTPServer interface.
"""
self.assertTrue(
filetransfer.ISFTPServer.providedBy(self.server.client),
f"ISFTPServer not provided by {self.server.client!r}",
)
def test_openedFileClosedWithConnection(self):
"""
A file opened with C{openFile} is closed when the connection is lost.
"""
d = self.client.openFile(
b"testfile1", filetransfer.FXF_READ | filetransfer.FXF_WRITE, {}
)
self._emptyBuffers()
oldClose = os.close
closed = []
def close(fd):
closed.append(fd)
oldClose(fd)
self.patch(os, "close", close)
def _fileOpened(openFile):
fd = self.server.openFiles[openFile.handle[4:]].fd
self.serverTransport.loseConnection()
self.clientTransport.loseConnection()
self.serverTransport.clearBuffer()
self.clientTransport.clearBuffer()
self.assertEqual(self.server.openFiles, {})
self.assertIn(fd, closed)
d.addCallback(_fileOpened)
return d
def test_openedDirectoryClosedWithConnection(self):
"""
A directory opened with C{openDirectory} is close when the connection
is lost.
"""
d = self.client.openDirectory("")
self._emptyBuffers()
def _getFiles(openDir):
self.serverTransport.loseConnection()
self.clientTransport.loseConnection()
self.serverTransport.clearBuffer()
self.clientTransport.clearBuffer()
self.assertEqual(self.server.openDirs, {})
d.addCallback(_getFiles)
return d
def test_openFileIO(self):
d = self.client.openFile(
b"testfile1", filetransfer.FXF_READ | filetransfer.FXF_WRITE, {}
)
self._emptyBuffers()
def _fileOpened(openFile):
self.assertEqual(openFile, filetransfer.ISFTPFile(openFile))
d = _readChunk(openFile)
d.addCallback(_writeChunk, openFile)
return d
def _readChunk(openFile):
d = openFile.readChunk(0, 20)
self._emptyBuffers()
d.addCallback(self.assertEqual, b"a" * 10 + b"b" * 10)
return d
def _writeChunk(_, openFile):
d = openFile.writeChunk(20, b"c" * 10)
self._emptyBuffers()
d.addCallback(_readChunk2, openFile)
return d
def _readChunk2(_, openFile):
d = openFile.readChunk(0, 30)
self._emptyBuffers()
d.addCallback(self.assertEqual, b"a" * 10 + b"b" * 10 + b"c" * 10)
return d
d.addCallback(_fileOpened)
return d
def test_closedFileGetAttrs(self):
d = self.client.openFile(
b"testfile1", filetransfer.FXF_READ | filetransfer.FXF_WRITE, {}
)
self._emptyBuffers()
def _getAttrs(_, openFile):
d = openFile.getAttrs()
self._emptyBuffers()
return d
def _err(f):
self.flushLoggedErrors()
return f
def _close(openFile):
d = openFile.close()
self._emptyBuffers()
d.addCallback(_getAttrs, openFile)
d.addErrback(_err)
return self.assertFailure(d, filetransfer.SFTPError)
d.addCallback(_close)
return d
def test_openFileAttributes(self):
d = self.client.openFile(
b"testfile1", filetransfer.FXF_READ | filetransfer.FXF_WRITE, {}
)
self._emptyBuffers()
def _getAttrs(openFile):
d = openFile.getAttrs()
self._emptyBuffers()
d.addCallback(_getAttrs2)
return d
def _getAttrs2(attrs1):
d = self.client.getAttrs(b"testfile1")
self._emptyBuffers()
d.addCallback(self.assertEqual, attrs1)
return d
return d.addCallback(_getAttrs)
def test_openFileSetAttrs(self):
# XXX test setAttrs
# Ok, how about this for a start? It caught a bug :) -- spiv.
d = self.client.openFile(
b"testfile1", filetransfer.FXF_READ | filetransfer.FXF_WRITE, {}
)
self._emptyBuffers()
def _getAttrs(openFile):
d = openFile.getAttrs()
self._emptyBuffers()
d.addCallback(_setAttrs)
return d
def _setAttrs(attrs):
attrs["atime"] = 0
d = self.client.setAttrs(b"testfile1", attrs)
self._emptyBuffers()
d.addCallback(_getAttrs2)
d.addCallback(self.assertEqual, attrs)
return d
def _getAttrs2(_):
d = self.client.getAttrs(b"testfile1")
self._emptyBuffers()
return d
d.addCallback(_getAttrs)
return d
def test_openFileExtendedAttributes(self):
"""
Check that L{filetransfer.FileTransferClient.openFile} can send
extended attributes, that should be extracted server side. By default,
they are ignored, so we just verify they are correctly parsed.
"""
savedAttributes = {}
oldOpenFile = self.server.client.openFile
def openFile(filename, flags, attrs):
savedAttributes.update(attrs)
return oldOpenFile(filename, flags, attrs)
self.server.client.openFile = openFile
d = self.client.openFile(
b"testfile1",
filetransfer.FXF_READ | filetransfer.FXF_WRITE,
{"ext_foo": b"bar"},
)
self._emptyBuffers()
def check(ign):
self.assertEqual(savedAttributes, {"ext_foo": b"bar"})
return d.addCallback(check)
def test_removeFile(self):
d = self.client.getAttrs(b"testRemoveFile")
self._emptyBuffers()
def _removeFile(ignored):
d = self.client.removeFile(b"testRemoveFile")
self._emptyBuffers()
return d
d.addCallback(_removeFile)
d.addCallback(_removeFile)
return self.assertFailure(d, filetransfer.SFTPError)
def test_renameFile(self):
d = self.client.getAttrs(b"testRenameFile")
self._emptyBuffers()
def _rename(attrs):
d = self.client.renameFile(b"testRenameFile", b"testRenamedFile")
self._emptyBuffers()
d.addCallback(_testRenamed, attrs)
return d
def _testRenamed(_, attrs):
d = self.client.getAttrs(b"testRenamedFile")
self._emptyBuffers()
d.addCallback(self.assertEqual, attrs)
return d.addCallback(_rename)
def test_directoryBad(self):
d = self.client.getAttrs(b"testMakeDirectory")
self._emptyBuffers()
return self.assertFailure(d, filetransfer.SFTPError)
def test_directoryCreation(self):
d = self.client.makeDirectory(b"testMakeDirectory", {})
self._emptyBuffers()
def _getAttrs(_):
d = self.client.getAttrs(b"testMakeDirectory")
self._emptyBuffers()
return d
# XXX not until version 4/5
# self.assertEqual(filetransfer.FILEXFER_TYPE_DIRECTORY&attrs['type'],
# filetransfer.FILEXFER_TYPE_DIRECTORY)
def _removeDirectory(_):
d = self.client.removeDirectory(b"testMakeDirectory")
self._emptyBuffers()
return d
d.addCallback(_getAttrs)
d.addCallback(_removeDirectory)
d.addCallback(_getAttrs)
return self.assertFailure(d, filetransfer.SFTPError)
def test_openDirectory(self):
d = self.client.openDirectory(b"")
self._emptyBuffers()
files = []
def _getFiles(openDir):
def append(f):
files.append(f)
return openDir
d = defer.maybeDeferred(openDir.next)
self._emptyBuffers()
d.addCallback(append)
d.addCallback(_getFiles)
d.addErrback(_close, openDir)
return d
def _checkFiles(ignored):
fs = list(list(zip(*files))[0])
fs.sort()
self.assertEqual(
fs,
[
b".testHiddenFile",
b"testDirectory",
b"testRemoveFile",
b"testRenameFile",
b"testfile1",
],
)
def _close(_, openDir):
d = openDir.close()
self._emptyBuffers()
return d
d.addCallback(_getFiles)
d.addCallback(_checkFiles)
return d
def test_linkDoesntExist(self):
d = self.client.getAttrs(b"testLink")
self._emptyBuffers()
return self.assertFailure(d, filetransfer.SFTPError)
def test_linkSharesAttrs(self):
d = self.client.makeLink(b"testLink", b"testfile1")
self._emptyBuffers()
def _getFirstAttrs(_):
d = self.client.getAttrs(b"testLink", 1)
self._emptyBuffers()
return d
def _getSecondAttrs(firstAttrs):
d = self.client.getAttrs(b"testfile1")
self._emptyBuffers()
d.addCallback(self.assertEqual, firstAttrs)
return d
d.addCallback(_getFirstAttrs)
return d.addCallback(_getSecondAttrs)
def test_linkPath(self):
d = self.client.makeLink(b"testLink", b"testfile1")
self._emptyBuffers()
def _readLink(_):
d = self.client.readLink(b"testLink")
self._emptyBuffers()
testFile = FilePath(os.getcwd()).preauthChild(self.testDir.path)
testFile = testFile.child("testfile1")
d.addCallback(self.assertEqual, testFile.path)
return d
def _realPath(_):
d = self.client.realPath(b"testLink")
self._emptyBuffers()
testLink = FilePath(os.getcwd()).preauthChild(self.testDir.path)
testLink = testLink.child("testfile1")
d.addCallback(self.assertEqual, testLink.path)
return d
d.addCallback(_readLink)
d.addCallback(_realPath)
return d
def test_extendedRequest(self):
d = self.client.extendedRequest(b"testExtendedRequest", b"foo")
self._emptyBuffers()
d.addCallback(self.assertEqual, b"bar")
d.addCallback(self._cbTestExtendedRequest)
return d
def _cbTestExtendedRequest(self, ignored):
d = self.client.extendedRequest(b"testBadRequest", b"")
self._emptyBuffers()
return self.assertFailure(d, NotImplementedError)
@defer.inlineCallbacks
def test_openDirectoryIteratorDeprecated(self):
"""
Using client.openDirectory as an iterator is deprecated.
"""
d = self.client.openDirectory(b"")
self._emptyBuffers()
openDir = yield d
oneFile = openDir.next()
self._emptyBuffers()
yield oneFile
warnings = self.flushWarnings()
message = (
"Using twisted.conch.ssh.filetransfer.ClientDirectory"
" as an iterator was deprecated in Twisted 18.9.0."
)
self.assertEqual(1, len(warnings))
self.assertEqual(DeprecationWarning, warnings[0]["category"])
self.assertEqual(message, warnings[0]["message"])
@defer.inlineCallbacks
def test_closedConnectionCancelsRequests(self):
"""
If there are requests outstanding when the connection
is closed for any reason, they should fail.
"""
d = self.client.openFile(b"testfile1", filetransfer.FXF_READ, {})
self._emptyBuffers()
fh = yield d
# Intercept the handling of the read request on the server side
gotReadRequest = []
def _slowRead(offset, length):
self.assertEqual(gotReadRequest, [])
d = defer.Deferred()
gotReadRequest.append(offset)
return d
[serverSideFh] = self.server.openFiles.values()
serverSideFh.readChunk = _slowRead
del serverSideFh
# Make a read request, dropping the connection before the reply
# is sent
d = fh.readChunk(100, 200)
self._emptyBuffers()
self.assertEqual(len(gotReadRequest), 1)
self.assertNoResult(d)
# Lost connection should cause an errback
self.serverTransport.loseConnection()
self.serverTransport.clearBuffer()
self.clientTransport.clearBuffer()
self._emptyBuffers()
self.assertFalse(self.client.connected)
self.failureResultOf(d, ConnectionLost)
# Further attempts to use the filetransfer session should fail
# immediately
d = fh.getAttrs()
self.failureResultOf(d, ConnectionLost)
class FakeConn:
def sendClose(self, channel):
pass
@skipIf(not unix, "can't run on non-posix computers")
class FileTransferCloseTests(TestCase):
def setUp(self):
self.avatar = TestAvatar()
def buildServerConnection(self):
# make a server connection
conn = connection.SSHConnection()
# server connections have a 'self.transport.avatar'.
class DummyTransport:
def __init__(self):
self.transport = self
def sendPacket(self, kind, data):
pass
def logPrefix(self):
return "dummy transport"
conn.transport = DummyTransport()
conn.transport.avatar = self.avatar
return conn
def interceptConnectionLost(self, sftpServer):
self.connectionLostFired = False
origConnectionLost = sftpServer.connectionLost
def connectionLost(reason):
self.connectionLostFired = True
origConnectionLost(reason)
sftpServer.connectionLost = connectionLost
def assertSFTPConnectionLost(self):
self.assertTrue(
self.connectionLostFired, "sftpServer's connectionLost was not called"
)
def test_sessionClose(self):
"""
Closing a session should notify an SFTP subsystem launched by that
session.
"""
# make a session
testSession = session.SSHSession(conn=FakeConn(), avatar=self.avatar)
# start an SFTP subsystem on the session
testSession.request_subsystem(common.NS(b"sftp"))
sftpServer = testSession.client.transport.proto
# intercept connectionLost so we can check that it's called
self.interceptConnectionLost(sftpServer)
# close session
testSession.closeReceived()
self.assertSFTPConnectionLost()
def test_clientClosesChannelOnConnnection(self):
"""
A client sending CHANNEL_CLOSE should trigger closeReceived on the
associated channel instance.
"""
conn = self.buildServerConnection()
# somehow get a session
packet = common.NS(b"session") + struct.pack(">L", 0) * 3
conn.ssh_CHANNEL_OPEN(packet)
sessionChannel = conn.channels[0]
sessionChannel.request_subsystem(common.NS(b"sftp"))
sftpServer = sessionChannel.client.transport.proto
self.interceptConnectionLost(sftpServer)
# intercept closeReceived
self.interceptConnectionLost(sftpServer)
# close the connection
conn.ssh_CHANNEL_CLOSE(struct.pack(">L", 0))
self.assertSFTPConnectionLost()
def test_stopConnectionServiceClosesChannel(self):
"""
Closing an SSH connection should close all sessions within it.
"""
conn = self.buildServerConnection()
# somehow get a session
packet = common.NS(b"session") + struct.pack(">L", 0) * 3
conn.ssh_CHANNEL_OPEN(packet)
sessionChannel = conn.channels[0]
sessionChannel.request_subsystem(common.NS(b"sftp"))
sftpServer = sessionChannel.client.transport.proto
self.interceptConnectionLost(sftpServer)
# close the connection
conn.serviceStopped()
self.assertSFTPConnectionLost()
@skipIf(not cryptography, "Cannot run without cryptography")
class ConstantsTests(TestCase):
"""
Tests for the constants used by the SFTP protocol implementation.
@ivar filexferSpecExcerpts: Excerpts from the
draft-ietf-secsh-filexfer-02.txt (draft) specification of the SFTP
protocol. There are more recent drafts of the specification, but this
one describes version 3, which is what conch (and OpenSSH) implements.
"""
filexferSpecExcerpts = [
"""
The following values are defined for packet types.
#define SSH_FXP_INIT 1
#define SSH_FXP_VERSION 2
#define SSH_FXP_OPEN 3
#define SSH_FXP_CLOSE 4
#define SSH_FXP_READ 5
#define SSH_FXP_WRITE 6
#define SSH_FXP_LSTAT 7
#define SSH_FXP_FSTAT 8
#define SSH_FXP_SETSTAT 9
#define SSH_FXP_FSETSTAT 10
#define SSH_FXP_OPENDIR 11
#define SSH_FXP_READDIR 12
#define SSH_FXP_REMOVE 13
#define SSH_FXP_MKDIR 14
#define SSH_FXP_RMDIR 15
#define SSH_FXP_REALPATH 16
#define SSH_FXP_STAT 17
#define SSH_FXP_RENAME 18
#define SSH_FXP_READLINK 19
#define SSH_FXP_SYMLINK 20
#define SSH_FXP_STATUS 101
#define SSH_FXP_HANDLE 102
#define SSH_FXP_DATA 103
#define SSH_FXP_NAME 104
#define SSH_FXP_ATTRS 105
#define SSH_FXP_EXTENDED 200
#define SSH_FXP_EXTENDED_REPLY 201
Additional packet types should only be defined if the protocol
version number (see Section ``Protocol Initialization'') is
incremented, and their use MUST be negotiated using the version
number. However, the SSH_FXP_EXTENDED and SSH_FXP_EXTENDED_REPLY
packets can be used to implement vendor-specific extensions. See
Section ``Vendor-Specific-Extensions'' for more details.
""",
"""
The flags bits are defined to have the following values:
#define SSH_FILEXFER_ATTR_SIZE 0x00000001
#define SSH_FILEXFER_ATTR_UIDGID 0x00000002
#define SSH_FILEXFER_ATTR_PERMISSIONS 0x00000004
#define SSH_FILEXFER_ATTR_ACMODTIME 0x00000008
#define SSH_FILEXFER_ATTR_EXTENDED 0x80000000
""",
"""
The `pflags' field is a bitmask. The following bits have been
defined.
#define SSH_FXF_READ 0x00000001
#define SSH_FXF_WRITE 0x00000002
#define SSH_FXF_APPEND 0x00000004
#define SSH_FXF_CREAT 0x00000008
#define SSH_FXF_TRUNC 0x00000010
#define SSH_FXF_EXCL 0x00000020
""",
"""
Currently, the following values are defined (other values may be
defined by future versions of this protocol):
#define SSH_FX_OK 0
#define SSH_FX_EOF 1
#define SSH_FX_NO_SUCH_FILE 2
#define SSH_FX_PERMISSION_DENIED 3
#define SSH_FX_FAILURE 4
#define SSH_FX_BAD_MESSAGE 5
#define SSH_FX_NO_CONNECTION 6
#define SSH_FX_CONNECTION_LOST 7
#define SSH_FX_OP_UNSUPPORTED 8
""",
]
def test_constantsAgainstSpec(self):
"""
The constants used by the SFTP protocol implementation match those
found by searching through the spec.
"""
constants = {}
for excerpt in self.filexferSpecExcerpts:
for line in excerpt.splitlines():
m = re.match(r"^\s*#define SSH_([A-Z_]+)\s+([0-9x]*)\s*$", line)
if m:
constants[m.group(1)] = int(m.group(2), 0)
self.assertTrue(
len(constants) > 0, "No constants found (the test must be buggy)."
)
for k, v in constants.items():
self.assertEqual(v, getattr(filetransfer, k))
# We don't run on Windows, as we don't have an SFTP file server implemented in conch.ssh for Windows.
# As soon as there is such an implementation, we can run these tests on Windows.
@skipIf(not unix, "can't run on non-posix computers")
@skipIf(not cryptography, "Cannot run without cryptography")
class RawPacketDataServerTests(TestCase):
"""
Tests for L{filetransfer.FileTransferServer} which explicitly craft
certain less common situations to exercise their handling.
"""
def setUp(self):
self.fts = filetransfer.FileTransferServer(avatar=TestAvatar())
def test_closeInvalidHandle(self):
"""
A close request with an unknown handle receives an FX_NO_SUCH_FILE error
response.
"""
transport = StringTransport()
self.fts.makeConnection(transport)
# any four bytes
requestId = b"1234"
# The handle to close, arbitrary bytes.
handle = b"invalid handle"
# Construct a message packet
# https://datatracker.ietf.org/doc/html/draft-ietf-secsh-filexfer-13#section-4
close = common.NS(
# Packet type - SSH_FXP_CLOSE
# https://datatracker.ietf.org/doc/html/draft-ietf-secsh-filexfer-13#section-4.3
bytes([4])
+ requestId
+ common.NS(handle)
)
self.fts.dataReceived(close)
# An SSH_FXP_STATUS message
# https://datatracker.ietf.org/doc/html/draft-ietf-secsh-filexfer-13#section-9.1
expected = common.NS(
# Packet type SSH_FXP_STATUS
# https://datatracker.ietf.org/doc/html/draft-ietf-secsh-filexfer-13#section-4.3
bytes([101])
+
# The same request id
requestId
+
# A four byte status code. SSH_FX_NO_SUCH_FILE in this case.
bytes([0, 0, 0, 2])
+
# Error message
common.NS(b"No such file or directory")
+
# error message language tag - conch doesn't send one at all,
# though maybe it should
common.NS(b"")
)
assert_that(
transport.value(),
equal_to(expected),
)
@skipIf(not cryptography, "Cannot run without cryptography")
class RawPacketDataTests(TestCase):
"""
Tests for L{filetransfer.FileTransferClient} which explicitly craft certain
less common protocol messages to exercise their handling.
"""
def setUp(self):
self.ftc = filetransfer.FileTransferClient()
def test_packetSTATUS(self):
"""
A STATUS packet containing a result code, a message, and a language is
parsed to produce the result of an outstanding request L{Deferred}.
@see: U{section 9.1<http://tools.ietf.org/html/draft-ietf-secsh-filexfer-13#section-9.1>}
of the SFTP Internet-Draft.
"""
d = defer.Deferred()
d.addCallback(self._cbTestPacketSTATUS)
self.ftc.openRequests[1] = d
data = (
struct.pack("!LL", 1, filetransfer.FX_OK)
+ common.NS(b"msg")
+ common.NS(b"lang")
)
self.ftc.packet_STATUS(data)
return d
def _cbTestPacketSTATUS(self, result):
"""
Assert that the result is a two-tuple containing the message and
language from the STATUS packet.
"""
self.assertEqual(result[0], b"msg")
self.assertEqual(result[1], b"lang")
def test_packetSTATUSShort(self):
"""
A STATUS packet containing only a result code can also be parsed to
produce the result of an outstanding request L{Deferred}. Such packets
are sent by some SFTP implementations, though not strictly legal.
@see: U{section 9.1<http://tools.ietf.org/html/draft-ietf-secsh-filexfer-13#section-9.1>}
of the SFTP Internet-Draft.
"""
d = defer.Deferred()
d.addCallback(self._cbTestPacketSTATUSShort)
self.ftc.openRequests[1] = d
data = struct.pack("!LL", 1, filetransfer.FX_OK)
self.ftc.packet_STATUS(data)
return d
def _cbTestPacketSTATUSShort(self, result):
"""
Assert that the result is a two-tuple containing empty strings, since
the STATUS packet had neither a message nor a language.
"""
self.assertEqual(result[0], b"")
self.assertEqual(result[1], b"")
def test_packetSTATUSWithoutLang(self):
"""
A STATUS packet containing a result code and a message but no language
can also be parsed to produce the result of an outstanding request
L{Deferred}. Such packets are sent by some SFTP implementations, though
not strictly legal.
@see: U{section 9.1<http://tools.ietf.org/html/draft-ietf-secsh-filexfer-13#section-9.1>}
of the SFTP Internet-Draft.
"""
d = defer.Deferred()
d.addCallback(self._cbTestPacketSTATUSWithoutLang)
self.ftc.openRequests[1] = d
data = struct.pack("!LL", 1, filetransfer.FX_OK) + common.NS(b"msg")
self.ftc.packet_STATUS(data)
return d
def _cbTestPacketSTATUSWithoutLang(self, result):
"""
Assert that the result is a two-tuple containing the message from the
STATUS packet and an empty string, since the language was missing.
"""
self.assertEqual(result[0], b"msg")
self.assertEqual(result[1], b"")

View File

@@ -0,0 +1,61 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.conch.ssh.forwarding}.
"""
from twisted.python.reflect import requireModule
cryptography = requireModule("cryptography")
if cryptography:
from twisted.conch.ssh import forwarding
from twisted.internet.address import IPv6Address
from twisted.internet.test.test_endpoints import deterministicResolvingReactor
from twisted.internet.testing import MemoryReactorClock, StringTransport
from twisted.trial import unittest
class TestSSHConnectForwardingChannel(unittest.TestCase):
"""
Unit and integration tests for L{SSHConnectForwardingChannel}.
"""
if not cryptography:
skip = "Cannot run without cryptography"
def makeTCPConnection(self, reactor: MemoryReactorClock) -> None:
"""
Fake that connection was established for first connectTCP request made
on C{reactor}.
@param reactor: Reactor on which to fake the connection.
@type reactor: A reactor.
"""
factory = reactor.tcpClients[0][2]
connector = reactor.connectors[0]
protocol = factory.buildProtocol(None)
transport = StringTransport(peerAddress=connector.getDestination())
protocol.makeConnection(transport)
def test_channelOpenHostnameRequests(self) -> None:
"""
When a hostname is sent as part of forwarding requests, it
is resolved using HostnameEndpoint's resolver.
"""
sut = forwarding.SSHConnectForwardingChannel(hostport=("fwd.example.org", 1234))
# Patch channel and resolver to not touch the network.
memoryReactor = MemoryReactorClock()
sut._reactor = deterministicResolvingReactor(memoryReactor, ["::1"])
sut.channelOpen(None)
self.makeTCPConnection(memoryReactor)
self.successResultOf(sut._channelOpenDeferred)
# Channel is connected using a forwarding client to the resolved
# address of the requested host.
self.assertIsInstance(sut.client, forwarding.SSHForwardingClient)
self.assertEqual(
IPv6Address("TCP", "::1", 1234), sut.client.transport.getPeer()
)

View File

@@ -0,0 +1,619 @@
# -*- test-case-name: twisted.conch.test.test_helper -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
from __future__ import annotations
import re
from typing import Callable
from twisted.conch.insults import helper
from twisted.conch.insults.insults import (
BLINK,
BOLD,
G0,
G1,
G2,
G3,
NORMAL,
REVERSE_VIDEO,
UNDERLINE,
modes,
privateModes,
)
from twisted.python import failure
from twisted.trial import unittest
WIDTH = 80
HEIGHT = 24
class BufferTests(unittest.TestCase):
def setUp(self) -> None:
self.term = helper.TerminalBuffer()
self.term.connectionMade()
def testInitialState(self) -> None:
self.assertEqual(self.term.width, WIDTH)
self.assertEqual(self.term.height, HEIGHT)
self.assertEqual(self.term.__bytes__(), b"\n" * (HEIGHT - 1))
self.assertEqual(self.term.reportCursorPosition(), (0, 0))
def test_initialPrivateModes(self) -> None:
"""
Verify that only DEC Auto Wrap Mode (DECAWM) and DEC Text Cursor Enable
Mode (DECTCEM) are initially in the Set Mode (SM) state.
"""
self.assertEqual(
{privateModes.AUTO_WRAP: True, privateModes.CURSOR_MODE: True},
self.term.privateModes,
)
def test_carriageReturn(self) -> None:
"""
C{"\r"} moves the cursor to the first column in the current row.
"""
self.term.cursorForward(5)
self.term.cursorDown(3)
self.assertEqual(self.term.reportCursorPosition(), (5, 3))
self.term.insertAtCursor(b"\r")
self.assertEqual(self.term.reportCursorPosition(), (0, 3))
def test_linefeed(self) -> None:
"""
C{"\n"} moves the cursor to the next row without changing the column.
"""
self.term.cursorForward(5)
self.assertEqual(self.term.reportCursorPosition(), (5, 0))
self.term.insertAtCursor(b"\n")
self.assertEqual(self.term.reportCursorPosition(), (5, 1))
def test_newline(self) -> None:
"""
C{write} transforms C{"\n"} into C{"\r\n"}.
"""
self.term.cursorForward(5)
self.term.cursorDown(3)
self.assertEqual(self.term.reportCursorPosition(), (5, 3))
self.term.write(b"\n")
self.assertEqual(self.term.reportCursorPosition(), (0, 4))
def test_setPrivateModes(self) -> None:
"""
Verify that L{helper.TerminalBuffer.setPrivateModes} changes the Set
Mode (SM) state to "set" for the private modes it is passed.
"""
expected = self.term.privateModes.copy()
self.term.setPrivateModes([privateModes.SCROLL, privateModes.SCREEN])
expected[privateModes.SCROLL] = True
expected[privateModes.SCREEN] = True
self.assertEqual(expected, self.term.privateModes)
def test_resetPrivateModes(self) -> None:
"""
Verify that L{helper.TerminalBuffer.resetPrivateModes} changes the Set
Mode (SM) state to "reset" for the private modes it is passed.
"""
expected = self.term.privateModes.copy()
self.term.resetPrivateModes([privateModes.AUTO_WRAP, privateModes.CURSOR_MODE])
del expected[privateModes.AUTO_WRAP]
del expected[privateModes.CURSOR_MODE]
self.assertEqual(expected, self.term.privateModes)
def testCursorDown(self) -> None:
self.term.cursorDown(3)
self.assertEqual(self.term.reportCursorPosition(), (0, 3))
self.term.cursorDown()
self.assertEqual(self.term.reportCursorPosition(), (0, 4))
self.term.cursorDown(HEIGHT)
self.assertEqual(self.term.reportCursorPosition(), (0, HEIGHT - 1))
def testCursorUp(self) -> None:
self.term.cursorUp(5)
self.assertEqual(self.term.reportCursorPosition(), (0, 0))
self.term.cursorDown(20)
self.term.cursorUp(1)
self.assertEqual(self.term.reportCursorPosition(), (0, 19))
self.term.cursorUp(19)
self.assertEqual(self.term.reportCursorPosition(), (0, 0))
def testCursorForward(self) -> None:
self.term.cursorForward(2)
self.assertEqual(self.term.reportCursorPosition(), (2, 0))
self.term.cursorForward(2)
self.assertEqual(self.term.reportCursorPosition(), (4, 0))
self.term.cursorForward(WIDTH)
self.assertEqual(self.term.reportCursorPosition(), (WIDTH, 0))
def testCursorBackward(self) -> None:
self.term.cursorForward(10)
self.term.cursorBackward(2)
self.assertEqual(self.term.reportCursorPosition(), (8, 0))
self.term.cursorBackward(7)
self.assertEqual(self.term.reportCursorPosition(), (1, 0))
self.term.cursorBackward(1)
self.assertEqual(self.term.reportCursorPosition(), (0, 0))
self.term.cursorBackward(1)
self.assertEqual(self.term.reportCursorPosition(), (0, 0))
def testCursorPositioning(self) -> None:
self.term.cursorPosition(3, 9)
self.assertEqual(self.term.reportCursorPosition(), (3, 9))
def testSimpleWriting(self) -> None:
s = b"Hello, world."
self.term.write(s)
self.assertEqual(self.term.__bytes__(), s + b"\n" + b"\n" * (HEIGHT - 2))
def testOvertype(self) -> None:
s = b"hello, world."
self.term.write(s)
self.term.cursorBackward(len(s))
self.term.resetModes([modes.IRM])
self.term.write(b"H")
self.assertEqual(
self.term.__bytes__(), (b"H" + s[1:]) + b"\n" + b"\n" * (HEIGHT - 2)
)
def testInsert(self) -> None:
s = b"ello, world."
self.term.write(s)
self.term.cursorBackward(len(s))
self.term.setModes([modes.IRM])
self.term.write(b"H")
self.assertEqual(
self.term.__bytes__(), (b"H" + s) + b"\n" + b"\n" * (HEIGHT - 2)
)
def testWritingInTheMiddle(self) -> None:
s = b"Hello, world."
self.term.cursorDown(5)
self.term.cursorForward(5)
self.term.write(s)
self.assertEqual(
self.term.__bytes__(),
b"\n" * 5 + (self.term.fill * 5) + s + b"\n" + b"\n" * (HEIGHT - 7),
)
def testWritingWrappedAtEndOfLine(self) -> None:
s = b"Hello, world."
self.term.cursorForward(WIDTH - 5)
self.term.write(s)
self.assertEqual(
self.term.__bytes__(),
s[:5].rjust(WIDTH) + b"\n" + s[5:] + b"\n" + b"\n" * (HEIGHT - 3),
)
def testIndex(self) -> None:
self.term.index()
self.assertEqual(self.term.reportCursorPosition(), (0, 1))
self.term.cursorDown(HEIGHT)
self.assertEqual(self.term.reportCursorPosition(), (0, HEIGHT - 1))
self.term.index()
self.assertEqual(self.term.reportCursorPosition(), (0, HEIGHT - 1))
def testReverseIndex(self) -> None:
self.term.reverseIndex()
self.assertEqual(self.term.reportCursorPosition(), (0, 0))
self.term.cursorDown(2)
self.assertEqual(self.term.reportCursorPosition(), (0, 2))
self.term.reverseIndex()
self.assertEqual(self.term.reportCursorPosition(), (0, 1))
def test_nextLine(self) -> None:
"""
C{nextLine} positions the cursor at the beginning of the row below the
current row.
"""
self.term.nextLine()
self.assertEqual(self.term.reportCursorPosition(), (0, 1))
self.term.cursorForward(5)
self.assertEqual(self.term.reportCursorPosition(), (5, 1))
self.term.nextLine()
self.assertEqual(self.term.reportCursorPosition(), (0, 2))
def testSaveCursor(self) -> None:
self.term.cursorDown(5)
self.term.cursorForward(7)
self.assertEqual(self.term.reportCursorPosition(), (7, 5))
self.term.saveCursor()
self.term.cursorDown(7)
self.term.cursorBackward(3)
self.assertEqual(self.term.reportCursorPosition(), (4, 12))
self.term.restoreCursor()
self.assertEqual(self.term.reportCursorPosition(), (7, 5))
def testSingleShifts(self) -> None:
self.term.singleShift2()
self.term.write(b"Hi")
ch = self.term.getCharacter(0, 0)
self.assertEqual(ch[0], b"H")
self.assertEqual(ch[1].charset, G2)
ch = self.term.getCharacter(1, 0)
self.assertEqual(ch[0], b"i")
self.assertEqual(ch[1].charset, G0)
self.term.singleShift3()
self.term.write(b"!!")
ch = self.term.getCharacter(2, 0)
self.assertEqual(ch[0], b"!")
self.assertEqual(ch[1].charset, G3)
ch = self.term.getCharacter(3, 0)
self.assertEqual(ch[0], b"!")
self.assertEqual(ch[1].charset, G0)
def testShifting(self) -> None:
s1 = b"Hello"
s2 = b"World"
s3 = b"Bye!"
self.term.write(b"Hello\n")
self.term.shiftOut()
self.term.write(b"World\n")
self.term.shiftIn()
self.term.write(b"Bye!\n")
g = G0
h = 0
for s in (s1, s2, s3):
for i in range(len(s)):
ch = self.term.getCharacter(i, h)
self.assertEqual(ch[0], s[i : i + 1])
self.assertEqual(ch[1].charset, g)
g = g == G0 and G1 or G0
h += 1
def testGraphicRendition(self) -> None:
self.term.selectGraphicRendition(BOLD, UNDERLINE, BLINK, REVERSE_VIDEO)
self.term.write(b"W")
self.term.selectGraphicRendition(NORMAL)
self.term.write(b"X")
self.term.selectGraphicRendition(BLINK)
self.term.write(b"Y")
self.term.selectGraphicRendition(BOLD)
self.term.write(b"Z")
ch = self.term.getCharacter(0, 0)
self.assertEqual(ch[0], b"W")
self.assertTrue(ch[1].bold)
self.assertTrue(ch[1].underline)
self.assertTrue(ch[1].blink)
self.assertTrue(ch[1].reverseVideo)
ch = self.term.getCharacter(1, 0)
self.assertEqual(ch[0], b"X")
self.assertFalse(ch[1].bold)
self.assertFalse(ch[1].underline)
self.assertFalse(ch[1].blink)
self.assertFalse(ch[1].reverseVideo)
ch = self.term.getCharacter(2, 0)
self.assertEqual(ch[0], b"Y")
self.assertTrue(ch[1].blink)
self.assertFalse(ch[1].bold)
self.assertFalse(ch[1].underline)
self.assertFalse(ch[1].reverseVideo)
ch = self.term.getCharacter(3, 0)
self.assertEqual(ch[0], b"Z")
self.assertTrue(ch[1].blink)
self.assertTrue(ch[1].bold)
self.assertFalse(ch[1].underline)
self.assertFalse(ch[1].reverseVideo)
def testColorAttributes(self) -> None:
s1 = b"Merry xmas"
s2 = b"Just kidding"
self.term.selectGraphicRendition(
helper.FOREGROUND + helper.RED, helper.BACKGROUND + helper.GREEN
)
self.term.write(s1 + b"\n")
self.term.selectGraphicRendition(NORMAL)
self.term.write(s2 + b"\n")
for i in range(len(s1)):
ch = self.term.getCharacter(i, 0)
self.assertEqual(ch[0], s1[i : i + 1])
self.assertEqual(ch[1].charset, G0)
self.assertFalse(ch[1].bold)
self.assertFalse(ch[1].underline)
self.assertFalse(ch[1].blink)
self.assertFalse(ch[1].reverseVideo)
self.assertEqual(ch[1].foreground, helper.RED)
self.assertEqual(ch[1].background, helper.GREEN)
for i in range(len(s2)):
ch = self.term.getCharacter(i, 1)
self.assertEqual(ch[0], s2[i : i + 1])
self.assertEqual(ch[1].charset, G0)
self.assertFalse(ch[1].bold)
self.assertFalse(ch[1].underline)
self.assertFalse(ch[1].blink)
self.assertFalse(ch[1].reverseVideo)
self.assertEqual(ch[1].foreground, helper.WHITE)
self.assertEqual(ch[1].background, helper.BLACK)
def testEraseLine(self) -> None:
s1 = b"line 1"
s2 = b"line 2"
s3 = b"line 3"
self.term.write(b"\n".join((s1, s2, s3)) + b"\n")
self.term.cursorPosition(1, 1)
self.term.eraseLine()
self.assertEqual(
self.term.__bytes__(),
s1 + b"\n" + b"\n" + s3 + b"\n" + b"\n" * (HEIGHT - 4),
)
def testEraseToLineEnd(self) -> None:
s = b"Hello, world."
self.term.write(s)
self.term.cursorBackward(5)
self.term.eraseToLineEnd()
self.assertEqual(self.term.__bytes__(), s[:-5] + b"\n" + b"\n" * (HEIGHT - 2))
def testEraseToLineBeginning(self) -> None:
s = b"Hello, world."
self.term.write(s)
self.term.cursorBackward(5)
self.term.eraseToLineBeginning()
self.assertEqual(
self.term.__bytes__(), s[-4:].rjust(len(s)) + b"\n" + b"\n" * (HEIGHT - 2)
)
def testEraseDisplay(self) -> None:
self.term.write(b"Hello world\n")
self.term.write(b"Goodbye world\n")
self.term.eraseDisplay()
self.assertEqual(self.term.__bytes__(), b"\n" * (HEIGHT - 1))
def testEraseToDisplayEnd(self) -> None:
s1 = b"Hello world"
s2 = b"Goodbye world"
self.term.write(b"\n".join((s1, s2, b"")))
self.term.cursorPosition(5, 1)
self.term.eraseToDisplayEnd()
self.assertEqual(
self.term.__bytes__(), s1 + b"\n" + s2[:5] + b"\n" + b"\n" * (HEIGHT - 3)
)
def testEraseToDisplayBeginning(self) -> None:
s1 = b"Hello world"
s2 = b"Goodbye world"
self.term.write(b"\n".join((s1, s2)))
self.term.cursorPosition(5, 1)
self.term.eraseToDisplayBeginning()
self.assertEqual(
self.term.__bytes__(),
b"\n" + s2[6:].rjust(len(s2)) + b"\n" + b"\n" * (HEIGHT - 3),
)
def testLineInsertion(self) -> None:
s1 = b"Hello world"
s2 = b"Goodbye world"
self.term.write(b"\n".join((s1, s2)))
self.term.cursorPosition(7, 1)
self.term.insertLine()
self.assertEqual(
self.term.__bytes__(),
s1 + b"\n" + b"\n" + s2 + b"\n" + b"\n" * (HEIGHT - 4),
)
def testLineDeletion(self) -> None:
s1 = b"Hello world"
s2 = b"Middle words"
s3 = b"Goodbye world"
self.term.write(b"\n".join((s1, s2, s3)))
self.term.cursorPosition(9, 1)
self.term.deleteLine()
self.assertEqual(
self.term.__bytes__(), s1 + b"\n" + s3 + b"\n" + b"\n" * (HEIGHT - 3)
)
class FakeDelayedCall:
called = False
cancelled = False
def __init__(
self,
fs: FakeScheduler,
timeout: float,
f: Callable[..., None],
a: tuple[object, ...],
kw: dict[str, object],
) -> None:
self.fs = fs
self.timeout = timeout
self.f = f
self.a = a
self.kw = kw
def active(self) -> bool:
return not (self.cancelled or self.called)
def cancel(self) -> None:
self.cancelled = True
# self.fs.calls.remove(self)
def call(self) -> None:
self.called = True
self.f(*self.a, **self.kw)
class FakeScheduler:
def __init__(self) -> None:
self.calls: list[FakeDelayedCall] = []
def callLater(
self, timeout: float, f: Callable[..., None], *a: object, **kw: object
) -> FakeDelayedCall:
self.calls.append(FakeDelayedCall(self, timeout, f, a, kw))
return self.calls[-1]
class ExpectTests(unittest.TestCase):
def setUp(self) -> None:
self.term = helper.ExpectableBuffer()
self.term.connectionMade()
self.fs = FakeScheduler()
def testSimpleString(self) -> None:
result: list[re.Match[bytes]] = []
d = self.term.expect(b"hello world", timeout=1, scheduler=self.fs)
d.addCallback(result.append)
self.term.write(b"greeting puny earthlings\n")
self.assertFalse(result)
self.term.write(b"hello world\n")
self.assertTrue(result)
self.assertEqual(result[0].group(), b"hello world")
self.assertEqual(len(self.fs.calls), 1)
self.assertFalse(self.fs.calls[0].active())
def testBrokenUpString(self) -> None:
result: list[re.Match[bytes]] = []
d = self.term.expect(b"hello world")
d.addCallback(result.append)
self.assertFalse(result)
self.term.write(b"hello ")
self.assertFalse(result)
self.term.write(b"worl")
self.assertFalse(result)
self.term.write(b"d")
self.assertTrue(result)
self.assertEqual(result[0].group(), b"hello world")
def testMultiple(self) -> None:
result: list[re.Match[bytes]] = []
d1 = self.term.expect(b"hello ")
d1.addCallback(result.append)
d2 = self.term.expect(b"world")
d2.addCallback(result.append)
self.assertFalse(result)
self.term.write(b"hello")
self.assertFalse(result)
self.term.write(b" ")
self.assertEqual(len(result), 1)
self.term.write(b"world")
self.assertEqual(len(result), 2)
self.assertEqual(result[0].group(), b"hello ")
self.assertEqual(result[1].group(), b"world")
def testSynchronous(self) -> None:
self.term.write(b"hello world")
result: list[re.Match[bytes]] = []
d = self.term.expect(b"hello world")
d.addCallback(result.append)
self.assertTrue(result)
self.assertEqual(result[0].group(), b"hello world")
def testMultipleSynchronous(self) -> None:
self.term.write(b"goodbye world")
result: list[re.Match[bytes]] = []
d1 = self.term.expect(b"bye")
d1.addCallback(result.append)
d2 = self.term.expect(b"world")
d2.addCallback(result.append)
self.assertEqual(len(result), 2)
self.assertEqual(result[0].group(), b"bye")
self.assertEqual(result[1].group(), b"world")
def _cbTestTimeoutFailure(self, res: failure.Failure) -> None:
self.assertTrue(hasattr(res, "type"))
self.assertEqual(res.type, helper.ExpectationTimeout)
def testTimeoutFailure(self) -> None:
d = self.term.expect(b"hello world", timeout=1, scheduler=self.fs)
d.addBoth(self._cbTestTimeoutFailure)
self.fs.calls[0].call()
def testOverlappingTimeout(self) -> None:
self.term.write(b"not zoomtastic")
result: list[re.Match[bytes]] = []
d1 = self.term.expect(b"hello world", timeout=1, scheduler=self.fs)
d1.addBoth(self._cbTestTimeoutFailure)
d2 = self.term.expect(b"zoom")
d2.addCallback(result.append)
self.fs.calls[0].call()
self.assertEqual(len(result), 1)
self.assertEqual(result[0].group(), b"zoom")
class CharacterAttributeTests(unittest.TestCase):
"""
Tests for L{twisted.conch.insults.helper.CharacterAttribute}.
"""
def test_equality(self) -> None:
"""
L{CharacterAttribute}s must have matching character attribute values
(bold, blink, underline, etc) with the same values to be considered
equal.
"""
self.assertEqual(helper.CharacterAttribute(), helper.CharacterAttribute())
self.assertEqual(
helper.CharacterAttribute(), helper.CharacterAttribute(charset=G0)
)
self.assertEqual(
helper.CharacterAttribute(
bold=True,
underline=True,
blink=False,
reverseVideo=True,
foreground=helper.BLUE,
),
helper.CharacterAttribute(
bold=True,
underline=True,
blink=False,
reverseVideo=True,
foreground=helper.BLUE,
),
)
self.assertNotEqual(
helper.CharacterAttribute(), helper.CharacterAttribute(charset=G1)
)
self.assertNotEqual(
helper.CharacterAttribute(bold=True), helper.CharacterAttribute(bold=False)
)
def test_wantOneDeprecated(self) -> None:
"""
L{twisted.conch.insults.helper.CharacterAttribute.wantOne} emits
a deprecation warning when invoked.
"""
# Trigger the deprecation warning.
helper._FormattingState().wantOne(bold=True)
warningsShown = self.flushWarnings([self.test_wantOneDeprecated])
self.assertEqual(len(warningsShown), 1)
self.assertEqual(warningsShown[0]["category"], DeprecationWarning)
deprecatedClass = "twisted.conch.insults.helper._FormattingState.wantOne"
self.assertEqual(
warningsShown[0]["message"],
"%s was deprecated in Twisted 13.1.0" % (deprecatedClass),
)

View File

@@ -0,0 +1,958 @@
# -*- test-case-name: twisted.conch.test.test_insults -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
import textwrap
from typing import Optional, Type
from twisted.conch.insults.insults import (
BLINK,
CS_ALTERNATE,
CS_ALTERNATE_SPECIAL,
CS_DRAWING,
CS_UK,
CS_US,
G0,
G1,
UNDERLINE,
ClientProtocol,
ServerProtocol,
modes,
privateModes,
)
from twisted.internet.protocol import Protocol
from twisted.internet.testing import StringTransport
from twisted.python.compat import iterbytes
from twisted.python.constants import ValueConstant, Values
from twisted.trial import unittest
def _getattr(mock, name):
return super(Mock, mock).__getattribute__(name)
def occurrences(mock):
return _getattr(mock, "occurrences")
def methods(mock):
return _getattr(mock, "methods")
def _append(mock, obj):
occurrences(mock).append(obj)
default = object()
def _ecmaCodeTableCoordinate(column, row):
"""
Return the byte in 7- or 8-bit code table identified by C{column}
and C{row}.
"An 8-bit code table consists of 256 positions arranged in 16
columns and 16 rows. The columns and rows are numbered 00 to 15."
"A 7-bit code table consists of 128 positions arranged in 8
columns and 16 rows. The columns are numbered 00 to 07 and the
rows 00 to 15 (see figure 1)."
p.5 of "Standard ECMA-35: Character Code Structure and Extension
Techniques", 6th Edition (December 1994).
"""
# 8 and 15 both happen to take up 4 bits, so the first number
# should be shifted by 4 for both the 7- and 8-bit tables.
return bytes(bytearray([(column << 4) | row]))
def _makeControlFunctionSymbols(name, colOffset, names, doc):
# the value for each name is the concatenation of the bit values
# of its x, y locations, with an offset of 4 added to its x value.
# so CUP is (0 + 4, 8) = (4, 8) = 4||8 = 1001000 = 72 = b"H"
# this is how it's defined in the standard!
attrs = {
name: ValueConstant(_ecmaCodeTableCoordinate(i + colOffset, j))
for j, row in enumerate(names)
for i, name in enumerate(row)
if name
}
attrs["__doc__"] = doc
return type(name, (Values,), attrs)
CSFinalByte = _makeControlFunctionSymbols(
"CSFinalByte",
colOffset=4,
names=[
# 4, 5, 6
["ICH", "DCH", "HPA"],
["CUU", "SSE", "HPR"],
["CUD", "CPR", "REP"],
["CUF", "SU", "DA"],
["CUB", "SD", "VPA"],
["CNL", "NP", "VPR"],
["CPL", "PP", "HVP"],
["CHA", "CTC", "TBC"],
["CUP", "ECH", "SM"],
["CHT", "CVT", "MC"],
["ED", "CBT", "HPB"],
["EL", "SRS", "VPB"],
["IL", "PTX", "RM"],
["DL", "SDS", "SGR"],
["EF", "SIMD", "DSR"],
["EA", None, "DAQ"],
],
doc=textwrap.dedent(
"""
Symbolic constants for all control sequence final bytes
that do not imply intermediate bytes. This happens to cover
movement control sequences.
See page 11 of "Standard ECMA 48: Control Functions for Coded
Character Sets", 5th Edition (June 1991).
Each L{ValueConstant} maps a control sequence name to L{bytes}
"""
),
)
C1SevenBit = _makeControlFunctionSymbols(
"C1SevenBit",
colOffset=4,
names=[
[None, "DCS"],
[None, "PU1"],
["BPH", "PU2"],
["NBH", "STS"],
[None, "CCH"],
["NEL", "MW"],
["SSA", "SPA"],
["ESA", "EPA"],
["HTS", "SOS"],
["HTJ", None],
["VTS", "SCI"],
["PLD", "CSI"],
["PLU", "ST"],
["RI", "OSC"],
["SS2", "PM"],
["SS3", "APC"],
],
doc=textwrap.dedent(
"""
Symbolic constants for all 7 bit versions of the C1 control functions
See page 9 "Standard ECMA 48: Control Functions for Coded
Character Sets", 5th Edition (June 1991).
Each L{ValueConstant} maps a control sequence name to L{bytes}
"""
),
)
class Mock:
callReturnValue = default
def __init__(self, methods=None, callReturnValue=default):
"""
@param methods: Mapping of names to return values
@param callReturnValue: object __call__ should return
"""
self.occurrences = []
if methods is None:
methods = {}
self.methods = methods
if callReturnValue is not default:
self.callReturnValue = callReturnValue
def __call__(self, *a, **kw):
returnValue = _getattr(self, "callReturnValue")
if returnValue is default:
returnValue = Mock()
# _getattr(self, 'occurrences').append(('__call__', returnValue, a, kw))
_append(self, ("__call__", returnValue, a, kw))
return returnValue
def __getattribute__(self, name):
methods = _getattr(self, "methods")
if name in methods:
attrValue = Mock(callReturnValue=methods[name])
else:
attrValue = Mock()
# _getattr(self, 'occurrences').append((name, attrValue))
_append(self, (name, attrValue))
return attrValue
class MockMixin:
def assertCall(
self, occurrence, methodName, expectedPositionalArgs=(), expectedKeywordArgs={}
):
attr, mock = occurrence
self.assertEqual(attr, methodName)
self.assertEqual(len(occurrences(mock)), 1)
[(call, result, args, kw)] = occurrences(mock)
self.assertEqual(call, "__call__")
self.assertEqual(args, expectedPositionalArgs)
self.assertEqual(kw, expectedKeywordArgs)
return result
_byteGroupingTestTemplate = """\
def testByte%(groupName)s(self):
transport = StringTransport()
proto = Mock()
parser = self.protocolFactory(lambda: proto)
parser.factory = self
parser.makeConnection(transport)
bytes = self.TEST_BYTES
while bytes:
chunk = bytes[:%(bytesPer)d]
bytes = bytes[%(bytesPer)d:]
parser.dataReceived(chunk)
self.verifyResults(transport, proto, parser)
"""
class ByteGroupingsMixin(MockMixin):
protocolFactory: Optional[Type[Protocol]] = None
for word, n in [
("Pairs", 2),
("Triples", 3),
("Quads", 4),
("Quints", 5),
("Sexes", 6),
]:
exec(_byteGroupingTestTemplate % {"groupName": word, "bytesPer": n})
del word, n
def verifyResults(self, transport, proto, parser):
result = self.assertCall(occurrences(proto).pop(0), "makeConnection", (parser,))
self.assertEqual(occurrences(result), [])
del _byteGroupingTestTemplate
class ServerArrowKeysTests(ByteGroupingsMixin, unittest.TestCase):
protocolFactory = ServerProtocol
# All the arrow keys once
TEST_BYTES = b"\x1b[A\x1b[B\x1b[C\x1b[D"
def verifyResults(self, transport, proto, parser):
ByteGroupingsMixin.verifyResults(self, transport, proto, parser)
for arrow in (
parser.UP_ARROW,
parser.DOWN_ARROW,
parser.RIGHT_ARROW,
parser.LEFT_ARROW,
):
result = self.assertCall(
occurrences(proto).pop(0), "keystrokeReceived", (arrow, None)
)
self.assertEqual(occurrences(result), [])
self.assertFalse(occurrences(proto))
class PrintableCharactersTests(ByteGroupingsMixin, unittest.TestCase):
protocolFactory = ServerProtocol
# Some letters and digits, first on their own, then capitalized,
# then modified with alt
TEST_BYTES = b"abc123ABC!@#\x1ba\x1bb\x1bc\x1b1\x1b2\x1b3"
def verifyResults(self, transport, proto, parser):
ByteGroupingsMixin.verifyResults(self, transport, proto, parser)
for char in iterbytes(b"abc123ABC!@#"):
result = self.assertCall(
occurrences(proto).pop(0), "keystrokeReceived", (char, None)
)
self.assertEqual(occurrences(result), [])
for char in iterbytes(b"abc123"):
result = self.assertCall(
occurrences(proto).pop(0), "keystrokeReceived", (char, parser.ALT)
)
self.assertEqual(occurrences(result), [])
occs = occurrences(proto)
self.assertFalse(occs, f"{occs!r} should have been []")
class ServerFunctionKeysTests(ByteGroupingsMixin, unittest.TestCase):
"""Test for parsing and dispatching function keys (F1 - F12)"""
protocolFactory = ServerProtocol
byteList = []
for byteCodes in (
b"OP",
b"OQ",
b"OR",
b"OS", # F1 - F4
b"15~",
b"17~",
b"18~",
b"19~", # F5 - F8
b"20~",
b"21~",
b"23~",
b"24~",
): # F9 - F12
byteList.append(b"\x1b[" + byteCodes)
TEST_BYTES = b"".join(byteList)
del byteList, byteCodes
def verifyResults(self, transport, proto, parser):
ByteGroupingsMixin.verifyResults(self, transport, proto, parser)
for funcNum in range(1, 13):
funcArg = getattr(parser, "F%d" % (funcNum,))
result = self.assertCall(
occurrences(proto).pop(0), "keystrokeReceived", (funcArg, None)
)
self.assertEqual(occurrences(result), [])
self.assertFalse(occurrences(proto))
class ClientCursorMovementTests(ByteGroupingsMixin, unittest.TestCase):
protocolFactory = ClientProtocol
d2 = b"\x1b[2B"
r4 = b"\x1b[4C"
u1 = b"\x1b[A"
l2 = b"\x1b[2D"
# Move the cursor down two, right four, up one, left two, up one, left two
TEST_BYTES = d2 + r4 + u1 + l2 + u1 + l2
del d2, r4, u1, l2
def verifyResults(self, transport, proto, parser):
ByteGroupingsMixin.verifyResults(self, transport, proto, parser)
for method, count in [
("Down", 2),
("Forward", 4),
("Up", 1),
("Backward", 2),
("Up", 1),
("Backward", 2),
]:
result = self.assertCall(
occurrences(proto).pop(0), "cursor" + method, (count,)
)
self.assertEqual(occurrences(result), [])
self.assertFalse(occurrences(proto))
class ClientControlSequencesTests(unittest.TestCase, MockMixin):
def setUp(self):
self.transport = StringTransport()
self.proto = Mock()
self.parser = ClientProtocol(lambda: self.proto)
self.parser.factory = self
self.parser.makeConnection(self.transport)
result = self.assertCall(
occurrences(self.proto).pop(0), "makeConnection", (self.parser,)
)
self.assertFalse(occurrences(result))
def testSimpleCardinals(self):
self.parser.dataReceived(
b"".join(
b"\x1b[" + n + ch
for ch in iterbytes(b"BACD")
for n in (b"", b"2", b"20", b"200")
)
)
occs = occurrences(self.proto)
for meth in ("Down", "Up", "Forward", "Backward"):
for count in (1, 2, 20, 200):
result = self.assertCall(occs.pop(0), "cursor" + meth, (count,))
self.assertFalse(occurrences(result))
self.assertFalse(occs)
def testScrollRegion(self):
self.parser.dataReceived(b"\x1b[5;22r\x1b[r")
occs = occurrences(self.proto)
result = self.assertCall(occs.pop(0), "setScrollRegion", (5, 22))
self.assertFalse(occurrences(result))
result = self.assertCall(occs.pop(0), "setScrollRegion", (None, None))
self.assertFalse(occurrences(result))
self.assertFalse(occs)
def testHeightAndWidth(self):
self.parser.dataReceived(b"\x1b#3\x1b#4\x1b#5\x1b#6")
occs = occurrences(self.proto)
result = self.assertCall(occs.pop(0), "doubleHeightLine", (True,))
self.assertFalse(occurrences(result))
result = self.assertCall(occs.pop(0), "doubleHeightLine", (False,))
self.assertFalse(occurrences(result))
result = self.assertCall(occs.pop(0), "singleWidthLine")
self.assertFalse(occurrences(result))
result = self.assertCall(occs.pop(0), "doubleWidthLine")
self.assertFalse(occurrences(result))
self.assertFalse(occs)
def testCharacterSet(self):
self.parser.dataReceived(
b"".join(
[
b"".join([b"\x1b" + g + n for n in iterbytes(b"AB012")])
for g in iterbytes(b"()")
]
)
)
occs = occurrences(self.proto)
for which in (G0, G1):
for charset in (
CS_UK,
CS_US,
CS_DRAWING,
CS_ALTERNATE,
CS_ALTERNATE_SPECIAL,
):
result = self.assertCall(
occs.pop(0), "selectCharacterSet", (charset, which)
)
self.assertFalse(occurrences(result))
self.assertFalse(occs)
def testShifting(self):
self.parser.dataReceived(b"\x15\x14")
occs = occurrences(self.proto)
result = self.assertCall(occs.pop(0), "shiftIn")
self.assertFalse(occurrences(result))
result = self.assertCall(occs.pop(0), "shiftOut")
self.assertFalse(occurrences(result))
self.assertFalse(occs)
def testSingleShifts(self):
self.parser.dataReceived(b"\x1bN\x1bO")
occs = occurrences(self.proto)
result = self.assertCall(occs.pop(0), "singleShift2")
self.assertFalse(occurrences(result))
result = self.assertCall(occs.pop(0), "singleShift3")
self.assertFalse(occurrences(result))
self.assertFalse(occs)
def testKeypadMode(self):
self.parser.dataReceived(b"\x1b=\x1b>")
occs = occurrences(self.proto)
result = self.assertCall(occs.pop(0), "applicationKeypadMode")
self.assertFalse(occurrences(result))
result = self.assertCall(occs.pop(0), "numericKeypadMode")
self.assertFalse(occurrences(result))
self.assertFalse(occs)
def testCursor(self):
self.parser.dataReceived(b"\x1b7\x1b8")
occs = occurrences(self.proto)
result = self.assertCall(occs.pop(0), "saveCursor")
self.assertFalse(occurrences(result))
result = self.assertCall(occs.pop(0), "restoreCursor")
self.assertFalse(occurrences(result))
self.assertFalse(occs)
def testReset(self):
self.parser.dataReceived(b"\x1bc")
occs = occurrences(self.proto)
result = self.assertCall(occs.pop(0), "reset")
self.assertFalse(occurrences(result))
self.assertFalse(occs)
def testIndex(self):
self.parser.dataReceived(b"\x1bD\x1bM\x1bE")
occs = occurrences(self.proto)
result = self.assertCall(occs.pop(0), "index")
self.assertFalse(occurrences(result))
result = self.assertCall(occs.pop(0), "reverseIndex")
self.assertFalse(occurrences(result))
result = self.assertCall(occs.pop(0), "nextLine")
self.assertFalse(occurrences(result))
self.assertFalse(occs)
def testModes(self):
self.parser.dataReceived(
b"\x1b["
+ b";".join(b"%d" % (m,) for m in [modes.KAM, modes.IRM, modes.LNM])
+ b"h"
)
self.parser.dataReceived(
b"\x1b["
+ b";".join(b"%d" % (m,) for m in [modes.KAM, modes.IRM, modes.LNM])
+ b"l"
)
occs = occurrences(self.proto)
result = self.assertCall(
occs.pop(0), "setModes", ([modes.KAM, modes.IRM, modes.LNM],)
)
self.assertFalse(occurrences(result))
result = self.assertCall(
occs.pop(0), "resetModes", ([modes.KAM, modes.IRM, modes.LNM],)
)
self.assertFalse(occurrences(result))
self.assertFalse(occs)
def testErasure(self):
self.parser.dataReceived(b"\x1b[K\x1b[1K\x1b[2K\x1b[J\x1b[1J\x1b[2J\x1b[3P")
occs = occurrences(self.proto)
for meth in (
"eraseToLineEnd",
"eraseToLineBeginning",
"eraseLine",
"eraseToDisplayEnd",
"eraseToDisplayBeginning",
"eraseDisplay",
):
result = self.assertCall(occs.pop(0), meth)
self.assertFalse(occurrences(result))
result = self.assertCall(occs.pop(0), "deleteCharacter", (3,))
self.assertFalse(occurrences(result))
self.assertFalse(occs)
def testLineDeletion(self):
self.parser.dataReceived(b"\x1b[M\x1b[3M")
occs = occurrences(self.proto)
for arg in (1, 3):
result = self.assertCall(occs.pop(0), "deleteLine", (arg,))
self.assertFalse(occurrences(result))
self.assertFalse(occs)
def testLineInsertion(self):
self.parser.dataReceived(b"\x1b[L\x1b[3L")
occs = occurrences(self.proto)
for arg in (1, 3):
result = self.assertCall(occs.pop(0), "insertLine", (arg,))
self.assertFalse(occurrences(result))
self.assertFalse(occs)
def testCursorPosition(self):
methods(self.proto)["reportCursorPosition"] = (6, 7)
self.parser.dataReceived(b"\x1b[6n")
self.assertEqual(self.transport.value(), b"\x1b[7;8R")
occs = occurrences(self.proto)
result = self.assertCall(occs.pop(0), "reportCursorPosition")
# This isn't really an interesting assert, since it only tests that
# our mock setup is working right, but I'll include it anyway.
self.assertEqual(result, (6, 7))
def test_applicationDataBytes(self):
"""
Contiguous non-control bytes are passed to a single call to the
C{write} method of the terminal to which the L{ClientProtocol} is
connected.
"""
occs = occurrences(self.proto)
self.parser.dataReceived(b"a")
self.assertCall(occs.pop(0), "write", (b"a",))
self.parser.dataReceived(b"bc")
self.assertCall(occs.pop(0), "write", (b"bc",))
def _applicationDataTest(self, data, calls):
occs = occurrences(self.proto)
self.parser.dataReceived(data)
while calls:
self.assertCall(occs.pop(0), *calls.pop(0))
self.assertFalse(occs, f"No other calls should happen: {occs!r}")
def test_shiftInAfterApplicationData(self):
"""
Application data bytes followed by a shift-in command are passed to a
call to C{write} before the terminal's C{shiftIn} method is called.
"""
self._applicationDataTest(b"ab\x15", [("write", (b"ab",)), ("shiftIn",)])
def test_shiftOutAfterApplicationData(self):
"""
Application data bytes followed by a shift-out command are passed to a
call to C{write} before the terminal's C{shiftOut} method is called.
"""
self._applicationDataTest(b"ab\x14", [("write", (b"ab",)), ("shiftOut",)])
def test_cursorBackwardAfterApplicationData(self):
"""
Application data bytes followed by a cursor-backward command are passed
to a call to C{write} before the terminal's C{cursorBackward} method is
called.
"""
self._applicationDataTest(b"ab\x08", [("write", (b"ab",)), ("cursorBackward",)])
def test_escapeAfterApplicationData(self):
"""
Application data bytes followed by an escape character are passed to a
call to C{write} before the terminal's handler method for the escape is
called.
"""
# Test a short escape
self._applicationDataTest(b"ab\x1bD", [("write", (b"ab",)), ("index",)])
# And a long escape
self._applicationDataTest(
b"ab\x1b[4h", [("write", (b"ab",)), ("setModes", ([4],))]
)
# There's some other cases too, but they're all handled by the same
# codepaths as above.
class ServerProtocolOutputTests(unittest.TestCase):
"""
Tests for the bytes L{ServerProtocol} writes to its transport when its
methods are called.
"""
# From ECMA 48: CSI is represented by bit combinations 01/11
# (representing ESC) and 05/11 in a 7-bit code or by bit
# combination 09/11 in an 8-bit code
ESC = _ecmaCodeTableCoordinate(1, 11)
CSI = ESC + _ecmaCodeTableCoordinate(5, 11)
def setUp(self):
self.protocol = ServerProtocol()
self.transport = StringTransport()
self.protocol.makeConnection(self.transport)
def test_cursorUp(self):
"""
L{ServerProtocol.cursorUp} writes the control sequence
ending with L{CSFinalByte.CUU} to its transport.
"""
self.protocol.cursorUp(1)
self.assertEqual(
self.transport.value(), self.CSI + b"1" + CSFinalByte.CUU.value
)
def test_cursorDown(self):
"""
L{ServerProtocol.cursorDown} writes the control sequence
ending with L{CSFinalByte.CUD} to its transport.
"""
self.protocol.cursorDown(1)
self.assertEqual(
self.transport.value(), self.CSI + b"1" + CSFinalByte.CUD.value
)
def test_cursorForward(self):
"""
L{ServerProtocol.cursorForward} writes the control sequence
ending with L{CSFinalByte.CUF} to its transport.
"""
self.protocol.cursorForward(1)
self.assertEqual(
self.transport.value(), self.CSI + b"1" + CSFinalByte.CUF.value
)
def test_cursorBackward(self):
"""
L{ServerProtocol.cursorBackward} writes the control sequence
ending with L{CSFinalByte.CUB} to its transport.
"""
self.protocol.cursorBackward(1)
self.assertEqual(
self.transport.value(), self.CSI + b"1" + CSFinalByte.CUB.value
)
def test_cursorPosition(self):
"""
L{ServerProtocol.cursorPosition} writes a control sequence
ending with L{CSFinalByte.CUP} and containing the expected
coordinates to its transport.
"""
self.protocol.cursorPosition(0, 0)
self.assertEqual(
self.transport.value(), self.CSI + b"1;1" + CSFinalByte.CUP.value
)
def test_cursorHome(self):
"""
L{ServerProtocol.cursorHome} writes a control sequence ending
with L{CSFinalByte.CUP} and no parameters, so that the client
defaults to (1, 1).
"""
self.protocol.cursorHome()
self.assertEqual(self.transport.value(), self.CSI + CSFinalByte.CUP.value)
def test_index(self):
"""
L{ServerProtocol.index} writes the control sequence ending in
the 8-bit code table coordinates 4, 4.
Note that ECMA48 5th Edition removes C{IND}.
"""
self.protocol.index()
self.assertEqual(
self.transport.value(), self.ESC + _ecmaCodeTableCoordinate(4, 4)
)
def test_reverseIndex(self):
"""
L{ServerProtocol.reverseIndex} writes the control sequence
ending in the L{C1SevenBit.RI}.
"""
self.protocol.reverseIndex()
self.assertEqual(self.transport.value(), self.ESC + C1SevenBit.RI.value)
def test_nextLine(self):
"""
L{ServerProtocol.nextLine} writes C{"\r\n"} to its transport.
"""
# Why doesn't it write ESC E? Because ESC E is poorly supported. For
# example, gnome-terminal (many different versions) fails to scroll if
# it receives ESC E and the cursor is already on the last row.
self.protocol.nextLine()
self.assertEqual(self.transport.value(), b"\r\n")
def test_setModes(self):
"""
L{ServerProtocol.setModes} writes a control sequence
containing the requested modes and ending in the
L{CSFinalByte.SM}.
"""
modesToSet = [modes.KAM, modes.IRM, modes.LNM]
self.protocol.setModes(modesToSet)
self.assertEqual(
self.transport.value(),
self.CSI
+ b";".join(b"%d" % (m,) for m in modesToSet)
+ CSFinalByte.SM.value,
)
def test_setPrivateModes(self):
"""
L{ServerProtocol.setPrivatesModes} writes a control sequence
containing the requested private modes and ending in the
L{CSFinalByte.SM}.
"""
privateModesToSet = [
privateModes.ERROR,
privateModes.COLUMN,
privateModes.ORIGIN,
]
self.protocol.setModes(privateModesToSet)
self.assertEqual(
self.transport.value(),
self.CSI
+ b";".join(b"%d" % (m,) for m in privateModesToSet)
+ CSFinalByte.SM.value,
)
def test_resetModes(self):
"""
L{ServerProtocol.resetModes} writes the control sequence
ending in the L{CSFinalByte.RM}.
"""
modesToSet = [modes.KAM, modes.IRM, modes.LNM]
self.protocol.resetModes(modesToSet)
self.assertEqual(
self.transport.value(),
self.CSI
+ b";".join(b"%d" % (m,) for m in modesToSet)
+ CSFinalByte.RM.value,
)
def test_singleShift2(self):
"""
L{ServerProtocol.singleShift2} writes an escape sequence
followed by L{C1SevenBit.SS2}
"""
self.protocol.singleShift2()
self.assertEqual(self.transport.value(), self.ESC + C1SevenBit.SS2.value)
def test_singleShift3(self):
"""
L{ServerProtocol.singleShift3} writes an escape sequence
followed by L{C1SevenBit.SS3}
"""
self.protocol.singleShift3()
self.assertEqual(self.transport.value(), self.ESC + C1SevenBit.SS3.value)
def test_selectGraphicRendition(self):
"""
L{ServerProtocol.selectGraphicRendition} writes a control
sequence containing the requested attributes and ending with
L{CSFinalByte.SGR}
"""
self.protocol.selectGraphicRendition(str(BLINK), str(UNDERLINE))
self.assertEqual(
self.transport.value(),
self.CSI + b"%d;%d" % (BLINK, UNDERLINE) + CSFinalByte.SGR.value,
)
def test_horizontalTabulationSet(self):
"""
L{ServerProtocol.horizontalTabulationSet} writes the escape
sequence ending in L{C1SevenBit.HTS}
"""
self.protocol.horizontalTabulationSet()
self.assertEqual(self.transport.value(), self.ESC + C1SevenBit.HTS.value)
def test_eraseToLineEnd(self):
"""
L{ServerProtocol.eraseToLineEnd} writes the control sequence
sequence ending in L{CSFinalByte.EL} and no parameters,
forcing the client to default to 0 (from the active present
position's current location to the end of the line.)
"""
self.protocol.eraseToLineEnd()
self.assertEqual(self.transport.value(), self.CSI + CSFinalByte.EL.value)
def test_eraseToLineBeginning(self):
"""
L{ServerProtocol.eraseToLineBeginning} writes the control
sequence sequence ending in L{CSFinalByte.EL} and a parameter
of 1 (from the beginning of the line up to and include the
active present position's current location.)
"""
self.protocol.eraseToLineBeginning()
self.assertEqual(self.transport.value(), self.CSI + b"1" + CSFinalByte.EL.value)
def test_eraseLine(self):
"""
L{ServerProtocol.eraseLine} writes the control
sequence sequence ending in L{CSFinalByte.EL} and a parameter
of 2 (the entire line.)
"""
self.protocol.eraseLine()
self.assertEqual(self.transport.value(), self.CSI + b"2" + CSFinalByte.EL.value)
def test_eraseToDisplayEnd(self):
"""
L{ServerProtocol.eraseToDisplayEnd} writes the control
sequence sequence ending in L{CSFinalByte.ED} and no parameters,
forcing the client to default to 0 (from the active present
position's current location to the end of the page.)
"""
self.protocol.eraseToDisplayEnd()
self.assertEqual(self.transport.value(), self.CSI + CSFinalByte.ED.value)
def test_eraseToDisplayBeginning(self):
"""
L{ServerProtocol.eraseToDisplayBeginning} writes the control
sequence sequence ending in L{CSFinalByte.ED} a parameter of 1
(from the beginning of the page up to and include the active
present position's current location.)
"""
self.protocol.eraseToDisplayBeginning()
self.assertEqual(self.transport.value(), self.CSI + b"1" + CSFinalByte.ED.value)
def test_eraseToDisplay(self):
"""
L{ServerProtocol.eraseDisplay} writes the control sequence
sequence ending in L{CSFinalByte.ED} a parameter of 2 (the
entire page)
"""
self.protocol.eraseDisplay()
self.assertEqual(self.transport.value(), self.CSI + b"2" + CSFinalByte.ED.value)
def test_deleteCharacter(self):
"""
L{ServerProtocol.deleteCharacter} writes the control sequence
containing the number of characters to delete and ending in
L{CSFinalByte.DCH}
"""
self.protocol.deleteCharacter(4)
self.assertEqual(
self.transport.value(), self.CSI + b"4" + CSFinalByte.DCH.value
)
def test_insertLine(self):
"""
L{ServerProtocol.insertLine} writes the control sequence
containing the number of lines to insert and ending in
L{CSFinalByte.IL}
"""
self.protocol.insertLine(5)
self.assertEqual(self.transport.value(), self.CSI + b"5" + CSFinalByte.IL.value)
def test_deleteLine(self):
"""
L{ServerProtocol.deleteLine} writes the control sequence
containing the number of lines to delete and ending in
L{CSFinalByte.DL}
"""
self.protocol.deleteLine(6)
self.assertEqual(self.transport.value(), self.CSI + b"6" + CSFinalByte.DL.value)
def test_setScrollRegionNoArgs(self):
"""
With no arguments, L{ServerProtocol.setScrollRegion} writes a
control sequence with no parameters, but a parameter
separator, and ending in C{b'r'}.
"""
self.protocol.setScrollRegion()
self.assertEqual(self.transport.value(), self.CSI + b";" + b"r")
def test_setScrollRegionJustFirst(self):
"""
With just a value for its C{first} argument,
L{ServerProtocol.setScrollRegion} writes a control sequence with
that parameter, a parameter separator, and finally a C{b'r'}.
"""
self.protocol.setScrollRegion(first=1)
self.assertEqual(self.transport.value(), self.CSI + b"1;" + b"r")
def test_setScrollRegionJustLast(self):
"""
With just a value for its C{last} argument,
L{ServerProtocol.setScrollRegion} writes a control sequence with
a parameter separator, that parameter, and finally a C{b'r'}.
"""
self.protocol.setScrollRegion(last=1)
self.assertEqual(self.transport.value(), self.CSI + b";1" + b"r")
def test_setScrollRegionFirstAndLast(self):
"""
When given both C{first} and C{last}
L{ServerProtocol.setScrollRegion} writes a control sequence with
the first parameter, a parameter separator, the last
parameter, and finally a C{b'r'}.
"""
self.protocol.setScrollRegion(first=1, last=2)
self.assertEqual(self.transport.value(), self.CSI + b"1;2" + b"r")
def test_reportCursorPosition(self):
"""
L{ServerProtocol.reportCursorPosition} writes a control
sequence ending in L{CSFinalByte.DSR} with a parameter of 6
(the Device Status Report returns the current active
position.)
"""
self.protocol.reportCursorPosition()
self.assertEqual(
self.transport.value(), self.CSI + b"6" + CSFinalByte.DSR.value
)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,462 @@
# -*- test-case-name: twisted.conch.test.test_manhole -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
# pylint: disable=I0011,W9401,W9402
"""
Tests for L{twisted.conch.manhole}.
"""
import sys
import traceback
from typing import Optional
ssh: Optional[bool] = None
from twisted.conch import manhole
from twisted.conch.insults import insults
from twisted.conch.test.test_recvline import (
_SSHMixin,
_StdioMixin,
_TelnetMixin,
ssh,
stdio,
)
from twisted.internet import defer, error
from twisted.internet.testing import StringTransport
from twisted.trial import unittest
def determineDefaultFunctionName():
"""
Return the string used by Python as the name for code objects which are
compiled from interactive input or at the top-level of modules.
"""
try:
1 // 0
except BaseException:
# The last frame is this function. The second to last frame is this
# function's caller, which is module-scope, which is what we want,
# so -2.
return traceback.extract_stack()[-2][2]
defaultFunctionName = determineDefaultFunctionName()
class ManholeInterpreterTests(unittest.TestCase):
"""
Tests for L{manhole.ManholeInterpreter}.
"""
def test_resetBuffer(self):
"""
L{ManholeInterpreter.resetBuffer} should empty the input buffer.
"""
interpreter = manhole.ManholeInterpreter(None)
interpreter.buffer.extend(["1", "2"])
interpreter.resetBuffer()
self.assertFalse(interpreter.buffer)
class ManholeProtocolTests(unittest.TestCase):
"""
Tests for L{manhole.Manhole}.
"""
def test_interruptResetsInterpreterBuffer(self):
"""
L{manhole.Manhole.handle_INT} should cause the interpreter input buffer
to be reset.
"""
transport = StringTransport()
terminal = insults.ServerProtocol(manhole.Manhole)
terminal.makeConnection(transport)
protocol = terminal.terminalProtocol
interpreter = protocol.interpreter
interpreter.buffer.extend(["1", "2"])
protocol.handle_INT()
self.assertFalse(interpreter.buffer)
class WriterTests(unittest.TestCase):
def test_Integer(self):
"""
Colorize an integer.
"""
manhole.lastColorizedLine("1")
def test_DoubleQuoteString(self):
"""
Colorize an integer in double quotes.
"""
manhole.lastColorizedLine('"1"')
def test_SingleQuoteString(self):
"""
Colorize an integer in single quotes.
"""
manhole.lastColorizedLine("'1'")
def test_TripleSingleQuotedString(self):
"""
Colorize an integer in triple quotes.
"""
manhole.lastColorizedLine("'''1'''")
def test_TripleDoubleQuotedString(self):
"""
Colorize an integer in triple and double quotes.
"""
manhole.lastColorizedLine('"""1"""')
def test_FunctionDefinition(self):
"""
Colorize a function definition.
"""
manhole.lastColorizedLine("def foo():")
def test_ClassDefinition(self):
"""
Colorize a class definition.
"""
manhole.lastColorizedLine("class foo:")
def test_unicode(self):
"""
Colorize a Unicode string.
"""
res = manhole.lastColorizedLine("\u0438")
self.assertTrue(isinstance(res, bytes))
def test_bytes(self):
"""
Colorize a UTF-8 byte string.
"""
res = manhole.lastColorizedLine(b"\xd0\xb8")
self.assertTrue(isinstance(res, bytes))
def test_identicalOutput(self):
"""
The output of UTF-8 bytestrings and Unicode strings are identical.
"""
self.assertEqual(
manhole.lastColorizedLine(b"\xd0\xb8"), manhole.lastColorizedLine("\u0438")
)
class ManholeLoopbackMixin:
serverProtocol = manhole.ColoredManhole
def test_SimpleExpression(self):
"""
Evaluate simple expression.
"""
done = self.recvlineClient.expect(b"done")
self._testwrite(b"1 + 1\n" b"done")
def finished(ign):
self._assertBuffer([b">>> 1 + 1", b"2", b">>> done"])
return done.addCallback(finished)
def test_TripleQuoteLineContinuation(self):
"""
Evaluate line continuation in triple quotes.
"""
done = self.recvlineClient.expect(b"done")
self._testwrite(b"'''\n'''\n" b"done")
def finished(ign):
self._assertBuffer([b">>> '''", b"... '''", b"'\\n'", b">>> done"])
return done.addCallback(finished)
def test_FunctionDefinition(self):
"""
Evaluate function definition.
"""
done = self.recvlineClient.expect(b"done")
self._testwrite(b"def foo(bar):\n" b"\tprint(bar)\n\n" b"foo(42)\n" b"done")
def finished(ign):
self._assertBuffer(
[
b">>> def foo(bar):",
b"... print(bar)",
b"... ",
b">>> foo(42)",
b"42",
b">>> done",
]
)
return done.addCallback(finished)
def test_ClassDefinition(self):
"""
Evaluate class definition.
"""
done = self.recvlineClient.expect(b"done")
self._testwrite(
b"class Foo:\n"
b"\tdef bar(self):\n"
b"\t\tprint('Hello, world!')\n\n"
b"Foo().bar()\n"
b"done"
)
def finished(ign):
self._assertBuffer(
[
b">>> class Foo:",
b"... def bar(self):",
b"... print('Hello, world!')",
b"... ",
b">>> Foo().bar()",
b"Hello, world!",
b">>> done",
]
)
return done.addCallback(finished)
def test_Exception(self):
"""
Evaluate raising an exception.
"""
done = self.recvlineClient.expect(b"done")
self._testwrite(b"raise Exception('foo bar baz')\n" b"done")
def finished(ign):
self._assertBuffer(
[
b">>> raise Exception('foo bar baz')",
b"Traceback (most recent call last):",
b' File "<console>", line 1, in '
+ defaultFunctionName.encode("utf-8"),
b"Exception: foo bar baz",
b">>> done",
],
)
done.addCallback(finished)
return done
def test_ExceptionWithCustomExcepthook(
self,
):
"""
Raised exceptions are handled the same way even if L{sys.excepthook}
has been modified from its original value.
"""
self.patch(sys, "excepthook", lambda *args: None)
return self.test_Exception()
def test_ControlC(self):
"""
Evaluate interrupting with CTRL-C.
"""
done = self.recvlineClient.expect(b"done")
self._testwrite(b"cancelled line" + manhole.CTRL_C + b"done")
def finished(ign):
self._assertBuffer(
[b">>> cancelled line", b"KeyboardInterrupt", b">>> done"]
)
return done.addCallback(finished)
def test_interruptDuringContinuation(self):
"""
Sending ^C to Manhole while in a state where more input is required to
complete a statement should discard the entire ongoing statement and
reset the input prompt to the non-continuation prompt.
"""
continuing = self.recvlineClient.expect(b"things")
self._testwrite(b"(\nthings")
def gotContinuation(ignored):
self._assertBuffer([b">>> (", b"... things"])
interrupted = self.recvlineClient.expect(b">>> ")
self._testwrite(manhole.CTRL_C)
return interrupted
continuing.addCallback(gotContinuation)
def gotInterruption(ignored):
self._assertBuffer([b">>> (", b"... things", b"KeyboardInterrupt", b">>> "])
continuing.addCallback(gotInterruption)
return continuing
def test_ControlBackslash(self):
r"""
Evaluate cancelling with CTRL-\.
"""
self._testwrite(b"cancelled line")
partialLine = self.recvlineClient.expect(b"cancelled line")
def gotPartialLine(ign):
self._assertBuffer([b">>> cancelled line"])
self._testwrite(manhole.CTRL_BACKSLASH)
d = self.recvlineClient.onDisconnection
return self.assertFailure(d, error.ConnectionDone)
def gotClearedLine(ign):
self._assertBuffer([b""])
return partialLine.addCallback(gotPartialLine).addCallback(gotClearedLine)
@defer.inlineCallbacks
def test_controlD(self):
"""
A CTRL+D in the middle of a line doesn't close a connection,
but at the beginning of a line it does.
"""
self._testwrite(b"1 + 1")
yield self.recvlineClient.expect(rb"\+ 1")
self._assertBuffer([b">>> 1 + 1"])
self._testwrite(manhole.CTRL_D + b" + 1")
yield self.recvlineClient.expect(rb"\+ 1")
self._assertBuffer([b">>> 1 + 1 + 1"])
self._testwrite(b"\n")
yield self.recvlineClient.expect(b"3\n>>> ")
self._testwrite(manhole.CTRL_D)
d = self.recvlineClient.onDisconnection
yield self.assertFailure(d, error.ConnectionDone)
@defer.inlineCallbacks
def test_ControlL(self):
"""
CTRL+L is generally used as a redraw-screen command in terminal
applications. Manhole doesn't currently respect this usage of it,
but it should at least do something reasonable in response to this
event (rather than, say, eating your face).
"""
# Start off with a newline so that when we clear the display we can
# tell by looking for the missing first empty prompt line.
self._testwrite(b"\n1 + 1")
yield self.recvlineClient.expect(rb"\+ 1")
self._assertBuffer([b">>> ", b">>> 1 + 1"])
self._testwrite(manhole.CTRL_L + b" + 1")
yield self.recvlineClient.expect(rb"1 \+ 1 \+ 1")
self._assertBuffer([b">>> 1 + 1 + 1"])
def test_controlA(self):
"""
CTRL-A can be used as HOME - returning cursor to beginning of
current line buffer.
"""
self._testwrite(b'rint "hello"' + b"\x01" + b"p")
d = self.recvlineClient.expect(b'print "hello"')
def cb(ignore):
self._assertBuffer([b'>>> print "hello"'])
return d.addCallback(cb)
def test_controlE(self):
"""
CTRL-E can be used as END - setting cursor to end of current
line buffer.
"""
self._testwrite(b'rint "hello' + b"\x01" + b"p" + b"\x05" + b'"')
d = self.recvlineClient.expect(b'print "hello"')
def cb(ignore):
self._assertBuffer([b'>>> print "hello"'])
return d.addCallback(cb)
@defer.inlineCallbacks
def test_deferred(self):
"""
When a deferred is returned to the manhole REPL, it is displayed with
a sequence number, and when the deferred fires, the result is printed.
"""
self._testwrite(
b"from twisted.internet import defer, reactor\n"
b"d = defer.Deferred()\n"
b"d\n"
)
yield self.recvlineClient.expect(b"<Deferred #0>")
self._testwrite(b"c = reactor.callLater(0.1, d.callback, 'Hi!')\n")
yield self.recvlineClient.expect(b">>> ")
yield self.recvlineClient.expect(b"Deferred #0 called back: 'Hi!'\n>>> ")
self._assertBuffer(
[
b">>> from twisted.internet import defer, reactor",
b">>> d = defer.Deferred()",
b">>> d",
b"<Deferred #0>",
b">>> c = reactor.callLater(0.1, d.callback, 'Hi!')",
b"Deferred #0 called back: 'Hi!'",
b">>> ",
]
)
class ManholeLoopbackTelnetTests(_TelnetMixin, unittest.TestCase, ManholeLoopbackMixin):
"""
Test manhole loopback over Telnet.
"""
pass
class ManholeLoopbackSSHTests(_SSHMixin, unittest.TestCase, ManholeLoopbackMixin):
"""
Test manhole loopback over SSH.
"""
if ssh is None:
skip = "cryptography requirements missing"
class ManholeLoopbackStdioTests(_StdioMixin, unittest.TestCase, ManholeLoopbackMixin):
"""
Test manhole loopback over standard IO.
"""
if stdio is None:
skip = "Terminal requirements missing"
else:
serverProtocol = stdio.ConsoleManhole
class ManholeMainTests(unittest.TestCase):
"""
Test the I{main} method from the I{manhole} module.
"""
if stdio is None:
skip = "Terminal requirements missing"
def test_mainClassNotFound(self):
"""
Will raise an exception when called with an argument which is a
dotted patch which can not be imported..
"""
exception = self.assertRaises(
ValueError,
stdio.main,
argv=["no-such-class"],
)
self.assertEqual("Empty module name", exception.args[0])

View File

@@ -0,0 +1,122 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.conch.manhole_tap}.
"""
from twisted.application.internet import StreamServerEndpointService
from twisted.application.service import MultiService
from twisted.conch import telnet
from twisted.cred import error
from twisted.cred.credentials import UsernamePassword
from twisted.python import usage
from twisted.python.reflect import requireModule
from twisted.trial.unittest import TestCase
cryptography = requireModule("cryptography")
if cryptography:
from twisted.conch import manhole_ssh, manhole_tap
class MakeServiceTests(TestCase):
"""
Tests for L{manhole_tap.makeService}.
"""
if not cryptography:
skip = "can't run without cryptography"
usernamePassword = (b"iamuser", b"thisispassword")
def setUp(self) -> None:
"""
Create a passwd-like file with a user.
"""
self.filename = self.mktemp()
with open(self.filename, "wb") as f:
f.write(b":".join(self.usernamePassword))
self.options = manhole_tap.Options()
def test_requiresPort(self) -> None:
"""
L{manhole_tap.makeService} requires either 'telnetPort' or 'sshPort' to
be given.
"""
with self.assertRaises(usage.UsageError) as e:
manhole_tap.Options().parseOptions([])
self.assertEqual(
e.exception.args[0],
("At least one of --telnetPort " "and --sshPort must be specified"),
)
def test_telnetPort(self) -> None:
"""
L{manhole_tap.makeService} will make a telnet service on the port
defined by C{--telnetPort}. It will not make a SSH service.
"""
self.options.parseOptions(["--telnetPort", "tcp:222"])
service = manhole_tap.makeService(self.options)
self.assertIsInstance(service, MultiService)
self.assertEqual(len(service.services), 1)
self.assertIsInstance(service.services[0], StreamServerEndpointService)
self.assertIsInstance(
service.services[0].factory.protocol, manhole_tap.makeTelnetProtocol
)
self.assertEqual(service.services[0].endpoint._port, 222)
def test_sshPort(self) -> None:
"""
L{manhole_tap.makeService} will make a SSH service on the port
defined by C{--sshPort}. It will not make a telnet service.
"""
# Why the sshKeyDir and sshKeySize params? To prevent it stomping over
# (or using!) the user's private key, we just make a super small one
# which will never be used in a temp directory.
self.options.parseOptions(
[
"--sshKeyDir",
self.mktemp(),
"--sshKeySize",
"1024",
"--sshPort",
"tcp:223",
]
)
service = manhole_tap.makeService(self.options)
self.assertIsInstance(service, MultiService)
self.assertEqual(len(service.services), 1)
self.assertIsInstance(service.services[0], StreamServerEndpointService)
self.assertIsInstance(service.services[0].factory, manhole_ssh.ConchFactory)
self.assertEqual(service.services[0].endpoint._port, 223)
def test_passwd(self) -> None:
"""
The C{--passwd} command-line option will load a passwd-like file.
"""
self.options.parseOptions(["--telnetPort", "tcp:22", "--passwd", self.filename])
service = manhole_tap.makeService(self.options)
portal = service.services[0].factory.protocol.portal
self.assertEqual(len(portal.checkers.keys()), 2)
# Ensure it's the passwd file we wanted by trying to authenticate
self.assertTrue(
self.successResultOf(
portal.login(
UsernamePassword(*self.usernamePassword),
None,
telnet.ITelnetProtocol,
)
)
)
self.assertIsInstance(
self.failureResultOf(
portal.login(
UsernamePassword(b"wrong", b"user"), None, telnet.ITelnetProtocol
)
).value,
error.UnauthorizedLogin,
)

View File

@@ -0,0 +1,44 @@
# -*- twisted.conch.test.test_mixin -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
from __future__ import annotations
from twisted.conch import mixin
from twisted.internet.testing import StringTransport
from twisted.trial import unittest
class TestBufferingProto(mixin.BufferingMixin):
scheduled = False
rescheduled = 0
transport: StringTransport
def schedule(self) -> object:
self.scheduled = True
return object()
def reschedule(self, token: object) -> None:
self.rescheduled += 1
class BufferingTests(unittest.TestCase):
def testBuffering(self) -> None:
p = TestBufferingProto()
t = p.transport = StringTransport()
self.assertFalse(p.scheduled)
L = [b"foo", b"bar", b"baz", b"quux"]
p.write(b"foo")
self.assertTrue(p.scheduled)
self.assertFalse(p.rescheduled)
for s in L:
n = p.rescheduled
p.write(s)
self.assertEqual(p.rescheduled, n + 1)
self.assertEqual(t.value(), b"")
p.flush()
self.assertEqual(t.value(), b"foo" + b"".join(L))

View File

@@ -0,0 +1,131 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.conch.openssh_compat}.
"""
import os
from unittest import skipIf
from twisted.conch.ssh._kex import getDHGeneratorAndPrime
from twisted.conch.test import keydata
from twisted.python.filepath import FilePath
from twisted.python.reflect import requireModule
from twisted.test.test_process import MockOS
from twisted.trial.unittest import TestCase
doSkip = False
skipReason = ""
if requireModule("cryptography"):
from twisted.conch.openssh_compat.factory import OpenSSHFactory
else:
doSkip = True
skipReason = "Cannot run without cryptography"
if not hasattr(os, "geteuid"):
doSkip = True
skipReason = "geteuid/seteuid not available"
@skipIf(doSkip, skipReason)
class OpenSSHFactoryTests(TestCase):
"""
Tests for L{OpenSSHFactory}.
"""
def setUp(self) -> None:
self.factory = OpenSSHFactory()
self.keysDir = FilePath(self.mktemp())
self.keysDir.makedirs()
self.factory.dataRoot = self.keysDir.path
self.moduliDir = FilePath(self.mktemp())
self.moduliDir.makedirs()
self.factory.moduliRoot = self.moduliDir.path
self.keysDir.child("ssh_host_foo").setContent(b"foo")
self.keysDir.child("bar_key").setContent(b"foo")
self.keysDir.child("ssh_host_one_key").setContent(keydata.privateRSA_openssh)
self.keysDir.child("ssh_host_two_key").setContent(keydata.privateDSA_openssh)
self.keysDir.child("ssh_host_three_key").setContent(b"not a key content")
self.keysDir.child("ssh_host_one_key.pub").setContent(keydata.publicRSA_openssh)
self.moduliDir.child("moduli").setContent(
b"\n"
b"# $OpenBSD: moduli,v 1.xx 2016/07/26 12:34:56 jhacker Exp $i\n"
b"# Time Type Tests Tries Size Generator Modulus\n"
b"20030501000000 2 6 100 2047 2 "
b"FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74"
b"020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F1437"
b"4FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7ED"
b"EE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF05"
b"98DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB"
b"9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3B"
b"E39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF695581718"
b"3995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF"
b"\n"
)
self.mockos = MockOS()
self.patch(os, "seteuid", self.mockos.seteuid)
self.patch(os, "setegid", self.mockos.setegid)
def test_getPublicKeys(self) -> None:
"""
L{OpenSSHFactory.getPublicKeys} should return the available public keys
in the data directory
"""
keys = self.factory.getPublicKeys()
self.assertEqual(len(keys), 1)
keyTypes = keys.keys()
self.assertEqual(list(keyTypes), [b"ssh-rsa"])
def test_getPrivateKeys(self) -> None:
"""
Will return the available private keys in the data directory, ignoring
key files which failed to be loaded.
"""
keys = self.factory.getPrivateKeys()
self.assertEqual(len(keys), 2)
keyTypes = keys.keys()
self.assertEqual(set(keyTypes), {b"ssh-rsa", b"ssh-dss"})
self.assertEqual(self.mockos.seteuidCalls, [])
self.assertEqual(self.mockos.setegidCalls, [])
def test_getPrivateKeysAsRoot(self) -> None:
"""
L{OpenSSHFactory.getPrivateKeys} should switch to root if the keys
aren't readable by the current user.
"""
keyFile = self.keysDir.child("ssh_host_two_key")
# Fake permission error by changing the mode
keyFile.chmod(0000)
self.addCleanup(keyFile.chmod, 0o777)
# And restore the right mode when seteuid is called
savedSeteuid = os.seteuid
def seteuid(euid: int) -> None:
keyFile.chmod(0o777)
return savedSeteuid(euid)
self.patch(os, "seteuid", seteuid)
keys = self.factory.getPrivateKeys()
self.assertEqual(len(keys), 2)
keyTypes = keys.keys()
self.assertEqual(set(keyTypes), {b"ssh-rsa", b"ssh-dss"})
self.assertEqual(self.mockos.seteuidCalls, [0, os.geteuid()])
self.assertEqual(self.mockos.setegidCalls, [0, os.getegid()])
def test_getPrimes(self) -> None:
"""
L{OpenSSHFactory.getPrimes} should return the available primes
in the moduli directory.
"""
primes = self.factory.getPrimes()
self.assertEqual(
primes,
{
2048: [getDHGeneratorAndPrime(b"diffie-hellman-group14-sha1")],
},
)

View File

@@ -0,0 +1,801 @@
# -*- test-case-name: twisted.conch.test.test_recvline -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.conch.recvline} and fixtures for testing related
functionality.
"""
import os
import sys
from unittest import skipIf
from twisted.conch import recvline
from twisted.conch.insults import insults
from twisted.cred import portal
from twisted.internet import defer, error
from twisted.internet.testing import StringTransport
from twisted.python import components, filepath, reflect
from twisted.python.compat import iterbytes
from twisted.python.reflect import requireModule
from twisted.trial.unittest import SkipTest, TestCase
stdio = requireModule("twisted.conch.stdio")
properEnv = dict(os.environ)
properEnv["PYTHONPATH"] = os.pathsep.join(sys.path)
class ArrowsTests(TestCase):
def setUp(self):
self.underlyingTransport = StringTransport()
self.pt = insults.ServerProtocol()
self.p = recvline.HistoricRecvLine()
self.pt.protocolFactory = lambda: self.p
self.pt.factory = self
self.pt.makeConnection(self.underlyingTransport)
def test_printableCharacters(self):
"""
When L{HistoricRecvLine} receives a printable character,
it adds it to the current line buffer.
"""
self.p.keystrokeReceived(b"x", None)
self.p.keystrokeReceived(b"y", None)
self.p.keystrokeReceived(b"z", None)
self.assertEqual(self.p.currentLineBuffer(), (b"xyz", b""))
def test_horizontalArrows(self):
"""
When L{HistoricRecvLine} receives a LEFT_ARROW or
RIGHT_ARROW keystroke it moves the cursor left or right
in the current line buffer, respectively.
"""
kR = lambda ch: self.p.keystrokeReceived(ch, None)
for ch in iterbytes(b"xyz"):
kR(ch)
self.assertEqual(self.p.currentLineBuffer(), (b"xyz", b""))
kR(self.pt.RIGHT_ARROW)
self.assertEqual(self.p.currentLineBuffer(), (b"xyz", b""))
kR(self.pt.LEFT_ARROW)
self.assertEqual(self.p.currentLineBuffer(), (b"xy", b"z"))
kR(self.pt.LEFT_ARROW)
self.assertEqual(self.p.currentLineBuffer(), (b"x", b"yz"))
kR(self.pt.LEFT_ARROW)
self.assertEqual(self.p.currentLineBuffer(), (b"", b"xyz"))
kR(self.pt.LEFT_ARROW)
self.assertEqual(self.p.currentLineBuffer(), (b"", b"xyz"))
kR(self.pt.RIGHT_ARROW)
self.assertEqual(self.p.currentLineBuffer(), (b"x", b"yz"))
kR(self.pt.RIGHT_ARROW)
self.assertEqual(self.p.currentLineBuffer(), (b"xy", b"z"))
kR(self.pt.RIGHT_ARROW)
self.assertEqual(self.p.currentLineBuffer(), (b"xyz", b""))
kR(self.pt.RIGHT_ARROW)
self.assertEqual(self.p.currentLineBuffer(), (b"xyz", b""))
def test_newline(self):
"""
When {HistoricRecvLine} receives a newline, it adds the current
line buffer to the end of its history buffer.
"""
kR = lambda ch: self.p.keystrokeReceived(ch, None)
for ch in iterbytes(b"xyz\nabc\n123\n"):
kR(ch)
self.assertEqual(self.p.currentHistoryBuffer(), ((b"xyz", b"abc", b"123"), ()))
kR(b"c")
kR(b"b")
kR(b"a")
self.assertEqual(self.p.currentHistoryBuffer(), ((b"xyz", b"abc", b"123"), ()))
kR(b"\n")
self.assertEqual(
self.p.currentHistoryBuffer(), ((b"xyz", b"abc", b"123", b"cba"), ())
)
def test_verticalArrows(self):
"""
When L{HistoricRecvLine} receives UP_ARROW or DOWN_ARROW
keystrokes it move the current index in the current history
buffer up or down, and resets the current line buffer to the
previous or next line in history, respectively for each.
"""
kR = lambda ch: self.p.keystrokeReceived(ch, None)
for ch in iterbytes(b"xyz\nabc\n123\n"):
kR(ch)
self.assertEqual(self.p.currentHistoryBuffer(), ((b"xyz", b"abc", b"123"), ()))
self.assertEqual(self.p.currentLineBuffer(), (b"", b""))
kR(self.pt.UP_ARROW)
self.assertEqual(self.p.currentHistoryBuffer(), ((b"xyz", b"abc"), (b"123",)))
self.assertEqual(self.p.currentLineBuffer(), (b"123", b""))
kR(self.pt.UP_ARROW)
self.assertEqual(self.p.currentHistoryBuffer(), ((b"xyz",), (b"abc", b"123")))
self.assertEqual(self.p.currentLineBuffer(), (b"abc", b""))
kR(self.pt.UP_ARROW)
self.assertEqual(self.p.currentHistoryBuffer(), ((), (b"xyz", b"abc", b"123")))
self.assertEqual(self.p.currentLineBuffer(), (b"xyz", b""))
kR(self.pt.UP_ARROW)
self.assertEqual(self.p.currentHistoryBuffer(), ((), (b"xyz", b"abc", b"123")))
self.assertEqual(self.p.currentLineBuffer(), (b"xyz", b""))
for i in range(4):
kR(self.pt.DOWN_ARROW)
self.assertEqual(self.p.currentHistoryBuffer(), ((b"xyz", b"abc", b"123"), ()))
def test_home(self):
"""
When L{HistoricRecvLine} receives a HOME keystroke it moves the
cursor to the beginning of the current line buffer.
"""
kR = lambda ch: self.p.keystrokeReceived(ch, None)
for ch in iterbytes(b"hello, world"):
kR(ch)
self.assertEqual(self.p.currentLineBuffer(), (b"hello, world", b""))
kR(self.pt.HOME)
self.assertEqual(self.p.currentLineBuffer(), (b"", b"hello, world"))
def test_end(self):
"""
When L{HistoricRecvLine} receives an END keystroke it moves the cursor
to the end of the current line buffer.
"""
kR = lambda ch: self.p.keystrokeReceived(ch, None)
for ch in iterbytes(b"hello, world"):
kR(ch)
self.assertEqual(self.p.currentLineBuffer(), (b"hello, world", b""))
kR(self.pt.HOME)
kR(self.pt.END)
self.assertEqual(self.p.currentLineBuffer(), (b"hello, world", b""))
def test_backspace(self):
"""
When L{HistoricRecvLine} receives a BACKSPACE keystroke it deletes
the character immediately before the cursor.
"""
kR = lambda ch: self.p.keystrokeReceived(ch, None)
for ch in iterbytes(b"xyz"):
kR(ch)
self.assertEqual(self.p.currentLineBuffer(), (b"xyz", b""))
kR(self.pt.BACKSPACE)
self.assertEqual(self.p.currentLineBuffer(), (b"xy", b""))
kR(self.pt.LEFT_ARROW)
kR(self.pt.BACKSPACE)
self.assertEqual(self.p.currentLineBuffer(), (b"", b"y"))
kR(self.pt.BACKSPACE)
self.assertEqual(self.p.currentLineBuffer(), (b"", b"y"))
def test_delete(self):
"""
When L{HistoricRecvLine} receives a DELETE keystroke, it
delets the character immediately after the cursor.
"""
kR = lambda ch: self.p.keystrokeReceived(ch, None)
for ch in iterbytes(b"xyz"):
kR(ch)
self.assertEqual(self.p.currentLineBuffer(), (b"xyz", b""))
kR(self.pt.DELETE)
self.assertEqual(self.p.currentLineBuffer(), (b"xyz", b""))
kR(self.pt.LEFT_ARROW)
kR(self.pt.DELETE)
self.assertEqual(self.p.currentLineBuffer(), (b"xy", b""))
kR(self.pt.LEFT_ARROW)
kR(self.pt.DELETE)
self.assertEqual(self.p.currentLineBuffer(), (b"x", b""))
kR(self.pt.LEFT_ARROW)
kR(self.pt.DELETE)
self.assertEqual(self.p.currentLineBuffer(), (b"", b""))
kR(self.pt.DELETE)
self.assertEqual(self.p.currentLineBuffer(), (b"", b""))
def test_insert(self):
"""
When not in INSERT mode, L{HistoricRecvLine} inserts the typed
character at the cursor before the next character.
"""
kR = lambda ch: self.p.keystrokeReceived(ch, None)
for ch in iterbytes(b"xyz"):
kR(ch)
kR(self.pt.LEFT_ARROW)
kR(b"A")
self.assertEqual(self.p.currentLineBuffer(), (b"xyA", b"z"))
kR(self.pt.LEFT_ARROW)
kR(b"B")
self.assertEqual(self.p.currentLineBuffer(), (b"xyB", b"Az"))
def test_typeover(self):
"""
When in INSERT mode and upon receiving a keystroke with a printable
character, L{HistoricRecvLine} replaces the character at
the cursor with the typed character rather than inserting before.
Ah, the ironies of INSERT mode.
"""
kR = lambda ch: self.p.keystrokeReceived(ch, None)
for ch in iterbytes(b"xyz"):
kR(ch)
kR(self.pt.INSERT)
kR(self.pt.LEFT_ARROW)
kR(b"A")
self.assertEqual(self.p.currentLineBuffer(), (b"xyA", b""))
kR(self.pt.LEFT_ARROW)
kR(b"B")
self.assertEqual(self.p.currentLineBuffer(), (b"xyB", b""))
def test_unprintableCharacters(self):
"""
When L{HistoricRecvLine} receives a keystroke for an unprintable
function key with no assigned behavior, the line buffer is unmodified.
"""
kR = lambda ch: self.p.keystrokeReceived(ch, None)
pt = self.pt
for ch in (
pt.F1,
pt.F2,
pt.F3,
pt.F4,
pt.F5,
pt.F6,
pt.F7,
pt.F8,
pt.F9,
pt.F10,
pt.F11,
pt.F12,
pt.PGUP,
pt.PGDN,
):
kR(ch)
self.assertEqual(self.p.currentLineBuffer(), (b"", b""))
from twisted.conch import telnet
from twisted.conch.insults import helper
from twisted.conch.test.loopback import LoopbackRelay
class EchoServer(recvline.HistoricRecvLine):
def lineReceived(self, line):
self.terminal.write(line + b"\n" + self.ps[self.pn])
# An insults API for this would be nice.
left = b"\x1b[D"
right = b"\x1b[C"
up = b"\x1b[A"
down = b"\x1b[B"
insert = b"\x1b[2~"
home = b"\x1b[1~"
delete = b"\x1b[3~"
end = b"\x1b[4~"
backspace = b"\x7f"
from twisted.cred import checkers
try:
from twisted.conch.manhole_ssh import (
ConchFactory,
TerminalRealm,
TerminalSession,
TerminalSessionTransport,
TerminalUser,
)
from twisted.conch.ssh import (
channel,
connection,
keys,
session,
transport,
userauth,
)
except ImportError:
ssh = False
else:
ssh = True
class SessionChannel(channel.SSHChannel):
name = b"session"
def __init__(
self, protocolFactory, protocolArgs, protocolKwArgs, width, height, *a, **kw
):
channel.SSHChannel.__init__(self, *a, **kw)
self.protocolFactory = protocolFactory
self.protocolArgs = protocolArgs
self.protocolKwArgs = protocolKwArgs
self.width = width
self.height = height
def channelOpen(self, data):
term = session.packRequest_pty_req(
b"vt102", (self.height, self.width, 0, 0), b""
)
self.conn.sendRequest(self, b"pty-req", term)
self.conn.sendRequest(self, b"shell", b"")
self._protocolInstance = self.protocolFactory(
*self.protocolArgs, **self.protocolKwArgs
)
self._protocolInstance.factory = self
self._protocolInstance.makeConnection(self)
def closed(self):
self._protocolInstance.connectionLost(error.ConnectionDone())
def dataReceived(self, data):
self._protocolInstance.dataReceived(data)
class TestConnection(connection.SSHConnection):
def __init__(
self, protocolFactory, protocolArgs, protocolKwArgs, width, height, *a, **kw
):
connection.SSHConnection.__init__(self, *a, **kw)
self.protocolFactory = protocolFactory
self.protocolArgs = protocolArgs
self.protocolKwArgs = protocolKwArgs
self.width = width
self.height = height
def serviceStarted(self):
self.__channel = SessionChannel(
self.protocolFactory,
self.protocolArgs,
self.protocolKwArgs,
self.width,
self.height,
)
self.openChannel(self.__channel)
def write(self, data):
return self.__channel.write(data)
class TestAuth(userauth.SSHUserAuthClient):
def __init__(self, username, password, *a, **kw):
userauth.SSHUserAuthClient.__init__(self, username, *a, **kw)
self.password = password
def getPassword(self):
return defer.succeed(self.password)
class TestTransport(transport.SSHClientTransport):
def __init__(
self,
protocolFactory,
protocolArgs,
protocolKwArgs,
username,
password,
width,
height,
*a,
**kw,
):
self.protocolFactory = protocolFactory
self.protocolArgs = protocolArgs
self.protocolKwArgs = protocolKwArgs
self.username = username
self.password = password
self.width = width
self.height = height
def verifyHostKey(self, hostKey, fingerprint):
return defer.succeed(True)
def connectionSecure(self):
self.__connection = TestConnection(
self.protocolFactory,
self.protocolArgs,
self.protocolKwArgs,
self.width,
self.height,
)
self.requestService(
TestAuth(self.username, self.password, self.__connection)
)
def write(self, data):
return self.__connection.write(data)
class TestSessionTransport(TerminalSessionTransport):
def protocolFactory(self):
return self.avatar.conn.transport.factory.serverProtocol()
class TestSession(TerminalSession):
transportFactory = TestSessionTransport
class TestUser(TerminalUser):
pass
components.registerAdapter(TestSession, TestUser, session.ISession)
class NotifyingExpectableBuffer(helper.ExpectableBuffer):
def __init__(self):
self.onConnection = defer.Deferred()
self.onDisconnection = defer.Deferred()
def connectionMade(self):
helper.ExpectableBuffer.connectionMade(self)
self.onConnection.callback(self)
def connectionLost(self, reason):
self.onDisconnection.errback(reason)
class _BaseMixin:
WIDTH = 80
HEIGHT = 24
def _assertBuffer(self, lines):
receivedLines = self.recvlineClient.__bytes__().splitlines()
expectedLines = lines + ([b""] * (self.HEIGHT - len(lines) - 1))
self.assertEqual(receivedLines, expectedLines)
def _trivialTest(self, inputLine, output):
done = self.recvlineClient.expect(b"done")
self._testwrite(inputLine)
def finished(ign):
self._assertBuffer(output)
return done.addCallback(finished)
class _SSHMixin(_BaseMixin):
def setUp(self):
if not ssh:
raise SkipTest(
"cryptography requirements missing, can't run historic "
"recvline tests over ssh"
)
u, p = b"testuser", b"testpass"
rlm = TerminalRealm()
rlm.userFactory = TestUser
rlm.chainedProtocolFactory = lambda: insultsServer
checker = checkers.InMemoryUsernamePasswordDatabaseDontUse()
checker.addUser(u, p)
ptl = portal.Portal(rlm)
ptl.registerChecker(checker)
sshFactory = ConchFactory(ptl)
sshKey = keys._getPersistentRSAKey(
filepath.FilePath(self.mktemp()), keySize=1024
)
sshFactory.publicKeys[b"ssh-rsa"] = sshKey
sshFactory.privateKeys[b"ssh-rsa"] = sshKey
sshFactory.serverProtocol = self.serverProtocol
sshFactory.startFactory()
recvlineServer = self.serverProtocol()
insultsServer = insults.ServerProtocol(lambda: recvlineServer)
sshServer = sshFactory.buildProtocol(None)
clientTransport = LoopbackRelay(sshServer)
recvlineClient = NotifyingExpectableBuffer()
insultsClient = insults.ClientProtocol(lambda: recvlineClient)
sshClient = TestTransport(
lambda: insultsClient, (), {}, u, p, self.WIDTH, self.HEIGHT
)
serverTransport = LoopbackRelay(sshClient)
sshClient.makeConnection(clientTransport)
sshServer.makeConnection(serverTransport)
self.recvlineClient = recvlineClient
self.sshClient = sshClient
self.sshServer = sshServer
self.clientTransport = clientTransport
self.serverTransport = serverTransport
return recvlineClient.onConnection
def _testwrite(self, data):
self.sshClient.write(data)
from twisted.conch.test import test_telnet
class TestInsultsClientProtocol(insults.ClientProtocol, test_telnet.TestProtocol):
pass
class TestInsultsServerProtocol(insults.ServerProtocol, test_telnet.TestProtocol):
pass
class _TelnetMixin(_BaseMixin):
def setUp(self):
recvlineServer = self.serverProtocol()
insultsServer = TestInsultsServerProtocol(lambda: recvlineServer)
telnetServer = telnet.TelnetTransport(lambda: insultsServer)
clientTransport = LoopbackRelay(telnetServer)
recvlineClient = NotifyingExpectableBuffer()
insultsClient = TestInsultsClientProtocol(lambda: recvlineClient)
telnetClient = telnet.TelnetTransport(lambda: insultsClient)
serverTransport = LoopbackRelay(telnetClient)
telnetClient.makeConnection(clientTransport)
telnetServer.makeConnection(serverTransport)
serverTransport.clearBuffer()
clientTransport.clearBuffer()
self.recvlineClient = recvlineClient
self.telnetClient = telnetClient
self.clientTransport = clientTransport
self.serverTransport = serverTransport
return recvlineClient.onConnection
def _testwrite(self, data):
self.telnetClient.write(data)
class _StdioMixin(_BaseMixin):
def setUp(self):
# A memory-only terminal emulator, into which the server will
# write things and make other state changes. What ends up
# here is basically what a user would have seen on their
# screen.
testTerminal = NotifyingExpectableBuffer()
# An insults client protocol which will translate bytes
# received from the child process into keystroke commands for
# an ITerminalProtocol.
insultsClient = insults.ClientProtocol(lambda: testTerminal)
# A process protocol which will translate stdout and stderr
# received from the child process to dataReceived calls and
# error reporting on an insults client protocol.
processClient = stdio.TerminalProcessProtocol(insultsClient)
# Run twisted/conch/stdio.py with the name of a class
# implementing ITerminalProtocol. This class will be used to
# handle bytes we send to the child process.
exe = sys.executable
module = stdio.__file__
if module.endswith(".pyc") or module.endswith(".pyo"):
module = module[:-1]
args = [exe, module, reflect.qual(self.serverProtocol)]
from twisted.internet import reactor
clientTransport = reactor.spawnProcess(
processClient, exe, args, env=properEnv, usePTY=True
)
self.recvlineClient = self.testTerminal = testTerminal
self.processClient = processClient
self.clientTransport = clientTransport
# Wait for the process protocol and test terminal to become
# connected before proceeding. The former should always
# happen first, but it doesn't hurt to be safe.
return defer.gatherResults(
filter(None, [processClient.onConnection, testTerminal.expect(b">>> ")])
)
def tearDown(self):
# Kill the child process. We're done with it.
try:
self.clientTransport.signalProcess("KILL")
except (error.ProcessExitedAlready, OSError):
pass
def trap(failure):
failure.trap(error.ProcessTerminated)
self.assertIsNone(failure.value.exitCode)
self.assertEqual(failure.value.status, 9)
return self.testTerminal.onDisconnection.addErrback(trap)
def _testwrite(self, data):
self.clientTransport.write(data)
class RecvlineLoopbackMixin:
serverProtocol = EchoServer
def testSimple(self):
return self._trivialTest(
b"first line\ndone", [b">>> first line", b"first line", b">>> done"]
)
def testLeftArrow(self):
return self._trivialTest(
insert + b"first line" + left * 4 + b"xxxx\ndone",
[b">>> first xxxx", b"first xxxx", b">>> done"],
)
def testRightArrow(self):
return self._trivialTest(
insert + b"right line" + left * 4 + right * 2 + b"xx\ndone",
[b">>> right lixx", b"right lixx", b">>> done"],
)
def testBackspace(self):
return self._trivialTest(
b"second line" + backspace * 4 + b"xxxx\ndone",
[b">>> second xxxx", b"second xxxx", b">>> done"],
)
def testDelete(self):
return self._trivialTest(
b"delete xxxx" + left * 4 + delete * 4 + b"line\ndone",
[b">>> delete line", b"delete line", b">>> done"],
)
def testInsert(self):
return self._trivialTest(
b"third ine" + left * 3 + b"l\ndone",
[b">>> third line", b"third line", b">>> done"],
)
def testTypeover(self):
return self._trivialTest(
b"fourth xine" + left * 4 + insert + b"l\ndone",
[b">>> fourth line", b"fourth line", b">>> done"],
)
def testHome(self):
return self._trivialTest(
insert + b"blah line" + home + b"home\ndone",
[b">>> home line", b"home line", b">>> done"],
)
def testEnd(self):
return self._trivialTest(
b"end " + left * 4 + end + b"line\ndone",
[b">>> end line", b"end line", b">>> done"],
)
class RecvlineLoopbackTelnetTests(_TelnetMixin, TestCase, RecvlineLoopbackMixin):
pass
class RecvlineLoopbackSSHTests(_SSHMixin, TestCase, RecvlineLoopbackMixin):
pass
@skipIf(not stdio, "Terminal requirements missing, can't run recvline tests over stdio")
class RecvlineLoopbackStdioTests(_StdioMixin, TestCase, RecvlineLoopbackMixin):
pass
class HistoricRecvlineLoopbackMixin:
serverProtocol = EchoServer
def testUpArrow(self):
return self._trivialTest(
b"first line\n" + up + b"\ndone",
[
b">>> first line",
b"first line",
b">>> first line",
b"first line",
b">>> done",
],
)
def test_DownArrowToPartialLineInHistory(self):
"""
Pressing down arrow to visit an entry that was added to the
history by pressing the up arrow instead of return does not
raise a L{TypeError}.
@see: U{http://twistedmatrix.com/trac/ticket/9031}
@return: A L{defer.Deferred} that fires when C{b"done"} is
echoed back.
"""
return self._trivialTest(
b"first line\n" + b"partial line" + up + down + b"\ndone",
[
b">>> first line",
b"first line",
b">>> partial line",
b"partial line",
b">>> done",
],
)
def testDownArrow(self):
return self._trivialTest(
b"first line\nsecond line\n" + up * 2 + down + b"\ndone",
[
b">>> first line",
b"first line",
b">>> second line",
b"second line",
b">>> second line",
b"second line",
b">>> done",
],
)
class HistoricRecvlineLoopbackTelnetTests(
_TelnetMixin, TestCase, HistoricRecvlineLoopbackMixin
):
pass
class HistoricRecvlineLoopbackSSHTests(
_SSHMixin, TestCase, HistoricRecvlineLoopbackMixin
):
pass
@skipIf(
not stdio,
"Terminal requirements missing, " "can't run historic recvline tests over stdio",
)
class HistoricRecvlineLoopbackStdioTests(
_StdioMixin, TestCase, HistoricRecvlineLoopbackMixin
):
pass
class TransportSequenceTests(TestCase):
"""
L{twisted.conch.recvline.TransportSequence}
"""
def test_invalidSequence(self):
"""
Initializing a L{recvline.TransportSequence} with no args
raises an assertion.
"""
self.assertRaises(AssertionError, recvline.TransportSequence)

View File

@@ -0,0 +1,70 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for the command-line interfaces to conch.
"""
from unittest import skipIf
from twisted.python.reflect import requireModule
from twisted.python.test.test_shellcomp import ZshScriptTestMixin
from twisted.scripts.test.test_scripts import ScriptTestsMixin
from twisted.trial.unittest import TestCase
doSkip = False
skipReason = ""
if not requireModule("cryptography"):
doSkip = True
cryptoSkip = "can't run w/o cryptography"
if not requireModule("tty"):
doSkip = True
ttySkip = "can't run w/o tty"
try:
import tkinter
except ImportError:
doSkip = True
skipReason = "can't run w/o tkinter"
else:
try:
tkinter.Tk().destroy()
except (tkinter.TclError, RuntimeError) as e:
# On GitHub Action the macOS Python might not support the version of TK
# provided by the OS and it will raise a RuntimeError
# See: https://github.com/actions/setup-python/issues/649
doSkip = True
skipReason = "Can't test Tkinter: " + str(e)
@skipIf(doSkip, skipReason)
class ScriptTests(TestCase, ScriptTestsMixin):
"""
Tests for the Conch scripts.
"""
def test_conch(self) -> None:
self.scriptTest("conch/conch")
def test_cftp(self) -> None:
self.scriptTest("conch/cftp")
def test_ckeygen(self) -> None:
self.scriptTest("conch/ckeygen")
def test_tkconch(self) -> None:
self.scriptTest("conch/tkconch")
class ZshIntegrationTests(TestCase, ZshScriptTestMixin):
"""
Test that zsh completion functions are generated without error
"""
generateFor = [
("conch", "twisted.conch.scripts.conch.ClientOptions"),
("cftp", "twisted.conch.scripts.cftp.ClientOptions"),
("ckeygen", "twisted.conch.scripts.ckeygen.GeneralOptions"),
("tkconch", "twisted.conch.scripts.tkconch.GeneralOptions"),
]

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,150 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.conch.tap}.
"""
from typing import Any, Tuple, Union
from twisted.application.internet import StreamServerEndpointService
from twisted.cred import error
from twisted.cred.checkers import FilePasswordDB, ICredentialsChecker
from twisted.cred.credentials import ISSHPrivateKey, IUsernamePassword, UsernamePassword
from twisted.internet.defer import Deferred
from twisted.python.reflect import requireModule
from twisted.trial.unittest import TestCase
cryptography = requireModule("cryptography")
unix = requireModule("twisted.conch.unix")
if cryptography and unix:
from twisted.conch import tap
from twisted.conch.openssh_compat.factory import OpenSSHFactory
class MakeServiceTests(TestCase):
"""
Tests for L{tap.makeService}.
"""
if not cryptography:
skip = "can't run without cryptography"
if not unix:
skip = "can't run on non-posix computers"
usernamePassword = (b"iamuser", b"thisispassword")
def setUp(self) -> None:
"""
Create a file with two users.
"""
self.filename = self.mktemp()
with open(self.filename, "wb+") as f:
f.write(b":".join(self.usernamePassword))
self.options = tap.Options()
def test_basic(self) -> None:
"""
L{tap.makeService} returns a L{StreamServerEndpointService} instance
running on TCP port 22, and the linked protocol factory is an instance
of L{OpenSSHFactory}.
"""
config = tap.Options()
service = tap.makeService(config)
self.assertIsInstance(service, StreamServerEndpointService)
self.assertEqual(service.endpoint._port, 22)
self.assertIsInstance(service.factory, OpenSSHFactory)
def test_defaultAuths(self) -> None:
"""
Make sure that if the C{--auth} command-line option is not passed,
the default checkers are (for backwards compatibility): SSH and UNIX
"""
numCheckers = 2
self.assertIn(
ISSHPrivateKey,
self.options["credInterfaces"],
"SSH should be one of the default checkers",
)
self.assertIn(
IUsernamePassword,
self.options["credInterfaces"],
"UNIX should be one of the default checkers",
)
self.assertEqual(
numCheckers,
len(self.options["credCheckers"]),
"There should be %d checkers by default" % (numCheckers,),
)
def test_authAdded(self) -> None:
"""
The C{--auth} command-line option will add a checker to the list of
checkers, and it should be the only auth checker
"""
self.options.parseOptions(["--auth", "file:" + self.filename])
self.assertEqual(len(self.options["credCheckers"]), 1)
def test_multipleAuthAdded(self) -> None:
"""
Multiple C{--auth} command-line options will add all checkers specified
to the list ofcheckers, and there should only be the specified auth
checkers (no default checkers).
"""
self.options.parseOptions(
[
"--auth",
"file:" + self.filename,
"--auth",
"memory:testuser:testpassword",
]
)
self.assertEqual(len(self.options["credCheckers"]), 2)
def test_authFailure(self) -> Any:
"""
The checker created by the C{--auth} command-line option returns a
L{Deferred} that fails with L{UnauthorizedLogin} when
presented with credentials that are unknown to that checker.
"""
self.options.parseOptions(["--auth", "file:" + self.filename])
checker: FilePasswordDB = self.options["credCheckers"][-1]
self.assertIsInstance(checker, FilePasswordDB)
invalid = UsernamePassword(self.usernamePassword[0], b"fake")
# Wrong password should raise error
return self.assertFailure(
checker.requestAvatarId(invalid), error.UnauthorizedLogin
)
def test_authSuccess(self) -> Deferred[None]:
"""
The checker created by the C{--auth} command-line option returns a
L{Deferred} that returns the avatar id when presented with credentials
that are known to that checker.
"""
self.options.parseOptions(["--auth", "file:" + self.filename])
checker: ICredentialsChecker = self.options["credCheckers"][-1]
correct = UsernamePassword(*self.usernamePassword)
d = checker.requestAvatarId(correct)
def checkSuccess(username: Union[bytes, Tuple[()]]) -> None:
self.assertEqual(username, correct.username)
return d.addCallback(checkSuccess)
def test_checkers(self) -> None:
"""
The L{OpenSSHFactory} built by L{tap.makeService} has a portal with
L{ISSHPrivateKey} and L{IUsernamePassword} interfaces registered as
checkers.
"""
config = tap.Options()
service = tap.makeService(config)
portal = service.factory.portal
self.assertEqual(
set(portal.checkers.keys()), {ISSHPrivateKey, IUsernamePassword}
)

View File

@@ -0,0 +1,778 @@
# -*- test-case-name: twisted.conch.test.test_telnet -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.conch.telnet}.
"""
from zope.interface import implementer
from zope.interface.verify import verifyObject
from twisted.conch import telnet
from twisted.internet import defer
from twisted.python.compat import iterbytes
from twisted.test import proto_helpers
from twisted.trial import unittest
@implementer(telnet.ITelnetProtocol)
class TestProtocol:
localEnableable = ()
remoteEnableable = ()
def __init__(self):
self.data = b""
self.subcmd = []
self.calls = []
self.enabledLocal = []
self.enabledRemote = []
self.disabledLocal = []
self.disabledRemote = []
def makeConnection(self, transport):
d = transport.negotiationMap = {}
d[b"\x12"] = self.neg_TEST_COMMAND
d = transport.commandMap = transport.commandMap.copy()
for cmd in ("EOR", "NOP", "DM", "BRK", "IP", "AO", "AYT", "EC", "EL", "GA"):
d[getattr(telnet, cmd)] = lambda arg, cmd=cmd: self.calls.append(cmd)
def dataReceived(self, data):
self.data += data
def connectionLost(self, reason):
pass
def neg_TEST_COMMAND(self, payload):
self.subcmd = payload
def enableLocal(self, option):
if option in self.localEnableable:
self.enabledLocal.append(option)
return True
return False
def disableLocal(self, option):
self.disabledLocal.append(option)
def enableRemote(self, option):
if option in self.remoteEnableable:
self.enabledRemote.append(option)
return True
return False
def disableRemote(self, option):
self.disabledRemote.append(option)
def connectionMade(self):
# IProtocol.connectionMade
pass
def unhandledCommand(self, command, argument):
# ITelnetProtocol.unhandledCommand
pass
def unhandledSubnegotiation(self, command, data):
# ITelnetProtocol.unhandledSubnegotiation
pass
class InterfacesTests(unittest.TestCase):
def test_interface(self):
"""
L{telnet.TelnetProtocol} implements L{telnet.ITelnetProtocol}
"""
p = telnet.TelnetProtocol()
verifyObject(telnet.ITelnetProtocol, p)
class TelnetTransportTests(unittest.TestCase):
"""
Tests for L{telnet.TelnetTransport}.
"""
def setUp(self):
self.p = telnet.TelnetTransport(TestProtocol)
self.t = proto_helpers.StringTransport()
self.p.makeConnection(self.t)
def testRegularBytes(self):
# Just send a bunch of bytes. None of these do anything
# with telnet. They should pass right through to the
# application layer.
h = self.p.protocol
L = [
b"here are some bytes la la la",
b"some more arrive here",
b"lots of bytes to play with",
b"la la la",
b"ta de da",
b"dum",
]
for b in L:
self.p.dataReceived(b)
self.assertEqual(h.data, b"".join(L))
def testNewlineHandling(self):
# Send various kinds of newlines and make sure they get translated
# into \n.
h = self.p.protocol
L = [
b"here is the first line\r\n",
b"here is the second line\r\0",
b"here is the third line\r\n",
b"here is the last line\r\0",
]
for b in L:
self.p.dataReceived(b)
self.assertEqual(
h.data,
L[0][:-2]
+ b"\n"
+ L[1][:-2]
+ b"\r"
+ L[2][:-2]
+ b"\n"
+ L[3][:-2]
+ b"\r",
)
def testIACEscape(self):
# Send a bunch of bytes and a couple quoted \xFFs. Unquoted,
# \xFF is a telnet command. Quoted, one of them from each pair
# should be passed through to the application layer.
h = self.p.protocol
L = [
b"here are some bytes\xff\xff with an embedded IAC",
b"and here is a test of a border escape\xff",
b"\xff did you get that IAC?",
]
for b in L:
self.p.dataReceived(b)
self.assertEqual(h.data, b"".join(L).replace(b"\xff\xff", b"\xff"))
def _simpleCommandTest(self, cmdName):
# Send a single simple telnet command and make sure
# it gets noticed and the appropriate method gets
# called.
h = self.p.protocol
cmd = telnet.IAC + getattr(telnet, cmdName)
L = [b"Here's some bytes, tra la la", b"But ono!" + cmd + b" an interrupt"]
for b in L:
self.p.dataReceived(b)
self.assertEqual(h.calls, [cmdName])
self.assertEqual(h.data, b"".join(L).replace(cmd, b""))
def testInterrupt(self):
self._simpleCommandTest("IP")
def testEndOfRecord(self):
self._simpleCommandTest("EOR")
def testNoOperation(self):
self._simpleCommandTest("NOP")
def testDataMark(self):
self._simpleCommandTest("DM")
def testBreak(self):
self._simpleCommandTest("BRK")
def testAbortOutput(self):
self._simpleCommandTest("AO")
def testAreYouThere(self):
self._simpleCommandTest("AYT")
def testEraseCharacter(self):
self._simpleCommandTest("EC")
def testEraseLine(self):
self._simpleCommandTest("EL")
def testGoAhead(self):
self._simpleCommandTest("GA")
def testSubnegotiation(self):
# Send a subnegotiation command and make sure it gets
# parsed and that the correct method is called.
h = self.p.protocol
cmd = telnet.IAC + telnet.SB + b"\x12hello world" + telnet.IAC + telnet.SE
L = [b"These are some bytes but soon" + cmd, b"there will be some more"]
for b in L:
self.p.dataReceived(b)
self.assertEqual(h.data, b"".join(L).replace(cmd, b""))
self.assertEqual(h.subcmd, list(iterbytes(b"hello world")))
def testSubnegotiationWithEmbeddedSE(self):
# Send a subnegotiation command with an embedded SE. Make sure
# that SE gets passed to the correct method.
h = self.p.protocol
cmd = telnet.IAC + telnet.SB + b"\x12" + telnet.SE + telnet.IAC + telnet.SE
L = [b"Some bytes are here" + cmd + b"and here", b"and here"]
for b in L:
self.p.dataReceived(b)
self.assertEqual(h.data, b"".join(L).replace(cmd, b""))
self.assertEqual(h.subcmd, [telnet.SE])
def testBoundarySubnegotiation(self):
# Send a subnegotiation command. Split it at every possible byte boundary
# and make sure it always gets parsed and that it is passed to the correct
# method.
cmd = (
telnet.IAC
+ telnet.SB
+ b"\x12"
+ telnet.SE
+ b"hello"
+ telnet.IAC
+ telnet.SE
)
for i in range(len(cmd)):
h = self.p.protocol = TestProtocol()
h.makeConnection(self.p)
a, b = cmd[:i], cmd[i:]
L = [b"first part" + a, b + b"last part"]
for data in L:
self.p.dataReceived(data)
self.assertEqual(h.data, b"".join(L).replace(cmd, b""))
self.assertEqual(h.subcmd, [telnet.SE] + list(iterbytes(b"hello")))
def _enabledHelper(self, o, eL=[], eR=[], dL=[], dR=[]):
self.assertEqual(o.enabledLocal, eL)
self.assertEqual(o.enabledRemote, eR)
self.assertEqual(o.disabledLocal, dL)
self.assertEqual(o.disabledRemote, dR)
def testRefuseWill(self):
# Try to enable an option. The server should refuse to enable it.
cmd = telnet.IAC + telnet.WILL + b"\x12"
data = b"surrounding bytes" + cmd + b"to spice things up"
self.p.dataReceived(data)
self.assertEqual(self.p.protocol.data, data.replace(cmd, b""))
self.assertEqual(self.t.value(), telnet.IAC + telnet.DONT + b"\x12")
self._enabledHelper(self.p.protocol)
def testRefuseDo(self):
# Try to enable an option. The server should refuse to enable it.
cmd = telnet.IAC + telnet.DO + b"\x12"
data = b"surrounding bytes" + cmd + b"to spice things up"
self.p.dataReceived(data)
self.assertEqual(self.p.protocol.data, data.replace(cmd, b""))
self.assertEqual(self.t.value(), telnet.IAC + telnet.WONT + b"\x12")
self._enabledHelper(self.p.protocol)
def testAcceptDo(self):
# Try to enable an option. The option is in our allowEnable
# list, so we will allow it to be enabled.
cmd = telnet.IAC + telnet.DO + b"\x19"
data = b"padding" + cmd + b"trailer"
h = self.p.protocol
h.localEnableable = (b"\x19",)
self.p.dataReceived(data)
self.assertEqual(self.t.value(), telnet.IAC + telnet.WILL + b"\x19")
self._enabledHelper(h, eL=[b"\x19"])
def testAcceptWill(self):
# Same as testAcceptDo, but reversed.
cmd = telnet.IAC + telnet.WILL + b"\x91"
data = b"header" + cmd + b"padding"
h = self.p.protocol
h.remoteEnableable = (b"\x91",)
self.p.dataReceived(data)
self.assertEqual(self.t.value(), telnet.IAC + telnet.DO + b"\x91")
self._enabledHelper(h, eR=[b"\x91"])
def testAcceptWont(self):
# Try to disable an option. The server must allow any option to
# be disabled at any time. Make sure it disables it and sends
# back an acknowledgement of this.
cmd = telnet.IAC + telnet.WONT + b"\x29"
# Jimmy it - after these two lines, the server will be in a state
# such that it believes the option to have been previously enabled
# via normal negotiation.
s = self.p.getOptionState(b"\x29")
s.him.state = "yes"
data = b"fiddle dee" + cmd
self.p.dataReceived(data)
self.assertEqual(self.p.protocol.data, data.replace(cmd, b""))
self.assertEqual(self.t.value(), telnet.IAC + telnet.DONT + b"\x29")
self.assertEqual(s.him.state, "no")
self._enabledHelper(self.p.protocol, dR=[b"\x29"])
def testAcceptDont(self):
# Try to disable an option. The server must allow any option to
# be disabled at any time. Make sure it disables it and sends
# back an acknowledgement of this.
cmd = telnet.IAC + telnet.DONT + b"\x29"
# Jimmy it - after these two lines, the server will be in a state
# such that it believes the option to have beenp previously enabled
# via normal negotiation.
s = self.p.getOptionState(b"\x29")
s.us.state = "yes"
data = b"fiddle dum " + cmd
self.p.dataReceived(data)
self.assertEqual(self.p.protocol.data, data.replace(cmd, b""))
self.assertEqual(self.t.value(), telnet.IAC + telnet.WONT + b"\x29")
self.assertEqual(s.us.state, "no")
self._enabledHelper(self.p.protocol, dL=[b"\x29"])
def testIgnoreWont(self):
# Try to disable an option. The option is already disabled. The
# server should send nothing in response to this.
cmd = telnet.IAC + telnet.WONT + b"\x47"
data = b"dum de dum" + cmd + b"tra la la"
self.p.dataReceived(data)
self.assertEqual(self.p.protocol.data, data.replace(cmd, b""))
self.assertEqual(self.t.value(), b"")
self._enabledHelper(self.p.protocol)
def testIgnoreDont(self):
# Try to disable an option. The option is already disabled. The
# server should send nothing in response to this. Doing so could
# lead to a negotiation loop.
cmd = telnet.IAC + telnet.DONT + b"\x47"
data = b"dum de dum" + cmd + b"tra la la"
self.p.dataReceived(data)
self.assertEqual(self.p.protocol.data, data.replace(cmd, b""))
self.assertEqual(self.t.value(), b"")
self._enabledHelper(self.p.protocol)
def testIgnoreWill(self):
# Try to enable an option. The option is already enabled. The
# server should send nothing in response to this. Doing so could
# lead to a negotiation loop.
cmd = telnet.IAC + telnet.WILL + b"\x56"
# Jimmy it - after these two lines, the server will be in a state
# such that it believes the option to have been previously enabled
# via normal negotiation.
s = self.p.getOptionState(b"\x56")
s.him.state = "yes"
data = b"tra la la" + cmd + b"dum de dum"
self.p.dataReceived(data)
self.assertEqual(self.p.protocol.data, data.replace(cmd, b""))
self.assertEqual(self.t.value(), b"")
self._enabledHelper(self.p.protocol)
def testIgnoreDo(self):
# Try to enable an option. The option is already enabled. The
# server should send nothing in response to this. Doing so could
# lead to a negotiation loop.
cmd = telnet.IAC + telnet.DO + b"\x56"
# Jimmy it - after these two lines, the server will be in a state
# such that it believes the option to have been previously enabled
# via normal negotiation.
s = self.p.getOptionState(b"\x56")
s.us.state = "yes"
data = b"tra la la" + cmd + b"dum de dum"
self.p.dataReceived(data)
self.assertEqual(self.p.protocol.data, data.replace(cmd, b""))
self.assertEqual(self.t.value(), b"")
self._enabledHelper(self.p.protocol)
def testAcceptedEnableRequest(self):
# Try to enable an option through the user-level API. This
# returns a Deferred that fires when negotiation about the option
# finishes. Make sure it fires, make sure state gets updated
# properly, make sure the result indicates the option was enabled.
d = self.p.do(b"\x42")
h = self.p.protocol
h.remoteEnableable = (b"\x42",)
self.assertEqual(self.t.value(), telnet.IAC + telnet.DO + b"\x42")
self.p.dataReceived(telnet.IAC + telnet.WILL + b"\x42")
d.addCallback(self.assertEqual, True)
d.addCallback(lambda _: self._enabledHelper(h, eR=[b"\x42"]))
return d
def test_refusedEnableRequest(self):
"""
If the peer refuses to enable an option we request it to enable, the
L{Deferred} returned by L{TelnetProtocol.do} fires with an
L{OptionRefused} L{Failure}.
"""
# Try to enable an option through the user-level API. This returns a
# Deferred that fires when negotiation about the option finishes. Make
# sure it fires, make sure state gets updated properly, make sure the
# result indicates the option was enabled.
self.p.protocol.remoteEnableable = (b"\x42",)
d = self.p.do(b"\x42")
self.assertEqual(self.t.value(), telnet.IAC + telnet.DO + b"\x42")
s = self.p.getOptionState(b"\x42")
self.assertEqual(s.him.state, "no")
self.assertEqual(s.us.state, "no")
self.assertTrue(s.him.negotiating)
self.assertFalse(s.us.negotiating)
self.p.dataReceived(telnet.IAC + telnet.WONT + b"\x42")
d = self.assertFailure(d, telnet.OptionRefused)
d.addCallback(lambda ignored: self._enabledHelper(self.p.protocol))
d.addCallback(lambda ignored: self.assertFalse(s.him.negotiating))
return d
def test_refusedEnableOffer(self):
"""
If the peer refuses to allow us to enable an option, the L{Deferred}
returned by L{TelnetProtocol.will} fires with an L{OptionRefused}
L{Failure}.
"""
# Try to offer an option through the user-level API. This returns a
# Deferred that fires when negotiation about the option finishes. Make
# sure it fires, make sure state gets updated properly, make sure the
# result indicates the option was enabled.
self.p.protocol.localEnableable = (b"\x42",)
d = self.p.will(b"\x42")
self.assertEqual(self.t.value(), telnet.IAC + telnet.WILL + b"\x42")
s = self.p.getOptionState(b"\x42")
self.assertEqual(s.him.state, "no")
self.assertEqual(s.us.state, "no")
self.assertFalse(s.him.negotiating)
self.assertTrue(s.us.negotiating)
self.p.dataReceived(telnet.IAC + telnet.DONT + b"\x42")
d = self.assertFailure(d, telnet.OptionRefused)
d.addCallback(lambda ignored: self._enabledHelper(self.p.protocol))
d.addCallback(lambda ignored: self.assertFalse(s.us.negotiating))
return d
def testAcceptedDisableRequest(self):
# Try to disable an option through the user-level API. This
# returns a Deferred that fires when negotiation about the option
# finishes. Make sure it fires, make sure state gets updated
# properly, make sure the result indicates the option was enabled.
s = self.p.getOptionState(b"\x42")
s.him.state = "yes"
d = self.p.dont(b"\x42")
self.assertEqual(self.t.value(), telnet.IAC + telnet.DONT + b"\x42")
self.p.dataReceived(telnet.IAC + telnet.WONT + b"\x42")
d.addCallback(self.assertEqual, True)
d.addCallback(lambda _: self._enabledHelper(self.p.protocol, dR=[b"\x42"]))
return d
def testNegotiationBlocksFurtherNegotiation(self):
# Try to disable an option, then immediately try to enable it, then
# immediately try to disable it. Ensure that the 2nd and 3rd calls
# fail quickly with the right exception.
s = self.p.getOptionState(b"\x24")
s.him.state = "yes"
self.p.dont(b"\x24") # fires after the first line of _final
def _do(x):
d = self.p.do(b"\x24")
return self.assertFailure(d, telnet.AlreadyNegotiating)
def _dont(x):
d = self.p.dont(b"\x24")
return self.assertFailure(d, telnet.AlreadyNegotiating)
def _final(x):
self.p.dataReceived(telnet.IAC + telnet.WONT + b"\x24")
# an assertion that only passes if d2 has fired
self._enabledHelper(self.p.protocol, dR=[b"\x24"])
# Make sure we allow this
self.p.protocol.remoteEnableable = (b"\x24",)
d = self.p.do(b"\x24")
self.p.dataReceived(telnet.IAC + telnet.WILL + b"\x24")
d.addCallback(self.assertEqual, True)
d.addCallback(
lambda _: self._enabledHelper(
self.p.protocol, eR=[b"\x24"], dR=[b"\x24"]
)
)
return d
d = _do(None)
d.addCallback(_dont)
d.addCallback(_final)
return d
def testSuperfluousDisableRequestRaises(self):
# Try to disable a disabled option. Make sure it fails properly.
d = self.p.dont(b"\xab")
return self.assertFailure(d, telnet.AlreadyDisabled)
def testSuperfluousEnableRequestRaises(self):
# Try to disable a disabled option. Make sure it fails properly.
s = self.p.getOptionState(b"\xab")
s.him.state = "yes"
d = self.p.do(b"\xab")
return self.assertFailure(d, telnet.AlreadyEnabled)
def testLostConnectionFailsDeferreds(self):
d1 = self.p.do(b"\x12")
d2 = self.p.do(b"\x23")
d3 = self.p.do(b"\x34")
class TestException(Exception):
pass
self.p.connectionLost(TestException("Total failure!"))
d1 = self.assertFailure(d1, TestException)
d2 = self.assertFailure(d2, TestException)
d3 = self.assertFailure(d3, TestException)
return defer.gatherResults([d1, d2, d3])
class TestTelnet(telnet.Telnet):
"""
A trivial extension of the telnet protocol class useful to unit tests.
"""
def __init__(self):
telnet.Telnet.__init__(self)
self.events = []
def applicationDataReceived(self, data):
"""
Record the given data in C{self.events}.
"""
self.events.append(("bytes", data))
def unhandledCommand(self, command, data):
"""
Record the given command in C{self.events}.
"""
self.events.append(("command", command, data))
def unhandledSubnegotiation(self, command, data):
"""
Record the given subnegotiation command in C{self.events}.
"""
self.events.append(("negotiate", command, data))
class TelnetTests(unittest.TestCase):
"""
Tests for L{telnet.Telnet}.
L{telnet.Telnet} implements the TELNET protocol (RFC 854), including option
and suboption negotiation, and option state tracking.
"""
def setUp(self):
"""
Create an unconnected L{telnet.Telnet} to be used by tests.
"""
self.protocol = TestTelnet()
def test_enableLocal(self):
"""
L{telnet.Telnet.enableLocal} should reject all options, since
L{telnet.Telnet} does not know how to implement any options.
"""
self.assertFalse(self.protocol.enableLocal(b"\0"))
def test_enableRemote(self):
"""
L{telnet.Telnet.enableRemote} should reject all options, since
L{telnet.Telnet} does not know how to implement any options.
"""
self.assertFalse(self.protocol.enableRemote(b"\0"))
def test_disableLocal(self):
"""
It is an error for L{telnet.Telnet.disableLocal} to be called, since
L{telnet.Telnet.enableLocal} will never allow any options to be enabled
locally. If a subclass overrides enableLocal, it must also override
disableLocal.
"""
self.assertRaises(NotImplementedError, self.protocol.disableLocal, b"\0")
def test_disableRemote(self):
"""
It is an error for L{telnet.Telnet.disableRemote} to be called, since
L{telnet.Telnet.enableRemote} will never allow any options to be
enabled remotely. If a subclass overrides enableRemote, it must also
override disableRemote.
"""
self.assertRaises(NotImplementedError, self.protocol.disableRemote, b"\0")
def test_requestNegotiation(self):
"""
L{telnet.Telnet.requestNegotiation} formats the feature byte and the
payload bytes into the subnegotiation format and sends them.
See RFC 855.
"""
transport = proto_helpers.StringTransport()
self.protocol.makeConnection(transport)
self.protocol.requestNegotiation(b"\x01", b"\x02\x03")
self.assertEqual(
transport.value(),
# IAC SB feature bytes IAC SE
b"\xff\xfa\x01\x02\x03\xff\xf0",
)
def test_requestNegotiationEscapesIAC(self):
"""
If the payload for a subnegotiation includes I{IAC}, it is escaped by
L{telnet.Telnet.requestNegotiation} with another I{IAC}.
See RFC 855.
"""
transport = proto_helpers.StringTransport()
self.protocol.makeConnection(transport)
self.protocol.requestNegotiation(b"\x01", b"\xff")
self.assertEqual(transport.value(), b"\xff\xfa\x01\xff\xff\xff\xf0")
def _deliver(self, data, *expected):
"""
Pass the given bytes to the protocol's C{dataReceived} method and
assert that the given events occur.
"""
received = self.protocol.events = []
self.protocol.dataReceived(data)
self.assertEqual(received, list(expected))
def test_oneApplicationDataByte(self):
"""
One application-data byte in the default state gets delivered right
away.
"""
self._deliver(b"a", ("bytes", b"a"))
def test_twoApplicationDataBytes(self):
"""
Two application-data bytes in the default state get delivered
together.
"""
self._deliver(b"bc", ("bytes", b"bc"))
def test_threeApplicationDataBytes(self):
"""
Three application-data bytes followed by a control byte get
delivered, but the control byte doesn't.
"""
self._deliver(b"def" + telnet.IAC, ("bytes", b"def"))
def test_escapedControl(self):
"""
IAC in the escaped state gets delivered and so does another
application-data byte following it.
"""
self._deliver(telnet.IAC)
self._deliver(telnet.IAC + b"g", ("bytes", telnet.IAC + b"g"))
def test_carriageReturn(self):
"""
A carriage return only puts the protocol into the newline state. A
linefeed in the newline state causes just the newline to be
delivered. A nul in the newline state causes a carriage return to
be delivered. An IAC in the newline state causes a carriage return
to be delivered and puts the protocol into the escaped state.
Anything else causes a carriage return and that thing to be
delivered.
"""
self._deliver(b"\r")
self._deliver(b"\n", ("bytes", b"\n"))
self._deliver(b"\r\n", ("bytes", b"\n"))
self._deliver(b"\r")
self._deliver(b"\0", ("bytes", b"\r"))
self._deliver(b"\r\0", ("bytes", b"\r"))
self._deliver(b"\r")
self._deliver(b"a", ("bytes", b"\ra"))
self._deliver(b"\ra", ("bytes", b"\ra"))
self._deliver(b"\r")
self._deliver(
telnet.IAC + telnet.IAC + b"x", ("bytes", b"\r" + telnet.IAC + b"x")
)
def test_applicationDataBeforeSimpleCommand(self):
"""
Application bytes received before a command are delivered before the
command is processed.
"""
self._deliver(
b"x" + telnet.IAC + telnet.NOP,
("bytes", b"x"),
("command", telnet.NOP, None),
)
def test_applicationDataBeforeCommand(self):
"""
Application bytes received before a WILL/WONT/DO/DONT are delivered
before the command is processed.
"""
self.protocol.commandMap = {}
self._deliver(
b"y" + telnet.IAC + telnet.WILL + b"\x00",
("bytes", b"y"),
("command", telnet.WILL, b"\x00"),
)
def test_applicationDataBeforeSubnegotiation(self):
"""
Application bytes received before a subnegotiation command are
delivered before the negotiation is processed.
"""
self._deliver(
b"z" + telnet.IAC + telnet.SB + b"Qx" + telnet.IAC + telnet.SE,
("bytes", b"z"),
("negotiate", b"Q", [b"x"]),
)

View File

@@ -0,0 +1,118 @@
# -*- test-case-name: twisted.conch.test.test_text -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
from twisted.conch.insults import text
from twisted.conch.insults.text import attributes as A
from twisted.trial import unittest
class FormattedTextTests(unittest.TestCase):
"""
Tests for assembling formatted text.
"""
def test_trivial(self) -> None:
"""
Using no formatting attributes produces no VT102 control sequences in
the flattened output.
"""
self.assertEqual(
text.assembleFormattedText(A.normal["Hello, world."]), "Hello, world."
)
def test_bold(self) -> None:
"""
The bold formatting attribute, L{A.bold}, emits the VT102 control
sequence to enable bold when flattened.
"""
self.assertEqual(
text.assembleFormattedText(A.bold["Hello, world."]), "\x1b[1mHello, world."
)
def test_underline(self) -> None:
"""
The underline formatting attribute, L{A.underline}, emits the VT102
control sequence to enable underlining when flattened.
"""
self.assertEqual(
text.assembleFormattedText(A.underline["Hello, world."]),
"\x1b[4mHello, world.",
)
def test_blink(self) -> None:
"""
The blink formatting attribute, L{A.blink}, emits the VT102 control
sequence to enable blinking when flattened.
"""
self.assertEqual(
text.assembleFormattedText(A.blink["Hello, world."]), "\x1b[5mHello, world."
)
def test_reverseVideo(self) -> None:
"""
The reverse-video formatting attribute, L{A.reverseVideo}, emits the
VT102 control sequence to enable reversed video when flattened.
"""
self.assertEqual(
text.assembleFormattedText(A.reverseVideo["Hello, world."]),
"\x1b[7mHello, world.",
)
def test_minus(self) -> None:
"""
Formatting attributes prefixed with a minus (C{-}) temporarily disable
the prefixed attribute, emitting no VT102 control sequence to enable
it in the flattened output.
"""
self.assertEqual(
text.assembleFormattedText(
A.bold[A.blink["Hello", -A.bold[" world"], "."]]
),
"\x1b[1;5mHello\x1b[0;5m world\x1b[1;5m.",
)
def test_foreground(self) -> None:
"""
The foreground color formatting attribute, L{A.fg}, emits the VT102
control sequence to set the selected foreground color when flattened.
"""
self.assertEqual(
text.assembleFormattedText(
A.normal[A.fg.red["Hello, "], A.fg.green["world!"]]
),
"\x1b[31mHello, \x1b[32mworld!",
)
def test_background(self) -> None:
"""
The background color formatting attribute, L{A.bg}, emits the VT102
control sequence to set the selected background color when flattened.
"""
self.assertEqual(
text.assembleFormattedText(
A.normal[A.bg.red["Hello, "], A.bg.green["world!"]]
),
"\x1b[41mHello, \x1b[42mworld!",
)
def test_flattenDeprecated(self) -> None:
"""
L{twisted.conch.insults.text.flatten} emits a deprecation warning when
imported or accessed.
"""
warningsShown = self.flushWarnings([self.test_flattenDeprecated])
self.assertEqual(len(warningsShown), 0)
# Trigger the deprecation warning.
text.flatten
warningsShown = self.flushWarnings([self.test_flattenDeprecated])
self.assertEqual(len(warningsShown), 1)
self.assertEqual(warningsShown[0]["category"], DeprecationWarning)
self.assertEqual(
warningsShown[0]["message"],
"twisted.conch.insults.text.flatten was deprecated in Twisted "
"13.1.0: Use twisted.conch.insults.text.assembleFormattedText "
"instead.",
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,142 @@
# -*- test-case-name: twisted.conch.test.test_unix -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
from zope.interface import implementer
from twisted.conch.interfaces import IConchUser
from twisted.cred.checkers import (
AllowAnonymousAccess,
InMemoryUsernamePasswordDatabaseDontUse,
)
from twisted.cred.credentials import (
Anonymous,
IAnonymous,
IUsernamePassword,
UsernamePassword,
)
from twisted.cred.error import LoginDenied
from twisted.cred.portal import Portal
from twisted.internet.interfaces import IReactorProcess
from twisted.python.fakepwd import UserDatabase
from twisted.python.reflect import requireModule
from twisted.trial import unittest
from .test_session import StubClient, StubConnection
cryptography = requireModule("cryptography")
unix = requireModule("twisted.conch.unix")
if unix is not None:
from twisted.conch.unix import UnixConchUser, UnixSSHRealm
@implementer(IReactorProcess)
class MockProcessSpawner:
"""
An L{IReactorProcess} that logs calls to C{spawnProcess}.
"""
def __init__(self):
self._spawnProcessCalls = []
def spawnProcess(
self,
processProtocol,
executable,
args=(),
env={},
path=None,
uid=None,
gid=None,
usePTY=0,
childFDs=None,
):
"""
Log a call to C{spawnProcess}. Do not actually spawn a process.
"""
self._spawnProcessCalls.append(
{
"processProtocol": processProtocol,
"executable": executable,
"args": args,
"env": env,
"path": path,
"uid": uid,
"gid": gid,
"usePTY": usePTY,
"childFDs": childFDs,
}
)
shouldSkip = (
"Cannot run without cryptography"
if cryptography is None
else "Unix system required"
if unix is None
else None
)
class TestSSHSessionForUnixConchUser(unittest.TestCase):
skip = shouldSkip
def testExecCommandEnvironment(self) -> None:
"""
C{execCommand} sets the C{HOME} environment variable to the avatar's home
directory.
"""
userdb = UserDatabase()
homeDirectory = "/made/up/path/"
userName = "user"
userdb.addUser(userName, home=homeDirectory)
self.patch(unix, "pwd", userdb)
mockReactor = MockProcessSpawner()
avatar = UnixConchUser(userName)
avatar.conn = StubConnection(transport=StubClient())
session = unix.SSHSessionForUnixConchUser(avatar, reactor=mockReactor)
protocol = None
command = ["not-actually-executed"]
session.execCommand(protocol, command)
[call] = mockReactor._spawnProcessCalls
self.assertEqual(homeDirectory, call["env"]["HOME"])
class TestUnixSSHRealm(unittest.TestCase):
"""
Tests for L{UnixSSHRealm}.
"""
skip = shouldSkip
def test_unixSSHRealm(self) -> None:
"""
L{UnixSSHRealm} is an L{IRealm} whose C{.requestAvatar} method returns
a L{UnixConchUser}.
"""
userdb = UserDatabase()
home = "/testing/home/value"
userdb.addUser("user", home=home)
self.patch(unix, "pwd", userdb)
pwdb = InMemoryUsernamePasswordDatabaseDontUse(user=b"password")
p = Portal(UnixSSHRealm(), [pwdb])
# there seems to be a bug in mypy-zope where sometimes things don't
# implement their superinterfaces; 0.3.11, when we upgrade to 0.9.0
# this type declaration will be extraneous
creds: IUsernamePassword = UsernamePassword(b"user", b"password")
result = p.login(creds, None, IConchUser)
resultInterface, avatar, logout = self.successResultOf(result)
self.assertIsInstance(avatar, UnixConchUser)
assert isinstance(avatar, UnixConchUser) # legibility for mypy
self.assertEqual(avatar.getHomeDir(), home)
def test_unixSSHRefusesAnonymousLogins(self) -> None:
"""
L{UnixSSHRealm} will refuse anonymous logins.
"""
p = Portal(UnixSSHRealm(), [AllowAnonymousAccess()])
result = p.login(IAnonymous(Anonymous()), None, IConchUser)
loginDenied = self.failureResultOf(result)
self.assertIsInstance(loginDenied.value, LoginDenied)

View File

@@ -0,0 +1,974 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for the implementation of the ssh-userauth service.
Maintainer: Paul Swartz
"""
from types import ModuleType
from typing import Optional
from zope.interface import implementer
from twisted.conch.error import ConchError, ValidPublicKey
from twisted.cred.checkers import ICredentialsChecker
from twisted.cred.credentials import IAnonymous, ISSHPrivateKey, IUsernamePassword
from twisted.cred.error import UnauthorizedLogin
from twisted.cred.portal import IRealm, Portal
from twisted.internet import defer, task
from twisted.protocols import loopback
from twisted.python.reflect import requireModule
from twisted.trial import unittest
keys: Optional[ModuleType] = None
if requireModule("cryptography"):
from twisted.conch.checkers import SSHProtocolChecker
from twisted.conch.ssh import keys, transport, userauth
from twisted.conch.ssh.common import NS
from twisted.conch.test import keydata
else:
class transport: # type: ignore[no-redef]
class SSHTransportBase:
"""
A stub class so that later class definitions won't die.
"""
class userauth: # type: ignore[no-redef]
class SSHUserAuthClient:
"""
A stub class so that later class definitions won't die.
"""
class ClientUserAuth(userauth.SSHUserAuthClient):
"""
A mock user auth client.
"""
def getPublicKey(self):
"""
If this is the first time we've been called, return a blob for
the DSA key. Otherwise, return a blob
for the RSA key.
"""
if self.lastPublicKey:
return keys.Key.fromString(keydata.publicRSA_openssh)
else:
return defer.succeed(keys.Key.fromString(keydata.publicDSA_openssh))
def getPrivateKey(self):
"""
Return the private key object for the RSA key.
"""
return defer.succeed(keys.Key.fromString(keydata.privateRSA_openssh))
def getPassword(self, prompt=None):
"""
Return 'foo' as the password.
"""
return defer.succeed(b"foo")
def getGenericAnswers(self, name, information, answers):
"""
Return 'foo' as the answer to two questions.
"""
return defer.succeed(("foo", "foo"))
class OldClientAuth(userauth.SSHUserAuthClient):
"""
The old SSHUserAuthClient returned a cryptography key object from
getPrivateKey() and a string from getPublicKey
"""
def getPrivateKey(self):
return defer.succeed(keys.Key.fromString(keydata.privateRSA_openssh).keyObject)
def getPublicKey(self):
return keys.Key.fromString(keydata.publicRSA_openssh).blob()
class ClientAuthWithoutPrivateKey(userauth.SSHUserAuthClient):
"""
This client doesn't have a private key, but it does have a public key.
"""
def getPrivateKey(self):
return
def getPublicKey(self):
return keys.Key.fromString(keydata.publicRSA_openssh)
class FakeTransport(transport.SSHTransportBase):
"""
L{userauth.SSHUserAuthServer} expects an SSH transport which has a factory
attribute which has a portal attribute. Because the portal is important for
testing authentication, we need to be able to provide an interesting portal
object to the L{SSHUserAuthServer}.
In addition, we want to be able to capture any packets sent over the
transport.
@ivar packets: a list of 2-tuples: (messageType, data). Each 2-tuple is
a sent packet.
@type packets: C{list}
@param lostConnecion: True if loseConnection has been called on us.
@type lostConnection: L{bool}
"""
class Service:
"""
A mock service, representing the other service offered by the server.
"""
name = b"nancy"
def serviceStarted(self):
pass
class Factory:
"""
A mock factory, representing the factory that spawned this user auth
service.
"""
def getService(self, transport, service):
"""
Return our fake service.
"""
if service == b"none":
return FakeTransport.Service
def __init__(self, portal):
self.factory = self.Factory()
self.factory.portal = portal
self.lostConnection = False
self.transport = self
self.packets = []
def sendPacket(self, messageType, message):
"""
Record the packet sent by the service.
"""
self.packets.append((messageType, message))
def isEncrypted(self, direction):
"""
Pretend that this transport encrypts traffic in both directions. The
SSHUserAuthServer disables password authentication if the transport
isn't encrypted.
"""
return True
def loseConnection(self):
self.lostConnection = True
@implementer(IRealm)
class Realm:
"""
A mock realm for testing L{userauth.SSHUserAuthServer}.
This realm is not actually used in the course of testing, so it returns the
simplest thing that could possibly work.
"""
def requestAvatar(self, avatarId, mind, *interfaces):
return defer.succeed((interfaces[0], None, lambda: None))
@implementer(ICredentialsChecker)
class PasswordChecker:
"""
A very simple username/password checker which authenticates anyone whose
password matches their username and rejects all others.
"""
credentialInterfaces = (IUsernamePassword,)
def requestAvatarId(self, creds):
if creds.username == creds.password:
return defer.succeed(creds.username)
return defer.fail(UnauthorizedLogin("Invalid username/password pair"))
@implementer(ICredentialsChecker)
class PrivateKeyChecker:
"""
A very simple public key checker which authenticates anyone whose
public/private keypair is the same keydata.public/privateRSA_openssh.
"""
credentialInterfaces = (ISSHPrivateKey,)
def requestAvatarId(self, creds):
if creds.blob == keys.Key.fromString(keydata.publicRSA_openssh).blob():
if creds.signature is not None:
obj = keys.Key.fromString(creds.blob)
if obj.verify(creds.signature, creds.sigData):
return creds.username
else:
raise ValidPublicKey()
raise UnauthorizedLogin()
@implementer(ICredentialsChecker)
class AnonymousChecker:
"""
A simple checker which isn't supported by L{SSHUserAuthServer}.
"""
credentialInterfaces = (IAnonymous,)
def requestAvatarId(self, credentials):
# ICredentialsChecker.requestAvatarId
pass
class SSHUserAuthServerTests(unittest.TestCase):
"""
Tests for SSHUserAuthServer.
"""
if keys is None:
skip = "cannot run without cryptography"
def setUp(self):
self.realm = Realm()
self.portal = Portal(self.realm)
self.portal.registerChecker(PasswordChecker())
self.portal.registerChecker(PrivateKeyChecker())
self.authServer = userauth.SSHUserAuthServer()
self.authServer.transport = FakeTransport(self.portal)
self.authServer.serviceStarted()
self.authServer.supportedAuthentications.sort() # give a consistent
# order
def tearDown(self):
self.authServer.serviceStopped()
self.authServer = None
def _checkFailed(self, ignored):
"""
Check that the authentication has failed.
"""
self.assertEqual(
self.authServer.transport.packets[-1],
(userauth.MSG_USERAUTH_FAILURE, NS(b"password,publickey") + b"\x00"),
)
def test_noneAuthentication(self):
"""
A client may request a list of authentication 'method name' values
that may continue by using the "none" authentication 'method name'.
See RFC 4252 Section 5.2.
"""
d = self.authServer.ssh_USERAUTH_REQUEST(
NS(b"foo") + NS(b"service") + NS(b"none")
)
return d.addCallback(self._checkFailed)
def test_successfulPasswordAuthentication(self):
"""
When provided with correct password authentication information, the
server should respond by sending a MSG_USERAUTH_SUCCESS message with
no other data.
See RFC 4252, Section 5.1.
"""
packet = b"".join([NS(b"foo"), NS(b"none"), NS(b"password"), b"\0", NS(b"foo")])
d = self.authServer.ssh_USERAUTH_REQUEST(packet)
def check(ignored):
self.assertEqual(
self.authServer.transport.packets,
[(userauth.MSG_USERAUTH_SUCCESS, b"")],
)
return d.addCallback(check)
def test_failedPasswordAuthentication(self):
"""
When provided with invalid authentication details, the server should
respond by sending a MSG_USERAUTH_FAILURE message which states whether
the authentication was partially successful, and provides other, open
options for authentication.
See RFC 4252, Section 5.1.
"""
# packet = username, next_service, authentication type, FALSE, password
packet = b"".join([NS(b"foo"), NS(b"none"), NS(b"password"), b"\0", NS(b"bar")])
self.authServer.clock = task.Clock()
d = self.authServer.ssh_USERAUTH_REQUEST(packet)
self.assertEqual(self.authServer.transport.packets, [])
self.authServer.clock.advance(2)
return d.addCallback(self._checkFailed)
def test_successfulPrivateKeyAuthentication(self):
"""
Test that private key authentication completes successfully,
"""
blob = keys.Key.fromString(keydata.publicRSA_openssh).blob()
obj = keys.Key.fromString(keydata.privateRSA_openssh)
packet = (
NS(b"foo")
+ NS(b"none")
+ NS(b"publickey")
+ b"\xff"
+ NS(obj.sshType())
+ NS(blob)
)
self.authServer.transport.sessionID = b"test"
signature = obj.sign(
NS(b"test") + bytes((userauth.MSG_USERAUTH_REQUEST,)) + packet
)
packet += NS(signature)
d = self.authServer.ssh_USERAUTH_REQUEST(packet)
def check(ignored):
self.assertEqual(
self.authServer.transport.packets,
[(userauth.MSG_USERAUTH_SUCCESS, b"")],
)
return d.addCallback(check)
def test_requestRaisesConchError(self):
"""
ssh_USERAUTH_REQUEST should raise a ConchError if tryAuth returns
None. Added to catch a bug noticed by pyflakes.
"""
d = defer.Deferred()
def mockCbFinishedAuth(self, ignored):
self.fail("request should have raised ConochError")
def mockTryAuth(kind, user, data):
return None
def mockEbBadAuth(reason):
d.errback(reason.value)
self.patch(self.authServer, "tryAuth", mockTryAuth)
self.patch(self.authServer, "_cbFinishedAuth", mockCbFinishedAuth)
self.patch(self.authServer, "_ebBadAuth", mockEbBadAuth)
packet = NS(b"user") + NS(b"none") + NS(b"public-key") + NS(b"data")
# If an error other than ConchError is raised, this will trigger an
# exception.
self.authServer.ssh_USERAUTH_REQUEST(packet)
return self.assertFailure(d, ConchError)
def test_verifyValidPrivateKey(self):
"""
Test that verifying a valid private key works.
"""
blob = keys.Key.fromString(keydata.publicRSA_openssh).blob()
packet = (
NS(b"foo")
+ NS(b"none")
+ NS(b"publickey")
+ b"\x00"
+ NS(b"ssh-rsa")
+ NS(blob)
)
d = self.authServer.ssh_USERAUTH_REQUEST(packet)
def check(ignored):
self.assertEqual(
self.authServer.transport.packets,
[(userauth.MSG_USERAUTH_PK_OK, NS(b"ssh-rsa") + NS(blob))],
)
return d.addCallback(check)
def test_failedPrivateKeyAuthenticationWithoutSignature(self):
"""
Test that private key authentication fails when the public key
is invalid.
"""
blob = keys.Key.fromString(keydata.publicDSA_openssh).blob()
packet = (
NS(b"foo")
+ NS(b"none")
+ NS(b"publickey")
+ b"\x00"
+ NS(b"ssh-dsa")
+ NS(blob)
)
d = self.authServer.ssh_USERAUTH_REQUEST(packet)
return d.addCallback(self._checkFailed)
def test_failedPrivateKeyAuthenticationWithSignature(self):
"""
Test that private key authentication fails when the public key
is invalid.
"""
blob = keys.Key.fromString(keydata.publicRSA_openssh).blob()
obj = keys.Key.fromString(keydata.privateRSA_openssh)
packet = (
NS(b"foo")
+ NS(b"none")
+ NS(b"publickey")
+ b"\xff"
+ NS(b"ssh-rsa")
+ NS(blob)
+ NS(obj.sign(blob))
)
self.authServer.transport.sessionID = b"test"
d = self.authServer.ssh_USERAUTH_REQUEST(packet)
return d.addCallback(self._checkFailed)
def test_unsupported_publickey(self):
"""
Private key authentication fails when the public key type is
unsupported or the public key is corrupt.
"""
blob = keys.Key.fromString(keydata.publicDSA_openssh).blob()
# Change the blob to a bad type
blob = NS(b"ssh-bad-type") + blob[11:]
packet = (
NS(b"foo")
+ NS(b"none")
+ NS(b"publickey")
+ b"\x00"
+ NS(b"ssh-rsa")
+ NS(blob)
)
d = self.authServer.ssh_USERAUTH_REQUEST(packet)
return d.addCallback(self._checkFailed)
def test_ignoreUnknownCredInterfaces(self):
"""
L{SSHUserAuthServer} sets up
C{SSHUserAuthServer.supportedAuthentications} by checking the portal's
credentials interfaces and mapping them to SSH authentication method
strings. If the Portal advertises an interface that
L{SSHUserAuthServer} can't map, it should be ignored. This is a white
box test.
"""
server = userauth.SSHUserAuthServer()
server.transport = FakeTransport(self.portal)
self.portal.registerChecker(AnonymousChecker())
server.serviceStarted()
server.serviceStopped()
server.supportedAuthentications.sort() # give a consistent order
self.assertEqual(server.supportedAuthentications, [b"password", b"publickey"])
def test_removePasswordIfUnencrypted(self):
"""
Test that the userauth service does not advertise password
authentication if the password would be send in cleartext.
"""
self.assertIn(b"password", self.authServer.supportedAuthentications)
# no encryption
clearAuthServer = userauth.SSHUserAuthServer()
clearAuthServer.transport = FakeTransport(self.portal)
clearAuthServer.transport.isEncrypted = lambda x: False
clearAuthServer.serviceStarted()
clearAuthServer.serviceStopped()
self.assertNotIn(b"password", clearAuthServer.supportedAuthentications)
# only encrypt incoming (the direction the password is sent)
halfAuthServer = userauth.SSHUserAuthServer()
halfAuthServer.transport = FakeTransport(self.portal)
halfAuthServer.transport.isEncrypted = lambda x: x == "in"
halfAuthServer.serviceStarted()
halfAuthServer.serviceStopped()
self.assertIn(b"password", halfAuthServer.supportedAuthentications)
def test_unencryptedConnectionWithoutPasswords(self):
"""
If the L{SSHUserAuthServer} is not advertising passwords, then an
unencrypted connection should not cause any warnings or exceptions.
This is a white box test.
"""
# create a Portal without password authentication
portal = Portal(self.realm)
portal.registerChecker(PrivateKeyChecker())
# no encryption
clearAuthServer = userauth.SSHUserAuthServer()
clearAuthServer.transport = FakeTransport(portal)
clearAuthServer.transport.isEncrypted = lambda x: False
clearAuthServer.serviceStarted()
clearAuthServer.serviceStopped()
self.assertEqual(clearAuthServer.supportedAuthentications, [b"publickey"])
# only encrypt incoming (the direction the password is sent)
halfAuthServer = userauth.SSHUserAuthServer()
halfAuthServer.transport = FakeTransport(portal)
halfAuthServer.transport.isEncrypted = lambda x: x == "in"
halfAuthServer.serviceStarted()
halfAuthServer.serviceStopped()
self.assertEqual(clearAuthServer.supportedAuthentications, [b"publickey"])
def test_loginTimeout(self):
"""
Test that the login times out.
"""
timeoutAuthServer = userauth.SSHUserAuthServer()
timeoutAuthServer.clock = task.Clock()
timeoutAuthServer.transport = FakeTransport(self.portal)
timeoutAuthServer.serviceStarted()
timeoutAuthServer.clock.advance(11 * 60 * 60)
timeoutAuthServer.serviceStopped()
self.assertEqual(
timeoutAuthServer.transport.packets,
[
(
transport.MSG_DISCONNECT,
b"\x00" * 3
+ bytes((transport.DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE,))
+ NS(b"you took too long")
+ NS(b""),
)
],
)
self.assertTrue(timeoutAuthServer.transport.lostConnection)
def test_cancelLoginTimeout(self):
"""
Test that stopping the service also stops the login timeout.
"""
timeoutAuthServer = userauth.SSHUserAuthServer()
timeoutAuthServer.clock = task.Clock()
timeoutAuthServer.transport = FakeTransport(self.portal)
timeoutAuthServer.serviceStarted()
timeoutAuthServer.serviceStopped()
timeoutAuthServer.clock.advance(11 * 60 * 60)
self.assertEqual(timeoutAuthServer.transport.packets, [])
self.assertFalse(timeoutAuthServer.transport.lostConnection)
def test_tooManyAttempts(self):
"""
Test that the server disconnects if the client fails authentication
too many times.
"""
packet = b"".join([NS(b"foo"), NS(b"none"), NS(b"password"), b"\0", NS(b"bar")])
self.authServer.clock = task.Clock()
for i in range(21):
d = self.authServer.ssh_USERAUTH_REQUEST(packet)
self.authServer.clock.advance(2)
def check(ignored):
self.assertEqual(
self.authServer.transport.packets[-1],
(
transport.MSG_DISCONNECT,
b"\x00" * 3
+ bytes((transport.DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE,))
+ NS(b"too many bad auths")
+ NS(b""),
),
)
return d.addCallback(check)
def test_failIfUnknownService(self):
"""
If the user requests a service that we don't support, the
authentication should fail.
"""
packet = NS(b"foo") + NS(b"") + NS(b"password") + b"\0" + NS(b"foo")
self.authServer.clock = task.Clock()
d = self.authServer.ssh_USERAUTH_REQUEST(packet)
return d.addCallback(self._checkFailed)
def test_tryAuthEdgeCases(self):
"""
tryAuth() has two edge cases that are difficult to reach.
1) an authentication method auth_* returns None instead of a Deferred.
2) an authentication type that is defined does not have a matching
auth_* method.
Both these cases should return a Deferred which fails with a
ConchError.
"""
def mockAuth(packet):
return None
self.patch(self.authServer, "auth_publickey", mockAuth) # first case
self.patch(self.authServer, "auth_password", None) # second case
def secondTest(ignored):
d2 = self.authServer.tryAuth(b"password", None, None)
return self.assertFailure(d2, ConchError)
d1 = self.authServer.tryAuth(b"publickey", None, None)
return self.assertFailure(d1, ConchError).addCallback(secondTest)
class SSHUserAuthClientTests(unittest.TestCase):
"""
Tests for SSHUserAuthClient.
"""
if keys is None:
skip = "cannot run without cryptography"
def setUp(self):
self.authClient = ClientUserAuth(b"foo", FakeTransport.Service())
self.authClient.transport = FakeTransport(None)
self.authClient.transport.sessionID = b"test"
self.authClient.serviceStarted()
def tearDown(self):
self.authClient.serviceStopped()
self.authClient = None
def test_init(self):
"""
Test that client is initialized properly.
"""
self.assertEqual(self.authClient.user, b"foo")
self.assertEqual(self.authClient.instance.name, b"nancy")
self.assertEqual(
self.authClient.transport.packets,
[(userauth.MSG_USERAUTH_REQUEST, NS(b"foo") + NS(b"nancy") + NS(b"none"))],
)
def test_USERAUTH_SUCCESS(self):
"""
Test that the client succeeds properly.
"""
instance = [None]
def stubSetService(service):
instance[0] = service
self.authClient.transport.setService = stubSetService
self.authClient.ssh_USERAUTH_SUCCESS(b"")
self.assertEqual(instance[0], self.authClient.instance)
def test_publickey(self):
"""
Test that the client can authenticate with a public key.
"""
self.authClient.ssh_USERAUTH_FAILURE(NS(b"publickey") + b"\x00")
self.assertEqual(
self.authClient.transport.packets[-1],
(
userauth.MSG_USERAUTH_REQUEST,
NS(b"foo")
+ NS(b"nancy")
+ NS(b"publickey")
+ b"\x00"
+ NS(b"ssh-dss")
+ NS(keys.Key.fromString(keydata.publicDSA_openssh).blob()),
),
)
# that key isn't good
self.authClient.ssh_USERAUTH_FAILURE(NS(b"publickey") + b"\x00")
blob = NS(keys.Key.fromString(keydata.publicRSA_openssh).blob())
self.assertEqual(
self.authClient.transport.packets[-1],
(
userauth.MSG_USERAUTH_REQUEST,
(
NS(b"foo")
+ NS(b"nancy")
+ NS(b"publickey")
+ b"\x00"
+ NS(b"ssh-rsa")
+ blob
),
),
)
self.authClient.ssh_USERAUTH_PK_OK(
NS(b"ssh-rsa") + NS(keys.Key.fromString(keydata.publicRSA_openssh).blob())
)
sigData = (
NS(self.authClient.transport.sessionID)
+ bytes((userauth.MSG_USERAUTH_REQUEST,))
+ NS(b"foo")
+ NS(b"nancy")
+ NS(b"publickey")
+ b"\x01"
+ NS(b"ssh-rsa")
+ blob
)
obj = keys.Key.fromString(keydata.privateRSA_openssh)
self.assertEqual(
self.authClient.transport.packets[-1],
(
userauth.MSG_USERAUTH_REQUEST,
NS(b"foo")
+ NS(b"nancy")
+ NS(b"publickey")
+ b"\x01"
+ NS(b"ssh-rsa")
+ blob
+ NS(obj.sign(sigData)),
),
)
def test_publickey_without_privatekey(self):
"""
If the SSHUserAuthClient doesn't return anything from signData,
the client should start the authentication over again by requesting
'none' authentication.
"""
authClient = ClientAuthWithoutPrivateKey(b"foo", FakeTransport.Service())
authClient.transport = FakeTransport(None)
authClient.transport.sessionID = b"test"
authClient.serviceStarted()
authClient.tryAuth(b"publickey")
authClient.transport.packets = []
self.assertIsNone(authClient.ssh_USERAUTH_PK_OK(b""))
self.assertEqual(
authClient.transport.packets,
[(userauth.MSG_USERAUTH_REQUEST, NS(b"foo") + NS(b"nancy") + NS(b"none"))],
)
def test_no_publickey(self):
"""
If there's no public key, auth_publickey should return a Deferred
called back with a False value.
"""
self.authClient.getPublicKey = lambda x: None
d = self.authClient.tryAuth(b"publickey")
def check(result):
self.assertFalse(result)
return d.addCallback(check)
def test_password(self):
"""
Test that the client can authentication with a password. This
includes changing the password.
"""
self.authClient.ssh_USERAUTH_FAILURE(NS(b"password") + b"\x00")
self.assertEqual(
self.authClient.transport.packets[-1],
(
userauth.MSG_USERAUTH_REQUEST,
NS(b"foo") + NS(b"nancy") + NS(b"password") + b"\x00" + NS(b"foo"),
),
)
self.authClient.ssh_USERAUTH_PK_OK(NS(b"") + NS(b""))
self.assertEqual(
self.authClient.transport.packets[-1],
(
userauth.MSG_USERAUTH_REQUEST,
NS(b"foo") + NS(b"nancy") + NS(b"password") + b"\xff" + NS(b"foo") * 2,
),
)
def test_no_password(self):
"""
If getPassword returns None, tryAuth should return False.
"""
self.authClient.getPassword = lambda: None
self.assertFalse(self.authClient.tryAuth(b"password"))
def test_keyboardInteractive(self):
"""
Make sure that the client can authenticate with the keyboard
interactive method.
"""
self.authClient.ssh_USERAUTH_PK_OK_keyboard_interactive(
NS(b"")
+ NS(b"")
+ NS(b"")
+ b"\x00\x00\x00\x01"
+ NS(b"Password: ")
+ b"\x00"
)
self.assertEqual(
self.authClient.transport.packets[-1],
(
userauth.MSG_USERAUTH_INFO_RESPONSE,
b"\x00\x00\x00\x02" + NS(b"foo") + NS(b"foo"),
),
)
def test_USERAUTH_PK_OK_unknown_method(self):
"""
If C{SSHUserAuthClient} gets a MSG_USERAUTH_PK_OK packet when it's not
expecting it, it should fail the current authentication and move on to
the next type.
"""
self.authClient.lastAuth = b"unknown"
self.authClient.transport.packets = []
self.authClient.ssh_USERAUTH_PK_OK(b"")
self.assertEqual(
self.authClient.transport.packets,
[(userauth.MSG_USERAUTH_REQUEST, NS(b"foo") + NS(b"nancy") + NS(b"none"))],
)
def test_USERAUTH_FAILURE_sorting(self):
"""
ssh_USERAUTH_FAILURE should sort the methods by their position
in SSHUserAuthClient.preferredOrder. Methods that are not in
preferredOrder should be sorted at the end of that list.
"""
def auth_firstmethod():
self.authClient.transport.sendPacket(255, b"here is data")
def auth_anothermethod():
self.authClient.transport.sendPacket(254, b"other data")
return True
self.authClient.auth_firstmethod = auth_firstmethod
self.authClient.auth_anothermethod = auth_anothermethod
# although they shouldn't get called, method callbacks auth_* MUST
# exist in order for the test to work properly.
self.authClient.ssh_USERAUTH_FAILURE(NS(b"anothermethod,password") + b"\x00")
# should send password packet
self.assertEqual(
self.authClient.transport.packets[-1],
(
userauth.MSG_USERAUTH_REQUEST,
NS(b"foo") + NS(b"nancy") + NS(b"password") + b"\x00" + NS(b"foo"),
),
)
self.authClient.ssh_USERAUTH_FAILURE(
NS(b"firstmethod,anothermethod,password") + b"\xff"
)
self.assertEqual(
self.authClient.transport.packets[-2:],
[(255, b"here is data"), (254, b"other data")],
)
def test_disconnectIfNoMoreAuthentication(self):
"""
If there are no more available user authentication messages,
the SSHUserAuthClient should disconnect with code
DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE.
"""
self.authClient.ssh_USERAUTH_FAILURE(NS(b"password") + b"\x00")
self.authClient.ssh_USERAUTH_FAILURE(NS(b"password") + b"\xff")
self.assertEqual(
self.authClient.transport.packets[-1],
(
transport.MSG_DISCONNECT,
b"\x00\x00\x00\x0e"
+ NS(b"no more authentication methods available")
+ b"\x00\x00\x00\x00",
),
)
def test_ebAuth(self):
"""
_ebAuth (the generic authentication error handler) should send
a request for the 'none' authentication method.
"""
self.authClient.transport.packets = []
self.authClient._ebAuth(None)
self.assertEqual(
self.authClient.transport.packets,
[(userauth.MSG_USERAUTH_REQUEST, NS(b"foo") + NS(b"nancy") + NS(b"none"))],
)
def test_defaults(self):
"""
getPublicKey() should return None. getPrivateKey() should return a
failed Deferred. getPassword() should return a failed Deferred.
getGenericAnswers() should return a failed Deferred.
"""
authClient = userauth.SSHUserAuthClient(b"foo", FakeTransport.Service())
self.assertIsNone(authClient.getPublicKey())
def check(result):
result.trap(NotImplementedError)
d = authClient.getPassword()
return d.addCallback(self.fail).addErrback(check2)
def check2(result):
result.trap(NotImplementedError)
d = authClient.getGenericAnswers(None, None, None)
return d.addCallback(self.fail).addErrback(check3)
def check3(result):
result.trap(NotImplementedError)
d = authClient.getPrivateKey()
return d.addCallback(self.fail).addErrback(check)
class LoopbackTests(unittest.TestCase):
if keys is None:
skip = "cannot run without cryptography"
class Factory:
class Service:
name = b"TestService"
def serviceStarted(self):
self.transport.loseConnection()
def serviceStopped(self):
pass
def getService(self, avatar, name):
return self.Service
def test_loopback(self):
"""
Test that the userauth server and client play nicely with each other.
"""
server = userauth.SSHUserAuthServer()
client = ClientUserAuth(b"foo", self.Factory.Service())
# set up transports
server.transport = transport.SSHTransportBase()
server.transport.service = server
server.transport.isEncrypted = lambda x: True
client.transport = transport.SSHTransportBase()
client.transport.service = client
server.transport.sessionID = client.transport.sessionID = b""
# don't send key exchange packet
server.transport.sendKexInit = client.transport.sendKexInit = lambda: None
# set up server authentication
server.transport.factory = self.Factory()
server.passwordDelay = 0 # remove bad password delay
realm = Realm()
portal = Portal(realm)
checker = SSHProtocolChecker()
checker.registerChecker(PasswordChecker())
checker.registerChecker(PrivateKeyChecker())
checker.areDone = lambda aId: (len(checker.successfulCredentials[aId]) == 2)
portal.registerChecker(checker)
server.transport.factory.portal = portal
d = loopback.loopbackAsync(server.transport, client.transport)
server.transport.transport.logPrefix = lambda: "_ServerLoopback"
client.transport.transport.logPrefix = lambda: "_ClientLoopback"
server.serviceStarted()
client.serviceStarted()
def check(ignored):
self.assertEqual(server.transport.service.name, b"TestService")
return d.addCallback(check)
class ModuleInitializationTests(unittest.TestCase):
if keys is None:
skip = "cannot run without cryptography"
def test_messages(self):
# Several message types have value 60, check that MSG_USERAUTH_PK_OK
# is always the one which is mapped.
self.assertEqual(
userauth.SSHUserAuthServer.protocolMessages[60], "MSG_USERAUTH_PK_OK"
)
self.assertEqual(
userauth.SSHUserAuthClient.protocolMessages[60], "MSG_USERAUTH_PK_OK"
)

View File

@@ -0,0 +1,172 @@
"""
Tests for the insults windowing module, L{twisted.conch.insults.window}.
"""
from __future__ import annotations
from typing import Callable
from twisted.conch.insults.insults import ServerProtocol
from twisted.conch.insults.window import (
ScrolledArea,
Selection,
TextOutput,
TopWindow,
Widget,
)
from twisted.trial.unittest import TestCase
class TopWindowTests(TestCase):
"""
Tests for L{TopWindow}, the root window container class.
"""
def test_paintScheduling(self) -> None:
"""
Verify that L{TopWindow.repaint} schedules an actual paint to occur
using the scheduling object passed to its initializer.
"""
paints: list[None] = []
scheduled: list[Callable[[], object]] = []
root = TopWindow(lambda: paints.append(None), scheduled.append)
# Nothing should have happened yet.
self.assertEqual(paints, [])
self.assertEqual(scheduled, [])
# Cause a paint to be scheduled.
root.repaint()
self.assertEqual(paints, [])
self.assertEqual(len(scheduled), 1)
# Do another one to verify nothing else happens as long as the previous
# one is still pending.
root.repaint()
self.assertEqual(paints, [])
self.assertEqual(len(scheduled), 1)
# Run the actual paint call.
scheduled.pop()()
self.assertEqual(len(paints), 1)
self.assertEqual(scheduled, [])
# Do one more to verify that now that the previous one is finished
# future paints will succeed.
root.repaint()
self.assertEqual(len(paints), 1)
self.assertEqual(len(scheduled), 1)
class ScrolledAreaTests(TestCase):
"""
Tests for L{ScrolledArea}, a widget which creates a viewport containing
another widget and can reposition that viewport using scrollbars.
"""
def test_parent(self) -> None:
"""
The parent of the widget passed to L{ScrolledArea} is set to a new
L{Viewport} created by the L{ScrolledArea} which itself has the
L{ScrolledArea} instance as its parent.
"""
widget = TextOutput()
scrolled = ScrolledArea(widget)
self.assertIs(widget.parent, scrolled._viewport)
self.assertIs(scrolled._viewport.parent, scrolled)
class SelectionTests(TestCase):
"""
Change focused entry in L{Selection} using function keys.
"""
def setUp(self) -> None:
"""
Create L{ScrolledArea} widget with 10 elements and position selection to 5th element.
"""
seq: list[bytes] = [f"{_num}".encode("ascii") for _num in range(10)]
self.widget = Selection(seq, None)
self.widget.height = 10
self.widget.focusedIndex = 5
def test_selectionDownArrow(self) -> None:
"""
Send DOWN_ARROW to select element just below the current one.
"""
self.widget.keystrokeReceived(ServerProtocol.DOWN_ARROW, None) # type: ignore[attr-defined]
self.assertIs(self.widget.focusedIndex, 6)
def test_selectionUpArrow(self) -> None:
"""
Send UP_ARROW to select element just above the current one.
"""
self.widget.keystrokeReceived(ServerProtocol.UP_ARROW, None) # type: ignore[attr-defined]
self.assertIs(self.widget.focusedIndex, 4)
def test_selectionPGDN(self) -> None:
"""
Send PGDN to select element one page down (here: last element).
"""
self.widget.keystrokeReceived(ServerProtocol.PGDN, None) # type: ignore[attr-defined]
self.assertIs(self.widget.focusedIndex, 9)
def test_selectionPGUP(self) -> None:
"""
Send PGUP to select element one page up (here: first element).
"""
self.widget.keystrokeReceived(ServerProtocol.PGUP, None) # type: ignore[attr-defined]
self.assertIs(self.widget.focusedIndex, 0)
class RecordingWidget(Widget):
"""
A dummy Widget implementation to test handling of function keys by
recording keyReceived events.
"""
def __init__(self) -> None:
Widget.__init__(self)
self.triggered: list[str] = []
def func_F1(self, modifier: str) -> None:
self.triggered.append("F1")
def func_HOME(self, modifier: str) -> None:
self.triggered.append("HOME")
def func_DOWN_ARROW(self, modifier: str) -> None:
self.triggered.append("DOWN_ARROW")
def func_UP_ARROW(self, modifier: str) -> None:
self.triggered.append("UP_ARROW")
def func_PGDN(self, modifier: str) -> None:
self.triggered.append("PGDN")
def func_PGUP(self, modifier: str) -> None:
self.triggered.append("PGUP")
class WidgetFunctionKeyTests(TestCase):
"""
Call functionKeyReceived with key values from insults.ServerProtocol
"""
def test_functionKeyReceivedDispatch(self) -> None:
"""
L{Widget.functionKeyReceived} dispatches its input, a constant on
ServerProtocol, to a matched C{func_KEY} method.
"""
widget = RecordingWidget()
def checkOneKey(key: str) -> None:
widget.functionKeyReceived(getattr(ServerProtocol, key), None)
self.assertEqual([key], widget.triggered)
widget.triggered.clear()
checkOneKey("F1")
checkOneKey("HOME")
checkOneKey("DOWN_ARROW")
checkOneKey("UP_ARROW")
checkOneKey("PGDN")
checkOneKey("PGUP")

View File

@@ -0,0 +1,122 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
#
import tty
# this module was autogenerated.
VINTR = 1
VQUIT = 2
VERASE = 3
VKILL = 4
VEOF = 5
VEOL = 6
VEOL2 = 7
VSTART = 8
VSTOP = 9
VSUSP = 10
VDSUSP = 11
VREPRINT = 12
VWERASE = 13
VLNEXT = 14
VFLUSH = 15
VSWTCH = 16
VSTATUS = 17
VDISCARD = 18
IGNPAR = 30
PARMRK = 31
INPCK = 32
ISTRIP = 33
INLCR = 34
IGNCR = 35
ICRNL = 36
IUCLC = 37
IXON = 38
IXANY = 39
IXOFF = 40
IMAXBEL = 41
ISIG = 50
ICANON = 51
XCASE = 52
ECHO = 53
ECHOE = 54
ECHOK = 55
ECHONL = 56
NOFLSH = 57
TOSTOP = 58
IEXTEN = 59
ECHOCTL = 60
ECHOKE = 61
PENDIN = 62
OPOST = 70
OLCUC = 71
ONLCR = 72
OCRNL = 73
ONOCR = 74
ONLRET = 75
CS7 = 90
CS8 = 91
PARENB = 92
PARODD = 93
TTY_OP_ISPEED = 128
TTY_OP_OSPEED = 129
TTYMODES = {
1: "VINTR",
2: "VQUIT",
3: "VERASE",
4: "VKILL",
5: "VEOF",
6: "VEOL",
7: "VEOL2",
8: "VSTART",
9: "VSTOP",
10: "VSUSP",
11: "VDSUSP",
12: "VREPRINT",
13: "VWERASE",
14: "VLNEXT",
15: "VFLUSH",
16: "VSWTCH",
17: "VSTATUS",
18: "VDISCARD",
30: (tty.IFLAG, "IGNPAR"),
31: (tty.IFLAG, "PARMRK"),
32: (tty.IFLAG, "INPCK"),
33: (tty.IFLAG, "ISTRIP"),
34: (tty.IFLAG, "INLCR"),
35: (tty.IFLAG, "IGNCR"),
36: (tty.IFLAG, "ICRNL"),
37: (tty.IFLAG, "IUCLC"),
38: (tty.IFLAG, "IXON"),
39: (tty.IFLAG, "IXANY"),
40: (tty.IFLAG, "IXOFF"),
41: (tty.IFLAG, "IMAXBEL"),
50: (tty.LFLAG, "ISIG"),
51: (tty.LFLAG, "ICANON"),
52: (tty.LFLAG, "XCASE"),
53: (tty.LFLAG, "ECHO"),
54: (tty.LFLAG, "ECHOE"),
55: (tty.LFLAG, "ECHOK"),
56: (tty.LFLAG, "ECHONL"),
57: (tty.LFLAG, "NOFLSH"),
58: (tty.LFLAG, "TOSTOP"),
59: (tty.LFLAG, "IEXTEN"),
60: (tty.LFLAG, "ECHOCTL"),
61: (tty.LFLAG, "ECHOKE"),
62: (tty.LFLAG, "PENDIN"),
70: (tty.OFLAG, "OPOST"),
71: (tty.OFLAG, "OLCUC"),
72: (tty.OFLAG, "ONLCR"),
73: (tty.OFLAG, "OCRNL"),
74: (tty.OFLAG, "ONOCR"),
75: (tty.OFLAG, "ONLRET"),
# 90 : (tty.CFLAG, 'CS7'),
# 91 : (tty.CFLAG, 'CS8'),
92: (tty.CFLAG, "PARENB"),
93: (tty.CFLAG, "PARODD"),
128: "ISPEED",
129: "OSPEED",
}

View File

@@ -0,0 +1,11 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
#
"""
twisted.conch.ui is home to the UI elements for tkconch.
Maintainer: Paul Swartz
"""

View File

@@ -0,0 +1,253 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
#
"""Module to parse ANSI escape sequences
Maintainer: Jean-Paul Calderone
"""
import string
# Twisted imports
from twisted.logger import Logger
_log = Logger()
class ColorText:
"""
Represents an element of text along with the texts colors and
additional attributes.
"""
# The colors to use
COLORS = ("b", "r", "g", "y", "l", "m", "c", "w")
BOLD_COLORS = tuple(x.upper() for x in COLORS)
BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = range(len(COLORS))
# Color names
COLOR_NAMES = (
"Black",
"Red",
"Green",
"Yellow",
"Blue",
"Magenta",
"Cyan",
"White",
)
def __init__(self, text, fg, bg, display, bold, underline, flash, reverse):
self.text, self.fg, self.bg = text, fg, bg
self.display = display
self.bold = bold
self.underline = underline
self.flash = flash
self.reverse = reverse
if self.reverse:
self.fg, self.bg = self.bg, self.fg
class AnsiParser:
"""
Parser class for ANSI codes.
"""
# Terminators for cursor movement ansi controls - unsupported
CURSOR_SET = ("H", "f", "A", "B", "C", "D", "R", "s", "u", "d", "G")
# Terminators for erasure ansi controls - unsupported
ERASE_SET = ("J", "K", "P")
# Terminators for mode change ansi controls - unsupported
MODE_SET = ("h", "l")
# Terminators for keyboard assignment ansi controls - unsupported
ASSIGN_SET = ("p",)
# Terminators for color change ansi controls - supported
COLOR_SET = ("m",)
SETS = (CURSOR_SET, ERASE_SET, MODE_SET, ASSIGN_SET, COLOR_SET)
def __init__(self, defaultFG, defaultBG):
self.defaultFG, self.defaultBG = defaultFG, defaultBG
self.currentFG, self.currentBG = self.defaultFG, self.defaultBG
self.bold, self.flash, self.underline, self.reverse = 0, 0, 0, 0
self.display = 1
self.prepend = ""
def stripEscapes(self, string):
"""
Remove all ANSI color escapes from the given string.
"""
result = ""
show = 1
i = 0
L = len(string)
while i < L:
if show == 0 and string[i] in _sets:
show = 1
elif show:
n = string.find("\x1B", i)
if n == -1:
return result + string[i:]
else:
result = result + string[i:n]
i = n
show = 0
i = i + 1
return result
def writeString(self, colorstr):
pass
def parseString(self, str):
"""
Turn a string input into a list of L{ColorText} elements.
"""
if self.prepend:
str = self.prepend + str
self.prepend = ""
parts = str.split("\x1B")
if len(parts) == 1:
self.writeString(self.formatText(parts[0]))
else:
self.writeString(self.formatText(parts[0]))
for s in parts[1:]:
L = len(s)
i = 0
type = None
while i < L:
if s[i] not in string.digits + "[;?":
break
i += 1
if not s:
self.prepend = "\x1b"
return
if s[0] != "[":
self.writeString(self.formatText(s[i + 1 :]))
continue
else:
s = s[1:]
i -= 1
if i == L - 1:
self.prepend = "\x1b["
return
type = _setmap.get(s[i], None)
if type is None:
continue
if type == AnsiParser.COLOR_SET:
self.parseColor(s[: i + 1])
s = s[i + 1 :]
self.writeString(self.formatText(s))
elif type == AnsiParser.CURSOR_SET:
cursor, s = s[: i + 1], s[i + 1 :]
self.parseCursor(cursor)
self.writeString(self.formatText(s))
elif type == AnsiParser.ERASE_SET:
erase, s = s[: i + 1], s[i + 1 :]
self.parseErase(erase)
self.writeString(self.formatText(s))
elif type == AnsiParser.MODE_SET:
s = s[i + 1 :]
# self.parseErase('2J')
self.writeString(self.formatText(s))
elif i == L:
self.prepend = "\x1B[" + s
else:
_log.warn(
"Unhandled ANSI control type: {control_type}", control_type=s[i]
)
s = s[i + 1 :]
self.writeString(self.formatText(s))
def parseColor(self, str):
"""
Handle a single ANSI color sequence
"""
# Drop the trailing 'm'
str = str[:-1]
if not str:
str = "0"
try:
parts = map(int, str.split(";"))
except ValueError:
_log.error("Invalid ANSI color sequence: {sequence!r}", sequence=str)
self.currentFG, self.currentBG = self.defaultFG, self.defaultBG
return
for x in parts:
if x == 0:
self.currentFG, self.currentBG = self.defaultFG, self.defaultBG
self.bold, self.flash, self.underline, self.reverse = 0, 0, 0, 0
self.display = 1
elif x == 1:
self.bold = 1
elif 30 <= x <= 37:
self.currentFG = x - 30
elif 40 <= x <= 47:
self.currentBG = x - 40
elif x == 39:
self.currentFG = self.defaultFG
elif x == 49:
self.currentBG = self.defaultBG
elif x == 4:
self.underline = 1
elif x == 5:
self.flash = 1
elif x == 7:
self.reverse = 1
elif x == 8:
self.display = 0
elif x == 22:
self.bold = 0
elif x == 24:
self.underline = 0
elif x == 25:
self.blink = 0
elif x == 27:
self.reverse = 0
elif x == 28:
self.display = 1
else:
_log.error("Unrecognised ANSI color command: {command}", command=x)
def parseCursor(self, cursor):
pass
def parseErase(self, erase):
pass
def pickColor(self, value, mode, BOLD=ColorText.BOLD_COLORS):
if mode:
return ColorText.COLORS[value]
else:
return self.bold and BOLD[value] or ColorText.COLORS[value]
def formatText(self, text):
return ColorText(
text,
self.pickColor(self.currentFG, 0),
self.pickColor(self.currentBG, 1),
self.display,
self.bold,
self.underline,
self.flash,
self.reverse,
)
_sets = "".join(map("".join, AnsiParser.SETS))
_setmap = {}
for s in AnsiParser.SETS:
for r in s:
_setmap[r] = s
del s

View File

@@ -0,0 +1,249 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
#
"""Module to emulate a VT100 terminal in Tkinter.
Maintainer: Paul Swartz
"""
import string
import tkinter as Tkinter
import tkinter.font as tkFont
from . import ansi
ttyFont = None # tkFont.Font(family = 'Courier', size = 10)
fontWidth, fontHeight = (
None,
None,
) # max(map(ttyFont.measure, string.letters+string.digits)), int(ttyFont.metrics()['linespace'])
colorKeys = (
"b",
"r",
"g",
"y",
"l",
"m",
"c",
"w",
"B",
"R",
"G",
"Y",
"L",
"M",
"C",
"W",
)
colorMap = {
"b": "#000000",
"r": "#c40000",
"g": "#00c400",
"y": "#c4c400",
"l": "#000080",
"m": "#c400c4",
"c": "#00c4c4",
"w": "#c4c4c4",
"B": "#626262",
"R": "#ff0000",
"G": "#00ff00",
"Y": "#ffff00",
"L": "#0000ff",
"M": "#ff00ff",
"C": "#00ffff",
"W": "#ffffff",
}
class VT100Frame(Tkinter.Frame):
def __init__(self, *args, **kw):
global ttyFont, fontHeight, fontWidth
ttyFont = tkFont.Font(family="Courier", size=10)
fontWidth = max(map(ttyFont.measure, string.ascii_letters + string.digits))
fontHeight = int(ttyFont.metrics()["linespace"])
self.width = kw.get("width", 80)
self.height = kw.get("height", 25)
self.callback = kw["callback"]
del kw["callback"]
kw["width"] = w = fontWidth * self.width
kw["height"] = h = fontHeight * self.height
Tkinter.Frame.__init__(self, *args, **kw)
self.canvas = Tkinter.Canvas(bg="#000000", width=w, height=h)
self.canvas.pack(side=Tkinter.TOP, fill=Tkinter.BOTH, expand=1)
self.canvas.bind("<Key>", self.keyPressed)
self.canvas.bind("<1>", lambda x: "break")
self.canvas.bind("<Up>", self.upPressed)
self.canvas.bind("<Down>", self.downPressed)
self.canvas.bind("<Left>", self.leftPressed)
self.canvas.bind("<Right>", self.rightPressed)
self.canvas.focus()
self.ansiParser = ansi.AnsiParser(ansi.ColorText.WHITE, ansi.ColorText.BLACK)
self.ansiParser.writeString = self.writeString
self.ansiParser.parseCursor = self.parseCursor
self.ansiParser.parseErase = self.parseErase
# for (a, b) in colorMap.items():
# self.canvas.tag_config(a, foreground=b)
# self.canvas.tag_config('b'+a, background=b)
# self.canvas.tag_config('underline', underline=1)
self.x = 0
self.y = 0
self.cursor = self.canvas.create_rectangle(
0, 0, fontWidth - 1, fontHeight - 1, fill="green", outline="green"
)
def _delete(self, sx, sy, ex, ey):
csx = sx * fontWidth + 1
csy = sy * fontHeight + 1
cex = ex * fontWidth + 3
cey = ey * fontHeight + 3
items = self.canvas.find_overlapping(csx, csy, cex, cey)
for item in items:
self.canvas.delete(item)
def _write(self, ch, fg, bg):
if self.x == self.width:
self.x = 0
self.y += 1
if self.y == self.height:
[self.canvas.move(x, 0, -fontHeight) for x in self.canvas.find_all()]
self.y -= 1
canvasX = self.x * fontWidth + 1
canvasY = self.y * fontHeight + 1
items = self.canvas.find_overlapping(canvasX, canvasY, canvasX + 2, canvasY + 2)
if items:
[self.canvas.delete(item) for item in items]
if bg:
self.canvas.create_rectangle(
canvasX,
canvasY,
canvasX + fontWidth - 1,
canvasY + fontHeight - 1,
fill=bg,
outline=bg,
)
self.canvas.create_text(
canvasX, canvasY, anchor=Tkinter.NW, font=ttyFont, text=ch, fill=fg
)
self.x += 1
def write(self, data):
self.ansiParser.parseString(data)
self.canvas.delete(self.cursor)
canvasX = self.x * fontWidth + 1
canvasY = self.y * fontHeight + 1
self.cursor = self.canvas.create_rectangle(
canvasX,
canvasY,
canvasX + fontWidth - 1,
canvasY + fontHeight - 1,
fill="green",
outline="green",
)
self.canvas.lower(self.cursor)
def writeString(self, i):
if not i.display:
return
fg = colorMap[i.fg]
bg = i.bg != "b" and colorMap[i.bg]
for ch in i.text:
b = ord(ch)
if b == 7: # bell
self.bell()
elif b == 8: # BS
if self.x:
self.x -= 1
elif b == 9: # TAB
[self._write(" ", fg, bg) for index in range(8)]
elif b == 10:
if self.y == self.height - 1:
self._delete(0, 0, self.width, 0)
[
self.canvas.move(x, 0, -fontHeight)
for x in self.canvas.find_all()
]
else:
self.y += 1
elif b == 13:
self.x = 0
elif 32 <= b < 127:
self._write(ch, fg, bg)
def parseErase(self, erase):
if ";" in erase:
end = erase[-1]
parts = erase[:-1].split(";")
[self.parseErase(x + end) for x in parts]
return
start = 0
x, y = self.x, self.y
if len(erase) > 1:
start = int(erase[:-1])
if erase[-1] == "J":
if start == 0:
self._delete(x, y, self.width, self.height)
else:
self._delete(0, 0, self.width, self.height)
self.x = 0
self.y = 0
elif erase[-1] == "K":
if start == 0:
self._delete(x, y, self.width, y)
elif start == 1:
self._delete(0, y, x, y)
self.x = 0
else:
self._delete(0, y, self.width, y)
self.x = 0
elif erase[-1] == "P":
self._delete(x, y, x + start, y)
def parseCursor(self, cursor):
# if ';' in cursor and cursor[-1]!='H':
# end = cursor[-1]
# parts = cursor[:-1].split(';')
# [self.parseCursor(x+end) for x in parts]
# return
start = 1
if len(cursor) > 1 and cursor[-1] != "H":
start = int(cursor[:-1])
if cursor[-1] == "C":
self.x += start
elif cursor[-1] == "D":
self.x -= start
elif cursor[-1] == "d":
self.y = start - 1
elif cursor[-1] == "G":
self.x = start - 1
elif cursor[-1] == "H":
if len(cursor) > 1:
y, x = map(int, cursor[:-1].split(";"))
y -= 1
x -= 1
else:
x, y = 0, 0
self.x = x
self.y = y
def keyPressed(self, event):
if self.callback and event.char:
self.callback(event.char)
return "break"
def upPressed(self, event):
self.callback("\x1bOA")
def downPressed(self, event):
self.callback("\x1bOB")
def rightPressed(self, event):
self.callback("\x1bOC")
def leftPressed(self, event):
self.callback("\x1bOD")

View File

@@ -0,0 +1,524 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
A UNIX SSH server.
"""
from __future__ import annotations
import fcntl
import grp
import os
import pty
import pwd
import socket
import struct
import time
import tty
from typing import Callable, Dict, Tuple
from zope.interface import implementer
from twisted.conch import ttymodes
from twisted.conch.avatar import ConchUser
from twisted.conch.error import ConchError
from twisted.conch.interfaces import ISession, ISFTPFile, ISFTPServer
from twisted.conch.ls import lsLine
from twisted.conch.ssh import filetransfer, forwarding, session
from twisted.conch.ssh.filetransfer import (
FXF_APPEND,
FXF_CREAT,
FXF_EXCL,
FXF_READ,
FXF_TRUNC,
FXF_WRITE,
)
from twisted.cred import portal
from twisted.cred.error import LoginDenied
from twisted.internet.error import ProcessExitedAlready
from twisted.internet.interfaces import IListeningPort
from twisted.logger import Logger
from twisted.python import components
from twisted.python.compat import nativeString
try:
import utmp
except ImportError:
utmp = None
@implementer(portal.IRealm)
class UnixSSHRealm:
def requestAvatar(
self,
username: bytes | Tuple[()],
mind: object,
*interfaces: portal._InterfaceItself,
) -> Tuple[portal._InterfaceItself, UnixConchUser, Callable[[], None]]:
if not isinstance(username, bytes):
raise LoginDenied("UNIX SSH realm does not authorize anonymous sessions.")
user = UnixConchUser(username.decode())
return interfaces[0], user, user.logout
class UnixConchUser(ConchUser):
def __init__(self, username: str) -> None:
ConchUser.__init__(self)
self.username = username
self.pwdData = pwd.getpwnam(self.username)
l = [self.pwdData[3]]
for groupname, password, gid, userlist in grp.getgrall():
if username in userlist:
l.append(gid)
self.otherGroups = l
self.listeners: Dict[
str, IListeningPort
] = {} # Dict mapping (interface, port) -> listener
self.channelLookup.update(
{
b"session": session.SSHSession,
b"direct-tcpip": forwarding.openConnectForwardingClient,
}
)
self.subsystemLookup.update({b"sftp": filetransfer.FileTransferServer})
def getUserGroupId(self):
return self.pwdData[2:4]
def getOtherGroups(self):
return self.otherGroups
def getHomeDir(self):
return self.pwdData[5]
def getShell(self):
return self.pwdData[6]
def global_tcpip_forward(self, data):
hostToBind, portToBind = forwarding.unpackGlobal_tcpip_forward(data)
from twisted.internet import reactor
try:
listener = self._runAsUser(
reactor.listenTCP,
portToBind,
forwarding.SSHListenForwardingFactory(
self.conn,
(hostToBind, portToBind),
forwarding.SSHListenServerForwardingChannel,
),
interface=hostToBind,
)
except BaseException:
return 0
else:
self.listeners[(hostToBind, portToBind)] = listener
if portToBind == 0:
portToBind = listener.getHost()[2] # The port
return 1, struct.pack(">L", portToBind)
else:
return 1
def global_cancel_tcpip_forward(self, data):
hostToBind, portToBind = forwarding.unpackGlobal_tcpip_forward(data)
listener = self.listeners.get((hostToBind, portToBind), None)
if not listener:
return 0
del self.listeners[(hostToBind, portToBind)]
self._runAsUser(listener.stopListening)
return 1
def logout(self) -> None:
# Remove all listeners.
for listener in self.listeners.values():
self._runAsUser(listener.stopListening)
self._log.info(
"avatar {username} logging out ({nlisteners})",
username=self.username,
nlisteners=len(self.listeners),
)
def _runAsUser(self, f, *args, **kw):
euid = os.geteuid()
egid = os.getegid()
groups = os.getgroups()
uid, gid = self.getUserGroupId()
os.setegid(0)
os.seteuid(0)
os.setgroups(self.getOtherGroups())
os.setegid(gid)
os.seteuid(uid)
try:
f = iter(f)
except TypeError:
f = [(f, args, kw)]
try:
for i in f:
func = i[0]
args = len(i) > 1 and i[1] or ()
kw = len(i) > 2 and i[2] or {}
r = func(*args, **kw)
finally:
os.setegid(0)
os.seteuid(0)
os.setgroups(groups)
os.setegid(egid)
os.seteuid(euid)
return r
@implementer(ISession)
class SSHSessionForUnixConchUser:
_log = Logger()
def __init__(self, avatar, reactor=None):
"""
Construct an C{SSHSessionForUnixConchUser}.
@param avatar: The L{UnixConchUser} for whom this is an SSH session.
@param reactor: An L{IReactorProcess} used to handle shell and exec
requests. Uses the default reactor if None.
"""
if reactor is None:
from twisted.internet import reactor
self._reactor = reactor
self.avatar = avatar
self.environ = {"PATH": "/bin:/usr/bin:/usr/local/bin"}
self.pty = None
self.ptyTuple = 0
def addUTMPEntry(self, loggedIn=1):
if not utmp:
return
ipAddress = self.avatar.conn.transport.transport.getPeer().host
(packedIp,) = struct.unpack("L", socket.inet_aton(ipAddress))
ttyName = self.ptyTuple[2][5:]
t = time.time()
t1 = int(t)
t2 = int((t - t1) * 1e6)
entry = utmp.UtmpEntry()
entry.ut_type = loggedIn and utmp.USER_PROCESS or utmp.DEAD_PROCESS
entry.ut_pid = self.pty.pid
entry.ut_line = ttyName
entry.ut_id = ttyName[-4:]
entry.ut_tv = (t1, t2)
if loggedIn:
entry.ut_user = self.avatar.username
entry.ut_host = socket.gethostbyaddr(ipAddress)[0]
entry.ut_addr_v6 = (packedIp, 0, 0, 0)
a = utmp.UtmpRecord(utmp.UTMP_FILE)
a.pututline(entry)
a.endutent()
b = utmp.UtmpRecord(utmp.WTMP_FILE)
b.pututline(entry)
b.endutent()
def getPty(self, term, windowSize, modes):
self.environ["TERM"] = term
self.winSize = windowSize
self.modes = modes
master, slave = pty.openpty()
ttyname = os.ttyname(slave)
self.environ["SSH_TTY"] = ttyname
self.ptyTuple = (master, slave, ttyname)
def openShell(self, proto):
if not self.ptyTuple: # We didn't get a pty-req.
self._log.error("tried to get shell without pty, failing")
raise ConchError("no pty")
uid, gid = self.avatar.getUserGroupId()
homeDir = self.avatar.getHomeDir()
shell = self.avatar.getShell()
self.environ["USER"] = self.avatar.username
self.environ["HOME"] = homeDir
self.environ["SHELL"] = shell
shellExec = os.path.basename(shell)
peer = self.avatar.conn.transport.transport.getPeer()
host = self.avatar.conn.transport.transport.getHost()
self.environ["SSH_CLIENT"] = f"{peer.host} {peer.port} {host.port}"
self.getPtyOwnership()
self.pty = self._reactor.spawnProcess(
proto,
shell,
[f"-{shellExec}"],
self.environ,
homeDir,
uid,
gid,
usePTY=self.ptyTuple,
)
self.addUTMPEntry()
fcntl.ioctl(self.pty.fileno(), tty.TIOCSWINSZ, struct.pack("4H", *self.winSize))
if self.modes:
self.setModes()
self.oldWrite = proto.transport.write
proto.transport.write = self._writeHack
self.avatar.conn.transport.transport.setTcpNoDelay(1)
def execCommand(self, proto, cmd):
uid, gid = self.avatar.getUserGroupId()
homeDir = self.avatar.getHomeDir()
shell = self.avatar.getShell() or "/bin/sh"
self.environ["HOME"] = homeDir
command = (shell, "-c", cmd)
peer = self.avatar.conn.transport.transport.getPeer()
host = self.avatar.conn.transport.transport.getHost()
self.environ["SSH_CLIENT"] = f"{peer.host} {peer.port} {host.port}"
if self.ptyTuple:
self.getPtyOwnership()
self.pty = self._reactor.spawnProcess(
proto,
shell,
command,
self.environ,
homeDir,
uid,
gid,
usePTY=self.ptyTuple or 0,
)
if self.ptyTuple:
self.addUTMPEntry()
if self.modes:
self.setModes()
self.avatar.conn.transport.transport.setTcpNoDelay(1)
def getPtyOwnership(self):
ttyGid = os.stat(self.ptyTuple[2])[5]
uid, gid = self.avatar.getUserGroupId()
euid, egid = os.geteuid(), os.getegid()
os.setegid(0)
os.seteuid(0)
try:
os.chown(self.ptyTuple[2], uid, ttyGid)
finally:
os.setegid(egid)
os.seteuid(euid)
def setModes(self):
pty = self.pty
attr = tty.tcgetattr(pty.fileno())
for mode, modeValue in self.modes:
if mode not in ttymodes.TTYMODES:
continue
ttyMode = ttymodes.TTYMODES[mode]
if len(ttyMode) == 2: # Flag.
flag, ttyAttr = ttyMode
if not hasattr(tty, ttyAttr):
continue
ttyval = getattr(tty, ttyAttr)
if modeValue:
attr[flag] = attr[flag] | ttyval
else:
attr[flag] = attr[flag] & ~ttyval
elif ttyMode == "OSPEED":
attr[tty.OSPEED] = getattr(tty, f"B{modeValue}")
elif ttyMode == "ISPEED":
attr[tty.ISPEED] = getattr(tty, f"B{modeValue}")
else:
if not hasattr(tty, ttyMode):
continue
ttyval = getattr(tty, ttyMode)
attr[tty.CC][ttyval] = bytes((modeValue,))
tty.tcsetattr(pty.fileno(), tty.TCSANOW, attr)
def eofReceived(self):
if self.pty:
self.pty.closeStdin()
def closed(self):
if self.ptyTuple and os.path.exists(self.ptyTuple[2]):
ttyGID = os.stat(self.ptyTuple[2])[5]
os.chown(self.ptyTuple[2], 0, ttyGID)
if self.pty:
try:
self.pty.signalProcess("HUP")
except (OSError, ProcessExitedAlready):
pass
self.pty.loseConnection()
self.addUTMPEntry(0)
self._log.info("shell closed")
def windowChanged(self, winSize):
self.winSize = winSize
fcntl.ioctl(self.pty.fileno(), tty.TIOCSWINSZ, struct.pack("4H", *self.winSize))
def _writeHack(self, data):
"""
Hack to send ignore messages when we aren't echoing.
"""
if self.pty is not None:
attr = tty.tcgetattr(self.pty.fileno())[3]
if not attr & tty.ECHO and attr & tty.ICANON: # No echo.
self.avatar.conn.transport.sendIgnore("\x00" * (8 + len(data)))
self.oldWrite(data)
@implementer(ISFTPServer)
class SFTPServerForUnixConchUser:
def __init__(self, avatar):
self.avatar = avatar
def _setAttrs(self, path, attrs):
"""
NOTE: this function assumes it runs as the logged-in user:
i.e. under _runAsUser()
"""
if "uid" in attrs and "gid" in attrs:
os.chown(path, attrs["uid"], attrs["gid"])
if "permissions" in attrs:
os.chmod(path, attrs["permissions"])
if "atime" in attrs and "mtime" in attrs:
os.utime(path, (attrs["atime"], attrs["mtime"]))
def _getAttrs(self, s):
return {
"size": s.st_size,
"uid": s.st_uid,
"gid": s.st_gid,
"permissions": s.st_mode,
"atime": int(s.st_atime),
"mtime": int(s.st_mtime),
}
def _absPath(self, path):
home = self.avatar.getHomeDir()
return os.path.join(nativeString(home.path), nativeString(path))
def gotVersion(self, otherVersion, extData):
return {}
def openFile(self, filename, flags, attrs):
return UnixSFTPFile(self, self._absPath(filename), flags, attrs)
def removeFile(self, filename):
filename = self._absPath(filename)
return self.avatar._runAsUser(os.remove, filename)
def renameFile(self, oldpath, newpath):
oldpath = self._absPath(oldpath)
newpath = self._absPath(newpath)
return self.avatar._runAsUser(os.rename, oldpath, newpath)
def makeDirectory(self, path, attrs):
path = self._absPath(path)
return self.avatar._runAsUser(
[(os.mkdir, (path,)), (self._setAttrs, (path, attrs))]
)
def removeDirectory(self, path):
path = self._absPath(path)
self.avatar._runAsUser(os.rmdir, path)
def openDirectory(self, path):
return UnixSFTPDirectory(self, self._absPath(path))
def getAttrs(self, path, followLinks):
path = self._absPath(path)
if followLinks:
s = self.avatar._runAsUser(os.stat, path)
else:
s = self.avatar._runAsUser(os.lstat, path)
return self._getAttrs(s)
def setAttrs(self, path, attrs):
path = self._absPath(path)
self.avatar._runAsUser(self._setAttrs, path, attrs)
def readLink(self, path):
path = self._absPath(path)
return self.avatar._runAsUser(os.readlink, path)
def makeLink(self, linkPath, targetPath):
linkPath = self._absPath(linkPath)
targetPath = self._absPath(targetPath)
return self.avatar._runAsUser(os.symlink, targetPath, linkPath)
def realPath(self, path):
return os.path.realpath(self._absPath(path))
def extendedRequest(self, extName, extData):
raise NotImplementedError
@implementer(ISFTPFile)
class UnixSFTPFile:
def __init__(self, server, filename, flags, attrs):
self.server = server
openFlags = 0
if flags & FXF_READ == FXF_READ and flags & FXF_WRITE == 0:
openFlags = os.O_RDONLY
if flags & FXF_WRITE == FXF_WRITE and flags & FXF_READ == 0:
openFlags = os.O_WRONLY
if flags & FXF_WRITE == FXF_WRITE and flags & FXF_READ == FXF_READ:
openFlags = os.O_RDWR
if flags & FXF_APPEND == FXF_APPEND:
openFlags |= os.O_APPEND
if flags & FXF_CREAT == FXF_CREAT:
openFlags |= os.O_CREAT
if flags & FXF_TRUNC == FXF_TRUNC:
openFlags |= os.O_TRUNC
if flags & FXF_EXCL == FXF_EXCL:
openFlags |= os.O_EXCL
if "permissions" in attrs:
mode = attrs["permissions"]
del attrs["permissions"]
else:
mode = 0o777
fd = server.avatar._runAsUser(os.open, filename, openFlags, mode)
if attrs:
server.avatar._runAsUser(server._setAttrs, filename, attrs)
self.fd = fd
def close(self):
return self.server.avatar._runAsUser(os.close, self.fd)
def readChunk(self, offset, length):
return self.server.avatar._runAsUser(
[(os.lseek, (self.fd, offset, 0)), (os.read, (self.fd, length))]
)
def writeChunk(self, offset, data):
return self.server.avatar._runAsUser(
[(os.lseek, (self.fd, offset, 0)), (os.write, (self.fd, data))]
)
def getAttrs(self):
s = self.server.avatar._runAsUser(os.fstat, self.fd)
return self.server._getAttrs(s)
def setAttrs(self, attrs):
raise NotImplementedError
class UnixSFTPDirectory:
def __init__(self, server, directory):
self.server = server
self.files = server.avatar._runAsUser(os.listdir, directory)
self.dir = directory
def __iter__(self):
return self
def __next__(self):
try:
f = self.files.pop(0)
except IndexError:
raise StopIteration
else:
s = self.server.avatar._runAsUser(os.lstat, os.path.join(self.dir, f))
longname = lsLine(f, s)
attrs = self.server._getAttrs(s)
return (f, longname, attrs)
next = __next__
def close(self):
self.files = []
components.registerAdapter(
SFTPServerForUnixConchUser, UnixConchUser, filetransfer.ISFTPServer
)
components.registerAdapter(SSHSessionForUnixConchUser, UnixConchUser, session.ISession)